Skip to main content

fraiseql_server/auth/
handlers.rs

1// HTTP handlers for authentication endpoints
2use std::sync::Arc;
3
4use axum::{
5    Json,
6    extract::{Query, State},
7    http::StatusCode,
8    response::IntoResponse,
9};
10use serde::{Deserialize, Serialize};
11
12use crate::auth::{
13    audit_logger::{AuditEventType, SecretType, get_audit_logger},
14    error::{AuthError, Result},
15    provider::OAuthProvider,
16    session::SessionStore,
17    state_store::StateStore,
18};
19
20/// AuthState holds the auth configuration and backends
21#[derive(Clone)]
22pub struct AuthState {
23    /// OAuth provider
24    pub oauth_provider: Arc<dyn OAuthProvider>,
25    /// Session store backend
26    pub session_store:  Arc<dyn SessionStore>,
27    /// CSRF state store backend (in-memory for single-instance, Redis for distributed)
28    pub state_store:    Arc<dyn StateStore>,
29}
30
31/// Request body for auth/start endpoint
32#[derive(Debug, Deserialize)]
33pub struct AuthStartRequest {
34    /// Optional provider name (for multi-provider setups)
35    pub provider: Option<String>,
36}
37
38/// Response for auth/start endpoint
39#[derive(Debug, Serialize)]
40pub struct AuthStartResponse {
41    /// Authorization URL to redirect user to
42    pub authorization_url: String,
43}
44
45/// Query parameters for auth/callback endpoint
46#[derive(Debug, Deserialize)]
47pub struct AuthCallbackQuery {
48    /// Authorization code from provider
49    pub code:              String,
50    /// State parameter for CSRF protection
51    pub state:             String,
52    /// Error from provider if present
53    pub error:             Option<String>,
54    /// Error description from provider
55    pub error_description: Option<String>,
56}
57
58/// Response for auth/callback endpoint
59#[derive(Debug, Serialize)]
60pub struct AuthCallbackResponse {
61    /// Access token for API requests
62    pub access_token:  String,
63    /// Optional refresh token
64    pub refresh_token: Option<String>,
65    /// Token type (usually "Bearer")
66    pub token_type:    String,
67    /// Time in seconds until token expires
68    pub expires_in:    u64,
69}
70
71/// Request body for auth/refresh endpoint
72#[derive(Debug, Deserialize)]
73pub struct AuthRefreshRequest {
74    /// Refresh token to exchange for new access token
75    pub refresh_token: String,
76}
77
78/// Response for auth/refresh endpoint
79#[derive(Debug, Serialize)]
80pub struct AuthRefreshResponse {
81    /// New access token
82    pub access_token: String,
83    /// Token type
84    pub token_type:   String,
85    /// Time in seconds until token expires
86    pub expires_in:   u64,
87}
88
89/// Request body for auth/logout endpoint
90#[derive(Debug, Deserialize)]
91pub struct AuthLogoutRequest {
92    /// Refresh token to revoke
93    pub refresh_token: Option<String>,
94}
95
96/// POST /auth/start - Initiate OAuth flow
97///
98/// Returns an authorization URL that the client should redirect the user to.
99pub async fn auth_start(
100    State(state): State<AuthState>,
101    Json(req): Json<AuthStartRequest>,
102) -> Result<Json<AuthStartResponse>> {
103    // Generate random state for CSRF protection using cryptographically secure RNG
104    let state_value = generate_secure_state();
105
106    // Get current time with explicit error handling (not unwrap_or_default)
107    let now = std::time::SystemTime::now()
108        .duration_since(std::time::UNIX_EPOCH)
109        .map_err(|_| AuthError::SystemTimeError {
110            message: "Failed to get current system time".to_string(),
111        })?
112        .as_secs();
113
114    // Store state with expiry (10 minutes)
115    let expiry = now + 600;
116
117    // SECURITY: Store state using configurable backend (in-memory or distributed)
118    let provider = req.provider.unwrap_or_else(|| "default".to_string());
119    state.state_store.store(state_value.clone(), provider, expiry).await?;
120
121    // Generate authorization URL
122    let authorization_url = state.oauth_provider.authorization_url(&state_value);
123
124    Ok(Json(AuthStartResponse { authorization_url }))
125}
126
127/// GET /auth/callback - OAuth provider redirects here
128///
129/// Exchanges the authorization code for tokens and creates a session.
130pub async fn auth_callback(
131    State(state): State<AuthState>,
132    Query(query): Query<AuthCallbackQuery>,
133) -> Result<impl IntoResponse> {
134    // Check for provider error
135    if let Some(error) = query.error {
136        let audit_logger = get_audit_logger();
137        audit_logger.log_failure(
138            AuditEventType::OauthCallback,
139            SecretType::AuthorizationCode,
140            None,
141            "exchange",
142            &error,
143        );
144        return Err(AuthError::OAuthError {
145            message: format!("{}: {}", error, query.error_description.unwrap_or_default()),
146        });
147    }
148
149    // SECURITY: Validate state using configurable backend (distributed-safe)
150    let (_provider_name, expiry) = state.state_store.retrieve(&query.state).await?;
151
152    // Check state expiry with explicit error handling
153    let now = std::time::SystemTime::now()
154        .duration_since(std::time::UNIX_EPOCH)
155        .map_err(|_| AuthError::SystemTimeError {
156            message: "Failed to get current system time".to_string(),
157        })?
158        .as_secs();
159
160    if now > expiry {
161        let audit_logger = get_audit_logger();
162        audit_logger.log_failure(
163            AuditEventType::CsrfStateValidated,
164            SecretType::StateToken,
165            None,
166            "validate",
167            "State token expired",
168        );
169        return Err(AuthError::InvalidState);
170    }
171
172    // Audit log: CSRF state validation success
173    let audit_logger = get_audit_logger();
174    audit_logger.log_success(
175        AuditEventType::CsrfStateValidated,
176        SecretType::StateToken,
177        None,
178        "validate",
179    );
180
181    // Exchange code for tokens
182    let token_response = state.oauth_provider.exchange_code(&query.code).await?;
183
184    // Audit log: Token exchange success
185    let audit_logger = get_audit_logger();
186    audit_logger.log_success(
187        AuditEventType::OauthCallback,
188        SecretType::AuthorizationCode,
189        None,
190        "exchange",
191    );
192
193    // Get user info
194    let user_info = state.oauth_provider.user_info(&token_response.access_token).await?;
195
196    // Create session (expires in 7 days)
197    let expires_at = now + (7 * 24 * 60 * 60);
198    let session_tokens = state.session_store.create_session(&user_info.id, expires_at).await?;
199
200    // Audit log: Session token created
201    let audit_logger = get_audit_logger();
202    audit_logger.log_success(
203        AuditEventType::SessionTokenCreated,
204        SecretType::SessionToken,
205        Some(user_info.id.clone()),
206        "create",
207    );
208
209    // Audit log: Auth success
210    let audit_logger = get_audit_logger();
211    audit_logger.log_success(
212        AuditEventType::AuthSuccess,
213        SecretType::SessionToken,
214        Some(user_info.id),
215        "oauth_flow",
216    );
217
218    let response = AuthCallbackResponse {
219        access_token:  session_tokens.access_token,
220        refresh_token: Some(session_tokens.refresh_token),
221        token_type:    "Bearer".to_string(),
222        expires_in:    session_tokens.expires_in,
223    };
224
225    // In a real app, would redirect to frontend with tokens in URL fragment
226    // For now, return JSON
227    Ok(Json(response))
228}
229
230/// POST /auth/refresh - Refresh access token
231///
232/// Uses refresh token to obtain a new access token.
233pub async fn auth_refresh(
234    State(state): State<AuthState>,
235    Json(req): Json<AuthRefreshRequest>,
236) -> Result<Json<AuthRefreshResponse>> {
237    // Validate refresh token exists in session store
238    use crate::auth::session::hash_token;
239    let token_hash = hash_token(&req.refresh_token);
240    let session = state.session_store.get_session(&token_hash).await?;
241
242    // Audit log: Refresh token validation success
243    let audit_logger = get_audit_logger();
244    audit_logger.log_success(
245        AuditEventType::SessionTokenValidation,
246        SecretType::RefreshToken,
247        Some(session.user_id.clone()),
248        "validate",
249    );
250
251    // In a real implementation, would generate new JWT here
252    // For now, return a simple response
253    let access_token = format!("new_access_token_{}", uuid::Uuid::new_v4());
254
255    // Audit log: JWT refresh success
256    let audit_logger = get_audit_logger();
257    audit_logger.log_success(
258        AuditEventType::JwtRefresh,
259        SecretType::JwtToken,
260        Some(session.user_id),
261        "refresh",
262    );
263
264    Ok(Json(AuthRefreshResponse {
265        access_token,
266        token_type: "Bearer".to_string(),
267        expires_in: 3600,
268    }))
269}
270
271/// POST /auth/logout - Logout and revoke session
272///
273/// Revokes the refresh token, effectively logging out the user.
274pub async fn auth_logout(
275    State(state): State<AuthState>,
276    Json(req): Json<AuthLogoutRequest>,
277) -> Result<StatusCode> {
278    if let Some(refresh_token) = req.refresh_token {
279        use crate::auth::session::hash_token;
280        let token_hash = hash_token(&refresh_token);
281        state.session_store.revoke_session(&token_hash).await?;
282    }
283
284    Ok(StatusCode::NO_CONTENT)
285}
286
287/// Generate a cryptographically random state for CSRF protection
288/// Uses OsRng for cryptographically secure randomness
289pub fn generate_secure_state() -> String {
290    use rand::RngCore;
291
292    // Generate 32 random bytes for 256 bits of entropy
293    let mut bytes = [0u8; 32];
294    rand::rngs::OsRng.fill_bytes(&mut bytes);
295
296    // Encode as hex string for safe transmission in URLs/headers
297    hex::encode(bytes)
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303
304    #[test]
305    fn test_generate_secure_state() {
306        let state1 = generate_secure_state();
307        let state2 = generate_secure_state();
308
309        // States should be random and different
310        assert_ne!(state1, state2);
311        // Should be non-empty
312        assert!(!state1.is_empty());
313        assert!(!state2.is_empty());
314        // Should be 64 hex characters (32 bytes encoded)
315        assert_eq!(state1.len(), 64);
316        assert_eq!(state2.len(), 64);
317        // Should be valid hex
318        assert!(hex::decode(&state1).is_ok());
319        assert!(hex::decode(&state2).is_ok());
320    }
321}