1use std::{sync::Arc, time::Duration};
2
3use serde::de::DeserializeOwned;
4use tokio::{sync::RwLock, task::JoinHandle};
5
6use crate::{
7 jwk::{JwkKeys, JwkKeysError},
8 verifier::{JwkVerifier, VerificationError},
9};
10
11const FALLBACK_TIMEOUT: Duration = Duration::from_secs(10);
12
13pub struct FirebaseAuth {
14 verifier: Arc<RwLock<JwkVerifier>>,
15 handler: JoinHandle<()>,
16}
17
18impl FirebaseAuth {
19 pub async fn new(project_id: &str) -> Result<Self, JwkKeysError> {
20 let jwk_keys = JwkKeys::fetch_keys().await?;
21
22 let verifier = Arc::new(RwLock::new(JwkVerifier::new(project_id, jwk_keys.keys)));
23
24 let handler = keep_key_updated(verifier.clone(), jwk_keys.max_age);
25
26 Ok(FirebaseAuth { verifier, handler })
27 }
28
29 pub async fn verify<T: DeserializeOwned>(&self, token: &str) -> Result<T, VerificationError> {
30 let verifier = self.verifier.read().await;
31 verifier.verify(token)
32 }
33}
34
35impl Drop for FirebaseAuth {
36 fn drop(&mut self) {
37 self.handler.abort();
38 }
39}
40
41fn keep_key_updated(
42 verifier: Arc<RwLock<JwkVerifier>>,
43 mut delay: Option<Duration>,
44) -> JoinHandle<()> {
45 tokio::spawn(async move {
46 loop {
47 let sleep = delay.unwrap_or(FALLBACK_TIMEOUT);
48 tracing::debug!("Fetcher sleeps {:?}", sleep);
49 tokio::time::sleep(sleep).await;
50
51 delay = match JwkKeys::fetch_keys().await {
52 Ok(jwk_keys) => {
53 let mut verifier = verifier.write().await;
54 verifier.set_keys(jwk_keys.keys);
55 tracing::debug!(
56 "Updated JWK keys. Next refresh will be in {:?}",
57 jwk_keys.max_age
58 );
59 jwk_keys.max_age
60 }
61 Err(err) => {
62 tracing::error!("Update JWK Keys Error {:?}", err);
63 None
64 }
65 };
66 }
67 })
68}