use std::collections::HashSet;
use std::collections::VecDeque;
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Mutex;
use std::task::{Context, Poll};
use std::time::Duration;
use bytes::Bytes;
use http::{HeaderMap, StatusCode};
use tokio::sync::{oneshot, Notify, Semaphore};
use crate::error::ScatterProxyError;
#[derive(Debug)]
pub struct ScatterResponse {
pub status: StatusCode,
pub headers: HeaderMap,
pub body: Bytes,
}
#[derive(Debug)]
pub struct TaskHandle {
rx: oneshot::Receiver<ScatterResponse>,
}
impl TaskHandle {
pub async fn with_timeout(
self,
duration: Duration,
) -> Result<ScatterResponse, ScatterProxyError> {
match tokio::time::timeout(duration, self).await {
Ok(resp) => Ok(resp),
Err(_) => Err(ScatterProxyError::Timeout { elapsed: duration }),
}
}
}
impl Future for TaskHandle {
type Output = ScatterResponse;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match Pin::new(&mut self.rx).poll(cx) {
Poll::Ready(Ok(resp)) => Poll::Ready(resp),
Poll::Ready(Err(_)) => {
Poll::Ready(ScatterResponse {
status: StatusCode::BAD_GATEWAY,
headers: HeaderMap::new(),
body: Bytes::from_static(
b"scatter-proxy: internal error - task channel closed",
),
})
}
Poll::Pending => Poll::Pending,
}
}
}
pub(crate) struct TaskEntry {
#[allow(dead_code)]
pub id: u64,
pub request: reqwest::Request,
pub host: String,
pub attempts: usize,
pub result_tx: Option<oneshot::Sender<ScatterResponse>>,
pub last_error: String,
}
pub struct TaskPool {
queue: Mutex<VecDeque<TaskEntry>>,
capacity: usize,
capacity_sem: Semaphore,
next_id: AtomicU64,
notify: Notify,
completed: AtomicU64,
failed: AtomicU64,
}
impl TaskPool {
pub fn new(capacity: usize) -> Self {
Self {
queue: Mutex::new(VecDeque::new()),
capacity,
capacity_sem: Semaphore::new(capacity),
next_id: AtomicU64::new(1),
notify: Notify::new(),
completed: AtomicU64::new(0),
failed: AtomicU64::new(0),
}
}
pub async fn submit(&self, request: reqwest::Request) -> TaskHandle {
let permit = self
.capacity_sem
.acquire()
.await
.expect("capacity semaphore closed");
permit.forget();
self.enqueue(request)
}
pub fn try_submit(&self, request: reqwest::Request) -> Result<TaskHandle, ScatterProxyError> {
let permit = self
.capacity_sem
.try_acquire()
.map_err(|_| ScatterProxyError::PoolFull {
capacity: self.capacity,
})?;
permit.forget();
Ok(self.enqueue(request))
}
pub async fn submit_timeout(
&self,
request: reqwest::Request,
timeout: Duration,
) -> Result<TaskHandle, ScatterProxyError> {
match tokio::time::timeout(timeout, self.submit(request)).await {
Ok(handle) => Ok(handle),
Err(_) => Err(ScatterProxyError::Timeout { elapsed: timeout }),
}
}
pub async fn submit_batch(&self, requests: Vec<reqwest::Request>) -> Vec<TaskHandle> {
let mut handles = Vec::with_capacity(requests.len());
for req in requests {
handles.push(self.submit(req).await);
}
handles
}
pub fn try_submit_batch(
&self,
requests: Vec<reqwest::Request>,
) -> Result<Vec<TaskHandle>, ScatterProxyError> {
let count = requests.len();
if count == 0 {
return Ok(Vec::new());
}
let permit = self
.capacity_sem
.try_acquire_many(count as u32)
.map_err(|_| ScatterProxyError::PoolFull {
capacity: self.capacity,
})?;
permit.forget();
let mut handles = Vec::with_capacity(count);
for req in requests {
handles.push(self.enqueue(req));
}
Ok(handles)
}
fn enqueue(&self, request: reqwest::Request) -> TaskHandle {
let host = request.url().host_str().unwrap_or("unknown").to_string();
let (tx, rx) = oneshot::channel();
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let entry = TaskEntry {
id,
request,
host,
attempts: 0,
result_tx: Some(tx),
last_error: String::new(),
};
{
let mut queue = self.queue.lock().unwrap();
queue.push_back(entry);
}
self.notify.notify_one();
TaskHandle { rx }
}
pub(crate) fn pick_next(&self, skip_hosts: &HashSet<String>) -> Option<TaskEntry> {
let mut queue = self.queue.lock().unwrap();
let len = queue.len();
for i in 0..len {
if let Some(entry) = queue.get(i) {
if !skip_hosts.contains(&entry.host) {
return queue.remove(i);
}
}
}
None
}
pub(crate) fn push_back(&self, entry: TaskEntry) {
{
let mut queue = self.queue.lock().unwrap();
queue.push_back(entry);
}
self.notify.notify_one();
}
pub fn pending_count(&self) -> usize {
let queue = self.queue.lock().unwrap();
queue.len()
}
pub fn completed_count(&self) -> u64 {
self.completed.load(Ordering::Relaxed)
}
pub(crate) fn mark_completed(&self) {
self.completed.fetch_add(1, Ordering::Relaxed);
self.capacity_sem.add_permits(1);
}
pub(crate) fn mark_failed(&self) {
self.failed.fetch_add(1, Ordering::Relaxed);
self.capacity_sem.add_permits(1);
}
pub fn failed_count(&self) -> u64 {
self.failed.load(Ordering::Relaxed)
}
#[allow(dead_code)]
pub(crate) async fn notified(&self) {
self.notify.notified().await;
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_request() -> reqwest::Request {
reqwest::Client::new()
.get("http://example.com/test")
.build()
.unwrap()
}
#[test]
fn new_pool_has_zero_pending() {
let pool = TaskPool::new(10);
assert_eq!(pool.pending_count(), 0);
assert_eq!(pool.completed_count(), 0);
}
#[test]
fn try_submit_increments_pending_count() {
let pool = TaskPool::new(10);
let _h1 = pool.try_submit(test_request()).unwrap();
let _h2 = pool.try_submit(test_request()).unwrap();
assert_eq!(pool.pending_count(), 2);
}
#[test]
fn try_submit_returns_pool_full_when_at_capacity() {
let pool = TaskPool::new(2);
let _h1 = pool.try_submit(test_request()).unwrap();
let _h2 = pool.try_submit(test_request()).unwrap();
let result = pool.try_submit(test_request());
assert!(result.is_err());
match result.unwrap_err() {
ScatterProxyError::PoolFull { capacity } => assert_eq!(capacity, 2),
other => panic!("expected PoolFull, got {other:?}"),
}
}
#[test]
fn try_submit_assigns_incrementing_ids() {
let pool = TaskPool::new(10);
let _h1 = pool.try_submit(test_request()).unwrap();
let _h2 = pool.try_submit(test_request()).unwrap();
let skip = HashSet::new();
let t1 = pool.pick_next(&skip).unwrap();
let t2 = pool.pick_next(&skip).unwrap();
assert!(t2.id > t1.id);
}
#[test]
fn try_submit_extracts_host_from_url() {
let pool = TaskPool::new(10);
let _h = pool.try_submit(test_request()).unwrap();
let skip = HashSet::new();
let task = pool.pick_next(&skip).unwrap();
assert_eq!(task.host, "example.com");
}
#[test]
fn try_submit_batch_adds_all_tasks() {
let pool = TaskPool::new(10);
let reqs = vec![test_request(), test_request(), test_request()];
let handles = pool.try_submit_batch(reqs).unwrap();
assert_eq!(handles.len(), 3);
assert_eq!(pool.pending_count(), 3);
}
#[test]
fn try_submit_batch_atomic_rejection_when_pool_full() {
let pool = TaskPool::new(2);
let reqs = vec![test_request(), test_request(), test_request()];
let result = pool.try_submit_batch(reqs);
assert!(result.is_err());
assert_eq!(pool.pending_count(), 0);
}
#[test]
fn try_submit_batch_empty_vec_is_ok() {
let pool = TaskPool::new(10);
let handles = pool.try_submit_batch(vec![]).unwrap();
assert!(handles.is_empty());
}
#[tokio::test]
async fn submit_blocks_then_proceeds_after_mark_completed() {
let pool = std::sync::Arc::new(TaskPool::new(1));
let _h1 = pool.try_submit(test_request()).unwrap();
let pool2 = pool.clone();
let join = tokio::spawn(async move {
let _handle = pool2.submit(test_request()).await;
});
tokio::time::sleep(Duration::from_millis(50)).await;
assert_eq!(pool.pending_count(), 1);
{
let skip = HashSet::new();
let _task = pool.pick_next(&skip).unwrap();
pool.mark_completed();
}
join.await.unwrap();
assert_eq!(pool.pending_count(), 1);
}
#[tokio::test]
async fn submit_timeout_returns_err_on_expiry() {
let pool = TaskPool::new(1);
let _h1 = pool.try_submit(test_request()).unwrap();
let result = pool
.submit_timeout(test_request(), Duration::from_millis(50))
.await;
assert!(result.is_err());
match result.unwrap_err() {
ScatterProxyError::Timeout { elapsed } => {
assert_eq!(elapsed, Duration::from_millis(50));
}
other => panic!("expected Timeout, got {other:?}"),
}
}
#[tokio::test]
async fn submit_batch_processes_all() {
let pool = TaskPool::new(10);
let reqs = vec![test_request(), test_request()];
let handles = pool.submit_batch(reqs).await;
assert_eq!(handles.len(), 2);
assert_eq!(pool.pending_count(), 2);
}
#[test]
fn pick_next_returns_fifo_order() {
let pool = TaskPool::new(10);
let _h1 = pool.try_submit(test_request()).unwrap();
let _h2 = pool.try_submit(test_request()).unwrap();
let skip = HashSet::new();
let t1 = pool.pick_next(&skip).unwrap();
let t2 = pool.pick_next(&skip).unwrap();
assert!(t1.id < t2.id);
}
#[test]
fn pick_next_skips_circuit_broken_hosts() {
let pool = TaskPool::new(10);
let _h1 = pool.try_submit(test_request()).unwrap();
let mut skip = HashSet::new();
skip.insert("example.com".into());
assert!(pool.pick_next(&skip).is_none());
}
#[test]
fn pick_next_returns_none_when_all_hosts_skipped() {
let pool = TaskPool::new(10);
let _h1 = pool.try_submit(test_request()).unwrap();
let _h2 = pool.try_submit(test_request()).unwrap();
let mut skip = HashSet::new();
skip.insert("example.com".into());
assert!(pool.pick_next(&skip).is_none());
assert_eq!(pool.pending_count(), 2);
}
#[test]
fn pick_next_returns_none_when_empty() {
let pool = TaskPool::new(10);
let skip = HashSet::new();
assert!(pool.pick_next(&skip).is_none());
}
#[test]
fn pick_next_selects_first_non_skipped_preserves_order() {
let pool = TaskPool::new(10);
let _h1 = pool.try_submit(test_request()).unwrap();
let req2 = reqwest::Client::new()
.get("http://other.com/path")
.build()
.unwrap();
let _h2 = pool.try_submit(req2).unwrap();
let _h3 = pool.try_submit(test_request()).unwrap();
let mut skip = HashSet::new();
skip.insert("example.com".into());
let picked = pool.pick_next(&skip).unwrap();
assert_eq!(picked.host, "other.com");
assert_eq!(pool.pending_count(), 2);
}
#[test]
fn push_back_requeues_to_tail() {
let pool = TaskPool::new(10);
let _h1 = pool.try_submit(test_request()).unwrap();
let _h2 = pool.try_submit(test_request()).unwrap();
let skip = HashSet::new();
let t1 = pool.pick_next(&skip).unwrap();
let id1 = t1.id;
pool.push_back(t1);
let t2 = pool.pick_next(&skip).unwrap();
let re_t1 = pool.pick_next(&skip).unwrap();
assert!(t2.id < id1 || re_t1.id == id1);
}
#[test]
fn mark_completed_increments_counter() {
let pool = TaskPool::new(10);
pool.mark_completed();
assert_eq!(pool.completed_count(), 1);
}
#[tokio::test]
async fn task_handle_receives_success() {
let pool = TaskPool::new(10);
let handle = pool.try_submit(test_request()).unwrap();
let skip = HashSet::new();
let mut task = pool.pick_next(&skip).unwrap();
if let Some(tx) = task.result_tx.take() {
let _ = tx.send(ScatterResponse {
status: StatusCode::OK,
headers: HeaderMap::new(),
body: Bytes::from_static(b"hello"),
});
}
let resp = handle.await;
assert_eq!(resp.status, StatusCode::OK);
assert_eq!(resp.body.as_ref(), b"hello");
}
#[tokio::test]
async fn task_handle_returns_502_when_sender_dropped() {
let pool = TaskPool::new(10);
let handle = pool.try_submit(test_request()).unwrap();
let skip = HashSet::new();
let _task = pool.pick_next(&skip).unwrap();
drop(_task);
let resp = handle.await;
assert_eq!(resp.status, StatusCode::BAD_GATEWAY);
}
#[tokio::test]
async fn task_handle_with_timeout_ok() {
let pool = TaskPool::new(10);
let handle = pool.try_submit(test_request()).unwrap();
let skip = HashSet::new();
let mut task = pool.pick_next(&skip).unwrap();
if let Some(tx) = task.result_tx.take() {
let _ = tx.send(ScatterResponse {
status: StatusCode::OK,
headers: HeaderMap::new(),
body: Bytes::from_static(b"ok"),
});
}
let resp = handle.with_timeout(Duration::from_secs(5)).await.unwrap();
assert_eq!(resp.status, StatusCode::OK);
}
#[tokio::test]
async fn task_handle_with_timeout_expires() {
let pool = TaskPool::new(10);
let handle = pool.try_submit(test_request()).unwrap();
let result = handle.with_timeout(Duration::from_millis(50)).await;
assert!(result.is_err());
}
#[tokio::test]
async fn notified_wakes_on_try_submit() {
let pool = std::sync::Arc::new(TaskPool::new(10));
let pool2 = pool.clone();
let waiter = tokio::spawn(async move {
pool2.notified().await;
true
});
tokio::time::sleep(Duration::from_millis(20)).await;
let _h = pool.try_submit(test_request()).unwrap();
assert!(waiter.await.unwrap());
}
#[tokio::test]
async fn notified_wakes_on_push_back() {
let pool = std::sync::Arc::new(TaskPool::new(10));
let _h = pool.try_submit(test_request()).unwrap();
let skip = HashSet::new();
let task = pool.pick_next(&skip).unwrap();
let pool2 = pool.clone();
let waiter = tokio::spawn(async move {
pool2.notified().await;
true
});
tokio::time::sleep(Duration::from_millis(20)).await;
pool.push_back(task);
assert!(waiter.await.unwrap());
}
#[test]
fn pool_with_zero_capacity_rejects_everything() {
let pool = TaskPool::new(0);
let result = pool.try_submit(test_request());
assert!(result.is_err());
}
#[test]
fn pool_allows_try_submit_after_mark_completed_frees_space() {
let pool = TaskPool::new(1);
let _h1 = pool.try_submit(test_request()).unwrap();
assert!(pool.try_submit(test_request()).is_err());
let skip = HashSet::new();
let _task = pool.pick_next(&skip).unwrap();
pool.mark_completed();
let _h2 = pool.try_submit(test_request()).unwrap();
}
#[test]
fn task_entry_has_correct_defaults_on_try_submit() {
let pool = TaskPool::new(10);
let _h = pool.try_submit(test_request()).unwrap();
let skip = HashSet::new();
let task = pool.pick_next(&skip).unwrap();
assert_eq!(task.attempts, 0);
assert!(task.last_error.is_empty());
assert!(task.result_tx.is_some());
}
#[test]
fn scatter_response_debug() {
let resp = ScatterResponse {
status: StatusCode::OK,
headers: HeaderMap::new(),
body: Bytes::from_static(b"test"),
};
let dbg = format!("{resp:?}");
assert!(dbg.contains("200"));
}
}