1use tokio::sync::oneshot;
8use uuid::Uuid;
9
10use crate::{types::internal::SolveTask, Order, SingleOrderQuote, SolveError, SolveParams};
11
12#[derive(Debug, Clone)]
14pub struct TaskQueueConfig {
15 pub capacity: usize,
17}
18
19impl Default for TaskQueueConfig {
20 fn default() -> Self {
21 Self { capacity: 1000 }
22 }
23}
24
25#[derive(Clone)]
29pub struct TaskQueueHandle {
30 sender: async_channel::Sender<SolveTask>,
31}
32
33impl TaskQueueHandle {
34 pub async fn enqueue(
38 &self,
39 order: Order,
40 params: SolveParams,
41 ) -> Result<SingleOrderQuote, SolveError> {
42 let (response_tx, response_rx) = oneshot::channel();
44
45 let task_id = Uuid::new_v4();
47
48 let task = SolveTask::new(task_id, order, response_tx).with_params(params);
50
51 self.sender
53 .send(task)
54 .await
55 .map_err(|_| SolveError::QueueFull)?;
56
57 response_rx
59 .await
60 .map_err(|_| SolveError::Internal("worker dropped response channel".to_string()))?
61 }
62
63 #[cfg(test)]
67 pub fn approximate_depth(&self) -> usize {
68 self.sender.len()
69 }
70
71 #[cfg(test)]
73 pub fn is_full(&self) -> bool {
74 self.sender.is_full()
75 }
76
77 pub fn from_sender(sender: async_channel::Sender<SolveTask>) -> Self {
81 Self { sender }
82 }
83}
84
85pub struct TaskQueue {
89 receiver: async_channel::Receiver<SolveTask>,
90 handle: TaskQueueHandle,
91}
92
93impl TaskQueue {
94 pub fn new(config: TaskQueueConfig) -> Self {
96 let (sender, receiver) = async_channel::bounded(config.capacity);
97 let handle = TaskQueueHandle { sender };
98
99 Self { receiver, handle }
100 }
101
102 pub fn split(self) -> (TaskQueueHandle, async_channel::Receiver<SolveTask>) {
104 (self.handle, self.receiver)
105 }
106
107 #[cfg(test)]
109 pub fn handle(&self) -> TaskQueueHandle {
110 self.handle.clone()
111 }
112
113 #[cfg(test)]
117 pub fn into_receiver(self) -> async_channel::Receiver<SolveTask> {
118 self.receiver
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use num_bigint::BigUint;
125 use rstest::rstest;
126 use tycho_simulation::tycho_core::{models::Address, Bytes};
127
128 use super::*;
129 use crate::{
130 BlockInfo, Order, OrderQuote, OrderSide, QuoteStatus, SingleOrderQuote, SolveParams,
131 };
132
133 fn make_address(byte: u8) -> Address {
138 Address::from([byte; 20])
139 }
140
141 fn make_order() -> Order {
142 Order::new(
143 make_address(0x01),
144 make_address(0x02),
145 BigUint::from(1000u64),
146 OrderSide::Sell,
147 make_address(0xAA),
148 )
149 .with_id("test-order".to_string())
150 }
151
152 fn make_single_quote() -> SingleOrderQuote {
153 SingleOrderQuote::new(
154 OrderQuote::new(
155 "test-order".to_string(),
156 QuoteStatus::Success,
157 BigUint::from(1000u64),
158 BigUint::from(990u64),
159 BigUint::from(100_000u64),
160 BigUint::from(990u64),
161 BlockInfo::new(1, "0x123".to_string(), 1000),
162 "test".to_string(),
163 Bytes::from(make_address(0xAA).as_ref()),
164 Bytes::from(make_address(0xAA).as_ref()),
165 "1".to_string(),
166 ),
167 5,
168 )
169 }
170
171 #[test]
176 fn test_config_default() {
177 let config = TaskQueueConfig::default();
178 assert_eq!(config.capacity, 1000);
179 }
180
181 #[rstest]
182 #[case::small(1)]
183 #[case::medium(100)]
184 #[case::large(10_000)]
185 fn test_config_custom_capacity(#[case] capacity: usize) {
186 let config = TaskQueueConfig { capacity };
187 assert_eq!(config.capacity, capacity);
188 }
189
190 #[rstest]
195 #[case::capacity_1(1)]
196 #[case::capacity_10(10)]
197 #[case::capacity_100(100)]
198 fn test_queue_creation(#[case] capacity: usize) {
199 let config = TaskQueueConfig { capacity };
200 let queue = TaskQueue::new(config);
201 let handle = queue.handle();
202
203 assert!(!handle.is_full());
204 assert_eq!(handle.approximate_depth(), 0);
205 }
206
207 #[test]
208 fn test_queue_handle_is_cloneable() {
209 let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
210 let handle1 = queue.handle();
211 let handle2 = handle1.clone();
212
213 assert_eq!(handle1.approximate_depth(), handle2.approximate_depth());
215 assert_eq!(handle1.is_full(), handle2.is_full());
216 }
217
218 #[test]
219 fn test_queue_into_receiver() {
220 let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
221 let _handle = queue.handle();
222 let _receiver = queue.into_receiver();
223 }
225
226 #[test]
227 fn test_queue_split() {
228 let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
229 let (handle, _receiver) = queue.split();
230
231 assert!(!handle.is_full());
232 assert_eq!(handle.approximate_depth(), 0);
233 }
234
235 #[tokio::test]
240 async fn test_enqueue_and_receive_response() {
241 let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
242 let handle = queue.handle();
243 let receiver = queue.into_receiver();
244
245 let worker = tokio::spawn(async move {
247 let task = receiver
248 .recv()
249 .await
250 .expect("should receive task");
251 assert_eq!(task.order().id(), "test-order");
252 task.respond(Ok(make_single_quote()));
253 });
254
255 let result = handle
257 .enqueue(make_order(), SolveParams::default())
258 .await;
259
260 worker
261 .await
262 .expect("worker should complete");
263 let quote = result.expect("should get quote");
264 assert_eq!(quote.solve_time_ms(), 5);
265 }
266
267 #[tokio::test]
268 async fn test_enqueue_receives_error_response() {
269 let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
270 let handle = queue.handle();
271 let receiver = queue.into_receiver();
272
273 let worker = tokio::spawn(async move {
274 let task = receiver
275 .recv()
276 .await
277 .expect("should receive task");
278 task.respond(Err(SolveError::NoRouteFound { order_id: "test".to_string() }));
279 });
280
281 let result = handle
282 .enqueue(make_order(), SolveParams::default())
283 .await;
284
285 worker
286 .await
287 .expect("worker should complete");
288 assert!(matches!(result, Err(SolveError::NoRouteFound { .. })));
289 }
290
291 #[tokio::test]
292 async fn test_enqueue_error_when_receiver_dropped() {
293 let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
294 let handle = queue.handle();
295 let receiver = queue.into_receiver();
296
297 let worker = tokio::spawn(async move {
299 let task = receiver
300 .recv()
301 .await
302 .expect("should receive task");
303 drop(task); });
305
306 let result = handle
307 .enqueue(make_order(), SolveParams::default())
308 .await;
309
310 worker
311 .await
312 .expect("worker should complete");
313 assert!(matches!(result, Err(SolveError::Internal(_))));
314 }
315
316 #[tokio::test]
317 async fn test_enqueue_queue_full_error() {
318 let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
319 let handle = queue.handle();
320 let receiver = queue.into_receiver();
321
322 drop(receiver);
324
325 let result = handle
326 .enqueue(make_order(), SolveParams::default())
327 .await;
328 assert!(matches!(result, Err(SolveError::QueueFull)));
329 }
330
331 #[tokio::test]
336 async fn test_approximate_depth_increases_with_pending_tasks() {
337 let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
338 let handle = queue.handle();
339 let _receiver = queue.into_receiver(); let (response_tx, _response_rx) = oneshot::channel();
343 let task = SolveTask::new(Uuid::new_v4(), make_order(), response_tx);
344
345 handle
346 .sender
347 .send(task)
348 .await
349 .expect("should send");
350
351 assert_eq!(handle.approximate_depth(), 1);
352
353 let (response_tx2, _response_rx2) = oneshot::channel();
355 let task2 = SolveTask::new(Uuid::new_v4(), make_order(), response_tx2);
356 handle
357 .sender
358 .send(task2)
359 .await
360 .expect("should send");
361
362 assert_eq!(handle.approximate_depth(), 2);
363 }
364
365 #[rstest]
366 #[case::capacity_1(1)]
367 #[case::capacity_5(5)]
368 #[case::capacity_10(10)]
369 #[tokio::test]
370 async fn test_is_full_when_at_capacity(#[case] capacity: usize) {
371 let queue = TaskQueue::new(TaskQueueConfig { capacity });
372 let handle = queue.handle();
373 let _receiver = queue.into_receiver();
374
375 for _ in 0..capacity {
377 let (response_tx, _response_rx) = oneshot::channel();
378 let task = SolveTask::new(Uuid::new_v4(), make_order(), response_tx);
379 handle
380 .sender
381 .send(task)
382 .await
383 .expect("should send");
384 }
385
386 assert!(handle.is_full());
387 assert_eq!(handle.approximate_depth(), capacity);
388 }
389
390 #[tokio::test]
391 async fn test_is_full_becomes_false_after_task_consumed() {
392 let queue = TaskQueue::new(TaskQueueConfig { capacity: 2 });
393 let handle = queue.handle();
394 let receiver = queue.into_receiver();
395
396 let (tx1, _rx1) = oneshot::channel();
398 let (tx2, _rx2) = oneshot::channel();
399 handle
400 .sender
401 .send(SolveTask::new(Uuid::new_v4(), make_order(), tx1))
402 .await
403 .unwrap();
404 handle
405 .sender
406 .send(SolveTask::new(Uuid::new_v4(), make_order(), tx2))
407 .await
408 .unwrap();
409
410 assert!(handle.is_full());
411
412 let _task = receiver.recv().await.unwrap();
414
415 assert!(!handle.is_full());
417 assert_eq!(handle.approximate_depth(), 1);
418 }
419
420 #[tokio::test]
425 async fn test_multiple_handles_can_enqueue_concurrently() {
426 let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
427 let handle1 = queue.handle();
428 let handle2 = queue.handle();
429 let receiver = queue.into_receiver();
430
431 let worker = tokio::spawn(async move {
433 for _ in 0..2 {
434 let task = receiver
435 .recv()
436 .await
437 .expect("should receive task");
438 task.respond(Ok(make_single_quote()));
439 }
440 });
441
442 let (result1, result2) = tokio::join!(
444 handle1.enqueue(make_order(), SolveParams::default()),
445 handle2.enqueue(make_order(), SolveParams::default()),
446 );
447
448 worker
449 .await
450 .expect("worker should complete");
451
452 assert!(result1.is_ok());
453 assert!(result2.is_ok());
454 }
455
456 #[tokio::test]
457 async fn test_task_id_is_unique_per_enqueue() {
458 let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
459 let handle = queue.handle();
460 let receiver = queue.into_receiver();
461
462 let collector = tokio::spawn(async move {
464 let task1 = receiver.recv().await.unwrap();
465 let id1 = task1.id();
466 task1.respond(Ok(make_single_quote()));
467
468 let task2 = receiver.recv().await.unwrap();
469 let id2 = task2.id();
470 task2.respond(Ok(make_single_quote()));
471
472 (id1, id2)
473 });
474
475 let _ = handle
477 .enqueue(make_order(), SolveParams::default())
478 .await;
479 let _ = handle
480 .enqueue(make_order(), SolveParams::default())
481 .await;
482
483 let (id1, id2): (Uuid, Uuid) = collector
484 .await
485 .expect("collector should complete");
486 assert_ne!(id1, id2, "Task IDs should be unique");
487 }
488
489 #[test]
494 fn test_solve_task_wait_time_increases() {
495 let (response_tx, _response_rx) = oneshot::channel();
496 let task = SolveTask::new(Uuid::new_v4(), make_order(), response_tx);
497
498 let wait1 = task.wait_time();
499 std::thread::sleep(std::time::Duration::from_millis(10));
500 let wait2 = task.wait_time();
501
502 assert!(wait2 > wait1);
503 }
504
505 #[tokio::test]
506 async fn test_solve_task_respond_delivers_result() {
507 let (response_tx, response_rx) = oneshot::channel();
508 let task = SolveTask::new(Uuid::new_v4(), make_order(), response_tx);
509
510 task.respond(Ok(make_single_quote()));
511
512 let result = response_rx
513 .await
514 .expect("should receive response");
515 assert!(result.is_ok());
516 }
517
518 #[tokio::test]
519 async fn test_solve_task_respond_delivers_error() {
520 let (response_tx, response_rx) = oneshot::channel();
521 let task = SolveTask::new(Uuid::new_v4(), make_order(), response_tx);
522
523 task.respond(Err(SolveError::Timeout { elapsed_ms: 100 }));
524
525 let result = response_rx
526 .await
527 .expect("should receive response");
528 assert!(matches!(result, Err(SolveError::Timeout { elapsed_ms: 100 })));
529 }
530}