1use axum::{
7 extract::{Query, State},
8 http::StatusCode,
9 response::{Json, Redirect},
10};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14use tokio::sync::RwLock;
15
16use crate::auth::oidc::{generate_oidc_token, OidcState, TenantContext};
17use crate::auth::token_lifecycle::{extract_token_id, TokenLifecycleManager};
18use chrono::Utc;
19use hex;
20use rand::Rng;
21use serde_json::json;
22use uuid;
23
24#[derive(Clone)]
26pub struct OAuth2ServerState {
27 pub oidc_state: Arc<RwLock<Option<OidcState>>>,
29 pub lifecycle_manager: Arc<TokenLifecycleManager>,
31 pub auth_codes: Arc<RwLock<HashMap<String, AuthorizationCodeInfo>>>,
33 pub refresh_tokens: Arc<RwLock<HashMap<String, RefreshTokenInfo>>>,
35}
36
37#[derive(Debug, Clone)]
39pub struct RefreshTokenInfo {
40 pub client_id: String,
42 pub scopes: Vec<String>,
44 pub user_id: String,
46 pub expires_at: i64,
48}
49
50#[derive(Debug, Clone)]
52pub struct AuthorizationCodeInfo {
53 pub client_id: String,
55 pub redirect_uri: String,
57 pub scopes: Vec<String>,
59 pub user_id: String,
61 pub state: Option<String>,
63 pub expires_at: i64,
65 pub tenant_context: Option<TenantContext>,
67}
68
69#[derive(Debug, Deserialize)]
71pub struct AuthorizationRequest {
72 pub client_id: String,
74 pub response_type: String,
76 pub redirect_uri: String,
78 pub scope: Option<String>,
80 pub state: Option<String>,
82 pub nonce: Option<String>,
84 pub prompt: Option<String>,
86}
87
88#[derive(Debug, Deserialize)]
90pub struct TokenRequest {
91 pub grant_type: String,
93 pub code: Option<String>,
95 pub redirect_uri: Option<String>,
97 pub client_id: Option<String>,
99 pub client_secret: Option<String>,
101 pub scope: Option<String>,
103 pub nonce: Option<String>,
105 pub refresh_token: Option<String>,
107}
108
109#[derive(Debug, Serialize)]
111pub struct TokenResponse {
112 pub access_token: String,
114 pub token_type: String,
116 pub expires_in: i64,
118 #[serde(skip_serializing_if = "Option::is_none")]
120 pub refresh_token: Option<String>,
121 #[serde(skip_serializing_if = "Option::is_none")]
123 pub scope: Option<String>,
124 #[serde(skip_serializing_if = "Option::is_none")]
126 pub id_token: Option<String>,
127}
128
129pub async fn authorize(
131 State(state): State<OAuth2ServerState>,
132 Query(params): Query<AuthorizationRequest>,
133) -> Result<Redirect, StatusCode> {
134 if params.response_type != "code" {
136 return Err(StatusCode::BAD_REQUEST);
137 }
138
139 if params.prompt.as_deref() == Some("consent") {
141 let mut consent_url = url::Url::parse("http://localhost/consent")
142 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
143 consent_url
144 .query_pairs_mut()
145 .append_pair("client_id", ¶ms.client_id)
146 .append_pair("redirect_uri", ¶ms.redirect_uri);
147 if let Some(ref scope) = params.scope {
148 consent_url.query_pairs_mut().append_pair("scope", scope);
149 }
150 if let Some(ref state) = params.state {
151 consent_url.query_pairs_mut().append_pair("state", state);
152 }
153 let redirect_target =
155 format!("/consent{}", consent_url.query().map(|q| format!("?{q}")).unwrap_or_default());
156 return Ok(Redirect::to(&redirect_target));
157 }
158
159 let auth_code = {
163 let mut rng = rand::rng();
164 let code_bytes: [u8; 32] = rng.random();
165 hex::encode(code_bytes)
166 };
167
168 let scopes = params
170 .scope
171 .as_ref()
172 .map(|s| s.split(' ').map(|s| s.to_string()).collect())
173 .unwrap_or_else(Vec::new);
174
175 let code_info = AuthorizationCodeInfo {
177 client_id: params.client_id.clone(),
178 redirect_uri: params.redirect_uri.clone(),
179 scopes,
180 user_id: "user-default".to_string(),
183 state: params.state.clone(),
184 expires_at: Utc::now().timestamp() + 600, tenant_context: None,
187 };
188
189 {
190 let mut codes = state.auth_codes.write().await;
191 codes.insert(auth_code.clone(), code_info);
192 }
193
194 let mut redirect_url =
196 url::Url::parse(¶ms.redirect_uri).map_err(|_| StatusCode::BAD_REQUEST)?;
197 redirect_url.query_pairs_mut().append_pair("code", &auth_code);
198 if let Some(state) = params.state {
199 redirect_url.query_pairs_mut().append_pair("state", &state);
200 }
201
202 Ok(Redirect::to(redirect_url.as_str()))
203}
204
205pub async fn token(
207 State(state): State<OAuth2ServerState>,
208 axum::extract::Form(request): axum::extract::Form<TokenRequest>,
209) -> Result<Json<TokenResponse>, StatusCode> {
210 match request.grant_type.as_str() {
211 "authorization_code" => handle_authorization_code_grant(state, request).await,
212 "client_credentials" => handle_client_credentials_grant(state, request).await,
213 "refresh_token" => handle_refresh_token_grant(state, request).await,
214 _ => Err(StatusCode::BAD_REQUEST),
215 }
216}
217
218async fn handle_authorization_code_grant(
220 state: OAuth2ServerState,
221 request: TokenRequest,
222) -> Result<Json<TokenResponse>, StatusCode> {
223 let code = request.code.ok_or(StatusCode::BAD_REQUEST)?;
224 let redirect_uri = request.redirect_uri.ok_or(StatusCode::BAD_REQUEST)?;
225
226 let code_info = {
228 let mut codes = state.auth_codes.write().await;
229 codes.remove(&code).ok_or(StatusCode::BAD_REQUEST)?
230 };
231
232 if code_info.redirect_uri != redirect_uri {
234 return Err(StatusCode::BAD_REQUEST);
235 }
236
237 if code_info.expires_at < Utc::now().timestamp() {
239 return Err(StatusCode::BAD_REQUEST);
240 }
241
242 let oidc_state_guard = state.oidc_state.read().await;
244 let oidc_state = oidc_state_guard.as_ref().ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
245
246 let mut additional_claims = HashMap::new();
248 additional_claims.insert("scope".to_string(), json!(code_info.scopes.join(" ")));
249 if let Some(nonce) = request.nonce {
250 additional_claims.insert("nonce".to_string(), json!(nonce));
251 }
252
253 let access_token = generate_oidc_token(
254 oidc_state,
255 code_info.user_id.clone(),
256 Some(additional_claims),
257 Some(3600), code_info.tenant_context.clone(),
259 )
260 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
261
262 let token_id = extract_token_id(&access_token);
264 if state.lifecycle_manager.revocation.is_revoked(&token_id).await.is_some() {
265 return Err(StatusCode::INTERNAL_SERVER_ERROR);
266 }
267
268 let refresh_token = format!("refresh_{}", uuid::Uuid::new_v4());
270 {
271 let mut tokens = state.refresh_tokens.write().await;
272 tokens.insert(
273 refresh_token.clone(),
274 RefreshTokenInfo {
275 client_id: code_info.client_id.clone(),
276 scopes: code_info.scopes.clone(),
277 user_id: code_info.user_id.clone(),
278 expires_at: Utc::now().timestamp() + 86400, },
280 );
281 }
282
283 Ok(Json(TokenResponse {
284 access_token,
285 token_type: "Bearer".to_string(),
286 expires_in: 3600,
287 refresh_token: Some(refresh_token),
288 scope: Some(code_info.scopes.join(" ")),
289 id_token: None,
290 }))
291}
292
293async fn handle_client_credentials_grant(
295 state: OAuth2ServerState,
296 request: TokenRequest,
297) -> Result<Json<TokenResponse>, StatusCode> {
298 let client_id = request.client_id.ok_or(StatusCode::BAD_REQUEST)?;
299 let _client_secret = request.client_secret.ok_or(StatusCode::BAD_REQUEST)?;
300
301 let oidc_state_guard = state.oidc_state.read().await;
305 let oidc_state = oidc_state_guard.as_ref().ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
306
307 let mut additional_claims = HashMap::new();
308 additional_claims.insert("client_id".to_string(), serde_json::json!(client_id));
309 let scope_clone = request.scope.clone();
310 if let Some(ref scope) = request.scope {
311 additional_claims.insert("scope".to_string(), serde_json::json!(scope));
312 }
313
314 let access_token = generate_oidc_token(
315 oidc_state,
316 format!("client_{}", client_id),
317 Some(additional_claims),
318 Some(3600),
319 None,
320 )
321 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
322
323 Ok(Json(TokenResponse {
324 access_token,
325 token_type: "Bearer".to_string(),
326 expires_in: 3600,
327 refresh_token: None,
328 scope: scope_clone,
329 id_token: None,
330 }))
331}
332
333async fn handle_refresh_token_grant(
335 state: OAuth2ServerState,
336 request: TokenRequest,
337) -> Result<Json<TokenResponse>, StatusCode> {
338 let refresh_token_value = request.refresh_token.ok_or(StatusCode::BAD_REQUEST)?;
340
341 let token_info = {
343 let mut tokens = state.refresh_tokens.write().await;
344 tokens.remove(&refresh_token_value).ok_or(StatusCode::UNAUTHORIZED)?
345 };
346
347 if token_info.expires_at < Utc::now().timestamp() {
349 return Err(StatusCode::UNAUTHORIZED);
350 }
351
352 if let Some(ref client_id) = request.client_id {
354 if *client_id != token_info.client_id {
355 return Err(StatusCode::UNAUTHORIZED);
356 }
357 }
358
359 let oidc_state_guard = state.oidc_state.read().await;
361 let oidc_state = oidc_state_guard.as_ref().ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
362
363 let mut additional_claims = HashMap::new();
364 additional_claims.insert("client_id".to_string(), json!(token_info.client_id.clone()));
365
366 let scopes = if let Some(ref scope) = request.scope {
368 additional_claims.insert("scope".to_string(), json!(scope));
369 scope.clone()
370 } else {
371 let scope_str = token_info.scopes.join(" ");
372 additional_claims.insert("scope".to_string(), json!(scope_str));
373 scope_str
374 };
375
376 let access_token = generate_oidc_token(
377 oidc_state,
378 token_info.user_id.clone(),
379 Some(additional_claims),
380 Some(3600),
381 None,
382 )
383 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
384
385 let new_refresh_token = format!("refresh_{}", uuid::Uuid::new_v4());
387 {
388 let mut tokens = state.refresh_tokens.write().await;
389 tokens.insert(
390 new_refresh_token.clone(),
391 RefreshTokenInfo {
392 client_id: token_info.client_id,
393 scopes: token_info.scopes,
394 user_id: token_info.user_id,
395 expires_at: Utc::now().timestamp() + 86400, },
397 );
398 }
399
400 Ok(Json(TokenResponse {
401 access_token,
402 token_type: "Bearer".to_string(),
403 expires_in: 3600,
404 refresh_token: Some(new_refresh_token),
405 scope: Some(scopes),
406 id_token: None,
407 }))
408}
409
410pub fn oauth2_server_router(state: OAuth2ServerState) -> axum::Router {
412 use axum::routing::{get, post};
413
414 axum::Router::new()
415 .route("/oauth2/authorize", get(authorize))
416 .route("/oauth2/token", post(token))
417 .with_state(state)
418}