google_auth_middleware/
lib.rs

1// ============================================================
2// Google Auth Middleware Crate
3// lib.rs - Complete and Ready to Compile
4// ============================================================
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::fs;
9use std::path::Path;
10use std::sync::RwLock;
11use std::time::{Duration, SystemTime};
12
13// ============================================================
14// Error Types
15// ============================================================
16
17#[derive(Debug, Clone)]
18pub enum AuthError {
19    InvalidToken,
20    UserNotFound,
21    FileError(String),
22    ParseError(String),
23    NetworkError(String),
24}
25
26impl std::fmt::Display for AuthError {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        match self {
29            AuthError::InvalidToken => write!(f, "Invalid Google token"),
30            AuthError::UserNotFound => write!(f, "User not authorized"),
31            AuthError::FileError(msg) => write!(f, "File error: {}", msg),
32            AuthError::ParseError(msg) => write!(f, "Parse error: {}", msg),
33            AuthError::NetworkError(msg) => write!(f, "Network error: {}", msg),
34        }
35    }
36}
37
38impl std::error::Error for AuthError {}
39
40// ============================================================
41// Data Structures
42// ============================================================
43
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct GoogleUser {
46    pub google_id: String,
47    pub email: String,
48    pub name: Option<String>,
49    pub picture: Option<String>,
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct AuthorizedUsers {
54    pub users: Vec<GoogleUser>,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct GoogleTokenInfo {
59    pub sub: String,           // Google ID
60    pub email: String,
61    pub email_verified: bool,
62    pub name: Option<String>,
63    pub picture: Option<String>,
64}
65
66// ============================================================
67// Auth Config
68// ============================================================
69
70pub struct GoogleAuthConfig {
71    pub users_file_path: String,
72    pub google_client_id: String,
73}
74
75impl GoogleAuthConfig {
76    pub fn new(users_file_path: String, google_client_id: String) -> Self {
77        Self {
78            users_file_path,
79            google_client_id,
80        }
81    }
82}
83
84// ============================================================
85// Core Auth Functions
86// ============================================================
87
88/// Load authorized users from JSON file
89pub fn load_authorized_users(path: &str) -> Result<AuthorizedUsers, AuthError> {
90    let file_path = Path::new(path);
91    
92    if !file_path.exists() {
93        return Err(AuthError::FileError(format!("File not found: {}", path)));
94    }
95
96    let content = fs::read_to_string(file_path)
97        .map_err(|e| AuthError::FileError(e.to_string()))?;
98
99    let users: AuthorizedUsers = serde_json::from_str(&content)
100        .map_err(|e| AuthError::ParseError(e.to_string()))?;
101
102    Ok(users)
103}
104
105/// Check if user is authorized
106pub fn is_user_authorized(
107    google_id: &str,
108    email: &str,
109    authorized_users: &AuthorizedUsers
110) -> bool {
111    authorized_users.users.iter().any(|user| {
112        user.google_id == google_id || user.email == email
113    })
114}
115
116/// Verify Google ID token and get user info
117pub async fn verify_google_token(
118    id_token: &str,
119    _client_id: &str,
120) -> Result<GoogleTokenInfo, AuthError> {
121    let url = format!(
122        "https://oauth2.googleapis.com/tokeninfo?id_token={}",
123        id_token
124    );
125
126    let response = reqwest::get(&url)
127        .await
128        .map_err(|e| AuthError::NetworkError(e.to_string()))?;
129
130    if !response.status().is_success() {
131        return Err(AuthError::InvalidToken);
132    }
133
134    let token_info: GoogleTokenInfo = response
135        .json()
136        .await
137        .map_err(|e| AuthError::ParseError(e.to_string()))?;
138
139    Ok(token_info)
140}
141
142/// Main authentication function
143pub async fn authenticate_user(
144    id_token: &str,
145    config: &GoogleAuthConfig,
146) -> Result<GoogleUser, AuthError> {
147    // 1. Verify token with Google
148    let token_info = verify_google_token(id_token, &config.google_client_id).await?;
149
150    // 2. Load authorized users
151    let authorized_users = load_authorized_users(&config.users_file_path)?;
152
153    // 3. Check if user is authorized
154    if !is_user_authorized(&token_info.sub, &token_info.email, &authorized_users) {
155        return Err(AuthError::UserNotFound);
156    }
157
158    // 4. Return authorized user info
159    Ok(GoogleUser {
160        google_id: token_info.sub,
161        email: token_info.email,
162        name: token_info.name,
163        picture: token_info.picture,
164    })
165}
166
167// ============================================================
168// Session Management
169// ============================================================
170
171#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct Session {
173    pub session_id: String,
174    pub user: GoogleUser,
175    pub created_at: SystemTime,
176    pub expires_at: SystemTime,
177    pub last_accessed: SystemTime,
178}
179
180impl Session {
181    pub fn new(session_id: String, user: GoogleUser, ttl_seconds: u64) -> Self {
182        let now = SystemTime::now();
183        Self {
184            session_id,
185            user,
186            created_at: now,
187            expires_at: now + Duration::from_secs(ttl_seconds),
188            last_accessed: now,
189        }
190    }
191
192    pub fn is_expired(&self) -> bool {
193        SystemTime::now() > self.expires_at
194    }
195
196    pub fn refresh(&mut self, ttl_seconds: u64) {
197        let now = SystemTime::now();
198        self.last_accessed = now;
199        self.expires_at = now + Duration::from_secs(ttl_seconds);
200    }
201}
202
203pub struct SessionStore {
204    sessions: RwLock<HashMap<String, Session>>,
205    ttl_seconds: u64,
206}
207
208impl SessionStore {
209    pub fn new(ttl_seconds: u64) -> Self {
210        Self {
211            sessions: RwLock::new(HashMap::new()),
212            ttl_seconds,
213        }
214    }
215
216    /// Generate a unique session ID
217    pub fn generate_session_id() -> String {
218        use std::time::UNIX_EPOCH;
219        let timestamp = SystemTime::now()
220            .duration_since(UNIX_EPOCH)
221            .unwrap()
222            .as_nanos();
223        format!("{:x}", timestamp)
224    }
225
226    /// Create a new session
227    pub fn create_session(&self, user: GoogleUser) -> Result<Session, AuthError> {
228        let session_id = Self::generate_session_id();
229        let session = Session::new(session_id.clone(), user, self.ttl_seconds);
230
231        let mut sessions = self.sessions.write()
232            .map_err(|_| AuthError::ParseError("Failed to acquire write lock".to_string()))?;
233        
234        sessions.insert(session_id.clone(), session.clone());
235        
236        Ok(session)
237    }
238
239    /// Get session by ID
240    pub fn get_session(&self, session_id: &str) -> Result<Session, AuthError> {
241        let sessions = self.sessions.read()
242            .map_err(|_| AuthError::ParseError("Failed to acquire read lock".to_string()))?;
243
244        let session = sessions.get(session_id)
245            .ok_or(AuthError::InvalidToken)?;
246
247        if session.is_expired() {
248            drop(sessions);
249            self.remove_session(session_id)?;
250            return Err(AuthError::InvalidToken);
251        }
252
253        Ok(session.clone())
254    }
255
256    /// Refresh session (extend expiration)
257    pub fn refresh_session(&self, session_id: &str) -> Result<Session, AuthError> {
258        let mut sessions = self.sessions.write()
259            .map_err(|_| AuthError::ParseError("Failed to acquire write lock".to_string()))?;
260
261        let session = sessions.get_mut(session_id)
262            .ok_or(AuthError::InvalidToken)?;
263
264        if session.is_expired() {
265            return Err(AuthError::InvalidToken);
266        }
267
268        session.refresh(self.ttl_seconds);
269        Ok(session.clone())
270    }
271
272    /// Remove session (logout)
273    pub fn remove_session(&self, session_id: &str) -> Result<(), AuthError> {
274        let mut sessions = self.sessions.write()
275            .map_err(|_| AuthError::ParseError("Failed to acquire write lock".to_string()))?;
276
277        sessions.remove(session_id);
278        Ok(())
279    }
280
281    /// Clean up expired sessions
282    pub fn cleanup_expired(&self) -> Result<usize, AuthError> {
283        let mut sessions = self.sessions.write()
284            .map_err(|_| AuthError::ParseError("Failed to acquire write lock".to_string()))?;
285
286        let before_count = sessions.len();
287        sessions.retain(|_, session| !session.is_expired());
288        let removed = before_count - sessions.len();
289
290        Ok(removed)
291    }
292
293    /// Get all active sessions for a user
294    pub fn get_user_sessions(&self, google_id: &str) -> Result<Vec<Session>, AuthError> {
295        let sessions = self.sessions.read()
296            .map_err(|_| AuthError::ParseError("Failed to acquire read lock".to_string()))?;
297
298        let user_sessions: Vec<Session> = sessions.values()
299            .filter(|s| s.user.google_id == google_id && !s.is_expired())
300            .cloned()
301            .collect();
302
303        Ok(user_sessions)
304    }
305
306    /// Get session count
307    pub fn session_count(&self) -> usize {
308        self.sessions.read()
309            .map(|s| s.len())
310            .unwrap_or(0)
311    }
312}
313
314// ============================================================
315// Axum Middleware (Optional)
316// ============================================================
317
318#[cfg(feature = "axum")]
319pub mod axum_middleware {
320    use super::*;
321    use axum::{
322        extract::{Request, State},
323        http::StatusCode,
324        middleware::Next,
325        response::{IntoResponse, Response},
326    };
327    use std::sync::Arc;
328
329    #[derive(Clone)]
330    pub struct AuthState {
331        pub config: Arc<GoogleAuthConfig>,
332    }
333
334    pub async fn google_auth_middleware(
335        State(auth_state): State<AuthState>,
336        mut request: Request,
337        next: Next,
338    ) -> Result<Response, AuthResponse> {
339        // Extract token from Authorization header
340        let auth_header = request
341            .headers()
342            .get("Authorization")
343            .and_then(|v| v.to_str().ok())
344            .ok_or(AuthResponse::Unauthorized("Missing Authorization header".into()))?;
345
346        let token = auth_header
347            .strip_prefix("Bearer ")
348            .ok_or(AuthResponse::Unauthorized("Invalid Authorization format".into()))?;
349
350        // Authenticate user
351        match authenticate_user(token, &auth_state.config).await {
352            Ok(user) => {
353                // Add user info to request extensions
354                request.extensions_mut().insert(user);
355                Ok(next.run(request).await)
356            }
357            Err(e) => Err(AuthResponse::Unauthorized(e.to_string())),
358        }
359    }
360
361    pub enum AuthResponse {
362        Unauthorized(String),
363    }
364
365    impl IntoResponse for AuthResponse {
366        fn into_response(self) -> Response {
367            match self {
368                AuthResponse::Unauthorized(msg) => {
369                    (StatusCode::UNAUTHORIZED, msg).into_response()
370                }
371            }
372        }
373    }
374}
375
376// ============================================================
377// Axum Session Middleware
378// ============================================================
379
380#[cfg(feature = "axum")]
381pub mod session_middleware {
382    use super::*;
383    use axum::{
384        extract::{Request, State},
385        http::{header, StatusCode},
386        middleware::Next,
387        response::{IntoResponse, Response},
388    };
389    use std::sync::Arc;
390
391    #[derive(Clone)]
392    pub struct SessionState {
393        pub store: Arc<SessionStore>,
394        pub config: Arc<GoogleAuthConfig>,
395    }
396
397    /// Middleware that checks for valid session cookie
398    pub async fn session_auth_middleware(
399        State(session_state): State<SessionState>,
400        mut request: Request,
401        next: Next,
402    ) -> Result<Response, SessionAuthResponse> {
403        // Try to get session ID from cookie
404        let session_id = request
405            .headers()
406            .get(header::COOKIE)
407            .and_then(|v| v.to_str().ok())
408            .and_then(|cookies| {
409                cookies.split(';')
410                    .find_map(|cookie| {
411                        let parts: Vec<&str> = cookie.trim().splitn(2, '=').collect();
412                        if parts.len() == 2 && parts[0] == "session_id" {
413                            Some(parts[1].to_string())
414                        } else {
415                            None
416                        }
417                    })
418            })
419            .ok_or(SessionAuthResponse::Unauthorized("No session cookie".into()))?;
420
421        // Get and refresh session
422        match session_state.store.refresh_session(&session_id) {
423            Ok(session) => {
424                // Add user and session to request extensions
425                request.extensions_mut().insert(session.user.clone());
426                request.extensions_mut().insert(session);
427                Ok(next.run(request).await)
428            }
429            Err(_) => Err(SessionAuthResponse::Unauthorized("Invalid or expired session".into())),
430        }
431    }
432
433    pub enum SessionAuthResponse {
434        Unauthorized(String),
435    }
436
437    impl IntoResponse for SessionAuthResponse {
438        fn into_response(self) -> Response {
439            match self {
440                SessionAuthResponse::Unauthorized(msg) => {
441                    (StatusCode::UNAUTHORIZED, msg).into_response()
442                }
443            }
444        }
445    }
446}
447
448// ============================================================
449// Helper Functions
450// ============================================================
451
452/// Create a sample users.json file
453pub fn create_sample_users_file(path: &str) -> Result<(), AuthError> {
454    let sample = AuthorizedUsers {
455        users: vec![
456            GoogleUser {
457                google_id: "123456789012345678901".to_string(),
458                email: "user@example.com".to_string(),
459                name: Some("Sample User".to_string()),
460                picture: None,
461            },
462        ],
463    };
464
465    let json = serde_json::to_string_pretty(&sample)
466        .map_err(|e| AuthError::ParseError(e.to_string()))?;
467
468    fs::write(path, json)
469        .map_err(|e| AuthError::FileError(e.to_string()))?;
470
471    Ok(())
472}
473
474// ============================================================
475// Tests
476// ============================================================
477
478#[cfg(test)]
479mod tests {
480    use super::*;
481
482    #[test]
483    fn test_load_authorized_users() {
484        // Create test file
485        let test_path = "test_users.json";
486        create_sample_users_file(test_path).unwrap();
487
488        // Load users
489        let users = load_authorized_users(test_path).unwrap();
490        assert_eq!(users.users.len(), 1);
491        assert_eq!(users.users[0].email, "user@example.com");
492
493        // Cleanup
494        let _ = fs::remove_file(test_path);
495    }
496
497    #[test]
498    fn test_is_user_authorized() {
499        let users = AuthorizedUsers {
500            users: vec![
501                GoogleUser {
502                    google_id: "123".to_string(),
503                    email: "test@example.com".to_string(),
504                    name: None,
505                    picture: None,
506                },
507            ],
508        };
509
510        assert!(is_user_authorized("123", "test@example.com", &users));
511        assert!(!is_user_authorized("999", "other@example.com", &users));
512    }
513
514    #[test]
515    fn test_session_lifecycle() {
516        let store = SessionStore::new(60);
517        
518        let user = GoogleUser {
519            google_id: "123".to_string(),
520            email: "test@example.com".to_string(),
521            name: None,
522            picture: None,
523        };
524
525        // Create session
526        let session = store.create_session(user).unwrap();
527        assert!(!session.is_expired());
528
529        // Get session
530        let retrieved = store.get_session(&session.session_id).unwrap();
531        assert_eq!(retrieved.user.email, "test@example.com");
532
533        // Remove session
534        store.remove_session(&session.session_id).unwrap();
535        assert!(store.get_session(&session.session_id).is_err());
536    }
537}