#[cfg(feature = "portal")]
use axum::{
extract::{Json, State},
http::StatusCode,
response::IntoResponse,
};
#[cfg(feature = "portal")]
use chrono::{Duration, Utc};
#[cfg(feature = "portal")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "portal")]
use sha2::{Digest, Sha256};
#[cfg(feature = "portal")]
use std::sync::Arc;
#[cfg(feature = "portal")]
use uuid::Uuid;
#[cfg(feature = "portal")]
use crate::portal::auth::{AuthConfig, AuthError, AuthService, TokenPair};
#[cfg(feature = "portal")]
use crate::portal::db::{
models::{CreateSession, CreateUser},
pool::DatabasePool,
queries::{SessionRepository, UserRepository},
DbError,
};
#[cfg(feature = "portal")]
#[derive(Clone)]
pub struct PortalState {
pub db: DatabasePool,
pub auth: Arc<AuthService>,
}
#[cfg(feature = "portal")]
impl PortalState {
pub fn new(db: DatabasePool, auth_config: AuthConfig) -> Self {
Self {
db,
auth: Arc::new(AuthService::new(auth_config)),
}
}
}
#[cfg(feature = "portal")]
#[derive(Debug, Deserialize)]
pub struct RegisterRequest {
pub email: String,
pub password: String,
pub gdpr_consent: bool,
}
#[cfg(feature = "portal")]
#[derive(Debug, Serialize)]
pub struct RegisterResponse {
pub success: bool,
pub user_id: Option<String>,
pub message: String,
}
#[cfg(feature = "portal")]
#[derive(Debug, Deserialize)]
pub struct LoginRequest {
pub email: String,
pub password: String,
}
#[cfg(feature = "portal")]
#[derive(Debug, Serialize)]
pub struct LoginResponse {
pub success: bool,
pub tokens: Option<TokenPair>,
pub user_id: Option<String>,
pub message: String,
}
#[cfg(feature = "portal")]
#[derive(Debug, Deserialize)]
pub struct RefreshRequest {
pub refresh_token: String,
}
#[cfg(feature = "portal")]
fn hash_refresh_token(token: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(token.as_bytes());
format!("{:x}", hasher.finalize())
}
#[cfg(feature = "portal")]
fn validate_password(password: &str) -> Result<(), &'static str> {
if password.len() < 12 {
return Err("Password must be at least 12 characters");
}
if !password.chars().any(|c| c.is_ascii_uppercase()) {
return Err("Password must contain at least one uppercase letter");
}
if !password.chars().any(|c| c.is_ascii_lowercase()) {
return Err("Password must contain at least one lowercase letter");
}
if !password.chars().any(|c| c.is_ascii_digit()) {
return Err("Password must contain at least one digit");
}
Ok(())
}
#[cfg(feature = "portal")]
fn validate_email(email: &str) -> Result<(), &'static str> {
if !email.contains('@') || !email.contains('.') {
return Err("Invalid email format");
}
if email.len() > 255 {
return Err("Email too long");
}
Ok(())
}
#[cfg(feature = "portal")]
pub async fn register(
State(state): State<PortalState>,
Json(req): Json<RegisterRequest>,
) -> impl IntoResponse {
if !req.gdpr_consent {
return (
StatusCode::BAD_REQUEST,
Json(RegisterResponse {
success: false,
user_id: None,
message: "GDPR consent is required".to_string(),
}),
);
}
if let Err(e) = validate_email(&req.email) {
return (
StatusCode::BAD_REQUEST,
Json(RegisterResponse {
success: false,
user_id: None,
message: e.to_string(),
}),
);
}
if let Err(e) = validate_password(&req.password) {
return (
StatusCode::BAD_REQUEST,
Json(RegisterResponse {
success: false,
user_id: None,
message: e.to_string(),
}),
);
}
let user_repo = UserRepository::new(state.db.pool());
match user_repo.email_exists(&req.email).await {
Ok(true) => {
return (
StatusCode::CONFLICT,
Json(RegisterResponse {
success: false,
user_id: None,
message: "Email already registered".to_string(),
}),
);
}
Ok(false) => {}
Err(e) => {
tracing::error!("Database error checking email: {}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(RegisterResponse {
success: false,
user_id: None,
message: "Registration failed".to_string(),
}),
);
}
}
let password_hash = match AuthService::hash_password(&req.password) {
Ok(hash) => hash,
Err(e) => {
tracing::error!("Password hashing error: {}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(RegisterResponse {
success: false,
user_id: None,
message: "Registration failed".to_string(),
}),
);
}
};
let create_user = CreateUser {
email: req.email.to_lowercase(),
password_hash,
gdpr_consent_version: Some("1.0".to_string()),
};
match user_repo.create(create_user).await {
Ok(user) => {
tracing::info!("User registered: {}", user.email);
(
StatusCode::CREATED,
Json(RegisterResponse {
success: true,
user_id: Some(user.id.to_string()),
message: "Registration successful. Please verify your email.".to_string(),
}),
)
}
Err(DbError::DuplicateEntry(_)) => (
StatusCode::CONFLICT,
Json(RegisterResponse {
success: false,
user_id: None,
message: "Email already registered".to_string(),
}),
),
Err(e) => {
tracing::error!("Database error creating user: {}", e);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(RegisterResponse {
success: false,
user_id: None,
message: "Registration failed".to_string(),
}),
)
}
}
}
#[cfg(feature = "portal")]
pub async fn login(
State(state): State<PortalState>,
Json(req): Json<LoginRequest>,
) -> impl IntoResponse {
let user_repo = UserRepository::new(state.db.pool());
let session_repo = SessionRepository::new(state.db.pool());
let user = match user_repo.find_by_email(&req.email.to_lowercase()).await {
Ok(user) => user,
Err(DbError::NotFound) => {
let _ = AuthService::hash_password("dummy_password");
return (
StatusCode::UNAUTHORIZED,
Json(LoginResponse {
success: false,
tokens: None,
user_id: None,
message: "Invalid email or password".to_string(),
}),
);
}
Err(e) => {
tracing::error!("Database error finding user: {}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(LoginResponse {
success: false,
tokens: None,
user_id: None,
message: "Login failed".to_string(),
}),
);
}
};
if user.deleted_at.is_some() {
return (
StatusCode::UNAUTHORIZED,
Json(LoginResponse {
success: false,
tokens: None,
user_id: None,
message: "Account has been deleted".to_string(),
}),
);
}
let password_valid = match AuthService::verify_password(&req.password, &user.password_hash) {
Ok(valid) => valid,
Err(e) => {
tracing::error!("Password verification error: {}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(LoginResponse {
success: false,
tokens: None,
user_id: None,
message: "Login failed".to_string(),
}),
);
}
};
if !password_valid {
return (
StatusCode::UNAUTHORIZED,
Json(LoginResponse {
success: false,
tokens: None,
user_id: None,
message: "Invalid email or password".to_string(),
}),
);
}
let scopes = if user.email_verified_at.is_some() {
vec!["read".to_string(), "write".to_string()]
} else {
vec!["read".to_string()] };
let tokens = state
.auth
.generate_token_pair(&user.id.to_string(), &user.email, scopes);
let refresh_token_hash = hash_refresh_token(&tokens.refresh_token);
let create_session = CreateSession {
user_id: user.id,
refresh_token_hash,
device_fingerprint: None, ip_address: None, user_agent: None, expires_at: Utc::now() + Duration::days(7),
};
if let Err(e) = session_repo.create(create_session).await {
tracing::error!("Failed to create session: {}", e);
}
tracing::info!("User logged in: {}", user.email);
(
StatusCode::OK,
Json(LoginResponse {
success: true,
tokens: Some(tokens),
user_id: Some(user.id.to_string()),
message: "Login successful".to_string(),
}),
)
}
#[cfg(feature = "portal")]
pub async fn refresh_token(
State(state): State<PortalState>,
Json(req): Json<RefreshRequest>,
) -> impl IntoResponse {
let session_repo = SessionRepository::new(state.db.pool());
let user_repo = UserRepository::new(state.db.pool());
let claims = match state.auth.validate_token(&req.refresh_token) {
Ok(claims) => claims,
Err(_) => {
return (
StatusCode::UNAUTHORIZED,
Json(LoginResponse {
success: false,
tokens: None,
user_id: None,
message: "Invalid refresh token".to_string(),
}),
);
}
};
if claims.token_type != crate::portal::auth::TokenType::Refresh {
return (
StatusCode::UNAUTHORIZED,
Json(LoginResponse {
success: false,
tokens: None,
user_id: None,
message: "Invalid token type".to_string(),
}),
);
}
let token_hash = hash_refresh_token(&req.refresh_token);
let session = match session_repo.find_by_token_hash(&token_hash).await {
Ok(session) => session,
Err(DbError::NotFound) => {
return (
StatusCode::UNAUTHORIZED,
Json(LoginResponse {
success: false,
tokens: None,
user_id: None,
message: "Session not found or expired".to_string(),
}),
);
}
Err(e) => {
tracing::error!("Database error finding session: {}", e);
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(LoginResponse {
success: false,
tokens: None,
user_id: None,
message: "Token refresh failed".to_string(),
}),
);
}
};
if let Err(e) = session_repo.revoke(session.id).await {
tracing::error!("Failed to revoke old session: {}", e);
}
let user = match user_repo.find_by_id(session.user_id).await {
Ok(user) => user,
Err(_) => {
return (
StatusCode::UNAUTHORIZED,
Json(LoginResponse {
success: false,
tokens: None,
user_id: None,
message: "User not found".to_string(),
}),
);
}
};
let scopes = if user.email_verified_at.is_some() {
vec!["read".to_string(), "write".to_string()]
} else {
vec!["read".to_string()]
};
let tokens = state
.auth
.generate_token_pair(&user.id.to_string(), &user.email, scopes);
let new_token_hash = hash_refresh_token(&tokens.refresh_token);
let create_session = CreateSession {
user_id: user.id,
refresh_token_hash: new_token_hash,
device_fingerprint: session.device_fingerprint,
ip_address: session.ip_address,
user_agent: session.user_agent,
expires_at: Utc::now() + Duration::days(7),
};
if let Err(e) = session_repo.create(create_session).await {
tracing::error!("Failed to create new session: {}", e);
}
(
StatusCode::OK,
Json(LoginResponse {
success: true,
tokens: Some(tokens),
user_id: Some(user.id.to_string()),
message: "Token refreshed".to_string(),
}),
)
}
#[cfg(feature = "portal")]
pub async fn logout(
State(state): State<PortalState>,
Json(req): Json<RefreshRequest>,
) -> impl IntoResponse {
let session_repo = SessionRepository::new(state.db.pool());
let token_hash = hash_refresh_token(&req.refresh_token);
match session_repo.find_by_token_hash(&token_hash).await {
Ok(session) => {
if let Err(e) = session_repo.revoke(session.id).await {
tracing::error!("Failed to revoke session: {}", e);
}
}
Err(_) => {
}
}
(
StatusCode::OK,
Json(serde_json::json!({
"success": true,
"message": "Logged out successfully"
})),
)
}