mockforge_http/handlers/
oauth2_server.rs

1//! OAuth2 server endpoints
2//!
3//! This module provides OAuth2 authorization server endpoints that integrate
4//! with OIDC, token lifecycle, consent, and risk simulation.
5
6use axum::{
7    extract::{Query, State},
8    http::StatusCode,
9    response::{Json, Redirect},
10};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14use tokio::sync::RwLock;
15
16use crate::auth::oidc::{generate_oidc_token, OidcState, TenantContext};
17use crate::auth::token_lifecycle::{extract_token_id, TokenLifecycleManager};
18use chrono::Utc;
19use hex;
20use rand::Rng;
21use serde_json::json;
22use uuid;
23
24/// OAuth2 server state
25#[derive(Clone)]
26pub struct OAuth2ServerState {
27    /// OIDC state for token generation
28    pub oidc_state: Arc<RwLock<Option<OidcState>>>,
29    /// Token lifecycle manager
30    pub lifecycle_manager: Arc<TokenLifecycleManager>,
31    /// Authorization codes (code -> authorization info)
32    pub auth_codes: Arc<RwLock<HashMap<String, AuthorizationCodeInfo>>>,
33}
34
35/// Authorization code information
36#[derive(Debug, Clone)]
37pub struct AuthorizationCodeInfo {
38    /// Client ID
39    pub client_id: String,
40    /// Redirect URI
41    pub redirect_uri: String,
42    /// Scopes requested
43    pub scopes: Vec<String>,
44    /// User ID (subject)
45    pub user_id: String,
46    /// State parameter (CSRF protection)
47    pub state: Option<String>,
48    /// Expiration time
49    pub expires_at: i64,
50    /// Tenant context
51    pub tenant_context: Option<TenantContext>,
52}
53
54/// OAuth2 authorization request parameters
55#[derive(Debug, Deserialize)]
56pub struct AuthorizationRequest {
57    /// Client ID
58    pub client_id: String,
59    /// Response type (code, token, id_token)
60    pub response_type: String,
61    /// Redirect URI
62    pub redirect_uri: String,
63    /// Scopes (space-separated)
64    pub scope: Option<String>,
65    /// State parameter (CSRF protection)
66    pub state: Option<String>,
67    /// Nonce (for OpenID Connect)
68    pub nonce: Option<String>,
69}
70
71/// OAuth2 token request
72#[derive(Debug, Deserialize)]
73pub struct TokenRequest {
74    /// Grant type
75    pub grant_type: String,
76    /// Authorization code (for authorization_code grant)
77    pub code: Option<String>,
78    /// Redirect URI (must match authorization request)
79    pub redirect_uri: Option<String>,
80    /// Client ID
81    pub client_id: Option<String>,
82    /// Client secret
83    pub client_secret: Option<String>,
84    /// Scope (for client_credentials grant)
85    pub scope: Option<String>,
86    /// Nonce (for OpenID Connect)
87    pub nonce: Option<String>,
88}
89
90/// OAuth2 token response
91#[derive(Debug, Serialize)]
92pub struct TokenResponse {
93    /// Access token
94    pub access_token: String,
95    /// Token type (usually "Bearer")
96    pub token_type: String,
97    /// Expires in (seconds)
98    pub expires_in: i64,
99    /// Refresh token (optional)
100    #[serde(skip_serializing_if = "Option::is_none")]
101    pub refresh_token: Option<String>,
102    /// Scope (optional)
103    #[serde(skip_serializing_if = "Option::is_none")]
104    pub scope: Option<String>,
105    /// ID token (for OpenID Connect)
106    #[serde(skip_serializing_if = "Option::is_none")]
107    pub id_token: Option<String>,
108}
109
110/// OAuth2 authorization endpoint
111pub async fn authorize(
112    State(state): State<OAuth2ServerState>,
113    Query(params): Query<AuthorizationRequest>,
114) -> Result<Redirect, StatusCode> {
115    
116    // Validate response_type
117    if params.response_type != "code" {
118        return Err(StatusCode::BAD_REQUEST);
119    }
120    
121    // Check if consent is required (simplified - in production, check user consent)
122    // For now, auto-approve and generate authorization code
123    
124    // Generate authorization code
125    let mut rng = rand::thread_rng();
126    let code_bytes: [u8; 32] = rng.gen();
127    let auth_code = hex::encode(code_bytes);
128    
129    // Parse scopes
130    let scopes = params
131        .scope
132        .as_ref()
133        .map(|s| s.split(' ').map(|s| s.to_string()).collect())
134        .unwrap_or_else(Vec::new);
135    
136    // Store authorization code (expires in 10 minutes)
137    let code_info = AuthorizationCodeInfo {
138        client_id: params.client_id.clone(),
139        redirect_uri: params.redirect_uri.clone(),
140        scopes,
141        // For mock server, use default user ID
142        // In production, extract from authenticated session
143        user_id: "user-default".to_string(),
144        state: params.state.clone(),
145        expires_at: Utc::now().timestamp() + 600, // 10 minutes
146        // Tenant context can be extracted from request headers or session
147        tenant_context: None,
148    };
149    
150    {
151        let mut codes = state.auth_codes.write().await;
152        codes.insert(auth_code.clone(), code_info);
153    }
154    
155    // Build redirect URL with authorization code
156    let mut redirect_url = url::Url::parse(&params.redirect_uri)
157        .map_err(|_| StatusCode::BAD_REQUEST)?;
158    redirect_url
159        .query_pairs_mut()
160        .append_pair("code", &auth_code);
161    if let Some(state) = params.state {
162        redirect_url.query_pairs_mut().append_pair("state", &state);
163    }
164    
165    Ok(Redirect::to(redirect_url.as_str()))
166}
167
168/// OAuth2 token endpoint
169pub async fn token(
170    State(state): State<OAuth2ServerState>,
171    axum::extract::Form(request): axum::extract::Form<TokenRequest>,
172) -> Result<Json<TokenResponse>, StatusCode> {
173    use chrono::Utc;
174    
175    match request.grant_type.as_str() {
176        "authorization_code" => {
177            handle_authorization_code_grant(state, request).await
178        }
179        "client_credentials" => {
180            handle_client_credentials_grant(state, request).await
181        }
182        "refresh_token" => {
183            handle_refresh_token_grant(state, request).await
184        }
185        _ => Err(StatusCode::BAD_REQUEST),
186    }
187}
188
189/// Handle authorization_code grant type
190async fn handle_authorization_code_grant(
191    state: OAuth2ServerState,
192    request: TokenRequest,
193) -> Result<Json<TokenResponse>, StatusCode> {
194    
195    let code = request.code.ok_or(StatusCode::BAD_REQUEST)?;
196    let redirect_uri = request.redirect_uri.ok_or(StatusCode::BAD_REQUEST)?;
197    
198    // Look up authorization code
199    let code_info = {
200        let mut codes = state.auth_codes.write().await;
201        codes.remove(&code).ok_or(StatusCode::BAD_REQUEST)?
202    };
203    
204    // Validate redirect URI
205    if code_info.redirect_uri != redirect_uri {
206        return Err(StatusCode::BAD_REQUEST);
207    }
208    
209    // Check expiration
210    if code_info.expires_at < Utc::now().timestamp() {
211        return Err(StatusCode::BAD_REQUEST);
212    }
213    
214    // Generate access token using OIDC
215    let oidc_state_guard = state.oidc_state.read().await;
216    let oidc_state = oidc_state_guard
217        .as_ref()
218        .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
219    
220    // Build claims
221    let mut additional_claims = HashMap::new();
222    additional_claims.insert("scope".to_string(), json!(code_info.scopes.join(" ")));
223    if let Some(nonce) = request.nonce {
224        additional_claims.insert("nonce".to_string(), json!(nonce));
225    }
226    
227    let access_token = generate_oidc_token(
228        oidc_state,
229        code_info.user_id.clone(),
230        Some(additional_claims),
231        Some(3600), // 1 hour expiration
232        code_info.tenant_context.clone(),
233    )
234    .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
235    
236    // Check if token is revoked (shouldn't be, but check anyway)
237    let token_id = extract_token_id(&access_token);
238    if state.lifecycle_manager.revocation.is_revoked(&token_id).await.is_some() {
239        return Err(StatusCode::INTERNAL_SERVER_ERROR);
240    }
241    
242    // Generate refresh token (simplified)
243    let refresh_token = format!("refresh_{}", uuid::Uuid::new_v4());
244    
245    Ok(Json(TokenResponse {
246        access_token,
247        token_type: "Bearer".to_string(),
248        expires_in: 3600,
249        refresh_token: Some(refresh_token),
250        scope: Some(code_info.scopes.join(" ")),
251        // ID token generation for OpenID Connect can be added by calling generate_oidc_token
252        // with appropriate OpenID Connect claims (sub, iss, aud, exp, iat, nonce, etc.)
253        id_token: None,
254    }))
255}
256
257/// Handle client_credentials grant type
258async fn handle_client_credentials_grant(
259    state: OAuth2ServerState,
260    request: TokenRequest,
261) -> Result<Json<TokenResponse>, StatusCode> {
262    let client_id = request.client_id.ok_or(StatusCode::BAD_REQUEST)?;
263    let _client_secret = request.client_secret.ok_or(StatusCode::BAD_REQUEST)?;
264    
265    // Validate client credentials (simplified - in production, check against database)
266    
267    // Generate access token
268    let oidc_state_guard = state.oidc_state.read().await;
269    let oidc_state = oidc_state_guard
270        .as_ref()
271        .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
272    
273    let mut additional_claims = HashMap::new();
274    additional_claims.insert("client_id".to_string(), serde_json::json!(client_id));
275    if let Some(scope) = request.scope {
276        additional_claims.insert("scope".to_string(), serde_json::json!(scope));
277    }
278    
279    let access_token = generate_oidc_token(
280        oidc_state,
281        format!("client_{}", client_id),
282        Some(additional_claims),
283        Some(3600),
284        None,
285    )
286    .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
287    
288    Ok(Json(TokenResponse {
289        access_token,
290        token_type: "Bearer".to_string(),
291        expires_in: 3600,
292        refresh_token: None,
293        scope: request.scope,
294        id_token: None,
295    }))
296}
297
298/// Handle refresh_token grant type
299async fn handle_refresh_token_grant(
300    state: OAuth2ServerState,
301    request: TokenRequest,
302) -> Result<Json<TokenResponse>, StatusCode> {
303    // For refresh token grant, we would:
304    // 1. Validate the refresh token
305    // 2. Check if it's revoked
306    // 3. Generate a new access token
307    // 4. Optionally generate a new refresh token
308    
309    // Simplified implementation - in production, validate refresh token from storage
310    let client_id = request.client_id.ok_or(StatusCode::BAD_REQUEST)?;
311    
312    // Generate new access token
313    let oidc_state_guard = state.oidc_state.read().await;
314    let oidc_state = oidc_state_guard
315        .as_ref()
316        .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
317    
318    let mut additional_claims = HashMap::new();
319    additional_claims.insert("client_id".to_string(), json!(client_id));
320    if let Some(scope) = request.scope {
321        additional_claims.insert("scope".to_string(), json!(scope));
322    }
323    
324    let access_token = generate_oidc_token(
325        oidc_state,
326        format!("client_{}", client_id),
327        Some(additional_claims),
328        Some(3600),
329        None,
330    )
331    .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
332    
333    // Generate new refresh token
334    let refresh_token = format!("refresh_{}", uuid::Uuid::new_v4());
335    
336    Ok(Json(TokenResponse {
337        access_token,
338        token_type: "Bearer".to_string(),
339        expires_in: 3600,
340        refresh_token: Some(refresh_token),
341        scope: request.scope,
342        id_token: None,
343    }))
344}
345
346/// Create OAuth2 server router
347pub fn oauth2_server_router(state: OAuth2ServerState) -> axum::Router {
348    use axum::routing::{get, post};
349    
350    axum::Router::new()
351        .route("/oauth2/authorize", get(authorize))
352        .route("/oauth2/token", post(token))
353        .with_state(state)
354}
355