Skip to main content

forge_dioxus/
auth.rs

1//! Built-in auth + viewer state management for forge-dioxus.
2//!
3//! Handles token storage, viewer persistence, refresh loops, and 401
4//! recovery. Apps get viewer access for free without writing their own
5//! storage layer.
6
7use dioxus::prelude::*;
8use serde::de::DeserializeOwned;
9use serde::{Deserialize, Serialize};
10
11use crate::{ConnectionState, ForgeClient, ForgeClientConfig};
12
13/// Persisted auth data: tokens + optional viewer.
14/// Backward compatible with old format (viewer is optional).
15#[derive(Debug, Clone, Serialize, Deserialize)]
16struct StoredAuth {
17    access_token: String,
18    refresh_token: String,
19    #[serde(default, skip_serializing_if = "Option::is_none")]
20    viewer: Option<serde_json::Value>,
21}
22
23/// Auth state tracked by the framework.
24#[derive(Debug, Clone)]
25pub enum ForgeAuthState {
26    Unauthenticated,
27    Authenticated {
28        access_token: String,
29        refresh_token: String,
30        viewer: Option<serde_json::Value>,
31    },
32}
33
34impl ForgeAuthState {
35    pub fn is_authenticated(&self) -> bool {
36        matches!(self, Self::Authenticated { .. })
37    }
38
39    pub fn access_token(&self) -> Option<String> {
40        match self {
41            Self::Authenticated { access_token, .. } => Some(access_token.clone()),
42            Self::Unauthenticated => None,
43        }
44    }
45
46    pub fn refresh_token(&self) -> Option<String> {
47        match self {
48            Self::Authenticated { refresh_token, .. } => Some(refresh_token.clone()),
49            Self::Unauthenticated => None,
50        }
51    }
52
53    fn viewer_json(&self) -> Option<&serde_json::Value> {
54        match self {
55            Self::Authenticated { viewer, .. } => viewer.as_ref(),
56            Self::Unauthenticated => None,
57        }
58    }
59}
60
61/// Auth handle provided to components via `use_forge_auth()`.
62#[derive(Clone, Copy)]
63pub struct ForgeAuth {
64    state: Signal<ForgeAuthState>,
65    app_name: Signal<String>,
66    generation: Signal<u64>,
67}
68
69impl ForgeAuth {
70    pub fn is_authenticated(&self) -> bool {
71        self.state.read().is_authenticated()
72    }
73
74    pub fn access_token(&self) -> Option<String> {
75        self.state.read().access_token()
76    }
77
78    pub fn refresh_token(&self) -> Option<String> {
79        self.state.read().refresh_token()
80    }
81
82    /// Read the stored viewer, deserialized into the app's type.
83    pub fn viewer<V: DeserializeOwned>(&self) -> Option<V> {
84        let state = self.state.read();
85        let json = state.viewer_json()?;
86        serde_json::from_value(json.clone()).ok()
87    }
88
89    /// Set tokens after login/register (no viewer).
90    pub fn login(&mut self, access_token: String, refresh_token: String) {
91        self.save_and_set(access_token, refresh_token, None);
92    }
93
94    /// Set tokens + viewer after login/register.
95    pub fn login_with_viewer<V: Serialize>(
96        &mut self,
97        access_token: String,
98        refresh_token: String,
99        viewer: &V,
100    ) {
101        let viewer_json = serde_json::to_value(viewer).ok();
102        self.save_and_set(access_token, refresh_token, viewer_json);
103    }
104
105    /// Update tokens (e.g., after a refresh). Preserves existing viewer.
106    pub fn update_tokens(&mut self, access_token: String, refresh_token: String) {
107        let existing_viewer = self.state.read().viewer_json().cloned();
108        self.save_and_set(access_token, refresh_token, existing_viewer);
109    }
110
111    /// Update just the viewer without touching tokens.
112    pub fn update_viewer<V: Serialize>(&mut self, viewer: &V) {
113        let state = self.state.read();
114        let (access_token, refresh_token) = match &*state {
115            ForgeAuthState::Authenticated {
116                access_token,
117                refresh_token,
118                ..
119            } => (access_token.clone(), refresh_token.clone()),
120            ForgeAuthState::Unauthenticated => return,
121        };
122        drop(state);
123        let viewer_json = serde_json::to_value(viewer).ok();
124        self.save_and_set(access_token, refresh_token, viewer_json);
125    }
126
127    /// Clear tokens, viewer, and log out.
128    pub fn logout(&mut self) {
129        storage::clear(&self.app_name.read());
130        self.state.set(ForgeAuthState::Unauthenticated);
131        self.generation.with_mut(|g| *g += 1);
132    }
133
134    fn save_and_set(
135        &mut self,
136        access_token: String,
137        refresh_token: String,
138        viewer: Option<serde_json::Value>,
139    ) {
140        let stored = StoredAuth {
141            access_token: access_token.clone(),
142            refresh_token: refresh_token.clone(),
143            viewer: viewer.clone(),
144        };
145        storage::save(&self.app_name.read(), &stored);
146        let was_authenticated = self.state.read().is_authenticated();
147        self.state.set(ForgeAuthState::Authenticated {
148            access_token,
149            refresh_token,
150            viewer,
151        });
152        if !was_authenticated {
153            self.generation.with_mut(|g| *g += 1);
154        }
155    }
156}
157
158/// Read the auth handle from context.
159pub fn use_forge_auth() -> ForgeAuth {
160    use_context::<ForgeAuth>()
161}
162
163/// Read the stored viewer, deserialized into the app's viewer type.
164/// Returns `None` when unauthenticated or if the viewer hasn't been set.
165pub fn use_viewer<V: DeserializeOwned + Clone + 'static>() -> Option<V> {
166    use_forge_auth().viewer::<V>()
167}
168
169/// Returns a string key that changes on login/logout transitions.
170/// Use this to key your router or main content area so SSE subscriptions
171/// reconnect with fresh auth state.
172///
173/// ```ignore
174/// let auth_key = use_auth_key();
175/// rsx! { main { key: "{auth_key}", Router::<Route> {} } }
176/// ```
177pub fn use_auth_key() -> String {
178    let auth = use_forge_auth();
179    let generation = auth.generation.read();
180    format!("forge-auth-{generation}")
181}
182
183/// Guard hook: redirects to `redirect_path` when unauthenticated.
184/// Returns `true` if authenticated, `false` during redirect.
185///
186/// ```ignore
187/// fn ProtectedPage() -> Element {
188///     if !use_require_auth("/login") { return rsx! {} }
189///     // ... render protected content
190/// }
191/// ```
192#[cfg(feature = "router")]
193pub fn use_require_auth(redirect_path: &str) -> bool {
194    let auth = use_forge_auth();
195    let navigator = use_navigator();
196    let path = redirect_path.to_string();
197
198    use_effect(move || {
199        if !auth.is_authenticated() {
200            navigator.replace(NavigationTarget::Internal(path.clone()));
201        }
202    });
203
204    auth.is_authenticated()
205}
206
207/// Provider component that sets up auth state, ForgeClient with auto token wiring,
208/// 401 detection, and periodic refresh.
209///
210/// ```ignore
211/// ForgeAuthProvider {
212///     url: "http://localhost:9081",
213///     app_name: "my-app",
214///     children: rsx! { Router::<Route> {} }
215/// }
216/// ```
217/// `refresh_interval_secs`: How often to proactively refresh tokens (default: 2400 = 40 min).
218/// Set to roughly 2/3 of your `access_token_ttl` from forge.toml.
219#[component]
220pub fn ForgeAuthProvider(
221    url: String,
222    #[props(default = "forge_app".to_string())] app_name: String,
223    #[props(default = 2400)] refresh_interval_secs: u64,
224    children: Element,
225) -> Element {
226    let initial = match storage::load(&app_name) {
227        Some(stored) => ForgeAuthState::Authenticated {
228            access_token: stored.access_token,
229            refresh_token: stored.refresh_token,
230            viewer: stored.viewer,
231        },
232        None => ForgeAuthState::Unauthenticated,
233    };
234
235    let auth_state = use_context_provider(|| Signal::new(initial));
236    let app_name_signal = use_context_provider(|| Signal::new(app_name));
237    let generation = use_context_provider(|| Signal::new(0_u64));
238    let forge_auth = use_context_provider(|| ForgeAuth {
239        state: auth_state,
240        app_name: app_name_signal,
241        generation,
242    });
243
244    let connection_state = use_context_provider(|| Signal::new(ConnectionState::Disconnected));
245    let needs_refresh = use_signal(|| false);
246
247    // Build ForgeClient with auto token provider and auth error handler
248    let url_clone = url.clone();
249    use_context_provider(move || {
250        let auth_for_token = auth_state;
251        let needs_refresh_clone = needs_refresh;
252        let config = ForgeClientConfig::new(url_clone)
253            .with_connection_state(connection_state)
254            .with_token_provider(move || auth_for_token.read().access_token())
255            .with_auth_error_handler(move |_err| {
256                let mut sig = needs_refresh_clone;
257                sig.set(true);
258            });
259        ForgeClient::new(config)
260    });
261
262    // Handle 401 errors by attempting token refresh
263    let url_for_refresh = url.clone();
264    use_effect(move || {
265        if !*needs_refresh.read() {
266            return;
267        }
268        let url = url_for_refresh.clone();
269        let mut auth = forge_auth;
270        spawn(async move {
271            try_refresh_tokens(&url, &mut auth).await;
272        });
273    });
274
275    // Periodic refresh (default every 40 minutes, configurable via refresh_interval_secs)
276    let url_for_periodic = url;
277    use_future(move || {
278        let url = url_for_periodic.clone();
279        let mut auth = forge_auth;
280        async move {
281            loop {
282                sleep(refresh_interval_secs).await;
283                if auth.is_authenticated() {
284                    try_refresh_tokens(&url, &mut auth).await;
285                }
286            }
287        }
288    });
289
290    rsx! { {children} }
291}
292
293/// Attempt to refresh tokens using an anonymous client.
294///
295/// Only logs out on definitive auth failures (401/403). Network errors
296/// are silently ignored so transient connectivity issues in hospital
297/// networks don't force unnecessary logouts.
298async fn try_refresh_tokens(api_url: &str, auth: &mut ForgeAuth) {
299    let refresh_token = match auth.refresh_token() {
300        Some(t) => t,
301        None => return,
302    };
303
304    let anon_client = ForgeClient::new(ForgeClientConfig::new(api_url.to_string()));
305
306    #[derive(Serialize)]
307    struct RefreshArgs {
308        refresh_token: String,
309    }
310
311    #[derive(Deserialize)]
312    struct RefreshResponse {
313        access_token: String,
314        refresh_token: String,
315    }
316
317    match anon_client
318        .call::<_, RefreshResponse>(
319            "refresh",
320            RefreshArgs {
321                refresh_token,
322            },
323        )
324        .await
325    {
326        Ok(resp) => {
327            auth.update_tokens(resp.access_token, resp.refresh_token);
328        }
329        Err(ref e)
330            if e.code == "UNAUTHORIZED"
331                || e.code == "FORBIDDEN"
332                || e.code == "NOT_FOUND" =>
333        {
334            // Definitive auth failure: token is invalid/expired/revoked.
335            auth.logout();
336        }
337        Err(_) => {
338            // Network or transient error. Keep current tokens and retry
339            // on the next refresh cycle rather than forcing a logout.
340        }
341    }
342}
343
344/// Platform-specific sleep (works on both WASM and native).
345async fn sleep(secs: u64) {
346    #[cfg(target_arch = "wasm32")]
347    gloo_timers::future::TimeoutFuture::new((secs * 1000) as u32).await;
348
349    #[cfg(not(target_arch = "wasm32"))]
350    tokio::time::sleep(std::time::Duration::from_secs(secs)).await;
351}
352
353// Platform-specific auth storage
354#[cfg(target_arch = "wasm32")]
355mod storage {
356    use super::StoredAuth;
357
358    fn key(app_name: &str) -> String {
359        format!("{app_name}_auth")
360    }
361
362    pub fn save(app_name: &str, auth: &StoredAuth) {
363        if let Ok(json) = serde_json::to_string(auth) {
364            if let Some(storage) = web_sys::window()
365                .and_then(|w| w.local_storage().ok())
366                .flatten()
367            {
368                let _ = storage.set_item(&key(app_name), &json);
369            }
370        }
371    }
372
373    pub fn load(app_name: &str) -> Option<StoredAuth> {
374        let storage = web_sys::window()?.local_storage().ok()??;
375        let json = storage.get_item(&key(app_name)).ok()??;
376        serde_json::from_str(&json).ok()
377    }
378
379    pub fn clear(app_name: &str) {
380        if let Some(storage) = web_sys::window()
381            .and_then(|w| w.local_storage().ok())
382            .flatten()
383        {
384            let _ = storage.remove_item(&key(app_name));
385        }
386    }
387}
388
389#[cfg(not(target_arch = "wasm32"))]
390mod storage {
391    use super::StoredAuth;
392    use std::fs;
393    use std::path::PathBuf;
394
395    fn storage_path(app_name: &str) -> Option<PathBuf> {
396        dirs::data_local_dir().map(|base| base.join(app_name).join("auth.json"))
397    }
398
399    pub fn save(app_name: &str, auth: &StoredAuth) {
400        let Some(path) = storage_path(app_name) else {
401            return;
402        };
403        if let Some(parent) = path.parent() {
404            let _ = fs::create_dir_all(parent);
405        }
406        if let Ok(json) = serde_json::to_vec(auth) {
407            let tmp = path.with_extension("tmp");
408            let _ = fs::write(&tmp, json).and_then(|()| fs::rename(tmp, path));
409        }
410    }
411
412    pub fn load(app_name: &str) -> Option<StoredAuth> {
413        let path = storage_path(app_name)?;
414        let json = fs::read_to_string(path).ok()?;
415        serde_json::from_str(&json).ok()
416    }
417
418    pub fn clear(app_name: &str) {
419        if let Some(path) = storage_path(app_name) {
420            let _ = fs::remove_file(path);
421        }
422    }
423}