#![recursion_limit = "256"]
#![deny(missing_docs)]
use futures_channel::mpsc::{self, Receiver, Sender, UnboundedReceiver, UnboundedSender};
use futures_util::{
ready, select,
stream::{FuturesUnordered, StreamExt as _},
};
use std::{
collections::VecDeque,
error, fmt,
future::Future,
pin::Pin,
sync::{
atomic::{AtomicBool, AtomicUsize, Ordering},
Arc, Mutex,
},
task::{Context, Poll, Waker},
time::Duration,
};
#[derive(Debug)]
pub enum Error {
AlreadyStarted,
TaskSendError(mpsc::SendError),
NewTaskError,
}
impl fmt::Display for Error {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
use self::Error::*;
match self {
AlreadyStarted => write!(fmt, "already started"),
TaskSendError(e) => write!(fmt, "failed to send task to coordinator: {}", e),
NewTaskError => write!(fmt, "failed to queue up new task coordinator"),
}
}
}
impl error::Error for Error {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
use self::Error::*;
match self {
TaskSendError(e) => Some(e),
_ => None,
}
}
}
struct NewTask {
inner: Arc<Inner>,
rx: Receiver<Task>,
}
struct LeakyBucketsInner {
tx: UnboundedSender<NewTask>,
rx: Mutex<Option<UnboundedReceiver<NewTask>>>,
}
#[derive(Clone)]
pub struct LeakyBuckets {
inner: Arc<LeakyBucketsInner>,
}
impl Default for LeakyBuckets {
fn default() -> Self {
Self::new()
}
}
impl LeakyBuckets {
pub fn new() -> Self {
let (tx, rx) = mpsc::unbounded();
let inner = Arc::new(LeakyBucketsInner {
tx,
rx: Mutex::new(Some(rx)),
});
LeakyBuckets { inner }
}
pub async fn coordinate(self) -> Result<(), Error> {
let mut rx = match self.inner.rx.lock().expect("ok mutex").take() {
Some(rx) => rx,
None => return Err(Error::AlreadyStarted),
};
let mut futures = FuturesUnordered::new();
loop {
while futures.is_empty() {
select! {
NewTask { inner, rx } = rx.select_next_some() => {
futures.push(inner.coordinate(rx));
}
}
}
select! {
_ = futures.next() => {
panic!("coordinator task exited unexpectedly");
}
NewTask { inner, rx } = rx.select_next_some() => {
futures.push(inner.coordinate(rx));
}
}
}
}
pub fn rate_limiter(&self) -> Builder {
Builder {
new_task_tx: self.inner.tx.clone(),
tokens: None,
max: None,
refill_interval: None,
refill_amount: None,
}
}
}
pub struct Builder {
new_task_tx: UnboundedSender<NewTask>,
tokens: Option<usize>,
max: Option<usize>,
refill_interval: Option<Duration>,
refill_amount: Option<usize>,
}
impl Builder {
#[inline(always)]
pub fn max(mut self, max: usize) -> Self {
self.max = Some(max);
self
}
#[inline(always)]
pub fn tokens(mut self, tokens: usize) -> Self {
self.tokens = Some(tokens);
self
}
#[inline(always)]
pub fn refill_interval(mut self, refill_interval: Duration) -> Self {
self.refill_interval = Some(refill_interval);
self
}
#[inline(always)]
pub fn refill_amount(mut self, refill_amount: usize) -> Self {
self.refill_amount = Some(refill_amount);
self
}
pub fn build(self) -> Result<LeakyBucket, Error> {
const DEFAULT_MAX: usize = 120;
const DEFAULT_TOKENS: usize = 0;
const DEFAULT_REFILL_INTERVAL: Duration = Duration::from_secs(1);
const DEFAULT_REFILL_AMOUNT: usize = 1;
let max = self.max.unwrap_or(DEFAULT_MAX);
let tokens = max.saturating_sub(self.tokens.unwrap_or(DEFAULT_TOKENS));
let refill_interval = self.refill_interval.unwrap_or(DEFAULT_REFILL_INTERVAL);
let refill_amount = self.refill_amount.unwrap_or(DEFAULT_REFILL_AMOUNT);
let tokens = AtomicUsize::new(tokens);
let (tx, rx) = mpsc::channel(1);
let inner = Arc::new(Inner {
tokens,
max,
refill_interval,
refill_amount,
tx,
});
self.new_task_tx
.unbounded_send(NewTask {
inner: inner.clone(),
rx,
})
.map_err(|_| Error::NewTaskError)?;
Ok(LeakyBucket { inner })
}
}
struct Task {
required: usize,
waker: Waker,
complete: Arc<AtomicBool>,
}
struct Inner {
tokens: AtomicUsize,
max: usize,
refill_interval: Duration,
refill_amount: usize,
tx: Sender<Task>,
}
impl Inner {
async fn coordinate(self: Arc<Inner>, mut rx: Receiver<Task>) -> Result<(), Error> {
let mut tasks = VecDeque::new();
let mut interval = tokio_timer::Interval::new_interval(self.refill_interval);
let mut amount = 0;
let mut current = None;
'outer: loop {
select! {
waker = rx.select_next_some() => {
tasks.push_back(waker);
},
_ = interval.select_next_some() => {
amount += self.refill_amount;
let mut task = match current.take().or_else(|| tasks.pop_front()) {
Some(task) => task,
None => {
self.balance_tokens(amount);
amount = 0;
continue;
}
};
while amount > 0 && amount >= task.required {
amount -= task.required;
task.complete.store(true, Ordering::Release);
task.waker.wake();
task = match tasks.pop_front() {
Some(task) => task,
None => {
if amount > 0 {
self.balance_tokens(amount);
amount = 0;
}
continue 'outer;
},
};
}
current = Some(task);
if tasks.is_empty() {
self.balance_tokens(amount);
}
},
}
}
}
fn balance_tokens(&self, amount: usize) {
let mut current = self.tokens.load(Ordering::Acquire);
while current > 0 {
let new = current.saturating_sub(amount);
match self.tokens.compare_exchange_weak(
current,
new,
Ordering::SeqCst,
Ordering::Acquire,
) {
Ok(_) => break,
Err(x) => current = x,
}
}
}
}
#[derive(Clone)]
pub struct LeakyBucket {
inner: Arc<Inner>,
}
impl LeakyBucket {
pub fn acquire_one(&self) -> Acquire<'_> {
self.acquire(1)
}
pub fn acquire(&self, amount: usize) -> Acquire<'_> {
Acquire {
tokens: &self.inner.tokens,
max: self.inner.max,
amount,
tx: self.inner.tx.clone(),
queued: None,
}
}
}
pub struct Acquire<'a> {
tokens: &'a AtomicUsize,
max: usize,
amount: usize,
tx: Sender<Task>,
queued: Option<Arc<AtomicBool>>,
}
impl Future for Acquire<'_> {
type Output = Result<(), Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if let Some(complete) = &self.queued {
return if complete.load(Ordering::Acquire) {
Poll::Ready(Ok(()))
} else {
Poll::Pending
};
}
let mut required = self.amount;
let current = self.tokens.fetch_add(required, Ordering::AcqRel);
if current + required < self.max {
return Poll::Ready(Ok(()));
}
if current < self.max {
required -= self.max - current;
}
match ready!(self.tx.poll_ready(cx)) {
Ok(()) => (),
Err(e) => return Poll::Ready(Err(Error::TaskSendError(e))),
}
let waker = cx.waker().clone();
let complete = Arc::new(AtomicBool::new(false));
self.queued = Some(complete.clone());
if let Err(e) = self.tx.start_send(Task {
required,
waker,
complete,
}) {
return Poll::Ready(Err(Error::TaskSendError(e)));
}
Poll::Pending
}
}
#[cfg(test)]
mod tests {
use super::{Error, LeakyBuckets};
use futures::prelude::*;
use std::time::{Duration, Instant};
use tokio::{runtime::current_thread::Runtime, timer};
#[test]
fn test_leaky_bucket() {
let mut rt = Runtime::new().expect("working runtime");
rt.block_on(async move {
let interval = Duration::from_millis(20);
let buckets = LeakyBuckets::new();
let leaky = buckets
.rate_limiter()
.tokens(0)
.max(10)
.refill_amount(10)
.refill_interval(interval)
.build()
.expect("build rate limiter");
let mut wakeups = 0;
let mut duration = None;
let test = async {
let start = Instant::now();
leaky.acquire(10).await?;
wakeups += 1;
leaky.acquire(10).await?;
wakeups += 1;
leaky.acquire(10).await?;
wakeups += 1;
duration = Some(Instant::now().duration_since(start));
Ok::<_, Error>(())
};
futures::future::select(test.boxed(), buckets.coordinate().boxed()).await;
assert_eq!(3, wakeups);
assert!(duration.expect("expected measured duration") > interval * 2);
});
}
#[test]
fn test_concurrent_rate_limited() {
let mut rt = Runtime::new().expect("working runtime");
rt.block_on(async move {
let interval = Duration::from_millis(20);
let buckets = LeakyBuckets::new();
let leaky = buckets
.rate_limiter()
.tokens(0)
.max(10)
.refill_amount(1)
.refill_interval(interval)
.build()
.expect("build rate limiter");
let mut one_wakeups = 0;
let one = async {
loop {
leaky.acquire(1).await?;
one_wakeups += 1;
}
#[allow(unreachable_code)]
Ok::<_, Error>(())
};
let mut two_wakeups = 0;
let two = async {
loop {
leaky.acquire(1).await?;
two_wakeups += 1;
}
#[allow(unreachable_code)]
Ok::<_, Error>(())
};
let delay = timer::delay(Instant::now() + Duration::from_millis(200));
let task = future::select(one.boxed(), two.boxed());
let task = future::select(task, delay);
future::select(task, buckets.coordinate().boxed()).await;
let total = one_wakeups + two_wakeups;
assert!(total > 5 && total < 15);
});
}
}