1use axum::{
7 extract::{Query, State},
8 http::StatusCode,
9 response::{Json, Redirect},
10};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14use tokio::sync::RwLock;
15
16use crate::auth::oidc::{generate_oidc_token, OidcState, TenantContext};
17use crate::auth::token_lifecycle::{extract_token_id, TokenLifecycleManager};
18use chrono::Utc;
19use hex;
20use rand::Rng;
21use serde_json::json;
22use uuid;
23
24#[derive(Clone)]
26pub struct OAuth2ServerState {
27 pub oidc_state: Arc<RwLock<Option<OidcState>>>,
29 pub lifecycle_manager: Arc<TokenLifecycleManager>,
31 pub auth_codes: Arc<RwLock<HashMap<String, AuthorizationCodeInfo>>>,
33}
34
35#[derive(Debug, Clone)]
37pub struct AuthorizationCodeInfo {
38 pub client_id: String,
40 pub redirect_uri: String,
42 pub scopes: Vec<String>,
44 pub user_id: String,
46 pub state: Option<String>,
48 pub expires_at: i64,
50 pub tenant_context: Option<TenantContext>,
52}
53
54#[derive(Debug, Deserialize)]
56pub struct AuthorizationRequest {
57 pub client_id: String,
59 pub response_type: String,
61 pub redirect_uri: String,
63 pub scope: Option<String>,
65 pub state: Option<String>,
67 pub nonce: Option<String>,
69}
70
71#[derive(Debug, Deserialize)]
73pub struct TokenRequest {
74 pub grant_type: String,
76 pub code: Option<String>,
78 pub redirect_uri: Option<String>,
80 pub client_id: Option<String>,
82 pub client_secret: Option<String>,
84 pub scope: Option<String>,
86 pub nonce: Option<String>,
88}
89
90#[derive(Debug, Serialize)]
92pub struct TokenResponse {
93 pub access_token: String,
95 pub token_type: String,
97 pub expires_in: i64,
99 #[serde(skip_serializing_if = "Option::is_none")]
101 pub refresh_token: Option<String>,
102 #[serde(skip_serializing_if = "Option::is_none")]
104 pub scope: Option<String>,
105 #[serde(skip_serializing_if = "Option::is_none")]
107 pub id_token: Option<String>,
108}
109
110pub async fn authorize(
112 State(state): State<OAuth2ServerState>,
113 Query(params): Query<AuthorizationRequest>,
114) -> Result<Redirect, StatusCode> {
115
116 if params.response_type != "code" {
118 return Err(StatusCode::BAD_REQUEST);
119 }
120
121 let mut rng = rand::thread_rng();
126 let code_bytes: [u8; 32] = rng.gen();
127 let auth_code = hex::encode(code_bytes);
128
129 let scopes = params
131 .scope
132 .as_ref()
133 .map(|s| s.split(' ').map(|s| s.to_string()).collect())
134 .unwrap_or_else(Vec::new);
135
136 let code_info = AuthorizationCodeInfo {
138 client_id: params.client_id.clone(),
139 redirect_uri: params.redirect_uri.clone(),
140 scopes,
141 user_id: "user-default".to_string(),
144 state: params.state.clone(),
145 expires_at: Utc::now().timestamp() + 600, tenant_context: None,
148 };
149
150 {
151 let mut codes = state.auth_codes.write().await;
152 codes.insert(auth_code.clone(), code_info);
153 }
154
155 let mut redirect_url = url::Url::parse(¶ms.redirect_uri)
157 .map_err(|_| StatusCode::BAD_REQUEST)?;
158 redirect_url
159 .query_pairs_mut()
160 .append_pair("code", &auth_code);
161 if let Some(state) = params.state {
162 redirect_url.query_pairs_mut().append_pair("state", &state);
163 }
164
165 Ok(Redirect::to(redirect_url.as_str()))
166}
167
168pub async fn token(
170 State(state): State<OAuth2ServerState>,
171 axum::extract::Form(request): axum::extract::Form<TokenRequest>,
172) -> Result<Json<TokenResponse>, StatusCode> {
173 use chrono::Utc;
174
175 match request.grant_type.as_str() {
176 "authorization_code" => {
177 handle_authorization_code_grant(state, request).await
178 }
179 "client_credentials" => {
180 handle_client_credentials_grant(state, request).await
181 }
182 "refresh_token" => {
183 handle_refresh_token_grant(state, request).await
184 }
185 _ => Err(StatusCode::BAD_REQUEST),
186 }
187}
188
189async fn handle_authorization_code_grant(
191 state: OAuth2ServerState,
192 request: TokenRequest,
193) -> Result<Json<TokenResponse>, StatusCode> {
194
195 let code = request.code.ok_or(StatusCode::BAD_REQUEST)?;
196 let redirect_uri = request.redirect_uri.ok_or(StatusCode::BAD_REQUEST)?;
197
198 let code_info = {
200 let mut codes = state.auth_codes.write().await;
201 codes.remove(&code).ok_or(StatusCode::BAD_REQUEST)?
202 };
203
204 if code_info.redirect_uri != redirect_uri {
206 return Err(StatusCode::BAD_REQUEST);
207 }
208
209 if code_info.expires_at < Utc::now().timestamp() {
211 return Err(StatusCode::BAD_REQUEST);
212 }
213
214 let oidc_state_guard = state.oidc_state.read().await;
216 let oidc_state = oidc_state_guard
217 .as_ref()
218 .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
219
220 let mut additional_claims = HashMap::new();
222 additional_claims.insert("scope".to_string(), json!(code_info.scopes.join(" ")));
223 if let Some(nonce) = request.nonce {
224 additional_claims.insert("nonce".to_string(), json!(nonce));
225 }
226
227 let access_token = generate_oidc_token(
228 oidc_state,
229 code_info.user_id.clone(),
230 Some(additional_claims),
231 Some(3600), code_info.tenant_context.clone(),
233 )
234 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
235
236 let token_id = extract_token_id(&access_token);
238 if state.lifecycle_manager.revocation.is_revoked(&token_id).await.is_some() {
239 return Err(StatusCode::INTERNAL_SERVER_ERROR);
240 }
241
242 let refresh_token = format!("refresh_{}", uuid::Uuid::new_v4());
244
245 Ok(Json(TokenResponse {
246 access_token,
247 token_type: "Bearer".to_string(),
248 expires_in: 3600,
249 refresh_token: Some(refresh_token),
250 scope: Some(code_info.scopes.join(" ")),
251 id_token: None,
254 }))
255}
256
257async fn handle_client_credentials_grant(
259 state: OAuth2ServerState,
260 request: TokenRequest,
261) -> Result<Json<TokenResponse>, StatusCode> {
262 let client_id = request.client_id.ok_or(StatusCode::BAD_REQUEST)?;
263 let _client_secret = request.client_secret.ok_or(StatusCode::BAD_REQUEST)?;
264
265 let oidc_state_guard = state.oidc_state.read().await;
269 let oidc_state = oidc_state_guard
270 .as_ref()
271 .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
272
273 let mut additional_claims = HashMap::new();
274 additional_claims.insert("client_id".to_string(), serde_json::json!(client_id));
275 if let Some(scope) = request.scope {
276 additional_claims.insert("scope".to_string(), serde_json::json!(scope));
277 }
278
279 let access_token = generate_oidc_token(
280 oidc_state,
281 format!("client_{}", client_id),
282 Some(additional_claims),
283 Some(3600),
284 None,
285 )
286 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
287
288 Ok(Json(TokenResponse {
289 access_token,
290 token_type: "Bearer".to_string(),
291 expires_in: 3600,
292 refresh_token: None,
293 scope: request.scope,
294 id_token: None,
295 }))
296}
297
298async fn handle_refresh_token_grant(
300 state: OAuth2ServerState,
301 request: TokenRequest,
302) -> Result<Json<TokenResponse>, StatusCode> {
303 let client_id = request.client_id.ok_or(StatusCode::BAD_REQUEST)?;
311
312 let oidc_state_guard = state.oidc_state.read().await;
314 let oidc_state = oidc_state_guard
315 .as_ref()
316 .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
317
318 let mut additional_claims = HashMap::new();
319 additional_claims.insert("client_id".to_string(), json!(client_id));
320 if let Some(scope) = request.scope {
321 additional_claims.insert("scope".to_string(), json!(scope));
322 }
323
324 let access_token = generate_oidc_token(
325 oidc_state,
326 format!("client_{}", client_id),
327 Some(additional_claims),
328 Some(3600),
329 None,
330 )
331 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
332
333 let refresh_token = format!("refresh_{}", uuid::Uuid::new_v4());
335
336 Ok(Json(TokenResponse {
337 access_token,
338 token_type: "Bearer".to_string(),
339 expires_in: 3600,
340 refresh_token: Some(refresh_token),
341 scope: request.scope,
342 id_token: None,
343 }))
344}
345
346pub fn oauth2_server_router(state: OAuth2ServerState) -> axum::Router {
348 use axum::routing::{get, post};
349
350 axum::Router::new()
351 .route("/oauth2/authorize", get(authorize))
352 .route("/oauth2/token", post(token))
353 .with_state(state)
354}
355