use std::time::{SystemTime, UNIX_EPOCH};
use tracing::debug;
use crate::control::security::util::base64_url_decode;
use crate::types::TenantId;
use super::identity::{AuthMethod, AuthenticatedIdentity, Role};
#[derive(Debug, Clone)]
pub struct JwtConfig {
pub hmac_secret: Vec<u8>,
pub rsa_public_key_der: Vec<u8>,
pub expected_issuer: String,
pub expected_audience: String,
pub clock_skew_seconds: u64,
}
impl Default for JwtConfig {
fn default() -> Self {
Self {
hmac_secret: Vec::new(),
rsa_public_key_der: Vec::new(),
expected_issuer: String::new(),
expected_audience: String::new(),
clock_skew_seconds: 60,
}
}
}
#[derive(Debug, serde::Deserialize)]
struct JwtHeader {
alg: String,
}
#[derive(Debug, Clone, serde::Deserialize)]
pub struct JwtClaims {
pub sub: String,
#[serde(default)]
pub tenant_id: u32,
#[serde(default)]
pub roles: Vec<String>,
#[serde(default)]
pub exp: u64,
#[serde(default)]
pub nbf: u64,
#[serde(default)]
pub iat: u64,
#[serde(default)]
pub iss: String,
#[serde(default)]
pub aud: String,
#[serde(default)]
pub user_id: u64,
#[serde(default)]
pub is_superuser: bool,
#[serde(flatten)]
pub extra: std::collections::HashMap<String, serde_json::Value>,
}
pub struct JwtValidator {
config: JwtConfig,
}
impl JwtValidator {
pub fn new(config: JwtConfig) -> Self {
Self { config }
}
pub fn validate(&self, token: &str) -> Result<AuthenticatedIdentity, JwtError> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return Err(JwtError::MalformedToken);
}
let header_bytes = base64_url_decode(parts[0]).ok_or(JwtError::DecodingError)?;
let header: JwtHeader =
serde_json::from_slice(&header_bytes).map_err(|_| JwtError::InvalidClaims)?;
let payload_bytes = base64_url_decode(parts[1]).ok_or(JwtError::DecodingError)?;
let claims: JwtClaims =
serde_json::from_slice(&payload_bytes).map_err(|_| JwtError::InvalidClaims)?;
let signing_input = format!("{}.{}", parts[0], parts[1]);
let signature_bytes = base64_url_decode(parts[2]).ok_or(JwtError::DecodingError)?;
match header.alg.as_str() {
"HS256" => {
if self.config.hmac_secret.is_empty() {
return Err(JwtError::UnsupportedAlgorithm);
}
if !verify_hmac_sha256(
&self.config.hmac_secret,
signing_input.as_bytes(),
&signature_bytes,
) {
return Err(JwtError::InvalidSignature);
}
}
"RS256" => {
if self.config.rsa_public_key_der.is_empty() {
return Err(JwtError::UnsupportedAlgorithm);
}
if !verify_rsa_sha256(
&self.config.rsa_public_key_der,
signing_input.as_bytes(),
&signature_bytes,
) {
return Err(JwtError::InvalidSignature);
}
}
_ => return Err(JwtError::UnsupportedAlgorithm),
}
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
if claims.exp > 0 && now > claims.exp + self.config.clock_skew_seconds {
return Err(JwtError::Expired);
}
if claims.nbf > 0 && now + self.config.clock_skew_seconds < claims.nbf {
return Err(JwtError::NotYetValid);
}
if !self.config.expected_issuer.is_empty() && claims.iss != self.config.expected_issuer {
return Err(JwtError::InvalidIssuer);
}
if !self.config.expected_audience.is_empty() && claims.aud != self.config.expected_audience
{
return Err(JwtError::InvalidAudience);
}
let roles: Vec<Role> = claims
.roles
.iter()
.map(|r| r.parse::<Role>().unwrap_or(Role::Custom(r.clone())))
.collect();
let username = if claims.sub.is_empty() {
format!("jwt_user_{}", claims.user_id)
} else {
claims.sub.clone()
};
debug!(
username = %username,
tenant_id = claims.tenant_id,
roles = ?roles,
"JWT validated"
);
Ok(AuthenticatedIdentity {
user_id: claims.user_id,
username,
tenant_id: TenantId::new(claims.tenant_id),
auth_method: AuthMethod::ApiKey, roles,
is_superuser: claims.is_superuser,
})
}
pub fn is_configured(&self) -> bool {
!self.config.hmac_secret.is_empty() || !self.config.rsa_public_key_der.is_empty()
}
}
#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
pub enum JwtError {
#[error("malformed JWT token")]
MalformedToken,
#[error("invalid JWT claims")]
InvalidClaims,
#[error("JWT signature verification failed")]
InvalidSignature,
#[error("JWT token expired")]
Expired,
#[error("JWT token not yet valid")]
NotYetValid,
#[error("JWT issuer mismatch")]
InvalidIssuer,
#[error("JWT audience mismatch")]
InvalidAudience,
#[error("JWT base64 decoding error")]
DecodingError,
#[error("JWT algorithm not supported or not configured")]
UnsupportedAlgorithm,
}
fn verify_hmac_sha256(secret: &[u8], message: &[u8], expected_signature: &[u8]) -> bool {
use hmac::{Hmac, Mac};
use sha2::Sha256;
type HmacSha256 = Hmac<Sha256>;
let mut mac = match HmacSha256::new_from_slice(secret) {
Ok(m) => m,
Err(_) => return false,
};
mac.update(message);
mac.verify_slice(expected_signature).is_ok()
}
fn verify_rsa_sha256(public_key_der: &[u8], message: &[u8], signature: &[u8]) -> bool {
use rsa::Pkcs1v15Sign;
let rsa_key = if let Ok(key) =
<rsa::RsaPublicKey as rsa::pkcs8::DecodePublicKey>::from_public_key_der(public_key_der)
{
key
} else if let Ok(key) =
<rsa::RsaPublicKey as rsa::pkcs1::DecodeRsaPublicKey>::from_pkcs1_der(public_key_der)
{
key
} else {
return false;
};
use sha2::Digest;
let digest = sha2::Sha256::digest(message);
let scheme = Pkcs1v15Sign::new::<sha2::Sha256>();
rsa_key.verify(scheme, &digest, signature).is_ok()
}
pub fn load_rsa_public_key_pem(pem_path: &std::path::Path) -> Result<Vec<u8>, JwtError> {
let pem_data = std::fs::read(pem_path).map_err(|_| JwtError::DecodingError)?;
let parsed = pem::parse(&pem_data).map_err(|_| JwtError::DecodingError)?;
match parsed.tag() {
"PUBLIC KEY" | "RSA PUBLIC KEY" => Ok(parsed.into_contents()),
_ => Err(JwtError::DecodingError),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn decode_claims() {
let payload =
r#"{"sub":"alice","tenant_id":1,"roles":["readwrite"],"exp":9999999999,"user_id":42}"#;
let claims: JwtClaims = serde_json::from_str(payload).unwrap();
assert_eq!(claims.sub, "alice");
assert_eq!(claims.tenant_id, 1);
assert_eq!(claims.user_id, 42);
assert_eq!(claims.roles, vec!["readwrite"]);
}
#[test]
fn malformed_token_rejected() {
let validator = JwtValidator::new(JwtConfig::default());
let result = validator.validate("not-a-jwt");
assert_eq!(result.err(), Some(JwtError::MalformedToken));
}
#[test]
fn base64url_decode_works() {
let encoded = base64_url_encode(b"hello world");
let decoded = base64_url_decode(&encoded).unwrap();
assert_eq!(decoded, b"hello world");
}
fn base64_url_encode(data: &[u8]) -> String {
use base64::Engine;
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(data)
}
#[test]
fn rs256_roundtrip() {
use rsa::pkcs1v15::SigningKey;
use rsa::signature::{SignatureEncoding, Signer};
let mut rng = rand::thread_rng();
let private_key = rsa::RsaPrivateKey::new(&mut rng, 2048).unwrap();
let public_key = rsa::RsaPublicKey::from(&private_key);
let pub_der = {
use rsa::pkcs8::EncodePublicKey;
public_key.to_public_key_der().unwrap().as_ref().to_vec()
};
let header = base64_url_encode(br#"{"alg":"RS256","typ":"JWT"}"#);
let payload_json =
r#"{"sub":"bob","tenant_id":2,"roles":["admin"],"exp":9999999999,"user_id":99}"#;
let payload = base64_url_encode(payload_json.as_bytes());
let signing_input = format!("{header}.{payload}");
let signing_key = SigningKey::<sha2::Sha256>::new(private_key);
let sig: rsa::pkcs1v15::Signature = signing_key.sign(signing_input.as_bytes());
let sig_b64 = base64_url_encode(&sig.to_bytes());
let token = format!("{signing_input}.{sig_b64}");
let config = JwtConfig {
rsa_public_key_der: pub_der,
..Default::default()
};
let validator = JwtValidator::new(config);
let identity = validator.validate(&token).unwrap();
assert_eq!(identity.username, "bob");
assert_eq!(identity.tenant_id, TenantId::new(2));
assert_eq!(identity.user_id, 99);
}
#[test]
fn rs256_wrong_key_rejected() {
use rsa::pkcs1v15::SigningKey;
use rsa::signature::{SignatureEncoding, Signer};
let mut rng = rand::thread_rng();
let key1 = rsa::RsaPrivateKey::new(&mut rng, 2048).unwrap();
let key2 = rsa::RsaPrivateKey::new(&mut rng, 2048).unwrap();
let pub2 = rsa::RsaPublicKey::from(&key2);
let pub2_der = {
use rsa::pkcs8::EncodePublicKey;
pub2.to_public_key_der().unwrap().as_ref().to_vec()
};
let header = base64_url_encode(br#"{"alg":"RS256","typ":"JWT"}"#);
let payload = base64_url_encode(br#"{"sub":"x","exp":9999999999}"#);
let signing_input = format!("{header}.{payload}");
let signing_key = SigningKey::<sha2::Sha256>::new(key1);
let sig: rsa::pkcs1v15::Signature = signing_key.sign(signing_input.as_bytes());
let sig_b64 = base64_url_encode(&sig.to_bytes());
let token = format!("{signing_input}.{sig_b64}");
let config = JwtConfig {
rsa_public_key_der: pub2_der,
..Default::default()
};
let validator = JwtValidator::new(config);
assert_eq!(
validator.validate(&token).err(),
Some(JwtError::InvalidSignature)
);
}
#[test]
fn unsupported_algorithm_rejected() {
let header = base64_url_encode(br#"{"alg":"ES256"}"#);
let payload = base64_url_encode(br#"{"sub":"x","exp":9999999999}"#);
let sig = base64_url_encode(b"fakesig");
let token = format!("{header}.{payload}.{sig}");
let validator = JwtValidator::new(JwtConfig::default());
assert_eq!(
validator.validate(&token).err(),
Some(JwtError::UnsupportedAlgorithm)
);
}
}