#![cfg_attr(docsrs, feature(doc_cfg))]
use std::{
marker::PhantomData,
ops::Deref,
sync::atomic::{AtomicU32, Ordering},
};
pub trait Rank {
fn calc(latency_us: u32) -> u32;
fn init() -> u32 {
1
}
}
pub struct Inverse;
impl Rank for Inverse {
#[inline(always)]
fn calc(latency: u32) -> u32 {
const BASE: u32 = 1 << 22;
BASE / latency.max(1)
}
}
pub struct Node<T> {
pub data: T,
pub weight: AtomicU32,
}
impl<T> Node<T> {
pub fn new(data: T, weight: u32) -> Self {
Self {
data,
weight: AtomicU32::new(weight),
}
}
}
impl<T> Deref for Node<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.data
}
}
pub struct Handle<'a, T> {
pub index: usize,
pub node: &'a Node<T>,
}
impl<'a, T> Deref for Handle<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.node.data
}
}
pub struct PickFast<T, M = Inverse> {
pub li: Vec<Node<T>>,
pub total: AtomicU32,
_marker: PhantomData<M>,
}
unsafe impl<T: Sync, M> Sync for PickFast<T, M> {}
unsafe impl<T: Send, M> Send for PickFast<T, M> {}
impl<T, M: Rank> PickFast<T, M> {
pub fn new(data: impl IntoIterator<Item = T>) -> Self {
let init_val = M::init();
let li: Vec<Node<T>> = data
.into_iter()
.map(|d| Node::new(d, init_val))
.collect();
let n = li.len();
assert!(n > 0, "PickFast: node count must be > 0");
if n > 256 {
log::warn!("PickFast n={n} is large, ensure Rank won't overflow u32");
}
let total = AtomicU32::new(init_val * (n as u32));
Self {
li,
total,
_marker: PhantomData,
}
}
#[inline(always)]
pub fn len(&self) -> usize {
self.li.len()
}
#[inline(always)]
pub fn is_empty(&self) -> bool {
self.li.is_empty()
}
#[inline(always)]
pub fn pick(&self) -> Handle<'_, T> {
let total_w = self.total.load(Ordering::Relaxed);
if total_w == 0 {
return Handle {
index: 0,
node: &self.li[0],
};
}
let target = fastrand::u32(0..total_w);
let mut sum = 0;
for (i, node) in self.li.iter().enumerate() {
sum += node.weight.load(Ordering::Relaxed);
if sum > target {
return Handle { index: i, node };
}
}
let last = self.li.len() - 1;
Handle {
index: last,
node: &self.li[last],
}
}
#[inline(always)]
pub fn set(&self, index: usize, val: u32) {
if index >= self.li.len() {
return;
}
let target_w = M::calc(val);
let _ = self.li[index]
.weight
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |old| {
Some((old + target_w) >> 1)
})
.map(|prev| {
let new_w = (prev + target_w) >> 1;
if new_w > prev {
self.total.fetch_add(new_w - prev, Ordering::Relaxed);
} else {
self.total.fetch_sub(prev - new_w, Ordering::Relaxed);
}
});
}
}