use crate::{AuthError, AuthResult};
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[async_trait]
pub trait Authenticatable: Send + Sync + Clone {
type Id: Clone + Send + Sync + std::fmt::Debug + PartialEq;
type Credentials: Send + Sync;
fn id(&self) -> &Self::Id;
fn username(&self) -> &str;
fn is_active(&self) -> bool {
true
}
fn is_locked(&self) -> bool {
false
}
fn roles(&self) -> Vec<String> {
vec![]
}
fn permissions(&self) -> Vec<String> {
vec![]
}
async fn verify_credentials(&self, credentials: &Self::Credentials) -> AuthResult<bool>;
fn additional_data(&self) -> HashMap<String, serde_json::Value> {
HashMap::new()
}
}
#[async_trait]
pub trait AuthProvider<User>: Send + Sync
where
User: Authenticatable,
{
type Token: Clone + Send + Sync + std::fmt::Debug;
type Credentials: Send + Sync;
async fn authenticate(
&self,
credentials: &Self::Credentials,
) -> AuthResult<AuthenticationResult<User, Self::Token>>;
async fn validate_token(&self, token: &Self::Token) -> AuthResult<User>;
async fn refresh_token(&self, _token: &Self::Token) -> AuthResult<Self::Token> {
Err(AuthError::token_error("Token refresh not supported"))
}
async fn revoke_token(&self, _token: &Self::Token) -> AuthResult<()> {
Ok(()) }
fn provider_name(&self) -> &str;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthenticationResult<User, Token> {
pub user: User,
pub token: Token,
pub refresh_token: Option<Token>,
pub requires_mfa: bool,
pub mfa_setup: Option<MfaSetup>,
pub expires_at: Option<DateTime<Utc>>,
pub metadata: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MfaSetup {
pub secret: String,
pub qr_code_url: String,
pub backup_codes: Vec<String>,
}
#[async_trait]
pub trait AuthorizationProvider: Send + Sync {
type User: Authenticatable;
type Role: Send + Sync + Clone;
type Permission: Send + Sync + Clone;
async fn has_role(&self, user: &Self::User, role: &str) -> AuthResult<bool>;
async fn has_permission(&self, user: &Self::User, permission: &str) -> AuthResult<bool>;
async fn has_permission_with_context(
&self,
user: &Self::User,
resource: &str,
action: &str,
context: Option<&HashMap<String, serde_json::Value>>,
) -> AuthResult<bool>;
async fn has_any_role(&self, user: &Self::User, roles: &[String]) -> AuthResult<bool> {
for role in roles {
if self.has_role(user, role).await? {
return Ok(true);
}
}
Ok(false)
}
async fn has_all_roles(&self, user: &Self::User, roles: &[String]) -> AuthResult<bool> {
for role in roles {
if !self.has_role(user, role).await? {
return Ok(false);
}
}
Ok(true)
}
async fn has_any_permission(
&self,
user: &Self::User,
permissions: &[String],
) -> AuthResult<bool> {
for permission in permissions {
if self.has_permission(user, permission).await? {
return Ok(true);
}
}
Ok(false)
}
async fn has_all_permissions(
&self,
user: &Self::User,
permissions: &[String],
) -> AuthResult<bool> {
for permission in permissions {
if !self.has_permission(user, permission).await? {
return Ok(false);
}
}
Ok(true)
}
async fn get_user_roles(&self, user: &Self::User) -> AuthResult<Vec<Self::Role>>;
async fn get_user_permissions(&self, user: &Self::User) -> AuthResult<Vec<Self::Permission>>;
}
#[async_trait]
pub trait SessionStorage: Send + Sync {
type SessionId: Clone + Send + Sync + std::fmt::Debug + PartialEq;
type SessionData: Clone + Send + Sync;
async fn create_session(
&self,
data: Self::SessionData,
expires_at: DateTime<Utc>,
) -> AuthResult<Self::SessionId>;
async fn get_session(&self, id: &Self::SessionId) -> AuthResult<Option<Self::SessionData>>;
async fn update_session(
&self,
id: &Self::SessionId,
data: Self::SessionData,
expires_at: DateTime<Utc>,
) -> AuthResult<()>;
async fn delete_session(&self, id: &Self::SessionId) -> AuthResult<()>;
async fn cleanup_expired_sessions(&self) -> AuthResult<u64>;
async fn get_session_expiry(&self, id: &Self::SessionId) -> AuthResult<Option<DateTime<Utc>>>;
async fn extend_session(
&self,
id: &Self::SessionId,
expires_at: DateTime<Utc>,
) -> AuthResult<()>;
}
#[async_trait]
pub trait MfaProvider: Send + Sync {
type User: Authenticatable;
type Secret: Clone + Send + Sync;
type Code: Send + Sync;
async fn setup_mfa(&self, user: &Self::User) -> AuthResult<MfaSetup>;
async fn verify_code(
&self,
user: &Self::User,
code: &Self::Code,
secret: &Self::Secret,
) -> AuthResult<bool>;
async fn generate_backup_codes(&self, user: &Self::User) -> AuthResult<Vec<String>>;
async fn verify_backup_code(&self, user: &Self::User, code: &str) -> AuthResult<bool>;
async fn is_mfa_enabled(&self, user: &Self::User) -> AuthResult<bool>;
async fn disable_mfa(&self, user: &Self::User) -> AuthResult<()>;
}
pub trait PasswordHasher: Send + Sync {
fn hash_password(&self, password: &str) -> AuthResult<String>;
fn verify_password(&self, password: &str, hash: &str) -> AuthResult<bool>;
fn hasher_name(&self) -> &str;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserContext {
pub user_id: String,
pub username: String,
pub roles: Vec<String>,
pub permissions: Vec<String>,
pub auth_provider: String,
pub authenticated_at: DateTime<Utc>,
pub expires_at: Option<DateTime<Utc>>,
pub additional_data: HashMap<String, serde_json::Value>,
}
impl UserContext {
pub fn new(user_id: String, username: String, auth_provider: String) -> Self {
Self {
user_id,
username,
roles: vec![],
permissions: vec![],
auth_provider,
authenticated_at: Utc::now(),
expires_at: None,
additional_data: HashMap::new(),
}
}
pub fn has_role(&self, role: &str) -> bool {
self.roles.contains(&role.to_string())
}
pub fn has_permission(&self, permission: &str) -> bool {
self.permissions.contains(&permission.to_string())
}
pub fn has_any_role(&self, roles: &[String]) -> bool {
roles.iter().any(|role| self.has_role(role))
}
pub fn has_all_roles(&self, roles: &[String]) -> bool {
roles.iter().all(|role| self.has_role(role))
}
pub fn is_expired(&self) -> bool {
self.expires_at.is_some_and(|exp| Utc::now() > exp)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_user_context_creation() {
let context = UserContext::new(
"123".to_string(),
"user@example.com".to_string(),
"jwt".to_string(),
);
assert_eq!(context.user_id, "123");
assert_eq!(context.username, "user@example.com");
assert_eq!(context.auth_provider, "jwt");
assert!(context.roles.is_empty());
assert!(context.permissions.is_empty());
}
#[test]
fn test_user_context_role_checking() {
let mut context = UserContext::new(
"123".to_string(),
"user@example.com".to_string(),
"jwt".to_string(),
);
context.roles = vec!["admin".to_string(), "editor".to_string()];
assert!(context.has_role("admin"));
assert!(context.has_role("editor"));
assert!(!context.has_role("viewer"));
assert!(context.has_any_role(&["admin".to_string(), "viewer".to_string()]));
assert!(!context.has_any_role(&["viewer".to_string(), "guest".to_string()]));
assert!(context.has_all_roles(&["admin".to_string(), "editor".to_string()]));
assert!(!context.has_all_roles(&["admin".to_string(), "viewer".to_string()]));
}
#[test]
fn test_user_context_permission_checking() {
let mut context = UserContext::new(
"123".to_string(),
"user@example.com".to_string(),
"jwt".to_string(),
);
context.permissions = vec!["read".to_string(), "write".to_string()];
assert!(context.has_permission("read"));
assert!(context.has_permission("write"));
assert!(!context.has_permission("delete"));
}
#[test]
fn test_user_context_expiration() {
let mut context = UserContext::new(
"123".to_string(),
"user@example.com".to_string(),
"jwt".to_string(),
);
assert!(!context.is_expired());
context.expires_at = Some(Utc::now() - chrono::Duration::hours(1));
assert!(context.is_expired());
context.expires_at = Some(Utc::now() + chrono::Duration::hours(1));
assert!(!context.is_expired());
}
}