#![cfg_attr(docsrs, feature(doc_cfg))]
use std::{
marker::PhantomData,
ops::Deref,
sync::atomic::{AtomicU32, Ordering},
};
pub trait Rank {
fn calc(base: u32, latency_us: u32) -> u32;
}
pub struct Inverse;
impl Rank for Inverse {
fn calc(base: u32, latency: u32) -> u32 {
base / (1+latency)
}
}
pub struct Node<T> {
pub data: T,
pub weight: AtomicU32,
}
impl<T> Node<T> {
pub fn new(data: T, weight: u32) -> Self {
Self {
data,
weight: AtomicU32::new(weight),
}
}
}
impl<T> Deref for Node<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.data
}
}
pub struct Handle<'a, T> {
pub index: usize,
pub node: &'a Node<T>,
}
impl<'a, T> Deref for Handle<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.node.data
}
}
pub struct PickFast<T, M = Inverse> {
pub li: Vec<Node<T>>,
pub total: AtomicU32,
base: u32,
_marker: PhantomData<M>,
}
unsafe impl<T: Sync, M> Sync for PickFast<T, M> {}
unsafe impl<T: Send, M> Send for PickFast<T, M> {}
impl<T, M: Rank> PickFast<T, M> {
pub fn new(data: impl IntoIterator<Item = T>) -> Self {
let li: Vec<Node<T>> = data.into_iter().map(|d| Node::new(d, 0)).collect();
let n = li.len();
assert!(n > 0, "PickFast: node count must be > 0");
let base = u32::MAX / (n as u32) / 32;
let init_val = (base / (n as u32)).max(1);
for node in &li {
node.weight.store(init_val, Ordering::Relaxed);
}
let total = AtomicU32::new(init_val * (n as u32));
Self {
li,
total,
base,
_marker: PhantomData,
}
}
pub fn len(&self) -> usize {
self.li.len()
}
pub fn is_empty(&self) -> bool {
self.li.is_empty()
}
pub fn pick(&self) -> Handle<'_, T> {
let total_w = self.total.load(Ordering::Relaxed);
if total_w == 0 {
return Handle {
index: 0,
node: &self.li[0],
};
}
let target = fastrand::u32(0..total_w);
let mut sum = 0;
for (i, node) in self.li.iter().enumerate() {
sum += node.weight.load(Ordering::Relaxed);
if sum > target {
return Handle { index: i, node };
}
}
let last = self.li.len() - 1;
Handle {
index: last,
node: &self.li[last],
}
}
pub fn set(&self, index: usize, val: u32) {
if index >= self.li.len() {
return;
}
let target_w = M::calc(self.base, val.max(1)).max(1);
let _ = self.li[index]
.weight
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |old| {
Some(((old * 31 + target_w) >> 5).max(1))
})
.map(|prev| {
let new_w = ((prev * 31 + target_w) >> 5).max(1);
if new_w > prev {
self.total.fetch_add(new_w - prev, Ordering::Relaxed);
} else {
self.total.fetch_sub(prev - new_w, Ordering::Relaxed);
}
});
}
pub fn failed(&self, index: usize) {
if index >= self.li.len() {
return;
}
let _ = self.li[index]
.weight
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |old| {
Some((old >> 1).max(1))
})
.map(|prev| {
let new_w = (prev >> 1).max(1);
self.total.fetch_sub(prev - new_w, Ordering::Relaxed);
});
}
#[cfg(feature = "iter")]
pub fn iter(&self) -> citer::CIter<'_, Node<T>> {
let total_w = self.total.load(Ordering::Relaxed);
let start_pos = if total_w == 0 {
0
} else {
let target = fastrand::u32(0..total_w);
let mut sum = 0;
let mut pos = 0;
for (i, node) in self.li.iter().enumerate() {
sum += node.weight.load(Ordering::Relaxed);
if sum > target {
pos = i;
break;
}
}
pos
};
citer::CIter::new(&self.li[..], start_pos)
}
}