use tokio::sync::oneshot;
use uuid::Uuid;
use crate::{types::internal::SolveTask, Order, SingleOrderQuote, SolveError};
#[derive(Debug, Clone)]
pub struct TaskQueueConfig {
pub capacity: usize,
}
impl Default for TaskQueueConfig {
fn default() -> Self {
Self { capacity: 1000 }
}
}
#[derive(Clone)]
pub struct TaskQueueHandle {
sender: async_channel::Sender<SolveTask>,
}
impl TaskQueueHandle {
pub async fn enqueue(&self, order: Order) -> Result<SingleOrderQuote, SolveError> {
let (response_tx, response_rx) = oneshot::channel();
let task_id = Uuid::new_v4();
let task = SolveTask::new(task_id, order, response_tx);
self.sender
.send(task)
.await
.map_err(|_| SolveError::QueueFull)?;
response_rx
.await
.map_err(|_| SolveError::Internal("worker dropped response channel".to_string()))?
}
#[cfg(test)]
pub fn approximate_depth(&self) -> usize {
self.sender.len()
}
#[cfg(test)]
pub fn is_full(&self) -> bool {
self.sender.is_full()
}
pub fn from_sender(sender: async_channel::Sender<SolveTask>) -> Self {
Self { sender }
}
}
pub struct TaskQueue {
receiver: async_channel::Receiver<SolveTask>,
handle: TaskQueueHandle,
}
impl TaskQueue {
pub fn new(config: TaskQueueConfig) -> Self {
let (sender, receiver) = async_channel::bounded(config.capacity);
let handle = TaskQueueHandle { sender };
Self { receiver, handle }
}
pub fn split(self) -> (TaskQueueHandle, async_channel::Receiver<SolveTask>) {
(self.handle, self.receiver)
}
#[cfg(test)]
pub fn handle(&self) -> TaskQueueHandle {
self.handle.clone()
}
#[cfg(test)]
pub fn into_receiver(self) -> async_channel::Receiver<SolveTask> {
self.receiver
}
}
#[cfg(test)]
mod tests {
use num_bigint::BigUint;
use rstest::rstest;
use tycho_simulation::tycho_core::{models::Address, Bytes};
use super::*;
use crate::{BlockInfo, Order, OrderQuote, OrderSide, QuoteStatus, SingleOrderQuote};
fn make_address(byte: u8) -> Address {
Address::from([byte; 20])
}
fn make_order() -> Order {
Order::new(
make_address(0x01),
make_address(0x02),
BigUint::from(1000u64),
OrderSide::Sell,
make_address(0xAA),
)
.with_id("test-order".to_string())
}
fn make_single_quote() -> SingleOrderQuote {
SingleOrderQuote::new(
OrderQuote::new(
"test-order".to_string(),
QuoteStatus::Success,
BigUint::from(1000u64),
BigUint::from(990u64),
BigUint::from(100_000u64),
BigUint::from(990u64),
BlockInfo::new(1, "0x123".to_string(), 1000),
"test".to_string(),
Bytes::from(make_address(0xAA).as_ref()),
Bytes::from(make_address(0xAA).as_ref()),
),
5,
)
}
#[test]
fn test_config_default() {
let config = TaskQueueConfig::default();
assert_eq!(config.capacity, 1000);
}
#[rstest]
#[case::small(1)]
#[case::medium(100)]
#[case::large(10_000)]
fn test_config_custom_capacity(#[case] capacity: usize) {
let config = TaskQueueConfig { capacity };
assert_eq!(config.capacity, capacity);
}
#[rstest]
#[case::capacity_1(1)]
#[case::capacity_10(10)]
#[case::capacity_100(100)]
fn test_queue_creation(#[case] capacity: usize) {
let config = TaskQueueConfig { capacity };
let queue = TaskQueue::new(config);
let handle = queue.handle();
assert!(!handle.is_full());
assert_eq!(handle.approximate_depth(), 0);
}
#[test]
fn test_queue_handle_is_cloneable() {
let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
let handle1 = queue.handle();
let handle2 = handle1.clone();
assert_eq!(handle1.approximate_depth(), handle2.approximate_depth());
assert_eq!(handle1.is_full(), handle2.is_full());
}
#[test]
fn test_queue_into_receiver() {
let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
let _handle = queue.handle();
let _receiver = queue.into_receiver();
}
#[test]
fn test_queue_split() {
let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
let (handle, _receiver) = queue.split();
assert!(!handle.is_full());
assert_eq!(handle.approximate_depth(), 0);
}
#[tokio::test]
async fn test_enqueue_and_receive_response() {
let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
let handle = queue.handle();
let receiver = queue.into_receiver();
let worker = tokio::spawn(async move {
let task = receiver
.recv()
.await
.expect("should receive task");
assert_eq!(task.order().id(), "test-order");
task.respond(Ok(make_single_quote()));
});
let result = handle.enqueue(make_order()).await;
worker
.await
.expect("worker should complete");
let quote = result.expect("should get quote");
assert_eq!(quote.solve_time_ms(), 5);
}
#[tokio::test]
async fn test_enqueue_receives_error_response() {
let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
let handle = queue.handle();
let receiver = queue.into_receiver();
let worker = tokio::spawn(async move {
let task = receiver
.recv()
.await
.expect("should receive task");
task.respond(Err(SolveError::NoRouteFound { order_id: "test".to_string() }));
});
let result = handle.enqueue(make_order()).await;
worker
.await
.expect("worker should complete");
assert!(matches!(result, Err(SolveError::NoRouteFound { .. })));
}
#[tokio::test]
async fn test_enqueue_error_when_receiver_dropped() {
let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
let handle = queue.handle();
let receiver = queue.into_receiver();
let worker = tokio::spawn(async move {
let task = receiver
.recv()
.await
.expect("should receive task");
drop(task); });
let result = handle.enqueue(make_order()).await;
worker
.await
.expect("worker should complete");
assert!(matches!(result, Err(SolveError::Internal(_))));
}
#[tokio::test]
async fn test_enqueue_queue_full_error() {
let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
let handle = queue.handle();
let receiver = queue.into_receiver();
drop(receiver);
let result = handle.enqueue(make_order()).await;
assert!(matches!(result, Err(SolveError::QueueFull)));
}
#[tokio::test]
async fn test_approximate_depth_increases_with_pending_tasks() {
let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
let handle = queue.handle();
let _receiver = queue.into_receiver();
let (response_tx, _response_rx) = oneshot::channel();
let task = SolveTask::new(Uuid::new_v4(), make_order(), response_tx);
handle
.sender
.send(task)
.await
.expect("should send");
assert_eq!(handle.approximate_depth(), 1);
let (response_tx2, _response_rx2) = oneshot::channel();
let task2 = SolveTask::new(Uuid::new_v4(), make_order(), response_tx2);
handle
.sender
.send(task2)
.await
.expect("should send");
assert_eq!(handle.approximate_depth(), 2);
}
#[rstest]
#[case::capacity_1(1)]
#[case::capacity_5(5)]
#[case::capacity_10(10)]
#[tokio::test]
async fn test_is_full_when_at_capacity(#[case] capacity: usize) {
let queue = TaskQueue::new(TaskQueueConfig { capacity });
let handle = queue.handle();
let _receiver = queue.into_receiver();
for _ in 0..capacity {
let (response_tx, _response_rx) = oneshot::channel();
let task = SolveTask::new(Uuid::new_v4(), make_order(), response_tx);
handle
.sender
.send(task)
.await
.expect("should send");
}
assert!(handle.is_full());
assert_eq!(handle.approximate_depth(), capacity);
}
#[tokio::test]
async fn test_is_full_becomes_false_after_task_consumed() {
let queue = TaskQueue::new(TaskQueueConfig { capacity: 2 });
let handle = queue.handle();
let receiver = queue.into_receiver();
let (tx1, _rx1) = oneshot::channel();
let (tx2, _rx2) = oneshot::channel();
handle
.sender
.send(SolveTask::new(Uuid::new_v4(), make_order(), tx1))
.await
.unwrap();
handle
.sender
.send(SolveTask::new(Uuid::new_v4(), make_order(), tx2))
.await
.unwrap();
assert!(handle.is_full());
let _task = receiver.recv().await.unwrap();
assert!(!handle.is_full());
assert_eq!(handle.approximate_depth(), 1);
}
#[tokio::test]
async fn test_multiple_handles_can_enqueue_concurrently() {
let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
let handle1 = queue.handle();
let handle2 = queue.handle();
let receiver = queue.into_receiver();
let worker = tokio::spawn(async move {
for _ in 0..2 {
let task = receiver
.recv()
.await
.expect("should receive task");
task.respond(Ok(make_single_quote()));
}
});
let (result1, result2) =
tokio::join!(handle1.enqueue(make_order()), handle2.enqueue(make_order()),);
worker
.await
.expect("worker should complete");
assert!(result1.is_ok());
assert!(result2.is_ok());
}
#[tokio::test]
async fn test_task_id_is_unique_per_enqueue() {
let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
let handle = queue.handle();
let receiver = queue.into_receiver();
let collector = tokio::spawn(async move {
let task1 = receiver.recv().await.unwrap();
let id1 = task1.id();
task1.respond(Ok(make_single_quote()));
let task2 = receiver.recv().await.unwrap();
let id2 = task2.id();
task2.respond(Ok(make_single_quote()));
(id1, id2)
});
let _ = handle.enqueue(make_order()).await;
let _ = handle.enqueue(make_order()).await;
let (id1, id2): (Uuid, Uuid) = collector
.await
.expect("collector should complete");
assert_ne!(id1, id2, "Task IDs should be unique");
}
#[test]
fn test_solve_task_wait_time_increases() {
let (response_tx, _response_rx) = oneshot::channel();
let task = SolveTask::new(Uuid::new_v4(), make_order(), response_tx);
let wait1 = task.wait_time();
std::thread::sleep(std::time::Duration::from_millis(10));
let wait2 = task.wait_time();
assert!(wait2 > wait1);
}
#[tokio::test]
async fn test_solve_task_respond_delivers_result() {
let (response_tx, response_rx) = oneshot::channel();
let task = SolveTask::new(Uuid::new_v4(), make_order(), response_tx);
task.respond(Ok(make_single_quote()));
let result = response_rx
.await
.expect("should receive response");
assert!(result.is_ok());
}
#[tokio::test]
async fn test_solve_task_respond_delivers_error() {
let (response_tx, response_rx) = oneshot::channel();
let task = SolveTask::new(Uuid::new_v4(), make_order(), response_tx);
task.respond(Err(SolveError::Timeout { elapsed_ms: 100 }));
let result = response_rx
.await
.expect("should receive response");
assert!(matches!(result, Err(SolveError::Timeout { elapsed_ms: 100 })));
}
}