ant_quic/relay/
session_manager.rs

1//! Session management for relay connections with complete state machine.
2
3use crate::relay::{
4    RelayConnection, RelayConnectionConfig, RelayError, RelayResult,
5    AuthToken, RelayAuthenticator,
6};
7use ed25519_dalek::VerifyingKey;
8use std::collections::HashMap;
9use std::net::SocketAddr;
10use std::sync::{Arc, Mutex};
11use std::time::{Duration, Instant};
12use tokio::sync::mpsc;
13
14/// Unique session identifier
15pub type SessionId = u32;
16
17/// Session configuration
18#[derive(Debug, Clone)]
19pub struct SessionConfig {
20    /// Maximum number of concurrent sessions
21    pub max_sessions: usize,
22    /// Default session timeout
23    pub default_timeout: Duration,
24    /// Session cleanup interval
25    pub cleanup_interval: Duration,
26    /// Default bandwidth limit per session
27    pub default_bandwidth_limit: u64,
28}
29
30impl Default for SessionConfig {
31    fn default() -> Self {
32        Self {
33            max_sessions: 100,
34            default_timeout: Duration::from_secs(300), // 5 minutes
35            cleanup_interval: Duration::from_secs(30), // 30 seconds
36            default_bandwidth_limit: 1048576, // 1 MB/s
37        }
38    }
39}
40
41/// Session state in the relay state machine
42#[derive(Debug, Clone, PartialEq, Eq)]
43pub enum SessionState {
44    /// Session requested but not yet established
45    Pending,
46    /// Session active and forwarding data
47    Active,
48    /// Session terminating gracefully
49    Terminating,
50    /// Session terminated
51    Terminated,
52    /// Session failed due to error
53    Failed { reason: String },
54}
55
56/// Information about a relay session
57#[derive(Debug, Clone)]
58pub struct RelaySessionInfo {
59    /// Session identifier
60    pub session_id: SessionId,
61    /// Client address
62    pub client_addr: SocketAddr,
63    /// Target peer connection ID
64    pub peer_connection_id: Vec<u8>,
65    /// Current session state
66    pub state: SessionState,
67    /// Session creation time
68    pub created_at: Instant,
69    /// Last activity time
70    pub last_activity: Instant,
71    /// Bandwidth limit
72    pub bandwidth_limit: u64,
73    /// Session timeout
74    pub timeout: Duration,
75    /// Bytes transferred
76    pub bytes_sent: u64,
77    pub bytes_received: u64,
78}
79
80/// Session manager for handling relay connections
81#[derive(Debug)]
82pub struct SessionManager {
83    /// Configuration
84    config: SessionConfig,
85    /// Active sessions
86    sessions: Arc<Mutex<HashMap<SessionId, RelaySessionInfo>>>,
87    /// Active connections
88    connections: Arc<Mutex<HashMap<SessionId, Arc<RelayConnection>>>>,
89    /// Authenticator for token verification
90    authenticator: RelayAuthenticator,
91    /// Trusted peer keys for authentication
92    trusted_keys: Arc<Mutex<HashMap<SocketAddr, VerifyingKey>>>,
93    /// Next session ID
94    next_session_id: Arc<Mutex<SessionId>>,
95    /// Event channels
96    event_sender: mpsc::UnboundedSender<SessionEvent>,
97    /// Last cleanup time
98    last_cleanup: Arc<Mutex<Instant>>,
99}
100
101/// Events generated by session management
102#[derive(Debug, Clone)]
103pub enum SessionEvent {
104    /// New session requested
105    SessionRequested {
106        session_id: SessionId,
107        client_addr: SocketAddr,
108        peer_connection_id: Vec<u8>,
109        auth_token: AuthToken,
110    },
111    /// Session established successfully
112    SessionEstablished {
113        session_id: SessionId,
114        client_addr: SocketAddr,
115    },
116    /// Session terminated
117    SessionTerminated {
118        session_id: SessionId,
119        reason: String,
120    },
121    /// Session failed
122    SessionFailed {
123        session_id: SessionId,
124        error: RelayError,
125    },
126    /// Data forwarded through session
127    DataForwarded {
128        session_id: SessionId,
129        bytes: usize,
130        direction: ForwardDirection,
131    },
132}
133
134/// Direction of data forwarding
135#[derive(Debug, Clone, PartialEq, Eq)]
136pub enum ForwardDirection {
137    /// From client to peer
138    ClientToPeer,
139    /// From peer to client
140    PeerToClient,
141}
142
143impl SessionManager {
144    /// Create a new session manager
145    pub fn new(config: SessionConfig) -> (Self, mpsc::UnboundedReceiver<SessionEvent>) {
146        let (event_sender, event_receiver) = mpsc::unbounded_channel();
147        
148        let manager = Self {
149            config,
150            sessions: Arc::new(Mutex::new(HashMap::new())),
151            connections: Arc::new(Mutex::new(HashMap::new())),
152            authenticator: RelayAuthenticator::new(),
153            trusted_keys: Arc::new(Mutex::new(HashMap::new())),
154            next_session_id: Arc::new(Mutex::new(1)),
155            event_sender,
156            last_cleanup: Arc::new(Mutex::new(Instant::now())),
157        };
158        
159        (manager, event_receiver)
160    }
161
162    /// Add a trusted peer key for authentication
163    pub fn add_trusted_key(&self, addr: SocketAddr, key: VerifyingKey) {
164        let mut trusted_keys = self.trusted_keys.lock().unwrap();
165        trusted_keys.insert(addr, key);
166    }
167
168    /// Remove a trusted peer key
169    pub fn remove_trusted_key(&self, addr: &SocketAddr) {
170        let mut trusted_keys = self.trusted_keys.lock().unwrap();
171        trusted_keys.remove(addr);
172    }
173
174    /// Generate next session ID
175    fn next_session_id(&self) -> SessionId {
176        let mut next_id = self.next_session_id.lock().unwrap();
177        let id = *next_id;
178        *next_id = next_id.wrapping_add(1);
179        if *next_id == 0 {
180            *next_id = 1; // Skip 0 as invalid session ID
181        }
182        id
183    }
184
185    /// Request a new relay session
186    pub fn request_session(
187        &self,
188        client_addr: SocketAddr,
189        peer_connection_id: Vec<u8>,
190        auth_token: AuthToken,
191    ) -> RelayResult<SessionId> {
192        // Check session limit
193        {
194            let sessions = self.sessions.lock().unwrap();
195            if sessions.len() >= self.config.max_sessions {
196                return Err(RelayError::ResourceExhausted {
197                    resource_type: "sessions".to_string(),
198                    current_usage: sessions.len() as u64,
199                    limit: self.config.max_sessions as u64,
200                });
201            }
202        }
203
204        // Verify authentication token
205        let trusted_keys = self.trusted_keys.lock().unwrap();
206        let peer_key = trusted_keys.get(&client_addr)
207            .ok_or_else(|| RelayError::AuthenticationFailed {
208                reason: format!("No trusted key for address {}", client_addr),
209            })?;
210
211        self.authenticator.verify_token(&auth_token, peer_key)?;
212
213        // Generate session ID
214        let session_id = self.next_session_id();
215
216        // Create session info
217        let now = Instant::now();
218        let session_info = RelaySessionInfo {
219            session_id,
220            client_addr,
221            peer_connection_id: peer_connection_id.clone(),
222            state: SessionState::Pending,
223            created_at: now,
224            last_activity: now,
225            bandwidth_limit: auth_token.bandwidth_limit as u64,
226            timeout: Duration::from_secs(auth_token.timeout_seconds as u64),
227            bytes_sent: 0,
228            bytes_received: 0,
229        };
230
231        // Store session
232        {
233            let mut sessions = self.sessions.lock().unwrap();
234            sessions.insert(session_id, session_info);
235        }
236
237        // Send event
238        let _ = self.event_sender.send(SessionEvent::SessionRequested {
239            session_id,
240            client_addr,
241            peer_connection_id,
242            auth_token,
243        });
244
245        Ok(session_id)
246    }
247
248    /// Establish a relay session
249    pub fn establish_session(&self, session_id: SessionId) -> RelayResult<()> {
250        let (client_addr, bandwidth_limit) = {
251            let mut sessions = self.sessions.lock().unwrap();
252            let session = sessions.get_mut(&session_id)
253                .ok_or_else(|| RelayError::SessionError {
254                    session_id: Some(session_id),
255                    kind: crate::relay::error::SessionErrorKind::NotFound,
256                })?;
257
258            if session.state != SessionState::Pending {
259                return Err(RelayError::SessionError {
260                    session_id: Some(session_id),
261                    kind: crate::relay::error::SessionErrorKind::InvalidState {
262                        current_state: format!("{:?}", session.state),
263                        expected_state: "Pending".to_string(),
264                    },
265                });
266            }
267
268            session.state = SessionState::Active;
269            session.last_activity = Instant::now();
270            
271            (session.client_addr, session.bandwidth_limit)
272        };
273
274        // Create relay connection
275        let (event_tx, _event_rx) = mpsc::unbounded_channel();
276        let (_action_tx, action_rx) = mpsc::unbounded_channel();
277        
278        let mut connection_config = RelayConnectionConfig::default();
279        connection_config.bandwidth_limit = bandwidth_limit;
280        
281        let connection = RelayConnection::new(
282            session_id,
283            client_addr,
284            connection_config,
285            event_tx,
286            action_rx,
287        );
288
289        // Store connection
290        {
291            let mut connections = self.connections.lock().unwrap();
292            connections.insert(session_id, Arc::new(connection));
293        }
294
295        // Send event
296        let _ = self.event_sender.send(SessionEvent::SessionEstablished {
297            session_id,
298            client_addr,
299        });
300
301        Ok(())
302    }
303
304    /// Terminate a relay session
305    pub fn terminate_session(&self, session_id: SessionId, reason: String) -> RelayResult<()> {
306        // Update session state
307        {
308            let mut sessions = self.sessions.lock().unwrap();
309            if let Some(session) = sessions.get_mut(&session_id) {
310                session.state = SessionState::Terminated;
311                session.last_activity = Instant::now();
312            }
313        }
314
315        // Terminate connection
316        {
317            let mut connections = self.connections.lock().unwrap();
318            if let Some(connection) = connections.remove(&session_id) {
319                let _ = connection.terminate(reason.clone());
320            }
321        }
322
323        // Send event
324        let _ = self.event_sender.send(SessionEvent::SessionTerminated {
325            session_id,
326            reason,
327        });
328
329        Ok(())
330    }
331
332    /// Forward data through a relay session
333    pub fn forward_data(
334        &self,
335        session_id: SessionId,
336        data: Vec<u8>,
337        direction: ForwardDirection,
338    ) -> RelayResult<()> {
339        let connection = {
340            let connections = self.connections.lock().unwrap();
341            connections.get(&session_id).cloned()
342                .ok_or_else(|| RelayError::SessionError {
343                    session_id: Some(session_id),
344                    kind: crate::relay::error::SessionErrorKind::NotFound,
345                })?
346        };
347
348        // Forward data based on direction
349        match direction {
350            ForwardDirection::ClientToPeer => {
351                connection.send_data(data.clone())?;
352            }
353            ForwardDirection::PeerToClient => {
354                connection.receive_data(data.clone())?;
355            }
356        }
357
358        // Update session statistics
359        {
360            let mut sessions = self.sessions.lock().unwrap();
361            if let Some(session) = sessions.get_mut(&session_id) {
362                session.last_activity = Instant::now();
363                match direction {
364                    ForwardDirection::ClientToPeer => {
365                        session.bytes_sent += data.len() as u64;
366                    }
367                    ForwardDirection::PeerToClient => {
368                        session.bytes_received += data.len() as u64;
369                    }
370                }
371            }
372        }
373
374        // Send event
375        let _ = self.event_sender.send(SessionEvent::DataForwarded {
376            session_id,
377            bytes: data.len(),
378            direction,
379        });
380
381        Ok(())
382    }
383
384    /// Get session information
385    pub fn get_session(&self, session_id: SessionId) -> Option<RelaySessionInfo> {
386        let sessions = self.sessions.lock().unwrap();
387        sessions.get(&session_id).cloned()
388    }
389
390    /// List all active sessions
391    pub fn list_sessions(&self) -> Vec<RelaySessionInfo> {
392        let sessions = self.sessions.lock().unwrap();
393        sessions.values().cloned().collect()
394    }
395
396    /// Get session count
397    pub fn session_count(&self) -> usize {
398        let sessions = self.sessions.lock().unwrap();
399        sessions.len()
400    }
401
402    /// Clean up expired sessions
403    pub fn cleanup_expired_sessions(&self) -> RelayResult<usize> {
404        let mut last_cleanup = self.last_cleanup.lock().unwrap();
405        let now = Instant::now();
406        
407        // Only cleanup if enough time has passed
408        if now.duration_since(*last_cleanup) < self.config.cleanup_interval {
409            return Ok(0);
410        }
411        
412        *last_cleanup = now;
413        drop(last_cleanup);
414
415        let mut expired_sessions = Vec::new();
416        
417        // Find expired sessions
418        {
419            let sessions = self.sessions.lock().unwrap();
420            for (session_id, session_info) in sessions.iter() {
421                let age = now.duration_since(session_info.last_activity);
422                if age > session_info.timeout {
423                    expired_sessions.push(*session_id);
424                }
425            }
426        }
427
428        // Remove expired sessions
429        let cleanup_count = expired_sessions.len();
430        for session_id in expired_sessions {
431            let _ = self.terminate_session(session_id, "Session expired".to_string());
432            
433            // Remove from sessions map
434            let mut sessions = self.sessions.lock().unwrap();
435            sessions.remove(&session_id);
436        }
437
438        Ok(cleanup_count)
439    }
440
441    /// Get session manager statistics
442    pub fn get_statistics(&self) -> SessionManagerStats {
443        let sessions = self.sessions.lock().unwrap();
444        let connections = self.connections.lock().unwrap();
445        
446        let mut active_count = 0;
447        let mut pending_count = 0;
448        let mut total_bytes_sent = 0;
449        let mut total_bytes_received = 0;
450        
451        for session in sessions.values() {
452            match session.state {
453                SessionState::Active => active_count += 1,
454                SessionState::Pending => pending_count += 1,
455                _ => {}
456            }
457            total_bytes_sent += session.bytes_sent;
458            total_bytes_received += session.bytes_received;
459        }
460        
461        SessionManagerStats {
462            total_sessions: sessions.len(),
463            active_sessions: active_count,
464            pending_sessions: pending_count,
465            total_connections: connections.len(),
466            total_bytes_sent,
467            total_bytes_received,
468        }
469    }
470}
471
472/// Session manager statistics
473#[derive(Debug, Clone)]
474pub struct SessionManagerStats {
475    pub total_sessions: usize,
476    pub active_sessions: usize,
477    pub pending_sessions: usize,
478    pub total_connections: usize,
479    pub total_bytes_sent: u64,
480    pub total_bytes_received: u64,
481}
482
483#[cfg(test)]
484mod tests {
485    use super::*;
486    use crate::relay::AuthToken;
487    use ed25519_dalek::SigningKey;
488    use rand::rngs::OsRng;
489    use std::net::{IpAddr, Ipv4Addr};
490
491    fn test_addr() -> SocketAddr {
492        SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080)
493    }
494
495    #[test]
496    fn test_session_manager_creation() {
497        let config = SessionConfig::default();
498        let (manager, _event_rx) = SessionManager::new(config);
499        
500        let stats = manager.get_statistics();
501        assert_eq!(stats.total_sessions, 0);
502        assert_eq!(stats.active_sessions, 0);
503    }
504
505    #[test]
506    fn test_trusted_key_management() {
507        let config = SessionConfig::default();
508        let (manager, _event_rx) = SessionManager::new(config);
509        
510        let signing_key = SigningKey::generate(&mut OsRng);
511        let verifying_key = signing_key.verifying_key();
512        let addr = test_addr();
513        
514        manager.add_trusted_key(addr, verifying_key);
515        
516        // Should be able to create a session with trusted key
517        let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
518        let result = manager.request_session(addr, vec![1, 2, 3], auth_token);
519        assert!(result.is_ok());
520        
521        // Remove trusted key
522        manager.remove_trusted_key(&addr);
523        
524        // Should fail without trusted key
525        let auth_token2 = AuthToken::new(1024, 300, &signing_key).unwrap();
526        let result2 = manager.request_session(addr, vec![4, 5, 6], auth_token2);
527        assert!(result2.is_err());
528    }
529
530    #[test]
531    fn test_session_request_and_establishment() {
532        let config = SessionConfig::default();
533        let (manager, mut event_rx) = SessionManager::new(config);
534        
535        let signing_key = SigningKey::generate(&mut OsRng);
536        let verifying_key = signing_key.verifying_key();
537        let addr = test_addr();
538        
539        manager.add_trusted_key(addr, verifying_key);
540        
541        // Request session
542        let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
543        let session_id = manager.request_session(addr, vec![1, 2, 3], auth_token).unwrap();
544        
545        // Check session exists and is pending
546        let session = manager.get_session(session_id).unwrap();
547        assert_eq!(session.state, SessionState::Pending);
548        assert_eq!(session.client_addr, addr);
549        
550        // Establish session
551        assert!(manager.establish_session(session_id).is_ok());
552        
553        // Check session is now active
554        let session = manager.get_session(session_id).unwrap();
555        assert_eq!(session.state, SessionState::Active);
556    }
557
558    #[test]
559    fn test_session_limit() {
560        let mut config = SessionConfig::default();
561        config.max_sessions = 2;
562        let (manager, _event_rx) = SessionManager::new(config);
563        
564        let signing_key = SigningKey::generate(&mut OsRng);
565        let verifying_key = signing_key.verifying_key();
566        let addr = test_addr();
567        
568        manager.add_trusted_key(addr, verifying_key);
569        
570        // Create maximum sessions
571        for i in 0..2 {
572            let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
573            let result = manager.request_session(addr, vec![i], auth_token);
574            assert!(result.is_ok());
575        }
576        
577        // Third session should fail
578        let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
579        let result = manager.request_session(addr, vec![3], auth_token);
580        assert!(result.is_err());
581    }
582
583    #[test]
584    fn test_session_termination() {
585        let config = SessionConfig::default();
586        let (manager, mut event_rx) = SessionManager::new(config);
587        
588        let signing_key = SigningKey::generate(&mut OsRng);
589        let verifying_key = signing_key.verifying_key();
590        let addr = test_addr();
591        
592        manager.add_trusted_key(addr, verifying_key);
593        
594        // Create and establish session
595        let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
596        let session_id = manager.request_session(addr, vec![1, 2, 3], auth_token).unwrap();
597        manager.establish_session(session_id).unwrap();
598        
599        // Terminate session
600        let reason = "Test termination".to_string();
601        assert!(manager.terminate_session(session_id, reason).is_ok());
602        
603        // Check session is terminated
604        let session = manager.get_session(session_id).unwrap();
605        assert_eq!(session.state, SessionState::Terminated);
606    }
607
608    #[test]
609    fn test_data_forwarding() {
610        let config = SessionConfig::default();
611        let (manager, _event_rx) = SessionManager::new(config);
612        
613        let signing_key = SigningKey::generate(&mut OsRng);
614        let verifying_key = signing_key.verifying_key();
615        let addr = test_addr();
616        
617        manager.add_trusted_key(addr, verifying_key);
618        
619        // Create and establish session
620        let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
621        let session_id = manager.request_session(addr, vec![1, 2, 3], auth_token).unwrap();
622        manager.establish_session(session_id).unwrap();
623        
624        // Forward data
625        let data = vec![1, 2, 3, 4, 5];
626        assert!(manager.forward_data(session_id, data.clone(), ForwardDirection::ClientToPeer).is_ok());
627        assert!(manager.forward_data(session_id, data, ForwardDirection::PeerToClient).is_ok());
628        
629        // Check statistics updated
630        let session = manager.get_session(session_id).unwrap();
631        assert_eq!(session.bytes_sent, 5);
632        assert_eq!(session.bytes_received, 5);
633    }
634
635    #[test]
636    fn test_session_cleanup() {
637        let mut config = SessionConfig::default();
638        config.cleanup_interval = Duration::from_millis(1);
639        let (manager, _event_rx) = SessionManager::new(config);
640        
641        let signing_key = SigningKey::generate(&mut OsRng);
642        let verifying_key = signing_key.verifying_key();
643        let addr = test_addr();
644        
645        manager.add_trusted_key(addr, verifying_key);
646        
647        // Create session with very short timeout
648        let mut auth_token = AuthToken::new(1024, 1, &signing_key).unwrap(); // 1 second timeout
649        let session_id = manager.request_session(addr, vec![1, 2, 3], auth_token).unwrap();
650        
651        assert_eq!(manager.session_count(), 1);
652        
653        // Wait for session to expire
654        std::thread::sleep(Duration::from_millis(2));
655        
656        // Cleanup should remove expired session
657        let cleanup_count = manager.cleanup_expired_sessions().unwrap();
658        assert!(cleanup_count > 0);
659    }
660
661    #[test]
662    fn test_session_id_generation() {
663        let config = SessionConfig::default();
664        let (manager, _event_rx) = SessionManager::new(config);
665        
666        let signing_key = SigningKey::generate(&mut OsRng);
667        let verifying_key = signing_key.verifying_key();
668        let addr = test_addr();
669        
670        manager.add_trusted_key(addr, verifying_key);
671        
672        // Generate multiple session IDs
673        let mut session_ids = Vec::new();
674        for i in 0..10 {
675            let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
676            let session_id = manager.request_session(addr, vec![i], auth_token).unwrap();
677            session_ids.push(session_id);
678        }
679        
680        // All IDs should be unique and non-zero
681        for id in &session_ids {
682            assert!(*id != 0);
683        }
684        
685        let mut unique_ids: std::collections::HashSet<_> = session_ids.iter().collect();
686        assert_eq!(unique_ids.len(), session_ids.len());
687    }
688}