Skip to main content

forge_dioxus/
auth.rs

1//! Built-in auth + viewer state management for forge-dioxus.
2//!
3//! Handles token storage, viewer persistence, refresh loops, and 401
4//! recovery. Apps get viewer access for free without writing their own
5//! storage layer.
6
7use dioxus::prelude::*;
8use serde::de::DeserializeOwned;
9use serde::{Deserialize, Serialize};
10
11use crate::signals::{ForgeSignals, SignalsConfig, setup_auto_capture};
12use crate::{ConnectionState, ForgeClient, ForgeClientConfig};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15struct StoredAuth {
16    access_token: String,
17    refresh_token: String,
18    viewer: Option<serde_json::Value>,
19}
20
21#[derive(Debug, Clone)]
22pub enum ForgeAuthState {
23    Unauthenticated,
24    Authenticated {
25        access_token: String,
26        refresh_token: String,
27        viewer: Option<serde_json::Value>,
28    },
29}
30
31impl ForgeAuthState {
32    pub fn is_authenticated(&self) -> bool {
33        matches!(self, Self::Authenticated { .. })
34    }
35
36    pub fn access_token(&self) -> Option<String> {
37        match self {
38            Self::Authenticated { access_token, .. } => Some(access_token.clone()),
39            Self::Unauthenticated => None,
40        }
41    }
42
43    pub fn refresh_token(&self) -> Option<String> {
44        match self {
45            Self::Authenticated { refresh_token, .. } => Some(refresh_token.clone()),
46            Self::Unauthenticated => None,
47        }
48    }
49
50    fn viewer_json(&self) -> Option<&serde_json::Value> {
51        match self {
52            Self::Authenticated { viewer, .. } => viewer.as_ref(),
53            Self::Unauthenticated => None,
54        }
55    }
56}
57
58/// Auth handle provided to components via `use_forge_auth()`.
59#[derive(Clone, Copy)]
60pub struct ForgeAuth {
61    state: Signal<ForgeAuthState>,
62    app_name: Signal<String>,
63    generation: Signal<u64>,
64}
65
66impl ForgeAuth {
67    pub fn is_authenticated(&self) -> bool {
68        self.state.read().is_authenticated()
69    }
70
71    pub fn access_token(&self) -> Option<String> {
72        self.state.read().access_token()
73    }
74
75    pub fn refresh_token(&self) -> Option<String> {
76        self.state.read().refresh_token()
77    }
78
79    /// Read the stored viewer, deserialized into the app's type.
80    pub fn viewer<V: DeserializeOwned>(&self) -> Option<V> {
81        let state = self.state.read();
82        let json = state.viewer_json()?;
83        serde_json::from_value(json.clone()).ok()
84    }
85
86    /// Set tokens after login/register (no viewer).
87    pub fn login(&mut self, access_token: String, refresh_token: String) {
88        self.save_and_set(access_token, refresh_token, None);
89    }
90
91    /// Set tokens + viewer after login/register.
92    pub fn login_with_viewer<V: Serialize>(
93        &mut self,
94        access_token: String,
95        refresh_token: String,
96        viewer: &V,
97    ) {
98        let viewer_json = serde_json::to_value(viewer).ok();
99        self.save_and_set(access_token, refresh_token, viewer_json);
100    }
101
102    /// Update tokens (e.g., after a refresh). Preserves existing viewer.
103    pub fn update_tokens(&mut self, access_token: String, refresh_token: String) {
104        let existing_viewer = self.state.read().viewer_json().cloned();
105        self.save_and_set(access_token, refresh_token, existing_viewer);
106    }
107
108    /// Update just the viewer without touching tokens.
109    pub fn update_viewer<V: Serialize>(&mut self, viewer: &V) {
110        let state = self.state.read();
111        let (access_token, refresh_token) = match &*state {
112            ForgeAuthState::Authenticated {
113                access_token,
114                refresh_token,
115                ..
116            } => (access_token.clone(), refresh_token.clone()),
117            ForgeAuthState::Unauthenticated => return,
118        };
119        drop(state);
120        let viewer_json = serde_json::to_value(viewer).ok();
121        self.save_and_set(access_token, refresh_token, viewer_json);
122    }
123
124    /// Clear tokens, viewer, and log out.
125    pub fn logout(&mut self) {
126        storage::clear(&self.app_name.read());
127        self.state.set(ForgeAuthState::Unauthenticated);
128        self.generation.with_mut(|g| *g += 1);
129    }
130
131    fn save_and_set(
132        &mut self,
133        access_token: String,
134        refresh_token: String,
135        viewer: Option<serde_json::Value>,
136    ) {
137        let stored = StoredAuth {
138            access_token: access_token.clone(),
139            refresh_token: refresh_token.clone(),
140            viewer: viewer.clone(),
141        };
142        storage::save(&self.app_name.read(), &stored);
143        let was_authenticated = self.state.read().is_authenticated();
144        self.state.set(ForgeAuthState::Authenticated {
145            access_token,
146            refresh_token,
147            viewer,
148        });
149        if !was_authenticated {
150            self.generation.with_mut(|g| *g += 1);
151        }
152    }
153}
154
155/// Read the auth handle from context.
156pub fn use_forge_auth() -> ForgeAuth {
157    use_context::<ForgeAuth>()
158}
159
160/// Read the stored viewer, deserialized into the app's viewer type.
161/// Returns `None` when unauthenticated or if the viewer hasn't been set.
162pub fn use_viewer<V: DeserializeOwned + Clone + 'static>() -> Option<V> {
163    use_forge_auth().viewer::<V>()
164}
165
166/// Returns a string key that changes on login/logout transitions.
167/// Use this to key your router or main content area so SSE subscriptions
168/// reconnect with fresh auth state.
169///
170/// ```ignore
171/// let auth_key = use_auth_key();
172/// rsx! { main { key: "{auth_key}", Router::<Route> {} } }
173/// ```
174pub fn use_auth_key() -> String {
175    let auth = use_forge_auth();
176    let generation = auth.generation.read();
177    format!("forge-auth-{generation}")
178}
179
180/// Guard hook: redirects to `redirect_path` when unauthenticated.
181/// Returns `true` if authenticated, `false` during redirect.
182///
183/// ```ignore
184/// fn ProtectedPage() -> Element {
185///     if !use_require_auth("/login") { return rsx! {} }
186///     // ... render protected content
187/// }
188/// ```
189#[cfg(feature = "router")]
190pub fn use_require_auth(redirect_path: &str) -> bool {
191    let auth = use_forge_auth();
192    let navigator = use_navigator();
193    let path = redirect_path.to_string();
194
195    use_effect(move || {
196        if !auth.is_authenticated() {
197            navigator.replace(NavigationTarget::Internal(path.clone()));
198        }
199    });
200
201    auth.is_authenticated()
202}
203
204/// Provider component that sets up auth state, ForgeClient with auto token wiring,
205/// 401 detection, and periodic refresh.
206///
207/// ```ignore
208/// ForgeAuthProvider {
209///     url: "http://localhost:9081",
210///     app_name: "my-app",
211///     children: rsx! { Router::<Route> {} }
212/// }
213/// ```
214/// `refresh_interval_secs`: How often to proactively refresh tokens (default: 2400 = 40 min).
215/// Set to roughly 2/3 of your `access_token_ttl` from forge.toml.
216#[component]
217pub fn ForgeAuthProvider(
218    url: String,
219    #[props(default = "forge_app".to_string())] app_name: String,
220    #[props(default = 2400)] refresh_interval_secs: u64,
221    #[props(default)] on_mutation_error: Option<EventHandler<crate::ForgeClientError>>,
222    children: Element,
223) -> Element {
224    let initial = match storage::load(&app_name) {
225        Some(stored) => ForgeAuthState::Authenticated {
226            access_token: stored.access_token,
227            refresh_token: stored.refresh_token,
228            viewer: stored.viewer,
229        },
230        None => ForgeAuthState::Unauthenticated,
231    };
232
233    let auth_state = use_context_provider(|| Signal::new(initial));
234    let app_name_signal = use_context_provider(|| Signal::new(app_name));
235    let generation = use_context_provider(|| Signal::new(0_u64));
236    let forge_auth = use_context_provider(|| ForgeAuth {
237        state: auth_state,
238        app_name: app_name_signal,
239        generation,
240    });
241
242    let connection_state = use_context_provider(|| Signal::new(ConnectionState::Disconnected));
243    let needs_refresh = use_signal(|| false);
244
245    let url_clone = url.clone();
246    use_context_provider(move || {
247        let auth_for_token = auth_state;
248        let needs_refresh_clone = needs_refresh;
249        let mut config = ForgeClientConfig::new(url_clone)
250            .with_connection_state(connection_state)
251            .with_token_provider(move || auth_for_token.read().access_token())
252            .with_auth_error_handler(move |_err| {
253                let mut sig = needs_refresh_clone;
254                sig.set(true);
255            });
256        if let Some(handler) = on_mutation_error {
257            config = config.with_mutation_error_handler(move |err| handler.call(err));
258        }
259        ForgeClient::new(config)
260    });
261
262    // Must follow use_context_provider above so ForgeClient is available
263    let client: ForgeClient = use_context();
264
265    let url_for_refresh = url.clone();
266    let client_for_refresh = client.clone();
267    use_effect(move || {
268        if !*needs_refresh.read() {
269            return;
270        }
271        let url = url_for_refresh.clone();
272        let mut auth = forge_auth;
273        let client = client_for_refresh.clone();
274        let mut needs_refresh_sig = needs_refresh;
275        spawn(async move {
276            try_refresh_tokens(&url, &mut auth, &client).await;
277            needs_refresh_sig.set(false);
278        });
279    });
280
281    // Short sleeps (~30 s) instead of one long sleep so browser background-tab
282    // throttling can't push the refresh past token expiry.
283    let url_for_periodic = url;
284    let client_for_periodic = client.clone();
285    use_future(move || {
286        let url = url_for_periodic.clone();
287        let client = client_for_periodic.clone();
288        let mut auth = forge_auth;
289        async move {
290            let poll_secs: u64 = 30;
291            let mut elapsed: u64 = 0;
292            loop {
293                sleep(poll_secs).await;
294                elapsed += poll_secs;
295                if elapsed < refresh_interval_secs {
296                    continue;
297                }
298                elapsed = 0;
299                if auth.is_authenticated() {
300                    try_refresh_tokens(&url, &mut auth, &client).await;
301                }
302            }
303        }
304    });
305    let signals_instance = use_context_provider(|| {
306        let s = ForgeSignals::new(client.clone(), SignalsConfig::default());
307        client.set_signals(s.clone());
308        s
309    });
310    use_hook(|| {
311        setup_auto_capture(signals_instance);
312    });
313
314    rsx! { {children} }
315}
316
317/// Ignores network errors so transient connectivity issues (hospital
318/// networks, flaky wifi) don't force unnecessary logouts. Only logs out
319/// on definitive 401/403. Reconnects SSE on success so the new token
320/// takes effect immediately.
321async fn try_refresh_tokens(api_url: &str, auth: &mut ForgeAuth, client: &ForgeClient) -> bool {
322    let refresh_token = match auth.refresh_token() {
323        Some(t) => t,
324        None => return false,
325    };
326
327    let anon_client = ForgeClient::new(ForgeClientConfig::new(api_url.to_string()));
328
329    #[derive(Serialize)]
330    struct RefreshArgs {
331        refresh_token: String,
332    }
333
334    #[derive(Deserialize)]
335    struct RefreshResponse {
336        access_token: String,
337        refresh_token: String,
338    }
339
340    match anon_client
341        .call::<_, RefreshResponse>(
342            "refresh",
343            RefreshArgs {
344                refresh_token,
345            },
346        )
347        .await
348    {
349        Ok(resp) => {
350            auth.update_tokens(resp.access_token, resp.refresh_token);
351            client.reconnect_sse();
352            true
353        }
354        Err(ref e)
355            if e.code == "UNAUTHORIZED"
356                || e.code == "FORBIDDEN"
357                || e.code == "NOT_FOUND" =>
358        {
359            // Definitive auth failure: token is invalid/expired/revoked.
360            auth.logout();
361            false
362        }
363        Err(_) => {
364            // Network or transient error. Keep current tokens and retry
365            // on the next refresh cycle rather than forcing a logout.
366            false
367        }
368    }
369}
370
371async fn sleep(secs: u64) {
372    #[cfg(target_arch = "wasm32")]
373    gloo_timers::future::TimeoutFuture::new((secs * 1000) as u32).await;
374
375    #[cfg(not(target_arch = "wasm32"))]
376    tokio::time::sleep(std::time::Duration::from_secs(secs)).await;
377}
378
379#[cfg(target_arch = "wasm32")]
380mod storage {
381    use super::StoredAuth;
382
383    fn key(app_name: &str) -> String {
384        format!("{app_name}_auth")
385    }
386
387    pub fn save(app_name: &str, auth: &StoredAuth) {
388        if let Ok(json) = serde_json::to_string(auth) {
389            if let Some(storage) = web_sys::window()
390                .and_then(|w| w.local_storage().ok())
391                .flatten()
392            {
393                let _ = storage.set_item(&key(app_name), &json);
394            }
395        }
396    }
397
398    pub fn load(app_name: &str) -> Option<StoredAuth> {
399        let storage = web_sys::window()?.local_storage().ok()??;
400        let json = storage.get_item(&key(app_name)).ok()??;
401        serde_json::from_str(&json).ok()
402    }
403
404    pub fn clear(app_name: &str) {
405        if let Some(storage) = web_sys::window()
406            .and_then(|w| w.local_storage().ok())
407            .flatten()
408        {
409            let _ = storage.remove_item(&key(app_name));
410        }
411    }
412}
413
414#[cfg(not(target_arch = "wasm32"))]
415mod storage {
416    use super::StoredAuth;
417    use std::fs;
418    use std::path::PathBuf;
419
420    fn storage_path(app_name: &str) -> Option<PathBuf> {
421        dirs::data_local_dir().map(|base| base.join(app_name).join("auth.json"))
422    }
423
424    pub fn save(app_name: &str, auth: &StoredAuth) {
425        let Some(path) = storage_path(app_name) else {
426            return;
427        };
428        if let Some(parent) = path.parent() {
429            let _ = fs::create_dir_all(parent);
430        }
431        if let Ok(json) = serde_json::to_vec(auth) {
432            let tmp = path.with_extension("tmp");
433            let _ = fs::write(&tmp, json).and_then(|()| fs::rename(tmp, path));
434        }
435    }
436
437    pub fn load(app_name: &str) -> Option<StoredAuth> {
438        let path = storage_path(app_name)?;
439        let json = fs::read_to_string(path).ok()?;
440        serde_json::from_str(&json).ok()
441    }
442
443    pub fn clear(app_name: &str) {
444        if let Some(path) = storage_path(app_name) {
445            let _ = fs::remove_file(path);
446        }
447    }
448}
449
450#[cfg(test)]
451mod tests {
452    use super::*;
453
454    #[test]
455    fn test_authenticated_state_exposes_tokens_and_viewer() {
456        let viewer = serde_json::json!({"id": "user-1", "role": "admin"});
457        let state = ForgeAuthState::Authenticated {
458            access_token: "access-token".into(),
459            refresh_token: "refresh-token".into(),
460            viewer: Some(viewer.clone()),
461        };
462
463        assert!(state.is_authenticated());
464        assert_eq!(state.access_token().as_deref(), Some("access-token"));
465        assert_eq!(state.refresh_token().as_deref(), Some("refresh-token"));
466        assert_eq!(state.viewer_json(), Some(&viewer));
467    }
468
469    #[test]
470    fn test_unauthenticated_state_has_no_auth_material() {
471        let state = ForgeAuthState::Unauthenticated;
472
473        assert!(!state.is_authenticated());
474        assert!(state.access_token().is_none());
475        assert!(state.refresh_token().is_none());
476        assert!(state.viewer_json().is_none());
477    }
478
479}