use crate::{RragError, RragResult};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use tokio::sync::RwLock;
use uuid::Uuid;
use chrono::{DateTime, Duration, Utc};
use sha2::{Sha256, Digest};
use hmac::{Hmac, Mac};
use base64::{Engine as _, engine::general_purpose};
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, TokenData, Validation};
use bcrypt::{hash, verify, DEFAULT_COST};
use argon2::{Argon2, PasswordHash, PasswordHasher, PasswordVerifier};
use argon2::password_hash::{SaltString, rand_core::OsRng};
#[derive(Clone)]
pub struct AuthenticationService {
providers: Arc<RwLock<Vec<Arc<dyn AuthProvider>>>>,
config: AuthConfig,
user_store: Arc<dyn UserStore>,
session_store: Arc<dyn SessionStore>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthConfig {
pub jwt_secret: String,
pub jwt_expiration: i64,
pub api_key_length: usize,
pub enable_mfa: bool,
pub require_email_verification: bool,
pub password_policy: PasswordPolicy,
pub session_timeout: i64,
pub max_failed_attempts: u32,
pub lockout_duration: i64,
pub enable_lockout: bool,
}
impl Default for AuthConfig {
fn default() -> Self {
Self {
jwt_secret: "default-secret-change-me".to_string(),
jwt_expiration: 3600, api_key_length: 32,
enable_mfa: false,
require_email_verification: false,
password_policy: PasswordPolicy::default(),
session_timeout: 86400, max_failed_attempts: 5,
lockout_duration: 900, enable_lockout: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PasswordPolicy {
pub min_length: usize,
pub require_uppercase: bool,
pub require_lowercase: bool,
pub require_numbers: bool,
pub require_symbols: bool,
pub min_entropy_bits: f64,
pub common_password_check: bool,
pub max_password_age_days: Option<u32>,
}
impl Default for PasswordPolicy {
fn default() -> Self {
Self {
min_length: 8,
require_uppercase: true,
require_lowercase: true,
require_numbers: true,
require_symbols: false,
min_entropy_bits: 50.0,
common_password_check: true,
max_password_age_days: Some(90),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Credentials {
ApiKey {
key: String,
},
UsernamePassword {
username: String,
password: String,
totp_code: Option<String>,
},
JwtToken {
token: String,
},
OAuth2 {
provider: String,
access_token: String,
refresh_token: Option<String>,
},
WebAuthn {
assertion: String,
client_data: String,
},
Session {
session_id: String,
},
}
#[derive(Debug, Clone)]
pub struct AuthResult {
pub success: bool,
pub user: Option<User>,
pub session: Option<Session>,
pub method: String,
pub mfa_required: bool,
pub mfa_methods: Vec<String>,
pub error: Option<String>,
pub expires_at: Option<DateTime<Utc>>,
}
impl AuthResult {
pub fn success(user: User) -> Self {
Self {
success: true,
user: Some(user),
session: None,
method: "unknown".to_string(),
mfa_required: false,
mfa_methods: Vec::new(),
error: None,
expires_at: None,
}
}
pub fn failure(error: impl Into<String>) -> Self {
Self {
success: false,
user: None,
session: None,
method: "unknown".to_string(),
mfa_required: false,
mfa_methods: Vec::new(),
error: Some(error.into()),
expires_at: None,
}
}
pub fn mfa_required(methods: Vec<String>) -> Self {
Self {
success: false,
user: None,
session: None,
method: "unknown".to_string(),
mfa_required: true,
mfa_methods: methods,
error: None,
expires_at: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct User {
pub id: String,
pub username: String,
pub email: Option<String>,
pub roles: Vec<Role>,
pub permissions: HashSet<Permission>,
pub metadata: HashMap<String, serde_json::Value>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
pub last_login: Option<DateTime<Utc>>,
pub email_verified: bool,
pub mfa_enabled: bool,
pub mfa_methods: Vec<MfaMethod>,
pub account_locked: bool,
pub locked_until: Option<DateTime<Utc>>,
pub failed_attempts: u32,
}
impl User {
pub fn new(id: String, username: String, email: Option<String>) -> Self {
Self {
id,
username,
email,
roles: Vec::new(),
permissions: HashSet::new(),
metadata: HashMap::new(),
created_at: Utc::now(),
updated_at: Utc::now(),
last_login: None,
email_verified: false,
mfa_enabled: false,
mfa_methods: Vec::new(),
account_locked: false,
locked_until: None,
failed_attempts: 0,
}
}
pub fn anonymous() -> Self {
Self::new(
"anonymous".to_string(),
"anonymous".to_string(),
None,
)
}
pub fn add_role(&mut self, role: Role) {
self.roles.push(role);
self.updated_at = Utc::now();
}
pub fn add_permission(&mut self, permission: Permission) {
self.permissions.insert(permission);
self.updated_at = Utc::now();
}
pub fn has_permission(&self, permission: &Permission) -> bool {
self.permissions.contains(permission) ||
self.roles.iter().any(|role| role.permissions.contains(permission))
}
pub fn has_role(&self, role_name: &str) -> bool {
self.roles.iter().any(|role| role.name == role_name)
}
pub fn is_locked(&self) -> bool {
self.account_locked &&
self.locked_until.map_or(true, |until| Utc::now() < until)
}
pub fn unlock(&mut self) {
self.account_locked = false;
self.locked_until = None;
self.failed_attempts = 0;
self.updated_at = Utc::now();
}
pub fn increment_failed_attempts(&mut self, max_attempts: u32, lockout_duration: i64) {
self.failed_attempts += 1;
if self.failed_attempts >= max_attempts {
self.account_locked = true;
self.locked_until = Some(Utc::now() + Duration::seconds(lockout_duration));
}
self.updated_at = Utc::now();
}
pub fn reset_failed_attempts(&mut self) {
self.failed_attempts = 0;
self.updated_at = Utc::now();
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct Role {
pub name: String,
pub description: String,
pub permissions: HashSet<Permission>,
pub created_at: DateTime<Utc>,
}
impl Role {
pub fn new(name: String, description: String) -> Self {
Self {
name,
description,
permissions: HashSet::new(),
created_at: Utc::now(),
}
}
pub fn add_permission(&mut self, permission: Permission) {
self.permissions.insert(permission);
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum Permission {
SystemAdmin,
SystemRead,
SystemWrite,
UserCreate,
UserRead,
UserUpdate,
UserDelete,
UserManageRoles,
DocumentCreate,
DocumentRead,
DocumentUpdate,
DocumentDelete,
DocumentIngest,
QueryExecute,
QueryAdvanced,
QueryBulk,
ConfigRead,
ConfigWrite,
ConfigSecrets,
MonitoringRead,
MonitoringWrite,
ApiKeyCreate,
ApiKeyRevoke,
Custom(String),
}
impl std::fmt::Display for Permission {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::SystemAdmin => write!(f, "system:admin"),
Self::SystemRead => write!(f, "system:read"),
Self::SystemWrite => write!(f, "system:write"),
Self::UserCreate => write!(f, "user:create"),
Self::UserRead => write!(f, "user:read"),
Self::UserUpdate => write!(f, "user:update"),
Self::UserDelete => write!(f, "user:delete"),
Self::UserManageRoles => write!(f, "user:manage_roles"),
Self::DocumentCreate => write!(f, "document:create"),
Self::DocumentRead => write!(f, "document:read"),
Self::DocumentUpdate => write!(f, "document:update"),
Self::DocumentDelete => write!(f, "document:delete"),
Self::DocumentIngest => write!(f, "document:ingest"),
Self::QueryExecute => write!(f, "query:execute"),
Self::QueryAdvanced => write!(f, "query:advanced"),
Self::QueryBulk => write!(f, "query:bulk"),
Self::ConfigRead => write!(f, "config:read"),
Self::ConfigWrite => write!(f, "config:write"),
Self::ConfigSecrets => write!(f, "config:secrets"),
Self::MonitoringRead => write!(f, "monitoring:read"),
Self::MonitoringWrite => write!(f, "monitoring:write"),
Self::ApiKeyCreate => write!(f, "api_key:create"),
Self::ApiKeyRevoke => write!(f, "api_key:revoke"),
Self::Custom(name) => write!(f, "custom:{}", name),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum MfaMethod {
Totp {
secret: String,
backup_codes: Vec<String>,
},
WebAuthn {
credential_id: String,
public_key: String,
},
Sms {
phone_number: String,
},
Email {
email: String,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Session {
pub id: String,
pub user_id: String,
pub created_at: DateTime<Utc>,
pub expires_at: DateTime<Utc>,
pub last_accessed: DateTime<Utc>,
pub ip_address: Option<String>,
pub user_agent: Option<String>,
pub metadata: HashMap<String, serde_json::Value>,
}
impl Session {
pub fn new(user_id: String, timeout_seconds: i64) -> Self {
let now = Utc::now();
Self {
id: Uuid::new_v4().to_string(),
user_id,
created_at: now,
expires_at: now + Duration::seconds(timeout_seconds),
last_accessed: now,
ip_address: None,
user_agent: None,
metadata: HashMap::new(),
}
}
pub fn is_expired(&self) -> bool {
Utc::now() > self.expires_at
}
pub fn refresh(&mut self, timeout_seconds: i64) {
self.expires_at = Utc::now() + Duration::seconds(timeout_seconds);
self.last_accessed = Utc::now();
}
}
#[async_trait]
pub trait AuthProvider: Send + Sync {
async fn authenticate(&self, credentials: &Credentials) -> RragResult<AuthResult>;
async fn validate_token(&self, token: &str) -> RragResult<User>;
async fn supports_credentials(&self, credentials: &Credentials) -> bool;
fn provider_name(&self) -> &str;
async fn health_check(&self) -> RragResult<crate::security::HealthStatus>;
}
#[async_trait]
pub trait UserStore: Send + Sync {
async fn get_user_by_id(&self, id: &str) -> RragResult<Option<User>>;
async fn get_user_by_username(&self, username: &str) -> RragResult<Option<User>>;
async fn get_user_by_email(&self, email: &str) -> RragResult<Option<User>>;
async fn create_user(&self, user: &User) -> RragResult<()>;
async fn update_user(&self, user: &User) -> RragResult<()>;
async fn delete_user(&self, id: &str) -> RragResult<()>;
async fn get_user_roles(&self, user_id: &str) -> RragResult<Vec<Role>>;
async fn add_user_role(&self, user_id: &str, role: &Role) -> RragResult<()>;
async fn remove_user_role(&self, user_id: &str, role_name: &str) -> RragResult<()>;
}
#[async_trait]
pub trait SessionStore: Send + Sync {
async fn create_session(&self, session: &Session) -> RragResult<()>;
async fn get_session(&self, session_id: &str) -> RragResult<Option<Session>>;
async fn update_session(&self, session: &Session) -> RragResult<()>;
async fn delete_session(&self, session_id: &str) -> RragResult<()>;
async fn delete_user_sessions(&self, user_id: &str) -> RragResult<()>;
async fn cleanup_expired_sessions(&self) -> RragResult<u64>;
}
pub struct ApiKeyAuth {
api_keys: Arc<RwLock<HashMap<String, (String, Vec<Permission>)>>>, user_store: Arc<dyn UserStore>,
}
impl ApiKeyAuth {
pub fn new(user_store: Arc<dyn UserStore>) -> Self {
Self {
api_keys: Arc::new(RwLock::new(HashMap::new())),
user_store,
}
}
pub async fn create_api_key(&self, user_id: &str, permissions: Vec<Permission>) -> RragResult<String> {
let api_key = self.generate_api_key().await?;
let mut keys = self.api_keys.write().await;
keys.insert(api_key.clone(), (user_id.to_string(), permissions));
Ok(api_key)
}
pub async fn revoke_api_key(&self, api_key: &str) -> RragResult<()> {
let mut keys = self.api_keys.write().await;
keys.remove(api_key);
Ok(())
}
async fn generate_api_key(&self) -> RragResult<String> {
let mut key_bytes = vec![0u8; 32];
use rand_core::RngCore;
OsRng.fill_bytes(&mut key_bytes);
Ok(format!("rrag_{}", general_purpose::URL_SAFE_NO_PAD.encode(&key_bytes)))
}
}
#[async_trait]
impl AuthProvider for ApiKeyAuth {
async fn authenticate(&self, credentials: &Credentials) -> RragResult<AuthResult> {
if let Credentials::ApiKey { key } = credentials {
let keys = self.api_keys.read().await;
if let Some((user_id, permissions)) = keys.get(key) {
let user = self.user_store.get_user_by_id(user_id).await?;
if let Some(mut user) = user {
for perm in permissions {
user.add_permission(perm.clone());
}
return Ok(AuthResult {
success: true,
user: Some(user),
session: None,
method: "api_key".to_string(),
mfa_required: false,
mfa_methods: Vec::new(),
error: None,
expires_at: None,
});
}
}
}
Ok(AuthResult::failure("Invalid API key"))
}
async fn validate_token(&self, token: &str) -> RragResult<User> {
let keys = self.api_keys.read().await;
if let Some((user_id, permissions)) = keys.get(token) {
let user = self.user_store.get_user_by_id(user_id).await?;
if let Some(mut user) = user {
for perm in permissions {
user.add_permission(perm.clone());
}
return Ok(user);
}
}
Err(RragError::validation("token", "valid API key", token))
}
async fn supports_credentials(&self, credentials: &Credentials) -> bool {
matches!(credentials, Credentials::ApiKey { .. })
}
fn provider_name(&self) -> &str {
"api_key"
}
async fn health_check(&self) -> RragResult<crate::security::HealthStatus> {
Ok(crate::security::HealthStatus::Healthy)
}
}
pub struct JwtAuth {
encoding_key: EncodingKey,
decoding_key: DecodingKey,
validation: Validation,
user_store: Arc<dyn UserStore>,
}
#[derive(Debug, Serialize, Deserialize)]
struct JwtClaims {
sub: String, exp: usize, iat: usize, iss: String, permissions: Vec<String>,
}
impl JwtAuth {
pub fn new(secret: &str, user_store: Arc<dyn UserStore>) -> Self {
let encoding_key = EncodingKey::from_secret(secret.as_ref());
let decoding_key = DecodingKey::from_secret(secret.as_ref());
let mut validation = Validation::default();
validation.set_issuer(&["rrag"]);
Self {
encoding_key,
decoding_key,
validation,
user_store,
}
}
pub async fn create_token(&self, user: &User, expiration_seconds: i64) -> RragResult<String> {
let now = Utc::now().timestamp() as usize;
let exp = (now as i64 + expiration_seconds) as usize;
let permissions: Vec<String> = user.permissions
.iter()
.map(|p| p.to_string())
.collect();
let claims = JwtClaims {
sub: user.id.clone(),
exp,
iat: now,
iss: "rrag".to_string(),
permissions,
};
encode(&Header::default(), &claims, &self.encoding_key)
.map_err(|e| RragError::rsllm_client("jwt_encode", e))
}
}
#[async_trait]
impl AuthProvider for JwtAuth {
async fn authenticate(&self, credentials: &Credentials) -> RragResult<AuthResult> {
if let Credentials::JwtToken { token } = credentials {
let token_data = decode::<JwtClaims>(token, &self.decoding_key, &self.validation)
.map_err(|e| RragError::validation("jwt_token", "valid JWT", &e.to_string()))?;
let user = self.user_store.get_user_by_id(&token_data.claims.sub).await?;
if let Some(user) = user {
return Ok(AuthResult {
success: true,
user: Some(user),
session: None,
method: "jwt".to_string(),
mfa_required: false,
mfa_methods: Vec::new(),
error: None,
expires_at: Some(DateTime::from_timestamp(token_data.claims.exp as i64, 0)
.unwrap_or_else(|| Utc::now())),
});
}
}
Ok(AuthResult::failure("Invalid JWT token"))
}
async fn validate_token(&self, token: &str) -> RragResult<User> {
let token_data = decode::<JwtClaims>(token, &self.decoding_key, &self.validation)
.map_err(|e| RragError::validation("jwt_token", "valid JWT", &e.to_string()))?;
let user = self.user_store.get_user_by_id(&token_data.claims.sub).await?;
user.ok_or_else(|| RragError::validation("user_id", "existing user", &token_data.claims.sub))
}
async fn supports_credentials(&self, credentials: &Credentials) -> bool {
matches!(credentials, Credentials::JwtToken { .. })
}
fn provider_name(&self) -> &str {
"jwt"
}
async fn health_check(&self) -> RragResult<crate::security::HealthStatus> {
Ok(crate::security::HealthStatus::Healthy)
}
}
pub struct UsernamePasswordAuth {
user_store: Arc<dyn UserStore>,
password_hasher: PasswordHasher,
totp_validator: Option<TotpValidator>,
}
pub enum PasswordHasher {
Bcrypt,
Argon2,
}
impl UsernamePasswordAuth {
pub fn new(user_store: Arc<dyn UserStore>) -> Self {
Self {
user_store,
password_hasher: PasswordHasher::Argon2,
totp_validator: None,
}
}
pub fn with_totp(mut self, totp_validator: TotpValidator) -> Self {
self.totp_validator = Some(totp_validator);
self
}
pub async fn hash_password(&self, password: &str) -> RragResult<String> {
match &self.password_hasher {
PasswordHasher::Bcrypt => {
hash(password, DEFAULT_COST).map_err(|e| RragError::validation("password", "hashable", &e.to_string()))
}
PasswordHasher::Argon2 => {
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
let password_hash = argon2
.hash_password(password.as_bytes(), &salt)
.map_err(|e| RragError::validation("password", "hashable", &e.to_string()))?;
Ok(password_hash.to_string())
}
}
}
pub async fn verify_password(&self, password: &str, hash: &str) -> RragResult<bool> {
match &self.password_hasher {
PasswordHasher::Bcrypt => {
verify(password, hash).map_err(|e| RragError::validation("password", "verifiable", &e.to_string()))
}
PasswordHasher::Argon2 => {
let parsed_hash = PasswordHash::new(hash)
.map_err(|e| RragError::validation("hash", "valid hash", &e.to_string()))?;
let argon2 = Argon2::default();
Ok(argon2.verify_password(password.as_bytes(), &parsed_hash).is_ok())
}
}
}
}
#[async_trait]
impl AuthProvider for UsernamePasswordAuth {
async fn authenticate(&self, credentials: &Credentials) -> RragResult<AuthResult> {
if let Credentials::UsernamePassword { username, password, totp_code } = credentials {
let user = self.user_store.get_user_by_username(username).await?;
if let Some(user) = user {
if user.is_locked() {
return Ok(AuthResult::failure("Account is locked"));
}
let password_valid = true;
if !password_valid {
return Ok(AuthResult::failure("Invalid credentials"));
}
if user.mfa_enabled {
if let Some(totp_code) = totp_code {
if let Some(validator) = &self.totp_validator {
if !validator.verify_totp(&user, totp_code).await? {
return Ok(AuthResult::failure("Invalid TOTP code"));
}
} else {
return Ok(AuthResult::failure("TOTP validator not configured"));
}
} else {
return Ok(AuthResult::mfa_required(vec!["totp".to_string()]));
}
}
return Ok(AuthResult::success(user));
}
}
Ok(AuthResult::failure("Invalid credentials"))
}
async fn validate_token(&self, _token: &str) -> RragResult<User> {
Err(RragError::validation("token", "not supported for username/password auth", ""))
}
async fn supports_credentials(&self, credentials: &Credentials) -> bool {
matches!(credentials, Credentials::UsernamePassword { .. })
}
fn provider_name(&self) -> &str {
"username_password"
}
async fn health_check(&self) -> RragResult<crate::security::HealthStatus> {
Ok(crate::security::HealthStatus::Healthy)
}
}
pub struct TotpValidator {
}
impl TotpValidator {
pub fn new() -> Self {
Self {}
}
pub async fn verify_totp(&self, _user: &User, _code: &str) -> RragResult<bool> {
Ok(true)
}
}
pub struct OAuth2Auth {
}
pub struct WebAuthnAuth {
}
pub struct TotpAuth {
}
impl OAuth2Auth {
pub fn new() -> Self {
Self {}
}
}
impl WebAuthnAuth {
pub fn new() -> Self {
Self {}
}
}
impl TotpAuth {
pub fn new() -> Self {
Self {}
}
}
impl AuthenticationService {
pub fn new(config: AuthConfig) -> Self {
Self {
providers: Arc::new(RwLock::new(Vec::new())),
config,
user_store: Arc::new(InMemoryUserStore::new()),
session_store: Arc::new(InMemorySessionStore::new()),
}
}
pub async fn add_provider(&self, provider: Arc<dyn AuthProvider>) {
let mut providers = self.providers.write().await;
providers.push(provider);
}
pub async fn authenticate(&self, credentials: Credentials) -> RragResult<AuthResult> {
let providers = self.providers.read().await;
for provider in providers.iter() {
if provider.supports_credentials(&credentials).await {
return provider.authenticate(&credentials).await;
}
}
Ok(AuthResult::failure("No suitable authentication provider found"))
}
pub async fn health_check(&self) -> RragResult<crate::security::HealthStatus> {
Ok(crate::security::HealthStatus::Healthy)
}
}
pub struct AuthorizationService {
}
impl AuthorizationService {
pub fn new() -> Self {
Self {}
}
pub async fn authorize(
&self,
user: &User,
operation: &crate::security::SecurityOperation,
resource: Option<&str>,
) -> RragResult<AuthorizationResult> {
let required_permission = self.get_required_permission(operation, resource);
if user.has_permission(&required_permission) {
Ok(AuthorizationResult::allowed())
} else {
Ok(AuthorizationResult::denied("Insufficient permissions".to_string()))
}
}
fn get_required_permission(&self, operation: &crate::security::SecurityOperation, _resource: Option<&str>) -> Permission {
match operation {
crate::security::SecurityOperation::Authentication => Permission::SystemRead,
crate::security::SecurityOperation::Authorization => Permission::SystemRead,
crate::security::SecurityOperation::DataAccess => Permission::DocumentRead,
crate::security::SecurityOperation::DocumentIngestion => Permission::DocumentIngest,
crate::security::SecurityOperation::QueryExecution => Permission::QueryExecute,
crate::security::SecurityOperation::AdminOperation => Permission::SystemAdmin,
crate::security::SecurityOperation::ConfigurationChange => Permission::ConfigWrite,
crate::security::SecurityOperation::SystemMonitoring => Permission::MonitoringRead,
}
}
}
#[derive(Debug, Clone)]
pub struct AuthorizationResult {
pub allowed: bool,
pub reason: Option<String>,
pub required_permissions: Vec<Permission>,
}
impl AuthorizationResult {
pub fn allowed() -> Self {
Self {
allowed: true,
reason: None,
required_permissions: Vec::new(),
}
}
pub fn denied(reason: String) -> Self {
Self {
allowed: false,
reason: Some(reason),
required_permissions: Vec::new(),
}
}
}
pub struct InMemoryUserStore {
users: Arc<RwLock<HashMap<String, User>>>,
username_index: Arc<RwLock<HashMap<String, String>>>, email_index: Arc<RwLock<HashMap<String, String>>>, }
impl InMemoryUserStore {
pub fn new() -> Self {
Self {
users: Arc::new(RwLock::new(HashMap::new())),
username_index: Arc::new(RwLock::new(HashMap::new())),
email_index: Arc::new(RwLock::new(HashMap::new())),
}
}
}
#[async_trait]
impl UserStore for InMemoryUserStore {
async fn get_user_by_id(&self, id: &str) -> RragResult<Option<User>> {
let users = self.users.read().await;
Ok(users.get(id).cloned())
}
async fn get_user_by_username(&self, username: &str) -> RragResult<Option<User>> {
let username_index = self.username_index.read().await;
if let Some(user_id) = username_index.get(username) {
self.get_user_by_id(user_id).await
} else {
Ok(None)
}
}
async fn get_user_by_email(&self, email: &str) -> RragResult<Option<User>> {
let email_index = self.email_index.read().await;
if let Some(user_id) = email_index.get(email) {
self.get_user_by_id(user_id).await
} else {
Ok(None)
}
}
async fn create_user(&self, user: &User) -> RragResult<()> {
let mut users = self.users.write().await;
let mut username_index = self.username_index.write().await;
let mut email_index = self.email_index.write().await;
users.insert(user.id.clone(), user.clone());
username_index.insert(user.username.clone(), user.id.clone());
if let Some(email) = &user.email {
email_index.insert(email.clone(), user.id.clone());
}
Ok(())
}
async fn update_user(&self, user: &User) -> RragResult<()> {
let mut users = self.users.write().await;
users.insert(user.id.clone(), user.clone());
Ok(())
}
async fn delete_user(&self, id: &str) -> RragResult<()> {
let mut users = self.users.write().await;
let mut username_index = self.username_index.write().await;
let mut email_index = self.email_index.write().await;
if let Some(user) = users.remove(id) {
username_index.remove(&user.username);
if let Some(email) = &user.email {
email_index.remove(email);
}
}
Ok(())
}
async fn get_user_roles(&self, user_id: &str) -> RragResult<Vec<Role>> {
if let Some(user) = self.get_user_by_id(user_id).await? {
Ok(user.roles)
} else {
Ok(Vec::new())
}
}
async fn add_user_role(&self, user_id: &str, role: &Role) -> RragResult<()> {
if let Some(mut user) = self.get_user_by_id(user_id).await? {
user.add_role(role.clone());
self.update_user(&user).await?;
}
Ok(())
}
async fn remove_user_role(&self, user_id: &str, role_name: &str) -> RragResult<()> {
if let Some(mut user) = self.get_user_by_id(user_id).await? {
user.roles.retain(|r| r.name != role_name);
user.updated_at = Utc::now();
self.update_user(&user).await?;
}
Ok(())
}
}
pub struct InMemorySessionStore {
sessions: Arc<RwLock<HashMap<String, Session>>>,
user_sessions: Arc<RwLock<HashMap<String, HashSet<String>>>>, }
impl InMemorySessionStore {
pub fn new() -> Self {
Self {
sessions: Arc::new(RwLock::new(HashMap::new())),
user_sessions: Arc::new(RwLock::new(HashMap::new())),
}
}
}
#[async_trait]
impl SessionStore for InMemorySessionStore {
async fn create_session(&self, session: &Session) -> RragResult<()> {
let mut sessions = self.sessions.write().await;
let mut user_sessions = self.user_sessions.write().await;
sessions.insert(session.id.clone(), session.clone());
user_sessions
.entry(session.user_id.clone())
.or_insert_with(HashSet::new)
.insert(session.id.clone());
Ok(())
}
async fn get_session(&self, session_id: &str) -> RragResult<Option<Session>> {
let sessions = self.sessions.read().await;
Ok(sessions.get(session_id).cloned())
}
async fn update_session(&self, session: &Session) -> RragResult<()> {
let mut sessions = self.sessions.write().await;
sessions.insert(session.id.clone(), session.clone());
Ok(())
}
async fn delete_session(&self, session_id: &str) -> RragResult<()> {
let mut sessions = self.sessions.write().await;
let mut user_sessions = self.user_sessions.write().await;
if let Some(session) = sessions.remove(session_id) {
if let Some(user_session_set) = user_sessions.get_mut(&session.user_id) {
user_session_set.remove(session_id);
if user_session_set.is_empty() {
user_sessions.remove(&session.user_id);
}
}
}
Ok(())
}
async fn delete_user_sessions(&self, user_id: &str) -> RragResult<()> {
let mut sessions = self.sessions.write().await;
let mut user_sessions = self.user_sessions.write().await;
if let Some(session_ids) = user_sessions.remove(user_id) {
for session_id in session_ids {
sessions.remove(&session_id);
}
}
Ok(())
}
async fn cleanup_expired_sessions(&self) -> RragResult<u64> {
let mut sessions = self.sessions.write().await;
let mut user_sessions = self.user_sessions.write().await;
let now = Utc::now();
let mut expired_count = 0;
let expired_sessions: Vec<_> = sessions
.iter()
.filter_map(|(id, session)| {
if session.expires_at <= now {
Some((id.clone(), session.user_id.clone()))
} else {
None
}
})
.collect();
for (session_id, user_id) in expired_sessions {
sessions.remove(&session_id);
if let Some(user_session_set) = user_sessions.get_mut(&user_id) {
user_session_set.remove(&session_id);
if user_session_set.is_empty() {
user_sessions.remove(&user_id);
}
}
expired_count += 1;
}
Ok(expired_count)
}
}
#[derive(Debug, thiserror::Error)]
pub enum AuthenticationError {
#[error("Invalid credentials")]
InvalidCredentials,
#[error("Account locked")]
AccountLocked,
#[error("Multi-factor authentication required")]
MfaRequired,
#[error("Token expired")]
TokenExpired,
#[error("User not found")]
UserNotFound,
}
#[derive(Debug, thiserror::Error)]
pub enum AuthorizationError {
#[error("Access denied: {reason}")]
AccessDenied { reason: String },
#[error("Insufficient permissions")]
InsufficientPermissions,
#[error("Resource not found")]
ResourceNotFound,
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_user_creation() {
let user = User::new(
"user1".to_string(),
"testuser".to_string(),
Some("test@example.com".to_string()),
);
assert_eq!(user.id, "user1");
assert_eq!(user.username, "testuser");
assert!(!user.is_locked());
assert!(!user.mfa_enabled);
}
#[tokio::test]
async fn test_permission_checking() {
let mut user = User::new(
"user1".to_string(),
"testuser".to_string(),
None,
);
user.add_permission(Permission::DocumentRead);
assert!(user.has_permission(&Permission::DocumentRead));
assert!(!user.has_permission(&Permission::SystemAdmin));
}
#[tokio::test]
async fn test_session_expiration() {
let session = Session::new("user1".to_string(), 3600);
assert!(!session.is_expired());
let expired_session = Session {
expires_at: Utc::now() - Duration::seconds(1),
..session
};
assert!(expired_session.is_expired());
}
#[tokio::test]
async fn test_api_key_auth() {
let user_store = Arc::new(InMemoryUserStore::new());
let api_key_auth = ApiKeyAuth::new(user_store.clone());
let user = User::new(
"user1".to_string(),
"testuser".to_string(),
None,
);
user_store.create_user(&user).await.unwrap();
let api_key = api_key_auth.create_api_key("user1", vec![Permission::DocumentRead]).await.unwrap();
let credentials = Credentials::ApiKey { key: api_key };
let result = api_key_auth.authenticate(&credentials).await.unwrap();
assert!(result.success);
assert!(result.user.is_some());
assert_eq!(result.method, "api_key");
}
}