use crate::{AuthError, AuthFramework, providers::ProviderProfile};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthMiddlewareConfig {
pub skip_paths: Vec<String>,
pub required_roles: Vec<String>,
pub required_permissions: Vec<String>,
pub cookie_name: String,
pub header_name: String,
pub allow_query_param: bool,
pub query_param_name: String,
}
impl Default for AuthMiddlewareConfig {
fn default() -> Self {
Self {
skip_paths: vec![
"/health".to_string(),
"/api/v1/auth/login".to_string(),
"/api/v1/auth/register".to_string(),
],
required_roles: Vec::new(),
required_permissions: Vec::new(),
cookie_name: "auth_token".to_string(),
header_name: "Authorization".to_string(),
allow_query_param: false,
query_param_name: "token".to_string(),
}
}
}
pub struct UnifiedAuthValidator {
auth_framework: Arc<AuthFramework>,
config: AuthMiddlewareConfig,
}
impl UnifiedAuthValidator {
pub fn new(auth_framework: Arc<AuthFramework>, config: AuthMiddlewareConfig) -> Self {
Self {
auth_framework,
config,
}
}
pub fn with_defaults(auth_framework: Arc<AuthFramework>) -> Self {
Self::new(auth_framework, AuthMiddlewareConfig::default())
}
pub fn extract_token_from_header(&self, auth_header: Option<&str>) -> Option<String> {
if let Some(header_value) = auth_header {
if let Some(token) = header_value.strip_prefix("Bearer ") {
return Some(token.to_string());
}
if !header_value.contains(' ') {
return Some(header_value.to_string());
}
}
None
}
pub fn extract_token_from_cookie(&self, cookie_value: Option<&str>) -> Option<String> {
cookie_value.map(|value| value.to_string())
}
pub fn extract_token_from_query(&self, query_value: Option<&str>) -> Option<String> {
if self.config.allow_query_param {
query_value.map(|value| value.to_string())
} else {
None
}
}
pub fn should_skip_path(&self, path: &str) -> bool {
self.config.skip_paths.iter().any(|skip_path| {
path == skip_path || path.starts_with(&format!("{}/", skip_path))
})
}
pub async fn validate_token(&self, token: &str) -> Result<ProviderProfile, AuthError> {
let jwt_claims = self
.auth_framework
.token_manager()
.validate_jwt_token(token)?;
let user_profile = self
.auth_framework
.get_user_profile(&jwt_claims.sub)
.await?;
Ok(user_profile)
}
pub async fn validate_access(&self, user_id: &str) -> Result<(), AuthError> {
let _user_profile = self.auth_framework.get_user_profile(user_id).await?;
if !self.config.required_roles.is_empty() {
for role in &self.config.required_roles {
let has_role = self.auth_framework.user_has_role(user_id, role).await.unwrap_or(false);
if !has_role {
return Err(AuthError::Permission(
crate::errors::PermissionError::InsufficientPermissions {
required: format!("role:{}", role),
actual: "none".to_string(),
},
));
}
}
}
if !self.config.required_permissions.is_empty() {
let effective_perms = self
.auth_framework
.get_effective_permissions(user_id)
.await
.unwrap_or_default();
for required_perm in &self.config.required_permissions {
if !effective_perms.contains(required_perm) {
return Err(AuthError::Permission(
crate::errors::PermissionError::InsufficientPermissions {
required: required_perm.clone(),
actual: effective_perms.join(", "),
},
));
}
}
}
Ok(())
}
}
pub struct UnifiedAuthBuilder {
auth_framework: Arc<AuthFramework>,
config: AuthMiddlewareConfig,
}
impl UnifiedAuthBuilder {
pub fn new(auth_framework: Arc<AuthFramework>) -> Self {
Self {
auth_framework,
config: AuthMiddlewareConfig::default(),
}
}
pub fn skip_paths(mut self, paths: Vec<String>) -> Self {
self.config.skip_paths.extend(paths);
self
}
pub fn require_roles(mut self, roles: Vec<String>) -> Self {
self.config.required_roles = roles;
self
}
pub fn require_permissions(mut self, permissions: Vec<String>) -> Self {
self.config.required_permissions = permissions;
self
}
pub fn cookie_name(mut self, name: String) -> Self {
self.config.cookie_name = name;
self
}
pub fn header_name(mut self, name: String) -> Self {
self.config.header_name = name;
self
}
pub fn allow_query_param(mut self, param_name: String) -> Self {
self.config.allow_query_param = true;
self.config.query_param_name = param_name;
self
}
pub fn build(self) -> UnifiedAuthValidator {
UnifiedAuthValidator::new(self.auth_framework, self.config)
}
}
pub fn create_auth_validator(auth_framework: Arc<AuthFramework>) -> UnifiedAuthValidator {
UnifiedAuthValidator::with_defaults(auth_framework)
}
pub fn auth_validator_builder(auth_framework: Arc<AuthFramework>) -> UnifiedAuthBuilder {
UnifiedAuthBuilder::new(auth_framework)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = AuthMiddlewareConfig::default();
assert_eq!(config.cookie_name, "auth_token");
assert_eq!(config.header_name, "Authorization");
assert!(!config.allow_query_param);
assert!(config.skip_paths.contains(&"/health".to_string()));
}
#[test]
fn test_config_builder() {
let config = AuthMiddlewareConfig {
skip_paths: vec!["/api/public".to_string()],
required_roles: vec!["admin".to_string()],
cookie_name: "session_token".to_string(),
..Default::default()
};
assert_eq!(config.cookie_name, "session_token");
assert!(config.required_roles.contains(&"admin".to_string()));
assert!(config.skip_paths.contains(&"/api/public".to_string()));
}
#[test]
fn test_extract_bearer_token_from_header() {
let token = UnifiedAuthValidator::extract_token_from_header("Bearer abc123");
assert_eq!(token, Some("abc123".to_string()));
}
#[test]
fn test_extract_token_missing_bearer_prefix() {
let token = UnifiedAuthValidator::extract_token_from_header("abc123");
assert_eq!(token, None);
}
#[test]
fn test_extract_token_empty_bearer() {
let token = UnifiedAuthValidator::extract_token_from_header("Bearer ");
assert_eq!(token, Some("".to_string()));
}
#[test]
fn test_extract_token_case_sensitive_bearer() {
let token = UnifiedAuthValidator::extract_token_from_header("bearer abc123");
assert_eq!(token, None);
}
#[test]
fn test_extract_token_from_cookie_present() {
let cookie_str = "other=val; auth_token=mytoken; foo=bar";
let token = UnifiedAuthValidator::extract_token_from_cookie(cookie_str, "auth_token");
assert_eq!(token, Some("mytoken".to_string()));
}
#[test]
fn test_extract_token_from_cookie_missing() {
let cookie_str = "other=val; foo=bar";
let token = UnifiedAuthValidator::extract_token_from_cookie(cookie_str, "auth_token");
assert_eq!(token, None);
}
#[test]
fn test_extract_token_from_cookie_empty_value() {
let cookie_str = "auth_token=";
let token = UnifiedAuthValidator::extract_token_from_cookie(cookie_str, "auth_token");
assert_eq!(token, Some("".to_string()));
}
#[test]
fn test_extract_token_from_query_present() {
let query = "foo=bar&token=mytoken&baz=1";
let token = UnifiedAuthValidator::extract_token_from_query(query, "token");
assert_eq!(token, Some("mytoken".to_string()));
}
#[test]
fn test_extract_token_from_query_missing() {
let query = "foo=bar";
let token = UnifiedAuthValidator::extract_token_from_query(query, "token");
assert_eq!(token, None);
}
#[test]
fn test_extract_token_from_query_empty_value() {
let query = "token=";
let token = UnifiedAuthValidator::extract_token_from_query(query, "token");
assert_eq!(token, Some("".to_string()));
}
#[test]
fn test_should_skip_path_exact_match() {
let config = AuthMiddlewareConfig {
skip_paths: vec!["/health".to_string(), "/ready".to_string()],
..Default::default()
};
let validator = UnifiedAuthValidator {
auth_framework: Arc::new(AuthFramework::new(
crate::config::AuthConfig::default().secret("test-secret-for-unified-tests"),
)),
config,
};
assert!(validator.should_skip_path("/health"));
assert!(validator.should_skip_path("/ready"));
assert!(!validator.should_skip_path("/api/users"));
}
#[test]
fn test_should_skip_path_empty_list() {
let config = AuthMiddlewareConfig {
skip_paths: vec![],
..Default::default()
};
let validator = UnifiedAuthValidator {
auth_framework: Arc::new(AuthFramework::new(
crate::config::AuthConfig::default().secret("test-secret-for-unified-tests"),
)),
config,
};
assert!(!validator.should_skip_path("/anything"));
}
#[test]
fn test_builder_chain() {
let _env = crate::testing::test_infrastructure::TestEnvironmentGuard::new()
.with_jwt_secret("test-secret-builder");
let fw = Arc::new(AuthFramework::new(
crate::config::AuthConfig::default().secret("test-secret-builder"),
));
let validator = UnifiedAuthBuilder::new(fw)
.skip_paths(vec!["/public".into()])
.require_roles(vec!["admin".into()])
.require_permissions(vec!["read:all".into()])
.cookie_name("my_cookie".into())
.header_name("X-Token".into())
.allow_query_param("api_key".into())
.build();
assert!(validator.config.skip_paths.contains(&"/public".to_string()));
assert!(validator.config.required_roles.contains(&"admin".to_string()));
assert!(validator.config.required_permissions.contains(&"read:all".to_string()));
assert_eq!(validator.config.cookie_name, "my_cookie");
assert_eq!(validator.config.header_name, "X-Token");
assert!(validator.config.allow_query_param);
assert_eq!(validator.config.query_param_name, "api_key");
}
#[test]
fn test_convenience_functions() {
let _env = crate::testing::test_infrastructure::TestEnvironmentGuard::new()
.with_jwt_secret("test-secret-convenience");
let fw = Arc::new(AuthFramework::new(
crate::config::AuthConfig::default().secret("test-secret-convenience"),
));
let v = create_auth_validator(fw.clone());
assert_eq!(v.config.cookie_name, "auth_token");
let b = auth_validator_builder(fw);
let v2 = b.build();
assert_eq!(v2.config.header_name, "Authorization");
}
}