1use tokio::sync::oneshot;
8use uuid::Uuid;
9
10use crate::{types::internal::SolveTask, Order, SingleOrderQuote, SolveError};
11
12#[derive(Debug, Clone)]
14pub struct TaskQueueConfig {
15 pub capacity: usize,
17}
18
19impl Default for TaskQueueConfig {
20 fn default() -> Self {
21 Self { capacity: 1000 }
22 }
23}
24
25#[derive(Clone)]
29pub struct TaskQueueHandle {
30 sender: async_channel::Sender<SolveTask>,
31}
32
33impl TaskQueueHandle {
34 pub async fn enqueue(&self, order: Order) -> Result<SingleOrderQuote, SolveError> {
38 let (response_tx, response_rx) = oneshot::channel();
40
41 let task_id = Uuid::new_v4();
43
44 let task = SolveTask::new(task_id, order, response_tx);
46
47 self.sender
49 .send(task)
50 .await
51 .map_err(|_| SolveError::QueueFull)?;
52
53 response_rx
55 .await
56 .map_err(|_| SolveError::Internal("worker dropped response channel".to_string()))?
57 }
58
59 #[cfg(test)]
63 pub fn approximate_depth(&self) -> usize {
64 self.sender.len()
65 }
66
67 #[cfg(test)]
69 pub fn is_full(&self) -> bool {
70 self.sender.is_full()
71 }
72
73 pub fn from_sender(sender: async_channel::Sender<SolveTask>) -> Self {
77 Self { sender }
78 }
79}
80
81pub struct TaskQueue {
85 receiver: async_channel::Receiver<SolveTask>,
86 handle: TaskQueueHandle,
87}
88
89impl TaskQueue {
90 pub fn new(config: TaskQueueConfig) -> Self {
92 let (sender, receiver) = async_channel::bounded(config.capacity);
93 let handle = TaskQueueHandle { sender };
94
95 Self { receiver, handle }
96 }
97
98 pub fn split(self) -> (TaskQueueHandle, async_channel::Receiver<SolveTask>) {
100 (self.handle, self.receiver)
101 }
102
103 #[cfg(test)]
105 pub fn handle(&self) -> TaskQueueHandle {
106 self.handle.clone()
107 }
108
109 #[cfg(test)]
113 pub fn into_receiver(self) -> async_channel::Receiver<SolveTask> {
114 self.receiver
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use num_bigint::BigUint;
121 use rstest::rstest;
122 use tycho_simulation::tycho_core::{models::Address, Bytes};
123
124 use super::*;
125 use crate::{BlockInfo, Order, OrderQuote, OrderSide, QuoteStatus, SingleOrderQuote};
126
127 fn make_address(byte: u8) -> Address {
132 Address::from([byte; 20])
133 }
134
135 fn make_order() -> Order {
136 Order::new(
137 make_address(0x01),
138 make_address(0x02),
139 BigUint::from(1000u64),
140 OrderSide::Sell,
141 make_address(0xAA),
142 )
143 .with_id("test-order".to_string())
144 }
145
146 fn make_single_quote() -> SingleOrderQuote {
147 SingleOrderQuote::new(
148 OrderQuote::new(
149 "test-order".to_string(),
150 QuoteStatus::Success,
151 BigUint::from(1000u64),
152 BigUint::from(990u64),
153 BigUint::from(100_000u64),
154 BigUint::from(990u64),
155 BlockInfo::new(1, "0x123".to_string(), 1000),
156 "test".to_string(),
157 Bytes::from(make_address(0xAA).as_ref()),
158 Bytes::from(make_address(0xAA).as_ref()),
159 ),
160 5,
161 )
162 }
163
164 #[test]
169 fn test_config_default() {
170 let config = TaskQueueConfig::default();
171 assert_eq!(config.capacity, 1000);
172 }
173
174 #[rstest]
175 #[case::small(1)]
176 #[case::medium(100)]
177 #[case::large(10_000)]
178 fn test_config_custom_capacity(#[case] capacity: usize) {
179 let config = TaskQueueConfig { capacity };
180 assert_eq!(config.capacity, capacity);
181 }
182
183 #[rstest]
188 #[case::capacity_1(1)]
189 #[case::capacity_10(10)]
190 #[case::capacity_100(100)]
191 fn test_queue_creation(#[case] capacity: usize) {
192 let config = TaskQueueConfig { capacity };
193 let queue = TaskQueue::new(config);
194 let handle = queue.handle();
195
196 assert!(!handle.is_full());
197 assert_eq!(handle.approximate_depth(), 0);
198 }
199
200 #[test]
201 fn test_queue_handle_is_cloneable() {
202 let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
203 let handle1 = queue.handle();
204 let handle2 = handle1.clone();
205
206 assert_eq!(handle1.approximate_depth(), handle2.approximate_depth());
208 assert_eq!(handle1.is_full(), handle2.is_full());
209 }
210
211 #[test]
212 fn test_queue_into_receiver() {
213 let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
214 let _handle = queue.handle();
215 let _receiver = queue.into_receiver();
216 }
218
219 #[test]
220 fn test_queue_split() {
221 let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
222 let (handle, _receiver) = queue.split();
223
224 assert!(!handle.is_full());
225 assert_eq!(handle.approximate_depth(), 0);
226 }
227
228 #[tokio::test]
233 async fn test_enqueue_and_receive_response() {
234 let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
235 let handle = queue.handle();
236 let receiver = queue.into_receiver();
237
238 let worker = tokio::spawn(async move {
240 let task = receiver
241 .recv()
242 .await
243 .expect("should receive task");
244 assert_eq!(task.order().id(), "test-order");
245 task.respond(Ok(make_single_quote()));
246 });
247
248 let result = handle.enqueue(make_order()).await;
250
251 worker
252 .await
253 .expect("worker should complete");
254 let quote = result.expect("should get quote");
255 assert_eq!(quote.solve_time_ms(), 5);
256 }
257
258 #[tokio::test]
259 async fn test_enqueue_receives_error_response() {
260 let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
261 let handle = queue.handle();
262 let receiver = queue.into_receiver();
263
264 let worker = tokio::spawn(async move {
265 let task = receiver
266 .recv()
267 .await
268 .expect("should receive task");
269 task.respond(Err(SolveError::NoRouteFound { order_id: "test".to_string() }));
270 });
271
272 let result = handle.enqueue(make_order()).await;
273
274 worker
275 .await
276 .expect("worker should complete");
277 assert!(matches!(result, Err(SolveError::NoRouteFound { .. })));
278 }
279
280 #[tokio::test]
281 async fn test_enqueue_error_when_receiver_dropped() {
282 let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
283 let handle = queue.handle();
284 let receiver = queue.into_receiver();
285
286 let worker = tokio::spawn(async move {
288 let task = receiver
289 .recv()
290 .await
291 .expect("should receive task");
292 drop(task); });
294
295 let result = handle.enqueue(make_order()).await;
296
297 worker
298 .await
299 .expect("worker should complete");
300 assert!(matches!(result, Err(SolveError::Internal(_))));
301 }
302
303 #[tokio::test]
304 async fn test_enqueue_queue_full_error() {
305 let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
306 let handle = queue.handle();
307 let receiver = queue.into_receiver();
308
309 drop(receiver);
311
312 let result = handle.enqueue(make_order()).await;
313 assert!(matches!(result, Err(SolveError::QueueFull)));
314 }
315
316 #[tokio::test]
321 async fn test_approximate_depth_increases_with_pending_tasks() {
322 let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
323 let handle = queue.handle();
324 let _receiver = queue.into_receiver(); let (response_tx, _response_rx) = oneshot::channel();
328 let task = SolveTask::new(Uuid::new_v4(), make_order(), response_tx);
329
330 handle
331 .sender
332 .send(task)
333 .await
334 .expect("should send");
335
336 assert_eq!(handle.approximate_depth(), 1);
337
338 let (response_tx2, _response_rx2) = oneshot::channel();
340 let task2 = SolveTask::new(Uuid::new_v4(), make_order(), response_tx2);
341 handle
342 .sender
343 .send(task2)
344 .await
345 .expect("should send");
346
347 assert_eq!(handle.approximate_depth(), 2);
348 }
349
350 #[rstest]
351 #[case::capacity_1(1)]
352 #[case::capacity_5(5)]
353 #[case::capacity_10(10)]
354 #[tokio::test]
355 async fn test_is_full_when_at_capacity(#[case] capacity: usize) {
356 let queue = TaskQueue::new(TaskQueueConfig { capacity });
357 let handle = queue.handle();
358 let _receiver = queue.into_receiver();
359
360 for _ in 0..capacity {
362 let (response_tx, _response_rx) = oneshot::channel();
363 let task = SolveTask::new(Uuid::new_v4(), make_order(), response_tx);
364 handle
365 .sender
366 .send(task)
367 .await
368 .expect("should send");
369 }
370
371 assert!(handle.is_full());
372 assert_eq!(handle.approximate_depth(), capacity);
373 }
374
375 #[tokio::test]
376 async fn test_is_full_becomes_false_after_task_consumed() {
377 let queue = TaskQueue::new(TaskQueueConfig { capacity: 2 });
378 let handle = queue.handle();
379 let receiver = queue.into_receiver();
380
381 let (tx1, _rx1) = oneshot::channel();
383 let (tx2, _rx2) = oneshot::channel();
384 handle
385 .sender
386 .send(SolveTask::new(Uuid::new_v4(), make_order(), tx1))
387 .await
388 .unwrap();
389 handle
390 .sender
391 .send(SolveTask::new(Uuid::new_v4(), make_order(), tx2))
392 .await
393 .unwrap();
394
395 assert!(handle.is_full());
396
397 let _task = receiver.recv().await.unwrap();
399
400 assert!(!handle.is_full());
402 assert_eq!(handle.approximate_depth(), 1);
403 }
404
405 #[tokio::test]
410 async fn test_multiple_handles_can_enqueue_concurrently() {
411 let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
412 let handle1 = queue.handle();
413 let handle2 = queue.handle();
414 let receiver = queue.into_receiver();
415
416 let worker = tokio::spawn(async move {
418 for _ in 0..2 {
419 let task = receiver
420 .recv()
421 .await
422 .expect("should receive task");
423 task.respond(Ok(make_single_quote()));
424 }
425 });
426
427 let (result1, result2) =
429 tokio::join!(handle1.enqueue(make_order()), handle2.enqueue(make_order()),);
430
431 worker
432 .await
433 .expect("worker should complete");
434
435 assert!(result1.is_ok());
436 assert!(result2.is_ok());
437 }
438
439 #[tokio::test]
440 async fn test_task_id_is_unique_per_enqueue() {
441 let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
442 let handle = queue.handle();
443 let receiver = queue.into_receiver();
444
445 let collector = tokio::spawn(async move {
447 let task1 = receiver.recv().await.unwrap();
448 let id1 = task1.id();
449 task1.respond(Ok(make_single_quote()));
450
451 let task2 = receiver.recv().await.unwrap();
452 let id2 = task2.id();
453 task2.respond(Ok(make_single_quote()));
454
455 (id1, id2)
456 });
457
458 let _ = handle.enqueue(make_order()).await;
460 let _ = handle.enqueue(make_order()).await;
461
462 let (id1, id2): (Uuid, Uuid) = collector
463 .await
464 .expect("collector should complete");
465 assert_ne!(id1, id2, "Task IDs should be unique");
466 }
467
468 #[test]
473 fn test_solve_task_wait_time_increases() {
474 let (response_tx, _response_rx) = oneshot::channel();
475 let task = SolveTask::new(Uuid::new_v4(), make_order(), response_tx);
476
477 let wait1 = task.wait_time();
478 std::thread::sleep(std::time::Duration::from_millis(10));
479 let wait2 = task.wait_time();
480
481 assert!(wait2 > wait1);
482 }
483
484 #[tokio::test]
485 async fn test_solve_task_respond_delivers_result() {
486 let (response_tx, response_rx) = oneshot::channel();
487 let task = SolveTask::new(Uuid::new_v4(), make_order(), response_tx);
488
489 task.respond(Ok(make_single_quote()));
490
491 let result = response_rx
492 .await
493 .expect("should receive response");
494 assert!(result.is_ok());
495 }
496
497 #[tokio::test]
498 async fn test_solve_task_respond_delivers_error() {
499 let (response_tx, response_rx) = oneshot::channel();
500 let task = SolveTask::new(Uuid::new_v4(), make_order(), response_tx);
501
502 task.respond(Err(SolveError::Timeout { elapsed_ms: 100 }));
503
504 let result = response_rx
505 .await
506 .expect("should receive response");
507 assert!(matches!(result, Err(SolveError::Timeout { elapsed_ms: 100 })));
508 }
509}