firebase_rs_sdk/auth/persistence/
file.rs1use std::fs::{remove_file, File};
2use std::io::{Read, Write};
3use std::path::{Path, PathBuf};
4use std::sync::{Arc, Mutex};
5
6use crate::auth::error::{AuthError, AuthResult};
7use crate::auth::persistence::{
8 AuthPersistence, PersistedAuthState, PersistenceListener, PersistenceSubscription,
9};
10use serde_json::{from_str as deserialize_state, to_string as serialize_state};
11
12#[derive(Clone)]
13pub struct FilePersistence {
14 path: Arc<PathBuf>,
15 listeners: Arc<Mutex<Vec<PersistenceListener>>>,
16}
17
18impl std::fmt::Debug for FilePersistence {
19 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20 f.debug_struct("FilePersistence")
21 .field("path", &self.path)
22 .finish()
23 }
24}
25
26impl FilePersistence {
27 pub fn new(path: impl AsRef<Path>) -> Self {
28 Self {
29 path: Arc::new(path.as_ref().to_path_buf()),
30 listeners: Arc::new(Mutex::new(Vec::new())),
31 }
32 }
33
34 fn notify_listeners(&self, state: Option<PersistedAuthState>) {
35 let listeners = self.listeners.lock().unwrap().clone();
36 for listener in listeners {
37 listener(state.clone());
38 }
39 }
40}
41
42impl AuthPersistence for FilePersistence {
43 fn set(&self, state: Option<PersistedAuthState>) -> AuthResult<()> {
44 match &state {
45 Some(state) => {
46 let serialized = serialize_state(state).map_err(|err| {
47 AuthError::InvalidCredential(format!(
48 "Failed to serialize auth state for persistence: {err}"
49 ))
50 })?;
51 if let Some(parent) = self.path.parent() {
52 std::fs::create_dir_all(parent).map_err(|err| {
53 AuthError::InvalidCredential(format!(
54 "Failed to create persistence directory: {err}"
55 ))
56 })?;
57 }
58 let mut file = File::create(&*self.path).map_err(|err| {
59 AuthError::InvalidCredential(format!(
60 "Failed to create auth persistence file: {err}"
61 ))
62 })?;
63 file.write_all(serialized.as_bytes()).map_err(|err| {
64 AuthError::InvalidCredential(format!(
65 "Failed to write auth persistence file: {err}"
66 ))
67 })?;
68 }
69 None => {
70 if self.path.exists() {
71 remove_file(&*self.path).map_err(|err| {
72 AuthError::InvalidCredential(format!(
73 "Failed to remove auth persistence file: {err}"
74 ))
75 })?;
76 }
77 }
78 }
79
80 self.notify_listeners(state.clone());
81 Ok(())
82 }
83
84 fn get(&self) -> AuthResult<Option<PersistedAuthState>> {
85 if !self.path.exists() {
86 return Ok(None);
87 }
88
89 let mut file = File::open(&*self.path).map_err(|err| {
90 AuthError::InvalidCredential(format!("Failed to open auth persistence file: {err}"))
91 })?;
92 let mut buffer = String::new();
93 file.read_to_string(&mut buffer).map_err(|err| {
94 AuthError::InvalidCredential(format!("Failed to read auth persistence file: {err}"))
95 })?;
96
97 if buffer.is_empty() {
98 return Ok(None);
99 }
100
101 let state = deserialize_state(&buffer).map_err(|err| {
102 AuthError::InvalidCredential(format!("Failed to parse auth persistence payload: {err}"))
103 })?;
104 Ok(Some(state))
105 }
106
107 fn subscribe(&self, listener: PersistenceListener) -> AuthResult<PersistenceSubscription> {
108 let listener_arc = listener.clone();
109 let mut listeners = self.listeners.lock().unwrap();
110 listeners.push(listener_arc.clone());
111 drop(listeners);
112
113 let listeners = Arc::downgrade(&self.listeners);
114 Ok(PersistenceSubscription::new(move || {
115 if let Some(listeners) = listeners.upgrade() {
116 let mut guard = listeners.lock().unwrap();
117 guard.retain(|existing| !Arc::ptr_eq(existing, &listener_arc));
118 }
119 }))
120 }
121}
122
123#[cfg(test)]
124mod tests {
125 use super::*;
126
127 fn temp_path(name: &str) -> PathBuf {
128 let mut path = std::env::temp_dir();
129 path.push(format!(
130 "firebase-auth-test-{}-{}.json",
131 name,
132 std::process::id()
133 ));
134 path
135 }
136
137 #[test]
138 fn roundtrip_persistence() {
139 let path = temp_path("roundtrip");
140 let persistence = FilePersistence::new(&path);
141 let state = PersistedAuthState {
142 user_id: "user".into(),
143 email: Some("user@example.com".into()),
144 refresh_token: Some("refresh".into()),
145 access_token: Some("access".into()),
146 expires_at: Some(1234),
147 };
148
149 persistence.set(Some(state.clone())).unwrap();
150 let loaded = persistence.get().unwrap();
151 assert_eq!(loaded, Some(state.clone()));
152
153 persistence.set(None).unwrap();
154 assert!(persistence.get().unwrap().is_none());
155
156 let _ = remove_file(path);
157 }
158}