quilt_rs/io/remote/
client.rs1use std::collections::HashMap;
2use std::time::Duration;
3
4use async_trait::async_trait;
5use reqwest::header::HeaderMap;
6use reqwest_middleware::ClientBuilder;
7use reqwest_middleware::ClientWithMiddleware;
8use reqwest_retry::DefaultRetryableStrategy;
9use reqwest_retry::RetryTransientMiddleware;
10use reqwest_retry::Retryable;
11use reqwest_retry::RetryableStrategy;
12use reqwest_retry::policies::ExponentialBackoff;
13use serde::de::DeserializeOwned;
14use tracing::warn;
15
16use crate::Error;
17use crate::Res;
18
19const REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
20const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);
21const POOL_IDLE_TIMEOUT: Duration = Duration::from_secs(90);
22const MAX_RETRIES: u32 = 2;
23
24#[async_trait]
25pub trait HttpClient: Send + Sync {
26 async fn get<T: DeserializeOwned>(&self, url: &str, auth_token: Option<&str>) -> Res<T>;
27 async fn head(&self, url: &str) -> Res<HeaderMap>;
28 async fn post<T: DeserializeOwned>(
29 &self,
30 url: &str,
31 form_data: &HashMap<String, String>,
32 ) -> Res<T>;
33 async fn post_json<T: DeserializeOwned, B: serde::Serialize + Send + Sync>(
34 &self,
35 url: &str,
36 body: &B,
37 ) -> Res<T>;
38}
39
40#[derive(Clone, Debug)]
41pub struct ReqwestClient {
42 client: ClientWithMiddleware,
43}
44
45impl Default for ReqwestClient {
46 fn default() -> Self {
47 Self::new()
48 }
49}
50
51impl ReqwestClient {
52 pub fn new() -> Self {
53 let inner = reqwest::Client::builder()
54 .timeout(REQUEST_TIMEOUT)
55 .connect_timeout(CONNECT_TIMEOUT)
56 .pool_idle_timeout(POOL_IDLE_TIMEOUT)
57 .build()
58 .expect("reqwest client build should not fail with default TLS config");
59
60 let retry_policy = ExponentialBackoff::builder().build_with_max_retries(MAX_RETRIES);
61 let retry_middleware =
62 RetryTransientMiddleware::new_with_policy_and_strategy(retry_policy, LoggingStrategy);
63
64 let client = ClientBuilder::new(inner).with(retry_middleware).build();
65
66 Self { client }
67 }
68}
69
70struct LoggingStrategy;
78
79impl RetryableStrategy for LoggingStrategy {
80 fn handle(
81 &self,
82 res: &Result<reqwest::Response, reqwest_middleware::Error>,
83 ) -> Option<Retryable> {
84 let decision = DefaultRetryableStrategy.handle(res);
85 if matches!(decision, Some(Retryable::Transient)) {
86 match res {
87 Ok(resp) => warn!(
88 status = resp.status().as_u16(),
89 url = %resp.url(),
90 "π transient HTTP response β may retry"
91 ),
92 Err(e) => warn!(
93 error = %e,
94 "π transient HTTP error β may retry"
95 ),
96 }
97 }
98 decision
99 }
100}
101
102impl From<reqwest_middleware::Error> for Error {
103 fn from(err: reqwest_middleware::Error) -> Self {
104 match err {
105 reqwest_middleware::Error::Reqwest(e) => Error::Reqwest(e),
106 reqwest_middleware::Error::Middleware(e) => {
111 Error::Io(std::io::Error::other(e.to_string()))
112 }
113 }
114 }
115}
116
117const USER_AGENT: &str =
118 "Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1; SV1; .NET CLR 1.0.3705; .NET CLR 1.1.4322)";
119
120const ERROR_BODY_LOG_LIMIT: usize = 500;
124
125async fn ensure_success(response: reqwest::Response) -> Res<reqwest::Response> {
129 if response.status().is_success() {
130 return Ok(response);
131 }
132 let status = response.status();
133 let url = response.url().clone();
134 let err = response
137 .error_for_status_ref()
138 .expect_err("status is non-success");
139 let body = response.text().await.unwrap_or_default();
140 warn!(
141 status = status.as_u16(),
142 url = %url,
143 body = %truncate_for_log(&body),
144 "β HTTP error response"
145 );
146 Err(err.into())
147}
148
149fn truncate_for_log(s: &str) -> String {
150 if s.len() <= ERROR_BODY_LOG_LIMIT {
151 return s.to_string();
152 }
153 let mut end = ERROR_BODY_LOG_LIMIT;
154 while end > 0 && !s.is_char_boundary(end) {
155 end -= 1;
156 }
157 format!("{}β¦[{} bytes total]", &s[..end], s.len())
158}
159
160#[async_trait]
161impl HttpClient for ReqwestClient {
162 async fn get<T: DeserializeOwned>(&self, url: &str, auth_token: Option<&str>) -> Res<T> {
163 let mut request = self.client.get(url).header("User-Agent", USER_AGENT);
164
165 if let Some(token) = auth_token {
166 request = request.bearer_auth(token);
167 }
168
169 let response = ensure_success(request.send().await?).await?;
170 Ok(response.json().await?)
171 }
172
173 async fn head(&self, url: &str) -> Res<HeaderMap> {
176 let response = self
177 .client
178 .head(url)
179 .header("User-Agent", USER_AGENT)
180 .send()
181 .await?;
182 Ok(response.headers().clone())
183 }
184
185 async fn post<T: DeserializeOwned>(
186 &self,
187 url: &str,
188 form_data: &HashMap<String, String>,
189 ) -> Res<T> {
190 let response = ensure_success(
191 self.client
192 .post(url)
193 .header("User-Agent", USER_AGENT)
194 .form(form_data)
195 .send()
196 .await?,
197 )
198 .await?;
199 Ok(response.json().await?)
200 }
201
202 async fn post_json<T: DeserializeOwned, B: serde::Serialize + Send + Sync>(
203 &self,
204 url: &str,
205 body: &B,
206 ) -> Res<T> {
207 let response = ensure_success(
208 self.client
209 .post(url)
210 .header("User-Agent", USER_AGENT)
211 .json(body)
212 .send()
213 .await?,
214 )
215 .await?;
216 Ok(response.json().await?)
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223 use test_log::test;
224
225 use serde::Deserialize;
226 use serde::Serialize;
227
228 #[test(tokio::test)]
229 async fn test_get_config() -> Res {
230 let client = ReqwestClient::new();
231
232 #[derive(Deserialize, Serialize)]
233 struct Config {
234 mode: String,
235 }
236
237 let response: Config = client
239 .get("https://open.quiltdata.com/config.json", None)
240 .await?;
241
242 assert_eq!(response.mode, "OPEN");
244
245 Ok(())
246 }
247
248 #[test]
249 fn truncate_short_body_is_unchanged() {
250 assert_eq!(truncate_for_log("hello"), "hello");
251 }
252
253 #[test]
254 fn truncate_long_body_is_cut_with_total_length() {
255 let s = "x".repeat(ERROR_BODY_LOG_LIMIT + 10);
256 let got = truncate_for_log(&s);
257 assert!(got.starts_with(&"x".repeat(ERROR_BODY_LOG_LIMIT)));
258 assert!(got.contains(&format!("[{} bytes total]", s.len())));
259 }
260
261 #[test]
264 fn truncate_never_splits_multibyte_chars() {
265 let prefix = "a".repeat(ERROR_BODY_LOG_LIMIT - 2);
267 let s = format!("{prefix}π₯trailing");
268 let got = truncate_for_log(&s);
269 assert!(got.contains(&prefix));
270 assert!(got.contains("bytes total"));
271 }
272}