use chrono::{Duration, Utc};
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use zeroize::Zeroizing;
use crate::error::ApiError;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Claims {
pub sub: String,
pub exp: i64,
pub iat: i64,
pub iss: String,
pub role: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub tenant_id: Option<String>,
#[serde(flatten)]
pub extra: std::collections::HashMap<String, serde_json::Value>,
}
impl Claims {
pub fn for_user(user_id: &str, role: &str, expires_in: Duration) -> Self {
let now = Utc::now();
Self {
sub: user_id.to_string(),
exp: (now + expires_in).timestamp(),
iat: now.timestamp(),
iss: "vex-api".to_string(),
role: role.to_string(),
tenant_id: None,
extra: std::collections::HashMap::new(),
}
}
pub fn for_agent(agent_id: Uuid, expires_in: Duration) -> Self {
Self::for_user(&agent_id.to_string(), "agent", expires_in)
}
pub fn is_expired(&self) -> bool {
Utc::now().timestamp() > self.exp
}
pub fn has_role(&self, role: &str) -> bool {
self.role == role || self.role == "admin"
}
}
#[derive(Clone)]
pub struct JwtAuth {
encoding_key: EncodingKey,
decoding_key: DecodingKey,
validation: Validation,
}
impl JwtAuth {
pub fn new(secret: &str) -> Self {
let encoding_key = EncodingKey::from_secret(secret.as_bytes());
let decoding_key = DecodingKey::from_secret(secret.as_bytes());
let mut validation = Validation::new(jsonwebtoken::Algorithm::HS256);
validation.set_issuer(&["vex-api"]);
validation.validate_exp = true;
Self {
encoding_key,
decoding_key,
validation,
}
}
pub fn from_env() -> Result<Self, ApiError> {
let secret: Zeroizing<String> =
Zeroizing::new(std::env::var("VEX_JWT_SECRET").map_err(|_| {
ApiError::Internal(
"VEX_JWT_SECRET environment variable is required. \
Generate with: openssl rand -base64 32"
.to_string(),
)
})?);
if secret.len() < 32 {
return Err(ApiError::Internal(
"VEX_JWT_SECRET must be at least 32 characters for security".to_string(),
));
}
Ok(Self::new(&secret))
}
pub fn encode(&self, claims: &Claims) -> Result<String, ApiError> {
encode(&Header::default(), claims, &self.encoding_key)
.map_err(|e| ApiError::Internal(format!("JWT encoding error: {}", e)))
}
pub fn decode(&self, token: &str) -> Result<Claims, ApiError> {
decode::<Claims>(token, &self.decoding_key, &self.validation)
.map(|data| data.claims)
.map_err(|e| match e.kind() {
jsonwebtoken::errors::ErrorKind::ExpiredSignature => {
ApiError::Unauthorized("Token expired".to_string())
}
jsonwebtoken::errors::ErrorKind::InvalidToken => {
ApiError::Unauthorized("Invalid token".to_string())
}
_ => ApiError::Unauthorized(format!("Token validation failed: {}", e)),
})
}
pub fn extract_from_header(header: &str) -> Result<&str, ApiError> {
header.strip_prefix("Bearer ").ok_or_else(|| {
ApiError::Unauthorized("Invalid Authorization header format".to_string())
})
}
}
#[derive(Debug, Clone)]
pub struct ApiKey {
pub key_id: uuid::Uuid,
pub user_id: String,
pub name: String,
pub scopes: Vec<String>,
pub rate_limit: Option<u32>,
}
impl ApiKey {
pub async fn validate<S: vex_persist::ApiKeyStore>(
key: &str,
store: &S,
) -> Result<Self, ApiError> {
let record = vex_persist::validate_api_key(store, key)
.await
.map_err(|e| match e {
vex_persist::ApiKeyError::NotFound => {
ApiError::Unauthorized("Invalid API key".to_string())
}
vex_persist::ApiKeyError::Expired => {
ApiError::Unauthorized("API key expired".to_string())
}
vex_persist::ApiKeyError::Revoked => {
ApiError::Unauthorized("API key revoked".to_string())
}
vex_persist::ApiKeyError::InvalidFormat => {
ApiError::Unauthorized("Invalid API key format".to_string())
}
vex_persist::ApiKeyError::Storage(msg) => {
ApiError::Internal(format!("Key validation error: {}", msg))
}
})?;
let rate_limit = if record.scopes.contains(&"enterprise".to_string()) {
Some(10000)
} else if record.scopes.contains(&"pro".to_string()) {
Some(1000)
} else {
Some(100) };
Ok(ApiKey {
key_id: record.id,
user_id: record.user_id,
name: record.name,
scopes: record.scopes,
rate_limit,
})
}
pub fn has_scope(&self, scope: &str) -> bool {
self.scopes.iter().any(|s| s == scope || s == "*")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_jwt_encode_decode() {
let auth = JwtAuth::new("test-secret-key-32-bytes-long!!");
let claims = Claims::for_user("user123", "user", Duration::hours(1));
let token = auth.encode(&claims).unwrap();
let decoded = auth.decode(&token).unwrap();
assert_eq!(decoded.sub, "user123");
assert_eq!(decoded.role, "user");
assert!(!decoded.is_expired());
}
#[test]
fn test_expired_token() {
let auth = JwtAuth::new("test-secret-key-32-bytes-long!!");
let claims = Claims::for_user("user123", "user", Duration::seconds(-300));
let token = auth.encode(&claims).unwrap();
let result = auth.decode(&token);
match &result {
Ok(c) => println!("Decoded claims despite expiry: {:?}", c),
Err(e) => println!("Error returned: {:?}", e),
}
assert!(
matches!(result, Err(ApiError::Unauthorized(_))),
"Expected Unauthorized error, got: {:?}",
result
);
}
#[test]
fn test_role_check() {
let claims = Claims::for_user("user123", "admin", Duration::hours(1));
assert!(claims.has_role("admin"));
assert!(claims.has_role("user")); }
}