Skip to main content

forge_dioxus/
auth.rs

1//! Built-in auth + viewer state management for forge-dioxus.
2//!
3//! Handles token storage, viewer persistence, refresh loops, and 401
4//! recovery. Apps get viewer access for free without writing their own
5//! storage layer.
6
7use dioxus::prelude::*;
8use serde::de::DeserializeOwned;
9use serde::{Deserialize, Serialize};
10
11use crate::signals::{ForgeSignals, SignalsConfig, setup_auto_capture};
12use crate::{ConnectionState, ForgeClient, ForgeClientConfig};
13
14/// Persisted auth data: tokens + optional viewer.
15#[derive(Debug, Clone, Serialize, Deserialize)]
16struct StoredAuth {
17    access_token: String,
18    refresh_token: String,
19    viewer: Option<serde_json::Value>,
20}
21
22/// Auth state tracked by the framework.
23#[derive(Debug, Clone)]
24pub enum ForgeAuthState {
25    Unauthenticated,
26    Authenticated {
27        access_token: String,
28        refresh_token: String,
29        viewer: Option<serde_json::Value>,
30    },
31}
32
33impl ForgeAuthState {
34    pub fn is_authenticated(&self) -> bool {
35        matches!(self, Self::Authenticated { .. })
36    }
37
38    pub fn access_token(&self) -> Option<String> {
39        match self {
40            Self::Authenticated { access_token, .. } => Some(access_token.clone()),
41            Self::Unauthenticated => None,
42        }
43    }
44
45    pub fn refresh_token(&self) -> Option<String> {
46        match self {
47            Self::Authenticated { refresh_token, .. } => Some(refresh_token.clone()),
48            Self::Unauthenticated => None,
49        }
50    }
51
52    fn viewer_json(&self) -> Option<&serde_json::Value> {
53        match self {
54            Self::Authenticated { viewer, .. } => viewer.as_ref(),
55            Self::Unauthenticated => None,
56        }
57    }
58}
59
60/// Auth handle provided to components via `use_forge_auth()`.
61#[derive(Clone, Copy)]
62pub struct ForgeAuth {
63    state: Signal<ForgeAuthState>,
64    app_name: Signal<String>,
65    generation: Signal<u64>,
66}
67
68impl ForgeAuth {
69    pub fn is_authenticated(&self) -> bool {
70        self.state.read().is_authenticated()
71    }
72
73    pub fn access_token(&self) -> Option<String> {
74        self.state.read().access_token()
75    }
76
77    pub fn refresh_token(&self) -> Option<String> {
78        self.state.read().refresh_token()
79    }
80
81    /// Read the stored viewer, deserialized into the app's type.
82    pub fn viewer<V: DeserializeOwned>(&self) -> Option<V> {
83        let state = self.state.read();
84        let json = state.viewer_json()?;
85        serde_json::from_value(json.clone()).ok()
86    }
87
88    /// Set tokens after login/register (no viewer).
89    pub fn login(&mut self, access_token: String, refresh_token: String) {
90        self.save_and_set(access_token, refresh_token, None);
91    }
92
93    /// Set tokens + viewer after login/register.
94    pub fn login_with_viewer<V: Serialize>(
95        &mut self,
96        access_token: String,
97        refresh_token: String,
98        viewer: &V,
99    ) {
100        let viewer_json = serde_json::to_value(viewer).ok();
101        self.save_and_set(access_token, refresh_token, viewer_json);
102    }
103
104    /// Update tokens (e.g., after a refresh). Preserves existing viewer.
105    pub fn update_tokens(&mut self, access_token: String, refresh_token: String) {
106        let existing_viewer = self.state.read().viewer_json().cloned();
107        self.save_and_set(access_token, refresh_token, existing_viewer);
108    }
109
110    /// Update just the viewer without touching tokens.
111    pub fn update_viewer<V: Serialize>(&mut self, viewer: &V) {
112        let state = self.state.read();
113        let (access_token, refresh_token) = match &*state {
114            ForgeAuthState::Authenticated {
115                access_token,
116                refresh_token,
117                ..
118            } => (access_token.clone(), refresh_token.clone()),
119            ForgeAuthState::Unauthenticated => return,
120        };
121        drop(state);
122        let viewer_json = serde_json::to_value(viewer).ok();
123        self.save_and_set(access_token, refresh_token, viewer_json);
124    }
125
126    /// Clear tokens, viewer, and log out.
127    pub fn logout(&mut self) {
128        storage::clear(&self.app_name.read());
129        self.state.set(ForgeAuthState::Unauthenticated);
130        self.generation.with_mut(|g| *g += 1);
131    }
132
133    fn save_and_set(
134        &mut self,
135        access_token: String,
136        refresh_token: String,
137        viewer: Option<serde_json::Value>,
138    ) {
139        let stored = StoredAuth {
140            access_token: access_token.clone(),
141            refresh_token: refresh_token.clone(),
142            viewer: viewer.clone(),
143        };
144        storage::save(&self.app_name.read(), &stored);
145        let was_authenticated = self.state.read().is_authenticated();
146        self.state.set(ForgeAuthState::Authenticated {
147            access_token,
148            refresh_token,
149            viewer,
150        });
151        if !was_authenticated {
152            self.generation.with_mut(|g| *g += 1);
153        }
154    }
155}
156
157/// Read the auth handle from context.
158pub fn use_forge_auth() -> ForgeAuth {
159    use_context::<ForgeAuth>()
160}
161
162/// Read the stored viewer, deserialized into the app's viewer type.
163/// Returns `None` when unauthenticated or if the viewer hasn't been set.
164pub fn use_viewer<V: DeserializeOwned + Clone + 'static>() -> Option<V> {
165    use_forge_auth().viewer::<V>()
166}
167
168/// Returns a string key that changes on login/logout transitions.
169/// Use this to key your router or main content area so SSE subscriptions
170/// reconnect with fresh auth state.
171///
172/// ```ignore
173/// let auth_key = use_auth_key();
174/// rsx! { main { key: "{auth_key}", Router::<Route> {} } }
175/// ```
176pub fn use_auth_key() -> String {
177    let auth = use_forge_auth();
178    let generation = auth.generation.read();
179    format!("forge-auth-{generation}")
180}
181
182/// Guard hook: redirects to `redirect_path` when unauthenticated.
183/// Returns `true` if authenticated, `false` during redirect.
184///
185/// ```ignore
186/// fn ProtectedPage() -> Element {
187///     if !use_require_auth("/login") { return rsx! {} }
188///     // ... render protected content
189/// }
190/// ```
191#[cfg(feature = "router")]
192pub fn use_require_auth(redirect_path: &str) -> bool {
193    let auth = use_forge_auth();
194    let navigator = use_navigator();
195    let path = redirect_path.to_string();
196
197    use_effect(move || {
198        if !auth.is_authenticated() {
199            navigator.replace(NavigationTarget::Internal(path.clone()));
200        }
201    });
202
203    auth.is_authenticated()
204}
205
206/// Provider component that sets up auth state, ForgeClient with auto token wiring,
207/// 401 detection, and periodic refresh.
208///
209/// ```ignore
210/// ForgeAuthProvider {
211///     url: "http://localhost:9081",
212///     app_name: "my-app",
213///     children: rsx! { Router::<Route> {} }
214/// }
215/// ```
216/// `refresh_interval_secs`: How often to proactively refresh tokens (default: 2400 = 40 min).
217/// Set to roughly 2/3 of your `access_token_ttl` from forge.toml.
218#[component]
219pub fn ForgeAuthProvider(
220    url: String,
221    #[props(default = "forge_app".to_string())] app_name: String,
222    #[props(default = 2400)] refresh_interval_secs: u64,
223    children: Element,
224) -> Element {
225    let initial = match storage::load(&app_name) {
226        Some(stored) => ForgeAuthState::Authenticated {
227            access_token: stored.access_token,
228            refresh_token: stored.refresh_token,
229            viewer: stored.viewer,
230        },
231        None => ForgeAuthState::Unauthenticated,
232    };
233
234    let auth_state = use_context_provider(|| Signal::new(initial));
235    let app_name_signal = use_context_provider(|| Signal::new(app_name));
236    let generation = use_context_provider(|| Signal::new(0_u64));
237    let forge_auth = use_context_provider(|| ForgeAuth {
238        state: auth_state,
239        app_name: app_name_signal,
240        generation,
241    });
242
243    let connection_state = use_context_provider(|| Signal::new(ConnectionState::Disconnected));
244    let needs_refresh = use_signal(|| false);
245
246    // Build ForgeClient with auto token provider and auth error handler
247    let url_clone = url.clone();
248    use_context_provider(move || {
249        let auth_for_token = auth_state;
250        let needs_refresh_clone = needs_refresh;
251        let config = ForgeClientConfig::new(url_clone)
252            .with_connection_state(connection_state)
253            .with_token_provider(move || auth_for_token.read().access_token())
254            .with_auth_error_handler(move |_err| {
255                let mut sig = needs_refresh_clone;
256                sig.set(true);
257            });
258        ForgeClient::new(config)
259    });
260
261    // Handle 401 errors by attempting token refresh
262    let url_for_refresh = url.clone();
263    use_effect(move || {
264        if !*needs_refresh.read() {
265            return;
266        }
267        let url = url_for_refresh.clone();
268        let mut auth = forge_auth;
269        spawn(async move {
270            try_refresh_tokens(&url, &mut auth).await;
271        });
272    });
273
274    // Periodic refresh (default every 40 minutes, configurable via refresh_interval_secs)
275    let url_for_periodic = url;
276    use_future(move || {
277        let url = url_for_periodic.clone();
278        let mut auth = forge_auth;
279        async move {
280            loop {
281                sleep(refresh_interval_secs).await;
282                if auth.is_authenticated() {
283                    try_refresh_tokens(&url, &mut auth).await;
284                }
285            }
286        }
287    });
288
289    // Initialize signals (must come after ForgeClient is provided)
290    let client: ForgeClient = use_context();
291    let signals_instance = use_context_provider(|| {
292        let s = ForgeSignals::new(client.clone(), SignalsConfig::default());
293        client.set_signals(s.clone());
294        s
295    });
296    use_hook(|| {
297        setup_auto_capture(signals_instance);
298    });
299
300    rsx! { {children} }
301}
302
303/// Attempt to refresh tokens using an anonymous client.
304///
305/// Only logs out on definitive auth failures (401/403). Network errors
306/// are silently ignored so transient connectivity issues in hospital
307/// networks don't force unnecessary logouts.
308async fn try_refresh_tokens(api_url: &str, auth: &mut ForgeAuth) {
309    let refresh_token = match auth.refresh_token() {
310        Some(t) => t,
311        None => return,
312    };
313
314    let anon_client = ForgeClient::new(ForgeClientConfig::new(api_url.to_string()));
315
316    #[derive(Serialize)]
317    struct RefreshArgs {
318        refresh_token: String,
319    }
320
321    #[derive(Deserialize)]
322    struct RefreshResponse {
323        access_token: String,
324        refresh_token: String,
325    }
326
327    match anon_client
328        .call::<_, RefreshResponse>(
329            "refresh",
330            RefreshArgs {
331                refresh_token,
332            },
333        )
334        .await
335    {
336        Ok(resp) => {
337            auth.update_tokens(resp.access_token, resp.refresh_token);
338        }
339        Err(ref e)
340            if e.code == "UNAUTHORIZED"
341                || e.code == "FORBIDDEN"
342                || e.code == "NOT_FOUND" =>
343        {
344            // Definitive auth failure: token is invalid/expired/revoked.
345            auth.logout();
346        }
347        Err(_) => {
348            // Network or transient error. Keep current tokens and retry
349            // on the next refresh cycle rather than forcing a logout.
350        }
351    }
352}
353
354/// Platform-specific sleep (works on both WASM and native).
355async fn sleep(secs: u64) {
356    #[cfg(target_arch = "wasm32")]
357    gloo_timers::future::TimeoutFuture::new((secs * 1000) as u32).await;
358
359    #[cfg(not(target_arch = "wasm32"))]
360    tokio::time::sleep(std::time::Duration::from_secs(secs)).await;
361}
362
363// Platform-specific auth storage
364#[cfg(target_arch = "wasm32")]
365mod storage {
366    use super::StoredAuth;
367
368    fn key(app_name: &str) -> String {
369        format!("{app_name}_auth")
370    }
371
372    pub fn save(app_name: &str, auth: &StoredAuth) {
373        if let Ok(json) = serde_json::to_string(auth) {
374            if let Some(storage) = web_sys::window()
375                .and_then(|w| w.local_storage().ok())
376                .flatten()
377            {
378                let _ = storage.set_item(&key(app_name), &json);
379            }
380        }
381    }
382
383    pub fn load(app_name: &str) -> Option<StoredAuth> {
384        let storage = web_sys::window()?.local_storage().ok()??;
385        let json = storage.get_item(&key(app_name)).ok()??;
386        serde_json::from_str(&json).ok()
387    }
388
389    pub fn clear(app_name: &str) {
390        if let Some(storage) = web_sys::window()
391            .and_then(|w| w.local_storage().ok())
392            .flatten()
393        {
394            let _ = storage.remove_item(&key(app_name));
395        }
396    }
397}
398
399#[cfg(not(target_arch = "wasm32"))]
400mod storage {
401    use super::StoredAuth;
402    use std::fs;
403    use std::path::PathBuf;
404
405    fn storage_path(app_name: &str) -> Option<PathBuf> {
406        dirs::data_local_dir().map(|base| base.join(app_name).join("auth.json"))
407    }
408
409    pub fn save(app_name: &str, auth: &StoredAuth) {
410        let Some(path) = storage_path(app_name) else {
411            return;
412        };
413        if let Some(parent) = path.parent() {
414            let _ = fs::create_dir_all(parent);
415        }
416        if let Ok(json) = serde_json::to_vec(auth) {
417            let tmp = path.with_extension("tmp");
418            let _ = fs::write(&tmp, json).and_then(|()| fs::rename(tmp, path));
419        }
420    }
421
422    pub fn load(app_name: &str) -> Option<StoredAuth> {
423        let path = storage_path(app_name)?;
424        let json = fs::read_to_string(path).ok()?;
425        serde_json::from_str(&json).ok()
426    }
427
428    pub fn clear(app_name: &str) {
429        if let Some(path) = storage_path(app_name) {
430            let _ = fs::remove_file(path);
431        }
432    }
433}
434
435#[cfg(test)]
436mod tests {
437    use super::*;
438
439    #[test]
440    fn test_authenticated_state_exposes_tokens_and_viewer() {
441        let viewer = serde_json::json!({"id": "user-1", "role": "admin"});
442        let state = ForgeAuthState::Authenticated {
443            access_token: "access-token".into(),
444            refresh_token: "refresh-token".into(),
445            viewer: Some(viewer.clone()),
446        };
447
448        assert!(state.is_authenticated());
449        assert_eq!(state.access_token().as_deref(), Some("access-token"));
450        assert_eq!(state.refresh_token().as_deref(), Some("refresh-token"));
451        assert_eq!(state.viewer_json(), Some(&viewer));
452    }
453
454    #[test]
455    fn test_unauthenticated_state_has_no_auth_material() {
456        let state = ForgeAuthState::Unauthenticated;
457
458        assert!(!state.is_authenticated());
459        assert!(state.access_token().is_none());
460        assert!(state.refresh_token().is_none());
461        assert!(state.viewer_json().is_none());
462    }
463
464}