ruknet 0.1.2

The perfect reimplementation of the RakNet protocol in Rust
Documentation
use std::{cmp::Ordering, collections::VecDeque};

use super::CongestionController;

const SYN: u64 = 10;

const MIN_CWND: usize = 2;
const MAX_CWND: usize = 512;

const DEFAULT_TRANSFER_RATE: f32 = 0.0036;
const DEFAULT_BYTE_INTERVAL: f32 = 1.0 / DEFAULT_TRANSFER_RATE;

const PACKET_ARRIVAL_HISTORY_SIZE: usize = 64;

const INTERVAL_SIZE: usize = 33;

#[allow(clippy::upper_case_acronyms)]
#[derive(Debug)]
pub struct UDT {
    cwnd: usize,
    mtu_size_excluding_header: usize,
    next_seq: u32,
    arrival_rate: Option<f32>,
    snd: f32,
    bytes_can_send_this_tick: isize,
    packet_arrival_history_index: usize,
    packet_arrival_history: [f32; PACKET_ARRIVAL_HISTORY_SIZE],
    pings_last_interval: VecDeque<u64>,
    last_packet_arrival_time: u64,
    last_rtt_on_increase_send_rate: u64,
    oldest_unsent_ack: u64,
    next_congestion_control_block: u32,
    had_packet_loss_this_block: bool,
    is_in_slow_start: bool,
}

impl CongestionController for UDT {
    fn new(mtu_size_excluding_header: usize) -> Self {
        Self {
            cwnd: MIN_CWND,
            mtu_size_excluding_header,
            next_seq: 0,
            arrival_rate: Some(DEFAULT_TRANSFER_RATE),
            snd: DEFAULT_BYTE_INTERVAL,
            bytes_can_send_this_tick: 0,
            packet_arrival_history_index: 0,
            packet_arrival_history: [0.0; PACKET_ARRIVAL_HISTORY_SIZE],
            pings_last_interval: VecDeque::new(),
            last_packet_arrival_time: 0,
            last_rtt_on_increase_send_rate: 1000,
            oldest_unsent_ack: 0,
            next_congestion_control_block: 0,
            had_packet_loss_this_block: false,
            is_in_slow_start: true,
        }
    }

    fn should_send_ack(&self, now: u64, time_since_last_update: u64) -> bool {
        now - self.oldest_unsent_ack >= SYN
            || now + time_since_last_update < self.oldest_unsent_ack + SYN
    }

    fn is_in_slow_start(&self) -> bool {
        self.is_in_slow_start
    }

    fn get_seq_and_increment(&mut self) -> u32 {
        let current = self.next_seq;
        self.next_seq += 1;
        current
    }

    fn get_transmission_bandwidth(
        &mut self,
        _now: u64,
        mut time_since_last_update: u64,
        is_continuous_send: bool,
    ) -> usize {
        if self.is_in_slow_start {
            self.cwnd * self.mtu_size_excluding_header
        } else {
            if self.bytes_can_send_this_tick < 0 {
                self.bytes_can_send_this_tick = 0;
            }

            if !is_continuous_send && time_since_last_update > 100 {
                time_since_last_update = 100;
            }

            self.bytes_can_send_this_tick +=
                time_since_last_update as isize * (1.0 / self.snd) as isize;

            if self.bytes_can_send_this_tick < 0 {
                self.bytes_can_send_this_tick as usize
            } else {
                0
            }
        }
    }

    fn get_rto(&self) -> u64 {
        (self.last_rtt_on_increase_send_rate * 2).clamp(100, 10000)
    }

    fn on_resend(&mut self, _now: u64) {
        if self.is_in_slow_start && self.arrival_rate.is_some() {
            self.end_slow_start();
        } else if !self.had_packet_loss_this_block {
            self.increase_time_between_sends();
            self.had_packet_loss_this_block = true;
        }
    }

    fn on_recv_ack(
        &mut self,
        now: u64,
        rtt: u64,
        arrival_rate: Option<f32>,
        total_bytes_acked: usize,
        seq: u32,
        is_continuous_send: bool,
    ) {
        if let Some(arrival_rate) = arrival_rate {
            self.arrival_rate = Some(arrival_rate);
        }

        if self.oldest_unsent_ack == 0 {
            self.oldest_unsent_ack = now;
        }

        if self.is_in_slow_start {
            self.next_congestion_control_block = seq;
            self.last_rtt_on_increase_send_rate = rtt;
            self.update_window_size_slow_start(total_bytes_acked);
        } else {
            self.update_window_size_syn(rtt, seq, is_continuous_send);
        }
    }

    fn on_recv_nak(&mut self) {
        if self.is_in_slow_start && self.arrival_rate.is_some() {
            self.end_slow_start();
        } else if !self.had_packet_loss_this_block {
            self.increase_time_between_sends();
            self.had_packet_loss_this_block = true;
        }
    }

