use std::collections::HashMap;
use std::hash::Hash;
use std::marker::PhantomData;
use super::{Accumulator, Collector};
#[derive(Debug, Clone)]
pub struct LoadBalance<K> {
loads: HashMap<K, i64>,
unfairness: i64,
}
impl<K> LoadBalance<K> {
pub fn loads(&self) -> &HashMap<K, i64> {
&self.loads
}
#[inline]
pub fn unfairness(&self) -> i64 {
self.unfairness
}
}
pub fn load_balance<K, F, M>(key_fn: F, metric_fn: M) -> LoadBalanceCollector<K, F, M>
where
K: Clone + Eq + Hash + Send + Sync,
F: Send + Sync,
M: Send + Sync,
{
LoadBalanceCollector {
key_fn,
metric_fn,
_phantom: PhantomData,
}
}
pub struct LoadBalanceCollector<K, F, M> {
key_fn: F,
metric_fn: M,
_phantom: PhantomData<fn() -> K>,
}
impl<Input, K, F, M> Collector<Input> for LoadBalanceCollector<K, F, M>
where
Input: Copy + Send + Sync,
K: Clone + Eq + Hash + Send + Sync,
F: Fn(Input) -> K + Send + Sync,
M: Fn(Input) -> i64 + Send + Sync,
{
type Value = (K, i64);
type Result = LoadBalance<K>;
type Accumulator = LoadBalanceAccumulator<K>;
#[inline]
fn extract(&self, input: Input) -> Self::Value {
((self.key_fn)(input), (self.metric_fn)(input))
}
fn create_accumulator(&self) -> Self::Accumulator {
LoadBalanceAccumulator::new()
}
}
pub struct LoadBalanceAccumulator<K> {
item_counts: HashMap<K, usize>,
loads: HashMap<K, i64>,
sum: i64,
squared_deviation_integral: i64,
squared_deviation_fraction_numerator: i64,
}
impl<K: Clone + Eq + Hash> LoadBalanceAccumulator<K> {
fn new() -> Self {
Self {
item_counts: HashMap::new(),
loads: HashMap::new(),
sum: 0,
squared_deviation_integral: 0,
squared_deviation_fraction_numerator: 0,
}
}
fn add_to_metric(&mut self, key: &K, diff: i64) {
let old_value = *self.loads.get(key).unwrap_or(&0);
let new_value = old_value + diff;
if old_value != new_value {
self.loads.insert(key.clone(), new_value);
self.update_squared_deviation(old_value, new_value);
self.sum += diff;
}
}
fn reset_metric(&mut self, key: &K) {
if let Some(old_value) = self.loads.remove(key) {
if old_value != 0 {
self.update_squared_deviation(old_value, 0);
self.sum -= old_value;
}
}
}
fn update_squared_deviation(&mut self, old_value: i64, new_value: i64) {
let term1 = new_value * new_value - old_value * old_value;
let sum_others = 2 * (self.sum - old_value);
let new_sum = self.sum - old_value + new_value;
let sum_diff = self.sum - new_sum;
let term3 = new_sum * new_sum - self.sum * self.sum;
let term4 = 2 * (old_value * self.sum - new_value * new_sum);
let fraction_delta = sum_others * sum_diff + term3 + term4;
self.squared_deviation_integral += term1;
self.squared_deviation_fraction_numerator += fraction_delta;
}
fn compute_unfairness(&self) -> i64 {
let n = self.item_counts.len();
match n {
0 => 0,
1 => {
let tmp = self.squared_deviation_fraction_numerator as f64
+ self.squared_deviation_integral as f64;
tmp.sqrt().round() as i64
}
_ => {
let tmp = (self.squared_deviation_fraction_numerator as f64 / n as f64)
+ self.squared_deviation_integral as f64;
tmp.sqrt().round() as i64
}
}
}
}
impl<K: Clone + Eq + Hash + Send + Sync> Accumulator<(K, i64), LoadBalance<K>>
for LoadBalanceAccumulator<K>
{
type Retraction = (K, i64);
#[inline]
fn accumulate(&mut self, value: (K, i64)) -> Self::Retraction {
let (key, metric) = value;
if metric == 0 {
return (key, metric); }
let count = self.item_counts.entry(key.clone()).or_insert(0);
*count += 1;
self.add_to_metric(&key, metric);
(key, metric)
}
#[inline]
fn retract(&mut self, value: Self::Retraction) {
let (key, metric) = value;
if metric == 0 {
return; }
if let Some(count) = self.item_counts.get_mut(&key) {
if *count > 0 {
*count -= 1;
if *count == 0 {
self.item_counts.remove(&key);
self.reset_metric(&key);
} else {
self.add_to_metric(&key, -metric);
}
}
}
}
fn with_result<T>(&self, f: impl FnOnce(&LoadBalance<K>) -> T) -> T {
let result = LoadBalance {
loads: self.loads.clone(),
unfairness: self.compute_unfairness(),
};
f(&result)
}
fn reset(&mut self) {
self.item_counts.clear();
self.loads.clear();
self.sum = 0;
self.squared_deviation_integral = 0;
self.squared_deviation_fraction_numerator = 0;
}
}