1use dioxus::prelude::*;
8use serde::de::DeserializeOwned;
9use serde::{Deserialize, Serialize};
10
11use crate::signals::{ForgeSignals, SignalsConfig, setup_auto_capture};
12use crate::{ConnectionState, ForgeClient, ForgeClientConfig};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15struct StoredAuth {
16 access_token: String,
17 refresh_token: String,
18 viewer: Option<serde_json::Value>,
19}
20
21#[derive(Debug, Clone)]
22pub enum ForgeAuthState {
23 Unauthenticated,
24 Authenticated {
25 access_token: String,
26 refresh_token: String,
27 viewer: Option<serde_json::Value>,
28 },
29}
30
31impl ForgeAuthState {
32 pub fn is_authenticated(&self) -> bool {
33 matches!(self, Self::Authenticated { .. })
34 }
35
36 pub fn access_token(&self) -> Option<String> {
37 match self {
38 Self::Authenticated { access_token, .. } => Some(access_token.clone()),
39 Self::Unauthenticated => None,
40 }
41 }
42
43 pub fn refresh_token(&self) -> Option<String> {
44 match self {
45 Self::Authenticated { refresh_token, .. } => Some(refresh_token.clone()),
46 Self::Unauthenticated => None,
47 }
48 }
49
50 fn viewer_json(&self) -> Option<&serde_json::Value> {
51 match self {
52 Self::Authenticated { viewer, .. } => viewer.as_ref(),
53 Self::Unauthenticated => None,
54 }
55 }
56}
57
58#[derive(Clone, Copy)]
60pub struct ForgeAuth {
61 state: Signal<ForgeAuthState>,
62 app_name: Signal<String>,
63 generation: Signal<u64>,
64}
65
66impl ForgeAuth {
67 pub fn is_authenticated(&self) -> bool {
68 self.state.read().is_authenticated()
69 }
70
71 pub fn access_token(&self) -> Option<String> {
72 self.state.read().access_token()
73 }
74
75 pub fn refresh_token(&self) -> Option<String> {
76 self.state.read().refresh_token()
77 }
78
79 pub fn viewer<V: DeserializeOwned>(&self) -> Option<V> {
81 let state = self.state.read();
82 let json = state.viewer_json()?;
83 serde_json::from_value(json.clone()).ok()
84 }
85
86 pub fn login(&mut self, access_token: String, refresh_token: String) {
88 self.save_and_set(access_token, refresh_token, None);
89 }
90
91 pub fn login_with_viewer<V: Serialize>(
93 &mut self,
94 access_token: String,
95 refresh_token: String,
96 viewer: &V,
97 ) {
98 let viewer_json = serde_json::to_value(viewer).ok();
99 self.save_and_set(access_token, refresh_token, viewer_json);
100 }
101
102 pub fn update_tokens(&mut self, access_token: String, refresh_token: String) {
104 let existing_viewer = self.state.read().viewer_json().cloned();
105 self.save_and_set(access_token, refresh_token, existing_viewer);
106 }
107
108 pub fn update_viewer<V: Serialize>(&mut self, viewer: &V) {
110 let state = self.state.read();
111 let (access_token, refresh_token) = match &*state {
112 ForgeAuthState::Authenticated {
113 access_token,
114 refresh_token,
115 ..
116 } => (access_token.clone(), refresh_token.clone()),
117 ForgeAuthState::Unauthenticated => return,
118 };
119 drop(state);
120 let viewer_json = serde_json::to_value(viewer).ok();
121 self.save_and_set(access_token, refresh_token, viewer_json);
122 }
123
124 pub fn logout(&mut self) {
126 storage::clear(&self.app_name.read());
127 self.state.set(ForgeAuthState::Unauthenticated);
128 self.generation.with_mut(|g| *g += 1);
129 }
130
131 fn save_and_set(
132 &mut self,
133 access_token: String,
134 refresh_token: String,
135 viewer: Option<serde_json::Value>,
136 ) {
137 let stored = StoredAuth {
138 access_token: access_token.clone(),
139 refresh_token: refresh_token.clone(),
140 viewer: viewer.clone(),
141 };
142 storage::save(&self.app_name.read(), &stored);
143 let was_authenticated = self.state.read().is_authenticated();
144 self.state.set(ForgeAuthState::Authenticated {
145 access_token,
146 refresh_token,
147 viewer,
148 });
149 if !was_authenticated {
150 self.generation.with_mut(|g| *g += 1);
151 }
152 }
153}
154
155pub fn use_forge_auth() -> ForgeAuth {
157 use_context::<ForgeAuth>()
158}
159
160pub fn use_viewer<V: DeserializeOwned + Clone + 'static>() -> Option<V> {
163 use_forge_auth().viewer::<V>()
164}
165
166pub fn use_auth_key() -> String {
175 let auth = use_forge_auth();
176 let generation = auth.generation.read();
177 format!("forge-auth-{generation}")
178}
179
180#[cfg(feature = "router")]
190pub fn use_require_auth(redirect_path: &str) -> bool {
191 let auth = use_forge_auth();
192 let navigator = use_navigator();
193 let path = redirect_path.to_string();
194
195 use_effect(move || {
196 if !auth.is_authenticated() {
197 navigator.replace(NavigationTarget::Internal(path.clone()));
198 }
199 });
200
201 auth.is_authenticated()
202}
203
204#[component]
217pub fn ForgeAuthProvider(
218 url: String,
219 #[props(default = "forge_app".to_string())] app_name: String,
220 #[props(default = 2400)] refresh_interval_secs: u64,
221 #[props(default)] on_mutation_error: Option<EventHandler<crate::ForgeClientError>>,
222 children: Element,
223) -> Element {
224 let initial = match storage::load(&app_name) {
225 Some(stored) => ForgeAuthState::Authenticated {
226 access_token: stored.access_token,
227 refresh_token: stored.refresh_token,
228 viewer: stored.viewer,
229 },
230 None => ForgeAuthState::Unauthenticated,
231 };
232
233 let auth_state = use_context_provider(|| Signal::new(initial));
234 let app_name_signal = use_context_provider(|| Signal::new(app_name));
235 let generation = use_context_provider(|| Signal::new(0_u64));
236 let forge_auth = use_context_provider(|| ForgeAuth {
237 state: auth_state,
238 app_name: app_name_signal,
239 generation,
240 });
241
242 let connection_state = use_context_provider(|| Signal::new(ConnectionState::Disconnected));
243 let needs_refresh = use_signal(|| false);
244
245 let url_clone = url.clone();
246 use_context_provider(move || {
247 let auth_for_token = auth_state;
248 let needs_refresh_clone = needs_refresh;
249 let mut config = ForgeClientConfig::new(url_clone)
250 .with_connection_state(connection_state)
251 .with_token_provider(move || auth_for_token.read().access_token())
252 .with_auth_error_handler(move |_err| {
253 let mut sig = needs_refresh_clone;
254 sig.set(true);
255 });
256 if let Some(handler) = on_mutation_error {
257 config = config.with_mutation_error_handler(move |err| handler.call(err));
258 }
259 ForgeClient::new(config)
260 });
261
262 let client: ForgeClient = use_context();
264
265 let url_for_refresh = url.clone();
266 let client_for_refresh = client.clone();
267 use_effect(move || {
268 if !*needs_refresh.read() {
269 return;
270 }
271 let url = url_for_refresh.clone();
272 let mut auth = forge_auth;
273 let client = client_for_refresh.clone();
274 let mut needs_refresh_sig = needs_refresh;
275 spawn(async move {
276 try_refresh_tokens(&url, &mut auth, &client).await;
277 needs_refresh_sig.set(false);
278 });
279 });
280
281 let url_for_periodic = url;
284 let client_for_periodic = client.clone();
285 use_future(move || {
286 let url = url_for_periodic.clone();
287 let client = client_for_periodic.clone();
288 let mut auth = forge_auth;
289 async move {
290 let poll_secs: u64 = 30;
291 let mut elapsed: u64 = 0;
292 loop {
293 sleep(poll_secs).await;
294 elapsed += poll_secs;
295 if elapsed < refresh_interval_secs {
296 continue;
297 }
298 elapsed = 0;
299 if auth.is_authenticated() {
300 try_refresh_tokens(&url, &mut auth, &client).await;
301 }
302 }
303 }
304 });
305 let signals_instance = use_context_provider(|| {
306 let s = ForgeSignals::new(client.clone(), SignalsConfig::default());
307 client.set_signals(s.clone());
308 s
309 });
310 use_hook(|| {
311 setup_auto_capture(signals_instance);
312 });
313
314 rsx! { {children} }
315}
316
317async fn try_refresh_tokens(api_url: &str, auth: &mut ForgeAuth, client: &ForgeClient) -> bool {
322 let refresh_token = match auth.refresh_token() {
323 Some(t) => t,
324 None => return false,
325 };
326
327 let anon_client = ForgeClient::new(ForgeClientConfig::new(api_url.to_string()));
328
329 #[derive(Serialize)]
330 struct RefreshArgs {
331 refresh_token: String,
332 }
333
334 #[derive(Deserialize)]
335 struct RefreshResponse {
336 access_token: String,
337 refresh_token: String,
338 }
339
340 match anon_client
341 .call::<_, RefreshResponse>(
342 "refresh",
343 RefreshArgs {
344 refresh_token,
345 },
346 )
347 .await
348 {
349 Ok(resp) => {
350 auth.update_tokens(resp.access_token, resp.refresh_token);
351 client.reconnect_sse();
352 true
353 }
354 Err(ref e)
355 if e.code == "UNAUTHORIZED"
356 || e.code == "FORBIDDEN"
357 || e.code == "NOT_FOUND" =>
358 {
359 auth.logout();
361 false
362 }
363 Err(_) => {
364 false
367 }
368 }
369}
370
371async fn sleep(secs: u64) {
372 #[cfg(target_arch = "wasm32")]
373 gloo_timers::future::TimeoutFuture::new((secs * 1000) as u32).await;
374
375 #[cfg(not(target_arch = "wasm32"))]
376 tokio::time::sleep(std::time::Duration::from_secs(secs)).await;
377}
378
379#[cfg(target_arch = "wasm32")]
380mod storage {
381 use super::StoredAuth;
382
383 fn key(app_name: &str) -> String {
384 format!("{app_name}_auth")
385 }
386
387 pub fn save(app_name: &str, auth: &StoredAuth) {
388 if let Ok(json) = serde_json::to_string(auth) {
389 if let Some(storage) = web_sys::window()
390 .and_then(|w| w.local_storage().ok())
391 .flatten()
392 {
393 let _ = storage.set_item(&key(app_name), &json);
394 }
395 }
396 }
397
398 pub fn load(app_name: &str) -> Option<StoredAuth> {
399 let storage = web_sys::window()?.local_storage().ok()??;
400 let json = storage.get_item(&key(app_name)).ok()??;
401 serde_json::from_str(&json).ok()
402 }
403
404 pub fn clear(app_name: &str) {
405 if let Some(storage) = web_sys::window()
406 .and_then(|w| w.local_storage().ok())
407 .flatten()
408 {
409 let _ = storage.remove_item(&key(app_name));
410 }
411 }
412}
413
414#[cfg(not(target_arch = "wasm32"))]
415mod storage {
416 use super::StoredAuth;
417 use std::fs;
418 use std::path::PathBuf;
419
420 fn storage_path(app_name: &str) -> Option<PathBuf> {
421 dirs::data_local_dir().map(|base| base.join(app_name).join("auth.json"))
422 }
423
424 pub fn save(app_name: &str, auth: &StoredAuth) {
425 let Some(path) = storage_path(app_name) else {
426 return;
427 };
428 if let Some(parent) = path.parent() {
429 let _ = fs::create_dir_all(parent);
430 }
431 if let Ok(json) = serde_json::to_vec(auth) {
432 let tmp = path.with_extension("tmp");
433 let _ = fs::write(&tmp, json).and_then(|()| fs::rename(tmp, path));
434 }
435 }
436
437 pub fn load(app_name: &str) -> Option<StoredAuth> {
438 let path = storage_path(app_name)?;
439 let json = fs::read_to_string(path).ok()?;
440 serde_json::from_str(&json).ok()
441 }
442
443 pub fn clear(app_name: &str) {
444 if let Some(path) = storage_path(app_name) {
445 let _ = fs::remove_file(path);
446 }
447 }
448}
449
450#[cfg(test)]
451mod tests {
452 use super::*;
453
454 #[test]
455 fn test_authenticated_state_exposes_tokens_and_viewer() {
456 let viewer = serde_json::json!({"id": "user-1", "role": "admin"});
457 let state = ForgeAuthState::Authenticated {
458 access_token: "access-token".into(),
459 refresh_token: "refresh-token".into(),
460 viewer: Some(viewer.clone()),
461 };
462
463 assert!(state.is_authenticated());
464 assert_eq!(state.access_token().as_deref(), Some("access-token"));
465 assert_eq!(state.refresh_token().as_deref(), Some("refresh-token"));
466 assert_eq!(state.viewer_json(), Some(&viewer));
467 }
468
469 #[test]
470 fn test_unauthenticated_state_has_no_auth_material() {
471 let state = ForgeAuthState::Unauthenticated;
472
473 assert!(!state.is_authenticated());
474 assert!(state.access_token().is_none());
475 assert!(state.refresh_token().is_none());
476 assert!(state.viewer_json().is_none());
477 }
478
479}