firebase_verifyid/jwk_cache/
mod.rs

1use super::{Error, Settings, TokenVerifier};
2use jwt_simple::algorithms::RS256PublicKey;
3use std::future::Future;
4use std::{collections::HashMap, time::Duration};
5use tokio::sync::watch;
6
7mod base64_serde;
8mod jwk_set;
9
10pub struct JwkCache {
11    client: reqwest::Client,
12    jwks: watch::Sender<HashMap<String, RS256PublicKey>>,
13    url: String,
14    cache_duration: Duration,
15}
16
17impl JwkCache {
18    pub async fn new(settings: Settings) -> Result<(TokenVerifier, Self), Error> {
19        let client = reqwest::Client::new();
20        let url = settings.url.clone();
21        let (jwks, max_age) = jwk_set::fetch_key_set(&client, &url).await?;
22        let (sender, receiver) = watch::channel(jwks);
23        let cache = Self {
24            client,
25            jwks: sender,
26            url,
27            cache_duration: max_age,
28        };
29        let verifier = TokenVerifier::new(receiver, settings)?;
30        Ok((verifier, cache))
31    }
32
33    pub async fn run<F>(mut self, mut shutdown: F)
34    where
35        F: Future<Output = ()> + Send + Unpin + 'static,
36    {
37        tracing::info!("starting firebase auth id token jwk cache");
38
39        loop {
40            tokio::select! {
41                _ = &mut shutdown => break,
42                _ = tokio::time::sleep(self.cache_duration) => {
43                    let new_cache_duration = self
44                        .refresh_key_set()
45                        .await
46                        .inspect_err(|err| tracing::error!(?err, "failure to refresh firebase auth token verifying public keys"))
47                        .unwrap_or(Duration::from_secs(60));
48                    self.cache_duration = new_cache_duration
49                }
50            }
51        }
52
53        tracing::info!("stopping firebase auth id token jwk cache");
54    }
55
56    async fn refresh_key_set(&mut self) -> Result<Duration, Error> {
57        let (new_jwks, new_max_age) = jwk_set::fetch_key_set(&self.client, &self.url).await?;
58        self.jwks.send_replace(new_jwks);
59        Ok(new_max_age)
60    }
61}