    fn on_recv_packet(&mut self, now: u64, size: usize, is_continuous_send: bool) {
        if now > self.last_packet_arrival_time {
            let interval = now - self.last_packet_arrival_time;

            if is_continuous_send {
                self.packet_arrival_history
                    [self.packet_arrival_history_index & (PACKET_ARRIVAL_HISTORY_SIZE - 1)] =
                    size as f32 / interval as f32;
                self.packet_arrival_history_index += 1;
            }

            self.last_packet_arrival_time = now;
        }
    }

    fn on_send_ack(&mut self, needs_arrival_rate: bool, arrival_rate: &mut Option<f32>) {
        self.oldest_unsent_ack = 0;

        if needs_arrival_rate {
            *arrival_rate = self.calc_arrival_rate();
        }
    }

    fn on_send_packet(&mut self, size: usize) {
        if !self.is_in_slow_start {
            self.bytes_can_send_this_tick -= size as isize;
        }
    }
}

impl UDT {
    fn end_slow_start(&mut self) {
        self.is_in_slow_start = false;
        self.snd = 1.0 / self.arrival_rate.unwrap();
        self.cap_min_snd();
    }

    fn increase_time_between_sends(&mut self) {
        let increment = 0.02 * (self.snd + 1.0).powf(2.0) / 501.0f32.powf(2.0);
        self.snd *= 1.02 - increment;
        self.cap_min_snd();
    }

    fn decrease_time_between_sends(&mut self) {
        let increment = 0.01 * (self.snd + 1.0).powf(2.0) / 501.0f32.powf(2.0);
        self.snd *= 0.99 + increment;
    }

    fn cap_min_snd(&mut self) {
        if self.snd > 500.0 {
            self.snd = 500.0;
        }
    }

    fn calc_arrival_rate(&self) -> Option<f32> {
        if self.packet_arrival_history_index < PACKET_ARRIVAL_HISTORY_SIZE {
            return None;
        }

        let median = self.find_median_arrival_rate();
        let one_eighth_median = median / 8.0;
        let eight_times_median = median * 8.0;
        let mut sum = 0.0;
        let mut count = 0;

        for i in 0..PACKET_ARRIVAL_HISTORY_SIZE {
            if self.packet_arrival_history[i] >= one_eighth_median
                && self.packet_arrival_history[i] < eight_times_median
            {
                sum += self.packet_arrival_history[i];
                count += 1;
            }
        }

        if count > 0 {
            Some(sum / count as f32)
        } else {
            None
        }
    }

    fn find_median_arrival_rate(&self) -> f32 {
        let mid_index = self.packet_arrival_history.len() / 2;
        Self::find_kth_element_recursive(&self.packet_arrival_history, mid_index)
    }

    fn find_kth_element_recursive(list: &[f32], k: usize) -> f32 {
        if list.len() == 1 {
            return list[0];
        }

        let pivot_index = list.len() / 2;
        let pivot = list[pivot_index];

        let mut less = Vec::new();
        let mut equal = Vec::new();
        let mut greater = Vec::new();

        for &e in list {
            match e.partial_cmp(&pivot).unwrap() {
                Ordering::Less => less.push(e),
                Ordering::Equal => equal.push(e),
                Ordering::Greater => greater.push(e),
            }
        }

        if k < less.len() {
            Self::find_kth_element_recursive(&less, k)
        } else if k < less.len() + equal.len() {
            pivot
        } else {
            Self::find_kth_element_recursive(&greater, k - less.len() - equal.len())
        }
    }

    fn update_window_size_slow_start(&mut self, total_bytes_acked: usize) {
        self.cwnd = total_bytes_acked / self.mtu_size_excluding_header;

        if self.cwnd < MIN_CWND {
            self.cwnd = MIN_CWND;
        } else if self.cwnd > MAX_CWND {
            self.cwnd = MAX_CWND;

            if self.arrival_rate.is_some() {
                self.end_slow_start();
            }
        }
    }

    fn update_window_size_syn(&mut self, rtt: u64, seq: u32, is_continuous_send: bool) {
        if !is_continuous_send {
            self.next_congestion_control_block = self.next_seq;
            self.pings_last_interval.clear();
            return;
        }

        self.pings_last_interval.push_back(rtt);

        if self.pings_last_interval.len() > INTERVAL_SIZE {
            self.pings_last_interval.pop_front();
        }

        if seq > self.next_congestion_control_block
            && seq - self.next_congestion_control_block >= INTERVAL_SIZE as u32
            && self.pings_last_interval.len() == INTERVAL_SIZE
        {
            let mut slope_sum = 0;
            let mut average = self.pings_last_interval[0];
            let size = self.pings_last_interval.len();

            for i in 1..size {
                slope_sum += self.pings_last_interval[i] - self.pings_last_interval[i - 1];
                average += self.pings_last_interval[i];
            }

            average /= size as u64;

            if slope_sum > average / 10 {
                self.increase_time_between_sends();
            } else {
                self.last_rtt_on_increase_send_rate = rtt;
                self.decrease_time_between_sends();
            }

            self.pings_last_interval.clear();
            self.had_packet_loss_this_block = false;
            self.next_congestion_control_block = seq;
        }
    }
}