Skip to main content

corevpn_auth/
session.rs

1//! Authentication Session Management
2
3use std::collections::HashMap;
4use std::time::{Duration, SystemTime, UNIX_EPOCH};
5
6use chrono::{DateTime, Utc};
7use parking_lot::RwLock;
8use serde::{Deserialize, Serialize};
9use uuid::Uuid;
10use tracing::{debug, warn};
11
12use crate::{AuthError, AuthState, Result, TokenSet, UserInfo};
13
14/// Rate limiter entry
15#[derive(Clone)]
16struct RateLimitEntry {
17    count: u32,
18    reset_at: SystemTime,
19}
20
21/// In-memory rate limiter for brute force protection
22pub struct RateLimiter {
23    /// Attempts per window
24    max_attempts: u32,
25    /// Window duration
26    window: Duration,
27    /// Entries by key (IP address, email, etc.)
28    entries: RwLock<HashMap<String, RateLimitEntry>>,
29}
30
31impl RateLimiter {
32    /// Create a new rate limiter
33    pub fn new(max_attempts: u32, window: Duration) -> Self {
34        Self {
35            max_attempts,
36            window,
37            entries: RwLock::new(HashMap::new()),
38        }
39    }
40
41    /// Check if key is rate limited
42    pub fn check(&self, key: &str) -> bool {
43        let mut entries = self.entries.write();
44        let now = SystemTime::now();
45
46        // Clean up expired entries
47        entries.retain(|_, entry| entry.reset_at > now);
48
49        let entry = entries.entry(key.to_string()).or_insert_with(|| {
50            RateLimitEntry {
51                count: 0,
52                reset_at: now + self.window,
53            }
54        });
55
56        // Check if reset time has passed
57        if now >= entry.reset_at {
58            entry.count = 0;
59            entry.reset_at = now + self.window;
60        }
61
62        entry.count += 1;
63        let allowed = entry.count <= self.max_attempts;
64
65        if !allowed {
66            warn!("Rate limit exceeded for key: {}", key);
67        }
68
69        allowed
70    }
71
72    /// Reset rate limit for a key
73    pub fn reset(&self, key: &str) {
74        self.entries.write().remove(key);
75    }
76
77    /// Clean up expired entries
78    pub fn cleanup(&self) {
79        let now = SystemTime::now();
80        self.entries.write().retain(|_, entry| entry.reset_at > now);
81    }
82}
83
84/// Authentication session
85#[derive(Debug, Clone, Serialize, Deserialize)]
86pub struct AuthSession {
87    /// Session ID
88    pub id: String,
89    /// Authentication state (for OAuth2 flow)
90    pub auth_state: Option<AuthState>,
91    /// Token set (after successful authentication)
92    pub tokens: Option<TokenSet>,
93    /// User information
94    pub user_info: Option<UserInfo>,
95    /// Provider name
96    pub provider: String,
97    /// Session creation time
98    pub created_at: DateTime<Utc>,
99    /// Session expiration time
100    pub expires_at: DateTime<Utc>,
101    /// Last activity time
102    pub last_activity: DateTime<Utc>,
103    /// Associated VPN session ID (if connected)
104    pub vpn_session_id: Option<String>,
105    /// IP address of the client
106    pub client_ip: Option<String>,
107    /// Additional metadata
108    pub metadata: HashMap<String, String>,
109}
110
111impl AuthSession {
112    /// Create a new auth session
113    pub fn new(provider: &str, lifetime: Duration) -> Self {
114        let now = Utc::now();
115        let auth_state = AuthState::new(Duration::from_secs(600)); // 10 min for OAuth flow
116
117        Self {
118            id: Uuid::new_v4().to_string(),
119            auth_state: Some(auth_state),
120            tokens: None,
121            user_info: None,
122            provider: provider.to_string(),
123            created_at: now,
124            expires_at: now + chrono::Duration::from_std(lifetime)
125                .unwrap_or_else(|_| chrono::Duration::seconds(86400)), // Fallback to 24 hours
126            last_activity: now,
127            vpn_session_id: None,
128            client_ip: None,
129            metadata: HashMap::new(),
130        }
131    }
132
133    /// Check if session is expired
134    pub fn is_expired(&self) -> bool {
135        Utc::now() > self.expires_at
136    }
137
138    /// Check if session is authenticated
139    pub fn is_authenticated(&self) -> bool {
140        self.tokens.is_some() && self.user_info.is_some()
141    }
142
143    /// Check if tokens need refresh
144    pub fn needs_token_refresh(&self) -> bool {
145        if let Some(tokens) = &self.tokens {
146            tokens.expires_within(chrono::Duration::minutes(5))
147        } else {
148            false
149        }
150    }
151
152    /// Update tokens
153    pub fn update_tokens(&mut self, tokens: TokenSet) {
154        self.tokens = Some(tokens);
155        self.last_activity = Utc::now();
156    }
157
158    /// Update user info
159    pub fn update_user_info(&mut self, user_info: UserInfo) {
160        self.user_info = Some(user_info);
161        self.last_activity = Utc::now();
162    }
163
164    /// Mark authentication complete
165    pub fn complete_auth(&mut self, tokens: TokenSet, user_info: UserInfo) {
166        self.tokens = Some(tokens);
167        self.user_info = Some(user_info);
168        self.auth_state = None; // Clear auth state after successful auth
169        self.last_activity = Utc::now();
170    }
171
172    /// Associate with VPN session
173    pub fn associate_vpn_session(&mut self, vpn_session_id: &str) {
174        self.vpn_session_id = Some(vpn_session_id.to_string());
175        self.last_activity = Utc::now();
176    }
177
178    /// Extend session lifetime
179    pub fn extend(&mut self, duration: Duration) {
180        self.expires_at = Utc::now() + chrono::Duration::from_std(duration).unwrap();
181    }
182
183    /// Touch session (update last activity)
184    pub fn touch(&mut self) {
185        self.last_activity = Utc::now();
186    }
187
188    /// Get session duration
189    pub fn duration(&self) -> chrono::Duration {
190        Utc::now() - self.created_at
191    }
192
193    /// Get idle time
194    pub fn idle_time(&self) -> chrono::Duration {
195        Utc::now() - self.last_activity
196    }
197
198    /// Get the OAuth2 state value
199    pub fn state(&self) -> Option<&str> {
200        self.auth_state.as_ref().map(|s| s.state.as_str())
201    }
202
203    /// Get user email
204    pub fn email(&self) -> Option<&str> {
205        self.user_info.as_ref().and_then(|u| u.email.as_deref())
206    }
207
208    /// Get user display name
209    pub fn display_name(&self) -> Option<&str> {
210        self.user_info.as_ref().and_then(|u| u.name.as_deref())
211    }
212}
213
214/// Authentication session manager
215pub struct AuthSessionManager {
216    /// Sessions by ID
217    sessions: RwLock<HashMap<String, AuthSession>>,
218    /// Sessions by OAuth2 state
219    sessions_by_state: RwLock<HashMap<String, String>>,
220    /// Default session lifetime
221    default_lifetime: Duration,
222    /// Maximum sessions per user
223    max_sessions_per_user: usize,
224    /// Rate limiter for session lookups
225    lookup_rate_limiter: RateLimiter,
226}
227
228impl AuthSessionManager {
229    /// Create a new session manager
230    pub fn new(default_lifetime: Duration, max_sessions_per_user: usize) -> Self {
231        Self {
232            sessions: RwLock::new(HashMap::new()),
233            sessions_by_state: RwLock::new(HashMap::new()),
234            default_lifetime,
235            max_sessions_per_user,
236            lookup_rate_limiter: RateLimiter::new(100, Duration::from_secs(60)), // 100 lookups per minute
237        }
238    }
239
240    /// Create a new session
241    pub fn create_session(&self, provider: &str) -> AuthSession {
242        let session = AuthSession::new(provider, self.default_lifetime);
243
244        // Store session
245        let mut sessions = self.sessions.write();
246        let mut by_state = self.sessions_by_state.write();
247
248        if let Some(state) = session.state() {
249            by_state.insert(state.to_string(), session.id.clone());
250        }
251        sessions.insert(session.id.clone(), session.clone());
252
253        session
254    }
255
256    /// Get session by ID (with rate limiting)
257    pub fn get_session(&self, id: &str, client_ip: Option<&str>) -> Option<AuthSession> {
258        // Rate limit lookups by IP or session ID
259        let rate_limit_key = client_ip.unwrap_or(id);
260        if !self.lookup_rate_limiter.check(rate_limit_key) {
261            warn!("Rate limit exceeded for session lookup: {}", rate_limit_key);
262            return None;
263        }
264
265        // Verify session ID is a valid UUID v4
266        if Uuid::parse_str(id).map(|u| u.get_version() != Some(uuid::Version::Random))
267            .unwrap_or(true) {
268            warn!("Invalid session ID format: {}", id);
269            return None;
270        }
271
272        self.sessions.read().get(id).cloned()
273    }
274
275    /// Get session by OAuth2 state (with rate limiting)
276    pub fn get_session_by_state(&self, state: &str, client_ip: Option<&str>) -> Option<AuthSession> {
277        // Rate limit lookups
278        let rate_limit_key = client_ip.unwrap_or(state);
279        if !self.lookup_rate_limiter.check(rate_limit_key) {
280            warn!("Rate limit exceeded for state lookup: {}", rate_limit_key);
281            return None;
282        }
283
284        let session_id = self.sessions_by_state.read().get(state)?.clone();
285        self.get_session(&session_id, client_ip)
286    }
287
288    /// Update session
289    pub fn update_session(&self, session: &AuthSession) -> Result<()> {
290        let mut sessions = self.sessions.write();
291        if sessions.contains_key(&session.id) {
292            sessions.insert(session.id.clone(), session.clone());
293            Ok(())
294        } else {
295            Err(AuthError::SessionNotFound)
296        }
297    }
298
299    /// Remove session
300    pub fn remove_session(&self, id: &str) -> Option<AuthSession> {
301        let mut sessions = self.sessions.write();
302        let mut by_state = self.sessions_by_state.write();
303
304        if let Some(session) = sessions.remove(id) {
305            if let Some(state) = session.state() {
306                by_state.remove(state);
307            }
308            Some(session)
309        } else {
310            None
311        }
312    }
313
314    /// Get all sessions for a user (by email) - with rate limiting
315    pub fn get_user_sessions(&self, email: &str, client_ip: Option<&str>) -> Vec<AuthSession> {
316        // Rate limit user session lookups
317        let rate_limit_key = client_ip.unwrap_or(email);
318        if !self.lookup_rate_limiter.check(rate_limit_key) {
319            warn!("Rate limit exceeded for user session lookup: {}", rate_limit_key);
320            return Vec::new();
321        }
322
323        self.sessions
324            .read()
325            .values()
326            .filter(|s| s.email() == Some(email))
327            .cloned()
328            .collect()
329    }
330
331    /// Remove all sessions for a user
332    pub fn remove_user_sessions(&self, email: &str) -> usize {
333        let mut sessions = self.sessions.write();
334        let mut by_state = self.sessions_by_state.write();
335
336        let to_remove: Vec<_> = sessions
337            .iter()
338            .filter(|(_, s)| s.email() == Some(email))
339            .map(|(id, s)| (id.clone(), s.state().map(String::from)))
340            .collect();
341
342        for (id, state) in &to_remove {
343            sessions.remove(id);
344            if let Some(s) = state {
345                by_state.remove(s);
346            }
347        }
348
349        to_remove.len()
350    }
351
352    /// Cleanup expired sessions
353    pub fn cleanup_expired(&self) -> usize {
354        let mut sessions = self.sessions.write();
355        let mut by_state = self.sessions_by_state.write();
356
357        let before = sessions.len();
358
359        let expired: Vec<_> = sessions
360            .iter()
361            .filter(|(_, s)| s.is_expired())
362            .map(|(id, s)| (id.clone(), s.state().map(String::from)))
363            .collect();
364
365        for (id, state) in &expired {
366            sessions.remove(id);
367            if let Some(s) = state {
368                by_state.remove(s);
369            }
370        }
371
372        before - sessions.len()
373    }
374
375    /// Get session count
376    pub fn session_count(&self) -> usize {
377        self.sessions.read().len()
378    }
379
380    /// Get all active sessions
381    pub fn active_sessions(&self) -> Vec<AuthSession> {
382        self.sessions
383            .read()
384            .values()
385            .filter(|s| !s.is_expired() && s.is_authenticated())
386            .cloned()
387            .collect()
388    }
389}
390
391impl Default for AuthSessionManager {
392    fn default() -> Self {
393        Self::new(Duration::from_secs(86400), 5) // 24 hours, 5 sessions per user
394    }
395}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400
401    #[test]
402    fn test_auth_session() {
403        let mut session = AuthSession::new("google", Duration::from_secs(3600));
404
405        assert!(!session.is_expired());
406        assert!(!session.is_authenticated());
407        assert!(session.state().is_some());
408    }
409
410    #[test]
411    fn test_session_manager() {
412        let manager = AuthSessionManager::default();
413
414        let session = manager.create_session("google");
415        let state = session.state().unwrap().to_string();
416
417        // Get by ID
418        let found = manager.get_session(&session.id, None);
419        assert!(found.is_some());
420
421        // Get by state
422        let found = manager.get_session_by_state(&state, None);
423        assert!(found.is_some());
424
425        // Remove
426        manager.remove_session(&session.id);
427        assert!(manager.get_session(&session.id, None).is_none());
428    }
429
430    #[test]
431    fn test_session_lifecycle() {
432        let manager = AuthSessionManager::default();
433
434        let mut session = manager.create_session("google");
435
436        // Simulate successful auth
437        let tokens = TokenSet {
438            access_token: "test-token".to_string(),
439            refresh_token: Some("refresh-token".to_string()),
440            id_token: None,
441            expires_at: Utc::now() + chrono::Duration::hours(1),
442            token_type: "Bearer".to_string(),
443            scopes: vec![],
444        };
445
446        let user_info = UserInfo {
447            sub: "user123".to_string(),
448            email: Some("user@example.com".to_string()),
449            email_verified: true,
450            name: Some("Test User".to_string()),
451            given_name: None,
452            family_name: None,
453            picture: None,
454            groups: vec![],
455            provider: "google".to_string(),
456        };
457
458        session.complete_auth(tokens, user_info);
459
460        assert!(session.is_authenticated());
461        assert!(session.auth_state.is_none()); // Cleared after auth
462        assert_eq!(session.email(), Some("user@example.com"));
463    }
464}