actix_firebase_auth/
client.rs

1use actix_web::rt as actix_rt;
2use serde::de::DeserializeOwned;
3use std::{
4    sync::{Arc, Mutex, RwLock},
5    time::Duration,
6};
7use tracing::{debug, warn};
8
9use crate::jwk::{
10    JwkConfig, JwkKeys, JwkVerifier, KeyResponse, PublicKeysError,
11};
12
13/// Fallback timeout if no `max-age` is provided in the Cache-Control header.
14const FALLBACK_TIMEOUT: Duration = Duration::from_secs(60);
15
16/// `FirebaseAuth` is responsible for verifying Firebase JWT tokens and keeping the
17/// Google public keys up to date by periodically fetching them.
18///
19/// It uses the Cache-Control `max-age` directive to schedule the next refresh.
20/// If fetching fails, it retries every 10 seconds until successful.
21#[derive(Clone)]
22pub struct FirebaseAuth {
23    verifier: Arc<RwLock<JwkVerifier>>,
24    handler: Arc<Mutex<Box<actix_rt::task::JoinHandle<()>>>>,
25}
26
27impl Drop for FirebaseAuth {
28    fn drop(&mut self) {
29        // Abort the background task on drop.
30        let handler = self.handler.lock().unwrap();
31        handler.abort();
32    }
33}
34
35impl FirebaseAuth {
36    /// Create a new `FirebaseAuth` instance with an initial key fetch.
37    pub async fn new(project_id: impl AsRef<str>) -> crate::Result<Self> {
38        // Fetch the initial set of public keys
39        let jwk_keys = Self::get_public_keys().await?;
40
41        let verifier =
42            Arc::new(RwLock::new(JwkVerifier::new(project_id, jwk_keys)));
43        let handler = Arc::new(Mutex::new(Box::new(actix_rt::spawn(async {})))); // placeholder
44
45        let mut instance = Self { verifier, handler };
46        instance.start_key_update();
47
48        Ok(instance)
49    }
50
51    /// Verifies a Firebase JWT token and deserializes the payload into type `T`.
52    pub fn verify<T: DeserializeOwned>(&self, token: &str) -> crate::Result<T> {
53        let verifier = self.verifier.read().unwrap();
54        verifier
55            .verify(token)
56            .map_err(crate::Error::VerificationError)
57    }
58
59    /// Spawns a background task to periodically refresh the JWK keys.
60    ///
61    /// If the fetch fails, retries every 10 seconds.
62    fn start_key_update(&mut self) {
63        let verifier_ref = Arc::clone(&self.verifier);
64
65        let task = actix_rt::spawn(async move {
66            loop {
67                let delay = match Self::get_public_keys().await {
68                    Ok(jwk_keys) => {
69                        let mut verifier = verifier_ref.write().unwrap();
70                        verifier.set_keys(jwk_keys.clone());
71                        debug!(
72                            "Updated JWK keys. Next refresh in {:?}",
73                            jwk_keys.max_age
74                        );
75                        jwk_keys.max_age
76                    }
77                    Err(err) => {
78                        warn!("Failed to refresh public JWK keys: {:?}", err);
79                        warn!("Retrying in 10 seconds...");
80                        Duration::from_secs(10)
81                    }
82                };
83                actix_rt::time::sleep(delay).await;
84            }
85        });
86
87        let mut handler = self.handler.lock().unwrap();
88        *handler = Box::new(task);
89    }
90
91    /// Fetches the latest public keys from the identity provider and parses cache-control headers.
92    pub(crate) async fn get_public_keys() -> crate::Result<JwkKeys> {
93        let response = reqwest::get(JwkConfig::JWK_URL)
94            .await
95            .map_err(PublicKeysError::FetchPublicKeys)?;
96
97        let cache_control = response
98            .headers()
99            .get("Cache-Control")
100            .ok_or(PublicKeysError::MissingCacheControlHeader)?
101            .to_str()
102            .map_err(|_| PublicKeysError::EmptyMaxAgeDirective)?;
103
104        let max_age = Self::parse_max_age_value(cache_control)
105            .unwrap_or(FALLBACK_TIMEOUT);
106
107        let public_keys = response
108            .json::<KeyResponse>()
109            .await
110            .map_err(PublicKeysError::PublicKeyParseError)?;
111
112        Ok(JwkKeys {
113            keys: public_keys.keys,
114            max_age,
115        })
116    }
117
118    /// Parses the `max-age` directive from a Cache-Control header string.
119    pub(crate) fn parse_max_age_value(
120        value: &str,
121    ) -> Result<Duration, PublicKeysError> {
122        for directive in value.split(',') {
123            let mut parts = directive.trim().splitn(2, '=');
124            let key = parts.next().unwrap_or("").trim();
125            let val = parts.next().unwrap_or("").trim();
126
127            if key.eq_ignore_ascii_case("max-age") {
128                let secs = val
129                    .parse::<u64>()
130                    .map_err(|_| PublicKeysError::InvalidMaxAgeValue)?;
131                return Ok(Duration::from_secs(secs));
132            }
133        }
134
135        Err(PublicKeysError::MissingMaxAgeDirective)
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use super::{FirebaseAuth, FALLBACK_TIMEOUT};
142    use actix_rt::test;
143    use httpmock::Method::GET;
144    use httpmock::MockServer;
145    use jwk::{JwkKeys, KeyResponse, PublicKeysError};
146    use serde_json::json;
147    use std::time::Duration;
148
149    use crate::jwk;
150
151    async fn get_public_keys_from_url(url: &str) -> crate::Result<JwkKeys> {
152        let response = reqwest::get(url)
153            .await
154            .map_err(PublicKeysError::FetchPublicKeys)?;
155
156        let cache_control = response
157            .headers()
158            .get("Cache-Control")
159            .ok_or(PublicKeysError::MissingCacheControlHeader)?
160            .to_str()
161            .map_err(|_| PublicKeysError::EmptyMaxAgeDirective)?;
162
163        let max_age = FirebaseAuth::parse_max_age_value(cache_control)
164            .unwrap_or(FALLBACK_TIMEOUT);
165
166        let public_keys = response
167            .json::<KeyResponse>()
168            .await
169            .map_err(PublicKeysError::PublicKeyParseError)?;
170
171        Ok(JwkKeys {
172            keys: public_keys.keys,
173            max_age,
174        })
175    }
176
177    #[test]
178    async fn parses_max_age_correctly() {
179        let input = "public, max-age=3600, must-revalidate";
180        let duration = FirebaseAuth::parse_max_age_value(input).unwrap();
181        assert_eq!(duration, Duration::from_secs(3600));
182    }
183
184    #[test]
185    async fn returns_error_for_missing_max_age() {
186        let input = "public, no-cache";
187        let err = FirebaseAuth::parse_max_age_value(input).unwrap_err();
188        matches!(err, PublicKeysError::MissingMaxAgeDirective);
189    }
190
191    #[test]
192    async fn returns_error_for_invalid_max_age() {
193        let input = "max-age=not_a_number";
194        let err = FirebaseAuth::parse_max_age_value(input).unwrap_err();
195        matches!(err, PublicKeysError::InvalidMaxAgeValue);
196    }
197
198    #[test]
199    async fn get_public_keys_successfully_parses_keys() {
200        let server = MockServer::start();
201
202        let body = json!({
203            "keys": [
204                {
205                    "kty": "RSA",
206                    "alg": "RS256",
207                    "use": "sig",
208                    "kid": "1234",
209                    "n": "modulus",
210                    "e": "AQAB"
211                }
212            ]
213        });
214
215        let _mock = server.mock(|when, then| {
216            when.method(GET).path("/keys");
217            then.status(200)
218                .header("Cache-Control", "public, max-age=120")
219                .json_body(body.clone());
220        });
221
222        let keys = get_public_keys_from_url(&server.url("/keys"))
223            .await
224            .unwrap();
225        assert_eq!(keys.max_age, Duration::from_secs(120));
226        assert_eq!(keys.keys.len(), 1);
227    }
228
229    #[test]
230    async fn background_task_aborts_on_drop() {
231        let auth = FirebaseAuth::new("dummy-project").await;
232        assert!(auth.is_ok(), "FirebaseAuth failed to build");
233        let auth = auth.unwrap();
234
235        {
236            let handler_guard = auth.handler.lock().unwrap();
237            assert!(!handler_guard.is_finished(), "Task should be running");
238        }
239
240        drop(auth); // Triggers Drop which aborts task
241
242        // Give a moment for task abort to propagate
243        actix_web::rt::time::sleep(Duration::from_millis(100)).await;
244    }
245}