mockforge_http/handlers/
oauth2_server.rs

1//! OAuth2 server endpoints
2//!
3//! This module provides OAuth2 authorization server endpoints that integrate
4//! with OIDC, token lifecycle, consent, and risk simulation.
5
6use axum::{
7    extract::{Query, State},
8    http::StatusCode,
9    response::{Json, Redirect},
10};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14use tokio::sync::RwLock;
15
16use crate::auth::oidc::{generate_oidc_token, OidcState, TenantContext};
17use crate::auth::token_lifecycle::{extract_token_id, TokenLifecycleManager};
18use chrono::Utc;
19use hex;
20use rand::Rng;
21use serde_json::json;
22use uuid;
23
24/// OAuth2 server state
25#[derive(Clone)]
26pub struct OAuth2ServerState {
27    /// OIDC state for token generation
28    pub oidc_state: Arc<RwLock<Option<OidcState>>>,
29    /// Token lifecycle manager
30    pub lifecycle_manager: Arc<TokenLifecycleManager>,
31    /// Authorization codes (code -> authorization info)
32    pub auth_codes: Arc<RwLock<HashMap<String, AuthorizationCodeInfo>>>,
33}
34
35/// Authorization code information
36#[derive(Debug, Clone)]
37pub struct AuthorizationCodeInfo {
38    /// Client ID
39    pub client_id: String,
40    /// Redirect URI
41    pub redirect_uri: String,
42    /// Scopes requested
43    pub scopes: Vec<String>,
44    /// User ID (subject)
45    pub user_id: String,
46    /// State parameter (CSRF protection)
47    pub state: Option<String>,
48    /// Expiration time
49    pub expires_at: i64,
50    /// Tenant context
51    pub tenant_context: Option<TenantContext>,
52}
53
54/// OAuth2 authorization request parameters
55#[derive(Debug, Deserialize)]
56pub struct AuthorizationRequest {
57    /// Client ID
58    pub client_id: String,
59    /// Response type (code, token, id_token)
60    pub response_type: String,
61    /// Redirect URI
62    pub redirect_uri: String,
63    /// Scopes (space-separated)
64    pub scope: Option<String>,
65    /// State parameter (CSRF protection)
66    pub state: Option<String>,
67    /// Nonce (for OpenID Connect)
68    pub nonce: Option<String>,
69}
70
71/// OAuth2 token request
72#[derive(Debug, Deserialize)]
73pub struct TokenRequest {
74    /// Grant type
75    pub grant_type: String,
76    /// Authorization code (for authorization_code grant)
77    pub code: Option<String>,
78    /// Redirect URI (must match authorization request)
79    pub redirect_uri: Option<String>,
80    /// Client ID
81    pub client_id: Option<String>,
82    /// Client secret
83    pub client_secret: Option<String>,
84    /// Scope (for client_credentials grant)
85    pub scope: Option<String>,
86    /// Nonce (for OpenID Connect)
87    pub nonce: Option<String>,
88}
89
90/// OAuth2 token response
91#[derive(Debug, Serialize)]
92pub struct TokenResponse {
93    /// Access token
94    pub access_token: String,
95    /// Token type (usually "Bearer")
96    pub token_type: String,
97    /// Expires in (seconds)
98    pub expires_in: i64,
99    /// Refresh token (optional)
100    #[serde(skip_serializing_if = "Option::is_none")]
101    pub refresh_token: Option<String>,
102    /// Scope (optional)
103    #[serde(skip_serializing_if = "Option::is_none")]
104    pub scope: Option<String>,
105    /// ID token (for OpenID Connect)
106    #[serde(skip_serializing_if = "Option::is_none")]
107    pub id_token: Option<String>,
108}
109
110/// OAuth2 authorization endpoint
111pub async fn authorize(
112    State(state): State<OAuth2ServerState>,
113    Query(params): Query<AuthorizationRequest>,
114) -> Result<Redirect, StatusCode> {
115    // Validate response_type
116    if params.response_type != "code" {
117        return Err(StatusCode::BAD_REQUEST);
118    }
119
120    // Check if consent is required (simplified - in production, check user consent)
121    // For now, auto-approve and generate authorization code
122
123    // Generate authorization code before any await points (ThreadRng is not Send)
124    let auth_code = {
125        let mut rng = rand::thread_rng();
126        let code_bytes: [u8; 32] = rng.gen();
127        hex::encode(code_bytes)
128    };
129
130    // Parse scopes
131    let scopes = params
132        .scope
133        .as_ref()
134        .map(|s| s.split(' ').map(|s| s.to_string()).collect())
135        .unwrap_or_else(Vec::new);
136
137    // Store authorization code (expires in 10 minutes)
138    let code_info = AuthorizationCodeInfo {
139        client_id: params.client_id.clone(),
140        redirect_uri: params.redirect_uri.clone(),
141        scopes,
142        // For mock server, use default user ID
143        // In production, extract from authenticated session
144        user_id: "user-default".to_string(),
145        state: params.state.clone(),
146        expires_at: Utc::now().timestamp() + 600, // 10 minutes
147        // Tenant context can be extracted from request headers or session
148        tenant_context: None,
149    };
150
151    {
152        let mut codes = state.auth_codes.write().await;
153        codes.insert(auth_code.clone(), code_info);
154    }
155
156    // Build redirect URL with authorization code
157    let mut redirect_url =
158        url::Url::parse(&params.redirect_uri).map_err(|_| StatusCode::BAD_REQUEST)?;
159    redirect_url.query_pairs_mut().append_pair("code", &auth_code);
160    if let Some(state) = params.state {
161        redirect_url.query_pairs_mut().append_pair("state", &state);
162    }
163
164    Ok(Redirect::to(redirect_url.as_str()))
165}
166
167/// OAuth2 token endpoint
168pub async fn token(
169    State(state): State<OAuth2ServerState>,
170    axum::extract::Form(request): axum::extract::Form<TokenRequest>,
171) -> Result<Json<TokenResponse>, StatusCode> {
172    use chrono::Utc;
173
174    match request.grant_type.as_str() {
175        "authorization_code" => handle_authorization_code_grant(state, request).await,
176        "client_credentials" => handle_client_credentials_grant(state, request).await,
177        "refresh_token" => handle_refresh_token_grant(state, request).await,
178        _ => Err(StatusCode::BAD_REQUEST),
179    }
180}
181
182/// Handle authorization_code grant type
183async fn handle_authorization_code_grant(
184    state: OAuth2ServerState,
185    request: TokenRequest,
186) -> Result<Json<TokenResponse>, StatusCode> {
187    let code = request.code.ok_or(StatusCode::BAD_REQUEST)?;
188    let redirect_uri = request.redirect_uri.ok_or(StatusCode::BAD_REQUEST)?;
189
190    // Look up authorization code
191    let code_info = {
192        let mut codes = state.auth_codes.write().await;
193        codes.remove(&code).ok_or(StatusCode::BAD_REQUEST)?
194    };
195
196    // Validate redirect URI
197    if code_info.redirect_uri != redirect_uri {
198        return Err(StatusCode::BAD_REQUEST);
199    }
200
201    // Check expiration
202    if code_info.expires_at < Utc::now().timestamp() {
203        return Err(StatusCode::BAD_REQUEST);
204    }
205
206    // Generate access token using OIDC
207    let oidc_state_guard = state.oidc_state.read().await;
208    let oidc_state = oidc_state_guard.as_ref().ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
209
210    // Build claims
211    let mut additional_claims = HashMap::new();
212    additional_claims.insert("scope".to_string(), json!(code_info.scopes.join(" ")));
213    if let Some(nonce) = request.nonce {
214        additional_claims.insert("nonce".to_string(), json!(nonce));
215    }
216
217    let access_token = generate_oidc_token(
218        oidc_state,
219        code_info.user_id.clone(),
220        Some(additional_claims),
221        Some(3600), // 1 hour expiration
222        code_info.tenant_context.clone(),
223    )
224    .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
225
226    // Check if token is revoked (shouldn't be, but check anyway)
227    let token_id = extract_token_id(&access_token);
228    if state.lifecycle_manager.revocation.is_revoked(&token_id).await.is_some() {
229        return Err(StatusCode::INTERNAL_SERVER_ERROR);
230    }
231
232    // Generate refresh token (simplified)
233    let refresh_token = format!("refresh_{}", uuid::Uuid::new_v4());
234
235    Ok(Json(TokenResponse {
236        access_token,
237        token_type: "Bearer".to_string(),
238        expires_in: 3600,
239        refresh_token: Some(refresh_token),
240        scope: Some(code_info.scopes.join(" ")),
241        // ID token generation for OpenID Connect can be added by calling generate_oidc_token
242        // with appropriate OpenID Connect claims (sub, iss, aud, exp, iat, nonce, etc.)
243        id_token: None,
244    }))
245}
246
247/// Handle client_credentials grant type
248async fn handle_client_credentials_grant(
249    state: OAuth2ServerState,
250    request: TokenRequest,
251) -> Result<Json<TokenResponse>, StatusCode> {
252    let client_id = request.client_id.ok_or(StatusCode::BAD_REQUEST)?;
253    let _client_secret = request.client_secret.ok_or(StatusCode::BAD_REQUEST)?;
254
255    // Validate client credentials (simplified - in production, check against database)
256
257    // Generate access token
258    let oidc_state_guard = state.oidc_state.read().await;
259    let oidc_state = oidc_state_guard.as_ref().ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
260
261    let mut additional_claims = HashMap::new();
262    additional_claims.insert("client_id".to_string(), serde_json::json!(client_id));
263    let scope_clone = request.scope.clone();
264    if let Some(ref scope) = request.scope {
265        additional_claims.insert("scope".to_string(), serde_json::json!(scope));
266    }
267
268    let access_token = generate_oidc_token(
269        oidc_state,
270        format!("client_{}", client_id),
271        Some(additional_claims),
272        Some(3600),
273        None,
274    )
275    .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
276
277    Ok(Json(TokenResponse {
278        access_token,
279        token_type: "Bearer".to_string(),
280        expires_in: 3600,
281        refresh_token: None,
282        scope: scope_clone,
283        id_token: None,
284    }))
285}
286
287/// Handle refresh_token grant type
288async fn handle_refresh_token_grant(
289    state: OAuth2ServerState,
290    request: TokenRequest,
291) -> Result<Json<TokenResponse>, StatusCode> {
292    // For refresh token grant, we would:
293    // 1. Validate the refresh token
294    // 2. Check if it's revoked
295    // 3. Generate a new access token
296    // 4. Optionally generate a new refresh token
297
298    // Simplified implementation - in production, validate refresh token from storage
299    let client_id = request.client_id.ok_or(StatusCode::BAD_REQUEST)?;
300
301    // Generate new access token
302    let oidc_state_guard = state.oidc_state.read().await;
303    let oidc_state = oidc_state_guard.as_ref().ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
304
305    let mut additional_claims = HashMap::new();
306    additional_claims.insert("client_id".to_string(), json!(client_id));
307    let scope_clone = request.scope.clone();
308    if let Some(ref scope) = request.scope {
309        additional_claims.insert("scope".to_string(), json!(scope));
310    }
311
312    let access_token = generate_oidc_token(
313        oidc_state,
314        format!("client_{}", client_id),
315        Some(additional_claims),
316        Some(3600),
317        None,
318    )
319    .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
320
321    // Generate new refresh token
322    let refresh_token = format!("refresh_{}", uuid::Uuid::new_v4());
323
324    Ok(Json(TokenResponse {
325        access_token,
326        token_type: "Bearer".to_string(),
327        expires_in: 3600,
328        refresh_token: Some(refresh_token),
329        scope: scope_clone,
330        id_token: None,
331    }))
332}
333
334/// Create OAuth2 server router
335pub fn oauth2_server_router(state: OAuth2ServerState) -> axum::Router {
336    use axum::routing::{get, post};
337
338    axum::Router::new()
339        .route("/oauth2/authorize", get(authorize))
340        .route("/oauth2/token", post(token))
341        .with_state(state)
342}