1use crate::crypto::pqc::types::MlDsaPublicKey;
11use crate::relay::{
12 AuthToken, RelayAuthenticator, RelayConnection, RelayConnectionConfig, RelayError, RelayResult,
13};
14use std::collections::HashMap;
15use std::net::SocketAddr;
16use std::sync::{Arc, Mutex};
17use std::time::{Duration, Instant};
18use tokio::sync::mpsc;
19
20pub type SessionId = u32;
22
23#[derive(Debug, Clone)]
25pub struct SessionConfig {
26 pub max_sessions: usize,
28 pub default_timeout: Duration,
30 pub cleanup_interval: Duration,
32 pub default_bandwidth_limit: u64,
34}
35
36impl Default for SessionConfig {
37 fn default() -> Self {
38 Self {
39 max_sessions: 100,
40 default_timeout: Duration::from_secs(300), cleanup_interval: Duration::from_secs(30), default_bandwidth_limit: 1048576, }
44 }
45}
46
47#[derive(Debug, Clone, PartialEq, Eq)]
49pub enum SessionState {
50 Pending,
52 Active,
54 Terminating,
56 Terminated,
58 Failed {
60 reason: String,
62 },
63}
64
65#[derive(Debug, Clone)]
67pub struct RelaySessionInfo {
68 pub session_id: SessionId,
70 pub client_addr: SocketAddr,
72 pub peer_connection_id: Vec<u8>,
74 pub state: SessionState,
76 pub created_at: Instant,
78 pub last_activity: Instant,
80 pub bandwidth_limit: u64,
82 pub timeout: Duration,
84 pub bytes_sent: u64,
86 pub bytes_received: u64,
88}
89
90#[derive(Debug)]
92pub struct SessionManager {
93 config: SessionConfig,
95 sessions: Arc<Mutex<HashMap<SessionId, RelaySessionInfo>>>,
97 connections: Arc<Mutex<HashMap<SessionId, Arc<RelayConnection>>>>,
99 authenticator: RelayAuthenticator,
101 trusted_keys: Arc<Mutex<HashMap<SocketAddr, MlDsaPublicKey>>>,
103 next_session_id: Arc<Mutex<SessionId>>,
105 event_sender: mpsc::UnboundedSender<SessionEvent>,
107 last_cleanup: Arc<Mutex<Instant>>,
109}
110
111#[derive(Debug, Clone)]
113pub enum SessionEvent {
114 SessionRequested {
116 session_id: SessionId,
118 client_addr: SocketAddr,
120 peer_connection_id: Vec<u8>,
122 auth_token: AuthToken,
124 },
125 SessionEstablished {
127 session_id: SessionId,
129 client_addr: SocketAddr,
131 },
132 SessionTerminated {
134 session_id: SessionId,
136 reason: String,
138 },
139 SessionFailed {
141 session_id: SessionId,
143 error: RelayError,
145 },
146 DataForwarded {
148 session_id: SessionId,
150 bytes: usize,
152 direction: ForwardDirection,
154 },
155}
156
157#[derive(Debug, Clone, PartialEq, Eq)]
159pub enum ForwardDirection {
160 ClientToPeer,
162 PeerToClient,
164}
165
166impl SessionManager {
167 pub fn new(
172 config: SessionConfig,
173 ) -> RelayResult<(Self, mpsc::UnboundedReceiver<SessionEvent>)> {
174 let (event_sender, event_receiver) = mpsc::unbounded_channel();
175
176 let manager = Self {
177 config,
178 sessions: Arc::new(Mutex::new(HashMap::new())),
179 connections: Arc::new(Mutex::new(HashMap::new())),
180 authenticator: RelayAuthenticator::new()?,
181 trusted_keys: Arc::new(Mutex::new(HashMap::new())),
182 next_session_id: Arc::new(Mutex::new(1)),
183 event_sender,
184 last_cleanup: Arc::new(Mutex::new(Instant::now())),
185 };
186
187 Ok((manager, event_receiver))
188 }
189
190 #[allow(clippy::unwrap_used)]
192 pub fn add_trusted_key(&self, addr: SocketAddr, key: MlDsaPublicKey) {
193 let mut trusted_keys = self.trusted_keys.lock().unwrap();
194 trusted_keys.insert(addr, key);
195 }
196
197 #[allow(clippy::unwrap_used)]
199 pub fn remove_trusted_key(&self, addr: &SocketAddr) {
200 let mut trusted_keys = self.trusted_keys.lock().unwrap();
201 trusted_keys.remove(addr);
202 }
203
204 #[allow(clippy::unwrap_used)]
206 fn next_session_id(&self) -> SessionId {
207 let mut next_id = self.next_session_id.lock().unwrap();
208 let id = *next_id;
209 *next_id = next_id.wrapping_add(1);
210 if *next_id == 0 {
211 *next_id = 1; }
213 id
214 }
215
216 #[allow(clippy::unwrap_used)]
218 pub fn request_session(
219 &self,
220 client_addr: SocketAddr,
221 peer_connection_id: Vec<u8>,
222 auth_token: AuthToken,
223 ) -> RelayResult<SessionId> {
224 {
226 let sessions = self.sessions.lock().unwrap();
227 if sessions.len() >= self.config.max_sessions {
228 return Err(RelayError::ResourceExhausted {
229 resource_type: "sessions".to_string(),
230 current_usage: sessions.len() as u64,
231 limit: self.config.max_sessions as u64,
232 });
233 }
234 }
235
236 let trusted_keys = self.trusted_keys.lock().unwrap();
238 let peer_key =
239 trusted_keys
240 .get(&client_addr)
241 .ok_or_else(|| RelayError::AuthenticationFailed {
242 reason: format!("No trusted key for address {}", client_addr),
243 })?;
244
245 self.authenticator.verify_token(&auth_token, peer_key)?;
246
247 let session_id = self.next_session_id();
249
250 let now = Instant::now();
252 let session_info = RelaySessionInfo {
253 session_id,
254 client_addr,
255 peer_connection_id: peer_connection_id.clone(),
256 state: SessionState::Pending,
257 created_at: now,
258 last_activity: now,
259 bandwidth_limit: auth_token.bandwidth_limit as u64,
260 timeout: Duration::from_secs(auth_token.timeout_seconds as u64),
261 bytes_sent: 0,
262 bytes_received: 0,
263 };
264
265 {
267 let mut sessions = self.sessions.lock().unwrap();
268 sessions.insert(session_id, session_info);
269 }
270
271 let _ = self.event_sender.send(SessionEvent::SessionRequested {
273 session_id,
274 client_addr,
275 peer_connection_id,
276 auth_token,
277 });
278
279 Ok(session_id)
280 }
281
282 #[allow(clippy::unwrap_used)]
284 pub fn establish_session(&self, session_id: SessionId) -> RelayResult<()> {
285 let (client_addr, bandwidth_limit) = {
286 let mut sessions = self.sessions.lock().unwrap();
287 let session = sessions
288 .get_mut(&session_id)
289 .ok_or(RelayError::SessionError {
290 session_id: Some(session_id),
291 kind: crate::relay::error::SessionErrorKind::NotFound,
292 })?;
293
294 if session.state != SessionState::Pending {
295 return Err(RelayError::SessionError {
296 session_id: Some(session_id),
297 kind: crate::relay::error::SessionErrorKind::InvalidState {
298 current_state: format!("{:?}", session.state),
299 expected_state: "Pending".to_string(),
300 },
301 });
302 }
303
304 session.state = SessionState::Active;
305 session.last_activity = Instant::now();
306
307 (session.client_addr, session.bandwidth_limit)
308 };
309
310 let (event_tx, _event_rx) = mpsc::unbounded_channel();
312 let (_action_tx, action_rx) = mpsc::unbounded_channel();
313
314 let mut connection_config = RelayConnectionConfig::default();
315 connection_config.bandwidth_limit = bandwidth_limit;
316
317 let connection = RelayConnection::new(
318 session_id,
319 client_addr,
320 connection_config,
321 event_tx,
322 action_rx,
323 );
324
325 {
327 let mut connections = self.connections.lock().unwrap();
328 connections.insert(session_id, Arc::new(connection));
329 }
330
331 let _ = self.event_sender.send(SessionEvent::SessionEstablished {
333 session_id,
334 client_addr,
335 });
336
337 Ok(())
338 }
339
340 #[allow(clippy::unwrap_used)]
342 pub fn terminate_session(&self, session_id: SessionId, reason: String) -> RelayResult<()> {
343 {
345 let mut sessions = self.sessions.lock().unwrap();
346 if let Some(session) = sessions.get_mut(&session_id) {
347 session.state = SessionState::Terminated;
348 session.last_activity = Instant::now();
349 }
350 }
351
352 {
354 let mut connections = self.connections.lock().unwrap();
355 if let Some(connection) = connections.remove(&session_id) {
356 let _ = connection.terminate(reason.clone());
357 }
358 }
359
360 let _ = self
362 .event_sender
363 .send(SessionEvent::SessionTerminated { session_id, reason });
364
365 Ok(())
366 }
367
368 #[allow(clippy::unwrap_used)]
370 pub fn forward_data(
371 &self,
372 session_id: SessionId,
373 data: Vec<u8>,
374 direction: ForwardDirection,
375 ) -> RelayResult<()> {
376 let connection = {
377 let connections = self.connections.lock().unwrap();
378 connections
379 .get(&session_id)
380 .cloned()
381 .ok_or(RelayError::SessionError {
382 session_id: Some(session_id),
383 kind: crate::relay::error::SessionErrorKind::NotFound,
384 })?
385 };
386
387 match direction {
389 ForwardDirection::ClientToPeer => {
390 connection.send_data(data.clone())?;
391 }
392 ForwardDirection::PeerToClient => {
393 connection.receive_data(data.clone())?;
394 }
395 }
396
397 {
399 let mut sessions = self.sessions.lock().unwrap();
400 if let Some(session) = sessions.get_mut(&session_id) {
401 session.last_activity = Instant::now();
402 match direction {
403 ForwardDirection::ClientToPeer => {
404 session.bytes_sent += data.len() as u64;
405 }
406 ForwardDirection::PeerToClient => {
407 session.bytes_received += data.len() as u64;
408 }
409 }
410 }
411 }
412
413 let _ = self.event_sender.send(SessionEvent::DataForwarded {
415 session_id,
416 bytes: data.len(),
417 direction,
418 });
419
420 Ok(())
421 }
422
423 #[allow(clippy::unwrap_used)]
425 pub fn get_session(&self, session_id: SessionId) -> Option<RelaySessionInfo> {
426 let sessions = self.sessions.lock().unwrap();
427 sessions.get(&session_id).cloned()
428 }
429
430 #[allow(clippy::unwrap_used)]
432 pub fn list_sessions(&self) -> Vec<RelaySessionInfo> {
433 let sessions = self.sessions.lock().unwrap();
434 sessions.values().cloned().collect()
435 }
436
437 #[allow(clippy::unwrap_used)]
439 pub fn session_count(&self) -> usize {
440 let sessions = self.sessions.lock().unwrap();
441 sessions.len()
442 }
443
444 #[allow(clippy::unwrap_used)]
446 pub fn cleanup_expired_sessions(&self) -> RelayResult<usize> {
447 let mut last_cleanup = self.last_cleanup.lock().unwrap();
448 let now = Instant::now();
449
450 if now.duration_since(*last_cleanup) < self.config.cleanup_interval {
452 return Ok(0);
453 }
454
455 *last_cleanup = now;
456 drop(last_cleanup);
457
458 let mut expired_sessions = Vec::new();
459
460 {
462 let sessions = self.sessions.lock().unwrap();
463 for (session_id, session_info) in sessions.iter() {
464 let age = now.duration_since(session_info.last_activity);
465 if age > session_info.timeout {
466 expired_sessions.push(*session_id);
467 }
468 }
469 }
470
471 let cleanup_count = expired_sessions.len();
473 for session_id in expired_sessions {
474 let _ = self.terminate_session(session_id, "Session expired".to_string());
475
476 let mut sessions = self.sessions.lock().unwrap();
478 sessions.remove(&session_id);
479 }
480
481 Ok(cleanup_count)
482 }
483
484 #[allow(clippy::unwrap_used)]
486 pub fn get_statistics(&self) -> SessionManagerStats {
487 let sessions = self.sessions.lock().unwrap();
488 let connections = self.connections.lock().unwrap();
489
490 let mut active_count = 0;
491 let mut pending_count = 0;
492 let mut total_bytes_sent = 0;
493 let mut total_bytes_received = 0;
494
495 for session in sessions.values() {
496 match session.state {
497 SessionState::Active => active_count += 1,
498 SessionState::Pending => pending_count += 1,
499 _ => {}
500 }
501 total_bytes_sent += session.bytes_sent;
502 total_bytes_received += session.bytes_received;
503 }
504
505 SessionManagerStats {
506 total_sessions: sessions.len(),
507 active_sessions: active_count,
508 pending_sessions: pending_count,
509 total_connections: connections.len(),
510 total_bytes_sent,
511 total_bytes_received,
512 }
513 }
514}
515
516#[derive(Debug, Clone)]
518pub struct SessionManagerStats {
519 pub total_sessions: usize,
521 pub active_sessions: usize,
523 pub pending_sessions: usize,
525 pub total_connections: usize,
527 pub total_bytes_sent: u64,
529 pub total_bytes_received: u64,
531}
532
533#[cfg(test)]
534mod tests {
535 use super::*;
536 use crate::crypto::raw_public_keys::key_utils::generate_ml_dsa_keypair;
537 use crate::relay::AuthToken;
538 use std::net::{IpAddr, Ipv4Addr};
539
540 fn test_addr() -> SocketAddr {
541 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080)
542 }
543
544 #[test]
545 fn test_session_manager_creation() {
546 let config = SessionConfig::default();
547 let (manager, _event_rx) = SessionManager::new(config).unwrap();
548
549 let stats = manager.get_statistics();
550 assert_eq!(stats.total_sessions, 0);
551 assert_eq!(stats.active_sessions, 0);
552 }
553
554 #[test]
555 fn test_trusted_key_management() {
556 let config = SessionConfig::default();
557 let (manager, _event_rx) = SessionManager::new(config).unwrap();
558
559 let (public_key, secret_key) = generate_ml_dsa_keypair().unwrap();
560 let addr = test_addr();
561
562 manager.add_trusted_key(addr, public_key.clone());
563
564 let auth_token = AuthToken::new(1024, 300, &secret_key).unwrap();
566 let result = manager.request_session(addr, vec![1, 2, 3], auth_token);
567 assert!(result.is_ok());
568
569 manager.remove_trusted_key(&addr);
571
572 let auth_token2 = AuthToken::new(1024, 300, &secret_key).unwrap();
574 let result2 = manager.request_session(addr, vec![4, 5, 6], auth_token2);
575 assert!(result2.is_err());
576 }
577
578 #[test]
579 fn test_session_request_and_establishment() {
580 let config = SessionConfig::default();
581 let (manager, _event_rx) = SessionManager::new(config).unwrap();
582
583 let (public_key, secret_key) = generate_ml_dsa_keypair().unwrap();
584 let addr = test_addr();
585
586 manager.add_trusted_key(addr, public_key);
587
588 let auth_token = AuthToken::new(1024, 300, &secret_key).unwrap();
590 let session_id = manager
591 .request_session(addr, vec![1, 2, 3], auth_token)
592 .unwrap();
593
594 let session = manager.get_session(session_id).unwrap();
596 assert_eq!(session.state, SessionState::Pending);
597 assert_eq!(session.client_addr, addr);
598
599 assert!(manager.establish_session(session_id).is_ok());
601
602 let session = manager.get_session(session_id).unwrap();
604 assert_eq!(session.state, SessionState::Active);
605 }
606
607 #[test]
608 fn test_session_limit() {
609 let mut config = SessionConfig::default();
610 config.max_sessions = 2;
611 let (manager, _event_rx) = SessionManager::new(config).unwrap();
612
613 let (public_key, secret_key) = generate_ml_dsa_keypair().unwrap();
614 let addr = test_addr();
615
616 manager.add_trusted_key(addr, public_key);
617
618 for i in 0..2 {
620 let auth_token = AuthToken::new(1024, 300, &secret_key).unwrap();
621 let result = manager.request_session(addr, vec![i], auth_token);
622 assert!(result.is_ok());
623 }
624
625 let auth_token = AuthToken::new(1024, 300, &secret_key).unwrap();
627 let result = manager.request_session(addr, vec![3], auth_token);
628 assert!(result.is_err());
629 }
630
631 #[test]
632 fn test_session_termination() {
633 let config = SessionConfig::default();
634 let (manager, _event_rx) = SessionManager::new(config).unwrap();
635
636 let (public_key, secret_key) = generate_ml_dsa_keypair().unwrap();
637 let addr = test_addr();
638
639 manager.add_trusted_key(addr, public_key);
640
641 let auth_token = AuthToken::new(1024, 300, &secret_key).unwrap();
643 let session_id = manager
644 .request_session(addr, vec![1, 2, 3], auth_token)
645 .unwrap();
646 manager.establish_session(session_id).unwrap();
647
648 let reason = "Test termination".to_string();
650 assert!(manager.terminate_session(session_id, reason).is_ok());
651
652 let session = manager.get_session(session_id).unwrap();
654 assert_eq!(session.state, SessionState::Terminated);
655 }
656
657 #[test]
658 fn test_data_forwarding() {
659 let config = SessionConfig::default();
660 let (manager, _event_rx) = SessionManager::new(config).unwrap();
661
662 let (public_key, secret_key) = generate_ml_dsa_keypair().unwrap();
663 let addr = test_addr();
664
665 manager.add_trusted_key(addr, public_key);
666
667 let auth_token = AuthToken::new(1024, 300, &secret_key).unwrap();
669 let session_id = manager
670 .request_session(addr, vec![1, 2, 3], auth_token)
671 .unwrap();
672 manager.establish_session(session_id).unwrap();
673
674 let data = vec![1, 2, 3, 4, 5];
676 assert!(
677 manager
678 .forward_data(session_id, data.clone(), ForwardDirection::ClientToPeer)
679 .is_ok()
680 );
681 assert!(
682 manager
683 .forward_data(session_id, data, ForwardDirection::PeerToClient)
684 .is_ok()
685 );
686
687 let session = manager.get_session(session_id).unwrap();
689 assert_eq!(session.bytes_sent, 5);
690 assert_eq!(session.bytes_received, 5);
691 }
692
693 #[test]
694 fn test_session_cleanup() {
695 let mut config = SessionConfig::default();
696 config.cleanup_interval = Duration::from_millis(1);
697 let (manager, _event_rx) = SessionManager::new(config).unwrap();
698
699 let (public_key, secret_key) = generate_ml_dsa_keypair().unwrap();
700 let addr = test_addr();
701
702 manager.add_trusted_key(addr, public_key);
703
704 let auth_token = AuthToken::new(1024, 0, &secret_key).unwrap(); let _session_id = manager
707 .request_session(addr, vec![1, 2, 3], auth_token)
708 .unwrap();
709
710 assert_eq!(manager.session_count(), 1);
711
712 std::thread::sleep(Duration::from_millis(10));
714
715 let cleanup_count = manager.cleanup_expired_sessions().unwrap();
717 assert!(cleanup_count > 0);
718 }
719
720 #[test]
721 fn test_session_id_generation() {
722 let config = SessionConfig::default();
723 let (manager, _event_rx) = SessionManager::new(config).unwrap();
724
725 let (public_key, secret_key) = generate_ml_dsa_keypair().unwrap();
726 let addr = test_addr();
727
728 manager.add_trusted_key(addr, public_key);
729
730 let mut session_ids = Vec::new();
732 for i in 0..10 {
733 let auth_token = AuthToken::new(1024, 300, &secret_key).unwrap();
734 let session_id = manager.request_session(addr, vec![i], auth_token).unwrap();
735 session_ids.push(session_id);
736 }
737
738 for id in &session_ids {
740 assert!(*id != 0);
741 }
742
743 let unique_ids: std::collections::HashSet<_> = session_ids.iter().collect();
744 assert_eq!(unique_ids.len(), session_ids.len());
745 }
746}