Skip to main content

quilt_rs/io/remote/
client.rs

1use std::collections::HashMap;
2use std::time::Duration;
3
4use async_trait::async_trait;
5use reqwest::header::HeaderMap;
6use reqwest_middleware::ClientBuilder;
7use reqwest_middleware::ClientWithMiddleware;
8use reqwest_retry::DefaultRetryableStrategy;
9use reqwest_retry::RetryTransientMiddleware;
10use reqwest_retry::Retryable;
11use reqwest_retry::RetryableStrategy;
12use reqwest_retry::policies::ExponentialBackoff;
13use serde::de::DeserializeOwned;
14use tracing::warn;
15
16use crate::Error;
17use crate::Res;
18
19const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
20const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
21const POOL_IDLE_TIMEOUT: Duration = Duration::from_secs(90);
22const MAX_RETRIES: u32 = 2;
23
24#[async_trait]
25pub trait HttpClient: Send + Sync {
26    async fn get<T: DeserializeOwned>(&self, url: &str, auth_token: Option<&str>) -> Res<T>;
27    async fn head(&self, url: &str) -> Res<HeaderMap>;
28    async fn post<T: DeserializeOwned>(
29        &self,
30        url: &str,
31        form_data: &HashMap<String, String>,
32    ) -> Res<T>;
33    async fn post_json<T: DeserializeOwned, B: serde::Serialize + Send + Sync>(
34        &self,
35        url: &str,
36        body: &B,
37    ) -> Res<T>;
38}
39
40#[derive(Clone, Debug)]
41pub struct ReqwestClient {
42    client: ClientWithMiddleware,
43}
44
45impl Default for ReqwestClient {
46    fn default() -> Self {
47        Self::new()
48    }
49}
50
51impl ReqwestClient {
52    pub fn new() -> Self {
53        let inner = reqwest::Client::builder()
54            .timeout(REQUEST_TIMEOUT)
55            .connect_timeout(CONNECT_TIMEOUT)
56            .pool_idle_timeout(POOL_IDLE_TIMEOUT)
57            .build()
58            .expect("reqwest client build should not fail with default TLS config");
59
60        let retry_policy = ExponentialBackoff::builder().build_with_max_retries(MAX_RETRIES);
61        let retry_middleware =
62            RetryTransientMiddleware::new_with_policy_and_strategy(retry_policy, LoggingStrategy);
63
64        let client = ClientBuilder::new(inner).with(retry_middleware).build();
65
66        Self { client }
67    }
68}
69
70/// Wraps [`DefaultRetryableStrategy`] with a `warn!` on every attempt the retry
71/// middleware classifies as transient. Gives us a flakiness signal in logs
72/// without standing up dedicated telemetry.
73///
74/// Fires on the *final* attempt too β€” reqwest-retry asks the strategy before
75/// checking whether any attempts remain, so "may retry" is honest: retry
76/// happens only if the attempt count hasn't been exhausted.
77struct LoggingStrategy;
78
79impl RetryableStrategy for LoggingStrategy {
80    fn handle(
81        &self,
82        res: &Result<reqwest::Response, reqwest_middleware::Error>,
83    ) -> Option<Retryable> {
84        let decision = DefaultRetryableStrategy.handle(res);
85        if matches!(decision, Some(Retryable::Transient)) {
86            match res {
87                Ok(resp) => warn!(
88                    status = resp.status().as_u16(),
89                    url = %resp.url(),
90                    "πŸ” transient HTTP response β€” may retry"
91                ),
92                Err(e) => warn!(
93                    error = %e,
94                    "πŸ” transient HTTP error β€” may retry"
95                ),
96            }
97        }
98        decision
99    }
100}
101
102impl From<reqwest_middleware::Error> for Error {
103    fn from(err: reqwest_middleware::Error) -> Self {
104        match err {
105            reqwest_middleware::Error::Reqwest(e) => Error::Reqwest(e),
106            // `Middleware(anyhow::Error)` is only produced if a middleware
107            // layer itself fails (not the HTTP exchange). Our only middleware
108            // is the retry layer, which doesn't surface errors this way; fold
109            // into `Error::Io` so callers don't need a new match arm.
110            reqwest_middleware::Error::Middleware(e) => {
111                Error::Io(std::io::Error::other(e.to_string()))
112            }
113        }
114    }
115}
116
117const USER_AGENT: &str =
118    "Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1; SV1; .NET CLR 1.0.3705; .NET CLR 1.1.4322)";
119
120/// Max bytes of response body to include in error log lines. Enough for an
121/// RFC 6749 Β§5.2 error payload (`{"error":"invalid_grant",...}`) or a short
122/// server error page, without flooding logs with a full HTML response.
123const ERROR_BODY_LOG_LIMIT: usize = 500;
124
125/// On non-2xx responses, reads and logs the status/url/body before returning
126/// the reqwest error. Keeps the response body β€” which `error_for_status`
127/// would otherwise discard β€” available for diagnostics.
128async fn ensure_success(response: reqwest::Response) -> Res<reqwest::Response> {
129    if response.status().is_success() {
130        return Ok(response);
131    }
132    let status = response.status();
133    let url = response.url().clone();
134    // Take the error via the non-consuming variant, then consume the response
135    // for its body.
136    let err = response
137        .error_for_status_ref()
138        .expect_err("status is non-success");
139    let body = response.text().await.unwrap_or_default();
140    warn!(
141        status = status.as_u16(),
142        url = %url,
143        body = %truncate_for_log(&body),
144        "❌ HTTP error response"
145    );
146    Err(err.into())
147}
148
149fn truncate_for_log(s: &str) -> String {
150    if s.len() <= ERROR_BODY_LOG_LIMIT {
151        return s.to_string();
152    }
153    let mut end = ERROR_BODY_LOG_LIMIT;
154    while end > 0 && !s.is_char_boundary(end) {
155        end -= 1;
156    }
157    format!("{}…[{} bytes total]", &s[..end], s.len())
158}
159
160#[async_trait]
161impl HttpClient for ReqwestClient {
162    async fn get<T: DeserializeOwned>(&self, url: &str, auth_token: Option<&str>) -> Res<T> {
163        let mut request = self.client.get(url).header("User-Agent", USER_AGENT);
164
165        if let Some(token) = auth_token {
166            request = request.bearer_auth(token);
167        }
168
169        let response = ensure_success(request.send().await?).await?;
170        Ok(response.json().await?)
171    }
172
173    // TODO: wire through `ensure_success` so non-2xx HEAD responses surface as
174    // errors instead of empty-header `Ok`.
175    async fn head(&self, url: &str) -> Res<HeaderMap> {
176        let response = self
177            .client
178            .head(url)
179            .header("User-Agent", USER_AGENT)
180            .send()
181            .await?;
182        Ok(response.headers().clone())
183    }
184
185    async fn post<T: DeserializeOwned>(
186        &self,
187        url: &str,
188        form_data: &HashMap<String, String>,
189    ) -> Res<T> {
190        let response = ensure_success(
191            self.client
192                .post(url)
193                .header("User-Agent", USER_AGENT)
194                .form(form_data)
195                .send()
196                .await?,
197        )
198        .await?;
199        Ok(response.json().await?)
200    }
201
202    async fn post_json<T: DeserializeOwned, B: serde::Serialize + Send + Sync>(
203        &self,
204        url: &str,
205        body: &B,
206    ) -> Res<T> {
207        let response = ensure_success(
208            self.client
209                .post(url)
210                .header("User-Agent", USER_AGENT)
211                .json(body)
212                .send()
213                .await?,
214        )
215        .await?;
216        Ok(response.json().await?)
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223    use test_log::test;
224
225    use serde::Deserialize;
226    use serde::Serialize;
227
228    #[test(tokio::test)]
229    async fn test_get_config() -> Res {
230        let client = ReqwestClient::new();
231
232        #[derive(Deserialize, Serialize)]
233        struct Config {
234            mode: String,
235        }
236
237        // Get the raw text content first to check for the QUILT_CATALOG_CONFIG string
238        let response: Config = client
239            .get("https://open.quiltdata.com/config.json", None)
240            .await?;
241
242        // Check that the config.js contains the QUILT_CATALOG_CONFIG string
243        assert_eq!(response.mode, "OPEN");
244
245        Ok(())
246    }
247
248    #[test]
249    fn truncate_short_body_is_unchanged() {
250        assert_eq!(truncate_for_log("hello"), "hello");
251    }
252
253    #[test]
254    fn truncate_long_body_is_cut_with_total_length() {
255        let s = "x".repeat(ERROR_BODY_LOG_LIMIT + 10);
256        let got = truncate_for_log(&s);
257        assert!(got.starts_with(&"x".repeat(ERROR_BODY_LOG_LIMIT)));
258        assert!(got.contains(&format!("[{} bytes total]", s.len())));
259    }
260
261    // `str` slicing must land on a char boundary β€” a multi-byte glyph at the
262    // cutoff would otherwise panic.
263    #[test]
264    fn truncate_never_splits_multibyte_chars() {
265        // "πŸ’₯" is 4 bytes; put one straddling the limit.
266        let prefix = "a".repeat(ERROR_BODY_LOG_LIMIT - 2);
267        let s = format!("{prefix}πŸ’₯trailing");
268        let got = truncate_for_log(&s);
269        assert!(got.contains(&prefix));
270        assert!(got.contains("bytes total"));
271    }
272}