mockforge_http/handlers/
oauth2_server.rs

1//! OAuth2 server endpoints
2//!
3//! This module provides OAuth2 authorization server endpoints that integrate
4//! with OIDC, token lifecycle, consent, and risk simulation.
5
6use axum::{
7    extract::{Query, State},
8    http::StatusCode,
9    response::{Json, Redirect},
10};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14use tokio::sync::RwLock;
15
16use crate::auth::oidc::{generate_oidc_token, OidcState, TenantContext};
17use crate::auth::token_lifecycle::{extract_token_id, TokenLifecycleManager};
18use chrono::Utc;
19use hex;
20use rand::Rng;
21use serde_json::json;
22use uuid;
23
24/// OAuth2 server state
25#[derive(Clone)]
26pub struct OAuth2ServerState {
27    /// OIDC state for token generation
28    pub oidc_state: Arc<RwLock<Option<OidcState>>>,
29    /// Token lifecycle manager
30    pub lifecycle_manager: Arc<TokenLifecycleManager>,
31    /// Authorization codes (code -> authorization info)
32    pub auth_codes: Arc<RwLock<HashMap<String, AuthorizationCodeInfo>>>,
33}
34
35/// Authorization code information
36#[derive(Debug, Clone)]
37pub struct AuthorizationCodeInfo {
38    /// Client ID
39    pub client_id: String,
40    /// Redirect URI
41    pub redirect_uri: String,
42    /// Scopes requested
43    pub scopes: Vec<String>,
44    /// User ID (subject)
45    pub user_id: String,
46    /// State parameter (CSRF protection)
47    pub state: Option<String>,
48    /// Expiration time
49    pub expires_at: i64,
50    /// Tenant context
51    pub tenant_context: Option<TenantContext>,
52}
53
54/// OAuth2 authorization request parameters
55#[derive(Debug, Deserialize)]
56pub struct AuthorizationRequest {
57    /// Client ID
58    pub client_id: String,
59    /// Response type (code, token, id_token)
60    pub response_type: String,
61    /// Redirect URI
62    pub redirect_uri: String,
63    /// Scopes (space-separated)
64    pub scope: Option<String>,
65    /// State parameter (CSRF protection)
66    pub state: Option<String>,
67    /// Nonce (for OpenID Connect)
68    pub nonce: Option<String>,
69}
70
71/// OAuth2 token request
72#[derive(Debug, Deserialize)]
73pub struct TokenRequest {
74    /// Grant type
75    pub grant_type: String,
76    /// Authorization code (for authorization_code grant)
77    pub code: Option<String>,
78    /// Redirect URI (must match authorization request)
79    pub redirect_uri: Option<String>,
80    /// Client ID
81    pub client_id: Option<String>,
82    /// Client secret
83    pub client_secret: Option<String>,
84    /// Scope (for client_credentials grant)
85    pub scope: Option<String>,
86    /// Nonce (for OpenID Connect)
87    pub nonce: Option<String>,
88}
89
90/// OAuth2 token response
91#[derive(Debug, Serialize)]
92pub struct TokenResponse {
93    /// Access token
94    pub access_token: String,
95    /// Token type (usually "Bearer")
96    pub token_type: String,
97    /// Expires in (seconds)
98    pub expires_in: i64,
99    /// Refresh token (optional)
100    #[serde(skip_serializing_if = "Option::is_none")]
101    pub refresh_token: Option<String>,
102    /// Scope (optional)
103    #[serde(skip_serializing_if = "Option::is_none")]
104    pub scope: Option<String>,
105    /// ID token (for OpenID Connect)
106    #[serde(skip_serializing_if = "Option::is_none")]
107    pub id_token: Option<String>,
108}
109
110/// OAuth2 authorization endpoint
111pub async fn authorize(
112    State(state): State<OAuth2ServerState>,
113    Query(params): Query<AuthorizationRequest>,
114) -> Result<Redirect, StatusCode> {
115    // Validate response_type
116    if params.response_type != "code" {
117        return Err(StatusCode::BAD_REQUEST);
118    }
119
120    // Check if consent is required (simplified - in production, check user consent)
121    // For now, auto-approve and generate authorization code
122
123    // Generate authorization code before any await points (ThreadRng is not Send)
124    let auth_code = {
125        let mut rng = rand::thread_rng();
126        let code_bytes: [u8; 32] = rng.gen();
127        hex::encode(code_bytes)
128    };
129
130    // Parse scopes
131    let scopes = params
132        .scope
133        .as_ref()
134        .map(|s| s.split(' ').map(|s| s.to_string()).collect())
135        .unwrap_or_else(Vec::new);
136
137    // Store authorization code (expires in 10 minutes)
138    let code_info = AuthorizationCodeInfo {
139        client_id: params.client_id.clone(),
140        redirect_uri: params.redirect_uri.clone(),
141        scopes,
142        // For mock server, use default user ID
143        // In production, extract from authenticated session
144        user_id: "user-default".to_string(),
145        state: params.state.clone(),
146        expires_at: Utc::now().timestamp() + 600, // 10 minutes
147        // Tenant context can be extracted from request headers or session
148        tenant_context: None,
149    };
150
151    {
152        let mut codes = state.auth_codes.write().await;
153        codes.insert(auth_code.clone(), code_info);
154    }
155
156    // Build redirect URL with authorization code
157    let mut redirect_url =
158        url::Url::parse(&params.redirect_uri).map_err(|_| StatusCode::BAD_REQUEST)?;
159    redirect_url.query_pairs_mut().append_pair("code", &auth_code);
160    if let Some(state) = params.state {
161        redirect_url.query_pairs_mut().append_pair("state", &state);
162    }
163
164    Ok(Redirect::to(redirect_url.as_str()))
165}
166
167/// OAuth2 token endpoint
168pub async fn token(
169    State(state): State<OAuth2ServerState>,
170    axum::extract::Form(request): axum::extract::Form<TokenRequest>,
171) -> Result<Json<TokenResponse>, StatusCode> {
172    match request.grant_type.as_str() {
173        "authorization_code" => handle_authorization_code_grant(state, request).await,
174        "client_credentials" => handle_client_credentials_grant(state, request).await,
175        "refresh_token" => handle_refresh_token_grant(state, request).await,
176        _ => Err(StatusCode::BAD_REQUEST),
177    }
178}
179
180/// Handle authorization_code grant type
181async fn handle_authorization_code_grant(
182    state: OAuth2ServerState,
183    request: TokenRequest,
184) -> Result<Json<TokenResponse>, StatusCode> {
185    let code = request.code.ok_or(StatusCode::BAD_REQUEST)?;
186    let redirect_uri = request.redirect_uri.ok_or(StatusCode::BAD_REQUEST)?;
187
188    // Look up authorization code
189    let code_info = {
190        let mut codes = state.auth_codes.write().await;
191        codes.remove(&code).ok_or(StatusCode::BAD_REQUEST)?
192    };
193
194    // Validate redirect URI
195    if code_info.redirect_uri != redirect_uri {
196        return Err(StatusCode::BAD_REQUEST);
197    }
198
199    // Check expiration
200    if code_info.expires_at < Utc::now().timestamp() {
201        return Err(StatusCode::BAD_REQUEST);
202    }
203
204    // Generate access token using OIDC
205    let oidc_state_guard = state.oidc_state.read().await;
206    let oidc_state = oidc_state_guard.as_ref().ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
207
208    // Build claims
209    let mut additional_claims = HashMap::new();
210    additional_claims.insert("scope".to_string(), json!(code_info.scopes.join(" ")));
211    if let Some(nonce) = request.nonce {
212        additional_claims.insert("nonce".to_string(), json!(nonce));
213    }
214
215    let access_token = generate_oidc_token(
216        oidc_state,
217        code_info.user_id.clone(),
218        Some(additional_claims),
219        Some(3600), // 1 hour expiration
220        code_info.tenant_context.clone(),
221    )
222    .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
223
224    // Check if token is revoked (shouldn't be, but check anyway)
225    let token_id = extract_token_id(&access_token);
226    if state.lifecycle_manager.revocation.is_revoked(&token_id).await.is_some() {
227        return Err(StatusCode::INTERNAL_SERVER_ERROR);
228    }
229
230    // Generate refresh token (simplified)
231    let refresh_token = format!("refresh_{}", uuid::Uuid::new_v4());
232
233    Ok(Json(TokenResponse {
234        access_token,
235        token_type: "Bearer".to_string(),
236        expires_in: 3600,
237        refresh_token: Some(refresh_token),
238        scope: Some(code_info.scopes.join(" ")),
239        // ID token generation for OpenID Connect can be added by calling generate_oidc_token
240        // with appropriate OpenID Connect claims (sub, iss, aud, exp, iat, nonce, etc.)
241        id_token: None,
242    }))
243}
244
245/// Handle client_credentials grant type
246async fn handle_client_credentials_grant(
247    state: OAuth2ServerState,
248    request: TokenRequest,
249) -> Result<Json<TokenResponse>, StatusCode> {
250    let client_id = request.client_id.ok_or(StatusCode::BAD_REQUEST)?;
251    let _client_secret = request.client_secret.ok_or(StatusCode::BAD_REQUEST)?;
252
253    // Validate client credentials (simplified - in production, check against database)
254
255    // Generate access token
256    let oidc_state_guard = state.oidc_state.read().await;
257    let oidc_state = oidc_state_guard.as_ref().ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
258
259    let mut additional_claims = HashMap::new();
260    additional_claims.insert("client_id".to_string(), serde_json::json!(client_id));
261    let scope_clone = request.scope.clone();
262    if let Some(ref scope) = request.scope {
263        additional_claims.insert("scope".to_string(), serde_json::json!(scope));
264    }
265
266    let access_token = generate_oidc_token(
267        oidc_state,
268        format!("client_{}", client_id),
269        Some(additional_claims),
270        Some(3600),
271        None,
272    )
273    .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
274
275    Ok(Json(TokenResponse {
276        access_token,
277        token_type: "Bearer".to_string(),
278        expires_in: 3600,
279        refresh_token: None,
280        scope: scope_clone,
281        id_token: None,
282    }))
283}
284
285/// Handle refresh_token grant type
286async fn handle_refresh_token_grant(
287    state: OAuth2ServerState,
288    request: TokenRequest,
289) -> Result<Json<TokenResponse>, StatusCode> {
290    // For refresh token grant, we would:
291    // 1. Validate the refresh token
292    // 2. Check if it's revoked
293    // 3. Generate a new access token
294    // 4. Optionally generate a new refresh token
295
296    // Simplified implementation - in production, validate refresh token from storage
297    let client_id = request.client_id.ok_or(StatusCode::BAD_REQUEST)?;
298
299    // Generate new access token
300    let oidc_state_guard = state.oidc_state.read().await;
301    let oidc_state = oidc_state_guard.as_ref().ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
302
303    let mut additional_claims = HashMap::new();
304    additional_claims.insert("client_id".to_string(), json!(client_id));
305    let scope_clone = request.scope.clone();
306    if let Some(ref scope) = request.scope {
307        additional_claims.insert("scope".to_string(), json!(scope));
308    }
309
310    let access_token = generate_oidc_token(
311        oidc_state,
312        format!("client_{}", client_id),
313        Some(additional_claims),
314        Some(3600),
315        None,
316    )
317    .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
318
319    // Generate new refresh token
320    let refresh_token = format!("refresh_{}", uuid::Uuid::new_v4());
321
322    Ok(Json(TokenResponse {
323        access_token,
324        token_type: "Bearer".to_string(),
325        expires_in: 3600,
326        refresh_token: Some(refresh_token),
327        scope: scope_clone,
328        id_token: None,
329    }))
330}
331
332/// Create OAuth2 server router
333pub fn oauth2_server_router(state: OAuth2ServerState) -> axum::Router {
334    use axum::routing::{get, post};
335
336    axum::Router::new()
337        .route("/oauth2/authorize", get(authorize))
338        .route("/oauth2/token", post(token))
339        .with_state(state)
340}