pub mod builder;
use crate::{
BoxFuture,
blocks::Blocks,
ingress::Ingress,
intervals::Intervals,
task::Task,
traits::{Key, Priority, TaskResult},
worker::Worker,
};
use std::{collections::BinaryHeap, sync::Arc, time::Duration};
use tokio::{
sync::{Mutex, RwLock, oneshot},
time::Instant,
};
#[derive(Debug)]
pub struct Limiter<K: Key, P: Priority, T: TaskResult> {
tasks: Arc<Mutex<BinaryHeap<Task<K, P, T>>>>,
ingress: Ingress<K, P, T>,
workers: Mutex<Vec<Worker>>,
blocks: Arc<RwLock<Blocks<K>>>,
intervals: Arc<RwLock<Intervals<K>>>,
}
impl<K: Key, P: Priority, T: TaskResult> AsRef<Limiter<K, P, T>> for Limiter<K, P, T> {
fn as_ref(&self) -> &Limiter<K, P, T> {
self
}
}
impl<P: Priority, T: TaskResult> Limiter<String, P, T> {
pub fn new<K: Key>(concurrent_tasks: usize) -> Limiter<K, P, T> {
Limiter::new_with(concurrent_tasks, Default::default(), Default::default())
}
pub fn new_with<K: Key>(
concurrent_tasks: usize,
blocks: Blocks<K>,
intervals: Intervals<K>,
) -> Limiter<K, P, T> {
let tasks: Arc<Mutex<BinaryHeap<Task<K, P, T>>>> = Default::default();
let blocks: Arc<RwLock<Blocks<K>>> = Arc::new(RwLock::new(blocks));
let intervals: Arc<RwLock<Intervals<K>>> = Arc::new(RwLock::new(intervals));
let ingress = Ingress::spawn(tasks.clone());
let workers = Mutex::new(
(0..concurrent_tasks)
.map(|_| ingress.spawn_worker(tasks.clone(), blocks.clone(), intervals.clone()))
.collect(),
);
Limiter {
tasks,
blocks,
intervals,
ingress,
workers,
}
}
}
impl<K: Key, P: Priority, T: TaskResult> Limiter<K, P, T> {
pub async fn get_default_block_duration(&self) -> Option<Duration> {
self.blocks.read().await.get_default()
}
pub async fn get_block_duration_by_key(&self, key: &K) -> Option<Duration> {
self.blocks.read().await.get_by_key(key)
}
pub async fn set_default_block_until_at_least(&self, instant: Instant) {
self.blocks.write().await.set_default_at_least(instant);
}
pub async fn set_block_by_key_until_at_least(&self, instant: Instant, key: K) {
self.blocks.write().await.set_at_least_by_key(instant, key);
}
pub async fn set_default_block_until(&self, instant: Option<Instant>) {
self.blocks.write().await.set_default(instant);
}
pub async fn set_block_by_key_until(&self, instant: Option<Instant>, key: K) {
self.blocks.write().await.set_by_key(instant, key);
}
pub async fn set_default_interval_at_least(&self, interval: Duration) {
self.intervals.write().await.set_default_at_least(interval);
}
pub async fn set_interval_by_key_at_least(&self, interval: Duration, key: K) {
self.intervals
.write()
.await
.set_at_least_by_key(interval, key);
}
pub async fn set_default_interval(&self, interval: Option<Duration>) {
self.intervals.write().await.set_default(interval);
}
pub async fn set_interval_by_key(&self, interval: Option<Duration>, key: K) {
self.intervals.write().await.set_by_key(interval, key);
}
pub async fn set_concurrent_tasks(&self, concurrent_tasks: usize) {
let mut guard = self.workers.lock().await;
let len = guard.len();
match len.cmp(&concurrent_tasks) {
std::cmp::Ordering::Less => {
for _ in len..concurrent_tasks {
guard.push(self.ingress.spawn_worker(
self.tasks.clone(),
self.blocks.clone(),
self.intervals.clone(),
));
}
}
std::cmp::Ordering::Equal => {}
std::cmp::Ordering::Greater => {
guard.drain(concurrent_tasks..);
}
}
}
pub async fn queue<J: Future<Output = T> + Send + 'static>(
&self,
job: J,
priority: P,
) -> BoxFuture<T> {
let (reply_sender, reply_receiver) = oneshot::channel();
self.ingress
.send(Task::new(job, priority, reply_sender))
.await;
Box::pin(async move { reply_receiver.await.expect("reply_sender should not drop") })
}
pub async fn queue_by_key<J: Future<Output = T> + Send + 'static>(
&self,
job: J,
priority: P,
key: K,
) -> BoxFuture<T> {
let (reply_sender, reply_receiver) = oneshot::channel();
self.ingress
.send(Task::new_with_key(job, priority, reply_sender, key))
.await;
Box::pin(async move { reply_receiver.await.expect("reply_sender should not drop") })
}
}
#[cfg(test)]
mod tests {
use crate::limiter::builder::LimiterBuilder;
use super::*;
use futures::future::join_all;
use std::sync::Arc;
use tokio::sync::Mutex;
#[tokio::test]
async fn it_should_work() {
use Prio::*;
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord)]
enum Prio {
Low,
Mid,
High,
}
let limiter = LimiterBuilder::new::<String>(0).build();
let acc: Arc<Mutex<Vec<Prio>>> = Default::default();
let futures = [
limiter
.queue(
{
let results = acc.clone();
async move {
results.lock().await.push(High);
1
}
},
High,
)
.await,
limiter
.queue(
{
let results = acc.clone();
async move {
results.lock().await.push(Mid);
2
}
},
Mid,
)
.await,
limiter
.queue(
{
let results = acc.clone();
async move {
results.lock().await.push(Low);
3
}
},
Low,
)
.await,
limiter
.queue(
{
let results = acc.clone();
async move {
results.lock().await.push(Low);
4
}
},
Low,
)
.await,
limiter
.queue(
{
let results = acc.clone();
async move {
results.lock().await.push(Mid);
5
}
},
Mid,
)
.await,
limiter
.queue(
{
let results = acc.clone();
async move {
results.lock().await.push(High);
6
}
},
High,
)
.await,
limiter
.queue(
{
let results = acc.clone();
async move {
results.lock().await.push(Mid);
7
}
},
Mid,
)
.await,
limiter
.queue(
{
let results = acc.clone();
async move {
results.lock().await.push(Low);
8
}
},
Low,
)
.await,
limiter
.queue(
{
let results = acc.clone();
async move {
results.lock().await.push(High);
9
}
},
High,
)
.await,
];
limiter
.set_default_interval(Some(Duration::from_millis(100)))
.await;
limiter.set_concurrent_tasks(2).await;
let order = join_all(futures).await;
let acc = acc.lock().await.clone();
assert_eq!(order, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
assert_eq!(acc, [High, High, High, Mid, Mid, Mid, Low, Low, Low]);
}
#[tokio::test]
#[should_panic(expected = "reply_sender should not drop")]
async fn panic_in_task_causes_cascading_panic() {
let limiter = Limiter::new::<String>(1);
let future = limiter
.queue(
async {
panic!("Intentional panic in task");
},
1,
)
.await;
future.await;
}
#[tokio::test]
async fn zero_concurrent_tasks_still_queues() {
let limiter = Limiter::new::<String>(0);
let start = Instant::now();
let fut = limiter.queue(async { 42 }, 1).await;
tokio::time::sleep(Duration::from_millis(10)).await;
limiter.set_concurrent_tasks(1).await;
let result = fut.await;
assert_eq!(result, 42);
assert!(start.elapsed() >= Duration::from_millis(10));
}
#[tokio::test]
async fn dynamic_worker_scaling() {
let limiter = Limiter::new::<String>(1);
let counter = Arc::new(Mutex::new(0));
let futures: Vec<_> = (0..10)
.map(|_| {
let counter = counter.clone();
limiter.queue(
async move {
*counter.lock().await += 1;
tokio::time::sleep(Duration::from_millis(50)).await;
},
1,
)
})
.collect();
let futures = join_all(futures).await;
let start = Instant::now();
tokio::time::sleep(Duration::from_millis(100)).await;
limiter.set_concurrent_tasks(5).await;
join_all(futures).await;
let elapsed = start.elapsed();
assert_eq!(*counter.lock().await, 10);
assert!(elapsed < Duration::from_millis(500));
}
#[tokio::test]
async fn read_locks_dont_block_reads() {
let limiter: Limiter<String, i32, ()> = Limiter::new(1);
limiter
.set_default_block_until(Some(Instant::now() + Duration::from_secs(10)))
.await;
let start = Instant::now();
let (r1, r2, r3) = tokio::join!(
limiter.get_default_block_duration(),
limiter.get_default_block_duration(),
limiter.get_default_block_duration(),
);
let elapsed = start.elapsed();
assert!(r1.is_some());
assert!(r2.is_some());
assert!(r3.is_some());
assert!(elapsed < Duration::from_millis(10));
}
}