use std::{
collections::{HashMap, VecDeque},
fmt::Debug,
fs::File,
hash::Hash,
io::{self, BufWriter, Write},
num::NonZeroUsize,
sync::atomic::{AtomicUsize, Ordering},
};
use enum_iterator::IntoEnumIterator;
use serde::{ser::SerializeMap, Serialize, Serializer};
use tokio::sync::{Mutex, Semaphore};
#[derive(Debug)]
pub struct WeightedRoundRobin<I, K> {
state: Mutex<IterationState<K>>,
slots: Vec<Slot<K>>,
queues: HashMap<K, QueueState<I>>,
total: Semaphore,
}
#[derive(Debug)]
struct QueueState<I> {
event_count: AtomicUsize,
queue: Mutex<VecDeque<I>>,
}
impl<I> QueueState<I> {
fn new() -> Self {
QueueState {
event_count: AtomicUsize::new(0),
queue: Mutex::new(VecDeque::new()),
}
}
#[inline]
async fn push_back(&self, element: I) {
self.queue.lock().await.push_back(element);
self.event_count.fetch_add(1, Ordering::SeqCst);
}
#[inline]
fn dec_count(&self) {
self.event_count.fetch_sub(1, Ordering::SeqCst);
}
#[inline]
fn event_count(&self) -> usize {
self.event_count.load(Ordering::SeqCst)
}
}
#[derive(Copy, Clone, Debug)]
struct IterationState<K> {
active_slot: Slot<K>,
active_slot_idx: usize,
}
#[derive(Copy, Clone, Debug)]
struct Slot<K> {
key: K,
tickets: usize,
}
impl<I, K> WeightedRoundRobin<I, K>
where
I: Serialize,
K: Copy + Clone + Eq + Hash + IntoEnumIterator + Serialize,
{
pub async fn snapshot<S: Serializer>(&self, serializer: S) -> Result<(), S::Error> {
let mut locks = Vec::new();
for kind in K::into_enum_iter() {
let queue_guard = self
.queues
.get(&kind)
.expect("missing queue while snapshotting")
.queue
.lock()
.await;
locks.push((kind, queue_guard));
}
let mut map = serializer.serialize_map(Some(locks.len()))?;
for (kind, guard) in locks {
let vd = &*guard;
map.serialize_key(&kind)?;
map.serialize_value(vd)?;
}
map.end()?;
Ok(())
}
}
impl<I, K> WeightedRoundRobin<I, K>
where
I: Debug,
K: Copy + Clone + Eq + Hash + IntoEnumIterator + Debug,
{
pub async fn debug_dump(&self, file: &mut File) -> Result<(), io::Error> {
let mut locks = Vec::new();
for kind in K::into_enum_iter() {
let queue_guard = self
.queues
.get(&kind)
.expect("missing queue while dumping")
.queue
.lock()
.await;
locks.push((kind, queue_guard));
}
let mut writer = BufWriter::new(file);
for (kind, guard) in locks {
let queue = &*guard;
writer.write_all(format!("Queue: {:?} ({}) [\n", kind, queue.len()).as_bytes())?;
for event in queue.iter() {
writer.write_all(format!("\t{:?}\n", event).as_bytes())?;
}
writer.write_all(b"]\n")?;
}
writer.flush()
}
}
impl<I, K> WeightedRoundRobin<I, K>
where
K: Copy + Clone + Eq + Hash,
{
pub(crate) fn new(weights: Vec<(K, NonZeroUsize)>) -> Self {
assert!(!weights.is_empty(), "must provide at least one slot");
let queues = weights
.iter()
.map(|(idx, _)| (*idx, QueueState::new()))
.collect();
let slots: Vec<Slot<K>> = weights
.into_iter()
.map(|(key, tickets)| Slot {
key,
tickets: tickets.get(),
})
.collect();
let active_slot = slots[0];
WeightedRoundRobin {
state: Mutex::new(IterationState {
active_slot,
active_slot_idx: 0,
}),
slots,
queues,
total: Semaphore::new(0),
}
}
pub(crate) async fn push(&self, item: I, queue: K) {
self.queues
.get(&queue)
.expect("tried to push to non-existent queue")
.push_back(item)
.await;
self.total.add_permits(1);
}
pub(crate) async fn pop(&self) -> (I, K) {
self.total.acquire().await.forget();
let mut inner = self.state.lock().await;
loop {
let queue_state = self
.queues
.get(&inner.active_slot.key)
.expect("the queue disappeared. this should not happen");
let mut current_queue = queue_state.queue.lock().await;
if inner.active_slot.tickets == 0 || current_queue.is_empty() {
inner.active_slot_idx = (inner.active_slot_idx + 1) % self.slots.len();
inner.active_slot = self.slots[inner.active_slot_idx];
continue;
}
inner.active_slot.tickets -= 1;
let item = current_queue
.pop_front()
.expect("item disappeared. this should not happen");
queue_state.dec_count();
break (item, inner.active_slot.key);
}
}
pub(crate) fn item_count(&self) -> usize {
self.total.available_permits()
}
pub(crate) fn event_queues_counts(&self) -> HashMap<K, usize> {
self.queues
.iter()
.map(|(key, queue)| (*key, queue.event_count()))
.collect()
}
}
#[cfg(test)]
mod tests {
use std::num::NonZeroUsize;
use futures::{future::FutureExt, join};
use super::*;
#[repr(usize)]
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
enum QueueKind {
One = 1,
Two,
}
fn weights() -> Vec<(QueueKind, NonZeroUsize)> {
unsafe {
vec![
(QueueKind::One, NonZeroUsize::new_unchecked(1)),
(QueueKind::Two, NonZeroUsize::new_unchecked(2)),
]
}
}
#[tokio::test]
async fn should_respect_weighting() {
let scheduler = WeightedRoundRobin::<char, QueueKind>::new(weights());
let future1 = scheduler
.push('a', QueueKind::One)
.then(|_| scheduler.push('b', QueueKind::One))
.then(|_| scheduler.push('c', QueueKind::One));
let future2 = scheduler
.push('d', QueueKind::Two)
.then(|_| scheduler.push('e', QueueKind::Two))
.then(|_| scheduler.push('f', QueueKind::Two));
join!(future2, future1);
assert_eq!(('a', QueueKind::One), scheduler.pop().await);
assert_eq!(('d', QueueKind::Two), scheduler.pop().await);
assert_eq!(('e', QueueKind::Two), scheduler.pop().await);
assert_eq!(('b', QueueKind::One), scheduler.pop().await);
assert_eq!(('f', QueueKind::Two), scheduler.pop().await);
assert_eq!(('c', QueueKind::One), scheduler.pop().await);
}
}