1use crate::api::{ApiResponse, ApiState, extract_bearer_token};
13use axum::{Json, extract::Query, extract::State, http::HeaderMap};
14use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
15use serde::{Deserialize, Serialize};
17use sha2::{Digest, Sha256};
18#[derive(Debug, Deserialize)]
22pub struct AuthorizeRequest {
23 pub response_type: String,
24 pub client_id: String,
25 pub redirect_uri: String,
26 #[serde(default)]
27 pub scope: Option<String>,
28 #[serde(default)]
29 pub state: Option<String>,
30 #[serde(default)]
31 pub code_challenge: Option<String>, #[serde(default)]
33 pub code_challenge_method: Option<String>, }
35
36#[derive(Debug, Serialize)]
38pub struct AuthorizeResponse {
39 pub authorization_url: String,
40 pub state: Option<String>,
41}
42
43#[derive(Debug, Deserialize)]
45pub struct TokenRequest {
46 pub grant_type: String,
47 #[serde(default)]
48 pub code: Option<String>,
49 #[serde(default)]
50 pub redirect_uri: Option<String>,
51 #[serde(default)]
52 pub client_id: Option<String>,
53 #[serde(default)]
54 pub client_secret: Option<String>,
55 #[serde(default)]
56 pub code_verifier: Option<String>, #[serde(default)]
58 pub refresh_token: Option<String>,
59}
60
61#[derive(Debug, Serialize)]
63pub struct TokenResponse {
64 pub access_token: String,
65 pub token_type: String,
66 pub expires_in: u64,
67 #[serde(skip_serializing_if = "Option::is_none")]
68 pub refresh_token: Option<String>,
69 #[serde(skip_serializing_if = "Option::is_none")]
70 pub scope: Option<String>,
71}
72
73#[derive(Debug, Deserialize)]
75pub struct RevokeRequest {
76 pub token: String,
77 #[serde(default)]
78 pub token_type_hint: Option<String>, }
80
81#[derive(Debug, Serialize)]
83pub struct UserInfoResponse {
84 pub sub: String,
85 #[serde(skip_serializing_if = "Option::is_none")]
86 pub name: Option<String>,
87 #[serde(skip_serializing_if = "Option::is_none")]
88 pub email: Option<String>,
89 #[serde(skip_serializing_if = "Option::is_none")]
90 pub picture: Option<String>,
91 #[serde(skip_serializing_if = "Option::is_none")]
92 pub updated_at: Option<i64>,
93}
94
95pub async fn authorize(
97 State(state): State<ApiState>,
98 Query(req): Query<AuthorizeRequest>,
99) -> ApiResponse<AuthorizeResponse> {
100 if req.response_type != "code" {
102 return ApiResponse::error_typed(
103 "unsupported_response_type",
104 "Only 'code' response type is supported",
105 );
106 }
107
108 if let Some(scope_str) = &req.scope {
110 let requested_scopes: Vec<&str> = scope_str.split_whitespace().collect();
111 let allowed_scopes = [
112 "openid",
113 "profile",
114 "email",
115 "address",
116 "phone",
117 "offline_access",
118 "read",
119 "write",
120 "admin",
121 ];
122
123 for scope in &requested_scopes {
124 if !allowed_scopes.contains(scope) {
125 return ApiResponse::error_typed(
126 "invalid_scope",
127 format!("Requested scope '{}' is not supported", scope),
128 );
129 }
130 }
131
132 for scope in &requested_scopes {
134 if scope.is_empty()
135 || !scope.chars().all(|c| {
136 c.is_alphanumeric() || c == ':' || c == '/' || c == '.' || c == '-' || c == '_'
137 })
138 {
139 return ApiResponse::error_typed(
140 "invalid_scope",
141 format!("Invalid scope format: '{}'", scope),
142 );
143 }
144 }
145 }
146
147 if req.client_id.is_empty() {
149 return ApiResponse::validation_error_typed("client_id is required");
150 }
151
152 if req.redirect_uri.is_empty() {
154 return ApiResponse::validation_error_typed("redirect_uri is required");
155 }
156
157 tracing::info!(
159 "OAuth2 authorization request from client: {}",
160 req.client_id
161 );
162
163 let auth_code = format!("ac_{}", uuid::Uuid::new_v4().to_string().replace("-", ""));
165
166 let code_data = serde_json::json!({
168 "client_id": req.client_id,
169 "redirect_uri": req.redirect_uri,
170 "scope": req.scope.clone().unwrap_or_else(|| "openid profile email".to_string()),
171 "state": req.state.clone(),
172 "code_challenge": req.code_challenge,
173 "code_challenge_method": req.code_challenge_method,
174 "created_at": chrono::Utc::now().to_rfc3339(),
175 "expires_at": (chrono::Utc::now() + chrono::Duration::minutes(10)).to_rfc3339(),
176 "used": false,
177 });
178
179 let storage_key = format!("oauth2_code:{}", auth_code);
180 let code_data_str = serde_json::to_string(&code_data).unwrap();
181
182 match state
184 .auth_framework
185 .storage()
186 .store_kv(
187 &storage_key,
188 code_data_str.as_bytes(),
189 Some(std::time::Duration::from_secs(600)),
190 )
191 .await
192 {
193 Ok(_) => {
194 let mut auth_url = format!("{}?code={}", req.redirect_uri, auth_code);
196 if let Some(state_param) = &req.state {
197 auth_url = format!("{}&state={}", auth_url, state_param);
198 }
199
200 let response = AuthorizeResponse {
201 authorization_url: auth_url,
202 state: req.state,
203 };
204
205 tracing::info!("Authorization code generated for client: {}", req.client_id);
206 ApiResponse::success(response)
207 }
208 Err(e) => {
209 tracing::error!("Failed to store authorization code: {:?}", e);
210 ApiResponse::error_typed(
211 "AUTHORIZATION_FAILED",
212 "Failed to generate authorization code",
213 )
214 }
215 }
216}
217
218pub async fn token(
220 State(state): State<ApiState>,
221 Json(req): Json<TokenRequest>,
222) -> ApiResponse<TokenResponse> {
223 match req.grant_type.as_str() {
224 "authorization_code" => handle_authorization_code_grant(state, req).await,
225 "refresh_token" => handle_refresh_token_grant(state, req).await,
226 _ => ApiResponse::error_typed(
227 "unsupported_grant_type",
228 "Supported grant types: authorization_code, refresh_token",
229 ),
230 }
231}
232
233async fn handle_authorization_code_grant(
234 state: ApiState,
235 req: TokenRequest,
236) -> ApiResponse<TokenResponse> {
237 let code = match req.code {
238 Some(c) => c,
239 None => {
240 return ApiResponse::validation_error_typed(
241 "code is required for authorization_code grant",
242 );
243 }
244 };
245
246 let client_id = match req.client_id {
247 Some(c) => c,
248 None => return ApiResponse::validation_error_typed("client_id is required"),
249 };
250
251 let storage_key = format!("oauth2_code:{}", code);
253 let code_data = match state.auth_framework.storage().get_kv(&storage_key).await {
254 Ok(Some(data)) => match serde_json::from_slice::<serde_json::Value>(&data) {
255 Ok(json) => json,
256 Err(e) => {
257 tracing::error!("Failed to parse stored authorization code data: {:?}", e);
258 return ApiResponse::error_typed("invalid_grant", "Invalid authorization code");
259 }
260 },
261 Ok(None) => {
262 return ApiResponse::error_typed(
263 "invalid_grant",
264 "Authorization code not found or expired",
265 );
266 }
267 Err(e) => {
268 tracing::error!("Failed to retrieve authorization code: {:?}", e);
269 return ApiResponse::error_typed(
270 "server_error",
271 "Failed to validate authorization code",
272 );
273 }
274 };
275
276 if code_data["used"].as_bool().unwrap_or(false) {
278 return ApiResponse::error_typed(
279 "invalid_grant",
280 "Authorization code has already been used",
281 );
282 }
283
284 if code_data["client_id"].as_str() != Some(&client_id) {
286 return ApiResponse::error_typed("invalid_grant", "client_id mismatch");
287 }
288
289 if let Some(redirect_uri) = &req.redirect_uri
291 && code_data["redirect_uri"].as_str() != Some(redirect_uri)
292 {
293 return ApiResponse::error_typed("invalid_grant", "redirect_uri mismatch");
294 }
295
296 let stored_challenge = code_data["code_challenge"].as_str();
298 let challenge_method = code_data["code_challenge_method"]
299 .as_str()
300 .unwrap_or("plain");
301
302 if let Some(stored) = stored_challenge {
303 let code_verifier = match &req.code_verifier {
305 Some(verifier) => verifier,
306 None => {
307 return ApiResponse::error_typed(
308 "invalid_request",
309 "code_verifier is required when PKCE challenge was provided",
310 );
311 }
312 };
313
314 let computed_challenge = match challenge_method {
315 "S256" => {
316 let mut hasher = Sha256::new();
317 hasher.update(code_verifier.as_bytes());
318 URL_SAFE_NO_PAD.encode(hasher.finalize())
319 }
320 "plain" => code_verifier.clone(),
321 _ => {
322 return ApiResponse::error_typed(
323 "invalid_request",
324 "Unsupported code_challenge_method",
325 );
326 }
327 };
328
329 if computed_challenge != stored {
330 return ApiResponse::error_typed("invalid_grant", "PKCE verification failed");
331 }
332 } else if req.code_verifier.is_some() {
333 return ApiResponse::error_typed(
335 "invalid_request",
336 "code_verifier provided but no PKCE challenge was used in authorization",
337 );
338 }
339
340 let mut updated_code_data = code_data.clone();
342 updated_code_data["used"] = serde_json::Value::Bool(true);
343 let updated_data_str = serde_json::to_string(&updated_code_data).unwrap();
344
345 if let Err(e) = state
346 .auth_framework
347 .storage()
348 .store_kv(
349 &storage_key,
350 updated_data_str.as_bytes(),
351 Some(std::time::Duration::from_secs(600)),
352 )
353 .await
354 {
355 tracing::error!("Failed to mark authorization code as used: {:?}", e);
356 }
357
358 let scope = code_data["scope"]
360 .as_str()
361 .unwrap_or("openid profile email");
362 let scopes: Vec<String> = scope.split_whitespace().map(|s| s.to_string()).collect();
363
364 let user_id = format!("oauth2_user_{}", client_id);
366
367 let token = match state.auth_framework.token_manager().create_auth_token(
368 &user_id,
369 scopes.clone(),
370 "oauth2",
371 None,
372 ) {
373 Ok(token) => token,
374 Err(e) => {
375 tracing::error!("Failed to create access token: {:?}", e);
376 return ApiResponse::error_typed("server_error", "Failed to create access token");
377 }
378 };
379
380 let response = TokenResponse {
381 access_token: token.access_token,
382 token_type: "Bearer".to_string(),
383 expires_in: 3600,
384 refresh_token: token.refresh_token,
385 scope: Some(scope.to_string()),
386 };
387
388 tracing::info!("OAuth2 tokens issued for client: {}", client_id);
389 ApiResponse::success(response)
390}
391
392async fn handle_refresh_token_grant(
393 state: ApiState,
394 req: TokenRequest,
395) -> ApiResponse<TokenResponse> {
396 let _refresh_token = match req.refresh_token {
397 Some(token) => token,
398 None => return ApiResponse::validation_error_typed("refresh_token is required"),
399 };
400
401 let client_id = req
408 .client_id
409 .unwrap_or_else(|| "unknown_client".to_string());
410 let user_id = format!("oauth2_user_{}", client_id);
411
412 let token = match state.auth_framework.token_manager().create_auth_token(
413 &user_id,
414 vec!["openid".to_string(), "profile".to_string()],
415 "oauth2",
416 None,
417 ) {
418 Ok(token) => token,
419 Err(e) => {
420 tracing::error!("Failed to refresh token: {:?}", e);
421 return ApiResponse::error_typed("invalid_grant", "Failed to refresh token");
422 }
423 };
424
425 let response = TokenResponse {
426 access_token: token.access_token,
427 token_type: "Bearer".to_string(),
428 expires_in: 3600,
429 refresh_token: token.refresh_token,
430 scope: Some("openid profile email".to_string()),
431 };
432
433 tracing::info!("OAuth2 token refreshed for client: {}", client_id);
434 ApiResponse::success(response)
435}
436
437pub async fn revoke(
439 State(state): State<ApiState>,
440 Json(req): Json<RevokeRequest>,
441) -> ApiResponse<serde_json::Value> {
442 let revoked_token_key = format!("oauth2_revoked_token:{}", req.token);
444 let revoked_data = serde_json::json!({
445 "token": req.token,
446 "revoked_at": chrono::Utc::now().to_rfc3339(),
447 "token_type_hint": req.token_type_hint
448 });
449
450 if let Err(e) = state
451 .auth_framework
452 .storage()
453 .store_kv(
454 &revoked_token_key,
455 serde_json::to_string(&revoked_data).unwrap().as_bytes(),
456 Some(std::time::Duration::from_secs(86400 * 7)), )
458 .await
459 {
460 tracing::error!("Failed to store revoked token: {:?}", e);
461 return ApiResponse::error_typed("server_error", "Failed to revoke token");
462 }
463
464 tracing::info!(
465 "OAuth2 token revoked: {}",
466 &req.token[..10.min(req.token.len())]
467 );
468
469 ApiResponse::success(serde_json::json!({
470 "message": "Token revoked successfully"
471 }))
472}
473
474pub async fn userinfo(
476 State(state): State<ApiState>,
477 headers: HeaderMap,
478) -> ApiResponse<UserInfoResponse> {
479 let token = match extract_bearer_token(&headers) {
481 Some(t) => t,
482 None => {
483 return ApiResponse::error_typed("invalid_token", "Authorization header required");
484 }
485 };
486
487 let revoked_token_key = format!("oauth2_revoked_token:{}", token);
489 if let Ok(Some(_)) = state
490 .auth_framework
491 .storage()
492 .get_kv(&revoked_token_key)
493 .await
494 {
495 return ApiResponse::error_typed("invalid_token", "Token has been revoked");
496 }
497
498 let claims = match state
500 .auth_framework
501 .token_manager()
502 .validate_jwt_token(&token)
503 {
504 Ok(c) => c,
505 Err(_) => {
506 return ApiResponse::error_typed("invalid_token", "Access token is invalid");
507 }
508 };
509
510 let user_profile = match state.auth_framework.get_user_profile(&claims.sub).await {
512 Ok(profile) => profile,
513 Err(e) => {
514 tracing::error!("Failed to get user profile: {:?}", e);
515 return ApiResponse::error_typed("server_error", "Failed to retrieve user information");
516 }
517 };
518
519 let userinfo = UserInfoResponse {
520 sub: claims.sub.clone(),
521 name: user_profile.username.clone(),
522 email: user_profile.email.clone(),
523 picture: user_profile.picture.clone(),
524 updated_at: Some(chrono::Utc::now().timestamp()),
525 };
526
527 tracing::info!("OAuth2 UserInfo requested for user: {}", claims.sub);
528 ApiResponse::success(userinfo)
529}