use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use jsonwebtoken::DecodingKey;
use tokio::sync::RwLock;
use crate::error::IdentityError;
#[derive(Clone)]
struct CachedKey {
alg: String,
decoding_key: DecodingKey,
}
struct CacheState {
keys: HashMap<String, CachedKey>,
fetched_at: Option<Instant>,
}
pub struct JwksKeyStore {
url: String,
ttl: Duration,
cache: Arc<RwLock<CacheState>>,
}
impl JwksKeyStore {
#[must_use]
pub fn new(url: &str, ttl: Duration) -> Self {
Self {
url: url.to_owned(),
ttl,
cache: Arc::new(RwLock::new(CacheState {
keys: HashMap::new(),
fetched_at: None,
})),
}
}
pub async fn fetch(&self) -> Result<(), IdentityError> {
let jwks = fetch_jwks_http(&self.url).await?;
let keys = parse_jwks(&jwks)?;
let mut cache = self.cache.write().await;
cache.keys = keys;
cache.fetched_at = Some(Instant::now());
Ok(())
}
pub async fn get_key(&self, kid: &str) -> Option<DecodingKey> {
{
let cache = self.cache.read().await;
if let Some(fetched_at) = cache.fetched_at {
if fetched_at.elapsed() < self.ttl {
return cache.keys.get(kid).map(|k| k.decoding_key.clone());
}
}
}
if let Err(e) = self.fetch().await {
tracing::warn!("JWKS refresh failed: {e}, using stale cache if available");
let cache = self.cache.read().await;
return cache.keys.get(kid).map(|k| k.decoding_key.clone());
}
let cache = self.cache.read().await;
cache.keys.get(kid).map(|k| k.decoding_key.clone())
}
pub async fn get_algorithm(&self, kid: &str) -> Option<String> {
let cache = self.cache.read().await;
cache.keys.get(kid).map(|k| k.alg.clone())
}
pub async fn is_cache_valid(&self) -> bool {
let cache = self.cache.read().await;
if let Some(fetched_at) = cache.fetched_at {
fetched_at.elapsed() < self.ttl && !cache.keys.is_empty()
} else {
false
}
}
}
#[derive(serde::Deserialize)]
struct JwksDocument {
keys: Vec<JwkKey>,
}
#[derive(serde::Deserialize)]
struct JwkKey {
#[serde(default)]
kid: Option<String>,
kty: String,
#[serde(default)]
alg: Option<String>,
#[serde(default)]
n: Option<String>,
#[serde(default)]
e: Option<String>,
#[serde(default)]
x: Option<String>,
#[serde(default)]
y: Option<String>,
#[serde(default)]
#[allow(dead_code)]
crv: Option<String>,
}
async fn fetch_jwks_http(url: &str) -> Result<String, IdentityError> {
let url_parsed = url::Url::parse(url).map_err(|_| IdentityError::ProviderUnavailable)?;
let host = url_parsed
.host_str()
.ok_or(IdentityError::ProviderUnavailable)?;
let port = url_parsed.port().unwrap_or(match url_parsed.scheme() {
"https" => 443,
_ => 80,
});
let path = url_parsed.path();
let addr = format!("{host}:{port}");
let stream = tokio::net::TcpStream::connect(&addr)
.await
.map_err(|_| IdentityError::ProviderUnavailable)?;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let mut stream = stream;
let request = format!("GET {path} HTTP/1.1\r\nHost: {host}\r\nConnection: close\r\n\r\n");
stream
.write_all(request.as_bytes())
.await
.map_err(|_| IdentityError::ProviderUnavailable)?;
let mut response = Vec::new();
stream
.read_to_end(&mut response)
.await
.map_err(|_| IdentityError::ProviderUnavailable)?;
let response_str = String::from_utf8_lossy(&response);
let body_start = response_str
.find("\r\n\r\n")
.map(|i| i + 4)
.ok_or(IdentityError::ProviderUnavailable)?;
Ok(response_str[body_start..].to_string())
}
fn parse_jwks(json: &str) -> Result<HashMap<String, CachedKey>, IdentityError> {
let doc: JwksDocument =
serde_json::from_str(json).map_err(|_| IdentityError::TokenMalformed)?;
let mut keys = HashMap::new();
for jwk in &doc.keys {
let kid = match &jwk.kid {
Some(k) => k.clone(),
None => continue, };
let alg = jwk.alg.clone().unwrap_or_default();
let decoding_key = match jwk.kty.as_str() {
"RSA" => {
let n = jwk.n.as_deref().ok_or(IdentityError::TokenMalformed)?;
let e = jwk.e.as_deref().ok_or(IdentityError::TokenMalformed)?;
DecodingKey::from_rsa_components(n, e).map_err(|_| IdentityError::TokenMalformed)?
}
"EC" => {
let x = jwk.x.as_deref().ok_or(IdentityError::TokenMalformed)?;
let y = jwk.y.as_deref().ok_or(IdentityError::TokenMalformed)?;
DecodingKey::from_ec_components(x, y).map_err(|_| IdentityError::TokenMalformed)?
}
_ => continue, };
keys.insert(kid, CachedKey { alg, decoding_key });
}
Ok(keys)
}