use std::collections::HashMap;
use std::time::{SystemTime, UNIX_EPOCH};
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
use hmac::{Hmac, KeyInit, Mac};
use thiserror::Error;
use sha2::Sha256;
use super::scopes::Scope;
#[derive(Debug, Error)]
pub enum JwtError {
#[error("JWT is malformed (expected header.payload.signature)")]
Malformed,
#[error("JWT base64url decode failed: {0}")]
Base64(#[from] base64::DecodeError),
#[error("JWT JSON parse failed: {0}")]
Json(#[from] serde_json::Error),
#[error("unsupported JWT algorithm: {0}")]
UnsupportedAlg(String),
#[error("JWT algorithm 'none' is not allowed")]
AlgNone,
#[error("JWT signature verification failed")]
BadSignature,
#[error("JWT has expired")]
Expired,
#[error("JWT is not yet valid (nbf)")]
NotYetValid,
#[error("JWT audience mismatch (expected {expected:?}, got {got:?})")]
WrongAudience { expected: String, got: String },
#[error("JWT issuer mismatch (expected {expected:?}, got {got:?})")]
WrongIssuer { expected: String, got: String },
#[error("JWT missing required claim: {0}")]
MissingClaim(String),
#[error("RSA public key decode failed: {0}")]
RsaKeyDecode(String),
#[error("RSA signature bytes invalid")]
RsaSignatureInvalid,
#[error("system clock unavailable")]
Clock,
}
pub type JwtResult<T> = Result<T, JwtError>;
#[derive(Debug, Clone)]
pub enum JwtAlgorithm {
Hs256 { secret: Vec<u8> },
Rs256 {
public_key_der: Vec<u8>,
},
}
#[derive(Debug, Clone)]
pub struct JwtConfig {
pub algorithm: JwtAlgorithm,
pub audience: Option<String>,
pub issuer: Option<String>,
pub required_scopes_per_route: HashMap<String, Vec<Scope>>,
pub clock_skew_secs: u64,
}
impl Default for JwtConfig {
fn default() -> Self {
Self {
algorithm: JwtAlgorithm::Hs256 { secret: Vec::new() },
audience: None,
issuer: None,
required_scopes_per_route: HashMap::new(),
clock_skew_secs: 60,
}
}
}
#[derive(Debug, serde::Deserialize)]
pub struct JwtClaims {
pub sub: Option<String>,
pub exp: Option<u64>,
pub nbf: Option<u64>,
pub iat: Option<u64>,
pub aud: Option<serde_json::Value>,
pub iss: Option<String>,
pub scope: Option<String>,
pub scopes: Option<Vec<String>>,
}
#[derive(serde::Deserialize)]
struct JwtHeader {
alg: String,
#[allow(dead_code)]
typ: Option<String>,
}
fn verify_hs256(secret: &[u8], signing_input: &[u8], signature: &[u8]) -> JwtResult<()> {
let mut mac = Hmac::<Sha256>::new_from_slice(secret).map_err(|_| JwtError::BadSignature)?;
mac.update(signing_input);
mac.verify_slice(signature)
.map_err(|_| JwtError::BadSignature)
}
fn verify_rs256(public_key_der: &[u8], signing_input: &[u8], signature: &[u8]) -> JwtResult<()> {
use rsa::pkcs1v15::VerifyingKey;
use rsa::pkcs8::DecodePublicKey;
use rsa::sha2::Sha256 as RsaSha256;
use rsa::signature::Verifier;
let pub_key = rsa::RsaPublicKey::from_public_key_der(public_key_der)
.map_err(|e| JwtError::RsaKeyDecode(e.to_string()))?;
let verifying_key = VerifyingKey::<RsaSha256>::new(pub_key);
let sig =
rsa::pkcs1v15::Signature::try_from(signature).map_err(|_| JwtError::RsaSignatureInvalid)?;
verifying_key
.verify(signing_input, &sig)
.map_err(|_| JwtError::BadSignature)
}
pub struct JwtVerifier {
config: JwtConfig,
}
impl JwtVerifier {
pub fn new(config: JwtConfig) -> Self {
Self { config }
}
pub fn verify(&self, token: &str) -> JwtResult<JwtClaims> {
let parts: Vec<&str> = token.splitn(4, '.').collect();
if parts.len() != 3 {
return Err(JwtError::Malformed);
}
let (header_b64, payload_b64, sig_b64) = (parts[0], parts[1], parts[2]);
let header_bytes = URL_SAFE_NO_PAD.decode(header_b64)?;
let payload_bytes = URL_SAFE_NO_PAD.decode(payload_b64)?;
let signature = URL_SAFE_NO_PAD.decode(sig_b64)?;
let header: JwtHeader = serde_json::from_slice(&header_bytes)?;
if header.alg.eq_ignore_ascii_case("none") {
return Err(JwtError::AlgNone);
}
let signing_input = format!("{header_b64}.{payload_b64}");
match &self.config.algorithm {
JwtAlgorithm::Hs256 { secret } => {
if !header.alg.eq_ignore_ascii_case("HS256") {
return Err(JwtError::UnsupportedAlg(header.alg));
}
verify_hs256(secret, signing_input.as_bytes(), &signature)?;
}
JwtAlgorithm::Rs256 { public_key_der } => {
if !header.alg.eq_ignore_ascii_case("RS256") {
return Err(JwtError::UnsupportedAlg(header.alg));
}
verify_rs256(public_key_der, signing_input.as_bytes(), &signature)?;
}
}
let claims: JwtClaims = serde_json::from_slice(&payload_bytes)?;
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|_| JwtError::Clock)?
.as_secs();
let skew = self.config.clock_skew_secs;
if let Some(exp) = claims.exp {
if now > exp.saturating_add(skew) {
return Err(JwtError::Expired);
}
}
if let Some(nbf) = claims.nbf {
if now.saturating_add(skew) < nbf {
return Err(JwtError::NotYetValid);
}
}
if let Some(expected_aud) = &self.config.audience {
match &claims.aud {
None => {
return Err(JwtError::MissingClaim("aud".to_string()));
}
Some(aud_val) => {
let matches = match aud_val {
serde_json::Value::String(s) => s == expected_aud,
serde_json::Value::Array(arr) => arr
.iter()
.any(|v| v.as_str().is_some_and(|s| s == expected_aud)),
_ => false,
};
if !matches {
let got = aud_val.to_string();
return Err(JwtError::WrongAudience {
expected: expected_aud.clone(),
got,
});
}
}
}
}
if let Some(expected_iss) = &self.config.issuer {
match &claims.iss {
None => {
return Err(JwtError::MissingClaim("iss".to_string()));
}
Some(iss) if iss != expected_iss => {
return Err(JwtError::WrongIssuer {
expected: expected_iss.clone(),
got: iss.clone(),
});
}
Some(_) => {}
}
}
Ok(claims)
}
pub fn scopes_from_claims(&self, claims: &JwtClaims) -> Vec<Scope> {
let mut out: Vec<Scope> = Vec::new();
if let Some(scope_str) = &claims.scope {
for s in scope_str.split_whitespace() {
if let Ok(scope) = s.parse::<Scope>() {
out.push(scope);
}
}
}
if let Some(scopes_arr) = &claims.scopes {
for s in scopes_arr {
if let Ok(scope) = s.parse::<Scope>() {
if !out.contains(&scope) {
out.push(scope);
}
}
}
}
out
}
pub fn required_scopes_for_path(&self, path: &str) -> &[Scope] {
self.config
.required_scopes_per_route
.get(path)
.map(Vec::as_slice)
.unwrap_or(&[])
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
fn make_hs256_token(
secret: &[u8],
claims: &serde_json::Value,
alg_override: Option<&str>,
) -> Result<String, Box<dyn std::error::Error>> {
let alg = alg_override.unwrap_or("HS256");
let header = serde_json::json!({ "alg": alg, "typ": "JWT" });
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_vec(&header)?);
let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_vec(claims)?);
let signing_input = format!("{header_b64}.{payload_b64}");
let sig_b64 = if alg.eq_ignore_ascii_case("none") {
String::new()
} else {
let mut mac =
Hmac::<Sha256>::new_from_slice(secret).expect("test fixture: HMAC construction");
mac.update(signing_input.as_bytes());
URL_SAFE_NO_PAD.encode(mac.finalize().into_bytes())
};
Ok(format!("{signing_input}.{sig_b64}"))
}
fn unix_now() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("test fixture: system time")
.as_secs()
}
fn hs256_verifier(secret: &[u8]) -> JwtVerifier {
JwtVerifier::new(JwtConfig {
algorithm: JwtAlgorithm::Hs256 {
secret: secret.to_vec(),
},
..Default::default()
})
}
#[test]
fn jwt_valid_hs256_verifies() {
let secret = b"super-secret-key";
let claims = serde_json::json!({
"sub": "user-1",
"exp": unix_now() + 3600,
});
let token =
make_hs256_token(secret, &claims, None).expect("test fixture: token construction");
let verifier = hs256_verifier(secret);
let result = verifier.verify(&token);
assert!(result.is_ok(), "expected Ok, got: {:?}", result.err());
}
#[test]
fn jwt_wrong_secret_rejected() {
let claims = serde_json::json!({ "sub": "user-1", "exp": unix_now() + 3600 });
let token = make_hs256_token(b"correct-secret", &claims, None).expect("test fixture");
let verifier = hs256_verifier(b"wrong-secret");
assert!(matches!(
verifier.verify(&token),
Err(JwtError::BadSignature)
));
}
#[test]
fn jwt_expired_token_rejected() {
let secret = b"my-secret";
let claims = serde_json::json!({
"sub": "user-1",
"exp": unix_now() - 1000,
});
let token = make_hs256_token(secret, &claims, None).expect("test fixture");
let verifier = hs256_verifier(secret);
assert!(
matches!(verifier.verify(&token), Err(JwtError::Expired)),
"expected Expired"
);
}
#[test]
fn jwt_nbf_in_future_rejected() {
let secret = b"my-secret";
let claims = serde_json::json!({
"sub": "user-1",
"exp": unix_now() + 3600,
"nbf": unix_now() + 10_000,
});
let token = make_hs256_token(secret, &claims, None).expect("test fixture");
let verifier = hs256_verifier(secret);
assert!(
matches!(verifier.verify(&token), Err(JwtError::NotYetValid)),
"expected NotYetValid"
);
}
#[test]
fn jwt_wrong_audience_rejected() {
let secret = b"my-secret";
let claims = serde_json::json!({
"sub": "user-1",
"exp": unix_now() + 3600,
"aud": "other-service",
});
let token = make_hs256_token(secret, &claims, None).expect("test fixture");
let verifier = JwtVerifier::new(JwtConfig {
algorithm: JwtAlgorithm::Hs256 {
secret: secret.to_vec(),
},
audience: Some("my-service".to_string()),
..Default::default()
});
assert!(
matches!(verifier.verify(&token), Err(JwtError::WrongAudience { .. })),
"expected WrongAudience"
);
}
#[test]
fn jwt_correct_audience_in_array_accepted() {
let secret = b"my-secret";
let claims = serde_json::json!({
"sub": "user-1",
"exp": unix_now() + 3600,
"aud": ["other-service", "my-service"],
});
let token = make_hs256_token(secret, &claims, None).expect("test fixture");
let verifier = JwtVerifier::new(JwtConfig {
algorithm: JwtAlgorithm::Hs256 {
secret: secret.to_vec(),
},
audience: Some("my-service".to_string()),
..Default::default()
});
assert!(verifier.verify(&token).is_ok());
}
#[test]
fn jwt_wrong_issuer_rejected() {
let secret = b"my-secret";
let claims = serde_json::json!({
"sub": "user-1",
"exp": unix_now() + 3600,
"iss": "bad-issuer",
});
let token = make_hs256_token(secret, &claims, None).expect("test fixture");
let verifier = JwtVerifier::new(JwtConfig {
algorithm: JwtAlgorithm::Hs256 {
secret: secret.to_vec(),
},
issuer: Some("trusted-issuer".to_string()),
..Default::default()
});
assert!(
matches!(verifier.verify(&token), Err(JwtError::WrongIssuer { .. })),
"expected WrongIssuer"
);
}
#[test]
fn jwt_malformed_token_rejected() {
let verifier = hs256_verifier(b"secret");
assert!(
matches!(
verifier.verify("not.a.valid.token"),
Err(JwtError::Malformed)
),
"expected Malformed for 4-part token"
);
assert!(
matches!(verifier.verify("only.two"), Err(JwtError::Malformed)),
"expected Malformed for 2-part token"
);
assert!(
matches!(verifier.verify("onepart"), Err(JwtError::Malformed)),
"expected Malformed for 1-part token"
);
}
#[test]
fn jwt_alg_none_rejected() {
let secret = b"my-secret";
let claims = serde_json::json!({ "sub": "user-1", "exp": unix_now() + 3600 });
let token = make_hs256_token(secret, &claims, Some("none")).expect("test fixture");
let verifier = hs256_verifier(secret);
assert!(
matches!(verifier.verify(&token), Err(JwtError::AlgNone)),
"expected AlgNone"
);
}
#[test]
fn jwt_alg_none_uppercase_rejected() {
let secret = b"my-secret";
let claims = serde_json::json!({ "sub": "user-1", "exp": unix_now() + 3600 });
let token = make_hs256_token(secret, &claims, Some("NONE")).expect("test fixture");
let verifier = hs256_verifier(secret);
assert!(matches!(verifier.verify(&token), Err(JwtError::AlgNone)));
}
#[test]
fn jwt_scope_string_parsed() {
let secret = b"my-secret";
let claims = serde_json::json!({
"sub": "user-1",
"exp": unix_now() + 3600,
"scope": "chat:read embed:read",
});
let token = make_hs256_token(secret, &claims, None).expect("test fixture");
let verifier = hs256_verifier(secret);
let decoded = verifier.verify(&token).expect("valid token");
let scopes = verifier.scopes_from_claims(&decoded);
assert!(scopes.contains(&Scope::ChatRead));
assert!(scopes.contains(&Scope::EmbedRead));
assert!(!scopes.contains(&Scope::ChatWrite));
}
#[test]
fn jwt_scopes_array_parsed() {
let secret = b"my-secret";
let claims = serde_json::json!({
"sub": "user-1",
"exp": unix_now() + 3600,
"scopes": ["chat:read", "admin:read"],
});
let token = make_hs256_token(secret, &claims, None).expect("test fixture");
let verifier = hs256_verifier(secret);
let decoded = verifier.verify(&token).expect("valid token");
let scopes = verifier.scopes_from_claims(&decoded);
assert!(scopes.contains(&Scope::ChatRead));
assert!(scopes.contains(&Scope::AdminRead));
}
#[test]
fn jwt_required_scopes_for_path() {
let mut route_scopes = HashMap::new();
route_scopes.insert("/v1/chat/completions".to_string(), vec![Scope::ChatWrite]);
let verifier = JwtVerifier::new(JwtConfig {
algorithm: JwtAlgorithm::Hs256 {
secret: b"s".to_vec(),
},
required_scopes_per_route: route_scopes,
..Default::default()
});
let required = verifier.required_scopes_for_path("/v1/chat/completions");
assert_eq!(required, [Scope::ChatWrite]);
let unrestricted = verifier.required_scopes_for_path("/health");
assert!(unrestricted.is_empty());
}
#[test]
#[ignore]
fn jwt_rs256_with_generated_key() {
use rsa::pkcs1v15::SigningKey;
use rsa::pkcs8::EncodePublicKey;
use rsa::sha2::Sha256 as RsaSha256;
use rsa::signature::RandomizedSigner;
use rsa::signature::SignatureEncoding;
let mut rng = rand_core::OsRng;
let private_key =
rsa::RsaPrivateKey::new(&mut rng, 2048).expect("test fixture: RSA key generation");
let public_key = rsa::RsaPublicKey::from(&private_key);
let public_key_der = public_key
.to_public_key_der()
.expect("test fixture: DER encode")
.to_vec();
let header = serde_json::json!({ "alg": "RS256", "typ": "JWT" });
let claims_json = serde_json::json!({
"sub": "rs256-user",
"exp": unix_now() + 3600,
});
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_vec(&header).expect("test fixture"));
let payload_b64 =
URL_SAFE_NO_PAD.encode(serde_json::to_vec(&claims_json).expect("test fixture"));
let signing_input = format!("{header_b64}.{payload_b64}");
let signing_key = SigningKey::<RsaSha256>::new(private_key);
let signature = signing_key.sign_with_rng(&mut rng, signing_input.as_bytes());
let sig_b64 = URL_SAFE_NO_PAD.encode(signature.to_bytes());
let token = format!("{signing_input}.{sig_b64}");
let verifier = JwtVerifier::new(JwtConfig {
algorithm: JwtAlgorithm::Rs256 { public_key_der },
..Default::default()
});
let result = verifier.verify(&token);
assert!(result.is_ok(), "RS256 verify failed: {:?}", result.err());
}
}