Skip to main content

mockforge_http/handlers/
oauth2_server.rs

1//! OAuth2 server endpoints
2//!
3//! This module provides OAuth2 authorization server endpoints that integrate
4//! with OIDC, token lifecycle, consent, and risk simulation.
5
6use axum::{
7    extract::{Query, State},
8    http::StatusCode,
9    response::{Json, Redirect},
10};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14use tokio::sync::RwLock;
15
16use crate::auth::oidc::{generate_oidc_token, OidcState, TenantContext};
17use crate::auth::token_lifecycle::{extract_token_id, TokenLifecycleManager};
18use chrono::Utc;
19use hex;
20use rand::Rng;
21use serde_json::json;
22use uuid;
23
24/// OAuth2 server state
25#[derive(Clone)]
26pub struct OAuth2ServerState {
27    /// OIDC state for token generation
28    pub oidc_state: Arc<RwLock<Option<OidcState>>>,
29    /// Token lifecycle manager
30    pub lifecycle_manager: Arc<TokenLifecycleManager>,
31    /// Authorization codes (code -> authorization info)
32    pub auth_codes: Arc<RwLock<HashMap<String, AuthorizationCodeInfo>>>,
33    /// Refresh tokens (token -> refresh token info)
34    pub refresh_tokens: Arc<RwLock<HashMap<String, RefreshTokenInfo>>>,
35}
36
37/// Refresh token information
38#[derive(Debug, Clone)]
39pub struct RefreshTokenInfo {
40    /// Client ID that issued this refresh token
41    pub client_id: String,
42    /// Scopes associated with this token
43    pub scopes: Vec<String>,
44    /// User/subject ID
45    pub user_id: String,
46    /// Expiration timestamp
47    pub expires_at: i64,
48}
49
50/// Authorization code information
51#[derive(Debug, Clone)]
52pub struct AuthorizationCodeInfo {
53    /// Client ID
54    pub client_id: String,
55    /// Redirect URI
56    pub redirect_uri: String,
57    /// Scopes requested
58    pub scopes: Vec<String>,
59    /// User ID (subject)
60    pub user_id: String,
61    /// State parameter (CSRF protection)
62    pub state: Option<String>,
63    /// Expiration time
64    pub expires_at: i64,
65    /// Tenant context
66    pub tenant_context: Option<TenantContext>,
67}
68
69/// OAuth2 authorization request parameters
70#[derive(Debug, Deserialize)]
71pub struct AuthorizationRequest {
72    /// Client ID
73    pub client_id: String,
74    /// Response type (code, token, id_token)
75    pub response_type: String,
76    /// Redirect URI
77    pub redirect_uri: String,
78    /// Scopes (space-separated)
79    pub scope: Option<String>,
80    /// State parameter (CSRF protection)
81    pub state: Option<String>,
82    /// Nonce (for OpenID Connect)
83    pub nonce: Option<String>,
84}
85
86/// OAuth2 token request
87#[derive(Debug, Deserialize)]
88pub struct TokenRequest {
89    /// Grant type
90    pub grant_type: String,
91    /// Authorization code (for authorization_code grant)
92    pub code: Option<String>,
93    /// Redirect URI (must match authorization request)
94    pub redirect_uri: Option<String>,
95    /// Client ID
96    pub client_id: Option<String>,
97    /// Client secret
98    pub client_secret: Option<String>,
99    /// Scope (for client_credentials grant)
100    pub scope: Option<String>,
101    /// Nonce (for OpenID Connect)
102    pub nonce: Option<String>,
103    /// Refresh token (for refresh_token grant)
104    pub refresh_token: Option<String>,
105}
106
107/// OAuth2 token response
108#[derive(Debug, Serialize)]
109pub struct TokenResponse {
110    /// Access token
111    pub access_token: String,
112    /// Token type (usually "Bearer")
113    pub token_type: String,
114    /// Expires in (seconds)
115    pub expires_in: i64,
116    /// Refresh token (optional)
117    #[serde(skip_serializing_if = "Option::is_none")]
118    pub refresh_token: Option<String>,
119    /// Scope (optional)
120    #[serde(skip_serializing_if = "Option::is_none")]
121    pub scope: Option<String>,
122    /// ID token (for OpenID Connect)
123    #[serde(skip_serializing_if = "Option::is_none")]
124    pub id_token: Option<String>,
125}
126
127/// OAuth2 authorization endpoint
128pub async fn authorize(
129    State(state): State<OAuth2ServerState>,
130    Query(params): Query<AuthorizationRequest>,
131) -> Result<Redirect, StatusCode> {
132    // Validate response_type
133    if params.response_type != "code" {
134        return Err(StatusCode::BAD_REQUEST);
135    }
136
137    // Check if consent is required (simplified - in production, check user consent)
138    // For now, auto-approve and generate authorization code
139
140    // Generate authorization code before any await points (ThreadRng is not Send)
141    let auth_code = {
142        let mut rng = rand::rng();
143        let code_bytes: [u8; 32] = rng.random();
144        hex::encode(code_bytes)
145    };
146
147    // Parse scopes
148    let scopes = params
149        .scope
150        .as_ref()
151        .map(|s| s.split(' ').map(|s| s.to_string()).collect())
152        .unwrap_or_else(Vec::new);
153
154    // Store authorization code (expires in 10 minutes)
155    let code_info = AuthorizationCodeInfo {
156        client_id: params.client_id.clone(),
157        redirect_uri: params.redirect_uri.clone(),
158        scopes,
159        // For mock server, use default user ID
160        // In production, extract from authenticated session
161        user_id: "user-default".to_string(),
162        state: params.state.clone(),
163        expires_at: Utc::now().timestamp() + 600, // 10 minutes
164        // Tenant context can be extracted from request headers or session
165        tenant_context: None,
166    };
167
168    {
169        let mut codes = state.auth_codes.write().await;
170        codes.insert(auth_code.clone(), code_info);
171    }
172
173    // Build redirect URL with authorization code
174    let mut redirect_url =
175        url::Url::parse(&params.redirect_uri).map_err(|_| StatusCode::BAD_REQUEST)?;
176    redirect_url.query_pairs_mut().append_pair("code", &auth_code);
177    if let Some(state) = params.state {
178        redirect_url.query_pairs_mut().append_pair("state", &state);
179    }
180
181    Ok(Redirect::to(redirect_url.as_str()))
182}
183
184/// OAuth2 token endpoint
185pub async fn token(
186    State(state): State<OAuth2ServerState>,
187    axum::extract::Form(request): axum::extract::Form<TokenRequest>,
188) -> Result<Json<TokenResponse>, StatusCode> {
189    match request.grant_type.as_str() {
190        "authorization_code" => handle_authorization_code_grant(state, request).await,
191        "client_credentials" => handle_client_credentials_grant(state, request).await,
192        "refresh_token" => handle_refresh_token_grant(state, request).await,
193        _ => Err(StatusCode::BAD_REQUEST),
194    }
195}
196
197/// Handle authorization_code grant type
198async fn handle_authorization_code_grant(
199    state: OAuth2ServerState,
200    request: TokenRequest,
201) -> Result<Json<TokenResponse>, StatusCode> {
202    let code = request.code.ok_or(StatusCode::BAD_REQUEST)?;
203    let redirect_uri = request.redirect_uri.ok_or(StatusCode::BAD_REQUEST)?;
204
205    // Look up authorization code
206    let code_info = {
207        let mut codes = state.auth_codes.write().await;
208        codes.remove(&code).ok_or(StatusCode::BAD_REQUEST)?
209    };
210
211    // Validate redirect URI
212    if code_info.redirect_uri != redirect_uri {
213        return Err(StatusCode::BAD_REQUEST);
214    }
215
216    // Check expiration
217    if code_info.expires_at < Utc::now().timestamp() {
218        return Err(StatusCode::BAD_REQUEST);
219    }
220
221    // Generate access token using OIDC
222    let oidc_state_guard = state.oidc_state.read().await;
223    let oidc_state = oidc_state_guard.as_ref().ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
224
225    // Build claims
226    let mut additional_claims = HashMap::new();
227    additional_claims.insert("scope".to_string(), json!(code_info.scopes.join(" ")));
228    if let Some(nonce) = request.nonce {
229        additional_claims.insert("nonce".to_string(), json!(nonce));
230    }
231
232    let access_token = generate_oidc_token(
233        oidc_state,
234        code_info.user_id.clone(),
235        Some(additional_claims),
236        Some(3600), // 1 hour expiration
237        code_info.tenant_context.clone(),
238    )
239    .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
240
241    // Check if token is revoked (shouldn't be, but check anyway)
242    let token_id = extract_token_id(&access_token);
243    if state.lifecycle_manager.revocation.is_revoked(&token_id).await.is_some() {
244        return Err(StatusCode::INTERNAL_SERVER_ERROR);
245    }
246
247    // Generate refresh token and store it
248    let refresh_token = format!("refresh_{}", uuid::Uuid::new_v4());
249    {
250        let mut tokens = state.refresh_tokens.write().await;
251        tokens.insert(
252            refresh_token.clone(),
253            RefreshTokenInfo {
254                client_id: code_info.client_id.clone(),
255                scopes: code_info.scopes.clone(),
256                user_id: code_info.user_id.clone(),
257                expires_at: Utc::now().timestamp() + 86400, // 24 hours
258            },
259        );
260    }
261
262    Ok(Json(TokenResponse {
263        access_token,
264        token_type: "Bearer".to_string(),
265        expires_in: 3600,
266        refresh_token: Some(refresh_token),
267        scope: Some(code_info.scopes.join(" ")),
268        id_token: None,
269    }))
270}
271
272/// Handle client_credentials grant type
273async fn handle_client_credentials_grant(
274    state: OAuth2ServerState,
275    request: TokenRequest,
276) -> Result<Json<TokenResponse>, StatusCode> {
277    let client_id = request.client_id.ok_or(StatusCode::BAD_REQUEST)?;
278    let _client_secret = request.client_secret.ok_or(StatusCode::BAD_REQUEST)?;
279
280    // Validate client credentials (simplified - in production, check against database)
281
282    // Generate access token
283    let oidc_state_guard = state.oidc_state.read().await;
284    let oidc_state = oidc_state_guard.as_ref().ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
285
286    let mut additional_claims = HashMap::new();
287    additional_claims.insert("client_id".to_string(), serde_json::json!(client_id));
288    let scope_clone = request.scope.clone();
289    if let Some(ref scope) = request.scope {
290        additional_claims.insert("scope".to_string(), serde_json::json!(scope));
291    }
292
293    let access_token = generate_oidc_token(
294        oidc_state,
295        format!("client_{}", client_id),
296        Some(additional_claims),
297        Some(3600),
298        None,
299    )
300    .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
301
302    Ok(Json(TokenResponse {
303        access_token,
304        token_type: "Bearer".to_string(),
305        expires_in: 3600,
306        refresh_token: None,
307        scope: scope_clone,
308        id_token: None,
309    }))
310}
311
312/// Handle refresh_token grant type
313async fn handle_refresh_token_grant(
314    state: OAuth2ServerState,
315    request: TokenRequest,
316) -> Result<Json<TokenResponse>, StatusCode> {
317    // Extract and validate the refresh token from the request
318    let refresh_token_value = request.refresh_token.ok_or(StatusCode::BAD_REQUEST)?;
319
320    // Look up and remove the old refresh token (single-use rotation)
321    let token_info = {
322        let mut tokens = state.refresh_tokens.write().await;
323        tokens.remove(&refresh_token_value).ok_or(StatusCode::UNAUTHORIZED)?
324    };
325
326    // Check expiration
327    if token_info.expires_at < Utc::now().timestamp() {
328        return Err(StatusCode::UNAUTHORIZED);
329    }
330
331    // Validate client_id matches if provided
332    if let Some(ref client_id) = request.client_id {
333        if *client_id != token_info.client_id {
334            return Err(StatusCode::UNAUTHORIZED);
335        }
336    }
337
338    // Generate new access token
339    let oidc_state_guard = state.oidc_state.read().await;
340    let oidc_state = oidc_state_guard.as_ref().ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;
341
342    let mut additional_claims = HashMap::new();
343    additional_claims.insert("client_id".to_string(), json!(token_info.client_id.clone()));
344
345    // Use scopes from stored token, or override with request scope if provided
346    let scopes = if let Some(ref scope) = request.scope {
347        additional_claims.insert("scope".to_string(), json!(scope));
348        scope.clone()
349    } else {
350        let scope_str = token_info.scopes.join(" ");
351        additional_claims.insert("scope".to_string(), json!(scope_str));
352        scope_str
353    };
354
355    let access_token = generate_oidc_token(
356        oidc_state,
357        token_info.user_id.clone(),
358        Some(additional_claims),
359        Some(3600),
360        None,
361    )
362    .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
363
364    // Generate and store new refresh token (rotation)
365    let new_refresh_token = format!("refresh_{}", uuid::Uuid::new_v4());
366    {
367        let mut tokens = state.refresh_tokens.write().await;
368        tokens.insert(
369            new_refresh_token.clone(),
370            RefreshTokenInfo {
371                client_id: token_info.client_id,
372                scopes: token_info.scopes,
373                user_id: token_info.user_id,
374                expires_at: Utc::now().timestamp() + 86400, // 24 hours
375            },
376        );
377    }
378
379    Ok(Json(TokenResponse {
380        access_token,
381        token_type: "Bearer".to_string(),
382        expires_in: 3600,
383        refresh_token: Some(new_refresh_token),
384        scope: Some(scopes),
385        id_token: None,
386    }))
387}
388
389/// Create OAuth2 server router
390pub fn oauth2_server_router(state: OAuth2ServerState) -> axum::Router {
391    use axum::routing::{get, post};
392
393    axum::Router::new()
394        .route("/oauth2/authorize", get(authorize))
395        .route("/oauth2/token", post(token))
396        .with_state(state)
397}