1use axum::{
7 extract::{Query, State},
8 http::StatusCode,
9 response::{Json, Redirect},
10};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14use tokio::sync::RwLock;
15
16use crate::auth::oidc::{generate_oidc_token, OidcState, TenantContext};
17use crate::auth::token_lifecycle::{extract_token_id, TokenLifecycleManager};
18use chrono::Utc;
19use hex;
20use rand::Rng;
21use serde_json::json;
22use uuid;
23
24#[derive(Clone)]
26pub struct OAuth2ServerState {
27 pub oidc_state: Arc<RwLock<Option<OidcState>>>,
29 pub lifecycle_manager: Arc<TokenLifecycleManager>,
31 pub auth_codes: Arc<RwLock<HashMap<String, AuthorizationCodeInfo>>>,
33}
34
35#[derive(Debug, Clone)]
37pub struct AuthorizationCodeInfo {
38 pub client_id: String,
40 pub redirect_uri: String,
42 pub scopes: Vec<String>,
44 pub user_id: String,
46 pub state: Option<String>,
48 pub expires_at: i64,
50 pub tenant_context: Option<TenantContext>,
52}
53
54#[derive(Debug, Deserialize)]
56pub struct AuthorizationRequest {
57 pub client_id: String,
59 pub response_type: String,
61 pub redirect_uri: String,
63 pub scope: Option<String>,
65 pub state: Option<String>,
67 pub nonce: Option<String>,
69}
70
71#[derive(Debug, Deserialize)]
73pub struct TokenRequest {
74 pub grant_type: String,
76 pub code: Option<String>,
78 pub redirect_uri: Option<String>,
80 pub client_id: Option<String>,
82 pub client_secret: Option<String>,
84 pub scope: Option<String>,
86 pub nonce: Option<String>,
88}
89
90#[derive(Debug, Serialize)]
92pub struct TokenResponse {
93 pub access_token: String,
95 pub token_type: String,
97 pub expires_in: i64,
99 #[serde(skip_serializing_if = "Option::is_none")]
101 pub refresh_token: Option<String>,
102 #[serde(skip_serializing_if = "Option::is_none")]
104 pub scope: Option<String>,
105 #[serde(skip_serializing_if = "Option::is_none")]
107 pub id_token: Option<String>,
108}
109
110pub async fn authorize(
112 State(state): State<OAuth2ServerState>,
113 Query(params): Query<AuthorizationRequest>,
114) -> Result<Redirect, StatusCode> {
115 if params.response_type != "code" {
117 return Err(StatusCode::BAD_REQUEST);
118 }
119
120 let auth_code = {
125 let mut rng = rand::thread_rng();
126 let code_bytes: [u8; 32] = rng.gen();
127 hex::encode(code_bytes)
128 };
129
130 let scopes = params
132 .scope
133 .as_ref()
134 .map(|s| s.split(' ').map(|s| s.to_string()).collect())
135 .unwrap_or_else(Vec::new);
136
137 let code_info = AuthorizationCodeInfo {
139 client_id: params.client_id.clone(),
140 redirect_uri: params.redirect_uri.clone(),
141 scopes,
142 user_id: "user-default".to_string(),
145 state: params.state.clone(),
146 expires_at: Utc::now().timestamp() + 600, tenant_context: None,
149 };
150
151 {
152 let mut codes = state.auth_codes.write().await;
153 codes.insert(auth_code.clone(), code_info);
154 }
155
156 let mut redirect_url =
158 url::Url::parse(¶ms.redirect_uri).map_err(|_| StatusCode::BAD_REQUEST)?;
159 redirect_url.query_pairs_mut().append_pair("code", &auth_code);
160 if let Some(state) = params.state {
161 redirect_url.query_pairs_mut().append_pair("state", &state);
162 }
163
164 Ok(Redirect::to(redirect_url.as_str()))
165}
166
167pub async fn token(
169 State(state): State<OAuth2ServerState>,
170 axum::extract::Form(request): axum::extract::Form<TokenRequest>,
171) -> Result<Json<TokenResponse>, StatusCode> {
172 match request.grant_type.as_str() {
173 "authorization_code" => handle_authorization_code_grant(state, request).await,
174 "client_credentials" => handle_client_credentials_grant(state, request).await,
175 "refresh_token" => handle_refresh_token_grant(state, request).await,
176 _ => Err(StatusCode::BAD_REQUEST),
177 }
178}
179
180async fn handle_authorization_code_grant(
182 state: OAuth2ServerState,
183 request: TokenRequest,
184) -> Result<Json<TokenResponse>, StatusCode> {
185 let code = request.code.ok_or(StatusCode::BAD_REQUEST)?;
186 let redirect_uri = request.redirect_uri.ok_or(StatusCode::BAD_REQUEST)?;
187
188 let code_info = {
190 let mut codes = state.auth_codes.write().await;
191 codes.remove(&code).ok_or(StatusCode::BAD_REQUEST)?
192 };
193
194 if code_info.redirect_uri != redirect_uri {
196 return Err(StatusCode::BAD_REQUEST);
197 }
198
199 if code_info.expires_at < Utc::now().timestamp() {
201 return Err(StatusCode::BAD_REQUEST);
202 }
203
204 let oidc_state_guard = state.oidc_state.read().await;
206 let oidc_state = oidc_state_guard.as_ref().ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
207
208 let mut additional_claims = HashMap::new();
210 additional_claims.insert("scope".to_string(), json!(code_info.scopes.join(" ")));
211 if let Some(nonce) = request.nonce {
212 additional_claims.insert("nonce".to_string(), json!(nonce));
213 }
214
215 let access_token = generate_oidc_token(
216 oidc_state,
217 code_info.user_id.clone(),
218 Some(additional_claims),
219 Some(3600), code_info.tenant_context.clone(),
221 )
222 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
223
224 let token_id = extract_token_id(&access_token);
226 if state.lifecycle_manager.revocation.is_revoked(&token_id).await.is_some() {
227 return Err(StatusCode::INTERNAL_SERVER_ERROR);
228 }
229
230 let refresh_token = format!("refresh_{}", uuid::Uuid::new_v4());
232
233 Ok(Json(TokenResponse {
234 access_token,
235 token_type: "Bearer".to_string(),
236 expires_in: 3600,
237 refresh_token: Some(refresh_token),
238 scope: Some(code_info.scopes.join(" ")),
239 id_token: None,
242 }))
243}
244
245async fn handle_client_credentials_grant(
247 state: OAuth2ServerState,
248 request: TokenRequest,
249) -> Result<Json<TokenResponse>, StatusCode> {
250 let client_id = request.client_id.ok_or(StatusCode::BAD_REQUEST)?;
251 let _client_secret = request.client_secret.ok_or(StatusCode::BAD_REQUEST)?;
252
253 let oidc_state_guard = state.oidc_state.read().await;
257 let oidc_state = oidc_state_guard.as_ref().ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
258
259 let mut additional_claims = HashMap::new();
260 additional_claims.insert("client_id".to_string(), serde_json::json!(client_id));
261 let scope_clone = request.scope.clone();
262 if let Some(ref scope) = request.scope {
263 additional_claims.insert("scope".to_string(), serde_json::json!(scope));
264 }
265
266 let access_token = generate_oidc_token(
267 oidc_state,
268 format!("client_{}", client_id),
269 Some(additional_claims),
270 Some(3600),
271 None,
272 )
273 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
274
275 Ok(Json(TokenResponse {
276 access_token,
277 token_type: "Bearer".to_string(),
278 expires_in: 3600,
279 refresh_token: None,
280 scope: scope_clone,
281 id_token: None,
282 }))
283}
284
285async fn handle_refresh_token_grant(
287 state: OAuth2ServerState,
288 request: TokenRequest,
289) -> Result<Json<TokenResponse>, StatusCode> {
290 let client_id = request.client_id.ok_or(StatusCode::BAD_REQUEST)?;
298
299 let oidc_state_guard = state.oidc_state.read().await;
301 let oidc_state = oidc_state_guard.as_ref().ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
302
303 let mut additional_claims = HashMap::new();
304 additional_claims.insert("client_id".to_string(), json!(client_id));
305 let scope_clone = request.scope.clone();
306 if let Some(ref scope) = request.scope {
307 additional_claims.insert("scope".to_string(), json!(scope));
308 }
309
310 let access_token = generate_oidc_token(
311 oidc_state,
312 format!("client_{}", client_id),
313 Some(additional_claims),
314 Some(3600),
315 None,
316 )
317 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
318
319 let refresh_token = format!("refresh_{}", uuid::Uuid::new_v4());
321
322 Ok(Json(TokenResponse {
323 access_token,
324 token_type: "Bearer".to_string(),
325 expires_in: 3600,
326 refresh_token: Some(refresh_token),
327 scope: scope_clone,
328 id_token: None,
329 }))
330}
331
332pub fn oauth2_server_router(state: OAuth2ServerState) -> axum::Router {
334 use axum::routing::{get, post};
335
336 axum::Router::new()
337 .route("/oauth2/authorize", get(authorize))
338 .route("/oauth2/token", post(token))
339 .with_state(state)
340}