1use std::{net::SocketAddr, sync::Arc};
4
5use axum::{
6 Json,
7 extract::{ConnectInfo, Query, State},
8 http::StatusCode,
9 response::IntoResponse,
10};
11use serde::{Deserialize, Serialize};
12
13use crate::{
14 audit::logger::{AuditEventType, SecretType, get_audit_logger},
15 error::{AuthError, Result},
16 provider::OAuthProvider,
17 rate_limiting::RateLimiters,
18 session::SessionStore,
19 state_store::StateStore,
20};
21
22#[derive(Clone)]
24pub struct AuthState {
25 pub oauth_provider: Arc<dyn OAuthProvider>,
27 pub session_store: Arc<dyn SessionStore>,
29 pub state_store: Arc<dyn StateStore>,
31 pub rate_limiters: Arc<RateLimiters>,
33}
34
35#[derive(Debug, Deserialize)]
37pub struct AuthStartRequest {
38 pub provider: Option<String>,
40}
41
42#[derive(Debug, Serialize)]
44pub struct AuthStartResponse {
45 pub authorization_url: String,
47}
48
49#[derive(Debug, Deserialize)]
51pub struct AuthCallbackQuery {
52 pub code: String,
54 pub state: String,
56 pub error: Option<String>,
58 pub error_description: Option<String>,
60}
61
62#[derive(Debug, Serialize)]
69pub struct AuthCallbackResponse {
70 pub access_token: String,
72 pub refresh_token: Option<String>,
74 pub token_type: String,
76 pub expires_in: u64,
78}
79
80#[derive(Debug, Deserialize)]
82pub struct AuthRefreshRequest {
83 pub refresh_token: String,
85}
86
87#[derive(Debug, Serialize)]
89pub struct AuthRefreshResponse {
90 pub access_token: String,
92 pub token_type: String,
94 pub expires_in: u64,
96}
97
98#[derive(Debug, Deserialize)]
100pub struct AuthLogoutRequest {
101 pub refresh_token: Option<String>,
103}
104
105pub async fn auth_start(
121 State(state): State<AuthState>,
122 ConnectInfo(addr): ConnectInfo<SocketAddr>,
123 Json(req): Json<AuthStartRequest>,
124) -> Result<Json<AuthStartResponse>> {
125 let client_ip = addr.ip().to_string();
127 if state.rate_limiters.auth_start.check(&client_ip).is_err() {
128 return Err(AuthError::RateLimited {
129 retry_after_secs: state.rate_limiters.auth_start.clone_config().window_secs,
130 });
131 }
132
133 let state_value = generate_secure_state();
135
136 let now = std::time::SystemTime::now()
138 .duration_since(std::time::UNIX_EPOCH)
139 .map_err(|_| AuthError::SystemTimeError {
140 message: "Failed to get current system time".to_string(),
141 })?
142 .as_secs();
143
144 let expiry = now + 600;
146
147 let provider = req.provider.unwrap_or_else(|| "default".to_string());
149 state.state_store.store(state_value.clone(), provider, expiry).await?;
150
151 let authorization_url = state.oauth_provider.authorization_url(&state_value);
153
154 Ok(Json(AuthStartResponse { authorization_url }))
155}
156
157pub async fn auth_callback(
174 State(state): State<AuthState>,
175 ConnectInfo(addr): ConnectInfo<SocketAddr>,
176 Query(query): Query<AuthCallbackQuery>,
177) -> Result<impl IntoResponse> {
178 let client_ip = addr.ip().to_string();
180 if state.rate_limiters.auth_callback.check(&client_ip).is_err() {
181 return Err(AuthError::RateLimited {
182 retry_after_secs: state.rate_limiters.auth_callback.clone_config().window_secs,
183 });
184 }
185
186 if let Some(error) = query.error {
188 let audit_logger = get_audit_logger();
189 audit_logger.log_failure(
190 AuditEventType::OauthCallback,
191 SecretType::AuthorizationCode,
192 None,
193 "exchange",
194 &error,
195 );
196 return Err(AuthError::OAuthError {
197 message: format!("{}: {}", error, query.error_description.unwrap_or_default()),
198 });
199 }
200
201 let (_provider_name, expiry) = state.state_store.retrieve(&query.state).await?;
203
204 let now = std::time::SystemTime::now()
206 .duration_since(std::time::UNIX_EPOCH)
207 .map_err(|_| AuthError::SystemTimeError {
208 message: "Failed to get current system time".to_string(),
209 })?
210 .as_secs();
211
212 if now > expiry {
213 let audit_logger = get_audit_logger();
214 audit_logger.log_failure(
215 AuditEventType::CsrfStateValidated,
216 SecretType::StateToken,
217 None,
218 "validate",
219 "State token expired",
220 );
221 return Err(AuthError::InvalidState);
222 }
223
224 let audit_logger = get_audit_logger();
226 audit_logger.log_success(
227 AuditEventType::CsrfStateValidated,
228 SecretType::StateToken,
229 None,
230 "validate",
231 );
232
233 let token_response = state.oauth_provider.exchange_code(&query.code).await?;
235
236 let audit_logger = get_audit_logger();
238 audit_logger.log_success(
239 AuditEventType::OauthCallback,
240 SecretType::AuthorizationCode,
241 None,
242 "exchange",
243 );
244
245 let user_info = state.oauth_provider.user_info(&token_response.access_token).await?;
247
248 let expires_at = now + (7 * 24 * 60 * 60);
250 let session_tokens = state.session_store.create_session(&user_info.id, expires_at).await?;
251
252 let audit_logger = get_audit_logger();
254 audit_logger.log_success(
255 AuditEventType::SessionTokenCreated,
256 SecretType::SessionToken,
257 Some(user_info.id.clone()),
258 "create",
259 );
260
261 let audit_logger = get_audit_logger();
263 audit_logger.log_success(
264 AuditEventType::AuthSuccess,
265 SecretType::SessionToken,
266 Some(user_info.id),
267 "oauth_flow",
268 );
269
270 let response = AuthCallbackResponse {
271 access_token: session_tokens.access_token,
272 refresh_token: Some(session_tokens.refresh_token),
273 token_type: "Bearer".to_string(),
274 expires_in: session_tokens.expires_in,
275 };
276
277 Ok(Json(response))
280}
281
282pub async fn auth_refresh(
298 State(state): State<AuthState>,
299 Json(req): Json<AuthRefreshRequest>,
300) -> Result<Json<AuthRefreshResponse>> {
301 use crate::session::hash_token;
303 let token_hash = hash_token(&req.refresh_token);
304 let session = state.session_store.get_session(&token_hash).await?;
305
306 if session.is_expired() {
310 let audit_logger = get_audit_logger();
311 audit_logger.log_failure(
312 AuditEventType::JwtRefresh,
313 SecretType::RefreshToken,
314 Some(session.user_id),
315 "refresh",
316 "Session expired",
317 );
318 return Err(AuthError::TokenExpired);
319 }
320
321 if state.rate_limiters.auth_refresh.check(&session.user_id).is_err() {
323 return Err(AuthError::RateLimited {
324 retry_after_secs: state.rate_limiters.auth_refresh.clone_config().window_secs,
325 });
326 }
327
328 let audit_logger = get_audit_logger();
330 audit_logger.log_success(
331 AuditEventType::SessionTokenValidation,
332 SecretType::RefreshToken,
333 Some(session.user_id),
334 "validate",
335 );
336
337 Err(AuthError::Internal {
340 message: "JWT signing not yet implemented — configure an OIDC provider for token issuance"
341 .to_string(),
342 })
343}
344
345pub async fn auth_logout(
360 State(state): State<AuthState>,
361 ConnectInfo(addr): ConnectInfo<SocketAddr>,
362 Json(req): Json<AuthLogoutRequest>,
363) -> Result<StatusCode> {
364 let client_ip = addr.ip().to_string();
365
366 if let Some(refresh_token) = req.refresh_token {
367 use crate::session::hash_token;
368 let token_hash = hash_token(&refresh_token);
369
370 let session = state.session_store.get_session(&token_hash).await?;
372
373 if state.rate_limiters.auth_logout.check(&session.user_id).is_err() {
375 return Err(AuthError::RateLimited {
376 retry_after_secs: state.rate_limiters.auth_logout.clone_config().window_secs,
377 });
378 }
379
380 state.session_store.revoke_session(&token_hash).await?;
381
382 let audit_logger = get_audit_logger();
384 audit_logger.log_success(
385 AuditEventType::SessionTokenRevoked,
386 SecretType::RefreshToken,
387 Some(session.user_id),
388 "revoke",
389 );
390 } else {
391 if state.rate_limiters.auth_logout.check(&client_ip).is_err() {
393 return Err(AuthError::RateLimited {
394 retry_after_secs: state.rate_limiters.auth_logout.clone_config().window_secs,
395 });
396 }
397 }
398
399 Ok(StatusCode::NO_CONTENT)
400}
401
402pub fn generate_secure_state() -> String {
405 use rand::RngCore;
406
407 let mut bytes = [0u8; 32];
409 rand::rngs::OsRng.fill_bytes(&mut bytes);
410
411 hex::encode(bytes)
413}
414
415#[cfg(test)]
416mod tests {
417 #[allow(clippy::wildcard_imports)]
418 use super::*;
420
421 #[test]
422 fn test_generate_secure_state() {
423 let state1 = generate_secure_state();
424 let state2 = generate_secure_state();
425
426 assert_ne!(state1, state2);
428 assert!(!state1.is_empty());
430 assert!(!state2.is_empty());
431 assert_eq!(state1.len(), 64);
433 assert_eq!(state2.len(), 64);
434 hex::decode(&state1).unwrap_or_else(|e| panic!("state1 should be valid hex: {e}"));
436 hex::decode(&state2).unwrap_or_else(|e| panic!("state2 should be valid hex: {e}"));
437 }
438}