ant_quic/relay/
session_manager.rs

1// Copyright 2024 Saorsa Labs Ltd.
2//
3// This Saorsa Network Software is licensed under the General Public License (GPL), version 3.
4// Please see the file LICENSE-GPL, or visit <http://www.gnu.org/licenses/> for the full text.
5//
6// Full details available at https://saorsalabs.com/licenses
7
8//! Session management for relay connections with complete state machine.
9
10use crate::crypto::pqc::types::MlDsaPublicKey;
11use crate::relay::{
12    AuthToken, RelayAuthenticator, RelayConnection, RelayConnectionConfig, RelayError, RelayResult,
13};
14use std::collections::HashMap;
15use std::net::SocketAddr;
16use std::sync::{Arc, Mutex};
17use std::time::{Duration, Instant};
18use tokio::sync::mpsc;
19
20/// Unique session identifier
21pub type SessionId = u32;
22
23/// Session configuration
24#[derive(Debug, Clone)]
25pub struct SessionConfig {
26    /// Maximum number of concurrent sessions
27    pub max_sessions: usize,
28    /// Default session timeout
29    pub default_timeout: Duration,
30    /// Session cleanup interval
31    pub cleanup_interval: Duration,
32    /// Default bandwidth limit per session
33    pub default_bandwidth_limit: u64,
34}
35
36impl Default for SessionConfig {
37    fn default() -> Self {
38        Self {
39            max_sessions: 100,
40            default_timeout: Duration::from_secs(300), // 5 minutes
41            cleanup_interval: Duration::from_secs(30), // 30 seconds
42            default_bandwidth_limit: 1048576,          // 1 MB/s
43        }
44    }
45}
46
47/// Session state in the relay state machine
48#[derive(Debug, Clone, PartialEq, Eq)]
49pub enum SessionState {
50    /// Session requested but not yet established
51    Pending,
52    /// Session active and forwarding data
53    Active,
54    /// Session terminating gracefully
55    Terminating,
56    /// Session terminated
57    Terminated,
58    /// Session failed due to error
59    Failed {
60        /// Terminal failure with human-readable reason
61        reason: String,
62    },
63}
64
65/// Information about a relay session
66#[derive(Debug, Clone)]
67pub struct RelaySessionInfo {
68    /// Session identifier
69    pub session_id: SessionId,
70    /// Client address
71    pub client_addr: SocketAddr,
72    /// Target peer connection ID
73    pub peer_connection_id: Vec<u8>,
74    /// Current session state
75    pub state: SessionState,
76    /// Session creation time
77    pub created_at: Instant,
78    /// Last activity time
79    pub last_activity: Instant,
80    /// Bandwidth limit
81    pub bandwidth_limit: u64,
82    /// Session timeout
83    pub timeout: Duration,
84    /// Bytes transferred
85    pub bytes_sent: u64,
86    /// Total bytes received by the relay session
87    pub bytes_received: u64,
88}
89
90/// Session manager for handling relay connections
91#[derive(Debug)]
92pub struct SessionManager {
93    /// Configuration
94    config: SessionConfig,
95    /// Active sessions
96    sessions: Arc<Mutex<HashMap<SessionId, RelaySessionInfo>>>,
97    /// Active connections
98    connections: Arc<Mutex<HashMap<SessionId, Arc<RelayConnection>>>>,
99    /// Authenticator for token verification
100    authenticator: RelayAuthenticator,
101    /// Trusted peer keys for authentication (ML-DSA-65)
102    trusted_keys: Arc<Mutex<HashMap<SocketAddr, MlDsaPublicKey>>>,
103    /// Next session ID
104    next_session_id: Arc<Mutex<SessionId>>,
105    /// Event channels
106    event_sender: mpsc::UnboundedSender<SessionEvent>,
107    /// Last cleanup time
108    last_cleanup: Arc<Mutex<Instant>>,
109}
110
111/// Events generated by session management
112#[derive(Debug, Clone)]
113pub enum SessionEvent {
114    /// New session requested
115    SessionRequested {
116        /// Identifier of the newly requested session
117        session_id: SessionId,
118        /// Address of the requesting client
119        client_addr: SocketAddr,
120        /// Connection ID of the target peer
121        peer_connection_id: Vec<u8>,
122        /// Authentication token used for the request
123        auth_token: AuthToken,
124    },
125    /// Session established successfully
126    SessionEstablished {
127        /// Identifier of the established session
128        session_id: SessionId,
129        /// Address of the client for the session
130        client_addr: SocketAddr,
131    },
132    /// Session terminated
133    SessionTerminated {
134        /// Identifier of the terminated session
135        session_id: SessionId,
136        /// Reason for termination
137        reason: String,
138    },
139    /// Session failed
140    SessionFailed {
141        /// Identifier of the failed session
142        session_id: SessionId,
143        /// The error that caused the failure
144        error: RelayError,
145    },
146    /// Data forwarded through session
147    DataForwarded {
148        /// Identifier of the session that forwarded data
149        session_id: SessionId,
150        /// Number of bytes forwarded
151        bytes: usize,
152        /// Direction of forwarding
153        direction: ForwardDirection,
154    },
155}
156
157/// Direction of data forwarding
158#[derive(Debug, Clone, PartialEq, Eq)]
159pub enum ForwardDirection {
160    /// From client to peer
161    ClientToPeer,
162    /// From peer to client
163    PeerToClient,
164}
165
166impl SessionManager {
167    /// Create a new session manager
168    ///
169    /// # Errors
170    /// Returns an error if the internal authenticator fails to initialize.
171    pub fn new(
172        config: SessionConfig,
173    ) -> RelayResult<(Self, mpsc::UnboundedReceiver<SessionEvent>)> {
174        let (event_sender, event_receiver) = mpsc::unbounded_channel();
175
176        let manager = Self {
177            config,
178            sessions: Arc::new(Mutex::new(HashMap::new())),
179            connections: Arc::new(Mutex::new(HashMap::new())),
180            authenticator: RelayAuthenticator::new()?,
181            trusted_keys: Arc::new(Mutex::new(HashMap::new())),
182            next_session_id: Arc::new(Mutex::new(1)),
183            event_sender,
184            last_cleanup: Arc::new(Mutex::new(Instant::now())),
185        };
186
187        Ok((manager, event_receiver))
188    }
189
190    /// Add a trusted peer key for authentication (ML-DSA-65)
191    #[allow(clippy::unwrap_used)]
192    pub fn add_trusted_key(&self, addr: SocketAddr, key: MlDsaPublicKey) {
193        let mut trusted_keys = self.trusted_keys.lock().unwrap();
194        trusted_keys.insert(addr, key);
195    }
196
197    /// Remove a trusted peer key
198    #[allow(clippy::unwrap_used)]
199    pub fn remove_trusted_key(&self, addr: &SocketAddr) {
200        let mut trusted_keys = self.trusted_keys.lock().unwrap();
201        trusted_keys.remove(addr);
202    }
203
204    /// Generate next session ID
205    #[allow(clippy::unwrap_used)]
206    fn next_session_id(&self) -> SessionId {
207        let mut next_id = self.next_session_id.lock().unwrap();
208        let id = *next_id;
209        *next_id = next_id.wrapping_add(1);
210        if *next_id == 0 {
211            *next_id = 1; // Skip 0 as invalid session ID
212        }
213        id
214    }
215
216    /// Request a new relay session
217    #[allow(clippy::unwrap_used)]
218    pub fn request_session(
219        &self,
220        client_addr: SocketAddr,
221        peer_connection_id: Vec<u8>,
222        auth_token: AuthToken,
223    ) -> RelayResult<SessionId> {
224        // Check session limit
225        {
226            let sessions = self.sessions.lock().unwrap();
227            if sessions.len() >= self.config.max_sessions {
228                return Err(RelayError::ResourceExhausted {
229                    resource_type: "sessions".to_string(),
230                    current_usage: sessions.len() as u64,
231                    limit: self.config.max_sessions as u64,
232                });
233            }
234        }
235
236        // Verify authentication token
237        let trusted_keys = self.trusted_keys.lock().unwrap();
238        let peer_key =
239            trusted_keys
240                .get(&client_addr)
241                .ok_or_else(|| RelayError::AuthenticationFailed {
242                    reason: format!("No trusted key for address {}", client_addr),
243                })?;
244
245        self.authenticator.verify_token(&auth_token, peer_key)?;
246
247        // Generate session ID
248        let session_id = self.next_session_id();
249
250        // Create session info
251        let now = Instant::now();
252        let session_info = RelaySessionInfo {
253            session_id,
254            client_addr,
255            peer_connection_id: peer_connection_id.clone(),
256            state: SessionState::Pending,
257            created_at: now,
258            last_activity: now,
259            bandwidth_limit: auth_token.bandwidth_limit as u64,
260            timeout: Duration::from_secs(auth_token.timeout_seconds as u64),
261            bytes_sent: 0,
262            bytes_received: 0,
263        };
264
265        // Store session
266        {
267            let mut sessions = self.sessions.lock().unwrap();
268            sessions.insert(session_id, session_info);
269        }
270
271        // Send event
272        let _ = self.event_sender.send(SessionEvent::SessionRequested {
273            session_id,
274            client_addr,
275            peer_connection_id,
276            auth_token,
277        });
278
279        Ok(session_id)
280    }
281
282    /// Establish a relay session
283    #[allow(clippy::unwrap_used)]
284    pub fn establish_session(&self, session_id: SessionId) -> RelayResult<()> {
285        let (client_addr, bandwidth_limit) = {
286            let mut sessions = self.sessions.lock().unwrap();
287            let session = sessions
288                .get_mut(&session_id)
289                .ok_or(RelayError::SessionError {
290                    session_id: Some(session_id),
291                    kind: crate::relay::error::SessionErrorKind::NotFound,
292                })?;
293
294            if session.state != SessionState::Pending {
295                return Err(RelayError::SessionError {
296                    session_id: Some(session_id),
297                    kind: crate::relay::error::SessionErrorKind::InvalidState {
298                        current_state: format!("{:?}", session.state),
299                        expected_state: "Pending".to_string(),
300                    },
301                });
302            }
303
304            session.state = SessionState::Active;
305            session.last_activity = Instant::now();
306
307            (session.client_addr, session.bandwidth_limit)
308        };
309
310        // Create relay connection
311        let (event_tx, _event_rx) = mpsc::unbounded_channel();
312        let (_action_tx, action_rx) = mpsc::unbounded_channel();
313
314        let mut connection_config = RelayConnectionConfig::default();
315        connection_config.bandwidth_limit = bandwidth_limit;
316
317        let connection = RelayConnection::new(
318            session_id,
319            client_addr,
320            connection_config,
321            event_tx,
322            action_rx,
323        );
324
325        // Store connection
326        {
327            let mut connections = self.connections.lock().unwrap();
328            connections.insert(session_id, Arc::new(connection));
329        }
330
331        // Send event
332        let _ = self.event_sender.send(SessionEvent::SessionEstablished {
333            session_id,
334            client_addr,
335        });
336
337        Ok(())
338    }
339
340    /// Terminate a relay session
341    #[allow(clippy::unwrap_used)]
342    pub fn terminate_session(&self, session_id: SessionId, reason: String) -> RelayResult<()> {
343        // Update session state
344        {
345            let mut sessions = self.sessions.lock().unwrap();
346            if let Some(session) = sessions.get_mut(&session_id) {
347                session.state = SessionState::Terminated;
348                session.last_activity = Instant::now();
349            }
350        }
351
352        // Terminate connection
353        {
354            let mut connections = self.connections.lock().unwrap();
355            if let Some(connection) = connections.remove(&session_id) {
356                let _ = connection.terminate(reason.clone());
357            }
358        }
359
360        // Send event
361        let _ = self
362            .event_sender
363            .send(SessionEvent::SessionTerminated { session_id, reason });
364
365        Ok(())
366    }
367
368    /// Forward data through a relay session
369    #[allow(clippy::unwrap_used)]
370    pub fn forward_data(
371        &self,
372        session_id: SessionId,
373        data: Vec<u8>,
374        direction: ForwardDirection,
375    ) -> RelayResult<()> {
376        let connection = {
377            let connections = self.connections.lock().unwrap();
378            connections
379                .get(&session_id)
380                .cloned()
381                .ok_or(RelayError::SessionError {
382                    session_id: Some(session_id),
383                    kind: crate::relay::error::SessionErrorKind::NotFound,
384                })?
385        };
386
387        // Forward data based on direction
388        match direction {
389            ForwardDirection::ClientToPeer => {
390                connection.send_data(data.clone())?;
391            }
392            ForwardDirection::PeerToClient => {
393                connection.receive_data(data.clone())?;
394            }
395        }
396
397        // Update session statistics
398        {
399            let mut sessions = self.sessions.lock().unwrap();
400            if let Some(session) = sessions.get_mut(&session_id) {
401                session.last_activity = Instant::now();
402                match direction {
403                    ForwardDirection::ClientToPeer => {
404                        session.bytes_sent += data.len() as u64;
405                    }
406                    ForwardDirection::PeerToClient => {
407                        session.bytes_received += data.len() as u64;
408                    }
409                }
410            }
411        }
412
413        // Send event
414        let _ = self.event_sender.send(SessionEvent::DataForwarded {
415            session_id,
416            bytes: data.len(),
417            direction,
418        });
419
420        Ok(())
421    }
422
423    /// Get session information
424    #[allow(clippy::unwrap_used)]
425    pub fn get_session(&self, session_id: SessionId) -> Option<RelaySessionInfo> {
426        let sessions = self.sessions.lock().unwrap();
427        sessions.get(&session_id).cloned()
428    }
429
430    /// List all active sessions
431    #[allow(clippy::unwrap_used)]
432    pub fn list_sessions(&self) -> Vec<RelaySessionInfo> {
433        let sessions = self.sessions.lock().unwrap();
434        sessions.values().cloned().collect()
435    }
436
437    /// Get session count
438    #[allow(clippy::unwrap_used)]
439    pub fn session_count(&self) -> usize {
440        let sessions = self.sessions.lock().unwrap();
441        sessions.len()
442    }
443
444    /// Clean up expired sessions
445    #[allow(clippy::unwrap_used)]
446    pub fn cleanup_expired_sessions(&self) -> RelayResult<usize> {
447        let mut last_cleanup = self.last_cleanup.lock().unwrap();
448        let now = Instant::now();
449
450        // Only cleanup if enough time has passed
451        if now.duration_since(*last_cleanup) < self.config.cleanup_interval {
452            return Ok(0);
453        }
454
455        *last_cleanup = now;
456        drop(last_cleanup);
457
458        let mut expired_sessions = Vec::new();
459
460        // Find expired sessions
461        {
462            let sessions = self.sessions.lock().unwrap();
463            for (session_id, session_info) in sessions.iter() {
464                let age = now.duration_since(session_info.last_activity);
465                if age > session_info.timeout {
466                    expired_sessions.push(*session_id);
467                }
468            }
469        }
470
471        // Remove expired sessions
472        let cleanup_count = expired_sessions.len();
473        for session_id in expired_sessions {
474            let _ = self.terminate_session(session_id, "Session expired".to_string());
475
476            // Remove from sessions map
477            let mut sessions = self.sessions.lock().unwrap();
478            sessions.remove(&session_id);
479        }
480
481        Ok(cleanup_count)
482    }
483
484    /// Get session manager statistics
485    #[allow(clippy::unwrap_used)]
486    pub fn get_statistics(&self) -> SessionManagerStats {
487        let sessions = self.sessions.lock().unwrap();
488        let connections = self.connections.lock().unwrap();
489
490        let mut active_count = 0;
491        let mut pending_count = 0;
492        let mut total_bytes_sent = 0;
493        let mut total_bytes_received = 0;
494
495        for session in sessions.values() {
496            match session.state {
497                SessionState::Active => active_count += 1,
498                SessionState::Pending => pending_count += 1,
499                _ => {}
500            }
501            total_bytes_sent += session.bytes_sent;
502            total_bytes_received += session.bytes_received;
503        }
504
505        SessionManagerStats {
506            total_sessions: sessions.len(),
507            active_sessions: active_count,
508            pending_sessions: pending_count,
509            total_connections: connections.len(),
510            total_bytes_sent,
511            total_bytes_received,
512        }
513    }
514}
515
516/// Session manager statistics
517#[derive(Debug, Clone)]
518pub struct SessionManagerStats {
519    /// Total number of sessions tracked
520    pub total_sessions: usize,
521    /// Number of active sessions
522    pub active_sessions: usize,
523    /// Number of pending sessions
524    pub pending_sessions: usize,
525    /// Number of active relay connections
526    pub total_connections: usize,
527    /// Total bytes sent across all sessions
528    pub total_bytes_sent: u64,
529    /// Total bytes received across all sessions
530    pub total_bytes_received: u64,
531}
532
533#[cfg(test)]
534mod tests {
535    use super::*;
536    use crate::crypto::raw_public_keys::key_utils::generate_ml_dsa_keypair;
537    use crate::relay::AuthToken;
538    use std::net::{IpAddr, Ipv4Addr};
539
540    fn test_addr() -> SocketAddr {
541        SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080)
542    }
543
544    #[test]
545    fn test_session_manager_creation() {
546        let config = SessionConfig::default();
547        let (manager, _event_rx) = SessionManager::new(config).unwrap();
548
549        let stats = manager.get_statistics();
550        assert_eq!(stats.total_sessions, 0);
551        assert_eq!(stats.active_sessions, 0);
552    }
553
554    #[test]
555    fn test_trusted_key_management() {
556        let config = SessionConfig::default();
557        let (manager, _event_rx) = SessionManager::new(config).unwrap();
558
559        let (public_key, secret_key) = generate_ml_dsa_keypair().unwrap();
560        let addr = test_addr();
561
562        manager.add_trusted_key(addr, public_key.clone());
563
564        // Should be able to create a session with trusted key
565        let auth_token = AuthToken::new(1024, 300, &secret_key).unwrap();
566        let result = manager.request_session(addr, vec![1, 2, 3], auth_token);
567        assert!(result.is_ok());
568
569        // Remove trusted key
570        manager.remove_trusted_key(&addr);
571
572        // Should fail without trusted key
573        let auth_token2 = AuthToken::new(1024, 300, &secret_key).unwrap();
574        let result2 = manager.request_session(addr, vec![4, 5, 6], auth_token2);
575        assert!(result2.is_err());
576    }
577
578    #[test]
579    fn test_session_request_and_establishment() {
580        let config = SessionConfig::default();
581        let (manager, _event_rx) = SessionManager::new(config).unwrap();
582
583        let (public_key, secret_key) = generate_ml_dsa_keypair().unwrap();
584        let addr = test_addr();
585
586        manager.add_trusted_key(addr, public_key);
587
588        // Request session
589        let auth_token = AuthToken::new(1024, 300, &secret_key).unwrap();
590        let session_id = manager
591            .request_session(addr, vec![1, 2, 3], auth_token)
592            .unwrap();
593
594        // Check session exists and is pending
595        let session = manager.get_session(session_id).unwrap();
596        assert_eq!(session.state, SessionState::Pending);
597        assert_eq!(session.client_addr, addr);
598
599        // Establish session
600        assert!(manager.establish_session(session_id).is_ok());
601
602        // Check session is now active
603        let session = manager.get_session(session_id).unwrap();
604        assert_eq!(session.state, SessionState::Active);
605    }
606
607    #[test]
608    fn test_session_limit() {
609        let mut config = SessionConfig::default();
610        config.max_sessions = 2;
611        let (manager, _event_rx) = SessionManager::new(config).unwrap();
612
613        let (public_key, secret_key) = generate_ml_dsa_keypair().unwrap();
614        let addr = test_addr();
615
616        manager.add_trusted_key(addr, public_key);
617
618        // Create maximum sessions
619        for i in 0..2 {
620            let auth_token = AuthToken::new(1024, 300, &secret_key).unwrap();
621            let result = manager.request_session(addr, vec![i], auth_token);
622            assert!(result.is_ok());
623        }
624
625        // Third session should fail
626        let auth_token = AuthToken::new(1024, 300, &secret_key).unwrap();
627        let result = manager.request_session(addr, vec![3], auth_token);
628        assert!(result.is_err());
629    }
630
631    #[test]
632    fn test_session_termination() {
633        let config = SessionConfig::default();
634        let (manager, _event_rx) = SessionManager::new(config).unwrap();
635
636        let (public_key, secret_key) = generate_ml_dsa_keypair().unwrap();
637        let addr = test_addr();
638
639        manager.add_trusted_key(addr, public_key);
640
641        // Create and establish session
642        let auth_token = AuthToken::new(1024, 300, &secret_key).unwrap();
643        let session_id = manager
644            .request_session(addr, vec![1, 2, 3], auth_token)
645            .unwrap();
646        manager.establish_session(session_id).unwrap();
647
648        // Terminate session
649        let reason = "Test termination".to_string();
650        assert!(manager.terminate_session(session_id, reason).is_ok());
651
652        // Check session is terminated
653        let session = manager.get_session(session_id).unwrap();
654        assert_eq!(session.state, SessionState::Terminated);
655    }
656
657    #[test]
658    fn test_data_forwarding() {
659        let config = SessionConfig::default();
660        let (manager, _event_rx) = SessionManager::new(config).unwrap();
661
662        let (public_key, secret_key) = generate_ml_dsa_keypair().unwrap();
663        let addr = test_addr();
664
665        manager.add_trusted_key(addr, public_key);
666
667        // Create and establish session
668        let auth_token = AuthToken::new(1024, 300, &secret_key).unwrap();
669        let session_id = manager
670            .request_session(addr, vec![1, 2, 3], auth_token)
671            .unwrap();
672        manager.establish_session(session_id).unwrap();
673
674        // Forward data
675        let data = vec![1, 2, 3, 4, 5];
676        assert!(
677            manager
678                .forward_data(session_id, data.clone(), ForwardDirection::ClientToPeer)
679                .is_ok()
680        );
681        assert!(
682            manager
683                .forward_data(session_id, data, ForwardDirection::PeerToClient)
684                .is_ok()
685        );
686
687        // Check statistics updated
688        let session = manager.get_session(session_id).unwrap();
689        assert_eq!(session.bytes_sent, 5);
690        assert_eq!(session.bytes_received, 5);
691    }
692
693    #[test]
694    fn test_session_cleanup() {
695        let mut config = SessionConfig::default();
696        config.cleanup_interval = Duration::from_millis(1);
697        let (manager, _event_rx) = SessionManager::new(config).unwrap();
698
699        let (public_key, secret_key) = generate_ml_dsa_keypair().unwrap();
700        let addr = test_addr();
701
702        manager.add_trusted_key(addr, public_key);
703
704        // Create session with very short timeout
705        let auth_token = AuthToken::new(1024, 0, &secret_key).unwrap(); // 0 second timeout (expires immediately)
706        let _session_id = manager
707            .request_session(addr, vec![1, 2, 3], auth_token)
708            .unwrap();
709
710        assert_eq!(manager.session_count(), 1);
711
712        // Wait for session to expire (give it a bit more time to ensure expiry)
713        std::thread::sleep(Duration::from_millis(10));
714
715        // Cleanup should remove expired session
716        let cleanup_count = manager.cleanup_expired_sessions().unwrap();
717        assert!(cleanup_count > 0);
718    }
719
720    #[test]
721    fn test_session_id_generation() {
722        let config = SessionConfig::default();
723        let (manager, _event_rx) = SessionManager::new(config).unwrap();
724
725        let (public_key, secret_key) = generate_ml_dsa_keypair().unwrap();
726        let addr = test_addr();
727
728        manager.add_trusted_key(addr, public_key);
729
730        // Generate multiple session IDs
731        let mut session_ids = Vec::new();
732        for i in 0..10 {
733            let auth_token = AuthToken::new(1024, 300, &secret_key).unwrap();
734            let session_id = manager.request_session(addr, vec![i], auth_token).unwrap();
735            session_ids.push(session_id);
736        }
737
738        // All IDs should be unique and non-zero
739        for id in &session_ids {
740            assert!(*id != 0);
741        }
742
743        let unique_ids: std::collections::HashSet<_> = session_ids.iter().collect();
744        assert_eq!(unique_ids.len(), session_ids.len());
745    }
746}