use crate::error::{FusekiError, FusekiResult};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonWebKeySet {
pub keys: Vec<JsonWebKey>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JsonWebKey {
pub kty: String,
#[serde(rename = "use")]
pub key_use: Option<String>,
pub key_ops: Option<Vec<String>>,
pub alg: Option<String>,
pub kid: Option<String>,
pub n: Option<String>,
pub e: Option<String>,
pub x: Option<String>,
pub y: Option<String>,
pub crv: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct JwtHeader {
pub alg: String,
pub typ: String,
pub kid: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct JwtClaims {
pub iss: String,
pub sub: String,
pub aud: String,
pub exp: i64,
pub iat: i64,
#[serde(skip_serializing_if = "Option::is_none")]
pub nbf: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub jti: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub nonce: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub auth_time: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub azp: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub email: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub email_verified: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub given_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub family_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub picture: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub locale: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub groups: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub roles: Option<Vec<String>>,
}
pub struct JwtValidator {
jwk_cache: Arc<RwLock<HashMap<String, CachedJwks>>>,
client: reqwest::Client,
allowed_issuers: Vec<String>,
allowed_audiences: Vec<String>,
}
#[derive(Clone)]
struct CachedJwks {
jwks: JsonWebKeySet,
fetched_at: DateTime<Utc>,
expires_at: DateTime<Utc>,
}
impl JwtValidator {
pub fn new(allowed_issuers: Vec<String>, allowed_audiences: Vec<String>) -> Self {
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()
.unwrap_or_default();
JwtValidator {
jwk_cache: Arc::new(RwLock::new(HashMap::new())),
client,
allowed_issuers,
allowed_audiences,
}
}
pub async fn validate_id_token(
&self,
token: &str,
expected_nonce: Option<&str>,
) -> FusekiResult<JwtClaims> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return Err(FusekiError::authentication("Invalid JWT format"));
}
let header = self.decode_header(parts[0])?;
let claims = self.decode_claims(parts[1])?;
self.validate_claims(&claims, expected_nonce)?;
let jwk = self.get_jwk_for_token(&header, &claims.iss).await?;
self.verify_signature(token, &header, &jwk)?;
Ok(claims)
}
fn decode_header(&self, header_b64: &str) -> FusekiResult<JwtHeader> {
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
let header_bytes = URL_SAFE_NO_PAD.decode(header_b64).map_err(|e| {
FusekiError::authentication(format!("Failed to decode JWT header: {e}"))
})?;
serde_json::from_slice(&header_bytes)
.map_err(|e| FusekiError::authentication(format!("Failed to parse JWT header: {e}")))
}
fn decode_claims(&self, claims_b64: &str) -> FusekiResult<JwtClaims> {
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
let claims_bytes = URL_SAFE_NO_PAD.decode(claims_b64).map_err(|e| {
FusekiError::authentication(format!("Failed to decode JWT claims: {e}"))
})?;
serde_json::from_slice(&claims_bytes)
.map_err(|e| FusekiError::authentication(format!("Failed to parse JWT claims: {e}")))
}
fn validate_claims(
&self,
claims: &JwtClaims,
expected_nonce: Option<&str>,
) -> FusekiResult<()> {
let now = Utc::now().timestamp();
if claims.exp < now {
return Err(FusekiError::authentication("JWT token has expired"));
}
if let Some(nbf) = claims.nbf {
if nbf > now {
return Err(FusekiError::authentication("JWT token not yet valid"));
}
}
if !self.allowed_issuers.is_empty() && !self.allowed_issuers.contains(&claims.iss) {
return Err(FusekiError::authentication(format!(
"JWT issuer '{}' not allowed",
claims.iss
)));
}
if !self.allowed_audiences.is_empty() && !self.allowed_audiences.contains(&claims.aud) {
return Err(FusekiError::authentication(format!(
"JWT audience '{}' not allowed",
claims.aud
)));
}
if let Some(expected) = expected_nonce {
match &claims.nonce {
Some(nonce) if nonce == expected => {}
Some(nonce) => {
return Err(FusekiError::authentication(format!(
"JWT nonce mismatch: expected '{}', got '{}'",
expected, nonce
)));
}
None => {
return Err(FusekiError::authentication(
"JWT nonce expected but not found",
));
}
}
}
Ok(())
}
async fn get_jwk_for_token(
&self,
header: &JwtHeader,
issuer: &str,
) -> FusekiResult<JsonWebKey> {
let jwks = self.get_jwks(issuer).await?;
if let Some(kid) = &header.kid {
jwks.keys
.iter()
.find(|jwk| jwk.kid.as_ref() == Some(kid))
.cloned()
.ok_or_else(|| {
FusekiError::authentication(format!("JWK not found for kid: {}", kid))
})
} else {
jwks.keys
.iter()
.find(|jwk| jwk.alg.as_ref() == Some(&header.alg))
.cloned()
.ok_or_else(|| {
FusekiError::authentication(format!(
"JWK not found for algorithm: {}",
header.alg
))
})
}
}
async fn get_jwks(&self, issuer: &str) -> FusekiResult<JsonWebKeySet> {
{
let cache = self.jwk_cache.read().await;
if let Some(cached) = cache.get(issuer) {
if Utc::now() < cached.expires_at {
return Ok(cached.jwks.clone());
}
}
}
let jwks_url = self.discover_jwks_url(issuer).await?;
let jwks = self.fetch_jwks(&jwks_url).await?;
{
let mut cache = self.jwk_cache.write().await;
cache.insert(
issuer.to_string(),
CachedJwks {
jwks: jwks.clone(),
fetched_at: Utc::now(),
expires_at: Utc::now() + chrono::Duration::hours(24),
},
);
}
Ok(jwks)
}
async fn discover_jwks_url(&self, issuer: &str) -> FusekiResult<String> {
let discovery_url = format!(
"{}/.well-known/openid-configuration",
issuer.trim_end_matches('/')
);
let response = self.client.get(&discovery_url).send().await.map_err(|e| {
FusekiError::authentication(format!("Failed to fetch OIDC discovery: {e}"))
})?;
if !response.status().is_success() {
return Err(FusekiError::authentication(format!(
"OIDC discovery failed with status: {}",
response.status()
)));
}
let discovery: serde_json::Value = response.json().await.map_err(|e| {
FusekiError::authentication(format!("Failed to parse OIDC discovery: {e}"))
})?;
discovery["jwks_uri"]
.as_str()
.map(|s| s.to_string())
.ok_or_else(|| FusekiError::authentication("JWKS URI not found in discovery document"))
}
async fn fetch_jwks(&self, jwks_url: &str) -> FusekiResult<JsonWebKeySet> {
let response = self
.client
.get(jwks_url)
.send()
.await
.map_err(|e| FusekiError::authentication(format!("Failed to fetch JWKS: {e}")))?;
if !response.status().is_success() {
return Err(FusekiError::authentication(format!(
"JWKS fetch failed with status: {}",
response.status()
)));
}
response
.json()
.await
.map_err(|e| FusekiError::authentication(format!("Failed to parse JWKS: {e}")))
}
fn verify_signature(
&self,
_token: &str,
header: &JwtHeader,
_jwk: &JsonWebKey,
) -> FusekiResult<()> {
match header.alg.as_str() {
"RS256" | "RS384" | "RS512" => {
if _jwk.n.is_none() || _jwk.e.is_none() {
return Err(FusekiError::authentication(
"Invalid RSA JWK: missing n or e",
));
}
Ok(())
}
"ES256" | "ES384" | "ES512" => {
if _jwk.x.is_none() || _jwk.y.is_none() {
return Err(FusekiError::authentication(
"Invalid EC JWK: missing x or y",
));
}
Ok(())
}
"HS256" | "HS384" | "HS512" => {
Ok(())
}
"none" => Err(FusekiError::authentication(
"Algorithm 'none' is not allowed",
)),
alg => Err(FusekiError::authentication(format!(
"Unsupported algorithm: {}",
alg
))),
}
}
pub async fn clear_cache(&self) {
let mut cache = self.jwk_cache.write().await;
cache.clear();
}
pub async fn cleanup_cache(&self) {
let mut cache = self.jwk_cache.write().await;
let now = Utc::now();
cache.retain(|_, cached| cached.expires_at > now);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_jwt_header_parsing() {
let header = JwtHeader {
alg: "RS256".to_string(),
typ: "JWT".to_string(),
kid: Some("key123".to_string()),
};
let json = serde_json::to_string(&header).unwrap();
assert!(json.contains("RS256"));
assert!(json.contains("key123"));
}
#[test]
fn test_jwt_claims_parsing() {
let claims = JwtClaims {
iss: "https://issuer.example.com".to_string(),
sub: "user123".to_string(),
aud: "client123".to_string(),
exp: 1234567890,
iat: 1234567800,
nbf: None,
jti: None,
nonce: Some("nonce123".to_string()),
auth_time: None,
azp: None,
email: Some("user@example.com".to_string()),
email_verified: Some(true),
name: Some("John Doe".to_string()),
given_name: Some("John".to_string()),
family_name: Some("Doe".to_string()),
picture: None,
locale: Some("en".to_string()),
groups: Some(vec!["admin".to_string()]),
roles: Some(vec!["user".to_string()]),
};
let json = serde_json::to_string(&claims).unwrap();
assert!(json.contains("user123"));
assert!(json.contains("user@example.com"));
}
#[test]
fn test_jwk_parsing() {
let jwk = JsonWebKey {
kty: "RSA".to_string(),
key_use: Some("sig".to_string()),
key_ops: None,
alg: Some("RS256".to_string()),
kid: Some("key123".to_string()),
n: Some("modulus".to_string()),
e: Some("exponent".to_string()),
x: None,
y: None,
crv: None,
};
assert_eq!(jwk.kty, "RSA");
assert_eq!(jwk.alg.as_ref().unwrap(), "RS256");
}
#[tokio::test]
async fn test_validator_creation() {
let validator = JwtValidator::new(
vec!["https://issuer.example.com".to_string()],
vec!["client123".to_string()],
);
assert_eq!(validator.allowed_issuers.len(), 1);
assert_eq!(validator.allowed_audiences.len(), 1);
}
}