Skip to main content

better_fetch/
client.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use tokio::sync::Semaphore;
6
7use http::Method;
8use reqwest::Client as ReqwestClient;
9use url::Url;
10
11use crate::auth::Auth;
12use crate::backend::{HttpBackend, HttpRequest, HttpResponse, ReqwestBackend};
13use crate::error::Error;
14use crate::hooks::{ErrorContext, Hooks, RequestContext, ResponseContext, SuccessContext};
15use crate::plugin::{PluginRegistry, PreparedRequest};
16use crate::request::RequestBuilder;
17use crate::response::Response;
18use crate::retry::{sleep_before_retry, RetryPolicy};
19use crate::url_build::build_url;
20use crate::Result;
21
22#[cfg(feature = "json")]
23use crate::json_parser::JsonParserFn;
24
25#[cfg(feature = "schema")]
26use crate::schema::SchemaRegistry;
27
28/// Shared client configuration.
29#[derive(Clone)]
30pub struct ClientConfig {
31    pub base_url: Url,
32    pub timeout: Option<Duration>,
33    pub retry: Option<RetryPolicy>,
34    pub auth: Option<Auth>,
35    pub default_headers: http::HeaderMap,
36    pub hooks: Hooks,
37    pub plugins: Arc<PluginRegistry>,
38    /// Limits concurrent in-flight requests for this client (core transport guard, no Tower dep).
39    pub max_in_flight: Option<Arc<Semaphore>>,
40    #[cfg(feature = "schema")]
41    pub schema_registry: Option<Arc<SchemaRegistry>>,
42    #[cfg(feature = "json")]
43    pub json_parser: Option<JsonParserFn>,
44}
45
46/// Typed HTTP client built on reqwest.
47#[derive(Clone)]
48pub struct Client {
49    config: Arc<ClientConfig>,
50    backend: Arc<dyn HttpBackend>,
51}
52
53impl Client {
54    pub fn new(base_url: impl AsRef<str>) -> Result<Self> {
55        ClientBuilder::new().base_url(base_url)?.build()
56    }
57
58    pub fn builder() -> ClientBuilder {
59        ClientBuilder::new()
60    }
61
62    pub fn with_http_client(reqwest_client: ReqwestClient) -> Result<Self> {
63        ClientBuilder::new().reqwest_client(reqwest_client).build()
64    }
65
66    pub fn config(&self) -> &ClientConfig {
67        &self.config
68    }
69
70    pub fn get(&self, path: impl Into<String>) -> RequestBuilder<'_> {
71        self.request(Method::GET, path)
72    }
73
74    pub fn post(&self, path: impl Into<String>) -> RequestBuilder<'_> {
75        self.request(Method::POST, path)
76    }
77
78    pub fn put(&self, path: impl Into<String>) -> RequestBuilder<'_> {
79        self.request(Method::PUT, path)
80    }
81
82    pub fn patch(&self, path: impl Into<String>) -> RequestBuilder<'_> {
83        self.request(Method::PATCH, path)
84    }
85
86    pub fn delete(&self, path: impl Into<String>) -> RequestBuilder<'_> {
87        self.request(Method::DELETE, path)
88    }
89
90    pub fn head(&self, path: impl Into<String>) -> RequestBuilder<'_> {
91        self.request(Method::HEAD, path)
92    }
93
94    pub fn request(&self, method: Method, path: impl Into<String>) -> RequestBuilder<'_> {
95        RequestBuilder {
96            client: self,
97            method,
98            path: path.into(),
99            params: HashMap::new(),
100            query: HashMap::new(),
101            headers: self.config.default_headers.clone(),
102            body: None,
103            timeout: self.config.timeout,
104            retry: self.config.retry.clone(),
105            auth: self.config.auth.clone(),
106            #[cfg(feature = "json")]
107            json_parser: None,
108            #[cfg(feature = "validate")]
109            validate_response: true,
110        }
111    }
112
113    pub(crate) async fn execute(&self, builder: RequestBuilder<'_>) -> Result<Response> {
114        #[cfg(feature = "json")]
115        let json_parser = builder
116            .json_parser
117            .clone()
118            .or_else(|| self.config.json_parser.clone());
119        let built = build_url(
120            &self.config.base_url,
121            &builder.path,
122            &builder.params,
123            &builder.query,
124        )?;
125
126        let mut method = builder.method;
127        if let Some(override_method) = built.method_override {
128            method = override_method;
129        }
130
131        #[cfg(feature = "schema")]
132        if let Some(registry) = &self.config.schema_registry {
133            registry.ensure_route(&builder.path, &method)?;
134        }
135
136        let mut url = built.url;
137
138        let mut prepared = PreparedRequest {
139            url: url.clone(),
140            path: builder.path.clone(),
141        };
142        self.config.plugins.run_init_all(&mut prepared).await?;
143        url = prepared.url;
144
145        let mut headers = builder.headers;
146        let auth = builder.auth.or_else(|| self.config.auth.clone());
147        if let Some(auth) = auth {
148            auth.apply(&mut headers).await?;
149        }
150
151        let mut req_ctx = RequestContext {
152            url: url.clone(),
153            method: method.clone(),
154            headers: headers.clone(),
155            body: builder.body.clone(),
156            retry_attempt: 0,
157        };
158
159        let merged_hooks = self
160            .config
161            .hooks
162            .clone()
163            .merge(self.config.plugins.merged_hooks());
164
165        req_ctx = merged_hooks.run_on_request(req_ctx).await?;
166        url = req_ctx.url.clone();
167        headers = req_ctx.headers.clone();
168        method = req_ctx.method.clone();
169
170        let timeout = builder.timeout;
171        let retry_policy = builder.retry.or_else(|| self.config.retry.clone());
172
173        let backend = self.backend.clone();
174        let body = req_ctx.body.clone();
175
176        let _in_flight_permit = match &self.config.max_in_flight {
177            Some(sem) => Some(
178                sem.acquire()
179                    .await
180                    .map_err(|_| Error::Other("max_in_flight semaphore closed".into()))?,
181            ),
182            None => None,
183        };
184
185        let mut attempt = 0u32;
186        let max_attempts = retry_policy.as_ref().map(|p| p.max_attempts()).unwrap_or(0);
187
188        let http_req = HttpRequest {
189            method,
190            url,
191            headers,
192            body,
193            timeout,
194        };
195
196        loop {
197            req_ctx.retry_attempt = attempt;
198
199            let result = backend.execute(http_req.clone()).await;
200
201            match result {
202                Ok(http_res) => {
203                    let response = Response::new(
204                        http_res.status,
205                        http_res.headers,
206                        http_res.body,
207                        Some(http_req.url.clone()),
208                        #[cfg(feature = "json")]
209                        json_parser.clone(),
210                    );
211
212                    let response = merged_hooks
213                        .run_on_response(ResponseContext {
214                            request: req_ctx.clone(),
215                            response,
216                        })
217                        .await?;
218
219                    let should_retry = retry_policy
220                        .as_ref()
221                        .map(|p| p.should_retry_response(&response, false))
222                        .unwrap_or(false);
223
224                    if should_retry && attempt < max_attempts {
225                        merged_hooks
226                            .run_on_retry(ResponseContext {
227                                request: req_ctx.clone(),
228                                response: response.clone(),
229                            })
230                            .await;
231                        let delay = retry_policy
232                            .as_ref()
233                            .map(|p| p.delay_before_attempt(attempt))
234                            .unwrap_or(Duration::from_secs(1));
235                        attempt += 1;
236                        sleep_before_retry(delay).await;
237                        continue;
238                    }
239
240                    if response.is_success() {
241                        merged_hooks
242                            .run_on_success(SuccessContext {
243                                request: req_ctx.clone(),
244                                response: response.clone(),
245                            })
246                            .await;
247                    } else {
248                        let status = response.status();
249                        merged_hooks
250                            .run_on_error(ErrorContext {
251                                request: req_ctx.clone(),
252                                response: Some(response.clone()),
253                                error: Error::http_with_status_text(
254                                    status,
255                                    status.canonical_reason().unwrap_or("request failed"),
256                                    status.canonical_reason().unwrap_or("request failed"),
257                                    Some(response.bytes().clone()),
258                                ),
259                            })
260                            .await;
261                    }
262
263                    return Ok(response);
264                }
265                Err(err) => {
266                    let retry_transport = matches!(&err, Error::Transport(_) | Error::Timeout);
267                    if retry_transport && retry_policy.is_some() && attempt < max_attempts {
268                        merged_hooks
269                            .run_on_retry(ResponseContext {
270                                request: req_ctx.clone(),
271                                response: Response::new(
272                                    http::StatusCode::SERVICE_UNAVAILABLE,
273                                    http::HeaderMap::new(),
274                                    bytes::Bytes::new(),
275                                    Some(http_req.url.clone()),
276                                    #[cfg(feature = "json")]
277                                    None,
278                                ),
279                            })
280                            .await;
281                        let delay = retry_policy
282                            .as_ref()
283                            .map(|p| p.delay_before_attempt(attempt))
284                            .unwrap_or(Duration::from_secs(1));
285                        attempt += 1;
286                        sleep_before_retry(delay).await;
287                        continue;
288                    }
289
290                    merged_hooks
291                        .run_on_error(ErrorContext {
292                            request: req_ctx.clone(),
293                            response: None,
294                            error: err.clone(),
295                        })
296                        .await;
297
298                    if retry_transport && retry_policy.is_some() {
299                        return Err(Error::retry_exhausted(attempt + 1, err));
300                    }
301
302                    return Err(err);
303                }
304            }
305        }
306    }
307}
308
309/// Builder for [`Client`].
310pub struct ClientBuilder {
311    base_url: Option<Url>,
312    timeout: Option<Duration>,
313    retry: Option<RetryPolicy>,
314    auth: Option<Auth>,
315    default_headers: http::HeaderMap,
316    hooks: Hooks,
317    plugins: PluginRegistry,
318    reqwest_client: Option<ReqwestClient>,
319    custom_backend: Option<Arc<dyn HttpBackend>>,
320    max_in_flight: Option<usize>,
321    #[cfg(feature = "schema")]
322    schema_registry: Option<Arc<SchemaRegistry>>,
323    #[cfg(feature = "json")]
324    json_parser: Option<JsonParserFn>,
325}
326
327impl ClientBuilder {
328    pub fn new() -> Self {
329        Self {
330            base_url: None,
331            timeout: None,
332            retry: None,
333            auth: None,
334            default_headers: http::HeaderMap::new(),
335            hooks: Hooks::default(),
336            plugins: PluginRegistry::new(),
337            reqwest_client: None,
338            custom_backend: None,
339            max_in_flight: None,
340            #[cfg(feature = "schema")]
341            schema_registry: None,
342            #[cfg(feature = "json")]
343            json_parser: None,
344        }
345    }
346
347    pub fn base_url(mut self, base_url: impl AsRef<str>) -> Result<Self> {
348        self.base_url = Some(Url::parse(base_url.as_ref()).map_err(Error::InvalidBaseUrl)?);
349        Ok(self)
350    }
351
352    pub fn timeout(mut self, timeout: Duration) -> Self {
353        self.timeout = Some(timeout);
354        self
355    }
356
357    pub fn retry(mut self, policy: RetryPolicy) -> Self {
358        self.retry = Some(policy);
359        self
360    }
361
362    pub fn auth(mut self, auth: Auth) -> Self {
363        self.auth = Some(auth);
364        self
365    }
366
367    pub fn default_header(mut self, key: impl AsRef<str>, value: impl AsRef<str>) -> Result<Self> {
368        let name = http::HeaderName::from_bytes(key.as_ref().as_bytes())
369            .map_err(|e| Error::Other(format!("invalid header name: {e}")))?;
370        let value = http::HeaderValue::from_str(value.as_ref())
371            .map_err(|e| Error::Other(format!("invalid header value: {e}")))?;
372        self.default_headers.insert(name, value);
373        Ok(self)
374    }
375
376    pub fn hooks(mut self, hooks: Hooks) -> Self {
377        self.hooks = hooks;
378        self
379    }
380
381    pub fn plugin<P: crate::plugin::Plugin + 'static>(mut self, plugin: P) -> Self {
382        self.plugins.push(Box::new(plugin));
383        self
384    }
385
386    pub fn reqwest_client(mut self, client: ReqwestClient) -> Self {
387        self.reqwest_client = Some(client);
388        self
389    }
390
391    /// Use a custom HTTP backend (for testing or alternate transports).
392    pub fn backend(mut self, backend: Arc<dyn HttpBackend>) -> Self {
393        self.custom_backend = Some(backend);
394        self
395    }
396
397    /// Limits how many requests this client may have in flight at once (including retries).
398    ///
399    /// Implemented with a tokio semaphore in the core client; does not require the `tower` feature.
400    /// For token-bucket rate limiting or richer policies, use [`Self::transport_stack`] with
401    /// Tower layers (feature `tower`).
402    pub fn max_in_flight(mut self, limit: usize) -> Self {
403        self.max_in_flight = Some(limit);
404        self
405    }
406
407    /// Attach a [`SchemaRegistry`] for strict route validation (feature `schema`).
408    #[cfg(feature = "schema")]
409    pub fn schema_registry(mut self, registry: Arc<SchemaRegistry>) -> Self {
410        self.schema_registry = Some(registry);
411        self
412    }
413
414    /// Use a Tower [`Service`](tower::Service) as the HTTP transport (feature `tower`).
415    #[cfg(feature = "tower")]
416    pub fn http_service<S>(mut self, service: S) -> Self
417    where
418        S: tower::Service<HttpRequest, Response = HttpResponse, Error = Error>
419            + Clone
420            + Send
421            + 'static,
422        S::Future: Send + 'static,
423    {
424        use crate::tower::ServiceBackend;
425
426        self.custom_backend = Some(Arc::new(ServiceBackend::new(service)));
427        self
428    }
429
430    /// Use a boxed Tower transport stack (feature `tower`).
431    #[cfg(feature = "tower")]
432    pub fn http_service_boxed(mut self, service: crate::tower::BoxHttpService) -> Self {
433        use crate::tower::ServiceBackend;
434
435        self.custom_backend = Some(Arc::new(ServiceBackend::from_box(service)));
436        self
437    }
438
439    /// Build a Tower transport stack on top of the configured (or default) reqwest client.
440    ///
441    /// Application hooks and [`RetryPolicy`](crate::RetryPolicy) remain in the core client;
442    /// only wire-level behavior is configured here.
443    #[cfg(feature = "tower")]
444    pub fn transport_stack<F>(mut self, configure: F) -> Self
445    where
446        F: FnOnce(crate::tower::ReqwestHttpService) -> crate::tower::BoxHttpService,
447    {
448        use crate::tower::ServiceBackend;
449
450        let client = self.reqwest_client.clone().unwrap_or_default();
451        let stacked = configure(crate::tower::ReqwestHttpService::new(client));
452        self.custom_backend = Some(Arc::new(ServiceBackend::from_box(stacked)));
453        self
454    }
455
456    /// Sets a custom JSON parser for all responses from this client.
457    ///
458    /// The parser receives raw response bytes and must return a [`serde_json::Value`].
459    /// Typed deserialization (`json`, `send_json`) then uses serde to map that value to `T`.
460    #[cfg(feature = "json")]
461    pub fn json_parser<F>(mut self, f: F) -> Self
462    where
463        F: Fn(&bytes::Bytes) -> std::result::Result<serde_json::Value, String>
464            + Send
465            + Sync
466            + 'static,
467    {
468        self.json_parser = Some(crate::json_parser::json_parser(f));
469        self
470    }
471
472    /// Sets a custom JSON parser from an existing [`JsonParserFn`].
473    #[cfg(feature = "json")]
474    pub fn json_parser_fn(mut self, parser: JsonParserFn) -> Self {
475        self.json_parser = Some(parser);
476        self
477    }
478
479    pub fn build(self) -> Result<Client> {
480        let base_url = match self.base_url {
481            Some(url) => url,
482            None => Url::parse("http://localhost")
483                .map_err(|e| Error::Other(format!("invalid default base URL: {e}")))?,
484        };
485
486        let backend: Arc<dyn HttpBackend> = if let Some(b) = self.custom_backend {
487            b
488        } else {
489            let reqwest_client = self.reqwest_client.unwrap_or_default();
490            Arc::new(ReqwestBackend::new(reqwest_client))
491        };
492
493        Ok(Client {
494            config: Arc::new(ClientConfig {
495                base_url,
496                timeout: self.timeout,
497                retry: self.retry,
498                auth: self.auth,
499                default_headers: self.default_headers,
500                hooks: self.hooks,
501                plugins: Arc::new(self.plugins),
502                max_in_flight: self.max_in_flight.map(|n| Arc::new(Semaphore::new(n))),
503                #[cfg(feature = "schema")]
504                schema_registry: self.schema_registry,
505                #[cfg(feature = "json")]
506                json_parser: self.json_parser,
507            }),
508            backend,
509        })
510    }
511}
512
513impl Default for ClientBuilder {
514    fn default() -> Self {
515        Self::new()
516    }
517}