corevpn_core/
session.rs

1//! VPN Session management
2
3use chrono::{DateTime, Duration, Utc};
4use serde::{Deserialize, Serialize};
5use uuid::Uuid;
6
7use crate::{CoreError, Result, UserId, VpnAddress};
8
9/// Unique session identifier
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
11pub struct SessionId(Uuid);
12
13impl SessionId {
14    /// Generate a new random session ID
15    pub fn new() -> Self {
16        Self(Uuid::new_v4())
17    }
18
19    /// Create from raw bytes
20    pub fn from_bytes(bytes: [u8; 16]) -> Self {
21        Self(Uuid::from_bytes(bytes))
22    }
23
24    /// Get the raw bytes
25    pub fn as_bytes(&self) -> &[u8; 16] {
26        self.0.as_bytes()
27    }
28}
29
30impl Default for SessionId {
31    fn default() -> Self {
32        Self::new()
33    }
34}
35
36impl std::fmt::Display for SessionId {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        write!(f, "{}", self.0)
39    }
40}
41
42/// Session state
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
44pub enum SessionState {
45    /// Initial connection, awaiting authentication
46    Connecting,
47    /// TLS handshake in progress
48    Handshaking,
49    /// Authenticating (OAuth2 flow or certificate)
50    Authenticating,
51    /// Fully established and active
52    Active,
53    /// Graceful disconnection in progress
54    Disconnecting,
55    /// Session terminated
56    Terminated,
57}
58
59/// VPN Session
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct Session {
62    /// Unique session ID
63    pub id: SessionId,
64    /// Associated user (if authenticated)
65    pub user_id: Option<UserId>,
66    /// Session state
67    pub state: SessionState,
68    /// Assigned VPN IP address
69    pub vpn_address: Option<VpnAddress>,
70    /// Client's real IP address
71    pub client_ip: std::net::IpAddr,
72    /// Client's port
73    pub client_port: u16,
74    /// Session creation time
75    pub created_at: DateTime<Utc>,
76    /// Last activity time
77    pub last_activity: DateTime<Utc>,
78    /// Session expiration time
79    pub expires_at: DateTime<Utc>,
80    /// Bytes received from client
81    pub bytes_rx: u64,
82    /// Bytes sent to client
83    pub bytes_tx: u64,
84    /// Packets received from client
85    pub packets_rx: u64,
86    /// Packets sent to client
87    pub packets_tx: u64,
88    /// Client user agent / version
89    pub client_version: Option<String>,
90    /// OAuth2 access token (if using OAuth2 auth)
91    #[serde(skip)]
92    pub oauth_token: Option<String>,
93}
94
95impl Session {
96    /// Create a new session
97    pub fn new(client_ip: std::net::IpAddr, client_port: u16, lifetime: Duration) -> Self {
98        let now = Utc::now();
99        Self {
100            id: SessionId::new(),
101            user_id: None,
102            state: SessionState::Connecting,
103            vpn_address: None,
104            client_ip,
105            client_port,
106            created_at: now,
107            last_activity: now,
108            expires_at: now + lifetime,
109            bytes_rx: 0,
110            bytes_tx: 0,
111            packets_rx: 0,
112            packets_tx: 0,
113            client_version: None,
114            oauth_token: None,
115        }
116    }
117
118    /// Check if session is expired
119    pub fn is_expired(&self) -> bool {
120        Utc::now() > self.expires_at
121    }
122
123    /// Check if session is active
124    pub fn is_active(&self) -> bool {
125        self.state == SessionState::Active && !self.is_expired()
126    }
127
128    /// Update last activity timestamp
129    pub fn touch(&mut self) {
130        self.last_activity = Utc::now();
131    }
132
133    /// Record received data
134    pub fn record_rx(&mut self, bytes: u64) {
135        self.bytes_rx += bytes;
136        self.packets_rx += 1;
137        self.touch();
138    }
139
140    /// Record sent data
141    pub fn record_tx(&mut self, bytes: u64) {
142        self.bytes_tx += bytes;
143        self.packets_tx += 1;
144        self.touch();
145    }
146
147    /// Transition to a new state
148    pub fn transition(&mut self, new_state: SessionState) -> Result<()> {
149        use SessionState::*;
150
151        // Validate state transitions
152        let valid = match (self.state, new_state) {
153            (Connecting, Handshaking) => true,
154            (Handshaking, Authenticating) => true,
155            (Authenticating, Active) => true,
156            (Active, Disconnecting) => true,
157            (_, Terminated) => true, // Can always terminate
158            _ => false,
159        };
160
161        if valid {
162            self.state = new_state;
163            Ok(())
164        } else {
165            Err(CoreError::Internal(format!(
166                "Invalid state transition: {:?} -> {:?}",
167                self.state, new_state
168            )))
169        }
170    }
171
172    /// Get session duration
173    pub fn duration(&self) -> Duration {
174        Utc::now() - self.created_at
175    }
176
177    /// Get idle time since last activity
178    pub fn idle_time(&self) -> Duration {
179        Utc::now() - self.last_activity
180    }
181
182    /// Extend session expiration
183    pub fn extend(&mut self, duration: Duration) {
184        self.expires_at = Utc::now() + duration;
185    }
186}
187
188/// Session manager for tracking active sessions
189pub struct SessionManager {
190    sessions: parking_lot::RwLock<std::collections::HashMap<SessionId, Session>>,
191    max_sessions: usize,
192    default_lifetime: Duration,
193}
194
195impl SessionManager {
196    /// Create a new session manager
197    pub fn new(max_sessions: usize, default_lifetime: Duration) -> Self {
198        Self {
199            sessions: parking_lot::RwLock::new(std::collections::HashMap::new()),
200            max_sessions,
201            default_lifetime,
202        }
203    }
204
205    /// Create a new session
206    pub fn create_session(
207        &self,
208        client_ip: std::net::IpAddr,
209        client_port: u16,
210    ) -> Result<Session> {
211        let mut sessions = self.sessions.write();
212
213        // Check capacity
214        if sessions.len() >= self.max_sessions {
215            // Try to clean up expired sessions
216            sessions.retain(|_, s| !s.is_expired());
217
218            if sessions.len() >= self.max_sessions {
219                return Err(CoreError::Internal("Maximum sessions reached".into()));
220            }
221        }
222
223        let session = Session::new(client_ip, client_port, self.default_lifetime);
224        sessions.insert(session.id, session.clone());
225
226        Ok(session)
227    }
228
229    /// Get a session by ID
230    pub fn get_session(&self, id: &SessionId) -> Option<Session> {
231        self.sessions.read().get(id).cloned()
232    }
233
234    /// Update a session
235    pub fn update_session(&self, session: Session) -> Result<()> {
236        use std::collections::hash_map::Entry;
237        let mut sessions = self.sessions.write();
238        match sessions.entry(session.id) {
239            Entry::Occupied(mut e) => {
240                e.insert(session);
241                Ok(())
242            }
243            Entry::Vacant(_) => Err(CoreError::SessionNotFound(session.id.to_string())),
244        }
245    }
246
247    /// Remove a session
248    pub fn remove_session(&self, id: &SessionId) -> Option<Session> {
249        self.sessions.write().remove(id)
250    }
251
252    /// Get all active sessions
253    pub fn active_sessions(&self) -> Vec<Session> {
254        self.sessions
255            .read()
256            .values()
257            .filter(|s| s.is_active())
258            .cloned()
259            .collect()
260    }
261
262    /// Get session count
263    pub fn session_count(&self) -> usize {
264        self.sessions.read().len()
265    }
266
267    /// Clean up expired sessions
268    pub fn cleanup_expired(&self) -> usize {
269        let mut sessions = self.sessions.write();
270        let before = sessions.len();
271        sessions.retain(|_, s| !s.is_expired());
272        before - sessions.len()
273    }
274
275    /// Get sessions by user ID
276    pub fn get_user_sessions(&self, user_id: &UserId) -> Vec<Session> {
277        self.sessions
278            .read()
279            .values()
280            .filter(|s| s.user_id.as_ref() == Some(user_id))
281            .cloned()
282            .collect()
283    }
284
285    /// Terminate all sessions for a user
286    pub fn terminate_user_sessions(&self, user_id: &UserId) -> usize {
287        let mut sessions = self.sessions.write();
288        let mut count = 0;
289
290        for session in sessions.values_mut() {
291            if session.user_id.as_ref() == Some(user_id) {
292                session.state = SessionState::Terminated;
293                count += 1;
294            }
295        }
296
297        count
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304
305    #[test]
306    fn test_session_lifecycle() {
307        let mut session = Session::new(
308            "192.168.1.1".parse().unwrap(),
309            12345,
310            Duration::hours(1),
311        );
312
313        assert_eq!(session.state, SessionState::Connecting);
314        assert!(!session.is_expired());
315
316        session.transition(SessionState::Handshaking).unwrap();
317        session.transition(SessionState::Authenticating).unwrap();
318        session.transition(SessionState::Active).unwrap();
319
320        assert!(session.is_active());
321
322        session.record_rx(1000);
323        session.record_tx(500);
324
325        assert_eq!(session.bytes_rx, 1000);
326        assert_eq!(session.bytes_tx, 500);
327    }
328
329    #[test]
330    fn test_session_manager() {
331        let manager = SessionManager::new(100, Duration::hours(1));
332
333        let session = manager
334            .create_session("192.168.1.1".parse().unwrap(), 12345)
335            .unwrap();
336
337        assert!(manager.get_session(&session.id).is_some());
338        assert_eq!(manager.session_count(), 1);
339
340        manager.remove_session(&session.id);
341        assert!(manager.get_session(&session.id).is_none());
342    }
343}