1use std::collections::HashSet;
2use std::collections::VecDeque;
3use std::sync::atomic::{AtomicU64, Ordering};
4use std::sync::Mutex;
5use std::time::{Duration, Instant};
6
7use bytes::Bytes;
8use http::{HeaderMap, StatusCode};
9use tokio::sync::Mutex as AsyncMutex;
10use tokio::sync::{oneshot, Notify, Semaphore};
11
12use crate::error::ScatterProxyError;
13
14#[derive(Debug)]
16pub struct ScatterResponse {
17 pub status: StatusCode,
18 pub headers: HeaderMap,
19 pub body: Bytes,
20}
21
22#[derive(Debug)]
30pub struct TaskHandle {
31 rx: AsyncMutex<oneshot::Receiver<ScatterResponse>>,
32}
33
34impl TaskHandle {
35 pub async fn with_timeout(
41 &self,
42 duration: Duration,
43 ) -> Result<Option<ScatterResponse>, ScatterProxyError> {
44 let mut rx = self.rx.lock().await;
45 match tokio::time::timeout(duration, &mut *rx).await {
46 Ok(Ok(resp)) => Ok(Some(resp)),
47 Ok(Err(_)) => Ok(Some(ScatterResponse {
48 status: StatusCode::BAD_GATEWAY,
49 headers: HeaderMap::new(),
50 body: Bytes::from_static(b"scatter-proxy: internal error - task channel closed"),
51 })),
52 Err(_) => Ok(None),
53 }
54 }
55}
56
57#[derive(Debug)]
59pub(crate) struct TaskEntry {
60 #[allow(dead_code)]
61 pub id: u64,
62 pub request: reqwest::Request,
63 pub host: String,
64 pub attempts: usize,
66 pub result_tx: Option<oneshot::Sender<ScatterResponse>>,
68 pub last_error: String,
70}
71
72#[derive(Debug)]
78struct DelayedTask {
79 ready_at: Instant,
80 entry: TaskEntry,
81}
82
83impl PartialEq for DelayedTask {
84 fn eq(&self, other: &Self) -> bool {
85 self.ready_at.eq(&other.ready_at)
86 }
87}
88
89impl Eq for DelayedTask {}
90
91impl PartialOrd for DelayedTask {
92 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
93 Some(self.cmp(other))
94 }
95}
96
97impl Ord for DelayedTask {
98 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
99 self.ready_at.cmp(&other.ready_at)
100 }
101}
102
103pub struct TaskPool {
104 queue: Mutex<VecDeque<TaskEntry>>,
105 delayed: Mutex<std::collections::BinaryHeap<std::cmp::Reverse<DelayedTask>>>,
106 capacity: usize,
107 capacity_sem: Semaphore,
111 next_id: AtomicU64,
112 notify: Notify,
114 completed: AtomicU64,
115 failed: AtomicU64,
116 requeued: AtomicU64,
117 zero_available: AtomicU64,
118 skipped_no_permit: AtomicU64,
119 skipped_rate_limit: AtomicU64,
120 skipped_cooldown: AtomicU64,
121 dispatches: AtomicU64,
122}
123
124impl TaskPool {
125 pub fn new(capacity: usize) -> Self {
127 Self {
128 queue: Mutex::new(VecDeque::new()),
129 delayed: Mutex::new(std::collections::BinaryHeap::new()),
130 capacity,
131 capacity_sem: Semaphore::new(capacity),
132 next_id: AtomicU64::new(1),
133 notify: Notify::new(),
134 completed: AtomicU64::new(0),
135 failed: AtomicU64::new(0),
136 requeued: AtomicU64::new(0),
137 zero_available: AtomicU64::new(0),
138 skipped_no_permit: AtomicU64::new(0),
139 skipped_rate_limit: AtomicU64::new(0),
140 skipped_cooldown: AtomicU64::new(0),
141 dispatches: AtomicU64::new(0),
142 }
143 }
144
145 pub async fn submit(&self, request: reqwest::Request) -> TaskHandle {
152 let permit = self
154 .capacity_sem
155 .acquire()
156 .await
157 .expect("capacity semaphore closed");
158 permit.forget(); self.enqueue(request)
161 }
162
163 pub fn try_submit(&self, request: reqwest::Request) -> Result<TaskHandle, ScatterProxyError> {
165 let permit = self
166 .capacity_sem
167 .try_acquire()
168 .map_err(|_| ScatterProxyError::PoolFull {
169 capacity: self.capacity,
170 })?;
171 permit.forget();
172 Ok(self.enqueue(request))
173 }
174
175 pub async fn submit_timeout(
180 &self,
181 request: reqwest::Request,
182 timeout: Duration,
183 ) -> Result<TaskHandle, ScatterProxyError> {
184 match tokio::time::timeout(timeout, self.submit(request)).await {
185 Ok(handle) => Ok(handle),
186 Err(_) => Err(ScatterProxyError::Timeout { elapsed: timeout }),
187 }
188 }
189
190 pub async fn submit_batch(&self, requests: Vec<reqwest::Request>) -> Vec<TaskHandle> {
195 let mut handles = Vec::with_capacity(requests.len());
196 for req in requests {
197 handles.push(self.submit(req).await);
198 }
199 handles
200 }
201
202 pub fn try_submit_batch(
205 &self,
206 requests: Vec<reqwest::Request>,
207 ) -> Result<Vec<TaskHandle>, ScatterProxyError> {
208 let count = requests.len();
209 if count == 0 {
210 return Ok(Vec::new());
211 }
212
213 let permit = self
215 .capacity_sem
216 .try_acquire_many(count as u32)
217 .map_err(|_| ScatterProxyError::PoolFull {
218 capacity: self.capacity,
219 })?;
220 permit.forget();
221
222 let mut handles = Vec::with_capacity(count);
223 for req in requests {
224 handles.push(self.enqueue(req));
225 }
226 Ok(handles)
227 }
228
229 fn enqueue(&self, request: reqwest::Request) -> TaskHandle {
235 let host = request.url().host_str().unwrap_or("unknown").to_string();
236 let (tx, rx) = oneshot::channel();
237 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
238
239 let entry = TaskEntry {
240 id,
241 request,
242 host,
243 attempts: 0,
244 result_tx: Some(tx),
245 last_error: String::new(),
246 };
247
248 {
249 self.promote_ready_delayed();
250 let mut queue = self.queue.lock().unwrap();
251 queue.push_back(entry);
252 }
253
254 self.notify.notify_one();
255 TaskHandle {
256 rx: AsyncMutex::new(rx),
257 }
258 }
259
260 pub(crate) fn promote_ready_delayed(&self) -> usize {
267 let now = Instant::now();
268 let mut delayed = self.delayed.lock().unwrap();
269 if delayed.is_empty() {
270 return 0;
271 }
272 let mut ready = Vec::new();
273 while let Some(std::cmp::Reverse(item)) = delayed.peek() {
274 if item.ready_at <= now {
275 let std::cmp::Reverse(item) = delayed.pop().expect("heap peeked item must pop");
276 ready.push(item.entry);
277 } else {
278 break;
279 }
280 }
281 drop(delayed);
282 if ready.is_empty() {
283 return 0;
284 }
285 let count = ready.len();
286 let mut queue = self.queue.lock().unwrap();
287 for entry in ready {
288 queue.push_back(entry);
289 }
290 count
291 }
292
293 pub(crate) fn next_delayed_ready_in(&self) -> Option<Duration> {
294 let delayed = self.delayed.lock().unwrap();
295 let now = Instant::now();
296 delayed
297 .peek()
298 .map(|d| d.0.ready_at.saturating_duration_since(now))
299 }
300
301 pub(crate) fn pick_next(&self, skip_hosts: &HashSet<String>) -> Option<TaskEntry> {
302 let mut queue = self.queue.lock().unwrap();
303 if skip_hosts.is_empty() {
304 return queue.pop_front();
305 }
306
307 let len = queue.len();
308 for _ in 0..len {
309 let entry = queue.pop_front()?;
310 if !skip_hosts.contains(&entry.host) {
311 return Some(entry);
312 }
313 queue.push_back(entry);
314 }
315
316 None
317 }
318
319 pub(crate) fn push_back(&self, entry: TaskEntry) {
321 self.requeued.fetch_add(1, Ordering::Relaxed);
322 {
323 let mut queue = self.queue.lock().unwrap();
324 queue.push_back(entry);
325 }
326 self.notify.notify_one();
327 }
328
329 pub(crate) fn push_delayed(&self, entry: TaskEntry, delay: Duration) {
330 self.requeued.fetch_add(1, Ordering::Relaxed);
331 {
332 let mut delayed = self.delayed.lock().unwrap();
333 delayed.push(std::cmp::Reverse(DelayedTask {
334 ready_at: Instant::now() + delay,
335 entry,
336 }));
337 }
338 self.notify.notify_one();
339 }
340
341 pub fn pending_count(&self) -> usize {
343 let queue = self.queue.lock().unwrap();
344 queue.len()
345 }
346
347 pub fn delayed_count(&self) -> usize {
348 let delayed = self.delayed.lock().unwrap();
349 delayed.len()
350 }
351
352 pub fn completed_count(&self) -> u64 {
354 self.completed.load(Ordering::Relaxed)
355 }
356
357 pub(crate) fn mark_completed(&self) {
360 self.completed.fetch_add(1, Ordering::Relaxed);
361 self.capacity_sem.add_permits(1);
362 }
363
364 pub(crate) fn mark_failed(&self) {
367 self.failed.fetch_add(1, Ordering::Relaxed);
368 self.capacity_sem.add_permits(1);
369 }
370
371 pub fn failed_count(&self) -> u64 {
373 self.failed.load(Ordering::Relaxed)
374 }
375 pub fn requeued_count(&self) -> u64 {
376 self.requeued.load(Ordering::Relaxed)
377 }
378
379 pub(crate) fn mark_zero_available(&self) {
380 self.zero_available.fetch_add(1, Ordering::Relaxed);
381 }
382
383 pub fn zero_available_count(&self) -> u64 {
384 self.zero_available.load(Ordering::Relaxed)
385 }
386
387 pub(crate) fn mark_skipped_no_permit(&self) {
388 self.skipped_no_permit.fetch_add(1, Ordering::Relaxed);
389 }
390
391 pub fn skipped_no_permit_count(&self) -> u64 {
392 self.skipped_no_permit.load(Ordering::Relaxed)
393 }
394
395 pub(crate) fn mark_skipped_rate_limit(&self) {
396 self.skipped_rate_limit.fetch_add(1, Ordering::Relaxed);
397 }
398
399 pub fn skipped_rate_limit_count(&self) -> u64 {
400 self.skipped_rate_limit.load(Ordering::Relaxed)
401 }
402
403 pub(crate) fn mark_skipped_cooldown(&self) {
404 self.skipped_cooldown.fetch_add(1, Ordering::Relaxed);
405 }
406
407 pub fn skipped_cooldown_count(&self) -> u64 {
408 self.skipped_cooldown.load(Ordering::Relaxed)
409 }
410
411 pub(crate) fn mark_dispatch(&self) {
412 self.dispatches.fetch_add(1, Ordering::Relaxed);
413 }
414
415 pub fn dispatch_count(&self) -> u64 {
416 self.dispatches.load(Ordering::Relaxed)
417 }
418
419 #[allow(dead_code)]
421 pub(crate) async fn notified(&self) {
422 self.notify.notified().await;
423 }
424}
425
426#[cfg(test)]
429mod tests {
430 use super::*;
431
432 fn test_request() -> reqwest::Request {
434 reqwest::Client::new()
435 .get("http://example.com/test")
436 .build()
437 .unwrap()
438 }
439
440 #[test]
443 fn new_pool_has_zero_pending() {
444 let pool = TaskPool::new(10);
445 assert_eq!(pool.pending_count(), 0);
446 assert_eq!(pool.delayed_count(), 0);
447 assert_eq!(pool.completed_count(), 0);
448 }
449
450 #[test]
453 fn try_submit_increments_pending_count() {
454 let pool = TaskPool::new(10);
455 let _h1 = pool.try_submit(test_request()).unwrap();
456 let _h2 = pool.try_submit(test_request()).unwrap();
457 assert_eq!(pool.pending_count(), 2);
458 }
459
460 #[test]
461 fn try_submit_returns_pool_full_when_at_capacity() {
462 let pool = TaskPool::new(2);
463 let _h1 = pool.try_submit(test_request()).unwrap();
464 let _h2 = pool.try_submit(test_request()).unwrap();
465 let result = pool.try_submit(test_request());
466 assert!(result.is_err());
467 match result.unwrap_err() {
468 ScatterProxyError::PoolFull { capacity } => assert_eq!(capacity, 2),
469 other => panic!("expected PoolFull, got {other:?}"),
470 }
471 }
472
473 #[test]
474 fn try_submit_assigns_incrementing_ids() {
475 let pool = TaskPool::new(10);
476 let _h1 = pool.try_submit(test_request()).unwrap();
477 let _h2 = pool.try_submit(test_request()).unwrap();
478
479 let skip = HashSet::new();
480 let t1 = pool.pick_next(&skip).unwrap();
481 let t2 = pool.pick_next(&skip).unwrap();
482 assert!(t2.id > t1.id);
483 }
484
485 #[test]
486 fn try_submit_extracts_host_from_url() {
487 let pool = TaskPool::new(10);
488 let _h = pool.try_submit(test_request()).unwrap();
489 let skip = HashSet::new();
490 let task = pool.pick_next(&skip).unwrap();
491 assert_eq!(task.host, "example.com");
492 }
493
494 #[test]
497 fn try_submit_batch_adds_all_tasks() {
498 let pool = TaskPool::new(10);
499 let reqs = vec![test_request(), test_request(), test_request()];
500 let handles = pool.try_submit_batch(reqs).unwrap();
501 assert_eq!(handles.len(), 3);
502 assert_eq!(pool.pending_count(), 3);
503 }
504
505 #[test]
506 fn try_submit_batch_atomic_rejection_when_pool_full() {
507 let pool = TaskPool::new(2);
508 let reqs = vec![test_request(), test_request(), test_request()];
509 let result = pool.try_submit_batch(reqs);
510 assert!(result.is_err());
511 assert_eq!(pool.pending_count(), 0);
512 }
513
514 #[test]
515 fn try_submit_batch_empty_vec_is_ok() {
516 let pool = TaskPool::new(10);
517 let handles = pool.try_submit_batch(vec![]).unwrap();
518 assert!(handles.is_empty());
519 }
520
521 #[tokio::test]
524 async fn submit_blocks_then_proceeds_after_mark_completed() {
525 let pool = std::sync::Arc::new(TaskPool::new(1));
526 let _h1 = pool.try_submit(test_request()).unwrap();
528
529 let pool2 = pool.clone();
530 let join = tokio::spawn(async move {
531 let _handle = pool2.submit(test_request()).await;
533 });
534
535 tokio::time::sleep(Duration::from_millis(50)).await;
537 assert_eq!(pool.pending_count(), 1); {
541 let skip = HashSet::new();
542 let _task = pool.pick_next(&skip).unwrap();
543 pool.mark_completed();
544 }
545
546 join.await.unwrap();
548 assert_eq!(pool.pending_count(), 1);
549 }
550
551 #[tokio::test]
552 async fn submit_timeout_returns_err_on_expiry() {
553 let pool = TaskPool::new(1);
554 let _h1 = pool.try_submit(test_request()).unwrap();
555
556 let result = pool
557 .submit_timeout(test_request(), Duration::from_millis(50))
558 .await;
559 assert!(result.is_err());
560 match result.unwrap_err() {
561 ScatterProxyError::Timeout { elapsed } => {
562 assert_eq!(elapsed, Duration::from_millis(50));
563 }
564 other => panic!("expected Timeout, got {other:?}"),
565 }
566 }
567
568 #[tokio::test]
569 async fn submit_batch_processes_all() {
570 let pool = TaskPool::new(10);
571 let reqs = vec![test_request(), test_request()];
572 let handles = pool.submit_batch(reqs).await;
573 assert_eq!(handles.len(), 2);
574 assert_eq!(pool.pending_count(), 2);
575 }
576
577 #[test]
580 fn pick_next_returns_fifo_order() {
581 let pool = TaskPool::new(10);
582 let _h1 = pool.try_submit(test_request()).unwrap();
583 let _h2 = pool.try_submit(test_request()).unwrap();
584
585 let skip = HashSet::new();
586 let t1 = pool.pick_next(&skip).unwrap();
587 let t2 = pool.pick_next(&skip).unwrap();
588 assert!(t1.id < t2.id);
589 }
590
591 #[test]
592 fn pick_next_skips_circuit_broken_hosts() {
593 let pool = TaskPool::new(10);
594 let _h1 = pool.try_submit(test_request()).unwrap(); let mut skip = HashSet::new();
597 skip.insert("example.com".into());
598 assert!(pool.pick_next(&skip).is_none());
599 }
600
601 #[test]
602 fn pick_next_returns_none_when_all_hosts_skipped() {
603 let pool = TaskPool::new(10);
604 let _h1 = pool.try_submit(test_request()).unwrap();
605 let _h2 = pool.try_submit(test_request()).unwrap();
606
607 let mut skip = HashSet::new();
608 skip.insert("example.com".into());
609 assert!(pool.pick_next(&skip).is_none());
610 assert_eq!(pool.pending_count(), 2);
611 }
612
613 #[test]
614 fn pick_next_returns_none_when_empty() {
615 let pool = TaskPool::new(10);
616 let skip = HashSet::new();
617 assert!(pool.pick_next(&skip).is_none());
618 }
619
620 #[test]
621 fn pick_next_selects_first_non_skipped_preserves_order() {
622 let pool = TaskPool::new(10);
623 let _h1 = pool.try_submit(test_request()).unwrap();
625 let req2 = reqwest::Client::new()
627 .get("http://other.com/path")
628 .build()
629 .unwrap();
630 let _h2 = pool.try_submit(req2).unwrap();
631 let _h3 = pool.try_submit(test_request()).unwrap();
633
634 let mut skip = HashSet::new();
635 skip.insert("example.com".into());
636
637 let picked = pool.pick_next(&skip).unwrap();
638 assert_eq!(picked.host, "other.com");
639 assert_eq!(pool.pending_count(), 2);
640 }
641
642 #[test]
645 fn push_back_requeues_to_tail() {
646 let pool = TaskPool::new(10);
647 let _h1 = pool.try_submit(test_request()).unwrap();
648 let _h2 = pool.try_submit(test_request()).unwrap();
649
650 let skip = HashSet::new();
651 let t1 = pool.pick_next(&skip).unwrap();
652 let id1 = t1.id;
653 pool.push_back(t1);
654
655 let t2 = pool.pick_next(&skip).unwrap();
657 let re_t1 = pool.pick_next(&skip).unwrap();
658 assert!(t2.id < id1 || re_t1.id == id1);
659 }
660
661 #[test]
664 fn delayed_task_promotes_when_ready() {
665 let pool = TaskPool::new(10);
666 let _ = pool.try_submit(test_request()).unwrap();
667 let skip = HashSet::new();
668 let task = pool.pick_next(&skip).unwrap();
669 pool.push_delayed(task, Duration::from_millis(10));
670 assert_eq!(pool.delayed_count(), 1);
671 std::thread::sleep(Duration::from_millis(20));
672 let promoted = pool.promote_ready_delayed();
673 assert_eq!(promoted, 1);
674 assert_eq!(pool.delayed_count(), 0);
675 assert_eq!(pool.pending_count(), 1);
676 }
677
678 #[test]
679 fn mark_completed_increments_counter() {
680 let pool = TaskPool::new(10);
681 pool.mark_completed();
682 assert_eq!(pool.completed_count(), 1);
683 }
684
685 #[tokio::test]
688 async fn task_handle_receives_success() {
689 let pool = TaskPool::new(10);
690 let handle = pool.try_submit(test_request()).unwrap();
691
692 let skip = HashSet::new();
693 let mut task = pool.pick_next(&skip).unwrap();
694 if let Some(tx) = task.result_tx.take() {
695 let _ = tx.send(ScatterResponse {
696 status: StatusCode::OK,
697 headers: HeaderMap::new(),
698 body: Bytes::from_static(b"hello"),
699 });
700 }
701
702 let resp = handle
703 .with_timeout(Duration::from_secs(1))
704 .await
705 .unwrap()
706 .unwrap();
707 assert_eq!(resp.status, StatusCode::OK);
708 assert_eq!(resp.body.as_ref(), b"hello");
709 }
710
711 #[tokio::test]
712 async fn task_handle_returns_502_when_sender_dropped() {
713 let pool = TaskPool::new(10);
714 let handle = pool.try_submit(test_request()).unwrap();
715
716 let skip = HashSet::new();
718 let _task = pool.pick_next(&skip).unwrap();
719 drop(_task);
720
721 let resp = handle
722 .with_timeout(Duration::from_secs(1))
723 .await
724 .unwrap()
725 .unwrap();
726 assert_eq!(resp.status, StatusCode::BAD_GATEWAY);
727 }
728
729 #[tokio::test]
730 async fn task_handle_with_timeout_ok() {
731 let pool = TaskPool::new(10);
732 let handle = pool.try_submit(test_request()).unwrap();
733
734 let skip = HashSet::new();
735 let mut task = pool.pick_next(&skip).unwrap();
736 if let Some(tx) = task.result_tx.take() {
737 let _ = tx.send(ScatterResponse {
738 status: StatusCode::OK,
739 headers: HeaderMap::new(),
740 body: Bytes::from_static(b"ok"),
741 });
742 }
743
744 let resp = handle
745 .with_timeout(Duration::from_secs(5))
746 .await
747 .unwrap()
748 .unwrap();
749 assert_eq!(resp.status, StatusCode::OK);
750 }
751
752 #[tokio::test]
753 async fn task_handle_with_timeout_expires() {
754 let pool = TaskPool::new(10);
755 let handle = pool.try_submit(test_request()).unwrap();
756
757 let result = handle
758 .with_timeout(Duration::from_millis(50))
759 .await
760 .unwrap();
761 assert!(result.is_none());
762 }
763
764 #[tokio::test]
767 async fn notified_wakes_on_try_submit() {
768 let pool = std::sync::Arc::new(TaskPool::new(10));
769 let pool2 = pool.clone();
770
771 let waiter = tokio::spawn(async move {
772 pool2.notified().await;
773 true
774 });
775
776 tokio::time::sleep(Duration::from_millis(20)).await;
777 let _h = pool.try_submit(test_request()).unwrap();
778
779 assert!(waiter.await.unwrap());
780 }
781
782 #[tokio::test]
783 async fn notified_wakes_on_push_back() {
784 let pool = std::sync::Arc::new(TaskPool::new(10));
785 let _h = pool.try_submit(test_request()).unwrap();
786
787 let skip = HashSet::new();
788 let task = pool.pick_next(&skip).unwrap();
789
790 let pool2 = pool.clone();
791 let waiter = tokio::spawn(async move {
792 pool2.notified().await;
793 true
794 });
795
796 tokio::time::sleep(Duration::from_millis(20)).await;
797 pool.push_back(task);
798
799 assert!(waiter.await.unwrap());
800 }
801
802 #[test]
805 fn pool_with_zero_capacity_rejects_everything() {
806 let pool = TaskPool::new(0);
807 let result = pool.try_submit(test_request());
808 assert!(result.is_err());
809 }
810
811 #[test]
812 fn pool_allows_try_submit_after_mark_completed_frees_space() {
813 let pool = TaskPool::new(1);
814 let _h1 = pool.try_submit(test_request()).unwrap();
815 assert!(pool.try_submit(test_request()).is_err());
817
818 let skip = HashSet::new();
820 let _task = pool.pick_next(&skip).unwrap();
821 pool.mark_completed();
822
823 let _h2 = pool.try_submit(test_request()).unwrap();
825 }
826
827 #[test]
828 fn task_entry_has_correct_defaults_on_try_submit() {
829 let pool = TaskPool::new(10);
830 let _h = pool.try_submit(test_request()).unwrap();
831
832 let skip = HashSet::new();
833 let task = pool.pick_next(&skip).unwrap();
834 assert_eq!(task.attempts, 0);
835 assert!(task.last_error.is_empty());
836 assert!(task.result_tx.is_some());
837 }
838
839 #[test]
840 fn scatter_response_debug() {
841 let resp = ScatterResponse {
842 status: StatusCode::OK,
843 headers: HeaderMap::new(),
844 body: Bytes::from_static(b"test"),
845 };
846 let dbg = format!("{resp:?}");
847 assert!(dbg.contains("200"));
848 }
849}