1use std::sync::Arc;
3
4use axum::{
5 Json,
6 extract::{Query, State},
7 http::StatusCode,
8 response::IntoResponse,
9};
10use serde::{Deserialize, Serialize};
11
12use crate::auth::{
13 audit_logger::{AuditEventType, SecretType, get_audit_logger},
14 error::{AuthError, Result},
15 provider::OAuthProvider,
16 session::SessionStore,
17 state_store::StateStore,
18};
19
20#[derive(Clone)]
22pub struct AuthState {
23 pub oauth_provider: Arc<dyn OAuthProvider>,
25 pub session_store: Arc<dyn SessionStore>,
27 pub state_store: Arc<dyn StateStore>,
29}
30
31#[derive(Debug, Deserialize)]
33pub struct AuthStartRequest {
34 pub provider: Option<String>,
36}
37
38#[derive(Debug, Serialize)]
40pub struct AuthStartResponse {
41 pub authorization_url: String,
43}
44
45#[derive(Debug, Deserialize)]
47pub struct AuthCallbackQuery {
48 pub code: String,
50 pub state: String,
52 pub error: Option<String>,
54 pub error_description: Option<String>,
56}
57
58#[derive(Debug, Serialize)]
60pub struct AuthCallbackResponse {
61 pub access_token: String,
63 pub refresh_token: Option<String>,
65 pub token_type: String,
67 pub expires_in: u64,
69}
70
71#[derive(Debug, Deserialize)]
73pub struct AuthRefreshRequest {
74 pub refresh_token: String,
76}
77
78#[derive(Debug, Serialize)]
80pub struct AuthRefreshResponse {
81 pub access_token: String,
83 pub token_type: String,
85 pub expires_in: u64,
87}
88
89#[derive(Debug, Deserialize)]
91pub struct AuthLogoutRequest {
92 pub refresh_token: Option<String>,
94}
95
96pub async fn auth_start(
100 State(state): State<AuthState>,
101 Json(req): Json<AuthStartRequest>,
102) -> Result<Json<AuthStartResponse>> {
103 let state_value = generate_secure_state();
105
106 let now = std::time::SystemTime::now()
108 .duration_since(std::time::UNIX_EPOCH)
109 .map_err(|_| AuthError::SystemTimeError {
110 message: "Failed to get current system time".to_string(),
111 })?
112 .as_secs();
113
114 let expiry = now + 600;
116
117 let provider = req.provider.unwrap_or_else(|| "default".to_string());
119 state.state_store.store(state_value.clone(), provider, expiry).await?;
120
121 let authorization_url = state.oauth_provider.authorization_url(&state_value);
123
124 Ok(Json(AuthStartResponse { authorization_url }))
125}
126
127pub async fn auth_callback(
131 State(state): State<AuthState>,
132 Query(query): Query<AuthCallbackQuery>,
133) -> Result<impl IntoResponse> {
134 if let Some(error) = query.error {
136 let audit_logger = get_audit_logger();
137 audit_logger.log_failure(
138 AuditEventType::OauthCallback,
139 SecretType::AuthorizationCode,
140 None,
141 "exchange",
142 &error,
143 );
144 return Err(AuthError::OAuthError {
145 message: format!("{}: {}", error, query.error_description.unwrap_or_default()),
146 });
147 }
148
149 let (_provider_name, expiry) = state.state_store.retrieve(&query.state).await?;
151
152 let now = std::time::SystemTime::now()
154 .duration_since(std::time::UNIX_EPOCH)
155 .map_err(|_| AuthError::SystemTimeError {
156 message: "Failed to get current system time".to_string(),
157 })?
158 .as_secs();
159
160 if now > expiry {
161 let audit_logger = get_audit_logger();
162 audit_logger.log_failure(
163 AuditEventType::CsrfStateValidated,
164 SecretType::StateToken,
165 None,
166 "validate",
167 "State token expired",
168 );
169 return Err(AuthError::InvalidState);
170 }
171
172 let audit_logger = get_audit_logger();
174 audit_logger.log_success(
175 AuditEventType::CsrfStateValidated,
176 SecretType::StateToken,
177 None,
178 "validate",
179 );
180
181 let token_response = state.oauth_provider.exchange_code(&query.code).await?;
183
184 let audit_logger = get_audit_logger();
186 audit_logger.log_success(
187 AuditEventType::OauthCallback,
188 SecretType::AuthorizationCode,
189 None,
190 "exchange",
191 );
192
193 let user_info = state.oauth_provider.user_info(&token_response.access_token).await?;
195
196 let expires_at = now + (7 * 24 * 60 * 60);
198 let session_tokens = state.session_store.create_session(&user_info.id, expires_at).await?;
199
200 let audit_logger = get_audit_logger();
202 audit_logger.log_success(
203 AuditEventType::SessionTokenCreated,
204 SecretType::SessionToken,
205 Some(user_info.id.clone()),
206 "create",
207 );
208
209 let audit_logger = get_audit_logger();
211 audit_logger.log_success(
212 AuditEventType::AuthSuccess,
213 SecretType::SessionToken,
214 Some(user_info.id),
215 "oauth_flow",
216 );
217
218 let response = AuthCallbackResponse {
219 access_token: session_tokens.access_token,
220 refresh_token: Some(session_tokens.refresh_token),
221 token_type: "Bearer".to_string(),
222 expires_in: session_tokens.expires_in,
223 };
224
225 Ok(Json(response))
228}
229
230pub async fn auth_refresh(
234 State(state): State<AuthState>,
235 Json(req): Json<AuthRefreshRequest>,
236) -> Result<Json<AuthRefreshResponse>> {
237 use crate::auth::session::hash_token;
239 let token_hash = hash_token(&req.refresh_token);
240 let session = state.session_store.get_session(&token_hash).await?;
241
242 let audit_logger = get_audit_logger();
244 audit_logger.log_success(
245 AuditEventType::SessionTokenValidation,
246 SecretType::RefreshToken,
247 Some(session.user_id.clone()),
248 "validate",
249 );
250
251 let access_token = format!("new_access_token_{}", uuid::Uuid::new_v4());
254
255 let audit_logger = get_audit_logger();
257 audit_logger.log_success(
258 AuditEventType::JwtRefresh,
259 SecretType::JwtToken,
260 Some(session.user_id),
261 "refresh",
262 );
263
264 Ok(Json(AuthRefreshResponse {
265 access_token,
266 token_type: "Bearer".to_string(),
267 expires_in: 3600,
268 }))
269}
270
271pub async fn auth_logout(
275 State(state): State<AuthState>,
276 Json(req): Json<AuthLogoutRequest>,
277) -> Result<StatusCode> {
278 if let Some(refresh_token) = req.refresh_token {
279 use crate::auth::session::hash_token;
280 let token_hash = hash_token(&refresh_token);
281 state.session_store.revoke_session(&token_hash).await?;
282 }
283
284 Ok(StatusCode::NO_CONTENT)
285}
286
287pub fn generate_secure_state() -> String {
290 use rand::RngCore;
291
292 let mut bytes = [0u8; 32];
294 rand::rngs::OsRng.fill_bytes(&mut bytes);
295
296 hex::encode(bytes)
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303
304 #[test]
305 fn test_generate_secure_state() {
306 let state1 = generate_secure_state();
307 let state2 = generate_secure_state();
308
309 assert_ne!(state1, state2);
311 assert!(!state1.is_empty());
313 assert!(!state2.is_empty());
314 assert_eq!(state1.len(), 64);
316 assert_eq!(state2.len(), 64);
317 assert!(hex::decode(&state1).is_ok());
319 assert!(hex::decode(&state2).is_ok());
320 }
321}