1use std::collections::HashSet;
2use std::collections::VecDeque;
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::atomic::{AtomicU64, Ordering};
6use std::sync::Mutex;
7use std::task::{Context, Poll};
8use std::time::Duration;
9
10use bytes::Bytes;
11use http::{HeaderMap, StatusCode};
12use tokio::sync::{oneshot, Notify, Semaphore};
13
14use crate::error::ScatterProxyError;
15
16#[derive(Debug)]
18pub struct ScatterResponse {
19 pub status: StatusCode,
20 pub headers: HeaderMap,
21 pub body: Bytes,
22}
23
24#[derive(Debug)]
32pub struct TaskHandle {
33 rx: oneshot::Receiver<ScatterResponse>,
34}
35
36impl TaskHandle {
37 pub async fn with_timeout(
46 self,
47 duration: Duration,
48 ) -> Result<ScatterResponse, ScatterProxyError> {
49 match tokio::time::timeout(duration, self).await {
50 Ok(resp) => Ok(resp),
51 Err(_) => Err(ScatterProxyError::Timeout { elapsed: duration }),
52 }
53 }
54}
55
56impl Future for TaskHandle {
57 type Output = ScatterResponse;
58
59 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
60 match Pin::new(&mut self.rx).poll(cx) {
61 Poll::Ready(Ok(resp)) => Poll::Ready(resp),
62 Poll::Ready(Err(_)) => {
63 Poll::Ready(ScatterResponse {
66 status: StatusCode::BAD_GATEWAY,
67 headers: HeaderMap::new(),
68 body: Bytes::from_static(
69 b"scatter-proxy: internal error - task channel closed",
70 ),
71 })
72 }
73 Poll::Pending => Poll::Pending,
74 }
75 }
76}
77
78pub(crate) struct TaskEntry {
80 #[allow(dead_code)]
81 pub id: u64,
82 pub request: reqwest::Request,
83 pub host: String,
84 pub attempts: usize,
86 pub result_tx: Option<oneshot::Sender<ScatterResponse>>,
88 pub last_error: String,
90}
91
92pub struct TaskPool {
98 queue: Mutex<VecDeque<TaskEntry>>,
99 capacity: usize,
100 capacity_sem: Semaphore,
104 next_id: AtomicU64,
105 notify: Notify,
107 completed: AtomicU64,
108 failed: AtomicU64,
109}
110
111impl TaskPool {
112 pub fn new(capacity: usize) -> Self {
114 Self {
115 queue: Mutex::new(VecDeque::new()),
116 capacity,
117 capacity_sem: Semaphore::new(capacity),
118 next_id: AtomicU64::new(1),
119 notify: Notify::new(),
120 completed: AtomicU64::new(0),
121 failed: AtomicU64::new(0),
122 }
123 }
124
125 pub async fn submit(&self, request: reqwest::Request) -> TaskHandle {
132 let permit = self
134 .capacity_sem
135 .acquire()
136 .await
137 .expect("capacity semaphore closed");
138 permit.forget(); self.enqueue(request)
141 }
142
143 pub fn try_submit(&self, request: reqwest::Request) -> Result<TaskHandle, ScatterProxyError> {
145 let permit = self
146 .capacity_sem
147 .try_acquire()
148 .map_err(|_| ScatterProxyError::PoolFull {
149 capacity: self.capacity,
150 })?;
151 permit.forget();
152 Ok(self.enqueue(request))
153 }
154
155 pub async fn submit_timeout(
160 &self,
161 request: reqwest::Request,
162 timeout: Duration,
163 ) -> Result<TaskHandle, ScatterProxyError> {
164 match tokio::time::timeout(timeout, self.submit(request)).await {
165 Ok(handle) => Ok(handle),
166 Err(_) => Err(ScatterProxyError::Timeout { elapsed: timeout }),
167 }
168 }
169
170 pub async fn submit_batch(&self, requests: Vec<reqwest::Request>) -> Vec<TaskHandle> {
175 let mut handles = Vec::with_capacity(requests.len());
176 for req in requests {
177 handles.push(self.submit(req).await);
178 }
179 handles
180 }
181
182 pub fn try_submit_batch(
185 &self,
186 requests: Vec<reqwest::Request>,
187 ) -> Result<Vec<TaskHandle>, ScatterProxyError> {
188 let count = requests.len();
189 if count == 0 {
190 return Ok(Vec::new());
191 }
192
193 let permit = self
195 .capacity_sem
196 .try_acquire_many(count as u32)
197 .map_err(|_| ScatterProxyError::PoolFull {
198 capacity: self.capacity,
199 })?;
200 permit.forget();
201
202 let mut handles = Vec::with_capacity(count);
203 for req in requests {
204 handles.push(self.enqueue(req));
205 }
206 Ok(handles)
207 }
208
209 fn enqueue(&self, request: reqwest::Request) -> TaskHandle {
215 let host = request.url().host_str().unwrap_or("unknown").to_string();
216 let (tx, rx) = oneshot::channel();
217 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
218
219 let entry = TaskEntry {
220 id,
221 request,
222 host,
223 attempts: 0,
224 result_tx: Some(tx),
225 last_error: String::new(),
226 };
227
228 {
229 let mut queue = self.queue.lock().unwrap();
230 queue.push_back(entry);
231 }
232
233 self.notify.notify_one();
234 TaskHandle { rx }
235 }
236
237 pub(crate) fn pick_next(&self, skip_hosts: &HashSet<String>) -> Option<TaskEntry> {
244 let mut queue = self.queue.lock().unwrap();
245 let len = queue.len();
246
247 for i in 0..len {
248 if let Some(entry) = queue.get(i) {
249 if !skip_hosts.contains(&entry.host) {
250 return queue.remove(i);
251 }
252 }
253 }
254
255 None
256 }
257
258 pub(crate) fn push_back(&self, entry: TaskEntry) {
260 {
261 let mut queue = self.queue.lock().unwrap();
262 queue.push_back(entry);
263 }
264 self.notify.notify_one();
265 }
266
267 pub fn pending_count(&self) -> usize {
269 let queue = self.queue.lock().unwrap();
270 queue.len()
271 }
272
273 pub fn completed_count(&self) -> u64 {
275 self.completed.load(Ordering::Relaxed)
276 }
277
278 pub(crate) fn mark_completed(&self) {
281 self.completed.fetch_add(1, Ordering::Relaxed);
282 self.capacity_sem.add_permits(1);
283 }
284
285 pub(crate) fn mark_failed(&self) {
288 self.failed.fetch_add(1, Ordering::Relaxed);
289 self.capacity_sem.add_permits(1);
290 }
291
292 pub fn failed_count(&self) -> u64 {
294 self.failed.load(Ordering::Relaxed)
295 }
296
297 #[allow(dead_code)]
299 pub(crate) async fn notified(&self) {
300 self.notify.notified().await;
301 }
302}
303
304#[cfg(test)]
307mod tests {
308 use super::*;
309
310 fn test_request() -> reqwest::Request {
312 reqwest::Client::new()
313 .get("http://example.com/test")
314 .build()
315 .unwrap()
316 }
317
318 #[test]
321 fn new_pool_has_zero_pending() {
322 let pool = TaskPool::new(10);
323 assert_eq!(pool.pending_count(), 0);
324 assert_eq!(pool.completed_count(), 0);
325 }
326
327 #[test]
330 fn try_submit_increments_pending_count() {
331 let pool = TaskPool::new(10);
332 let _h1 = pool.try_submit(test_request()).unwrap();
333 let _h2 = pool.try_submit(test_request()).unwrap();
334 assert_eq!(pool.pending_count(), 2);
335 }
336
337 #[test]
338 fn try_submit_returns_pool_full_when_at_capacity() {
339 let pool = TaskPool::new(2);
340 let _h1 = pool.try_submit(test_request()).unwrap();
341 let _h2 = pool.try_submit(test_request()).unwrap();
342 let result = pool.try_submit(test_request());
343 assert!(result.is_err());
344 match result.unwrap_err() {
345 ScatterProxyError::PoolFull { capacity } => assert_eq!(capacity, 2),
346 other => panic!("expected PoolFull, got {other:?}"),
347 }
348 }
349
350 #[test]
351 fn try_submit_assigns_incrementing_ids() {
352 let pool = TaskPool::new(10);
353 let _h1 = pool.try_submit(test_request()).unwrap();
354 let _h2 = pool.try_submit(test_request()).unwrap();
355
356 let skip = HashSet::new();
357 let t1 = pool.pick_next(&skip).unwrap();
358 let t2 = pool.pick_next(&skip).unwrap();
359 assert!(t2.id > t1.id);
360 }
361
362 #[test]
363 fn try_submit_extracts_host_from_url() {
364 let pool = TaskPool::new(10);
365 let _h = pool.try_submit(test_request()).unwrap();
366 let skip = HashSet::new();
367 let task = pool.pick_next(&skip).unwrap();
368 assert_eq!(task.host, "example.com");
369 }
370
371 #[test]
374 fn try_submit_batch_adds_all_tasks() {
375 let pool = TaskPool::new(10);
376 let reqs = vec![test_request(), test_request(), test_request()];
377 let handles = pool.try_submit_batch(reqs).unwrap();
378 assert_eq!(handles.len(), 3);
379 assert_eq!(pool.pending_count(), 3);
380 }
381
382 #[test]
383 fn try_submit_batch_atomic_rejection_when_pool_full() {
384 let pool = TaskPool::new(2);
385 let reqs = vec![test_request(), test_request(), test_request()];
386 let result = pool.try_submit_batch(reqs);
387 assert!(result.is_err());
388 assert_eq!(pool.pending_count(), 0);
389 }
390
391 #[test]
392 fn try_submit_batch_empty_vec_is_ok() {
393 let pool = TaskPool::new(10);
394 let handles = pool.try_submit_batch(vec![]).unwrap();
395 assert!(handles.is_empty());
396 }
397
398 #[tokio::test]
401 async fn submit_blocks_then_proceeds_after_mark_completed() {
402 let pool = std::sync::Arc::new(TaskPool::new(1));
403 let _h1 = pool.try_submit(test_request()).unwrap();
405
406 let pool2 = pool.clone();
407 let join = tokio::spawn(async move {
408 let _handle = pool2.submit(test_request()).await;
410 });
411
412 tokio::time::sleep(Duration::from_millis(50)).await;
414 assert_eq!(pool.pending_count(), 1); {
418 let skip = HashSet::new();
419 let _task = pool.pick_next(&skip).unwrap();
420 pool.mark_completed();
421 }
422
423 join.await.unwrap();
425 assert_eq!(pool.pending_count(), 1);
426 }
427
428 #[tokio::test]
429 async fn submit_timeout_returns_err_on_expiry() {
430 let pool = TaskPool::new(1);
431 let _h1 = pool.try_submit(test_request()).unwrap();
432
433 let result = pool
434 .submit_timeout(test_request(), Duration::from_millis(50))
435 .await;
436 assert!(result.is_err());
437 match result.unwrap_err() {
438 ScatterProxyError::Timeout { elapsed } => {
439 assert_eq!(elapsed, Duration::from_millis(50));
440 }
441 other => panic!("expected Timeout, got {other:?}"),
442 }
443 }
444
445 #[tokio::test]
446 async fn submit_batch_processes_all() {
447 let pool = TaskPool::new(10);
448 let reqs = vec![test_request(), test_request()];
449 let handles = pool.submit_batch(reqs).await;
450 assert_eq!(handles.len(), 2);
451 assert_eq!(pool.pending_count(), 2);
452 }
453
454 #[test]
457 fn pick_next_returns_fifo_order() {
458 let pool = TaskPool::new(10);
459 let _h1 = pool.try_submit(test_request()).unwrap();
460 let _h2 = pool.try_submit(test_request()).unwrap();
461
462 let skip = HashSet::new();
463 let t1 = pool.pick_next(&skip).unwrap();
464 let t2 = pool.pick_next(&skip).unwrap();
465 assert!(t1.id < t2.id);
466 }
467
468 #[test]
469 fn pick_next_skips_circuit_broken_hosts() {
470 let pool = TaskPool::new(10);
471 let _h1 = pool.try_submit(test_request()).unwrap(); let mut skip = HashSet::new();
474 skip.insert("example.com".into());
475 assert!(pool.pick_next(&skip).is_none());
476 }
477
478 #[test]
479 fn pick_next_returns_none_when_all_hosts_skipped() {
480 let pool = TaskPool::new(10);
481 let _h1 = pool.try_submit(test_request()).unwrap();
482 let _h2 = pool.try_submit(test_request()).unwrap();
483
484 let mut skip = HashSet::new();
485 skip.insert("example.com".into());
486 assert!(pool.pick_next(&skip).is_none());
487 assert_eq!(pool.pending_count(), 2);
488 }
489
490 #[test]
491 fn pick_next_returns_none_when_empty() {
492 let pool = TaskPool::new(10);
493 let skip = HashSet::new();
494 assert!(pool.pick_next(&skip).is_none());
495 }
496
497 #[test]
498 fn pick_next_selects_first_non_skipped_preserves_order() {
499 let pool = TaskPool::new(10);
500 let _h1 = pool.try_submit(test_request()).unwrap();
502 let req2 = reqwest::Client::new()
504 .get("http://other.com/path")
505 .build()
506 .unwrap();
507 let _h2 = pool.try_submit(req2).unwrap();
508 let _h3 = pool.try_submit(test_request()).unwrap();
510
511 let mut skip = HashSet::new();
512 skip.insert("example.com".into());
513
514 let picked = pool.pick_next(&skip).unwrap();
515 assert_eq!(picked.host, "other.com");
516 assert_eq!(pool.pending_count(), 2);
517 }
518
519 #[test]
522 fn push_back_requeues_to_tail() {
523 let pool = TaskPool::new(10);
524 let _h1 = pool.try_submit(test_request()).unwrap();
525 let _h2 = pool.try_submit(test_request()).unwrap();
526
527 let skip = HashSet::new();
528 let t1 = pool.pick_next(&skip).unwrap();
529 let id1 = t1.id;
530 pool.push_back(t1);
531
532 let t2 = pool.pick_next(&skip).unwrap();
534 let re_t1 = pool.pick_next(&skip).unwrap();
535 assert!(t2.id < id1 || re_t1.id == id1);
536 }
537
538 #[test]
541 fn mark_completed_increments_counter() {
542 let pool = TaskPool::new(10);
543 pool.mark_completed();
544 assert_eq!(pool.completed_count(), 1);
545 }
546
547 #[tokio::test]
550 async fn task_handle_receives_success() {
551 let pool = TaskPool::new(10);
552 let handle = pool.try_submit(test_request()).unwrap();
553
554 let skip = HashSet::new();
555 let mut task = pool.pick_next(&skip).unwrap();
556 if let Some(tx) = task.result_tx.take() {
557 let _ = tx.send(ScatterResponse {
558 status: StatusCode::OK,
559 headers: HeaderMap::new(),
560 body: Bytes::from_static(b"hello"),
561 });
562 }
563
564 let resp = handle.await;
565 assert_eq!(resp.status, StatusCode::OK);
566 assert_eq!(resp.body.as_ref(), b"hello");
567 }
568
569 #[tokio::test]
570 async fn task_handle_returns_502_when_sender_dropped() {
571 let pool = TaskPool::new(10);
572 let handle = pool.try_submit(test_request()).unwrap();
573
574 let skip = HashSet::new();
576 let _task = pool.pick_next(&skip).unwrap();
577 drop(_task);
578
579 let resp = handle.await;
580 assert_eq!(resp.status, StatusCode::BAD_GATEWAY);
581 }
582
583 #[tokio::test]
584 async fn task_handle_with_timeout_ok() {
585 let pool = TaskPool::new(10);
586 let handle = pool.try_submit(test_request()).unwrap();
587
588 let skip = HashSet::new();
589 let mut task = pool.pick_next(&skip).unwrap();
590 if let Some(tx) = task.result_tx.take() {
591 let _ = tx.send(ScatterResponse {
592 status: StatusCode::OK,
593 headers: HeaderMap::new(),
594 body: Bytes::from_static(b"ok"),
595 });
596 }
597
598 let resp = handle.with_timeout(Duration::from_secs(5)).await.unwrap();
599 assert_eq!(resp.status, StatusCode::OK);
600 }
601
602 #[tokio::test]
603 async fn task_handle_with_timeout_expires() {
604 let pool = TaskPool::new(10);
605 let handle = pool.try_submit(test_request()).unwrap();
606
607 let result = handle.with_timeout(Duration::from_millis(50)).await;
608 assert!(result.is_err());
609 }
610
611 #[tokio::test]
614 async fn notified_wakes_on_try_submit() {
615 let pool = std::sync::Arc::new(TaskPool::new(10));
616 let pool2 = pool.clone();
617
618 let waiter = tokio::spawn(async move {
619 pool2.notified().await;
620 true
621 });
622
623 tokio::time::sleep(Duration::from_millis(20)).await;
624 let _h = pool.try_submit(test_request()).unwrap();
625
626 assert!(waiter.await.unwrap());
627 }
628
629 #[tokio::test]
630 async fn notified_wakes_on_push_back() {
631 let pool = std::sync::Arc::new(TaskPool::new(10));
632 let _h = pool.try_submit(test_request()).unwrap();
633
634 let skip = HashSet::new();
635 let task = pool.pick_next(&skip).unwrap();
636
637 let pool2 = pool.clone();
638 let waiter = tokio::spawn(async move {
639 pool2.notified().await;
640 true
641 });
642
643 tokio::time::sleep(Duration::from_millis(20)).await;
644 pool.push_back(task);
645
646 assert!(waiter.await.unwrap());
647 }
648
649 #[test]
652 fn pool_with_zero_capacity_rejects_everything() {
653 let pool = TaskPool::new(0);
654 let result = pool.try_submit(test_request());
655 assert!(result.is_err());
656 }
657
658 #[test]
659 fn pool_allows_try_submit_after_mark_completed_frees_space() {
660 let pool = TaskPool::new(1);
661 let _h1 = pool.try_submit(test_request()).unwrap();
662 assert!(pool.try_submit(test_request()).is_err());
664
665 let skip = HashSet::new();
667 let _task = pool.pick_next(&skip).unwrap();
668 pool.mark_completed();
669
670 let _h2 = pool.try_submit(test_request()).unwrap();
672 }
673
674 #[test]
675 fn task_entry_has_correct_defaults_on_try_submit() {
676 let pool = TaskPool::new(10);
677 let _h = pool.try_submit(test_request()).unwrap();
678
679 let skip = HashSet::new();
680 let task = pool.pick_next(&skip).unwrap();
681 assert_eq!(task.attempts, 0);
682 assert!(task.last_error.is_empty());
683 assert!(task.result_tx.is_some());
684 }
685
686 #[test]
687 fn scatter_response_debug() {
688 let resp = ScatterResponse {
689 status: StatusCode::OK,
690 headers: HeaderMap::new(),
691 body: Bytes::from_static(b"test"),
692 };
693 let dbg = format!("{resp:?}");
694 assert!(dbg.contains("200"));
695 }
696}