actix_firebase_auth/
client.rs

1use serde::de::DeserializeOwned;
2use std::{
3    sync::{Arc, Mutex, RwLock},
4    time::Duration,
5};
6use tokio::{task::JoinHandle, time::sleep};
7use tracing::*;
8
9use crate::jwk::{JwkConfig, JwkKeys, JwkVerifier, KeyResponse, PublicKeysError};
10
11/// Fallback timeout if no `max-age` is provided in the Cache-Control header.
12const FALLBACK_TIMEOUT: Duration = Duration::from_secs(60);
13
14/// FirebaseAuth is responsible for verifying Firebase JWT tokens and keeping the
15/// Google public keys up to date by periodically fetching them.
16///
17/// It uses the Cache-Control `max-age` directive to schedule the next refresh.
18/// If fetching fails, it retries every 10 seconds until successful.
19#[derive(Clone)]
20pub struct FirebaseAuth {
21    verifier: Arc<RwLock<JwkVerifier>>,
22    handler: Arc<Mutex<Box<JoinHandle<()>>>>,
23}
24
25impl Drop for FirebaseAuth {
26    fn drop(&mut self) {
27        // Abort the background task on drop.
28        let handler = self.handler.lock().unwrap();
29        handler.abort();
30    }
31}
32
33impl FirebaseAuth {
34    /// Create a new FirebaseAuth instance with an initial key fetch.
35    ///
36    /// Panics if the initial fetch fails, since the app cannot verify any tokens.
37    pub async fn new(project_id: impl AsRef<str>) -> Self {
38        let jwk_keys = match Self::get_public_keys().await {
39            Ok(keys) => keys,
40            Err(e) => {
41                eprintln!("Error fetching initial public JWK keys: {:?}", e);
42                panic!("Failed to fetch initial public JWK keys. Cannot verify Firebase tokens.");
43            }
44        };
45
46        let verifier = Arc::new(RwLock::new(JwkVerifier::new(project_id, jwk_keys)));
47        let handler = Arc::new(Mutex::new(Box::new(tokio::spawn(async {})))); // placeholder
48
49        let mut instance = Self { verifier, handler };
50        instance.start_key_update();
51        instance
52    }
53
54    /// Verifies a Firebase JWT token and deserializes the payload into type `T`.
55    pub fn verify<T: DeserializeOwned>(&self, token: &str) -> crate::Result<T> {
56        let verifier = self.verifier.read().unwrap();
57        verifier
58            .verify(token)
59            .map_err(crate::Error::VerificationError)
60    }
61
62    /// Spawns a background task to periodically refresh the JWK keys.
63    ///
64    /// If the fetch fails, retries every 10 seconds.
65    fn start_key_update(&mut self) {
66        let verifier_ref = Arc::clone(&self.verifier);
67
68        let task = tokio::spawn(async move {
69            loop {
70                let delay = match Self::get_public_keys().await {
71                    Ok(jwk_keys) => {
72                        let mut verifier = verifier_ref.write().unwrap();
73                        verifier.set_keys(jwk_keys.clone());
74                        debug!("Updated JWK keys. Next refresh in {:?}", jwk_keys.max_age);
75                        jwk_keys.max_age
76                    }
77                    Err(err) => {
78                        warn!("Failed to refresh public JWK keys: {:?}", err);
79                        warn!("Retrying in 10 seconds...");
80                        Duration::from_secs(10)
81                    }
82                };
83                sleep(delay).await;
84            }
85        });
86
87        let mut handler = self.handler.lock().unwrap();
88        *handler = Box::new(task);
89    }
90
91    /// Fetches the latest public keys from the identity provider and parses cache-control headers.
92    pub(crate) async fn get_public_keys() -> crate::Result<JwkKeys> {
93        let response = reqwest::get(JwkConfig::JWK_URL)
94            .await
95            .map_err(PublicKeysError::FetchPublicKeys)?;
96
97        let cache_control = response
98            .headers()
99            .get("Cache-Control")
100            .ok_or(PublicKeysError::MissingCacheControlHeader)?
101            .to_str()
102            .map_err(|_| PublicKeysError::EmptyMaxAgeDirective)?;
103
104        let max_age = Self::parse_max_age_value(cache_control).unwrap_or(FALLBACK_TIMEOUT);
105
106        let public_keys = response
107            .json::<KeyResponse>()
108            .await
109            .map_err(PublicKeysError::PublicKeyParseError)?;
110
111        Ok(JwkKeys {
112            keys: public_keys.keys,
113            max_age,
114        })
115    }
116
117    /// Parses the `max-age` directive from a Cache-Control header string.
118    pub(crate) fn parse_max_age_value(value: &str) -> Result<Duration, PublicKeysError> {
119        for directive in value.split(',') {
120            let mut parts = directive.trim().splitn(2, '=');
121            let key = parts.next().unwrap_or("").trim();
122            let val = parts.next().unwrap_or("").trim();
123
124            if key.eq_ignore_ascii_case("max-age") {
125                let secs = val
126                    .parse::<u64>()
127                    .map_err(|_| PublicKeysError::InvalidMaxAgeValue)?;
128                return Ok(Duration::from_secs(secs));
129            }
130        }
131
132        Err(PublicKeysError::MissingMaxAgeDirective)
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139    use actix_rt::test;
140    use httpmock::Method::GET;
141    use httpmock::MockServer;
142    use jwk::{JwkKeys, KeyResponse, PublicKeysError};
143    use serde_json::json;
144    use std::time::Duration;
145
146    use crate::jwk;
147
148    async fn get_public_keys_from_url(url: &str) -> crate::Result<JwkKeys> {
149        let response = reqwest::get(url)
150            .await
151            .map_err(PublicKeysError::FetchPublicKeys)?;
152
153        let cache_control = response
154            .headers()
155            .get("Cache-Control")
156            .ok_or(PublicKeysError::MissingCacheControlHeader)?
157            .to_str()
158            .map_err(|_| PublicKeysError::EmptyMaxAgeDirective)?;
159
160        let max_age = FirebaseAuth::parse_max_age_value(cache_control).unwrap_or(FALLBACK_TIMEOUT);
161
162        let public_keys = response
163            .json::<KeyResponse>()
164            .await
165            .map_err(PublicKeysError::PublicKeyParseError)?;
166
167        Ok(JwkKeys {
168            keys: public_keys.keys,
169            max_age,
170        })
171    }
172
173    #[test]
174    async fn parses_max_age_correctly() {
175        let input = "public, max-age=3600, must-revalidate";
176        let duration = FirebaseAuth::parse_max_age_value(input).unwrap();
177        assert_eq!(duration, Duration::from_secs(3600));
178    }
179
180    #[test]
181    async fn returns_error_for_missing_max_age() {
182        let input = "public, no-cache";
183        let err = FirebaseAuth::parse_max_age_value(input).unwrap_err();
184        matches!(err, PublicKeysError::MissingMaxAgeDirective);
185    }
186
187    #[test]
188    async fn returns_error_for_invalid_max_age() {
189        let input = "max-age=not_a_number";
190        let err = FirebaseAuth::parse_max_age_value(input).unwrap_err();
191        matches!(err, PublicKeysError::InvalidMaxAgeValue);
192    }
193
194    #[test]
195    async fn get_public_keys_successfully_parses_keys() {
196        let server = MockServer::start();
197
198        let body = json!({
199            "keys": [
200                {
201                    "kty": "RSA",
202                    "alg": "RS256",
203                    "use": "sig",
204                    "kid": "1234",
205                    "n": "modulus",
206                    "e": "AQAB"
207                }
208            ]
209        });
210
211        let _mock = server.mock(|when, then| {
212            when.method(GET).path("/keys");
213            then.status(200)
214                .header("Cache-Control", "public, max-age=120")
215                .json_body(body.clone());
216        });
217
218        let keys = get_public_keys_from_url(&server.url("/keys"))
219            .await
220            .unwrap();
221        assert_eq!(keys.max_age, Duration::from_secs(120));
222        assert_eq!(keys.keys.len(), 1);
223    }
224
225    #[test]
226    async fn background_task_aborts_on_drop() {
227        let auth = FirebaseAuth::new("dummy-project").await;
228
229        {
230            let handler_guard = auth.handler.lock().unwrap();
231            assert!(!handler_guard.is_finished(), "Task should be running");
232        }
233
234        drop(auth); // Triggers Drop which aborts task
235
236        // Give a moment for task abort to propagate
237        actix_rt::time::sleep(Duration::from_millis(100)).await;
238    }
239}