fire_auth_token/
lib.rs

1pub mod structs;
2
3use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
4use reqwest;
5use std::{
6    collections::HashSet,
7    sync::Arc,
8    time::{SystemTime, UNIX_EPOCH},
9};
10use structs::*;
11use tokio::sync::RwLock;
12
13impl FirebaseAuth {
14    pub async fn new(project_id: String) -> Self {
15        let auth = FirebaseAuth {
16            config: FirebaseAuthConfig {
17                project_id,
18                public_keys_url: String::from(
19                    "https://www.googleapis.com/robot/v1/metadata/x509/securetoken@system.gserviceaccount.com",
20                ),
21            },
22            cached_public_keys: Arc::new(RwLock::new(None)),
23        };
24
25        auth.update_public_keys()
26            .await
27            .expect("Initial key fetch failed");
28        auth.start_key_refresh_task();
29        auth
30    }
31
32    fn start_key_refresh_task(&self) {
33        let cached_keys = self.cached_public_keys.clone();
34        let config = self.config.clone();
35
36        tokio::spawn(async move {
37            loop {
38                let next_update = {
39                    let keys = cached_keys.read().await;
40                    keys.as_ref().map(|state| state.expiry).unwrap_or_else(|| {
41                        SystemTime::now()
42                            .duration_since(UNIX_EPOCH)
43                            .unwrap()
44                            .as_secs() as i64
45                    })
46                };
47
48                let current_time = SystemTime::now()
49                    .duration_since(UNIX_EPOCH)
50                    .unwrap()
51                    .as_secs() as i64;
52
53                let sleep_duration = if next_update > current_time {
54                    ((next_update - current_time) as f64 * 0.9) as u64
55                } else {
56                    0
57                };
58
59                tokio::time::sleep(tokio::time::Duration::from_secs(sleep_duration)).await;
60
61                let client = reqwest::Client::new();
62                match Self::fetch_public_keys(&config, &client).await {
63                    Ok((keys, expiry)) => {
64                        let mut cached = cached_keys.write().await;
65                        *cached = Some(SharedState { keys, expiry });
66                    }
67                    Err(e) => {
68                        eprintln!("Failed to update public keys: {:?}", e);
69                        tokio::time::sleep(tokio::time::Duration::from_secs(60)).await;
70                    }
71                }
72            }
73        });
74    }
75
76    async fn fetch_public_keys(
77        config: &FirebaseAuthConfig,
78        client: &reqwest::Client,
79    ) -> FirebaseAuthResult<(PublicKeysResponse, i64)> {
80        let response = client
81            .get(&config.public_keys_url)
82            .send()
83            .await
84            .map_err(|e| FirebaseAuthError::HttpError(e.to_string()))?;
85
86        let cache_control = response
87            .headers()
88            .get("Cache-Control")
89            .and_then(|h| h.to_str().ok())
90            .unwrap_or("max-age=3600");
91
92        let max_age = cache_control
93            .split(',')
94            .find(|&s| s.trim().starts_with("max-age="))
95            .and_then(|s| s.trim().strip_prefix("max-age="))
96            .and_then(|s| s.parse::<i64>().ok())
97            .unwrap_or(3600);
98
99        let keys: PublicKeysResponse = response
100            .json()
101            .await
102            .map_err(|e| FirebaseAuthError::HttpError(e.to_string()))?;
103
104        let expiry = SystemTime::now()
105            .duration_since(UNIX_EPOCH)
106            .unwrap()
107            .as_secs() as i64
108            + max_age;
109
110        Ok((keys, expiry))
111    }
112
113    async fn update_public_keys(&self) -> FirebaseAuthResult<()> {
114        let client = reqwest::Client::new();
115        let (keys, expiry) = Self::fetch_public_keys(&self.config, &client).await?;
116        let mut cached = self.cached_public_keys.write().await;
117        *cached = Some(SharedState { keys, expiry });
118        Ok(())
119    }
120
121    pub async fn verify_token<T>(&self, token: &str) -> FirebaseAuthResult<T>
122    where
123        T: TokenVerifier + serde::de::DeserializeOwned,
124    {
125        let header =
126            decode_header(token).map_err(|e| FirebaseAuthError::JwtError(e.to_string()))?;
127        if header.alg != Algorithm::RS256 {
128            return Err(FirebaseAuthError::InvalidTokenFormat);
129        }
130
131        let kid = header.kid.ok_or(FirebaseAuthError::InvalidTokenFormat)?;
132        let cached_keys = self.cached_public_keys.read().await;
133        let state = cached_keys
134            .as_ref()
135            .ok_or(FirebaseAuthError::InvalidTokenFormat)?;
136
137        let public_key = state
138            .keys
139            .keys
140            .get(&kid)
141            .ok_or(FirebaseAuthError::InvalidSignature)?;
142
143        let mut validation = Validation::new(Algorithm::RS256);
144        let mut iss_set = HashSet::new();
145        iss_set.insert(format!(
146            "https://securetoken.google.com/{}",
147            self.config.project_id
148        ));
149        validation.iss = Some(iss_set);
150
151        let mut aud_set = HashSet::new();
152        aud_set.insert(self.config.project_id.clone());
153        validation.aud = Some(aud_set);
154
155        validation.validate_exp = true;
156        validation.validate_nbf = false;
157        validation.set_required_spec_claims(&["sub"]);
158
159        let token_data = decode::<T>(
160            token,
161            &DecodingKey::from_rsa_pem(public_key.as_bytes())
162                .map_err(|e| FirebaseAuthError::JwtError(e.to_string()))?,
163            &validation,
164        )
165        .map_err(|e| FirebaseAuthError::JwtError(e.to_string()))?;
166
167        let current_time = SystemTime::now()
168            .duration_since(UNIX_EPOCH)
169            .unwrap()
170            .as_secs() as i64;
171
172        token_data
173            .claims
174            .verify(&self.config.project_id, current_time)?;
175
176        Ok(token_data.claims)
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183
184    #[tokio::test]
185    async fn test_public_key_fetch() {
186        let auth = FirebaseAuth::new("test-project".to_string()).await;
187        let client = reqwest::Client::new();
188
189        let result = FirebaseAuth::fetch_public_keys(&auth.config, &client).await;
190        assert!(result.is_ok());
191
192        let (keys, _) = result.unwrap();
193        assert!(!keys.keys.is_empty());
194    }
195
196    #[tokio::test]
197    async fn test_verify_normal_token() {
198        let auth = FirebaseAuth::new("test-project".to_string()).await;
199
200        // You would need to replace this with a valid test token
201        let test_token = "your.test.token";
202
203        let result: FirebaseAuthResult<FirebaseAuthUser> = auth.verify_token(test_token).await;
204        assert!(result.is_err()); // Will fail with invalid token, replace with proper test token
205    }
206
207    #[tokio::test]
208    async fn test_verify_google_token() {
209        let auth = FirebaseAuth::new("test-project".to_string()).await;
210
211        // You would need to replace this with a valid test token
212        let test_token = "your.google.test.token";
213
214        let result: FirebaseAuthResult<FirebaseAuthGoogleUser> =
215            auth.verify_token(test_token).await;
216        assert!(result.is_err()); // Will fail with invalid token, replace with proper test token
217    }
218}