1use std::{sync::Arc, time::Duration};
2
3use serde::Deserialize;
4use tokio::{sync::RwLock, task::JoinHandle};
5
6use crate::{
7 jwk::{JwkKeys, JwkKeysError},
8 verifier::{JwkVerifier, VerificationError},
9};
10
11const FALLBACK_TIMEOUT: Duration = Duration::from_secs(10);
12
13#[derive(Debug, Deserialize)]
14pub struct Claims {
15 pub aud: String,
16 pub exp: i64,
17 pub iss: String,
18 pub sub: String,
19 pub iat: i64,
20}
21
22pub struct FirebaseAuth {
23 verifier: Arc<RwLock<JwkVerifier>>,
24 handler: JoinHandle<()>,
25}
26
27impl FirebaseAuth {
28 pub async fn new(project_id: &str) -> Result<Self, JwkKeysError> {
29 let jwk_keys = JwkKeys::fetch_keys().await?;
30
31 let verifier = Arc::new(RwLock::new(JwkVerifier::new(project_id, jwk_keys.keys)));
32
33 let handler = keep_key_updated(verifier.clone(), jwk_keys.max_age);
34
35 Ok(FirebaseAuth { verifier, handler })
36 }
37
38 pub async fn verify(&self, token: &str) -> Result<Claims, VerificationError> {
39 let verifier = self.verifier.read().await;
40 verifier.verify(token)
41 }
42}
43
44impl Drop for FirebaseAuth {
45 fn drop(&mut self) {
46 self.handler.abort();
47 }
48}
49
50fn keep_key_updated(
51 verifier: Arc<RwLock<JwkVerifier>>,
52 mut delay: Option<Duration>,
53) -> JoinHandle<()> {
54 tokio::spawn(async move {
55 loop {
56 let sleep = delay.unwrap_or(FALLBACK_TIMEOUT);
57 tracing::debug!("Fetcher sleeps {:?}", sleep);
58 tokio::time::sleep(sleep).await;
59
60 delay = match JwkKeys::fetch_keys().await {
61 Ok(jwk_keys) => {
62 let mut verifier = verifier.write().await;
63 verifier.set_keys(jwk_keys.keys);
64 tracing::debug!(
65 "Updated JWK keys. Next refresh will be in {:?}",
66 jwk_keys.max_age
67 );
68 jwk_keys.max_age
69 }
70 Err(err) => {
71 tracing::error!("Update JWK Keys Error {:?}", err);
72 None
73 }
74 };
75 }
76 })
77}