1use std::collections::{BTreeMap, VecDeque};
6use std::time::{Duration, Instant};
7
8use bytes::Bytes;
9
10use crate::{ProtocolError, Result};
11
12#[derive(Debug, Clone)]
14pub struct ReliableConfig {
15 pub initial_rto: Duration,
17 pub max_rto: Duration,
19 pub rto_backoff: f64,
21 pub max_retransmits: u32,
23 pub window_size: u32,
25 pub ack_delay: Duration,
27}
28
29impl Default for ReliableConfig {
30 fn default() -> Self {
31 Self {
32 initial_rto: Duration::from_secs(2),
33 max_rto: Duration::from_secs(60),
34 rto_backoff: 2.0,
35 max_retransmits: 10,
36 window_size: 8,
37 ack_delay: Duration::from_millis(100),
38 }
39 }
40}
41
42#[derive(Debug)]
44struct PendingPacket {
45 data: Bytes,
47 sent_at: Instant,
49 next_retransmit: Instant,
51 rto: Duration,
53 retransmits: u32,
55}
56
57pub struct ReliableTransport {
59 config: ReliableConfig,
61 next_send_id: u32,
63 next_recv_id: u32,
65 pending: BTreeMap<u32, PendingPacket>,
67 pending_acks: VecDeque<u32>,
69 out_of_order: BTreeMap<u32, Bytes>,
71 last_ack_sent: Option<Instant>,
73 srtt: Option<Duration>,
75 rttvar: Duration,
77}
78
79impl ReliableTransport {
80 pub fn new(config: ReliableConfig) -> Self {
82 Self {
83 config,
84 next_send_id: 0,
85 next_recv_id: 0,
86 pending: BTreeMap::new(),
87 pending_acks: VecDeque::new(),
88 out_of_order: BTreeMap::new(),
89 last_ack_sent: None,
90 srtt: None,
91 rttvar: Duration::from_millis(500),
92 }
93 }
94
95 pub fn send(&mut self, data: Bytes) -> Result<(u32, Bytes)> {
99 if self.pending.len() >= self.config.window_size as usize {
101 return Err(ProtocolError::InvalidPacket("send window full".into()));
102 }
103
104 let packet_id = self.next_send_id;
105 self.next_send_id = self.next_send_id.wrapping_add(1);
106
107 let now = Instant::now();
108 let rto = self.calculate_rto();
109
110 self.pending.insert(
111 packet_id,
112 PendingPacket {
113 data: data.clone(),
114 sent_at: now,
115 next_retransmit: now + rto,
116 rto,
117 retransmits: 0,
118 },
119 );
120
121 Ok((packet_id, data))
122 }
123
124 pub fn receive(&mut self, packet_id: u32, data: Bytes) -> Option<Bytes> {
129 self.pending_acks.push_back(packet_id);
131
132 if packet_id == self.next_recv_id {
133 self.next_recv_id = self.next_recv_id.wrapping_add(1);
135
136 while let Some(_buffered) = self.out_of_order.remove(&self.next_recv_id) {
138 self.next_recv_id = self.next_recv_id.wrapping_add(1);
139 }
141
142 Some(data)
143 } else if packet_id > self.next_recv_id {
144 self.out_of_order.insert(packet_id, data);
146 None
147 } else {
148 None
150 }
151 }
152
153 pub fn process_acks(&mut self, acks: &[u32]) {
155 let now = Instant::now();
156
157 for &ack_id in acks {
158 if let Some(pending) = self.pending.remove(&ack_id) {
159 if pending.retransmits == 0 {
161 let rtt = now.duration_since(pending.sent_at);
162 self.update_rtt(rtt);
163 }
164 }
165 }
166 }
167
168 pub fn get_acks(&mut self) -> Vec<u32> {
170 self.pending_acks.drain(..).collect()
171 }
172
173 pub fn should_send_ack(&self) -> bool {
175 if self.pending_acks.is_empty() {
176 return false;
177 }
178
179 match self.last_ack_sent {
180 Some(last) => last.elapsed() >= self.config.ack_delay,
181 None => true,
182 }
183 }
184
185 pub fn ack_sent(&mut self) {
187 self.last_ack_sent = Some(Instant::now());
188 }
189
190 pub fn get_retransmits(&mut self) -> Vec<(u32, Bytes)> {
192 let now = Instant::now();
193 let mut retransmits = Vec::new();
194
195 for (id, pending) in self.pending.iter_mut() {
196 if now >= pending.next_retransmit {
197 if pending.retransmits >= self.config.max_retransmits {
198 continue;
200 }
201
202 retransmits.push((*id, pending.data.clone()));
203
204 pending.retransmits += 1;
206 pending.rto = Duration::from_secs_f64(
207 (pending.rto.as_secs_f64() * self.config.rto_backoff)
208 .min(self.config.max_rto.as_secs_f64()),
209 );
210 pending.next_retransmit = now + pending.rto;
211 }
212 }
213
214 retransmits
215 }
216
217 pub fn has_pending(&self) -> bool {
219 !self.pending.is_empty()
220 }
221
222 pub fn next_timeout(&self) -> Option<Duration> {
224 self.pending
225 .values()
226 .map(|p| p.next_retransmit)
227 .min()
228 .map(|t| t.saturating_duration_since(Instant::now()))
229 }
230
231 fn calculate_rto(&self) -> Duration {
232 match self.srtt {
233 Some(srtt) => {
234 let rto = srtt + self.rttvar * 4;
236 rto.max(self.config.initial_rto)
237 .min(self.config.max_rto)
238 }
239 None => self.config.initial_rto,
240 }
241 }
242
243 fn update_rtt(&mut self, rtt: Duration) {
244 match self.srtt {
245 Some(srtt) => {
246 let diff = rtt.abs_diff(srtt);
250 self.rttvar = Duration::from_secs_f64(
251 0.75 * self.rttvar.as_secs_f64() + 0.25 * diff.as_secs_f64(),
252 );
253 self.srtt = Some(Duration::from_secs_f64(
254 0.875 * srtt.as_secs_f64() + 0.125 * rtt.as_secs_f64(),
255 ));
256 }
257 None => {
258 self.srtt = Some(rtt);
260 self.rttvar = rtt / 2;
261 }
262 }
263 }
264}
265
266pub struct TlsRecordReassembler {
268 buffer: Vec<u8>,
270 max_size: usize,
272}
273
274impl TlsRecordReassembler {
275 pub fn new(max_size: usize) -> Self {
277 Self {
278 buffer: Vec::new(),
279 max_size,
280 }
281 }
282
283 pub fn add(&mut self, data: &[u8]) -> Result<()> {
285 if self.buffer.len() + data.len() > self.max_size {
286 return Err(ProtocolError::InvalidPacket("TLS record too large".into()));
287 }
288 self.buffer.extend_from_slice(data);
289 Ok(())
290 }
291
292 pub fn extract_records(&mut self) -> Vec<Bytes> {
294 let mut records = Vec::new();
295
296 while self.buffer.len() >= 5 {
297 let length = u16::from_be_bytes([self.buffer[3], self.buffer[4]]) as usize;
299
300 if self.buffer.len() < 5 + length {
301 break; }
303
304 let record = self.buffer.drain(..5 + length).collect::<Vec<_>>();
305 records.push(Bytes::from(record));
306 }
307
308 records
309 }
310
311 pub fn len(&self) -> usize {
313 self.buffer.len()
314 }
315
316 pub fn is_empty(&self) -> bool {
318 self.buffer.is_empty()
319 }
320
321 pub fn clear(&mut self) {
323 self.buffer.clear();
324 }
325}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330
331 #[test]
332 fn test_reliable_basic() {
333 let mut transport = ReliableTransport::new(ReliableConfig::default());
334
335 let (id, _) = transport.send(Bytes::from_static(b"hello")).unwrap();
337 assert_eq!(id, 0);
338 assert!(transport.has_pending());
339
340 transport.process_acks(&[0]);
342 assert!(!transport.has_pending());
343 }
344
345 #[test]
346 fn test_reliable_receive() {
347 let mut transport = ReliableTransport::new(ReliableConfig::default());
348
349 let data = transport.receive(0, Bytes::from_static(b"first"));
351 assert!(data.is_some());
352
353 let data = transport.receive(2, Bytes::from_static(b"third"));
355 assert!(data.is_none()); let data = transport.receive(1, Bytes::from_static(b"second"));
359 assert!(data.is_some());
360
361 }
363
364 #[test]
365 fn test_tls_reassembler() {
366 let mut reassembler = TlsRecordReassembler::new(16384);
367
368 reassembler.add(&[0x17, 0x03, 0x03, 0x00, 0x05]).unwrap();
370 assert!(reassembler.extract_records().is_empty());
371
372 reassembler.add(&[1, 2, 3, 4, 5]).unwrap();
374 let records = reassembler.extract_records();
375 assert_eq!(records.len(), 1);
376 assert_eq!(records[0].len(), 10); }
378}