use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use super::throughput::{EffectiveThroughput, ThroughputConfig};
use super::Uplink;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum SchedulingStrategy {
WeightedRoundRobin,
LowestLatency,
LowestLoss,
#[default]
Adaptive,
Redundant,
PrimaryBackup,
BandwidthProportional,
EcmpAware,
EffectiveThroughput,
LatencyAware,
SizeBased,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SchedulerConfig {
#[serde(default)]
pub strategy: SchedulingStrategy,
#[serde(default = "default_rtt_threshold")]
pub rtt_threshold_ms: u32,
#[serde(default = "default_loss_threshold")]
pub loss_threshold_percent: f32,
#[serde(default = "default_rtt_weight")]
pub rtt_weight: f32,
#[serde(default = "default_loss_weight")]
pub loss_weight: f32,
#[serde(default = "default_bw_weight")]
pub bandwidth_weight: f32,
#[serde(default = "default_nat_weight")]
pub nat_penalty_weight: f32,
#[serde(default = "default_sticky")]
pub sticky_paths: bool,
#[serde(default = "default_sticky_timeout", with = "humantime_serde")]
pub sticky_timeout: Duration,
#[serde(default = "default_probe")]
pub probe_backup_paths: bool,
#[serde(default = "default_probe_interval", with = "humantime_serde")]
pub probe_interval: Duration,
#[serde(default)]
pub throughput: ThroughputConfig,
#[serde(default = "default_max_latency")]
pub max_acceptable_latency_ms: u32,
#[serde(default = "default_size_threshold")]
pub size_threshold_bytes: u64,
#[serde(default = "default_throughput_aware")]
pub throughput_aware: bool,
#[serde(default = "default_effective_throughput_weight")]
pub effective_throughput_weight: f32,
#[serde(default = "default_prevent_latency_blocking")]
pub prevent_latency_blocking: bool,
#[serde(default = "default_latency_blocking_ratio")]
pub latency_blocking_ratio: f32,
}
fn default_rtt_threshold() -> u32 {
10
}
fn default_loss_threshold() -> f32 {
2.0
}
fn default_rtt_weight() -> f32 {
0.25
}
fn default_loss_weight() -> f32 {
0.25
}
fn default_bw_weight() -> f32 {
0.35
}
fn default_nat_weight() -> f32 {
0.05
}
fn default_sticky() -> bool {
true
}
fn default_sticky_timeout() -> Duration {
Duration::from_secs(5)
}
fn default_probe() -> bool {
true
}
fn default_probe_interval() -> Duration {
Duration::from_secs(1)
}
fn default_max_latency() -> u32 {
500
}
fn default_size_threshold() -> u64 {
64 * 1024
} fn default_throughput_aware() -> bool {
true
}
fn default_effective_throughput_weight() -> f32 {
0.10
}
fn default_prevent_latency_blocking() -> bool {
true
}
fn default_latency_blocking_ratio() -> f32 {
10.0
}
impl Default for SchedulerConfig {
fn default() -> Self {
Self {
strategy: SchedulingStrategy::default(),
rtt_threshold_ms: default_rtt_threshold(),
loss_threshold_percent: default_loss_threshold(),
rtt_weight: default_rtt_weight(),
loss_weight: default_loss_weight(),
bandwidth_weight: default_bw_weight(),
nat_penalty_weight: default_nat_weight(),
sticky_paths: default_sticky(),
sticky_timeout: default_sticky_timeout(),
probe_backup_paths: default_probe(),
probe_interval: default_probe_interval(),
throughput: ThroughputConfig::default(),
max_acceptable_latency_ms: default_max_latency(),
size_threshold_bytes: default_size_threshold(),
throughput_aware: default_throughput_aware(),
effective_throughput_weight: default_effective_throughput_weight(),
prevent_latency_blocking: default_prevent_latency_blocking(),
latency_blocking_ratio: default_latency_blocking_ratio(),
}
}
}
#[derive(Debug, Default)]
struct WrrState {
current_index: usize,
current_weight: u32,
}
#[derive(Debug)]
struct PathStickiness {
flows: HashMap<u64, (u16, Instant)>,
}
impl PathStickiness {
fn new() -> Self {
Self {
flows: HashMap::new(),
}
}
fn get(&self, flow_id: u64, timeout: Duration) -> Option<u16> {
self.flows.get(&flow_id).and_then(|(uplink, last)| {
if last.elapsed() < timeout {
Some(*uplink)
} else {
None
}
})
}
fn set(&mut self, flow_id: u64, uplink_id: u16) {
self.flows.insert(flow_id, (uplink_id, Instant::now()));
}
fn cleanup(&mut self, timeout: Duration) {
self.flows.retain(|_, (_, last)| last.elapsed() < timeout);
}
}
pub struct Scheduler {
config: SchedulerConfig,
wrr_state: RwLock<WrrState>,
stickiness: RwLock<PathStickiness>,
last_probe: RwLock<HashMap<u16, Instant>>,
throughput_cache: RwLock<HashMap<u16, (Instant, EffectiveThroughput)>>,
cache_ttl: Duration,
}
impl Scheduler {
pub fn new(config: SchedulerConfig) -> Self {
Self {
config,
wrr_state: RwLock::new(WrrState::default()),
stickiness: RwLock::new(PathStickiness::new()),
last_probe: RwLock::new(HashMap::new()),
throughput_cache: RwLock::new(HashMap::new()),
cache_ttl: Duration::from_millis(100),
}
}
pub fn config(&self) -> &SchedulerConfig {
&self.config
}
pub fn select(&self, uplinks: &[Arc<Uplink>], flow_id: Option<u64>) -> Vec<u16> {
self.select_for_size(uplinks, flow_id, None)
}
pub fn select_for_size(
&self,
uplinks: &[Arc<Uplink>],
flow_id: Option<u64>,
size_bytes: Option<u64>,
) -> Vec<u16> {
let usable: Vec<_> = if self.config.prevent_latency_blocking {
self.filter_latency_blocked(uplinks)
} else {
uplinks.iter().filter(|u| u.is_usable()).collect()
};
if usable.is_empty() {
return vec![];
}
if self.config.sticky_paths {
if let Some(flow) = flow_id {
let sticky = self.stickiness.read().get(flow, self.config.sticky_timeout);
if let Some(sticky_uplink) = sticky {
if usable.iter().any(|u| u.numeric_id() == sticky_uplink) {
return vec![sticky_uplink];
}
}
}
}
let selected = match self.config.strategy {
SchedulingStrategy::WeightedRoundRobin => self.select_wrr(&usable),
SchedulingStrategy::LowestLatency => Self::select_lowest_latency(&usable),
SchedulingStrategy::LowestLoss => Self::select_lowest_loss(&usable),
SchedulingStrategy::Adaptive => self.select_adaptive(&usable),
SchedulingStrategy::Redundant => Self::select_redundant(&usable),
SchedulingStrategy::PrimaryBackup => Self::select_primary_backup(&usable),
SchedulingStrategy::BandwidthProportional => {
self.select_bandwidth_proportional(&usable)
}
SchedulingStrategy::EcmpAware => Self::select_ecmp_aware(&usable, flow_id),
SchedulingStrategy::EffectiveThroughput => self.select_effective_throughput(&usable),
SchedulingStrategy::LatencyAware => self.select_latency_aware(&usable, size_bytes),
SchedulingStrategy::SizeBased => self.select_size_based(&usable, size_bytes),
};
if self.config.sticky_paths && !selected.is_empty() {
if let Some(flow) = flow_id {
self.stickiness.write().set(flow, selected[0]);
}
}
selected
}
fn filter_latency_blocked<'a>(&self, uplinks: &'a [Arc<Uplink>]) -> Vec<&'a Arc<Uplink>> {
let usable: Vec<_> = uplinks.iter().filter(|u| u.is_usable()).collect();
if usable.is_empty() {
return usable;
}
let min_rtt = usable
.iter()
.map(|u| u.rtt())
.min()
.unwrap_or(Duration::ZERO);
if min_rtt == Duration::ZERO {
return usable;
}
let threshold = Duration::from_secs_f64(
min_rtt.as_secs_f64() * self.config.latency_blocking_ratio as f64,
);
usable
.into_iter()
.filter(|u| u.rtt() <= threshold)
.collect()
}
fn select_wrr(&self, uplinks: &[&Arc<Uplink>]) -> Vec<u16> {
if uplinks.is_empty() {
return vec![];
}
let mut state = self.wrr_state.write();
let max_weight: u32 = uplinks.iter().map(|u| u.config().weight).max().unwrap_or(1);
loop {
state.current_index = (state.current_index + 1) % uplinks.len();
if state.current_index == 0 {
if state.current_weight == 0 {
state.current_weight = max_weight;
} else {
state.current_weight -= 1;
}
}
let uplink = &uplinks[state.current_index];
if uplink.config().weight >= state.current_weight && uplink.can_send() {
return vec![uplink.numeric_id()];
}
if state.current_weight == 0 && state.current_index == 0 {
break;
}
}
uplinks
.first()
.map(|u| vec![u.numeric_id()])
.unwrap_or_default()
}
fn select_lowest_latency(uplinks: &[&Arc<Uplink>]) -> Vec<u16> {
uplinks
.iter()
.filter(|u| u.can_send())
.min_by_key(|u| u.rtt())
.map(|u| vec![u.numeric_id()])
.unwrap_or_default()
}
fn select_lowest_loss(uplinks: &[&Arc<Uplink>]) -> Vec<u16> {
uplinks
.iter()
.filter(|u| u.can_send())
.min_by(|a, b| {
a.loss_ratio()
.partial_cmp(&b.loss_ratio())
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|u| vec![u.numeric_id()])
.unwrap_or_default()
}
fn select_adaptive(&self, uplinks: &[&Arc<Uplink>]) -> Vec<u16> {
let mut scored: Vec<_> = uplinks
.iter()
.filter(|u| u.can_send())
.map(|u| {
let rtt_score = Self::rtt_score(u);
let loss_score = Self::loss_score(u);
let bw_score = Self::bandwidth_score(u, uplinks);
let nat_score = Self::nat_score(u);
let mut total_score = rtt_score * self.config.rtt_weight
+ loss_score * self.config.loss_weight
+ bw_score * self.config.bandwidth_weight
+ nat_score * self.config.nat_penalty_weight;
if self.config.throughput_aware {
let eff_throughput = self.calculate_effective_throughput(u);
total_score +=
eff_throughput.score as f32 * self.config.effective_throughput_weight;
}
(u.numeric_id(), total_score)
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.into_iter().map(|(id, _)| id).take(1).collect()
}
fn select_effective_throughput(&self, uplinks: &[&Arc<Uplink>]) -> Vec<u16> {
let mut scored: Vec<_> = uplinks
.iter()
.filter(|u| u.can_send())
.map(|u| {
let throughput = self.calculate_effective_throughput(u);
(u.numeric_id(), throughput.score)
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.into_iter().map(|(id, _)| id).take(1).collect()
}
fn select_latency_aware(&self, uplinks: &[&Arc<Uplink>], size_bytes: Option<u64>) -> Vec<u16> {
let size = size_bytes.unwrap_or(self.config.size_threshold_bytes);
let mut scored: Vec<_> = uplinks
.iter()
.filter(|u| u.can_send())
.map(|u| {
let throughput = self.calculate_effective_throughput(u);
let transfer_time = throughput.transfer_time(size);
(u.numeric_id(), transfer_time)
})
.collect();
scored.sort_by(|a, b| a.1.cmp(&b.1));
scored.into_iter().map(|(id, _)| id).take(1).collect()
}
fn select_size_based(&self, uplinks: &[&Arc<Uplink>], size_bytes: Option<u64>) -> Vec<u16> {
let size = size_bytes.unwrap_or(self.config.size_threshold_bytes);
if size < self.config.size_threshold_bytes {
Self::select_lowest_latency(uplinks)
} else {
self.select_effective_throughput(uplinks)
}
}
fn calculate_effective_throughput(&self, uplink: &Uplink) -> EffectiveThroughput {
let uplink_id = uplink.numeric_id();
{
let cache = self.throughput_cache.read();
if let Some((cached_at, throughput)) = cache.get(&uplink_id) {
if cached_at.elapsed() < self.cache_ttl {
return *throughput;
}
}
}
let metrics = uplink.quality_metrics();
let throughput = EffectiveThroughput::calculate(
uplink.bandwidth().bytes_per_sec,
uplink.rtt(),
uplink.loss_ratio(),
metrics.jitter,
&self.config.throughput,
);
{
let mut cache = self.throughput_cache.write();
cache.insert(uplink_id, (Instant::now(), throughput));
}
throughput
}
fn select_redundant(uplinks: &[&Arc<Uplink>]) -> Vec<u16> {
uplinks
.iter()
.filter(|u| u.can_send())
.map(|u| u.numeric_id())
.collect()
}
fn select_primary_backup(uplinks: &[&Arc<Uplink>]) -> Vec<u16> {
let mut sorted: Vec<_> = uplinks.iter().collect();
sorted.sort_by_key(|u| std::cmp::Reverse(u.priority_score()));
let mut result = Vec::new();
for uplink in sorted {
if uplink.can_send() {
result.push(uplink.numeric_id());
if result.len() >= 2 {
break;
}
}
}
result
}
fn select_bandwidth_proportional(&self, uplinks: &[&Arc<Uplink>]) -> Vec<u16> {
let total_bw: f64 = uplinks
.iter()
.filter(|u| u.can_send())
.map(|u| u.bandwidth().bytes_per_sec)
.sum();
if total_bw == 0.0 {
return self.select_wrr(uplinks);
}
let r: f64 = rand::random();
let mut cumulative = 0.0;
for uplink in uplinks.iter().filter(|u| u.can_send()) {
cumulative += uplink.bandwidth().bytes_per_sec / total_bw;
if r <= cumulative {
return vec![uplink.numeric_id()];
}
}
uplinks
.first()
.map(|u| vec![u.numeric_id()])
.unwrap_or_default()
}
fn select_ecmp_aware(uplinks: &[&Arc<Uplink>], flow_id: Option<u64>) -> Vec<u16> {
let sendable: Vec<_> = uplinks.iter().filter(|u| u.can_send()).collect();
if sendable.is_empty() {
return vec![];
}
if let Some(flow) = flow_id {
let index = (flow as usize) % sendable.len();
return vec![sendable[index].numeric_id()];
}
sendable
.iter()
.max_by(|a, b| {
let score_a = a.priority_score();
let score_b = b.priority_score();
score_a.cmp(&score_b)
})
.map(|u| vec![u.numeric_id()])
.unwrap_or_default()
}
fn rtt_score(uplink: &Uplink) -> f32 {
let rtt = uplink.rtt().as_secs_f32() * 1000.0; 1.0 / (1.0 + rtt / 50.0)
}
fn loss_score(uplink: &Uplink) -> f32 {
let loss = uplink.loss_ratio() as f32;
1.0 - loss.min(1.0)
}
fn bandwidth_score(uplink: &Uplink, all: &[&Arc<Uplink>]) -> f32 {
let bw = uplink.bandwidth().bytes_per_sec;
let max_bw: f64 = all
.iter()
.map(|u| u.bandwidth().bytes_per_sec)
.max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.unwrap_or(1.0);
if max_bw == 0.0 {
0.5
} else {
(bw / max_bw) as f32
}
}
fn nat_score(uplink: &Uplink) -> f32 {
if !uplink.is_natted() {
return 1.0;
}
match uplink.nat_type() {
super::nat::NatType::None => 1.0,
super::nat::NatType::Unknown => 0.7, super::nat::NatType::FullCone => 0.8, super::nat::NatType::RestrictedCone => 0.6,
super::nat::NatType::PortRestrictedCone => 0.4,
super::nat::NatType::Symmetric => 0.2, }
}
pub fn needs_probe(&self, uplink: &Uplink) -> bool {
if !self.config.probe_backup_paths {
return false;
}
let probes = self.last_probe.read();
match probes.get(&uplink.numeric_id()) {
Some(last) => last.elapsed() >= self.config.probe_interval,
None => true,
}
}
pub fn record_probe(&self, uplink_id: u16) {
self.last_probe.write().insert(uplink_id, Instant::now());
}
pub fn cleanup(&self) {
self.stickiness.write().cleanup(self.config.sticky_timeout);
let timeout = self.config.probe_interval * 10;
self.last_probe
.write()
.retain(|_, last| last.elapsed() < timeout);
self.throughput_cache
.write()
.retain(|_, (cached_at, _)| cached_at.elapsed() < self.cache_ttl * 10);
}
pub fn uplinks_to_probe(&self, uplinks: &[Arc<Uplink>]) -> Vec<u16> {
uplinks
.iter()
.filter(|u| u.is_usable() && self.needs_probe(u))
.map(|u| u.numeric_id())
.collect()
}
}
#[allow(clippy::missing_fields_in_debug)]
impl std::fmt::Debug for Scheduler {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Scheduler")
.field("strategy", &self.config.strategy)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_scheduler_creation() {
let scheduler = Scheduler::new(SchedulerConfig::default());
assert_eq!(scheduler.config.strategy, SchedulingStrategy::Adaptive);
}
#[test]
fn test_empty_uplinks() {
let scheduler = Scheduler::new(SchedulerConfig::default());
let result = scheduler.select(&[], None);
assert!(result.is_empty());
}
}