1use crate::relay::{
4 AuthToken, RelayAuthenticator, RelayConnection, RelayConnectionConfig, RelayError, RelayResult,
5};
6use ed25519_dalek::VerifyingKey;
7use std::collections::HashMap;
8use std::net::SocketAddr;
9use std::sync::{Arc, Mutex};
10use std::time::{Duration, Instant};
11use tokio::sync::mpsc;
12
13pub type SessionId = u32;
15
16#[derive(Debug, Clone)]
18pub struct SessionConfig {
19 pub max_sessions: usize,
21 pub default_timeout: Duration,
23 pub cleanup_interval: Duration,
25 pub default_bandwidth_limit: u64,
27}
28
29impl Default for SessionConfig {
30 fn default() -> Self {
31 Self {
32 max_sessions: 100,
33 default_timeout: Duration::from_secs(300), cleanup_interval: Duration::from_secs(30), default_bandwidth_limit: 1048576, }
37 }
38}
39
40#[derive(Debug, Clone, PartialEq, Eq)]
42pub enum SessionState {
43 Pending,
45 Active,
47 Terminating,
49 Terminated,
51 Failed { reason: String },
53}
54
55#[derive(Debug, Clone)]
57pub struct RelaySessionInfo {
58 pub session_id: SessionId,
60 pub client_addr: SocketAddr,
62 pub peer_connection_id: Vec<u8>,
64 pub state: SessionState,
66 pub created_at: Instant,
68 pub last_activity: Instant,
70 pub bandwidth_limit: u64,
72 pub timeout: Duration,
74 pub bytes_sent: u64,
76 pub bytes_received: u64,
77}
78
79#[derive(Debug)]
81pub struct SessionManager {
82 config: SessionConfig,
84 sessions: Arc<Mutex<HashMap<SessionId, RelaySessionInfo>>>,
86 connections: Arc<Mutex<HashMap<SessionId, Arc<RelayConnection>>>>,
88 authenticator: RelayAuthenticator,
90 trusted_keys: Arc<Mutex<HashMap<SocketAddr, VerifyingKey>>>,
92 next_session_id: Arc<Mutex<SessionId>>,
94 event_sender: mpsc::UnboundedSender<SessionEvent>,
96 last_cleanup: Arc<Mutex<Instant>>,
98}
99
100#[derive(Debug, Clone)]
102pub enum SessionEvent {
103 SessionRequested {
105 session_id: SessionId,
106 client_addr: SocketAddr,
107 peer_connection_id: Vec<u8>,
108 auth_token: AuthToken,
109 },
110 SessionEstablished {
112 session_id: SessionId,
113 client_addr: SocketAddr,
114 },
115 SessionTerminated {
117 session_id: SessionId,
118 reason: String,
119 },
120 SessionFailed {
122 session_id: SessionId,
123 error: RelayError,
124 },
125 DataForwarded {
127 session_id: SessionId,
128 bytes: usize,
129 direction: ForwardDirection,
130 },
131}
132
133#[derive(Debug, Clone, PartialEq, Eq)]
135pub enum ForwardDirection {
136 ClientToPeer,
138 PeerToClient,
140}
141
142impl SessionManager {
143 pub fn new(config: SessionConfig) -> (Self, mpsc::UnboundedReceiver<SessionEvent>) {
145 let (event_sender, event_receiver) = mpsc::unbounded_channel();
146
147 let manager = Self {
148 config,
149 sessions: Arc::new(Mutex::new(HashMap::new())),
150 connections: Arc::new(Mutex::new(HashMap::new())),
151 authenticator: RelayAuthenticator::new(),
152 trusted_keys: Arc::new(Mutex::new(HashMap::new())),
153 next_session_id: Arc::new(Mutex::new(1)),
154 event_sender,
155 last_cleanup: Arc::new(Mutex::new(Instant::now())),
156 };
157
158 (manager, event_receiver)
159 }
160
161 pub fn add_trusted_key(&self, addr: SocketAddr, key: VerifyingKey) {
163 let mut trusted_keys = self.trusted_keys.lock().unwrap();
164 trusted_keys.insert(addr, key);
165 }
166
167 pub fn remove_trusted_key(&self, addr: &SocketAddr) {
169 let mut trusted_keys = self.trusted_keys.lock().unwrap();
170 trusted_keys.remove(addr);
171 }
172
173 fn next_session_id(&self) -> SessionId {
175 let mut next_id = self.next_session_id.lock().unwrap();
176 let id = *next_id;
177 *next_id = next_id.wrapping_add(1);
178 if *next_id == 0 {
179 *next_id = 1; }
181 id
182 }
183
184 pub fn request_session(
186 &self,
187 client_addr: SocketAddr,
188 peer_connection_id: Vec<u8>,
189 auth_token: AuthToken,
190 ) -> RelayResult<SessionId> {
191 {
193 let sessions = self.sessions.lock().unwrap();
194 if sessions.len() >= self.config.max_sessions {
195 return Err(RelayError::ResourceExhausted {
196 resource_type: "sessions".to_string(),
197 current_usage: sessions.len() as u64,
198 limit: self.config.max_sessions as u64,
199 });
200 }
201 }
202
203 let trusted_keys = self.trusted_keys.lock().unwrap();
205 let peer_key =
206 trusted_keys
207 .get(&client_addr)
208 .ok_or_else(|| RelayError::AuthenticationFailed {
209 reason: format!("No trusted key for address {}", client_addr),
210 })?;
211
212 self.authenticator.verify_token(&auth_token, peer_key)?;
213
214 let session_id = self.next_session_id();
216
217 let now = Instant::now();
219 let session_info = RelaySessionInfo {
220 session_id,
221 client_addr,
222 peer_connection_id: peer_connection_id.clone(),
223 state: SessionState::Pending,
224 created_at: now,
225 last_activity: now,
226 bandwidth_limit: auth_token.bandwidth_limit as u64,
227 timeout: Duration::from_secs(auth_token.timeout_seconds as u64),
228 bytes_sent: 0,
229 bytes_received: 0,
230 };
231
232 {
234 let mut sessions = self.sessions.lock().unwrap();
235 sessions.insert(session_id, session_info);
236 }
237
238 let _ = self.event_sender.send(SessionEvent::SessionRequested {
240 session_id,
241 client_addr,
242 peer_connection_id,
243 auth_token,
244 });
245
246 Ok(session_id)
247 }
248
249 pub fn establish_session(&self, session_id: SessionId) -> RelayResult<()> {
251 let (client_addr, bandwidth_limit) = {
252 let mut sessions = self.sessions.lock().unwrap();
253 let session = sessions
254 .get_mut(&session_id)
255 .ok_or(RelayError::SessionError {
256 session_id: Some(session_id),
257 kind: crate::relay::error::SessionErrorKind::NotFound,
258 })?;
259
260 if session.state != SessionState::Pending {
261 return Err(RelayError::SessionError {
262 session_id: Some(session_id),
263 kind: crate::relay::error::SessionErrorKind::InvalidState {
264 current_state: format!("{:?}", session.state),
265 expected_state: "Pending".to_string(),
266 },
267 });
268 }
269
270 session.state = SessionState::Active;
271 session.last_activity = Instant::now();
272
273 (session.client_addr, session.bandwidth_limit)
274 };
275
276 let (event_tx, _event_rx) = mpsc::unbounded_channel();
278 let (_action_tx, action_rx) = mpsc::unbounded_channel();
279
280 let mut connection_config = RelayConnectionConfig::default();
281 connection_config.bandwidth_limit = bandwidth_limit;
282
283 let connection = RelayConnection::new(
284 session_id,
285 client_addr,
286 connection_config,
287 event_tx,
288 action_rx,
289 );
290
291 {
293 let mut connections = self.connections.lock().unwrap();
294 connections.insert(session_id, Arc::new(connection));
295 }
296
297 let _ = self.event_sender.send(SessionEvent::SessionEstablished {
299 session_id,
300 client_addr,
301 });
302
303 Ok(())
304 }
305
306 pub fn terminate_session(&self, session_id: SessionId, reason: String) -> RelayResult<()> {
308 {
310 let mut sessions = self.sessions.lock().unwrap();
311 if let Some(session) = sessions.get_mut(&session_id) {
312 session.state = SessionState::Terminated;
313 session.last_activity = Instant::now();
314 }
315 }
316
317 {
319 let mut connections = self.connections.lock().unwrap();
320 if let Some(connection) = connections.remove(&session_id) {
321 let _ = connection.terminate(reason.clone());
322 }
323 }
324
325 let _ = self
327 .event_sender
328 .send(SessionEvent::SessionTerminated { session_id, reason });
329
330 Ok(())
331 }
332
333 pub fn forward_data(
335 &self,
336 session_id: SessionId,
337 data: Vec<u8>,
338 direction: ForwardDirection,
339 ) -> RelayResult<()> {
340 let connection = {
341 let connections = self.connections.lock().unwrap();
342 connections
343 .get(&session_id)
344 .cloned()
345 .ok_or(RelayError::SessionError {
346 session_id: Some(session_id),
347 kind: crate::relay::error::SessionErrorKind::NotFound,
348 })?
349 };
350
351 match direction {
353 ForwardDirection::ClientToPeer => {
354 connection.send_data(data.clone())?;
355 }
356 ForwardDirection::PeerToClient => {
357 connection.receive_data(data.clone())?;
358 }
359 }
360
361 {
363 let mut sessions = self.sessions.lock().unwrap();
364 if let Some(session) = sessions.get_mut(&session_id) {
365 session.last_activity = Instant::now();
366 match direction {
367 ForwardDirection::ClientToPeer => {
368 session.bytes_sent += data.len() as u64;
369 }
370 ForwardDirection::PeerToClient => {
371 session.bytes_received += data.len() as u64;
372 }
373 }
374 }
375 }
376
377 let _ = self.event_sender.send(SessionEvent::DataForwarded {
379 session_id,
380 bytes: data.len(),
381 direction,
382 });
383
384 Ok(())
385 }
386
387 pub fn get_session(&self, session_id: SessionId) -> Option<RelaySessionInfo> {
389 let sessions = self.sessions.lock().unwrap();
390 sessions.get(&session_id).cloned()
391 }
392
393 pub fn list_sessions(&self) -> Vec<RelaySessionInfo> {
395 let sessions = self.sessions.lock().unwrap();
396 sessions.values().cloned().collect()
397 }
398
399 pub fn session_count(&self) -> usize {
401 let sessions = self.sessions.lock().unwrap();
402 sessions.len()
403 }
404
405 pub fn cleanup_expired_sessions(&self) -> RelayResult<usize> {
407 let mut last_cleanup = self.last_cleanup.lock().unwrap();
408 let now = Instant::now();
409
410 if now.duration_since(*last_cleanup) < self.config.cleanup_interval {
412 return Ok(0);
413 }
414
415 *last_cleanup = now;
416 drop(last_cleanup);
417
418 let mut expired_sessions = Vec::new();
419
420 {
422 let sessions = self.sessions.lock().unwrap();
423 for (session_id, session_info) in sessions.iter() {
424 let age = now.duration_since(session_info.last_activity);
425 if age > session_info.timeout {
426 expired_sessions.push(*session_id);
427 }
428 }
429 }
430
431 let cleanup_count = expired_sessions.len();
433 for session_id in expired_sessions {
434 let _ = self.terminate_session(session_id, "Session expired".to_string());
435
436 let mut sessions = self.sessions.lock().unwrap();
438 sessions.remove(&session_id);
439 }
440
441 Ok(cleanup_count)
442 }
443
444 pub fn get_statistics(&self) -> SessionManagerStats {
446 let sessions = self.sessions.lock().unwrap();
447 let connections = self.connections.lock().unwrap();
448
449 let mut active_count = 0;
450 let mut pending_count = 0;
451 let mut total_bytes_sent = 0;
452 let mut total_bytes_received = 0;
453
454 for session in sessions.values() {
455 match session.state {
456 SessionState::Active => active_count += 1,
457 SessionState::Pending => pending_count += 1,
458 _ => {}
459 }
460 total_bytes_sent += session.bytes_sent;
461 total_bytes_received += session.bytes_received;
462 }
463
464 SessionManagerStats {
465 total_sessions: sessions.len(),
466 active_sessions: active_count,
467 pending_sessions: pending_count,
468 total_connections: connections.len(),
469 total_bytes_sent,
470 total_bytes_received,
471 }
472 }
473}
474
475#[derive(Debug, Clone)]
477pub struct SessionManagerStats {
478 pub total_sessions: usize,
479 pub active_sessions: usize,
480 pub pending_sessions: usize,
481 pub total_connections: usize,
482 pub total_bytes_sent: u64,
483 pub total_bytes_received: u64,
484}
485
486#[cfg(test)]
487mod tests {
488 use super::*;
489 use crate::relay::AuthToken;
490 use ed25519_dalek::SigningKey;
491 use rand::rngs::OsRng;
492 use std::net::{IpAddr, Ipv4Addr};
493
494 fn test_addr() -> SocketAddr {
495 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080)
496 }
497
498 #[test]
499 fn test_session_manager_creation() {
500 let config = SessionConfig::default();
501 let (manager, _event_rx) = SessionManager::new(config);
502
503 let stats = manager.get_statistics();
504 assert_eq!(stats.total_sessions, 0);
505 assert_eq!(stats.active_sessions, 0);
506 }
507
508 #[test]
509 fn test_trusted_key_management() {
510 let config = SessionConfig::default();
511 let (manager, _event_rx) = SessionManager::new(config);
512
513 let signing_key = SigningKey::generate(&mut OsRng);
514 let verifying_key = signing_key.verifying_key();
515 let addr = test_addr();
516
517 manager.add_trusted_key(addr, verifying_key);
518
519 let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
521 let result = manager.request_session(addr, vec![1, 2, 3], auth_token);
522 assert!(result.is_ok());
523
524 manager.remove_trusted_key(&addr);
526
527 let auth_token2 = AuthToken::new(1024, 300, &signing_key).unwrap();
529 let result2 = manager.request_session(addr, vec![4, 5, 6], auth_token2);
530 assert!(result2.is_err());
531 }
532
533 #[test]
534 fn test_session_request_and_establishment() {
535 let config = SessionConfig::default();
536 let (manager, _event_rx) = SessionManager::new(config);
537
538 let signing_key = SigningKey::generate(&mut OsRng);
539 let verifying_key = signing_key.verifying_key();
540 let addr = test_addr();
541
542 manager.add_trusted_key(addr, verifying_key);
543
544 let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
546 let session_id = manager
547 .request_session(addr, vec![1, 2, 3], auth_token)
548 .unwrap();
549
550 let session = manager.get_session(session_id).unwrap();
552 assert_eq!(session.state, SessionState::Pending);
553 assert_eq!(session.client_addr, addr);
554
555 assert!(manager.establish_session(session_id).is_ok());
557
558 let session = manager.get_session(session_id).unwrap();
560 assert_eq!(session.state, SessionState::Active);
561 }
562
563 #[test]
564 fn test_session_limit() {
565 let mut config = SessionConfig::default();
566 config.max_sessions = 2;
567 let (manager, _event_rx) = SessionManager::new(config);
568
569 let signing_key = SigningKey::generate(&mut OsRng);
570 let verifying_key = signing_key.verifying_key();
571 let addr = test_addr();
572
573 manager.add_trusted_key(addr, verifying_key);
574
575 for i in 0..2 {
577 let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
578 let result = manager.request_session(addr, vec![i], auth_token);
579 assert!(result.is_ok());
580 }
581
582 let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
584 let result = manager.request_session(addr, vec![3], auth_token);
585 assert!(result.is_err());
586 }
587
588 #[test]
589 fn test_session_termination() {
590 let config = SessionConfig::default();
591 let (manager, _event_rx) = SessionManager::new(config);
592
593 let signing_key = SigningKey::generate(&mut OsRng);
594 let verifying_key = signing_key.verifying_key();
595 let addr = test_addr();
596
597 manager.add_trusted_key(addr, verifying_key);
598
599 let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
601 let session_id = manager
602 .request_session(addr, vec![1, 2, 3], auth_token)
603 .unwrap();
604 manager.establish_session(session_id).unwrap();
605
606 let reason = "Test termination".to_string();
608 assert!(manager.terminate_session(session_id, reason).is_ok());
609
610 let session = manager.get_session(session_id).unwrap();
612 assert_eq!(session.state, SessionState::Terminated);
613 }
614
615 #[test]
616 fn test_data_forwarding() {
617 let config = SessionConfig::default();
618 let (manager, _event_rx) = SessionManager::new(config);
619
620 let signing_key = SigningKey::generate(&mut OsRng);
621 let verifying_key = signing_key.verifying_key();
622 let addr = test_addr();
623
624 manager.add_trusted_key(addr, verifying_key);
625
626 let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
628 let session_id = manager
629 .request_session(addr, vec![1, 2, 3], auth_token)
630 .unwrap();
631 manager.establish_session(session_id).unwrap();
632
633 let data = vec![1, 2, 3, 4, 5];
635 assert!(
636 manager
637 .forward_data(session_id, data.clone(), ForwardDirection::ClientToPeer)
638 .is_ok()
639 );
640 assert!(
641 manager
642 .forward_data(session_id, data, ForwardDirection::PeerToClient)
643 .is_ok()
644 );
645
646 let session = manager.get_session(session_id).unwrap();
648 assert_eq!(session.bytes_sent, 5);
649 assert_eq!(session.bytes_received, 5);
650 }
651
652 #[test]
653 fn test_session_cleanup() {
654 let mut config = SessionConfig::default();
655 config.cleanup_interval = Duration::from_millis(1);
656 let (manager, _event_rx) = SessionManager::new(config);
657
658 let signing_key = SigningKey::generate(&mut OsRng);
659 let verifying_key = signing_key.verifying_key();
660 let addr = test_addr();
661
662 manager.add_trusted_key(addr, verifying_key);
663
664 let auth_token = AuthToken::new(1024, 1, &signing_key).unwrap(); let _session_id = manager
667 .request_session(addr, vec![1, 2, 3], auth_token)
668 .unwrap();
669
670 assert_eq!(manager.session_count(), 1);
671
672 std::thread::sleep(Duration::from_millis(2));
674
675 let cleanup_count = manager.cleanup_expired_sessions().unwrap();
677 assert!(cleanup_count > 0);
678 }
679
680 #[test]
681 fn test_session_id_generation() {
682 let config = SessionConfig::default();
683 let (manager, _event_rx) = SessionManager::new(config);
684
685 let signing_key = SigningKey::generate(&mut OsRng);
686 let verifying_key = signing_key.verifying_key();
687 let addr = test_addr();
688
689 manager.add_trusted_key(addr, verifying_key);
690
691 let mut session_ids = Vec::new();
693 for i in 0..10 {
694 let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
695 let session_id = manager.request_session(addr, vec![i], auth_token).unwrap();
696 session_ids.push(session_id);
697 }
698
699 for id in &session_ids {
701 assert!(*id != 0);
702 }
703
704 let unique_ids: std::collections::HashSet<_> = session_ids.iter().collect();
705 assert_eq!(unique_ids.len(), session_ids.len());
706 }
707}