use std::collections::HashMap;
use std::hash::Hash;
use std::marker::PhantomData;
use super::{Accumulator, UniCollector};
#[derive(Debug, Clone)]
pub struct LoadBalance<K> {
loads: HashMap<K, i64>,
unfairness: i64,
}
impl<K> LoadBalance<K> {
pub fn loads(&self) -> &HashMap<K, i64> {
&self.loads
}
#[inline]
pub fn unfairness(&self) -> i64 {
self.unfairness
}
}
pub fn load_balance<A, K, F, M>(key_fn: F, metric_fn: M) -> LoadBalanceCollector<A, K, F, M>
where
K: Clone + Eq + Hash + Send + Sync,
F: Fn(&A) -> K + Send + Sync,
M: Fn(&A) -> i64 + Send + Sync,
{
LoadBalanceCollector {
key_fn,
metric_fn,
_phantom: PhantomData,
}
}
pub struct LoadBalanceCollector<A, K, F, M> {
key_fn: F,
metric_fn: M,
_phantom: PhantomData<fn(&A) -> K>,
}
impl<A, K, F, M> UniCollector<A> for LoadBalanceCollector<A, K, F, M>
where
A: Send + Sync,
K: Clone + Eq + Hash + Send + Sync,
F: Fn(&A) -> K + Send + Sync,
M: Fn(&A) -> i64 + Send + Sync,
{
type Value = (K, i64);
type Result = LoadBalance<K>;
type Accumulator = LoadBalanceAccumulator<K>;
#[inline]
fn extract(&self, entity: &A) -> Self::Value {
((self.key_fn)(entity), (self.metric_fn)(entity))
}
fn create_accumulator(&self) -> Self::Accumulator {
LoadBalanceAccumulator::new()
}
}
pub struct LoadBalanceAccumulator<K> {
item_counts: HashMap<K, usize>,
loads: HashMap<K, i64>,
sum: i64,
squared_deviation_integral: i64,
squared_deviation_fraction_numerator: i64,
}
impl<K: Clone + Eq + Hash> LoadBalanceAccumulator<K> {
fn new() -> Self {
Self {
item_counts: HashMap::new(),
loads: HashMap::new(),
sum: 0,
squared_deviation_integral: 0,
squared_deviation_fraction_numerator: 0,
}
}
fn add_to_metric(&mut self, key: &K, diff: i64) {
let old_value = *self.loads.get(key).unwrap_or(&0);
let new_value = old_value + diff;
if old_value != new_value {
self.loads.insert(key.clone(), new_value);
self.update_squared_deviation(old_value, new_value);
self.sum += diff;
}
}
fn reset_metric(&mut self, key: &K) {
if let Some(old_value) = self.loads.remove(key) {
if old_value != 0 {
self.update_squared_deviation(old_value, 0);
self.sum -= old_value;
}
}
}
fn update_squared_deviation(&mut self, old_value: i64, new_value: i64) {
let term1 = new_value * new_value - old_value * old_value;
let sum_others = 2 * (self.sum - old_value);
let new_sum = self.sum - old_value + new_value;
let sum_diff = self.sum - new_sum;
let term3 = new_sum * new_sum - self.sum * self.sum;
let term4 = 2 * (old_value * self.sum - new_value * new_sum);
let fraction_delta = sum_others * sum_diff + term3 + term4;
self.squared_deviation_integral += term1;
self.squared_deviation_fraction_numerator += fraction_delta;
}
fn compute_unfairness(&self) -> i64 {
let n = self.item_counts.len();
match n {
0 => 0,
1 => {
let tmp = self.squared_deviation_fraction_numerator as f64
+ self.squared_deviation_integral as f64;
tmp.sqrt().round() as i64
}
_ => {
let tmp = (self.squared_deviation_fraction_numerator as f64 / n as f64)
+ self.squared_deviation_integral as f64;
tmp.sqrt().round() as i64
}
}
}
}
impl<K: Clone + Eq + Hash + Send + Sync> Accumulator<(K, i64), LoadBalance<K>>
for LoadBalanceAccumulator<K>
{
#[inline]
fn accumulate(&mut self, value: &(K, i64)) {
let (key, metric) = value;
if *metric == 0 {
return; }
let count = self.item_counts.entry(key.clone()).or_insert(0);
*count += 1;
self.add_to_metric(key, *metric);
}
#[inline]
fn retract(&mut self, value: &(K, i64)) {
let (key, metric) = value;
if *metric == 0 {
return; }
if let Some(count) = self.item_counts.get_mut(key) {
if *count > 0 {
*count -= 1;
if *count == 0 {
self.item_counts.remove(key);
self.reset_metric(key);
} else {
self.add_to_metric(key, -*metric);
}
}
}
}
fn finish(&self) -> LoadBalance<K> {
LoadBalance {
loads: self.loads.clone(),
unfairness: self.compute_unfairness(),
}
}
fn reset(&mut self) {
self.item_counts.clear();
self.loads.clear();
self.sum = 0;
self.squared_deviation_integral = 0;
self.squared_deviation_fraction_numerator = 0;
}
}