Skip to main content

corevpn_core/
session.rs

1//! VPN Session management
2
3use chrono::{DateTime, Duration, Utc};
4use secrecy::SecretString;
5use serde::{Deserialize, Serialize};
6use uuid::Uuid;
7
8use crate::{CoreError, Result, UserId, VpnAddress};
9
10/// Unique session identifier
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
12pub struct SessionId(Uuid);
13
14impl SessionId {
15    /// Generate a new random session ID
16    pub fn new() -> Self {
17        Self(Uuid::new_v4())
18    }
19
20    /// Create from raw bytes
21    pub fn from_bytes(bytes: [u8; 16]) -> Self {
22        Self(Uuid::from_bytes(bytes))
23    }
24
25    /// Get the raw bytes
26    pub fn as_bytes(&self) -> &[u8; 16] {
27        self.0.as_bytes()
28    }
29}
30
31impl Default for SessionId {
32    fn default() -> Self {
33        Self::new()
34    }
35}
36
37impl std::fmt::Display for SessionId {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        write!(f, "{}", self.0)
40    }
41}
42
43/// Session state
44#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
45pub enum SessionState {
46    /// Initial connection, awaiting authentication
47    Connecting,
48    /// TLS handshake in progress
49    Handshaking,
50    /// Authenticating (OAuth2 flow or certificate)
51    Authenticating,
52    /// Fully established and active
53    Active,
54    /// Graceful disconnection in progress
55    Disconnecting,
56    /// Session terminated
57    Terminated,
58}
59
60/// VPN Session
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct Session {
63    /// Unique session ID
64    pub id: SessionId,
65    /// Associated user (if authenticated)
66    pub user_id: Option<UserId>,
67    /// Session state
68    pub state: SessionState,
69    /// Assigned VPN IP address
70    pub vpn_address: Option<VpnAddress>,
71    /// Client's real IP address
72    pub client_ip: std::net::IpAddr,
73    /// Client's port
74    pub client_port: u16,
75    /// Session creation time
76    pub created_at: DateTime<Utc>,
77    /// Last activity time
78    pub last_activity: DateTime<Utc>,
79    /// Session expiration time
80    pub expires_at: DateTime<Utc>,
81    /// Bytes received from client
82    pub bytes_rx: u64,
83    /// Bytes sent to client
84    pub bytes_tx: u64,
85    /// Packets received from client
86    pub packets_rx: u64,
87    /// Packets sent to client
88    pub packets_tx: u64,
89    /// Client user agent / version
90    pub client_version: Option<String>,
91    /// OAuth2 access token (if using OAuth2 auth)
92    #[serde(skip)]
93    pub oauth_token: Option<SecretString>,
94}
95
96impl Session {
97    /// Create a new session
98    pub fn new(client_ip: std::net::IpAddr, client_port: u16, lifetime: Duration) -> Self {
99        let now = Utc::now();
100        Self {
101            id: SessionId::new(),
102            user_id: None,
103            state: SessionState::Connecting,
104            vpn_address: None,
105            client_ip,
106            client_port,
107            created_at: now,
108            last_activity: now,
109            expires_at: now + lifetime,
110            bytes_rx: 0,
111            bytes_tx: 0,
112            packets_rx: 0,
113            packets_tx: 0,
114            client_version: None,
115            oauth_token: None,
116        }
117    }
118
119    /// Check if session is expired
120    pub fn is_expired(&self) -> bool {
121        Utc::now() > self.expires_at
122    }
123
124    /// Check if session is active
125    pub fn is_active(&self) -> bool {
126        self.state == SessionState::Active && !self.is_expired()
127    }
128
129    /// Update last activity timestamp
130    pub fn touch(&mut self) {
131        self.last_activity = Utc::now();
132    }
133
134    /// Record received data
135    pub fn record_rx(&mut self, bytes: u64) {
136        self.bytes_rx += bytes;
137        self.packets_rx += 1;
138        self.touch();
139    }
140
141    /// Record sent data
142    pub fn record_tx(&mut self, bytes: u64) {
143        self.bytes_tx += bytes;
144        self.packets_tx += 1;
145        self.touch();
146    }
147
148    /// Transition to a new state
149    pub fn transition(&mut self, new_state: SessionState) -> Result<()> {
150        use SessionState::*;
151
152        // Validate state transitions
153        let valid = match (self.state, new_state) {
154            (Connecting, Handshaking) => true,
155            (Handshaking, Authenticating) => true,
156            (Authenticating, Active) => true,
157            (Active, Disconnecting) => true,
158            (_, Terminated) => true, // Can always terminate
159            _ => false,
160        };
161
162        if valid {
163            self.state = new_state;
164            Ok(())
165        } else {
166            Err(CoreError::Internal(format!(
167                "Invalid state transition: {:?} -> {:?}",
168                self.state, new_state
169            )))
170        }
171    }
172
173    /// Get session duration
174    pub fn duration(&self) -> Duration {
175        Utc::now() - self.created_at
176    }
177
178    /// Get idle time since last activity
179    pub fn idle_time(&self) -> Duration {
180        Utc::now() - self.last_activity
181    }
182
183    /// Extend session expiration
184    pub fn extend(&mut self, duration: Duration) {
185        self.expires_at = Utc::now() + duration;
186    }
187}
188
189/// Session manager for tracking active sessions
190pub struct SessionManager {
191    sessions: parking_lot::RwLock<std::collections::HashMap<SessionId, Session>>,
192    max_sessions: usize,
193    default_lifetime: Duration,
194}
195
196impl SessionManager {
197    /// Create a new session manager
198    pub fn new(max_sessions: usize, default_lifetime: Duration) -> Self {
199        Self {
200            sessions: parking_lot::RwLock::new(std::collections::HashMap::new()),
201            max_sessions,
202            default_lifetime,
203        }
204    }
205
206    /// Create a new session
207    pub fn create_session(
208        &self,
209        client_ip: std::net::IpAddr,
210        client_port: u16,
211    ) -> Result<Session> {
212        let mut sessions = self.sessions.write();
213
214        // Check capacity
215        if sessions.len() >= self.max_sessions {
216            // Try to clean up expired sessions
217            sessions.retain(|_, s| !s.is_expired());
218
219            if sessions.len() >= self.max_sessions {
220                return Err(CoreError::Internal("Maximum sessions reached".into()));
221            }
222        }
223
224        let session = Session::new(client_ip, client_port, self.default_lifetime);
225        sessions.insert(session.id, session.clone());
226
227        Ok(session)
228    }
229
230    /// Get a session by ID
231    pub fn get_session(&self, id: &SessionId) -> Option<Session> {
232        self.sessions.read().get(id).cloned()
233    }
234
235    /// Update a session
236    pub fn update_session(&self, session: Session) -> Result<()> {
237        use std::collections::hash_map::Entry;
238        let mut sessions = self.sessions.write();
239        match sessions.entry(session.id) {
240            Entry::Occupied(mut e) => {
241                e.insert(session);
242                Ok(())
243            }
244            Entry::Vacant(_) => Err(CoreError::SessionNotFound(session.id.to_string())),
245        }
246    }
247
248    /// Remove a session
249    pub fn remove_session(&self, id: &SessionId) -> Option<Session> {
250        self.sessions.write().remove(id)
251    }
252
253    /// Get all active sessions
254    pub fn active_sessions(&self) -> Vec<Session> {
255        self.sessions
256            .read()
257            .values()
258            .filter(|s| s.is_active())
259            .cloned()
260            .collect()
261    }
262
263    /// Get session count
264    pub fn session_count(&self) -> usize {
265        self.sessions.read().len()
266    }
267
268    /// Clean up expired sessions
269    pub fn cleanup_expired(&self) -> usize {
270        let mut sessions = self.sessions.write();
271        let before = sessions.len();
272        sessions.retain(|_, s| !s.is_expired());
273        before - sessions.len()
274    }
275
276    /// Get sessions by user ID
277    pub fn get_user_sessions(&self, user_id: &UserId) -> Vec<Session> {
278        self.sessions
279            .read()
280            .values()
281            .filter(|s| s.user_id.as_ref() == Some(user_id))
282            .cloned()
283            .collect()
284    }
285
286    /// Terminate all sessions for a user
287    pub fn terminate_user_sessions(&self, user_id: &UserId) -> usize {
288        let mut sessions = self.sessions.write();
289        let mut count = 0;
290
291        for session in sessions.values_mut() {
292            if session.user_id.as_ref() == Some(user_id) {
293                session.state = SessionState::Terminated;
294                count += 1;
295            }
296        }
297
298        count
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305
306    #[test]
307    fn test_session_lifecycle() {
308        let mut session = Session::new(
309            "192.168.1.1".parse().unwrap(),
310            12345,
311            Duration::hours(1),
312        );
313
314        assert_eq!(session.state, SessionState::Connecting);
315        assert!(!session.is_expired());
316
317        session.transition(SessionState::Handshaking).unwrap();
318        session.transition(SessionState::Authenticating).unwrap();
319        session.transition(SessionState::Active).unwrap();
320
321        assert!(session.is_active());
322
323        session.record_rx(1000);
324        session.record_tx(500);
325
326        assert_eq!(session.bytes_rx, 1000);
327        assert_eq!(session.bytes_tx, 500);
328    }
329
330    #[test]
331    fn test_session_manager() {
332        let manager = SessionManager::new(100, Duration::hours(1));
333
334        let session = manager
335            .create_session("192.168.1.1".parse().unwrap(), 12345)
336            .unwrap();
337
338        assert!(manager.get_session(&session.id).is_some());
339        assert_eq!(manager.session_count(), 1);
340
341        manager.remove_session(&session.id);
342        assert!(manager.get_session(&session.id).is_none());
343    }
344}