use std::{
collections::{BTreeMap, HashMap, VecDeque},
fmt::{Debug, Display},
hash::Hash,
num::NonZeroUsize,
sync::atomic::{AtomicBool, AtomicUsize, Ordering},
};
use enum_iterator::IntoEnumIterator;
use serde::Serialize;
use tokio::sync::{Mutex, MutexGuard, Semaphore};
use tracing::{debug, warn};
#[derive(Debug)]
pub struct WeightedRoundRobin<I, K> {
state: Mutex<IterationState<K>>,
slots: Vec<Slot<K>>,
queues: HashMap<K, QueueState<I>>,
total: Semaphore,
sealed: AtomicBool,
recent_event_count_peak: Option<AtomicUsize>,
}
#[derive(Debug)]
struct QueueState<I> {
event_count: AtomicUsize,
queue: Mutex<VecDeque<I>>,
}
impl<I> QueueState<I> {
fn new() -> Self {
QueueState {
event_count: AtomicUsize::new(0),
queue: Mutex::new(VecDeque::new()),
}
}
#[cfg(test)]
async fn drain(&self) -> Vec<I> {
let mut guard = self.queue.lock().await;
let events: Vec<I> = guard.drain(..).collect();
self.event_count.fetch_sub(events.len(), Ordering::SeqCst);
events
}
#[inline]
async fn push_back(&self, element: I) {
self.queue.lock().await.push_back(element);
self.event_count.fetch_add(1, Ordering::SeqCst);
}
#[inline]
fn dec_count(&self) {
self.event_count.fetch_sub(1, Ordering::SeqCst);
}
#[inline]
fn event_count(&self) -> usize {
self.event_count.load(Ordering::SeqCst)
}
}
#[derive(Copy, Clone, Debug)]
struct IterationState<K> {
active_slot: Slot<K>,
active_slot_idx: usize,
}
#[derive(Copy, Clone, Debug)]
struct Slot<K> {
key: K,
tickets: usize,
}
#[derive(Debug, Serialize)]
pub struct QueueDump<'a, K, I>
where
K: Ord + Eq,
{
queues: BTreeMap<K, &'a VecDeque<I>>,
}
impl<I, K> WeightedRoundRobin<I, K>
where
I: Debug,
K: Copy + Clone + Eq + Hash + IntoEnumIterator + Debug,
{
pub(crate) fn new(
weights: Vec<(K, NonZeroUsize)>,
initial_event_count_threshold: Option<usize>,
) -> Self {
assert!(!weights.is_empty(), "must provide at least one slot");
let queues = weights
.iter()
.map(|(idx, _)| (*idx, QueueState::new()))
.collect();
let slots: Vec<Slot<K>> = weights
.into_iter()
.map(|(key, tickets)| Slot {
key,
tickets: tickets.get(),
})
.collect();
let active_slot = slots[0];
WeightedRoundRobin {
state: Mutex::new(IterationState {
active_slot,
active_slot_idx: 0,
}),
slots,
queues,
total: Semaphore::new(0),
sealed: AtomicBool::new(false),
recent_event_count_peak: initial_event_count_threshold.map(AtomicUsize::new),
}
}
pub async fn dump<F: FnOnce(&QueueDump<K, I>)>(&self, dumper: F)
where
K: Ord,
{
let locks = self.lock_queues().await;
let mut queues = BTreeMap::new();
for (kind, guard) in &locks {
let queue = &**guard;
queues.insert(*kind, queue);
}
let queue_dump = QueueDump { queues };
dumper(&queue_dump);
}
async fn lock_queues(&self) -> Vec<(K, MutexGuard<'_, VecDeque<I>>)> {
let mut locks = Vec::new();
for kind in K::into_enum_iter() {
let queue_guard = self
.queues
.get(&kind)
.expect("missing queue while locking")
.queue
.lock()
.await;
locks.push((kind, queue_guard));
}
locks
}
}
fn should_dump_queues(total: usize, recent_threshold: usize) -> bool {
total > ((recent_threshold * 11) / 10)
}
impl<I, K> WeightedRoundRobin<I, K>
where
K: Copy + Clone + Eq + Hash + Display,
{
pub(crate) async fn push(&self, item: I, queue: K) {
if self.sealed.load(Ordering::SeqCst) {
debug!("queue sealed, dropping item");
return;
}
self.queues
.get(&queue)
.expect("tried to push to non-existent queue")
.push_back(item)
.await;
if let Some(recent_event_count_peak) = &self.recent_event_count_peak {
let total = self.queues.iter().map(|q| q.1.event_count()).sum::<usize>();
let recent_threshold = recent_event_count_peak.load(Ordering::SeqCst);
if should_dump_queues(total, recent_threshold) {
recent_event_count_peak.store(total, Ordering::SeqCst);
let info: Vec<_> = self
.queues
.iter()
.map(|q| (q.0.to_string(), q.1.event_count()))
.filter(|(_, count)| count > &0)
.collect();
warn!("Current event queue size ({total}) is above the threshold ({recent_threshold}): details {info:?}");
}
}
self.total.add_permits(1);
}
pub(crate) async fn pop(&self) -> (I, K) {
self.total.acquire().await.expect("should acquire").forget();
let mut inner = self.state.lock().await;
loop {
let queue_state = self
.queues
.get(&inner.active_slot.key)
.expect("the queue disappeared. this should not happen");
let mut current_queue = queue_state.queue.lock().await;
if inner.active_slot.tickets == 0 || current_queue.is_empty() {
inner.active_slot_idx = (inner.active_slot_idx + 1) % self.slots.len();
inner.active_slot = self.slots[inner.active_slot_idx];
continue;
}
inner.active_slot.tickets -= 1;
let item = current_queue
.pop_front()
.expect("item disappeared. this should not happen");
queue_state.dec_count();
break (item, inner.active_slot.key);
}
}
#[cfg(test)]
pub(crate) async fn drain_queue(&self, queue: K) -> Vec<I> {
let events = self
.queues
.get(&queue)
.expect("queue to be drained disappeared")
.drain()
.await;
self.total
.acquire_many(events.len() as u32)
.await
.expect("could not acquire tickets during drain")
.forget();
events
}
#[cfg(test)]
pub async fn drain_queues(&self) -> Vec<I> {
let mut events = Vec::new();
let keys: Vec<K> = self.queues.keys().cloned().collect();
for kind in keys {
events.extend(self.drain_queue(kind).await);
}
events
}
#[cfg(test)]
pub fn seal(&self) {
self.sealed.store(true, Ordering::SeqCst);
}
#[cfg(test)]
pub(crate) fn item_count(&self) -> usize {
self.total.available_permits()
}
pub(crate) fn event_queues_counts(&self) -> HashMap<K, usize> {
self.queues
.iter()
.map(|(key, queue)| (*key, queue.event_count()))
.collect()
}
}
#[cfg(test)]
mod tests {
use std::num::NonZeroUsize;
use futures::{future::FutureExt, join};
use super::*;
#[repr(usize)]
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, IntoEnumIterator)]
enum QueueKind {
One = 1,
Two,
}
fn weights() -> Vec<(QueueKind, NonZeroUsize)> {
unsafe {
vec![
(QueueKind::One, NonZeroUsize::new_unchecked(1)),
(QueueKind::Two, NonZeroUsize::new_unchecked(2)),
]
}
}
impl Display for QueueKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
QueueKind::One => write!(f, "One"),
QueueKind::Two => write!(f, "Two"),
}
}
}
#[tokio::test]
async fn should_respect_weighting() {
let scheduler = WeightedRoundRobin::<char, QueueKind>::new(weights(), None);
let future1 = scheduler
.push('a', QueueKind::One)
.then(|_| scheduler.push('b', QueueKind::One))
.then(|_| scheduler.push('c', QueueKind::One));
let future2 = scheduler
.push('d', QueueKind::Two)
.then(|_| scheduler.push('e', QueueKind::Two))
.then(|_| scheduler.push('f', QueueKind::Two));
join!(future2, future1);
assert_eq!(('a', QueueKind::One), scheduler.pop().await);
assert_eq!(('d', QueueKind::Two), scheduler.pop().await);
assert_eq!(('e', QueueKind::Two), scheduler.pop().await);
assert_eq!(('b', QueueKind::One), scheduler.pop().await);
assert_eq!(('f', QueueKind::Two), scheduler.pop().await);
assert_eq!(('c', QueueKind::One), scheduler.pop().await);
}
#[tokio::test]
async fn can_seal_queue() {
let scheduler = WeightedRoundRobin::<char, QueueKind>::new(weights(), None);
assert_eq!(scheduler.item_count(), 0);
scheduler.push('a', QueueKind::One).await;
assert_eq!(scheduler.item_count(), 1);
scheduler.push('b', QueueKind::Two).await;
assert_eq!(scheduler.item_count(), 2);
scheduler.seal();
assert_eq!(scheduler.item_count(), 2);
scheduler.push('c', QueueKind::One).await;
assert_eq!(scheduler.item_count(), 2);
scheduler.push('d', QueueKind::One).await;
assert_eq!(scheduler.item_count(), 2);
assert_eq!(('a', QueueKind::One), scheduler.pop().await);
assert_eq!(scheduler.item_count(), 1);
assert_eq!(('b', QueueKind::Two), scheduler.pop().await);
assert_eq!(scheduler.item_count(), 0);
assert!(scheduler.drain_queues().await.is_empty());
}
#[test]
fn should_calculate_dump_threshold() {
let total = 0;
let recent_threshold = 100;
assert!(!should_dump_queues(total, recent_threshold));
let total = 100;
let recent_threshold = 100;
assert!(!should_dump_queues(total, recent_threshold));
let total = 109;
let recent_threshold = 100;
assert!(!should_dump_queues(total, recent_threshold));
let total = 110;
let recent_threshold = 100;
assert!(!should_dump_queues(total, recent_threshold));
let total = 111;
let recent_threshold = 100;
assert!(should_dump_queues(total, recent_threshold));
let total = 112;
let recent_threshold = 100;
assert!(should_dump_queues(total, recent_threshold));
let total = 1_000_000;
let recent_threshold = 100;
assert!(should_dump_queues(total, recent_threshold));
}
}