use std::{net::SocketAddr, sync::Arc};
use axum::{
Json,
extract::{ConnectInfo, Query, State},
http::StatusCode,
response::IntoResponse,
};
use serde::{Deserialize, Serialize};
use crate::{
audit_logger::{AuditEventType, SecretType, get_audit_logger},
error::{AuthError, Result},
provider::OAuthProvider,
rate_limiting::RateLimiters,
session::SessionStore,
state_store::StateStore,
};
#[derive(Clone)]
pub struct AuthState {
pub oauth_provider: Arc<dyn OAuthProvider>,
pub session_store: Arc<dyn SessionStore>,
pub state_store: Arc<dyn StateStore>,
pub rate_limiters: Arc<RateLimiters>,
}
#[derive(Debug, Deserialize)]
pub struct AuthStartRequest {
pub provider: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct AuthStartResponse {
pub authorization_url: String,
}
#[derive(Debug, Deserialize)]
pub struct AuthCallbackQuery {
pub code: String,
pub state: String,
pub error: Option<String>,
pub error_description: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct AuthCallbackResponse {
pub access_token: String,
pub refresh_token: Option<String>,
pub token_type: String,
pub expires_in: u64,
}
#[derive(Debug, Deserialize)]
pub struct AuthRefreshRequest {
pub refresh_token: String,
}
#[derive(Debug, Serialize)]
pub struct AuthRefreshResponse {
pub access_token: String,
pub token_type: String,
pub expires_in: u64,
}
#[derive(Debug, Deserialize)]
pub struct AuthLogoutRequest {
pub refresh_token: Option<String>,
}
pub async fn auth_start(
State(state): State<AuthState>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
Json(req): Json<AuthStartRequest>,
) -> Result<Json<AuthStartResponse>> {
let client_ip = addr.ip().to_string();
if state.rate_limiters.auth_start.check(&client_ip).is_err() {
return Err(AuthError::RateLimited {
retry_after_secs: state.rate_limiters.auth_start.clone_config().window_secs,
});
}
let state_value = generate_secure_state();
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_err(|_| AuthError::SystemTimeError {
message: "Failed to get current system time".to_string(),
})?
.as_secs();
let expiry = now + 600;
let provider = req.provider.unwrap_or_else(|| "default".to_string());
state.state_store.store(state_value.clone(), provider, expiry).await?;
let authorization_url = state.oauth_provider.authorization_url(&state_value);
Ok(Json(AuthStartResponse { authorization_url }))
}
pub async fn auth_callback(
State(state): State<AuthState>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
Query(query): Query<AuthCallbackQuery>,
) -> Result<impl IntoResponse> {
let client_ip = addr.ip().to_string();
if state.rate_limiters.auth_callback.check(&client_ip).is_err() {
return Err(AuthError::RateLimited {
retry_after_secs: state.rate_limiters.auth_callback.clone_config().window_secs,
});
}
if let Some(error) = query.error {
let audit_logger = get_audit_logger();
audit_logger.log_failure(
AuditEventType::OauthCallback,
SecretType::AuthorizationCode,
None,
"exchange",
&error,
);
return Err(AuthError::OAuthError {
message: format!("{}: {}", error, query.error_description.unwrap_or_default()),
});
}
let (_provider_name, expiry) = state.state_store.retrieve(&query.state).await?;
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_err(|_| AuthError::SystemTimeError {
message: "Failed to get current system time".to_string(),
})?
.as_secs();
if now > expiry {
let audit_logger = get_audit_logger();
audit_logger.log_failure(
AuditEventType::CsrfStateValidated,
SecretType::StateToken,
None,
"validate",
"State token expired",
);
return Err(AuthError::InvalidState);
}
let audit_logger = get_audit_logger();
audit_logger.log_success(
AuditEventType::CsrfStateValidated,
SecretType::StateToken,
None,
"validate",
);
let token_response = state.oauth_provider.exchange_code(&query.code).await?;
let audit_logger = get_audit_logger();
audit_logger.log_success(
AuditEventType::OauthCallback,
SecretType::AuthorizationCode,
None,
"exchange",
);
let user_info = state.oauth_provider.user_info(&token_response.access_token).await?;
let expires_at = now + (7 * 24 * 60 * 60);
let session_tokens = state.session_store.create_session(&user_info.id, expires_at).await?;
let audit_logger = get_audit_logger();
audit_logger.log_success(
AuditEventType::SessionTokenCreated,
SecretType::SessionToken,
Some(user_info.id.clone()),
"create",
);
let audit_logger = get_audit_logger();
audit_logger.log_success(
AuditEventType::AuthSuccess,
SecretType::SessionToken,
Some(user_info.id),
"oauth_flow",
);
let response = AuthCallbackResponse {
access_token: session_tokens.access_token,
refresh_token: Some(session_tokens.refresh_token),
token_type: "Bearer".to_string(),
expires_in: session_tokens.expires_in,
};
Ok(Json(response))
}
pub async fn auth_refresh(
State(state): State<AuthState>,
Json(req): Json<AuthRefreshRequest>,
) -> Result<Json<AuthRefreshResponse>> {
use crate::session::hash_token;
let token_hash = hash_token(&req.refresh_token);
let session = state.session_store.get_session(&token_hash).await?;
if state.rate_limiters.auth_refresh.check(&session.user_id).is_err() {
return Err(AuthError::RateLimited {
retry_after_secs: state.rate_limiters.auth_refresh.clone_config().window_secs,
});
}
let audit_logger = get_audit_logger();
audit_logger.log_success(
AuditEventType::SessionTokenValidation,
SecretType::RefreshToken,
Some(session.user_id.clone()),
"validate",
);
let access_token = format!("new_access_token_{}", uuid::Uuid::new_v4());
let audit_logger = get_audit_logger();
audit_logger.log_success(
AuditEventType::JwtRefresh,
SecretType::JwtToken,
Some(session.user_id),
"refresh",
);
Ok(Json(AuthRefreshResponse {
access_token,
token_type: "Bearer".to_string(),
expires_in: 3600,
}))
}
pub async fn auth_logout(
State(state): State<AuthState>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
Json(req): Json<AuthLogoutRequest>,
) -> Result<StatusCode> {
let client_ip = addr.ip().to_string();
if let Some(refresh_token) = req.refresh_token {
use crate::session::hash_token;
let token_hash = hash_token(&refresh_token);
let session = state.session_store.get_session(&token_hash).await?;
if state.rate_limiters.auth_logout.check(&session.user_id).is_err() {
return Err(AuthError::RateLimited {
retry_after_secs: state.rate_limiters.auth_logout.clone_config().window_secs,
});
}
state.session_store.revoke_session(&token_hash).await?;
let audit_logger = get_audit_logger();
audit_logger.log_success(
AuditEventType::SessionTokenRevoked,
SecretType::RefreshToken,
Some(session.user_id),
"revoke",
);
} else {
if state.rate_limiters.auth_logout.check(&client_ip).is_err() {
return Err(AuthError::RateLimited {
retry_after_secs: state.rate_limiters.auth_logout.clone_config().window_secs,
});
}
}
Ok(StatusCode::NO_CONTENT)
}
pub fn generate_secure_state() -> String {
use rand::RngCore;
let mut bytes = [0u8; 32];
rand::rngs::OsRng.fill_bytes(&mut bytes);
hex::encode(bytes)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate_secure_state() {
let state1 = generate_secure_state();
let state2 = generate_secure_state();
assert_ne!(state1, state2);
assert!(!state1.is_empty());
assert!(!state2.is_empty());
assert_eq!(state1.len(), 64);
assert_eq!(state2.len(), 64);
assert!(hex::decode(&state1).is_ok());
assert!(hex::decode(&state2).is_ok());
}
}