1use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::fs;
9use std::path::Path;
10use std::sync::RwLock;
11use std::time::{Duration, SystemTime};
12
13#[derive(Debug, Clone)]
18pub enum AuthError {
19 InvalidToken,
20 UserNotFound,
21 FileError(String),
22 ParseError(String),
23 NetworkError(String),
24}
25
26impl std::fmt::Display for AuthError {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 match self {
29 AuthError::InvalidToken => write!(f, "Invalid Google token"),
30 AuthError::UserNotFound => write!(f, "User not authorized"),
31 AuthError::FileError(msg) => write!(f, "File error: {}", msg),
32 AuthError::ParseError(msg) => write!(f, "Parse error: {}", msg),
33 AuthError::NetworkError(msg) => write!(f, "Network error: {}", msg),
34 }
35 }
36}
37
38impl std::error::Error for AuthError {}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct GoogleUser {
46 pub google_id: String,
47 pub email: String,
48 pub name: Option<String>,
49 pub picture: Option<String>,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct AuthorizedUsers {
54 pub users: Vec<GoogleUser>,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct GoogleTokenInfo {
59 pub sub: String, pub email: String,
61 pub email_verified: bool,
62 pub name: Option<String>,
63 pub picture: Option<String>,
64}
65
66pub struct GoogleAuthConfig {
71 pub users_file_path: String,
72 pub google_client_id: String,
73}
74
75impl GoogleAuthConfig {
76 pub fn new(users_file_path: String, google_client_id: String) -> Self {
77 Self {
78 users_file_path,
79 google_client_id,
80 }
81 }
82}
83
84pub fn load_authorized_users(path: &str) -> Result<AuthorizedUsers, AuthError> {
90 let file_path = Path::new(path);
91
92 if !file_path.exists() {
93 return Err(AuthError::FileError(format!("File not found: {}", path)));
94 }
95
96 let content = fs::read_to_string(file_path)
97 .map_err(|e| AuthError::FileError(e.to_string()))?;
98
99 let users: AuthorizedUsers = serde_json::from_str(&content)
100 .map_err(|e| AuthError::ParseError(e.to_string()))?;
101
102 Ok(users)
103}
104
105pub fn is_user_authorized(
107 google_id: &str,
108 email: &str,
109 authorized_users: &AuthorizedUsers
110) -> bool {
111 authorized_users.users.iter().any(|user| {
112 user.google_id == google_id || user.email == email
113 })
114}
115
116pub async fn verify_google_token(
118 id_token: &str,
119 _client_id: &str,
120) -> Result<GoogleTokenInfo, AuthError> {
121 let url = format!(
122 "https://oauth2.googleapis.com/tokeninfo?id_token={}",
123 id_token
124 );
125
126 let response = reqwest::get(&url)
127 .await
128 .map_err(|e| AuthError::NetworkError(e.to_string()))?;
129
130 if !response.status().is_success() {
131 return Err(AuthError::InvalidToken);
132 }
133
134 let token_info: GoogleTokenInfo = response
135 .json()
136 .await
137 .map_err(|e| AuthError::ParseError(e.to_string()))?;
138
139 Ok(token_info)
140}
141
142pub async fn authenticate_user(
144 id_token: &str,
145 config: &GoogleAuthConfig,
146) -> Result<GoogleUser, AuthError> {
147 let token_info = verify_google_token(id_token, &config.google_client_id).await?;
149
150 let authorized_users = load_authorized_users(&config.users_file_path)?;
152
153 if !is_user_authorized(&token_info.sub, &token_info.email, &authorized_users) {
155 return Err(AuthError::UserNotFound);
156 }
157
158 Ok(GoogleUser {
160 google_id: token_info.sub,
161 email: token_info.email,
162 name: token_info.name,
163 picture: token_info.picture,
164 })
165}
166
167#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct Session {
173 pub session_id: String,
174 pub user: GoogleUser,
175 pub created_at: SystemTime,
176 pub expires_at: SystemTime,
177 pub last_accessed: SystemTime,
178}
179
180impl Session {
181 pub fn new(session_id: String, user: GoogleUser, ttl_seconds: u64) -> Self {
182 let now = SystemTime::now();
183 Self {
184 session_id,
185 user,
186 created_at: now,
187 expires_at: now + Duration::from_secs(ttl_seconds),
188 last_accessed: now,
189 }
190 }
191
192 pub fn is_expired(&self) -> bool {
193 SystemTime::now() > self.expires_at
194 }
195
196 pub fn refresh(&mut self, ttl_seconds: u64) {
197 let now = SystemTime::now();
198 self.last_accessed = now;
199 self.expires_at = now + Duration::from_secs(ttl_seconds);
200 }
201}
202
203pub struct SessionStore {
204 sessions: RwLock<HashMap<String, Session>>,
205 ttl_seconds: u64,
206}
207
208impl SessionStore {
209 pub fn new(ttl_seconds: u64) -> Self {
210 Self {
211 sessions: RwLock::new(HashMap::new()),
212 ttl_seconds,
213 }
214 }
215
216 pub fn generate_session_id() -> String {
218 use std::time::UNIX_EPOCH;
219 let timestamp = SystemTime::now()
220 .duration_since(UNIX_EPOCH)
221 .unwrap()
222 .as_nanos();
223 format!("{:x}", timestamp)
224 }
225
226 pub fn create_session(&self, user: GoogleUser) -> Result<Session, AuthError> {
228 let session_id = Self::generate_session_id();
229 let session = Session::new(session_id.clone(), user, self.ttl_seconds);
230
231 let mut sessions = self.sessions.write()
232 .map_err(|_| AuthError::ParseError("Failed to acquire write lock".to_string()))?;
233
234 sessions.insert(session_id.clone(), session.clone());
235
236 Ok(session)
237 }
238
239 pub fn get_session(&self, session_id: &str) -> Result<Session, AuthError> {
241 let sessions = self.sessions.read()
242 .map_err(|_| AuthError::ParseError("Failed to acquire read lock".to_string()))?;
243
244 let session = sessions.get(session_id)
245 .ok_or(AuthError::InvalidToken)?;
246
247 if session.is_expired() {
248 drop(sessions);
249 self.remove_session(session_id)?;
250 return Err(AuthError::InvalidToken);
251 }
252
253 Ok(session.clone())
254 }
255
256 pub fn refresh_session(&self, session_id: &str) -> Result<Session, AuthError> {
258 let mut sessions = self.sessions.write()
259 .map_err(|_| AuthError::ParseError("Failed to acquire write lock".to_string()))?;
260
261 let session = sessions.get_mut(session_id)
262 .ok_or(AuthError::InvalidToken)?;
263
264 if session.is_expired() {
265 return Err(AuthError::InvalidToken);
266 }
267
268 session.refresh(self.ttl_seconds);
269 Ok(session.clone())
270 }
271
272 pub fn remove_session(&self, session_id: &str) -> Result<(), AuthError> {
274 let mut sessions = self.sessions.write()
275 .map_err(|_| AuthError::ParseError("Failed to acquire write lock".to_string()))?;
276
277 sessions.remove(session_id);
278 Ok(())
279 }
280
281 pub fn cleanup_expired(&self) -> Result<usize, AuthError> {
283 let mut sessions = self.sessions.write()
284 .map_err(|_| AuthError::ParseError("Failed to acquire write lock".to_string()))?;
285
286 let before_count = sessions.len();
287 sessions.retain(|_, session| !session.is_expired());
288 let removed = before_count - sessions.len();
289
290 Ok(removed)
291 }
292
293 pub fn get_user_sessions(&self, google_id: &str) -> Result<Vec<Session>, AuthError> {
295 let sessions = self.sessions.read()
296 .map_err(|_| AuthError::ParseError("Failed to acquire read lock".to_string()))?;
297
298 let user_sessions: Vec<Session> = sessions.values()
299 .filter(|s| s.user.google_id == google_id && !s.is_expired())
300 .cloned()
301 .collect();
302
303 Ok(user_sessions)
304 }
305
306 pub fn session_count(&self) -> usize {
308 self.sessions.read()
309 .map(|s| s.len())
310 .unwrap_or(0)
311 }
312}
313
314#[cfg(feature = "axum")]
319pub mod axum_middleware {
320 use super::*;
321 use axum::{
322 extract::{Request, State},
323 http::StatusCode,
324 middleware::Next,
325 response::{IntoResponse, Response},
326 };
327 use std::sync::Arc;
328
329 #[derive(Clone)]
330 pub struct AuthState {
331 pub config: Arc<GoogleAuthConfig>,
332 }
333
334 pub async fn google_auth_middleware(
335 State(auth_state): State<AuthState>,
336 mut request: Request,
337 next: Next,
338 ) -> Result<Response, AuthResponse> {
339 let auth_header = request
341 .headers()
342 .get("Authorization")
343 .and_then(|v| v.to_str().ok())
344 .ok_or(AuthResponse::Unauthorized("Missing Authorization header".into()))?;
345
346 let token = auth_header
347 .strip_prefix("Bearer ")
348 .ok_or(AuthResponse::Unauthorized("Invalid Authorization format".into()))?;
349
350 match authenticate_user(token, &auth_state.config).await {
352 Ok(user) => {
353 request.extensions_mut().insert(user);
355 Ok(next.run(request).await)
356 }
357 Err(e) => Err(AuthResponse::Unauthorized(e.to_string())),
358 }
359 }
360
361 pub enum AuthResponse {
362 Unauthorized(String),
363 }
364
365 impl IntoResponse for AuthResponse {
366 fn into_response(self) -> Response {
367 match self {
368 AuthResponse::Unauthorized(msg) => {
369 (StatusCode::UNAUTHORIZED, msg).into_response()
370 }
371 }
372 }
373 }
374}
375
376#[cfg(feature = "axum")]
381pub mod session_middleware {
382 use super::*;
383 use axum::{
384 extract::{Request, State},
385 http::{header, StatusCode},
386 middleware::Next,
387 response::{IntoResponse, Response},
388 };
389 use std::sync::Arc;
390
391 #[derive(Clone)]
392 pub struct SessionState {
393 pub store: Arc<SessionStore>,
394 pub config: Arc<GoogleAuthConfig>,
395 }
396
397 pub async fn session_auth_middleware(
399 State(session_state): State<SessionState>,
400 mut request: Request,
401 next: Next,
402 ) -> Result<Response, SessionAuthResponse> {
403 let session_id = request
405 .headers()
406 .get(header::COOKIE)
407 .and_then(|v| v.to_str().ok())
408 .and_then(|cookies| {
409 cookies.split(';')
410 .find_map(|cookie| {
411 let parts: Vec<&str> = cookie.trim().splitn(2, '=').collect();
412 if parts.len() == 2 && parts[0] == "session_id" {
413 Some(parts[1].to_string())
414 } else {
415 None
416 }
417 })
418 })
419 .ok_or(SessionAuthResponse::Unauthorized("No session cookie".into()))?;
420
421 match session_state.store.refresh_session(&session_id) {
423 Ok(session) => {
424 request.extensions_mut().insert(session.user.clone());
426 request.extensions_mut().insert(session);
427 Ok(next.run(request).await)
428 }
429 Err(_) => Err(SessionAuthResponse::Unauthorized("Invalid or expired session".into())),
430 }
431 }
432
433 pub enum SessionAuthResponse {
434 Unauthorized(String),
435 }
436
437 impl IntoResponse for SessionAuthResponse {
438 fn into_response(self) -> Response {
439 match self {
440 SessionAuthResponse::Unauthorized(msg) => {
441 (StatusCode::UNAUTHORIZED, msg).into_response()
442 }
443 }
444 }
445 }
446}
447
448pub fn create_sample_users_file(path: &str) -> Result<(), AuthError> {
454 let sample = AuthorizedUsers {
455 users: vec![
456 GoogleUser {
457 google_id: "123456789012345678901".to_string(),
458 email: "user@example.com".to_string(),
459 name: Some("Sample User".to_string()),
460 picture: None,
461 },
462 ],
463 };
464
465 let json = serde_json::to_string_pretty(&sample)
466 .map_err(|e| AuthError::ParseError(e.to_string()))?;
467
468 fs::write(path, json)
469 .map_err(|e| AuthError::FileError(e.to_string()))?;
470
471 Ok(())
472}
473
474#[cfg(test)]
479mod tests {
480 use super::*;
481
482 #[test]
483 fn test_load_authorized_users() {
484 let test_path = "test_users.json";
486 create_sample_users_file(test_path).unwrap();
487
488 let users = load_authorized_users(test_path).unwrap();
490 assert_eq!(users.users.len(), 1);
491 assert_eq!(users.users[0].email, "user@example.com");
492
493 let _ = fs::remove_file(test_path);
495 }
496
497 #[test]
498 fn test_is_user_authorized() {
499 let users = AuthorizedUsers {
500 users: vec![
501 GoogleUser {
502 google_id: "123".to_string(),
503 email: "test@example.com".to_string(),
504 name: None,
505 picture: None,
506 },
507 ],
508 };
509
510 assert!(is_user_authorized("123", "test@example.com", &users));
511 assert!(!is_user_authorized("999", "other@example.com", &users));
512 }
513
514 #[test]
515 fn test_session_lifecycle() {
516 let store = SessionStore::new(60);
517
518 let user = GoogleUser {
519 google_id: "123".to_string(),
520 email: "test@example.com".to_string(),
521 name: None,
522 picture: None,
523 };
524
525 let session = store.create_session(user).unwrap();
527 assert!(!session.is_expired());
528
529 let retrieved = store.get_session(&session.session_id).unwrap();
531 assert_eq!(retrieved.user.email, "test@example.com");
532
533 store.remove_session(&session.session_id).unwrap();
535 assert!(store.get_session(&session.session_id).is_err());
536 }
537}