1use chrono::{DateTime, Duration, Utc};
4use secrecy::SecretString;
5use serde::{Deserialize, Serialize};
6use uuid::Uuid;
7
8use crate::{CoreError, Result, UserId, VpnAddress};
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
12pub struct SessionId(Uuid);
13
14impl SessionId {
15 pub fn new() -> Self {
17 Self(Uuid::new_v4())
18 }
19
20 pub fn from_bytes(bytes: [u8; 16]) -> Self {
22 Self(Uuid::from_bytes(bytes))
23 }
24
25 pub fn as_bytes(&self) -> &[u8; 16] {
27 self.0.as_bytes()
28 }
29}
30
31impl Default for SessionId {
32 fn default() -> Self {
33 Self::new()
34 }
35}
36
37impl std::fmt::Display for SessionId {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 write!(f, "{}", self.0)
40 }
41}
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
45pub enum SessionState {
46 Connecting,
48 Handshaking,
50 Authenticating,
52 Active,
54 Disconnecting,
56 Terminated,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct Session {
63 pub id: SessionId,
65 pub user_id: Option<UserId>,
67 pub state: SessionState,
69 pub vpn_address: Option<VpnAddress>,
71 pub client_ip: std::net::IpAddr,
73 pub client_port: u16,
75 pub created_at: DateTime<Utc>,
77 pub last_activity: DateTime<Utc>,
79 pub expires_at: DateTime<Utc>,
81 pub bytes_rx: u64,
83 pub bytes_tx: u64,
85 pub packets_rx: u64,
87 pub packets_tx: u64,
89 pub client_version: Option<String>,
91 #[serde(skip)]
93 pub oauth_token: Option<SecretString>,
94}
95
96impl Session {
97 pub fn new(client_ip: std::net::IpAddr, client_port: u16, lifetime: Duration) -> Self {
99 let now = Utc::now();
100 Self {
101 id: SessionId::new(),
102 user_id: None,
103 state: SessionState::Connecting,
104 vpn_address: None,
105 client_ip,
106 client_port,
107 created_at: now,
108 last_activity: now,
109 expires_at: now + lifetime,
110 bytes_rx: 0,
111 bytes_tx: 0,
112 packets_rx: 0,
113 packets_tx: 0,
114 client_version: None,
115 oauth_token: None,
116 }
117 }
118
119 pub fn is_expired(&self) -> bool {
121 Utc::now() > self.expires_at
122 }
123
124 pub fn is_active(&self) -> bool {
126 self.state == SessionState::Active && !self.is_expired()
127 }
128
129 pub fn touch(&mut self) {
131 self.last_activity = Utc::now();
132 }
133
134 pub fn record_rx(&mut self, bytes: u64) {
136 self.bytes_rx += bytes;
137 self.packets_rx += 1;
138 self.touch();
139 }
140
141 pub fn record_tx(&mut self, bytes: u64) {
143 self.bytes_tx += bytes;
144 self.packets_tx += 1;
145 self.touch();
146 }
147
148 pub fn transition(&mut self, new_state: SessionState) -> Result<()> {
150 use SessionState::*;
151
152 let valid = match (self.state, new_state) {
154 (Connecting, Handshaking) => true,
155 (Handshaking, Authenticating) => true,
156 (Authenticating, Active) => true,
157 (Active, Disconnecting) => true,
158 (_, Terminated) => true, _ => false,
160 };
161
162 if valid {
163 self.state = new_state;
164 Ok(())
165 } else {
166 Err(CoreError::Internal(format!(
167 "Invalid state transition: {:?} -> {:?}",
168 self.state, new_state
169 )))
170 }
171 }
172
173 pub fn duration(&self) -> Duration {
175 Utc::now() - self.created_at
176 }
177
178 pub fn idle_time(&self) -> Duration {
180 Utc::now() - self.last_activity
181 }
182
183 pub fn extend(&mut self, duration: Duration) {
185 self.expires_at = Utc::now() + duration;
186 }
187}
188
189pub struct SessionManager {
191 sessions: parking_lot::RwLock<std::collections::HashMap<SessionId, Session>>,
192 max_sessions: usize,
193 default_lifetime: Duration,
194}
195
196impl SessionManager {
197 pub fn new(max_sessions: usize, default_lifetime: Duration) -> Self {
199 Self {
200 sessions: parking_lot::RwLock::new(std::collections::HashMap::new()),
201 max_sessions,
202 default_lifetime,
203 }
204 }
205
206 pub fn create_session(
208 &self,
209 client_ip: std::net::IpAddr,
210 client_port: u16,
211 ) -> Result<Session> {
212 let mut sessions = self.sessions.write();
213
214 if sessions.len() >= self.max_sessions {
216 sessions.retain(|_, s| !s.is_expired());
218
219 if sessions.len() >= self.max_sessions {
220 return Err(CoreError::Internal("Maximum sessions reached".into()));
221 }
222 }
223
224 let session = Session::new(client_ip, client_port, self.default_lifetime);
225 sessions.insert(session.id, session.clone());
226
227 Ok(session)
228 }
229
230 pub fn get_session(&self, id: &SessionId) -> Option<Session> {
232 self.sessions.read().get(id).cloned()
233 }
234
235 pub fn update_session(&self, session: Session) -> Result<()> {
237 use std::collections::hash_map::Entry;
238 let mut sessions = self.sessions.write();
239 match sessions.entry(session.id) {
240 Entry::Occupied(mut e) => {
241 e.insert(session);
242 Ok(())
243 }
244 Entry::Vacant(_) => Err(CoreError::SessionNotFound(session.id.to_string())),
245 }
246 }
247
248 pub fn remove_session(&self, id: &SessionId) -> Option<Session> {
250 self.sessions.write().remove(id)
251 }
252
253 pub fn active_sessions(&self) -> Vec<Session> {
255 self.sessions
256 .read()
257 .values()
258 .filter(|s| s.is_active())
259 .cloned()
260 .collect()
261 }
262
263 pub fn session_count(&self) -> usize {
265 self.sessions.read().len()
266 }
267
268 pub fn cleanup_expired(&self) -> usize {
270 let mut sessions = self.sessions.write();
271 let before = sessions.len();
272 sessions.retain(|_, s| !s.is_expired());
273 before - sessions.len()
274 }
275
276 pub fn get_user_sessions(&self, user_id: &UserId) -> Vec<Session> {
278 self.sessions
279 .read()
280 .values()
281 .filter(|s| s.user_id.as_ref() == Some(user_id))
282 .cloned()
283 .collect()
284 }
285
286 pub fn terminate_user_sessions(&self, user_id: &UserId) -> usize {
288 let mut sessions = self.sessions.write();
289 let mut count = 0;
290
291 for session in sessions.values_mut() {
292 if session.user_id.as_ref() == Some(user_id) {
293 session.state = SessionState::Terminated;
294 count += 1;
295 }
296 }
297
298 count
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305
306 #[test]
307 fn test_session_lifecycle() {
308 let mut session = Session::new(
309 "192.168.1.1".parse().unwrap(),
310 12345,
311 Duration::hours(1),
312 );
313
314 assert_eq!(session.state, SessionState::Connecting);
315 assert!(!session.is_expired());
316
317 session.transition(SessionState::Handshaking).unwrap();
318 session.transition(SessionState::Authenticating).unwrap();
319 session.transition(SessionState::Active).unwrap();
320
321 assert!(session.is_active());
322
323 session.record_rx(1000);
324 session.record_tx(500);
325
326 assert_eq!(session.bytes_rx, 1000);
327 assert_eq!(session.bytes_tx, 500);
328 }
329
330 #[test]
331 fn test_session_manager() {
332 let manager = SessionManager::new(100, Duration::hours(1));
333
334 let session = manager
335 .create_session("192.168.1.1".parse().unwrap(), 12345)
336 .unwrap();
337
338 assert!(manager.get_session(&session.id).is_some());
339 assert_eq!(manager.session_count(), 1);
340
341 manager.remove_session(&session.id);
342 assert!(manager.get_session(&session.id).is_none());
343 }
344}