use std::collections::{BTreeMap, HashMap, VecDeque};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use parking_lot::{Mutex, RwLock};
use serde::{Deserialize, Serialize};
use super::Uplink;
use crate::types::Bandwidth;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AggregatorConfig {
#[serde(default = "default_enabled")]
pub enabled: bool,
#[serde(default = "default_reorder_buffer_size")]
pub reorder_buffer_size: usize,
#[serde(default = "default_reorder_timeout", with = "humantime_serde")]
pub reorder_timeout: Duration,
#[serde(default = "default_min_buffer_packets")]
pub min_buffer_packets: usize,
#[serde(default = "default_weight_update_interval", with = "humantime_serde")]
pub weight_update_interval: Duration,
#[serde(default = "default_min_uplink_bandwidth")]
pub min_uplink_bandwidth: f64,
#[serde(default = "default_latency_compensation")]
pub latency_compensation: bool,
}
fn default_enabled() -> bool {
true
}
fn default_reorder_buffer_size() -> usize {
4096
}
fn default_reorder_timeout() -> Duration {
Duration::from_millis(500)
}
fn default_min_buffer_packets() -> usize {
4
}
fn default_weight_update_interval() -> Duration {
Duration::from_millis(100)
}
fn default_min_uplink_bandwidth() -> f64 {
10_000.0
}
fn default_latency_compensation() -> bool {
true
}
impl Default for AggregatorConfig {
fn default() -> Self {
Self {
enabled: default_enabled(),
reorder_buffer_size: default_reorder_buffer_size(),
reorder_timeout: default_reorder_timeout(),
min_buffer_packets: default_min_buffer_packets(),
weight_update_interval: default_weight_update_interval(),
min_uplink_bandwidth: default_min_uplink_bandwidth(),
latency_compensation: default_latency_compensation(),
}
}
}
#[derive(Debug, Clone, Copy)]
struct UplinkWeight {
uplink_id: u16,
bandwidth: f64,
weight: u32,
deficit: i32,
rtt: Duration,
}
#[derive(Debug)]
struct BufferedPacket {
sequence: u64,
data: Vec<u8>,
received_at: Instant,
from_uplink: u16,
}
#[derive(Debug)]
pub struct ReorderBuffer {
buffer: BTreeMap<u64, BufferedPacket>,
next_seq: u64,
max_size: usize,
timeout: Duration,
delivered: u64,
dropped: u64,
reordered: u64,
}
impl ReorderBuffer {
pub fn new(max_size: usize, timeout: Duration) -> Self {
Self {
buffer: BTreeMap::new(),
next_seq: 0,
max_size,
timeout,
delivered: 0,
dropped: 0,
reordered: 0,
}
}
pub fn insert(&mut self, sequence: u64, data: Vec<u8>, from_uplink: u16) -> Vec<Vec<u8>> {
let now = Instant::now();
if self.next_seq == 0 && self.buffer.is_empty() {
self.next_seq = sequence;
}
if sequence < self.next_seq {
return vec![];
}
if sequence != self.next_seq {
self.reordered += 1;
}
self.buffer.insert(
sequence,
BufferedPacket {
sequence,
data,
received_at: now,
from_uplink,
},
);
while self.buffer.len() > self.max_size {
if let Some((&oldest_seq, _)) = self.buffer.iter().next() {
self.buffer.remove(&oldest_seq);
self.dropped += 1;
if oldest_seq == self.next_seq {
self.next_seq = oldest_seq + 1;
}
}
}
self.collect_ready(now)
}
fn collect_ready(&mut self, now: Instant) -> Vec<Vec<u8>> {
let mut ready = Vec::new();
while let Some(packet) = self.buffer.remove(&self.next_seq) {
ready.push(packet.data);
self.next_seq += 1;
self.delivered += 1;
}
if !self.buffer.is_empty() {
if let Some((&min_seq, oldest)) = self.buffer.iter().next() {
if now.duration_since(oldest.received_at) > self.timeout {
let gap = min_seq - self.next_seq;
self.dropped += gap;
self.next_seq = min_seq;
while let Some(packet) = self.buffer.remove(&self.next_seq) {
ready.push(packet.data);
self.next_seq += 1;
self.delivered += 1;
}
}
}
}
ready
}
pub fn flush(&mut self) -> Vec<Vec<u8>> {
let mut ready = Vec::new();
for (_, packet) in std::mem::take(&mut self.buffer) {
ready.push(packet.data);
self.delivered += 1;
}
ready
}
pub fn stats(&self) -> ReorderStats {
ReorderStats {
buffered: self.buffer.len(),
next_seq: self.next_seq,
delivered: self.delivered,
dropped: self.dropped,
reordered: self.reordered,
}
}
pub fn has_pending(&self) -> bool {
!self.buffer.is_empty()
}
pub fn current_delay(&self) -> Duration {
self.buffer
.iter()
.next()
.map(|(_, p)| p.received_at.elapsed())
.unwrap_or(Duration::ZERO)
}
pub fn poll_timeout(&mut self) -> Vec<Vec<u8>> {
self.collect_ready(Instant::now())
}
}
#[derive(Debug, Clone, Copy)]
pub struct ReorderStats {
pub buffered: usize,
pub next_seq: u64,
pub delivered: u64,
pub dropped: u64,
pub reordered: u64,
}
pub struct BandwidthAggregator {
config: AggregatorConfig,
weights: RwLock<Vec<UplinkWeight>>,
stripe_index: AtomicU64,
uplink_counters: RwLock<HashMap<u16, u64>>,
last_weight_update: RwLock<Instant>,
reorder_buffer: Mutex<ReorderBuffer>,
send_seq: AtomicU64,
stats: RwLock<AggregatorStats>,
}
#[derive(Debug, Clone, Copy, Default)]
pub struct AggregatorStats {
pub packets_sent: u64,
pub bytes_sent: u64,
pub packets_received: u64,
pub bytes_received: u64,
pub packets_per_uplink: [u64; 16],
}
impl BandwidthAggregator {
pub fn new(config: AggregatorConfig) -> Self {
let reorder_buffer = ReorderBuffer::new(config.reorder_buffer_size, config.reorder_timeout);
Self {
config,
weights: RwLock::new(Vec::new()),
stripe_index: AtomicU64::new(0),
uplink_counters: RwLock::new(HashMap::new()),
last_weight_update: RwLock::new(Instant::now()),
reorder_buffer: Mutex::new(reorder_buffer),
send_seq: AtomicU64::new(1),
stats: RwLock::new(AggregatorStats::default()),
}
}
pub fn update_weights(&self, uplinks: &[Arc<Uplink>]) {
let mut weights = Vec::new();
let mut total_bw = 0.0f64;
for uplink in uplinks {
if !uplink.is_usable() {
continue;
}
let bw = uplink.bandwidth().bytes_per_sec;
if bw < self.config.min_uplink_bandwidth {
continue;
}
total_bw += bw;
weights.push(UplinkWeight {
uplink_id: uplink.numeric_id(),
bandwidth: bw,
weight: 0, deficit: 0,
rtt: uplink.rtt(),
});
}
if weights.is_empty() || total_bw == 0.0 {
*self.weights.write() = weights;
return;
}
for w in &mut weights {
w.weight = ((w.bandwidth / total_bw) * 100.0).max(1.0) as u32;
}
weights.sort_by(|a, b| {
b.bandwidth
.partial_cmp(&a.bandwidth)
.unwrap_or(std::cmp::Ordering::Equal)
});
*self.weights.write() = weights;
*self.last_weight_update.write() = Instant::now();
}
pub fn next_stripe(&self, uplinks: &[Arc<Uplink>]) -> Option<(u16, u64)> {
if self.last_weight_update.read().elapsed() > self.config.weight_update_interval {
self.update_weights(uplinks);
}
let weights = self.weights.read();
if weights.is_empty() {
return None;
}
let idx = self.stripe_index.fetch_add(1, Ordering::Relaxed) as usize;
let mut weights_clone = weights.clone();
drop(weights);
let uplink_id = self.weighted_select(&mut weights_clone, idx);
{
let mut counters = self.uplink_counters.write();
*counters.entry(uplink_id).or_insert(0) += 1;
}
let seq = self.send_seq.fetch_add(1, Ordering::SeqCst);
{
let mut stats = self.stats.write();
stats.packets_sent += 1;
if (uplink_id as usize) < 16 {
stats.packets_per_uplink[uplink_id as usize] += 1;
}
}
Some((uplink_id, seq))
}
fn weighted_select(&self, weights: &mut [UplinkWeight], iteration: usize) -> u16 {
if weights.is_empty() {
return 0;
}
if weights.len() == 1 {
return weights[0].uplink_id;
}
let total_weight: u32 = weights.iter().map(|w| w.weight).sum();
if total_weight == 0 {
return weights[0].uplink_id;
}
let pos = (iteration as u32) % total_weight;
let mut cumulative = 0u32;
for w in weights.iter() {
cumulative += w.weight;
if pos < cumulative {
return w.uplink_id;
}
}
weights.last().map(|w| w.uplink_id).unwrap_or(0)
}
pub fn stripe_packet(&self, uplinks: &[Arc<Uplink>], _data_len: usize) -> Vec<(u16, u64)> {
if let Some(stripe) = self.next_stripe(uplinks) {
vec![stripe]
} else {
vec![]
}
}
pub fn receive(&self, sequence: u64, data: Vec<u8>, from_uplink: u16) -> Vec<Vec<u8>> {
{
let mut stats = self.stats.write();
stats.packets_received += 1;
stats.bytes_received += data.len() as u64;
}
self.reorder_buffer
.lock()
.insert(sequence, data, from_uplink)
}
pub fn poll_timeout(&self) -> Vec<Vec<u8>> {
let mut buffer = self.reorder_buffer.lock();
let now = Instant::now();
buffer.collect_ready(now)
}
pub fn flush(&self) -> Vec<Vec<u8>> {
self.reorder_buffer.lock().flush()
}
pub fn current_seq(&self) -> u64 {
self.send_seq.load(Ordering::SeqCst)
}
pub fn stats(&self) -> AggregatorStats {
*self.stats.read()
}
pub fn reorder_stats(&self) -> ReorderStats {
self.reorder_buffer.lock().stats()
}
pub fn bandwidth_distribution(&self) -> Vec<(u16, f64, u64)> {
let weights = self.weights.read();
let counters = self.uplink_counters.read();
weights
.iter()
.map(|w| {
let count = counters.get(&w.uplink_id).copied().unwrap_or(0);
(w.uplink_id, w.bandwidth, count)
})
.collect()
}
pub fn is_active(&self) -> bool {
self.config.enabled && !self.weights.read().is_empty()
}
pub fn latency_spread(&self) -> Duration {
let weights = self.weights.read();
if weights.is_empty() {
return Duration::ZERO;
}
let min_rtt = weights
.iter()
.map(|w| w.rtt)
.min()
.unwrap_or(Duration::ZERO);
let max_rtt = weights
.iter()
.map(|w| w.rtt)
.max()
.unwrap_or(Duration::ZERO);
max_rtt.saturating_sub(min_rtt)
}
pub fn optimal_timeout(&self) -> Duration {
let spread = self.latency_spread();
let optimal = spread.mul_f64(2.5);
optimal
.max(Duration::from_millis(50))
.min(Duration::from_secs(2))
}
}
impl std::fmt::Debug for BandwidthAggregator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let weights = self.weights.read();
let stats = self.stats.read();
f.debug_struct("BandwidthAggregator")
.field("uplinks", &weights.len())
.field("packets_sent", &stats.packets_sent)
.field("packets_received", &stats.packets_received)
.field("reorder_stats", &self.reorder_buffer.lock().stats())
.finish()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum AggregationMode {
#[default]
None,
Full,
Adaptive,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_reorder_buffer_in_order() {
let mut buffer = ReorderBuffer::new(100, Duration::from_secs(1));
let r1 = buffer.insert(1, vec![1], 0);
assert_eq!(r1.len(), 1);
let r2 = buffer.insert(2, vec![2], 0);
assert_eq!(r2.len(), 1);
let r3 = buffer.insert(3, vec![3], 0);
assert_eq!(r3.len(), 1);
let stats = buffer.stats();
assert_eq!(stats.delivered, 3);
assert_eq!(stats.reordered, 0);
}
#[test]
fn test_reorder_buffer_out_of_order() {
let mut buffer = ReorderBuffer::new(100, Duration::from_secs(1));
buffer.insert(1, vec![1], 0);
let r3 = buffer.insert(3, vec![3], 0);
assert_eq!(r3.len(), 0);
let r2 = buffer.insert(2, vec![2], 0);
assert_eq!(r2.len(), 2);
let stats = buffer.stats();
assert_eq!(stats.delivered, 3);
assert_eq!(stats.reordered, 1); }
#[test]
fn test_reorder_buffer_gap() {
let mut buffer = ReorderBuffer::new(100, Duration::from_millis(10));
buffer.insert(1, vec![1], 0);
buffer.insert(3, vec![3], 0);
buffer.insert(4, vec![4], 0);
std::thread::sleep(Duration::from_millis(20));
let ready = buffer.poll_timeout();
let ready2 = buffer.insert(5, vec![5], 0);
let stats = buffer.stats();
assert!(stats.dropped >= 1); }
#[test]
fn test_weighted_selection() {
let config = AggregatorConfig::default();
let agg = BandwidthAggregator::new(config);
{
let mut weights = agg.weights.write();
weights.push(UplinkWeight {
uplink_id: 1,
bandwidth: 100_000.0,
weight: 67, deficit: 0,
rtt: Duration::from_millis(10),
});
weights.push(UplinkWeight {
uplink_id: 2,
bandwidth: 50_000.0,
weight: 33, deficit: 0,
rtt: Duration::from_millis(20),
});
}
let mut counts: HashMap<u16, u32> = HashMap::new();
for i in 0..100 {
let mut weights = agg.weights.read().clone();
let id = agg.weighted_select(&mut weights, i);
*counts.entry(id).or_insert(0) += 1;
}
let c1 = *counts.get(&1).unwrap_or(&0);
let c2 = *counts.get(&2).unwrap_or(&0);
assert!(
c1 > c2,
"Uplink 1 should get more packets: {} vs {}",
c1,
c2
);
let ratio = c1 as f64 / c2.max(1) as f64;
assert!(
ratio > 1.5 && ratio < 2.5,
"Ratio should be ~2:1, got {}",
ratio
);
}
}