1use axum::{
7 extract::{Query, State},
8 http::StatusCode,
9 response::{Json, Redirect},
10};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14use tokio::sync::RwLock;
15
16use crate::auth::oidc::{generate_oidc_token, OidcState, TenantContext};
17use crate::auth::token_lifecycle::{extract_token_id, TokenLifecycleManager};
18use chrono::Utc;
19use hex;
20use rand::Rng;
21use serde_json::json;
22use uuid;
23
24#[derive(Clone)]
26pub struct OAuth2ServerState {
27 pub oidc_state: Arc<RwLock<Option<OidcState>>>,
29 pub lifecycle_manager: Arc<TokenLifecycleManager>,
31 pub auth_codes: Arc<RwLock<HashMap<String, AuthorizationCodeInfo>>>,
33 pub refresh_tokens: Arc<RwLock<HashMap<String, RefreshTokenInfo>>>,
35}
36
37#[derive(Debug, Clone)]
39pub struct RefreshTokenInfo {
40 pub client_id: String,
42 pub scopes: Vec<String>,
44 pub user_id: String,
46 pub expires_at: i64,
48}
49
50#[derive(Debug, Clone)]
52pub struct AuthorizationCodeInfo {
53 pub client_id: String,
55 pub redirect_uri: String,
57 pub scopes: Vec<String>,
59 pub user_id: String,
61 pub state: Option<String>,
63 pub expires_at: i64,
65 pub tenant_context: Option<TenantContext>,
67}
68
69#[derive(Debug, Deserialize)]
71pub struct AuthorizationRequest {
72 pub client_id: String,
74 pub response_type: String,
76 pub redirect_uri: String,
78 pub scope: Option<String>,
80 pub state: Option<String>,
82 pub nonce: Option<String>,
84}
85
86#[derive(Debug, Deserialize)]
88pub struct TokenRequest {
89 pub grant_type: String,
91 pub code: Option<String>,
93 pub redirect_uri: Option<String>,
95 pub client_id: Option<String>,
97 pub client_secret: Option<String>,
99 pub scope: Option<String>,
101 pub nonce: Option<String>,
103 pub refresh_token: Option<String>,
105}
106
107#[derive(Debug, Serialize)]
109pub struct TokenResponse {
110 pub access_token: String,
112 pub token_type: String,
114 pub expires_in: i64,
116 #[serde(skip_serializing_if = "Option::is_none")]
118 pub refresh_token: Option<String>,
119 #[serde(skip_serializing_if = "Option::is_none")]
121 pub scope: Option<String>,
122 #[serde(skip_serializing_if = "Option::is_none")]
124 pub id_token: Option<String>,
125}
126
127pub async fn authorize(
129 State(state): State<OAuth2ServerState>,
130 Query(params): Query<AuthorizationRequest>,
131) -> Result<Redirect, StatusCode> {
132 if params.response_type != "code" {
134 return Err(StatusCode::BAD_REQUEST);
135 }
136
137 let auth_code = {
142 let mut rng = rand::rng();
143 let code_bytes: [u8; 32] = rng.random();
144 hex::encode(code_bytes)
145 };
146
147 let scopes = params
149 .scope
150 .as_ref()
151 .map(|s| s.split(' ').map(|s| s.to_string()).collect())
152 .unwrap_or_else(Vec::new);
153
154 let code_info = AuthorizationCodeInfo {
156 client_id: params.client_id.clone(),
157 redirect_uri: params.redirect_uri.clone(),
158 scopes,
159 user_id: "user-default".to_string(),
162 state: params.state.clone(),
163 expires_at: Utc::now().timestamp() + 600, tenant_context: None,
166 };
167
168 {
169 let mut codes = state.auth_codes.write().await;
170 codes.insert(auth_code.clone(), code_info);
171 }
172
173 let mut redirect_url =
175 url::Url::parse(¶ms.redirect_uri).map_err(|_| StatusCode::BAD_REQUEST)?;
176 redirect_url.query_pairs_mut().append_pair("code", &auth_code);
177 if let Some(state) = params.state {
178 redirect_url.query_pairs_mut().append_pair("state", &state);
179 }
180
181 Ok(Redirect::to(redirect_url.as_str()))
182}
183
184pub async fn token(
186 State(state): State<OAuth2ServerState>,
187 axum::extract::Form(request): axum::extract::Form<TokenRequest>,
188) -> Result<Json<TokenResponse>, StatusCode> {
189 match request.grant_type.as_str() {
190 "authorization_code" => handle_authorization_code_grant(state, request).await,
191 "client_credentials" => handle_client_credentials_grant(state, request).await,
192 "refresh_token" => handle_refresh_token_grant(state, request).await,
193 _ => Err(StatusCode::BAD_REQUEST),
194 }
195}
196
197async fn handle_authorization_code_grant(
199 state: OAuth2ServerState,
200 request: TokenRequest,
201) -> Result<Json<TokenResponse>, StatusCode> {
202 let code = request.code.ok_or(StatusCode::BAD_REQUEST)?;
203 let redirect_uri = request.redirect_uri.ok_or(StatusCode::BAD_REQUEST)?;
204
205 let code_info = {
207 let mut codes = state.auth_codes.write().await;
208 codes.remove(&code).ok_or(StatusCode::BAD_REQUEST)?
209 };
210
211 if code_info.redirect_uri != redirect_uri {
213 return Err(StatusCode::BAD_REQUEST);
214 }
215
216 if code_info.expires_at < Utc::now().timestamp() {
218 return Err(StatusCode::BAD_REQUEST);
219 }
220
221 let oidc_state_guard = state.oidc_state.read().await;
223 let oidc_state = oidc_state_guard.as_ref().ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
224
225 let mut additional_claims = HashMap::new();
227 additional_claims.insert("scope".to_string(), json!(code_info.scopes.join(" ")));
228 if let Some(nonce) = request.nonce {
229 additional_claims.insert("nonce".to_string(), json!(nonce));
230 }
231
232 let access_token = generate_oidc_token(
233 oidc_state,
234 code_info.user_id.clone(),
235 Some(additional_claims),
236 Some(3600), code_info.tenant_context.clone(),
238 )
239 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
240
241 let token_id = extract_token_id(&access_token);
243 if state.lifecycle_manager.revocation.is_revoked(&token_id).await.is_some() {
244 return Err(StatusCode::INTERNAL_SERVER_ERROR);
245 }
246
247 let refresh_token = format!("refresh_{}", uuid::Uuid::new_v4());
249 {
250 let mut tokens = state.refresh_tokens.write().await;
251 tokens.insert(
252 refresh_token.clone(),
253 RefreshTokenInfo {
254 client_id: code_info.client_id.clone(),
255 scopes: code_info.scopes.clone(),
256 user_id: code_info.user_id.clone(),
257 expires_at: Utc::now().timestamp() + 86400, },
259 );
260 }
261
262 Ok(Json(TokenResponse {
263 access_token,
264 token_type: "Bearer".to_string(),
265 expires_in: 3600,
266 refresh_token: Some(refresh_token),
267 scope: Some(code_info.scopes.join(" ")),
268 id_token: None,
269 }))
270}
271
272async fn handle_client_credentials_grant(
274 state: OAuth2ServerState,
275 request: TokenRequest,
276) -> Result<Json<TokenResponse>, StatusCode> {
277 let client_id = request.client_id.ok_or(StatusCode::BAD_REQUEST)?;
278 let _client_secret = request.client_secret.ok_or(StatusCode::BAD_REQUEST)?;
279
280 let oidc_state_guard = state.oidc_state.read().await;
284 let oidc_state = oidc_state_guard.as_ref().ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
285
286 let mut additional_claims = HashMap::new();
287 additional_claims.insert("client_id".to_string(), serde_json::json!(client_id));
288 let scope_clone = request.scope.clone();
289 if let Some(ref scope) = request.scope {
290 additional_claims.insert("scope".to_string(), serde_json::json!(scope));
291 }
292
293 let access_token = generate_oidc_token(
294 oidc_state,
295 format!("client_{}", client_id),
296 Some(additional_claims),
297 Some(3600),
298 None,
299 )
300 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
301
302 Ok(Json(TokenResponse {
303 access_token,
304 token_type: "Bearer".to_string(),
305 expires_in: 3600,
306 refresh_token: None,
307 scope: scope_clone,
308 id_token: None,
309 }))
310}
311
312async fn handle_refresh_token_grant(
314 state: OAuth2ServerState,
315 request: TokenRequest,
316) -> Result<Json<TokenResponse>, StatusCode> {
317 let refresh_token_value = request.refresh_token.ok_or(StatusCode::BAD_REQUEST)?;
319
320 let token_info = {
322 let mut tokens = state.refresh_tokens.write().await;
323 tokens.remove(&refresh_token_value).ok_or(StatusCode::UNAUTHORIZED)?
324 };
325
326 if token_info.expires_at < Utc::now().timestamp() {
328 return Err(StatusCode::UNAUTHORIZED);
329 }
330
331 if let Some(ref client_id) = request.client_id {
333 if *client_id != token_info.client_id {
334 return Err(StatusCode::UNAUTHORIZED);
335 }
336 }
337
338 let oidc_state_guard = state.oidc_state.read().await;
340 let oidc_state = oidc_state_guard.as_ref().ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
341
342 let mut additional_claims = HashMap::new();
343 additional_claims.insert("client_id".to_string(), json!(token_info.client_id.clone()));
344
345 let scopes = if let Some(ref scope) = request.scope {
347 additional_claims.insert("scope".to_string(), json!(scope));
348 scope.clone()
349 } else {
350 let scope_str = token_info.scopes.join(" ");
351 additional_claims.insert("scope".to_string(), json!(scope_str));
352 scope_str
353 };
354
355 let access_token = generate_oidc_token(
356 oidc_state,
357 token_info.user_id.clone(),
358 Some(additional_claims),
359 Some(3600),
360 None,
361 )
362 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
363
364 let new_refresh_token = format!("refresh_{}", uuid::Uuid::new_v4());
366 {
367 let mut tokens = state.refresh_tokens.write().await;
368 tokens.insert(
369 new_refresh_token.clone(),
370 RefreshTokenInfo {
371 client_id: token_info.client_id,
372 scopes: token_info.scopes,
373 user_id: token_info.user_id,
374 expires_at: Utc::now().timestamp() + 86400, },
376 );
377 }
378
379 Ok(Json(TokenResponse {
380 access_token,
381 token_type: "Bearer".to_string(),
382 expires_in: 3600,
383 refresh_token: Some(new_refresh_token),
384 scope: Some(scopes),
385 id_token: None,
386 }))
387}
388
389pub fn oauth2_server_router(state: OAuth2ServerState) -> axum::Router {
391 use axum::routing::{get, post};
392
393 axum::Router::new()
394 .route("/oauth2/authorize", get(authorize))
395 .route("/oauth2/token", post(token))
396 .with_state(state)
397}