1use chrono::{DateTime, Duration, Utc};
4use serde::{Deserialize, Serialize};
5use uuid::Uuid;
6
7use crate::{CoreError, Result, UserId, VpnAddress};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
11pub struct SessionId(Uuid);
12
13impl SessionId {
14 pub fn new() -> Self {
16 Self(Uuid::new_v4())
17 }
18
19 pub fn from_bytes(bytes: [u8; 16]) -> Self {
21 Self(Uuid::from_bytes(bytes))
22 }
23
24 pub fn as_bytes(&self) -> &[u8; 16] {
26 self.0.as_bytes()
27 }
28}
29
30impl Default for SessionId {
31 fn default() -> Self {
32 Self::new()
33 }
34}
35
36impl std::fmt::Display for SessionId {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 write!(f, "{}", self.0)
39 }
40}
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
44pub enum SessionState {
45 Connecting,
47 Handshaking,
49 Authenticating,
51 Active,
53 Disconnecting,
55 Terminated,
57}
58
59#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct Session {
62 pub id: SessionId,
64 pub user_id: Option<UserId>,
66 pub state: SessionState,
68 pub vpn_address: Option<VpnAddress>,
70 pub client_ip: std::net::IpAddr,
72 pub client_port: u16,
74 pub created_at: DateTime<Utc>,
76 pub last_activity: DateTime<Utc>,
78 pub expires_at: DateTime<Utc>,
80 pub bytes_rx: u64,
82 pub bytes_tx: u64,
84 pub packets_rx: u64,
86 pub packets_tx: u64,
88 pub client_version: Option<String>,
90 #[serde(skip)]
92 pub oauth_token: Option<String>,
93}
94
95impl Session {
96 pub fn new(client_ip: std::net::IpAddr, client_port: u16, lifetime: Duration) -> Self {
98 let now = Utc::now();
99 Self {
100 id: SessionId::new(),
101 user_id: None,
102 state: SessionState::Connecting,
103 vpn_address: None,
104 client_ip,
105 client_port,
106 created_at: now,
107 last_activity: now,
108 expires_at: now + lifetime,
109 bytes_rx: 0,
110 bytes_tx: 0,
111 packets_rx: 0,
112 packets_tx: 0,
113 client_version: None,
114 oauth_token: None,
115 }
116 }
117
118 pub fn is_expired(&self) -> bool {
120 Utc::now() > self.expires_at
121 }
122
123 pub fn is_active(&self) -> bool {
125 self.state == SessionState::Active && !self.is_expired()
126 }
127
128 pub fn touch(&mut self) {
130 self.last_activity = Utc::now();
131 }
132
133 pub fn record_rx(&mut self, bytes: u64) {
135 self.bytes_rx += bytes;
136 self.packets_rx += 1;
137 self.touch();
138 }
139
140 pub fn record_tx(&mut self, bytes: u64) {
142 self.bytes_tx += bytes;
143 self.packets_tx += 1;
144 self.touch();
145 }
146
147 pub fn transition(&mut self, new_state: SessionState) -> Result<()> {
149 use SessionState::*;
150
151 let valid = match (self.state, new_state) {
153 (Connecting, Handshaking) => true,
154 (Handshaking, Authenticating) => true,
155 (Authenticating, Active) => true,
156 (Active, Disconnecting) => true,
157 (_, Terminated) => true, _ => false,
159 };
160
161 if valid {
162 self.state = new_state;
163 Ok(())
164 } else {
165 Err(CoreError::Internal(format!(
166 "Invalid state transition: {:?} -> {:?}",
167 self.state, new_state
168 )))
169 }
170 }
171
172 pub fn duration(&self) -> Duration {
174 Utc::now() - self.created_at
175 }
176
177 pub fn idle_time(&self) -> Duration {
179 Utc::now() - self.last_activity
180 }
181
182 pub fn extend(&mut self, duration: Duration) {
184 self.expires_at = Utc::now() + duration;
185 }
186}
187
188pub struct SessionManager {
190 sessions: parking_lot::RwLock<std::collections::HashMap<SessionId, Session>>,
191 max_sessions: usize,
192 default_lifetime: Duration,
193}
194
195impl SessionManager {
196 pub fn new(max_sessions: usize, default_lifetime: Duration) -> Self {
198 Self {
199 sessions: parking_lot::RwLock::new(std::collections::HashMap::new()),
200 max_sessions,
201 default_lifetime,
202 }
203 }
204
205 pub fn create_session(
207 &self,
208 client_ip: std::net::IpAddr,
209 client_port: u16,
210 ) -> Result<Session> {
211 let mut sessions = self.sessions.write();
212
213 if sessions.len() >= self.max_sessions {
215 sessions.retain(|_, s| !s.is_expired());
217
218 if sessions.len() >= self.max_sessions {
219 return Err(CoreError::Internal("Maximum sessions reached".into()));
220 }
221 }
222
223 let session = Session::new(client_ip, client_port, self.default_lifetime);
224 sessions.insert(session.id, session.clone());
225
226 Ok(session)
227 }
228
229 pub fn get_session(&self, id: &SessionId) -> Option<Session> {
231 self.sessions.read().get(id).cloned()
232 }
233
234 pub fn update_session(&self, session: Session) -> Result<()> {
236 use std::collections::hash_map::Entry;
237 let mut sessions = self.sessions.write();
238 match sessions.entry(session.id) {
239 Entry::Occupied(mut e) => {
240 e.insert(session);
241 Ok(())
242 }
243 Entry::Vacant(_) => Err(CoreError::SessionNotFound(session.id.to_string())),
244 }
245 }
246
247 pub fn remove_session(&self, id: &SessionId) -> Option<Session> {
249 self.sessions.write().remove(id)
250 }
251
252 pub fn active_sessions(&self) -> Vec<Session> {
254 self.sessions
255 .read()
256 .values()
257 .filter(|s| s.is_active())
258 .cloned()
259 .collect()
260 }
261
262 pub fn session_count(&self) -> usize {
264 self.sessions.read().len()
265 }
266
267 pub fn cleanup_expired(&self) -> usize {
269 let mut sessions = self.sessions.write();
270 let before = sessions.len();
271 sessions.retain(|_, s| !s.is_expired());
272 before - sessions.len()
273 }
274
275 pub fn get_user_sessions(&self, user_id: &UserId) -> Vec<Session> {
277 self.sessions
278 .read()
279 .values()
280 .filter(|s| s.user_id.as_ref() == Some(user_id))
281 .cloned()
282 .collect()
283 }
284
285 pub fn terminate_user_sessions(&self, user_id: &UserId) -> usize {
287 let mut sessions = self.sessions.write();
288 let mut count = 0;
289
290 for session in sessions.values_mut() {
291 if session.user_id.as_ref() == Some(user_id) {
292 session.state = SessionState::Terminated;
293 count += 1;
294 }
295 }
296
297 count
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304
305 #[test]
306 fn test_session_lifecycle() {
307 let mut session = Session::new(
308 "192.168.1.1".parse().unwrap(),
309 12345,
310 Duration::hours(1),
311 );
312
313 assert_eq!(session.state, SessionState::Connecting);
314 assert!(!session.is_expired());
315
316 session.transition(SessionState::Handshaking).unwrap();
317 session.transition(SessionState::Authenticating).unwrap();
318 session.transition(SessionState::Active).unwrap();
319
320 assert!(session.is_active());
321
322 session.record_rx(1000);
323 session.record_tx(500);
324
325 assert_eq!(session.bytes_rx, 1000);
326 assert_eq!(session.bytes_tx, 500);
327 }
328
329 #[test]
330 fn test_session_manager() {
331 let manager = SessionManager::new(100, Duration::hours(1));
332
333 let session = manager
334 .create_session("192.168.1.1".parse().unwrap(), 12345)
335 .unwrap();
336
337 assert!(manager.get_session(&session.id).is_some());
338 assert_eq!(manager.session_count(), 1);
339
340 manager.remove_session(&session.id);
341 assert!(manager.get_session(&session.id).is_none());
342 }
343}