use axum::{
extract::{Query, State},
http::StatusCode,
response::{Json, Redirect},
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use crate::auth::oidc::{generate_oidc_token, OidcState, TenantContext};
use crate::auth::token_lifecycle::{extract_token_id, TokenLifecycleManager};
use chrono::Utc;
use hex;
use rand::Rng;
use serde_json::json;
use uuid;
#[derive(Clone)]
pub struct OAuth2ServerState {
pub oidc_state: Arc<RwLock<Option<OidcState>>>,
pub lifecycle_manager: Arc<TokenLifecycleManager>,
pub auth_codes: Arc<RwLock<HashMap<String, AuthorizationCodeInfo>>>,
pub refresh_tokens: Arc<RwLock<HashMap<String, RefreshTokenInfo>>>,
}
#[derive(Debug, Clone)]
pub struct RefreshTokenInfo {
pub client_id: String,
pub scopes: Vec<String>,
pub user_id: String,
pub expires_at: i64,
}
#[derive(Debug, Clone)]
pub struct AuthorizationCodeInfo {
pub client_id: String,
pub redirect_uri: String,
pub scopes: Vec<String>,
pub user_id: String,
pub state: Option<String>,
pub expires_at: i64,
pub tenant_context: Option<TenantContext>,
}
#[derive(Debug, Deserialize)]
pub struct AuthorizationRequest {
pub client_id: String,
pub response_type: String,
pub redirect_uri: String,
pub scope: Option<String>,
pub state: Option<String>,
pub nonce: Option<String>,
pub prompt: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct TokenRequest {
pub grant_type: String,
pub code: Option<String>,
pub redirect_uri: Option<String>,
pub client_id: Option<String>,
pub client_secret: Option<String>,
pub scope: Option<String>,
pub nonce: Option<String>,
pub refresh_token: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct TokenResponse {
pub access_token: String,
pub token_type: String,
pub expires_in: i64,
#[serde(skip_serializing_if = "Option::is_none")]
pub refresh_token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub scope: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub id_token: Option<String>,
}
pub async fn authorize(
State(state): State<OAuth2ServerState>,
Query(params): Query<AuthorizationRequest>,
) -> Result<Redirect, StatusCode> {
if params.response_type != "code" {
return Err(StatusCode::BAD_REQUEST);
}
if params.prompt.as_deref() == Some("consent") {
let mut consent_url = url::Url::parse("http://localhost/consent")
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
consent_url
.query_pairs_mut()
.append_pair("client_id", ¶ms.client_id)
.append_pair("redirect_uri", ¶ms.redirect_uri);
if let Some(ref scope) = params.scope {
consent_url.query_pairs_mut().append_pair("scope", scope);
}
if let Some(ref state) = params.state {
consent_url.query_pairs_mut().append_pair("state", state);
}
let redirect_target =
format!("/consent{}", consent_url.query().map(|q| format!("?{q}")).unwrap_or_default());
return Ok(Redirect::to(&redirect_target));
}
let auth_code = {
let mut rng = rand::rng();
let code_bytes: [u8; 32] = rng.random();
hex::encode(code_bytes)
};
let scopes = params
.scope
.as_ref()
.map(|s| s.split(' ').map(|s| s.to_string()).collect())
.unwrap_or_else(Vec::new);
let code_info = AuthorizationCodeInfo {
client_id: params.client_id.clone(),
redirect_uri: params.redirect_uri.clone(),
scopes,
user_id: "user-default".to_string(),
state: params.state.clone(),
expires_at: Utc::now().timestamp() + 600, tenant_context: None,
};
{
let mut codes = state.auth_codes.write().await;
codes.insert(auth_code.clone(), code_info);
}
let mut redirect_url =
url::Url::parse(¶ms.redirect_uri).map_err(|_| StatusCode::BAD_REQUEST)?;
redirect_url.query_pairs_mut().append_pair("code", &auth_code);
if let Some(state) = params.state {
redirect_url.query_pairs_mut().append_pair("state", &state);
}
Ok(Redirect::to(redirect_url.as_str()))
}
pub async fn token(
State(state): State<OAuth2ServerState>,
axum::extract::Form(request): axum::extract::Form<TokenRequest>,
) -> Result<Json<TokenResponse>, StatusCode> {
match request.grant_type.as_str() {
"authorization_code" => handle_authorization_code_grant(state, request).await,
"client_credentials" => handle_client_credentials_grant(state, request).await,
"refresh_token" => handle_refresh_token_grant(state, request).await,
_ => Err(StatusCode::BAD_REQUEST),
}
}
async fn handle_authorization_code_grant(
state: OAuth2ServerState,
request: TokenRequest,
) -> Result<Json<TokenResponse>, StatusCode> {
let code = request.code.ok_or(StatusCode::BAD_REQUEST)?;
let redirect_uri = request.redirect_uri.ok_or(StatusCode::BAD_REQUEST)?;
let code_info = {
let mut codes = state.auth_codes.write().await;
codes.remove(&code).ok_or(StatusCode::BAD_REQUEST)?
};
if code_info.redirect_uri != redirect_uri {
return Err(StatusCode::BAD_REQUEST);
}
if code_info.expires_at < Utc::now().timestamp() {
return Err(StatusCode::BAD_REQUEST);
}
let oidc_state_guard = state.oidc_state.read().await;
let oidc_state = oidc_state_guard.as_ref().ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
let mut additional_claims = HashMap::new();
additional_claims.insert("scope".to_string(), json!(code_info.scopes.join(" ")));
if let Some(nonce) = request.nonce {
additional_claims.insert("nonce".to_string(), json!(nonce));
}
let access_token = generate_oidc_token(
oidc_state,
code_info.user_id.clone(),
Some(additional_claims),
Some(3600), code_info.tenant_context.clone(),
)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let token_id = extract_token_id(&access_token);
if state.lifecycle_manager.revocation.is_revoked(&token_id).await.is_some() {
return Err(StatusCode::INTERNAL_SERVER_ERROR);
}
let refresh_token = format!("refresh_{}", uuid::Uuid::new_v4());
{
let mut tokens = state.refresh_tokens.write().await;
tokens.insert(
refresh_token.clone(),
RefreshTokenInfo {
client_id: code_info.client_id.clone(),
scopes: code_info.scopes.clone(),
user_id: code_info.user_id.clone(),
expires_at: Utc::now().timestamp() + 86400, },
);
}
Ok(Json(TokenResponse {
access_token,
token_type: "Bearer".to_string(),
expires_in: 3600,
refresh_token: Some(refresh_token),
scope: Some(code_info.scopes.join(" ")),
id_token: None,
}))
}
async fn handle_client_credentials_grant(
state: OAuth2ServerState,
request: TokenRequest,
) -> Result<Json<TokenResponse>, StatusCode> {
let client_id = request.client_id.ok_or(StatusCode::BAD_REQUEST)?;
let _client_secret = request.client_secret.ok_or(StatusCode::BAD_REQUEST)?;
let oidc_state_guard = state.oidc_state.read().await;
let oidc_state = oidc_state_guard.as_ref().ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
let mut additional_claims = HashMap::new();
additional_claims.insert("client_id".to_string(), serde_json::json!(client_id));
let scope_clone = request.scope.clone();
if let Some(ref scope) = request.scope {
additional_claims.insert("scope".to_string(), serde_json::json!(scope));
}
let access_token = generate_oidc_token(
oidc_state,
format!("client_{}", client_id),
Some(additional_claims),
Some(3600),
None,
)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(TokenResponse {
access_token,
token_type: "Bearer".to_string(),
expires_in: 3600,
refresh_token: None,
scope: scope_clone,
id_token: None,
}))
}
async fn handle_refresh_token_grant(
state: OAuth2ServerState,
request: TokenRequest,
) -> Result<Json<TokenResponse>, StatusCode> {
let refresh_token_value = request.refresh_token.ok_or(StatusCode::BAD_REQUEST)?;
let token_info = {
let mut tokens = state.refresh_tokens.write().await;
tokens.remove(&refresh_token_value).ok_or(StatusCode::UNAUTHORIZED)?
};
if token_info.expires_at < Utc::now().timestamp() {
return Err(StatusCode::UNAUTHORIZED);
}
if let Some(ref client_id) = request.client_id {
if *client_id != token_info.client_id {
return Err(StatusCode::UNAUTHORIZED);
}
}
let oidc_state_guard = state.oidc_state.read().await;
let oidc_state = oidc_state_guard.as_ref().ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
let mut additional_claims = HashMap::new();
additional_claims.insert("client_id".to_string(), json!(token_info.client_id.clone()));
let scopes = if let Some(ref scope) = request.scope {
additional_claims.insert("scope".to_string(), json!(scope));
scope.clone()
} else {
let scope_str = token_info.scopes.join(" ");
additional_claims.insert("scope".to_string(), json!(scope_str));
scope_str
};
let access_token = generate_oidc_token(
oidc_state,
token_info.user_id.clone(),
Some(additional_claims),
Some(3600),
None,
)
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
let new_refresh_token = format!("refresh_{}", uuid::Uuid::new_v4());
{
let mut tokens = state.refresh_tokens.write().await;
tokens.insert(
new_refresh_token.clone(),
RefreshTokenInfo {
client_id: token_info.client_id,
scopes: token_info.scopes,
user_id: token_info.user_id,
expires_at: Utc::now().timestamp() + 86400, },
);
}
Ok(Json(TokenResponse {
access_token,
token_type: "Bearer".to_string(),
expires_in: 3600,
refresh_token: Some(new_refresh_token),
scope: Some(scopes),
id_token: None,
}))
}
pub fn oauth2_server_router(state: OAuth2ServerState) -> axum::Router {
use axum::routing::{get, post};
axum::Router::new()
.route("/oauth2/authorize", get(authorize))
.route("/oauth2/token", post(token))
.with_state(state)
}