use super::client::OAuthClient;
use super::config::OAuthGatewayConfig;
use super::session::{OAuthSession, SessionStore};
use super::types::{CallbackParams, LogoutRequest, RefreshRequest, UserInfo};
use actix_web::{HttpRequest, HttpResponse, Result as ActixResult, web};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tracing::{debug, error, info, warn};
#[derive(Clone)]
pub struct OAuthState {
pub config: Arc<OAuthGatewayConfig>,
pub clients: Arc<HashMap<String, OAuthClient>>,
pub session_store: Arc<dyn SessionStore>,
}
impl std::fmt::Debug for OAuthState {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OAuthState")
.field("providers", &self.clients.keys().collect::<Vec<_>>())
.finish()
}
}
impl OAuthState {
pub fn new(
config: OAuthGatewayConfig,
session_store: Arc<dyn SessionStore>,
) -> Result<Self, String> {
let mut clients = HashMap::new();
for (name, provider_config) in &config.providers {
if provider_config.enabled {
let client = OAuthClient::new(provider_config.clone())
.map_err(|e| format!("Failed to create OAuth client for {}: {}", name, e))?;
clients.insert(name.clone(), client);
}
}
Ok(Self {
config: Arc::new(config),
clients: Arc::new(clients),
session_store,
})
}
pub fn get_client(&self, provider: &str) -> Option<&OAuthClient> {
self.clients.get(provider)
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct OAuthLoginResponse {
pub authorization_url: String,
pub state: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct AuthResponse {
pub session_id: String,
pub access_token: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub refresh_token: Option<String>,
pub token_type: String,
pub expires_in: u64,
pub user: UserInfo,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct OAuthErrorResponse {
pub error: String,
pub error_description: String,
}
impl OAuthErrorResponse {
pub fn new(error: impl Into<String>, description: impl Into<String>) -> Self {
Self {
error: error.into(),
error_description: description.into(),
}
}
}
#[derive(Debug, Deserialize)]
pub struct LoginQuery {
#[serde(default)]
pub redirect_uri: Option<String>,
#[serde(default)]
pub login_hint: Option<String>,
#[serde(default)]
pub prompt: Option<String>,
}
fn is_redirect_origin_allowed(redirect_uri: &str, allowed_origins: &[String]) -> bool {
if allowed_origins.is_empty() {
warn!(
"OAuth allowed_redirect_origins is empty; rejecting redirect to '{}'. \
Configure allowed_redirect_origins explicitly to permit client redirects.",
redirect_uri
);
return false;
}
let Ok(parsed) = url::Url::parse(redirect_uri) else {
return false;
};
let origin = match parsed.port() {
Some(port) => format!(
"{}://{}:{}",
parsed.scheme(),
parsed.host_str().unwrap_or(""),
port
),
None => format!("{}://{}", parsed.scheme(), parsed.host_str().unwrap_or("")),
};
allowed_origins.iter().any(|allowed| allowed == &origin)
}
pub fn configure_routes(cfg: &mut web::ServiceConfig) {
cfg.service(
web::scope("/oauth")
.route("/{provider}/login", web::get().to(oauth_login))
.route("/{provider}/callback", web::get().to(oauth_callback))
.route("/refresh", web::post().to(oauth_refresh))
.route("/logout", web::post().to(oauth_logout))
.route("/userinfo", web::get().to(oauth_userinfo)),
);
}
pub async fn oauth_login(
oauth: web::Data<OAuthState>,
path: web::Path<String>,
query: web::Query<LoginQuery>,
req: HttpRequest,
) -> ActixResult<HttpResponse> {
let provider = path.into_inner();
info!("OAuth login request for provider: {}", provider);
let client = match oauth.get_client(&provider) {
Some(c) => c,
None => {
warn!("Unknown OAuth provider: {}", provider);
return Ok(HttpResponse::NotFound().json(OAuthErrorResponse::new(
"provider_not_found",
format!("OAuth provider '{}' is not configured", provider),
)));
}
};
let (mut auth_url, mut state) = client.get_authorization_url();
if let Some(redirect) = &query.redirect_uri {
state = state.with_data("client_redirect", redirect.clone());
}
if let Some(hint) = &query.login_hint
&& !auth_url.contains("login_hint=")
{
auth_url = format!(
"{}&login_hint={}",
auth_url,
url::form_urlencoded::byte_serialize(hint.as_bytes()).collect::<String>()
);
}
if let Some(prompt) = &query.prompt
&& !auth_url.contains("prompt=")
{
auth_url = format!(
"{}&prompt={}",
auth_url,
url::form_urlencoded::byte_serialize(prompt.as_bytes()).collect::<String>()
);
}
if let Some(ip) = req.connection_info().peer_addr() {
state = state.with_data("client_ip", ip.to_string());
}
if let Some(ua) = req
.headers()
.get("User-Agent")
.and_then(|h| h.to_str().ok())
{
state = state.with_data("user_agent", ua.to_string());
}
if let Err(e) = oauth.session_store.set_state(state.clone()).await {
error!("Failed to store OAuth state: {:?}", e);
return Ok(
HttpResponse::InternalServerError().json(OAuthErrorResponse::new(
"state_storage_error",
"Failed to initialize OAuth flow",
)),
);
}
debug!("Redirecting to OAuth provider: {}", auth_url);
Ok(HttpResponse::Found()
.insert_header(("Location", auth_url))
.finish())
}
pub async fn oauth_callback(
oauth: web::Data<OAuthState>,
path: web::Path<String>,
query: web::Query<CallbackParams>,
) -> ActixResult<HttpResponse> {
let provider = path.into_inner();
info!("OAuth callback for provider: {}", provider);
if let Some(error) = &query.error {
warn!(
"OAuth error from provider: {} - {}",
error,
query
.error_description
.as_deref()
.unwrap_or("No description")
);
return Ok(HttpResponse::BadRequest().json(OAuthErrorResponse::new(
error.clone(),
query.error_description.clone().unwrap_or_default(),
)));
}
let state_id = match &query.state {
Some(s) => s,
None => {
return Ok(HttpResponse::BadRequest().json(OAuthErrorResponse::new(
"missing_state",
"State parameter is required",
)));
}
};
let stored_state = match oauth.session_store.get_and_delete_state(state_id).await {
Ok(Some(s)) => s,
Ok(None) => {
warn!("OAuth state not found or expired: {}", state_id);
return Ok(HttpResponse::BadRequest().json(OAuthErrorResponse::new(
"invalid_state",
"OAuth state is invalid or has expired",
)));
}
Err(e) => {
error!("Failed to retrieve OAuth state: {:?}", e);
return Ok(
HttpResponse::InternalServerError().json(OAuthErrorResponse::new(
"state_retrieval_error",
"Failed to validate OAuth state",
)),
);
}
};
let client = match oauth.get_client(&provider) {
Some(c) => c,
None => {
return Ok(HttpResponse::NotFound().json(OAuthErrorResponse::new(
"provider_not_found",
format!("OAuth provider '{}' is not configured", provider),
)));
}
};
if let Err(e) = client.validate_callback(&query, &stored_state) {
warn!("OAuth callback validation failed: {}", e);
return Ok(HttpResponse::BadRequest()
.json(OAuthErrorResponse::new("validation_error", e.to_string())));
}
let code = match query.code.as_ref() {
Some(code) => code,
None => {
return Ok(HttpResponse::BadRequest().json(OAuthErrorResponse::new(
"missing_code",
"Authorization code is required",
)));
}
};
let token_response = match client.exchange_code(code, &stored_state).await {
Ok(t) => t,
Err(e) => {
error!("Token exchange failed: {}", e);
return Ok(
HttpResponse::InternalServerError().json(OAuthErrorResponse::new(
"token_exchange_error",
"Failed to exchange authorization code for tokens",
)),
);
}
};
let user_info = match client.get_user_info(&token_response.access_token).await {
Ok(u) => u,
Err(e) => {
error!("Failed to get user info: {}", e);
return Ok(
HttpResponse::InternalServerError().json(OAuthErrorResponse::new(
"userinfo_error",
"Failed to retrieve user information",
)),
);
}
};
info!(
"OAuth authentication successful for user: {}",
user_info.email
);
let mut session = OAuthSession::new(
user_info.clone(),
token_response.access_token.clone(),
token_response.expires_in,
oauth.config.session_ttl_seconds,
);
if let Some(rt) = &token_response.refresh_token {
session = session.with_refresh_token(rt.clone());
}
if let Some(it) = &token_response.id_token {
session = session.with_id_token(it.clone());
}
session = session.with_client_info(
stored_state.extra_data.get("client_ip").cloned(),
stored_state.extra_data.get("user_agent").cloned(),
);
if !oauth.config.default_role.is_empty() {
session = session.with_role(&oauth.config.default_role);
}
let session_id = session.session_id.clone();
if let Err(e) = oauth.session_store.set(session).await {
error!("Failed to store session: {:?}", e);
return Ok(
HttpResponse::InternalServerError().json(OAuthErrorResponse::new(
"session_storage_error",
"Failed to create session",
)),
);
}
if let Some(redirect) = stored_state.extra_data.get("client_redirect") {
let allowed = &oauth.config.allowed_redirect_origins;
if is_redirect_origin_allowed(redirect, allowed) {
let redirect_url = if redirect.contains('?') {
format!("{}&session_id={}", redirect, session_id)
} else {
format!("{}?session_id={}", redirect, session_id)
};
return Ok(HttpResponse::Found()
.insert_header(("Location", redirect_url))
.finish());
} else {
warn!(
"OAuth callback: client_redirect '{}' not in allowed_redirect_origins; falling back to /",
redirect
);
return Ok(HttpResponse::Found()
.insert_header(("Location", "/"))
.finish());
}
}
Ok(HttpResponse::Ok().json(AuthResponse {
session_id,
access_token: token_response.access_token,
refresh_token: token_response.refresh_token,
token_type: token_response.token_type,
expires_in: token_response.expires_in,
user: user_info,
}))
}
pub async fn oauth_refresh(
oauth: web::Data<OAuthState>,
body: web::Json<RefreshRequest>,
) -> ActixResult<HttpResponse> {
debug!("OAuth token refresh request");
for (provider_name, client) in oauth.clients.iter() {
match client.refresh_token(&body.refresh_token).await {
Ok(token_response) => {
info!("Token refresh successful with provider: {}", provider_name);
return Ok(HttpResponse::Ok().json(AuthResponse {
session_id: String::new(), access_token: token_response.access_token,
refresh_token: token_response.refresh_token,
token_type: token_response.token_type,
expires_in: token_response.expires_in,
user: UserInfo::new("", "", provider_name), }));
}
Err(e) => {
debug!("Refresh failed with provider {}: {}", provider_name, e);
}
}
}
Ok(HttpResponse::Unauthorized().json(OAuthErrorResponse::new(
"refresh_failed",
"Failed to refresh token with any provider",
)))
}
pub async fn oauth_logout(
oauth: web::Data<OAuthState>,
body: web::Json<LogoutRequest>,
req: HttpRequest,
) -> ActixResult<HttpResponse> {
info!("OAuth logout request");
let session_id = req
.headers()
.get("X-Session-ID")
.and_then(|h| h.to_str().ok())
.map(String::from)
.or_else(|| {
req.headers()
.get("Authorization")
.and_then(|h| h.to_str().ok())
.and_then(|h| h.strip_prefix("Bearer "))
.map(String::from)
});
if let Some(sid) = session_id {
if let Err(e) = oauth.session_store.delete(&sid).await {
warn!("Failed to delete session during logout: {:?}", e);
} else {
debug!("Session deleted: {}", sid);
}
}
let logout_redirect = body.post_logout_redirect_uri.clone();
Ok(HttpResponse::Ok().json(serde_json::json!({
"success": true,
"message": "Logged out successfully",
"redirect_url": logout_redirect
})))
}
pub async fn oauth_userinfo(
oauth: web::Data<OAuthState>,
req: HttpRequest,
) -> ActixResult<HttpResponse> {
debug!("OAuth userinfo request");
let session_id = req
.headers()
.get("X-Session-ID")
.and_then(|h| h.to_str().ok())
.or_else(|| {
req.headers()
.get("Authorization")
.and_then(|h| h.to_str().ok())
.and_then(|h| h.strip_prefix("Bearer "))
});
let session_id = match session_id {
Some(sid) => sid,
None => {
return Ok(HttpResponse::Unauthorized().json(OAuthErrorResponse::new(
"missing_session",
"Session ID is required",
)));
}
};
let session = match oauth.session_store.get(session_id).await {
Ok(Some(s)) => s,
Ok(None) => {
return Ok(HttpResponse::Unauthorized().json(OAuthErrorResponse::new(
"invalid_session",
"Session not found or expired",
)));
}
Err(e) => {
error!("Failed to retrieve session: {:?}", e);
return Ok(
HttpResponse::InternalServerError().json(OAuthErrorResponse::new(
"session_error",
"Failed to retrieve session",
)),
);
}
};
Ok(HttpResponse::Ok().json(session.user_info))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::oauth::config::OAuthConfig;
use crate::auth::oauth::session::InMemorySessionStore;
fn create_test_config() -> OAuthGatewayConfig {
let mut config = OAuthGatewayConfig::default();
config.add_provider(
"google",
OAuthConfig::google("test_client_id", "https://app.example.com/callback")
.with_client_secret("test_secret"),
);
config
}
#[test]
fn test_oauth_state_creation() {
let config = create_test_config();
let session_store = Arc::new(InMemorySessionStore::new());
let state = OAuthState::new(config, session_store);
assert!(state.is_ok());
let state = state.unwrap();
assert!(state.get_client("google").is_some());
assert!(state.get_client("unknown").is_none());
}
#[test]
fn test_login_response_serialization() {
let response = OAuthLoginResponse {
authorization_url: "https://auth.example.com".to_string(),
state: "state123".to_string(),
};
let json = serde_json::to_string(&response).unwrap();
assert!(json.contains("authorization_url"));
assert!(json.contains("state123"));
}
#[test]
fn test_auth_response_serialization() {
let response = AuthResponse {
session_id: "session123".to_string(),
access_token: "access_token".to_string(),
refresh_token: Some("refresh_token".to_string()),
token_type: "Bearer".to_string(),
expires_in: 3600,
user: UserInfo::new("123", "test@example.com", "google"),
};
let json = serde_json::to_string(&response).unwrap();
assert!(json.contains("session123"));
assert!(json.contains("access_token"));
assert!(json.contains("test@example.com"));
}
#[test]
fn test_error_response() {
let error = OAuthErrorResponse::new("invalid_grant", "The authorization code has expired");
assert_eq!(error.error, "invalid_grant");
assert!(error.error_description.contains("expired"));
}
#[test]
fn test_login_query_deserialization() {
let json =
r#"{"redirect_uri": "https://app.example.com", "login_hint": "user@example.com"}"#;
let query: LoginQuery = serde_json::from_str(json).unwrap();
assert_eq!(
query.redirect_uri,
Some("https://app.example.com".to_string())
);
assert_eq!(query.login_hint, Some("user@example.com".to_string()));
}
#[test]
fn test_oauth_state_debug() {
let config = create_test_config();
let session_store = Arc::new(InMemorySessionStore::new());
let state = OAuthState::new(config, session_store).unwrap();
let debug_str = format!("{:?}", state);
assert!(debug_str.contains("OAuthState"));
assert!(debug_str.contains("google"));
}
#[test]
fn test_empty_allowed_origins_rejects_all() {
let empty: Vec<String> = vec![];
assert!(!is_redirect_origin_allowed(
"https://evil.com/callback",
&empty
));
assert!(!is_redirect_origin_allowed(
"https://app.example.com",
&empty
));
}
#[test]
fn test_allowed_origins_permits_matching() {
let origins = vec!["https://app.example.com".to_string()];
assert!(is_redirect_origin_allowed(
"https://app.example.com/callback?foo=bar",
&origins
));
}
#[test]
fn test_allowed_origins_rejects_non_matching() {
let origins = vec!["https://app.example.com".to_string()];
assert!(!is_redirect_origin_allowed(
"https://evil.com/callback",
&origins
));
}
#[test]
fn test_allowed_origins_with_port() {
let origins = vec!["http://localhost:3000".to_string()];
assert!(is_redirect_origin_allowed(
"http://localhost:3000/cb",
&origins
));
assert!(!is_redirect_origin_allowed(
"http://localhost:4000/cb",
&origins
));
}
#[test]
fn test_invalid_redirect_uri_rejected() {
let origins = vec!["https://app.example.com".to_string()];
assert!(!is_redirect_origin_allowed("not-a-url", &origins));
}
}