Skip to main content

fynd_core/worker_pool/
task_queue.rs

1//! Task queue for distributing solve requests to workers.
2//!
3//! The queue sits between the HTTP handlers and the worker pool.
4//! It provides backpressure and allows the HTTP layer to remain
5//! responsive even when workers are busy.
6
7use tokio::sync::oneshot;
8use uuid::Uuid;
9
10use crate::{types::internal::SolveTask, Order, SingleOrderQuote, SolveError};
11
12/// Configuration for the task queue.
13#[derive(Debug, Clone)]
14pub struct TaskQueueConfig {
15    /// Maximum number of pending tasks.
16    pub capacity: usize,
17}
18
19impl Default for TaskQueueConfig {
20    fn default() -> Self {
21        Self { capacity: 1000 }
22    }
23}
24
25/// Handle for enqueueing tasks.
26///
27/// This is cloned and shared with HTTP handlers.
28#[derive(Clone)]
29pub struct TaskQueueHandle {
30    sender: async_channel::Sender<SolveTask>,
31}
32
33impl TaskQueueHandle {
34    /// Enqueues a solve request and returns a future that resolves to the result.
35    ///
36    /// Returns an error if the queue is full.
37    pub async fn enqueue(&self, order: Order) -> Result<SingleOrderQuote, SolveError> {
38        // Create response channel
39        let (response_tx, response_rx) = oneshot::channel();
40
41        // Generate task ID
42        let task_id = Uuid::new_v4();
43
44        // Create task
45        let task = SolveTask::new(task_id, order, response_tx);
46
47        // Try to send
48        self.sender
49            .send(task)
50            .await
51            .map_err(|_| SolveError::QueueFull)?;
52
53        // Wait for response
54        response_rx
55            .await
56            .map_err(|_| SolveError::Internal("worker dropped response channel".to_string()))?
57    }
58
59    /// Returns the current approximate queue depth.
60    ///
61    /// Note: This is not exact due to the async nature of the queue.
62    #[cfg(test)]
63    pub fn approximate_depth(&self) -> usize {
64        self.sender.len()
65    }
66
67    /// Returns true if the queue is likely full.
68    #[cfg(test)]
69    pub fn is_full(&self) -> bool {
70        self.sender.is_full()
71    }
72
73    /// Creates a TaskQueueHandle from an existing sender.
74    ///
75    /// This is primarily useful for testing with mock channels.
76    pub fn from_sender(sender: async_channel::Sender<SolveTask>) -> Self {
77        Self { sender }
78    }
79}
80
81/// The task queue itself.
82///
83/// This is consumed when creating the worker pool.
84pub struct TaskQueue {
85    receiver: async_channel::Receiver<SolveTask>,
86    handle: TaskQueueHandle,
87}
88
89impl TaskQueue {
90    /// Creates a new task queue with the given configuration.
91    pub fn new(config: TaskQueueConfig) -> Self {
92        let (sender, receiver) = async_channel::bounded(config.capacity);
93        let handle = TaskQueueHandle { sender };
94
95        Self { receiver, handle }
96    }
97
98    /// Splits the queue into handle and receiver.
99    pub fn split(self) -> (TaskQueueHandle, async_channel::Receiver<SolveTask>) {
100        (self.handle, self.receiver)
101    }
102
103    /// Returns a handle for enqueueing tasks.
104    #[cfg(test)]
105    pub fn handle(&self) -> TaskQueueHandle {
106        self.handle.clone()
107    }
108
109    /// Consumes the queue and returns the receiver.
110    ///
111    /// This is called when setting up the worker pool.
112    #[cfg(test)]
113    pub fn into_receiver(self) -> async_channel::Receiver<SolveTask> {
114        self.receiver
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use num_bigint::BigUint;
121    use rstest::rstest;
122    use tycho_simulation::tycho_core::{models::Address, Bytes};
123
124    use super::*;
125    use crate::{BlockInfo, Order, OrderQuote, OrderSide, QuoteStatus, SingleOrderQuote};
126
127    // -------------------------------------------------------------------------
128    // Test Helpers
129    // -------------------------------------------------------------------------
130
131    fn make_address(byte: u8) -> Address {
132        Address::from([byte; 20])
133    }
134
135    fn make_order() -> Order {
136        Order::new(
137            make_address(0x01),
138            make_address(0x02),
139            BigUint::from(1000u64),
140            OrderSide::Sell,
141            make_address(0xAA),
142        )
143        .with_id("test-order".to_string())
144    }
145
146    fn make_single_quote() -> SingleOrderQuote {
147        SingleOrderQuote::new(
148            OrderQuote::new(
149                "test-order".to_string(),
150                QuoteStatus::Success,
151                BigUint::from(1000u64),
152                BigUint::from(990u64),
153                BigUint::from(100_000u64),
154                BigUint::from(990u64),
155                BlockInfo::new(1, "0x123".to_string(), 1000),
156                "test".to_string(),
157                Bytes::from(make_address(0xAA).as_ref()),
158                Bytes::from(make_address(0xAA).as_ref()),
159            ),
160            5,
161        )
162    }
163
164    // -------------------------------------------------------------------------
165    // TaskQueueConfig Tests
166    // -------------------------------------------------------------------------
167
168    #[test]
169    fn test_config_default() {
170        let config = TaskQueueConfig::default();
171        assert_eq!(config.capacity, 1000);
172    }
173
174    #[rstest]
175    #[case::small(1)]
176    #[case::medium(100)]
177    #[case::large(10_000)]
178    fn test_config_custom_capacity(#[case] capacity: usize) {
179        let config = TaskQueueConfig { capacity };
180        assert_eq!(config.capacity, capacity);
181    }
182
183    // -------------------------------------------------------------------------
184    // TaskQueue Creation Tests
185    // -------------------------------------------------------------------------
186
187    #[rstest]
188    #[case::capacity_1(1)]
189    #[case::capacity_10(10)]
190    #[case::capacity_100(100)]
191    fn test_queue_creation(#[case] capacity: usize) {
192        let config = TaskQueueConfig { capacity };
193        let queue = TaskQueue::new(config);
194        let handle = queue.handle();
195
196        assert!(!handle.is_full());
197        assert_eq!(handle.approximate_depth(), 0);
198    }
199
200    #[test]
201    fn test_queue_handle_is_cloneable() {
202        let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
203        let handle1 = queue.handle();
204        let handle2 = handle1.clone();
205
206        // Both handles should report same state
207        assert_eq!(handle1.approximate_depth(), handle2.approximate_depth());
208        assert_eq!(handle1.is_full(), handle2.is_full());
209    }
210
211    #[test]
212    fn test_queue_into_receiver() {
213        let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
214        let _handle = queue.handle();
215        let _receiver = queue.into_receiver();
216        // Queue is consumed - receiver is ready for worker pool
217    }
218
219    #[test]
220    fn test_queue_split() {
221        let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
222        let (handle, _receiver) = queue.split();
223
224        assert!(!handle.is_full());
225        assert_eq!(handle.approximate_depth(), 0);
226    }
227
228    // -------------------------------------------------------------------------
229    // TaskQueueHandle Tests
230    // -------------------------------------------------------------------------
231
232    #[tokio::test]
233    async fn test_enqueue_and_receive_response() {
234        let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
235        let handle = queue.handle();
236        let receiver = queue.into_receiver();
237
238        // Spawn a "worker" that responds to the task
239        let worker = tokio::spawn(async move {
240            let task = receiver
241                .recv()
242                .await
243                .expect("should receive task");
244            assert_eq!(task.order().id(), "test-order");
245            task.respond(Ok(make_single_quote()));
246        });
247
248        // Enqueue an order
249        let result = handle.enqueue(make_order()).await;
250
251        worker
252            .await
253            .expect("worker should complete");
254        let quote = result.expect("should get quote");
255        assert_eq!(quote.solve_time_ms(), 5);
256    }
257
258    #[tokio::test]
259    async fn test_enqueue_receives_error_response() {
260        let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
261        let handle = queue.handle();
262        let receiver = queue.into_receiver();
263
264        let worker = tokio::spawn(async move {
265            let task = receiver
266                .recv()
267                .await
268                .expect("should receive task");
269            task.respond(Err(SolveError::NoRouteFound { order_id: "test".to_string() }));
270        });
271
272        let result = handle.enqueue(make_order()).await;
273
274        worker
275            .await
276            .expect("worker should complete");
277        assert!(matches!(result, Err(SolveError::NoRouteFound { .. })));
278    }
279
280    #[tokio::test]
281    async fn test_enqueue_error_when_receiver_dropped() {
282        let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
283        let handle = queue.handle();
284        let receiver = queue.into_receiver();
285
286        // Worker receives task but drops it without responding
287        let worker = tokio::spawn(async move {
288            let task = receiver
289                .recv()
290                .await
291                .expect("should receive task");
292            drop(task); // Drop without responding
293        });
294
295        let result = handle.enqueue(make_order()).await;
296
297        worker
298            .await
299            .expect("worker should complete");
300        assert!(matches!(result, Err(SolveError::Internal(_))));
301    }
302
303    #[tokio::test]
304    async fn test_enqueue_queue_full_error() {
305        let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
306        let handle = queue.handle();
307        let receiver = queue.into_receiver();
308
309        // Drop receiver to close channel
310        drop(receiver);
311
312        let result = handle.enqueue(make_order()).await;
313        assert!(matches!(result, Err(SolveError::QueueFull)));
314    }
315
316    // -------------------------------------------------------------------------
317    // Queue Depth and Full Detection Tests
318    // -------------------------------------------------------------------------
319
320    #[tokio::test]
321    async fn test_approximate_depth_increases_with_pending_tasks() {
322        let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
323        let handle = queue.handle();
324        let _receiver = queue.into_receiver(); // Keep receiver alive but don't consume
325
326        // Create a oneshot and send a task
327        let (response_tx, _response_rx) = oneshot::channel();
328        let task = SolveTask::new(Uuid::new_v4(), make_order(), response_tx);
329
330        handle
331            .sender
332            .send(task)
333            .await
334            .expect("should send");
335
336        assert_eq!(handle.approximate_depth(), 1);
337
338        // Send another
339        let (response_tx2, _response_rx2) = oneshot::channel();
340        let task2 = SolveTask::new(Uuid::new_v4(), make_order(), response_tx2);
341        handle
342            .sender
343            .send(task2)
344            .await
345            .expect("should send");
346
347        assert_eq!(handle.approximate_depth(), 2);
348    }
349
350    #[rstest]
351    #[case::capacity_1(1)]
352    #[case::capacity_5(5)]
353    #[case::capacity_10(10)]
354    #[tokio::test]
355    async fn test_is_full_when_at_capacity(#[case] capacity: usize) {
356        let queue = TaskQueue::new(TaskQueueConfig { capacity });
357        let handle = queue.handle();
358        let _receiver = queue.into_receiver();
359
360        // Fill the queue
361        for _ in 0..capacity {
362            let (response_tx, _response_rx) = oneshot::channel();
363            let task = SolveTask::new(Uuid::new_v4(), make_order(), response_tx);
364            handle
365                .sender
366                .send(task)
367                .await
368                .expect("should send");
369        }
370
371        assert!(handle.is_full());
372        assert_eq!(handle.approximate_depth(), capacity);
373    }
374
375    #[tokio::test]
376    async fn test_is_full_becomes_false_after_task_consumed() {
377        let queue = TaskQueue::new(TaskQueueConfig { capacity: 2 });
378        let handle = queue.handle();
379        let receiver = queue.into_receiver();
380
381        // Fill queue
382        let (tx1, _rx1) = oneshot::channel();
383        let (tx2, _rx2) = oneshot::channel();
384        handle
385            .sender
386            .send(SolveTask::new(Uuid::new_v4(), make_order(), tx1))
387            .await
388            .unwrap();
389        handle
390            .sender
391            .send(SolveTask::new(Uuid::new_v4(), make_order(), tx2))
392            .await
393            .unwrap();
394
395        assert!(handle.is_full());
396
397        // Consume one task
398        let _task = receiver.recv().await.unwrap();
399
400        // Queue should no longer be full
401        assert!(!handle.is_full());
402        assert_eq!(handle.approximate_depth(), 1);
403    }
404
405    // -------------------------------------------------------------------------
406    // Concurrent Operation Tests
407    // -------------------------------------------------------------------------
408
409    #[tokio::test]
410    async fn test_multiple_handles_can_enqueue_concurrently() {
411        let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
412        let handle1 = queue.handle();
413        let handle2 = queue.handle();
414        let receiver = queue.into_receiver();
415
416        // Spawn worker that processes multiple tasks
417        let worker = tokio::spawn(async move {
418            for _ in 0..2 {
419                let task = receiver
420                    .recv()
421                    .await
422                    .expect("should receive task");
423                task.respond(Ok(make_single_quote()));
424            }
425        });
426
427        // Enqueue from both handles concurrently
428        let (result1, result2) =
429            tokio::join!(handle1.enqueue(make_order()), handle2.enqueue(make_order()),);
430
431        worker
432            .await
433            .expect("worker should complete");
434
435        assert!(result1.is_ok());
436        assert!(result2.is_ok());
437    }
438
439    #[tokio::test]
440    async fn test_task_id_is_unique_per_enqueue() {
441        let queue = TaskQueue::new(TaskQueueConfig { capacity: 10 });
442        let handle = queue.handle();
443        let receiver = queue.into_receiver();
444
445        // Spawn workers to collect task IDs
446        let collector = tokio::spawn(async move {
447            let task1 = receiver.recv().await.unwrap();
448            let id1 = task1.id();
449            task1.respond(Ok(make_single_quote()));
450
451            let task2 = receiver.recv().await.unwrap();
452            let id2 = task2.id();
453            task2.respond(Ok(make_single_quote()));
454
455            (id1, id2)
456        });
457
458        // Enqueue two orders
459        let _ = handle.enqueue(make_order()).await;
460        let _ = handle.enqueue(make_order()).await;
461
462        let (id1, id2): (Uuid, Uuid) = collector
463            .await
464            .expect("collector should complete");
465        assert_ne!(id1, id2, "Task IDs should be unique");
466    }
467
468    // -------------------------------------------------------------------------
469    // SolveTask Tests (internal type used by queue)
470    // -------------------------------------------------------------------------
471
472    #[test]
473    fn test_solve_task_wait_time_increases() {
474        let (response_tx, _response_rx) = oneshot::channel();
475        let task = SolveTask::new(Uuid::new_v4(), make_order(), response_tx);
476
477        let wait1 = task.wait_time();
478        std::thread::sleep(std::time::Duration::from_millis(10));
479        let wait2 = task.wait_time();
480
481        assert!(wait2 > wait1);
482    }
483
484    #[tokio::test]
485    async fn test_solve_task_respond_delivers_result() {
486        let (response_tx, response_rx) = oneshot::channel();
487        let task = SolveTask::new(Uuid::new_v4(), make_order(), response_tx);
488
489        task.respond(Ok(make_single_quote()));
490
491        let result = response_rx
492            .await
493            .expect("should receive response");
494        assert!(result.is_ok());
495    }
496
497    #[tokio::test]
498    async fn test_solve_task_respond_delivers_error() {
499        let (response_tx, response_rx) = oneshot::channel();
500        let task = SolveTask::new(Uuid::new_v4(), make_order(), response_tx);
501
502        task.respond(Err(SolveError::Timeout { elapsed_ms: 100 }));
503
504        let result = response_rx
505            .await
506            .expect("should receive response");
507        assert!(matches!(result, Err(SolveError::Timeout { elapsed_ms: 100 })));
508    }
509}