ant_quic/relay/
connection.rs

1//! Relay connection implementation for bidirectional data forwarding.
2
3use crate::relay::{RelayError, RelayResult};
4use std::collections::VecDeque;
5use std::net::SocketAddr;
6use std::sync::{Arc, Mutex};
7use std::time::{Duration, Instant};
8use tokio::sync::mpsc;
9
10/// Configuration for relay connections
11#[derive(Debug, Clone)]
12pub struct RelayConnectionConfig {
13    /// Maximum data frame size
14    pub max_frame_size: usize,
15    /// Buffer size for queued data
16    pub buffer_size: usize,
17    /// Connection timeout
18    pub connection_timeout: Duration,
19    /// Keep-alive interval
20    pub keep_alive_interval: Duration,
21    /// Maximum bandwidth per connection (bytes/sec)
22    pub bandwidth_limit: u64,
23}
24
25impl Default for RelayConnectionConfig {
26    fn default() -> Self {
27        Self {
28            max_frame_size: 65536,                        // 64 KB
29            buffer_size: 1048576,                         // 1 MB
30            connection_timeout: Duration::from_secs(300), // 5 minutes
31            keep_alive_interval: Duration::from_secs(30), // 30 seconds
32            bandwidth_limit: 1048576,                     // 1 MB/s
33        }
34    }
35}
36
37/// Events that can occur during relay operation
38#[derive(Debug, Clone)]
39pub enum RelayEvent {
40    /// Connection established successfully
41    ConnectionEstablished {
42        session_id: u32,
43        peer_addr: SocketAddr,
44    },
45    /// Data received from peer
46    DataReceived { session_id: u32, data: Vec<u8> },
47    /// Connection terminated
48    ConnectionTerminated { session_id: u32, reason: String },
49    /// Error occurred during relay operation
50    Error {
51        session_id: Option<u32>,
52        error: RelayError,
53    },
54    /// Bandwidth limit exceeded
55    BandwidthLimitExceeded {
56        session_id: u32,
57        current_usage: u64,
58        limit: u64,
59    },
60    /// Keep-alive signal
61    KeepAlive { session_id: u32 },
62}
63
64/// Actions that can be taken in response to relay events
65#[derive(Debug, Clone)]
66pub enum RelayAction {
67    /// Send data to peer
68    SendData { session_id: u32, data: Vec<u8> },
69    /// Terminate connection
70    TerminateConnection { session_id: u32, reason: String },
71    /// Update bandwidth limit
72    UpdateBandwidthLimit { session_id: u32, new_limit: u64 },
73    /// Send keep-alive
74    SendKeepAlive { session_id: u32 },
75}
76
77/// Relay connection for bidirectional data forwarding
78#[derive(Debug)]
79pub struct RelayConnection {
80    /// Unique session identifier
81    session_id: u32,
82    /// Peer address
83    peer_addr: SocketAddr,
84    /// Configuration
85    config: RelayConnectionConfig,
86    /// Connection state
87    state: Arc<Mutex<ConnectionState>>,
88    /// Event sender
89    event_sender: mpsc::UnboundedSender<RelayEvent>,
90    /// Action receiver
91    action_receiver: mpsc::UnboundedReceiver<RelayAction>,
92}
93
94/// Internal connection state
95#[derive(Debug)]
96struct ConnectionState {
97    /// Whether connection is active
98    is_active: bool,
99    /// Data queue for outgoing packets
100    outgoing_queue: VecDeque<Vec<u8>>,
101    /// Data queue for incoming packets
102    incoming_queue: VecDeque<Vec<u8>>,
103    /// Current buffer usage
104    buffer_usage: usize,
105    /// Bandwidth tracking
106    bandwidth_tracker: BandwidthTracker,
107    /// Last activity timestamp
108    last_activity: Instant,
109    /// Keep-alive timer
110    next_keep_alive: Instant,
111}
112
113/// Bandwidth usage tracker
114#[derive(Debug)]
115struct BandwidthTracker {
116    /// Bytes sent in current window
117    bytes_sent: u64,
118    /// Bytes received in current window
119    bytes_received: u64,
120    /// Window start time
121    window_start: Instant,
122    /// Window duration (1 second)
123    window_duration: Duration,
124    /// Rate limit
125    limit: u64,
126}
127
128impl BandwidthTracker {
129    fn new(limit: u64) -> Self {
130        Self {
131            bytes_sent: 0,
132            bytes_received: 0,
133            window_start: Instant::now(),
134            window_duration: Duration::from_secs(1),
135            limit,
136        }
137    }
138
139    fn reset_if_needed(&mut self) {
140        let now = Instant::now();
141        if now.duration_since(self.window_start) >= self.window_duration {
142            self.bytes_sent = 0;
143            self.bytes_received = 0;
144            self.window_start = now;
145        }
146    }
147
148    fn can_send(&mut self, bytes: u64) -> bool {
149        self.reset_if_needed();
150        self.bytes_sent + bytes <= self.limit
151    }
152
153    fn record_sent(&mut self, bytes: u64) {
154        self.reset_if_needed();
155        self.bytes_sent += bytes;
156    }
157
158    fn record_received(&mut self, bytes: u64) {
159        self.reset_if_needed();
160        self.bytes_received += bytes;
161    }
162
163    fn current_usage(&mut self) -> u64 {
164        self.reset_if_needed();
165        self.bytes_sent + self.bytes_received
166    }
167}
168
169impl RelayConnection {
170    /// Create a new relay connection
171    pub fn new(
172        session_id: u32,
173        peer_addr: SocketAddr,
174        config: RelayConnectionConfig,
175        event_sender: mpsc::UnboundedSender<RelayEvent>,
176        action_receiver: mpsc::UnboundedReceiver<RelayAction>,
177    ) -> Self {
178        let now = Instant::now();
179        let state = ConnectionState {
180            is_active: true,
181            outgoing_queue: VecDeque::new(),
182            incoming_queue: VecDeque::new(),
183            buffer_usage: 0,
184            bandwidth_tracker: BandwidthTracker::new(config.bandwidth_limit),
185            last_activity: now,
186            next_keep_alive: now + config.keep_alive_interval,
187        };
188
189        Self {
190            session_id,
191            peer_addr,
192            config,
193            state: Arc::new(Mutex::new(state)),
194            event_sender,
195            action_receiver,
196        }
197    }
198
199    /// Get session ID
200    pub fn session_id(&self) -> u32 {
201        self.session_id
202    }
203
204    /// Get peer address
205    pub fn peer_addr(&self) -> SocketAddr {
206        self.peer_addr
207    }
208
209    /// Check if connection is active
210    pub fn is_active(&self) -> bool {
211        let state = self.state.lock().unwrap();
212        state.is_active
213    }
214
215    /// Send data through the relay
216    pub fn send_data(&self, data: Vec<u8>) -> RelayResult<()> {
217        if data.len() > self.config.max_frame_size {
218            return Err(RelayError::ProtocolError {
219                frame_type: 0x46, // RELAY_DATA
220                reason: format!(
221                    "Data size {} exceeds maximum {}",
222                    data.len(),
223                    self.config.max_frame_size
224                ),
225            });
226        }
227
228        let mut state = self.state.lock().unwrap();
229
230        if !state.is_active {
231            return Err(RelayError::SessionError {
232                session_id: Some(self.session_id),
233                kind: crate::relay::error::SessionErrorKind::Terminated,
234            });
235        }
236
237        // Check bandwidth limit
238        if !state.bandwidth_tracker.can_send(data.len() as u64) {
239            let current_usage = state.bandwidth_tracker.current_usage();
240            return Err(RelayError::SessionError {
241                session_id: Some(self.session_id),
242                kind: crate::relay::error::SessionErrorKind::BandwidthExceeded {
243                    used: current_usage,
244                    limit: self.config.bandwidth_limit,
245                },
246            });
247        }
248
249        // Check buffer space
250        if state.buffer_usage + data.len() > self.config.buffer_size {
251            return Err(RelayError::ResourceExhausted {
252                resource_type: "buffer".to_string(),
253                current_usage: state.buffer_usage as u64,
254                limit: self.config.buffer_size as u64,
255            });
256        }
257
258        // Queue data and update tracking
259        state.bandwidth_tracker.record_sent(data.len() as u64);
260        state.buffer_usage += data.len();
261        state.outgoing_queue.push_back(data.clone());
262        state.last_activity = Instant::now();
263
264        // Send event
265        let _ = self.event_sender.send(RelayEvent::DataReceived {
266            session_id: self.session_id,
267            data,
268        });
269
270        Ok(())
271    }
272
273    /// Receive data from the relay
274    pub fn receive_data(&self, data: Vec<u8>) -> RelayResult<()> {
275        let mut state = self.state.lock().unwrap();
276
277        if !state.is_active {
278            return Err(RelayError::SessionError {
279                session_id: Some(self.session_id),
280                kind: crate::relay::error::SessionErrorKind::Terminated,
281            });
282        }
283
284        // Check buffer space
285        if state.buffer_usage + data.len() > self.config.buffer_size {
286            return Err(RelayError::ResourceExhausted {
287                resource_type: "buffer".to_string(),
288                current_usage: state.buffer_usage as u64,
289                limit: self.config.buffer_size as u64,
290            });
291        }
292
293        // Queue data and update tracking
294        state.bandwidth_tracker.record_received(data.len() as u64);
295        state.buffer_usage += data.len();
296        state.incoming_queue.push_back(data.clone());
297        state.last_activity = Instant::now();
298
299        // Send event
300        let _ = self.event_sender.send(RelayEvent::DataReceived {
301            session_id: self.session_id,
302            data,
303        });
304
305        Ok(())
306    }
307
308    /// Get next outgoing data packet
309    pub fn next_outgoing(&self) -> Option<Vec<u8>> {
310        let mut state = self.state.lock().unwrap();
311        if let Some(data) = state.outgoing_queue.pop_front() {
312            state.buffer_usage = state.buffer_usage.saturating_sub(data.len());
313            Some(data)
314        } else {
315            None
316        }
317    }
318
319    /// Get next incoming data packet
320    pub fn next_incoming(&self) -> Option<Vec<u8>> {
321        let mut state = self.state.lock().unwrap();
322        if let Some(data) = state.incoming_queue.pop_front() {
323            state.buffer_usage = state.buffer_usage.saturating_sub(data.len());
324            Some(data)
325        } else {
326            None
327        }
328    }
329
330    /// Check if connection has timed out
331    pub fn check_timeout(&self) -> RelayResult<()> {
332        let state = self.state.lock().unwrap();
333        let now = Instant::now();
334
335        if now.duration_since(state.last_activity) > self.config.connection_timeout {
336            return Err(RelayError::SessionError {
337                session_id: Some(self.session_id),
338                kind: crate::relay::error::SessionErrorKind::Expired,
339            });
340        }
341
342        Ok(())
343    }
344
345    /// Check if keep-alive should be sent
346    pub fn should_send_keep_alive(&self) -> bool {
347        let state = self.state.lock().unwrap();
348        Instant::now() >= state.next_keep_alive
349    }
350
351    /// Send keep-alive
352    pub fn send_keep_alive(&self) -> RelayResult<()> {
353        let mut state = self.state.lock().unwrap();
354        state.next_keep_alive = Instant::now() + self.config.keep_alive_interval;
355
356        let _ = self.event_sender.send(RelayEvent::KeepAlive {
357            session_id: self.session_id,
358        });
359
360        Ok(())
361    }
362
363    /// Terminate the connection
364    pub fn terminate(&self, reason: String) -> RelayResult<()> {
365        let mut state = self.state.lock().unwrap();
366        state.is_active = false;
367
368        let _ = self.event_sender.send(RelayEvent::ConnectionTerminated {
369            session_id: self.session_id,
370            reason: reason.clone(),
371        });
372
373        Ok(())
374    }
375
376    /// Get connection statistics
377    pub fn get_stats(&self) -> ConnectionStats {
378        let state = self.state.lock().unwrap();
379        ConnectionStats {
380            session_id: self.session_id,
381            peer_addr: self.peer_addr,
382            is_active: state.is_active,
383            bytes_sent: state.bandwidth_tracker.bytes_sent,
384            bytes_received: state.bandwidth_tracker.bytes_received,
385            buffer_usage: state.buffer_usage,
386            outgoing_queue_size: state.outgoing_queue.len(),
387            incoming_queue_size: state.incoming_queue.len(),
388            last_activity: state.last_activity,
389        }
390    }
391}
392
393/// Connection statistics
394#[derive(Debug, Clone)]
395pub struct ConnectionStats {
396    pub session_id: u32,
397    pub peer_addr: SocketAddr,
398    pub is_active: bool,
399    pub bytes_sent: u64,
400    pub bytes_received: u64,
401    pub buffer_usage: usize,
402    pub outgoing_queue_size: usize,
403    pub incoming_queue_size: usize,
404    pub last_activity: Instant,
405}
406
407#[cfg(test)]
408mod tests {
409    use super::*;
410    use std::net::{IpAddr, Ipv4Addr};
411    use tokio::sync::mpsc;
412
413    fn test_addr() -> SocketAddr {
414        SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080)
415    }
416
417    #[test]
418    fn test_relay_connection_creation() {
419        let (event_tx, _event_rx) = mpsc::unbounded_channel();
420        let (_action_tx, action_rx) = mpsc::unbounded_channel();
421
422        let connection = RelayConnection::new(
423            123,
424            test_addr(),
425            RelayConnectionConfig::default(),
426            event_tx,
427            action_rx,
428        );
429
430        assert_eq!(connection.session_id(), 123);
431        assert_eq!(connection.peer_addr(), test_addr());
432        assert!(connection.is_active());
433    }
434
435    #[test]
436    fn test_send_data_within_limits() {
437        let (event_tx, _event_rx) = mpsc::unbounded_channel();
438        let (_action_tx, action_rx) = mpsc::unbounded_channel();
439
440        let connection = RelayConnection::new(
441            123,
442            test_addr(),
443            RelayConnectionConfig::default(),
444            event_tx,
445            action_rx,
446        );
447
448        let data = vec![1, 2, 3, 4];
449        assert!(connection.send_data(data.clone()).is_ok());
450
451        // Check that data is queued
452        assert_eq!(connection.next_outgoing(), Some(data));
453    }
454
455    #[test]
456    fn test_send_data_exceeds_frame_size() {
457        let (event_tx, _event_rx) = mpsc::unbounded_channel();
458        let (_action_tx, action_rx) = mpsc::unbounded_channel();
459
460        let mut config = RelayConnectionConfig::default();
461        config.max_frame_size = 10;
462
463        let connection = RelayConnection::new(123, test_addr(), config, event_tx, action_rx);
464
465        let large_data = vec![0; 20];
466        assert!(connection.send_data(large_data).is_err());
467    }
468
469    #[test]
470    fn test_bandwidth_limiting() {
471        let (event_tx, _event_rx) = mpsc::unbounded_channel();
472        let (_action_tx, action_rx) = mpsc::unbounded_channel();
473
474        let mut config = RelayConnectionConfig::default();
475        config.bandwidth_limit = 100; // Very low limit
476
477        let connection = RelayConnection::new(123, test_addr(), config, event_tx, action_rx);
478
479        // First small packet should succeed
480        assert!(connection.send_data(vec![0; 50]).is_ok());
481
482        // Second packet should exceed bandwidth limit
483        assert!(connection.send_data(vec![0; 60]).is_err());
484    }
485
486    #[test]
487    fn test_buffer_size_limiting() {
488        let (event_tx, _event_rx) = mpsc::unbounded_channel();
489        let (_action_tx, action_rx) = mpsc::unbounded_channel();
490
491        let mut config = RelayConnectionConfig::default();
492        config.buffer_size = 100; // Very small buffer
493
494        let connection = RelayConnection::new(123, test_addr(), config, event_tx, action_rx);
495
496        // Fill buffer
497        assert!(connection.send_data(vec![0; 80]).is_ok());
498
499        // Should fail to add more data
500        assert!(connection.send_data(vec![0; 30]).is_err());
501    }
502
503    #[test]
504    fn test_connection_termination() {
505        let (event_tx, _event_rx) = mpsc::unbounded_channel();
506        let (_action_tx, action_rx) = mpsc::unbounded_channel();
507
508        let connection = RelayConnection::new(
509            123,
510            test_addr(),
511            RelayConnectionConfig::default(),
512            event_tx,
513            action_rx,
514        );
515
516        assert!(connection.is_active());
517
518        let reason = "Test termination".to_string();
519        assert!(connection.terminate(reason.clone()).is_ok());
520
521        assert!(!connection.is_active());
522
523        // Should not be able to send data after termination
524        assert!(connection.send_data(vec![1, 2, 3]).is_err());
525    }
526
527    #[test]
528    fn test_keep_alive() {
529        let (event_tx, _event_rx) = mpsc::unbounded_channel();
530        let (_action_tx, action_rx) = mpsc::unbounded_channel();
531
532        let mut config = RelayConnectionConfig::default();
533        config.keep_alive_interval = Duration::from_millis(1);
534
535        let connection = RelayConnection::new(123, test_addr(), config, event_tx, action_rx);
536
537        // Initially should not need keep-alive
538        assert!(!connection.should_send_keep_alive());
539
540        // Wait for keep-alive interval
541        std::thread::sleep(Duration::from_millis(2));
542
543        // Should need keep-alive now
544        assert!(connection.should_send_keep_alive());
545
546        // Send keep-alive
547        assert!(connection.send_keep_alive().is_ok());
548
549        // Should not need keep-alive immediately after sending
550        assert!(!connection.should_send_keep_alive());
551    }
552
553    #[test]
554    fn test_connection_stats() {
555        let (event_tx, _event_rx) = mpsc::unbounded_channel();
556        let (_action_tx, action_rx) = mpsc::unbounded_channel();
557
558        let connection = RelayConnection::new(
559            123,
560            test_addr(),
561            RelayConnectionConfig::default(),
562            event_tx,
563            action_rx,
564        );
565
566        // Send some data
567        connection.send_data(vec![1, 2, 3]).unwrap();
568        connection.receive_data(vec![4, 5, 6, 7]).unwrap();
569
570        let stats = connection.get_stats();
571        assert_eq!(stats.session_id, 123);
572        assert_eq!(stats.peer_addr, test_addr());
573        assert!(stats.is_active);
574        assert_eq!(stats.bytes_sent, 3);
575        assert_eq!(stats.bytes_received, 4);
576        assert_eq!(stats.outgoing_queue_size, 1);
577        assert_eq!(stats.incoming_queue_size, 1);
578    }
579}