corevpn_protocol/
reliable.rs

1//! Reliable Transport Layer for Control Channel
2//!
3//! Implements packet acknowledgment and retransmission for the control channel.
4
5use std::collections::{BTreeMap, VecDeque};
6use std::time::{Duration, Instant};
7
8use bytes::Bytes;
9
10use crate::{ProtocolError, Result};
11
12/// Configuration for reliable transport
13#[derive(Debug, Clone)]
14pub struct ReliableConfig {
15    /// Initial retransmit timeout
16    pub initial_rto: Duration,
17    /// Maximum retransmit timeout
18    pub max_rto: Duration,
19    /// RTO backoff multiplier
20    pub rto_backoff: f64,
21    /// Maximum retransmit attempts
22    pub max_retransmits: u32,
23    /// Window size (max outstanding packets)
24    pub window_size: u32,
25    /// ACK delay (time to wait before sending standalone ACK)
26    pub ack_delay: Duration,
27}
28
29impl Default for ReliableConfig {
30    fn default() -> Self {
31        Self {
32            initial_rto: Duration::from_secs(2),
33            max_rto: Duration::from_secs(60),
34            rto_backoff: 2.0,
35            max_retransmits: 10,
36            window_size: 8,
37            ack_delay: Duration::from_millis(100),
38        }
39    }
40}
41
42/// Outgoing packet awaiting acknowledgment
43#[derive(Debug)]
44struct PendingPacket {
45    /// Packet data
46    data: Bytes,
47    /// Time sent
48    sent_at: Instant,
49    /// Next retransmit time
50    next_retransmit: Instant,
51    /// Current RTO
52    rto: Duration,
53    /// Retransmit count
54    retransmits: u32,
55}
56
57/// Reliable transport layer
58pub struct ReliableTransport {
59    /// Configuration
60    config: ReliableConfig,
61    /// Next packet ID to send
62    next_send_id: u32,
63    /// Next expected packet ID to receive
64    next_recv_id: u32,
65    /// Packets awaiting ACK
66    pending: BTreeMap<u32, PendingPacket>,
67    /// ACKs to send
68    pending_acks: VecDeque<u32>,
69    /// Out-of-order received packets
70    out_of_order: BTreeMap<u32, Bytes>,
71    /// Time of last ACK sent
72    last_ack_sent: Option<Instant>,
73    /// Smoothed RTT (for RTO calculation)
74    srtt: Option<Duration>,
75    /// RTT variation
76    rttvar: Duration,
77}
78
79impl ReliableTransport {
80    /// Create a new reliable transport
81    pub fn new(config: ReliableConfig) -> Self {
82        Self {
83            config,
84            next_send_id: 0,
85            next_recv_id: 0,
86            pending: BTreeMap::new(),
87            pending_acks: VecDeque::new(),
88            out_of_order: BTreeMap::new(),
89            last_ack_sent: None,
90            srtt: None,
91            rttvar: Duration::from_millis(500),
92        }
93    }
94
95    /// Queue a packet for sending
96    ///
97    /// Returns the packet ID and the data to send
98    pub fn send(&mut self, data: Bytes) -> Result<(u32, Bytes)> {
99        // Check window
100        if self.pending.len() >= self.config.window_size as usize {
101            return Err(ProtocolError::InvalidPacket("send window full".into()));
102        }
103
104        let packet_id = self.next_send_id;
105        self.next_send_id = self.next_send_id.wrapping_add(1);
106
107        let now = Instant::now();
108        let rto = self.calculate_rto();
109
110        self.pending.insert(
111            packet_id,
112            PendingPacket {
113                data: data.clone(),
114                sent_at: now,
115                next_retransmit: now + rto,
116                rto,
117                retransmits: 0,
118            },
119        );
120
121        Ok((packet_id, data))
122    }
123
124    /// Process received packet
125    ///
126    /// Returns the payload if this is the next expected packet,
127    /// otherwise buffers it for later delivery.
128    pub fn receive(&mut self, packet_id: u32, data: Bytes) -> Option<Bytes> {
129        // Queue ACK
130        self.pending_acks.push_back(packet_id);
131
132        if packet_id == self.next_recv_id {
133            // In order - deliver immediately
134            self.next_recv_id = self.next_recv_id.wrapping_add(1);
135
136            // Check for buffered packets that can now be delivered
137            while let Some(_buffered) = self.out_of_order.remove(&self.next_recv_id) {
138                self.next_recv_id = self.next_recv_id.wrapping_add(1);
139                // Note: in a real implementation, we'd queue these for delivery
140            }
141
142            Some(data)
143        } else if packet_id > self.next_recv_id {
144            // Out of order - buffer
145            self.out_of_order.insert(packet_id, data);
146            None
147        } else {
148            // Duplicate - ignore
149            None
150        }
151    }
152
153    /// Process received ACKs
154    pub fn process_acks(&mut self, acks: &[u32]) {
155        let now = Instant::now();
156
157        for &ack_id in acks {
158            if let Some(pending) = self.pending.remove(&ack_id) {
159                // Update RTT estimate
160                if pending.retransmits == 0 {
161                    let rtt = now.duration_since(pending.sent_at);
162                    self.update_rtt(rtt);
163                }
164            }
165        }
166    }
167
168    /// Get ACKs to send
169    pub fn get_acks(&mut self) -> Vec<u32> {
170        self.pending_acks.drain(..).collect()
171    }
172
173    /// Check if we should send a standalone ACK
174    pub fn should_send_ack(&self) -> bool {
175        if self.pending_acks.is_empty() {
176            return false;
177        }
178
179        match self.last_ack_sent {
180            Some(last) => last.elapsed() >= self.config.ack_delay,
181            None => true,
182        }
183    }
184
185    /// Mark ACK as sent
186    pub fn ack_sent(&mut self) {
187        self.last_ack_sent = Some(Instant::now());
188    }
189
190    /// Get packets that need retransmission
191    pub fn get_retransmits(&mut self) -> Vec<(u32, Bytes)> {
192        let now = Instant::now();
193        let mut retransmits = Vec::new();
194
195        for (id, pending) in self.pending.iter_mut() {
196            if now >= pending.next_retransmit {
197                if pending.retransmits >= self.config.max_retransmits {
198                    // TODO: Signal connection failure
199                    continue;
200                }
201
202                retransmits.push((*id, pending.data.clone()));
203
204                // Update for next retransmit
205                pending.retransmits += 1;
206                pending.rto = Duration::from_secs_f64(
207                    (pending.rto.as_secs_f64() * self.config.rto_backoff)
208                        .min(self.config.max_rto.as_secs_f64()),
209                );
210                pending.next_retransmit = now + pending.rto;
211            }
212        }
213
214        retransmits
215    }
216
217    /// Check if there are pending packets
218    pub fn has_pending(&self) -> bool {
219        !self.pending.is_empty()
220    }
221
222    /// Get next timeout (when we need to check for retransmits)
223    pub fn next_timeout(&self) -> Option<Duration> {
224        self.pending
225            .values()
226            .map(|p| p.next_retransmit)
227            .min()
228            .map(|t| t.saturating_duration_since(Instant::now()))
229    }
230
231    fn calculate_rto(&self) -> Duration {
232        match self.srtt {
233            Some(srtt) => {
234                // RTO = SRTT + 4 * RTTVAR (RFC 6298)
235                let rto = srtt + self.rttvar * 4;
236                rto.max(self.config.initial_rto)
237                    .min(self.config.max_rto)
238            }
239            None => self.config.initial_rto,
240        }
241    }
242
243    fn update_rtt(&mut self, rtt: Duration) {
244        match self.srtt {
245            Some(srtt) => {
246                // RTTVAR = (1 - beta) * RTTVAR + beta * |SRTT - R|
247                // SRTT = (1 - alpha) * SRTT + alpha * R
248                // where alpha = 1/8, beta = 1/4
249                let diff = rtt.abs_diff(srtt);
250                self.rttvar = Duration::from_secs_f64(
251                    0.75 * self.rttvar.as_secs_f64() + 0.25 * diff.as_secs_f64(),
252                );
253                self.srtt = Some(Duration::from_secs_f64(
254                    0.875 * srtt.as_secs_f64() + 0.125 * rtt.as_secs_f64(),
255                ));
256            }
257            None => {
258                // First RTT measurement
259                self.srtt = Some(rtt);
260                self.rttvar = rtt / 2;
261            }
262        }
263    }
264}
265
266/// Reassembles fragmented TLS records
267pub struct TlsRecordReassembler {
268    /// Buffer for partial records
269    buffer: Vec<u8>,
270    /// Maximum buffer size
271    max_size: usize,
272}
273
274impl TlsRecordReassembler {
275    /// Create a new reassembler
276    pub fn new(max_size: usize) -> Self {
277        Self {
278            buffer: Vec::new(),
279            max_size,
280        }
281    }
282
283    /// Add data to the buffer
284    pub fn add(&mut self, data: &[u8]) -> Result<()> {
285        if self.buffer.len() + data.len() > self.max_size {
286            return Err(ProtocolError::InvalidPacket("TLS record too large".into()));
287        }
288        self.buffer.extend_from_slice(data);
289        Ok(())
290    }
291
292    /// Try to extract complete TLS records
293    pub fn extract_records(&mut self) -> Vec<Bytes> {
294        let mut records = Vec::new();
295
296        while self.buffer.len() >= 5 {
297            // TLS record header: type (1) + version (2) + length (2)
298            let length = u16::from_be_bytes([self.buffer[3], self.buffer[4]]) as usize;
299
300            if self.buffer.len() < 5 + length {
301                break; // Incomplete record
302            }
303
304            let record = self.buffer.drain(..5 + length).collect::<Vec<_>>();
305            records.push(Bytes::from(record));
306        }
307
308        records
309    }
310
311    /// Get buffer length
312    pub fn len(&self) -> usize {
313        self.buffer.len()
314    }
315
316    /// Check if buffer is empty
317    pub fn is_empty(&self) -> bool {
318        self.buffer.is_empty()
319    }
320
321    /// Clear the buffer
322    pub fn clear(&mut self) {
323        self.buffer.clear();
324    }
325}
326
327#[cfg(test)]
328mod tests {
329    use super::*;
330
331    #[test]
332    fn test_reliable_basic() {
333        let mut transport = ReliableTransport::new(ReliableConfig::default());
334
335        // Send a packet
336        let (id, _) = transport.send(Bytes::from_static(b"hello")).unwrap();
337        assert_eq!(id, 0);
338        assert!(transport.has_pending());
339
340        // ACK it
341        transport.process_acks(&[0]);
342        assert!(!transport.has_pending());
343    }
344
345    #[test]
346    fn test_reliable_receive() {
347        let mut transport = ReliableTransport::new(ReliableConfig::default());
348
349        // Receive packet 0
350        let data = transport.receive(0, Bytes::from_static(b"first"));
351        assert!(data.is_some());
352
353        // Receive packet 2 (out of order)
354        let data = transport.receive(2, Bytes::from_static(b"third"));
355        assert!(data.is_none()); // Buffered
356
357        // Receive packet 1
358        let data = transport.receive(1, Bytes::from_static(b"second"));
359        assert!(data.is_some());
360
361        // Packet 2 should now be deliverable (in a real impl)
362    }
363
364    #[test]
365    fn test_tls_reassembler() {
366        let mut reassembler = TlsRecordReassembler::new(16384);
367
368        // Add partial TLS record header
369        reassembler.add(&[0x17, 0x03, 0x03, 0x00, 0x05]).unwrap();
370        assert!(reassembler.extract_records().is_empty());
371
372        // Add the rest
373        reassembler.add(&[1, 2, 3, 4, 5]).unwrap();
374        let records = reassembler.extract_records();
375        assert_eq!(records.len(), 1);
376        assert_eq!(records[0].len(), 10); // 5 header + 5 payload
377    }
378}