use super::Weight;
use std::{collections::HashMap, hash::Hash};
#[derive(Copy, Clone, Debug)]
struct SmoothWeightItem<T> {
item: T,
weight: isize,
current_weight: isize,
effective_weight: isize,
}
#[derive(Default)]
pub struct SmoothWeight<T> {
items: Vec<SmoothWeightItem<T>>,
n: isize,
}
impl<T: Copy + PartialEq + Eq + Hash> SmoothWeight<T> {
pub fn new() -> Self {
SmoothWeight {
items: Vec::new(),
n: 0,
}
}
fn next_smooth_weighted(&mut self) -> Option<SmoothWeightItem<T>> {
let mut total = 0;
let mut best = self.items[0];
let mut best_index = 0;
let mut found = false;
let items_len = self.items.len();
for i in 0..items_len {
self.items[i].current_weight += self.items[i].effective_weight;
total += self.items[i].effective_weight;
if self.items[i].effective_weight < self.items[i].weight {
self.items[i].effective_weight += 1;
}
if !found || self.items[i].current_weight > best.current_weight {
best = self.items[i];
found = true;
best_index = i;
}
}
if !found {
return None;
}
self.items[best_index].current_weight -= total;
Some(best)
}
}
impl<T: Copy + PartialEq + Eq + Hash> Weight for SmoothWeight<T> {
type Item = T;
fn next(&mut self) -> Option<T> {
if self.n == 0 {
return None;
}
if self.n == 1 {
return Some(self.items[0].item);
}
let rt = self.next_smooth_weighted()?;
Some(rt.item)
}
fn add(&mut self, item: T, weight: isize) {
let weight_item = SmoothWeightItem {
item,
weight,
current_weight: 0,
effective_weight: weight,
};
self.items.push(weight_item);
self.n += 1;
}
fn all(&self) -> HashMap<T, isize> {
let mut rt: HashMap<T, isize> = HashMap::new();
for w in &self.items {
rt.insert(w.item, w.weight);
}
rt
}
fn remove_all(&mut self) {
self.items.clear();
self.n = 0;
}
fn reset(&mut self) {
for w in &mut self.items {
w.current_weight = 0;
w.effective_weight = w.weight;
}
}
}
#[cfg(test)]
mod tests {
use crate::{SmoothWeight, Weight};
use std::collections::HashMap;
#[test]
fn test_smooth_weight() {
let mut sw: SmoothWeight<&str> = SmoothWeight::new();
sw.add("server1", 5);
sw.add("server2", 2);
sw.add("server3", 3);
let mut results: HashMap<&str, usize> = HashMap::new();
for _ in 0..100 {
let s = sw.next().unwrap();
*results.entry(s).or_insert(0) += 1;
}
assert_eq!(results["server1"], 50);
assert_eq!(results["server2"], 20);
assert_eq!(results["server3"], 30);
}
}