Skip to main content

mockforge_http/handlers/
oauth2_server.rs

1//! OAuth2 server endpoints
2//!
3//! This module provides OAuth2 authorization server endpoints that integrate
4//! with OIDC, token lifecycle, consent, and risk simulation.
5
6use axum::{
7    extract::{Query, State},
8    http::StatusCode,
9    response::{Json, Redirect},
10};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14use tokio::sync::RwLock;
15
16use crate::auth::oidc::{generate_oidc_token, OidcState, TenantContext};
17use crate::auth::token_lifecycle::{extract_token_id, TokenLifecycleManager};
18use chrono::Utc;
19use hex;
20use rand::Rng;
21use serde_json::json;
22use uuid;
23
24/// OAuth2 server state
25#[derive(Clone)]
26pub struct OAuth2ServerState {
27    /// OIDC state for token generation
28    pub oidc_state: Arc<RwLock<Option<OidcState>>>,
29    /// Token lifecycle manager
30    pub lifecycle_manager: Arc<TokenLifecycleManager>,
31    /// Authorization codes (code -> authorization info)
32    pub auth_codes: Arc<RwLock<HashMap<String, AuthorizationCodeInfo>>>,
33    /// Refresh tokens (token -> refresh token info)
34    pub refresh_tokens: Arc<RwLock<HashMap<String, RefreshTokenInfo>>>,
35}
36
37/// Refresh token information
38#[derive(Debug, Clone)]
39pub struct RefreshTokenInfo {
40    /// Client ID that issued this refresh token
41    pub client_id: String,
42    /// Scopes associated with this token
43    pub scopes: Vec<String>,
44    /// User/subject ID
45    pub user_id: String,
46    /// Expiration timestamp
47    pub expires_at: i64,
48}
49
50/// Authorization code information
51#[derive(Debug, Clone)]
52pub struct AuthorizationCodeInfo {
53    /// Client ID
54    pub client_id: String,
55    /// Redirect URI
56    pub redirect_uri: String,
57    /// Scopes requested
58    pub scopes: Vec<String>,
59    /// User ID (subject)
60    pub user_id: String,
61    /// State parameter (CSRF protection)
62    pub state: Option<String>,
63    /// Expiration time
64    pub expires_at: i64,
65    /// Tenant context
66    pub tenant_context: Option<TenantContext>,
67}
68
69/// OAuth2 authorization request parameters
70#[derive(Debug, Deserialize)]
71pub struct AuthorizationRequest {
72    /// Client ID
73    pub client_id: String,
74    /// Response type (code, token, id_token)
75    pub response_type: String,
76    /// Redirect URI
77    pub redirect_uri: String,
78    /// Scopes (space-separated)
79    pub scope: Option<String>,
80    /// State parameter (CSRF protection)
81    pub state: Option<String>,
82    /// Nonce (for OpenID Connect)
83    pub nonce: Option<String>,
84    /// Prompt parameter (OpenID Connect). When set to "consent", shows the consent screen.
85    pub prompt: Option<String>,
86}
87
88/// OAuth2 token request
89#[derive(Debug, Deserialize)]
90pub struct TokenRequest {
91    /// Grant type
92    pub grant_type: String,
93    /// Authorization code (for authorization_code grant)
94    pub code: Option<String>,
95    /// Redirect URI (must match authorization request)
96    pub redirect_uri: Option<String>,
97    /// Client ID
98    pub client_id: Option<String>,
99    /// Client secret
100    pub client_secret: Option<String>,
101    /// Scope (for client_credentials grant)
102    pub scope: Option<String>,
103    /// Nonce (for OpenID Connect)
104    pub nonce: Option<String>,
105    /// Refresh token (for refresh_token grant)
106    pub refresh_token: Option<String>,
107}
108
109/// OAuth2 token response
110#[derive(Debug, Serialize)]
111pub struct TokenResponse {
112    /// Access token
113    pub access_token: String,
114    /// Token type (usually "Bearer")
115    pub token_type: String,
116    /// Expires in (seconds)
117    pub expires_in: i64,
118    /// Refresh token (optional)
119    #[serde(skip_serializing_if = "Option::is_none")]
120    pub refresh_token: Option<String>,
121    /// Scope (optional)
122    #[serde(skip_serializing_if = "Option::is_none")]
123    pub scope: Option<String>,
124    /// ID token (for OpenID Connect)
125    #[serde(skip_serializing_if = "Option::is_none")]
126    pub id_token: Option<String>,
127}
128
129/// OAuth2 authorization endpoint
130pub async fn authorize(
131    State(state): State<OAuth2ServerState>,
132    Query(params): Query<AuthorizationRequest>,
133) -> Result<Redirect, StatusCode> {
134    // Validate response_type
135    if params.response_type != "code" {
136        return Err(StatusCode::BAD_REQUEST);
137    }
138
139    // If prompt=consent, redirect to the consent screen instead of auto-approving
140    if params.prompt.as_deref() == Some("consent") {
141        let mut consent_url = url::Url::parse("http://localhost/consent")
142            .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
143        consent_url
144            .query_pairs_mut()
145            .append_pair("client_id", &params.client_id)
146            .append_pair("redirect_uri", &params.redirect_uri);
147        if let Some(ref scope) = params.scope {
148            consent_url.query_pairs_mut().append_pair("scope", scope);
149        }
150        if let Some(ref state) = params.state {
151            consent_url.query_pairs_mut().append_pair("state", state);
152        }
153        // Use only the path and query, ignoring the dummy host
154        let redirect_target =
155            format!("/consent{}", consent_url.query().map(|q| format!("?{q}")).unwrap_or_default());
156        return Ok(Redirect::to(&redirect_target));
157    }
158
159    // Auto-approve flow: generate authorization code directly (default for mock server)
160
161    // Generate authorization code before any await points (ThreadRng is not Send)
162    let auth_code = {
163        let mut rng = rand::rng();
164        let code_bytes: [u8; 32] = rng.random();
165        hex::encode(code_bytes)
166    };
167
168    // Parse scopes
169    let scopes = params
170        .scope
171        .as_ref()
172        .map(|s| s.split(' ').map(|s| s.to_string()).collect())
173        .unwrap_or_else(Vec::new);
174
175    // Store authorization code (expires in 10 minutes)
176    let code_info = AuthorizationCodeInfo {
177        client_id: params.client_id.clone(),
178        redirect_uri: params.redirect_uri.clone(),
179        scopes,
180        // For mock server, use default user ID
181        // In production, extract from authenticated session
182        user_id: "user-default".to_string(),
183        state: params.state.clone(),
184        expires_at: Utc::now().timestamp() + 600, // 10 minutes
185        // Tenant context can be extracted from request headers or session
186        tenant_context: None,
187    };
188
189    {
190        let mut codes = state.auth_codes.write().await;
191        codes.insert(auth_code.clone(), code_info);
192    }
193
194    // Build redirect URL with authorization code
195    let mut redirect_url =
196        url::Url::parse(&params.redirect_uri).map_err(|_| StatusCode::BAD_REQUEST)?;
197    redirect_url.query_pairs_mut().append_pair("code", &auth_code);
198    if let Some(state) = params.state {
199        redirect_url.query_pairs_mut().append_pair("state", &state);
200    }
201
202    Ok(Redirect::to(redirect_url.as_str()))
203}
204
205/// OAuth2 token endpoint
206pub async fn token(
207    State(state): State<OAuth2ServerState>,
208    axum::extract::Form(request): axum::extract::Form<TokenRequest>,
209) -> Result<Json<TokenResponse>, StatusCode> {
210    match request.grant_type.as_str() {
211        "authorization_code" => handle_authorization_code_grant(state, request).await,
212        "client_credentials" => handle_client_credentials_grant(state, request).await,
213        "refresh_token" => handle_refresh_token_grant(state, request).await,
214        _ => Err(StatusCode::BAD_REQUEST),
215    }
216}
217
218/// Handle authorization_code grant type
219async fn handle_authorization_code_grant(
220    state: OAuth2ServerState,
221    request: TokenRequest,
222) -> Result<Json<TokenResponse>, StatusCode> {
223    let code = request.code.ok_or(StatusCode::BAD_REQUEST)?;
224    let redirect_uri = request.redirect_uri.ok_or(StatusCode::BAD_REQUEST)?;
225
226    // Look up authorization code
227    let code_info = {
228        let mut codes = state.auth_codes.write().await;
229        codes.remove(&code).ok_or(StatusCode::BAD_REQUEST)?
230    };
231
232    // Validate redirect URI
233    if code_info.redirect_uri != redirect_uri {
234        return Err(StatusCode::BAD_REQUEST);
235    }
236
237    // Check expiration
238    if code_info.expires_at < Utc::now().timestamp() {
239        return Err(StatusCode::BAD_REQUEST);
240    }
241
242    // Generate access token using OIDC
243    let oidc_state_guard = state.oidc_state.read().await;
244    let oidc_state = oidc_state_guard.as_ref().ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
245
246    // Build claims
247    let mut additional_claims = HashMap::new();
248    additional_claims.insert("scope".to_string(), json!(code_info.scopes.join(" ")));
249    if let Some(nonce) = request.nonce {
250        additional_claims.insert("nonce".to_string(), json!(nonce));
251    }
252
253    let access_token = generate_oidc_token(
254        oidc_state,
255        code_info.user_id.clone(),
256        Some(additional_claims),
257        Some(3600), // 1 hour expiration
258        code_info.tenant_context.clone(),
259    )
260    .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
261
262    // Check if token is revoked (shouldn't be, but check anyway)
263    let token_id = extract_token_id(&access_token);
264    if state.lifecycle_manager.revocation.is_revoked(&token_id).await.is_some() {
265        return Err(StatusCode::INTERNAL_SERVER_ERROR);
266    }
267
268    // Generate refresh token and store it
269    let refresh_token = format!("refresh_{}", uuid::Uuid::new_v4());
270    {
271        let mut tokens = state.refresh_tokens.write().await;
272        tokens.insert(
273            refresh_token.clone(),
274            RefreshTokenInfo {
275                client_id: code_info.client_id.clone(),
276                scopes: code_info.scopes.clone(),
277                user_id: code_info.user_id.clone(),
278                expires_at: Utc::now().timestamp() + 86400, // 24 hours
279            },
280        );
281    }
282
283    Ok(Json(TokenResponse {
284        access_token,
285        token_type: "Bearer".to_string(),
286        expires_in: 3600,
287        refresh_token: Some(refresh_token),
288        scope: Some(code_info.scopes.join(" ")),
289        id_token: None,
290    }))
291}
292
293/// Handle client_credentials grant type
294async fn handle_client_credentials_grant(
295    state: OAuth2ServerState,
296    request: TokenRequest,
297) -> Result<Json<TokenResponse>, StatusCode> {
298    let client_id = request.client_id.ok_or(StatusCode::BAD_REQUEST)?;
299    let _client_secret = request.client_secret.ok_or(StatusCode::BAD_REQUEST)?;
300
301    // Validate client credentials (simplified - in production, check against database)
302
303    // Generate access token
304    let oidc_state_guard = state.oidc_state.read().await;
305    let oidc_state = oidc_state_guard.as_ref().ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
306
307    let mut additional_claims = HashMap::new();
308    additional_claims.insert("client_id".to_string(), serde_json::json!(client_id));
309    let scope_clone = request.scope.clone();
310    if let Some(ref scope) = request.scope {
311        additional_claims.insert("scope".to_string(), serde_json::json!(scope));
312    }
313
314    let access_token = generate_oidc_token(
315        oidc_state,
316        format!("client_{}", client_id),
317        Some(additional_claims),
318        Some(3600),
319        None,
320    )
321    .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
322
323    Ok(Json(TokenResponse {
324        access_token,
325        token_type: "Bearer".to_string(),
326        expires_in: 3600,
327        refresh_token: None,
328        scope: scope_clone,
329        id_token: None,
330    }))
331}
332
333/// Handle refresh_token grant type
334async fn handle_refresh_token_grant(
335    state: OAuth2ServerState,
336    request: TokenRequest,
337) -> Result<Json<TokenResponse>, StatusCode> {
338    // Extract and validate the refresh token from the request
339    let refresh_token_value = request.refresh_token.ok_or(StatusCode::BAD_REQUEST)?;
340
341    // Look up and remove the old refresh token (single-use rotation)
342    let token_info = {
343        let mut tokens = state.refresh_tokens.write().await;
344        tokens.remove(&refresh_token_value).ok_or(StatusCode::UNAUTHORIZED)?
345    };
346
347    // Check expiration
348    if token_info.expires_at < Utc::now().timestamp() {
349        return Err(StatusCode::UNAUTHORIZED);
350    }
351
352    // Validate client_id matches if provided
353    if let Some(ref client_id) = request.client_id {
354        if *client_id != token_info.client_id {
355            return Err(StatusCode::UNAUTHORIZED);
356        }
357    }
358
359    // Generate new access token
360    let oidc_state_guard = state.oidc_state.read().await;
361    let oidc_state = oidc_state_guard.as_ref().ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
362
363    let mut additional_claims = HashMap::new();
364    additional_claims.insert("client_id".to_string(), json!(token_info.client_id.clone()));
365
366    // Use scopes from stored token, or override with request scope if provided
367    let scopes = if let Some(ref scope) = request.scope {
368        additional_claims.insert("scope".to_string(), json!(scope));
369        scope.clone()
370    } else {
371        let scope_str = token_info.scopes.join(" ");
372        additional_claims.insert("scope".to_string(), json!(scope_str));
373        scope_str
374    };
375
376    let access_token = generate_oidc_token(
377        oidc_state,
378        token_info.user_id.clone(),
379        Some(additional_claims),
380        Some(3600),
381        None,
382    )
383    .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
384
385    // Generate and store new refresh token (rotation)
386    let new_refresh_token = format!("refresh_{}", uuid::Uuid::new_v4());
387    {
388        let mut tokens = state.refresh_tokens.write().await;
389        tokens.insert(
390            new_refresh_token.clone(),
391            RefreshTokenInfo {
392                client_id: token_info.client_id,
393                scopes: token_info.scopes,
394                user_id: token_info.user_id,
395                expires_at: Utc::now().timestamp() + 86400, // 24 hours
396            },
397        );
398    }
399
400    Ok(Json(TokenResponse {
401        access_token,
402        token_type: "Bearer".to_string(),
403        expires_in: 3600,
404        refresh_token: Some(new_refresh_token),
405        scope: Some(scopes),
406        id_token: None,
407    }))
408}
409
410/// Create OAuth2 server router
411pub fn oauth2_server_router(state: OAuth2ServerState) -> axum::Router {
412    use axum::routing::{get, post};
413
414    axum::Router::new()
415        .route("/oauth2/authorize", get(authorize))
416        .route("/oauth2/token", post(token))
417        .with_state(state)
418}