cyfs_bdt/cc/
bbr.rs

1use std::{
2    fmt::Debug, 
3    time::Duration
4};
5
6use cyfs_base::*;
7use crate::types::*;
8use super::cc_impl::CcImpl;
9
10
11#[derive(Debug, Copy, Clone, Default)]
12struct MinMaxSample {
13    time: u64,
14    val: u64,
15}
16
17#[derive(Copy, Clone, Debug)]
18struct MinMax {
19    window: u64,
20    samples: [MinMaxSample; 3],
21}
22
23impl MinMax {
24    fn new(window: u64) -> Self {
25        MinMax {
26            window,
27            samples: [Default::default(); 3],
28        }
29    }
30
31    fn get(&self) -> u64 {
32        self.samples[0].val
33    }
34
35    fn reset(&mut self) {
36        self.samples.fill(Default::default());
37    }
38
39    fn update_max(&mut self, time: u64, meas: u64) {
40        let sample = MinMaxSample {
41            time,
42            val: meas,
43        };
44
45        if self.samples[0].val == 0 
46            || sample.val >= self.samples[0].val
47            || sample.time - self.samples[2].time > self.window
48        {
49            self.samples.fill(sample);
50            return;
51        }
52
53        if sample.val >= self.samples[1].val {
54            self.samples[2] = sample;
55            self.samples[1] = sample;
56        } else if sample.val >= self.samples[2].val {
57            self.samples[2] = sample;
58        }
59
60        self.subwin_update(sample);
61    }
62
63    fn subwin_update(&mut self, sample: MinMaxSample) {
64        let dt = sample.time - self.samples[0].time;
65        if dt > self.window {
66            self.samples[0] = self.samples[1];
67            self.samples[1] = self.samples[2];
68            self.samples[2] = sample;
69            if sample.time - self.samples[0].time > self.window {
70                self.samples[0] = self.samples[1];
71                self.samples[1] = self.samples[2];
72                self.samples[2] = sample;
73            }
74        } else if self.samples[1].time == self.samples[0].time && dt > self.window / 4 {
75            self.samples[2] = sample;
76            self.samples[1] = sample;
77        } else if self.samples[2].time == self.samples[1].time && dt > self.window / 2 {
78            self.samples[2] = sample;
79        }
80    }
81}
82
83
84
85#[derive(Clone)]
86struct BandwidthEstimation {
87    total_acked: u64,
88    prev_total_acked: u64,
89    acked_time: Timestamp,
90    prev_acked_time: Timestamp,
91    total_sent: u64,
92    prev_total_sent: u64,
93    sent_time: Timestamp,
94    prev_sent_time: Timestamp,
95    max_filter: MinMax,
96    acked_at_last_window: u64,
97
98    bw_info_show_time: Timestamp,
99}
100
101impl Default for BandwidthEstimation {
102    fn default() -> Self {
103        BandwidthEstimation {
104            total_acked: 0,
105            prev_total_acked: 0,
106            acked_time: 0,
107            prev_acked_time: 0,
108            total_sent: 0,
109            prev_total_sent: 0,
110            sent_time: 0,
111            prev_sent_time: 0,
112            max_filter: MinMax::new(10),
113            acked_at_last_window: 0,
114            bw_info_show_time: bucky_time_now(),
115        }
116    }
117}
118
119impl Debug for BandwidthEstimation {
120    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121        write!(
122            f,
123            "{:.2} KB/s",
124            self.get_estimate() as f32 / 1024 as f32
125        )
126    }
127}
128
129impl BandwidthEstimation {
130    // pub fn bw_info(&mut self) -> bool {
131    //     let now = SystemTime::now();
132    //     if system_time_to_bucky_time(&now) > self.bw_info_show_time {
133    //         info!("bbr-BandwidthEstimation-BW: {:?}", self);
134    //         self.bw_info_show_time = system_time_to_bucky_time(&(now + Duration::from_secs(5)));
135    //         true
136    //     } else {
137    //         false
138    //     }
139    // }
140
141    fn on_sent(&mut self, now: Timestamp, bytes: u64) {
142        self.prev_total_sent = self.total_sent;
143        self.total_sent += bytes;
144        self.prev_sent_time = self.sent_time;
145        self.sent_time = now;
146    }
147
148    fn on_ack(
149        &mut self,
150        now: Timestamp,
151        _sent: Timestamp,
152        bytes: u64,
153        round: u64,
154        app_limited: bool,
155    ) {
156        self.prev_total_acked = self.total_acked;
157        self.total_acked += bytes;
158        self.prev_acked_time = self.acked_time;
159        self.acked_time = now;
160
161        if self.prev_sent_time == 0 {
162            return;
163        }
164
165        let send_rate = if self.sent_time > self.prev_sent_time {
166            Self::bw_from_delta(
167                self.total_sent - self.prev_total_sent,
168                Duration::from_micros(self.sent_time - self.prev_sent_time)
169            )
170        } else {
171            u64::MAX
172        };
173
174        let ack_rate= if self.prev_acked_time == 0 {
175            0
176        } else {
177            Self::bw_from_delta(
178                self.total_acked - self.prev_total_acked,
179                Duration::from_micros(self.acked_time - self.prev_acked_time)
180            )
181        };
182
183        let bandwidth = send_rate.min(ack_rate);
184        if !app_limited && self.max_filter.get() < bandwidth {
185            self.max_filter.update_max(round, bandwidth);
186        }
187    }
188
189    fn bytes_acked_this_window(&self) -> u64 {
190        self.total_acked - self.acked_at_last_window
191    }
192
193    fn end_acks(&mut self, _current_round: u64, _app_limited: bool) {
194        self.acked_at_last_window = self.total_acked;
195    }
196
197    fn get_estimate(&self) -> u64 {
198        self.max_filter.get()
199    }
200
201    fn bw_from_delta(bytes: u64, delta: Duration) -> u64 {
202        let window_duration_ns = delta.as_nanos();
203        if window_duration_ns == 0 {
204            return 0;
205        }
206        let b_ns = bytes * 1_000_000_000;
207        let bytes_per_second = b_ns / (window_duration_ns as u64);
208        bytes_per_second
209    }
210}
211
212
213
214
215
216#[derive(Debug, Clone)]
217pub struct Config {
218    pub min_cwnd: u64, 
219    pub init_cwnd: u64, 
220    pub probe_rtt_time: Duration, 
221    pub probe_rtt_based_on_bdp: bool, 
222    pub drain_to_target: bool, 
223    pub startup_growth_target: f32, 
224    pub default_high_gain: f32, 
225    pub derived_high_cwnd_gain: f32, 
226    pub pacing_gain: [f32; 8], 
227    pub min_rtt_expire_time: Duration, 
228    pub mode_rate_probe_rtt_multiplier: f32, 
229    pub round_trips_with_growth_before_exiting_startup: u8, 
230}
231
232
233impl Default for Config {
234    fn default() -> Self {
235        Self {
236            min_cwnd: 2, 
237            init_cwnd: 10, 
238            probe_rtt_time: Duration::from_millis(200), 
239            probe_rtt_based_on_bdp: true, 
240            drain_to_target: true, 
241            startup_growth_target: 1.25, 
242            default_high_gain: 2.885, 
243            derived_high_cwnd_gain: 2.0, 
244            pacing_gain: [1.25, 0.75, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],  
245            min_rtt_expire_time: Duration::from_secs(10), 
246            mode_rate_probe_rtt_multiplier: 0.75, 
247            round_trips_with_growth_before_exiting_startup: 3
248        }
249    }
250}
251
252
253
254#[derive(Debug, Clone, Copy, Eq, PartialEq)]
255enum Mode {
256    Startup,
257    Drain,
258    ProbeBw,
259    ProbeRtt,
260}
261
262#[derive(Debug, Clone, Copy, Eq, PartialEq)]
263enum RecoveryState {
264    NotInRecovery,
265    Conservation,
266    Growth,
267}
268
269impl RecoveryState {
270    fn in_recovery(&self) -> bool {
271        !matches!(self, RecoveryState::NotInRecovery)
272    }
273}
274
275#[derive(Debug, Copy, Clone)]
276struct AckAggregationState {
277    max_ack_height: MinMax,
278    aggregation_epoch_start_time: Timestamp,
279    aggregation_epoch_bytes: u64,
280}
281
282
283impl AckAggregationState {
284    fn new() -> Self {
285        Self {
286            max_ack_height: MinMax::new(10),
287            aggregation_epoch_start_time: bucky_time_now(),
288            aggregation_epoch_bytes: 0,
289        }
290    }
291
292    fn update_ack_aggregation_bytes(
293        &mut self,
294        newly_acked_bytes: u64,
295        now: Timestamp,
296        round: u64,
297        max_bandwidth: u64,
298    ) -> u64 {
299        let expected_bytes_acked = if now > self.aggregation_epoch_start_time {
300            max_bandwidth * (now - self.aggregation_epoch_start_time) / 1_000_000
301        } else {
302            0
303        };
304            
305        if self.aggregation_epoch_bytes <= expected_bytes_acked {
306            self.aggregation_epoch_bytes = newly_acked_bytes;
307            self.aggregation_epoch_start_time = now;
308            return 0;
309        }
310
311        self.aggregation_epoch_bytes += newly_acked_bytes;
312        let diff = self.aggregation_epoch_bytes - expected_bytes_acked;
313        self.max_ack_height.update_max(round, diff);
314        diff
315    }
316}
317
318#[derive(Debug, Clone, Default)]
319struct LossState {
320    lost_bytes: u64,
321}
322
323impl LossState {
324    fn reset(&mut self) {
325        self.lost_bytes = 0;
326    }
327
328    fn has_losses(&self) -> bool {
329        self.lost_bytes != 0
330    }
331}
332
333#[derive(Debug)]
334pub struct Bbr {
335    config: Config,
336    rtt: Duration,
337    mss: u64, 
338
339    cwnd: u64,
340    max_bandwidth: BandwidthEstimation,
341    acked_bytes: u64,
342    mode: Mode,
343    loss_state: LossState,
344    recovery_state: RecoveryState,
345    recovery_window: u64,
346    is_at_full_bandwidth: bool,
347    last_cycle_start: Option<Timestamp>,
348    current_cycle_offset: u8,
349    prev_in_flight_count: u64,
350    exit_probe_rtt_at: Option<Timestamp>,
351    probe_rtt_last_started_at: Option<Timestamp>,
352    min_rtt: Duration,
353    exiting_quiescence: bool,
354    pacing_rate: u64,
355    max_acked_packet_number: u64,
356    max_sent_packet_number: u64,
357    end_recovery_at_packet_number: u64,
358    current_round_trip_end_packet_number: u64,
359    round_count: u64,
360    bw_at_last_round: u64,
361    ack_aggregation: AckAggregationState,
362    pacing_gain: f32,
363    high_gain: f32,
364    drain_gain: f32,
365    cwnd_gain: f32,
366    high_cwnd_gain: f32,
367    round_wo_bw_gain: u64,
368}
369
370impl Bbr {
371    pub fn new(mss: usize, config: &Config) -> Self {
372        let mut config = config.clone();
373        config.min_cwnd = config.min_cwnd * mss as u64;
374        config.init_cwnd = config.init_cwnd * mss as u64;
375
376
377        Self {
378            cwnd: config.init_cwnd,
379            max_bandwidth: BandwidthEstimation::default(),
380
381            acked_bytes: 0,
382            mode: Mode::Startup,
383            loss_state: Default::default(),
384            recovery_state: RecoveryState::NotInRecovery,
385            recovery_window: 0,
386            is_at_full_bandwidth: false,
387            pacing_gain: config.default_high_gain,
388            high_gain: config.default_high_gain,
389            drain_gain: 1.0 / config.default_high_gain,
390            cwnd_gain: config.default_high_gain,
391            high_cwnd_gain: config.default_high_gain,
392            last_cycle_start: None,
393            current_cycle_offset: 0,
394            prev_in_flight_count: 0,
395            exit_probe_rtt_at: None,
396            probe_rtt_last_started_at: None,
397            min_rtt: Default::default(),
398            exiting_quiescence: false,
399            pacing_rate: 0,
400            max_acked_packet_number: 0,
401            max_sent_packet_number: 0,
402            end_recovery_at_packet_number: 0,
403            current_round_trip_end_packet_number: 0,
404            round_count: 0,
405            bw_at_last_round: 0,
406            round_wo_bw_gain: 0,
407            ack_aggregation: AckAggregationState::new(),
408
409            mss: mss as u64, 
410            rtt: Duration::from_secs(0),
411            config, 
412        }
413    }
414
415    fn enter_startup_mode(&mut self) {
416        self.mode = Mode::Startup;
417        self.pacing_gain = self.high_gain;
418        self.cwnd_gain = self.high_cwnd_gain;
419    }
420
421    fn enter_probe_bandwidth_mode(&mut self, now: Timestamp) {
422        self.mode = Mode::ProbeBw;
423        self.cwnd_gain = self.config.derived_high_cwnd_gain;
424        self.last_cycle_start = Some(now);
425        
426        let mut rand_index = rand::random::<u8>() % (self.config.pacing_gain.len() as u8 - 1);
427        if rand_index >= 1 {
428            rand_index += 1;
429        }
430        self.current_cycle_offset = rand_index;
431        self.pacing_gain = self.config.pacing_gain[rand_index as usize];
432    }
433
434    fn update_recovery_state(&mut self, is_round_start: bool) {
435        if self.loss_state.has_losses() {
436            self.end_recovery_at_packet_number = self.max_sent_packet_number;
437        }
438        match self.recovery_state {
439            RecoveryState::NotInRecovery if self.loss_state.has_losses() => {
440                self.recovery_state = RecoveryState::Conservation;
441                self.recovery_window = 0;
442                self.current_round_trip_end_packet_number = self.max_sent_packet_number;
443            }
444            RecoveryState::Growth | RecoveryState::Conservation => {
445                if self.recovery_state == RecoveryState::Conservation && is_round_start {
446                    self.recovery_state = RecoveryState::Growth;
447                }
448                if !self.loss_state.has_losses()
449                    && self.max_acked_packet_number > self.end_recovery_at_packet_number
450                {
451                    self.recovery_state = RecoveryState::NotInRecovery;
452                }
453            }
454            _ => {}
455        }
456    }
457
458    fn update_gain_cycle_phase(&mut self, now: Timestamp, in_flight: u64) {
459        let mut should_advance_gain_cycling = self
460            .last_cycle_start
461            .map(|last_cycle_start| Duration::from_micros(now - last_cycle_start) > self.min_rtt)
462            .unwrap_or(false);
463        if self.pacing_gain > 1.0
464            && !self.loss_state.has_losses()
465            && self.prev_in_flight_count < self.get_target_cwnd(self.pacing_gain)
466        {
467            should_advance_gain_cycling = false;
468        }
469
470        if self.pacing_gain < 1.0 && in_flight <= self.get_target_cwnd(1.0) {
471            should_advance_gain_cycling = true;
472        }
473
474        if should_advance_gain_cycling {
475            self.current_cycle_offset = (self.current_cycle_offset + 1) % self.config.pacing_gain.len() as u8;
476            self.last_cycle_start = Some(now);
477            
478            if self.config.drain_to_target
479                && self.pacing_gain < 1.0
480                && (self.config.pacing_gain[self.current_cycle_offset as usize] - 1.0).abs() < f32::EPSILON
481                && in_flight > self.get_target_cwnd(1.0)
482            {
483                return;
484            }
485            self.pacing_gain = self.config.pacing_gain[self.current_cycle_offset as usize];
486        }
487    }
488
489    fn maybe_exit_startup_or_drain(&mut self, now: Timestamp, in_flight: u64) {
490        if self.mode == Mode::Startup && self.is_at_full_bandwidth {
491            self.mode = Mode::Drain;
492            self.pacing_gain = self.drain_gain;
493            self.cwnd_gain = self.high_cwnd_gain;
494        }
495        if self.mode == Mode::Drain && in_flight <= self.get_target_cwnd(1.0) {
496            self.enter_probe_bandwidth_mode(now);
497        }
498    }
499
500    fn is_min_rtt_expired(&self, now: Timestamp, app_limited: bool) -> bool {
501        !app_limited
502            && self
503                .probe_rtt_last_started_at
504                .map(|last| if now > last { Duration::from_micros(now - last) > self.config.min_rtt_expire_time } else { false })
505                .unwrap_or(true)
506    }
507
508    fn maybe_enter_or_exit_probe_rtt(
509        &mut self,
510        now: Timestamp,
511        is_round_start: bool,
512        bytes_in_flight: u64,
513        app_limited: bool,
514    ) {
515        let min_rtt_expired = self.is_min_rtt_expired(now, app_limited);
516        if min_rtt_expired && !self.exiting_quiescence && self.mode != Mode::ProbeRtt {
517            self.mode = Mode::ProbeRtt;
518            self.pacing_gain = 1.0;
519            self.exit_probe_rtt_at = None;
520            self.probe_rtt_last_started_at = Some(now);
521        }
522
523        if self.mode == Mode::ProbeRtt {
524            if self.exit_probe_rtt_at.is_none() {
525                if bytes_in_flight < self.get_probe_rtt_cwnd() + self.mss {
526                    self.exit_probe_rtt_at = Some(now + self.config.probe_rtt_time.as_micros() as u64);
527                }
528            } else if is_round_start && now >= self.exit_probe_rtt_at.unwrap() {
529                if !self.is_at_full_bandwidth {
530                    self.enter_startup_mode();
531                } else {
532                    self.enter_probe_bandwidth_mode(now);
533                }
534            }
535        }
536
537        self.exiting_quiescence = false;
538    }
539
540    fn get_target_cwnd(&self, gain: f32) -> u64 {
541        let bw = self.max_bandwidth.get_estimate();
542        let bdp = self.min_rtt.as_micros() as u64 * bw;
543        let bdpf = bdp as f64;
544        let cwnd = ((gain as f64 * bdpf) / 1_000_000f64) as u64;
545        if cwnd == 0 {
546            self.config.init_cwnd
547        } else {
548            cwnd.max(self.config.min_cwnd)
549        }
550        
551    }
552
553    fn get_probe_rtt_cwnd(&self) -> u64 {
554        if self.config.probe_rtt_based_on_bdp {
555            self.get_target_cwnd(self.config.mode_rate_probe_rtt_multiplier)
556        } else {
557            self.config.min_cwnd
558        }
559    }
560
561    fn calculate_pacing_rate(&mut self) {
562        let bw = self.max_bandwidth.get_estimate();
563        if bw == 0 {
564            return;
565        }
566        let target_rate = (bw as f64 * self.pacing_gain as f64) as u64;
567        if self.is_at_full_bandwidth {
568            self.pacing_rate = target_rate;
569            return;
570        }
571
572        if self.pacing_rate == 0 && self.min_rtt.as_nanos() != 0 {
573            self.pacing_rate = BandwidthEstimation::bw_from_delta(self.config.init_cwnd, self.min_rtt);
574            return;
575        }
576
577        if self.pacing_rate < target_rate {
578            self.pacing_rate = target_rate;
579        }
580    }
581
582    fn calculate_cwnd(&mut self, bytes_acked: u64, excess_acked: u64) {
583        if self.mode == Mode::ProbeRtt {
584            return;
585        }
586        let mut target_window = self.get_target_cwnd(self.cwnd_gain);
587        if self.is_at_full_bandwidth {
588            target_window += self.ack_aggregation.max_ack_height.get();
589        } else {
590            target_window += excess_acked;
591        }
592        
593        if self.is_at_full_bandwidth {
594            self.cwnd = target_window.min(self.cwnd + bytes_acked);
595        } else if (self.cwnd_gain < target_window as f32) || (self.acked_bytes < self.config.init_cwnd) {
596            self.cwnd += bytes_acked;
597        }
598
599        self.cwnd = self.cwnd.max(self.config.min_cwnd);
600    }
601
602    fn calculate_recovery_window(&mut self, bytes_acked: u64, bytes_lost: u64, in_flight: u64) {
603        if !self.recovery_state.in_recovery() {
604            return;
605        }
606        
607        if self.recovery_window == 0 {
608            self.recovery_window = self.config.min_cwnd.max(in_flight + bytes_acked);
609            return;
610        }
611
612        if self.recovery_window >= bytes_lost {
613            self.recovery_window -= bytes_lost;
614        } else {
615            self.recovery_window = self.mss;
616        }
617        
618        if self.recovery_state == RecoveryState::Growth {
619            self.recovery_window += bytes_acked;
620        }
621
622        self.recovery_window = self.recovery_window.max(in_flight + bytes_acked).max(self.config.min_cwnd);
623    }
624
625    fn check_if_full_bw_reached(&mut self, app_limited: bool) {
626        if app_limited {
627            return;
628        }
629        let target = (self.bw_at_last_round as f64 * self.config.startup_growth_target as f64) as u64;
630        let bw = self.max_bandwidth.get_estimate();
631        if bw >= target {
632            self.bw_at_last_round = bw;
633            self.round_wo_bw_gain = 0;
634            self.ack_aggregation.max_ack_height.reset();
635            return;
636        }
637
638        self.round_wo_bw_gain += 1;
639        if self.round_wo_bw_gain >= self.config.round_trips_with_growth_before_exiting_startup as u64
640            || (self.recovery_state.in_recovery())
641        {
642            self.is_at_full_bandwidth = true;
643        }
644    }
645}
646
647
648
649
650impl CcImpl for Bbr {
651    fn on_sent(&mut self, now: Timestamp, bytes: u64, last_packet_number: u64) {
652        self.max_sent_packet_number = last_packet_number;
653        self.max_bandwidth.on_sent(now, bytes);
654    }
655
656    fn cwnd(&self) -> u64 {
657        if self.mode == Mode::ProbeRtt {
658            self.get_probe_rtt_cwnd()
659        } else if self.recovery_state.in_recovery()
660            && self.mode != Mode::Startup {
661            self.cwnd.min(self.recovery_window)
662        } else {
663            self.cwnd
664        }
665    }
666
667    fn on_estimate(&mut self, rtt: Duration, _rto: Duration, _delay: Duration, app_limited: bool) {
668        let now = bucky_time_now();
669
670        if self.is_min_rtt_expired(now, app_limited) || self.min_rtt > rtt {
671            self.min_rtt = rtt;
672        }
673    }
674
675    fn on_ack(&mut self, flight: u64, ack: u64, largest_packet_num_acked: Option<u64>, sent_time: Timestamp, app_limited: bool) { //ret cwnd
676        let now = bucky_time_now();
677
678        self.max_bandwidth.on_ack(
679            now,
680            sent_time,
681            ack,
682            self.round_count,
683            app_limited
684        );
685        self.acked_bytes += ack;
686
687        let ack_in_wnd = self.max_bandwidth.bytes_acked_this_window();
688        let excess_acked = self.ack_aggregation.update_ack_aggregation_bytes(
689            ack_in_wnd,
690            now,
691            self.round_count,
692            self.max_bandwidth.get_estimate(),
693        );
694        self.max_bandwidth.end_acks(self.round_count, app_limited);
695        if let Some(largest_acked_packet) = largest_packet_num_acked {
696            self.max_acked_packet_number = largest_acked_packet;
697        }
698
699        let mut is_round_start = false;
700        if ack_in_wnd > 0 {
701            is_round_start = self.max_acked_packet_number
702                > self.current_round_trip_end_packet_number;
703            if is_round_start {
704                self.current_round_trip_end_packet_number =
705                    self.max_sent_packet_number;
706                self.round_count += 1;
707            }
708        }
709
710        self.update_recovery_state(is_round_start);
711
712        if self.mode == Mode::ProbeBw {
713            self.update_gain_cycle_phase(now, flight);
714        }
715
716        if is_round_start && !self.is_at_full_bandwidth {
717            self.check_if_full_bw_reached(app_limited);
718        }
719
720        self.maybe_exit_startup_or_drain(now, flight);
721
722        self.maybe_enter_or_exit_probe_rtt(now, is_round_start, flight, app_limited);
723
724        self.calculate_pacing_rate();
725        self.calculate_cwnd(ack_in_wnd, excess_acked);
726        self.calculate_recovery_window(
727            ack_in_wnd,
728            self.loss_state.lost_bytes,
729            flight,
730        );
731        self.prev_in_flight_count = flight;
732        self.loss_state.reset();
733    }
734
735    fn on_loss(&mut self, lost: u64) {
736        self.loss_state.lost_bytes += lost;
737    }
738
739    fn on_no_resp(&mut self, rto: Duration, lost: u64) -> Duration {
740        self.loss_state.lost_bytes += lost;
741        rto * 2
742    }
743
744    fn on_time_escape(&mut self, _: Timestamp) {
745    }
746
747    fn rate(&self) -> u64 {
748        self.max_bandwidth.get_estimate()
749    }
750}