use super::{IDENTIFY_DELAY, LIMIT_PERIOD, Queue};
use std::{collections::VecDeque, fmt::Debug, iter};
use tokio::{
sync::{mpsc, oneshot},
task::yield_now,
time::{Duration, Instant, sleep_until},
};
#[derive(Debug)]
enum Message {
Request {
shard: u32,
tx: oneshot::Sender<()>,
},
Update(Settings),
}
#[derive(Debug)]
struct Settings {
max_concurrency: u16,
remaining: u32,
reset_after: Duration,
total: u32,
}
async fn runner(
mut rx: mpsc::UnboundedReceiver<Message>,
Settings {
max_concurrency,
mut remaining,
reset_after,
mut total,
}: Settings,
) {
let (interval, reset_at) = {
let now = Instant::now();
(sleep_until(now), sleep_until(now + reset_after))
};
tokio::pin!(interval, reset_at);
let mut queues = iter::repeat_with(VecDeque::new)
.take(max_concurrency.into())
.collect::<Box<_>>();
#[allow(clippy::ignored_unit_patterns)]
loop {
tokio::select! {
biased;
_ = &mut reset_at, if remaining != total => {
remaining = total;
}
message = rx.recv() => {
match message {
Some(Message::Request { shard, tx }) => {
if queues.is_empty() {
_ = tx.send(());
} else {
let key = shard as usize % queues.len();
queues[key].push_back((shard, tx));
}
}
Some(Message::Update(update)) => {
let (max_concurrency, reset_after);
Settings {
max_concurrency,
remaining,
reset_after,
total,
} = update;
if remaining != total {
reset_at.as_mut().reset(Instant::now() + reset_after);
}
if max_concurrency as usize != queues.len() {
let unbalanced = queues.into_vec().into_iter().flatten();
queues = iter::repeat_with(VecDeque::new)
.take(max_concurrency.into())
.collect();
for (shard, tx) in unbalanced {
let key = (shard % u32::from(max_concurrency)) as usize;
queues[key].push_back((shard, tx));
}
}
}
None => break,
}
}
_ = &mut interval, if queues.iter().any(|queue| !queue.is_empty()) => {
let now = Instant::now();
let span = tracing::info_span!("bucket", moment = ?now);
interval.as_mut().reset(now + IDENTIFY_DELAY);
if remaining == total {
reset_at.as_mut().reset(now + LIMIT_PERIOD);
}
for (key, queue) in queues.iter_mut().enumerate() {
if remaining == 0 {
tracing::debug!(
refill_delay = ?reset_at.deadline().saturating_duration_since(now),
"exhausted available permits"
);
(&mut reset_at).await;
remaining = total;
break;
}
while let Some((shard, tx)) = queue.pop_front() {
if tx.send(()).is_err() {
continue;
}
tracing::debug!(parent: &span, key, shard);
remaining -= 1;
yield_now().await;
break;
}
}
}
}
}
}
#[derive(Clone, Debug)]
pub struct InMemoryQueue {
tx: mpsc::UnboundedSender<Message>,
}
impl InMemoryQueue {
pub fn new(max_concurrency: u16, remaining: u32, reset_after: Duration, total: u32) -> Self {
assert!(total >= remaining);
let (tx, rx) = mpsc::unbounded_channel();
tokio::spawn(runner(
rx,
Settings {
max_concurrency,
remaining,
reset_after,
total,
},
));
Self { tx }
}
pub fn update(&self, max_concurrency: u16, remaining: u32, reset_after: Duration, total: u32) {
assert!(total >= remaining);
self.tx
.send(Message::Update(Settings {
max_concurrency,
remaining,
reset_after,
total,
}))
.expect("receiver dropped after sender");
}
}
impl Default for InMemoryQueue {
fn default() -> Self {
Self::new(1, 1000, LIMIT_PERIOD, 1000)
}
}
impl Queue for InMemoryQueue {
fn enqueue(&self, shard: u32) -> oneshot::Receiver<()> {
let (tx, rx) = oneshot::channel();
self.tx
.send(Message::Request { shard, tx })
.expect("receiver dropped after sender");
rx
}
}
#[cfg(test)]
mod tests {
use super::InMemoryQueue;
use crate::Queue;
use static_assertions::assert_impl_all;
use std::fmt::Debug;
assert_impl_all!(InMemoryQueue: Clone, Debug, Default, Send, Sync, Queue);
}