1use crate::relay::{
11 AuthToken, RelayAuthenticator, RelayConnection, RelayConnectionConfig, RelayError, RelayResult,
12};
13use ed25519_dalek::VerifyingKey;
14use std::collections::HashMap;
15use std::net::SocketAddr;
16use std::sync::{Arc, Mutex};
17use std::time::{Duration, Instant};
18use tokio::sync::mpsc;
19
20pub type SessionId = u32;
22
23#[derive(Debug, Clone)]
25pub struct SessionConfig {
26 pub max_sessions: usize,
28 pub default_timeout: Duration,
30 pub cleanup_interval: Duration,
32 pub default_bandwidth_limit: u64,
34}
35
36impl Default for SessionConfig {
37 fn default() -> Self {
38 Self {
39 max_sessions: 100,
40 default_timeout: Duration::from_secs(300), cleanup_interval: Duration::from_secs(30), default_bandwidth_limit: 1048576, }
44 }
45}
46
47#[derive(Debug, Clone, PartialEq, Eq)]
49pub enum SessionState {
50 Pending,
52 Active,
54 Terminating,
56 Terminated,
58 Failed {
60 reason: String,
62 },
63}
64
65#[derive(Debug, Clone)]
67pub struct RelaySessionInfo {
68 pub session_id: SessionId,
70 pub client_addr: SocketAddr,
72 pub peer_connection_id: Vec<u8>,
74 pub state: SessionState,
76 pub created_at: Instant,
78 pub last_activity: Instant,
80 pub bandwidth_limit: u64,
82 pub timeout: Duration,
84 pub bytes_sent: u64,
86 pub bytes_received: u64,
88}
89
90#[derive(Debug)]
92pub struct SessionManager {
93 config: SessionConfig,
95 sessions: Arc<Mutex<HashMap<SessionId, RelaySessionInfo>>>,
97 connections: Arc<Mutex<HashMap<SessionId, Arc<RelayConnection>>>>,
99 authenticator: RelayAuthenticator,
101 trusted_keys: Arc<Mutex<HashMap<SocketAddr, VerifyingKey>>>,
103 next_session_id: Arc<Mutex<SessionId>>,
105 event_sender: mpsc::UnboundedSender<SessionEvent>,
107 last_cleanup: Arc<Mutex<Instant>>,
109}
110
111#[derive(Debug, Clone)]
113pub enum SessionEvent {
114 SessionRequested {
116 session_id: SessionId,
118 client_addr: SocketAddr,
120 peer_connection_id: Vec<u8>,
122 auth_token: AuthToken,
124 },
125 SessionEstablished {
127 session_id: SessionId,
129 client_addr: SocketAddr,
131 },
132 SessionTerminated {
134 session_id: SessionId,
136 reason: String,
138 },
139 SessionFailed {
141 session_id: SessionId,
143 error: RelayError,
145 },
146 DataForwarded {
148 session_id: SessionId,
150 bytes: usize,
152 direction: ForwardDirection,
154 },
155}
156
157#[derive(Debug, Clone, PartialEq, Eq)]
159pub enum ForwardDirection {
160 ClientToPeer,
162 PeerToClient,
164}
165
166impl SessionManager {
167 pub fn new(config: SessionConfig) -> (Self, mpsc::UnboundedReceiver<SessionEvent>) {
169 let (event_sender, event_receiver) = mpsc::unbounded_channel();
170
171 let manager = Self {
172 config,
173 sessions: Arc::new(Mutex::new(HashMap::new())),
174 connections: Arc::new(Mutex::new(HashMap::new())),
175 authenticator: RelayAuthenticator::new(),
176 trusted_keys: Arc::new(Mutex::new(HashMap::new())),
177 next_session_id: Arc::new(Mutex::new(1)),
178 event_sender,
179 last_cleanup: Arc::new(Mutex::new(Instant::now())),
180 };
181
182 (manager, event_receiver)
183 }
184
185 #[allow(clippy::unwrap_used)]
187 pub fn add_trusted_key(&self, addr: SocketAddr, key: VerifyingKey) {
188 let mut trusted_keys = self.trusted_keys.lock().unwrap();
189 trusted_keys.insert(addr, key);
190 }
191
192 #[allow(clippy::unwrap_used)]
194 pub fn remove_trusted_key(&self, addr: &SocketAddr) {
195 let mut trusted_keys = self.trusted_keys.lock().unwrap();
196 trusted_keys.remove(addr);
197 }
198
199 #[allow(clippy::unwrap_used)]
201 fn next_session_id(&self) -> SessionId {
202 let mut next_id = self.next_session_id.lock().unwrap();
203 let id = *next_id;
204 *next_id = next_id.wrapping_add(1);
205 if *next_id == 0 {
206 *next_id = 1; }
208 id
209 }
210
211 #[allow(clippy::unwrap_used)]
213 pub fn request_session(
214 &self,
215 client_addr: SocketAddr,
216 peer_connection_id: Vec<u8>,
217 auth_token: AuthToken,
218 ) -> RelayResult<SessionId> {
219 {
221 let sessions = self.sessions.lock().unwrap();
222 if sessions.len() >= self.config.max_sessions {
223 return Err(RelayError::ResourceExhausted {
224 resource_type: "sessions".to_string(),
225 current_usage: sessions.len() as u64,
226 limit: self.config.max_sessions as u64,
227 });
228 }
229 }
230
231 let trusted_keys = self.trusted_keys.lock().unwrap();
233 let peer_key =
234 trusted_keys
235 .get(&client_addr)
236 .ok_or_else(|| RelayError::AuthenticationFailed {
237 reason: format!("No trusted key for address {}", client_addr),
238 })?;
239
240 self.authenticator.verify_token(&auth_token, peer_key)?;
241
242 let session_id = self.next_session_id();
244
245 let now = Instant::now();
247 let session_info = RelaySessionInfo {
248 session_id,
249 client_addr,
250 peer_connection_id: peer_connection_id.clone(),
251 state: SessionState::Pending,
252 created_at: now,
253 last_activity: now,
254 bandwidth_limit: auth_token.bandwidth_limit as u64,
255 timeout: Duration::from_secs(auth_token.timeout_seconds as u64),
256 bytes_sent: 0,
257 bytes_received: 0,
258 };
259
260 {
262 let mut sessions = self.sessions.lock().unwrap();
263 sessions.insert(session_id, session_info);
264 }
265
266 let _ = self.event_sender.send(SessionEvent::SessionRequested {
268 session_id,
269 client_addr,
270 peer_connection_id,
271 auth_token,
272 });
273
274 Ok(session_id)
275 }
276
277 #[allow(clippy::unwrap_used)]
279 pub fn establish_session(&self, session_id: SessionId) -> RelayResult<()> {
280 let (client_addr, bandwidth_limit) = {
281 let mut sessions = self.sessions.lock().unwrap();
282 let session = sessions
283 .get_mut(&session_id)
284 .ok_or(RelayError::SessionError {
285 session_id: Some(session_id),
286 kind: crate::relay::error::SessionErrorKind::NotFound,
287 })?;
288
289 if session.state != SessionState::Pending {
290 return Err(RelayError::SessionError {
291 session_id: Some(session_id),
292 kind: crate::relay::error::SessionErrorKind::InvalidState {
293 current_state: format!("{:?}", session.state),
294 expected_state: "Pending".to_string(),
295 },
296 });
297 }
298
299 session.state = SessionState::Active;
300 session.last_activity = Instant::now();
301
302 (session.client_addr, session.bandwidth_limit)
303 };
304
305 let (event_tx, _event_rx) = mpsc::unbounded_channel();
307 let (_action_tx, action_rx) = mpsc::unbounded_channel();
308
309 let mut connection_config = RelayConnectionConfig::default();
310 connection_config.bandwidth_limit = bandwidth_limit;
311
312 let connection = RelayConnection::new(
313 session_id,
314 client_addr,
315 connection_config,
316 event_tx,
317 action_rx,
318 );
319
320 {
322 let mut connections = self.connections.lock().unwrap();
323 connections.insert(session_id, Arc::new(connection));
324 }
325
326 let _ = self.event_sender.send(SessionEvent::SessionEstablished {
328 session_id,
329 client_addr,
330 });
331
332 Ok(())
333 }
334
335 #[allow(clippy::unwrap_used)]
337 pub fn terminate_session(&self, session_id: SessionId, reason: String) -> RelayResult<()> {
338 {
340 let mut sessions = self.sessions.lock().unwrap();
341 if let Some(session) = sessions.get_mut(&session_id) {
342 session.state = SessionState::Terminated;
343 session.last_activity = Instant::now();
344 }
345 }
346
347 {
349 let mut connections = self.connections.lock().unwrap();
350 if let Some(connection) = connections.remove(&session_id) {
351 let _ = connection.terminate(reason.clone());
352 }
353 }
354
355 let _ = self
357 .event_sender
358 .send(SessionEvent::SessionTerminated { session_id, reason });
359
360 Ok(())
361 }
362
363 #[allow(clippy::unwrap_used)]
365 pub fn forward_data(
366 &self,
367 session_id: SessionId,
368 data: Vec<u8>,
369 direction: ForwardDirection,
370 ) -> RelayResult<()> {
371 let connection = {
372 let connections = self.connections.lock().unwrap();
373 connections
374 .get(&session_id)
375 .cloned()
376 .ok_or(RelayError::SessionError {
377 session_id: Some(session_id),
378 kind: crate::relay::error::SessionErrorKind::NotFound,
379 })?
380 };
381
382 match direction {
384 ForwardDirection::ClientToPeer => {
385 connection.send_data(data.clone())?;
386 }
387 ForwardDirection::PeerToClient => {
388 connection.receive_data(data.clone())?;
389 }
390 }
391
392 {
394 let mut sessions = self.sessions.lock().unwrap();
395 if let Some(session) = sessions.get_mut(&session_id) {
396 session.last_activity = Instant::now();
397 match direction {
398 ForwardDirection::ClientToPeer => {
399 session.bytes_sent += data.len() as u64;
400 }
401 ForwardDirection::PeerToClient => {
402 session.bytes_received += data.len() as u64;
403 }
404 }
405 }
406 }
407
408 let _ = self.event_sender.send(SessionEvent::DataForwarded {
410 session_id,
411 bytes: data.len(),
412 direction,
413 });
414
415 Ok(())
416 }
417
418 #[allow(clippy::unwrap_used)]
420 pub fn get_session(&self, session_id: SessionId) -> Option<RelaySessionInfo> {
421 let sessions = self.sessions.lock().unwrap();
422 sessions.get(&session_id).cloned()
423 }
424
425 #[allow(clippy::unwrap_used)]
427 pub fn list_sessions(&self) -> Vec<RelaySessionInfo> {
428 let sessions = self.sessions.lock().unwrap();
429 sessions.values().cloned().collect()
430 }
431
432 #[allow(clippy::unwrap_used)]
434 pub fn session_count(&self) -> usize {
435 let sessions = self.sessions.lock().unwrap();
436 sessions.len()
437 }
438
439 #[allow(clippy::unwrap_used)]
441 pub fn cleanup_expired_sessions(&self) -> RelayResult<usize> {
442 let mut last_cleanup = self.last_cleanup.lock().unwrap();
443 let now = Instant::now();
444
445 if now.duration_since(*last_cleanup) < self.config.cleanup_interval {
447 return Ok(0);
448 }
449
450 *last_cleanup = now;
451 drop(last_cleanup);
452
453 let mut expired_sessions = Vec::new();
454
455 {
457 let sessions = self.sessions.lock().unwrap();
458 for (session_id, session_info) in sessions.iter() {
459 let age = now.duration_since(session_info.last_activity);
460 if age > session_info.timeout {
461 expired_sessions.push(*session_id);
462 }
463 }
464 }
465
466 let cleanup_count = expired_sessions.len();
468 for session_id in expired_sessions {
469 let _ = self.terminate_session(session_id, "Session expired".to_string());
470
471 let mut sessions = self.sessions.lock().unwrap();
473 sessions.remove(&session_id);
474 }
475
476 Ok(cleanup_count)
477 }
478
479 #[allow(clippy::unwrap_used)]
481 pub fn get_statistics(&self) -> SessionManagerStats {
482 let sessions = self.sessions.lock().unwrap();
483 let connections = self.connections.lock().unwrap();
484
485 let mut active_count = 0;
486 let mut pending_count = 0;
487 let mut total_bytes_sent = 0;
488 let mut total_bytes_received = 0;
489
490 for session in sessions.values() {
491 match session.state {
492 SessionState::Active => active_count += 1,
493 SessionState::Pending => pending_count += 1,
494 _ => {}
495 }
496 total_bytes_sent += session.bytes_sent;
497 total_bytes_received += session.bytes_received;
498 }
499
500 SessionManagerStats {
501 total_sessions: sessions.len(),
502 active_sessions: active_count,
503 pending_sessions: pending_count,
504 total_connections: connections.len(),
505 total_bytes_sent,
506 total_bytes_received,
507 }
508 }
509}
510
511#[derive(Debug, Clone)]
513pub struct SessionManagerStats {
514 pub total_sessions: usize,
516 pub active_sessions: usize,
518 pub pending_sessions: usize,
520 pub total_connections: usize,
522 pub total_bytes_sent: u64,
524 pub total_bytes_received: u64,
526}
527
528#[cfg(test)]
529mod tests {
530 use super::*;
531 use crate::relay::AuthToken;
532 use ed25519_dalek::SigningKey;
533 use rand::rngs::OsRng;
534 use std::net::{IpAddr, Ipv4Addr};
535
536 fn test_addr() -> SocketAddr {
537 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080)
538 }
539
540 #[test]
541 fn test_session_manager_creation() {
542 let config = SessionConfig::default();
543 let (manager, _event_rx) = SessionManager::new(config);
544
545 let stats = manager.get_statistics();
546 assert_eq!(stats.total_sessions, 0);
547 assert_eq!(stats.active_sessions, 0);
548 }
549
550 #[test]
551 fn test_trusted_key_management() {
552 let config = SessionConfig::default();
553 let (manager, _event_rx) = SessionManager::new(config);
554
555 let signing_key = SigningKey::generate(&mut OsRng);
556 let verifying_key = signing_key.verifying_key();
557 let addr = test_addr();
558
559 manager.add_trusted_key(addr, verifying_key);
560
561 let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
563 let result = manager.request_session(addr, vec![1, 2, 3], auth_token);
564 assert!(result.is_ok());
565
566 manager.remove_trusted_key(&addr);
568
569 let auth_token2 = AuthToken::new(1024, 300, &signing_key).unwrap();
571 let result2 = manager.request_session(addr, vec![4, 5, 6], auth_token2);
572 assert!(result2.is_err());
573 }
574
575 #[test]
576 fn test_session_request_and_establishment() {
577 let config = SessionConfig::default();
578 let (manager, _event_rx) = SessionManager::new(config);
579
580 let signing_key = SigningKey::generate(&mut OsRng);
581 let verifying_key = signing_key.verifying_key();
582 let addr = test_addr();
583
584 manager.add_trusted_key(addr, verifying_key);
585
586 let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
588 let session_id = manager
589 .request_session(addr, vec![1, 2, 3], auth_token)
590 .unwrap();
591
592 let session = manager.get_session(session_id).unwrap();
594 assert_eq!(session.state, SessionState::Pending);
595 assert_eq!(session.client_addr, addr);
596
597 assert!(manager.establish_session(session_id).is_ok());
599
600 let session = manager.get_session(session_id).unwrap();
602 assert_eq!(session.state, SessionState::Active);
603 }
604
605 #[test]
606 fn test_session_limit() {
607 let mut config = SessionConfig::default();
608 config.max_sessions = 2;
609 let (manager, _event_rx) = SessionManager::new(config);
610
611 let signing_key = SigningKey::generate(&mut OsRng);
612 let verifying_key = signing_key.verifying_key();
613 let addr = test_addr();
614
615 manager.add_trusted_key(addr, verifying_key);
616
617 for i in 0..2 {
619 let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
620 let result = manager.request_session(addr, vec![i], auth_token);
621 assert!(result.is_ok());
622 }
623
624 let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
626 let result = manager.request_session(addr, vec![3], auth_token);
627 assert!(result.is_err());
628 }
629
630 #[test]
631 fn test_session_termination() {
632 let config = SessionConfig::default();
633 let (manager, _event_rx) = SessionManager::new(config);
634
635 let signing_key = SigningKey::generate(&mut OsRng);
636 let verifying_key = signing_key.verifying_key();
637 let addr = test_addr();
638
639 manager.add_trusted_key(addr, verifying_key);
640
641 let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
643 let session_id = manager
644 .request_session(addr, vec![1, 2, 3], auth_token)
645 .unwrap();
646 manager.establish_session(session_id).unwrap();
647
648 let reason = "Test termination".to_string();
650 assert!(manager.terminate_session(session_id, reason).is_ok());
651
652 let session = manager.get_session(session_id).unwrap();
654 assert_eq!(session.state, SessionState::Terminated);
655 }
656
657 #[test]
658 fn test_data_forwarding() {
659 let config = SessionConfig::default();
660 let (manager, _event_rx) = SessionManager::new(config);
661
662 let signing_key = SigningKey::generate(&mut OsRng);
663 let verifying_key = signing_key.verifying_key();
664 let addr = test_addr();
665
666 manager.add_trusted_key(addr, verifying_key);
667
668 let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
670 let session_id = manager
671 .request_session(addr, vec![1, 2, 3], auth_token)
672 .unwrap();
673 manager.establish_session(session_id).unwrap();
674
675 let data = vec![1, 2, 3, 4, 5];
677 assert!(
678 manager
679 .forward_data(session_id, data.clone(), ForwardDirection::ClientToPeer)
680 .is_ok()
681 );
682 assert!(
683 manager
684 .forward_data(session_id, data, ForwardDirection::PeerToClient)
685 .is_ok()
686 );
687
688 let session = manager.get_session(session_id).unwrap();
690 assert_eq!(session.bytes_sent, 5);
691 assert_eq!(session.bytes_received, 5);
692 }
693
694 #[test]
695 fn test_session_cleanup() {
696 let mut config = SessionConfig::default();
697 config.cleanup_interval = Duration::from_millis(1);
698 let (manager, _event_rx) = SessionManager::new(config);
699
700 let signing_key = SigningKey::generate(&mut OsRng);
701 let verifying_key = signing_key.verifying_key();
702 let addr = test_addr();
703
704 manager.add_trusted_key(addr, verifying_key);
705
706 let auth_token = AuthToken::new(1024, 0, &signing_key).unwrap(); let _session_id = manager
709 .request_session(addr, vec![1, 2, 3], auth_token)
710 .unwrap();
711
712 assert_eq!(manager.session_count(), 1);
713
714 std::thread::sleep(Duration::from_millis(10));
716
717 let cleanup_count = manager.cleanup_expired_sessions().unwrap();
719 assert!(cleanup_count > 0);
720 }
721
722 #[test]
723 fn test_session_id_generation() {
724 let config = SessionConfig::default();
725 let (manager, _event_rx) = SessionManager::new(config);
726
727 let signing_key = SigningKey::generate(&mut OsRng);
728 let verifying_key = signing_key.verifying_key();
729 let addr = test_addr();
730
731 manager.add_trusted_key(addr, verifying_key);
732
733 let mut session_ids = Vec::new();
735 for i in 0..10 {
736 let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
737 let session_id = manager.request_session(addr, vec![i], auth_token).unwrap();
738 session_ids.push(session_id);
739 }
740
741 for id in &session_ids {
743 assert!(*id != 0);
744 }
745
746 let unique_ids: std::collections::HashSet<_> = session_ids.iter().collect();
747 assert_eq!(unique_ids.len(), session_ids.len());
748 }
749}