use super::test_helpers::*;
use crate::{claims, jwks, parser, AuthError};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::task::JoinSet;
#[cfg(test)]
mod tests {
use super::*;
use base64::{engine::general_purpose, Engine as _};
use jsonwebtoken::{encode, Algorithm, Header};
use p256::{ecdsa::SigningKey, pkcs8::EncodePrivateKey};
use rand::rngs::OsRng;
#[tokio::test]
async fn test_end_to_end_jwt_validation_with_valid_signature() {
let signing_key = SigningKey::random(&mut OsRng);
let private_key_pem = signing_key.to_pkcs8_pem(Default::default()).unwrap();
let encoding_key =
jsonwebtoken::EncodingKey::from_ec_pem(private_key_pem.as_bytes()).unwrap();
let public_key = signing_key.verifying_key();
let point = public_key.to_encoded_point(false);
let x = general_purpose::URL_SAFE_NO_PAD.encode(point.x().unwrap());
let y = general_purpose::URL_SAFE_NO_PAD.encode(point.y().unwrap());
let kid = "dynamic-e2e-test-kid";
let jwk = jwks::Jwk {
kid: kid.to_string(),
kty: "EC".to_string(),
alg: Some("ES256".to_string()),
key_use: Some("sig".to_string()),
key_ops: Some(vec!["verify".to_string()]),
crv: Some("P-256".to_string()),
x: Some(x),
y: Some(y),
n: None,
e: None,
ext: Some(true),
};
let mock_server = MockServer::start().await;
let jwks_response = jwks::JwksResponse { keys: vec![jwk] };
Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/jwks"))
.respond_with(ResponseTemplate::new(200).set_body_json(&jwks_response))
.mount(&mock_server)
.await;
let jwks_url = format!("{}/jwks", mock_server.uri());
let jwks_cache = jwks::JwksCache::new(&jwks_url);
let mut original_claims = create_test_claims();
original_claims.kid = None;
let mut header = Header::new(Algorithm::ES256);
header.kid = Some(kid.to_string());
let token = encode(&header, &original_claims, &encoding_key)
.expect("Failed to sign JWT with dynamic key");
let validated_claims = claims::Claims::from_token(&token, &jwks_cache)
.await
.expect("JWT validation should succeed for a token signed with a valid key");
assert_eq!(
validated_claims.sub, original_claims.sub,
"User ID (sub) should match"
);
assert_eq!(
validated_claims.exp, original_claims.exp,
"Expiration time (exp) should match"
);
assert_eq!(
validated_claims.iat, original_claims.iat,
"Issued at (iat) should match"
);
assert_eq!(
validated_claims.jti, original_claims.jti,
"JWT ID (jti) should match"
);
assert_eq!(
validated_claims.email, original_claims.email,
"Email should match"
);
assert_eq!(
validated_claims.phone, original_claims.phone,
"Phone should match"
);
assert_eq!(
validated_claims.role, original_claims.role,
"Role should match"
);
assert_eq!(
validated_claims.app_metadata, original_claims.app_metadata,
"App metadata should match"
);
assert_eq!(
validated_claims.user_metadata, original_claims.user_metadata,
"User metadata should match"
);
assert_eq!(
validated_claims.aud, original_claims.aud,
"Audience (aud) should match"
);
assert_eq!(
validated_claims.iss, original_claims.iss,
"Issuer (iss) should match"
);
assert_eq!(
validated_claims.aal, original_claims.aal,
"Authentication Assurance Level (aal) should match"
);
assert_eq!(
validated_claims.amr, original_claims.amr,
"Authentication Methods References (amr) should match"
);
assert_eq!(
validated_claims.session_id, original_claims.session_id,
"Session ID should match"
);
assert_eq!(
validated_claims.is_anonymous, original_claims.is_anonymous,
"Anonymous status should match"
);
assert_eq!(
validated_claims.kid.as_deref(),
Some(kid),
"Key ID (kid) should be correctly populated from the JWT header"
);
}
#[tokio::test]
async fn test_end_to_end_validation_with_invalid_signature() {
let mock_server = create_mock_jwks_server().await;
let jwks_url = format!("{}/jwks", mock_server.uri());
let jwks_cache = jwks::JwksCache::new(&jwks_url);
let token = create_mock_jwt_token();
let header =
parser::JwtParser::decode_header(&token).expect("Should decode header successfully");
let jwk = jwks_cache
.find_key(header.kid.as_ref().unwrap())
.await
.expect("Should find key in JWKS");
let decoding_key =
parser::JwtParser::create_decoding_key(&jwk).expect("Should create decoding key");
let algorithm =
parser::JwtParser::parse_algorithm(&header.alg).expect("Should parse algorithm");
let result = parser::JwtParser::verify_and_decode(&token, &decoding_key, algorithm);
assert!(
result.is_err(),
"Validation should fail for mock token with invalid signature"
);
match result {
Err(AuthError::Verification) => {
}
Err(e) => {
panic!("Expected a verification error due to invalid signature, but got: {e:?}");
}
Ok(_) => {
panic!("Validation should have failed for an invalid signature, but it passed");
}
}
}
#[tokio::test]
async fn test_error_handling_integration() {
let invalid_token = "invalid.jwt.token";
let result = parser::JwtParser::decode_header(invalid_token);
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), AuthError::DecodeHeader));
}
#[tokio::test]
async fn test_jwks_concurrent_access() {
let mock_server = create_mock_jwks_server().await;
let jwks_url = format!("{}/jwks", mock_server.uri());
let jwks_cache = Arc::new(jwks::JwksCache::new(&jwks_url));
let mut join_set = JoinSet::new();
for _i in 0..5 {
let jwks_clone = jwks_cache.clone();
join_set.spawn(async move {
let jwks_result = jwks_clone.get_jwks().await;
assert!(jwks_result.is_ok());
let key_result = jwks_clone.find_key("test-key-id").await;
assert!(key_result.is_ok());
jwks_result.is_ok()
});
}
let mut results = Vec::new();
while let Some(result) = join_set.join_next().await {
results.push(result.unwrap());
}
assert_eq!(results.len(), 5);
for jwks_ok in results {
assert!(jwks_ok);
}
}
#[tokio::test]
async fn test_claims_module_integration() {
let claims = create_test_claims();
assert_eq!(claims.user_id(), TEST_USER_ID);
assert_eq!(claims.email(), Some(TEST_EMAIL));
assert_eq!(claims.role(), "authenticated");
assert!(!claims.is_anonymous());
assert!(claims.validate_security().is_ok());
let name: Option<String> = claims.get_user_metadata("name");
assert_eq!(name, Some("Test User".to_string()));
let provider: Option<String> = claims.get_app_metadata("provider");
assert_eq!(provider, Some("email".to_string()));
let serialized = serde_json::to_string(&claims).expect("Should serialize successfully");
let deserialized: claims::Claims =
serde_json::from_str(&serialized).expect("Should deserialize successfully");
assert_eq!(deserialized.user_id(), claims.user_id());
assert_eq!(deserialized.email(), claims.email());
assert_eq!(deserialized.role(), claims.role());
assert_eq!(deserialized.is_anonymous(), claims.is_anonymous());
}
#[tokio::test]
async fn test_claims_concurrent_safety() {
let claims = Arc::new(create_test_claims());
let mut join_set = JoinSet::new();
for i in 0..10 {
let claims_clone = claims.clone();
join_set.spawn(async move {
let user_id = claims_clone.user_id();
let email = claims_clone.email();
let role = claims_clone.role();
let is_anonymous = claims_clone.is_anonymous();
let name: Option<String> = claims_clone.get_user_metadata("name");
let provider: Option<String> = claims_clone.get_app_metadata("provider");
let security_ok = claims_clone.validate_security().is_ok();
(
user_id == TEST_USER_ID,
email == Some(TEST_EMAIL),
role == "authenticated",
!is_anonymous,
name == Some("Test User".to_string()),
provider == Some("email".to_string()),
security_ok,
i,
)
});
}
let mut results = Vec::new();
while let Some(result) = join_set.join_next().await {
results.push(result.unwrap());
}
assert_eq!(results.len(), 10);
for (user_id_ok, email_ok, role_ok, not_anonymous, name_ok, provider_ok, security_ok, _) in
results
{
assert!(user_id_ok, "User ID should be consistent across threads");
assert!(email_ok, "Email should be consistent across threads");
assert!(role_ok, "Role should be consistent across threads");
assert!(
not_anonymous,
"Anonymous status should be consistent across threads"
);
assert!(name_ok, "User metadata should be consistent across threads");
assert!(
provider_ok,
"App metadata should be consistent across threads"
);
assert!(
security_ok,
"Security validation should be consistent across threads"
);
}
}
#[tokio::test]
async fn test_claims_serialization_performance() {
let claims = create_test_claims();
let iterations = 1000;
let start = Instant::now();
for _ in 0..iterations {
let _serialized =
serde_json::to_string(&claims).expect("Serialization should not fail");
}
let serialization_duration = start.elapsed();
let avg_serialization_time = serialization_duration / iterations;
assert!(
avg_serialization_time < Duration::from_millis(1),
"Serialization too slow: {avg_serialization_time:?} per operation"
);
let serialized = serde_json::to_string(&claims).unwrap();
let start = Instant::now();
for _ in 0..iterations {
let _deserialized: claims::Claims =
serde_json::from_str(&serialized).expect("Deserialization should not fail");
}
let deserialization_duration = start.elapsed();
let avg_deserialization_time = deserialization_duration / iterations;
assert!(
avg_deserialization_time < Duration::from_millis(2),
"Deserialization too slow: {avg_deserialization_time:?} per operation"
);
println!(
"Performance metrics - Serialization: {avg_serialization_time:?}/op, Deserialization: {avg_deserialization_time:?}/op"
);
}
#[tokio::test]
async fn test_end_to_end_performance() {
let mock_server = create_mock_jwks_server().await;
let jwks_url = format!("{}/jwks", mock_server.uri());
let jwks_cache = jwks::JwksCache::new(&jwks_url);
let _ = jwks_cache.get_jwks().await;
let token = create_mock_jwt_token();
let iterations = 100;
let start = Instant::now();
for _ in 0..iterations {
let header = parser::JwtParser::decode_header(&token);
if let Ok(header) = header {
if let Some(kid) = header.kid {
let _ = jwks_cache.find_key(&kid).await;
}
}
}
let total_duration = start.elapsed();
let avg_operation_time = total_duration / iterations;
assert!(
avg_operation_time < Duration::from_millis(10),
"End-to-end operation too slow: {avg_operation_time:?} per operation"
);
println!("End-to-end performance: {avg_operation_time:?}/op for {iterations} iterations");
}
#[tokio::test]
async fn test_complete_module_integration() {
let mock_server = create_mock_jwks_server().await;
let jwks_url = format!("{}/jwks", mock_server.uri());
let jwks_cache = jwks::JwksCache::new(&jwks_url);
let token = create_mock_jwt_token();
let test_claims = create_test_claims();
let header_result = parser::JwtParser::decode_header(&token);
assert!(
header_result.is_ok(),
"Parser should decode header successfully"
);
let header = header_result.unwrap();
assert!(header.kid.is_some(), "Header should contain kid");
let jwk_result = jwks_cache.find_key(header.kid.as_ref().unwrap()).await;
assert!(jwk_result.is_ok(), "JWKS should find key successfully");
assert!(
test_claims.validate_security().is_ok(),
"Claims should pass security validation"
);
assert_eq!(
test_claims.user_id(),
TEST_USER_ID,
"Claims should provide correct user ID"
);
let invalid_token = "invalid.jwt.token";
let invalid_header_result = parser::JwtParser::decode_header(invalid_token);
assert!(
invalid_header_result.is_err(),
"Parser should reject invalid token"
);
assert!(matches!(
invalid_header_result.unwrap_err(),
AuthError::DecodeHeader
));
let mut join_set = JoinSet::new();
for _ in 0..5 {
let jwks_clone = Arc::new(jwks_cache.clone());
let claims_clone = test_claims.clone();
let token_clone = token.clone();
join_set.spawn(async move {
let header = parser::JwtParser::decode_header(&token_clone)?;
let _jwk = jwks_clone.find_key(header.kid.as_ref().unwrap()).await?;
claims_clone.validate_security()?;
Ok::<(), AuthError>(())
});
}
let mut concurrent_results = Vec::new();
while let Some(result) = join_set.join_next().await {
concurrent_results.push(result.unwrap());
}
assert_eq!(concurrent_results.len(), 5);
for result in concurrent_results {
assert!(
result.is_ok(),
"Concurrent module integration should succeed"
);
}
}
}