async-priority-limiter 0.4.4

Throttles prioritised tasks by limiting the max concurrent tasks and minimum time between tasks, with up to two levels based on keys
Documentation
pub mod builder;

use crate::{
    BoxFuture,
    blocks::Blocks,
    ingress::Ingress,
    intervals::Intervals,
    task::Task,
    traits::{Key, Priority, TaskResult},
    worker::Worker,
};

use std::{collections::BinaryHeap, sync::Arc, time::Duration};
use tokio::{
    sync::{Mutex, RwLock, oneshot},
    time::Instant,
};

#[derive(Debug)]
pub struct Limiter<K: Key, P: Priority, T: TaskResult> {
    tasks: Arc<Mutex<BinaryHeap<Task<K, P, T>>>>,
    ingress: Ingress<K, P, T>,
    workers: Mutex<Vec<Worker>>,
    blocks: Arc<RwLock<Blocks<K>>>,
    intervals: Arc<RwLock<Intervals<K>>>,
}

impl<K: Key, P: Priority, T: TaskResult> AsRef<Limiter<K, P, T>> for Limiter<K, P, T> {
    fn as_ref(&self) -> &Limiter<K, P, T> {
        self
    }
}

impl<P: Priority, T: TaskResult> Limiter<String, P, T> {
    pub fn new<K: Key>(concurrent_tasks: usize) -> Limiter<K, P, T> {
        Limiter::new_with(concurrent_tasks, Default::default(), Default::default())
    }

    pub fn new_with<K: Key>(
        concurrent_tasks: usize,
        blocks: Blocks<K>,
        intervals: Intervals<K>,
    ) -> Limiter<K, P, T> {
        let tasks: Arc<Mutex<BinaryHeap<Task<K, P, T>>>> = Default::default();
        let blocks: Arc<RwLock<Blocks<K>>> = Arc::new(RwLock::new(blocks));
        let intervals: Arc<RwLock<Intervals<K>>> = Arc::new(RwLock::new(intervals));
        let ingress = Ingress::spawn(tasks.clone());
        let workers = Mutex::new(
            (0..concurrent_tasks)
                .map(|_| ingress.spawn_worker(tasks.clone(), blocks.clone(), intervals.clone()))
                .collect(),
        );

        Limiter {
            tasks,
            blocks,
            intervals,
            ingress,
            workers,
        }
    }
}

impl<K: Key, P: Priority, T: TaskResult> Limiter<K, P, T> {
    pub async fn get_default_block_duration(&self) -> Option<Duration> {
        self.blocks.read().await.get_default()
    }

    pub async fn get_block_duration_by_key(&self, key: &K) -> Option<Duration> {
        self.blocks.read().await.get_by_key(key)
    }

    pub async fn set_default_block_until_at_least(&self, instant: Instant) {
        self.blocks.write().await.set_default_at_least(instant);
    }

    pub async fn set_block_by_key_until_at_least(&self, instant: Instant, key: K) {
        self.blocks.write().await.set_at_least_by_key(instant, key);
    }

    pub async fn set_default_block_until(&self, instant: Option<Instant>) {
        self.blocks.write().await.set_default(instant);
    }

    pub async fn set_block_by_key_until(&self, instant: Option<Instant>, key: K) {
        self.blocks.write().await.set_by_key(instant, key);
    }

    pub async fn set_default_interval_at_least(&self, interval: Duration) {
        self.intervals.write().await.set_default_at_least(interval);
    }

    pub async fn set_interval_by_key_at_least(&self, interval: Duration, key: K) {
        self.intervals
            .write()
            .await
            .set_at_least_by_key(interval, key);
    }

    pub async fn set_default_interval(&self, interval: Option<Duration>) {
        self.intervals.write().await.set_default(interval);
    }

    pub async fn set_interval_by_key(&self, interval: Option<Duration>, key: K) {
        self.intervals.write().await.set_by_key(interval, key);
    }

    pub async fn set_concurrent_tasks(&self, concurrent_tasks: usize) {
        let mut guard = self.workers.lock().await;
        let len = guard.len();

        match len.cmp(&concurrent_tasks) {
            std::cmp::Ordering::Less => {
                for _ in len..concurrent_tasks {
                    guard.push(self.ingress.spawn_worker(
                        self.tasks.clone(),
                        self.blocks.clone(),
                        self.intervals.clone(),
                    ));
                }
            }
            std::cmp::Ordering::Equal => {}
            std::cmp::Ordering::Greater => {
                guard.drain(concurrent_tasks..);
            }
        }
    }

    pub async fn queue<J: Future<Output = T> + Send + 'static>(
        &self,
        job: J,
        priority: P,
    ) -> BoxFuture<T> {
        let (reply_sender, reply_receiver) = oneshot::channel();

        self.ingress
            .send(Task::new(job, priority, reply_sender))
            .await;

        Box::pin(async move { reply_receiver.await.expect("reply_sender should not drop") })
    }

    pub async fn queue_by_key<J: Future<Output = T> + Send + 'static>(
        &self,
        job: J,
        priority: P,
        key: K,
    ) -> BoxFuture<T> {
        let (reply_sender, reply_receiver) = oneshot::channel();

        self.ingress
            .send(Task::new_with_key(job, priority, reply_sender, key))
            .await;

        Box::pin(async move { reply_receiver.await.expect("reply_sender should not drop") })
    }
}

