corevpn_auth/
session.rs

1//! Authentication Session Management
2
3use std::collections::HashMap;
4use std::time::Duration;
5
6use chrono::{DateTime, Utc};
7use parking_lot::RwLock;
8use serde::{Deserialize, Serialize};
9use uuid::Uuid;
10
11use crate::{AuthError, AuthState, Result, TokenSet, UserInfo};
12
13/// Authentication session
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct AuthSession {
16    /// Session ID
17    pub id: String,
18    /// Authentication state (for OAuth2 flow)
19    pub auth_state: Option<AuthState>,
20    /// Token set (after successful authentication)
21    pub tokens: Option<TokenSet>,
22    /// User information
23    pub user_info: Option<UserInfo>,
24    /// Provider name
25    pub provider: String,
26    /// Session creation time
27    pub created_at: DateTime<Utc>,
28    /// Session expiration time
29    pub expires_at: DateTime<Utc>,
30    /// Last activity time
31    pub last_activity: DateTime<Utc>,
32    /// Associated VPN session ID (if connected)
33    pub vpn_session_id: Option<String>,
34    /// IP address of the client
35    pub client_ip: Option<String>,
36    /// Additional metadata
37    pub metadata: HashMap<String, String>,
38}
39
40impl AuthSession {
41    /// Create a new auth session
42    pub fn new(provider: &str, lifetime: Duration) -> Self {
43        let now = Utc::now();
44        let auth_state = AuthState::new(Duration::from_secs(600)); // 10 min for OAuth flow
45
46        Self {
47            id: Uuid::new_v4().to_string(),
48            auth_state: Some(auth_state),
49            tokens: None,
50            user_info: None,
51            provider: provider.to_string(),
52            created_at: now,
53            expires_at: now + chrono::Duration::from_std(lifetime).unwrap(),
54            last_activity: now,
55            vpn_session_id: None,
56            client_ip: None,
57            metadata: HashMap::new(),
58        }
59    }
60
61    /// Check if session is expired
62    pub fn is_expired(&self) -> bool {
63        Utc::now() > self.expires_at
64    }
65
66    /// Check if session is authenticated
67    pub fn is_authenticated(&self) -> bool {
68        self.tokens.is_some() && self.user_info.is_some()
69    }
70
71    /// Check if tokens need refresh
72    pub fn needs_token_refresh(&self) -> bool {
73        if let Some(tokens) = &self.tokens {
74            tokens.expires_within(chrono::Duration::minutes(5))
75        } else {
76            false
77        }
78    }
79
80    /// Update tokens
81    pub fn update_tokens(&mut self, tokens: TokenSet) {
82        self.tokens = Some(tokens);
83        self.last_activity = Utc::now();
84    }
85
86    /// Update user info
87    pub fn update_user_info(&mut self, user_info: UserInfo) {
88        self.user_info = Some(user_info);
89        self.last_activity = Utc::now();
90    }
91
92    /// Mark authentication complete
93    pub fn complete_auth(&mut self, tokens: TokenSet, user_info: UserInfo) {
94        self.tokens = Some(tokens);
95        self.user_info = Some(user_info);
96        self.auth_state = None; // Clear auth state after successful auth
97        self.last_activity = Utc::now();
98    }
99
100    /// Associate with VPN session
101    pub fn associate_vpn_session(&mut self, vpn_session_id: &str) {
102        self.vpn_session_id = Some(vpn_session_id.to_string());
103        self.last_activity = Utc::now();
104    }
105
106    /// Extend session lifetime
107    pub fn extend(&mut self, duration: Duration) {
108        self.expires_at = Utc::now() + chrono::Duration::from_std(duration).unwrap();
109    }
110
111    /// Touch session (update last activity)
112    pub fn touch(&mut self) {
113        self.last_activity = Utc::now();
114    }
115
116    /// Get session duration
117    pub fn duration(&self) -> chrono::Duration {
118        Utc::now() - self.created_at
119    }
120
121    /// Get idle time
122    pub fn idle_time(&self) -> chrono::Duration {
123        Utc::now() - self.last_activity
124    }
125
126    /// Get the OAuth2 state value
127    pub fn state(&self) -> Option<&str> {
128        self.auth_state.as_ref().map(|s| s.state.as_str())
129    }
130
131    /// Get user email
132    pub fn email(&self) -> Option<&str> {
133        self.user_info.as_ref().and_then(|u| u.email.as_deref())
134    }
135
136    /// Get user display name
137    pub fn display_name(&self) -> Option<&str> {
138        self.user_info.as_ref().and_then(|u| u.name.as_deref())
139    }
140}
141
142/// Authentication session manager
143pub struct AuthSessionManager {
144    /// Sessions by ID
145    sessions: RwLock<HashMap<String, AuthSession>>,
146    /// Sessions by OAuth2 state
147    sessions_by_state: RwLock<HashMap<String, String>>,
148    /// Default session lifetime
149    default_lifetime: Duration,
150    /// Maximum sessions per user
151    max_sessions_per_user: usize,
152}
153
154impl AuthSessionManager {
155    /// Create a new session manager
156    pub fn new(default_lifetime: Duration, max_sessions_per_user: usize) -> Self {
157        Self {
158            sessions: RwLock::new(HashMap::new()),
159            sessions_by_state: RwLock::new(HashMap::new()),
160            default_lifetime,
161            max_sessions_per_user,
162        }
163    }
164
165    /// Create a new session
166    pub fn create_session(&self, provider: &str) -> AuthSession {
167        let session = AuthSession::new(provider, self.default_lifetime);
168
169        // Store session
170        let mut sessions = self.sessions.write();
171        let mut by_state = self.sessions_by_state.write();
172
173        if let Some(state) = session.state() {
174            by_state.insert(state.to_string(), session.id.clone());
175        }
176        sessions.insert(session.id.clone(), session.clone());
177
178        session
179    }
180
181    /// Get session by ID
182    pub fn get_session(&self, id: &str) -> Option<AuthSession> {
183        self.sessions.read().get(id).cloned()
184    }
185
186    /// Get session by OAuth2 state
187    pub fn get_session_by_state(&self, state: &str) -> Option<AuthSession> {
188        let session_id = self.sessions_by_state.read().get(state)?.clone();
189        self.get_session(&session_id)
190    }
191
192    /// Update session
193    pub fn update_session(&self, session: &AuthSession) -> Result<()> {
194        let mut sessions = self.sessions.write();
195        if sessions.contains_key(&session.id) {
196            sessions.insert(session.id.clone(), session.clone());
197            Ok(())
198        } else {
199            Err(AuthError::SessionNotFound)
200        }
201    }
202
203    /// Remove session
204    pub fn remove_session(&self, id: &str) -> Option<AuthSession> {
205        let mut sessions = self.sessions.write();
206        let mut by_state = self.sessions_by_state.write();
207
208        if let Some(session) = sessions.remove(id) {
209            if let Some(state) = session.state() {
210                by_state.remove(state);
211            }
212            Some(session)
213        } else {
214            None
215        }
216    }
217
218    /// Get all sessions for a user (by email)
219    pub fn get_user_sessions(&self, email: &str) -> Vec<AuthSession> {
220        self.sessions
221            .read()
222            .values()
223            .filter(|s| s.email() == Some(email))
224            .cloned()
225            .collect()
226    }
227
228    /// Remove all sessions for a user
229    pub fn remove_user_sessions(&self, email: &str) -> usize {
230        let mut sessions = self.sessions.write();
231        let mut by_state = self.sessions_by_state.write();
232
233        let to_remove: Vec<_> = sessions
234            .iter()
235            .filter(|(_, s)| s.email() == Some(email))
236            .map(|(id, s)| (id.clone(), s.state().map(String::from)))
237            .collect();
238
239        for (id, state) in &to_remove {
240            sessions.remove(id);
241            if let Some(s) = state {
242                by_state.remove(s);
243            }
244        }
245
246        to_remove.len()
247    }
248
249    /// Cleanup expired sessions
250    pub fn cleanup_expired(&self) -> usize {
251        let mut sessions = self.sessions.write();
252        let mut by_state = self.sessions_by_state.write();
253
254        let before = sessions.len();
255
256        let expired: Vec<_> = sessions
257            .iter()
258            .filter(|(_, s)| s.is_expired())
259            .map(|(id, s)| (id.clone(), s.state().map(String::from)))
260            .collect();
261
262        for (id, state) in &expired {
263            sessions.remove(id);
264            if let Some(s) = state {
265                by_state.remove(s);
266            }
267        }
268
269        before - sessions.len()
270    }
271
272    /// Get session count
273    pub fn session_count(&self) -> usize {
274        self.sessions.read().len()
275    }
276
277    /// Get all active sessions
278    pub fn active_sessions(&self) -> Vec<AuthSession> {
279        self.sessions
280            .read()
281            .values()
282            .filter(|s| !s.is_expired() && s.is_authenticated())
283            .cloned()
284            .collect()
285    }
286}
287
288impl Default for AuthSessionManager {
289    fn default() -> Self {
290        Self::new(Duration::from_secs(86400), 5) // 24 hours, 5 sessions per user
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297
298    #[test]
299    fn test_auth_session() {
300        let mut session = AuthSession::new("google", Duration::from_secs(3600));
301
302        assert!(!session.is_expired());
303        assert!(!session.is_authenticated());
304        assert!(session.state().is_some());
305    }
306
307    #[test]
308    fn test_session_manager() {
309        let manager = AuthSessionManager::default();
310
311        let session = manager.create_session("google");
312        let state = session.state().unwrap().to_string();
313
314        // Get by ID
315        let found = manager.get_session(&session.id);
316        assert!(found.is_some());
317
318        // Get by state
319        let found = manager.get_session_by_state(&state);
320        assert!(found.is_some());
321
322        // Remove
323        manager.remove_session(&session.id);
324        assert!(manager.get_session(&session.id).is_none());
325    }
326
327    #[test]
328    fn test_session_lifecycle() {
329        let manager = AuthSessionManager::default();
330
331        let mut session = manager.create_session("google");
332
333        // Simulate successful auth
334        let tokens = TokenSet {
335            access_token: "test-token".to_string(),
336            refresh_token: Some("refresh-token".to_string()),
337            id_token: None,
338            expires_at: Utc::now() + chrono::Duration::hours(1),
339            token_type: "Bearer".to_string(),
340            scopes: vec![],
341        };
342
343        let user_info = UserInfo {
344            sub: "user123".to_string(),
345            email: Some("user@example.com".to_string()),
346            email_verified: true,
347            name: Some("Test User".to_string()),
348            given_name: None,
349            family_name: None,
350            picture: None,
351            groups: vec![],
352            provider: "google".to_string(),
353        };
354
355        session.complete_auth(tokens, user_info);
356
357        assert!(session.is_authenticated());
358        assert!(session.auth_state.is_none()); // Cleared after auth
359        assert_eq!(session.email(), Some("user@example.com"));
360    }
361}