use crate::server::middleware::AuthRateLimiter;
use crate::server::routes::ApiResponse;
use crate::server::state::AppState;
use crate::utils::auth::crypto::password::verify_password;
use actix_web::{HttpRequest, HttpResponse, Result as ActixResult, web};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tracing::{error, info, warn};
use super::models::{LoginRequest, LoginResponse, UserInfo};
static LOGIN_RATE_LIMITER: std::sync::OnceLock<Arc<AuthRateLimiter>> = std::sync::OnceLock::new();
fn get_login_rate_limiter() -> Arc<AuthRateLimiter> {
LOGIN_RATE_LIMITER
.get_or_init(|| Arc::new(AuthRateLimiter::new(5, 60, 60)))
.clone()
}
fn client_ip_from_xff(xff: &str) -> Option<String> {
xff.split(',')
.next()
.map(str::trim)
.filter(|s| !s.is_empty())
.and_then(|s| s.parse::<std::net::IpAddr>().ok())
.map(|ip| ip.to_string())
}
fn extract_client_ip(req: &HttpRequest, trusted_proxies: &[String]) -> String {
let peer = req
.connection_info()
.peer_addr()
.unwrap_or("unknown")
.to_string();
let peer_ip = peer
.parse::<std::net::SocketAddr>()
.map(|addr| addr.ip().to_string())
.unwrap_or(peer);
if !trusted_proxies.is_empty()
&& trusted_proxies.contains(&peer_ip)
&& let Some(xff) = req.headers().get("x-forwarded-for")
&& let Ok(xff_str) = xff.to_str()
&& let Some(client_ip) = client_ip_from_xff(xff_str)
{
return client_ip;
}
peer_ip
}
static REQUEST_COUNTER: AtomicU64 = AtomicU64::new(0);
pub async fn login(
req: HttpRequest,
state: web::Data<AppState>,
request: web::Json<LoginRequest>,
) -> ActixResult<HttpResponse> {
let cfg = state.config.load();
let client_ip = extract_client_ip(&req, &cfg.gateway.server.trusted_proxies);
let limiter = get_login_rate_limiter();
let count = REQUEST_COUNTER.fetch_add(1, Ordering::Relaxed);
if count.is_multiple_of(100) {
limiter.cleanup_old_entries();
}
if let Err(retry_after) = limiter.check_allowed(&client_ip) {
warn!(
"Login rate limit exceeded for IP {}: retry after {}s",
client_ip, retry_after
);
return Ok(HttpResponse::TooManyRequests()
.insert_header(("Retry-After", retry_after.to_string()))
.json(ApiResponse::<()>::error(
"Too many login attempts. Please try again later.".to_string(),
)));
}
info!("User login attempt from IP {}", client_ip);
let user = match state
.storage
.database
.find_user_by_username(&request.username)
.await
{
Ok(Some(user)) => user,
Ok(None) => {
warn!("Login attempt with invalid username from IP {}", client_ip);
limiter.record_failure(&client_ip);
return Ok(HttpResponse::Unauthorized()
.json(ApiResponse::<()>::error("Invalid credentials".to_string())));
}
Err(e) => {
error!("Database error during login: {}", e);
return Ok(HttpResponse::InternalServerError()
.json(ApiResponse::<()>::error("Database error".to_string())));
}
};
if !user.is_active() {
warn!("Login attempt for inactive user from IP {}", client_ip);
limiter.record_failure(&client_ip);
return Ok(HttpResponse::Forbidden()
.json(ApiResponse::<()>::error("Account is disabled".to_string())));
}
let password_valid = match verify_password(&request.password, &user.password_hash) {
Ok(valid) => valid,
Err(e) => {
error!("Password verification error: {}", e);
return Ok(HttpResponse::InternalServerError()
.json(ApiResponse::<()>::error("Authentication error".to_string())));
}
};
if !password_valid {
warn!("Login attempt with invalid password from IP {}", client_ip);
limiter.record_failure(&client_ip);
return Ok(HttpResponse::Unauthorized()
.json(ApiResponse::<()>::error("Invalid credentials".to_string())));
}
if let Err(e) = state
.storage
.database
.update_user_last_login(user.id())
.await
{
warn!("Failed to update last login time: {}", e);
}
let access_token = match state
.auth
.jwt()
.create_access_token(user.id(), user.role.to_string(), vec![], None, None)
.await
{
Ok(token) => token,
Err(e) => {
error!("Failed to generate access token: {}", e);
return Ok(
HttpResponse::InternalServerError().json(ApiResponse::<()>::error(
"Token generation failed".to_string(),
)),
);
}
};
let refresh_token = match state.auth.jwt().create_refresh_token(user.id(), None).await {
Ok(token) => token,
Err(e) => {
error!("Failed to generate refresh token: {}", e);
return Ok(
HttpResponse::InternalServerError().json(ApiResponse::<()>::error(
"Token generation failed".to_string(),
)),
);
}
};
info!("User logged in successfully from IP {}", client_ip);
let response = LoginResponse {
access_token,
refresh_token,
token_type: "Bearer".to_string(),
expires_in: 3600, user: UserInfo {
id: user.id(),
username: user.username,
email: user.email,
full_name: user.display_name,
role: user.role.to_string(),
email_verified: user.email_verified,
},
};
Ok(HttpResponse::Ok().json(ApiResponse::success(response)))
}
#[cfg(test)]
mod tests {
use super::*;
use uuid::Uuid;
#[test]
fn test_login_request_deserialization() {
let json = r#"{"username": "testuser", "password": "pass123"}"#;
let request: LoginRequest = serde_json::from_str(json).expect("Failed to deserialize");
assert_eq!(request.username, "testuser");
assert_eq!(request.password, "pass123");
}
#[test]
fn test_login_request_missing_fields() {
let json = r#"{"username": "testuser"}"#;
let result: Result<LoginRequest, _> = serde_json::from_str(json);
assert!(result.is_err());
}
#[test]
fn test_login_response_serialization() {
let response = LoginResponse {
access_token: "access_token_here".to_string(),
refresh_token: "refresh_token_here".to_string(),
token_type: "Bearer".to_string(),
expires_in: 3600,
user: UserInfo {
id: Uuid::new_v4(),
username: "testuser".to_string(),
email: "test@example.com".to_string(),
full_name: Some("Test User".to_string()),
role: "User".to_string(),
email_verified: true,
},
};
let json = serde_json::to_string(&response).expect("Failed to serialize");
assert!(json.contains("access_token"));
assert!(json.contains("refresh_token"));
assert!(json.contains("Bearer"));
assert!(json.contains("testuser"));
}
#[test]
fn test_user_info_structure() {
let user_info = UserInfo {
id: Uuid::new_v4(),
username: "john_doe".to_string(),
email: "john@example.com".to_string(),
full_name: Some("John Doe".to_string()),
role: "Admin".to_string(),
email_verified: true,
};
assert_eq!(user_info.username, "john_doe");
assert_eq!(user_info.role, "Admin");
assert!(user_info.email_verified);
assert!(user_info.full_name.is_some());
}
#[test]
fn test_login_rate_limiter_blocks_after_limit() {
let limiter = AuthRateLimiter::new(5, 60, 60);
let ip = "192.0.2.1";
for _ in 0..5 {
assert!(limiter.check_allowed(ip).is_ok());
limiter.record_failure(ip);
}
assert!(limiter.check_allowed(ip).is_err());
}
#[test]
fn test_login_rate_limiter_different_ips_independent() {
let limiter = AuthRateLimiter::new(5, 60, 60);
let ip1 = "192.0.2.1";
let ip2 = "192.0.2.2";
for _ in 0..5 {
assert!(limiter.check_allowed(ip1).is_ok());
limiter.record_failure(ip1);
}
assert!(limiter.check_allowed(ip1).is_err());
assert!(limiter.check_allowed(ip2).is_ok());
}
#[test]
fn test_extract_client_ip_strips_port() {
let ipv4_with_port = "192.0.2.1:54321"
.parse::<std::net::SocketAddr>()
.map(|a| a.ip().to_string())
.unwrap_or_else(|_| "192.0.2.1:54321".to_string());
assert_eq!(ipv4_with_port, "192.0.2.1");
let ipv6_with_port = "[::1]:54321"
.parse::<std::net::SocketAddr>()
.map(|a| a.ip().to_string())
.unwrap_or_else(|_| "[::1]:54321".to_string());
assert_eq!(ipv6_with_port, "::1");
}
#[test]
fn test_success_does_not_reset_rate_limit_counter() {
let limiter = AuthRateLimiter::new(5, 60, 60);
let ip = "192.0.2.1";
for _ in 0..4 {
assert!(limiter.check_allowed(ip).is_ok());
limiter.record_failure(ip);
}
assert!(limiter.check_allowed(ip).is_ok());
limiter.record_failure(ip);
assert!(limiter.check_allowed(ip).is_err());
}
#[test]
fn test_client_ip_from_xff_single() {
assert_eq!(
client_ip_from_xff("203.0.113.5"),
Some("203.0.113.5".to_string())
);
}
#[test]
fn test_client_ip_from_xff_chain() {
assert_eq!(
client_ip_from_xff("203.0.113.5, 10.0.0.1, 10.0.0.2"),
Some("203.0.113.5".to_string())
);
}
#[test]
fn test_client_ip_from_xff_invalid_returns_none() {
assert_eq!(client_ip_from_xff("not-an-ip, 10.0.0.1"), None);
}
#[test]
fn test_client_ip_from_xff_empty_returns_none() {
assert_eq!(client_ip_from_xff(""), None);
}
#[test]
fn test_client_ip_from_xff_ipv6() {
assert_eq!(
client_ip_from_xff("2001:db8::1, 10.0.0.1"),
Some("2001:db8::1".to_string())
);
}
}