use std::collections::HashSet;
use std::collections::VecDeque;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Mutex;
use std::time::{Duration, Instant};
use bytes::Bytes;
use http::{HeaderMap, StatusCode};
use tokio::sync::Mutex as AsyncMutex;
use tokio::sync::{oneshot, Notify, Semaphore};
use crate::error::ScatterProxyError;
#[derive(Debug)]
pub struct ScatterResponse {
pub status: StatusCode,
pub headers: HeaderMap,
pub body: Bytes,
}
#[derive(Debug)]
pub struct TaskHandle {
rx: AsyncMutex<oneshot::Receiver<ScatterResponse>>,
}
impl TaskHandle {
pub async fn with_timeout(
&self,
duration: Duration,
) -> Result<Option<ScatterResponse>, ScatterProxyError> {
let mut rx = self.rx.lock().await;
match tokio::time::timeout(duration, &mut *rx).await {
Ok(Ok(resp)) => Ok(Some(resp)),
Ok(Err(_)) => Ok(Some(ScatterResponse {
status: StatusCode::BAD_GATEWAY,
headers: HeaderMap::new(),
body: Bytes::from_static(b"scatter-proxy: internal error - task channel closed"),
})),
Err(_) => Ok(None),
}
}
}
#[derive(Debug)]
pub(crate) struct TaskEntry {
#[allow(dead_code)]
pub id: u64,
pub request: reqwest::Request,
pub host: String,
pub attempts: usize,
pub result_tx: Option<oneshot::Sender<ScatterResponse>>,
pub last_error: String,
}
#[derive(Debug)]
struct DelayedTask {
ready_at: Instant,
entry: TaskEntry,
}
impl PartialEq for DelayedTask {
fn eq(&self, other: &Self) -> bool {
self.ready_at.eq(&other.ready_at)
}
}
impl Eq for DelayedTask {}
impl PartialOrd for DelayedTask {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for DelayedTask {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.ready_at.cmp(&other.ready_at)
}
}
pub struct TaskPool {
queue: Mutex<VecDeque<TaskEntry>>,
delayed: Mutex<std::collections::BinaryHeap<std::cmp::Reverse<DelayedTask>>>,
capacity: usize,
capacity_sem: Semaphore,
next_id: AtomicU64,
notify: Notify,
completed: AtomicU64,
failed: AtomicU64,
requeued: AtomicU64,
zero_available: AtomicU64,
skipped_no_permit: AtomicU64,
skipped_rate_limit: AtomicU64,
skipped_cooldown: AtomicU64,
dispatches: AtomicU64,
}
impl TaskPool {
pub fn new(capacity: usize) -> Self {
Self {
queue: Mutex::new(VecDeque::new()),
delayed: Mutex::new(std::collections::BinaryHeap::new()),
capacity,
capacity_sem: Semaphore::new(capacity),
next_id: AtomicU64::new(1),
notify: Notify::new(),
completed: AtomicU64::new(0),
failed: AtomicU64::new(0),
requeued: AtomicU64::new(0),
zero_available: AtomicU64::new(0),
skipped_no_permit: AtomicU64::new(0),
skipped_rate_limit: AtomicU64::new(0),
skipped_cooldown: AtomicU64::new(0),
dispatches: AtomicU64::new(0),
}
}
pub async fn submit(&self, request: reqwest::Request) -> TaskHandle {
let permit = self
.capacity_sem
.acquire()
.await
.expect("capacity semaphore closed");
permit.forget();
self.enqueue(request)
}
pub fn try_submit(&self, request: reqwest::Request) -> Result<TaskHandle, ScatterProxyError> {
let permit = self
.capacity_sem
.try_acquire()
.map_err(|_| ScatterProxyError::PoolFull {
capacity: self.capacity,
})?;
permit.forget();
Ok(self.enqueue(request))
}
pub async fn submit_timeout(
&self,
request: reqwest::Request,
timeout: Duration,
) -> Result<TaskHandle, ScatterProxyError> {
match tokio::time::timeout(timeout, self.submit(request)).await {
Ok(handle) => Ok(handle),
Err(_) => Err(ScatterProxyError::Timeout { elapsed: timeout }),
}
}
pub async fn submit_batch(&self, requests: Vec<reqwest::Request>) -> Vec<TaskHandle> {
let mut handles = Vec::with_capacity(requests.len());
for req in requests {
handles.push(self.submit(req).await);
}
handles
}
pub fn try_submit_batch(
&self,
requests: Vec<reqwest::Request>,
) -> Result<Vec<TaskHandle>, ScatterProxyError> {
let count = requests.len();
if count == 0 {
return Ok(Vec::new());
}
let permit = self
.capacity_sem
.try_acquire_many(count as u32)
.map_err(|_| ScatterProxyError::PoolFull {
capacity: self.capacity,
})?;
permit.forget();
let mut handles = Vec::with_capacity(count);
for req in requests {
handles.push(self.enqueue(req));
}
Ok(handles)
}
fn enqueue(&self, request: reqwest::Request) -> TaskHandle {
let host = request.url().host_str().unwrap_or("unknown").to_string();
let (tx, rx) = oneshot::channel();
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let entry = TaskEntry {
id,
request,
host,
attempts: 0,
result_tx: Some(tx),
last_error: String::new(),
};
{
self.promote_ready_delayed();
let mut queue = self.queue.lock().unwrap();
queue.push_back(entry);
}
self.notify.notify_one();
TaskHandle {
rx: AsyncMutex::new(rx),
}
}
pub(crate) fn promote_ready_delayed(&self) -> usize {
let now = Instant::now();
let mut delayed = self.delayed.lock().unwrap();
if delayed.is_empty() {
return 0;
}
let mut ready = Vec::new();
while let Some(std::cmp::Reverse(item)) = delayed.peek() {
if item.ready_at <= now {
let std::cmp::Reverse(item) = delayed.pop().expect("heap peeked item must pop");
ready.push(item.entry);
} else {
break;
}
}
drop(delayed);
if ready.is_empty() {
return 0;
}
let count = ready.len();
let mut queue = self.queue.lock().unwrap();
for entry in ready {
queue.push_back(entry);
}
count
}
pub(crate) fn next_delayed_ready_in(&self) -> Option<Duration> {
let delayed = self.delayed.lock().unwrap();
let now = Instant::now();
delayed
.peek()
.map(|d| d.0.ready_at.saturating_duration_since(now))
}
pub(crate) fn pick_next(&self, skip_hosts: &HashSet<String>) -> Option<TaskEntry> {
let mut queue = self.queue.lock().unwrap();
if skip_hosts.is_empty() {
return queue.pop_front();
}
let len = queue.len();
for _ in 0..len {
let entry = queue.pop_front()?;
if !skip_hosts.contains(&entry.host) {
return Some(entry);
}
queue.push_back(entry);
}
None
}
pub(crate) fn push_back(&self, entry: TaskEntry) {
self.requeued.fetch_add(1, Ordering::Relaxed);
{
let mut queue = self.queue.lock().unwrap();
queue.push_back(entry);
}
self.notify.notify_one();
}
pub(crate) fn push_delayed(&self, entry: TaskEntry, delay: Duration) {
self.requeued.fetch_add(1, Ordering::Relaxed);
{
let mut delayed = self.delayed.lock().unwrap();
delayed.push(std::cmp::Reverse(DelayedTask {
ready_at: Instant::now() + delay,
entry,
}));
}
self.notify.notify_one();
}
pub fn pending_count(&self) -> usize {
let queue = self.queue.lock().unwrap();
queue.len()
}
pub fn delayed_count(&self) -> usize {
let delayed = self.delayed.lock().unwrap();
delayed.len()
}
pub fn completed_count(&self) -> u64 {
self.completed.load(Ordering::Relaxed)
}
pub(crate) fn mark_completed(&self) {
self.completed.fetch_add(1, Ordering::Relaxed);
self.capacity_sem.add_permits(1);
}
pub(crate) fn mark_failed(&self) {
self.failed.fetch_add(1, Ordering::Relaxed);
self.capacity_sem.add_permits(1);
}
pub fn failed_count(&self) -> u64 {
self.failed.load(Ordering::Relaxed)
}
pub fn requeued_count(&self) -> u64 {
self.requeued.load(Ordering::Relaxed)
}
pub(crate) fn mark_zero_available(&self) {
self.zero_available.fetch_add(1, Ordering::Relaxed);
}
pub fn zero_available_count(&self) -> u64 {
self.zero_available.load(Ordering::Relaxed)
}
pub(crate) fn mark_skipped_no_permit(&self) {
self.skipped_no_permit.fetch_add(1, Ordering::Relaxed);
}
pub fn skipped_no_permit_count(&self) -> u64 {
self.skipped_no_permit.load(Ordering::Relaxed)
}
pub(crate) fn mark_skipped_rate_limit(&self) {
self.skipped_rate_limit.fetch_add(1, Ordering::Relaxed);
}
pub fn skipped_rate_limit_count(&self) -> u64 {
self.skipped_rate_limit.load(Ordering::Relaxed)
}
pub(crate) fn mark_skipped_cooldown(&self) {
self.skipped_cooldown.fetch_add(1, Ordering::Relaxed);
}
pub fn skipped_cooldown_count(&self) -> u64 {
self.skipped_cooldown.load(Ordering::Relaxed)
}
pub(crate) fn mark_dispatch(&self) {
self.dispatches.fetch_add(1, Ordering::Relaxed);
}
pub fn dispatch_count(&self) -> u64 {
self.dispatches.load(Ordering::Relaxed)
}
#[allow(dead_code)]
pub(crate) async fn notified(&self) {
self.notify.notified().await;
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_request() -> reqwest::Request {
reqwest::Client::new()
.get("http://example.com/test")
.build()
.unwrap()
}
#[test]
fn new_pool_has_zero_pending() {
let pool = TaskPool::new(10);
assert_eq!(pool.pending_count(), 0);
assert_eq!(pool.delayed_count(), 0);
assert_eq!(pool.completed_count(), 0);
}
#[test]
fn try_submit_increments_pending_count() {
let pool = TaskPool::new(10);
let _h1 = pool.try_submit(test_request()).unwrap();
let _h2 = pool.try_submit(test_request()).unwrap();
assert_eq!(pool.pending_count(), 2);
}
#[test]
fn try_submit_returns_pool_full_when_at_capacity() {
let pool = TaskPool::new(2);
let _h1 = pool.try_submit(test_request()).unwrap();
let _h2 = pool.try_submit(test_request()).unwrap();
let result = pool.try_submit(test_request());
assert!(result.is_err());
match result.unwrap_err() {
ScatterProxyError::PoolFull { capacity } => assert_eq!(capacity, 2),
other => panic!("expected PoolFull, got {other:?}"),
}
}
#[test]
fn try_submit_assigns_incrementing_ids() {
let pool = TaskPool::new(10);
let _h1 = pool.try_submit(test_request()).unwrap();
let _h2 = pool.try_submit(test_request()).unwrap();
let skip = HashSet::new();
let t1 = pool.pick_next(&skip).unwrap();
let t2 = pool.pick_next(&skip).unwrap();
assert!(t2.id > t1.id);
}
#[test]
fn try_submit_extracts_host_from_url() {
let pool = TaskPool::new(10);
let _h = pool.try_submit(test_request()).unwrap();
let skip = HashSet::new();
let task = pool.pick_next(&skip).unwrap();
assert_eq!(task.host, "example.com");
}
#[test]
fn try_submit_batch_adds_all_tasks() {
let pool = TaskPool::new(10);
let reqs = vec![test_request(), test_request(), test_request()];
let handles = pool.try_submit_batch(reqs).unwrap();
assert_eq!(handles.len(), 3);
assert_eq!(pool.pending_count(), 3);
}
#[test]
fn try_submit_batch_atomic_rejection_when_pool_full() {
let pool = TaskPool::new(2);
let reqs = vec![test_request(), test_request(), test_request()];
let result = pool.try_submit_batch(reqs);
assert!(result.is_err());
assert_eq!(pool.pending_count(), 0);
}
#[test]
fn try_submit_batch_empty_vec_is_ok() {
let pool = TaskPool::new(10);
let handles = pool.try_submit_batch(vec![]).unwrap();
assert!(handles.is_empty());
}
#[tokio::test]
async fn submit_blocks_then_proceeds_after_mark_completed() {
let pool = std::sync::Arc::new(TaskPool::new(1));
let _h1 = pool.try_submit(test_request()).unwrap();
let pool2 = pool.clone();
let join = tokio::spawn(async move {
let _handle = pool2.submit(test_request()).await;
});
tokio::time::sleep(Duration::from_millis(50)).await;
assert_eq!(pool.pending_count(), 1);
{
let skip = HashSet::new();
let _task = pool.pick_next(&skip).unwrap();
pool.mark_completed();
}
join.await.unwrap();
assert_eq!(pool.pending_count(), 1);
}
#[tokio::test]
async fn submit_timeout_returns_err_on_expiry() {
let pool = TaskPool::new(1);
let _h1 = pool.try_submit(test_request()).unwrap();
let result = pool
.submit_timeout(test_request(), Duration::from_millis(50))
.await;
assert!(result.is_err());
match result.unwrap_err() {
ScatterProxyError::Timeout { elapsed } => {
assert_eq!(elapsed, Duration::from_millis(50));
}
other => panic!("expected Timeout, got {other:?}"),
}
}
#[tokio::test]
async fn submit_batch_processes_all() {
let pool = TaskPool::new(10);
let reqs = vec![test_request(), test_request()];
let handles = pool.submit_batch(reqs).await;
assert_eq!(handles.len(), 2);
assert_eq!(pool.pending_count(), 2);
}
#[test]
fn pick_next_returns_fifo_order() {
let pool = TaskPool::new(10);
let _h1 = pool.try_submit(test_request()).unwrap();
let _h2 = pool.try_submit(test_request()).unwrap();
let skip = HashSet::new();
let t1 = pool.pick_next(&skip).unwrap();
let t2 = pool.pick_next(&skip).unwrap();
assert!(t1.id < t2.id);
}
#[test]
fn pick_next_skips_circuit_broken_hosts() {
let pool = TaskPool::new(10);
let _h1 = pool.try_submit(test_request()).unwrap();
let mut skip = HashSet::new();
skip.insert("example.com".into());
assert!(pool.pick_next(&skip).is_none());
}
#[test]
fn pick_next_returns_none_when_all_hosts_skipped() {
let pool = TaskPool::new(10);
let _h1 = pool.try_submit(test_request()).unwrap();
let _h2 = pool.try_submit(test_request()).unwrap();
let mut skip = HashSet::new();
skip.insert("example.com".into());
assert!(pool.pick_next(&skip).is_none());
assert_eq!(pool.pending_count(), 2);
}
#[test]
fn pick_next_returns_none_when_empty() {
let pool = TaskPool::new(10);
let skip = HashSet::new();
assert!(pool.pick_next(&skip).is_none());
}
#[test]
fn pick_next_selects_first_non_skipped_preserves_order() {
let pool = TaskPool::new(10);
let _h1 = pool.try_submit(test_request()).unwrap();
let req2 = reqwest::Client::new()
.get("http://other.com/path")
.build()
.unwrap();
let _h2 = pool.try_submit(req2).unwrap();
let _h3 = pool.try_submit(test_request()).unwrap();
let mut skip = HashSet::new();
skip.insert("example.com".into());
let picked = pool.pick_next(&skip).unwrap();
assert_eq!(picked.host, "other.com");
assert_eq!(pool.pending_count(), 2);
}
#[test]
fn push_back_requeues_to_tail() {
let pool = TaskPool::new(10);
let _h1 = pool.try_submit(test_request()).unwrap();
let _h2 = pool.try_submit(test_request()).unwrap();
let skip = HashSet::new();
let t1 = pool.pick_next(&skip).unwrap();
let id1 = t1.id;
pool.push_back(t1);
let t2 = pool.pick_next(&skip).unwrap();
let re_t1 = pool.pick_next(&skip).unwrap();
assert!(t2.id < id1 || re_t1.id == id1);
}
#[test]
fn delayed_task_promotes_when_ready() {
let pool = TaskPool::new(10);
let _ = pool.try_submit(test_request()).unwrap();
let skip = HashSet::new();
let task = pool.pick_next(&skip).unwrap();
pool.push_delayed(task, Duration::from_millis(10));
assert_eq!(pool.delayed_count(), 1);
std::thread::sleep(Duration::from_millis(20));
let promoted = pool.promote_ready_delayed();
assert_eq!(promoted, 1);
assert_eq!(pool.delayed_count(), 0);
assert_eq!(pool.pending_count(), 1);
}
#[test]
fn mark_completed_increments_counter() {
let pool = TaskPool::new(10);
pool.mark_completed();
assert_eq!(pool.completed_count(), 1);
}
#[tokio::test]
async fn task_handle_receives_success() {
let pool = TaskPool::new(10);
let handle = pool.try_submit(test_request()).unwrap();
let skip = HashSet::new();
let mut task = pool.pick_next(&skip).unwrap();
if let Some(tx) = task.result_tx.take() {
let _ = tx.send(ScatterResponse {
status: StatusCode::OK,
headers: HeaderMap::new(),
body: Bytes::from_static(b"hello"),
});
}
let resp = handle
.with_timeout(Duration::from_secs(1))
.await
.unwrap()
.unwrap();
assert_eq!(resp.status, StatusCode::OK);
assert_eq!(resp.body.as_ref(), b"hello");
}
#[tokio::test]
async fn task_handle_returns_502_when_sender_dropped() {
let pool = TaskPool::new(10);
let handle = pool.try_submit(test_request()).unwrap();
let skip = HashSet::new();
let _task = pool.pick_next(&skip).unwrap();
drop(_task);
let resp = handle
.with_timeout(Duration::from_secs(1))
.await
.unwrap()
.unwrap();
assert_eq!(resp.status, StatusCode::BAD_GATEWAY);
}
#[tokio::test]
async fn task_handle_with_timeout_ok() {
let pool = TaskPool::new(10);
let handle = pool.try_submit(test_request()).unwrap();
let skip = HashSet::new();
let mut task = pool.pick_next(&skip).unwrap();
if let Some(tx) = task.result_tx.take() {
let _ = tx.send(ScatterResponse {
status: StatusCode::OK,
headers: HeaderMap::new(),
body: Bytes::from_static(b"ok"),
});
}
let resp = handle
.with_timeout(Duration::from_secs(5))
.await
.unwrap()
.unwrap();
assert_eq!(resp.status, StatusCode::OK);
}
#[tokio::test]
async fn task_handle_with_timeout_expires() {
let pool = TaskPool::new(10);
let handle = pool.try_submit(test_request()).unwrap();
let result = handle
.with_timeout(Duration::from_millis(50))
.await
.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn notified_wakes_on_try_submit() {
let pool = std::sync::Arc::new(TaskPool::new(10));
let pool2 = pool.clone();
let waiter = tokio::spawn(async move {
pool2.notified().await;
true
});
tokio::time::sleep(Duration::from_millis(20)).await;
let _h = pool.try_submit(test_request()).unwrap();
assert!(waiter.await.unwrap());
}
#[tokio::test]
async fn notified_wakes_on_push_back() {
let pool = std::sync::Arc::new(TaskPool::new(10));
let _h = pool.try_submit(test_request()).unwrap();
let skip = HashSet::new();
let task = pool.pick_next(&skip).unwrap();
let pool2 = pool.clone();
let waiter = tokio::spawn(async move {
pool2.notified().await;
true
});
tokio::time::sleep(Duration::from_millis(20)).await;
pool.push_back(task);
assert!(waiter.await.unwrap());
}
#[test]
fn pool_with_zero_capacity_rejects_everything() {
let pool = TaskPool::new(0);
let result = pool.try_submit(test_request());
assert!(result.is_err());
}
#[test]
fn pool_allows_try_submit_after_mark_completed_frees_space() {
let pool = TaskPool::new(1);
let _h1 = pool.try_submit(test_request()).unwrap();
assert!(pool.try_submit(test_request()).is_err());
let skip = HashSet::new();
let _task = pool.pick_next(&skip).unwrap();
pool.mark_completed();
let _h2 = pool.try_submit(test_request()).unwrap();
}
#[test]
fn task_entry_has_correct_defaults_on_try_submit() {
let pool = TaskPool::new(10);
let _h = pool.try_submit(test_request()).unwrap();
let skip = HashSet::new();
let task = pool.pick_next(&skip).unwrap();
assert_eq!(task.attempts, 0);
assert!(task.last_error.is_empty());
assert!(task.result_tx.is_some());
}
#[test]
fn scatter_response_debug() {
let resp = ScatterResponse {
status: StatusCode::OK,
headers: HeaderMap::new(),
body: Bytes::from_static(b"test"),
};
let dbg = format!("{resp:?}");
assert!(dbg.contains("200"));
}
}