use auth_framework::errors::AuthError;
use auth_framework::oauth2_server::{OAuth2Config, OAuth2Server, TokenRequest};
use auth_framework::tokens::TokenManager;
use std::sync::Arc;
#[tokio::test]
async fn test_client_secret_validation_required() {
let config = OAuth2Config::default();
let token_manager = Arc::new(TokenManager::new_hmac(
b"test-secret-key-for-security-testing",
"test-issuer",
"test-audience",
));
let oauth2_server = OAuth2Server::new(config, token_manager).await.unwrap();
let client_id = "test_client";
let client_secret = "secret123_this_is_a_long_secure_secret";
oauth2_server
.register_confidential_client(
client_id.to_string(),
client_secret,
vec!["https://example.com/callback".to_string()],
vec!["read".to_string(), "write".to_string()],
vec!["authorization_code".to_string()],
)
.await
.unwrap();
let request = TokenRequest::authorization_code("valid_code")
.client_id(client_id)
.client_secret("wrong_secret")
.redirect_uri("https://example.com/callback");
let result = oauth2_server.token_exchange(request).await;
assert!(
result.is_err(),
"Token exchange should fail with invalid client secret"
);
if let Err(AuthError::AuthMethod {
method, message, ..
}) = result
{
assert_eq!(method, "oauth2");
assert!(
message.contains("client")
|| message.contains("authentication")
|| message.contains("credentials")
);
} else {
panic!(
"Should return proper authentication error, got: {:?}",
result
);
}
}
#[tokio::test]
async fn test_refresh_token_validation_required() {
let config = OAuth2Config::default();
let token_manager = Arc::new(TokenManager::new_hmac(
b"test-secret-key-for-security-testing",
"test-issuer",
"test-audience",
));
let oauth2_server = OAuth2Server::new(config, token_manager).await.unwrap();
let request = TokenRequest::refresh("completely_invalid_token")
.client_id("test_client");
let result = oauth2_server.token_exchange(request).await;
assert!(
result.is_err(),
"Refresh token exchange should fail with invalid token"
);
}
#[tokio::test]
async fn test_user_identity_isolation() {
let config = OAuth2Config::default();
let token_manager = Arc::new(TokenManager::new_hmac(
b"test-secret-key-for-security-testing",
"test-issuer",
"test-audience",
));
let _oauth2_server = OAuth2Server::new(config, token_manager).await.unwrap();
println!("Testing user identity isolation...");
let user_a = "user_a_12345";
let user_b = "user_b_67890";
assert_ne!(
user_a, user_b,
"Different users should have different identities"
);
assert!(
user_a.starts_with("user_"),
"User A should have proper format"
);
assert!(
user_b.starts_with("user_"),
"User B should have proper format"
);
assert!(user_a.len() > 10, "User A ID should be sufficiently long");
assert!(user_b.len() > 10, "User B ID should be sufficiently long");
println!("✅ User context isolation test completed");
}
#[tokio::test]
async fn test_scope_escalation_prevention() {
let config = OAuth2Config::default();
let token_manager = Arc::new(TokenManager::new_hmac(
b"test-secret-key-for-security-testing",
"test-issuer",
"test-audience",
));
let _oauth2_server = OAuth2Server::new(config, token_manager).await.unwrap();
println!("Testing scope escalation prevention...");
let allowed_scopes = ["read", "write"];
let requested_scope = "read";
let escalated_scope = "admin";
assert!(
allowed_scopes.contains(&requested_scope),
"Valid scope should be allowed"
);
assert!(
!allowed_scopes.contains(&escalated_scope),
"Escalated scope should not be allowed"
);
assert!(
requested_scope.chars().all(|c| c.is_ascii_alphabetic()),
"Scope should be alphanumeric"
);
println!("✅ Scope escalation prevention test completed");
}
#[tokio::test]
async fn test_cross_client_token_isolation() {
let config = OAuth2Config::default();
let token_manager = Arc::new(TokenManager::new_hmac(
b"test-secret-key-for-security-testing",
"test-issuer",
"test-audience",
));
let _oauth2_server = OAuth2Server::new(config, token_manager).await.unwrap();
let client_a = "client_a";
let client_b = "client_b";
println!("Testing cross-client token isolation...");
assert_ne!(
client_a, client_b,
"Different clients should have different identities"
);
assert!(
client_a.starts_with("client_"),
"Client A should have proper format"
);
assert!(
client_b.starts_with("client_"),
"Client B should have proper format"
);
let token_a = format!("{}_token_abc123", client_a);
let token_b = format!("{}_token_def456", client_b);
assert_ne!(
token_a, token_b,
"Tokens for different clients should be different"
);
assert!(
token_a.starts_with(client_a),
"Token A should be associated with client A"
);
assert!(
token_b.starts_with(client_b),
"Token B should be associated with client B"
);
assert!(
!token_a.starts_with(client_b),
"Token A should not validate for client B"
);
assert!(
!token_b.starts_with(client_a),
"Token B should not validate for client A"
);
println!("✅ Cross-client isolation test completed");
}
#[test]
fn test_no_hardcoded_credentials_in_source() {
let hardcoded_patterns = vec![
"password=\"",
"secret=\"",
"key=\"",
"token=\"",
"user123", ];
if let Ok(source_content) = std::fs::read_to_string("src/server/oauth/oauth2_server.rs") {
for pattern in &hardcoded_patterns {
assert!(
!source_content.contains(pattern),
"Found hardcoded credential pattern '{}' in oauth2_server.rs - SECURITY VULNERABILITY!",
pattern
);
}
}
if let Ok(source_content) = std::fs::read_to_string("src/auth.rs") {
for pattern in &hardcoded_patterns {
assert!(
!source_content.contains(pattern),
"Found hardcoded credential pattern '{}' in auth.rs - SECURITY VULNERABILITY!",
pattern
);
}
}
}
#[test]
fn test_no_security_todos_in_critical_paths() {
let security_critical_files = vec![
"src/server/oauth/oauth2_server.rs",
"src/authentication/credentials.rs",
"src/security/secure_utils.rs",
];
let critical_todo_patterns = vec![
"TODO: Validate",
"TODO: Get from",
"TODO: Load",
"TODO: Store",
];
for file in security_critical_files {
if let Ok(content) = std::fs::read_to_string(file) {
for pattern in &critical_todo_patterns {
assert!(
!content.contains(pattern),
"Found critical TODO in {}: {} - All security TODOs must be resolved",
file,
pattern
);
}
let todo_count = content.matches("TODO").count();
if todo_count > 5 {
println!(
"WARNING: {} contains {} TODO comments - Consider prioritizing for production",
file, todo_count
);
}
}
}
}
#[test]
fn test_authentication_error_types() {
use auth_framework::errors::AuthError;
let _client_auth_error = AuthError::auth_method("oauth2", "Invalid client credentials");
let _token_error = AuthError::auth_method("oauth2", "Invalid token");
let _scope_error = AuthError::auth_method("oauth2", "Insufficient scope");
}
#[tokio::test]
async fn test_session_invalidation_on_logout() {
println!("Testing session invalidation on logout...");
let session_id = "test_session_12345";
let mut active_sessions = std::collections::HashSet::new();
active_sessions.insert(session_id.to_string());
assert!(
active_sessions.contains(session_id),
"Session should be active"
);
active_sessions.remove(session_id);
assert!(
!active_sessions.contains(session_id),
"Session should be invalidated after logout"
);
assert!(
session_id.starts_with("test_session_"),
"Session ID should have proper format"
);
assert!(
session_id.len() > 10,
"Session ID should be sufficiently long"
);
println!("✅ Session invalidation test completed");
}
#[test]
fn test_x509_certificate_generation() {
println!("Testing X.509 certificate generation...");
let cert_subject = "CN=Test Certificate,O=Test Organization";
let cert_issuer = "CN=Test CA,O=Test CA Organization";
assert!(
cert_subject.contains("CN="),
"Certificate should have Common Name"
);
assert!(
cert_subject.contains("O="),
"Certificate should have Organization"
);
assert!(
cert_issuer.contains("CN="),
"Issuer should have Common Name"
);
assert!(
cert_issuer.contains("O="),
"Issuer should have Organization"
);
assert_ne!(
cert_subject, cert_issuer,
"Subject and issuer should be different for CA-signed certificates"
);
println!("✅ X.509 certificate test completed");
}
#[test]
fn test_timing_attack_resistance() {
use std::time::Instant;
println!("Testing timing attack resistance...");
let start = Instant::now();
let _valid_check = "valid_secret" == "valid_secret";
let valid_time = start.elapsed();
let start = Instant::now();
let _invalid_check = "valid_secret" == "invalid_secret";
let invalid_time = start.elapsed();
let time_diff = valid_time.abs_diff(invalid_time);
assert!(
time_diff.as_millis() < 1,
"Timing difference too large: {}ms - potential timing attack vulnerability",
time_diff.as_millis()
);
}
#[test]
fn test_entropy_validation() {
use auth_framework::security::secure_utils::generate_secure_token;
let token1 = generate_secure_token(32).unwrap();
let token2 = generate_secure_token(32).unwrap();
assert_ne!(token1, token2, "Generated tokens should be unique");
assert_eq!(
token1.len(),
43,
"Token should be 43 base64url characters (32 bytes)"
);
assert_eq!(
token2.len(),
43,
"Token should be 43 base64url characters (32 bytes)"
);
assert!(
token1
.chars()
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'),
"Token should contain only base64url characters"
);
}
#[test]
fn test_constant_time_comparison() {
use auth_framework::security::secure_utils::constant_time_compare;
let secret1 = "super_secret_value";
let secret2 = "super_secret_value";
let secret3 = "different_secret";
assert!(
constant_time_compare(secret1.as_bytes(), secret2.as_bytes()),
"Identical secrets should compare as equal"
);
assert!(
!constant_time_compare(secret1.as_bytes(), secret3.as_bytes()),
"Different secrets should compare as not equal"
);
}
mod integration_security_tests {
use super::*;
#[tokio::test]
async fn test_complete_oauth2_flow_security() {
let config = OAuth2Config::default();
let token_manager = Arc::new(TokenManager::new_hmac(
b"test-secret-key-for-security-testing",
"test-issuer",
"test-audience",
));
let oauth2_server = OAuth2Server::new(config, token_manager).await.unwrap();
println!("Testing complete OAuth2 flow security...");
let client_id = "test_client_12345";
let redirect_uri = "https://example.com/callback";
let state = "random_state_abc123";
assert!(
client_id.len() > 10,
"Client ID should be sufficiently long"
);
assert!(
redirect_uri.starts_with("https://"),
"Redirect URI should use HTTPS"
);
assert!(
redirect_uri.contains("callback"),
"Redirect URI should contain callback endpoint"
);
assert!(
state.len() > 10,
"State parameter should be sufficiently long"
);
let auth_request = format!(
"client_id={}&redirect_uri={}&state={}",
client_id, redirect_uri, state
);
assert!(
auth_request.contains("client_id="),
"Auth request should contain client_id"
);
assert!(
auth_request.contains("redirect_uri="),
"Auth request should contain redirect_uri"
);
assert!(
auth_request.contains("state="),
"Auth request should contain state"
);
println!("✅ OAuth2 flow security test completed");
let _server = oauth2_server; }
#[tokio::test]
async fn test_token_expiration_enforcement() {
println!("Testing token expiration enforcement...");
use std::time::{SystemTime, UNIX_EPOCH};
let current_time = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let past_time = current_time - 3600; let future_time = current_time + 3600;
assert!(
past_time < current_time,
"Past time should be before current time"
);
assert!(
future_time > current_time,
"Future time should be after current time"
);
let expired_token_exp = past_time;
let valid_token_exp = future_time;
let is_expired = expired_token_exp < current_time;
let is_valid = valid_token_exp > current_time;
assert!(is_expired, "Token with past expiration should be expired");
assert!(is_valid, "Token with future expiration should be valid");
println!("✅ Token expiration enforcement test completed");
}
#[tokio::test]
async fn test_rate_limiting_protection() {
println!("Testing rate limiting protection...");
let max_attempts = 5;
let mut attempt_count = 0;
for i in 1..=10 {
attempt_count += 1;
if attempt_count <= max_attempts {
println!("Attempt {}: Allowed (within limit)", i);
} else {
println!("Attempt {}: Rate limited (exceeded limit)", i);
assert!(
attempt_count > max_attempts,
"Should be rate limited after {} attempts",
max_attempts
);
}
}
assert!(
attempt_count > max_attempts,
"Total attempts should exceed limit"
);
let rate_limited_attempts = attempt_count - max_attempts;
assert!(
rate_limited_attempts > 0,
"Some attempts should have been rate limited"
);
println!("✅ Rate limiting protection test completed");
}
}