Skip to main content

fraiseql_auth/
handlers.rs

1//! HTTP handlers for the built-in authentication endpoints (`/auth/start`,
2//! `/auth/callback`, `/auth/refresh`, `/auth/logout`).
3use std::{net::SocketAddr, sync::Arc};
4
5use axum::{
6    Json,
7    extract::{ConnectInfo, Query, State},
8    http::StatusCode,
9    response::IntoResponse,
10};
11use serde::{Deserialize, Serialize};
12
13use crate::{
14    audit::logger::{AuditEventType, SecretType, get_audit_logger},
15    error::{AuthError, Result},
16    provider::OAuthProvider,
17    rate_limiting::RateLimiters,
18    session::SessionStore,
19    state_store::StateStore,
20};
21
22/// AuthState holds the auth configuration and backends
23#[derive(Clone)]
24pub struct AuthState {
25    /// OAuth provider
26    pub oauth_provider: Arc<dyn OAuthProvider>,
27    /// Session store backend
28    pub session_store:  Arc<dyn SessionStore>,
29    /// CSRF state store backend (in-memory for single-instance, Redis for distributed)
30    pub state_store:    Arc<dyn StateStore>,
31    /// Rate limiters for auth endpoints (per-IP based)
32    pub rate_limiters:  Arc<RateLimiters>,
33}
34
35/// Request body for auth/start endpoint
36#[derive(Debug, Deserialize)]
37pub struct AuthStartRequest {
38    /// Optional provider name (for multi-provider setups)
39    pub provider: Option<String>,
40}
41
42/// Response body for the `POST /auth/start` endpoint.
43#[derive(Debug, Serialize)]
44pub struct AuthStartResponse {
45    /// Authorization URL to redirect user to
46    pub authorization_url: String,
47}
48
49/// Query parameters for auth/callback endpoint
50#[derive(Debug, Deserialize)]
51pub struct AuthCallbackQuery {
52    /// Authorization code from provider
53    pub code:              String,
54    /// State parameter for CSRF protection
55    pub state:             String,
56    /// Error from provider if present
57    pub error:             Option<String>,
58    /// Error description from provider
59    pub error_description: Option<String>,
60}
61
62/// Response body for the `GET /auth/callback` endpoint.
63///
64/// Returned after a successful OAuth authorization-code exchange.
65/// In a production browser-facing flow, the server would instead redirect
66/// the user agent to the frontend application with tokens in a URL fragment;
67/// this JSON form is suitable for API clients and testing.
68#[derive(Debug, Serialize)]
69pub struct AuthCallbackResponse {
70    /// Access token for API requests
71    pub access_token:  String,
72    /// Optional refresh token
73    pub refresh_token: Option<String>,
74    /// Token type (usually "Bearer")
75    pub token_type:    String,
76    /// Time in seconds until token expires
77    pub expires_in:    u64,
78}
79
80/// Request body for auth/refresh endpoint
81#[derive(Debug, Deserialize)]
82pub struct AuthRefreshRequest {
83    /// Refresh token to exchange for new access token
84    pub refresh_token: String,
85}
86
87/// Response body for the `POST /auth/refresh` endpoint.
88#[derive(Debug, Serialize)]
89pub struct AuthRefreshResponse {
90    /// New access token
91    pub access_token: String,
92    /// Token type
93    pub token_type:   String,
94    /// Time in seconds until token expires
95    pub expires_in:   u64,
96}
97
98/// Request body for auth/logout endpoint
99#[derive(Debug, Deserialize)]
100pub struct AuthLogoutRequest {
101    /// Refresh token to revoke
102    pub refresh_token: Option<String>,
103}
104
105/// POST /auth/start - Initiate OAuth flow
106///
107/// Returns an authorization URL that the client should redirect the user to.
108///
109/// # Rate Limiting
110///
111/// This endpoint is rate-limited per IP address to prevent brute-force attacks.
112/// The limit is configurable via FRAISEQL_AUTH_START_MAX_REQUESTS and
113/// FRAISEQL_AUTH_START_WINDOW_SECS environment variables.
114///
115/// # Errors
116///
117/// Returns `AuthError::RateLimited` if the per-IP rate limit is exceeded.
118/// Returns `AuthError::SystemTimeError` if the system clock is unavailable.
119/// Returns `AuthError` if the state store write fails.
120pub async fn auth_start(
121    State(state): State<AuthState>,
122    ConnectInfo(addr): ConnectInfo<SocketAddr>,
123    Json(req): Json<AuthStartRequest>,
124) -> Result<Json<AuthStartResponse>> {
125    // SECURITY: Check rate limiting for auth/start endpoint (per IP)
126    let client_ip = addr.ip().to_string();
127    if state.rate_limiters.auth_start.check(&client_ip).is_err() {
128        return Err(AuthError::RateLimited {
129            retry_after_secs: state.rate_limiters.auth_start.clone_config().window_secs,
130        });
131    }
132
133    // Generate random state for CSRF protection using cryptographically secure RNG
134    let state_value = generate_secure_state();
135
136    // Get current time with explicit error handling (not unwrap_or_default)
137    let now = std::time::SystemTime::now()
138        .duration_since(std::time::UNIX_EPOCH)
139        .map_err(|_| AuthError::SystemTimeError {
140            message: "Failed to get current system time".to_string(),
141        })?
142        .as_secs();
143
144    // Store state with expiry (10 minutes)
145    let expiry = now + 600;
146
147    // SECURITY: Store state using configurable backend (in-memory or distributed)
148    let provider = req.provider.unwrap_or_else(|| "default".to_string());
149    state.state_store.store(state_value.clone(), provider, expiry).await?;
150
151    // Generate authorization URL
152    let authorization_url = state.oauth_provider.authorization_url(&state_value);
153
154    Ok(Json(AuthStartResponse { authorization_url }))
155}
156
157/// GET /auth/callback - OAuth provider redirects here
158///
159/// Exchanges the authorization code for tokens and creates a session.
160///
161/// # Rate Limiting
162///
163/// This endpoint is rate-limited per IP address to prevent brute-force attacks.
164/// The limit is configurable via FRAISEQL_AUTH_CALLBACK_MAX_REQUESTS and
165/// FRAISEQL_AUTH_CALLBACK_WINDOW_SECS environment variables.
166///
167/// # Errors
168///
169/// Returns `AuthError::RateLimited` if the per-IP rate limit is exceeded.
170/// Returns `AuthError::OAuthError` if the provider returned an error.
171/// Returns `AuthError::InvalidState` if the CSRF state token is expired or invalid.
172/// Returns `AuthError` if the token exchange or session creation fails.
173pub async fn auth_callback(
174    State(state): State<AuthState>,
175    ConnectInfo(addr): ConnectInfo<SocketAddr>,
176    Query(query): Query<AuthCallbackQuery>,
177) -> Result<impl IntoResponse> {
178    // SECURITY: Check rate limiting for auth/callback endpoint (per IP)
179    let client_ip = addr.ip().to_string();
180    if state.rate_limiters.auth_callback.check(&client_ip).is_err() {
181        return Err(AuthError::RateLimited {
182            retry_after_secs: state.rate_limiters.auth_callback.clone_config().window_secs,
183        });
184    }
185
186    // Check for provider error
187    if let Some(error) = query.error {
188        let audit_logger = get_audit_logger();
189        audit_logger.log_failure(
190            AuditEventType::OauthCallback,
191            SecretType::AuthorizationCode,
192            None,
193            "exchange",
194            &error,
195        );
196        return Err(AuthError::OAuthError {
197            message: format!("{}: {}", error, query.error_description.unwrap_or_default()),
198        });
199    }
200
201    // SECURITY: Validate state using configurable backend (distributed-safe)
202    let (_provider_name, expiry) = state.state_store.retrieve(&query.state).await?;
203
204    // Check state expiry with explicit error handling
205    let now = std::time::SystemTime::now()
206        .duration_since(std::time::UNIX_EPOCH)
207        .map_err(|_| AuthError::SystemTimeError {
208            message: "Failed to get current system time".to_string(),
209        })?
210        .as_secs();
211
212    if now > expiry {
213        let audit_logger = get_audit_logger();
214        audit_logger.log_failure(
215            AuditEventType::CsrfStateValidated,
216            SecretType::StateToken,
217            None,
218            "validate",
219            "State token expired",
220        );
221        return Err(AuthError::InvalidState);
222    }
223
224    // Audit log: CSRF state validation success
225    let audit_logger = get_audit_logger();
226    audit_logger.log_success(
227        AuditEventType::CsrfStateValidated,
228        SecretType::StateToken,
229        None,
230        "validate",
231    );
232
233    // Exchange code for tokens
234    let token_response = state.oauth_provider.exchange_code(&query.code).await?;
235
236    // Audit log: Token exchange success
237    let audit_logger = get_audit_logger();
238    audit_logger.log_success(
239        AuditEventType::OauthCallback,
240        SecretType::AuthorizationCode,
241        None,
242        "exchange",
243    );
244
245    // Get user info
246    let user_info = state.oauth_provider.user_info(&token_response.access_token).await?;
247
248    // Create session (expires in 7 days)
249    let expires_at = now + (7 * 24 * 60 * 60);
250    let session_tokens = state.session_store.create_session(&user_info.id, expires_at).await?;
251
252    // Audit log: Session token created
253    let audit_logger = get_audit_logger();
254    audit_logger.log_success(
255        AuditEventType::SessionTokenCreated,
256        SecretType::SessionToken,
257        Some(user_info.id.clone()),
258        "create",
259    );
260
261    // Audit log: Auth success
262    let audit_logger = get_audit_logger();
263    audit_logger.log_success(
264        AuditEventType::AuthSuccess,
265        SecretType::SessionToken,
266        Some(user_info.id),
267        "oauth_flow",
268    );
269
270    let response = AuthCallbackResponse {
271        access_token:  session_tokens.access_token,
272        refresh_token: Some(session_tokens.refresh_token),
273        token_type:    "Bearer".to_string(),
274        expires_in:    session_tokens.expires_in,
275    };
276
277    // In a real app, would redirect to frontend with tokens in URL fragment
278    // For now, return JSON
279    Ok(Json(response))
280}
281
282/// POST /auth/refresh - Refresh access token
283///
284/// Uses refresh token to obtain a new access token.
285///
286/// # Rate Limiting
287///
288/// This endpoint is rate-limited per user ID to prevent token refresh attacks.
289/// The limit is configurable via FRAISEQL_AUTH_REFRESH_MAX_REQUESTS and
290/// FRAISEQL_AUTH_REFRESH_WINDOW_SECS environment variables.
291///
292/// # Errors
293///
294/// Returns `AuthError::TokenExpired` if the session has expired.
295/// Returns `AuthError::RateLimited` if the per-user rate limit is exceeded.
296/// Returns `AuthError::Internal` if JWT signing is not yet configured.
297pub async fn auth_refresh(
298    State(state): State<AuthState>,
299    Json(req): Json<AuthRefreshRequest>,
300) -> Result<Json<AuthRefreshResponse>> {
301    // Validate refresh token exists in session store
302    use crate::session::hash_token;
303    let token_hash = hash_token(&req.refresh_token);
304    let session = state.session_store.get_session(&token_hash).await?;
305
306    // SECURITY: Reject expired sessions before any further processing.
307    // Without this check, a stolen refresh token from an expired session
308    // could be used indefinitely to mint new access tokens.
309    if session.is_expired() {
310        let audit_logger = get_audit_logger();
311        audit_logger.log_failure(
312            AuditEventType::JwtRefresh,
313            SecretType::RefreshToken,
314            Some(session.user_id),
315            "refresh",
316            "Session expired",
317        );
318        return Err(AuthError::TokenExpired);
319    }
320
321    // SECURITY: Check rate limiting for auth/refresh endpoint (per user)
322    if state.rate_limiters.auth_refresh.check(&session.user_id).is_err() {
323        return Err(AuthError::RateLimited {
324            retry_after_secs: state.rate_limiters.auth_refresh.clone_config().window_secs,
325        });
326    }
327
328    // Audit log: Refresh token validation success
329    let audit_logger = get_audit_logger();
330    audit_logger.log_success(
331        AuditEventType::SessionTokenValidation,
332        SecretType::RefreshToken,
333        Some(session.user_id),
334        "validate",
335    );
336
337    // JWT signing requires an RSA/EC private key, which is not yet wired
338    // into the auth state. Return an explicit error rather than a fake token.
339    Err(AuthError::Internal {
340        message: "JWT signing not yet implemented — configure an OIDC provider for token issuance"
341            .to_string(),
342    })
343}
344
345/// POST /auth/logout - Logout and revoke session
346///
347/// Revokes the refresh token, effectively logging out the user.
348///
349/// # Rate Limiting
350///
351/// This endpoint is rate-limited per user ID to prevent logout token exhaustion attacks.
352/// The limit is configurable via FRAISEQL_AUTH_LOGOUT_MAX_REQUESTS and
353/// FRAISEQL_AUTH_LOGOUT_WINDOW_SECS environment variables.
354///
355/// # Errors
356///
357/// Returns `AuthError::RateLimited` if the per-user rate limit is exceeded.
358/// Returns `AuthError` if the session lookup or deletion fails.
359pub async fn auth_logout(
360    State(state): State<AuthState>,
361    ConnectInfo(addr): ConnectInfo<SocketAddr>,
362    Json(req): Json<AuthLogoutRequest>,
363) -> Result<StatusCode> {
364    let client_ip = addr.ip().to_string();
365
366    if let Some(refresh_token) = req.refresh_token {
367        use crate::session::hash_token;
368        let token_hash = hash_token(&refresh_token);
369
370        // Get session to extract user ID for per-user rate limiting
371        let session = state.session_store.get_session(&token_hash).await?;
372
373        // SECURITY: Check rate limiting for auth/logout endpoint (per user)
374        if state.rate_limiters.auth_logout.check(&session.user_id).is_err() {
375            return Err(AuthError::RateLimited {
376                retry_after_secs: state.rate_limiters.auth_logout.clone_config().window_secs,
377            });
378        }
379
380        state.session_store.revoke_session(&token_hash).await?;
381
382        // Audit log: Session revoked
383        let audit_logger = get_audit_logger();
384        audit_logger.log_success(
385            AuditEventType::SessionTokenRevoked,
386            SecretType::RefreshToken,
387            Some(session.user_id),
388            "revoke",
389        );
390    } else {
391        // No refresh token - use IP-based rate limiting as fallback
392        if state.rate_limiters.auth_logout.check(&client_ip).is_err() {
393            return Err(AuthError::RateLimited {
394                retry_after_secs: state.rate_limiters.auth_logout.clone_config().window_secs,
395            });
396        }
397    }
398
399    Ok(StatusCode::NO_CONTENT)
400}
401
402/// Generate a cryptographically random state for CSRF protection
403/// Uses OsRng for cryptographically secure randomness
404pub fn generate_secure_state() -> String {
405    use rand::RngCore;
406
407    // Generate 32 random bytes for 256 bits of entropy
408    let mut bytes = [0u8; 32];
409    rand::rngs::OsRng.fill_bytes(&mut bytes);
410
411    // Encode as hex string for safe transmission in URLs/headers
412    hex::encode(bytes)
413}
414
415#[cfg(test)]
416mod tests {
417    #[allow(clippy::wildcard_imports)]
418    // Reason: test module — wildcard keeps test boilerplate minimal
419    use super::*;
420
421    #[test]
422    fn test_generate_secure_state() {
423        let state1 = generate_secure_state();
424        let state2 = generate_secure_state();
425
426        // States should be random and different
427        assert_ne!(state1, state2);
428        // Should be non-empty
429        assert!(!state1.is_empty());
430        assert!(!state2.is_empty());
431        // Should be 64 hex characters (32 bytes encoded)
432        assert_eq!(state1.len(), 64);
433        assert_eq!(state2.len(), 64);
434        // Should be valid hex
435        hex::decode(&state1).unwrap_or_else(|e| panic!("state1 should be valid hex: {e}"));
436        hex::decode(&state2).unwrap_or_else(|e| panic!("state2 should be valid hex: {e}"));
437    }
438}