actix_firebase_auth/
client.rs1use actix_web::rt as actix_rt;
2use serde::de::DeserializeOwned;
3use std::{
4 sync::{Arc, Mutex, RwLock},
5 time::Duration,
6};
7use tracing::{debug, warn};
8
9use crate::jwk::{
10 JwkConfig, JwkKeys, JwkVerifier, KeyResponse, PublicKeysError,
11};
12
13const FALLBACK_TIMEOUT: Duration = Duration::from_secs(60);
15
16#[derive(Clone)]
22pub struct FirebaseAuth {
23 verifier: Arc<RwLock<JwkVerifier>>,
24 handler: Arc<Mutex<Box<actix_rt::task::JoinHandle<()>>>>,
25}
26
27impl Drop for FirebaseAuth {
28 fn drop(&mut self) {
29 let handler = self.handler.lock().unwrap();
31 handler.abort();
32 }
33}
34
35impl FirebaseAuth {
36 pub async fn new(project_id: impl AsRef<str>) -> crate::Result<Self> {
38 let jwk_keys = Self::get_public_keys().await?;
40
41 let verifier =
42 Arc::new(RwLock::new(JwkVerifier::new(project_id, jwk_keys)));
43 let handler = Arc::new(Mutex::new(Box::new(actix_rt::spawn(async {})))); let mut instance = Self { verifier, handler };
46 instance.start_key_update();
47
48 Ok(instance)
49 }
50
51 pub fn verify<T: DeserializeOwned>(&self, token: &str) -> crate::Result<T> {
53 let verifier = self.verifier.read().unwrap();
54 verifier
55 .verify(token)
56 .map_err(crate::Error::VerificationError)
57 }
58
59 fn start_key_update(&mut self) {
63 let verifier_ref = Arc::clone(&self.verifier);
64
65 let task = actix_rt::spawn(async move {
66 loop {
67 let delay = match Self::get_public_keys().await {
68 Ok(jwk_keys) => {
69 let mut verifier = verifier_ref.write().unwrap();
70 verifier.set_keys(jwk_keys.clone());
71 debug!(
72 "Updated JWK keys. Next refresh in {:?}",
73 jwk_keys.max_age
74 );
75 jwk_keys.max_age
76 }
77 Err(err) => {
78 warn!("Failed to refresh public JWK keys: {:?}", err);
79 warn!("Retrying in 10 seconds...");
80 Duration::from_secs(10)
81 }
82 };
83 actix_rt::time::sleep(delay).await;
84 }
85 });
86
87 let mut handler = self.handler.lock().unwrap();
88 *handler = Box::new(task);
89 }
90
91 pub(crate) async fn get_public_keys() -> crate::Result<JwkKeys> {
93 let response = reqwest::get(JwkConfig::JWK_URL)
94 .await
95 .map_err(PublicKeysError::FetchPublicKeys)?;
96
97 let cache_control = response
98 .headers()
99 .get("Cache-Control")
100 .ok_or(PublicKeysError::MissingCacheControlHeader)?
101 .to_str()
102 .map_err(|_| PublicKeysError::EmptyMaxAgeDirective)?;
103
104 let max_age = Self::parse_max_age_value(cache_control)
105 .unwrap_or(FALLBACK_TIMEOUT);
106
107 let public_keys = response
108 .json::<KeyResponse>()
109 .await
110 .map_err(PublicKeysError::PublicKeyParseError)?;
111
112 Ok(JwkKeys {
113 keys: public_keys.keys,
114 max_age,
115 })
116 }
117
118 pub(crate) fn parse_max_age_value(
120 value: &str,
121 ) -> Result<Duration, PublicKeysError> {
122 for directive in value.split(',') {
123 let mut parts = directive.trim().splitn(2, '=');
124 let key = parts.next().unwrap_or("").trim();
125 let val = parts.next().unwrap_or("").trim();
126
127 if key.eq_ignore_ascii_case("max-age") {
128 let secs = val
129 .parse::<u64>()
130 .map_err(|_| PublicKeysError::InvalidMaxAgeValue)?;
131 return Ok(Duration::from_secs(secs));
132 }
133 }
134
135 Err(PublicKeysError::MissingMaxAgeDirective)
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::{FirebaseAuth, FALLBACK_TIMEOUT};
142 use actix_rt::test;
143 use httpmock::Method::GET;
144 use httpmock::MockServer;
145 use jwk::{JwkKeys, KeyResponse, PublicKeysError};
146 use serde_json::json;
147 use std::time::Duration;
148
149 use crate::jwk;
150
151 async fn get_public_keys_from_url(url: &str) -> crate::Result<JwkKeys> {
152 let response = reqwest::get(url)
153 .await
154 .map_err(PublicKeysError::FetchPublicKeys)?;
155
156 let cache_control = response
157 .headers()
158 .get("Cache-Control")
159 .ok_or(PublicKeysError::MissingCacheControlHeader)?
160 .to_str()
161 .map_err(|_| PublicKeysError::EmptyMaxAgeDirective)?;
162
163 let max_age = FirebaseAuth::parse_max_age_value(cache_control)
164 .unwrap_or(FALLBACK_TIMEOUT);
165
166 let public_keys = response
167 .json::<KeyResponse>()
168 .await
169 .map_err(PublicKeysError::PublicKeyParseError)?;
170
171 Ok(JwkKeys {
172 keys: public_keys.keys,
173 max_age,
174 })
175 }
176
177 #[test]
178 async fn parses_max_age_correctly() {
179 let input = "public, max-age=3600, must-revalidate";
180 let duration = FirebaseAuth::parse_max_age_value(input).unwrap();
181 assert_eq!(duration, Duration::from_secs(3600));
182 }
183
184 #[test]
185 async fn returns_error_for_missing_max_age() {
186 let input = "public, no-cache";
187 let err = FirebaseAuth::parse_max_age_value(input).unwrap_err();
188 matches!(err, PublicKeysError::MissingMaxAgeDirective);
189 }
190
191 #[test]
192 async fn returns_error_for_invalid_max_age() {
193 let input = "max-age=not_a_number";
194 let err = FirebaseAuth::parse_max_age_value(input).unwrap_err();
195 matches!(err, PublicKeysError::InvalidMaxAgeValue);
196 }
197
198 #[test]
199 async fn get_public_keys_successfully_parses_keys() {
200 let server = MockServer::start();
201
202 let body = json!({
203 "keys": [
204 {
205 "kty": "RSA",
206 "alg": "RS256",
207 "use": "sig",
208 "kid": "1234",
209 "n": "modulus",
210 "e": "AQAB"
211 }
212 ]
213 });
214
215 let _mock = server.mock(|when, then| {
216 when.method(GET).path("/keys");
217 then.status(200)
218 .header("Cache-Control", "public, max-age=120")
219 .json_body(body.clone());
220 });
221
222 let keys = get_public_keys_from_url(&server.url("/keys"))
223 .await
224 .unwrap();
225 assert_eq!(keys.max_age, Duration::from_secs(120));
226 assert_eq!(keys.keys.len(), 1);
227 }
228
229 #[test]
230 async fn background_task_aborts_on_drop() {
231 let auth = FirebaseAuth::new("dummy-project").await;
232 assert!(auth.is_ok(), "FirebaseAuth failed to build");
233 let auth = auth.unwrap();
234
235 {
236 let handler_guard = auth.handler.lock().unwrap();
237 assert!(!handler_guard.is_finished(), "Task should be running");
238 }
239
240 drop(auth); actix_web::rt::time::sleep(Duration::from_millis(100)).await;
244 }
245}