use anyhow::{anyhow, Result};
use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use std::marker::PhantomData;
use std::sync::Arc;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum AuthError {
#[error("Missing authorization header")]
MissingHeader,
#[error("Invalid authorization format (expected: Bearer <token>)")]
InvalidFormat,
#[error("Token validation failed: {0}")]
ValidationFailed(String),
#[error("Token expired")]
TokenExpired,
#[error("Region not allowed: {0}")]
RegionNotAllowed(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JwtClaims<E = ()>
where
E: Default,
{
pub token_id: String,
pub user_id: i32,
pub allowed_regions: Vec<String>,
pub exp: i64,
pub iat: i64,
#[serde(default)]
pub iss: Option<String>,
#[serde(default)]
pub aud: Option<String>,
#[serde(flatten, default)]
pub extra: E,
}
impl JwtClaims<()> {
pub fn new(
token_id: String,
user_id: i32,
allowed_regions: Vec<String>,
exp: i64,
iat: i64,
iss: Option<String>,
aud: Option<String>,
) -> Self {
Self {
token_id,
user_id,
allowed_regions,
exp,
iat,
iss,
aud,
extra: (),
}
}
}
impl<E: Default> JwtClaims<E> {
pub fn with_extra(
token_id: String,
user_id: i32,
allowed_regions: Vec<String>,
exp: i64,
iat: i64,
iss: Option<String>,
aud: Option<String>,
extra: E,
) -> Self {
Self {
token_id,
user_id,
allowed_regions,
exp,
iat,
iss,
aud,
extra,
}
}
}
#[derive(Debug)]
pub struct JwtValidator<E = ()> {
secret: String,
algorithm: Algorithm,
current_region: String,
expected_issuer: Option<String>,
expected_audience: Option<String>,
_phantom: PhantomData<E>,
}
impl<E> JwtValidator<E>
where
E: Default + DeserializeOwned + Clone + std::fmt::Debug,
{
pub fn new(
secret: String,
algorithm: String,
current_region: String,
expected_issuer: Option<String>,
expected_audience: Option<String>,
) -> Result<Self> {
let algo = match algorithm.to_uppercase().as_str() {
"HS256" => Algorithm::HS256,
"HS384" => Algorithm::HS384,
"HS512" => Algorithm::HS512,
_ => return Err(anyhow!("Unsupported JWT algorithm: {}", algorithm)),
};
Ok(Self {
secret,
algorithm: algo,
current_region,
expected_issuer,
expected_audience,
_phantom: PhantomData,
})
}
pub fn expected_issuer(&self) -> Option<&str> {
self.expected_issuer.as_deref()
}
pub fn expected_audience(&self) -> Option<&str> {
self.expected_audience.as_deref()
}
pub fn has_strict_validation(&self) -> bool {
self.expected_issuer.is_some() && self.expected_audience.is_some()
}
fn extract_token(auth_header: &str) -> Result<String, AuthError> {
let parts: Vec<&str> = auth_header.splitn(2, ' ').collect();
if parts.len() != 2 {
return Err(AuthError::InvalidFormat);
}
let auth_type = parts[0].to_lowercase();
let credentials = parts[1].trim();
if credentials.is_empty() {
return Err(AuthError::InvalidFormat);
}
if auth_type == "bearer" {
return Ok(credentials.to_string());
}
if auth_type == "basic" {
use base64::{engine::general_purpose, Engine as _};
let decoded = general_purpose::STANDARD
.decode(credentials)
.map_err(|_| AuthError::InvalidFormat)?;
let decoded_str = String::from_utf8(decoded).map_err(|_| AuthError::InvalidFormat)?;
let user_pass: Vec<&str> = decoded_str.splitn(2, ':').collect();
if user_pass.is_empty() {
return Err(AuthError::InvalidFormat);
}
let username = user_pass[0];
if let Some(token) = username.strip_prefix("Bearer ") {
let token = token.trim();
if !token.is_empty() {
return Ok(token.to_string());
}
}
if !username.is_empty() {
return Ok(username.to_string());
}
return Err(AuthError::InvalidFormat);
}
Err(AuthError::InvalidFormat)
}
pub fn validate(&self, auth_header: &str) -> Result<JwtClaims<E>, AuthError> {
let token = Self::extract_token(auth_header)?;
let mut validation = Validation::new(self.algorithm);
validation.validate_exp = true;
validation.validate_nbf = false;
if let Some(ref iss) = self.expected_issuer {
validation.set_issuer(&[iss]);
}
if let Some(ref aud) = self.expected_audience {
validation.set_audience(&[aud]);
}
let decoding_key = DecodingKey::from_secret(self.secret.as_bytes());
let token_data =
decode::<JwtClaims<E>>(&token, &decoding_key, &validation).map_err(|e| {
use jsonwebtoken::errors::ErrorKind;
match e.kind() {
ErrorKind::ExpiredSignature => AuthError::TokenExpired,
ErrorKind::InvalidIssuer => {
AuthError::ValidationFailed("Invalid issuer".to_string())
}
ErrorKind::InvalidAudience => {
AuthError::ValidationFailed("Invalid audience".to_string())
}
ErrorKind::InvalidSignature => {
AuthError::ValidationFailed("Invalid signature".to_string())
}
_ => AuthError::ValidationFailed(e.to_string()),
}
})?;
let claims = token_data.claims;
if !claims.allowed_regions.is_empty()
&& !claims.allowed_regions.contains(&"*".to_string())
&& !claims.allowed_regions.contains(&self.current_region)
{
return Err(AuthError::RegionNotAllowed(self.current_region.clone()));
}
Ok(claims)
}
pub fn validate_request<T>(&self, request: &http::Request<T>) -> Result<JwtClaims<E>, AuthError> {
let auth_header = request
.headers()
.get("proxy-authorization")
.or_else(|| request.headers().get("authorization"))
.ok_or(AuthError::MissingHeader)?;
let auth_str = auth_header.to_str().map_err(|_| AuthError::InvalidFormat)?;
self.validate(auth_str)
}
}
pub type SharedJwtValidator<E = ()> = Arc<JwtValidator<E>>;
#[cfg(test)]
mod tests {
use super::*;
use jsonwebtoken::{encode, EncodingKey, Header};
#[test]
fn test_extract_bearer_token() {
let result = JwtValidator::<()>::extract_token("Bearer abc123");
assert!(result.is_ok());
assert_eq!(result.unwrap(), "abc123");
let result = JwtValidator::<()>::extract_token("bearer xyz789");
assert!(result.is_ok());
assert_eq!(result.unwrap(), "xyz789");
let result = JwtValidator::<()>::extract_token("abc123");
assert!(result.is_err());
}
#[test]
fn test_extract_basic_auth_token() {
use base64::{engine::general_purpose, Engine as _};
let token = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.test";
let basic_auth = format!("{}:", token); let encoded = general_purpose::STANDARD.encode(basic_auth.as_bytes());
let auth_header = format!("Basic {}", encoded);
let result = JwtValidator::<()>::extract_token(&auth_header);
assert!(result.is_ok(), "Should extract token from Basic Auth");
assert_eq!(result.unwrap(), token);
let token2 = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.test2";
let basic_auth2 = format!("Bearer {}:", token2); let encoded2 = general_purpose::STANDARD.encode(basic_auth2.as_bytes());
let auth_header2 = format!("Basic {}", encoded2);
let result2 = JwtValidator::<()>::extract_token(&auth_header2);
assert!(
result2.is_ok(),
"Should extract token from 'Bearer token:' format"
);
assert_eq!(result2.unwrap(), token2);
let result3 = JwtValidator::<()>::extract_token("Basic ");
assert!(result3.is_err());
let result4 = JwtValidator::<()>::extract_token("Basic invalid!!!base64");
assert!(result4.is_err());
}
#[test]
fn test_jwt_validator_creation() {
let validator: Result<JwtValidator<()>, _> = JwtValidator::new(
"test_secret".to_string(),
"HS256".to_string(),
"us-east".to_string(),
Some("probeops".to_string()),
Some("forward-proxy".to_string()),
);
assert!(validator.is_ok());
let validator = validator.unwrap();
assert_eq!(validator.expected_issuer, Some("probeops".to_string()));
assert_eq!(
validator.expected_audience,
Some("forward-proxy".to_string())
);
let validator: Result<JwtValidator<()>, _> = JwtValidator::new(
"test_secret_long_enough_for_testing".to_string(),
"INVALID".to_string(),
"us-east".to_string(),
None,
None,
);
assert!(validator.is_err());
}
#[test]
fn test_valid_token_validation() {
let secret = "test_secret_key_probeops_2025";
let validator: JwtValidator<()> = JwtValidator::new(
secret.to_string(),
"HS256".to_string(),
"us-east".to_string(),
Some("probeops".to_string()),
Some("forward-proxy".to_string()),
)
.unwrap();
let claims: JwtClaims<()> = JwtClaims {
token_id: "test_token_123".to_string(),
user_id: 42,
allowed_regions: vec!["us-east".to_string(), "eu-west".to_string()],
exp: (chrono::Utc::now() + chrono::Duration::hours(1)).timestamp(),
iat: chrono::Utc::now().timestamp(),
iss: Some("probeops".to_string()),
aud: Some("forward-proxy".to_string()),
extra: (),
};
let token = encode(
&Header::new(Algorithm::HS256),
&claims,
&EncodingKey::from_secret(secret.as_bytes()),
)
.unwrap();
let auth_header = format!("Bearer {}", token);
let result = validator.validate(&auth_header);
assert!(result.is_ok());
let validated_claims = result.unwrap();
assert_eq!(validated_claims.token_id, "test_token_123");
assert_eq!(validated_claims.user_id, 42);
}
#[test]
fn test_expired_token() {
let secret = "test_secret_key_probeops_2025";
let validator: JwtValidator<()> = JwtValidator::new(
secret.to_string(),
"HS256".to_string(),
"us-east".to_string(),
Some("probeops".to_string()),
Some("forward-proxy".to_string()),
)
.unwrap();
let claims: JwtClaims<()> = JwtClaims {
token_id: "expired_token".to_string(),
user_id: 42,
allowed_regions: vec!["us-east".to_string()],
exp: (chrono::Utc::now() - chrono::Duration::hours(1)).timestamp(),
iat: (chrono::Utc::now() - chrono::Duration::hours(2)).timestamp(),
iss: Some("probeops".to_string()),
aud: Some("forward-proxy".to_string()),
extra: (),
};
let token = encode(
&Header::new(Algorithm::HS256),
&claims,
&EncodingKey::from_secret(secret.as_bytes()),
)
.unwrap();
let auth_header = format!("Bearer {}", token);
let result = validator.validate(&auth_header);
assert!(result.is_err());
match result.unwrap_err() {
AuthError::TokenExpired => (),
other => panic!("Expected TokenExpired, got {:?}", other),
}
}
#[test]
fn test_invalid_signature() {
let secret = "test_secret_key_probeops_2025";
let validator: JwtValidator<()> = JwtValidator::new(
secret.to_string(),
"HS256".to_string(),
"us-east".to_string(),
Some("probeops".to_string()),
Some("forward-proxy".to_string()),
)
.unwrap();
let claims: JwtClaims<()> = JwtClaims {
token_id: "test_token".to_string(),
user_id: 42,
allowed_regions: vec!["us-east".to_string()],
exp: (chrono::Utc::now() + chrono::Duration::hours(1)).timestamp(),
iat: chrono::Utc::now().timestamp(),
iss: Some("probeops".to_string()),
aud: Some("forward-proxy".to_string()),
extra: (),
};
let token = encode(
&Header::new(Algorithm::HS256),
&claims,
&EncodingKey::from_secret("wrong_secret".as_bytes()),
)
.unwrap();
let auth_header = format!("Bearer {}", token);
let result = validator.validate(&auth_header);
assert!(result.is_err());
match result.unwrap_err() {
AuthError::ValidationFailed(msg) => {
assert!(msg.contains("Invalid signature"));
}
other => panic!("Expected ValidationFailed, got {:?}", other),
}
}
#[test]
fn test_region_not_allowed() {
let secret = "test_secret_key_probeops_2025";
let validator: JwtValidator<()> = JwtValidator::new(
secret.to_string(),
"HS256".to_string(),
"ap-south".to_string(), Some("probeops".to_string()),
Some("forward-proxy".to_string()),
)
.unwrap();
let claims: JwtClaims<()> = JwtClaims {
token_id: "test_token".to_string(),
user_id: 42,
allowed_regions: vec!["us-east".to_string(), "eu-west".to_string()],
exp: (chrono::Utc::now() + chrono::Duration::hours(1)).timestamp(),
iat: chrono::Utc::now().timestamp(),
iss: Some("probeops".to_string()),
aud: Some("forward-proxy".to_string()),
extra: (),
};
let token = encode(
&Header::new(Algorithm::HS256),
&claims,
&EncodingKey::from_secret(secret.as_bytes()),
)
.unwrap();
let auth_header = format!("Bearer {}", token);
let result = validator.validate(&auth_header);
assert!(result.is_err());
match result.unwrap_err() {
AuthError::RegionNotAllowed(region) => {
assert_eq!(region, "ap-south");
}
other => panic!("Expected RegionNotAllowed, got {:?}", other),
}
}
#[test]
fn test_wildcard_region_allowed() {
let secret = "test_secret_key_probeops_2025";
let validator: JwtValidator<()> = JwtValidator::new(
secret.to_string(),
"HS256".to_string(),
"ap-south".to_string(), Some("probeops".to_string()),
Some("forward-proxy".to_string()),
)
.unwrap();
let claims: JwtClaims<()> = JwtClaims {
token_id: "test_token_wildcard".to_string(),
user_id: 42,
allowed_regions: vec!["*".to_string()], exp: (chrono::Utc::now() + chrono::Duration::hours(1)).timestamp(),
iat: chrono::Utc::now().timestamp(),
iss: Some("probeops".to_string()),
aud: Some("forward-proxy".to_string()),
extra: (),
};
let token = encode(
&Header::new(Algorithm::HS256),
&claims,
&EncodingKey::from_secret(secret.as_bytes()),
)
.unwrap();
let auth_header = format!("Bearer {}", token);
let result = validator.validate(&auth_header);
assert!(
result.is_ok(),
"Wildcard region should grant access to any region"
);
let validated_claims = result.unwrap();
assert_eq!(validated_claims.allowed_regions, vec!["*".to_string()]);
}
#[test]
fn test_invalid_issuer() {
let secret = "test_secret_key_probeops_2025";
let validator: JwtValidator<()> = JwtValidator::new(
secret.to_string(),
"HS256".to_string(),
"us-east".to_string(),
Some("probeops".to_string()),
Some("forward-proxy".to_string()),
)
.unwrap();
let claims: JwtClaims<()> = JwtClaims {
token_id: "test_token".to_string(),
user_id: 42,
allowed_regions: vec!["us-east".to_string()],
exp: (chrono::Utc::now() + chrono::Duration::hours(1)).timestamp(),
iat: chrono::Utc::now().timestamp(),
iss: Some("malicious_issuer".to_string()),
aud: Some("forward-proxy".to_string()),
extra: (),
};
let token = encode(
&Header::new(Algorithm::HS256),
&claims,
&EncodingKey::from_secret(secret.as_bytes()),
)
.unwrap();
let auth_header = format!("Bearer {}", token);
let result = validator.validate(&auth_header);
assert!(result.is_err());
match result.unwrap_err() {
AuthError::ValidationFailed(msg) => {
assert!(msg.contains("Invalid issuer"));
}
other => panic!(
"Expected ValidationFailed with issuer error, got {:?}",
other
),
}
}
#[test]
fn test_invalid_audience() {
let secret = "test_secret_key_probeops_2025";
let validator: JwtValidator<()> = JwtValidator::new(
secret.to_string(),
"HS256".to_string(),
"us-east".to_string(),
Some("probeops".to_string()),
Some("forward-proxy".to_string()),
)
.unwrap();
let claims: JwtClaims<()> = JwtClaims {
token_id: "test_token".to_string(),
user_id: 42,
allowed_regions: vec!["us-east".to_string()],
exp: (chrono::Utc::now() + chrono::Duration::hours(1)).timestamp(),
iat: chrono::Utc::now().timestamp(),
iss: Some("probeops".to_string()),
aud: Some("wrong_service".to_string()),
extra: (),
};
let token = encode(
&Header::new(Algorithm::HS256),
&claims,
&EncodingKey::from_secret(secret.as_bytes()),
)
.unwrap();
let auth_header = format!("Bearer {}", token);
let result = validator.validate(&auth_header);
assert!(result.is_err());
match result.unwrap_err() {
AuthError::ValidationFailed(msg) => {
assert!(msg.contains("Invalid audience"));
}
other => panic!(
"Expected ValidationFailed with audience error, got {:?}",
other
),
}
}
#[test]
fn test_configurable_issuer_audience_disabled() {
let secret = "test_secret_key_probeops_2025";
let validator: JwtValidator<()> = JwtValidator::new(
secret.to_string(),
"HS256".to_string(),
"us-east".to_string(),
None, None, )
.unwrap();
assert_eq!(validator.expected_issuer, None);
assert_eq!(validator.expected_audience, None);
let claims: JwtClaims<()> = JwtClaims {
token_id: "test_token".to_string(),
user_id: 42,
allowed_regions: vec!["us-east".to_string()],
exp: (chrono::Utc::now() + chrono::Duration::hours(1)).timestamp(),
iat: chrono::Utc::now().timestamp(),
iss: None, aud: None, extra: (),
};
let token = encode(
&Header::new(Algorithm::HS256),
&claims,
&EncodingKey::from_secret(secret.as_bytes()),
)
.unwrap();
let auth_header = format!("Bearer {}", token);
let result = validator.validate(&auth_header);
assert!(
result.is_ok(),
"Token without issuer/audience should succeed when validation disabled"
);
let validator_with_checks: JwtValidator<()> = JwtValidator::new(
secret.to_string(),
"HS256".to_string(),
"us-east".to_string(),
Some("probeops".to_string()),
Some("forward-proxy".to_string()),
)
.unwrap();
assert_eq!(
validator_with_checks.expected_issuer,
Some("probeops".to_string())
);
assert_eq!(
validator_with_checks.expected_audience,
Some("forward-proxy".to_string())
);
let result = validator_with_checks.validate(&auth_header);
if result.is_err() {
let err = result.unwrap_err();
assert!(
matches!(err, AuthError::ValidationFailed(_)),
"Should be ValidationFailed error"
);
}
}
#[test]
fn test_proxy_authorization_header() {
let secret = "test_secret_key_probeops_2025";
let validator: JwtValidator<()> = JwtValidator::new(
secret.to_string(),
"HS256".to_string(),
"us-east".to_string(),
Some("probeops".to_string()),
Some("forward-proxy".to_string()),
)
.unwrap();
let claims: JwtClaims<()> = JwtClaims {
token_id: "test_token".to_string(),
user_id: 42,
allowed_regions: vec!["us-east".to_string()],
exp: (chrono::Utc::now() + chrono::Duration::hours(1)).timestamp(),
iat: chrono::Utc::now().timestamp(),
iss: Some("probeops".to_string()),
aud: Some("forward-proxy".to_string()),
extra: (),
};
let token = encode(
&Header::new(Algorithm::HS256),
&claims,
&EncodingKey::from_secret(secret.as_bytes()),
)
.unwrap();
let mut request = http::Request::builder()
.header("Proxy-Authorization", format!("Bearer {}", token))
.body(())
.unwrap();
let result = validator.validate_request(&request);
assert!(result.is_ok());
request = http::Request::builder()
.header("Authorization", format!("Bearer {}", token))
.body(())
.unwrap();
let result = validator.validate_request(&request);
assert!(result.is_ok());
request = http::Request::builder().body(()).unwrap();
let result = validator.validate_request(&request);
assert!(result.is_err());
match result.unwrap_err() {
AuthError::MissingHeader => (),
other => panic!("Expected MissingHeader, got {:?}", other),
}
}
#[test]
fn test_extended_claims() {
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
struct CustomClaims {
rate_limit_per_hour: Option<usize>,
tier: Option<String>,
}
let secret = "test_secret_key_probeops_2025";
let validator: JwtValidator<CustomClaims> = JwtValidator::new(
secret.to_string(),
"HS256".to_string(),
"us-east".to_string(),
Some("probeops".to_string()),
Some("forward-proxy".to_string()),
)
.unwrap();
let claims: JwtClaims<CustomClaims> = JwtClaims {
token_id: "extended_token".to_string(),
user_id: 42,
allowed_regions: vec!["us-east".to_string()],
exp: (chrono::Utc::now() + chrono::Duration::hours(1)).timestamp(),
iat: chrono::Utc::now().timestamp(),
iss: Some("probeops".to_string()),
aud: Some("forward-proxy".to_string()),
extra: CustomClaims {
rate_limit_per_hour: Some(10000),
tier: Some("pro".to_string()),
},
};
let token = encode(
&Header::new(Algorithm::HS256),
&claims,
&EncodingKey::from_secret(secret.as_bytes()),
)
.unwrap();
let auth_header = format!("Bearer {}", token);
let result = validator.validate(&auth_header);
assert!(result.is_ok(), "Extended claims token should be valid");
let validated_claims = result.unwrap();
assert_eq!(validated_claims.token_id, "extended_token");
assert_eq!(validated_claims.extra.rate_limit_per_hour, Some(10000));
assert_eq!(validated_claims.extra.tier, Some("pro".to_string()));
}
#[test]
fn test_backwards_compatibility_default_claims() {
let secret = "test_secret_key_probeops_2025";
let validator: JwtValidator<()> = JwtValidator::new(
secret.to_string(),
"HS256".to_string(),
"us-east".to_string(),
None,
None,
)
.unwrap();
let claims: JwtClaims<()> = JwtClaims {
token_id: "standard_token".to_string(),
user_id: 1,
allowed_regions: vec!["us-east".to_string()],
exp: (chrono::Utc::now() + chrono::Duration::hours(1)).timestamp(),
iat: chrono::Utc::now().timestamp(),
iss: None,
aud: None,
extra: (),
};
let token = encode(
&Header::new(Algorithm::HS256),
&claims,
&EncodingKey::from_secret(secret.as_bytes()),
)
.unwrap();
let auth_header = format!("Bearer {}", token);
let result = validator.validate(&auth_header);
assert!(result.is_ok(), "Standard claims should work");
}
#[test]
fn test_jwt_claims_new_constructor() {
let claims = JwtClaims::new(
"token_123".to_string(),
42,
vec!["us-east".to_string()],
chrono::Utc::now().timestamp() + 3600,
chrono::Utc::now().timestamp(),
Some("issuer".to_string()),
Some("audience".to_string()),
);
assert_eq!(claims.token_id, "token_123");
assert_eq!(claims.user_id, 42);
assert_eq!(claims.allowed_regions, vec!["us-east".to_string()]);
assert_eq!(claims.iss, Some("issuer".to_string()));
assert_eq!(claims.aud, Some("audience".to_string()));
assert_eq!(claims.extra, ()); }
#[test]
fn test_jwt_claims_with_extra_constructor() {
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq)]
struct CustomClaims {
tier: Option<String>,
rate_limit: Option<usize>,
}
let custom = CustomClaims {
tier: Some("pro".to_string()),
rate_limit: Some(10000),
};
let claims = JwtClaims::with_extra(
"token_456".to_string(),
99,
vec!["*".to_string()],
chrono::Utc::now().timestamp() + 7200,
chrono::Utc::now().timestamp(),
None,
None,
custom.clone(),
);
assert_eq!(claims.token_id, "token_456");
assert_eq!(claims.user_id, 99);
assert_eq!(claims.extra.tier, Some("pro".to_string()));
assert_eq!(claims.extra.rate_limit, Some(10000));
assert_eq!(claims.extra, custom);
}
}