use crate::api::{ApiResponse, ApiState, extract_bearer_token};
use axum::{Json, extract::State, http::HeaderMap};
use serde::{Deserialize, Serialize};
#[derive(Debug, Deserialize)]
pub struct LoginRequest {
pub username: String,
pub password: String,
#[serde(default)]
pub challenge_id: Option<String>,
#[serde(default)]
pub mfa_code: Option<String>,
#[serde(default)]
pub remember_me: bool,
}
#[derive(Debug, Serialize)]
pub struct LoginResponse {
pub access_token: String,
pub refresh_token: String,
pub token_type: String,
pub expires_in: u64,
pub user: LoginUserInfo,
pub login_risk_level: String,
pub security_warnings: Vec<String>,
}
#[derive(Debug, Serialize)]
pub struct LoginUserInfo {
pub id: String,
pub username: String,
pub roles: Vec<String>,
pub permissions: Vec<String>,
}
async fn build_login_response(
state: &ApiState,
user_id: &str,
username: String,
permissions: Vec<String>,
) -> ApiResponse<LoginResponse> {
let user_key = format!("user:{}", user_id);
let roles: Vec<String> = match state.auth_framework.storage().get_kv(&user_key).await {
Ok(Some(bytes)) => {
let json: serde_json::Value = serde_json::from_slice(&bytes).unwrap_or_default();
json["roles"]
.as_array()
.map(|arr| {
arr.iter()
.filter_map(|value| value.as_str())
.map(|value| value.to_string())
.collect()
})
.unwrap_or_default()
}
_ => vec![],
};
let user_info = LoginUserInfo {
id: user_id.to_string(),
username,
roles: roles.clone(),
permissions,
};
let token_lifetime = state.auth_framework.config().token_lifetime;
let access_token = match state.auth_framework.token_manager().create_jwt_token(
user_id,
roles,
Some(token_lifetime),
) {
Ok(jwt) => jwt,
Err(e) => {
tracing::error!("Failed to create JWT token: {}", e);
return ApiResponse::error_typed(
"TOKEN_CREATION_FAILED",
"Failed to create access token",
);
}
};
let refresh_token_lifetime = state.auth_framework.config().refresh_token_lifetime;
let refresh_token = match state.auth_framework.token_manager().create_jwt_token(
user_id,
vec!["refresh".to_string()],
Some(refresh_token_lifetime),
) {
Ok(jwt) => jwt,
Err(e) => {
tracing::error!("Failed to create refresh token: {}", e);
return ApiResponse::error_typed(
"TOKEN_CREATION_FAILED",
"Failed to create refresh token",
);
}
};
ApiResponse::success(LoginResponse {
access_token,
refresh_token,
token_type: "Bearer".to_string(),
expires_in: token_lifetime.as_secs(),
user: user_info,
login_risk_level: "low".to_string(), security_warnings: Vec::new(), })
}
#[derive(Debug, Deserialize)]
pub struct RefreshRequest {
pub refresh_token: String,
}
#[derive(Debug, Serialize)]
pub struct RefreshResponse {
pub access_token: String,
pub token_type: String,
pub expires_in: u64,
}
#[derive(Debug, Deserialize)]
pub struct LogoutRequest {
#[serde(default)]
pub refresh_token: Option<String>,
}
pub(crate) fn login_risk_level(headers: &HeaderMap) -> (&'static str, Vec<String>) {
let user_agent = headers
.get("user-agent")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
let forwarded_for = headers
.get("x-forwarded-for")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
let mut risk_points: u8 = 0;
let mut warnings: Vec<String> = Vec::new();
if user_agent.is_empty() {
risk_points = risk_points.saturating_add(30);
warnings.push(
"No browser User-Agent detected; this request may originate from an automated script."
.to_string(),
);
}
if user_agent.to_lowercase().contains("tor browser") {
risk_points = risk_points.saturating_add(40);
warnings.push("Login originated from the Tor Browser.".to_string());
}
let hop_count = forwarded_for.split(',').count();
if hop_count >= 2 {
risk_points = risk_points.saturating_add(15);
warnings.push(format!(
"Request passed through {} proxy hops (X-Forwarded-For).",
hop_count
));
}
let level = match risk_points {
0..=9 => "low",
10..=29 => "medium",
30..=59 => "high",
_ => "critical",
};
(level, warnings)
}
async fn increment_login_failure(state: &ApiState, lockout_key: &str, window_secs: u64) {
let current: u64 = match state.auth_framework.storage().get_kv(lockout_key).await {
Ok(Some(bytes)) => std::str::from_utf8(&bytes)
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(0),
_ => 0,
};
let new_count = current.saturating_add(1);
let _ = state
.auth_framework
.storage()
.store_kv(
lockout_key,
new_count.to_string().as_bytes(),
Some(std::time::Duration::from_secs(window_secs)),
)
.await;
}
pub async fn login(
State(state): State<ApiState>,
headers: HeaderMap,
Json(req): Json<LoginRequest>,
) -> ApiResponse<LoginResponse> {
if req.username.is_empty() || req.password.is_empty() {
return ApiResponse::validation_error_typed("Username and password are required");
}
if req.challenge_id.is_some() ^ req.mfa_code.is_some() {
return ApiResponse::validation_error_typed(
"challenge_id and mfa_code must be provided together",
);
}
if let (Some(challenge_id), Some(mfa_code)) =
(req.challenge_id.clone(), req.mfa_code.as_deref())
{
return match state
.auth_framework
.complete_mfa_by_id(&challenge_id, mfa_code)
.await
{
Ok(token) => {
let mut response = build_login_response(
&state,
&token.user_id,
req.username,
token.permissions.to_vec(),
)
.await;
if let Some(data) = response.data.as_mut() {
data.login_risk_level = "low".to_string();
}
response
}
Err(e) => {
tracing::debug!("MFA completion failed during login: {}", e);
ApiResponse::error_typed(
"MFA_INVALID_CODE",
"Invalid or expired MFA challenge or code",
)
}
};
}
let (risk_level, mut risk_warnings) = login_risk_level(&headers);
let lockout_key = format!("login_failures:{}", req.username);
const MAX_FAILED_ATTEMPTS: u64 = 5;
const LOCKOUT_WINDOW_SECS: u64 = 900; if let Ok(Some(count_bytes)) = state.auth_framework.storage().get_kv(&lockout_key).await {
if let Ok(count_str) = std::str::from_utf8(&count_bytes) {
if let Ok(count) = count_str.parse::<u64>() {
if count >= MAX_FAILED_ATTEMPTS {
tracing::warn!(
username = %req.username,
failed_attempts = count,
"Login rejected — account temporarily locked due to repeated failures"
);
return ApiResponse::error_typed(
"ACCOUNT_LOCKED",
"Too many failed login attempts. Please try again later.",
);
}
}
}
}
let credential = crate::authentication::credentials::Credential::Password {
username: req.username.clone(),
password: req.password.clone(),
};
match state
.auth_framework
.authenticate("password", credential)
.await
{
Ok(auth_result) => match auth_result {
crate::auth::AuthResult::Success(token) => {
let mfa_enrolled =
crate::api::mfa::check_user_mfa_status(&state.auth_framework, &token.user_id)
.await;
if !mfa_enrolled && matches!(risk_level, "high" | "critical") {
risk_warnings.push(
"Your account does not have multi-factor authentication enabled. \
Enable MFA to protect this account from high-risk login contexts."
.to_string(),
);
tracing::warn!(
user_id = %token.user_id,
risk_level = %risk_level,
"High-risk login without MFA enrolled"
);
} else {
tracing::info!(
user_id = %token.user_id,
risk_level = %risk_level,
mfa_enrolled = %mfa_enrolled,
"Successful login"
);
}
let mut response = build_login_response(
&state,
&token.user_id,
req.username,
token.permissions.to_vec(),
)
.await;
let _ = state.auth_framework.storage().delete_kv(&lockout_key).await;
if let Some(data) = response.data.as_mut() {
data.login_risk_level = risk_level.to_string();
data.security_warnings = risk_warnings;
}
response
}
crate::auth::AuthResult::MfaRequired(challenge) => {
let mfa_type_str = match &challenge.mfa_type {
crate::methods::MfaType::Totp => "totp",
crate::methods::MfaType::Sms { .. } => "sms",
crate::methods::MfaType::Email { .. } => "email",
crate::methods::MfaType::Push { .. } => "push",
crate::methods::MfaType::SecurityKey => "security_key",
crate::methods::MfaType::BackupCode => "backup_code",
crate::methods::MfaType::MultiMethod => "totp_or_backup_code",
};
ApiResponse::<()>::error_with_details(
"MFA_REQUIRED",
"Multi-factor authentication required",
serde_json::json!({
"challenge_id": challenge.id,
"mfa_type": mfa_type_str,
"expires_at": challenge.expires_at.to_rfc3339(),
"message": challenge.message,
}),
)
.cast()
}
crate::auth::AuthResult::Failure(reason) => {
increment_login_failure(&state, &lockout_key, LOCKOUT_WINDOW_SECS).await;
ApiResponse::error_typed("AUTHENTICATION_FAILED", reason)
}
},
Err(e) => {
increment_login_failure(&state, &lockout_key, LOCKOUT_WINDOW_SECS).await;
tracing::debug!(
"Authentication error (reported as INVALID_CREDENTIALS): {}",
e
);
ApiResponse::error_typed("INVALID_CREDENTIALS", "Invalid username or password")
}
}
}
pub async fn refresh_token(
State(state): State<ApiState>,
Json(req): Json<RefreshRequest>,
) -> ApiResponse<RefreshResponse> {
if req.refresh_token.is_empty() {
return ApiResponse::validation_error_typed("Invalid request");
}
match state
.auth_framework
.token_manager()
.validate_jwt_token(&req.refresh_token)
{
Ok(claims) => {
if !claims.scope.contains("refresh") {
return ApiResponse::error_typed(
"INVALID_TOKEN",
"Expected a refresh token, but received an access token",
);
}
let revocation_key = format!("revoked_token:{}", claims.jti);
match state.auth_framework.storage().get_kv(&revocation_key).await {
Ok(Some(_)) => {
return ApiResponse::error_typed(
"INVALID_TOKEN",
"Refresh token has been revoked",
);
}
Ok(None) => {} Err(e) => {
tracing::error!("Refresh token revocation check failed: {}", e);
return ApiResponse::error_typed(
"INTERNAL_ERROR",
"Unable to verify token status",
);
}
}
{
let user_key = format!("user:{}", claims.sub);
if let Ok(Some(user_bytes)) = state.auth_framework.storage().get_kv(&user_key).await
{
let user_json: serde_json::Value =
serde_json::from_slice(&user_bytes).unwrap_or_default();
let active = user_json["active"].as_bool().unwrap_or(true);
if !active {
return ApiResponse::error_typed(
"ACCOUNT_DEACTIVATED",
"Account has been deactivated",
);
}
}
}
let permissions: Vec<String> = match state
.auth_framework
.storage()
.get_kv(&format!("user_permissions:{}", claims.sub))
.await
{
Ok(Some(data)) => serde_json::from_slice(&data).unwrap_or_default(),
_ => vec![],
};
let token_lifetime = state.auth_framework.config().token_lifetime;
let new_access_token = match state.auth_framework.token_manager().create_jwt_token(
&claims.sub,
permissions,
Some(token_lifetime),
) {
Ok(jwt) => jwt,
Err(e) => {
tracing::error!("Failed to create new access token: {}", e);
return ApiResponse::error_typed(
"TOKEN_CREATION_FAILED",
"Failed to create new access token",
);
}
};
let response = RefreshResponse {
access_token: new_access_token,
token_type: "Bearer".to_string(),
expires_in: token_lifetime.as_secs(),
};
ApiResponse::success(response)
}
Err(e) => {
tracing::warn!("Invalid refresh token: {}", e);
ApiResponse::error_typed("INVALID_TOKEN", "Invalid or expired refresh token")
}
}
}
pub async fn logout(
State(state): State<ApiState>,
headers: HeaderMap,
Json(req): Json<LogoutRequest>,
) -> ApiResponse<()> {
if let Some(token) = extract_bearer_token(&headers) {
match state
.auth_framework
.token_manager()
.validate_jwt_token(&token)
{
Ok(claims) => {
let revocation_key = format!("revoked_token:{}", claims.jti);
let ttl = std::time::Duration::from_secs(7 * 86400);
if let Err(e) = state
.auth_framework
.storage()
.store_kv(revocation_key.as_str(), b"revoked", Some(ttl))
.await
{
tracing::error!("Failed to revoke access token JTI {}: {}", claims.jti, e);
} else {
tracing::info!("Access token revoked (JTI: {})", claims.jti);
}
}
Err(_) => {
tracing::debug!("Logout called with invalid/expired access token");
}
}
}
if let Some(ref refresh_token) = req.refresh_token {
match state
.auth_framework
.token_manager()
.validate_jwt_token(refresh_token)
{
Ok(claims) => {
let revocation_key = format!("revoked_token:{}", claims.jti);
let ttl = std::time::Duration::from_secs(7 * 86400);
if let Err(e) = state
.auth_framework
.storage()
.store_kv(revocation_key.as_str(), b"revoked", Some(ttl))
.await
{
tracing::error!("Failed to revoke refresh token JTI {}: {}", claims.jti, e);
} else {
tracing::info!("Refresh token revoked (JTI: {})", claims.jti);
}
}
Err(_) => {
tracing::debug!("Logout called with invalid/expired refresh token");
}
}
}
ApiResponse::<()>::ok_with_message("Successfully logged out")
}
pub async fn validate_token(
State(state): State<ApiState>,
headers: HeaderMap,
) -> ApiResponse<LoginUserInfo> {
match extract_bearer_token(&headers) {
Some(token) => {
match crate::api::validate_api_token(&state.auth_framework, &token).await {
Ok(auth_token) => {
let username = match state
.auth_framework
.get_user_profile(&auth_token.user_id)
.await
{
Ok(profile) => profile
.username
.unwrap_or_else(|| format!("user_{}", auth_token.user_id)),
Err(_) => format!("user_{}", auth_token.user_id), };
let user_info = LoginUserInfo {
id: auth_token.user_id,
username,
roles: auth_token.roles.to_vec(),
permissions: auth_token.permissions.to_vec(),
};
ApiResponse::success(user_info)
}
Err(_e) => ApiResponse::error_typed("AUTH_ERROR", "Token validation failed"),
}
}
None => ApiResponse::unauthorized_typed(),
}
}
pub async fn list_providers(State(_state): State<ApiState>) -> ApiResponse<Vec<ProviderInfo>> {
let providers = vec![
ProviderInfo {
name: "google".to_string(),
display_name: "Google".to_string(),
auth_url: "/oauth/google".to_string(),
},
ProviderInfo {
name: "github".to_string(),
display_name: "GitHub".to_string(),
auth_url: "/oauth/github".to_string(),
},
ProviderInfo {
name: "microsoft".to_string(),
display_name: "Microsoft".to_string(),
auth_url: "/oauth/microsoft".to_string(),
},
];
ApiResponse::success(providers)
}
#[derive(Debug, Serialize)]
pub struct ProviderInfo {
pub name: String,
pub display_name: String,
pub auth_url: String,
}
#[derive(Debug, Deserialize)]
pub struct RegisterRequest {
pub username: String,
pub email: String,
pub password: String,
}
#[derive(Debug, Serialize)]
pub struct RegisterResponse {
pub user_id: String,
pub username: String,
pub email: String,
}
pub async fn register(
State(state): State<ApiState>,
Json(req): Json<RegisterRequest>,
) -> ApiResponse<RegisterResponse> {
if req.username.is_empty() || req.password.is_empty() || req.email.is_empty() {
return ApiResponse::validation_error_typed("Username, password, and email are required");
}
if let Err(e) = crate::utils::validation::validate_username(&req.username) {
return ApiResponse::validation_error_typed(format!("{e}"));
}
if let Err(e) = crate::utils::validation::validate_password(&req.password) {
return ApiResponse::validation_error_typed(format!("{e}"));
}
match crate::utils::breach_check::is_password_breached(&req.password).await {
Ok(true) => {
return ApiResponse::validation_error_typed(
"This password has appeared in a known data breach. Please choose a different password.",
);
}
Ok(false) => {} Err(_) => {} }
if let Err(e) = crate::utils::validation::validate_email(&req.email) {
return ApiResponse::validation_error_typed(format!("{e}"));
}
let username_key = format!("user:credentials:{}", req.username);
match state.auth_framework.storage().get_kv(&username_key).await {
Ok(Some(_)) => {
return ApiResponse::error_typed(
"CONFLICT",
"An account with the provided details already exists",
);
}
Err(e) => {
tracing::error!("Storage error checking username: {}", e);
return ApiResponse::internal_error_typed();
}
Ok(None) => {}
}
let email_key = format!("user:email:{}", req.email);
match state.auth_framework.storage().get_kv(&email_key).await {
Ok(Some(_)) => {
return ApiResponse::error_typed(
"CONFLICT",
"An account with the provided details already exists",
);
}
Err(e) => {
tracing::error!("Storage error checking email: {}", e);
return ApiResponse::internal_error_typed();
}
Ok(None) => {}
}
let password_hash = match crate::utils::password::hash_password(&req.password) {
Ok(hash) => hash,
Err(e) => {
tracing::error!("Password hashing failed: {:?}", e);
return ApiResponse::error_typed("REGISTRATION_FAILED", "Failed to process password");
}
};
let user_id = format!("user_{}", uuid::Uuid::new_v4().as_simple());
let created_at = chrono::Utc::now().to_rfc3339();
let user_data = serde_json::json!({
"user_id": user_id,
"username": req.username,
"email": req.email,
"password_hash": password_hash,
"created_at": created_at,
});
let user_data_bytes = user_data.to_string().into_bytes();
if let Err(e) = state
.auth_framework
.storage()
.store_kv(&username_key, &user_data_bytes, None)
.await
{
tracing::error!("User registration storage failed: {:?}", e);
return ApiResponse::error_typed("REGISTRATION_FAILED", "Failed to create user account");
}
if let Err(e) = state
.auth_framework
.storage()
.store_kv(&email_key, user_id.as_bytes(), None)
.await
{
tracing::error!("Email mapping storage failed: {:?}", e);
let _ = state
.auth_framework
.storage()
.delete_kv(&username_key)
.await;
return ApiResponse::error_typed("REGISTRATION_FAILED", "Failed to create user account");
}
let canonical_user_data = serde_json::json!({
"user_id": user_id,
"username": req.username,
"email": req.email,
"password_hash": password_hash,
"roles": ["user"],
"active": true,
"created_at": created_at,
});
let canonical_key = format!("user:{}", user_id);
if let Err(e) = state
.auth_framework
.storage()
.store_kv(
&canonical_key,
canonical_user_data.to_string().as_bytes(),
None,
)
.await
{
tracing::error!("Canonical user record storage failed: {:?}", e);
let _ = state
.auth_framework
.storage()
.delete_kv(&username_key)
.await;
let _ = state.auth_framework.storage().delete_kv(&email_key).await;
return ApiResponse::error_typed("REGISTRATION_FAILED", "Failed to create user account");
}
let username_id_key = format!("user:username:{}", req.username);
if let Err(e) = state
.auth_framework
.storage()
.store_kv(&username_id_key, user_id.as_bytes(), None)
.await
{
tracing::error!("Username-id mapping storage failed: {:?}", e);
let _ = state
.auth_framework
.storage()
.delete_kv(&username_key)
.await;
let _ = state.auth_framework.storage().delete_kv(&email_key).await;
let _ = state
.auth_framework
.storage()
.delete_kv(&canonical_key)
.await;
return ApiResponse::error_typed("REGISTRATION_FAILED", "Failed to create user account");
}
let index_key = "users:index";
let mut ids: Vec<String> = match state.auth_framework.storage().get_kv(index_key).await {
Ok(Some(bytes)) => serde_json::from_slice(&bytes).unwrap_or_default(),
_ => vec![],
};
ids.push(user_id.clone());
if let Ok(idx_json) = serde_json::to_vec(&ids) {
if let Err(e) = state
.auth_framework
.storage()
.store_kv(index_key, &idx_json, None)
.await
{
tracing::warn!("Failed to update user index after registration: {}", e);
}
}
tracing::info!("New user registered: {} ({})", req.username, user_id);
ApiResponse::success(RegisterResponse {
user_id,
username: req.username,
email: req.email,
})
}
#[derive(Debug, Serialize)]
pub struct CreateApiKeyResponse {
pub api_key: String,
pub token_type: String,
}
pub async fn create_api_key(
State(state): State<ApiState>,
headers: HeaderMap,
) -> ApiResponse<CreateApiKeyResponse> {
let token = match crate::api::extract_bearer_token(&headers) {
Some(t) => t,
None => return ApiResponse::unauthorized_typed(),
};
let auth_token = match crate::api::validate_api_token(&state.auth_framework, &token).await {
Ok(t) => t,
Err(_) => return ApiResponse::unauthorized_typed(),
};
match state
.auth_framework
.create_api_key(&auth_token.user_id, None)
.await
{
Ok(api_key) => ApiResponse::success(CreateApiKeyResponse {
api_key,
token_type: "ApiKey".to_string(),
}),
Err(e) => {
tracing::error!(
"Failed to create API key for user {}: {}",
auth_token.user_id,
e
);
ApiResponse::internal_error_typed()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::HeaderMap;
#[test]
fn test_login_risk_low_normal_request() {
let mut headers = HeaderMap::new();
headers.insert(
"user-agent",
"Mozilla/5.0 (Windows NT 10.0; Win64; x64)".parse().unwrap(),
);
let (level, warnings) = login_risk_level(&headers);
assert_eq!(level, "low");
assert!(warnings.is_empty());
}
#[test]
fn test_login_risk_high_no_user_agent() {
let headers = HeaderMap::new();
let (level, warnings) = login_risk_level(&headers);
assert_eq!(level, "high");
assert!(!warnings.is_empty());
}
#[test]
fn test_login_risk_tor_browser() {
let mut headers = HeaderMap::new();
headers.insert("user-agent", "Mozilla/5.0 (Tor Browser)".parse().unwrap());
let (level, warnings) = login_risk_level(&headers);
assert!(level == "high" || level == "critical");
assert!(warnings.iter().any(|w| w.contains("Tor")));
}
#[test]
fn test_login_risk_proxy_hops() {
let mut headers = HeaderMap::new();
headers.insert("user-agent", "Mozilla/5.0".parse().unwrap());
headers.insert("x-forwarded-for", "1.2.3.4, 5.6.7.8".parse().unwrap());
let (level, warnings) = login_risk_level(&headers);
assert_eq!(level, "medium");
assert!(warnings.iter().any(|w| w.contains("proxy")));
}
#[test]
fn test_login_request_deserialization() {
let json = r#"{"username":"alice","password":"secret"}"#;
let req: LoginRequest = serde_json::from_str(json).unwrap();
assert_eq!(req.username, "alice");
assert_eq!(req.password, "secret");
assert!(!req.remember_me);
assert!(req.challenge_id.is_none());
assert!(req.mfa_code.is_none());
}
#[test]
fn test_login_response_serialization() {
let resp = LoginResponse {
access_token: "at".into(),
refresh_token: "rt".into(),
token_type: "Bearer".into(),
expires_in: 3600,
user: LoginUserInfo {
id: "uid".into(),
username: "alice".into(),
roles: vec!["user".into()],
permissions: vec![],
},
login_risk_level: "low".into(),
security_warnings: vec![],
};
let json = serde_json::to_value(&resp).unwrap();
assert_eq!(json["token_type"], "Bearer");
assert_eq!(json["expires_in"], 3600);
assert_eq!(json["user"]["username"], "alice");
}
#[test]
fn test_register_request_deserialization() {
let json = r#"{"username":"bob","password":"StrongP@ss1","email":"bob@example.com"}"#;
let req: RegisterRequest = serde_json::from_str(json).unwrap();
assert_eq!(req.username, "bob");
assert_eq!(req.email, "bob@example.com");
}
#[test]
fn test_refresh_request_deserialization() {
let json = r#"{"refresh_token":"some_token"}"#;
let req: RefreshRequest = serde_json::from_str(json).unwrap();
assert_eq!(req.refresh_token, "some_token");
}
}