use crate::{auth::AuthResult, error::FusekiError, server::AppState};
use axum::{
extract::{Query, State},
http::{header::SET_COOKIE, HeaderMap, StatusCode},
response::{IntoResponse, Json, Response},
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tracing::{error, info, instrument, warn};
#[derive(Debug, Deserialize)]
pub struct OAuth2AuthParams {
pub redirect_uri: Option<String>,
pub scope: Option<String>,
pub state: Option<String>,
pub use_pkce: Option<bool>,
}
#[derive(Debug, Deserialize)]
pub struct OAuth2CallbackParams {
pub code: Option<String>,
pub state: Option<String>,
pub error: Option<String>,
pub error_description: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct OAuth2RefreshRequest {
pub refresh_token: String,
}
#[derive(Debug, Serialize)]
pub struct OAuth2AuthResponse {
pub success: bool,
pub authorization_url: Option<String>,
pub state: Option<String>,
pub message: String,
}
#[derive(Debug, Serialize)]
pub struct OAuth2TokenResponse {
pub success: bool,
pub access_token: Option<String>,
pub token_type: Option<String>,
pub expires_in: Option<u64>,
pub refresh_token: Option<String>,
pub user: Option<crate::auth::User>,
pub message: String,
}
#[derive(Debug, Serialize)]
pub struct OAuth2UserInfoResponse {
pub success: bool,
pub user_info: Option<crate::auth::oauth::OIDCUserInfo>,
pub message: String,
}
#[instrument(skip(state))]
pub async fn initiate_oauth2_flow(
State(state): State<Arc<AppState>>,
Query(params): Query<OAuth2AuthParams>,
) -> Result<Json<OAuth2AuthResponse>, FusekiError> {
let auth_service = state
.auth_service
.as_ref()
.ok_or_else(|| FusekiError::service_unavailable("Authentication service not available"))?;
if !auth_service.is_oauth2_enabled() {
return Ok(Json(OAuth2AuthResponse {
success: false,
authorization_url: None,
state: None,
message: "OAuth2 authentication not configured".to_string(),
}));
}
let redirect_uri = params.redirect_uri.unwrap_or_else(|| {
format!(
"{}://{}:{}/auth/oauth2/callback",
if state.config.server.tls.is_some() {
"https"
} else {
"http"
},
state.config.server.host,
state.config.server.port
)
});
let scopes = params
.scope
.map(|s| {
s.split(' ')
.map(|scope| scope.to_string())
.collect::<Vec<_>>()
})
.unwrap_or_else(|| {
vec![
"openid".to_string(),
"profile".to_string(),
"email".to_string(),
]
});
let use_pkce = params.use_pkce.unwrap_or(true);
match auth_service
.generate_oauth2_auth_url(&redirect_uri, &scopes, use_pkce)
.await
{
Ok((authorization_url, state_param)) => {
info!(
"Generated OAuth2 authorization URL with state: {}",
state_param
);
Ok(Json(OAuth2AuthResponse {
success: true,
authorization_url: Some(authorization_url),
state: Some(state_param),
message: "OAuth2 authorization URL generated successfully".to_string(),
}))
}
Err(e) => {
error!("Failed to generate OAuth2 authorization URL: {}", e);
Err(e)
}
}
}
#[instrument(skip(state))]
pub async fn handle_oauth2_callback(
State(state): State<Arc<AppState>>,
Query(params): Query<OAuth2CallbackParams>,
) -> Result<Response, FusekiError> {
let auth_service = state
.auth_service
.as_ref()
.ok_or_else(|| FusekiError::service_unavailable("Authentication service not available"))?;
if let Some(error) = params.error {
warn!(
"OAuth2 authorization error: {} - {}",
error,
params.error_description.unwrap_or_default()
);
let error_response = OAuth2TokenResponse {
success: false,
access_token: None,
token_type: None,
expires_in: None,
refresh_token: None,
user: None,
message: format!("OAuth2 authorization failed: {error}"),
};
return Ok((StatusCode::BAD_REQUEST, Json(error_response)).into_response());
}
let _code = params
.code
.ok_or_else(|| FusekiError::bad_request("Missing authorization code"))?;
let _state_param = params
.state
.ok_or_else(|| FusekiError::bad_request("Missing state parameter"))?;
let _redirect_uri = format!(
"{}://{}:{}/auth/oauth2/callback",
if state.config.server.tls.is_some() {
"https"
} else {
"http"
},
state.config.server.host,
state.config.server.port
);
match auth_service
.complete_oauth2_authentication(&_code, &_state_param, &_redirect_uri)
.await
{
Ok(AuthResult::Authenticated(user)) => {
info!(
"OAuth2 authentication successful for user: {}",
user.username
);
let session_id = auth_service.create_session(user.clone()).await?;
let cookie_value = format!(
"session_id={}; HttpOnly; Secure; SameSite=Strict; Max-Age={}",
session_id, state.config.security.session.timeout_secs
);
let response = OAuth2TokenResponse {
success: true,
access_token: Some(session_id),
token_type: Some("Bearer".to_string()),
expires_in: Some(state.config.security.session.timeout_secs),
refresh_token: None,
user: Some(user),
message: "OAuth2 authentication successful".to_string(),
};
let mut resp = (StatusCode::OK, Json(response)).into_response();
resp.headers_mut().insert(
SET_COOKIE,
cookie_value
.parse()
.expect("cookie value should be valid header"),
);
Ok(resp)
}
Ok(_) => {
warn!("OAuth2 authentication failed");
let response = OAuth2TokenResponse {
success: false,
access_token: None,
token_type: None,
expires_in: None,
refresh_token: None,
user: None,
message: "OAuth2 authentication failed".to_string(),
};
Ok((StatusCode::UNAUTHORIZED, Json(response)).into_response())
}
Err(e) => {
error!("OAuth2 callback processing failed: {}", e);
Err(e)
}
}
}
#[instrument(skip(state, _request))]
pub async fn refresh_oauth2_token(
State(state): State<Arc<AppState>>,
Json(_request): Json<OAuth2RefreshRequest>,
) -> Result<Json<OAuth2TokenResponse>, FusekiError> {
let auth_service = state
.auth_service
.as_ref()
.ok_or_else(|| FusekiError::service_unavailable("Authentication service not available"))?;
if !auth_service.is_oauth2_enabled() {
return Ok(Json(OAuth2TokenResponse {
success: false,
access_token: None,
token_type: None,
expires_in: None,
refresh_token: None,
user: None,
message: "OAuth2 authentication not configured".to_string(),
}));
}
match auth_service
.refresh_oauth2_token(&_request.refresh_token)
.await
{
Ok(token) => {
info!("Successfully refreshed OAuth2 access token");
Ok(Json(OAuth2TokenResponse {
success: true,
access_token: Some(token.access_token),
token_type: Some(token.token_type),
expires_in: Some(token.expires_in),
refresh_token: token.refresh_token,
user: None, message: "Token refreshed successfully".to_string(),
}))
}
Err(e) => {
warn!("Failed to refresh OAuth2 token: {}", e);
Err(e)
}
}
}
#[instrument(skip(state, _headers))]
pub async fn get_oauth2_user_info(
State(state): State<Arc<AppState>>,
_headers: HeaderMap,
) -> Result<Json<OAuth2UserInfoResponse>, FusekiError> {
let auth_service = state
.auth_service
.as_ref()
.ok_or_else(|| FusekiError::service_unavailable("Authentication service not available"))?;
if !auth_service.is_oauth2_enabled() {
return Ok(Json(OAuth2UserInfoResponse {
success: false,
user_info: None,
message: "OAuth2 authentication not configured".to_string(),
}));
}
let access_token = extract_bearer_token(&_headers)
.ok_or_else(|| FusekiError::authentication("Missing or invalid authorization header"))?;
match auth_service.get_oauth2_user_info(&access_token).await {
Ok(user_info) => {
info!(
"Successfully retrieved OAuth2 user info for user: {}",
user_info.sub
);
Ok(Json(OAuth2UserInfoResponse {
success: true,
user_info: Some(user_info),
message: "User information retrieved successfully".to_string(),
}))
}
Err(e) => {
warn!("Failed to retrieve OAuth2 user info: {}", e);
Ok(Json(OAuth2UserInfoResponse {
success: false,
user_info: None,
message: format!("Failed to retrieve user information: {e}"),
}))
}
}
}
#[instrument(skip(state, _headers))]
pub async fn validate_oauth2_token(
State(state): State<Arc<AppState>>,
_headers: HeaderMap,
) -> Result<Json<serde_json::Value>, FusekiError> {
let auth_service = state
.auth_service
.as_ref()
.ok_or_else(|| FusekiError::service_unavailable("Authentication service not available"))?;
if !auth_service.is_oauth2_enabled() {
return Err(FusekiError::service_unavailable(
"OAuth2 authentication not configured",
));
}
let access_token = extract_bearer_token(&_headers)
.ok_or_else(|| FusekiError::authentication("Missing or invalid authorization header"))?;
match auth_service.validate_access_token(&access_token).await {
Ok(valid) => {
if valid {
Ok(Json(serde_json::json!({
"valid": true,
"message": "Token is valid"
})))
} else {
Ok(Json(serde_json::json!({
"valid": false,
"message": "Token is invalid or expired"
})))
}
}
Err(e) => {
warn!("OAuth2 token validation failed: {}", e);
Err(e)
}
}
}
#[instrument(skip(state))]
pub async fn get_oauth2_config(
State(state): State<Arc<AppState>>,
) -> Result<Json<serde_json::Value>, FusekiError> {
let auth_service = state
.auth_service
.as_ref()
.ok_or_else(|| FusekiError::service_unavailable("Authentication service not available"))?;
if !auth_service.is_oauth2_enabled() {
return Ok(Json(serde_json::json!({
"enabled": false,
"message": "OAuth2 authentication not configured"
})));
}
let oauth_config = &state.config.security.oauth;
if let Some(oauth_config) = oauth_config {
let config_info = serde_json::json!({
"enabled": true,
"provider": oauth_config.provider,
"authorization_endpoint": oauth_config.auth_url,
"supported_scopes": oauth_config.scopes,
"client_id": oauth_config.client_id,
"token_endpoint": oauth_config.token_url,
"user_info_endpoint": oauth_config.user_info_url,
});
Ok(Json(config_info))
} else {
Ok(Json(serde_json::json!({
"enabled": false,
"message": "OAuth2 not configured"
})))
}
}
#[instrument(skip(state))]
pub async fn oauth2_discovery(
State(state): State<Arc<AppState>>,
) -> Result<Json<serde_json::Value>, FusekiError> {
let auth_service = state
.auth_service
.as_ref()
.ok_or_else(|| FusekiError::service_unavailable("Authentication service not available"))?;
let oauth_config = auth_service
.get_oauth_config()
.ok_or_else(|| FusekiError::service_unavailable("OAuth2 authentication not configured"))?;
let base_url = format!(
"http://{}:{}",
state.config.server.host, state.config.server.port
);
let discovery_doc = serde_json::json!({
"issuer": base_url,
"authorization_endpoint": oauth_config.auth_url,
"token_endpoint": oauth_config.token_url,
"userinfo_endpoint": oauth_config.user_info_url,
"jwks_uri": format!("{}/auth/oauth2/jwks", base_url),
"scopes_supported": oauth_config.scopes,
"response_types_supported": ["code", "token", "id_token", "code id_token", "code token", "id_token token", "code id_token token"],
"grant_types_supported": ["authorization_code", "refresh_token"],
"subject_types_supported": ["public"],
"id_token_signing_alg_values_supported": ["RS256"],
"code_challenge_methods_supported": ["S256"],
"claims_supported": ["sub", "iss", "aud", "exp", "iat", "auth_time", "nonce", "email", "email_verified", "name", "given_name", "family_name"],
"token_endpoint_auth_methods_supported": ["client_secret_basic", "client_secret_post"],
"revocation_endpoint": format!("{}/auth/oauth2/revoke", base_url),
"introspection_endpoint": format!("{}/auth/oauth2/introspect", base_url),
"end_session_endpoint": format!("{}/auth/oauth2/logout", base_url)
});
Ok(Json(discovery_doc))
}
fn extract_bearer_token(headers: &HeaderMap) -> Option<String> {
let auth_header = headers.get("authorization")?;
let auth_str = auth_header.to_str().ok()?;
auth_str
.strip_prefix("Bearer ")
.map(|stripped| stripped.to_string())
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::HeaderValue;
#[test]
fn test_bearer_token_extraction() {
let mut headers = HeaderMap::new();
headers.insert(
"authorization",
HeaderValue::from_static("Bearer test_token_123"),
);
let token = extract_bearer_token(&headers);
assert_eq!(token, Some("test_token_123".to_string()));
headers.insert("authorization", HeaderValue::from_static("Basic dGVzdA=="));
let token = extract_bearer_token(&headers);
assert_eq!(token, None);
headers.remove("authorization");
let token = extract_bearer_token(&headers);
assert_eq!(token, None);
}
#[test]
fn test_oauth2_response_serialization() {
let response = OAuth2AuthResponse {
success: true,
authorization_url: Some("https://provider.example.com/auth".to_string()),
state: Some("state123".to_string()),
message: "Success".to_string(),
};
let json = serde_json::to_string(&response).unwrap();
assert!(json.contains("authorization_url"));
assert!(json.contains("state123"));
}
}