use anyhow::Result;
use chrono::{Duration, Utc};
use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tokio::sync::RwLock;
use uuid::Uuid;
use std::sync::Arc;
use crate::config::Config;
use crate::{
utils::crypto,
};
pub struct AuthService {
auth_config: LocalAuthConfig,
sessions: RwLock<HashMap<String, Session>>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct Claims {
pub sub: String,
pub exp: i64,
pub iat: i64,
pub roles: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct Session {
pub user_id: String,
pub username: String,
pub roles: Vec<String>,
pub created_at: chrono::DateTime<Utc>,
pub expires_at: chrono::DateTime<Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct User {
pub id: String,
pub username: String,
pub email: String,
pub password_hash: String,
pub roles: Vec<String>,
pub created_at: chrono::DateTime<Utc>,
pub is_active: bool,
}
#[derive(Debug, Deserialize)]
pub struct LoginRequest {
pub username: String,
pub password: String,
}
#[derive(Debug, Serialize)]
pub struct LoginResponse {
pub token: String,
pub expires_in: i64,
pub user: UserInfo,
}
#[derive(Debug, Serialize)]
pub struct UserInfo {
pub id: String,
pub username: String,
pub email: String,
pub roles: Vec<String>,
}
#[derive(Debug, Deserialize)]
pub struct RegisterRequest {
pub username: String,
pub email: String,
pub password: String,
}
struct LocalAuthConfig {
#[allow(dead_code)]
jwt_secret: String,
session_timeout: u64,
allow_registration: bool,
encoding_key: EncodingKey,
decoding_key: DecodingKey,
}
impl std::fmt::Debug for LocalAuthConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LocalAuthConfig")
.field("session_timeout", &self.session_timeout)
.field("allow_registration", &self.allow_registration)
.finish_non_exhaustive()
}
}
impl AuthService {
pub async fn new(config: Arc<Config>) -> Result<Self> {
let encoding_key = EncodingKey::from_secret(config.auth.jwt_secret.as_ref());
let decoding_key = DecodingKey::from_secret(config.auth.jwt_secret.as_ref());
let auth_service = Self {
auth_config: LocalAuthConfig {
jwt_secret: config.auth.jwt_secret.clone(),
session_timeout: config.auth.token_expiration,
allow_registration: config.auth.allow_registration,
encoding_key,
decoding_key,
},
sessions: RwLock::new(HashMap::new()),
};
Ok(auth_service)
}
pub async fn login(&self, request: LoginRequest) -> Result<LoginResponse> {
let user = self.find_user_by_username(&request.username).await?;
if !user.is_active {
return Err(anyhow::anyhow!("User account is disabled"));
}
if !crypto::verify_password(&request.password, &user.password_hash)? {
return Err(anyhow::anyhow!("Invalid credentials"));
}
let token = self.generate_token(&user)?;
let expires_in = self.auth_config.session_timeout as i64;
let session = Session {
user_id: user.id.clone(),
username: user.username.clone(),
roles: user.roles.clone(),
created_at: Utc::now(),
expires_at: Utc::now() + Duration::seconds(expires_in),
};
self.sessions.write().await.insert(user.id.clone(), session);
Ok(LoginResponse {
token,
expires_in,
user: UserInfo {
id: user.id,
username: user.username,
email: user.email,
roles: user.roles,
},
})
}
pub async fn register(&self, request: RegisterRequest) -> Result<UserInfo> {
if !self.auth_config.allow_registration {
return Err(anyhow::anyhow!("Registration is disabled"));
}
if !crate::utils::validation::is_valid_email(&request.email) {
return Err(anyhow::anyhow!("Invalid email format"));
}
if request.password.len() < 8 {
return Err(anyhow::anyhow!("Password must be at least 8 characters"));
}
if self.find_user_by_username(&request.username).await.is_ok() {
return Err(anyhow::anyhow!("Username already exists"));
}
let password_hash = crypto::hash_password(&request.password)?;
let user = User {
id: Uuid::new_v4().to_string(),
username: request.username,
email: request.email,
password_hash,
roles: vec!["user".to_string()],
created_at: Utc::now(),
is_active: true,
};
self.save_user(&user).await?;
Ok(UserInfo {
id: user.id,
username: user.username,
email: user.email,
roles: user.roles,
})
}
pub fn validate_token(&self, token: &str) -> Result<Claims> {
let validation = Validation::new(Algorithm::HS256);
let token_data = decode::<Claims>(token, &self.auth_config.decoding_key, &validation)?;
let now = Utc::now().timestamp();
if token_data.claims.exp < now {
return Err(anyhow::anyhow!("Token has expired"));
}
Ok(token_data.claims)
}
pub async fn get_session(&self, user_id: &str) -> Option<Session> {
let sessions = self.sessions.read().await;
sessions.get(user_id).cloned()
}
pub async fn logout(&self, user_id: &str) -> Result<()> {
self.sessions.write().await.remove(user_id);
Ok(())
}
pub fn has_role(&self, claims: &Claims, required_role: &str) -> bool {
claims.roles.contains(&required_role.to_string()) ||
claims.roles.contains(&"admin".to_string())
}
pub async fn cleanup_expired_sessions(&self) {
let now = Utc::now();
let mut sessions = self.sessions.write().await;
sessions.retain(|_, session| session.expires_at > now);
}
fn generate_token(&self, user: &User) -> Result<String> {
let now = Utc::now();
let expires_at = now + Duration::seconds(self.auth_config.session_timeout as i64);
let claims = Claims {
sub: user.id.clone(),
exp: expires_at.timestamp(),
iat: now.timestamp(),
roles: user.roles.clone(),
};
let token = encode(&Header::default(), &claims, &self.auth_config.encoding_key)?;
Ok(token)
}
async fn find_user_by_username(&self, username: &str) -> Result<User> {
if username == "admin" {
return Ok(User {
id: "admin-id".to_string(),
username: "admin".to_string(),
email: "admin@example.com".to_string(),
password_hash: crypto::hash_password("admin123")?,
roles: vec!["admin".to_string()],
created_at: Utc::now(),
is_active: true,
});
}
Err(anyhow::anyhow!("User not found"))
}
async fn save_user(&self, _user: &User) -> Result<()> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_auth_service() -> AuthService {
let config = LocalAuthConfig {
jwt_secret: "test_secret_key_that_is_long_enough".to_string(),
session_timeout: 3600,
allow_registration: true,
encoding_key: EncodingKey::from_secret("test_secret_key_that_is_long_enough".as_ref()),
decoding_key: DecodingKey::from_secret("test_secret_key_that_is_long_enough".as_ref()),
};
AuthService {
auth_config: config,
sessions: RwLock::new(HashMap::new()),
}
}
#[tokio::test]
async fn test_login_success() {
let auth_service = create_test_auth_service();
let request = LoginRequest {
username: "admin".to_string(),
password: "admin123".to_string(),
};
let result = auth_service.login(request).await;
assert!(result.is_ok());
let response = result.unwrap();
assert!(!response.token.is_empty());
assert_eq!(response.user.username, "admin");
}
#[tokio::test]
async fn test_token_validation() {
let auth_service = create_test_auth_service();
let user = User {
id: "test-id".to_string(),
username: "test".to_string(),
email: "test@example.com".to_string(),
password_hash: "hash".to_string(),
roles: vec!["user".to_string()],
created_at: Utc::now(),
is_active: true,
};
let token = auth_service.generate_token(&user).unwrap();
let claims = auth_service.validate_token(&token).unwrap();
assert_eq!(claims.sub, "test-id");
assert_eq!(claims.roles, vec!["user"]);
}
#[test]
fn test_role_checking() {
let auth_service = create_test_auth_service();
let claims = Claims {
sub: "user-id".to_string(),
exp: Utc::now().timestamp() + 3600,
iat: Utc::now().timestamp(),
roles: vec!["user".to_string()],
};
assert!(auth_service.has_role(&claims, "user"));
assert!(!auth_service.has_role(&claims, "admin"));
}
}