use crate::transport::TransportKind;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum AggregationStrategy {
RoundRobin,
Weighted,
LatencyOptimal,
}
#[derive(Clone, Debug)]
pub struct FragmentAssignment {
pub transport: TransportKind,
pub data: Vec<u8>,
pub fragment_index: u16,
pub total_fragments: u16,
}
pub struct BandwidthAggregator {
strategy: AggregationStrategy,
fragment_size: usize,
rr_counter: usize,
}
impl BandwidthAggregator {
pub fn new(strategy: AggregationStrategy, fragment_size: usize) -> Self {
Self {
strategy,
fragment_size: fragment_size.max(64),
rr_counter: 0,
}
}
pub fn strategy(&self) -> AggregationStrategy {
self.strategy
}
pub fn set_strategy(&mut self, strategy: AggregationStrategy) {
self.strategy = strategy;
}
pub fn distribute(
&mut self,
data: &[u8],
transport_scores: &[(TransportKind, f64)],
) -> Vec<FragmentAssignment> {
let healthy: Vec<(TransportKind, f64)> = transport_scores
.iter()
.filter(|(_, s)| *s > 0.0)
.copied()
.collect();
if healthy.is_empty() || data.is_empty() {
return Vec::new();
}
match self.strategy {
AggregationStrategy::RoundRobin => self.distribute_round_robin(data, &healthy),
AggregationStrategy::Weighted => self.distribute_weighted(data, &healthy),
AggregationStrategy::LatencyOptimal => self.distribute_latency_optimal(data, &healthy),
}
}
fn fragment_data(&self, data: &[u8]) -> Vec<Vec<u8>> {
data.chunks(self.fragment_size)
.map(|c| c.to_vec())
.collect()
}
fn distribute_round_robin(
&mut self,
data: &[u8],
transports: &[(TransportKind, f64)],
) -> Vec<FragmentAssignment> {
let chunks = self.fragment_data(data);
let total = chunks.len() as u16;
let start_counter = self.rr_counter;
let result: Vec<FragmentAssignment> = chunks
.into_iter()
.enumerate()
.map(|(i, chunk)| {
let idx = (start_counter + i) % transports.len();
FragmentAssignment {
transport: transports[idx].0,
data: chunk,
fragment_index: i as u16,
total_fragments: total,
}
})
.collect();
self.rr_counter = (start_counter + result.len()) % transports.len();
result
}
fn distribute_weighted(
&mut self,
data: &[u8],
transports: &[(TransportKind, f64)],
) -> Vec<FragmentAssignment> {
let chunks = self.fragment_data(data);
let total = chunks.len() as u16;
if transports.len() == 1 {
return chunks
.into_iter()
.enumerate()
.map(|(i, chunk)| FragmentAssignment {
transport: transports[0].0,
data: chunk,
fragment_index: i as u16,
total_fragments: total,
})
.collect();
}
let total_score: f64 = transports.iter().map(|(_, s)| s).sum();
let weights: Vec<f64> = transports.iter().map(|(_, s)| s / total_score).collect();
let n = chunks.len();
let mut assignments: Vec<usize> = Vec::with_capacity(n);
let mut remaining = n;
for (i, w) in weights.iter().enumerate() {
let count = if i == weights.len() - 1 {
remaining
} else {
let c = (n as f64 * w).round() as usize;
c.min(remaining)
};
for _ in 0..count {
assignments.push(i);
}
remaining = remaining.saturating_sub(count);
}
chunks
.into_iter()
.enumerate()
.map(|(i, chunk)| {
let tidx = if i < assignments.len() {
assignments[i]
} else {
0
};
FragmentAssignment {
transport: transports[tidx].0,
data: chunk,
fragment_index: i as u16,
total_fragments: total,
}
})
.collect()
}
fn distribute_latency_optimal(
&mut self,
data: &[u8],
transports: &[(TransportKind, f64)],
) -> Vec<FragmentAssignment> {
let best = transports
.iter()
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
.unwrap();
let chunks = self.fragment_data(data);
let total = chunks.len() as u16;
chunks
.into_iter()
.enumerate()
.map(|(i, chunk)| FragmentAssignment {
transport: best.0,
data: chunk,
fragment_index: i as u16,
total_fragments: total,
})
.collect()
}
pub fn reassemble(fragments: &[(u16, &[u8])], expected_total: u16) -> Option<Vec<u8>> {
if fragments.is_empty() || expected_total == 0 {
return None;
}
if fragments.len() != expected_total as usize {
return None;
}
for (idx, _) in fragments {
if *idx >= expected_total {
return None;
}
}
let mut sorted: Vec<(u16, &[u8])> = fragments.to_vec();
sorted.sort_by_key(|(idx, _)| *idx);
let total_len: usize = sorted.iter().map(|(_, data)| data.len()).sum();
let mut result = Vec::with_capacity(total_len);
for (_, data) in sorted {
result.extend_from_slice(data);
}
Some(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn single_transport_passthrough() {
let mut agg = BandwidthAggregator::new(AggregationStrategy::RoundRobin, 1024);
let data = vec![0u8; 100];
let scores = vec![(TransportKind::Tcp, 1.0)];
let frags = agg.distribute(&data, &scores);
assert_eq!(frags.len(), 1);
assert_eq!(frags[0].transport, TransportKind::Tcp);
assert_eq!(frags[0].data.len(), 100);
assert_eq!(frags[0].fragment_index, 0);
assert_eq!(frags[0].total_fragments, 1);
}
#[test]
fn round_robin_distributes_evenly() {
let mut agg = BandwidthAggregator::new(AggregationStrategy::RoundRobin, 100);
let data = vec![0u8; 300]; let scores = vec![
(TransportKind::Tcp, 1.0),
(TransportKind::Udp, 1.0),
(TransportKind::Quic, 1.0),
];
let frags = agg.distribute(&data, &scores);
assert_eq!(frags.len(), 3);
assert_eq!(frags[0].transport, TransportKind::Tcp);
assert_eq!(frags[1].transport, TransportKind::Udp);
assert_eq!(frags[2].transport, TransportKind::Quic);
}
#[test]
fn weighted_distribution_favors_healthy() {
let mut agg = BandwidthAggregator::new(AggregationStrategy::Weighted, 100);
let data = vec![0u8; 1000]; let scores = vec![(TransportKind::Tcp, 0.9), (TransportKind::Udp, 0.1)];
let frags = agg.distribute(&data, &scores);
assert_eq!(frags.len(), 10);
let tcp_count = frags
.iter()
.filter(|f| f.transport == TransportKind::Tcp)
.count();
let udp_count = frags
.iter()
.filter(|f| f.transport == TransportKind::Udp)
.count();
assert!(
tcp_count > udp_count,
"TCP={}, UDP={}",
tcp_count,
udp_count
);
}
#[test]
fn latency_optimal_uses_best() {
let mut agg = BandwidthAggregator::new(AggregationStrategy::LatencyOptimal, 100);
let data = vec![0u8; 300];
let scores = vec![
(TransportKind::Tcp, 0.5),
(TransportKind::Udp, 0.9),
(TransportKind::Quic, 0.3),
];
let frags = agg.distribute(&data, &scores);
for f in &frags {
assert_eq!(f.transport, TransportKind::Udp);
}
}
#[test]
fn zero_score_transports_excluded() {
let mut agg = BandwidthAggregator::new(AggregationStrategy::RoundRobin, 100);
let data = vec![0u8; 200];
let scores = vec![
(TransportKind::Tcp, 0.0), (TransportKind::Udp, 1.0),
];
let frags = agg.distribute(&data, &scores);
for f in &frags {
assert_eq!(f.transport, TransportKind::Udp);
}
}
#[test]
fn empty_data_returns_empty() {
let mut agg = BandwidthAggregator::new(AggregationStrategy::RoundRobin, 100);
let frags = agg.distribute(&[], &[(TransportKind::Tcp, 1.0)]);
assert!(frags.is_empty());
}
#[test]
fn no_healthy_transports_returns_empty() {
let mut agg = BandwidthAggregator::new(AggregationStrategy::RoundRobin, 100);
let frags = agg.distribute(&[1, 2, 3], &[(TransportKind::Tcp, 0.0)]);
assert!(frags.is_empty());
}
#[test]
fn fragment_indices_correct() {
let mut agg = BandwidthAggregator::new(AggregationStrategy::RoundRobin, 64);
let data = vec![0u8; 200]; let scores = vec![(TransportKind::Tcp, 1.0)];
let frags = agg.distribute(&data, &scores);
assert_eq!(frags.len(), 4);
for (i, f) in frags.iter().enumerate() {
assert_eq!(f.fragment_index, i as u16);
assert_eq!(f.total_fragments, 4);
}
}
#[test]
fn reassemble_roundtrip() {
let mut agg = BandwidthAggregator::new(AggregationStrategy::RoundRobin, 100);
let original = vec![0xABu8; 300];
let scores = vec![
(TransportKind::Tcp, 1.0),
(TransportKind::Udp, 1.0),
(TransportKind::Quic, 1.0),
];
let frags = agg.distribute(&original, &scores);
let total = frags[0].total_fragments;
let mut received: Vec<(u16, Vec<u8>)> = Vec::new();
for f in &frags {
received.push((f.fragment_index, f.data.clone()));
}
let fragments: Vec<(u16, &[u8])> =
received.iter().map(|(i, d)| (*i, d.as_slice())).collect();
let reassembled = BandwidthAggregator::reassemble(&fragments, total).unwrap();
assert_eq!(reassembled, original);
}
#[test]
fn reassemble_out_of_order() {
let fragments: Vec<(u16, &[u8])> = vec![(2, b"o"), (0, b"he"), (1, b"ll")];
let result = BandwidthAggregator::reassemble(&fragments, 3).unwrap();
assert_eq!(result, b"hello");
}
#[test]
fn reassemble_incomplete_returns_none() {
let fragments: Vec<(u16, &[u8])> = vec![(0, b"he"), (2, b"o")];
assert!(BandwidthAggregator::reassemble(&fragments, 3).is_none());
}
#[test]
fn reassemble_invalid_index_returns_none() {
let fragments: Vec<(u16, &[u8])> = vec![(0, b"a"), (5, b"b")];
assert!(BandwidthAggregator::reassemble(&fragments, 3).is_none());
}
}