1use axum::{
7 extract::{Query, State},
8 http::StatusCode,
9 response::{Json, Redirect},
10};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14use tokio::sync::RwLock;
15
16use crate::auth::oidc::{generate_oidc_token, OidcState, TenantContext};
17use crate::auth::token_lifecycle::{extract_token_id, TokenLifecycleManager};
18use chrono::Utc;
19use hex;
20use rand::Rng;
21use serde_json::json;
22use uuid;
23
24#[derive(Clone)]
26pub struct OAuth2ServerState {
27 pub oidc_state: Arc<RwLock<Option<OidcState>>>,
29 pub lifecycle_manager: Arc<TokenLifecycleManager>,
31 pub auth_codes: Arc<RwLock<HashMap<String, AuthorizationCodeInfo>>>,
33}
34
35#[derive(Debug, Clone)]
37pub struct AuthorizationCodeInfo {
38 pub client_id: String,
40 pub redirect_uri: String,
42 pub scopes: Vec<String>,
44 pub user_id: String,
46 pub state: Option<String>,
48 pub expires_at: i64,
50 pub tenant_context: Option<TenantContext>,
52}
53
54#[derive(Debug, Deserialize)]
56pub struct AuthorizationRequest {
57 pub client_id: String,
59 pub response_type: String,
61 pub redirect_uri: String,
63 pub scope: Option<String>,
65 pub state: Option<String>,
67 pub nonce: Option<String>,
69}
70
71#[derive(Debug, Deserialize)]
73pub struct TokenRequest {
74 pub grant_type: String,
76 pub code: Option<String>,
78 pub redirect_uri: Option<String>,
80 pub client_id: Option<String>,
82 pub client_secret: Option<String>,
84 pub scope: Option<String>,
86 pub nonce: Option<String>,
88}
89
90#[derive(Debug, Serialize)]
92pub struct TokenResponse {
93 pub access_token: String,
95 pub token_type: String,
97 pub expires_in: i64,
99 #[serde(skip_serializing_if = "Option::is_none")]
101 pub refresh_token: Option<String>,
102 #[serde(skip_serializing_if = "Option::is_none")]
104 pub scope: Option<String>,
105 #[serde(skip_serializing_if = "Option::is_none")]
107 pub id_token: Option<String>,
108}
109
110pub async fn authorize(
112 State(state): State<OAuth2ServerState>,
113 Query(params): Query<AuthorizationRequest>,
114) -> Result<Redirect, StatusCode> {
115 if params.response_type != "code" {
117 return Err(StatusCode::BAD_REQUEST);
118 }
119
120 let auth_code = {
125 let mut rng = rand::thread_rng();
126 let code_bytes: [u8; 32] = rng.gen();
127 hex::encode(code_bytes)
128 };
129
130 let scopes = params
132 .scope
133 .as_ref()
134 .map(|s| s.split(' ').map(|s| s.to_string()).collect())
135 .unwrap_or_else(Vec::new);
136
137 let code_info = AuthorizationCodeInfo {
139 client_id: params.client_id.clone(),
140 redirect_uri: params.redirect_uri.clone(),
141 scopes,
142 user_id: "user-default".to_string(),
145 state: params.state.clone(),
146 expires_at: Utc::now().timestamp() + 600, tenant_context: None,
149 };
150
151 {
152 let mut codes = state.auth_codes.write().await;
153 codes.insert(auth_code.clone(), code_info);
154 }
155
156 let mut redirect_url =
158 url::Url::parse(¶ms.redirect_uri).map_err(|_| StatusCode::BAD_REQUEST)?;
159 redirect_url.query_pairs_mut().append_pair("code", &auth_code);
160 if let Some(state) = params.state {
161 redirect_url.query_pairs_mut().append_pair("state", &state);
162 }
163
164 Ok(Redirect::to(redirect_url.as_str()))
165}
166
167pub async fn token(
169 State(state): State<OAuth2ServerState>,
170 axum::extract::Form(request): axum::extract::Form<TokenRequest>,
171) -> Result<Json<TokenResponse>, StatusCode> {
172 use chrono::Utc;
173
174 match request.grant_type.as_str() {
175 "authorization_code" => handle_authorization_code_grant(state, request).await,
176 "client_credentials" => handle_client_credentials_grant(state, request).await,
177 "refresh_token" => handle_refresh_token_grant(state, request).await,
178 _ => Err(StatusCode::BAD_REQUEST),
179 }
180}
181
182async fn handle_authorization_code_grant(
184 state: OAuth2ServerState,
185 request: TokenRequest,
186) -> Result<Json<TokenResponse>, StatusCode> {
187 let code = request.code.ok_or(StatusCode::BAD_REQUEST)?;
188 let redirect_uri = request.redirect_uri.ok_or(StatusCode::BAD_REQUEST)?;
189
190 let code_info = {
192 let mut codes = state.auth_codes.write().await;
193 codes.remove(&code).ok_or(StatusCode::BAD_REQUEST)?
194 };
195
196 if code_info.redirect_uri != redirect_uri {
198 return Err(StatusCode::BAD_REQUEST);
199 }
200
201 if code_info.expires_at < Utc::now().timestamp() {
203 return Err(StatusCode::BAD_REQUEST);
204 }
205
206 let oidc_state_guard = state.oidc_state.read().await;
208 let oidc_state = oidc_state_guard.as_ref().ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
209
210 let mut additional_claims = HashMap::new();
212 additional_claims.insert("scope".to_string(), json!(code_info.scopes.join(" ")));
213 if let Some(nonce) = request.nonce {
214 additional_claims.insert("nonce".to_string(), json!(nonce));
215 }
216
217 let access_token = generate_oidc_token(
218 oidc_state,
219 code_info.user_id.clone(),
220 Some(additional_claims),
221 Some(3600), code_info.tenant_context.clone(),
223 )
224 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
225
226 let token_id = extract_token_id(&access_token);
228 if state.lifecycle_manager.revocation.is_revoked(&token_id).await.is_some() {
229 return Err(StatusCode::INTERNAL_SERVER_ERROR);
230 }
231
232 let refresh_token = format!("refresh_{}", uuid::Uuid::new_v4());
234
235 Ok(Json(TokenResponse {
236 access_token,
237 token_type: "Bearer".to_string(),
238 expires_in: 3600,
239 refresh_token: Some(refresh_token),
240 scope: Some(code_info.scopes.join(" ")),
241 id_token: None,
244 }))
245}
246
247async fn handle_client_credentials_grant(
249 state: OAuth2ServerState,
250 request: TokenRequest,
251) -> Result<Json<TokenResponse>, StatusCode> {
252 let client_id = request.client_id.ok_or(StatusCode::BAD_REQUEST)?;
253 let _client_secret = request.client_secret.ok_or(StatusCode::BAD_REQUEST)?;
254
255 let oidc_state_guard = state.oidc_state.read().await;
259 let oidc_state = oidc_state_guard.as_ref().ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
260
261 let mut additional_claims = HashMap::new();
262 additional_claims.insert("client_id".to_string(), serde_json::json!(client_id));
263 let scope_clone = request.scope.clone();
264 if let Some(ref scope) = request.scope {
265 additional_claims.insert("scope".to_string(), serde_json::json!(scope));
266 }
267
268 let access_token = generate_oidc_token(
269 oidc_state,
270 format!("client_{}", client_id),
271 Some(additional_claims),
272 Some(3600),
273 None,
274 )
275 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
276
277 Ok(Json(TokenResponse {
278 access_token,
279 token_type: "Bearer".to_string(),
280 expires_in: 3600,
281 refresh_token: None,
282 scope: scope_clone,
283 id_token: None,
284 }))
285}
286
287async fn handle_refresh_token_grant(
289 state: OAuth2ServerState,
290 request: TokenRequest,
291) -> Result<Json<TokenResponse>, StatusCode> {
292 let client_id = request.client_id.ok_or(StatusCode::BAD_REQUEST)?;
300
301 let oidc_state_guard = state.oidc_state.read().await;
303 let oidc_state = oidc_state_guard.as_ref().ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
304
305 let mut additional_claims = HashMap::new();
306 additional_claims.insert("client_id".to_string(), json!(client_id));
307 let scope_clone = request.scope.clone();
308 if let Some(ref scope) = request.scope {
309 additional_claims.insert("scope".to_string(), json!(scope));
310 }
311
312 let access_token = generate_oidc_token(
313 oidc_state,
314 format!("client_{}", client_id),
315 Some(additional_claims),
316 Some(3600),
317 None,
318 )
319 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
320
321 let refresh_token = format!("refresh_{}", uuid::Uuid::new_v4());
323
324 Ok(Json(TokenResponse {
325 access_token,
326 token_type: "Bearer".to_string(),
327 expires_in: 3600,
328 refresh_token: Some(refresh_token),
329 scope: scope_clone,
330 id_token: None,
331 }))
332}
333
334pub fn oauth2_server_router(state: OAuth2ServerState) -> axum::Router {
336 use axum::routing::{get, post};
337
338 axum::Router::new()
339 .route("/oauth2/authorize", get(authorize))
340 .route("/oauth2/token", post(token))
341 .with_state(state)
342}