use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::Mutex;
use super::{SelectionContext, Strategy};
use crate::{Node, Weighted};
pub struct WeightedRoundRobin {
state: Mutex<WrrState>,
}
impl std::fmt::Debug for WeightedRoundRobin {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WeightedRoundRobin").finish_non_exhaustive()
}
}
struct WrrState {
fingerprint: u64,
weights: Vec<i64>,
}
impl WeightedRoundRobin {
pub fn new() -> Self {
Self {
state: Mutex::new(WrrState {
fingerprint: 0,
weights: Vec::new(),
}),
}
}
}
impl Default for WeightedRoundRobin {
fn default() -> Self {
Self::new()
}
}
fn wrr_fingerprint<N: Weighted + Node>(candidates: &[N]) -> u64 {
let mut hasher = DefaultHasher::new();
candidates.len().hash(&mut hasher);
for node in candidates {
node.id().hash(&mut hasher);
node.weight().hash(&mut hasher);
}
hasher.finish()
}
impl<N: Weighted + Node> Strategy<N> for WeightedRoundRobin {
fn select(&self, candidates: &[N], ctx: &SelectionContext) -> Option<usize> {
if candidates.is_empty() {
return None;
}
let fingerprint = wrr_fingerprint(candidates);
let mut state = self.state.lock().unwrap();
if state.fingerprint != fingerprint {
state.fingerprint = fingerprint;
state.weights = vec![0; candidates.len()];
}
let total_weight: i64 = candidates
.iter()
.enumerate()
.filter(|(i, _)| !ctx.is_excluded(*i))
.map(|(_, n)| n.weight() as i64)
.sum();
if total_weight == 0 {
return None;
}
let mut best_idx = None;
let mut best_weight = i64::MIN;
for (i, node) in candidates.iter().enumerate() {
if ctx.is_excluded(i) {
continue;
}
state.weights[i] += node.weight() as i64;
if state.weights[i] > best_weight {
best_weight = state.weights[i];
best_idx = Some(i);
}
}
if let Some(idx) = best_idx {
state.weights[idx] -= total_weight;
}
best_idx
}
}
#[cfg(test)]
mod tests {
use super::*;
struct W {
id: &'static str,
weight: u32,
}
impl W {
fn new(id: &'static str, weight: u32) -> Self {
Self { id, weight }
}
}
impl Node for W {
type Id = &'static str;
fn id(&self) -> &&'static str {
&self.id
}
}
impl Weighted for W {
fn weight(&self) -> u32 {
self.weight
}
}
#[test]
fn respects_weights() {
let wrr = WeightedRoundRobin::new();
let nodes = [W::new("a", 5), W::new("b", 1), W::new("c", 1)];
let ctx = SelectionContext::default();
let mut counts = [0u32; 3];
for _ in 0..70 {
let idx = wrr.select(&nodes, &ctx).unwrap();
counts[idx] += 1;
}
assert_eq!(counts[0], 50);
assert_eq!(counts[1], 10);
assert_eq!(counts[2], 10);
}
#[test]
fn smooth_distribution() {
let wrr = WeightedRoundRobin::new();
let nodes = [W::new("x", 2), W::new("y", 1)];
let ctx = SelectionContext::default();
let sequence: Vec<usize> = (0..6)
.map(|_| wrr.select(&nodes, &ctx).unwrap())
.collect();
assert_eq!(sequence, vec![0, 1, 0, 0, 1, 0]);
}
#[test]
fn skips_excluded() {
let wrr = WeightedRoundRobin::new();
let nodes = [W::new("a", 3), W::new("b", 1)];
let ctx = SelectionContext::builder().exclude(vec![0]).build();
assert_eq!(wrr.select(&nodes, &ctx), Some(1));
}
#[test]
fn all_excluded_returns_none() {
let wrr = WeightedRoundRobin::new();
let nodes = [W::new("a", 1), W::new("b", 1)];
let ctx = SelectionContext::builder().exclude(vec![0, 1]).build();
assert_eq!(wrr.select(&nodes, &ctx), None);
}
#[test]
fn resets_on_candidate_change() {
let wrr = WeightedRoundRobin::new();
let ctx = SelectionContext::default();
let nodes_v1 = [W::new("a", 2), W::new("b", 1)];
let _ = wrr.select(&nodes_v1, &ctx);
let _ = wrr.select(&nodes_v1, &ctx);
let nodes_v2 = [W::new("b", 1), W::new("c", 3)];
let idx = wrr.select(&nodes_v2, &ctx).unwrap();
assert_eq!(idx, 1);
}
}