use std::sync::{Arc, RwLock};
use std::time::Instant;
use jsonwebtoken::{
DecodingKey, TokenData, Validation, decode, decode_header,
jwk::{Jwk, JwkSet},
};
use typesec_core::typestate::{AgentError, Authenticator, Credentials};
use crate::http::{HttpClient, ReqwestHttpClient};
use super::claims::{JwtClaims, VerifiedSubject};
use super::config::OidcConfig;
pub struct JwtAuthenticator {
config: OidcConfig,
http: Arc<dyn HttpClient>,
jwks: RwLock<Option<CachedJwks>>,
}
#[derive(Clone)]
struct CachedJwks {
keys: JwkSet,
fetched_at: Instant,
}
impl JwtAuthenticator {
pub fn new(config: OidcConfig) -> Self {
Self::with_http(config, Arc::new(ReqwestHttpClient::new()))
}
pub fn with_http(config: OidcConfig, http: Arc<dyn HttpClient>) -> Self {
Self {
config,
http,
jwks: RwLock::new(None),
}
}
pub fn verify(&self, token: &str) -> Result<VerifiedSubject, JwtAuthError> {
let data = self.decode_claims(token)?;
if !data.claims.aud.contains(&self.config.audience) {
return Err(JwtAuthError::InvalidAudience);
}
Ok(data.claims.into())
}
fn decode_claims(&self, token: &str) -> Result<TokenData<JwtClaims>, JwtAuthError> {
let header = decode_header(token)?;
let key = self.resolve_key(header.kid.as_deref())?;
let mut validation = Validation::new(header.alg);
validation.algorithms = self.config.algorithms.clone();
validation.set_issuer(&[self.config.issuer.as_str()]);
validation.set_audience(&[self.config.audience.as_str()]);
Ok(decode::<JwtClaims>(
token,
&DecodingKey::from_jwk(&key)?,
&validation,
)?)
}
fn resolve_key(&self, kid: Option<&str>) -> Result<Jwk, JwtAuthError> {
let jwks = self.jwks(false)?;
match kid {
Some(kid) => {
if let Some(key) = jwks.find(kid) {
return Ok(key.clone());
}
let jwks = self.jwks(true)?;
jwks.find(kid).cloned().ok_or(JwtAuthError::MissingKey)
}
None => match jwks.keys.as_slice() {
[only] => Ok(only.clone()),
[] => Err(JwtAuthError::MissingKey),
_ => Err(JwtAuthError::MissingKid),
},
}
}
fn jwks(&self, force_refresh: bool) -> Result<JwkSet, JwtAuthError> {
if !force_refresh
&& let Some(cached) = self.jwks.read().expect("jwks lock poisoned").as_ref()
&& cached.fetched_at.elapsed() < self.config.jwks_ttl
{
return Ok(cached.keys.clone());
}
let value = self.http.get_json(&self.config.jwks_url, &[])?;
let keys: JwkSet = serde_json::from_value(value)?;
*self.jwks.write().expect("jwks lock poisoned") = Some(CachedJwks {
keys: keys.clone(),
fetched_at: Instant::now(),
});
Ok(keys)
}
}
impl Authenticator for JwtAuthenticator {
fn verify_credentials(&self, credentials: &Credentials) -> Result<String, AgentError> {
let verified =
self.verify(credentials.token.expose())
.map_err(|e| AgentError::AuthFailed {
reason: format!("jwt verification failed: {e}"),
})?;
if !credentials.subject.is_empty() && credentials.subject != verified.subject {
return Err(AgentError::AuthFailed {
reason: format!(
"claimed subject '{}' does not match verified token subject '{}'",
credentials.subject, verified.subject
),
});
}
Ok(verified.subject)
}
}
#[derive(Debug, thiserror::Error)]
pub enum JwtAuthError {
#[error("jwt validation failed: {0}")]
Jwt(#[from] jsonwebtoken::errors::Error),
#[error("jwks fetch failed: {0}")]
Http(#[from] Box<dyn std::error::Error + Send + Sync>),
#[error("jwks parse failed: {0}")]
Json(#[from] serde_json::Error),
#[error("no matching signing key found in JWKS")]
MissingKey,
#[error("token has no kid but JWKS is ambiguous (multiple keys)")]
MissingKid,
#[error("token audience did not match expected audience")]
InvalidAudience,
}