#[cfg(test)]
mod tests {
    use crate::limiter::builder::LimiterBuilder;

    use super::*;
    use futures::future::join_all;
    use std::sync::Arc;
    use tokio::sync::Mutex;

    #[tokio::test]
    async fn it_should_work() {
        use Prio::*;

        #[derive(Debug, Clone, Copy, PartialEq, PartialOrd, Eq, Ord)]
        enum Prio {
            Low,
            Mid,
            High,
        }

        let limiter = LimiterBuilder::new::<String>(0).build();

        let acc: Arc<Mutex<Vec<Prio>>> = Default::default();

        let futures = [
            limiter
                .queue(
                    {
                        let results = acc.clone();
                        async move {
                            results.lock().await.push(High);
                            1
                        }
                    },
                    High,
                )
                .await,
            limiter
                .queue(
                    {
                        let results = acc.clone();
                        async move {
                            results.lock().await.push(Mid);
                            2
                        }
                    },
                    Mid,
                )
                .await,
            limiter
                .queue(
                    {
                        let results = acc.clone();
                        async move {
                            results.lock().await.push(Low);
                            3
                        }
                    },
                    Low,
                )
                .await,
            limiter
                .queue(
                    {
                        let results = acc.clone();
                        async move {
                            results.lock().await.push(Low);
                            4
                        }
                    },
                    Low,
                )
                .await,
            limiter
                .queue(
                    {
                        let results = acc.clone();
                        async move {
                            results.lock().await.push(Mid);
                            5
                        }
                    },
                    Mid,
                )
                .await,
            limiter
                .queue(
                    {
                        let results = acc.clone();
                        async move {
                            results.lock().await.push(High);
                            6
                        }
                    },
                    High,
                )
                .await,
            limiter
                .queue(
                    {
                        let results = acc.clone();
                        async move {
                            results.lock().await.push(Mid);
                            7
                        }
                    },
                    Mid,
                )
                .await,
            limiter
                .queue(
                    {
                        let results = acc.clone();
                        async move {
                            results.lock().await.push(Low);
                            8
                        }
                    },
                    Low,
                )
                .await,
            limiter
                .queue(
                    {
                        let results = acc.clone();
                        async move {
                            results.lock().await.push(High);
                            9
                        }
                    },
                    High,
                )
                .await,
        ];

        limiter
            .set_default_interval(Some(Duration::from_millis(100)))
            .await;

        limiter.set_concurrent_tasks(2).await;

        let order = join_all(futures).await;
        let acc = acc.lock().await.clone();

        assert_eq!(order, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);

        assert_eq!(acc, [High, High, High, Mid, Mid, Mid, Low, Low, Low]);
    }

    #[tokio::test]
    #[should_panic(expected = "reply_sender should not drop")]
    async fn panic_in_task_causes_cascading_panic() {
        let limiter = Limiter::new::<String>(1);

        let future = limiter
            .queue(
                async {
                    panic!("Intentional panic in task");
                },
                1,
            )
            .await;

        // This should panic when the worker panics and drops the reply sender
        future.await;
    }

    #[tokio::test]
    async fn zero_concurrent_tasks_still_queues() {
        let limiter = Limiter::new::<String>(0);
        let start = Instant::now();

        let fut = limiter.queue(async { 42 }, 1).await;

        // With zero workers, the task should queue but not execute yet
        tokio::time::sleep(Duration::from_millis(10)).await;

        // Now add a worker
        limiter.set_concurrent_tasks(1).await;

        let result = fut.await;
        assert_eq!(result, 42);
        assert!(start.elapsed() >= Duration::from_millis(10));
    }

    #[tokio::test]
    async fn dynamic_worker_scaling() {
        let limiter = Limiter::new::<String>(1);
        let counter = Arc::new(Mutex::new(0));

        let futures: Vec<_> = (0..10)
            .map(|_| {
                let counter = counter.clone();
                limiter.queue(
                    async move {
                        *counter.lock().await += 1;
                        tokio::time::sleep(Duration::from_millis(50)).await;
                    },
                    1,
                )
            })
            .collect();

        let futures = join_all(futures).await;

        // Start with 1 worker - should take ~500ms
        let start = Instant::now();

        // Scale up to 5 workers after 100ms
        tokio::time::sleep(Duration::from_millis(100)).await;
        limiter.set_concurrent_tasks(5).await;

        join_all(futures).await;
        let elapsed = start.elapsed();

        assert_eq!(*counter.lock().await, 10);
        // Should complete faster than 10 * 50ms due to scaling
        assert!(elapsed < Duration::from_millis(500));
    }

    #[tokio::test]
    async fn read_locks_dont_block_reads() {
        let limiter: Limiter<String, i32, ()> = Limiter::new(1);
        limiter
            .set_default_block_until(Some(Instant::now() + Duration::from_secs(10)))
            .await;

        // Multiple concurrent reads should not block each other
        let start = Instant::now();
        let (r1, r2, r3) = tokio::join!(
            limiter.get_default_block_duration(),
            limiter.get_default_block_duration(),
            limiter.get_default_block_duration(),
        );

        let elapsed = start.elapsed();

        assert!(r1.is_some());
        assert!(r2.is_some());
        assert!(r3.is_some());
        // Should complete nearly instantly, not sequentially blocked
        assert!(elapsed < Duration::from_millis(10));
    }
}