use std::collections::HashSet;
use std::collections::VecDeque;
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Mutex;
use std::task::{Context, Poll};
use std::time::{Duration, Instant};
use bytes::Bytes;
use http::{HeaderMap, StatusCode};
use tokio::sync::{oneshot, Notify};
use crate::error::ScatterProxyError;
#[derive(Debug)]
pub struct ScatterResponse {
pub status: StatusCode,
pub headers: HeaderMap,
pub body: Bytes,
}
#[derive(Debug)]
pub struct TaskHandle {
rx: oneshot::Receiver<Result<ScatterResponse, ScatterProxyError>>,
}
impl Future for TaskHandle {
type Output = Result<ScatterResponse, ScatterProxyError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.rx).poll(cx).map(|result| {
result.unwrap_or_else(|_| Err(ScatterProxyError::Init("task channel closed".into())))
})
}
}
pub(crate) struct TaskEntry {
#[allow(dead_code)]
pub id: u64,
pub request: reqwest::Request,
pub host: String,
pub attempts: usize,
pub max_attempts: usize,
pub submitted_at: Instant,
pub task_timeout: Duration,
pub result_tx: Option<oneshot::Sender<Result<ScatterResponse, ScatterProxyError>>>,
pub last_error: String,
}
pub struct TaskPool {
queue: Mutex<VecDeque<TaskEntry>>,
capacity: usize,
next_id: AtomicU64,
notify: Notify,
completed: AtomicU64,
failed: AtomicU64,
}
impl TaskPool {
pub fn new(capacity: usize) -> Self {
Self {
queue: Mutex::new(VecDeque::new()),
capacity,
next_id: AtomicU64::new(1),
notify: Notify::new(),
completed: AtomicU64::new(0),
failed: AtomicU64::new(0),
}
}
pub fn submit(
&self,
request: reqwest::Request,
max_attempts: usize,
task_timeout: Duration,
) -> Result<TaskHandle, ScatterProxyError> {
let host = request.url().host_str().unwrap_or("unknown").to_string();
let (tx, rx) = oneshot::channel();
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let entry = TaskEntry {
id,
request,
host,
attempts: 0,
max_attempts,
submitted_at: Instant::now(),
task_timeout,
result_tx: Some(tx),
last_error: String::new(),
};
{
let mut queue = self.queue.lock().unwrap();
if queue.len() >= self.capacity {
return Err(ScatterProxyError::PoolFull {
capacity: self.capacity,
});
}
queue.push_back(entry);
}
self.notify.notify_one();
Ok(TaskHandle { rx })
}
pub fn submit_batch(
&self,
requests: Vec<reqwest::Request>,
max_attempts: usize,
task_timeout: Duration,
) -> Result<Vec<TaskHandle>, ScatterProxyError> {
let count = requests.len();
{
let queue = self.queue.lock().unwrap();
if queue.len() + count > self.capacity {
return Err(ScatterProxyError::PoolFull {
capacity: self.capacity,
});
}
}
let mut handles = Vec::with_capacity(count);
{
let mut queue = self.queue.lock().unwrap();
if queue.len() + count > self.capacity {
return Err(ScatterProxyError::PoolFull {
capacity: self.capacity,
});
}
for request in requests {
let host = request.url().host_str().unwrap_or("unknown").to_string();
let (tx, rx) = oneshot::channel();
let id = self.next_id.fetch_add(1, Ordering::Relaxed);
let entry = TaskEntry {
id,
request,
host,
attempts: 0,
max_attempts,
submitted_at: Instant::now(),
task_timeout,
result_tx: Some(tx),
last_error: String::new(),
};
queue.push_back(entry);
handles.push(TaskHandle { rx });
}
}
for _ in 0..count {
self.notify.notify_one();
}
Ok(handles)
}
pub(crate) fn pick_next(&self, skip_hosts: &HashSet<String>) -> Option<TaskEntry> {
let mut queue = self.queue.lock().unwrap();
let len = queue.len();
for i in 0..len {
if let Some(entry) = queue.get(i) {
if !skip_hosts.contains(&entry.host) {
return queue.remove(i);
}
}
}
None
}
pub(crate) fn push_back(&self, entry: TaskEntry) {
{
let mut queue = self.queue.lock().unwrap();
queue.push_back(entry);
}
self.notify.notify_one();
}
pub fn pending_count(&self) -> usize {
let queue = self.queue.lock().unwrap();
queue.len()
}
pub fn completed_count(&self) -> u64 {
self.completed.load(Ordering::Relaxed)
}
pub fn failed_count(&self) -> u64 {
self.failed.load(Ordering::Relaxed)
}
pub(crate) fn mark_completed(&self) {
self.completed.fetch_add(1, Ordering::Relaxed);
}
pub(crate) fn mark_failed(&self) {
self.failed.fetch_add(1, Ordering::Relaxed);
}
#[allow(dead_code)]
pub(crate) async fn notified(&self) {
self.notify.notified().await;
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
use std::time::Duration;
fn test_request(url: &str) -> reqwest::Request {
reqwest::Client::new().get(url).build().unwrap()
}
#[test]
fn new_pool_has_zero_pending() {
let pool = TaskPool::new(100);
assert_eq!(pool.pending_count(), 0);
assert_eq!(pool.completed_count(), 0);
assert_eq!(pool.failed_count(), 0);
}
#[test]
fn submit_increments_pending_count() {
let pool = TaskPool::new(10);
let _h = pool
.submit(
test_request("http://example.com"),
3,
Duration::from_secs(10),
)
.unwrap();
assert_eq!(pool.pending_count(), 1);
}
#[test]
fn submit_returns_pool_full_when_at_capacity() {
let pool = TaskPool::new(1);
let _h1 = pool
.submit(
test_request("http://example.com"),
3,
Duration::from_secs(10),
)
.unwrap();
let result = pool.submit(
test_request("http://example.com/2"),
3,
Duration::from_secs(10),
);
assert!(result.is_err());
match result.unwrap_err() {
ScatterProxyError::PoolFull { capacity } => assert_eq!(capacity, 1),
other => panic!("expected PoolFull, got: {other:?}"),
}
}
#[test]
fn submit_assigns_incrementing_ids() {
let pool = TaskPool::new(10);
let _h1 = pool
.submit(test_request("http://a.com"), 3, Duration::from_secs(10))
.unwrap();
let _h2 = pool
.submit(test_request("http://b.com"), 3, Duration::from_secs(10))
.unwrap();
let skip = HashSet::new();
let t1 = pool.pick_next(&skip).unwrap();
let t2 = pool.pick_next(&skip).unwrap();
assert!(t2.id > t1.id);
}
#[test]
fn submit_extracts_host_from_url() {
let pool = TaskPool::new(10);
let _h = pool
.submit(
test_request("http://myhost.example.com/path?q=1"),
3,
Duration::from_secs(10),
)
.unwrap();
let skip = HashSet::new();
let entry = pool.pick_next(&skip).unwrap();
assert_eq!(entry.host, "myhost.example.com");
}
#[test]
fn submit_batch_adds_all_tasks() {
let pool = TaskPool::new(10);
let reqs = vec![
test_request("http://a.com"),
test_request("http://b.com"),
test_request("http://c.com"),
];
let handles = pool.submit_batch(reqs, 3, Duration::from_secs(10)).unwrap();
assert_eq!(handles.len(), 3);
assert_eq!(pool.pending_count(), 3);
}
#[test]
fn submit_batch_atomic_rejection_when_pool_full() {
let pool = TaskPool::new(2);
let _h = pool
.submit(test_request("http://x.com"), 3, Duration::from_secs(10))
.unwrap();
let reqs = vec![test_request("http://a.com"), test_request("http://b.com")];
let result = pool.submit_batch(reqs, 3, Duration::from_secs(10));
assert!(result.is_err());
assert_eq!(pool.pending_count(), 1);
}
#[test]
fn submit_batch_empty_vec_is_ok() {
let pool = TaskPool::new(10);
let handles = pool
.submit_batch(vec![], 3, Duration::from_secs(10))
.unwrap();
assert!(handles.is_empty());
assert_eq!(pool.pending_count(), 0);
}
#[test]
fn pick_next_returns_fifo_order() {
let pool = TaskPool::new(10);
let _h1 = pool
.submit(test_request("http://first.com"), 3, Duration::from_secs(10))
.unwrap();
let _h2 = pool
.submit(
test_request("http://second.com"),
3,
Duration::from_secs(10),
)
.unwrap();
let skip = HashSet::new();
let t = pool.pick_next(&skip).unwrap();
assert_eq!(t.host, "first.com");
assert_eq!(pool.pending_count(), 1);
}
#[test]
fn pick_next_skips_circuit_broken_hosts() {
let pool = TaskPool::new(10);
let _h1 = pool
.submit(
test_request("http://broken.com/a"),
3,
Duration::from_secs(10),
)
.unwrap();
let _h2 = pool
.submit(test_request("http://ok.com/b"), 3, Duration::from_secs(10))
.unwrap();
let mut skip = HashSet::new();
skip.insert("broken.com".to_string());
let t = pool.pick_next(&skip).unwrap();
assert_eq!(t.host, "ok.com");
assert_eq!(pool.pending_count(), 1);
}
#[test]
fn pick_next_returns_none_when_all_hosts_skipped() {
let pool = TaskPool::new(10);
let _h = pool
.submit(
test_request("http://broken.com"),
3,
Duration::from_secs(10),
)
.unwrap();
let mut skip = HashSet::new();
skip.insert("broken.com".to_string());
assert!(pool.pick_next(&skip).is_none());
assert_eq!(pool.pending_count(), 1);
}
#[test]
fn pick_next_returns_none_when_empty() {
let pool = TaskPool::new(10);
let skip = HashSet::new();
assert!(pool.pick_next(&skip).is_none());
}
#[test]
fn push_back_requeues_to_tail() {
let pool = TaskPool::new(10);
let _h1 = pool
.submit(test_request("http://first.com"), 3, Duration::from_secs(10))
.unwrap();
let _h2 = pool
.submit(
test_request("http://second.com"),
3,
Duration::from_secs(10),
)
.unwrap();
let skip = HashSet::new();
let mut entry = pool.pick_next(&skip).unwrap();
assert_eq!(entry.host, "first.com");
entry.attempts += 1;
entry.last_error = "connection refused".into();
pool.push_back(entry);
assert_eq!(pool.pending_count(), 2);
let t = pool.pick_next(&skip).unwrap();
assert_eq!(t.host, "second.com");
let t = pool.pick_next(&skip).unwrap();
assert_eq!(t.host, "first.com");
assert_eq!(t.attempts, 1);
assert_eq!(t.last_error, "connection refused");
}
#[test]
fn mark_completed_increments_counter() {
let pool = TaskPool::new(10);
pool.mark_completed();
pool.mark_completed();
assert_eq!(pool.completed_count(), 2);
}
#[test]
fn mark_failed_increments_counter() {
let pool = TaskPool::new(10);
pool.mark_failed();
assert_eq!(pool.failed_count(), 1);
}
#[tokio::test]
async fn task_handle_receives_success() {
let pool = TaskPool::new(10);
let handle = pool
.submit(
test_request("http://example.com"),
3,
Duration::from_secs(10),
)
.unwrap();
let skip = HashSet::new();
let entry = pool.pick_next(&skip).unwrap();
let response = ScatterResponse {
status: StatusCode::OK,
headers: HeaderMap::new(),
body: Bytes::from("hello"),
};
entry.result_tx.unwrap().send(Ok(response)).unwrap();
let result = handle.await;
assert!(result.is_ok());
let resp = result.unwrap();
assert_eq!(resp.status, StatusCode::OK);
assert_eq!(resp.body, Bytes::from("hello"));
}
#[tokio::test]
async fn task_handle_receives_error() {
let pool = TaskPool::new(10);
let handle = pool
.submit(
test_request("http://example.com"),
3,
Duration::from_secs(10),
)
.unwrap();
let skip = HashSet::new();
let entry = pool.pick_next(&skip).unwrap();
entry
.result_tx
.unwrap()
.send(Err(ScatterProxyError::MaxAttemptsExhausted {
host: "example.com".into(),
attempts: 3,
last_error: "timeout".into(),
}))
.unwrap();
let result = handle.await;
assert!(result.is_err());
}
#[tokio::test]
async fn task_handle_returns_error_when_sender_dropped() {
let pool = TaskPool::new(10);
let handle = pool
.submit(
test_request("http://example.com"),
3,
Duration::from_secs(10),
)
.unwrap();
let skip = HashSet::new();
let entry = pool.pick_next(&skip).unwrap();
drop(entry);
let result = handle.await;
assert!(result.is_err());
match result.unwrap_err() {
ScatterProxyError::Init(msg) => assert!(msg.contains("channel closed")),
other => panic!("expected Init, got: {other:?}"),
}
}
#[tokio::test]
async fn notified_wakes_on_submit() {
let pool = std::sync::Arc::new(TaskPool::new(10));
let pool2 = pool.clone();
let waiter = tokio::spawn(async move {
pool2.notified().await;
true
});
tokio::time::sleep(Duration::from_millis(10)).await;
let _h = pool
.submit(
test_request("http://example.com"),
3,
Duration::from_secs(10),
)
.unwrap();
let woke = tokio::time::timeout(Duration::from_secs(1), waiter)
.await
.unwrap()
.unwrap();
assert!(woke);
}
#[tokio::test]
async fn notified_wakes_on_push_back() {
let pool = std::sync::Arc::new(TaskPool::new(10));
let _h = pool
.submit(
test_request("http://example.com"),
3,
Duration::from_secs(10),
)
.unwrap();
let skip = HashSet::new();
let entry = pool.pick_next(&skip).unwrap();
let pool2 = pool.clone();
let waiter = tokio::spawn(async move {
pool2.notified().await;
true
});
tokio::time::sleep(Duration::from_millis(10)).await;
pool.push_back(entry);
let woke = tokio::time::timeout(Duration::from_secs(1), waiter)
.await
.unwrap()
.unwrap();
assert!(woke);
}
#[test]
fn pool_with_zero_capacity_rejects_everything() {
let pool = TaskPool::new(0);
let result = pool.submit(test_request("http://a.com"), 1, Duration::from_secs(5));
assert!(matches!(
result,
Err(ScatterProxyError::PoolFull { capacity: 0 })
));
}
#[test]
fn pool_allows_submit_after_pick_frees_space() {
let pool = TaskPool::new(1);
let _h1 = pool
.submit(test_request("http://a.com"), 1, Duration::from_secs(5))
.unwrap();
assert!(pool
.submit(test_request("http://b.com"), 1, Duration::from_secs(5))
.is_err());
let skip = HashSet::new();
let _entry = pool.pick_next(&skip).unwrap();
let _h2 = pool
.submit(test_request("http://c.com"), 1, Duration::from_secs(5))
.unwrap();
assert_eq!(pool.pending_count(), 1);
}
#[test]
fn task_entry_has_correct_defaults_on_submit() {
let pool = TaskPool::new(10);
let timeout = Duration::from_secs(42);
let _h = pool
.submit(test_request("http://host.example.com/path"), 7, timeout)
.unwrap();
let skip = HashSet::new();
let entry = pool.pick_next(&skip).unwrap();
assert_eq!(entry.host, "host.example.com");
assert_eq!(entry.attempts, 0);
assert_eq!(entry.max_attempts, 7);
assert_eq!(entry.task_timeout, timeout);
assert!(entry.last_error.is_empty());
assert!(entry.result_tx.is_some());
}
#[test]
fn scatter_response_debug() {
let resp = ScatterResponse {
status: StatusCode::NOT_FOUND,
headers: HeaderMap::new(),
body: Bytes::from("not found"),
};
let dbg = format!("{resp:?}");
assert!(dbg.contains("404"));
}
#[test]
fn pick_next_selects_first_non_skipped_preserves_order() {
let pool = TaskPool::new(10);
let _h1 = pool
.submit(test_request("http://a.com/1"), 1, Duration::from_secs(5))
.unwrap();
let _h2 = pool
.submit(test_request("http://b.com/2"), 1, Duration::from_secs(5))
.unwrap();
let _h3 = pool
.submit(test_request("http://a.com/3"), 1, Duration::from_secs(5))
.unwrap();
let _h4 = pool
.submit(test_request("http://c.com/4"), 1, Duration::from_secs(5))
.unwrap();
let mut skip = HashSet::new();
skip.insert("a.com".to_string());
let t1 = pool.pick_next(&skip).unwrap();
assert_eq!(t1.host, "b.com");
let t2 = pool.pick_next(&skip).unwrap();
assert_eq!(t2.host, "c.com");
assert_eq!(pool.pending_count(), 2);
assert!(pool.pick_next(&skip).is_none());
}
}