use std::collections::HashMap;
use std::time::Duration;
use async_trait::async_trait;
use reqwest::header::HeaderMap;
use reqwest_middleware::ClientBuilder;
use reqwest_middleware::ClientWithMiddleware;
use reqwest_retry::DefaultRetryableStrategy;
use reqwest_retry::RetryTransientMiddleware;
use reqwest_retry::Retryable;
use reqwest_retry::RetryableStrategy;
use reqwest_retry::policies::ExponentialBackoff;
use serde::de::DeserializeOwned;
use tracing::warn;
use crate::Error;
use crate::Res;
const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
const POOL_IDLE_TIMEOUT: Duration = Duration::from_secs(90);
const MAX_RETRIES: u32 = 2;
#[async_trait]
pub trait HttpClient: Send + Sync {
async fn get<T: DeserializeOwned>(&self, url: &str, auth_token: Option<&str>) -> Res<T>;
async fn head(&self, url: &str) -> Res<HeaderMap>;
async fn post<T: DeserializeOwned>(
&self,
url: &str,
form_data: &HashMap<String, String>,
) -> Res<T>;
async fn post_json<T: DeserializeOwned, B: serde::Serialize + Send + Sync>(
&self,
url: &str,
body: &B,
) -> Res<T>;
}
#[derive(Clone, Debug)]
pub struct ReqwestClient {
client: ClientWithMiddleware,
}
impl Default for ReqwestClient {
fn default() -> Self {
Self::new()
}
}
impl ReqwestClient {
pub fn new() -> Self {
let inner = reqwest::Client::builder()
.timeout(REQUEST_TIMEOUT)
.connect_timeout(CONNECT_TIMEOUT)
.pool_idle_timeout(POOL_IDLE_TIMEOUT)
.build()
.expect("reqwest client build should not fail with default TLS config");
let retry_policy = ExponentialBackoff::builder().build_with_max_retries(MAX_RETRIES);
let retry_middleware =
RetryTransientMiddleware::new_with_policy_and_strategy(retry_policy, LoggingStrategy);
let client = ClientBuilder::new(inner).with(retry_middleware).build();
Self { client }
}
}
struct LoggingStrategy;
impl RetryableStrategy for LoggingStrategy {
fn handle(
&self,
res: &Result<reqwest::Response, reqwest_middleware::Error>,
) -> Option<Retryable> {
let decision = DefaultRetryableStrategy.handle(res);
if matches!(decision, Some(Retryable::Transient)) {
match res {
Ok(resp) => warn!(
status = resp.status().as_u16(),
url = %resp.url(),
"🔁 transient HTTP response — may retry"
),
Err(e) => warn!(
error = %e,
"🔁 transient HTTP error — may retry"
),
}
}
decision
}
}
impl From<reqwest_middleware::Error> for Error {
fn from(err: reqwest_middleware::Error) -> Self {
match err {
reqwest_middleware::Error::Reqwest(e) => Error::Reqwest(e),
reqwest_middleware::Error::Middleware(e) => {
Error::Io(std::io::Error::other(e.to_string()))
}
}
}
}
const USER_AGENT: &str =
"Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1; SV1; .NET CLR 1.0.3705; .NET CLR 1.1.4322)";
const ERROR_BODY_LOG_LIMIT: usize = 500;
async fn ensure_success(response: reqwest::Response) -> Res<reqwest::Response> {
if response.status().is_success() {
return Ok(response);
}
let status = response.status();
let url = response.url().clone();
let err = response
.error_for_status_ref()
.expect_err("status is non-success");
let body = response.text().await.unwrap_or_default();
warn!(
status = status.as_u16(),
url = %url,
body = %truncate_for_log(&body),
"❌ HTTP error response"
);
Err(err.into())
}
fn truncate_for_log(s: &str) -> String {
if s.len() <= ERROR_BODY_LOG_LIMIT {
return s.to_string();
}
let mut end = ERROR_BODY_LOG_LIMIT;
while end > 0 && !s.is_char_boundary(end) {
end -= 1;
}
format!("{}…[{} bytes total]", &s[..end], s.len())
}
#[async_trait]
impl HttpClient for ReqwestClient {
async fn get<T: DeserializeOwned>(&self, url: &str, auth_token: Option<&str>) -> Res<T> {
let mut request = self.client.get(url).header("User-Agent", USER_AGENT);
if let Some(token) = auth_token {
request = request.bearer_auth(token);
}
let response = ensure_success(request.send().await?).await?;
Ok(response.json().await?)
}
async fn head(&self, url: &str) -> Res<HeaderMap> {
let response = self
.client
.head(url)
.header("User-Agent", USER_AGENT)
.send()
.await?;
Ok(response.headers().clone())
}
async fn post<T: DeserializeOwned>(
&self,
url: &str,
form_data: &HashMap<String, String>,
) -> Res<T> {
let response = ensure_success(
self.client
.post(url)
.header("User-Agent", USER_AGENT)
.form(form_data)
.send()
.await?,
)
.await?;
Ok(response.json().await?)
}
async fn post_json<T: DeserializeOwned, B: serde::Serialize + Send + Sync>(
&self,
url: &str,
body: &B,
) -> Res<T> {
let response = ensure_success(
self.client
.post(url)
.header("User-Agent", USER_AGENT)
.json(body)
.send()
.await?,
)
.await?;
Ok(response.json().await?)
}
}
#[cfg(test)]
mod tests {
use super::*;
use test_log::test;
use serde::Deserialize;
use serde::Serialize;
#[test(tokio::test)]
async fn test_get_config() -> Res {
let client = ReqwestClient::new();
#[derive(Deserialize, Serialize)]
struct Config {
mode: String,
}
let response: Config = client
.get("https://open.quiltdata.com/config.json", None)
.await?;
assert_eq!(response.mode, "OPEN");
Ok(())
}
#[test]
fn truncate_short_body_is_unchanged() {
assert_eq!(truncate_for_log("hello"), "hello");
}
#[test]
fn truncate_long_body_is_cut_with_total_length() {
let s = "x".repeat(ERROR_BODY_LOG_LIMIT + 10);
let got = truncate_for_log(&s);
assert!(got.starts_with(&"x".repeat(ERROR_BODY_LOG_LIMIT)));
assert!(got.contains(&format!("[{} bytes total]", s.len())));
}
#[test]
fn truncate_never_splits_multibyte_chars() {
let prefix = "a".repeat(ERROR_BODY_LOG_LIMIT - 2);
let s = format!("{prefix}💥trailing");
let got = truncate_for_log(&s);
assert!(got.contains(&prefix));
assert!(got.contains("bytes total"));
}
}