Skip to main content

corevpn_protocol/
reliable.rs

1//! Reliable Transport Layer for Control Channel
2//!
3//! Implements packet acknowledgment and retransmission for the control channel.
4
5use std::collections::{BTreeMap, VecDeque};
6use std::time::{Duration, Instant};
7
8use bytes::Bytes;
9
10use crate::{ProtocolError, Result};
11
12/// Configuration for reliable transport
13#[derive(Debug, Clone)]
14pub struct ReliableConfig {
15    /// Initial retransmit timeout
16    pub initial_rto: Duration,
17    /// Maximum retransmit timeout
18    pub max_rto: Duration,
19    /// RTO backoff multiplier
20    pub rto_backoff: f64,
21    /// Maximum retransmit attempts
22    pub max_retransmits: u32,
23    /// Window size (max outstanding packets)
24    pub window_size: u32,
25    /// ACK delay (time to wait before sending standalone ACK)
26    pub ack_delay: Duration,
27}
28
29impl Default for ReliableConfig {
30    fn default() -> Self {
31        Self {
32            initial_rto: Duration::from_secs(2),
33            max_rto: Duration::from_secs(60),
34            rto_backoff: 2.0,
35            max_retransmits: 10,
36            window_size: 8,
37            ack_delay: Duration::from_millis(100),
38        }
39    }
40}
41
42/// Outgoing packet awaiting acknowledgment
43#[derive(Debug)]
44struct PendingPacket {
45    /// Packet data
46    data: Bytes,
47    /// Time sent
48    sent_at: Instant,
49    /// Next retransmit time
50    next_retransmit: Instant,
51    /// Current RTO
52    rto: Duration,
53    /// Retransmit count
54    retransmits: u32,
55}
56
57/// Reliable transport layer
58pub struct ReliableTransport {
59    /// Configuration
60    config: ReliableConfig,
61    /// Next packet ID to send
62    next_send_id: u32,
63    /// Next expected packet ID to receive
64    next_recv_id: u32,
65    /// Packets awaiting ACK
66    pending: BTreeMap<u32, PendingPacket>,
67    /// ACKs to send
68    pending_acks: VecDeque<u32>,
69    /// Out-of-order received packets
70    out_of_order: BTreeMap<u32, Bytes>,
71    /// Time of last ACK sent
72    last_ack_sent: Option<Instant>,
73    /// Smoothed RTT (for RTO calculation)
74    srtt: Option<Duration>,
75    /// RTT variation
76    rttvar: Duration,
77}
78
79impl ReliableTransport {
80    /// Create a new reliable transport
81    pub fn new(config: ReliableConfig) -> Self {
82        Self {
83            config,
84            next_send_id: 0,
85            next_recv_id: 0,
86            pending: BTreeMap::new(),
87            pending_acks: VecDeque::new(),
88            out_of_order: BTreeMap::new(),
89            last_ack_sent: None,
90            srtt: None,
91            rttvar: Duration::from_millis(500),
92        }
93    }
94
95    /// Queue a packet for sending
96    ///
97    /// Returns the packet ID and the data to send
98    pub fn send(&mut self, data: Bytes) -> Result<(u32, Bytes)> {
99        // Check window
100        if self.pending.len() >= self.config.window_size as usize {
101            return Err(ProtocolError::InvalidPacket("send window full".into()));
102        }
103
104        let packet_id = self.next_send_id;
105        self.next_send_id = self.next_send_id.wrapping_add(1);
106
107        let now = Instant::now();
108        let rto = self.calculate_rto();
109
110        self.pending.insert(
111            packet_id,
112            PendingPacket {
113                data: data.clone(),
114                sent_at: now,
115                next_retransmit: now + rto,
116                rto,
117                retransmits: 0,
118            },
119        );
120
121        Ok((packet_id, data))
122    }
123
124    /// Maximum number of out-of-order packets to buffer
125    const MAX_OUT_OF_ORDER: usize = 100;
126
127    /// Process received packet
128    ///
129    /// Returns the payload if this is the next expected packet,
130    /// otherwise buffers it for later delivery.
131    pub fn receive(&mut self, packet_id: u32, data: Bytes) -> Result<Option<Bytes>> {
132        // Queue ACK
133        self.pending_acks.push_back(packet_id);
134
135        if packet_id == self.next_recv_id {
136            // In order - deliver immediately
137            self.next_recv_id = self.next_recv_id.wrapping_add(1);
138
139            // Check for buffered packets that can now be delivered
140            while let Some(_buffered) = self.out_of_order.remove(&self.next_recv_id) {
141                self.next_recv_id = self.next_recv_id.wrapping_add(1);
142                // Note: in a real implementation, we'd queue these for delivery
143            }
144
145            Ok(Some(data))
146        } else if packet_id > self.next_recv_id {
147            // Out of order - buffer
148            // Security: Limit buffer size to prevent DoS
149            if self.out_of_order.len() >= Self::MAX_OUT_OF_ORDER {
150                return Err(ProtocolError::InvalidPacket(
151                    "too many out-of-order packets".into(),
152                ));
153            }
154            self.out_of_order.insert(packet_id, data);
155            Ok(None)
156        } else {
157            // Duplicate - ignore
158            Ok(None)
159        }
160    }
161
162    /// Process received ACKs
163    pub fn process_acks(&mut self, acks: &[u32]) {
164        let now = Instant::now();
165
166        for &ack_id in acks {
167            if let Some(pending) = self.pending.remove(&ack_id) {
168                // Update RTT estimate
169                if pending.retransmits == 0 {
170                    let rtt = now.duration_since(pending.sent_at);
171                    self.update_rtt(rtt);
172                }
173            }
174        }
175    }
176
177    /// Get ACKs to send
178    pub fn get_acks(&mut self) -> Vec<u32> {
179        self.pending_acks.drain(..).collect()
180    }
181
182    /// Check if we should send a standalone ACK
183    pub fn should_send_ack(&self) -> bool {
184        if self.pending_acks.is_empty() {
185            return false;
186        }
187
188        match self.last_ack_sent {
189            Some(last) => last.elapsed() >= self.config.ack_delay,
190            None => true,
191        }
192    }
193
194    /// Mark ACK as sent
195    pub fn ack_sent(&mut self) {
196        self.last_ack_sent = Some(Instant::now());
197    }
198
199    /// Get packets that need retransmission
200    pub fn get_retransmits(&mut self) -> Vec<(u32, Bytes)> {
201        let now = Instant::now();
202        let mut retransmits = Vec::new();
203
204        for (id, pending) in self.pending.iter_mut() {
205            if now >= pending.next_retransmit {
206                if pending.retransmits >= self.config.max_retransmits {
207                    // TODO: Signal connection failure
208                    continue;
209                }
210
211                retransmits.push((*id, pending.data.clone()));
212
213                // Update for next retransmit
214                pending.retransmits += 1;
215                pending.rto = Duration::from_secs_f64(
216                    (pending.rto.as_secs_f64() * self.config.rto_backoff)
217                        .min(self.config.max_rto.as_secs_f64()),
218                );
219                pending.next_retransmit = now + pending.rto;
220            }
221        }
222
223        retransmits
224    }
225
226    /// Check if there are pending packets
227    pub fn has_pending(&self) -> bool {
228        !self.pending.is_empty()
229    }
230
231    /// Get next timeout (when we need to check for retransmits)
232    pub fn next_timeout(&self) -> Option<Duration> {
233        self.pending
234            .values()
235            .map(|p| p.next_retransmit)
236            .min()
237            .map(|t| t.saturating_duration_since(Instant::now()))
238    }
239
240    fn calculate_rto(&self) -> Duration {
241        match self.srtt {
242            Some(srtt) => {
243                // RTO = SRTT + 4 * RTTVAR (RFC 6298)
244                let rto = srtt + self.rttvar * 4;
245                rto.max(self.config.initial_rto)
246                    .min(self.config.max_rto)
247            }
248            None => self.config.initial_rto,
249        }
250    }
251
252    fn update_rtt(&mut self, rtt: Duration) {
253        match self.srtt {
254            Some(srtt) => {
255                // RTTVAR = (1 - beta) * RTTVAR + beta * |SRTT - R|
256                // SRTT = (1 - alpha) * SRTT + alpha * R
257                // where alpha = 1/8, beta = 1/4
258                let diff = rtt.abs_diff(srtt);
259                self.rttvar = Duration::from_secs_f64(
260                    0.75 * self.rttvar.as_secs_f64() + 0.25 * diff.as_secs_f64(),
261                );
262                self.srtt = Some(Duration::from_secs_f64(
263                    0.875 * srtt.as_secs_f64() + 0.125 * rtt.as_secs_f64(),
264                ));
265            }
266            None => {
267                // First RTT measurement
268                self.srtt = Some(rtt);
269                self.rttvar = rtt / 2;
270            }
271        }
272    }
273}
274
275/// Reassembles fragmented TLS records
276pub struct TlsRecordReassembler {
277    /// Buffer for partial records
278    buffer: Vec<u8>,
279    /// Maximum buffer size
280    max_size: usize,
281}
282
283impl TlsRecordReassembler {
284    /// Create a new reassembler
285    pub fn new(max_size: usize) -> Self {
286        Self {
287            buffer: Vec::new(),
288            max_size,
289        }
290    }
291
292    /// Add data to the buffer
293    pub fn add(&mut self, data: &[u8]) -> Result<()> {
294        if self.buffer.len() + data.len() > self.max_size {
295            return Err(ProtocolError::InvalidPacket("TLS record too large".into()));
296        }
297        self.buffer.extend_from_slice(data);
298        Ok(())
299    }
300
301    /// Try to extract complete TLS records
302    pub fn extract_records(&mut self) -> Vec<Bytes> {
303        let mut records = Vec::new();
304
305        while self.buffer.len() >= 5 {
306            // TLS record header: type (1) + version (2) + length (2)
307            let length = u16::from_be_bytes([self.buffer[3], self.buffer[4]]) as usize;
308
309            if self.buffer.len() < 5 + length {
310                break; // Incomplete record
311            }
312
313            let record = self.buffer.drain(..5 + length).collect::<Vec<_>>();
314            records.push(Bytes::from(record));
315        }
316
317        records
318    }
319
320    /// Get buffer length
321    pub fn len(&self) -> usize {
322        self.buffer.len()
323    }
324
325    /// Check if buffer is empty
326    pub fn is_empty(&self) -> bool {
327        self.buffer.is_empty()
328    }
329
330    /// Clear the buffer
331    pub fn clear(&mut self) {
332        self.buffer.clear();
333    }
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339
340    #[test]
341    fn test_reliable_basic() {
342        let mut transport = ReliableTransport::new(ReliableConfig::default());
343
344        // Send a packet
345        let (id, _) = transport.send(Bytes::from_static(b"hello")).unwrap();
346        assert_eq!(id, 0);
347        assert!(transport.has_pending());
348
349        // ACK it
350        transport.process_acks(&[0]);
351        assert!(!transport.has_pending());
352    }
353
354    #[test]
355    fn test_reliable_receive() {
356        let mut transport = ReliableTransport::new(ReliableConfig::default());
357
358        // Receive packet 0
359        let data = transport.receive(0, Bytes::from_static(b"first"));
360        assert!(data.is_some());
361
362        // Receive packet 2 (out of order)
363        let data = transport.receive(2, Bytes::from_static(b"third"));
364        assert!(data.is_none()); // Buffered
365
366        // Receive packet 1
367        let data = transport.receive(1, Bytes::from_static(b"second"));
368        assert!(data.is_some());
369
370        // Packet 2 should now be deliverable (in a real impl)
371    }
372
373    #[test]
374    fn test_tls_reassembler() {
375        let mut reassembler = TlsRecordReassembler::new(16384);
376
377        // Add partial TLS record header
378        reassembler.add(&[0x17, 0x03, 0x03, 0x00, 0x05]).unwrap();
379        assert!(reassembler.extract_records().is_empty());
380
381        // Add the rest
382        reassembler.add(&[1, 2, 3, 4, 5]).unwrap();
383        let records = reassembler.extract_records();
384        assert_eq!(records.len(), 1);
385        assert_eq!(records[0].len(), 10); // 5 header + 5 payload
386    }
387}