ant_quic/relay/
session_manager.rs

1//! Session management for relay connections with complete state machine.
2
3use crate::relay::{
4    AuthToken, RelayAuthenticator, RelayConnection, RelayConnectionConfig, RelayError, RelayResult,
5};
6use ed25519_dalek::VerifyingKey;
7use std::collections::HashMap;
8use std::net::SocketAddr;
9use std::sync::{Arc, Mutex};
10use std::time::{Duration, Instant};
11use tokio::sync::mpsc;
12
13/// Unique session identifier
14pub type SessionId = u32;
15
16/// Session configuration
17#[derive(Debug, Clone)]
18pub struct SessionConfig {
19    /// Maximum number of concurrent sessions
20    pub max_sessions: usize,
21    /// Default session timeout
22    pub default_timeout: Duration,
23    /// Session cleanup interval
24    pub cleanup_interval: Duration,
25    /// Default bandwidth limit per session
26    pub default_bandwidth_limit: u64,
27}
28
29impl Default for SessionConfig {
30    fn default() -> Self {
31        Self {
32            max_sessions: 100,
33            default_timeout: Duration::from_secs(300), // 5 minutes
34            cleanup_interval: Duration::from_secs(30), // 30 seconds
35            default_bandwidth_limit: 1048576,          // 1 MB/s
36        }
37    }
38}
39
40/// Session state in the relay state machine
41#[derive(Debug, Clone, PartialEq, Eq)]
42pub enum SessionState {
43    /// Session requested but not yet established
44    Pending,
45    /// Session active and forwarding data
46    Active,
47    /// Session terminating gracefully
48    Terminating,
49    /// Session terminated
50    Terminated,
51    /// Session failed due to error
52    Failed { reason: String },
53}
54
55/// Information about a relay session
56#[derive(Debug, Clone)]
57pub struct RelaySessionInfo {
58    /// Session identifier
59    pub session_id: SessionId,
60    /// Client address
61    pub client_addr: SocketAddr,
62    /// Target peer connection ID
63    pub peer_connection_id: Vec<u8>,
64    /// Current session state
65    pub state: SessionState,
66    /// Session creation time
67    pub created_at: Instant,
68    /// Last activity time
69    pub last_activity: Instant,
70    /// Bandwidth limit
71    pub bandwidth_limit: u64,
72    /// Session timeout
73    pub timeout: Duration,
74    /// Bytes transferred
75    pub bytes_sent: u64,
76    pub bytes_received: u64,
77}
78
79/// Session manager for handling relay connections
80#[derive(Debug)]
81pub struct SessionManager {
82    /// Configuration
83    config: SessionConfig,
84    /// Active sessions
85    sessions: Arc<Mutex<HashMap<SessionId, RelaySessionInfo>>>,
86    /// Active connections
87    connections: Arc<Mutex<HashMap<SessionId, Arc<RelayConnection>>>>,
88    /// Authenticator for token verification
89    authenticator: RelayAuthenticator,
90    /// Trusted peer keys for authentication
91    trusted_keys: Arc<Mutex<HashMap<SocketAddr, VerifyingKey>>>,
92    /// Next session ID
93    next_session_id: Arc<Mutex<SessionId>>,
94    /// Event channels
95    event_sender: mpsc::UnboundedSender<SessionEvent>,
96    /// Last cleanup time
97    last_cleanup: Arc<Mutex<Instant>>,
98}
99
100/// Events generated by session management
101#[derive(Debug, Clone)]
102pub enum SessionEvent {
103    /// New session requested
104    SessionRequested {
105        session_id: SessionId,
106        client_addr: SocketAddr,
107        peer_connection_id: Vec<u8>,
108        auth_token: AuthToken,
109    },
110    /// Session established successfully
111    SessionEstablished {
112        session_id: SessionId,
113        client_addr: SocketAddr,
114    },
115    /// Session terminated
116    SessionTerminated {
117        session_id: SessionId,
118        reason: String,
119    },
120    /// Session failed
121    SessionFailed {
122        session_id: SessionId,
123        error: RelayError,
124    },
125    /// Data forwarded through session
126    DataForwarded {
127        session_id: SessionId,
128        bytes: usize,
129        direction: ForwardDirection,
130    },
131}
132
133/// Direction of data forwarding
134#[derive(Debug, Clone, PartialEq, Eq)]
135pub enum ForwardDirection {
136    /// From client to peer
137    ClientToPeer,
138    /// From peer to client
139    PeerToClient,
140}
141
142impl SessionManager {
143    /// Create a new session manager
144    pub fn new(config: SessionConfig) -> (Self, mpsc::UnboundedReceiver<SessionEvent>) {
145        let (event_sender, event_receiver) = mpsc::unbounded_channel();
146
147        let manager = Self {
148            config,
149            sessions: Arc::new(Mutex::new(HashMap::new())),
150            connections: Arc::new(Mutex::new(HashMap::new())),
151            authenticator: RelayAuthenticator::new(),
152            trusted_keys: Arc::new(Mutex::new(HashMap::new())),
153            next_session_id: Arc::new(Mutex::new(1)),
154            event_sender,
155            last_cleanup: Arc::new(Mutex::new(Instant::now())),
156        };
157
158        (manager, event_receiver)
159    }
160
161    /// Add a trusted peer key for authentication
162    pub fn add_trusted_key(&self, addr: SocketAddr, key: VerifyingKey) {
163        let mut trusted_keys = self.trusted_keys.lock().unwrap();
164        trusted_keys.insert(addr, key);
165    }
166
167    /// Remove a trusted peer key
168    pub fn remove_trusted_key(&self, addr: &SocketAddr) {
169        let mut trusted_keys = self.trusted_keys.lock().unwrap();
170        trusted_keys.remove(addr);
171    }
172
173    /// Generate next session ID
174    fn next_session_id(&self) -> SessionId {
175        let mut next_id = self.next_session_id.lock().unwrap();
176        let id = *next_id;
177        *next_id = next_id.wrapping_add(1);
178        if *next_id == 0 {
179            *next_id = 1; // Skip 0 as invalid session ID
180        }
181        id
182    }
183
184    /// Request a new relay session
185    pub fn request_session(
186        &self,
187        client_addr: SocketAddr,
188        peer_connection_id: Vec<u8>,
189        auth_token: AuthToken,
190    ) -> RelayResult<SessionId> {
191        // Check session limit
192        {
193            let sessions = self.sessions.lock().unwrap();
194            if sessions.len() >= self.config.max_sessions {
195                return Err(RelayError::ResourceExhausted {
196                    resource_type: "sessions".to_string(),
197                    current_usage: sessions.len() as u64,
198                    limit: self.config.max_sessions as u64,
199                });
200            }
201        }
202
203        // Verify authentication token
204        let trusted_keys = self.trusted_keys.lock().unwrap();
205        let peer_key =
206            trusted_keys
207                .get(&client_addr)
208                .ok_or_else(|| RelayError::AuthenticationFailed {
209                    reason: format!("No trusted key for address {}", client_addr),
210                })?;
211
212        self.authenticator.verify_token(&auth_token, peer_key)?;
213
214        // Generate session ID
215        let session_id = self.next_session_id();
216
217        // Create session info
218        let now = Instant::now();
219        let session_info = RelaySessionInfo {
220            session_id,
221            client_addr,
222            peer_connection_id: peer_connection_id.clone(),
223            state: SessionState::Pending,
224            created_at: now,
225            last_activity: now,
226            bandwidth_limit: auth_token.bandwidth_limit as u64,
227            timeout: Duration::from_secs(auth_token.timeout_seconds as u64),
228            bytes_sent: 0,
229            bytes_received: 0,
230        };
231
232        // Store session
233        {
234            let mut sessions = self.sessions.lock().unwrap();
235            sessions.insert(session_id, session_info);
236        }
237
238        // Send event
239        let _ = self.event_sender.send(SessionEvent::SessionRequested {
240            session_id,
241            client_addr,
242            peer_connection_id,
243            auth_token,
244        });
245
246        Ok(session_id)
247    }
248
249    /// Establish a relay session
250    pub fn establish_session(&self, session_id: SessionId) -> RelayResult<()> {
251        let (client_addr, bandwidth_limit) = {
252            let mut sessions = self.sessions.lock().unwrap();
253            let session = sessions
254                .get_mut(&session_id)
255                .ok_or(RelayError::SessionError {
256                    session_id: Some(session_id),
257                    kind: crate::relay::error::SessionErrorKind::NotFound,
258                })?;
259
260            if session.state != SessionState::Pending {
261                return Err(RelayError::SessionError {
262                    session_id: Some(session_id),
263                    kind: crate::relay::error::SessionErrorKind::InvalidState {
264                        current_state: format!("{:?}", session.state),
265                        expected_state: "Pending".to_string(),
266                    },
267                });
268            }
269
270            session.state = SessionState::Active;
271            session.last_activity = Instant::now();
272
273            (session.client_addr, session.bandwidth_limit)
274        };
275
276        // Create relay connection
277        let (event_tx, _event_rx) = mpsc::unbounded_channel();
278        let (_action_tx, action_rx) = mpsc::unbounded_channel();
279
280        let mut connection_config = RelayConnectionConfig::default();
281        connection_config.bandwidth_limit = bandwidth_limit;
282
283        let connection = RelayConnection::new(
284            session_id,
285            client_addr,
286            connection_config,
287            event_tx,
288            action_rx,
289        );
290
291        // Store connection
292        {
293            let mut connections = self.connections.lock().unwrap();
294            connections.insert(session_id, Arc::new(connection));
295        }
296
297        // Send event
298        let _ = self.event_sender.send(SessionEvent::SessionEstablished {
299            session_id,
300            client_addr,
301        });
302
303        Ok(())
304    }
305
306    /// Terminate a relay session
307    pub fn terminate_session(&self, session_id: SessionId, reason: String) -> RelayResult<()> {
308        // Update session state
309        {
310            let mut sessions = self.sessions.lock().unwrap();
311            if let Some(session) = sessions.get_mut(&session_id) {
312                session.state = SessionState::Terminated;
313                session.last_activity = Instant::now();
314            }
315        }
316
317        // Terminate connection
318        {
319            let mut connections = self.connections.lock().unwrap();
320            if let Some(connection) = connections.remove(&session_id) {
321                let _ = connection.terminate(reason.clone());
322            }
323        }
324
325        // Send event
326        let _ = self
327            .event_sender
328            .send(SessionEvent::SessionTerminated { session_id, reason });
329
330        Ok(())
331    }
332
333    /// Forward data through a relay session
334    pub fn forward_data(
335        &self,
336        session_id: SessionId,
337        data: Vec<u8>,
338        direction: ForwardDirection,
339    ) -> RelayResult<()> {
340        let connection = {
341            let connections = self.connections.lock().unwrap();
342            connections
343                .get(&session_id)
344                .cloned()
345                .ok_or(RelayError::SessionError {
346                    session_id: Some(session_id),
347                    kind: crate::relay::error::SessionErrorKind::NotFound,
348                })?
349        };
350
351        // Forward data based on direction
352        match direction {
353            ForwardDirection::ClientToPeer => {
354                connection.send_data(data.clone())?;
355            }
356            ForwardDirection::PeerToClient => {
357                connection.receive_data(data.clone())?;
358            }
359        }
360
361        // Update session statistics
362        {
363            let mut sessions = self.sessions.lock().unwrap();
364            if let Some(session) = sessions.get_mut(&session_id) {
365                session.last_activity = Instant::now();
366                match direction {
367                    ForwardDirection::ClientToPeer => {
368                        session.bytes_sent += data.len() as u64;
369                    }
370                    ForwardDirection::PeerToClient => {
371                        session.bytes_received += data.len() as u64;
372                    }
373                }
374            }
375        }
376
377        // Send event
378        let _ = self.event_sender.send(SessionEvent::DataForwarded {
379            session_id,
380            bytes: data.len(),
381            direction,
382        });
383
384        Ok(())
385    }
386
387    /// Get session information
388    pub fn get_session(&self, session_id: SessionId) -> Option<RelaySessionInfo> {
389        let sessions = self.sessions.lock().unwrap();
390        sessions.get(&session_id).cloned()
391    }
392
393    /// List all active sessions
394    pub fn list_sessions(&self) -> Vec<RelaySessionInfo> {
395        let sessions = self.sessions.lock().unwrap();
396        sessions.values().cloned().collect()
397    }
398
399    /// Get session count
400    pub fn session_count(&self) -> usize {
401        let sessions = self.sessions.lock().unwrap();
402        sessions.len()
403    }
404
405    /// Clean up expired sessions
406    pub fn cleanup_expired_sessions(&self) -> RelayResult<usize> {
407        let mut last_cleanup = self.last_cleanup.lock().unwrap();
408        let now = Instant::now();
409
410        // Only cleanup if enough time has passed
411        if now.duration_since(*last_cleanup) < self.config.cleanup_interval {
412            return Ok(0);
413        }
414
415        *last_cleanup = now;
416        drop(last_cleanup);
417
418        let mut expired_sessions = Vec::new();
419
420        // Find expired sessions
421        {
422            let sessions = self.sessions.lock().unwrap();
423            for (session_id, session_info) in sessions.iter() {
424                let age = now.duration_since(session_info.last_activity);
425                if age > session_info.timeout {
426                    expired_sessions.push(*session_id);
427                }
428            }
429        }
430
431        // Remove expired sessions
432        let cleanup_count = expired_sessions.len();
433        for session_id in expired_sessions {
434            let _ = self.terminate_session(session_id, "Session expired".to_string());
435
436            // Remove from sessions map
437            let mut sessions = self.sessions.lock().unwrap();
438            sessions.remove(&session_id);
439        }
440
441        Ok(cleanup_count)
442    }
443
444    /// Get session manager statistics
445    pub fn get_statistics(&self) -> SessionManagerStats {
446        let sessions = self.sessions.lock().unwrap();
447        let connections = self.connections.lock().unwrap();
448
449        let mut active_count = 0;
450        let mut pending_count = 0;
451        let mut total_bytes_sent = 0;
452        let mut total_bytes_received = 0;
453
454        for session in sessions.values() {
455            match session.state {
456                SessionState::Active => active_count += 1,
457                SessionState::Pending => pending_count += 1,
458                _ => {}
459            }
460            total_bytes_sent += session.bytes_sent;
461            total_bytes_received += session.bytes_received;
462        }
463
464        SessionManagerStats {
465            total_sessions: sessions.len(),
466            active_sessions: active_count,
467            pending_sessions: pending_count,
468            total_connections: connections.len(),
469            total_bytes_sent,
470            total_bytes_received,
471        }
472    }
473}
474
475/// Session manager statistics
476#[derive(Debug, Clone)]
477pub struct SessionManagerStats {
478    pub total_sessions: usize,
479    pub active_sessions: usize,
480    pub pending_sessions: usize,
481    pub total_connections: usize,
482    pub total_bytes_sent: u64,
483    pub total_bytes_received: u64,
484}
485
486#[cfg(test)]
487mod tests {
488    use super::*;
489    use crate::relay::AuthToken;
490    use ed25519_dalek::SigningKey;
491    use rand::rngs::OsRng;
492    use std::net::{IpAddr, Ipv4Addr};
493
494    fn test_addr() -> SocketAddr {
495        SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080)
496    }
497
498    #[test]
499    fn test_session_manager_creation() {
500        let config = SessionConfig::default();
501        let (manager, _event_rx) = SessionManager::new(config);
502
503        let stats = manager.get_statistics();
504        assert_eq!(stats.total_sessions, 0);
505        assert_eq!(stats.active_sessions, 0);
506    }
507
508    #[test]
509    fn test_trusted_key_management() {
510        let config = SessionConfig::default();
511        let (manager, _event_rx) = SessionManager::new(config);
512
513        let signing_key = SigningKey::generate(&mut OsRng);
514        let verifying_key = signing_key.verifying_key();
515        let addr = test_addr();
516
517        manager.add_trusted_key(addr, verifying_key);
518
519        // Should be able to create a session with trusted key
520        let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
521        let result = manager.request_session(addr, vec![1, 2, 3], auth_token);
522        assert!(result.is_ok());
523
524        // Remove trusted key
525        manager.remove_trusted_key(&addr);
526
527        // Should fail without trusted key
528        let auth_token2 = AuthToken::new(1024, 300, &signing_key).unwrap();
529        let result2 = manager.request_session(addr, vec![4, 5, 6], auth_token2);
530        assert!(result2.is_err());
531    }
532
533    #[test]
534    fn test_session_request_and_establishment() {
535        let config = SessionConfig::default();
536        let (manager, _event_rx) = SessionManager::new(config);
537
538        let signing_key = SigningKey::generate(&mut OsRng);
539        let verifying_key = signing_key.verifying_key();
540        let addr = test_addr();
541
542        manager.add_trusted_key(addr, verifying_key);
543
544        // Request session
545        let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
546        let session_id = manager
547            .request_session(addr, vec![1, 2, 3], auth_token)
548            .unwrap();
549
550        // Check session exists and is pending
551        let session = manager.get_session(session_id).unwrap();
552        assert_eq!(session.state, SessionState::Pending);
553        assert_eq!(session.client_addr, addr);
554
555        // Establish session
556        assert!(manager.establish_session(session_id).is_ok());
557
558        // Check session is now active
559        let session = manager.get_session(session_id).unwrap();
560        assert_eq!(session.state, SessionState::Active);
561    }
562
563    #[test]
564    fn test_session_limit() {
565        let mut config = SessionConfig::default();
566        config.max_sessions = 2;
567        let (manager, _event_rx) = SessionManager::new(config);
568
569        let signing_key = SigningKey::generate(&mut OsRng);
570        let verifying_key = signing_key.verifying_key();
571        let addr = test_addr();
572
573        manager.add_trusted_key(addr, verifying_key);
574
575        // Create maximum sessions
576        for i in 0..2 {
577            let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
578            let result = manager.request_session(addr, vec![i], auth_token);
579            assert!(result.is_ok());
580        }
581
582        // Third session should fail
583        let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
584        let result = manager.request_session(addr, vec![3], auth_token);
585        assert!(result.is_err());
586    }
587
588    #[test]
589    fn test_session_termination() {
590        let config = SessionConfig::default();
591        let (manager, _event_rx) = SessionManager::new(config);
592
593        let signing_key = SigningKey::generate(&mut OsRng);
594        let verifying_key = signing_key.verifying_key();
595        let addr = test_addr();
596
597        manager.add_trusted_key(addr, verifying_key);
598
599        // Create and establish session
600        let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
601        let session_id = manager
602            .request_session(addr, vec![1, 2, 3], auth_token)
603            .unwrap();
604        manager.establish_session(session_id).unwrap();
605
606        // Terminate session
607        let reason = "Test termination".to_string();
608        assert!(manager.terminate_session(session_id, reason).is_ok());
609
610        // Check session is terminated
611        let session = manager.get_session(session_id).unwrap();
612        assert_eq!(session.state, SessionState::Terminated);
613    }
614
615    #[test]
616    fn test_data_forwarding() {
617        let config = SessionConfig::default();
618        let (manager, _event_rx) = SessionManager::new(config);
619
620        let signing_key = SigningKey::generate(&mut OsRng);
621        let verifying_key = signing_key.verifying_key();
622        let addr = test_addr();
623
624        manager.add_trusted_key(addr, verifying_key);
625
626        // Create and establish session
627        let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
628        let session_id = manager
629            .request_session(addr, vec![1, 2, 3], auth_token)
630            .unwrap();
631        manager.establish_session(session_id).unwrap();
632
633        // Forward data
634        let data = vec![1, 2, 3, 4, 5];
635        assert!(
636            manager
637                .forward_data(session_id, data.clone(), ForwardDirection::ClientToPeer)
638                .is_ok()
639        );
640        assert!(
641            manager
642                .forward_data(session_id, data, ForwardDirection::PeerToClient)
643                .is_ok()
644        );
645
646        // Check statistics updated
647        let session = manager.get_session(session_id).unwrap();
648        assert_eq!(session.bytes_sent, 5);
649        assert_eq!(session.bytes_received, 5);
650    }
651
652    #[test]
653    fn test_session_cleanup() {
654        let mut config = SessionConfig::default();
655        config.cleanup_interval = Duration::from_millis(1);
656        let (manager, _event_rx) = SessionManager::new(config);
657
658        let signing_key = SigningKey::generate(&mut OsRng);
659        let verifying_key = signing_key.verifying_key();
660        let addr = test_addr();
661
662        manager.add_trusted_key(addr, verifying_key);
663
664        // Create session with very short timeout
665        let auth_token = AuthToken::new(1024, 1, &signing_key).unwrap(); // 1 second timeout
666        let _session_id = manager
667            .request_session(addr, vec![1, 2, 3], auth_token)
668            .unwrap();
669
670        assert_eq!(manager.session_count(), 1);
671
672        // Wait for session to expire
673        std::thread::sleep(Duration::from_millis(2));
674
675        // Cleanup should remove expired session
676        let cleanup_count = manager.cleanup_expired_sessions().unwrap();
677        assert!(cleanup_count > 0);
678    }
679
680    #[test]
681    fn test_session_id_generation() {
682        let config = SessionConfig::default();
683        let (manager, _event_rx) = SessionManager::new(config);
684
685        let signing_key = SigningKey::generate(&mut OsRng);
686        let verifying_key = signing_key.verifying_key();
687        let addr = test_addr();
688
689        manager.add_trusted_key(addr, verifying_key);
690
691        // Generate multiple session IDs
692        let mut session_ids = Vec::new();
693        for i in 0..10 {
694            let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
695            let session_id = manager.request_session(addr, vec![i], auth_token).unwrap();
696            session_ids.push(session_id);
697        }
698
699        // All IDs should be unique and non-zero
700        for id in &session_ids {
701            assert!(*id != 0);
702        }
703
704        let unique_ids: std::collections::HashSet<_> = session_ids.iter().collect();
705        assert_eq!(unique_ids.len(), session_ids.len());
706    }
707}