actix_firebase_auth/
client.rs

1use actix_web::rt as actix_rt;
2use serde::de::DeserializeOwned;
3use std::{
4    sync::{Arc, Mutex, RwLock},
5    time::Duration,
6};
7use tracing::*;
8
9use crate::jwk::{JwkConfig, JwkKeys, JwkVerifier, KeyResponse, PublicKeysError};
10
11/// Fallback timeout if no `max-age` is provided in the Cache-Control header.
12const FALLBACK_TIMEOUT: Duration = Duration::from_secs(60);
13
14/// FirebaseAuth is responsible for verifying Firebase JWT tokens and keeping the
15/// Google public keys up to date by periodically fetching them.
16///
17/// It uses the Cache-Control `max-age` directive to schedule the next refresh.
18/// If fetching fails, it retries every 10 seconds until successful.
19#[derive(Clone)]
20pub struct FirebaseAuth {
21    verifier: Arc<RwLock<JwkVerifier>>,
22    handler: Arc<Mutex<Box<actix_rt::task::JoinHandle<()>>>>,
23}
24
25impl Drop for FirebaseAuth {
26    fn drop(&mut self) {
27        // Abort the background task on drop.
28        let handler = self.handler.lock().unwrap();
29        handler.abort();
30    }
31}
32
33impl FirebaseAuth {
34    /// Create a new FirebaseAuth instance with an initial key fetch.
35    pub async fn new(project_id: impl AsRef<str>) -> crate::Result<Self> {
36        // Fetch the initial set of public keys
37        let jwk_keys = Self::get_public_keys().await?;
38
39        let verifier = Arc::new(RwLock::new(JwkVerifier::new(project_id, jwk_keys)));
40        let handler = Arc::new(Mutex::new(Box::new(actix_rt::spawn(async {})))); // placeholder
41
42        let mut instance = Self { verifier, handler };
43        instance.start_key_update();
44
45        Ok(instance)
46    }
47
48    /// Verifies a Firebase JWT token and deserializes the payload into type `T`.
49    pub fn verify<T: DeserializeOwned>(&self, token: &str) -> crate::Result<T> {
50        let verifier = self.verifier.read().unwrap();
51        verifier
52            .verify(token)
53            .map_err(crate::Error::VerificationError)
54    }
55
56    /// Spawns a background task to periodically refresh the JWK keys.
57    ///
58    /// If the fetch fails, retries every 10 seconds.
59    fn start_key_update(&mut self) {
60        let verifier_ref = Arc::clone(&self.verifier);
61
62        let task = actix_rt::spawn(async move {
63            loop {
64                let delay = match Self::get_public_keys().await {
65                    Ok(jwk_keys) => {
66                        let mut verifier = verifier_ref.write().unwrap();
67                        verifier.set_keys(jwk_keys.clone());
68                        debug!("Updated JWK keys. Next refresh in {:?}", jwk_keys.max_age);
69                        jwk_keys.max_age
70                    }
71                    Err(err) => {
72                        warn!("Failed to refresh public JWK keys: {:?}", err);
73                        warn!("Retrying in 10 seconds...");
74                        Duration::from_secs(10)
75                    }
76                };
77                actix_rt::time::sleep(delay).await;
78            }
79        });
80
81        let mut handler = self.handler.lock().unwrap();
82        *handler = Box::new(task);
83    }
84
85    /// Fetches the latest public keys from the identity provider and parses cache-control headers.
86    pub(crate) async fn get_public_keys() -> crate::Result<JwkKeys> {
87        let response = reqwest::get(JwkConfig::JWK_URL)
88            .await
89            .map_err(PublicKeysError::FetchPublicKeys)?;
90
91        let cache_control = response
92            .headers()
93            .get("Cache-Control")
94            .ok_or(PublicKeysError::MissingCacheControlHeader)?
95            .to_str()
96            .map_err(|_| PublicKeysError::EmptyMaxAgeDirective)?;
97
98        let max_age = Self::parse_max_age_value(cache_control).unwrap_or(FALLBACK_TIMEOUT);
99
100        let public_keys = response
101            .json::<KeyResponse>()
102            .await
103            .map_err(PublicKeysError::PublicKeyParseError)?;
104
105        Ok(JwkKeys {
106            keys: public_keys.keys,
107            max_age,
108        })
109    }
110
111    /// Parses the `max-age` directive from a Cache-Control header string.
112    pub(crate) fn parse_max_age_value(value: &str) -> Result<Duration, PublicKeysError> {
113        for directive in value.split(',') {
114            let mut parts = directive.trim().splitn(2, '=');
115            let key = parts.next().unwrap_or("").trim();
116            let val = parts.next().unwrap_or("").trim();
117
118            if key.eq_ignore_ascii_case("max-age") {
119                let secs = val
120                    .parse::<u64>()
121                    .map_err(|_| PublicKeysError::InvalidMaxAgeValue)?;
122                return Ok(Duration::from_secs(secs));
123            }
124        }
125
126        Err(PublicKeysError::MissingMaxAgeDirective)
127    }
128}
129
130#[cfg(test)]
131mod tests {
132    use super::{FALLBACK_TIMEOUT, FirebaseAuth};
133    use actix_rt::test;
134    use httpmock::Method::GET;
135    use httpmock::MockServer;
136    use jwk::{JwkKeys, KeyResponse, PublicKeysError};
137    use serde_json::json;
138    use std::time::Duration;
139
140    use crate::jwk;
141
142    async fn get_public_keys_from_url(url: &str) -> crate::Result<JwkKeys> {
143        let response = reqwest::get(url)
144            .await
145            .map_err(PublicKeysError::FetchPublicKeys)?;
146
147        let cache_control = response
148            .headers()
149            .get("Cache-Control")
150            .ok_or(PublicKeysError::MissingCacheControlHeader)?
151            .to_str()
152            .map_err(|_| PublicKeysError::EmptyMaxAgeDirective)?;
153
154        let max_age = FirebaseAuth::parse_max_age_value(cache_control).unwrap_or(FALLBACK_TIMEOUT);
155
156        let public_keys = response
157            .json::<KeyResponse>()
158            .await
159            .map_err(PublicKeysError::PublicKeyParseError)?;
160
161        Ok(JwkKeys {
162            keys: public_keys.keys,
163            max_age,
164        })
165    }
166
167    #[test]
168    async fn parses_max_age_correctly() {
169        let input = "public, max-age=3600, must-revalidate";
170        let duration = FirebaseAuth::parse_max_age_value(input).unwrap();
171        assert_eq!(duration, Duration::from_secs(3600));
172    }
173
174    #[test]
175    async fn returns_error_for_missing_max_age() {
176        let input = "public, no-cache";
177        let err = FirebaseAuth::parse_max_age_value(input).unwrap_err();
178        matches!(err, PublicKeysError::MissingMaxAgeDirective);
179    }
180
181    #[test]
182    async fn returns_error_for_invalid_max_age() {
183        let input = "max-age=not_a_number";
184        let err = FirebaseAuth::parse_max_age_value(input).unwrap_err();
185        matches!(err, PublicKeysError::InvalidMaxAgeValue);
186    }
187
188    #[test]
189    async fn get_public_keys_successfully_parses_keys() {
190        let server = MockServer::start();
191
192        let body = json!({
193            "keys": [
194                {
195                    "kty": "RSA",
196                    "alg": "RS256",
197                    "use": "sig",
198                    "kid": "1234",
199                    "n": "modulus",
200                    "e": "AQAB"
201                }
202            ]
203        });
204
205        let _mock = server.mock(|when, then| {
206            when.method(GET).path("/keys");
207            then.status(200)
208                .header("Cache-Control", "public, max-age=120")
209                .json_body(body.clone());
210        });
211
212        let keys = get_public_keys_from_url(&server.url("/keys"))
213            .await
214            .unwrap();
215        assert_eq!(keys.max_age, Duration::from_secs(120));
216        assert_eq!(keys.keys.len(), 1);
217    }
218
219    #[test]
220    async fn background_task_aborts_on_drop() {
221        let auth = FirebaseAuth::new("dummy-project").await;
222        assert!(auth.is_ok(), "FirebaseAuth failed to build");
223        let auth = auth.unwrap();
224
225        {
226            let handler_guard = auth.handler.lock().unwrap();
227            assert!(!handler_guard.is_finished(), "Task should be running");
228        }
229
230        drop(auth); // Triggers Drop which aborts task
231
232        // Give a moment for task abort to propagate
233        actix_web::rt::time::sleep(Duration::from_millis(100)).await;
234    }
235}