1use std::collections::{BTreeMap, VecDeque};
6use std::time::{Duration, Instant};
7
8use bytes::Bytes;
9
10use crate::{ProtocolError, Result};
11
12#[derive(Debug, Clone)]
14pub struct ReliableConfig {
15 pub initial_rto: Duration,
17 pub max_rto: Duration,
19 pub rto_backoff: f64,
21 pub max_retransmits: u32,
23 pub window_size: u32,
25 pub ack_delay: Duration,
27}
28
29impl Default for ReliableConfig {
30 fn default() -> Self {
31 Self {
32 initial_rto: Duration::from_secs(2),
33 max_rto: Duration::from_secs(60),
34 rto_backoff: 2.0,
35 max_retransmits: 10,
36 window_size: 8,
37 ack_delay: Duration::from_millis(100),
38 }
39 }
40}
41
42#[derive(Debug)]
44struct PendingPacket {
45 data: Bytes,
47 sent_at: Instant,
49 next_retransmit: Instant,
51 rto: Duration,
53 retransmits: u32,
55}
56
57pub struct ReliableTransport {
59 config: ReliableConfig,
61 next_send_id: u32,
63 next_recv_id: u32,
65 pending: BTreeMap<u32, PendingPacket>,
67 pending_acks: VecDeque<u32>,
69 out_of_order: BTreeMap<u32, Bytes>,
71 last_ack_sent: Option<Instant>,
73 srtt: Option<Duration>,
75 rttvar: Duration,
77}
78
79impl ReliableTransport {
80 pub fn new(config: ReliableConfig) -> Self {
82 Self {
83 config,
84 next_send_id: 0,
85 next_recv_id: 0,
86 pending: BTreeMap::new(),
87 pending_acks: VecDeque::new(),
88 out_of_order: BTreeMap::new(),
89 last_ack_sent: None,
90 srtt: None,
91 rttvar: Duration::from_millis(500),
92 }
93 }
94
95 pub fn send(&mut self, data: Bytes) -> Result<(u32, Bytes)> {
99 if self.pending.len() >= self.config.window_size as usize {
101 return Err(ProtocolError::InvalidPacket("send window full".into()));
102 }
103
104 let packet_id = self.next_send_id;
105 self.next_send_id = self.next_send_id.wrapping_add(1);
106
107 let now = Instant::now();
108 let rto = self.calculate_rto();
109
110 self.pending.insert(
111 packet_id,
112 PendingPacket {
113 data: data.clone(),
114 sent_at: now,
115 next_retransmit: now + rto,
116 rto,
117 retransmits: 0,
118 },
119 );
120
121 Ok((packet_id, data))
122 }
123
124 const MAX_OUT_OF_ORDER: usize = 100;
126
127 pub fn receive(&mut self, packet_id: u32, data: Bytes) -> Result<Option<Bytes>> {
132 self.pending_acks.push_back(packet_id);
134
135 if packet_id == self.next_recv_id {
136 self.next_recv_id = self.next_recv_id.wrapping_add(1);
138
139 while let Some(_buffered) = self.out_of_order.remove(&self.next_recv_id) {
141 self.next_recv_id = self.next_recv_id.wrapping_add(1);
142 }
144
145 Ok(Some(data))
146 } else if packet_id > self.next_recv_id {
147 if self.out_of_order.len() >= Self::MAX_OUT_OF_ORDER {
150 return Err(ProtocolError::InvalidPacket(
151 "too many out-of-order packets".into(),
152 ));
153 }
154 self.out_of_order.insert(packet_id, data);
155 Ok(None)
156 } else {
157 Ok(None)
159 }
160 }
161
162 pub fn process_acks(&mut self, acks: &[u32]) {
164 let now = Instant::now();
165
166 for &ack_id in acks {
167 if let Some(pending) = self.pending.remove(&ack_id) {
168 if pending.retransmits == 0 {
170 let rtt = now.duration_since(pending.sent_at);
171 self.update_rtt(rtt);
172 }
173 }
174 }
175 }
176
177 pub fn get_acks(&mut self) -> Vec<u32> {
179 self.pending_acks.drain(..).collect()
180 }
181
182 pub fn should_send_ack(&self) -> bool {
184 if self.pending_acks.is_empty() {
185 return false;
186 }
187
188 match self.last_ack_sent {
189 Some(last) => last.elapsed() >= self.config.ack_delay,
190 None => true,
191 }
192 }
193
194 pub fn ack_sent(&mut self) {
196 self.last_ack_sent = Some(Instant::now());
197 }
198
199 pub fn get_retransmits(&mut self) -> Vec<(u32, Bytes)> {
201 let now = Instant::now();
202 let mut retransmits = Vec::new();
203
204 for (id, pending) in self.pending.iter_mut() {
205 if now >= pending.next_retransmit {
206 if pending.retransmits >= self.config.max_retransmits {
207 continue;
209 }
210
211 retransmits.push((*id, pending.data.clone()));
212
213 pending.retransmits += 1;
215 pending.rto = Duration::from_secs_f64(
216 (pending.rto.as_secs_f64() * self.config.rto_backoff)
217 .min(self.config.max_rto.as_secs_f64()),
218 );
219 pending.next_retransmit = now + pending.rto;
220 }
221 }
222
223 retransmits
224 }
225
226 pub fn has_pending(&self) -> bool {
228 !self.pending.is_empty()
229 }
230
231 pub fn next_timeout(&self) -> Option<Duration> {
233 self.pending
234 .values()
235 .map(|p| p.next_retransmit)
236 .min()
237 .map(|t| t.saturating_duration_since(Instant::now()))
238 }
239
240 fn calculate_rto(&self) -> Duration {
241 match self.srtt {
242 Some(srtt) => {
243 let rto = srtt + self.rttvar * 4;
245 rto.max(self.config.initial_rto)
246 .min(self.config.max_rto)
247 }
248 None => self.config.initial_rto,
249 }
250 }
251
252 fn update_rtt(&mut self, rtt: Duration) {
253 match self.srtt {
254 Some(srtt) => {
255 let diff = rtt.abs_diff(srtt);
259 self.rttvar = Duration::from_secs_f64(
260 0.75 * self.rttvar.as_secs_f64() + 0.25 * diff.as_secs_f64(),
261 );
262 self.srtt = Some(Duration::from_secs_f64(
263 0.875 * srtt.as_secs_f64() + 0.125 * rtt.as_secs_f64(),
264 ));
265 }
266 None => {
267 self.srtt = Some(rtt);
269 self.rttvar = rtt / 2;
270 }
271 }
272 }
273}
274
275pub struct TlsRecordReassembler {
277 buffer: Vec<u8>,
279 max_size: usize,
281}
282
283impl TlsRecordReassembler {
284 pub fn new(max_size: usize) -> Self {
286 Self {
287 buffer: Vec::new(),
288 max_size,
289 }
290 }
291
292 pub fn add(&mut self, data: &[u8]) -> Result<()> {
294 if self.buffer.len() + data.len() > self.max_size {
295 return Err(ProtocolError::InvalidPacket("TLS record too large".into()));
296 }
297 self.buffer.extend_from_slice(data);
298 Ok(())
299 }
300
301 pub fn extract_records(&mut self) -> Vec<Bytes> {
303 let mut records = Vec::new();
304
305 while self.buffer.len() >= 5 {
306 let length = u16::from_be_bytes([self.buffer[3], self.buffer[4]]) as usize;
308
309 if self.buffer.len() < 5 + length {
310 break; }
312
313 let record = self.buffer.drain(..5 + length).collect::<Vec<_>>();
314 records.push(Bytes::from(record));
315 }
316
317 records
318 }
319
320 pub fn len(&self) -> usize {
322 self.buffer.len()
323 }
324
325 pub fn is_empty(&self) -> bool {
327 self.buffer.is_empty()
328 }
329
330 pub fn clear(&mut self) {
332 self.buffer.clear();
333 }
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339
340 #[test]
341 fn test_reliable_basic() {
342 let mut transport = ReliableTransport::new(ReliableConfig::default());
343
344 let (id, _) = transport.send(Bytes::from_static(b"hello")).unwrap();
346 assert_eq!(id, 0);
347 assert!(transport.has_pending());
348
349 transport.process_acks(&[0]);
351 assert!(!transport.has_pending());
352 }
353
354 #[test]
355 fn test_reliable_receive() {
356 let mut transport = ReliableTransport::new(ReliableConfig::default());
357
358 let data = transport.receive(0, Bytes::from_static(b"first"));
360 assert!(data.is_some());
361
362 let data = transport.receive(2, Bytes::from_static(b"third"));
364 assert!(data.is_none()); let data = transport.receive(1, Bytes::from_static(b"second"));
368 assert!(data.is_some());
369
370 }
372
373 #[test]
374 fn test_tls_reassembler() {
375 let mut reassembler = TlsRecordReassembler::new(16384);
376
377 reassembler.add(&[0x17, 0x03, 0x03, 0x00, 0x05]).unwrap();
379 assert!(reassembler.extract_records().is_empty());
380
381 reassembler.add(&[1, 2, 3, 4, 5]).unwrap();
383 let records = reassembler.extract_records();
384 assert_eq!(records.len(), 1);
385 assert_eq!(records[0].len(), 10); }
387}