use axum::{
extract::State,
http::{header, HeaderMap, HeaderValue},
response::IntoResponse,
Json,
};
use serde::Deserialize;
use std::sync::Arc;
use super::call_logout_callback_with_timeout;
use crate::callback::AuthCallback;
use crate::errors::AppError;
use axum::http::StatusCode;
use crate::models::{MessageResponse, UserResponse};
use crate::repositories::AuditEventType;
use crate::services::EmailService;
use crate::utils::{
authenticate, build_logout_cookies, extract_access_token, extract_cookie, hash_refresh_token,
user_entity_to_auth_user,
};
use crate::AppState;
pub async fn logout<C: AuthCallback, E: EmailService>(
State(state): State<Arc<AppState<C, E>>>,
headers: HeaderMap,
) -> Result<impl IntoResponse, AppError> {
let token = extract_access_token(&headers, &state.config.cookie.access_cookie_name);
let claims = token.and_then(|token| state.jwt_service.validate_access_token(&token).ok());
if claims.is_none() && !state.config.cookie.enabled {
return Err(AppError::InvalidToken);
}
if let Some(claims) = claims {
if let Some(session) = state.session_repo.find_by_id(claims.sid).await? {
if session.user_id == claims.sub {
state
.session_repo
.revoke_with_reason(claims.sid, "logout")
.await?;
call_logout_callback_with_timeout(&state.callback, &claims.sub.to_string()).await;
if let Err(e) = state
.audit_service
.log_user_event(AuditEventType::UserLogout, claims.sub, Some(&headers))
.await
{
tracing::warn!(error = %e, user_id = %claims.sub, "Failed to log user logout audit event");
}
}
}
} else if state.config.cookie.enabled {
if let Some(refresh_token) =
extract_cookie(&headers, &state.config.cookie.refresh_cookie_name)
{
if let Some(session) = state
.session_repo
.find_by_refresh_token(&hash_refresh_token(
&refresh_token,
&state.config.jwt.secret,
))
.await?
{
state
.session_repo
.revoke_with_reason(session.id, "logout")
.await?;
}
}
}
let message = MessageResponse {
message: "Logged out successfully".to_string(),
};
if state.config.cookie.enabled {
let cookies = build_logout_cookies(&state.config.cookie);
let mut resp = Json(message).into_response();
let headers_mut = resp.headers_mut();
for cookie in cookies {
if let Ok(value) = HeaderValue::from_str(&cookie) {
headers_mut.append(header::SET_COOKIE, value);
}
}
Ok(resp)
} else {
Ok(Json(message).into_response())
}
}
pub async fn logout_all<C: AuthCallback, E: EmailService>(
State(state): State<Arc<AppState<C, E>>>,
headers: HeaderMap,
) -> Result<impl IntoResponse, AppError> {
let token = extract_access_token(&headers, &state.config.cookie.access_cookie_name);
let claims = token.and_then(|token| state.jwt_service.validate_access_token(&token).ok());
if claims.is_none() && !state.config.cookie.enabled {
return Err(AppError::InvalidToken);
}
if let Some(claims) = claims {
if let Some(session) = state.session_repo.find_by_id(claims.sid).await? {
if session.user_id == claims.sub {
state
.session_repo
.revoke_all_for_user_with_reason(claims.sub, "logout_all")
.await?;
call_logout_callback_with_timeout(&state.callback, &claims.sub.to_string()).await;
if let Err(e) = state
.audit_service
.log_user_event(AuditEventType::UserLogoutAll, claims.sub, Some(&headers))
.await
{
tracing::warn!(error = %e, user_id = %claims.sub, "Failed to log user logout-all audit event");
}
}
}
} else if state.config.cookie.enabled {
if let Some(refresh_token) =
extract_cookie(&headers, &state.config.cookie.refresh_cookie_name)
{
if let Some(session) = state
.session_repo
.find_by_refresh_token(&hash_refresh_token(
&refresh_token,
&state.config.jwt.secret,
))
.await?
{
state
.session_repo
.revoke_all_for_user_with_reason(session.user_id, "logout_all")
.await?;
}
}
}
let message = MessageResponse {
message: "Logged out from all devices successfully".to_string(),
};
if state.config.cookie.enabled {
let cookies = build_logout_cookies(&state.config.cookie);
let mut resp = Json(message).into_response();
let headers_mut = resp.headers_mut();
for cookie in cookies {
if let Ok(value) = HeaderValue::from_str(&cookie) {
headers_mut.append(header::SET_COOKIE, value);
}
}
Ok(resp)
} else {
Ok(Json(message).into_response())
}
}
pub async fn get_user<C: AuthCallback, E: EmailService>(
State(state): State<Arc<AppState<C, E>>>,
headers: HeaderMap,
) -> Result<impl IntoResponse, AppError> {
let auth = authenticate(&state, &headers).await?;
let user = state
.user_repo
.find_by_id(auth.user_id)
.await?
.ok_or(AppError::InvalidToken)?;
Ok(Json(UserResponse {
user: user_entity_to_auth_user(&user),
}))
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct UpdateProfileRequest {
pub name: Option<String>,
pub username: Option<String>,
pub picture: Option<String>,
}
pub async fn update_profile<C: AuthCallback, E: EmailService>(
State(state): State<Arc<AppState<C, E>>>,
headers: HeaderMap,
Json(req): Json<UpdateProfileRequest>,
) -> Result<impl IntoResponse, AppError> {
let auth = authenticate(&state, &headers).await?;
let mut user = state
.user_repo
.find_by_id(auth.user_id)
.await?
.ok_or(AppError::InvalidToken)?;
if let Some(name) = req.name {
let trimmed = name.trim();
if trimmed.is_empty() {
return Err(AppError::Validation("Name cannot be empty".into()));
}
if trimmed.len() > 100 {
return Err(AppError::Validation(
"Name must be 100 characters or less".into(),
));
}
user.name = Some(trimmed.to_string());
}
if let Some(username) = req.username {
let trimmed = username.trim().to_lowercase();
validate_username(&trimmed)?;
if state.user_repo.username_exists(&trimmed).await? {
return Err(AppError::Validation("Username is already taken".into()));
}
user.username = Some(trimmed);
}
if let Some(picture) = req.picture {
let trimmed = picture.trim();
if !trimmed.is_empty() {
if !trimmed.starts_with("https://") && !trimmed.starts_with("http://") {
return Err(AppError::Validation("Picture must be a valid URL".into()));
}
if trimmed.len() > 2048 {
return Err(AppError::Validation(
"Picture URL must be 2048 characters or less".into(),
));
}
user.picture = Some(trimmed.to_string());
} else {
user.picture = None;
}
}
let updated_user = state.user_repo.update(user).await?;
let _ = state
.audit_service
.log_user_event(
AuditEventType::UserProfileUpdated,
auth.user_id,
Some(&headers),
)
.await;
Ok(Json(UserResponse {
user: user_entity_to_auth_user(&updated_user),
}))
}
const RESERVED_USERNAMES: &[&str] = &[
"admin", "system", "support", "help", "root", "moderator", "mod", "staff", "cedros",
];
fn validate_username(username: &str) -> Result<(), AppError> {
if username.len() < 3 {
return Err(AppError::Validation(
"Username must be at least 3 characters".into(),
));
}
if username.len() > 30 {
return Err(AppError::Validation(
"Username must be 30 characters or less".into(),
));
}
if !username
.chars()
.all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '_')
{
return Err(AppError::Validation(
"Username may only contain lowercase letters, numbers, and underscores".into(),
));
}
if username.starts_with('_') || username.ends_with('_') {
return Err(AppError::Validation(
"Username cannot start or end with an underscore".into(),
));
}
if RESERVED_USERNAMES.contains(&username) {
return Err(AppError::Validation("This username is reserved".into()));
}
Ok(())
}
pub async fn welcome_completed<C: AuthCallback, E: EmailService>(
State(state): State<Arc<AppState<C, E>>>,
headers: HeaderMap,
) -> Result<impl IntoResponse, AppError> {
let auth = authenticate(&state, &headers).await?;
state
.user_repo
.set_welcome_completed(auth.user_id)
.await?;
Ok(StatusCode::NO_CONTENT)
}