use crate::claims::NucleusClaims;
use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
use serde::Deserialize;
use std::collections::HashMap;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
#[derive(Debug, thiserror::Error)]
pub enum NucleusError {
#[error("nucleus: failed to fetch JWKS: {0}")]
JwksFetch(String),
#[error("nucleus: no matching key found in JWKS for kid `{0}`")]
KeyNotFound(String),
#[error("nucleus: invalid token: {0}")]
InvalidToken(String),
#[error("nucleus: API error {status}: {message}")]
Api { status: u16, message: String },
#[error("nucleus: HTTP request failed: {0}")]
Http(#[from] reqwest::Error),
}
#[derive(Debug, Deserialize)]
struct JwksResponse {
keys: Vec<Jwk>,
}
#[derive(Debug, Clone, Deserialize)]
struct Jwk {
kid: Option<String>,
kty: String,
n: Option<String>,
e: Option<String>,
}
struct CachedKeys {
keys: HashMap<String, DecodingKey>,
fetched_at: Instant,
}
pub struct JwksVerifier {
jwks_url: String,
ttl: Duration,
cache: RwLock<Option<CachedKeys>>,
http: reqwest::Client,
}
impl JwksVerifier {
pub fn new(base_url: &str, ttl_secs: u64) -> Self {
let jwks_url = format!("{}/.well-known/jwks.json", base_url.trim_end_matches('/'));
Self {
jwks_url,
ttl: Duration::from_secs(ttl_secs),
cache: RwLock::new(None),
http: reqwest::Client::new(),
}
}
pub async fn verify(&self, token: &str) -> Result<NucleusClaims, NucleusError> {
let header = decode_header(token)
.map_err(|e| NucleusError::InvalidToken(e.to_string()))?;
let kid = header
.kid
.ok_or_else(|| NucleusError::InvalidToken("token header missing `kid`".into()))?;
let key = self.get_key(&kid).await?;
let mut validation = Validation::new(Algorithm::RS256);
validation.validate_aud = false;
let data = decode::<NucleusClaims>(token, &key, &validation)
.map_err(|e| NucleusError::InvalidToken(e.to_string()))?;
Ok(data.claims)
}
async fn get_key(&self, kid: &str) -> Result<DecodingKey, NucleusError> {
{
let guard = self.cache.read().await;
if let Some(cached) = guard.as_ref() {
if cached.fetched_at.elapsed() < self.ttl {
if let Some(key) = cached.keys.get(kid) {
return Ok(key.clone());
}
}
}
}
self.refresh_and_get(kid).await
}
async fn refresh_and_get(&self, kid: &str) -> Result<DecodingKey, NucleusError> {
let mut guard = self.cache.write().await;
if let Some(cached) = guard.as_ref() {
if cached.fetched_at.elapsed() < self.ttl {
if let Some(key) = cached.keys.get(kid) {
return Ok(key.clone());
}
}
}
let resp: JwksResponse = self
.http
.get(&self.jwks_url)
.send()
.await?
.error_for_status()
.map_err(|e| NucleusError::JwksFetch(e.to_string()))?
.json()
.await
.map_err(|e| NucleusError::JwksFetch(e.to_string()))?;
let mut keys = HashMap::new();
for jwk in &resp.keys {
if jwk.kty != "RSA" {
continue;
}
if let (Some(n), Some(e), Some(k)) = (&jwk.n, &jwk.e, &jwk.kid) {
if let Ok(dk) = DecodingKey::from_rsa_components(n, e) {
keys.insert(k.clone(), dk);
}
}
}
let key = keys
.get(kid)
.ok_or_else(|| NucleusError::KeyNotFound(kid.to_string()))?
.clone();
*guard = Some(CachedKeys {
keys,
fetched_at: Instant::now(),
});
Ok(key)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::claims::NucleusClaims;
#[test]
fn verifier_creation_stores_jwks_url() {
let verifier = JwksVerifier::new("https://api.test.com", 3600);
assert_eq!(verifier.jwks_url, "https://api.test.com/.well-known/jwks.json");
}
#[test]
fn verifier_trims_trailing_slash() {
let verifier = JwksVerifier::new("https://api.test.com/", 3600);
assert_eq!(verifier.jwks_url, "https://api.test.com/.well-known/jwks.json");
}
#[test]
fn verifier_stores_ttl() {
let verifier = JwksVerifier::new("https://api.test.com", 7200);
assert_eq!(verifier.ttl, Duration::from_secs(7200));
}
#[tokio::test]
async fn verify_invalid_token_returns_error() {
let verifier = JwksVerifier::new("https://api.test.com", 3600);
let result = verifier.verify("not.a.valid.token").await;
assert!(result.is_err());
match result {
Err(NucleusError::InvalidToken(_)) => {}
other => panic!("expected InvalidToken error, got: {:?}", other),
}
}
#[tokio::test]
async fn verify_empty_token_returns_error() {
let verifier = JwksVerifier::new("https://api.test.com", 3600);
let result = verifier.verify("").await;
assert!(result.is_err());
}
#[test]
fn claims_serde_roundtrip() {
let json = serde_json::json!({
"sub": "user_123",
"iss": "https://api.test.com",
"aud": "project_456",
"exp": 1700000000u64,
"iat": 1699996400u64,
"jti": "jwt_abc",
"email": "test@example.com",
"first_name": "Test",
"last_name": "User",
"org_id": "org_1",
"org_role": "admin",
"org_permissions": ["read", "write"]
});
let claims: NucleusClaims = serde_json::from_value(json).unwrap();
assert_eq!(claims.user_id(), "user_123");
assert_eq!(claims.sub, "user_123");
assert_eq!(claims.aud, "project_456");
assert_eq!(claims.email, Some("test@example.com".to_string()));
assert_eq!(claims.first_name, Some("Test".to_string()));
assert_eq!(claims.org_id, Some("org_1".to_string()));
assert_eq!(claims.org_role, Some("admin".to_string()));
assert_eq!(
claims.org_permissions,
Some(vec!["read".to_string(), "write".to_string()])
);
let serialized = serde_json::to_value(&claims).unwrap();
assert_eq!(serialized["sub"], "user_123");
}
#[test]
fn claims_missing_optional_fields() {
let json = serde_json::json!({
"sub": "user_1",
"iss": "https://test.com",
"aud": "proj_1",
"exp": 1700000000u64,
"iat": 1699996400u64
});
let claims: NucleusClaims = serde_json::from_value(json).unwrap();
assert_eq!(claims.user_id(), "user_1");
assert!(claims.email.is_none());
assert!(claims.org_id.is_none());
assert!(claims.org_permissions.is_none());
assert!(claims.metadata.is_none());
}
#[test]
fn nucleus_error_display() {
let err = NucleusError::InvalidToken("test error".into());
assert!(err.to_string().contains("invalid token"));
let err = NucleusError::KeyNotFound("kid-123".into());
assert!(err.to_string().contains("kid-123"));
let err = NucleusError::JwksFetch("timeout".into());
assert!(err.to_string().contains("JWKS"));
}
}