use crate::error::*;
use crate::lock;
#[cfg(debug_assertions)]
use crate::http::recording::{RecordedRequest, RecordedResponse, Recording, RecordingEntry};
use crate::http::types::ResponseMetadata;
use futures::executor::block_on;
use rand::RngExt;
use reqwest::Client as InnerClient;
use reqwest::header::HeaderMap;
use reqwest::{Method, Request, RequestBuilder, Url};
#[cfg(debug_assertions)]
use std::path::{Path, PathBuf};
#[cfg(debug_assertions)]
use std::sync::Mutex;
use std::time::Duration;
use tracing::{debug, error, info};
pub trait AbstractClient {
fn execute(&self, request: Request) -> Result<(ResponseMetadata, Vec<u8>)>;
fn sleep(&self, sleep: fn(Duration), duration: Duration) {
sleep(duration)
}
fn execute_with_retries(
&self,
max_retries: usize,
add_jitter: bool,
method: Method,
url: Url,
headers: Option<&HeaderMap>,
body: Option<&[u8]>,
) -> Result<(ResponseMetadata, Vec<u8>)> {
self.execute_with_retries_custom_sleep(
std::thread::sleep,
max_retries,
add_jitter,
method,
url,
headers,
body,
)
}
fn execute_with_retries_custom_sleep(
&self,
sleep: fn(Duration),
max_retries: usize,
add_jitter: bool,
method: Method,
url: Url,
headers: Option<&HeaderMap>,
body: Option<&[u8]>,
) -> Result<(ResponseMetadata, Vec<u8>)> {
if max_retries > 58 {
return Err(Error::InvalidArgument(
"max_retries must be <= 58".to_string(),
));
}
let mut rng = rand::rng();
for retry in 0..max_retries + 1 {
let mut request = Request::new(method.clone(), url.clone());
if let Some(headers) = headers {
(*request.headers_mut()) = headers.clone();
}
if let Some(body) = body {
(*request.body_mut()) = Some(body.to_vec().into());
}
if retry > 0 {
let jitter: u64 = if add_jitter {
rng.random_range(0..10)
} else {
0
};
let wait: u64 = (1_u64 << retry - 1) * 100 + jitter;
info!("Sleep for {}ms before retrying {} {}", wait, method, url);
self.sleep(sleep, Duration::from_millis(wait));
}
match self.execute(request) {
Ok((res_metadata, res_body)) => {
let status = res_metadata.get_status()?;
if status.is_server_error() {
info!("{} {} returned {}, retrying...", method, url, status);
continue;
}
return Ok((res_metadata, res_body));
}
Err(Error::Http(e)) => {
info!(
"{} {} failed with transport error {}, retrying...",
method, url, e
);
continue;
}
Err(e) => return Err(e),
}
}
Err(Error::HttpRetry(format!(
"failed to get a success response after {} retries.",
max_retries
)))
}
fn get(&self, url: Url) -> RequestBuilder;
fn post(&self, url: Url) -> RequestBuilder;
fn put(&self, url: Url) -> RequestBuilder;
fn patch(&self, url: Url) -> RequestBuilder;
fn delete(&self, url: Url) -> RequestBuilder;
fn head(&self, url: Url) -> RequestBuilder;
}
pub struct Client {
inner: InnerClient,
#[cfg(debug_assertions)]
recording: Option<Mutex<Recording>>,
#[cfg(debug_assertions)]
recording_output: Option<PathBuf>,
}
impl Client {
pub fn new() -> Self {
Client {
inner: InnerClient::new(),
#[cfg(debug_assertions)]
recording: None,
#[cfg(debug_assertions)]
recording_output: None,
}
}
#[cfg(debug_assertions)]
pub fn new_with_recording<P: AsRef<Path>>(recording_output: P) -> Self {
Client {
inner: InnerClient::new(),
recording: Some(Mutex::new(Recording::default())),
recording_output: Some(recording_output.as_ref().to_path_buf()),
}
}
fn execute_impl(&self, request: Request) -> Result<(ResponseMetadata, Vec<u8>)> {
#[cfg(debug_assertions)]
let method = request.method().clone();
#[cfg(debug_assertions)]
let url = request.url().clone();
let res = block_on(self.inner.execute(request))?;
let metadata = ResponseMetadata::from(&res);
let body: Vec<u8> = block_on(res.bytes())?.into_iter().collect();
#[cfg(debug_assertions)]
match metadata.get_status() {
Ok(status) => debug!("{} {} => {}", method, url, status),
Err(e) => debug!("{} {} => (invalid status: {})", method, url, e),
}
Ok((metadata, body))
}
}
impl AbstractClient for Client {
#[cfg(not(debug_assertions))]
fn execute(&self, request: Request) -> Result<(ResponseMetadata, Vec<u8>)> {
self.execute_impl(request)
}
#[cfg(debug_assertions)]
fn execute(&self, request: Request) -> Result<(ResponseMetadata, Vec<u8>)> {
let recorded_req = RecordedRequest::from(&request);
let res = self.execute_impl(request)?;
if let Some(recording) = self.recording.as_ref() {
let recorded_res = RecordedResponse::from(&res);
let mut guard = lock(recording);
guard.0.push_back(RecordingEntry {
req: recorded_req,
res: recorded_res,
});
}
Ok(res)
}
fn get(&self, url: Url) -> RequestBuilder {
self.inner.get(url)
}
fn post(&self, url: Url) -> RequestBuilder {
self.inner.post(url)
}
fn put(&self, url: Url) -> RequestBuilder {
self.inner.put(url)
}
fn patch(&self, url: Url) -> RequestBuilder {
self.inner.patch(url)
}
fn delete(&self, url: Url) -> RequestBuilder {
self.inner.delete(url)
}
fn head(&self, url: Url) -> RequestBuilder {
self.inner.head(url)
}
}
#[cfg(debug_assertions)]
impl Drop for Client {
fn drop(&mut self) {
let (Some(recording), Some(recording_output)) =
(self.recording.as_ref(), self.recording_output.as_ref())
else {
return;
};
if std::thread::panicking() {
return;
}
let guard = lock(&recording);
if let Err(e) = guard.flush(recording_output) {
error!(
"failed to write HTTP client recording to {}: {}",
recording_output.display(),
e
);
} else {
debug!(
"Wrote HTTP client recording to: {}",
recording_output.display()
);
}
}
}