use anyhow::Result;
use chrono::{Duration, Utc};
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::sync::OnceLock;
const DEFAULT_JWT_ISSUER: &str = "mockforge-registry";
const DEFAULT_JWT_AUDIENCE: &str = "mockforge-api";
static JWT_ISSUER: OnceLock<String> = OnceLock::new();
static JWT_AUDIENCE: OnceLock<String> = OnceLock::new();
fn derive_kid(secret: &str) -> String {
let hash = Sha256::digest(secret.as_bytes());
hex::encode(&hash[..4])
}
fn build_header(secret: &str) -> Header {
let mut header = Header::new(Algorithm::HS256);
header.kid = Some(derive_kid(secret));
header
}
fn get_jwt_issuer() -> &'static str {
JWT_ISSUER.get_or_init(|| {
std::env::var("JWT_ISSUER").unwrap_or_else(|_| DEFAULT_JWT_ISSUER.to_string())
})
}
fn get_jwt_audience() -> &'static str {
JWT_AUDIENCE.get_or_init(|| {
std::env::var("JWT_AUDIENCE").unwrap_or_else(|_| DEFAULT_JWT_AUDIENCE.to_string())
})
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TokenType {
Access,
Refresh,
}
pub const ACCESS_TOKEN_EXPIRY_HOURS: i64 = 24;
pub const REFRESH_TOKEN_EXPIRY_DAYS: i64 = 7;
#[derive(Debug, Serialize, Deserialize)]
pub struct Claims {
pub sub: String, pub exp: usize, pub iat: usize, pub iss: String, pub aud: String, #[serde(default = "default_token_type")]
pub token_type: TokenType, #[serde(skip_serializing_if = "Option::is_none")]
pub jti: Option<String>, }
fn default_token_type() -> TokenType {
TokenType::Access
}
#[derive(Debug, Serialize, Deserialize)]
pub struct TokenPair {
pub access_token: String,
pub refresh_token: String,
pub access_token_expires_at: i64,
pub refresh_token_expires_at: i64,
}
pub fn create_access_token(user_id: &str, secret: &str) -> Result<String> {
let now = Utc::now();
let expiration = now
.checked_add_signed(Duration::hours(ACCESS_TOKEN_EXPIRY_HOURS))
.ok_or_else(|| anyhow::anyhow!("Failed to calculate token expiration"))?
.timestamp();
let claims = Claims {
sub: user_id.to_string(),
exp: expiration as usize,
iat: now.timestamp() as usize,
iss: get_jwt_issuer().to_string(),
aud: get_jwt_audience().to_string(),
token_type: TokenType::Access,
jti: None,
};
let header = build_header(secret);
let token = encode(&header, &claims, &EncodingKey::from_secret(secret.as_bytes()))?;
Ok(token)
}
pub fn create_refresh_token(user_id: &str, secret: &str) -> Result<(String, String, i64)> {
let now = Utc::now();
let expiration = now
.checked_add_signed(Duration::days(REFRESH_TOKEN_EXPIRY_DAYS))
.ok_or_else(|| anyhow::anyhow!("Failed to calculate refresh token expiration"))?
.timestamp();
let jti = uuid::Uuid::new_v4().to_string();
let claims = Claims {
sub: user_id.to_string(),
exp: expiration as usize,
iat: now.timestamp() as usize,
iss: get_jwt_issuer().to_string(),
aud: get_jwt_audience().to_string(),
token_type: TokenType::Refresh,
jti: Some(jti.clone()),
};
let header = build_header(secret);
let token = encode(&header, &claims, &EncodingKey::from_secret(secret.as_bytes()))?;
Ok((token, jti, expiration))
}
pub fn create_token_pair(user_id: &str, secret: &str) -> Result<(TokenPair, String)> {
let access_token = create_access_token(user_id, secret)?;
let (refresh_token, jti, refresh_exp) = create_refresh_token(user_id, secret)?;
let now = Utc::now();
let access_exp = now
.checked_add_signed(Duration::hours(ACCESS_TOKEN_EXPIRY_HOURS))
.ok_or_else(|| anyhow::anyhow!("Failed to calculate access token expiration"))?
.timestamp();
Ok((
TokenPair {
access_token,
refresh_token,
access_token_expires_at: access_exp,
refresh_token_expires_at: refresh_exp,
},
jti,
))
}
pub fn create_token(user_id: &str, secret: &str) -> Result<String> {
create_access_token(user_id, secret)
}
pub fn create_deployment_ingest_token(
deployment_id: uuid::Uuid,
secret: &str,
ttl_days: i64,
) -> Result<String> {
let now = Utc::now();
let expiration = now
.checked_add_signed(Duration::days(ttl_days))
.ok_or_else(|| anyhow::anyhow!("Failed to calculate deployment token expiration"))?
.timestamp();
let claims = Claims {
sub: format!("deployment:{}", deployment_id),
exp: expiration as usize,
iat: now.timestamp() as usize,
iss: get_jwt_issuer().to_string(),
aud: get_jwt_audience().to_string(),
token_type: TokenType::Access,
jti: None,
};
let header = build_header(secret);
let token = encode(&header, &claims, &EncodingKey::from_secret(secret.as_bytes()))?;
Ok(token)
}
pub fn verify_deployment_ingest_token(token: &str, secret: &str) -> Result<uuid::Uuid> {
let claims = verify_token(token, secret)?;
let id_str = claims
.sub
.strip_prefix("deployment:")
.ok_or_else(|| anyhow::anyhow!("Token sub is not a deployment subject"))?;
uuid::Uuid::parse_str(id_str)
.map_err(|_| anyhow::anyhow!("Token sub deployment id is not a valid UUID"))
}
pub fn verify_token(token: &str, secret: &str) -> Result<Claims> {
match verify_token_with_secret(token, secret) {
Ok(claims) => Ok(claims),
Err(primary_err) => {
if let Ok(previous_secret) = std::env::var("JWT_SECRET_PREVIOUS") {
if !previous_secret.is_empty() {
if let Ok(claims) = verify_token_with_secret(token, &previous_secret) {
tracing::info!(
"Token verified with previous JWT secret (key rotation in progress)"
);
return Ok(claims);
}
}
}
Err(primary_err)
}
}
}
fn verify_token_with_secret(token: &str, secret: &str) -> Result<Claims> {
let mut validation = Validation::default();
validation.set_audience(&[get_jwt_audience()]);
validation.set_issuer(&[get_jwt_issuer()]);
let token_data =
decode::<Claims>(token, &DecodingKey::from_secret(secret.as_bytes()), &validation)?;
Ok(token_data.claims)
}
pub fn verify_refresh_token(token: &str, secret: &str) -> Result<(Claims, String)> {
let claims = verify_token(token, secret)?;
if claims.token_type != TokenType::Refresh {
anyhow::bail!("Expected refresh token, got access token");
}
let jti = claims.jti.clone().ok_or_else(|| anyhow::anyhow!("Refresh token missing JTI"))?;
Ok((claims, jti))
}
pub fn verify_access_token(token: &str, secret: &str) -> Result<Claims> {
let claims = verify_token(token, secret)?;
if claims.token_type != TokenType::Access {
anyhow::bail!("Expected access token, got refresh token");
}
Ok(claims)
}
pub fn hash_password(password: &str) -> Result<String> {
let hash = bcrypt::hash(password, bcrypt::DEFAULT_COST)?;
Ok(hash)
}
pub fn verify_password(password: &str, hash: &str) -> Result<bool> {
let valid = bcrypt::verify(password, hash)?;
Ok(valid)
}
#[cfg(test)]
mod tests {
use super::*;
const TEST_SECRET: &str = "test-secret-key-for-jwt-signing-minimum-32-chars";
#[test]
fn test_create_token() {
let user_id = "user-123";
let token = create_token(user_id, TEST_SECRET).unwrap();
assert!(!token.is_empty());
let parts: Vec<&str> = token.split('.').collect();
assert_eq!(parts.len(), 3);
}
#[test]
fn test_verify_token_valid() {
let user_id = "user-456";
let token = create_token(user_id, TEST_SECRET).unwrap();
let claims = verify_token(&token, TEST_SECRET).unwrap();
assert_eq!(claims.sub, user_id);
assert!(claims.exp > claims.iat);
let duration = claims.exp - claims.iat;
let expected_secs = ACCESS_TOKEN_EXPIRY_HOURS as usize * 3600;
assert!(duration >= expected_secs - 60, "Token should be valid for at least 23h59m");
assert!(duration <= expected_secs + 60, "Token should be valid for at most 24h1m");
assert_eq!(claims.token_type, TokenType::Access);
}
#[test]
fn test_access_token() {
let user_id = "user-access";
let token = create_access_token(user_id, TEST_SECRET).unwrap();
let claims = verify_access_token(&token, TEST_SECRET).unwrap();
assert_eq!(claims.sub, user_id);
assert_eq!(claims.token_type, TokenType::Access);
assert!(claims.jti.is_none());
}
#[test]
fn test_refresh_token() {
let user_id = "user-refresh";
let (token, jti, _expires) = create_refresh_token(user_id, TEST_SECRET).unwrap();
let (claims, verified_jti) = verify_refresh_token(&token, TEST_SECRET).unwrap();
assert_eq!(claims.sub, user_id);
assert_eq!(claims.token_type, TokenType::Refresh);
assert_eq!(verified_jti, jti);
let duration = claims.exp - claims.iat;
assert!(
duration >= 6 * 24 * 60 * 60,
"Refresh token should be valid for at least 6 days"
);
assert!(duration <= 8 * 24 * 60 * 60, "Refresh token should be valid for at most 8 days");
}
#[test]
fn test_token_pair() {
let user_id = "user-pair";
let (pair, jti) = create_token_pair(user_id, TEST_SECRET).unwrap();
let access_claims = verify_access_token(&pair.access_token, TEST_SECRET).unwrap();
assert_eq!(access_claims.sub, user_id);
assert_eq!(access_claims.token_type, TokenType::Access);
let (refresh_claims, verified_jti) =
verify_refresh_token(&pair.refresh_token, TEST_SECRET).unwrap();
assert_eq!(refresh_claims.sub, user_id);
assert_eq!(refresh_claims.token_type, TokenType::Refresh);
assert_eq!(verified_jti, jti);
assert!(pair.access_token_expires_at < pair.refresh_token_expires_at);
}
#[test]
fn test_refresh_token_rejected_as_access() {
let user_id = "user-reject";
let (token, _, _) = create_refresh_token(user_id, TEST_SECRET).unwrap();
let result = verify_access_token(&token, TEST_SECRET);
assert!(result.is_err());
}
#[test]
fn test_access_token_rejected_as_refresh() {
let user_id = "user-reject2";
let token = create_access_token(user_id, TEST_SECRET).unwrap();
let result = verify_refresh_token(&token, TEST_SECRET);
assert!(result.is_err());
}
#[test]
fn test_verify_token_wrong_secret() {
let user_id = "user-789";
let token = create_token(user_id, TEST_SECRET).unwrap();
let wrong_secret = "wrong-secret-key-for-jwt-signing";
let result = verify_token(&token, wrong_secret);
assert!(result.is_err());
}
#[test]
fn test_verify_token_invalid_format() {
let invalid_token = "invalid.token.format";
let result = verify_token(invalid_token, TEST_SECRET);
assert!(result.is_err());
}
#[test]
fn test_verify_token_malformed() {
let malformed_token = "not-a-jwt-token";
let result = verify_token(malformed_token, TEST_SECRET);
assert!(result.is_err());
}
#[test]
fn test_hash_password() {
let password = "my-secure-password";
let hash = hash_password(password).unwrap();
assert!(!hash.is_empty());
assert!(hash.starts_with("$2"));
assert_ne!(hash, password);
}
#[test]
fn test_hash_password_different_hashes() {
let password = "same-password";
let hash1 = hash_password(password).unwrap();
let hash2 = hash_password(password).unwrap();
assert_ne!(hash1, hash2);
}
#[test]
fn test_verify_password_valid() {
let password = "correct-password";
let hash = hash_password(password).unwrap();
let valid = verify_password(password, &hash).unwrap();
assert!(valid);
}
#[test]
fn test_verify_password_invalid() {
let password = "correct-password";
let hash = hash_password(password).unwrap();
let valid = verify_password("wrong-password", &hash).unwrap();
assert!(!valid);
}
#[test]
fn test_verify_password_empty() {
let password = "password";
let hash = hash_password(password).unwrap();
let valid = verify_password("", &hash).unwrap();
assert!(!valid);
}
#[test]
fn test_verify_password_invalid_hash() {
let password = "password";
let invalid_hash = "not-a-valid-bcrypt-hash";
let result = verify_password(password, invalid_hash);
assert!(result.is_err());
}
#[test]
fn test_hash_password_empty() {
let password = "";
let hash = hash_password(password).unwrap();
assert!(!hash.is_empty());
assert!(hash.starts_with("$2"));
}
#[test]
fn test_hash_password_special_chars() {
let password = "p@ssw0rd!#$%^&*()";
let hash = hash_password(password).unwrap();
assert!(!hash.is_empty());
let valid = verify_password(password, &hash).unwrap();
assert!(valid);
}
#[test]
fn test_hash_password_unicode() {
let password = "пароль密码🔒";
let hash = hash_password(password).unwrap();
assert!(!hash.is_empty());
let valid = verify_password(password, &hash).unwrap();
assert!(valid);
}
#[test]
fn test_claims_serialization() {
let claims = Claims {
sub: "user-123".to_string(),
exp: 1234567890,
iat: 1234567800,
iss: "mockforge-registry".to_string(),
aud: "mockforge-api".to_string(),
token_type: TokenType::Access,
jti: None,
};
let json = serde_json::to_string(&claims).unwrap();
assert!(json.contains("user-123"));
assert!(json.contains("1234567890"));
assert!(json.contains("access")); assert!(json.contains("mockforge-registry")); assert!(json.contains("mockforge-api"));
let deserialized: Claims = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.sub, claims.sub);
assert_eq!(deserialized.exp, claims.exp);
assert_eq!(deserialized.iat, claims.iat);
assert_eq!(deserialized.iss, claims.iss);
assert_eq!(deserialized.aud, claims.aud);
assert_eq!(deserialized.token_type, TokenType::Access);
}
#[test]
fn test_token_contains_user_id() {
let user_id = "unique-user-id-12345";
let token = create_token(user_id, TEST_SECRET).unwrap();
let claims = verify_token(&token, TEST_SECRET).unwrap();
assert_eq!(claims.sub, user_id);
}
#[test]
fn test_multiple_tokens_same_user() {
let user_id = "user-123";
let token1 = create_token(user_id, TEST_SECRET).unwrap();
std::thread::sleep(std::time::Duration::from_millis(1100));
let token2 = create_token(user_id, TEST_SECRET).unwrap();
assert_ne!(token1, token2);
let claims1 = verify_token(&token1, TEST_SECRET).unwrap();
let claims2 = verify_token(&token2, TEST_SECRET).unwrap();
assert_eq!(claims1.sub, user_id);
assert_eq!(claims2.sub, user_id);
}
#[test]
fn test_token_includes_issuer_and_audience() {
let user_id = "user-iss-aud";
let token = create_access_token(user_id, TEST_SECRET).unwrap();
let claims = verify_token(&token, TEST_SECRET).unwrap();
assert!(!claims.iss.is_empty());
assert!(!claims.aud.is_empty());
}
#[test]
fn test_refresh_token_includes_issuer_and_audience() {
let user_id = "user-refresh-iss";
let (token, _, _) = create_refresh_token(user_id, TEST_SECRET).unwrap();
let (claims, _) = verify_refresh_token(&token, TEST_SECRET).unwrap();
assert!(!claims.iss.is_empty());
assert!(!claims.aud.is_empty());
}
#[test]
fn test_token_includes_kid_header() {
let user_id = "user-kid";
let token = create_access_token(user_id, TEST_SECRET).unwrap();
let header = jsonwebtoken::decode_header(&token).unwrap();
assert!(header.kid.is_some(), "Token should include kid header");
assert_eq!(header.kid.unwrap(), derive_kid(TEST_SECRET));
}
#[test]
fn test_key_rotation_with_previous_secret() {
let old_secret = "old-secret-key-for-jwt-minimum-32-characters";
let new_secret = "new-secret-key-for-jwt-minimum-32-characters";
let user_id = "user-rotation";
let token = create_access_token(user_id, old_secret).unwrap();
assert!(verify_token_with_secret(&token, new_secret).is_err());
let claims = verify_token_with_secret(&token, old_secret).unwrap();
assert_eq!(claims.sub, user_id);
}
#[test]
fn test_derive_kid_deterministic() {
let kid1 = derive_kid(TEST_SECRET);
let kid2 = derive_kid(TEST_SECRET);
assert_eq!(kid1, kid2, "derive_kid should be deterministic");
}
#[test]
fn test_derive_kid_different_for_different_secrets() {
let kid1 = derive_kid("secret-one-minimum-32-characters-long");
let kid2 = derive_kid("secret-two-minimum-32-characters-long");
assert_ne!(kid1, kid2, "Different secrets should produce different kids");
}
}