#![cfg_attr(docsrs, feature(doc_cfg))]
use std::{
marker::PhantomData,
ops::Deref,
sync::atomic::{AtomicU32, Ordering},
};
pub trait Rank {
fn calc(latency_us: u32) -> u32;
fn init() -> u32 {
1
}
}
pub struct Inverse;
impl Rank for Inverse {
#[inline(always)]
fn calc(latency: u32) -> u32 {
const BASE: u32 = 1 << 22;
BASE / latency.max(1)
}
}
pub struct Handle<'a, T> {
pub index: usize,
pub data: &'a T,
}
impl<'a, T> Deref for Handle<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.data
}
}
pub struct PickFast<T, M = Inverse> {
pub li: Vec<T>,
pub weight_li: Vec<AtomicU32>,
pub total: AtomicU32,
_marker: PhantomData<M>,
}
unsafe impl<T: Sync, M> Sync for PickFast<T, M> {}
unsafe impl<T: Send, M> Send for PickFast<T, M> {}
impl<T, M: Rank> PickFast<T, M> {
pub fn new(li: impl Into<Vec<T>>) -> Self {
let li = li.into();
let n = li.len();
assert!(n > 0, "PickFast: node count must be > 0");
if n > 256 {
log::warn!("PickFast n={n} is large, ensure Rank won't overflow u32");
}
let init_val = M::init();
let weight_li: Vec<AtomicU32> = (0..n).map(|_| AtomicU32::new(init_val)).collect();
let total = AtomicU32::new(init_val * (n as u32));
Self {
li,
weight_li,
total,
_marker: PhantomData,
}
}
#[inline(always)]
pub fn len(&self) -> usize {
self.li.len()
}
#[inline(always)]
pub fn is_empty(&self) -> bool {
self.li.is_empty()
}
#[inline(always)]
pub fn pick(&self) -> Handle<'_, T> {
let total_w = self.total.load(Ordering::Relaxed);
if total_w == 0 {
return Handle {
index: 0,
data: &self.li[0],
};
}
let target = fastrand::u32(0..total_w);
let mut sum = 0;
for (i, w) in self.weight_li.iter().enumerate() {
sum += w.load(Ordering::Relaxed);
if sum > target {
return Handle {
index: i,
data: &self.li[i],
};
}
}
let last = self.li.len() - 1;
Handle {
index: last,
data: &self.li[last],
}
}
#[inline(always)]
pub fn set(&self, index: usize, val: u32) {
if index >= self.li.len() {
return;
}
let target_w = M::calc(val);
let _ = self.weight_li[index]
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |old| {
Some((old + target_w) >> 1)
})
.map(|prev| {
let new_w = (prev + target_w) >> 1;
if new_w > prev {
self.total.fetch_add(new_w - prev, Ordering::Relaxed);
} else {
self.total.fetch_sub(prev - new_w, Ordering::Relaxed);
}
});
}
}