actix_firebase_auth/
client.rs1use serde::de::DeserializeOwned;
2use std::{
3 sync::{Arc, Mutex, RwLock},
4 time::Duration,
5};
6use tokio::{task::JoinHandle, time::sleep};
7use tracing::*;
8
9use crate::jwk::{JwkConfig, JwkKeys, JwkVerifier, KeyResponse, PublicKeysError};
10
11const FALLBACK_TIMEOUT: Duration = Duration::from_secs(60);
13
14#[derive(Clone)]
20pub struct FirebaseAuth {
21 verifier: Arc<RwLock<JwkVerifier>>,
22 handler: Arc<Mutex<Box<JoinHandle<()>>>>,
23}
24
25impl Drop for FirebaseAuth {
26 fn drop(&mut self) {
27 let handler = self.handler.lock().unwrap();
29 handler.abort();
30 }
31}
32
33impl FirebaseAuth {
34 pub async fn new(project_id: impl AsRef<str>) -> Self {
38 let jwk_keys = match Self::get_public_keys().await {
39 Ok(keys) => keys,
40 Err(e) => {
41 eprintln!("Error fetching initial public JWK keys: {:?}", e);
42 panic!("Failed to fetch initial public JWK keys. Cannot verify Firebase tokens.");
43 }
44 };
45
46 let verifier = Arc::new(RwLock::new(JwkVerifier::new(project_id, jwk_keys)));
47 let handler = Arc::new(Mutex::new(Box::new(tokio::spawn(async {})))); let mut instance = Self { verifier, handler };
50 instance.start_key_update();
51 instance
52 }
53
54 pub fn verify<T: DeserializeOwned>(&self, token: &str) -> crate::Result<T> {
56 let verifier = self.verifier.read().unwrap();
57 verifier
58 .verify(token)
59 .map_err(crate::Error::VerificationError)
60 }
61
62 fn start_key_update(&mut self) {
66 let verifier_ref = Arc::clone(&self.verifier);
67
68 let task = tokio::spawn(async move {
69 loop {
70 let delay = match Self::get_public_keys().await {
71 Ok(jwk_keys) => {
72 let mut verifier = verifier_ref.write().unwrap();
73 verifier.set_keys(jwk_keys.clone());
74 debug!("Updated JWK keys. Next refresh in {:?}", jwk_keys.max_age);
75 jwk_keys.max_age
76 }
77 Err(err) => {
78 warn!("Failed to refresh public JWK keys: {:?}", err);
79 warn!("Retrying in 10 seconds...");
80 Duration::from_secs(10)
81 }
82 };
83 sleep(delay).await;
84 }
85 });
86
87 let mut handler = self.handler.lock().unwrap();
88 *handler = Box::new(task);
89 }
90
91 pub(crate) async fn get_public_keys() -> crate::Result<JwkKeys> {
93 let response = reqwest::get(JwkConfig::JWK_URL)
94 .await
95 .map_err(PublicKeysError::FetchPublicKeys)?;
96
97 let cache_control = response
98 .headers()
99 .get("Cache-Control")
100 .ok_or(PublicKeysError::MissingCacheControlHeader)?
101 .to_str()
102 .map_err(|_| PublicKeysError::EmptyMaxAgeDirective)?;
103
104 let max_age = Self::parse_max_age_value(cache_control).unwrap_or(FALLBACK_TIMEOUT);
105
106 let public_keys = response
107 .json::<KeyResponse>()
108 .await
109 .map_err(PublicKeysError::PublicKeyParseError)?;
110
111 Ok(JwkKeys {
112 keys: public_keys.keys,
113 max_age,
114 })
115 }
116
117 pub(crate) fn parse_max_age_value(value: &str) -> Result<Duration, PublicKeysError> {
119 for directive in value.split(',') {
120 let mut parts = directive.trim().splitn(2, '=');
121 let key = parts.next().unwrap_or("").trim();
122 let val = parts.next().unwrap_or("").trim();
123
124 if key.eq_ignore_ascii_case("max-age") {
125 let secs = val
126 .parse::<u64>()
127 .map_err(|_| PublicKeysError::InvalidMaxAgeValue)?;
128 return Ok(Duration::from_secs(secs));
129 }
130 }
131
132 Err(PublicKeysError::MissingMaxAgeDirective)
133 }
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139 use actix_rt::test;
140 use httpmock::Method::GET;
141 use httpmock::MockServer;
142 use jwk::{JwkKeys, KeyResponse, PublicKeysError};
143 use serde_json::json;
144 use std::time::Duration;
145
146 use crate::jwk;
147
148 async fn get_public_keys_from_url(url: &str) -> crate::Result<JwkKeys> {
149 let response = reqwest::get(url)
150 .await
151 .map_err(PublicKeysError::FetchPublicKeys)?;
152
153 let cache_control = response
154 .headers()
155 .get("Cache-Control")
156 .ok_or(PublicKeysError::MissingCacheControlHeader)?
157 .to_str()
158 .map_err(|_| PublicKeysError::EmptyMaxAgeDirective)?;
159
160 let max_age = FirebaseAuth::parse_max_age_value(cache_control).unwrap_or(FALLBACK_TIMEOUT);
161
162 let public_keys = response
163 .json::<KeyResponse>()
164 .await
165 .map_err(PublicKeysError::PublicKeyParseError)?;
166
167 Ok(JwkKeys {
168 keys: public_keys.keys,
169 max_age,
170 })
171 }
172
173 #[test]
174 async fn parses_max_age_correctly() {
175 let input = "public, max-age=3600, must-revalidate";
176 let duration = FirebaseAuth::parse_max_age_value(input).unwrap();
177 assert_eq!(duration, Duration::from_secs(3600));
178 }
179
180 #[test]
181 async fn returns_error_for_missing_max_age() {
182 let input = "public, no-cache";
183 let err = FirebaseAuth::parse_max_age_value(input).unwrap_err();
184 matches!(err, PublicKeysError::MissingMaxAgeDirective);
185 }
186
187 #[test]
188 async fn returns_error_for_invalid_max_age() {
189 let input = "max-age=not_a_number";
190 let err = FirebaseAuth::parse_max_age_value(input).unwrap_err();
191 matches!(err, PublicKeysError::InvalidMaxAgeValue);
192 }
193
194 #[test]
195 async fn get_public_keys_successfully_parses_keys() {
196 let server = MockServer::start();
197
198 let body = json!({
199 "keys": [
200 {
201 "kty": "RSA",
202 "alg": "RS256",
203 "use": "sig",
204 "kid": "1234",
205 "n": "modulus",
206 "e": "AQAB"
207 }
208 ]
209 });
210
211 let _mock = server.mock(|when, then| {
212 when.method(GET).path("/keys");
213 then.status(200)
214 .header("Cache-Control", "public, max-age=120")
215 .json_body(body.clone());
216 });
217
218 let keys = get_public_keys_from_url(&server.url("/keys"))
219 .await
220 .unwrap();
221 assert_eq!(keys.max_age, Duration::from_secs(120));
222 assert_eq!(keys.keys.len(), 1);
223 }
224
225 #[test]
226 async fn background_task_aborts_on_drop() {
227 let auth = FirebaseAuth::new("dummy-project").await;
228
229 {
230 let handler_guard = auth.handler.lock().unwrap();
231 assert!(!handler_guard.is_finished(), "Task should be running");
232 }
233
234 drop(auth); actix_rt::time::sleep(Duration::from_millis(100)).await;
238 }
239}