use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use jsonwebtoken::DecodingKey;
use serde::Deserialize;
use tokio::sync::RwLock;
use crate::config::SecurityConfig;
use crate::error::SecurityError;
#[derive(Debug, Clone, Deserialize)]
#[allow(dead_code)]
struct Jwk {
kid: Option<String>,
kty: String,
#[serde(default)]
alg: Option<String>,
#[serde(default)]
n: Option<String>,
#[serde(default)]
e: Option<String>,
}
#[derive(Debug, Deserialize)]
struct JwksResponse {
keys: Vec<Jwk>,
}
#[derive(Debug, Clone)]
struct CachedJwk {
kty: String,
n: Option<String>,
e: Option<String>,
}
impl CachedJwk {
fn to_decoding_key(&self) -> Result<DecodingKey, SecurityError> {
match self.kty.as_str() {
"RSA" => {
let n = self.n.as_deref().ok_or_else(|| {
SecurityError::ValidationFailed("RSA key missing 'n' component".into())
})?;
let e = self.e.as_deref().ok_or_else(|| {
SecurityError::ValidationFailed("RSA key missing 'e' component".into())
})?;
DecodingKey::from_rsa_components(n, e).map_err(|err| {
SecurityError::ValidationFailed(format!(
"Failed to construct RSA decoding key: {err}"
))
})
}
other => Err(SecurityError::ValidationFailed(format!(
"Unsupported key type: {other}"
))),
}
}
}
struct CacheInner {
keys: HashMap<String, CachedJwk>,
last_refresh: Option<Instant>,
}
pub struct JwksCache {
inner: Arc<RwLock<CacheInner>>,
config: SecurityConfig,
client: reqwest::Client,
}
impl JwksCache {
pub async fn new(config: SecurityConfig) -> Result<Self, SecurityError> {
let client = reqwest::Client::new();
let cache = Self {
inner: Arc::new(RwLock::new(CacheInner {
keys: HashMap::new(),
last_refresh: None,
})),
config,
client,
};
cache.refresh().await?;
Ok(cache)
}
pub async fn get_key(&self, kid: &str) -> Result<DecodingKey, SecurityError> {
{
let cache = self.inner.read().await;
if let Some(jwk) = cache.keys.get(kid) {
return jwk.to_decoding_key();
}
}
self.refresh().await?;
let cache = self.inner.read().await;
cache
.keys
.get(kid)
.ok_or_else(|| SecurityError::UnknownKeyId(kid.to_string()))?
.to_decoding_key()
}
async fn refresh(&self) -> Result<(), SecurityError> {
let response = self
.client
.get(&self.config.jwks_url)
.send()
.await
.map_err(|e| SecurityError::JwksFetchError(e.to_string()))?;
let jwks: JwksResponse = response
.json()
.await
.map_err(|e| SecurityError::JwksFetchError(format!("Failed to parse JWKS: {e}")))?;
let mut keys = HashMap::new();
for jwk in jwks.keys {
if let Some(kid) = &jwk.kid {
let cached = CachedJwk {
kty: jwk.kty.clone(),
n: jwk.n.clone(),
e: jwk.e.clone(),
};
keys.insert(kid.clone(), cached);
}
}
let mut cache = self.inner.write().await;
cache.keys = keys;
cache.last_refresh = Some(Instant::now());
Ok(())
}
}