firebase_admin_sdk/auth/
keys.rs1use jsonwebtoken::DecodingKey;
2use reqwest::Client;
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6use thiserror::Error;
7use tokio::sync::RwLock;
8
9const GOOGLE_PUBLIC_KEYS_URL: &str =
10 "https://www.googleapis.com/robot/v1/metadata/x509/securetoken@system.gserviceaccount.com";
11
12#[derive(Error, Debug)]
13pub enum KeyFetchError {
14 #[error("Network error: {0}")]
15 NetworkError(#[from] reqwest::Error),
16 #[error("Failed to parse keys")]
17 ParseError,
18}
19
20#[derive(Clone)]
21struct CachedKeys {
22 keys: HashMap<String, DecodingKey>,
23 expires_at: Instant,
24}
25
26pub struct PublicKeyManager {
27 client: Client,
28 cache: Arc<RwLock<Option<CachedKeys>>>,
29}
30
31impl PublicKeyManager {
32 pub fn new() -> Self {
33 Self {
34 client: Client::new(),
35 cache: Arc::new(RwLock::new(None)),
36 }
37 }
38
39 pub async fn get_key(&self, kid: &str) -> Result<DecodingKey, KeyFetchError> {
40 {
42 let cache = self.cache.read().await;
43 if let Some(cached) = &*cache {
44 if Instant::now() < cached.expires_at {
45 if let Some(key) = cached.keys.get(kid) {
46 return Ok(key.clone());
47 }
48 }
49 }
50 }
51
52 self.refresh_keys().await?;
54
55 let cache = self.cache.read().await;
57 if let Some(cached) = &*cache {
58 cached
59 .keys
60 .get(kid)
61 .cloned()
62 .ok_or(KeyFetchError::ParseError)
63 } else {
64 Err(KeyFetchError::ParseError)
65 }
66 }
67
68 async fn refresh_keys(&self) -> Result<(), KeyFetchError> {
69 let response = self.client.get(GOOGLE_PUBLIC_KEYS_URL).send().await?;
70
71 let max_age = response
73 .headers()
74 .get(reqwest::header::CACHE_CONTROL)
75 .and_then(|h| h.to_str().ok())
76 .and_then(|s| {
77 s.split(',').find_map(|part| {
78 let part = part.trim();
79 if part.starts_with("max-age=") {
80 part.trim_start_matches("max-age=").parse::<u64>().ok()
81 } else {
82 None
83 }
84 })
85 })
86 .unwrap_or(3600); let keys_json: HashMap<String, String> = response.json().await?;
89
90 let mut parsed_keys = HashMap::new();
91 for (kid, pem) in keys_json {
92 let decoding_key =
93 DecodingKey::from_rsa_pem(pem.as_bytes()).map_err(|_| KeyFetchError::ParseError)?;
94 parsed_keys.insert(kid, decoding_key);
95 }
96
97 let mut cache = self.cache.write().await;
98 *cache = Some(CachedKeys {
99 keys: parsed_keys,
100 expires_at: Instant::now() + Duration::from_secs(max_age),
101 });
102
103 Ok(())
104 }
105}