fire_auth_token/
lib.rs

1pub mod structs;
2
3use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
4use reqwest;
5use std::{collections::HashSet, sync::Arc};
6use structs::*;
7use time::{Duration, OffsetDateTime};
8use tokio::sync::RwLock;
9
10impl FirebaseTokenPayload {
11    fn verify(&self, project_id: &str, current_time: OffsetDateTime) -> FirebaseAuthResult<()> {
12        // Verify expiration time
13        if self.exp <= current_time.unix_timestamp() {
14            return Err(FirebaseAuthError::TokenExpired);
15        }
16
17        // Verify issued at time
18        if self.iat >= current_time.unix_timestamp() {
19            return Err(FirebaseAuthError::InvalidTokenFormat);
20        }
21
22        // Verify authentication time
23        if self.auth_time >= current_time.unix_timestamp() {
24            return Err(FirebaseAuthError::InvalidAuthTime);
25        }
26
27        // Verify audience
28        if self.aud != project_id {
29            return Err(FirebaseAuthError::InvalidAudience);
30        }
31
32        // Verify issuer
33        let expected_issuer = format!("https://securetoken.google.com/{}", project_id);
34        if self.iss != expected_issuer {
35            return Err(FirebaseAuthError::InvalidIssuer);
36        }
37
38        // Verify subject
39        if self.sub.is_empty() {
40            return Err(FirebaseAuthError::InvalidSubject);
41        }
42
43        Ok(())
44    }
45
46    fn to_auth_user(&self) -> FirebaseAuthUser {
47        FirebaseAuthUser {
48            uid: self.sub.clone(),
49            issued_at: OffsetDateTime::from_unix_timestamp(self.iat)
50                .unwrap_or_else(|_| OffsetDateTime::now_utc()),
51            expires_at: OffsetDateTime::from_unix_timestamp(self.exp)
52                .unwrap_or_else(|_| OffsetDateTime::now_utc()),
53            auth_time: OffsetDateTime::from_unix_timestamp(self.auth_time)
54                .unwrap_or_else(|_| OffsetDateTime::now_utc()),
55        }
56    }
57}
58
59impl FirebaseAuth {
60    pub async fn new(project_id: String) -> Self {
61        let auth = FirebaseAuth {
62            config: FirebaseAuthConfig {
63                project_id,
64                public_keys_url: String::from(
65                    "https://www.googleapis.com/robot/v1/metadata/x509/securetoken@system.gserviceaccount.com",
66                ),
67            },
68            cached_public_keys: Arc::new(RwLock::new(None)),
69        };
70
71        // Initialize the keys
72        auth.update_public_keys()
73            .await
74            .expect("Initial key fetch failed");
75
76        // Start the background refresh task
77        auth.start_key_refresh_task();
78
79        auth
80    }
81
82    fn start_key_refresh_task(&self) {
83        let cached_keys = self.cached_public_keys.clone();
84        let config = self.config.clone();
85
86        tokio::spawn(async move {
87            loop {
88                // Read current state
89                let next_update = {
90                    let keys = cached_keys.read().await;
91                    keys.as_ref()
92                        .map(|state| state.expiry)
93                        .unwrap_or_else(|| OffsetDateTime::now_utc())
94                };
95
96                // Calculate sleep duration
97                let now = OffsetDateTime::now_utc();
98                let sleep_duration = if next_update > now {
99                    // Refresh slightly before expiry (90% of the remaining time)
100                    let total_duration = (next_update - now).whole_seconds();
101                    Duration::seconds((total_duration as f64 * 0.9) as i64)
102                } else {
103                    Duration::seconds(0)
104                };
105
106                // Sleep until next refresh
107                tokio::time::sleep(tokio::time::Duration::from_secs(
108                    sleep_duration.whole_seconds() as u64,
109                ))
110                .await;
111
112                // Create new client for each request
113                let client = reqwest::Client::new();
114
115                // Fetch new keys
116                match Self::fetch_public_keys(&config, &client).await {
117                    Ok((keys, expiry)) => {
118                        let mut cached = cached_keys.write().await;
119                        *cached = Some(SharedState { keys, expiry });
120                        println!(
121                            "Successfully updated public keys. Next update in {} seconds",
122                            (expiry - OffsetDateTime::now_utc()).whole_seconds()
123                        );
124                    }
125                    Err(e) => {
126                        eprintln!("Failed to update public keys: {:?}", e);
127                        // On error, retry after 1 minute
128                        tokio::time::sleep(tokio::time::Duration::from_secs(60)).await;
129                    }
130                }
131            }
132        });
133    }
134
135    async fn update_public_keys(&self) -> FirebaseAuthResult<()> {
136        println!("Updating public keys...");
137        let client = reqwest::Client::new();
138        let (keys, expiry) = Self::fetch_public_keys(&self.config, &client).await?;
139        let mut cached = self.cached_public_keys.write().await;
140        *cached = Some(SharedState { keys, expiry });
141        println!("Public keys updated successfully with expiry: {}", expiry);
142        Ok(())
143    }
144
145    async fn fetch_public_keys(
146        config: &FirebaseAuthConfig,
147        client: &reqwest::Client,
148    ) -> FirebaseAuthResult<(PublicKeysResponse, OffsetDateTime)> {
149        println!("Fetching public keys from URL: {}", config.public_keys_url);
150        let response = client
151            .get(&config.public_keys_url)
152            .send()
153            .await
154            .map_err(|e| FirebaseAuthError::HttpError(e.to_string()))?;
155
156        println!("Received response with status: {}", response.status());
157        // Get cache control header
158        let cache_control = response
159            .headers()
160            .get("Cache-Control")
161            .and_then(|h| h.to_str().ok())
162            .unwrap_or("max-age=3600");
163        println!("Cache-Control header value: {}", cache_control);
164
165        // Parse max age
166        let max_age = cache_control
167            .split(',')
168            .find(|&s| s.trim().starts_with("max-age="))
169            .and_then(|s| s.trim().strip_prefix("max-age="))
170            .and_then(|s| s.parse::<i64>().ok())
171            .unwrap_or(3600);
172
173        let keys: PublicKeysResponse = response
174            .json()
175            .await
176            .map_err(|e| FirebaseAuthError::HttpError(e.to_string()))?;
177
178        // Calculate expiry time
179        let expiry = OffsetDateTime::now_utc() + Duration::seconds(max_age);
180
181        Ok((keys, expiry))
182    }
183
184    pub async fn verify_token<T>(&self, token: &str) -> FirebaseAuthResult<FirebaseAuthUser>
185    where
186        T: TokenVerifier + serde::de::DeserializeOwned,
187    {
188        // Decode header without verification
189        let header = decode_header(token).map_err(|e| FirebaseAuthError::JwtError(e.to_string()))?;
190
191        // Verify algorithm
192        if header.alg != Algorithm::RS256 {
193            return Err(FirebaseAuthError::InvalidTokenFormat);
194        }
195
196        // Get key ID
197        let kid = header.kid.ok_or(FirebaseAuthError::InvalidTokenFormat)?;
198
199        // Get public keys
200        let cached_keys = self.cached_public_keys.read().await;
201        let state = cached_keys
202            .as_ref()
203            .ok_or(FirebaseAuthError::InvalidTokenFormat)?;
204
205        // Find matching key using the updated structure
206        let public_key = state
207            .keys
208            .keys
209            .get(&kid)
210            .ok_or(FirebaseAuthError::InvalidSignature)?;
211
212        // Set up validation parameters
213        let mut validation = Validation::new(Algorithm::RS256);
214
215        // Configure validation parameters using HashSet
216        let mut iss_set = HashSet::new();
217        iss_set.insert(format!(
218            "https://securetoken.google.com/{}",
219            self.config.project_id
220        ));
221        validation.iss = Some(iss_set);
222
223        let mut aud_set = HashSet::new();
224        aud_set.insert(self.config.project_id.clone());
225        validation.aud = Some(aud_set);
226
227        validation.validate_exp = true;
228        validation.validate_nbf = false;
229        validation.set_required_spec_claims(&["sub"]);
230
231        // Decode and verify token
232        let token_data = decode::<T>(
233            token,
234            &DecodingKey::from_rsa_pem(public_key.as_bytes())
235                .map_err(|e| FirebaseAuthError::JwtError(e.to_string()))?,
236            &validation,
237        )
238        .map_err(|e| FirebaseAuthError::JwtError(e.to_string()))?;
239
240        // Verify additional Firebase-specific claims
241        token_data
242            .claims
243            .verify(&self.config.project_id, OffsetDateTime::now_utc())?;
244
245        Ok(token_data.claims.to_auth_user())
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252
253    #[tokio::test]
254    async fn test_public_key_fetch() {
255        println!("Starting public key fetch test");
256
257        let auth = FirebaseAuth::new("oyetime-test".to_string()).await;
258        let client = reqwest::Client::new();
259
260        println!("Making request to fetch public keys...");
261        match FirebaseAuth::fetch_public_keys(&auth.config, &client).await {
262            Ok((keys, expiry)) => {
263                println!("✅ Successfully fetched public keys:");
264                println!("Keys: {:#?}", keys);
265                println!("Expiry: {}", expiry);
266                assert!(!keys.keys.is_empty(), "Keys should not be empty");
267            }
268            Err(e) => {
269                println!("❌ Failed to fetch public keys:");
270                println!("Error: {:?}", e);
271                panic!("Public key fetch failed");
272            }
273        }
274    }
275
276    #[tokio::test]
277    async fn test_key_refresh() {
278        println!("Starting key refresh test");
279
280        let auth = FirebaseAuth::new("test-project".to_string()).await;
281        println!(
282            "Initial cached keys: {:#?}",
283            auth.cached_public_keys.read().await
284        );
285
286        auth.update_public_keys().await.expect("Key refresh failed");
287
288        let cached = auth.cached_public_keys.read().await;
289        println!("Updated cached keys: {:#?}", cached);
290        assert!(
291            cached.is_some(),
292            "Cached keys should be present after refresh"
293        );
294    }
295}