1use std::collections::BTreeSet;
39use std::time::{Duration, Instant};
40use tracing::{debug, warn};
41
42const DEFAULT_RTO_MS: u64 = 200;
43
44#[derive(Debug, Clone)]
46pub struct InFlightPacket {
47 pub sequence: u64,
49 pub sent_at: Instant,
51 pub retransmit_count: u32,
53}
54
55impl InFlightPacket {
56 fn new(sequence: u64) -> Self {
57 InFlightPacket {
58 sequence,
59 sent_at: Instant::now(),
60 retransmit_count: 0,
61 }
62 }
63
64 pub fn is_timed_out(&self, rto: Duration) -> bool {
66 self.sent_at.elapsed() > rto
67 }
68}
69
70pub struct FlowController {
75 window_size: usize,
76 in_flight: Vec<InFlightPacket>,
77 acked: BTreeSet<u64>,
78 rto: Duration,
79 srtt: Option<Duration>,
80 total_sent: u64,
81 total_acked: u64,
82 total_lost: u64,
83}
84
85impl FlowController {
86 pub fn new(window_size: usize) -> Self {
88 debug!(window_size, "FlowController created");
89 FlowController {
90 window_size,
91 in_flight: Vec::new(),
92 acked: BTreeSet::new(),
93 rto: Duration::from_millis(DEFAULT_RTO_MS),
94 srtt: None,
95 total_sent: 0,
96 total_acked: 0,
97 total_lost: 0,
98 }
99 }
100
101 pub fn with_rto(window_size: usize, rto_ms: u64) -> Self {
103 let mut fc = Self::new(window_size);
104 fc.rto = Duration::from_millis(rto_ms);
105 fc
106 }
107
108 pub fn can_send(&self) -> bool {
112 self.in_flight.len() < self.window_size
113 }
114
115 pub fn available_slots(&self) -> usize {
117 self.window_size.saturating_sub(self.in_flight.len())
118 }
119
120 pub fn window_size(&self) -> usize {
122 self.window_size
123 }
124
125 pub fn set_window_size(&mut self, size: usize) {
127 debug!(old = self.window_size, new = size, "Window size updated");
128 self.window_size = size;
129 }
130
131 pub fn in_flight_count(&self) -> usize {
133 self.in_flight.len()
134 }
135
136 pub fn oldest_unacked_sequence(&self) -> Option<u64> {
139 self.in_flight.first().map(|p| p.sequence)
140 }
141
142 pub fn on_send(&mut self, sequence: u64) -> bool {
148 if !self.can_send() {
149 warn!(sequence, "on_send() called but window is full");
150 return false;
151 }
152 self.in_flight.push(InFlightPacket::new(sequence));
153 self.total_sent += 1;
154 debug!(
155 sequence,
156 in_flight = self.in_flight.len(),
157 window = self.window_size,
158 "Packet sent"
159 );
160 true
161 }
162
163 pub fn on_ack(&mut self, sequence: u64) -> bool {
167 if let Some(pos) = self.in_flight.iter().position(|p| p.sequence == sequence) {
168 let packet = self.in_flight.remove(pos);
169 let rtt = packet.sent_at.elapsed();
170
171 self.srtt = Some(match self.srtt {
172 None => rtt,
173 Some(srtt) => {
174 let srtt_ns = srtt.as_nanos() as u64;
175 let rtt_ns = rtt.as_nanos() as u64;
176 Duration::from_nanos(srtt_ns / 8 * 7 + rtt_ns / 8)
177 }
178 });
179
180 if let Some(srtt) = self.srtt {
181 self.rto = (srtt * 2).max(Duration::from_millis(50));
182 }
183
184 self.acked.insert(sequence);
185 self.total_acked += 1;
186 debug!(sequence, rtt_ms = rtt.as_millis(), in_flight = self.in_flight.len(), "Packet acked");
187 true
188 } else {
189 warn!(sequence, "on_ack() for unknown or duplicate sequence");
190 false
191 }
192 }
193
194 pub fn timed_out_packets(&mut self) -> Vec<u64> {
198 let rto = self.rto;
199 let mut timed_out = Vec::new();
200 for packet in self.in_flight.iter_mut() {
201 if packet.is_timed_out(rto) {
202 warn!(
203 sequence = packet.sequence,
204 retransmit_count = packet.retransmit_count,
205 rto_ms = rto.as_millis(),
206 "Packet timed out"
207 );
208 timed_out.push(packet.sequence);
209 packet.retransmit_count += 1;
210 packet.sent_at = Instant::now();
211 self.total_lost += 1;
212 }
213 }
214 timed_out
215 }
216
217 pub fn srtt(&self) -> Option<Duration> {
221 self.srtt
222 }
223
224 pub fn rto(&self) -> Duration {
226 self.rto
227 }
228
229 pub fn total_sent(&self) -> u64 {
231 self.total_sent
232 }
233
234 pub fn total_acked(&self) -> u64 {
236 self.total_acked
237 }
238
239 pub fn total_lost(&self) -> u64 {
241 self.total_lost
242 }
243
244 pub fn loss_rate(&self) -> f64 {
246 if self.total_sent == 0 { return 0.0; }
247 self.total_lost as f64 / self.total_sent as f64
248 }
249
250 pub fn is_acked(&self, sequence: u64) -> bool {
252 self.acked.contains(&sequence)
253 }
254
255 pub fn reset(&mut self) {
257 debug!("FlowController reset");
258 self.in_flight.clear();
259 self.acked.clear();
260 self.srtt = None;
261 self.total_sent = 0;
262 self.total_acked = 0;
263 self.total_lost = 0;
264 }
265}
266
267impl Default for FlowController {
268 fn default() -> Self {
269 Self::new(64)
270 }
271}
272
273#[cfg(test)]
274mod tests {
275 use super::*;
276
277 #[test]
278 fn test_new() {
279 let fc = FlowController::new(4);
280 assert_eq!(fc.window_size(), 4);
281 assert_eq!(fc.in_flight_count(), 0);
282 assert!(fc.can_send());
283 assert_eq!(fc.available_slots(), 4);
284 }
285
286 #[test]
287 fn test_window_full() {
288 let mut fc = FlowController::new(3);
289 assert!(fc.on_send(0));
290 assert!(fc.on_send(1));
291 assert!(fc.on_send(2));
292 assert!(!fc.can_send());
293 assert_eq!(fc.available_slots(), 0);
294 assert_eq!(fc.in_flight_count(), 3);
295 }
296
297 #[test]
298 fn test_ack_opens_window() {
299 let mut fc = FlowController::new(2);
300 fc.on_send(0);
301 fc.on_send(1);
302 assert!(!fc.can_send());
303 fc.on_ack(0);
304 assert!(fc.can_send());
305 assert_eq!(fc.available_slots(), 1);
306 }
307
308 #[test]
309 fn test_ack_unknown_sequence() {
310 let mut fc = FlowController::new(4);
311 fc.on_send(0);
312 assert!(!fc.on_ack(99));
313 assert_eq!(fc.in_flight_count(), 1);
314 }
315
316 #[test]
317 fn test_is_acked() {
318 let mut fc = FlowController::new(4);
319 fc.on_send(0);
320 assert!(!fc.is_acked(0));
321 fc.on_ack(0);
322 assert!(fc.is_acked(0));
323 }
324
325 #[test]
326 fn test_stats() {
327 let mut fc = FlowController::new(10);
328 fc.on_send(0);
329 fc.on_send(1);
330 fc.on_send(2);
331 fc.on_ack(0);
332 fc.on_ack(1);
333 assert_eq!(fc.total_sent(), 3);
334 assert_eq!(fc.total_acked(), 2);
335 assert_eq!(fc.in_flight_count(), 1);
336 }
337
338 #[test]
339 fn test_loss_rate_zero() {
340 let fc = FlowController::new(4);
341 assert_eq!(fc.loss_rate(), 0.0);
342 }
343
344 #[test]
345 fn test_set_window_size() {
346 let mut fc = FlowController::new(4);
347 fc.set_window_size(8);
348 assert_eq!(fc.window_size(), 8);
349 assert_eq!(fc.available_slots(), 8);
350 }
351
352 #[test]
353 fn test_reset() {
354 let mut fc = FlowController::new(4);
355 fc.on_send(0);
356 fc.on_send(1);
357 fc.on_ack(0);
358 fc.reset();
359 assert_eq!(fc.in_flight_count(), 0);
360 assert_eq!(fc.total_sent(), 0);
361 assert_eq!(fc.total_acked(), 0);
362 assert!(fc.srtt().is_none());
363 }
364
365 #[test]
366 fn test_timed_out_packets() {
367 let mut fc = FlowController::with_rto(4, 1);
368 fc.on_send(0);
369 fc.on_send(1);
370 std::thread::sleep(Duration::from_millis(5));
371 let timed_out = fc.timed_out_packets();
372 assert_eq!(timed_out.len(), 2);
373 assert!(timed_out.contains(&0));
374 assert!(timed_out.contains(&1));
375 assert_eq!(fc.total_lost(), 2);
376 }
377
378 #[test]
379 fn test_srtt_updated_on_ack() {
380 let mut fc = FlowController::new(4);
381 fc.on_send(0);
382 assert!(fc.srtt().is_none());
383 fc.on_ack(0);
384 assert!(fc.srtt().is_some());
385 }
386
387 #[test]
388 fn test_default() {
389 let fc = FlowController::default();
390 assert_eq!(fc.window_size(), 64);
391 }
392
393 #[test]
394 fn test_on_send_full_window_returns_false() {
395 let mut fc = FlowController::new(1);
396 assert!(fc.on_send(0));
397 assert!(!fc.on_send(1));
398 }
399
400 #[test]
401 fn test_multiple_acks() {
402 let mut fc = FlowController::new(10);
403 for i in 0..10 { fc.on_send(i); }
404 for i in 0..10 { assert!(fc.on_ack(i)); }
405 assert_eq!(fc.total_acked(), 10);
406 assert_eq!(fc.in_flight_count(), 0);
407 assert_eq!(fc.available_slots(), 10);
408 }
409
410 #[test]
411 fn test_oldest_unacked_sequence() {
412 let mut fc = FlowController::new(4);
413 assert!(fc.oldest_unacked_sequence().is_none());
414 fc.on_send(5);
415 fc.on_send(6);
416 assert_eq!(fc.oldest_unacked_sequence(), Some(5));
417 fc.on_ack(5);
418 assert_eq!(fc.oldest_unacked_sequence(), Some(6));
419 }
420}