ant_quic/relay/
session_manager.rs

1// Copyright 2024 Saorsa Labs Ltd.
2//
3// This Saorsa Network Software is licensed under the General Public License (GPL), version 3.
4// Please see the file LICENSE-GPL, or visit <http://www.gnu.org/licenses/> for the full text.
5//
6// Full details available at https://saorsalabs.com/licenses
7
8//! Session management for relay connections with complete state machine.
9
10use crate::relay::{
11    AuthToken, RelayAuthenticator, RelayConnection, RelayConnectionConfig, RelayError, RelayResult,
12};
13use ed25519_dalek::VerifyingKey;
14use std::collections::HashMap;
15use std::net::SocketAddr;
16use std::sync::{Arc, Mutex};
17use std::time::{Duration, Instant};
18use tokio::sync::mpsc;
19
20/// Unique session identifier
21pub type SessionId = u32;
22
23/// Session configuration
24#[derive(Debug, Clone)]
25pub struct SessionConfig {
26    /// Maximum number of concurrent sessions
27    pub max_sessions: usize,
28    /// Default session timeout
29    pub default_timeout: Duration,
30    /// Session cleanup interval
31    pub cleanup_interval: Duration,
32    /// Default bandwidth limit per session
33    pub default_bandwidth_limit: u64,
34}
35
36impl Default for SessionConfig {
37    fn default() -> Self {
38        Self {
39            max_sessions: 100,
40            default_timeout: Duration::from_secs(300), // 5 minutes
41            cleanup_interval: Duration::from_secs(30), // 30 seconds
42            default_bandwidth_limit: 1048576,          // 1 MB/s
43        }
44    }
45}
46
47/// Session state in the relay state machine
48#[derive(Debug, Clone, PartialEq, Eq)]
49pub enum SessionState {
50    /// Session requested but not yet established
51    Pending,
52    /// Session active and forwarding data
53    Active,
54    /// Session terminating gracefully
55    Terminating,
56    /// Session terminated
57    Terminated,
58    /// Session failed due to error
59    Failed {
60        /// Terminal failure with human-readable reason
61        reason: String,
62    },
63}
64
65/// Information about a relay session
66#[derive(Debug, Clone)]
67pub struct RelaySessionInfo {
68    /// Session identifier
69    pub session_id: SessionId,
70    /// Client address
71    pub client_addr: SocketAddr,
72    /// Target peer connection ID
73    pub peer_connection_id: Vec<u8>,
74    /// Current session state
75    pub state: SessionState,
76    /// Session creation time
77    pub created_at: Instant,
78    /// Last activity time
79    pub last_activity: Instant,
80    /// Bandwidth limit
81    pub bandwidth_limit: u64,
82    /// Session timeout
83    pub timeout: Duration,
84    /// Bytes transferred
85    pub bytes_sent: u64,
86    /// Total bytes received by the relay session
87    pub bytes_received: u64,
88}
89
90/// Session manager for handling relay connections
91#[derive(Debug)]
92pub struct SessionManager {
93    /// Configuration
94    config: SessionConfig,
95    /// Active sessions
96    sessions: Arc<Mutex<HashMap<SessionId, RelaySessionInfo>>>,
97    /// Active connections
98    connections: Arc<Mutex<HashMap<SessionId, Arc<RelayConnection>>>>,
99    /// Authenticator for token verification
100    authenticator: RelayAuthenticator,
101    /// Trusted peer keys for authentication
102    trusted_keys: Arc<Mutex<HashMap<SocketAddr, VerifyingKey>>>,
103    /// Next session ID
104    next_session_id: Arc<Mutex<SessionId>>,
105    /// Event channels
106    event_sender: mpsc::UnboundedSender<SessionEvent>,
107    /// Last cleanup time
108    last_cleanup: Arc<Mutex<Instant>>,
109}
110
111/// Events generated by session management
112#[derive(Debug, Clone)]
113pub enum SessionEvent {
114    /// New session requested
115    SessionRequested {
116        /// Identifier of the newly requested session
117        session_id: SessionId,
118        /// Address of the requesting client
119        client_addr: SocketAddr,
120        /// Connection ID of the target peer
121        peer_connection_id: Vec<u8>,
122        /// Authentication token used for the request
123        auth_token: AuthToken,
124    },
125    /// Session established successfully
126    SessionEstablished {
127        /// Identifier of the established session
128        session_id: SessionId,
129        /// Address of the client for the session
130        client_addr: SocketAddr,
131    },
132    /// Session terminated
133    SessionTerminated {
134        /// Identifier of the terminated session
135        session_id: SessionId,
136        /// Reason for termination
137        reason: String,
138    },
139    /// Session failed
140    SessionFailed {
141        /// Identifier of the failed session
142        session_id: SessionId,
143        /// The error that caused the failure
144        error: RelayError,
145    },
146    /// Data forwarded through session
147    DataForwarded {
148        /// Identifier of the session that forwarded data
149        session_id: SessionId,
150        /// Number of bytes forwarded
151        bytes: usize,
152        /// Direction of forwarding
153        direction: ForwardDirection,
154    },
155}
156
157/// Direction of data forwarding
158#[derive(Debug, Clone, PartialEq, Eq)]
159pub enum ForwardDirection {
160    /// From client to peer
161    ClientToPeer,
162    /// From peer to client
163    PeerToClient,
164}
165
166impl SessionManager {
167    /// Create a new session manager
168    pub fn new(config: SessionConfig) -> (Self, mpsc::UnboundedReceiver<SessionEvent>) {
169        let (event_sender, event_receiver) = mpsc::unbounded_channel();
170
171        let manager = Self {
172            config,
173            sessions: Arc::new(Mutex::new(HashMap::new())),
174            connections: Arc::new(Mutex::new(HashMap::new())),
175            authenticator: RelayAuthenticator::new(),
176            trusted_keys: Arc::new(Mutex::new(HashMap::new())),
177            next_session_id: Arc::new(Mutex::new(1)),
178            event_sender,
179            last_cleanup: Arc::new(Mutex::new(Instant::now())),
180        };
181
182        (manager, event_receiver)
183    }
184
185    /// Add a trusted peer key for authentication
186    #[allow(clippy::unwrap_used)]
187    pub fn add_trusted_key(&self, addr: SocketAddr, key: VerifyingKey) {
188        let mut trusted_keys = self.trusted_keys.lock().unwrap();
189        trusted_keys.insert(addr, key);
190    }
191
192    /// Remove a trusted peer key
193    #[allow(clippy::unwrap_used)]
194    pub fn remove_trusted_key(&self, addr: &SocketAddr) {
195        let mut trusted_keys = self.trusted_keys.lock().unwrap();
196        trusted_keys.remove(addr);
197    }
198
199    /// Generate next session ID
200    #[allow(clippy::unwrap_used)]
201    fn next_session_id(&self) -> SessionId {
202        let mut next_id = self.next_session_id.lock().unwrap();
203        let id = *next_id;
204        *next_id = next_id.wrapping_add(1);
205        if *next_id == 0 {
206            *next_id = 1; // Skip 0 as invalid session ID
207        }
208        id
209    }
210
211    /// Request a new relay session
212    #[allow(clippy::unwrap_used)]
213    pub fn request_session(
214        &self,
215        client_addr: SocketAddr,
216        peer_connection_id: Vec<u8>,
217        auth_token: AuthToken,
218    ) -> RelayResult<SessionId> {
219        // Check session limit
220        {
221            let sessions = self.sessions.lock().unwrap();
222            if sessions.len() >= self.config.max_sessions {
223                return Err(RelayError::ResourceExhausted {
224                    resource_type: "sessions".to_string(),
225                    current_usage: sessions.len() as u64,
226                    limit: self.config.max_sessions as u64,
227                });
228            }
229        }
230
231        // Verify authentication token
232        let trusted_keys = self.trusted_keys.lock().unwrap();
233        let peer_key =
234            trusted_keys
235                .get(&client_addr)
236                .ok_or_else(|| RelayError::AuthenticationFailed {
237                    reason: format!("No trusted key for address {}", client_addr),
238                })?;
239
240        self.authenticator.verify_token(&auth_token, peer_key)?;
241
242        // Generate session ID
243        let session_id = self.next_session_id();
244
245        // Create session info
246        let now = Instant::now();
247        let session_info = RelaySessionInfo {
248            session_id,
249            client_addr,
250            peer_connection_id: peer_connection_id.clone(),
251            state: SessionState::Pending,
252            created_at: now,
253            last_activity: now,
254            bandwidth_limit: auth_token.bandwidth_limit as u64,
255            timeout: Duration::from_secs(auth_token.timeout_seconds as u64),
256            bytes_sent: 0,
257            bytes_received: 0,
258        };
259
260        // Store session
261        {
262            let mut sessions = self.sessions.lock().unwrap();
263            sessions.insert(session_id, session_info);
264        }
265
266        // Send event
267        let _ = self.event_sender.send(SessionEvent::SessionRequested {
268            session_id,
269            client_addr,
270            peer_connection_id,
271            auth_token,
272        });
273
274        Ok(session_id)
275    }
276
277    /// Establish a relay session
278    #[allow(clippy::unwrap_used)]
279    pub fn establish_session(&self, session_id: SessionId) -> RelayResult<()> {
280        let (client_addr, bandwidth_limit) = {
281            let mut sessions = self.sessions.lock().unwrap();
282            let session = sessions
283                .get_mut(&session_id)
284                .ok_or(RelayError::SessionError {
285                    session_id: Some(session_id),
286                    kind: crate::relay::error::SessionErrorKind::NotFound,
287                })?;
288
289            if session.state != SessionState::Pending {
290                return Err(RelayError::SessionError {
291                    session_id: Some(session_id),
292                    kind: crate::relay::error::SessionErrorKind::InvalidState {
293                        current_state: format!("{:?}", session.state),
294                        expected_state: "Pending".to_string(),
295                    },
296                });
297            }
298
299            session.state = SessionState::Active;
300            session.last_activity = Instant::now();
301
302            (session.client_addr, session.bandwidth_limit)
303        };
304
305        // Create relay connection
306        let (event_tx, _event_rx) = mpsc::unbounded_channel();
307        let (_action_tx, action_rx) = mpsc::unbounded_channel();
308
309        let mut connection_config = RelayConnectionConfig::default();
310        connection_config.bandwidth_limit = bandwidth_limit;
311
312        let connection = RelayConnection::new(
313            session_id,
314            client_addr,
315            connection_config,
316            event_tx,
317            action_rx,
318        );
319
320        // Store connection
321        {
322            let mut connections = self.connections.lock().unwrap();
323            connections.insert(session_id, Arc::new(connection));
324        }
325
326        // Send event
327        let _ = self.event_sender.send(SessionEvent::SessionEstablished {
328            session_id,
329            client_addr,
330        });
331
332        Ok(())
333    }
334
335    /// Terminate a relay session
336    #[allow(clippy::unwrap_used)]
337    pub fn terminate_session(&self, session_id: SessionId, reason: String) -> RelayResult<()> {
338        // Update session state
339        {
340            let mut sessions = self.sessions.lock().unwrap();
341            if let Some(session) = sessions.get_mut(&session_id) {
342                session.state = SessionState::Terminated;
343                session.last_activity = Instant::now();
344            }
345        }
346
347        // Terminate connection
348        {
349            let mut connections = self.connections.lock().unwrap();
350            if let Some(connection) = connections.remove(&session_id) {
351                let _ = connection.terminate(reason.clone());
352            }
353        }
354
355        // Send event
356        let _ = self
357            .event_sender
358            .send(SessionEvent::SessionTerminated { session_id, reason });
359
360        Ok(())
361    }
362
363    /// Forward data through a relay session
364    #[allow(clippy::unwrap_used)]
365    pub fn forward_data(
366        &self,
367        session_id: SessionId,
368        data: Vec<u8>,
369        direction: ForwardDirection,
370    ) -> RelayResult<()> {
371        let connection = {
372            let connections = self.connections.lock().unwrap();
373            connections
374                .get(&session_id)
375                .cloned()
376                .ok_or(RelayError::SessionError {
377                    session_id: Some(session_id),
378                    kind: crate::relay::error::SessionErrorKind::NotFound,
379                })?
380        };
381
382        // Forward data based on direction
383        match direction {
384            ForwardDirection::ClientToPeer => {
385                connection.send_data(data.clone())?;
386            }
387            ForwardDirection::PeerToClient => {
388                connection.receive_data(data.clone())?;
389            }
390        }
391
392        // Update session statistics
393        {
394            let mut sessions = self.sessions.lock().unwrap();
395            if let Some(session) = sessions.get_mut(&session_id) {
396                session.last_activity = Instant::now();
397                match direction {
398                    ForwardDirection::ClientToPeer => {
399                        session.bytes_sent += data.len() as u64;
400                    }
401                    ForwardDirection::PeerToClient => {
402                        session.bytes_received += data.len() as u64;
403                    }
404                }
405            }
406        }
407
408        // Send event
409        let _ = self.event_sender.send(SessionEvent::DataForwarded {
410            session_id,
411            bytes: data.len(),
412            direction,
413        });
414
415        Ok(())
416    }
417
418    /// Get session information
419    #[allow(clippy::unwrap_used)]
420    pub fn get_session(&self, session_id: SessionId) -> Option<RelaySessionInfo> {
421        let sessions = self.sessions.lock().unwrap();
422        sessions.get(&session_id).cloned()
423    }
424
425    /// List all active sessions
426    #[allow(clippy::unwrap_used)]
427    pub fn list_sessions(&self) -> Vec<RelaySessionInfo> {
428        let sessions = self.sessions.lock().unwrap();
429        sessions.values().cloned().collect()
430    }
431
432    /// Get session count
433    #[allow(clippy::unwrap_used)]
434    pub fn session_count(&self) -> usize {
435        let sessions = self.sessions.lock().unwrap();
436        sessions.len()
437    }
438
439    /// Clean up expired sessions
440    #[allow(clippy::unwrap_used)]
441    pub fn cleanup_expired_sessions(&self) -> RelayResult<usize> {
442        let mut last_cleanup = self.last_cleanup.lock().unwrap();
443        let now = Instant::now();
444
445        // Only cleanup if enough time has passed
446        if now.duration_since(*last_cleanup) < self.config.cleanup_interval {
447            return Ok(0);
448        }
449
450        *last_cleanup = now;
451        drop(last_cleanup);
452
453        let mut expired_sessions = Vec::new();
454
455        // Find expired sessions
456        {
457            let sessions = self.sessions.lock().unwrap();
458            for (session_id, session_info) in sessions.iter() {
459                let age = now.duration_since(session_info.last_activity);
460                if age > session_info.timeout {
461                    expired_sessions.push(*session_id);
462                }
463            }
464        }
465
466        // Remove expired sessions
467        let cleanup_count = expired_sessions.len();
468        for session_id in expired_sessions {
469            let _ = self.terminate_session(session_id, "Session expired".to_string());
470
471            // Remove from sessions map
472            let mut sessions = self.sessions.lock().unwrap();
473            sessions.remove(&session_id);
474        }
475
476        Ok(cleanup_count)
477    }
478
479    /// Get session manager statistics
480    #[allow(clippy::unwrap_used)]
481    pub fn get_statistics(&self) -> SessionManagerStats {
482        let sessions = self.sessions.lock().unwrap();
483        let connections = self.connections.lock().unwrap();
484
485        let mut active_count = 0;
486        let mut pending_count = 0;
487        let mut total_bytes_sent = 0;
488        let mut total_bytes_received = 0;
489
490        for session in sessions.values() {
491            match session.state {
492                SessionState::Active => active_count += 1,
493                SessionState::Pending => pending_count += 1,
494                _ => {}
495            }
496            total_bytes_sent += session.bytes_sent;
497            total_bytes_received += session.bytes_received;
498        }
499
500        SessionManagerStats {
501            total_sessions: sessions.len(),
502            active_sessions: active_count,
503            pending_sessions: pending_count,
504            total_connections: connections.len(),
505            total_bytes_sent,
506            total_bytes_received,
507        }
508    }
509}
510
511/// Session manager statistics
512#[derive(Debug, Clone)]
513pub struct SessionManagerStats {
514    /// Total number of sessions tracked
515    pub total_sessions: usize,
516    /// Number of active sessions
517    pub active_sessions: usize,
518    /// Number of pending sessions
519    pub pending_sessions: usize,
520    /// Number of active relay connections
521    pub total_connections: usize,
522    /// Total bytes sent across all sessions
523    pub total_bytes_sent: u64,
524    /// Total bytes received across all sessions
525    pub total_bytes_received: u64,
526}
527
528#[cfg(test)]
529mod tests {
530    use super::*;
531    use crate::relay::AuthToken;
532    use ed25519_dalek::SigningKey;
533    use rand::rngs::OsRng;
534    use std::net::{IpAddr, Ipv4Addr};
535
536    fn test_addr() -> SocketAddr {
537        SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080)
538    }
539
540    #[test]
541    fn test_session_manager_creation() {
542        let config = SessionConfig::default();
543        let (manager, _event_rx) = SessionManager::new(config);
544
545        let stats = manager.get_statistics();
546        assert_eq!(stats.total_sessions, 0);
547        assert_eq!(stats.active_sessions, 0);
548    }
549
550    #[test]
551    fn test_trusted_key_management() {
552        let config = SessionConfig::default();
553        let (manager, _event_rx) = SessionManager::new(config);
554
555        let signing_key = SigningKey::generate(&mut OsRng);
556        let verifying_key = signing_key.verifying_key();
557        let addr = test_addr();
558
559        manager.add_trusted_key(addr, verifying_key);
560
561        // Should be able to create a session with trusted key
562        let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
563        let result = manager.request_session(addr, vec![1, 2, 3], auth_token);
564        assert!(result.is_ok());
565
566        // Remove trusted key
567        manager.remove_trusted_key(&addr);
568
569        // Should fail without trusted key
570        let auth_token2 = AuthToken::new(1024, 300, &signing_key).unwrap();
571        let result2 = manager.request_session(addr, vec![4, 5, 6], auth_token2);
572        assert!(result2.is_err());
573    }
574
575    #[test]
576    fn test_session_request_and_establishment() {
577        let config = SessionConfig::default();
578        let (manager, _event_rx) = SessionManager::new(config);
579
580        let signing_key = SigningKey::generate(&mut OsRng);
581        let verifying_key = signing_key.verifying_key();
582        let addr = test_addr();
583
584        manager.add_trusted_key(addr, verifying_key);
585
586        // Request session
587        let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
588        let session_id = manager
589            .request_session(addr, vec![1, 2, 3], auth_token)
590            .unwrap();
591
592        // Check session exists and is pending
593        let session = manager.get_session(session_id).unwrap();
594        assert_eq!(session.state, SessionState::Pending);
595        assert_eq!(session.client_addr, addr);
596
597        // Establish session
598        assert!(manager.establish_session(session_id).is_ok());
599
600        // Check session is now active
601        let session = manager.get_session(session_id).unwrap();
602        assert_eq!(session.state, SessionState::Active);
603    }
604
605    #[test]
606    fn test_session_limit() {
607        let mut config = SessionConfig::default();
608        config.max_sessions = 2;
609        let (manager, _event_rx) = SessionManager::new(config);
610
611        let signing_key = SigningKey::generate(&mut OsRng);
612        let verifying_key = signing_key.verifying_key();
613        let addr = test_addr();
614
615        manager.add_trusted_key(addr, verifying_key);
616
617        // Create maximum sessions
618        for i in 0..2 {
619            let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
620            let result = manager.request_session(addr, vec![i], auth_token);
621            assert!(result.is_ok());
622        }
623
624        // Third session should fail
625        let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
626        let result = manager.request_session(addr, vec![3], auth_token);
627        assert!(result.is_err());
628    }
629
630    #[test]
631    fn test_session_termination() {
632        let config = SessionConfig::default();
633        let (manager, _event_rx) = SessionManager::new(config);
634
635        let signing_key = SigningKey::generate(&mut OsRng);
636        let verifying_key = signing_key.verifying_key();
637        let addr = test_addr();
638
639        manager.add_trusted_key(addr, verifying_key);
640
641        // Create and establish session
642        let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
643        let session_id = manager
644            .request_session(addr, vec![1, 2, 3], auth_token)
645            .unwrap();
646        manager.establish_session(session_id).unwrap();
647
648        // Terminate session
649        let reason = "Test termination".to_string();
650        assert!(manager.terminate_session(session_id, reason).is_ok());
651
652        // Check session is terminated
653        let session = manager.get_session(session_id).unwrap();
654        assert_eq!(session.state, SessionState::Terminated);
655    }
656
657    #[test]
658    fn test_data_forwarding() {
659        let config = SessionConfig::default();
660        let (manager, _event_rx) = SessionManager::new(config);
661
662        let signing_key = SigningKey::generate(&mut OsRng);
663        let verifying_key = signing_key.verifying_key();
664        let addr = test_addr();
665
666        manager.add_trusted_key(addr, verifying_key);
667
668        // Create and establish session
669        let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
670        let session_id = manager
671            .request_session(addr, vec![1, 2, 3], auth_token)
672            .unwrap();
673        manager.establish_session(session_id).unwrap();
674
675        // Forward data
676        let data = vec![1, 2, 3, 4, 5];
677        assert!(
678            manager
679                .forward_data(session_id, data.clone(), ForwardDirection::ClientToPeer)
680                .is_ok()
681        );
682        assert!(
683            manager
684                .forward_data(session_id, data, ForwardDirection::PeerToClient)
685                .is_ok()
686        );
687
688        // Check statistics updated
689        let session = manager.get_session(session_id).unwrap();
690        assert_eq!(session.bytes_sent, 5);
691        assert_eq!(session.bytes_received, 5);
692    }
693
694    #[test]
695    fn test_session_cleanup() {
696        let mut config = SessionConfig::default();
697        config.cleanup_interval = Duration::from_millis(1);
698        let (manager, _event_rx) = SessionManager::new(config);
699
700        let signing_key = SigningKey::generate(&mut OsRng);
701        let verifying_key = signing_key.verifying_key();
702        let addr = test_addr();
703
704        manager.add_trusted_key(addr, verifying_key);
705
706        // Create session with very short timeout
707        let auth_token = AuthToken::new(1024, 0, &signing_key).unwrap(); // 0 second timeout (expires immediately)
708        let _session_id = manager
709            .request_session(addr, vec![1, 2, 3], auth_token)
710            .unwrap();
711
712        assert_eq!(manager.session_count(), 1);
713
714        // Wait for session to expire (give it a bit more time to ensure expiry)
715        std::thread::sleep(Duration::from_millis(10));
716
717        // Cleanup should remove expired session
718        let cleanup_count = manager.cleanup_expired_sessions().unwrap();
719        assert!(cleanup_count > 0);
720    }
721
722    #[test]
723    fn test_session_id_generation() {
724        let config = SessionConfig::default();
725        let (manager, _event_rx) = SessionManager::new(config);
726
727        let signing_key = SigningKey::generate(&mut OsRng);
728        let verifying_key = signing_key.verifying_key();
729        let addr = test_addr();
730
731        manager.add_trusted_key(addr, verifying_key);
732
733        // Generate multiple session IDs
734        let mut session_ids = Vec::new();
735        for i in 0..10 {
736            let auth_token = AuthToken::new(1024, 300, &signing_key).unwrap();
737            let session_id = manager.request_session(addr, vec![i], auth_token).unwrap();
738            session_ids.push(session_id);
739        }
740
741        // All IDs should be unique and non-zero
742        for id in &session_ids {
743            assert!(*id != 0);
744        }
745
746        let unique_ids: std::collections::HashSet<_> = session_ids.iter().collect();
747        assert_eq!(unique_ids.len(), session_ids.len());
748    }
749}