1pub mod structs;
2
3use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
4use reqwest;
5use std::{
6 collections::HashSet,
7 sync::Arc,
8 time::{SystemTime, UNIX_EPOCH},
9};
10use structs::*;
11use tokio::sync::RwLock;
12
13impl FirebaseAuth {
14 pub async fn new(project_id: String) -> Self {
15 let auth = FirebaseAuth {
16 config: FirebaseAuthConfig {
17 project_id,
18 public_keys_url: String::from(
19 "https://www.googleapis.com/robot/v1/metadata/x509/securetoken@system.gserviceaccount.com",
20 ),
21 },
22 cached_public_keys: Arc::new(RwLock::new(None)),
23 };
24
25 auth.update_public_keys()
26 .await
27 .expect("Initial key fetch failed");
28 auth.start_key_refresh_task();
29 auth
30 }
31
32 fn start_key_refresh_task(&self) {
33 let cached_keys = self.cached_public_keys.clone();
34 let config = self.config.clone();
35
36 tokio::spawn(async move {
37 loop {
38 let next_update = {
39 let keys = cached_keys.read().await;
40 keys.as_ref().map(|state| state.expiry).unwrap_or_else(|| {
41 SystemTime::now()
42 .duration_since(UNIX_EPOCH)
43 .unwrap()
44 .as_secs() as i64
45 })
46 };
47
48 let current_time = SystemTime::now()
49 .duration_since(UNIX_EPOCH)
50 .unwrap()
51 .as_secs() as i64;
52
53 let sleep_duration = if next_update > current_time {
54 ((next_update - current_time) as f64 * 0.9) as u64
55 } else {
56 0
57 };
58
59 tokio::time::sleep(tokio::time::Duration::from_secs(sleep_duration)).await;
60
61 let client = reqwest::Client::new();
62 match Self::fetch_public_keys(&config, &client).await {
63 Ok((keys, expiry)) => {
64 let mut cached = cached_keys.write().await;
65 *cached = Some(SharedState { keys, expiry });
66 }
67 Err(e) => {
68 eprintln!("Failed to update public keys: {:?}", e);
69 tokio::time::sleep(tokio::time::Duration::from_secs(60)).await;
70 }
71 }
72 }
73 });
74 }
75
76 async fn fetch_public_keys(
77 config: &FirebaseAuthConfig,
78 client: &reqwest::Client,
79 ) -> FirebaseAuthResult<(PublicKeysResponse, i64)> {
80 let response = client
81 .get(&config.public_keys_url)
82 .send()
83 .await
84 .map_err(|e| FirebaseAuthError::HttpError(e.to_string()))?;
85
86 let cache_control = response
87 .headers()
88 .get("Cache-Control")
89 .and_then(|h| h.to_str().ok())
90 .unwrap_or("max-age=3600");
91
92 let max_age = cache_control
93 .split(',')
94 .find(|&s| s.trim().starts_with("max-age="))
95 .and_then(|s| s.trim().strip_prefix("max-age="))
96 .and_then(|s| s.parse::<i64>().ok())
97 .unwrap_or(3600);
98
99 let keys: PublicKeysResponse = response
100 .json()
101 .await
102 .map_err(|e| FirebaseAuthError::HttpError(e.to_string()))?;
103
104 let expiry = SystemTime::now()
105 .duration_since(UNIX_EPOCH)
106 .unwrap()
107 .as_secs() as i64
108 + max_age;
109
110 Ok((keys, expiry))
111 }
112
113 async fn update_public_keys(&self) -> FirebaseAuthResult<()> {
114 let client = reqwest::Client::new();
115 let (keys, expiry) = Self::fetch_public_keys(&self.config, &client).await?;
116 let mut cached = self.cached_public_keys.write().await;
117 *cached = Some(SharedState { keys, expiry });
118 Ok(())
119 }
120
121 pub async fn verify_token<T>(&self, token: &str) -> FirebaseAuthResult<T>
122 where
123 T: TokenVerifier + serde::de::DeserializeOwned,
124 {
125 let header =
126 decode_header(token).map_err(|e| FirebaseAuthError::JwtError(e.to_string()))?;
127 if header.alg != Algorithm::RS256 {
128 return Err(FirebaseAuthError::InvalidTokenFormat);
129 }
130
131 let kid = header.kid.ok_or(FirebaseAuthError::InvalidTokenFormat)?;
132 let cached_keys = self.cached_public_keys.read().await;
133 let state = cached_keys
134 .as_ref()
135 .ok_or(FirebaseAuthError::InvalidTokenFormat)?;
136
137 let public_key = state
138 .keys
139 .keys
140 .get(&kid)
141 .ok_or(FirebaseAuthError::InvalidSignature)?;
142
143 let mut validation = Validation::new(Algorithm::RS256);
144 let mut iss_set = HashSet::new();
145 iss_set.insert(format!(
146 "https://securetoken.google.com/{}",
147 self.config.project_id
148 ));
149 validation.iss = Some(iss_set);
150
151 let mut aud_set = HashSet::new();
152 aud_set.insert(self.config.project_id.clone());
153 validation.aud = Some(aud_set);
154
155 validation.validate_exp = true;
156 validation.validate_nbf = false;
157 validation.set_required_spec_claims(&["sub"]);
158
159 let token_data = decode::<T>(
160 token,
161 &DecodingKey::from_rsa_pem(public_key.as_bytes())
162 .map_err(|e| FirebaseAuthError::JwtError(e.to_string()))?,
163 &validation,
164 )
165 .map_err(|e| FirebaseAuthError::JwtError(e.to_string()))?;
166
167 let current_time = SystemTime::now()
168 .duration_since(UNIX_EPOCH)
169 .unwrap()
170 .as_secs() as i64;
171
172 token_data
173 .claims
174 .verify(&self.config.project_id, current_time)?;
175
176 Ok(token_data.claims)
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183
184 #[tokio::test]
185 async fn test_public_key_fetch() {
186 let auth = FirebaseAuth::new("test-project".to_string()).await;
187 let client = reqwest::Client::new();
188
189 let result = FirebaseAuth::fetch_public_keys(&auth.config, &client).await;
190 assert!(result.is_ok());
191
192 let (keys, _) = result.unwrap();
193 assert!(!keys.keys.is_empty());
194 }
195
196 #[tokio::test]
197 async fn test_verify_normal_token() {
198 let auth = FirebaseAuth::new("test-project".to_string()).await;
199
200 let test_token = "your.test.token";
202
203 let result: FirebaseAuthResult<FirebaseAuthUser> = auth.verify_token(test_token).await;
204 assert!(result.is_err()); }
206
207 #[tokio::test]
208 async fn test_verify_google_token() {
209 let auth = FirebaseAuth::new("test-project".to_string()).await;
210
211 let test_token = "your.google.test.token";
213
214 let result: FirebaseAuthResult<FirebaseAuthGoogleUser> =
215 auth.verify_token(test_token).await;
216 assert!(result.is_err()); }
218}