use crate::error::PoolError;
use crate::capacity_gate::CapacityGate;
use crate::task::ManagedTaskInternal;
use fibre::mpsc::{self, UnboundedAsyncReceiver, UnboundedAsyncSender, RecvError};
use std::fmt;
use std::sync::Arc;
use tokio_util::sync::CancellationToken;
pub(crate) struct QueueMessage<R: Send + 'static> {
pub(crate) task: ManagedTaskInternal<R>,
_permit: Permit,
}
impl<R: Send + 'static> fmt::Debug for QueueMessage<R> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("QueueMessage")
.field("task_id", &self.task.task_id)
.finish_non_exhaustive()
}
}
#[derive(Debug)]
pub(crate) struct Permit {
gate: Arc<CapacityGate>,
}
impl Drop for Permit {
fn drop(&mut self) {
self.gate.release();
}
}
#[derive(Debug)]
pub(crate) struct TaskQueue<R: Send + 'static> {
tx: UnboundedAsyncSender<QueueMessage<R>>,
rx: UnboundedAsyncReceiver<QueueMessage<R>>,
gate: Arc<CapacityGate>,
}
impl<R: Send + 'static> TaskQueue<R> {
pub(crate) fn new(capacity: usize) -> Self {
let (tx, rx) = mpsc::unbounded_async();
Self {
tx,
rx,
gate: Arc::new(CapacityGate::new(capacity)),
}
}
pub(crate) fn split(self) -> (QueueProducer<R>, QueueConsumer<R>) {
(
QueueProducer {
tx: self.tx,
gate: self.gate,
},
QueueConsumer { rx: self.rx },
)
}
}
#[derive(Clone)]
pub(crate) struct QueueProducer<R: Send + 'static> {
tx: UnboundedAsyncSender<QueueMessage<R>>,
gate: Arc<CapacityGate>,
}
#[derive(Debug)]
pub(crate) struct QueueConsumer<R: Send + 'static> {
rx: UnboundedAsyncReceiver<QueueMessage<R>>,
}
impl<R: Send + 'static> fmt::Debug for QueueProducer<R> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("QueueProducer")
.field("len", &self.len())
.field("gate_permits", &self.gate.get_permits())
.finish_non_exhaustive()
}
}
impl<R: Send + 'static> QueueProducer<R> {
pub(crate) async fn send(
&self,
task: ManagedTaskInternal<R>,
shutdown_token: &CancellationToken,
) -> Result<(), PoolError> {
if shutdown_token.is_cancelled() || self.tx.is_closed() {
return Err(PoolError::PoolShuttingDown);
}
let temp_permit_guard;
tokio::select! {
biased;
_ = shutdown_token.cancelled() => return Err(PoolError::PoolShuttingDown),
guard = self.gate.acquire() => {
temp_permit_guard = guard;
},
};
let long_lived_permit = Permit {
gate: self.gate.clone(),
};
let message = QueueMessage {
task,
_permit: long_lived_permit,
};
if self.tx.send(message).await.is_ok() {
std::mem::forget(temp_permit_guard);
Ok(())
} else {
Err(PoolError::QueueSendChannelClosed)
}
}
pub(crate) fn close(&self) {
let _ = self.tx.close();
}
pub(crate) fn is_closed(&self) -> bool {
self.tx.is_closed()
}
pub(crate) fn len(&self) -> usize {
self.tx.len()
}
}
impl<R: Send + 'static> QueueConsumer<R> {
pub(crate) async fn recv(&self) -> Result<ManagedTaskInternal<R>, RecvError> {
match self.rx.recv().await {
Ok(message) => Ok(message.task),
Err(e) => Err(e),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::task::TaskToExecute;
use std::collections::HashSet;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use fibre::oneshot::oneshot;
fn dummy_task(id: u64) -> ManagedTaskInternal<String> {
let future: TaskToExecute<String> = Box::pin(async move { "done".to_string() });
let (tx, _) = oneshot();
ManagedTaskInternal {
task_id: id,
labels: HashSet::new(),
future,
token: CancellationToken::new(),
result_sender: Some(tx),
}
}
#[tokio::test]
async fn test_queue_send_recv() {
let queue = TaskQueue::<String>::new(5);
let (producer, consumer) = queue.split();
let shutdown_token = CancellationToken::new();
assert_eq!(producer.gate.get_permits(), 5);
producer.send(dummy_task(1), &shutdown_token).await.unwrap();
assert_eq!(producer.gate.get_permits(), 4);
let received_task = consumer.recv().await.unwrap();
assert_eq!(received_task.task_id, 1);
assert_eq!(producer.gate.get_permits(), 5);
}
#[tokio::test]
async fn test_queue_capacity_blocks_send() {
let queue = TaskQueue::<String>::new(1);
let (producer, consumer) = queue.split();
let shutdown_token = CancellationToken::new();
producer.send(dummy_task(1), &shutdown_token).await.unwrap();
assert_eq!(producer.gate.get_permits(), 0);
let send_future = producer.send(dummy_task(2), &shutdown_token);
tokio::pin!(send_future);
tokio::select! {
_ = &mut send_future => {
panic!("Send should have blocked because the queue is full.");
},
_ = tokio::time::sleep(Duration::from_millis(50)) => {
}
}
assert_eq!(producer.gate.get_permits(), 0);
let received_task = consumer.recv().await.unwrap();
assert_eq!(received_task.task_id, 1);
assert_eq!(producer.gate.get_permits(), 1);
tokio::time::timeout(Duration::from_millis(50), send_future)
.await
.expect("Send did not complete after queue was drained.")
.unwrap();
assert_eq!(producer.gate.get_permits(), 0);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_queue_concurrent_sends() {
let queue = TaskQueue::<String>::new(4);
let (producer, consumer) = queue.split();
let shutdown_token = CancellationToken::new();
let num_tasks: u64 = 20;
let received_count = Arc::new(AtomicUsize::new(0));
let producer_handle = {
let producer = producer.clone();
tokio::spawn(async move {
let mut handles = Vec::new();
for i in 0..num_tasks {
let p = producer.clone();
let s = shutdown_token.clone();
handles.push(tokio::spawn(async move {
p.send(dummy_task(i), &s).await.unwrap();
}));
}
for handle in handles {
handle.await.unwrap();
}
})
};
let consumer_handle = {
let received_count = received_count.clone();
tokio::spawn(async move {
for _ in 0..num_tasks {
if consumer.recv().await.is_ok() {
received_count.fetch_add(1, Ordering::SeqCst);
}
}
})
};
producer_handle.await.unwrap();
consumer_handle.await.unwrap();
assert_eq!(received_count.load(Ordering::SeqCst), num_tasks as usize);
assert_eq!(producer.gate.get_permits(), 4);
}
#[tokio::test]
async fn test_send_respects_shutdown_token() {
let queue = TaskQueue::<String>::new(1);
let (producer, _consumer) = queue.split();
let shutdown_token = CancellationToken::new();
producer.send(dummy_task(1), &shutdown_token).await.unwrap();
shutdown_token.cancel();
let result = producer.send(dummy_task(2), &shutdown_token).await;
assert!(matches!(result, Err(PoolError::PoolShuttingDown)));
assert_eq!(
producer.gate.get_permits(),
0,
"Permit should still be held by first task"
);
}
#[tokio::test]
async fn test_close_sender_stops_consumer() {
let queue = TaskQueue::<String>::new(2);
let (producer, consumer) = queue.split();
let shutdown_token = CancellationToken::new();
producer.send(dummy_task(1), &shutdown_token).await.unwrap();
producer.close();
assert_eq!(consumer.recv().await.unwrap().task_id, 1);
assert_eq!(producer.gate.get_permits(), 2);
let result = consumer.recv().await;
assert!(matches!(result, Err(RecvError::Disconnected)));
}
}