1use std::time::{Duration, Instant};
6
7#[derive(Debug, Clone)]
9pub struct PendingAck {
10 pub version: u64,
12 pub sent_at: Instant,
14 pub retransmit_count: u32,
16 pub rto: Duration,
18}
19
20impl PendingAck {
21 pub fn new(version: u64, rto: Duration) -> Self {
23 Self {
24 version,
25 sent_at: Instant::now(),
26 retransmit_count: 0,
27 rto,
28 }
29 }
30
31 pub fn needs_retransmit(&self) -> bool {
33 self.sent_at.elapsed() >= self.rto
34 }
35
36 pub fn retransmit(&mut self, backoff_multiplier: u32, max_rto: Duration) {
38 self.sent_at = Instant::now();
39 self.retransmit_count += 1;
40 self.rto = (self.rto * backoff_multiplier).min(max_rto);
42 }
43
44 pub fn time_until_retransmit(&self) -> Duration {
46 let elapsed = self.sent_at.elapsed();
47 if elapsed >= self.rto {
48 Duration::ZERO
49 } else {
50 self.rto - elapsed
51 }
52 }
53}
54
55pub const DEFAULT_INITIAL_RTO: Duration = Duration::from_millis(1000);
57
58pub const DEFAULT_MIN_RTO: Duration = Duration::from_millis(100);
61
62pub const DEFAULT_MAX_RTO: Duration = Duration::from_secs(60);
65
66pub const DEFAULT_BACKOFF_MULTIPLIER: u32 = 2;
69
70pub const DEFAULT_MAX_RETRANSMITS: u32 = 10;
73
74#[derive(Debug)]
78pub struct AckTracker {
79 pending: Vec<PendingAck>,
81
82 highest_acked: u64,
84
85 initial_rto: Duration,
87 min_rto: Duration,
88 max_rto: Duration,
89 backoff_multiplier: u32,
90 max_retransmits: u32,
91
92 srtt: Option<Duration>,
94 rttvar: Option<Duration>,
95}
96
97impl AckTracker {
98 pub fn new() -> Self {
100 Self {
101 pending: Vec::new(),
102 highest_acked: 0,
103 initial_rto: DEFAULT_INITIAL_RTO,
104 min_rto: DEFAULT_MIN_RTO,
105 max_rto: DEFAULT_MAX_RTO,
106 backoff_multiplier: DEFAULT_BACKOFF_MULTIPLIER,
107 max_retransmits: DEFAULT_MAX_RETRANSMITS,
108 srtt: None,
109 rttvar: None,
110 }
111 }
112
113 pub fn with_rto(
115 initial_rto: Duration,
116 min_rto: Duration,
117 max_rto: Duration,
118 backoff_multiplier: u32,
119 max_retransmits: u32,
120 ) -> Self {
121 Self {
122 pending: Vec::new(),
123 highest_acked: 0,
124 initial_rto,
125 min_rto,
126 max_rto,
127 backoff_multiplier,
128 max_retransmits,
129 srtt: None,
130 rttvar: None,
131 }
132 }
133
134 pub fn register_sent(&mut self, version: u64) {
136 if self.pending.iter().any(|p| p.version == version) {
138 return;
139 }
140
141 let rto = self.current_rto();
142 self.pending.push(PendingAck::new(version, rto));
143 }
144
145 pub fn process_ack(&mut self, acked_version: u64) -> Option<Duration> {
149 if acked_version <= self.highest_acked {
150 return None;
151 }
152
153 self.highest_acked = acked_version;
154
155 let mut rtt_sample = None;
157
158 self.pending.retain(|pending| {
159 if pending.version <= acked_version {
160 if pending.retransmit_count == 0 && rtt_sample.is_none() {
162 rtt_sample = Some(pending.sent_at.elapsed());
163 }
164 false } else {
166 true }
168 });
169
170 if let Some(rtt) = rtt_sample {
172 self.update_rtt(rtt);
173 }
174
175 rtt_sample
176 }
177
178 fn update_rtt(&mut self, rtt: Duration) {
180 let rtt_secs = rtt.as_secs_f64();
181
182 match (self.srtt, self.rttvar) {
183 (None, None) => {
184 self.srtt = Some(rtt);
186 self.rttvar = Some(rtt / 2);
187 }
188 (Some(srtt), Some(rttvar)) => {
189 let srtt_secs = srtt.as_secs_f64();
191 let rttvar_secs = rttvar.as_secs_f64();
192
193 let new_rttvar =
196 0.75 * rttvar_secs + 0.25 * (srtt_secs - rtt_secs).abs();
197
198 let new_srtt = 0.875 * srtt_secs + 0.125 * rtt_secs;
201
202 self.srtt = Some(Duration::from_secs_f64(new_srtt));
203 self.rttvar = Some(Duration::from_secs_f64(new_rttvar));
204 }
205 _ => {}
206 }
207 }
208
209 pub fn current_rto(&self) -> Duration {
211 match (self.srtt, self.rttvar) {
212 (Some(srtt), Some(rttvar)) => {
213 let k = 4;
216 let g = Duration::from_millis(1);
217 let rto = srtt + (g.max(rttvar * k));
218 rto.clamp(self.min_rto, self.max_rto)
219 }
220 _ => self.initial_rto,
221 }
222 }
223
224 pub fn srtt(&self) -> Option<Duration> {
226 self.srtt
227 }
228
229 pub fn rttvar(&self) -> Option<Duration> {
231 self.rttvar
232 }
233
234 pub fn needs_retransmit(&self) -> impl Iterator<Item = u64> + '_ {
236 self.pending
237 .iter()
238 .filter(|p| p.needs_retransmit() && p.retransmit_count < self.max_retransmits)
239 .map(|p| p.version)
240 }
241
242 pub fn failed_versions(&self) -> impl Iterator<Item = u64> + '_ {
244 self.pending
245 .iter()
246 .filter(|p| p.retransmit_count >= self.max_retransmits)
247 .map(|p| p.version)
248 }
249
250 pub fn mark_retransmitted(&mut self, version: u64) {
252 if let Some(pending) = self.pending.iter_mut().find(|p| p.version == version) {
253 pending.retransmit(self.backoff_multiplier, self.max_rto);
254 }
255 }
256
257 pub fn has_pending(&self) -> bool {
259 !self.pending.is_empty()
260 }
261
262 pub fn pending_count(&self) -> usize {
264 self.pending.len()
265 }
266
267 pub fn highest_acked(&self) -> u64 {
269 self.highest_acked
270 }
271
272 pub fn time_until_retransmit(&self) -> Option<Duration> {
274 self.pending
275 .iter()
276 .filter(|p| p.retransmit_count < self.max_retransmits)
277 .map(|p| p.time_until_retransmit())
278 .min()
279 }
280
281 pub fn cancel(&mut self, version: u64) {
283 self.pending.retain(|p| p.version != version);
284 }
285
286 pub fn cancel_all(&mut self) {
288 self.pending.clear();
289 }
290
291 pub fn reset(&mut self) {
293 self.pending.clear();
294 self.highest_acked = 0;
295 self.srtt = None;
296 self.rttvar = None;
297 }
298}
299
300impl Default for AckTracker {
301 fn default() -> Self {
302 Self::new()
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use super::*;
309 use std::thread;
310
311 #[test]
312 fn test_new_tracker() {
313 let tracker = AckTracker::new();
314 assert!(!tracker.has_pending());
315 assert_eq!(tracker.highest_acked(), 0);
316 assert_eq!(tracker.current_rto(), DEFAULT_INITIAL_RTO);
317 }
318
319 #[test]
320 fn test_register_sent() {
321 let mut tracker = AckTracker::new();
322
323 tracker.register_sent(1);
324 assert!(tracker.has_pending());
325 assert_eq!(tracker.pending_count(), 1);
326
327 tracker.register_sent(1);
329 assert_eq!(tracker.pending_count(), 1);
330
331 tracker.register_sent(2);
332 assert_eq!(tracker.pending_count(), 2);
333 }
334
335 #[test]
336 fn test_process_ack() {
337 let mut tracker = AckTracker::new();
338
339 tracker.register_sent(1);
340 tracker.register_sent(2);
341 tracker.register_sent(3);
342
343 tracker.process_ack(2);
345 assert_eq!(tracker.highest_acked(), 2);
346 assert_eq!(tracker.pending_count(), 1); tracker.process_ack(1);
350 assert_eq!(tracker.highest_acked(), 2);
351 }
352
353 #[test]
354 fn test_rtt_sample() {
355 let mut tracker = AckTracker::new();
356
357 tracker.register_sent(1);
358 thread::sleep(Duration::from_millis(10));
359
360 let rtt = tracker.process_ack(1);
361 assert!(rtt.is_some());
362 assert!(rtt.unwrap() >= Duration::from_millis(10));
363
364 assert!(tracker.srtt().is_some());
366 assert!(tracker.rttvar().is_some());
367 }
368
369 #[test]
370 fn test_retransmit() {
371 let mut tracker = AckTracker::with_rto(
372 Duration::from_millis(10),
373 Duration::from_millis(10),
374 Duration::from_secs(1),
375 2,
376 3,
377 );
378
379 tracker.register_sent(1);
380
381 assert_eq!(tracker.needs_retransmit().count(), 0);
383
384 thread::sleep(Duration::from_millis(15));
386
387 let versions: Vec<_> = tracker.needs_retransmit().collect();
389 assert_eq!(versions, vec![1]);
390
391 tracker.mark_retransmitted(1);
393
394 assert_eq!(tracker.needs_retransmit().count(), 0);
396 }
397
398 #[test]
399 fn test_max_retransmits() {
400 let mut tracker = AckTracker::with_rto(
401 Duration::from_millis(1),
402 Duration::from_millis(1),
403 Duration::from_millis(10),
404 1, 2, );
407
408 tracker.register_sent(1);
409 thread::sleep(Duration::from_millis(5));
410
411 tracker.mark_retransmitted(1);
413 thread::sleep(Duration::from_millis(5));
414
415 tracker.mark_retransmitted(1);
417 thread::sleep(Duration::from_millis(5));
418
419 let failed: Vec<_> = tracker.failed_versions().collect();
421 assert_eq!(failed, vec![1]);
422
423 assert_eq!(tracker.needs_retransmit().count(), 0);
425 }
426
427 #[test]
428 fn test_cancel() {
429 let mut tracker = AckTracker::new();
430
431 tracker.register_sent(1);
432 tracker.register_sent(2);
433 tracker.register_sent(3);
434
435 tracker.cancel(2);
436 assert_eq!(tracker.pending_count(), 2);
437
438 tracker.cancel_all();
439 assert!(!tracker.has_pending());
440 }
441
442 #[test]
443 fn test_reset() {
444 let mut tracker = AckTracker::new();
445
446 tracker.register_sent(1);
447 tracker.process_ack(1);
448
449 tracker.reset();
450
451 assert!(!tracker.has_pending());
452 assert_eq!(tracker.highest_acked(), 0);
453 assert!(tracker.srtt().is_none());
454 }
455
456 #[test]
457 fn test_time_until_retransmit() {
458 let mut tracker = AckTracker::with_rto(
459 Duration::from_millis(100),
460 Duration::from_millis(100),
461 Duration::from_secs(1),
462 2,
463 10,
464 );
465
466 assert!(tracker.time_until_retransmit().is_none());
467
468 tracker.register_sent(1);
469 let time = tracker.time_until_retransmit();
470 assert!(time.is_some());
471 assert!(time.unwrap() <= Duration::from_millis(100));
472 }
473}