firebase_rs_sdk/auth/persistence/
file.rs

1use std::fs::{remove_file, File};
2use std::io::{Read, Write};
3use std::path::{Path, PathBuf};
4use std::sync::{Arc, Mutex};
5
6use crate::auth::error::{AuthError, AuthResult};
7use crate::auth::persistence::{
8    AuthPersistence, PersistedAuthState, PersistenceListener, PersistenceSubscription,
9};
10use serde_json::{from_str as deserialize_state, to_string as serialize_state};
11
12#[derive(Clone)]
13pub struct FilePersistence {
14    path: Arc<PathBuf>,
15    listeners: Arc<Mutex<Vec<PersistenceListener>>>,
16}
17
18impl std::fmt::Debug for FilePersistence {
19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20        f.debug_struct("FilePersistence")
21            .field("path", &self.path)
22            .finish()
23    }
24}
25
26impl FilePersistence {
27    pub fn new(path: impl AsRef<Path>) -> Self {
28        Self {
29            path: Arc::new(path.as_ref().to_path_buf()),
30            listeners: Arc::new(Mutex::new(Vec::new())),
31        }
32    }
33
34    fn notify_listeners(&self, state: Option<PersistedAuthState>) {
35        let listeners = self.listeners.lock().unwrap().clone();
36        for listener in listeners {
37            listener(state.clone());
38        }
39    }
40}
41
42impl AuthPersistence for FilePersistence {
43    fn set(&self, state: Option<PersistedAuthState>) -> AuthResult<()> {
44        match &state {
45            Some(state) => {
46                let serialized = serialize_state(state).map_err(|err| {
47                    AuthError::InvalidCredential(format!(
48                        "Failed to serialize auth state for persistence: {err}"
49                    ))
50                })?;
51                if let Some(parent) = self.path.parent() {
52                    std::fs::create_dir_all(parent).map_err(|err| {
53                        AuthError::InvalidCredential(format!(
54                            "Failed to create persistence directory: {err}"
55                        ))
56                    })?;
57                }
58                let mut file = File::create(&*self.path).map_err(|err| {
59                    AuthError::InvalidCredential(format!(
60                        "Failed to create auth persistence file: {err}"
61                    ))
62                })?;
63                file.write_all(serialized.as_bytes()).map_err(|err| {
64                    AuthError::InvalidCredential(format!(
65                        "Failed to write auth persistence file: {err}"
66                    ))
67                })?;
68            }
69            None => {
70                if self.path.exists() {
71                    remove_file(&*self.path).map_err(|err| {
72                        AuthError::InvalidCredential(format!(
73                            "Failed to remove auth persistence file: {err}"
74                        ))
75                    })?;
76                }
77            }
78        }
79
80        self.notify_listeners(state.clone());
81        Ok(())
82    }
83
84    fn get(&self) -> AuthResult<Option<PersistedAuthState>> {
85        if !self.path.exists() {
86            return Ok(None);
87        }
88
89        let mut file = File::open(&*self.path).map_err(|err| {
90            AuthError::InvalidCredential(format!("Failed to open auth persistence file: {err}"))
91        })?;
92        let mut buffer = String::new();
93        file.read_to_string(&mut buffer).map_err(|err| {
94            AuthError::InvalidCredential(format!("Failed to read auth persistence file: {err}"))
95        })?;
96
97        if buffer.is_empty() {
98            return Ok(None);
99        }
100
101        let state = deserialize_state(&buffer).map_err(|err| {
102            AuthError::InvalidCredential(format!("Failed to parse auth persistence payload: {err}"))
103        })?;
104        Ok(Some(state))
105    }
106
107    fn subscribe(&self, listener: PersistenceListener) -> AuthResult<PersistenceSubscription> {
108        let listener_arc = listener.clone();
109        let mut listeners = self.listeners.lock().unwrap();
110        listeners.push(listener_arc.clone());
111        drop(listeners);
112
113        let listeners = Arc::downgrade(&self.listeners);
114        Ok(PersistenceSubscription::new(move || {
115            if let Some(listeners) = listeners.upgrade() {
116                let mut guard = listeners.lock().unwrap();
117                guard.retain(|existing| !Arc::ptr_eq(existing, &listener_arc));
118            }
119        }))
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126
127    fn temp_path(name: &str) -> PathBuf {
128        let mut path = std::env::temp_dir();
129        path.push(format!(
130            "firebase-auth-test-{}-{}.json",
131            name,
132            std::process::id()
133        ));
134        path
135    }
136
137    #[test]
138    fn roundtrip_persistence() {
139        let path = temp_path("roundtrip");
140        let persistence = FilePersistence::new(&path);
141        let state = PersistedAuthState {
142            user_id: "user".into(),
143            email: Some("user@example.com".into()),
144            refresh_token: Some("refresh".into()),
145            access_token: Some("access".into()),
146            expires_at: Some(1234),
147        };
148
149        persistence.set(Some(state.clone())).unwrap();
150        let loaded = persistence.get().unwrap();
151        assert_eq!(loaded, Some(state.clone()));
152
153        persistence.set(None).unwrap();
154        assert!(persistence.get().unwrap().is_none());
155
156        let _ = remove_file(path);
157    }
158}