use jsonwebtoken::{Algorithm, DecodingKey, Validation};
use serde::Deserialize;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use crate::jwt::{AuthError, JwtConfig, JwtManager};
#[derive(Debug, Clone, Deserialize)]
pub struct Jwk {
pub kty: String,
#[serde(default)]
pub kid: Option<String>,
#[serde(default)]
pub alg: Option<String>,
#[serde(rename = "use", default)]
pub key_use: Option<String>,
#[serde(default)]
pub n: Option<String>,
#[serde(default)]
pub e: Option<String>,
#[serde(default)]
pub crv: Option<String>,
#[serde(default)]
pub x: Option<String>,
#[serde(default)]
pub y: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct JwkSet {
pub keys: Vec<Jwk>,
}
struct ParsedKey {
decoding_key: DecodingKey,
algorithm: Algorithm,
}
fn algorithm_from_str(alg: &str) -> Option<Algorithm> {
match alg {
"RS256" => Some(Algorithm::RS256),
"RS384" => Some(Algorithm::RS384),
"RS512" => Some(Algorithm::RS512),
"ES256" => Some(Algorithm::ES256),
"ES384" => Some(Algorithm::ES384),
"PS256" => Some(Algorithm::PS256),
"PS384" => Some(Algorithm::PS384),
"PS512" => Some(Algorithm::PS512),
_ => None,
}
}
fn parse_jwk(jwk: &Jwk) -> Result<ParsedKey, AuthError> {
#[allow(clippy::unnecessary_lazy_evaluations)]
let algorithm = jwk
.alg
.as_deref()
.and_then(algorithm_from_str)
.or_else(|| {
match jwk.kty.as_str() {
"RSA" => Some(Algorithm::RS256),
"EC" => match jwk.crv.as_deref() {
Some("P-384") => Some(Algorithm::ES384),
_ => Some(Algorithm::ES256), },
_ => None,
}
})
.ok_or_else(|| AuthError::Internal("unsupported key type/algorithm".into()))?;
let decoding_key = match jwk.kty.as_str() {
"RSA" => {
let n = jwk
.n
.as_deref()
.ok_or_else(|| AuthError::Internal("RSA JWK missing 'n' parameter".into()))?;
let e = jwk
.e
.as_deref()
.ok_or_else(|| AuthError::Internal("RSA JWK missing 'e' parameter".into()))?;
DecodingKey::from_rsa_components(n, e)
.map_err(|err| AuthError::Internal(format!("RSA JWK parse error: {err}")))?
}
"EC" => {
let x = jwk
.x
.as_deref()
.ok_or_else(|| AuthError::Internal("EC JWK missing 'x' parameter".into()))?;
let y = jwk
.y
.as_deref()
.ok_or_else(|| AuthError::Internal("EC JWK missing 'y' parameter".into()))?;
DecodingKey::from_ec_components(x, y)
.map_err(|err| AuthError::Internal(format!("EC JWK parse error: {err}")))?
}
other => {
return Err(AuthError::Internal(format!(
"unsupported JWK key type: {other}"
)));
}
};
Ok(ParsedKey {
decoding_key,
algorithm,
})
}
#[derive(Clone)]
pub struct JwksProvider {
inner: Arc<RwLock<JwksInner>>,
}
struct JwksInner {
keys: HashMap<String, (DecodingKey, Algorithm)>,
default_key: Option<(DecodingKey, Algorithm)>,
}
impl JwksProvider {
pub fn from_json(jwks_json: &str) -> Result<Self, AuthError> {
let jwk_set: JwkSet = serde_json::from_str(jwks_json)
.map_err(|e| AuthError::Internal(format!("JWKS parse error: {e}")))?;
let inner = Self::build_inner(&jwk_set)?;
Ok(Self {
inner: Arc::new(RwLock::new(inner)),
})
}
pub fn refresh(&self, jwks_json: &str) -> Result<(), AuthError> {
let jwk_set: JwkSet = serde_json::from_str(jwks_json)
.map_err(|e| AuthError::Internal(format!("JWKS parse error: {e}")))?;
let new_inner = Self::build_inner(&jwk_set)?;
let mut guard = self
.inner
.write()
.map_err(|_| AuthError::Internal("JWKS lock poisoned".into()))?;
*guard = new_inner;
Ok(())
}
pub fn manager_for_kid(&self, kid: &str) -> Option<JwtManager> {
let guard = self.inner.read().ok()?;
let (dk, alg) = guard.keys.get(kid)?;
let mut validation = Validation::new(*alg);
validation.validate_exp = true;
validation.leeway = 60;
Some(JwtManager::with_config(JwtConfig {
decoding_key: dk.clone(),
encoding_key: None,
validation,
}))
}
pub fn decode<T>(&self, token: &str) -> Result<T, AuthError>
where
T: for<'de> serde::de::Deserialize<'de> + crate::jwt::HasJti,
{
let kid = extract_kid_from_header(token);
let guard = self
.inner
.read()
.map_err(|_| AuthError::Internal("JWKS lock poisoned".into()))?;
let (dk, alg) = if let Some(kid) = &kid {
guard
.keys
.get(kid.as_str())
.ok_or_else(|| AuthError::InvalidToken(format!("unknown kid: {kid}")))?
} else {
guard.default_key.as_ref().ok_or(AuthError::InvalidToken(
"no kid in token and no default key".into(),
))?
};
let mut validation = Validation::new(*alg);
validation.validate_exp = true;
validation.leeway = 60;
let token_data =
jsonwebtoken::decode::<T>(token, dk, &validation).map_err(|e| match e.kind() {
jsonwebtoken::errors::ErrorKind::ExpiredSignature => AuthError::Expired,
_ => AuthError::InvalidToken(e.to_string()),
})?;
Ok(token_data.claims)
}
pub fn key_count(&self) -> usize {
let guard = self.inner.read().unwrap_or_else(|e| e.into_inner());
guard.keys.len() + if guard.default_key.is_some() { 1 } else { 0 }
}
fn build_inner(jwk_set: &JwkSet) -> Result<JwksInner, AuthError> {
let mut keys = HashMap::new();
let mut default_key = None;
for jwk in &jwk_set.keys {
if jwk.key_use.as_deref() == Some("enc") {
continue;
}
match parse_jwk(jwk) {
Ok(parsed) => {
if let Some(kid) = &jwk.kid {
keys.insert(kid.clone(), (parsed.decoding_key, parsed.algorithm));
} else if default_key.is_none() {
default_key = Some((parsed.decoding_key, parsed.algorithm));
}
}
Err(_) => {
continue;
}
}
}
Ok(JwksInner { keys, default_key })
}
}
fn extract_kid_from_header(token: &str) -> Option<String> {
let header_b64 = token.split('.').next()?;
let header_json = base64url_decode(header_b64)?;
let header: serde_json::Value = serde_json::from_slice(&header_json).ok()?;
header.get("kid")?.as_str().map(|s| s.to_string())
}
fn base64url_decode(input: &str) -> Option<Vec<u8>> {
const TABLE: [u8; 128] = {
let mut t = [255u8; 128];
let mut i = 0u8;
while i < 26 {
t[(b'A' + i) as usize] = i;
i += 1;
}
i = 0;
while i < 26 {
t[(b'a' + i) as usize] = 26 + i;
i += 1;
}
i = 0;
while i < 10 {
t[(b'0' + i) as usize] = 52 + i;
i += 1;
}
t[b'+' as usize] = 62;
t[b'-' as usize] = 62; t[b'/' as usize] = 63;
t[b'_' as usize] = 63; t
};
let input = input.trim_end_matches('=');
let len = input.len();
let mut out = Vec::with_capacity(len * 3 / 4);
let mut i = 0;
while i + 3 < len {
let (a, b, c, d) = (
TABLE[input.as_bytes()[i] as usize],
TABLE[input.as_bytes()[i + 1] as usize],
TABLE[input.as_bytes()[i + 2] as usize],
TABLE[input.as_bytes()[i + 3] as usize],
);
if a == 255 || b == 255 || c == 255 || d == 255 {
return None;
}
out.push((a << 2) | (b >> 4));
out.push((b << 4) | (c >> 2));
out.push((c << 6) | d);
i += 4;
}
let remaining = len - i;
if remaining >= 2 {
let a = TABLE[input.as_bytes()[i] as usize];
let b = TABLE[input.as_bytes()[i + 1] as usize];
if a == 255 || b == 255 {
return None;
}
out.push((a << 2) | (b >> 4));
if remaining >= 3 {
let c = TABLE[input.as_bytes()[i + 2] as usize];
if c == 255 {
return None;
}
out.push((b << 4) | (c >> 2));
}
}
Some(out)
}
#[cfg(test)]
mod tests {
use super::*;
const SAMPLE_JWKS: &str = r#"{
"keys": [
{
"kty": "RSA",
"kid": "key-1",
"use": "sig",
"alg": "RS256",
"n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw",
"e": "AQAB"
},
{
"kty": "RSA",
"kid": "key-2",
"use": "sig",
"alg": "RS256",
"n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw",
"e": "AQAB"
},
{
"kty": "RSA",
"kid": "enc-key",
"use": "enc",
"alg": "RS256",
"n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw",
"e": "AQAB"
}
]
}"#;
#[test]
fn test_parse_jwks_json() {
let provider = JwksProvider::from_json(SAMPLE_JWKS).unwrap();
assert_eq!(provider.key_count(), 2);
}
#[test]
fn test_manager_for_kid() {
let provider = JwksProvider::from_json(SAMPLE_JWKS).unwrap();
assert!(provider.manager_for_kid("key-1").is_some());
assert!(provider.manager_for_kid("key-2").is_some());
assert!(provider.manager_for_kid("nonexistent").is_none());
assert!(provider.manager_for_kid("enc-key").is_none());
}
#[test]
fn test_refresh_replaces_keys() {
let provider = JwksProvider::from_json(SAMPLE_JWKS).unwrap();
assert_eq!(provider.key_count(), 2);
let single = r#"{"keys": [{"kty": "RSA", "kid": "new-key", "use": "sig", "alg": "RS256", "n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw", "e": "AQAB"}]}"#;
provider.refresh(single).unwrap();
assert_eq!(provider.key_count(), 1);
assert!(provider.manager_for_kid("new-key").is_some());
assert!(provider.manager_for_kid("key-1").is_none());
}
#[test]
fn test_parse_invalid_json() {
assert!(JwksProvider::from_json("not json").is_err());
}
#[test]
fn test_empty_keyset() {
let provider = JwksProvider::from_json(r#"{"keys": []}"#).unwrap();
assert_eq!(provider.key_count(), 0);
}
#[test]
fn test_default_key_no_kid() {
let jwks = r#"{"keys": [{"kty": "RSA", "use": "sig", "alg": "RS256", "n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw", "e": "AQAB"}]}"#;
let provider = JwksProvider::from_json(jwks).unwrap();
assert_eq!(provider.key_count(), 1); }
#[test]
fn test_base64url_decode() {
assert_eq!(base64url_decode("SGVsbG8").unwrap(), b"Hello");
assert_eq!(base64url_decode("SGVsbG8gV29ybGQ").unwrap(), b"Hello World");
}
#[test]
fn test_extract_kid_from_header() {
let token = "eyJhbGciOiJSUzI1NiIsImtpZCI6Im15LWtleSIsInR5cCI6IkpXVCJ9.e30.sig";
let kid = extract_kid_from_header(token);
assert_eq!(kid.as_deref(), Some("my-key"));
}
#[test]
fn test_extract_kid_no_kid() {
let token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30.sig";
let kid = extract_kid_from_header(token);
assert!(kid.is_none());
}
#[test]
fn test_algorithm_from_str() {
assert_eq!(algorithm_from_str("RS256"), Some(Algorithm::RS256));
assert_eq!(algorithm_from_str("ES256"), Some(Algorithm::ES256));
assert_eq!(algorithm_from_str("PS512"), Some(Algorithm::PS512));
assert_eq!(algorithm_from_str("HS256"), None); assert_eq!(algorithm_from_str("unknown"), None);
}
}