Skip to main content

better_fetch/
client.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use indexmap::IndexMap;
6use tokio::sync::Semaphore;
7
8use http::Method;
9use reqwest::Client as ReqwestClient;
10use url::Url;
11
12use crate::auth::Auth;
13use crate::backend::{HttpBackend, HttpBody, HttpRequest, ReqwestBackend};
14use crate::cancel::execute_or_cancel;
15use crate::endpoint::{Endpoint, EndpointRequestBuilder};
16use crate::error::Error;
17use crate::hooks::{ErrorContext, Hooks, RequestContext, ResponseContext, SuccessContext};
18use crate::plugin::{PluginRegistry, PreparedRequest};
19use crate::request::RequestBuilder;
20use crate::response::Response;
21use crate::retry::{sleep_or_cancel, RetryPolicy};
22use crate::url_build::build_url;
23use crate::Result;
24
25#[cfg(feature = "tower")]
26use crate::backend::HttpResponse;
27
28#[cfg(feature = "json")]
29use crate::json_parser::JsonParserFn;
30
31#[cfg(feature = "schema")]
32use crate::schema::SchemaRegistry;
33
34fn body_for_context(body: &HttpBody) -> Option<bytes::Bytes> {
35    match body {
36        HttpBody::Empty => None,
37        HttpBody::Bytes(b) => Some(b.clone()),
38    }
39}
40
41/// Shared client configuration.
42#[derive(Clone)]
43pub struct ClientConfig {
44    pub base_url: Url,
45    pub timeout: Option<Duration>,
46    pub retry: Option<RetryPolicy>,
47    pub auth: Option<Auth>,
48    pub default_headers: http::HeaderMap,
49    pub hooks: Hooks,
50    pub(crate) merged_hooks: Hooks,
51    pub plugins: Arc<PluginRegistry>,
52    /// Limits concurrent in-flight requests for this client (including retries).
53    ///
54    /// This is separate from Tower's [`ConcurrencyLimitLayer`](crate::tower::stack::ConcurrencyLimitLayer):
55    /// the client semaphore applies to the full request lifecycle (hooks + retries), while Tower
56    /// limits only transport-layer concurrency. Avoid stacking both without accounting for that.
57    pub max_in_flight: Option<Arc<Semaphore>>,
58    #[cfg(feature = "schema")]
59    pub schema_registry: Option<Arc<SchemaRegistry>>,
60    #[cfg(feature = "json")]
61    pub json_parser: Option<JsonParserFn>,
62}
63
64/// Typed HTTP client built on reqwest.
65#[derive(Clone)]
66pub struct Client {
67    config: Arc<ClientConfig>,
68    backend: Arc<dyn HttpBackend>,
69}
70
71impl Client {
72    pub fn new(base_url: impl AsRef<str>) -> Result<Self> {
73        ClientBuilder::new().base_url(base_url)?.build()
74    }
75
76    pub fn builder() -> ClientBuilder {
77        ClientBuilder::new()
78    }
79
80    /// Builds a client with a custom reqwest instance. [`ClientBuilder::base_url`] is required.
81    pub fn with_http_client(reqwest_client: ReqwestClient, base_url: impl AsRef<str>) -> Result<Self> {
82        ClientBuilder::new()
83            .reqwest_client(reqwest_client)
84            .base_url(base_url)?
85            .build()
86    }
87
88    /// Start a typed request for [`Endpoint`] `E`.
89    pub fn call<E: Endpoint>(&self) -> EndpointRequestBuilder<'_, E> {
90        EndpointRequestBuilder::new(self.request(E::METHOD, E::PATH))
91    }
92
93    pub fn config(&self) -> &ClientConfig {
94        &self.config
95    }
96
97    pub fn get(&self, path: impl Into<String>) -> RequestBuilder<'_> {
98        self.request(Method::GET, path)
99    }
100
101    pub fn post(&self, path: impl Into<String>) -> RequestBuilder<'_> {
102        self.request(Method::POST, path)
103    }
104
105    pub fn put(&self, path: impl Into<String>) -> RequestBuilder<'_> {
106        self.request(Method::PUT, path)
107    }
108
109    pub fn patch(&self, path: impl Into<String>) -> RequestBuilder<'_> {
110        self.request(Method::PATCH, path)
111    }
112
113    pub fn delete(&self, path: impl Into<String>) -> RequestBuilder<'_> {
114        self.request(Method::DELETE, path)
115    }
116
117    pub fn head(&self, path: impl Into<String>) -> RequestBuilder<'_> {
118        self.request(Method::HEAD, path)
119    }
120
121    pub fn request(&self, method: Method, path: impl Into<String>) -> RequestBuilder<'_> {
122        RequestBuilder {
123            client: self,
124            method,
125            path: path.into(),
126            params: HashMap::new(),
127            query: IndexMap::new(),
128            headers: self.config.default_headers.clone(),
129            body: HttpBody::Empty,
130            #[cfg(feature = "multipart")]
131            multipart: None,
132            timeout: self.config.timeout,
133            retry: self.config.retry.clone(),
134            auth: self.config.auth.clone(),
135            cancellation: None,
136            throw_on_error: false,
137            #[cfg(feature = "json")]
138            json_parser: None,
139            #[cfg(feature = "validate")]
140            validate_response: true,
141        }
142    }
143
144    pub(crate) async fn execute(&self, builder: RequestBuilder<'_>) -> Result<Response> {
145        #[cfg(feature = "json")]
146        let json_parser = builder
147            .json_parser
148            .clone()
149            .or_else(|| self.config.json_parser.clone());
150
151        let built = build_url(
152            &self.config.base_url,
153            &builder.path,
154            &builder.params,
155            &builder.query,
156        )?;
157
158        let mut method = builder.method;
159        if let Some(override_method) = built.method_override {
160            method = override_method;
161        }
162
163        #[cfg(feature = "schema")]
164        if let Some(registry) = &self.config.schema_registry {
165            registry.ensure_route(&builder.path, &method)?;
166        }
167
168        let mut url = built.url;
169        let mut headers = builder.headers;
170        let auth = builder.auth.or_else(|| self.config.auth.clone());
171        if let Some(auth) = auth {
172            auth.apply(&mut headers).await?;
173        }
174
175        let mut prepared = PreparedRequest {
176            url: url.clone(),
177            path: builder.path.clone(),
178            method: method.clone(),
179            headers: headers.clone(),
180        };
181        self.config.plugins.run_init_all(&mut prepared).await?;
182        url = prepared.url;
183        headers = prepared.headers;
184        method = prepared.method;
185
186        let mut req_ctx = RequestContext {
187            url: url.clone(),
188            method: method.clone(),
189            headers: headers.clone(),
190            body: body_for_context(&builder.body),
191            retry_attempt: 0,
192        };
193
194        let merged_hooks = &self.config.merged_hooks;
195        req_ctx = merged_hooks.run_on_request(req_ctx).await?;
196        url = req_ctx.url.clone();
197        headers = req_ctx.headers.clone();
198        method = req_ctx.method.clone();
199
200        let timeout = builder.timeout;
201        let retry_policy = builder.retry.or_else(|| self.config.retry.clone());
202        let throw_on_error = builder.throw_on_error;
203        let cancel = builder.cancellation;
204
205        let backend = self.backend.clone();
206
207        let _in_flight_permit = match &self.config.max_in_flight {
208            Some(sem) => Some(
209                sem.acquire()
210                    .await
211                    .map_err(|_| Error::Other("max_in_flight semaphore closed".into()))?,
212            ),
213            None => None,
214        };
215
216        let mut attempt = 0u32;
217        let max_attempts = retry_policy.as_ref().map(|p| p.max_attempts()).unwrap_or(0);
218
219        let request_body = builder.body;
220        #[cfg(feature = "multipart")]
221        let mut multipart_body = builder.multipart;
222        #[cfg(feature = "multipart")]
223        let had_multipart = multipart_body.is_some();
224
225        let cancel_ref = cancel.as_ref();
226
227        loop {
228            req_ctx.retry_attempt = attempt;
229
230            #[cfg(feature = "multipart")]
231            if attempt > 0 && had_multipart {
232                return Err(Error::Other(
233                    "automatic retry is not supported with multipart request bodies".into(),
234                ));
235            }
236
237            let http_req = HttpRequest {
238                method: method.clone(),
239                url: url.clone(),
240                headers: headers.clone(),
241                body: request_body.clone(),
242                timeout,
243                cancellation: cancel.clone(),
244                #[cfg(feature = "multipart")]
245                multipart: multipart_body.take(),
246            };
247            let request_url = http_req.url.clone();
248
249            let result = execute_or_cancel(cancel_ref, backend.execute(http_req)).await;
250
251            match result {
252                Ok(http_res) => {
253                    let response = Response::new(
254                        http_res.status,
255                        http_res.headers.clone(),
256                        http_res.body,
257                        Some(request_url.clone()),
258                        #[cfg(feature = "json")]
259                        json_parser.clone(),
260                    );
261
262                    let response = merged_hooks
263                        .run_on_response(ResponseContext {
264                            request: req_ctx.clone(),
265                            response,
266                        })
267                        .await?;
268
269                    let should_retry = retry_policy
270                        .as_ref()
271                        .map(|p| p.should_retry_response(&response, false))
272                        .unwrap_or(false);
273
274                    if should_retry && attempt < max_attempts {
275                        merged_hooks
276                            .run_on_retry(ResponseContext {
277                                request: req_ctx.clone(),
278                                response: response.clone(),
279                            })
280                            .await;
281                        let delay = retry_policy
282                            .as_ref()
283                            .map(|p| p.delay_after_response(attempt, response.headers()))
284                            .unwrap_or(Duration::from_secs(1));
285                        attempt += 1;
286                        sleep_or_cancel(delay, cancel_ref).await?;
287                        continue;
288                    }
289
290                    if response.is_success() {
291                        merged_hooks
292                            .run_on_success(SuccessContext {
293                                request: req_ctx.clone(),
294                                response: response.clone(),
295                            })
296                            .await;
297                        return Ok(response);
298                    }
299
300                    let status = response.status();
301                    let http_err = Error::http_with_status_text(
302                        status,
303                        status.canonical_reason().unwrap_or("request failed"),
304                        status.canonical_reason().unwrap_or("request failed"),
305                        Some(response.bytes().clone()),
306                    );
307                    merged_hooks
308                        .run_on_error(ErrorContext {
309                            request: req_ctx.clone(),
310                            response: Some(response.clone()),
311                            error: http_err.clone(),
312                        })
313                        .await;
314
315                    if throw_on_error {
316                        return Err(http_err);
317                    }
318                    return Ok(response);
319                }
320                Err(err) => {
321                    if err.is_cancelled() {
322                        merged_hooks
323                            .run_on_error(ErrorContext {
324                                request: req_ctx.clone(),
325                                response: None,
326                                error: err.clone(),
327                            })
328                            .await;
329                        return Err(err);
330                    }
331
332                    let retry_transport = matches!(&err, Error::Transport(_) | Error::Timeout);
333                    if retry_transport && retry_policy.is_some() && attempt < max_attempts {
334                        merged_hooks
335                            .run_on_retry(ResponseContext {
336                                request: req_ctx.clone(),
337                                response: Response::new(
338                                    http::StatusCode::SERVICE_UNAVAILABLE,
339                                    http::HeaderMap::new(),
340                                    bytes::Bytes::new(),
341                                    Some(request_url.clone()),
342                                    #[cfg(feature = "json")]
343                                    None,
344                                ),
345                            })
346                            .await;
347                        let delay = retry_policy
348                            .as_ref()
349                            .map(|p| p.delay_after_response(attempt, &http::HeaderMap::new()))
350                            .unwrap_or(Duration::from_secs(1));
351                        attempt += 1;
352                        sleep_or_cancel(delay, cancel_ref).await?;
353                        continue;
354                    }
355
356                    merged_hooks
357                        .run_on_error(ErrorContext {
358                            request: req_ctx.clone(),
359                            response: None,
360                            error: err.clone(),
361                        })
362                        .await;
363
364                    if retry_transport && retry_policy.is_some() {
365                        return Err(Error::retry_exhausted(attempt + 1, err));
366                    }
367
368                    return Err(err);
369                }
370            }
371        }
372    }
373}
374
375/// Builder for [`Client`].
376pub struct ClientBuilder {
377    base_url: Option<Url>,
378    timeout: Option<Duration>,
379    retry: Option<RetryPolicy>,
380    auth: Option<Auth>,
381    default_headers: http::HeaderMap,
382    hooks: Hooks,
383    plugins: PluginRegistry,
384    reqwest_client: Option<ReqwestClient>,
385    custom_backend: Option<Arc<dyn HttpBackend>>,
386    max_in_flight: Option<usize>,
387    #[cfg(feature = "schema")]
388    schema_registry: Option<Arc<SchemaRegistry>>,
389    #[cfg(feature = "json")]
390    json_parser: Option<JsonParserFn>,
391}
392
393impl ClientBuilder {
394    pub fn new() -> Self {
395        Self {
396            base_url: None,
397            timeout: None,
398            retry: None,
399            auth: None,
400            default_headers: http::HeaderMap::new(),
401            hooks: Hooks::default(),
402            plugins: PluginRegistry::new(),
403            reqwest_client: None,
404            custom_backend: None,
405            max_in_flight: None,
406            #[cfg(feature = "schema")]
407            schema_registry: None,
408            #[cfg(feature = "json")]
409            json_parser: None,
410        }
411    }
412
413    pub fn base_url(mut self, base_url: impl AsRef<str>) -> Result<Self> {
414        self.base_url = Some(Url::parse(base_url.as_ref()).map_err(Error::InvalidBaseUrl)?);
415        Ok(self)
416    }
417
418    pub fn timeout(mut self, timeout: Duration) -> Self {
419        self.timeout = Some(timeout);
420        self
421    }
422
423    pub fn retry(mut self, policy: RetryPolicy) -> Self {
424        self.retry = Some(policy);
425        self
426    }
427
428    pub fn auth(mut self, auth: Auth) -> Self {
429        self.auth = Some(auth);
430        self
431    }
432
433    pub fn default_header(mut self, key: impl AsRef<str>, value: impl AsRef<str>) -> Result<Self> {
434        let name = http::HeaderName::from_bytes(key.as_ref().as_bytes())
435            .map_err(|e| Error::Other(format!("invalid header name: {e}")))?;
436        let value = http::HeaderValue::from_str(value.as_ref())
437            .map_err(|e| Error::Other(format!("invalid header value: {e}")))?;
438        self.default_headers.insert(name, value);
439        Ok(self)
440    }
441
442    pub fn hooks(mut self, hooks: Hooks) -> Self {
443        self.hooks = hooks;
444        self
445    }
446
447    pub fn plugin<P: crate::plugin::Plugin + 'static>(mut self, plugin: P) -> Self {
448        self.plugins.push(Box::new(plugin));
449        self
450    }
451
452    pub fn reqwest_client(mut self, client: ReqwestClient) -> Self {
453        self.reqwest_client = Some(client);
454        self
455    }
456
457    /// Use a custom HTTP backend (for testing or alternate transports).
458    pub fn backend(mut self, backend: Arc<dyn HttpBackend>) -> Self {
459        self.custom_backend = Some(backend);
460        self
461    }
462
463    /// Limits how many requests this client may have in flight at once (including retries).
464    ///
465    /// Implemented with a tokio semaphore in the core client. This counts the full request
466    /// lifecycle (hooks and retries), not just the transport hop. For wire-level limits only,
467    /// use [`Self::transport_stack`] with Tower's [`ConcurrencyLimitLayer`](crate::tower::stack::ConcurrencyLimitLayer)
468    /// (feature `tower`) instead of—or deliberately alongside—this setting.
469    pub fn max_in_flight(mut self, limit: usize) -> Self {
470        self.max_in_flight = Some(limit);
471        self
472    }
473
474    /// Attach a [`SchemaRegistry`] for strict route validation (feature `schema`).
475    #[cfg(feature = "schema")]
476    pub fn schema_registry(mut self, registry: Arc<SchemaRegistry>) -> Self {
477        self.schema_registry = Some(registry);
478        self
479    }
480
481    /// Use a Tower [`Service`](tower::Service) as the HTTP transport (feature `tower`).
482    #[cfg(feature = "tower")]
483    pub fn http_service<S>(mut self, service: S) -> Self
484    where
485        S: tower::Service<HttpRequest, Response = HttpResponse, Error = Error>
486            + Clone
487            + Send
488            + 'static,
489        S::Future: Send + 'static,
490    {
491        use crate::tower::ServiceBackend;
492
493        self.custom_backend = Some(Arc::new(ServiceBackend::new(service)));
494        self
495    }
496
497    /// Use a boxed Tower transport stack (feature `tower`).
498    #[cfg(feature = "tower")]
499    pub fn http_service_boxed(mut self, service: crate::tower::BoxHttpService) -> Self {
500        use crate::tower::ServiceBackend;
501
502        self.custom_backend = Some(Arc::new(ServiceBackend::from_box(service)));
503        self
504    }
505
506    /// Build a Tower transport stack on top of the configured (or default) reqwest client.
507    ///
508    /// Application hooks and [`RetryPolicy`](crate::RetryPolicy) remain in the core client;
509    /// only wire-level behavior is configured here.
510    #[cfg(feature = "tower")]
511    pub fn transport_stack<F>(mut self, configure: F) -> Self
512    where
513        F: FnOnce(crate::tower::ReqwestHttpService) -> crate::tower::BoxHttpService,
514    {
515        use crate::tower::ServiceBackend;
516
517        let client = self.reqwest_client.clone().unwrap_or_default();
518        let stacked = configure(crate::tower::ReqwestHttpService::new(client));
519        self.custom_backend = Some(Arc::new(ServiceBackend::from_box(stacked)));
520        self
521    }
522
523    /// Sets a custom JSON parser for all responses from this client.
524    #[cfg(feature = "json")]
525    pub fn json_parser<F>(mut self, f: F) -> Self
526    where
527        F: Fn(&bytes::Bytes) -> std::result::Result<serde_json::Value, String>
528            + Send
529            + Sync
530            + 'static,
531    {
532        self.json_parser = Some(crate::json_parser::json_parser(f));
533        self
534    }
535
536    /// Sets a custom JSON parser from an existing [`JsonParserFn`].
537    #[cfg(feature = "json")]
538    pub fn json_parser_fn(mut self, parser: JsonParserFn) -> Self {
539        self.json_parser = Some(parser);
540        self
541    }
542
543    pub fn build(self) -> Result<Client> {
544        let base_url = self
545            .base_url
546            .ok_or(Error::MissingBaseUrl)?;
547
548        let backend: Arc<dyn HttpBackend> = if let Some(b) = self.custom_backend {
549            b
550        } else {
551            let reqwest_client = self.reqwest_client.unwrap_or_default();
552            Arc::new(ReqwestBackend::new(reqwest_client))
553        };
554
555        let plugins = Arc::new(self.plugins);
556        let merged_hooks = self.hooks.clone().merge(plugins.merged_hooks());
557
558        Ok(Client {
559            config: Arc::new(ClientConfig {
560                base_url,
561                timeout: self.timeout,
562                retry: self.retry,
563                auth: self.auth,
564                default_headers: self.default_headers,
565                hooks: self.hooks,
566                merged_hooks,
567                plugins,
568                max_in_flight: self.max_in_flight.map(|n| Arc::new(Semaphore::new(n))),
569                #[cfg(feature = "schema")]
570                schema_registry: self.schema_registry,
571                #[cfg(feature = "json")]
572                json_parser: self.json_parser,
573            }),
574            backend,
575        })
576    }
577}
578
579impl Default for ClientBuilder {
580    fn default() -> Self {
581        Self::new()
582    }
583}