1use crate::relay::{
4 RelayConnection, RelayConnectionConfig, RelayError, RelayResult,
5 AuthToken, RelayAuthenticator,
6};
7use ed25519_dalek::VerifyingKey;
8use std::collections::HashMap;
9use std::net::SocketAddr;
10use std::sync::{Arc, Mutex};
11use std::time::{Duration, Instant};
12use tokio::sync::mpsc;
13
14pub type SessionId = u32;
16
17#[derive(Debug, Clone)]
19pub struct SessionConfig {
20 pub max_sessions: usize,
22 pub default_timeout: Duration,
24 pub cleanup_interval: Duration,
26 pub default_bandwidth_limit: u64,
28}
29
30impl Default for SessionConfig {
31 fn default() -> Self {
32 Self {
33 max_sessions: 100,
34 default_timeout: Duration::from_secs(300), cleanup_interval: Duration::from_secs(30), default_bandwidth_limit: 1048576, }
38 }
39}
40
41#[derive(Debug, Clone, PartialEq, Eq)]
43pub enum SessionState {
44 Pending,
46 Active,
48 Terminating,
50 Terminated,
52 Failed { reason: String },
54}
55
56#[derive(Debug, Clone)]
58pub struct RelaySessionInfo {
59 pub session_id: SessionId,
61 pub client_addr: SocketAddr,
63 pub peer_connection_id: Vec<u8>,
65 pub state: SessionState,
67 pub created_at: Instant,
69 pub last_activity: Instant,
71 pub bandwidth_limit: u64,
73 pub timeout: Duration,
75 pub bytes_sent: u64,
77 pub bytes_received: u64,
78}
79
80#[derive(Debug)]
82pub struct SessionManager {
83 config: SessionConfig,
85 sessions: Arc<Mutex<HashMap<SessionId, RelaySessionInfo>>>,
87 connections: Arc<Mutex<HashMap<SessionId, Arc<RelayConnection>>>>,
89 authenticator: RelayAuthenticator,
91 trusted_keys: Arc<Mutex<HashMap<SocketAddr, VerifyingKey>>>,
93 next_session_id: Arc<Mutex<SessionId>>,
95 event_sender: mpsc::UnboundedSender<SessionEvent>,
97 last_cleanup: Arc<Mutex<Instant>>,
99}
100
101#[derive(Debug, Clone)]
103pub enum SessionEvent {
104 SessionRequested {
106 session_id: SessionId,
107 client_addr: SocketAddr,
108 peer_connection_id: Vec<u8>,
109 auth_token: AuthToken,
110 },
111 SessionEstablished {
113 session_id: SessionId,
114 client_addr: SocketAddr,
115 },
116 SessionTerminated {
118 session_id: SessionId,
119 reason: String,
120 },
121 SessionFailed {
123 session_id: SessionId,
124 error: RelayError,
125 },
126 DataForwarded {
128 session_id: SessionId,
129 bytes: usize,
130 direction: ForwardDirection,
131 },
132}
133
134#[derive(Debug, Clone, PartialEq, Eq)]
136pub enum ForwardDirection {
137 ClientToPeer,
139 PeerToClient,
141}
142
143impl SessionManager {
144 pub fn new(config: SessionConfig) -> (Self, mpsc::UnboundedReceiver<SessionEvent>) {
146 let (event_sender, event_receiver) = mpsc::unbounded_channel();
147
148 let manager = Self {
149 config,
150 sessions: Arc::new(Mutex::new(HashMap::new())),
151 connections: Arc::new(Mutex::new(HashMap::new())),
152 authenticator: RelayAuthenticator::new(),
153 trusted_keys: Arc::new(Mutex::new(HashMap::new())),
154 next_session_id: Arc::new(Mutex::new(1)),
155 event_sender,
156 last_cleanup: Arc::new(Mutex::new(Instant::now())),
157 };
158
159 (manager, event_receiver)
160 }
161
162 pub fn add_trusted_key(&self, addr: SocketAddr, key: VerifyingKey) {
164 let mut trusted_keys = self.trusted_keys.lock().unwrap();
165 trusted_keys.insert(addr, key);
166 }
167
168 pub fn remove_trusted_key(&self, addr: &SocketAddr) {
170 let mut trusted_keys = self.trusted_keys.lock().unwrap();
171 trusted_keys.remove(addr);
172 }
173
174 fn next_session_id(&self) -> SessionId {
176 let mut next_id = self.next_session_id.lock().unwrap();
177 let id = *next_id;
178 *next_id = next_id.wrapping_add(1);
179 if *next_id == 0 {
180 *next_id = 1; }
182 id
183 }
184
185 pub fn request_session(
187 &self,
188 client_addr: SocketAddr,
189 peer_connection_id: Vec<u8>,
190 auth_token: AuthToken,
191 ) -> RelayResult<SessionId> {
192 {
194 let sessions = self.sessions.lock().unwrap();
195 if sessions.len() >= self.config.max_sessions {
196 return Err(RelayError::ResourceExhausted {
197 resource_type: "sessions".to_string(),
198 current_usage: sessions.len() as u64,
199 limit: self.config.max_sessions as u64,
200 });
201 }
202 }
203
204 let trusted_keys = self.trusted_keys.lock().unwrap();
206 let peer_key = trusted_keys.get(&client_addr)
207 .ok_or_else(|| RelayError::AuthenticationFailed {
208 reason: format!("No trusted key for address {}", client_addr),
209 })?;
210
211 self.authenticator.verify_token(&auth_token, peer_key)?;
212
213 let session_id = self.next_session_id();
215
216 let now = Instant::now();
218 let session_info = RelaySessionInfo {
219 session_id,
220 client_addr,
221 peer_connection_id: peer_connection_id.clone(),
222 state: SessionState::Pending,
223 created_at: now,
224 last_activity: now,
225 bandwidth_limit: auth_token.bandwidth_limit as u64,
226 timeout: Duration::from_secs(auth_token.timeout_seconds as u64),
227 bytes_sent: 0,
228 bytes_received: 0,
229 };
230
231 {
233 let mut sessions = self.sessions.lock().unwrap();
234 sessions.insert(session_id, session_info);
235 }
236
237 let _ = self.event_sender.send(SessionEvent::SessionRequested {
239 session_id,
240 client_addr,
241 peer_connection_id,
242 auth_token,
243 });
244
245 Ok(session_id)
246 }
247
248 pub fn establish_session(&self, session_id: SessionId) -> RelayResult<()> {
250 let (client_addr, bandwidth_limit) = {
251 let mut sessions = self.sessions.lock().unwrap();
252 let session = sessions.get_mut(&session_id)
253 .ok_or_else(|| RelayError::SessionError {
254 session_id: Some(session_id),
255 kind: crate::relay::error::SessionErrorKind::NotFound,
256 })?;
257
258 if session.state != SessionState::Pending {
259 return Err(RelayError::SessionError {
260 session_id: Some(session_id),
261 kind: crate::relay::error::SessionErrorKind::InvalidState {
262 current_state: format!("{:?}", session.state),
263 expected_state: "Pending".to_string(),
264 },
265 });
266 }
267
268 session.state = SessionState::Active;
269 session.last_activity = Instant::now();
270
271 (session.client_addr, session.bandwidth_limit)
272 };
273
274 let (event_tx, _event_rx) = mpsc::unbounded_channel();
276 let (_action_tx, action_rx) = mpsc::unbounded_channel();
277
278 let mut connection_config = RelayConnectionConfig::default();
279 connection_config.bandwidth_limit = bandwidth_limit;
280
281 let connection = RelayConnection::new(
282 session_id,
283 client_addr,
284 connection_config,
285 event_tx,
286 action_rx,
287 );
288
289 {
291 let mut connections = self.connections.lock().unwrap();
292 connections.insert(session_id, Arc::new(connection));
293 }
294
295 let _ = self.event_sender.send(SessionEvent::SessionEstablished {
297 session_id,
298 client_addr,
299 });
300
301 Ok(())
302 }
303
304 pub fn terminate_session(&self, session_id: SessionId, reason: String) -> RelayResult<()> {
306 {
308 let mut sessions = self.sessions.lock().unwrap();
309 if let Some(session) = sessions.get_mut(&session_id) {
310 session.state = SessionState::Terminated;
311 session.last_activity = Instant::now();
312 }
313 }
314
315 {
317 let mut connections = self.connections.lock().unwrap();
318 if let Some(connection) = connections.remove(&session_id) {
319 let _ = connection.terminate(reason.clone());
320 }
321 }
322
323 let _ = self.event_sender.send(SessionEvent::SessionTerminated {
325 session_id,
326 reason,
327 });
328
329 Ok(())
330 }
331
332 pub fn forward_data(
334 &self,
335 session_id: SessionId,
336 data: Vec<u8>,
337 direction: ForwardDirection,
338 ) -> RelayResult<()> {
339 let connection = {
340 let connections = self.connections.lock().unwrap();
341 connections.get(&session_id).cloned()
342 .ok_or_else(|| RelayError::SessionError {
343 session_id: Some(session_id),
344 kind: crate::relay::error::SessionErrorKind::NotFound,
345 })?
346 };
347
348 match direction {
350 ForwardDirection::ClientToPeer => {
351 connection.send_data(data.clone())?;
352 }
353 ForwardDirection::PeerToClient => {
354 connection.receive_data(data.clone())?;
355 }
356 }
357
358 {
360 let mut sessions = self.sessions.lock().unwrap();
361 if let Some(session) = sessions.get_mut(&session_id) {
362 session.last_activity = Instant::now();
363 match direction {
364 ForwardDirection::ClientToPeer => {
365 session.bytes_sent += data.len() as u64;
366 }
367 ForwardDirection::PeerToClient => {
368 session.bytes_received += data.len() as u64;
369 }
370 }
371 }
372 }
373
374 let _ = self.event_sender.send(SessionEvent::DataForwarded {
376 session_id,
377 bytes: data.len(),
378 direction,
379 });
380
381 Ok(())
382 }
383
384 pub fn get_session(&self, session_id: SessionId) -> Option<RelaySessionInfo> {
386 let sessions = self.sessions.lock().unwrap();
387 sessions.get(&session_id).cloned()
388 }
389
390 pub fn list_sessions(&self) -> Vec<RelaySessionInfo> {
392 let sessions = self.sessions.lock().unwrap();
393 sessions.values().cloned().collect()
394 }
395
396 pub fn session_count(&self) -> usize {
398 let sessions = self.sessions.lock().unwrap();
399 sessions.len()
400 }
401
402 pub fn cleanup_expired_sessions(&self) -> RelayResult<usize> {
404 let mut last_cleanup = self.last_cleanup.lock().unwrap();
405 let now = Instant::now();
406
407 if now.duration_since(*last_cleanup) < self.config.cleanup_interval {
409 return Ok(0);
410 }
411
412 *last_cleanup = now;
413 drop(last_cleanup);
414
415 let mut expired_sessions = Vec::new();
416
417 {
419 let sessions = self.sessions.lock().unwrap();
420 for (session_id, session_info) in sessions.iter() {
421 let age = now.duration_since(session_info.last_activity);
422 if age > session_info.timeout {
423 expired_sessions.push(*session_id);
424 }
425 }
426 }
427
428 let cleanup_count = expired_sessions.len();
430 for session_id in expired_sessions {
431 let _ = self.terminate_session(session_id, "Session expired".to_string());
432
433 let mut sessions = self.sessions.lock().unwrap();
435 sessions.remove(&session_id);
436 }
437
438 Ok(cleanup_count)
439 }
440
441 pub fn get_statistics(&self) -> SessionManagerStats {
443 let sessions = self.sessions.lock().unwrap();
444 let connections = self.connections.lock().unwrap();
445
446 let mut active_count = 0;
447 let mut pending_count = 0;
448 let mut total_bytes_sent = 0;
449 let mut total_bytes_received = 0;
450
451 for session in sessions.values() {
452 match session.state {
453 SessionState::Active => active_count += 1,
454 SessionState::Pending => pending_count += 1,
455 _ => {}
456 }
457 total_bytes_sent += session.bytes_sent;
458 total_bytes_received += session.bytes_received;
459 }
460
461 SessionManagerStats {
462 total_sessions: sessions.len(),
463 active_sessions: active_count,
464 pending_sessions: pending_count,
465 total_connections: connections.len(),
466 total_bytes_sent,
467 total_bytes_received,
468 }
469 }
470}
471
472#[derive(Debug, Clone)]
474pub struct SessionManagerStats {
475 pub total_sessions: usize,
476 pub active_sessions: usize,
477 pub pending_sessions: usize,
478 pub total_connections: usize,
479 pub total_bytes_sent: u64,
480 pub total_bytes_received: u64,
481}
482
483#[cfg(test)]
484mod tests {
485 use super::*;
486 use crate::relay::AuthToken;
487 use ed25519_dalek::SigningKey;
488 use rand::rngs::OsRng;
489 use std::net::{IpAddr, Ipv4Addr};
490
491 fn test_addr() -> SocketAddr {
492 SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080)
493 }
494
495 #[test]
496 fn test_session_manager_creation() {
497 let config = SessionConfig::default();
498 let (manager, _event_rx) = SessionManager::new(config);
499
500 let stats = manager.get_statistics();
501 assert_eq!(stats.total_sessions, 0);
502 assert_eq!(stats.active_sessions, 0);
503 }
504
505 #[test]
506 fn test_trusted_key_management() {
507 let config = SessionConfig::default();
508 let (manager, _event_rx) = SessionManager::new(config);
509
510 let signing_key = SigningKey::generate(&mut OsRng);
511 let verifying_key = signing_key.verifying_key();
512 let addr = test_addr();
513
514 manager.add_trusted_key(addr, verifying_key);
515
516 let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
518 let result = manager.request_session(addr, vec![1, 2, 3], auth_token);
519 assert!(result.is_ok());
520
521 manager.remove_trusted_key(&addr);
523
524 let auth_token2 = AuthToken::new(1024, 300, &signing_key).unwrap();
526 let result2 = manager.request_session(addr, vec![4, 5, 6], auth_token2);
527 assert!(result2.is_err());
528 }
529
530 #[test]
531 fn test_session_request_and_establishment() {
532 let config = SessionConfig::default();
533 let (manager, mut event_rx) = SessionManager::new(config);
534
535 let signing_key = SigningKey::generate(&mut OsRng);
536 let verifying_key = signing_key.verifying_key();
537 let addr = test_addr();
538
539 manager.add_trusted_key(addr, verifying_key);
540
541 let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
543 let session_id = manager.request_session(addr, vec![1, 2, 3], auth_token).unwrap();
544
545 let session = manager.get_session(session_id).unwrap();
547 assert_eq!(session.state, SessionState::Pending);
548 assert_eq!(session.client_addr, addr);
549
550 assert!(manager.establish_session(session_id).is_ok());
552
553 let session = manager.get_session(session_id).unwrap();
555 assert_eq!(session.state, SessionState::Active);
556 }
557
558 #[test]
559 fn test_session_limit() {
560 let mut config = SessionConfig::default();
561 config.max_sessions = 2;
562 let (manager, _event_rx) = SessionManager::new(config);
563
564 let signing_key = SigningKey::generate(&mut OsRng);
565 let verifying_key = signing_key.verifying_key();
566 let addr = test_addr();
567
568 manager.add_trusted_key(addr, verifying_key);
569
570 for i in 0..2 {
572 let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
573 let result = manager.request_session(addr, vec![i], auth_token);
574 assert!(result.is_ok());
575 }
576
577 let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
579 let result = manager.request_session(addr, vec![3], auth_token);
580 assert!(result.is_err());
581 }
582
583 #[test]
584 fn test_session_termination() {
585 let config = SessionConfig::default();
586 let (manager, mut event_rx) = SessionManager::new(config);
587
588 let signing_key = SigningKey::generate(&mut OsRng);
589 let verifying_key = signing_key.verifying_key();
590 let addr = test_addr();
591
592 manager.add_trusted_key(addr, verifying_key);
593
594 let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
596 let session_id = manager.request_session(addr, vec![1, 2, 3], auth_token).unwrap();
597 manager.establish_session(session_id).unwrap();
598
599 let reason = "Test termination".to_string();
601 assert!(manager.terminate_session(session_id, reason).is_ok());
602
603 let session = manager.get_session(session_id).unwrap();
605 assert_eq!(session.state, SessionState::Terminated);
606 }
607
608 #[test]
609 fn test_data_forwarding() {
610 let config = SessionConfig::default();
611 let (manager, _event_rx) = SessionManager::new(config);
612
613 let signing_key = SigningKey::generate(&mut OsRng);
614 let verifying_key = signing_key.verifying_key();
615 let addr = test_addr();
616
617 manager.add_trusted_key(addr, verifying_key);
618
619 let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
621 let session_id = manager.request_session(addr, vec![1, 2, 3], auth_token).unwrap();
622 manager.establish_session(session_id).unwrap();
623
624 let data = vec![1, 2, 3, 4, 5];
626 assert!(manager.forward_data(session_id, data.clone(), ForwardDirection::ClientToPeer).is_ok());
627 assert!(manager.forward_data(session_id, data, ForwardDirection::PeerToClient).is_ok());
628
629 let session = manager.get_session(session_id).unwrap();
631 assert_eq!(session.bytes_sent, 5);
632 assert_eq!(session.bytes_received, 5);
633 }
634
635 #[test]
636 fn test_session_cleanup() {
637 let mut config = SessionConfig::default();
638 config.cleanup_interval = Duration::from_millis(1);
639 let (manager, _event_rx) = SessionManager::new(config);
640
641 let signing_key = SigningKey::generate(&mut OsRng);
642 let verifying_key = signing_key.verifying_key();
643 let addr = test_addr();
644
645 manager.add_trusted_key(addr, verifying_key);
646
647 let mut auth_token = AuthToken::new(1024, 1, &signing_key).unwrap(); let session_id = manager.request_session(addr, vec![1, 2, 3], auth_token).unwrap();
650
651 assert_eq!(manager.session_count(), 1);
652
653 std::thread::sleep(Duration::from_millis(2));
655
656 let cleanup_count = manager.cleanup_expired_sessions().unwrap();
658 assert!(cleanup_count > 0);
659 }
660
661 #[test]
662 fn test_session_id_generation() {
663 let config = SessionConfig::default();
664 let (manager, _event_rx) = SessionManager::new(config);
665
666 let signing_key = SigningKey::generate(&mut OsRng);
667 let verifying_key = signing_key.verifying_key();
668 let addr = test_addr();
669
670 manager.add_trusted_key(addr, verifying_key);
671
672 let mut session_ids = Vec::new();
674 for i in 0..10 {
675 let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
676 let session_id = manager.request_session(addr, vec![i], auth_token).unwrap();
677 session_ids.push(session_id);
678 }
679
680 for id in &session_ids {
682 assert!(*id != 0);
683 }
684
685 let mut unique_ids: std::collections::HashSet<_> = session_ids.iter().collect();
686 assert_eq!(unique_ids.len(), session_ids.len());
687 }
688}