Skip to main content

fynd_core/worker_pool/
task_queue.rs

1//! Task queue for distributing solve requests to workers.
2//!
3//! The queue sits between the HTTP handlers and the worker pool.
4//! It provides backpressure and allows the HTTP layer to remain
5//! responsive even when workers are busy.
6
7use tokio::sync::oneshot;
8use uuid::Uuid;
9
10use crate::{types::internal::SolveTask, Order, SingleOrderQuote, SolveError, SolveParams};
11
12/// Configuration for the task queue.
13#[derive(Debug, Clone)]
14pub struct TaskQueueConfig {
15    /// Maximum number of pending tasks.
16    pub capacity: usize,
17}
18
19impl Default for TaskQueueConfig {
20    fn default() -> Self {
21        Self { capacity: 1000 }
22    }
23}
24
25/// Handle for enqueueing tasks.
26///
27/// This is cloned and shared with HTTP handlers.
28#[derive(Clone)]
29pub struct TaskQueueHandle {
30    sender: async_channel::Sender<SolveTask>,
31}
32
33impl TaskQueueHandle {
34    /// Enqueues a solve request and returns a future that resolves to the result.
35    ///
36    /// Returns an error if the queue is full.
37    pub async fn enqueue(
38        &self,
39        order: Order,
40        params: SolveParams,
41    ) -> Result<SingleOrderQuote, SolveError> {
42        // Create response channel
43        let (response_tx, response_rx) = oneshot::channel();
44
45        // Generate task ID
46        let task_id = Uuid::new_v4();
47
48        // Create task
49        let task = SolveTask::new(task_id, order, response_tx).with_params(params);
50
51        // Try to send
52        self.sender
53            .send(task)
54            .await
55            .map_err(|_| SolveError::QueueFull)?;
56
57        // Wait for response
58        response_rx
59            .await
60            .map_err(|_| SolveError::Internal("worker dropped response channel".to_string()))?
61    }
62
63    /// Returns the current approximate queue depth.
64    ///
65    /// Note: This is not exact due to the async nature of the queue.
66    #[cfg(test)]
67    pub fn approximate_depth(&self) -> usize {
68        self.sender.len()
69    }
70
71    /// Returns true if the queue is likely full.
72    #[cfg(test)]
73    pub fn is_full(&self) -> bool {
74        self.sender.is_full()
75    }
76
77    /// Creates a TaskQueueHandle from an existing sender.
78    ///
79    /// This is primarily useful for testing with mock channels.
80    pub fn from_sender(sender: async_channel::Sender<SolveTask>) -> Self {
81        Self { sender }
82    }
83}
84
85/// The task queue itself.
86///
87/// This is consumed when creating the worker pool.
88pub struct TaskQueue {
89    receiver: async_channel::Receiver<SolveTask>,
90    handle: TaskQueueHandle,
91}
92
93impl TaskQueue {
94    /// Creates a new task queue with the given configuration.
95    pub fn new(config: TaskQueueConfig) -> Self {
96        let (sender, receiver) = async_channel::bounded(config.capacity);
97        let handle = TaskQueueHandle { sender };
98
99        Self { receiver, handle }
100    }
101
102    /// Splits the queue into handle and receiver.
103    pub fn split(self) -> (TaskQueueHandle, async_channel::Receiver<SolveTask>) {
104        (self.handle, self.receiver)
105    }
106
107    /// Returns a handle for enqueueing tasks.
108    #[cfg(test)]
109    pub fn handle(&self) -> TaskQueueHandle {
110        self.handle.clone()
111    }
112
113    /// Consumes the queue and returns the receiver.
114    ///
115    /// This is called when setting up the worker pool.
116    #[cfg(test)]
117    pub fn into_receiver(self) -> async_channel::Receiver<SolveTask> {
118        self.receiver
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use num_bigint::BigUint;
125    use rstest::rstest;
126    use tycho_simulation::tycho_core::{models::Address, Bytes};
127
128    use super::*;
129    use crate::{
130        BlockInfo, Order, OrderQuote, OrderSide, QuoteStatus, SingleOrderQuote, SolveParams,
131    };
132
133    // -------------------------------------------------------------------------
134    // Test Helpers
135    // -------------------------------------------------------------------------
136
137    fn make_address(byte: u8) -> Address {
138        Address::from([byte; 20])
139    }
140
141    fn make_order() -> Order {
142        Order::new(
143            make_address(0x01),
144            make_address(0x02),
145            BigUint::from(1000u64),
146            OrderSide::Sell,
147            make_address(0xAA),
148        )
149        .with_id("test-order".to_string())
150    }
151
152    fn make_single_quote() -> SingleOrderQuote {
153        SingleOrderQuote::new(
154            OrderQuote::new(
155                "test-order".to_string(),
156                QuoteStatus::Success,
157                BigUint::from(1000u64),
158                BigUint::from(990u64),
159                BigUint::from(100_000u64),
160                BigUint::from(990u64),
161                BlockInfo::new(1, "0x123".to_string(), 1000),
162                "test".to_string(),
163                Bytes::from(make_address(0xAA).as_ref()),
164                Bytes::from(make_address(0xAA).as_ref()),
165                "1".to_string(),
166            ),
167            5,
168        )
169    }
170
171    // -------------------------------------------------------------------------
172    // TaskQueueConfig Tests
173    // -------------------------------------------------------------------------
174
175    #[test]
176    fn test_config_default() {
177        let config = TaskQueueConfig::default();
178        assert_eq!(config.capacity, 1000);
179    }
180
181    #[rstest]
182    #[case::small(1)]
183    #[case::medium(100)]
184    #[case::large(10_000)]
185    fn test_config_custom_capacity(#[case] capacity: usize) {
186        let config = TaskQueueConfig { capacity };
187        assert_eq!(config.capacity, capacity);
188    }
189
190    // -------------------------------------------------------------------------
191    // TaskQueue Creation Tests
192    // -------------------------------------------------------------------------
193
194    #[rstest]
195    #[case::capacity_1(1)]
196    #[case::capacity_10(10)]
197    #[case::capacity_100(100)]
198    fn test_queue_creation(#[case] capacity: usize) {
199        let config = TaskQueueConfig { capacity };
200        let queue = TaskQueue::new(config);
201        let handle = queue.handle();
202
203        assert!(!handle.is_full());
204        assert_eq!(handle.approximate_depth(), 0);
205    }
206
207    #[test]
208    fn test_queue_handle_is_cloneable() {
209        let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
210        let handle1 = queue.handle();
211        let handle2 = handle1.clone();
212
213        // Both handles should report same state
214        assert_eq!(handle1.approximate_depth(), handle2.approximate_depth());
215        assert_eq!(handle1.is_full(), handle2.is_full());
216    }
217
218    #[test]
219    fn test_queue_into_receiver() {
220        let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
221        let _handle = queue.handle();
222        let _receiver = queue.into_receiver();
223        // Queue is consumed - receiver is ready for worker pool
224    }
225
226    #[test]
227    fn test_queue_split() {
228        let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
229        let (handle, _receiver) = queue.split();
230
231        assert!(!handle.is_full());
232        assert_eq!(handle.approximate_depth(), 0);
233    }
234
235    // -------------------------------------------------------------------------
236    // TaskQueueHandle Tests
237    // -------------------------------------------------------------------------
238
239    #[tokio::test]
240    async fn test_enqueue_and_receive_response() {
241        let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
242        let handle = queue.handle();
243        let receiver = queue.into_receiver();
244
245        // Spawn a "worker" that responds to the task
246        let worker = tokio::spawn(async move {
247            let task = receiver
248                .recv()
249                .await
250                .expect("should receive task");
251            assert_eq!(task.order().id(), "test-order");
252            task.respond(Ok(make_single_quote()));
253        });
254
255        // Enqueue an order
256        let result = handle
257            .enqueue(make_order(), SolveParams::default())
258            .await;
259
260        worker
261            .await
262            .expect("worker should complete");
263        let quote = result.expect("should get quote");
264        assert_eq!(quote.solve_time_ms(), 5);
265    }
266
267    #[tokio::test]
268    async fn test_enqueue_receives_error_response() {
269        let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
270        let handle = queue.handle();
271        let receiver = queue.into_receiver();
272
273        let worker = tokio::spawn(async move {
274            let task = receiver
275                .recv()
276                .await
277                .expect("should receive task");
278            task.respond(Err(SolveError::NoRouteFound { order_id: "test".to_string() }));
279        });
280
281        let result = handle
282            .enqueue(make_order(), SolveParams::default())
283            .await;
284
285        worker
286            .await
287            .expect("worker should complete");
288        assert!(matches!(result, Err(SolveError::NoRouteFound { .. })));
289    }
290
291    #[tokio::test]
292    async fn test_enqueue_error_when_receiver_dropped() {
293        let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
294        let handle = queue.handle();
295        let receiver = queue.into_receiver();
296
297        // Worker receives task but drops it without responding
298        let worker = tokio::spawn(async move {
299            let task = receiver
300                .recv()
301                .await
302                .expect("should receive task");
303            drop(task); // Drop without responding
304        });
305
306        let result = handle
307            .enqueue(make_order(), SolveParams::default())
308            .await;
309
310        worker
311            .await
312            .expect("worker should complete");
313        assert!(matches!(result, Err(SolveError::Internal(_))));
314    }
315
316    #[tokio::test]
317    async fn test_enqueue_queue_full_error() {
318        let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
319        let handle = queue.handle();
320        let receiver = queue.into_receiver();
321
322        // Drop receiver to close channel
323        drop(receiver);
324
325        let result = handle
326            .enqueue(make_order(), SolveParams::default())
327            .await;
328        assert!(matches!(result, Err(SolveError::QueueFull)));
329    }
330
331    // -------------------------------------------------------------------------
332    // Queue Depth and Full Detection Tests
333    // -------------------------------------------------------------------------
334
335    #[tokio::test]
336    async fn test_approximate_depth_increases_with_pending_tasks() {
337        let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
338        let handle = queue.handle();
339        let _receiver = queue.into_receiver(); // Keep receiver alive but don't consume
340
341        // Create a oneshot and send a task
342        let (response_tx, _response_rx) = oneshot::channel();
343        let task = SolveTask::new(Uuid::new_v4(), make_order(), response_tx);
344
345        handle
346            .sender
347            .send(task)
348            .await
349            .expect("should send");
350
351        assert_eq!(handle.approximate_depth(), 1);
352
353        // Send another
354        let (response_tx2, _response_rx2) = oneshot::channel();
355        let task2 = SolveTask::new(Uuid::new_v4(), make_order(), response_tx2);
356        handle
357            .sender
358            .send(task2)
359            .await
360            .expect("should send");
361
362        assert_eq!(handle.approximate_depth(), 2);
363    }
364
365    #[rstest]
366    #[case::capacity_1(1)]
367    #[case::capacity_5(5)]
368    #[case::capacity_10(10)]
369    #[tokio::test]
370    async fn test_is_full_when_at_capacity(#[case] capacity: usize) {
371        let queue = TaskQueue::new(TaskQueueConfig { capacity });
372        let handle = queue.handle();
373        let _receiver = queue.into_receiver();
374
375        // Fill the queue
376        for _ in 0..capacity {
377            let (response_tx, _response_rx) = oneshot::channel();
378            let task = SolveTask::new(Uuid::new_v4(), make_order(), response_tx);
379            handle
380                .sender
381                .send(task)
382                .await
383                .expect("should send");
384        }
385
386        assert!(handle.is_full());
387        assert_eq!(handle.approximate_depth(), capacity);
388    }
389
390    #[tokio::test]
391    async fn test_is_full_becomes_false_after_task_consumed() {
392        let queue = TaskQueue::new(TaskQueueConfig { capacity: 2 });
393        let handle = queue.handle();
394        let receiver = queue.into_receiver();
395
396        // Fill queue
397        let (tx1, _rx1) = oneshot::channel();
398        let (tx2, _rx2) = oneshot::channel();
399        handle
400            .sender
401            .send(SolveTask::new(Uuid::new_v4(), make_order(), tx1))
402            .await
403            .unwrap();
404        handle
405            .sender
406            .send(SolveTask::new(Uuid::new_v4(), make_order(), tx2))
407            .await
408            .unwrap();
409
410        assert!(handle.is_full());
411
412        // Consume one task
413        let _task = receiver.recv().await.unwrap();
414
415        // Queue should no longer be full
416        assert!(!handle.is_full());
417        assert_eq!(handle.approximate_depth(), 1);
418    }
419
420    // -------------------------------------------------------------------------
421    // Concurrent Operation Tests
422    // -------------------------------------------------------------------------
423
424    #[tokio::test]
425    async fn test_multiple_handles_can_enqueue_concurrently() {
426        let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
427        let handle1 = queue.handle();
428        let handle2 = queue.handle();
429        let receiver = queue.into_receiver();
430
431        // Spawn worker that processes multiple tasks
432        let worker = tokio::spawn(async move {
433            for _ in 0..2 {
434                let task = receiver
435                    .recv()
436                    .await
437                    .expect("should receive task");
438                task.respond(Ok(make_single_quote()));
439            }
440        });
441
442        // Enqueue from both handles concurrently
443        let (result1, result2) = tokio::join!(
444            handle1.enqueue(make_order(), SolveParams::default()),
445            handle2.enqueue(make_order(), SolveParams::default()),
446        );
447
448        worker
449            .await
450            .expect("worker should complete");
451
452        assert!(result1.is_ok());
453        assert!(result2.is_ok());
454    }
455
456    #[tokio::test]
457    async fn test_task_id_is_unique_per_enqueue() {
458        let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
459        let handle = queue.handle();
460        let receiver = queue.into_receiver();
461
462        // Spawn workers to collect task IDs
463        let collector = tokio::spawn(async move {
464            let task1 = receiver.recv().await.unwrap();
465            let id1 = task1.id();
466            task1.respond(Ok(make_single_quote()));
467
468            let task2 = receiver.recv().await.unwrap();
469            let id2 = task2.id();
470            task2.respond(Ok(make_single_quote()));
471
472            (id1, id2)
473        });
474
475        // Enqueue two orders
476        let _ = handle
477            .enqueue(make_order(), SolveParams::default())
478            .await;
479        let _ = handle
480            .enqueue(make_order(), SolveParams::default())
481            .await;
482
483        let (id1, id2): (Uuid, Uuid) = collector
484            .await
485            .expect("collector should complete");
486        assert_ne!(id1, id2, "Task IDs should be unique");
487    }
488
489    // -------------------------------------------------------------------------
490    // SolveTask Tests (internal type used by queue)
491    // -------------------------------------------------------------------------
492
493    #[test]
494    fn test_solve_task_wait_time_increases() {
495        let (response_tx, _response_rx) = oneshot::channel();
496        let task = SolveTask::new(Uuid::new_v4(), make_order(), response_tx);
497
498        let wait1 = task.wait_time();
499        std::thread::sleep(std::time::Duration::from_millis(10));
500        let wait2 = task.wait_time();
501
502        assert!(wait2 > wait1);
503    }
504
505    #[tokio::test]
506    async fn test_solve_task_respond_delivers_result() {
507        let (response_tx, response_rx) = oneshot::channel();
508        let task = SolveTask::new(Uuid::new_v4(), make_order(), response_tx);
509
510        task.respond(Ok(make_single_quote()));
511
512        let result = response_rx
513            .await
514            .expect("should receive response");
515        assert!(result.is_ok());
516    }
517
518    #[tokio::test]
519    async fn test_solve_task_respond_delivers_error() {
520        let (response_tx, response_rx) = oneshot::channel();
521        let task = SolveTask::new(Uuid::new_v4(), make_order(), response_tx);
522
523        task.respond(Err(SolveError::Timeout { elapsed_ms: 100 }));
524
525        let result = response_rx
526            .await
527            .expect("should receive response");
528        assert!(matches!(result, Err(SolveError::Timeout { elapsed_ms: 100 })));
529    }
530}