1use dioxus::prelude::*;
8use serde::de::DeserializeOwned;
9use serde::{Deserialize, Serialize};
10
11use crate::{ConnectionState, ForgeClient, ForgeClientConfig};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
16struct StoredAuth {
17 access_token: String,
18 refresh_token: String,
19 #[serde(default, skip_serializing_if = "Option::is_none")]
20 viewer: Option<serde_json::Value>,
21}
22
23#[derive(Debug, Clone)]
25pub enum ForgeAuthState {
26 Unauthenticated,
27 Authenticated {
28 access_token: String,
29 refresh_token: String,
30 viewer: Option<serde_json::Value>,
31 },
32}
33
34impl ForgeAuthState {
35 pub fn is_authenticated(&self) -> bool {
36 matches!(self, Self::Authenticated { .. })
37 }
38
39 pub fn access_token(&self) -> Option<String> {
40 match self {
41 Self::Authenticated { access_token, .. } => Some(access_token.clone()),
42 Self::Unauthenticated => None,
43 }
44 }
45
46 pub fn refresh_token(&self) -> Option<String> {
47 match self {
48 Self::Authenticated { refresh_token, .. } => Some(refresh_token.clone()),
49 Self::Unauthenticated => None,
50 }
51 }
52
53 fn viewer_json(&self) -> Option<&serde_json::Value> {
54 match self {
55 Self::Authenticated { viewer, .. } => viewer.as_ref(),
56 Self::Unauthenticated => None,
57 }
58 }
59}
60
61#[derive(Clone, Copy)]
63pub struct ForgeAuth {
64 state: Signal<ForgeAuthState>,
65 app_name: Signal<String>,
66 generation: Signal<u64>,
67}
68
69impl ForgeAuth {
70 pub fn is_authenticated(&self) -> bool {
71 self.state.read().is_authenticated()
72 }
73
74 pub fn access_token(&self) -> Option<String> {
75 self.state.read().access_token()
76 }
77
78 pub fn refresh_token(&self) -> Option<String> {
79 self.state.read().refresh_token()
80 }
81
82 pub fn viewer<V: DeserializeOwned>(&self) -> Option<V> {
84 let state = self.state.read();
85 let json = state.viewer_json()?;
86 serde_json::from_value(json.clone()).ok()
87 }
88
89 pub fn login(&mut self, access_token: String, refresh_token: String) {
91 self.save_and_set(access_token, refresh_token, None);
92 }
93
94 pub fn login_with_viewer<V: Serialize>(
96 &mut self,
97 access_token: String,
98 refresh_token: String,
99 viewer: &V,
100 ) {
101 let viewer_json = serde_json::to_value(viewer).ok();
102 self.save_and_set(access_token, refresh_token, viewer_json);
103 }
104
105 pub fn update_tokens(&mut self, access_token: String, refresh_token: String) {
107 let existing_viewer = self.state.read().viewer_json().cloned();
108 self.save_and_set(access_token, refresh_token, existing_viewer);
109 }
110
111 pub fn update_viewer<V: Serialize>(&mut self, viewer: &V) {
113 let state = self.state.read();
114 let (access_token, refresh_token) = match &*state {
115 ForgeAuthState::Authenticated {
116 access_token,
117 refresh_token,
118 ..
119 } => (access_token.clone(), refresh_token.clone()),
120 ForgeAuthState::Unauthenticated => return,
121 };
122 drop(state);
123 let viewer_json = serde_json::to_value(viewer).ok();
124 self.save_and_set(access_token, refresh_token, viewer_json);
125 }
126
127 pub fn logout(&mut self) {
129 storage::clear(&self.app_name.read());
130 self.state.set(ForgeAuthState::Unauthenticated);
131 self.generation.with_mut(|g| *g += 1);
132 }
133
134 fn save_and_set(
135 &mut self,
136 access_token: String,
137 refresh_token: String,
138 viewer: Option<serde_json::Value>,
139 ) {
140 let stored = StoredAuth {
141 access_token: access_token.clone(),
142 refresh_token: refresh_token.clone(),
143 viewer: viewer.clone(),
144 };
145 storage::save(&self.app_name.read(), &stored);
146 let was_authenticated = self.state.read().is_authenticated();
147 self.state.set(ForgeAuthState::Authenticated {
148 access_token,
149 refresh_token,
150 viewer,
151 });
152 if !was_authenticated {
153 self.generation.with_mut(|g| *g += 1);
154 }
155 }
156}
157
158pub fn use_forge_auth() -> ForgeAuth {
160 use_context::<ForgeAuth>()
161}
162
163pub fn use_viewer<V: DeserializeOwned + Clone + 'static>() -> Option<V> {
166 use_forge_auth().viewer::<V>()
167}
168
169pub fn use_auth_key() -> String {
178 let auth = use_forge_auth();
179 let generation = auth.generation.read();
180 format!("forge-auth-{generation}")
181}
182
183#[cfg(feature = "router")]
193pub fn use_require_auth(redirect_path: &str) -> bool {
194 let auth = use_forge_auth();
195 let navigator = use_navigator();
196 let path = redirect_path.to_string();
197
198 use_effect(move || {
199 if !auth.is_authenticated() {
200 navigator.replace(NavigationTarget::Internal(path.clone()));
201 }
202 });
203
204 auth.is_authenticated()
205}
206
207#[component]
220pub fn ForgeAuthProvider(
221 url: String,
222 #[props(default = "forge_app".to_string())] app_name: String,
223 #[props(default = 2400)] refresh_interval_secs: u64,
224 children: Element,
225) -> Element {
226 let initial = match storage::load(&app_name) {
227 Some(stored) => ForgeAuthState::Authenticated {
228 access_token: stored.access_token,
229 refresh_token: stored.refresh_token,
230 viewer: stored.viewer,
231 },
232 None => ForgeAuthState::Unauthenticated,
233 };
234
235 let auth_state = use_context_provider(|| Signal::new(initial));
236 let app_name_signal = use_context_provider(|| Signal::new(app_name));
237 let generation = use_context_provider(|| Signal::new(0_u64));
238 let forge_auth = use_context_provider(|| ForgeAuth {
239 state: auth_state,
240 app_name: app_name_signal,
241 generation,
242 });
243
244 let connection_state = use_context_provider(|| Signal::new(ConnectionState::Disconnected));
245 let needs_refresh = use_signal(|| false);
246
247 let url_clone = url.clone();
249 use_context_provider(move || {
250 let auth_for_token = auth_state;
251 let needs_refresh_clone = needs_refresh;
252 let config = ForgeClientConfig::new(url_clone)
253 .with_connection_state(connection_state)
254 .with_token_provider(move || auth_for_token.read().access_token())
255 .with_auth_error_handler(move |_err| {
256 let mut sig = needs_refresh_clone;
257 sig.set(true);
258 });
259 ForgeClient::new(config)
260 });
261
262 let url_for_refresh = url.clone();
264 use_effect(move || {
265 if !*needs_refresh.read() {
266 return;
267 }
268 let url = url_for_refresh.clone();
269 let mut auth = forge_auth;
270 spawn(async move {
271 try_refresh_tokens(&url, &mut auth).await;
272 });
273 });
274
275 let url_for_periodic = url;
277 use_future(move || {
278 let url = url_for_periodic.clone();
279 let mut auth = forge_auth;
280 async move {
281 loop {
282 sleep(refresh_interval_secs).await;
283 if auth.is_authenticated() {
284 try_refresh_tokens(&url, &mut auth).await;
285 }
286 }
287 }
288 });
289
290 rsx! { {children} }
291}
292
293async fn try_refresh_tokens(api_url: &str, auth: &mut ForgeAuth) {
299 let refresh_token = match auth.refresh_token() {
300 Some(t) => t,
301 None => return,
302 };
303
304 let anon_client = ForgeClient::new(ForgeClientConfig::new(api_url.to_string()));
305
306 #[derive(Serialize)]
307 struct RefreshArgs {
308 refresh_token: String,
309 }
310
311 #[derive(Deserialize)]
312 struct RefreshResponse {
313 access_token: String,
314 refresh_token: String,
315 }
316
317 match anon_client
318 .call::<_, RefreshResponse>(
319 "refresh",
320 RefreshArgs {
321 refresh_token,
322 },
323 )
324 .await
325 {
326 Ok(resp) => {
327 auth.update_tokens(resp.access_token, resp.refresh_token);
328 }
329 Err(ref e)
330 if e.code == "UNAUTHORIZED"
331 || e.code == "FORBIDDEN"
332 || e.code == "NOT_FOUND" =>
333 {
334 auth.logout();
336 }
337 Err(_) => {
338 }
341 }
342}
343
344async fn sleep(secs: u64) {
346 #[cfg(target_arch = "wasm32")]
347 gloo_timers::future::TimeoutFuture::new((secs * 1000) as u32).await;
348
349 #[cfg(not(target_arch = "wasm32"))]
350 tokio::time::sleep(std::time::Duration::from_secs(secs)).await;
351}
352
353#[cfg(target_arch = "wasm32")]
355mod storage {
356 use super::StoredAuth;
357
358 fn key(app_name: &str) -> String {
359 format!("{app_name}_auth")
360 }
361
362 pub fn save(app_name: &str, auth: &StoredAuth) {
363 if let Ok(json) = serde_json::to_string(auth) {
364 if let Some(storage) = web_sys::window()
365 .and_then(|w| w.local_storage().ok())
366 .flatten()
367 {
368 let _ = storage.set_item(&key(app_name), &json);
369 }
370 }
371 }
372
373 pub fn load(app_name: &str) -> Option<StoredAuth> {
374 let storage = web_sys::window()?.local_storage().ok()??;
375 let json = storage.get_item(&key(app_name)).ok()??;
376 serde_json::from_str(&json).ok()
377 }
378
379 pub fn clear(app_name: &str) {
380 if let Some(storage) = web_sys::window()
381 .and_then(|w| w.local_storage().ok())
382 .flatten()
383 {
384 let _ = storage.remove_item(&key(app_name));
385 }
386 }
387}
388
389#[cfg(not(target_arch = "wasm32"))]
390mod storage {
391 use super::StoredAuth;
392 use std::fs;
393 use std::path::PathBuf;
394
395 fn storage_path(app_name: &str) -> Option<PathBuf> {
396 dirs::data_local_dir().map(|base| base.join(app_name).join("auth.json"))
397 }
398
399 pub fn save(app_name: &str, auth: &StoredAuth) {
400 let Some(path) = storage_path(app_name) else {
401 return;
402 };
403 if let Some(parent) = path.parent() {
404 let _ = fs::create_dir_all(parent);
405 }
406 if let Ok(json) = serde_json::to_vec(auth) {
407 let tmp = path.with_extension("tmp");
408 let _ = fs::write(&tmp, json).and_then(|()| fs::rename(tmp, path));
409 }
410 }
411
412 pub fn load(app_name: &str) -> Option<StoredAuth> {
413 let path = storage_path(app_name)?;
414 let json = fs::read_to_string(path).ok()?;
415 serde_json::from_str(&json).ok()
416 }
417
418 pub fn clear(app_name: &str) {
419 if let Some(path) = storage_path(app_name) {
420 let _ = fs::remove_file(path);
421 }
422 }
423}