#![cfg_attr(docsrs, feature(doc_cfg))]
use std::{
marker::PhantomData,
ops::Deref,
sync::atomic::{AtomicU32, Ordering},
};
pub trait Rank {
fn calc(latency_us: u32) -> u32;
fn init() -> u32 {
1
}
}
pub struct Inverse;
impl Rank for Inverse {
#[inline(always)]
fn calc(latency: u32) -> u32 {
const BASE: u32 = 1 << 22;
BASE / latency.max(1)
}
}
pub struct Handle<'a, T> {
pub index: usize,
pub data: &'a T,
}
impl<'a, T> Deref for Handle<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.data
}
}
#[repr(align(64))] pub struct PickFast<T, const N: usize, M = Inverse> {
pub li: [T; N],
pub weight_li: [AtomicU32; N],
total: AtomicU32,
_marker: PhantomData<M>,
}
unsafe impl<T: Sync, const N: usize, M> Sync for PickFast<T, N, M> {}
unsafe impl<T: Send, const N: usize, M> Send for PickFast<T, N, M> {}
impl<T, const N: usize, M: Rank> PickFast<T, N, M> {
pub fn new(li: [T; N]) -> Self {
assert!(N > 0, "PickFast: N must be > 0");
if N > 256 {
log::warn!("PickFast N={N} is large, ensure Rank won't overflow u32");
}
let init_val = M::init();
let weight_li = [const { AtomicU32::new(1) }; N];
for w in &weight_li {
w.store(init_val, Ordering::Relaxed);
}
let total = AtomicU32::new(init_val * (N as u32));
Self {
li,
weight_li,
total,
_marker: PhantomData,
}
}
#[inline(always)]
pub fn pick(&self) -> Handle<'_, T> {
let total_w = self.total.load(Ordering::Relaxed);
if total_w == 0 {
return Handle {
index: 0,
data: &self.li[0],
};
}
let target = fastrand::u32(0..total_w);
let mut sum = 0;
for (i, w) in self.weight_li.iter().enumerate() {
sum += w.load(Ordering::Relaxed);
if sum > target {
return Handle {
index: i,
data: &self.li[i],
};
}
}
Handle {
index: N - 1,
data: &self.li[N - 1],
}
}
#[inline(always)]
pub fn set(&self, index: usize, val: u32) {
if index >= N {
return;
}
let target_w = M::calc(val);
let _ = self.weight_li[index]
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |old| {
Some((old + target_w) >> 1)
})
.map(|prev| {
let new_w = (prev + target_w) >> 1;
if new_w > prev {
self.total.fetch_add(new_w - prev, Ordering::Relaxed);
} else {
self.total.fetch_sub(prev - new_w, Ordering::Relaxed);
}
});
}
}