Skip to main content

firebase_admin_sdk/auth/
keys.rs

1use jsonwebtoken::DecodingKey;
2use reqwest::Client;
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6use thiserror::Error;
7use tokio::sync::RwLock;
8
9const GOOGLE_PUBLIC_KEYS_URL: &str =
10    "https://www.googleapis.com/robot/v1/metadata/x509/securetoken@system.gserviceaccount.com";
11
12#[derive(Error, Debug)]
13pub enum KeyFetchError {
14    #[error("Network error: {0}")]
15    NetworkError(#[from] reqwest::Error),
16    #[error("Failed to parse keys")]
17    ParseError,
18}
19
20#[derive(Clone)]
21struct CachedKeys {
22    keys: HashMap<String, DecodingKey>,
23    expires_at: Instant,
24}
25
26pub struct PublicKeyManager {
27    client: Client,
28    cache: Arc<RwLock<Option<CachedKeys>>>,
29}
30
31impl PublicKeyManager {
32    pub fn new() -> Self {
33        Self {
34            client: Client::new(),
35            cache: Arc::new(RwLock::new(None)),
36        }
37    }
38
39    pub async fn get_key(&self, kid: &str) -> Result<DecodingKey, KeyFetchError> {
40        // Check cache first
41        {
42            let cache = self.cache.read().await;
43            if let Some(cached) = &*cache {
44                if Instant::now() < cached.expires_at {
45                    if let Some(key) = cached.keys.get(kid) {
46                        return Ok(key.clone());
47                    }
48                }
49            }
50        }
51
52        // Fetch new keys
53        self.refresh_keys().await?;
54
55        // Check cache again
56        let cache = self.cache.read().await;
57        if let Some(cached) = &*cache {
58            cached
59                .keys
60                .get(kid)
61                .cloned()
62                .ok_or(KeyFetchError::ParseError)
63        } else {
64            Err(KeyFetchError::ParseError)
65        }
66    }
67
68    async fn refresh_keys(&self) -> Result<(), KeyFetchError> {
69        let response = self.client.get(GOOGLE_PUBLIC_KEYS_URL).send().await?;
70
71        // Parse Cache-Control header
72        let max_age = response
73            .headers()
74            .get(reqwest::header::CACHE_CONTROL)
75            .and_then(|h| h.to_str().ok())
76            .and_then(|s| {
77                s.split(',').find_map(|part| {
78                    let part = part.trim();
79                    if part.starts_with("max-age=") {
80                        part.trim_start_matches("max-age=").parse::<u64>().ok()
81                    } else {
82                        None
83                    }
84                })
85            })
86            .unwrap_or(3600); // Default to 1 hour if missing
87
88        let keys_json: HashMap<String, String> = response.json().await?;
89
90        let mut parsed_keys = HashMap::new();
91        for (kid, pem) in keys_json {
92            let decoding_key =
93                DecodingKey::from_rsa_pem(pem.as_bytes()).map_err(|_| KeyFetchError::ParseError)?;
94            parsed_keys.insert(kid, decoding_key);
95        }
96
97        let mut cache = self.cache.write().await;
98        *cache = Some(CachedKeys {
99            keys: parsed_keys,
100            expires_at: Instant::now() + Duration::from_secs(max_age),
101        });
102
103        Ok(())
104    }
105}