1use dioxus::prelude::*;
8use serde::de::DeserializeOwned;
9use serde::{Deserialize, Serialize};
10
11use crate::signals::{ForgeSignals, SignalsConfig, setup_auto_capture};
12use crate::{ConnectionState, ForgeClient, ForgeClientConfig};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16struct StoredAuth {
17 access_token: String,
18 refresh_token: String,
19 viewer: Option<serde_json::Value>,
20}
21
22#[derive(Debug, Clone)]
24pub enum ForgeAuthState {
25 Unauthenticated,
26 Authenticated {
27 access_token: String,
28 refresh_token: String,
29 viewer: Option<serde_json::Value>,
30 },
31}
32
33impl ForgeAuthState {
34 pub fn is_authenticated(&self) -> bool {
35 matches!(self, Self::Authenticated { .. })
36 }
37
38 pub fn access_token(&self) -> Option<String> {
39 match self {
40 Self::Authenticated { access_token, .. } => Some(access_token.clone()),
41 Self::Unauthenticated => None,
42 }
43 }
44
45 pub fn refresh_token(&self) -> Option<String> {
46 match self {
47 Self::Authenticated { refresh_token, .. } => Some(refresh_token.clone()),
48 Self::Unauthenticated => None,
49 }
50 }
51
52 fn viewer_json(&self) -> Option<&serde_json::Value> {
53 match self {
54 Self::Authenticated { viewer, .. } => viewer.as_ref(),
55 Self::Unauthenticated => None,
56 }
57 }
58}
59
60#[derive(Clone, Copy)]
62pub struct ForgeAuth {
63 state: Signal<ForgeAuthState>,
64 app_name: Signal<String>,
65 generation: Signal<u64>,
66}
67
68impl ForgeAuth {
69 pub fn is_authenticated(&self) -> bool {
70 self.state.read().is_authenticated()
71 }
72
73 pub fn access_token(&self) -> Option<String> {
74 self.state.read().access_token()
75 }
76
77 pub fn refresh_token(&self) -> Option<String> {
78 self.state.read().refresh_token()
79 }
80
81 pub fn viewer<V: DeserializeOwned>(&self) -> Option<V> {
83 let state = self.state.read();
84 let json = state.viewer_json()?;
85 serde_json::from_value(json.clone()).ok()
86 }
87
88 pub fn login(&mut self, access_token: String, refresh_token: String) {
90 self.save_and_set(access_token, refresh_token, None);
91 }
92
93 pub fn login_with_viewer<V: Serialize>(
95 &mut self,
96 access_token: String,
97 refresh_token: String,
98 viewer: &V,
99 ) {
100 let viewer_json = serde_json::to_value(viewer).ok();
101 self.save_and_set(access_token, refresh_token, viewer_json);
102 }
103
104 pub fn update_tokens(&mut self, access_token: String, refresh_token: String) {
106 let existing_viewer = self.state.read().viewer_json().cloned();
107 self.save_and_set(access_token, refresh_token, existing_viewer);
108 }
109
110 pub fn update_viewer<V: Serialize>(&mut self, viewer: &V) {
112 let state = self.state.read();
113 let (access_token, refresh_token) = match &*state {
114 ForgeAuthState::Authenticated {
115 access_token,
116 refresh_token,
117 ..
118 } => (access_token.clone(), refresh_token.clone()),
119 ForgeAuthState::Unauthenticated => return,
120 };
121 drop(state);
122 let viewer_json = serde_json::to_value(viewer).ok();
123 self.save_and_set(access_token, refresh_token, viewer_json);
124 }
125
126 pub fn logout(&mut self) {
128 storage::clear(&self.app_name.read());
129 self.state.set(ForgeAuthState::Unauthenticated);
130 self.generation.with_mut(|g| *g += 1);
131 }
132
133 fn save_and_set(
134 &mut self,
135 access_token: String,
136 refresh_token: String,
137 viewer: Option<serde_json::Value>,
138 ) {
139 let stored = StoredAuth {
140 access_token: access_token.clone(),
141 refresh_token: refresh_token.clone(),
142 viewer: viewer.clone(),
143 };
144 storage::save(&self.app_name.read(), &stored);
145 let was_authenticated = self.state.read().is_authenticated();
146 self.state.set(ForgeAuthState::Authenticated {
147 access_token,
148 refresh_token,
149 viewer,
150 });
151 if !was_authenticated {
152 self.generation.with_mut(|g| *g += 1);
153 }
154 }
155}
156
157pub fn use_forge_auth() -> ForgeAuth {
159 use_context::<ForgeAuth>()
160}
161
162pub fn use_viewer<V: DeserializeOwned + Clone + 'static>() -> Option<V> {
165 use_forge_auth().viewer::<V>()
166}
167
168pub fn use_auth_key() -> String {
177 let auth = use_forge_auth();
178 let generation = auth.generation.read();
179 format!("forge-auth-{generation}")
180}
181
182#[cfg(feature = "router")]
192pub fn use_require_auth(redirect_path: &str) -> bool {
193 let auth = use_forge_auth();
194 let navigator = use_navigator();
195 let path = redirect_path.to_string();
196
197 use_effect(move || {
198 if !auth.is_authenticated() {
199 navigator.replace(NavigationTarget::Internal(path.clone()));
200 }
201 });
202
203 auth.is_authenticated()
204}
205
206#[component]
219pub fn ForgeAuthProvider(
220 url: String,
221 #[props(default = "forge_app".to_string())] app_name: String,
222 #[props(default = 2400)] refresh_interval_secs: u64,
223 children: Element,
224) -> Element {
225 let initial = match storage::load(&app_name) {
226 Some(stored) => ForgeAuthState::Authenticated {
227 access_token: stored.access_token,
228 refresh_token: stored.refresh_token,
229 viewer: stored.viewer,
230 },
231 None => ForgeAuthState::Unauthenticated,
232 };
233
234 let auth_state = use_context_provider(|| Signal::new(initial));
235 let app_name_signal = use_context_provider(|| Signal::new(app_name));
236 let generation = use_context_provider(|| Signal::new(0_u64));
237 let forge_auth = use_context_provider(|| ForgeAuth {
238 state: auth_state,
239 app_name: app_name_signal,
240 generation,
241 });
242
243 let connection_state = use_context_provider(|| Signal::new(ConnectionState::Disconnected));
244 let needs_refresh = use_signal(|| false);
245
246 let url_clone = url.clone();
248 use_context_provider(move || {
249 let auth_for_token = auth_state;
250 let needs_refresh_clone = needs_refresh;
251 let config = ForgeClientConfig::new(url_clone)
252 .with_connection_state(connection_state)
253 .with_token_provider(move || auth_for_token.read().access_token())
254 .with_auth_error_handler(move |_err| {
255 let mut sig = needs_refresh_clone;
256 sig.set(true);
257 });
258 ForgeClient::new(config)
259 });
260
261 let url_for_refresh = url.clone();
263 use_effect(move || {
264 if !*needs_refresh.read() {
265 return;
266 }
267 let url = url_for_refresh.clone();
268 let mut auth = forge_auth;
269 spawn(async move {
270 try_refresh_tokens(&url, &mut auth).await;
271 });
272 });
273
274 let url_for_periodic = url;
276 use_future(move || {
277 let url = url_for_periodic.clone();
278 let mut auth = forge_auth;
279 async move {
280 loop {
281 sleep(refresh_interval_secs).await;
282 if auth.is_authenticated() {
283 try_refresh_tokens(&url, &mut auth).await;
284 }
285 }
286 }
287 });
288
289 let client: ForgeClient = use_context();
291 let signals_instance = use_context_provider(|| {
292 let s = ForgeSignals::new(client.clone(), SignalsConfig::default());
293 client.set_signals(s.clone());
294 s
295 });
296 use_hook(|| {
297 setup_auto_capture(signals_instance);
298 });
299
300 rsx! { {children} }
301}
302
303async fn try_refresh_tokens(api_url: &str, auth: &mut ForgeAuth) {
309 let refresh_token = match auth.refresh_token() {
310 Some(t) => t,
311 None => return,
312 };
313
314 let anon_client = ForgeClient::new(ForgeClientConfig::new(api_url.to_string()));
315
316 #[derive(Serialize)]
317 struct RefreshArgs {
318 refresh_token: String,
319 }
320
321 #[derive(Deserialize)]
322 struct RefreshResponse {
323 access_token: String,
324 refresh_token: String,
325 }
326
327 match anon_client
328 .call::<_, RefreshResponse>(
329 "refresh",
330 RefreshArgs {
331 refresh_token,
332 },
333 )
334 .await
335 {
336 Ok(resp) => {
337 auth.update_tokens(resp.access_token, resp.refresh_token);
338 }
339 Err(ref e)
340 if e.code == "UNAUTHORIZED"
341 || e.code == "FORBIDDEN"
342 || e.code == "NOT_FOUND" =>
343 {
344 auth.logout();
346 }
347 Err(_) => {
348 }
351 }
352}
353
354async fn sleep(secs: u64) {
356 #[cfg(target_arch = "wasm32")]
357 gloo_timers::future::TimeoutFuture::new((secs * 1000) as u32).await;
358
359 #[cfg(not(target_arch = "wasm32"))]
360 tokio::time::sleep(std::time::Duration::from_secs(secs)).await;
361}
362
363#[cfg(target_arch = "wasm32")]
365mod storage {
366 use super::StoredAuth;
367
368 fn key(app_name: &str) -> String {
369 format!("{app_name}_auth")
370 }
371
372 pub fn save(app_name: &str, auth: &StoredAuth) {
373 if let Ok(json) = serde_json::to_string(auth) {
374 if let Some(storage) = web_sys::window()
375 .and_then(|w| w.local_storage().ok())
376 .flatten()
377 {
378 let _ = storage.set_item(&key(app_name), &json);
379 }
380 }
381 }
382
383 pub fn load(app_name: &str) -> Option<StoredAuth> {
384 let storage = web_sys::window()?.local_storage().ok()??;
385 let json = storage.get_item(&key(app_name)).ok()??;
386 serde_json::from_str(&json).ok()
387 }
388
389 pub fn clear(app_name: &str) {
390 if let Some(storage) = web_sys::window()
391 .and_then(|w| w.local_storage().ok())
392 .flatten()
393 {
394 let _ = storage.remove_item(&key(app_name));
395 }
396 }
397}
398
399#[cfg(not(target_arch = "wasm32"))]
400mod storage {
401 use super::StoredAuth;
402 use std::fs;
403 use std::path::PathBuf;
404
405 fn storage_path(app_name: &str) -> Option<PathBuf> {
406 dirs::data_local_dir().map(|base| base.join(app_name).join("auth.json"))
407 }
408
409 pub fn save(app_name: &str, auth: &StoredAuth) {
410 let Some(path) = storage_path(app_name) else {
411 return;
412 };
413 if let Some(parent) = path.parent() {
414 let _ = fs::create_dir_all(parent);
415 }
416 if let Ok(json) = serde_json::to_vec(auth) {
417 let tmp = path.with_extension("tmp");
418 let _ = fs::write(&tmp, json).and_then(|()| fs::rename(tmp, path));
419 }
420 }
421
422 pub fn load(app_name: &str) -> Option<StoredAuth> {
423 let path = storage_path(app_name)?;
424 let json = fs::read_to_string(path).ok()?;
425 serde_json::from_str(&json).ok()
426 }
427
428 pub fn clear(app_name: &str) {
429 if let Some(path) = storage_path(app_name) {
430 let _ = fs::remove_file(path);
431 }
432 }
433}
434
435#[cfg(test)]
436mod tests {
437 use super::*;
438
439 #[test]
440 fn test_authenticated_state_exposes_tokens_and_viewer() {
441 let viewer = serde_json::json!({"id": "user-1", "role": "admin"});
442 let state = ForgeAuthState::Authenticated {
443 access_token: "access-token".into(),
444 refresh_token: "refresh-token".into(),
445 viewer: Some(viewer.clone()),
446 };
447
448 assert!(state.is_authenticated());
449 assert_eq!(state.access_token().as_deref(), Some("access-token"));
450 assert_eq!(state.refresh_token().as_deref(), Some("refresh-token"));
451 assert_eq!(state.viewer_json(), Some(&viewer));
452 }
453
454 #[test]
455 fn test_unauthenticated_state_has_no_auth_material() {
456 let state = ForgeAuthState::Unauthenticated;
457
458 assert!(!state.is_authenticated());
459 assert!(state.access_token().is_none());
460 assert!(state.refresh_token().is_none());
461 assert!(state.viewer_json().is_none());
462 }
463
464}