firebase_admin_sdk/auth/
keys.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use tokio::sync::RwLock;
4use reqwest::Client;
5use std::time::{Duration, Instant};
6use thiserror::Error;
7
8const GOOGLE_PUBLIC_KEYS_URL: &str = "https://www.googleapis.com/robot/v1/metadata/x509/securetoken@system.gserviceaccount.com";
9
10#[derive(Error, Debug)]
11pub enum KeyFetchError {
12    #[error("Network error: {0}")]
13    NetworkError(#[from] reqwest::Error),
14    #[error("Failed to parse keys")]
15    ParseError,
16}
17
18#[derive(Clone)]
19struct CachedKeys {
20    keys: HashMap<String, String>,
21    expires_at: Instant,
22}
23
24pub struct PublicKeyManager {
25    client: Client,
26    cache: Arc<RwLock<Option<CachedKeys>>>,
27}
28
29impl PublicKeyManager {
30    pub fn new() -> Self {
31        Self {
32            client: Client::new(),
33            cache: Arc::new(RwLock::new(None)),
34        }
35    }
36
37    pub async fn get_key(&self, kid: &str) -> Result<String, KeyFetchError> {
38        // Check cache first
39        {
40            let cache = self.cache.read().await;
41            if let Some(cached) = &*cache {
42                if Instant::now() < cached.expires_at {
43                    if let Some(key) = cached.keys.get(kid) {
44                        return Ok(key.clone());
45                    }
46                }
47            }
48        }
49
50        // Fetch new keys
51        self.refresh_keys().await?;
52
53        // Check cache again
54        let cache = self.cache.read().await;
55        if let Some(cached) = &*cache {
56            cached.keys.get(kid).cloned().ok_or(KeyFetchError::ParseError)
57        } else {
58            Err(KeyFetchError::ParseError)
59        }
60    }
61
62    async fn refresh_keys(&self) -> Result<(), KeyFetchError> {
63        let response = self.client.get(GOOGLE_PUBLIC_KEYS_URL).send().await?;
64
65        // Parse Cache-Control header
66        let max_age = response.headers()
67            .get(reqwest::header::CACHE_CONTROL)
68            .and_then(|h| h.to_str().ok())
69            .and_then(|s| {
70                s.split(',')
71                    .find_map(|part| {
72                        let part = part.trim();
73                        if part.starts_with("max-age=") {
74                            part.trim_start_matches("max-age=").parse::<u64>().ok()
75                        } else {
76                            None
77                        }
78                    })
79            })
80            .unwrap_or(3600); // Default to 1 hour if missing
81
82        let keys_json: HashMap<String, String> = response.json().await?;
83
84        let mut cache = self.cache.write().await;
85        *cache = Some(CachedKeys {
86            keys: keys_json,
87            expires_at: Instant::now() + Duration::from_secs(max_age),
88        });
89
90        Ok(())
91    }
92}