use crate::error::{ServerError, ServerResult};
use crate::models::{AuthRequest, AuthResponse, RefreshTokenRequest, RefreshTokenResponse, UserInfo};
use sha2::Digest;
use axum::{
extract::{Request, State, FromRequestParts},
http::{header, StatusCode, request::Parts},
middleware::Next,
response::Response,
async_trait,
};
use chrono::{Duration, Utc};
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::collections::HashMap;
use std::sync::Arc;
use argon2::{
Argon2, PasswordHash, PasswordHasher, PasswordVerifier,
password_hash::{rand_core::OsRng, SaltString}
};
use uuid::Uuid;
use base64::{Engine as _, engine::general_purpose};
use percent_encoding;
use sha2;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenClaims {
pub sub: String,
pub username: String,
pub email: Option<String>,
pub roles: Vec<String>,
pub tenant_id: Option<String>,
pub iat: i64,
pub exp: i64,
pub jti: String,
}
#[derive(Debug, Clone)]
pub struct OptionalTokenClaims(pub Option<TokenClaims>);
#[async_trait]
impl<S> FromRequestParts<S> for OptionalTokenClaims
where
S: Send + Sync,
{
type Rejection = StatusCode;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let claims = parts.extensions.get::<TokenClaims>().cloned();
Ok(OptionalTokenClaims(claims))
}
}
#[derive(Debug, Clone)]
pub struct RequiredTokenClaims(pub TokenClaims);
#[async_trait]
impl<S> FromRequestParts<S> for RequiredTokenClaims
where
S: Send + Sync,
{
type Rejection = StatusCode;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
parts.extensions
.get::<TokenClaims>()
.cloned()
.map(RequiredTokenClaims)
.ok_or(StatusCode::UNAUTHORIZED)
}
}
#[derive(Clone)]
pub struct AuthManager {
encoding_key: EncodingKey,
decoding_key: DecodingKey,
token_expiration: Duration,
user_store: Arc<dyn UserStore>,
}
#[async_trait::async_trait]
pub trait UserStore: Send + Sync {
async fn authenticate(&self, request: AuthRequest) -> ServerResult<UserInfo>;
async fn get_user(&self, user_id: &str) -> ServerResult<Option<UserInfo>>;
async fn validate_refresh_token(&self, refresh_token: &str) -> ServerResult<UserInfo>;
async fn store_refresh_token(&self, user_id: &str, refresh_token: &str) -> ServerResult<()>;
async fn revoke_refresh_token(&self, refresh_token: &str) -> ServerResult<()>;
}
#[derive(Clone)]
pub struct OidcUserStore {
provider_config: OidcProviderConfig,
client: reqwest::Client,
user_cache: Arc<parking_lot::RwLock<HashMap<String, CachedUser>>>,
refresh_tokens: Arc<parking_lot::RwLock<HashMap<String, String>>>, }
#[derive(Debug, Clone)]
pub struct OidcProviderConfig {
pub issuer_url: String,
pub client_id: String,
pub client_secret: String,
pub redirect_uri: String,
pub scopes: Vec<String>,
pub enable_pkce: bool,
pub token_endpoint: Option<String>,
pub authorization_endpoint: Option<String>,
pub userinfo_endpoint: Option<String>,
pub jwks_uri: Option<String>,
}
#[derive(Clone)]
struct CachedUser {
user_info: UserInfo,
expires_at: chrono::DateTime<chrono::Utc>,
}
#[derive(Debug, Deserialize)]
pub struct OidcTokenResponse {
pub access_token: String,
pub token_type: String,
pub expires_in: Option<u64>,
pub refresh_token: Option<String>,
pub id_token: Option<String>,
pub scope: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct OidcUserInfo {
pub sub: String,
pub name: Option<String>,
pub email: Option<String>,
pub preferred_username: Option<String>,
pub groups: Option<Vec<String>>,
pub roles: Option<Vec<String>>,
}
#[derive(Debug, Deserialize)]
pub struct OidcDiscoveryDocument {
pub issuer: String,
pub authorization_endpoint: String,
pub token_endpoint: String,
pub userinfo_endpoint: Option<String>,
pub jwks_uri: String,
pub scopes_supported: Option<Vec<String>>,
pub response_types_supported: Option<Vec<String>>,
pub grant_types_supported: Option<Vec<String>>,
pub token_endpoint_auth_methods_supported: Option<Vec<String>>,
}
#[derive(Debug, Clone)]
pub struct OidcAuthRequest {
pub code: String,
pub code_verifier: Option<String>,
pub state: String,
pub redirect_uri: String,
}
#[derive(Debug, Clone)]
pub struct OidcAuthResult {
pub user_info: UserInfo,
pub access_token: String,
pub refresh_token: Option<String>,
pub expires_in: Option<u64>,
pub id_token: Option<String>,
}
impl OidcUserStore {
pub fn new(config: OidcProviderConfig) -> Self {
Self {
provider_config: config.clone(),
client: reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()
.expect("Failed to create HTTP client"),
user_cache: Arc::new(parking_lot::RwLock::new(HashMap::new())),
refresh_tokens: Arc::new(parking_lot::RwLock::new(HashMap::new())),
}
}
pub async fn discover_endpoints(&self) -> ServerResult<OidcDiscoveryDocument> {
let discovery_url = format!("{}/.well-known/openid_configuration", self.provider_config.issuer_url);
let response = self.client
.get(&discovery_url)
.send()
.await
.map_err(|e| ServerError::internal(format!("Failed to fetch discovery document: {}", e)))?;
if !response.status().is_success() {
return Err(ServerError::internal(format!(
"Discovery request failed with status: {}",
response.status()
)));
}
let discovery: OidcDiscoveryDocument = response
.json()
.await
.map_err(|e| ServerError::internal(format!("Failed to parse discovery document: {}", e)))?;
tracing::info!("Discovered OIDC endpoints for issuer: {}", discovery.issuer);
Ok(discovery)
}
pub fn get_authorization_url(&self, _state: &str, code_verifier: Option<&str>) -> ServerResult<String> {
let auth_endpoint = self.provider_config.authorization_endpoint.as_ref()
.ok_or_else(|| ServerError::internal("Authorization endpoint not configured"))?;
let mut params: HashMap<String, String> = std::collections::HashMap::new();
params.insert("response_type".to_string(), "code".to_string());
params.insert("client_id".to_string(), self.provider_config.client_id.clone());
params.insert("redirect_uri".to_string(), self.provider_config.redirect_uri.clone());
if let Some(verifier) = code_verifier {
if self.provider_config.enable_pkce {
let challenge_bytes = sha2::Sha256::digest(verifier.as_bytes());
let challenge = general_purpose::STANDARD.encode(challenge_bytes);
params.insert("code_challenge".to_string(), challenge);
params.insert("code_challenge_method".to_string(), "S256".to_string());
}
}
let query_string = params.iter()
.map(|(k, v)| format!("{}={}", percent_encoding::utf8_percent_encode(k, &percent_encoding::NON_ALPHANUMERIC), percent_encoding::utf8_percent_encode(v, &percent_encoding::NON_ALPHANUMERIC)))
.collect::<Vec<_>>()
.join("&");
Ok(format!("{}?{}", auth_endpoint, query_string))
}
pub async fn exchange_code_for_tokens(&self, request: OidcAuthRequest) -> ServerResult<OidcTokenResponse> {
let token_endpoint = self.provider_config.token_endpoint.as_ref()
.ok_or_else(|| ServerError::internal("Token endpoint not configured"))?;
let mut params = std::collections::HashMap::new();
params.insert("grant_type", "authorization_code");
params.insert("code", &request.code);
params.insert("redirect_uri", &request.redirect_uri);
params.insert("client_id", &self.provider_config.client_id);
params.insert("client_secret", &self.provider_config.client_secret);
if let Some(verifier) = &request.code_verifier {
params.insert("code_verifier", verifier);
}
let response = self.client
.post(token_endpoint)
.form(¶ms)
.send()
.await
.map_err(|e| ServerError::internal(format!("Failed to exchange code for tokens: {}", e)))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
return Err(ServerError::internal(format!(
"Token exchange failed: {} - {}",
status,
error_text
)));
}
let token_response: OidcTokenResponse = response
.json()
.await
.map_err(|e| ServerError::internal(format!("Failed to parse token response: {}", e)))?;
Ok(token_response)
}
pub async fn get_user_info(&self, access_token: &str) -> ServerResult<OidcUserInfo> {
let cache_key = format!("userinfo:{}", access_token);
{
let cache = self.user_cache.read();
if let Some(cached) = cache.get(&cache_key) {
if cached.expires_at > chrono::Utc::now() {
return Ok(OidcUserInfo {
sub: cached.user_info.id.clone(),
name: Some(cached.user_info.username.clone()),
email: cached.user_info.email.clone(),
preferred_username: Some(cached.user_info.username.clone()),
groups: None,
roles: Some(cached.user_info.roles.clone()),
});
}
}
}
let userinfo_endpoint = self.provider_config.userinfo_endpoint.as_ref()
.ok_or_else(|| ServerError::internal("UserInfo endpoint not configured"))?;
let response = self.client
.get(userinfo_endpoint)
.header("Authorization", format!("Bearer {}", access_token))
.send()
.await
.map_err(|e| ServerError::internal(format!("Failed to fetch user info: {}", e)))?;
if !response.status().is_success() {
return Err(ServerError::internal(format!(
"User info request failed: {}",
response.status()
)));
}
let user_info: OidcUserInfo = response
.json()
.await
.map_err(|e| ServerError::internal(format!("Failed to parse user info: {}", e)))?;
let expires_at = chrono::Utc::now() + chrono::Duration::minutes(15);
let cached_user = CachedUser {
user_info: UserInfo {
id: user_info.sub.clone(),
username: user_info.preferred_username.clone().unwrap_or_else(|| user_info.sub.clone()),
email: user_info.email.clone(),
roles: user_info.roles.clone().unwrap_or_default(),
tenant_id: None,
},
expires_at,
};
{
let mut cache = self.user_cache.write();
cache.insert(cache_key, cached_user);
}
Ok(user_info)
}
pub fn validate_id_token(&self, id_token: &str) -> ServerResult<TokenClaims> {
let parts: Vec<&str> = id_token.split('.').collect();
if parts.len() != 3 {
return Err(ServerError::auth("Invalid ID token format"));
}
let payload = general_purpose::STANDARD.decode(parts[1])
.map_err(|_| ServerError::auth("Failed to decode ID token payload"))?;
let claims: serde_json::Value = serde_json::from_slice(&payload)
.map_err(|_| ServerError::auth("Failed to parse ID token claims"))?;
let sub = claims.get("sub")
.and_then(|v| v.as_str())
.ok_or_else(|| ServerError::auth("Missing subject claim"))?;
let exp = claims.get("exp")
.and_then(|v| v.as_i64())
.ok_or_else(|| ServerError::auth("Missing expiration claim"))?;
let iat = claims.get("iat")
.and_then(|v| v.as_i64())
.ok_or_else(|| ServerError::auth("Missing issued at claim"))?;
let now = chrono::Utc::now().timestamp();
if exp <= now {
return Err(ServerError::auth("ID token has expired"));
}
Ok(TokenClaims {
sub: sub.to_string(),
username: sub.to_string(),
email: claims.get("email").and_then(|v| v.as_str()).map(|s| s.to_string()),
roles: claims.get("roles")
.and_then(|v| v.as_array())
.map(|arr| arr.iter().filter_map(|v| v.as_str().map(|s| s.to_string())).collect())
.unwrap_or_default(),
tenant_id: None,
iat,
exp,
jti: uuid::Uuid::new_v4().to_string(),
})
}
pub async fn authenticate_with_code(&self, request: OidcAuthRequest) -> ServerResult<OidcAuthResult> {
let token_response = self.exchange_code_for_tokens(request.clone()).await?;
let user_info = self.get_user_info(&token_response.access_token).await?;
let user = UserInfo {
id: user_info.sub.clone(),
username: user_info.preferred_username.clone().unwrap_or_else(|| user_info.sub.clone()),
email: user_info.email.clone(),
roles: user_info.roles.clone().unwrap_or_default(),
tenant_id: None,
};
if let Some(refresh_token) = &token_response.refresh_token {
self.store_refresh_token(&user.id, refresh_token).await?;
}
Ok(OidcAuthResult {
user_info: user,
access_token: token_response.access_token,
refresh_token: token_response.refresh_token,
expires_in: token_response.expires_in,
id_token: token_response.id_token,
})
}
pub async fn refresh_access_token(&self, refresh_token: &str) -> ServerResult<OidcTokenResponse> {
let token_endpoint = self.provider_config.token_endpoint.as_ref()
.ok_or_else(|| ServerError::internal("Token endpoint not configured"))?;
let mut params = std::collections::HashMap::new();
params.insert("grant_type", "refresh_token");
params.insert("refresh_token", refresh_token);
params.insert("client_id", &self.provider_config.client_id);
params.insert("client_secret", &self.provider_config.client_secret);
let response = self.client
.post(token_endpoint)
.form(¶ms)
.send()
.await
.map_err(|e| ServerError::internal(format!("Failed to refresh token: {}", e)))?;
if !response.status().is_success() {
return Err(ServerError::internal(format!(
"Token refresh failed: {}",
response.status()
)));
}
let token_response: OidcTokenResponse = response
.json()
.await
.map_err(|e| ServerError::internal(format!("Failed to parse refresh response: {}", e)))?;
Ok(token_response)
}
pub async fn revoke_token(&self, token: &str) -> ServerResult<()> {
self.refresh_tokens.write().remove(token);
Ok(())
}
}
pub struct InMemoryUserStore {
users: Arc<parking_lot::RwLock<HashMap<String, UserRecord>>>,
refresh_tokens: Arc<parking_lot::RwLock<HashMap<String, String>>>, }
#[derive(Clone)]
pub struct UserRecord {
id: String,
username: String,
password_hash: String,
email: Option<String>,
roles: Vec<String>,
tenant_id: Option<String>,
failed_login_attempts: u32,
locked_until: Option<chrono::DateTime<chrono::Utc>>,
}
#[async_trait::async_trait]
impl UserStore for OidcUserStore {
async fn authenticate(&self, _request: AuthRequest) -> ServerResult<UserInfo> {
Err(ServerError::auth("OIDC user store doesn't support username/password authentication"))
}
async fn get_user(&self, user_id: &str) -> ServerResult<Option<UserInfo>> {
let cache_key = format!("user:{}", user_id);
{
let cache = self.user_cache.read();
if let Some(cached) = cache.get(&cache_key) {
if cached.expires_at > chrono::Utc::now() {
return Ok(Some(cached.user_info.clone()));
}
}
}
Ok(None)
}
async fn validate_refresh_token(&self, refresh_token: &str) -> ServerResult<UserInfo> {
let _user_id = {
let refresh_tokens = self.refresh_tokens.read();
refresh_tokens.get(refresh_token)
.ok_or_else(|| ServerError::auth("Invalid refresh token"))?
.clone()
};
let token_response = self.refresh_access_token(refresh_token).await?;
let user_info = self.get_user_info(&token_response.access_token).await?;
Ok(UserInfo {
id: user_info.sub.clone(),
username: user_info.preferred_username.clone().unwrap_or_else(|| user_info.sub.clone()),
email: user_info.email.clone(),
roles: user_info.roles.clone().unwrap_or_default(),
tenant_id: None,
})
}
async fn store_refresh_token(&self, user_id: &str, refresh_token: &str) -> ServerResult<()> {
let mut refresh_tokens = self.refresh_tokens.write();
refresh_tokens.insert(refresh_token.to_string(), user_id.to_string());
Ok(())
}
async fn revoke_refresh_token(&self, refresh_token: &str) -> ServerResult<()> {
let mut refresh_tokens = self.refresh_tokens.write();
refresh_tokens.remove(refresh_token);
Ok(())
}
}
impl AuthManager {
pub fn new(jwt_secret: &str, token_expiration: Duration, user_store: Arc<dyn UserStore>) -> Self {
Self {
encoding_key: EncodingKey::from_secret(jwt_secret.as_ref()),
decoding_key: DecodingKey::from_secret(jwt_secret.as_ref()),
token_expiration,
user_store,
}
}
pub async fn authenticate(&self, request: AuthRequest) -> ServerResult<AuthResponse> {
let user = self.user_store.authenticate(request).await?;
let access_token = self.generate_access_token(&user)?;
let refresh_token = self.generate_refresh_token();
self.user_store.store_refresh_token(&user.id, &refresh_token).await?;
Ok(AuthResponse {
access_token,
token_type: "Bearer".to_string(),
expires_in: self.token_expiration.num_seconds() as u64,
refresh_token: Some(refresh_token),
user,
})
}
pub async fn refresh_token(&self, request: RefreshTokenRequest) -> ServerResult<RefreshTokenResponse> {
let user = self.user_store.validate_refresh_token(&request.refresh_token).await?;
let access_token = self.generate_access_token(&user)?;
let new_refresh_token = self.generate_refresh_token();
self.user_store.store_refresh_token(&user.id, &new_refresh_token).await?;
self.user_store.revoke_refresh_token(&request.refresh_token).await?;
Ok(RefreshTokenResponse {
access_token,
token_type: "Bearer".to_string(),
expires_in: self.token_expiration.num_seconds() as u64,
refresh_token: Some(new_refresh_token),
})
}
pub fn validate_token(&self, token: &str) -> ServerResult<TokenClaims> {
let mut validation = Validation::default();
validation.algorithms = vec![jsonwebtoken::Algorithm::HS256];
validation.validate_exp = true;
validation.validate_nbf = true;
validation.leeway = 0;
let token_data = decode::<TokenClaims>(token, &self.decoding_key, &validation)
.map_err(|e| ServerError::auth(format!("Invalid token: {}", e)))?;
Ok(token_data.claims)
}
fn generate_access_token(&self, user: &UserInfo) -> ServerResult<String> {
let now = Utc::now();
let claims = TokenClaims {
sub: user.id.clone(),
username: user.username.clone(),
email: user.email.clone(),
roles: user.roles.clone(),
tenant_id: user.tenant_id.clone(),
iat: now.timestamp(),
exp: (now + self.token_expiration).timestamp(),
jti: Uuid::new_v4().to_string(),
};
encode(&Header::default(), &claims, &self.encoding_key)
.map_err(|e| ServerError::internal(format!("Failed to generate token: {}", e)))
}
fn generate_refresh_token(&self) -> String {
use rand::Rng;
let mut token = String::with_capacity(64);
let mut rng = rand::thread_rng();
for _ in 0..64 {
let chars = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
token.push(chars[rng.gen_range(0..chars.len())] as char);
}
token
}
pub fn has_role(&self, claims: &TokenClaims, required_role: &str) -> bool {
claims.roles.contains(&required_role.to_string())
}
pub fn has_any_role(&self, claims: &TokenClaims, required_roles: &[&str]) -> bool {
let user_roles: HashSet<String> = claims.roles.iter().cloned().collect();
required_roles.iter().any(|role| user_roles.contains(*role))
}
pub fn has_tenant_access(&self, claims: &TokenClaims, tenant_id: &str) -> bool {
match &claims.tenant_id {
Some(user_tenant) => user_tenant == tenant_id,
None => false,
}
}
}
impl InMemoryUserStore {
pub fn new() -> Self {
Self {
users: Arc::new(parking_lot::RwLock::new(HashMap::new())),
refresh_tokens: Arc::new(parking_lot::RwLock::new(HashMap::new())),
}
}
pub fn with_default_admin() -> Self {
let store = Self::new();
let admin_password = "admin123"; let admin_user = UserRecord {
id: "admin".to_string(),
username: "admin".to_string(),
password_hash: hash_password_secure(admin_password).expect("Failed to hash admin password"),
email: Some("admin@fortress-db.com".to_string()),
roles: vec!["admin".to_string(), "user".to_string()],
tenant_id: None,
failed_login_attempts: 0,
locked_until: None,
};
{
let mut users = store.users.write();
users.insert("admin".to_string(), admin_user);
}
store
}
pub fn add_user(&self, user: UserRecord) {
let mut users = self.users.write();
users.insert(user.username.clone(), user);
}
}
#[async_trait::async_trait]
impl UserStore for InMemoryUserStore {
async fn authenticate(&self, request: AuthRequest) -> ServerResult<UserInfo> {
let mut users = self.users.write();
let user_record = users.get_mut(&request.username)
.ok_or_else(|| ServerError::auth("Invalid username or password"))?;
if let Some(locked_until) = user_record.locked_until {
if chrono::Utc::now() < locked_until {
return Err(ServerError::auth("Account is temporarily locked due to multiple failed login attempts"));
} else {
user_record.locked_until = None;
user_record.failed_login_attempts = 0;
}
}
match verify_password_secure(&request.password, &user_record.password_hash) {
Ok(true) => {
user_record.failed_login_attempts = 0;
user_record.locked_until = None;
Ok(UserInfo {
id: user_record.id.clone(),
username: user_record.username.clone(),
email: user_record.email.clone(),
roles: user_record.roles.clone(),
tenant_id: user_record.tenant_id.clone(),
})
}
Ok(false) => {
user_record.failed_login_attempts += 1;
if user_record.failed_login_attempts >= 5 {
user_record.locked_until = Some(chrono::Utc::now() + chrono::Duration::minutes(30));
}
Err(ServerError::auth("Invalid username or password"))
}
Err(e) => {
tracing::error!("Password verification error: {}", e);
Err(ServerError::auth("Authentication service error"))
}
}
}
async fn get_user(&self, user_id: &str) -> ServerResult<Option<UserInfo>> {
let users = self.users.read();
for user_record in users.values() {
if user_record.id == user_id {
return Ok(Some(UserInfo {
id: user_record.id.clone(),
username: user_record.username.clone(),
email: user_record.email.clone(),
roles: user_record.roles.clone(),
tenant_id: user_record.tenant_id.clone(),
}));
}
}
Ok(None)
}
async fn validate_refresh_token(&self, refresh_token: &str) -> ServerResult<UserInfo> {
let user_id = {
let refresh_tokens = self.refresh_tokens.read();
refresh_tokens.get(refresh_token)
.ok_or_else(|| ServerError::auth("Invalid refresh token"))?
.clone()
};
self.get_user(&user_id).await
.map_err(|_| ServerError::auth("User not found"))?
.ok_or_else(|| ServerError::auth("User not found"))
}
async fn store_refresh_token(&self, user_id: &str, refresh_token: &str) -> ServerResult<()> {
let mut refresh_tokens = self.refresh_tokens.write();
refresh_tokens.insert(refresh_token.to_string(), user_id.to_string());
Ok(())
}
async fn revoke_refresh_token(&self, refresh_token: &str) -> ServerResult<()> {
let mut refresh_tokens = self.refresh_tokens.write();
refresh_tokens.remove(refresh_token);
Ok(())
}
}
fn hash_password_secure(password: &str) -> Result<String, argon2::password_hash::Error> {
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
let password_hash = argon2.hash_password(password.as_bytes(), &salt)?;
Ok(password_hash.to_string())
}
fn verify_password_secure(password: &str, hash: &str) -> Result<bool, argon2::password_hash::Error> {
let parsed_hash = PasswordHash::new(hash)?;
let argon2 = Argon2::default();
Ok(argon2.verify_password(password.as_bytes(), &parsed_hash).is_ok())
}
#[deprecated(note = "Use hash_password_secure instead")]
fn hash_password_legacy(password: &str) -> String {
use sha2::{Sha256, Digest};
let mut hasher = Sha256::new();
hasher.update(password.as_bytes());
format!("{:x}", hasher.finalize())
}
#[deprecated(note = "Use verify_password_secure instead")]
#[allow(deprecated)]
fn verify_password_legacy(password: &str, hash: &str) -> bool {
hash_password_legacy(password) == hash
}
#[deprecated(note = "Use hash_password_secure instead")]
#[allow(deprecated)]
fn hash_password_legacy_test(password: &str) -> String {
hash_password_legacy(password)
}
pub async fn auth_middleware(
State(auth_manager): State<Arc<AuthManager>>,
mut request: Request,
next: Next,
) -> Result<Response, StatusCode> {
let auth_header = request
.headers()
.get(header::AUTHORIZATION)
.and_then(|h| h.to_str().ok());
if let Some(auth_header) = auth_header {
if let Some(token) = auth_header.strip_prefix("Bearer ") {
match auth_manager.validate_token(token) {
Ok(claims) => {
request.extensions_mut().insert(claims);
return Ok(next.run(request).await);
}
Err(_) => {
return Err(StatusCode::UNAUTHORIZED);
}
}
}
}
Err(StatusCode::UNAUTHORIZED)
}
pub fn require_role(role: &'static str) -> impl Fn(&Request) -> bool {
move |request: &Request| {
if let Some(claims) = request.extensions().get::<TokenClaims>() {
claims.roles.contains(&role.to_string())
} else {
false
}
}
}
pub fn require_any_role(roles: &'static [&'static str]) -> impl Fn(&Request) -> bool {
let required_roles: HashSet<String> = roles.iter().map(|&r| r.to_string()).collect();
move |request: &Request| {
if let Some(claims) = request.extensions().get::<TokenClaims>() {
let user_roles: HashSet<String> = claims.roles.iter().cloned().collect();
required_roles.iter().any(|role| user_roles.contains(role))
} else {
false
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_secure_password_hashing() {
let password = "test123";
let hash = hash_password_secure(password).unwrap();
assert!(verify_password_secure(password, &hash).unwrap());
assert!(!verify_password_secure("wrong", &hash).unwrap());
let hash2 = hash_password_secure(password).unwrap();
assert_ne!(hash, hash2);
}
#[test]
fn test_legacy_password_hashing() {
let password = "test123";
let hash = hash_password_legacy(password);
assert!(verify_password_legacy(password, &hash));
assert!(!verify_password_legacy("wrong", &hash));
}
#[tokio::test]
async fn test_in_memory_user_store() {
let store = InMemoryUserStore::with_default_admin();
let auth_request = AuthRequest {
username: "admin".to_string(),
password: "admin123".to_string(),
tenant_id: None,
};
let user = store.authenticate(auth_request).await.unwrap();
assert_eq!(user.username, "admin");
assert!(user.roles.contains(&"admin".to_string()));
}
#[tokio::test]
async fn test_token_generation() {
let store = Arc::new(InMemoryUserStore::new());
let auth_manager = AuthManager::new(
"test_secret",
Duration::hours(1),
store,
);
let auth_request = AuthRequest {
username: "admin".to_string(),
password: "admin123".to_string(),
tenant_id: None,
};
let auth_response = auth_manager.authenticate(auth_request).await.unwrap();
assert!(!auth_response.access_token.is_empty());
assert_eq!(auth_response.token_type, "Bearer");
let claims = auth_manager.validate_token(&auth_response.access_token).unwrap();
assert_eq!(claims.username, "admin");
}
}
#[cfg(test)]
mod auth_security_tests {
use super::*;
#[test]
fn test_secure_password_hashing() {
let password = "test123";
let hash = hash_password_secure(password).unwrap();
assert!(verify_password_secure(password, &hash).unwrap());
assert!(!verify_password_secure("wrong", &hash).unwrap());
let hash2 = hash_password_secure(password).unwrap();
assert_ne!(hash, hash2);
println!("✅ Secure password hashing test passed");
}
#[test]
fn test_argon2id_security() {
let password = "secure_password_123!";
let hash = hash_password_secure(password).unwrap();
assert!(hash.starts_with("$argon2id$"));
assert!(hash.len() > 50); assert!(hash.contains('$'));
println!("✅ Argon2id security test passed");
}
}