use crate::io::{HttpClient, HttpRequest, HttpResponse};
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashSet};
use std::sync::atomic::{AtomicU64, Ordering as AtomicOrdering};
use std::sync::Mutex;
#[derive(Debug)]
struct PrioritizedRequest {
request: HttpRequest,
priority: f64,
sequence: u64,
}
impl PartialEq for PrioritizedRequest {
fn eq(&self, other: &Self) -> bool {
self.request.url == other.request.url
}
}
impl Eq for PrioritizedRequest {}
impl PartialOrd for PrioritizedRequest {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for PrioritizedRequest {
fn cmp(&self, other: &Self) -> Ordering {
other
.priority
.partial_cmp(&self.priority)
.unwrap_or(Ordering::Equal)
.then_with(|| other.sequence.cmp(&self.sequence))
}
}
pub struct FetchPool {
client: Box<dyn HttpClient>,
max_concurrent: usize,
queue: Mutex<BinaryHeap<PrioritizedRequest>>,
known_urls: Mutex<HashSet<String>>,
in_flight_urls: Mutex<HashSet<String>>,
cancelled_in_flight: Mutex<HashSet<String>>,
sequence: AtomicU64,
}
impl FetchPool {
pub fn new(client: Box<dyn HttpClient>, max_concurrent: usize) -> Self {
Self {
client,
max_concurrent: max_concurrent.max(1),
queue: Mutex::new(BinaryHeap::new()),
known_urls: Mutex::new(HashSet::new()),
in_flight_urls: Mutex::new(HashSet::new()),
cancelled_in_flight: Mutex::new(HashSet::new()),
sequence: AtomicU64::new(0),
}
}
pub fn enqueue(&self, request: HttpRequest, priority: f64) {
let url = request.url.clone();
let mut known = match self.known_urls.lock() {
Ok(u) => u,
Err(_) => return,
};
if !known.insert(url.clone()) {
return;
}
drop(known);
if let Ok(mut cancelled) = self.cancelled_in_flight.lock() {
cancelled.remove(&url);
}
if let Ok(mut queue) = self.queue.lock() {
queue.push(PrioritizedRequest {
request,
priority,
sequence: self.sequence.fetch_add(1, AtomicOrdering::Relaxed),
});
}
}
pub fn flush(&self) {
let mut in_flight = match self.in_flight_urls.lock() {
Ok(f) => f,
Err(_) => return,
};
let mut queue = match self.queue.lock() {
Ok(q) => q,
Err(_) => return,
};
while in_flight.len() < self.max_concurrent {
match queue.pop() {
Some(req) => {
let url = req.request.url.clone();
self.client.send(req.request);
in_flight.insert(url);
}
None => break,
}
}
}
pub fn cancel(&self, url: &str) -> bool {
let in_queue = self
.queue
.lock()
.is_ok_and(|q| q.iter().any(|r| r.request.url == url));
if !in_queue {
return false;
}
if let Ok(mut known) = self.known_urls.lock() {
known.remove(url);
}
if let Ok(mut queue) = self.queue.lock() {
let old: Vec<_> = queue.drain().collect();
for req in old {
if req.request.url != url {
queue.push(req);
}
}
}
true
}
pub fn force_cancel(&self, url: &str) {
let was_queued = self
.queue
.lock()
.is_ok_and(|q| q.iter().any(|r| r.request.url == url));
if let Ok(mut known) = self.known_urls.lock() {
known.remove(url);
}
if was_queued {
if let Ok(mut queue) = self.queue.lock() {
let old: Vec<_> = queue.drain().collect();
for req in old {
if req.request.url != url {
queue.push(req);
}
}
}
} else {
if let Ok(mut in_flight) = self.in_flight_urls.lock() {
in_flight.remove(url);
}
if let Ok(mut cancelled) = self.cancelled_in_flight.lock() {
cancelled.insert(url.to_owned());
}
}
}
pub fn clear_queue(&self) {
if let Ok(mut queue) = self.queue.lock() {
if let Ok(mut known) = self.known_urls.lock() {
for req in queue.iter() {
known.remove(&req.request.url);
}
}
queue.clear();
}
}
pub fn queued_count(&self) -> usize {
self.queue.lock().map(|q| q.len()).unwrap_or(0)
}
#[inline]
pub fn max_concurrent(&self) -> usize {
self.max_concurrent
}
pub fn in_flight_count(&self) -> usize {
self.in_flight_urls.lock().map(|g| g.len()).unwrap_or(0)
}
pub fn known_count(&self) -> usize {
self.known_urls.lock().map(|k| k.len()).unwrap_or(0)
}
pub fn cancelled_in_flight_count(&self) -> usize {
self.cancelled_in_flight
.lock()
.map(|set| set.len())
.unwrap_or(0)
}
pub fn is_known(&self, url: &str) -> bool {
self.known_urls.lock().is_ok_and(|k| k.contains(url))
}
pub fn poll(&self) -> Vec<(String, Result<HttpResponse, String>)> {
let results = self.client.poll();
if !results.is_empty() {
let mut cancelled = self.cancelled_in_flight.lock().ok();
let mut in_flight = self.in_flight_urls.lock().ok();
for (url, _) in &results {
if let Some(ref mut set) = cancelled {
if set.remove(url.as_str()) {
continue;
}
}
if let Some(ref mut urls) = in_flight {
urls.remove(url.as_str());
}
}
drop(cancelled);
drop(in_flight);
if let Ok(mut known) = self.known_urls.lock() {
for (url, _) in &results {
known.remove(url);
}
}
}
self.flush();
results
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::io::shared_http_client::SharedHttpClient;
use std::sync::Arc;
struct InstantMockClient {
sent: Mutex<Vec<String>>,
}
impl InstantMockClient {
fn new() -> Self {
Self {
sent: Mutex::new(Vec::new()),
}
}
}
impl HttpClient for InstantMockClient {
fn send(&self, request: HttpRequest) {
self.sent.lock().unwrap().push(request.url);
}
fn poll(&self) -> Vec<(String, Result<HttpResponse, String>)> {
let sent = std::mem::take(&mut *self.sent.lock().unwrap());
sent.into_iter()
.map(|url| {
(
url,
Ok(HttpResponse {
status: 200,
body: Vec::new(),
headers: Vec::new(),
}),
)
})
.collect()
}
}
struct DeferredMockClient {
sent: Mutex<Vec<String>>,
}
impl DeferredMockClient {
fn new() -> Self {
Self {
sent: Mutex::new(Vec::new()),
}
}
}
impl HttpClient for DeferredMockClient {
fn send(&self, request: HttpRequest) {
self.sent.lock().unwrap().push(request.url);
}
fn poll(&self) -> Vec<(String, Result<HttpResponse, String>)> {
Vec::new()
}
}
struct SharedMock(Arc<Mutex<Vec<String>>>);
impl HttpClient for SharedMock {
fn send(&self, request: HttpRequest) {
self.0.lock().unwrap().push(request.url);
}
fn poll(&self) -> Vec<(String, Result<HttpResponse, String>)> {
Vec::new()
}
}
#[derive(Clone, Default)]
#[allow(clippy::type_complexity)]
struct RecordingClient {
sent: Arc<Mutex<Vec<String>>>,
responses: Arc<Mutex<Vec<(String, Result<HttpResponse, String>)>>>,
}
impl RecordingClient {
fn sent_urls(&self) -> Vec<String> {
self.sent.lock().unwrap().clone()
}
fn complete(&self, url: &str) {
self.responses.lock().unwrap().push((
url.to_owned(),
Ok(HttpResponse {
status: 200,
body: Vec::new(),
headers: Vec::new(),
}),
));
}
}
impl HttpClient for RecordingClient {
fn send(&self, request: HttpRequest) {
self.sent.lock().unwrap().push(request.url);
}
fn poll(&self) -> Vec<(String, Result<HttpResponse, String>)> {
std::mem::take(&mut *self.responses.lock().unwrap())
}
}
#[test]
fn respects_concurrency_limit() {
let pool = FetchPool::new(Box::new(DeferredMockClient::new()), 2);
pool.enqueue(HttpRequest::get("a"), 1.0);
pool.enqueue(HttpRequest::get("b"), 2.0);
pool.enqueue(HttpRequest::get("c"), 3.0);
pool.flush();
assert_eq!(pool.in_flight_count(), 2);
assert_eq!(pool.queued_count(), 1);
}
#[test]
fn freed_slots_dispatch_queued() {
let pool = FetchPool::new(Box::new(InstantMockClient::new()), 1);
pool.enqueue(HttpRequest::get("a"), 1.0);
pool.enqueue(HttpRequest::get("b"), 2.0);
pool.flush();
assert_eq!(pool.in_flight_count(), 1);
assert_eq!(pool.queued_count(), 1);
let results = pool.poll();
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, "a");
assert_eq!(pool.queued_count(), 0);
assert_eq!(pool.in_flight_count(), 1); }
#[test]
fn priority_order_nearest_first() {
let sent = Arc::new(Mutex::new(Vec::new()));
let pool = FetchPool::new(Box::new(SharedMock(Arc::clone(&sent))), 10);
pool.enqueue(HttpRequest::get("far"), 100.0);
pool.enqueue(HttpRequest::get("near"), 1.0);
pool.enqueue(HttpRequest::get("mid"), 50.0);
pool.flush();
let order = sent.lock().unwrap().clone();
assert_eq!(order.len(), 3);
assert_eq!(order[0], "near");
assert_eq!(order[1], "mid");
assert_eq!(order[2], "far");
}
#[test]
fn equal_priority_preserves_enqueue_order() {
let sent = Arc::new(Mutex::new(Vec::new()));
let pool = FetchPool::new(Box::new(SharedMock(Arc::clone(&sent))), 10);
pool.enqueue(HttpRequest::get("a"), 1.0);
pool.enqueue(HttpRequest::get("b"), 1.0);
pool.enqueue(HttpRequest::get("c"), 1.0);
pool.flush();
let order = sent.lock().unwrap().clone();
assert_eq!(order, vec!["a", "b", "c"]);
}
#[test]
fn duplicate_enqueue_ignored() {
let sent = Arc::new(Mutex::new(Vec::new()));
let pool = FetchPool::new(Box::new(SharedMock(Arc::clone(&sent))), 10);
pool.enqueue(HttpRequest::get("a"), 1.0);
pool.enqueue(HttpRequest::get("a"), 2.0); pool.flush();
assert_eq!(pool.in_flight_count(), 1);
assert_eq!(sent.lock().unwrap().len(), 1);
}
#[test]
fn duplicate_suppressed_while_in_flight() {
let pool = FetchPool::new(Box::new(DeferredMockClient::new()), 10);
pool.enqueue(HttpRequest::get("a"), 1.0);
pool.flush();
pool.enqueue(HttpRequest::get("a"), 2.0);
assert_eq!(pool.queued_count(), 0, "should not re-queue in-flight URL");
}
#[test]
fn can_re_enqueue_after_completion() {
let pool = FetchPool::new(Box::new(InstantMockClient::new()), 10);
pool.enqueue(HttpRequest::get("a"), 1.0);
pool.flush();
let results = pool.poll();
assert_eq!(results.len(), 1);
pool.enqueue(HttpRequest::get("a"), 1.0);
pool.flush();
assert_eq!(pool.in_flight_count(), 1);
}
#[test]
fn cancel_removes_queued_request() {
let pool = FetchPool::new(Box::new(DeferredMockClient::new()), 1);
pool.enqueue(HttpRequest::get("a"), 1.0);
pool.enqueue(HttpRequest::get("b"), 2.0);
pool.enqueue(HttpRequest::get("c"), 3.0);
pool.flush();
assert_eq!(pool.queued_count(), 2);
assert!(pool.cancel("b"));
assert_eq!(pool.queued_count(), 1);
}
#[test]
fn cancel_nonexistent_returns_false() {
let pool = FetchPool::new(Box::new(DeferredMockClient::new()), 10);
assert!(!pool.cancel("nope"));
}
#[test]
fn cancel_in_flight_returns_false() {
let pool = FetchPool::new(Box::new(DeferredMockClient::new()), 10);
pool.enqueue(HttpRequest::get("a"), 1.0);
pool.flush();
assert!(!pool.cancel("a"));
}
#[test]
fn clear_queue_discards_pending() {
let pool = FetchPool::new(Box::new(DeferredMockClient::new()), 1);
pool.enqueue(HttpRequest::get("a"), 1.0);
pool.enqueue(HttpRequest::get("b"), 2.0);
pool.enqueue(HttpRequest::get("c"), 3.0);
pool.flush();
assert_eq!(pool.queued_count(), 2);
pool.clear_queue();
assert_eq!(pool.queued_count(), 0);
assert_eq!(pool.in_flight_count(), 1);
}
#[test]
fn cleared_urls_can_be_re_enqueued() {
let pool = FetchPool::new(Box::new(DeferredMockClient::new()), 1);
pool.enqueue(HttpRequest::get("a"), 1.0);
pool.enqueue(HttpRequest::get("b"), 2.0);
pool.flush(); pool.clear_queue();
pool.enqueue(HttpRequest::get("b"), 2.0); assert_eq!(pool.queued_count(), 1);
}
#[test]
fn zero_concurrency_clamped_to_one() {
let pool = FetchPool::new(Box::new(DeferredMockClient::new()), 0);
pool.enqueue(HttpRequest::get("a"), 1.0);
pool.flush();
assert_eq!(pool.in_flight_count(), 1);
}
#[test]
fn poll_empty_returns_empty() {
let pool = FetchPool::new(Box::new(DeferredMockClient::new()), 10);
assert!(pool.poll().is_empty());
assert_eq!(pool.in_flight_count(), 0);
assert_eq!(pool.queued_count(), 0);
}
#[test]
fn flush_empty_is_noop() {
let pool = FetchPool::new(Box::new(DeferredMockClient::new()), 10);
pool.flush(); assert_eq!(pool.in_flight_count(), 0);
}
#[test]
fn force_cancel_immediately_reclaims_in_flight_slot() {
let pool = FetchPool::new(Box::new(DeferredMockClient::new()), 2);
pool.enqueue(HttpRequest::get("a"), 1.0);
pool.enqueue(HttpRequest::get("b"), 2.0);
pool.enqueue(HttpRequest::get("c"), 3.0);
pool.flush();
assert_eq!(pool.in_flight_count(), 2);
assert_eq!(pool.queued_count(), 1);
pool.force_cancel("a");
assert_eq!(
pool.in_flight_count(),
1,
"ghost slot must be reclaimed immediately"
);
assert_eq!(pool.cancelled_in_flight_count(), 1);
pool.flush();
assert_eq!(
pool.in_flight_count(),
2,
"freed slot should accept queued request"
);
assert_eq!(pool.queued_count(), 0);
}
#[test]
fn ghost_response_does_not_double_decrement_in_flight() {
let pool = FetchPool::new(Box::new(DeferredMockClient::new()), 2);
pool.enqueue(HttpRequest::get("a"), 1.0);
pool.enqueue(HttpRequest::get("b"), 2.0);
pool.flush();
assert_eq!(pool.in_flight_count(), 2);
pool.force_cancel("a");
assert_eq!(pool.in_flight_count(), 1);
assert_eq!(pool.cancelled_in_flight_count(), 1);
let _ = pool.poll();
assert_eq!(pool.in_flight_count(), 1);
assert_eq!(
pool.cancelled_in_flight_count(),
1,
"ghost persists until response arrives"
);
}
#[test]
fn force_cancel_allows_re_enqueue_of_in_flight_url() {
let pool = FetchPool::new(Box::new(DeferredMockClient::new()), 10);
pool.enqueue(HttpRequest::get("a"), 1.0);
pool.flush();
assert_eq!(pool.in_flight_count(), 1);
pool.force_cancel("a");
assert_eq!(pool.in_flight_count(), 0);
pool.enqueue(HttpRequest::get("a"), 1.0);
pool.flush();
assert_eq!(pool.in_flight_count(), 1);
}
#[test]
fn reenqueue_after_force_cancel_with_shared_dedup_completes_cleanly() {
let inner = RecordingClient::default();
let shared = SharedHttpClient::new(Box::new(inner.clone()));
let pool = FetchPool::new(Box::new(shared), 2);
pool.enqueue(HttpRequest::get("a"), 1.0);
pool.enqueue(HttpRequest::get("b"), 2.0);
pool.flush();
assert_eq!(pool.in_flight_count(), 2);
assert_eq!(inner.sent_urls(), vec!["a".to_string(), "b".to_string()]);
pool.force_cancel("a");
assert_eq!(pool.in_flight_count(), 1);
assert_eq!(pool.cancelled_in_flight_count(), 1);
pool.enqueue(HttpRequest::get("a"), 0.5);
pool.flush();
assert_eq!(pool.in_flight_count(), 2);
assert_eq!(inner.sent_urls(), vec!["a".to_string(), "b".to_string()]);
assert_eq!(
pool.cancelled_in_flight_count(),
0,
"re-enqueue should clear ghost marker"
);
inner.complete("a");
let results = pool.poll();
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, "a");
assert_eq!(
pool.in_flight_count(),
1,
"completion should retire the re-enqueued URL"
);
}
}