use axum::{
extract::{Path, State},
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use crate::auth::{user::User, Authenticated};
#[derive(Debug, Serialize, Deserialize)]
pub struct AssignRoleRequest {
pub role: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct RoleResponse {
pub user_id: i64,
pub roles: Vec<String>,
pub message: String,
}
pub async fn get_user_roles(
State(db): State<PgPool>,
Authenticated(admin): Authenticated<User>,
Path(user_id): Path<i64>,
) -> Result<Response, StatusCode> {
if !admin.roles.contains(&"admin".to_string()) {
tracing::warn!(
admin_id = admin.id,
user_id = user_id,
"Non-admin attempted to view user roles"
);
return Err(StatusCode::FORBIDDEN);
}
let user = match User::find_by_id(user_id, &db).await {
Ok(user) => user,
Err(e) => {
tracing::error!(error = ?e, user_id = user_id, "Failed to fetch user");
return Err(StatusCode::NOT_FOUND);
}
};
let response = RoleResponse {
user_id: user.id,
roles: user.roles,
message: "Roles retrieved successfully".to_string(),
};
Ok((StatusCode::OK, Json(response)).into_response())
}
#[allow(clippy::cognitive_complexity)] pub async fn assign_role(
State(db): State<PgPool>,
Authenticated(admin): Authenticated<User>,
Path(user_id): Path<i64>,
Json(request): Json<AssignRoleRequest>,
) -> Result<Response, StatusCode> {
if !admin.roles.contains(&"admin".to_string()) {
tracing::warn!(
admin_id = admin.id,
user_id = user_id,
role = %request.role,
"Non-admin attempted to assign role"
);
return Err(StatusCode::FORBIDDEN);
}
let valid_roles = ["user", "moderator", "admin"];
if !valid_roles.contains(&request.role.as_str()) {
tracing::warn!(role = %request.role, "Invalid role name");
return Err(StatusCode::BAD_REQUEST);
}
let mut user = match User::find_by_id(user_id, &db).await {
Ok(user) => user,
Err(e) => {
tracing::error!(error = ?e, user_id = user_id, "Failed to fetch user");
return Err(StatusCode::NOT_FOUND);
}
};
if !user.roles.contains(&request.role) {
user.roles.push(request.role.clone());
match sqlx::query(
r"UPDATE users SET roles = $1 WHERE id = $2"
)
.bind(&user.roles)
.bind(user.id)
.execute(&db)
.await
{
Ok(_) => {
tracing::info!(
admin_id = admin.id,
user_id = user.id,
role = %request.role,
"Role assigned successfully"
);
}
Err(e) => {
tracing::error!(error = ?e, user_id = user.id, "Failed to update user roles");
return Err(StatusCode::INTERNAL_SERVER_ERROR);
}
}
}
let response = RoleResponse {
user_id: user.id,
roles: user.roles.clone(),
message: format!("Role '{}' assigned successfully", request.role),
};
Ok((StatusCode::OK, Json(response)).into_response())
}
#[allow(clippy::cognitive_complexity)] pub async fn remove_role(
State(db): State<PgPool>,
Authenticated(admin): Authenticated<User>,
Path((user_id, role)): Path<(i64, String)>,
) -> Result<Response, StatusCode> {
if !admin.roles.contains(&"admin".to_string()) {
tracing::warn!(
admin_id = admin.id,
user_id = user_id,
role = %role,
"Non-admin attempted to remove role"
);
return Err(StatusCode::FORBIDDEN);
}
let mut user = match User::find_by_id(user_id, &db).await {
Ok(user) => user,
Err(e) => {
tracing::error!(error = ?e, user_id = user_id, "Failed to fetch user");
return Err(StatusCode::NOT_FOUND);
}
};
if role == "user" {
tracing::warn!(
admin_id = admin.id,
user_id = user.id,
"Attempted to remove required 'user' role"
);
return Err(StatusCode::BAD_REQUEST);
}
user.roles.retain(|r| r != &role);
match sqlx::query(
r"UPDATE users SET roles = $1 WHERE id = $2"
)
.bind(&user.roles)
.bind(user.id)
.execute(&db)
.await
{
Ok(_) => {
tracing::info!(
admin_id = admin.id,
user_id = user.id,
role = %role,
"Role removed successfully"
);
}
Err(e) => {
tracing::error!(error = ?e, user_id = user.id, "Failed to update user roles");
return Err(StatusCode::INTERNAL_SERVER_ERROR);
}
}
let response = RoleResponse {
user_id: user.id,
roles: user.roles.clone(),
message: format!("Role '{role}' removed successfully"),
};
Ok((StatusCode::OK, Json(response)).into_response())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_assign_role_request_serialization() {
let request = AssignRoleRequest {
role: "moderator".to_string(),
};
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("moderator"));
}
#[test]
fn test_role_response_serialization() {
let response = RoleResponse {
user_id: 123,
roles: vec!["user".to_string(), "moderator".to_string()],
message: "Success".to_string(),
};
let json = serde_json::to_string(&response).unwrap();
assert!(json.contains("\"user_id\":123"));
assert!(json.contains("moderator"));
}
}