Skip to main content

better_fetch/
client.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use tokio::sync::Semaphore;
6
7use http::Method;
8use reqwest::Client as ReqwestClient;
9use url::Url;
10
11use crate::auth::Auth;
12use crate::backend::{HttpBackend, HttpRequest, ReqwestBackend};
13#[cfg(feature = "tower")]
14use crate::backend::HttpResponse;
15use crate::error::Error;
16use crate::hooks::{ErrorContext, Hooks, RequestContext, ResponseContext, SuccessContext};
17use crate::plugin::{PluginRegistry, PreparedRequest};
18use crate::request::RequestBuilder;
19use crate::response::Response;
20use crate::retry::{sleep_before_retry, RetryPolicy};
21use crate::url_build::build_url;
22use crate::Result;
23
24#[cfg(feature = "json")]
25use crate::json_parser::JsonParserFn;
26
27#[cfg(feature = "schema")]
28use crate::schema::SchemaRegistry;
29
30/// Shared client configuration.
31#[derive(Clone)]
32pub struct ClientConfig {
33    pub base_url: Url,
34    pub timeout: Option<Duration>,
35    pub retry: Option<RetryPolicy>,
36    pub auth: Option<Auth>,
37    pub default_headers: http::HeaderMap,
38    pub hooks: Hooks,
39    pub plugins: Arc<PluginRegistry>,
40    /// Limits concurrent in-flight requests for this client (core transport guard, no Tower dep).
41    pub max_in_flight: Option<Arc<Semaphore>>,
42    #[cfg(feature = "schema")]
43    pub schema_registry: Option<Arc<SchemaRegistry>>,
44    #[cfg(feature = "json")]
45    pub json_parser: Option<JsonParserFn>,
46}
47
48/// Typed HTTP client built on reqwest.
49#[derive(Clone)]
50pub struct Client {
51    config: Arc<ClientConfig>,
52    backend: Arc<dyn HttpBackend>,
53}
54
55impl Client {
56    pub fn new(base_url: impl AsRef<str>) -> Result<Self> {
57        ClientBuilder::new().base_url(base_url)?.build()
58    }
59
60    pub fn builder() -> ClientBuilder {
61        ClientBuilder::new()
62    }
63
64    pub fn with_http_client(reqwest_client: ReqwestClient) -> Result<Self> {
65        ClientBuilder::new().reqwest_client(reqwest_client).build()
66    }
67
68    pub fn config(&self) -> &ClientConfig {
69        &self.config
70    }
71
72    pub fn get(&self, path: impl Into<String>) -> RequestBuilder<'_> {
73        self.request(Method::GET, path)
74    }
75
76    pub fn post(&self, path: impl Into<String>) -> RequestBuilder<'_> {
77        self.request(Method::POST, path)
78    }
79
80    pub fn put(&self, path: impl Into<String>) -> RequestBuilder<'_> {
81        self.request(Method::PUT, path)
82    }
83
84    pub fn patch(&self, path: impl Into<String>) -> RequestBuilder<'_> {
85        self.request(Method::PATCH, path)
86    }
87
88    pub fn delete(&self, path: impl Into<String>) -> RequestBuilder<'_> {
89        self.request(Method::DELETE, path)
90    }
91
92    pub fn head(&self, path: impl Into<String>) -> RequestBuilder<'_> {
93        self.request(Method::HEAD, path)
94    }
95
96    pub fn request(&self, method: Method, path: impl Into<String>) -> RequestBuilder<'_> {
97        RequestBuilder {
98            client: self,
99            method,
100            path: path.into(),
101            params: HashMap::new(),
102            query: HashMap::new(),
103            headers: self.config.default_headers.clone(),
104            body: None,
105            timeout: self.config.timeout,
106            retry: self.config.retry.clone(),
107            auth: self.config.auth.clone(),
108            #[cfg(feature = "json")]
109            json_parser: None,
110            #[cfg(feature = "validate")]
111            validate_response: true,
112        }
113    }
114
115    pub(crate) async fn execute(&self, builder: RequestBuilder<'_>) -> Result<Response> {
116        #[cfg(feature = "json")]
117        let json_parser = builder
118            .json_parser
119            .clone()
120            .or_else(|| self.config.json_parser.clone());
121        let built = build_url(
122            &self.config.base_url,
123            &builder.path,
124            &builder.params,
125            &builder.query,
126        )?;
127
128        let mut method = builder.method;
129        if let Some(override_method) = built.method_override {
130            method = override_method;
131        }
132
133        #[cfg(feature = "schema")]
134        if let Some(registry) = &self.config.schema_registry {
135            registry.ensure_route(&builder.path, &method)?;
136        }
137
138        let mut url = built.url;
139
140        let mut prepared = PreparedRequest {
141            url: url.clone(),
142            path: builder.path.clone(),
143        };
144        self.config.plugins.run_init_all(&mut prepared).await?;
145        url = prepared.url;
146
147        let mut headers = builder.headers;
148        let auth = builder.auth.or_else(|| self.config.auth.clone());
149        if let Some(auth) = auth {
150            auth.apply(&mut headers).await?;
151        }
152
153        let mut req_ctx = RequestContext {
154            url: url.clone(),
155            method: method.clone(),
156            headers: headers.clone(),
157            body: builder.body.clone(),
158            retry_attempt: 0,
159        };
160
161        let merged_hooks = self
162            .config
163            .hooks
164            .clone()
165            .merge(self.config.plugins.merged_hooks());
166
167        req_ctx = merged_hooks.run_on_request(req_ctx).await?;
168        url = req_ctx.url.clone();
169        headers = req_ctx.headers.clone();
170        method = req_ctx.method.clone();
171
172        let timeout = builder.timeout;
173        let retry_policy = builder.retry.or_else(|| self.config.retry.clone());
174
175        let backend = self.backend.clone();
176        let body = req_ctx.body.clone();
177
178        let _in_flight_permit = match &self.config.max_in_flight {
179            Some(sem) => Some(
180                sem.acquire()
181                    .await
182                    .map_err(|_| Error::Other("max_in_flight semaphore closed".into()))?,
183            ),
184            None => None,
185        };
186
187        let mut attempt = 0u32;
188        let max_attempts = retry_policy.as_ref().map(|p| p.max_attempts()).unwrap_or(0);
189
190        let http_req = HttpRequest {
191            method,
192            url,
193            headers,
194            body,
195            timeout,
196        };
197
198        loop {
199            req_ctx.retry_attempt = attempt;
200
201            let result = backend.execute(http_req.clone()).await;
202
203            match result {
204                Ok(http_res) => {
205                    let response = Response::new(
206                        http_res.status,
207                        http_res.headers,
208                        http_res.body,
209                        Some(http_req.url.clone()),
210                        #[cfg(feature = "json")]
211                        json_parser.clone(),
212                    );
213
214                    let response = merged_hooks
215                        .run_on_response(ResponseContext {
216                            request: req_ctx.clone(),
217                            response,
218                        })
219                        .await?;
220
221                    let should_retry = retry_policy
222                        .as_ref()
223                        .map(|p| p.should_retry_response(&response, false))
224                        .unwrap_or(false);
225
226                    if should_retry && attempt < max_attempts {
227                        merged_hooks
228                            .run_on_retry(ResponseContext {
229                                request: req_ctx.clone(),
230                                response: response.clone(),
231                            })
232                            .await;
233                        let delay = retry_policy
234                            .as_ref()
235                            .map(|p| p.delay_before_attempt(attempt))
236                            .unwrap_or(Duration::from_secs(1));
237                        attempt += 1;
238                        sleep_before_retry(delay).await;
239                        continue;
240                    }
241
242                    if response.is_success() {
243                        merged_hooks
244                            .run_on_success(SuccessContext {
245                                request: req_ctx.clone(),
246                                response: response.clone(),
247                            })
248                            .await;
249                    } else {
250                        let status = response.status();
251                        merged_hooks
252                            .run_on_error(ErrorContext {
253                                request: req_ctx.clone(),
254                                response: Some(response.clone()),
255                                error: Error::http_with_status_text(
256                                    status,
257                                    status.canonical_reason().unwrap_or("request failed"),
258                                    status.canonical_reason().unwrap_or("request failed"),
259                                    Some(response.bytes().clone()),
260                                ),
261                            })
262                            .await;
263                    }
264
265                    return Ok(response);
266                }
267                Err(err) => {
268                    let retry_transport = matches!(&err, Error::Transport(_) | Error::Timeout);
269                    if retry_transport && retry_policy.is_some() && attempt < max_attempts {
270                        merged_hooks
271                            .run_on_retry(ResponseContext {
272                                request: req_ctx.clone(),
273                                response: Response::new(
274                                    http::StatusCode::SERVICE_UNAVAILABLE,
275                                    http::HeaderMap::new(),
276                                    bytes::Bytes::new(),
277                                    Some(http_req.url.clone()),
278                                    #[cfg(feature = "json")]
279                                    None,
280                                ),
281                            })
282                            .await;
283                        let delay = retry_policy
284                            .as_ref()
285                            .map(|p| p.delay_before_attempt(attempt))
286                            .unwrap_or(Duration::from_secs(1));
287                        attempt += 1;
288                        sleep_before_retry(delay).await;
289                        continue;
290                    }
291
292                    merged_hooks
293                        .run_on_error(ErrorContext {
294                            request: req_ctx.clone(),
295                            response: None,
296                            error: err.clone(),
297                        })
298                        .await;
299
300                    if retry_transport && retry_policy.is_some() {
301                        return Err(Error::retry_exhausted(attempt + 1, err));
302                    }
303
304                    return Err(err);
305                }
306            }
307        }
308    }
309}
310
311/// Builder for [`Client`].
312pub struct ClientBuilder {
313    base_url: Option<Url>,
314    timeout: Option<Duration>,
315    retry: Option<RetryPolicy>,
316    auth: Option<Auth>,
317    default_headers: http::HeaderMap,
318    hooks: Hooks,
319    plugins: PluginRegistry,
320    reqwest_client: Option<ReqwestClient>,
321    custom_backend: Option<Arc<dyn HttpBackend>>,
322    max_in_flight: Option<usize>,
323    #[cfg(feature = "schema")]
324    schema_registry: Option<Arc<SchemaRegistry>>,
325    #[cfg(feature = "json")]
326    json_parser: Option<JsonParserFn>,
327}
328
329impl ClientBuilder {
330    pub fn new() -> Self {
331        Self {
332            base_url: None,
333            timeout: None,
334            retry: None,
335            auth: None,
336            default_headers: http::HeaderMap::new(),
337            hooks: Hooks::default(),
338            plugins: PluginRegistry::new(),
339            reqwest_client: None,
340            custom_backend: None,
341            max_in_flight: None,
342            #[cfg(feature = "schema")]
343            schema_registry: None,
344            #[cfg(feature = "json")]
345            json_parser: None,
346        }
347    }
348
349    pub fn base_url(mut self, base_url: impl AsRef<str>) -> Result<Self> {
350        self.base_url = Some(Url::parse(base_url.as_ref()).map_err(Error::InvalidBaseUrl)?);
351        Ok(self)
352    }
353
354    pub fn timeout(mut self, timeout: Duration) -> Self {
355        self.timeout = Some(timeout);
356        self
357    }
358
359    pub fn retry(mut self, policy: RetryPolicy) -> Self {
360        self.retry = Some(policy);
361        self
362    }
363
364    pub fn auth(mut self, auth: Auth) -> Self {
365        self.auth = Some(auth);
366        self
367    }
368
369    pub fn default_header(mut self, key: impl AsRef<str>, value: impl AsRef<str>) -> Result<Self> {
370        let name = http::HeaderName::from_bytes(key.as_ref().as_bytes())
371            .map_err(|e| Error::Other(format!("invalid header name: {e}")))?;
372        let value = http::HeaderValue::from_str(value.as_ref())
373            .map_err(|e| Error::Other(format!("invalid header value: {e}")))?;
374        self.default_headers.insert(name, value);
375        Ok(self)
376    }
377
378    pub fn hooks(mut self, hooks: Hooks) -> Self {
379        self.hooks = hooks;
380        self
381    }
382
383    pub fn plugin<P: crate::plugin::Plugin + 'static>(mut self, plugin: P) -> Self {
384        self.plugins.push(Box::new(plugin));
385        self
386    }
387
388    pub fn reqwest_client(mut self, client: ReqwestClient) -> Self {
389        self.reqwest_client = Some(client);
390        self
391    }
392
393    /// Use a custom HTTP backend (for testing or alternate transports).
394    pub fn backend(mut self, backend: Arc<dyn HttpBackend>) -> Self {
395        self.custom_backend = Some(backend);
396        self
397    }
398
399    /// Limits how many requests this client may have in flight at once (including retries).
400    ///
401    /// Implemented with a tokio semaphore in the core client; does not require the `tower` feature.
402    /// For token-bucket rate limiting or richer policies, use [`Self::transport_stack`] with
403    /// Tower layers (feature `tower`).
404    pub fn max_in_flight(mut self, limit: usize) -> Self {
405        self.max_in_flight = Some(limit);
406        self
407    }
408
409    /// Attach a [`SchemaRegistry`] for strict route validation (feature `schema`).
410    #[cfg(feature = "schema")]
411    pub fn schema_registry(mut self, registry: Arc<SchemaRegistry>) -> Self {
412        self.schema_registry = Some(registry);
413        self
414    }
415
416    /// Use a Tower [`Service`](tower::Service) as the HTTP transport (feature `tower`).
417    #[cfg(feature = "tower")]
418    pub fn http_service<S>(mut self, service: S) -> Self
419    where
420        S: tower::Service<HttpRequest, Response = HttpResponse, Error = Error>
421            + Clone
422            + Send
423            + 'static,
424        S::Future: Send + 'static,
425    {
426        use crate::tower::ServiceBackend;
427
428        self.custom_backend = Some(Arc::new(ServiceBackend::new(service)));
429        self
430    }
431
432    /// Use a boxed Tower transport stack (feature `tower`).
433    #[cfg(feature = "tower")]
434    pub fn http_service_boxed(mut self, service: crate::tower::BoxHttpService) -> Self {
435        use crate::tower::ServiceBackend;
436
437        self.custom_backend = Some(Arc::new(ServiceBackend::from_box(service)));
438        self
439    }
440
441    /// Build a Tower transport stack on top of the configured (or default) reqwest client.
442    ///
443    /// Application hooks and [`RetryPolicy`](crate::RetryPolicy) remain in the core client;
444    /// only wire-level behavior is configured here.
445    #[cfg(feature = "tower")]
446    pub fn transport_stack<F>(mut self, configure: F) -> Self
447    where
448        F: FnOnce(crate::tower::ReqwestHttpService) -> crate::tower::BoxHttpService,
449    {
450        use crate::tower::ServiceBackend;
451
452        let client = self.reqwest_client.clone().unwrap_or_default();
453        let stacked = configure(crate::tower::ReqwestHttpService::new(client));
454        self.custom_backend = Some(Arc::new(ServiceBackend::from_box(stacked)));
455        self
456    }
457
458    /// Sets a custom JSON parser for all responses from this client.
459    ///
460    /// The parser receives raw response bytes and must return a [`serde_json::Value`].
461    /// Typed deserialization (`json`, `send_json`) then uses serde to map that value to `T`.
462    #[cfg(feature = "json")]
463    pub fn json_parser<F>(mut self, f: F) -> Self
464    where
465        F: Fn(&bytes::Bytes) -> std::result::Result<serde_json::Value, String>
466            + Send
467            + Sync
468            + 'static,
469    {
470        self.json_parser = Some(crate::json_parser::json_parser(f));
471        self
472    }
473
474    /// Sets a custom JSON parser from an existing [`JsonParserFn`].
475    #[cfg(feature = "json")]
476    pub fn json_parser_fn(mut self, parser: JsonParserFn) -> Self {
477        self.json_parser = Some(parser);
478        self
479    }
480
481    pub fn build(self) -> Result<Client> {
482        let base_url = match self.base_url {
483            Some(url) => url,
484            None => Url::parse("http://localhost")
485                .map_err(|e| Error::Other(format!("invalid default base URL: {e}")))?,
486        };
487
488        let backend: Arc<dyn HttpBackend> = if let Some(b) = self.custom_backend {
489            b
490        } else {
491            let reqwest_client = self.reqwest_client.unwrap_or_default();
492            Arc::new(ReqwestBackend::new(reqwest_client))
493        };
494
495        Ok(Client {
496            config: Arc::new(ClientConfig {
497                base_url,
498                timeout: self.timeout,
499                retry: self.retry,
500                auth: self.auth,
501                default_headers: self.default_headers,
502                hooks: self.hooks,
503                plugins: Arc::new(self.plugins),
504                max_in_flight: self.max_in_flight.map(|n| Arc::new(Semaphore::new(n))),
505                #[cfg(feature = "schema")]
506                schema_registry: self.schema_registry,
507                #[cfg(feature = "json")]
508                json_parser: self.json_parser,
509            }),
510            backend,
511        })
512    }
513}
514
515impl Default for ClientBuilder {
516    fn default() -> Self {
517        Self::new()
518    }
519}