use super::{HttpClient, HttpRequest, HttpResponse};
use std::collections::{HashMap, VecDeque};
use std::sync::{Arc, Mutex};
type SubscriberId = u64;
struct InflightEntry {
subscribers: Vec<SubscriberId>,
#[allow(dead_code)]
url: String,
}
struct SharedState {
client: Box<dyn HttpClient>,
next_subscriber_id: SubscriberId,
inflight: HashMap<String, InflightEntry>,
#[allow(clippy::type_complexity)]
outboxes: HashMap<SubscriberId, VecDeque<(String, Result<HttpResponse, String>)>>,
}
pub struct SharedHttpClient {
subscriber_id: SubscriberId,
state: Arc<Mutex<SharedState>>,
}
impl SharedHttpClient {
pub fn new(client: Box<dyn HttpClient>) -> Self {
let mut outboxes = HashMap::new();
outboxes.insert(0, VecDeque::new());
Self {
subscriber_id: 0,
state: Arc::new(Mutex::new(SharedState {
client,
next_subscriber_id: 1,
inflight: HashMap::new(),
outboxes,
})),
}
}
}
impl Clone for SharedHttpClient {
fn clone(&self) -> Self {
let mut state = self.state.lock().expect("SharedHttpClient lock poisoned");
let id = state.next_subscriber_id;
state.next_subscriber_id += 1;
state.outboxes.insert(id, VecDeque::new());
Self {
subscriber_id: id,
state: Arc::clone(&self.state),
}
}
}
impl std::fmt::Debug for SharedHttpClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let (inflight, subscribers) = self
.state
.lock()
.map(|s| (s.inflight.len(), s.outboxes.len()))
.unwrap_or((0, 0));
f.debug_struct("SharedHttpClient")
.field("subscriber_id", &self.subscriber_id)
.field("inflight_urls", &inflight)
.field("total_subscribers", &subscribers)
.finish()
}
}
impl HttpClient for SharedHttpClient {
fn send(&self, request: HttpRequest) {
let mut state = match self.state.lock() {
Ok(s) => s,
Err(_) => return,
};
let url = request.url.clone();
if let Some(entry) = state.inflight.get_mut(&url) {
if !entry.subscribers.contains(&self.subscriber_id) {
entry.subscribers.push(self.subscriber_id);
}
return;
}
state.inflight.insert(
url.clone(),
InflightEntry {
subscribers: vec![self.subscriber_id],
url,
},
);
state.client.send(request);
}
fn poll(&self) -> Vec<(String, Result<HttpResponse, String>)> {
let mut state = match self.state.lock() {
Ok(s) => s,
Err(_) => return Vec::new(),
};
let completed = state.client.poll();
for (url, result) in completed {
if let Some(entry) = state.inflight.remove(&url) {
let subscriber_count = entry.subscribers.len();
for (i, &sub_id) in entry.subscribers.iter().enumerate() {
let cloned_result = if i + 1 == subscriber_count {
clone_result(&result)
} else {
clone_result(&result)
};
if let Some(outbox) = state.outboxes.get_mut(&sub_id) {
outbox.push_back((url.clone(), cloned_result));
}
}
}
}
let outbox = match state.outboxes.get_mut(&self.subscriber_id) {
Some(o) => o,
None => return Vec::new(),
};
outbox.drain(..).collect()
}
}
fn clone_result(result: &Result<HttpResponse, String>) -> Result<HttpResponse, String> {
match result {
Ok(response) => Ok(HttpResponse {
status: response.status,
body: response.body.clone(),
headers: response.headers.clone(),
}),
Err(e) => Err(e.clone()),
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex as StdMutex;
#[derive(Default)]
struct MockHttpClient {
sent: StdMutex<Vec<String>>,
responses: StdMutex<Vec<(String, Result<HttpResponse, String>)>>,
}
impl MockHttpClient {
fn sent_urls(&self) -> Vec<String> {
self.sent.lock().unwrap().clone()
}
fn inject_response(&self, url: &str, status: u16, body: &[u8]) {
self.responses.lock().unwrap().push((
url.to_owned(),
Ok(HttpResponse {
status,
body: body.to_vec(),
headers: Vec::new(),
}),
));
}
}
impl HttpClient for MockHttpClient {
fn send(&self, request: HttpRequest) {
self.sent.lock().unwrap().push(request.url);
}
fn poll(&self) -> Vec<(String, Result<HttpResponse, String>)> {
self.responses.lock().unwrap().drain(..).collect()
}
}
#[test]
fn dedup_same_url_across_subscribers() {
let mock = Arc::new(MockHttpClient::default());
let inner: Box<dyn HttpClient> = Box::new(Arc::clone(&mock));
let shared = SharedHttpClient::new(inner);
let sub_a = shared.clone();
let sub_b = shared.clone();
sub_a.send(HttpRequest::get("https://tiles.example.com/5/10/12.pbf"));
sub_b.send(HttpRequest::get("https://tiles.example.com/5/10/12.pbf"));
assert_eq!(mock.sent_urls().len(), 1);
}
#[test]
fn different_urls_are_not_deduped() {
let mock = Arc::new(MockHttpClient::default());
let inner: Box<dyn HttpClient> = Box::new(Arc::clone(&mock));
let shared = SharedHttpClient::new(inner);
let sub_a = shared.clone();
sub_a.send(HttpRequest::get("https://example.com/a.png"));
sub_a.send(HttpRequest::get("https://example.com/b.png"));
assert_eq!(mock.sent_urls().len(), 2);
}
#[test]
fn response_fanned_out_to_all_subscribers() {
let mock = Arc::new(MockHttpClient::default());
let inner: Box<dyn HttpClient> = Box::new(Arc::clone(&mock));
let shared = SharedHttpClient::new(inner);
let sub_a = shared.clone();
let sub_b = shared.clone();
let url = "https://tiles.example.com/5/10/12.pbf";
sub_a.send(HttpRequest::get(url));
sub_b.send(HttpRequest::get(url));
mock.inject_response(url, 200, b"tile-data");
let results_a = sub_a.poll();
assert_eq!(results_a.len(), 1);
assert_eq!(results_a[0].0, url);
assert_eq!(results_a[0].1.as_ref().unwrap().body, b"tile-data");
let results_b = sub_b.poll();
assert_eq!(results_b.len(), 1);
assert_eq!(results_b[0].0, url);
assert_eq!(results_b[0].1.as_ref().unwrap().body, b"tile-data");
}
#[test]
fn subscriber_only_sees_own_responses() {
let mock = Arc::new(MockHttpClient::default());
let inner: Box<dyn HttpClient> = Box::new(Arc::clone(&mock));
let shared = SharedHttpClient::new(inner);
let sub_a = shared.clone();
let sub_b = shared.clone();
sub_a.send(HttpRequest::get("https://example.com/a.png"));
sub_b.send(HttpRequest::get("https://example.com/b.png"));
mock.inject_response("https://example.com/a.png", 200, b"data-a");
mock.inject_response("https://example.com/b.png", 200, b"data-b");
let results_a = sub_a.poll();
assert_eq!(results_a.len(), 1);
assert_eq!(results_a[0].0, "https://example.com/a.png");
let results_b = sub_b.poll();
assert_eq!(results_b.len(), 1);
assert_eq!(results_b[0].0, "https://example.com/b.png");
}
#[test]
fn error_response_fanned_out() {
let mock = Arc::new(MockHttpClient::default());
let inner: Box<dyn HttpClient> = Box::new(Arc::clone(&mock));
let shared = SharedHttpClient::new(inner);
let sub_a = shared.clone();
let sub_b = shared.clone();
let url = "https://tiles.example.com/err";
sub_a.send(HttpRequest::get(url));
sub_b.send(HttpRequest::get(url));
mock.responses
.lock()
.unwrap()
.push((url.to_owned(), Err("connection refused".into())));
let results_a = sub_a.poll();
assert_eq!(results_a.len(), 1);
assert!(results_a[0].1.is_err());
let results_b = sub_b.poll();
assert_eq!(results_b.len(), 1);
assert!(results_b[0].1.is_err());
}
#[test]
fn second_request_after_completion_issues_new_fetch() {
let mock = Arc::new(MockHttpClient::default());
let inner: Box<dyn HttpClient> = Box::new(Arc::clone(&mock));
let shared = SharedHttpClient::new(inner);
let sub_a = shared.clone();
let url = "https://tiles.example.com/5/10/12.pbf";
sub_a.send(HttpRequest::get(url));
mock.inject_response(url, 200, b"first");
let _ = sub_a.poll();
sub_a.send(HttpRequest::get(url));
assert_eq!(mock.sent_urls().len(), 2);
}
impl HttpClient for Arc<MockHttpClient> {
fn send(&self, request: HttpRequest) {
(**self).send(request);
}
fn poll(&self) -> Vec<(String, Result<HttpResponse, String>)> {
(**self).poll()
}
}
}