1use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::Duration;
9
10use indexmap::IndexMap;
11use tokio::sync::Semaphore;
12
13use http::Method;
14use reqwest::Client as ReqwestClient;
15use url::Url;
16
17use crate::auth::Auth;
18use crate::backend::{HttpBackend, HttpBody, HttpRequest, ReqwestBackend};
19use crate::cancel::execute_or_cancel;
20use crate::endpoint::{Endpoint, EndpointRequestBuilder};
21use crate::error::Error;
22use crate::hooks::{ErrorContext, Hooks, RequestContext, ResponseContext, SuccessContext};
23use crate::plugin::{PluginRegistry, PreparedRequest};
24use crate::request::RequestBuilder;
25use crate::response::Response;
26use crate::retry::{sleep_or_cancel, RetryPolicy};
27use crate::url_build::build_url;
28use crate::Result;
29
30#[cfg(feature = "tower")]
31use crate::backend::HttpResponse;
32
33#[cfg(feature = "json")]
34use crate::json_parser::JsonParserFn;
35
36#[cfg(feature = "schema")]
37use crate::schema::SchemaRegistry;
38
39fn body_for_context(body: &HttpBody) -> Option<bytes::Bytes> {
40 match body {
41 HttpBody::Empty => None,
42 HttpBody::Bytes(b) => Some(b.clone()),
43 }
44}
45
46#[derive(Clone)]
48pub struct ClientConfig {
49 pub base_url: Url,
51 pub timeout: Option<Duration>,
53 pub retry: Option<RetryPolicy>,
55 pub auth: Option<Auth>,
57 pub default_headers: http::HeaderMap,
59 pub hooks: Hooks,
61 pub(crate) merged_hooks: Hooks,
62 pub plugins: Arc<PluginRegistry>,
64 pub max_in_flight: Option<Arc<Semaphore>>,
70 #[cfg(feature = "schema")]
71 pub schema_registry: Option<Arc<SchemaRegistry>>,
73 #[cfg(feature = "json")]
74 pub json_parser: Option<JsonParserFn>,
76}
77
78#[derive(Clone)]
80pub struct Client {
81 config: Arc<ClientConfig>,
82 backend: Arc<dyn HttpBackend>,
83}
84
85impl Client {
86 pub fn new(base_url: impl AsRef<str>) -> Result<Self> {
100 ClientBuilder::new().base_url(base_url)?.build()
101 }
102
103 pub fn builder() -> ClientBuilder {
105 ClientBuilder::new()
106 }
107
108 pub fn with_http_client(reqwest_client: ReqwestClient, base_url: impl AsRef<str>) -> Result<Self> {
110 ClientBuilder::new()
111 .reqwest_client(reqwest_client)
112 .base_url(base_url)?
113 .build()
114 }
115
116 pub fn call<E: Endpoint>(&self) -> EndpointRequestBuilder<'_, E> {
120 EndpointRequestBuilder::new(self.request(E::METHOD, E::PATH))
121 }
122
123 pub fn config(&self) -> &ClientConfig {
125 &self.config
126 }
127
128 pub fn get(&self, path: impl Into<String>) -> RequestBuilder<'_> {
130 self.request(Method::GET, path)
131 }
132
133 pub fn post(&self, path: impl Into<String>) -> RequestBuilder<'_> {
135 self.request(Method::POST, path)
136 }
137
138 pub fn put(&self, path: impl Into<String>) -> RequestBuilder<'_> {
140 self.request(Method::PUT, path)
141 }
142
143 pub fn patch(&self, path: impl Into<String>) -> RequestBuilder<'_> {
145 self.request(Method::PATCH, path)
146 }
147
148 pub fn delete(&self, path: impl Into<String>) -> RequestBuilder<'_> {
150 self.request(Method::DELETE, path)
151 }
152
153 pub fn head(&self, path: impl Into<String>) -> RequestBuilder<'_> {
155 self.request(Method::HEAD, path)
156 }
157
158 pub fn request(&self, method: Method, path: impl Into<String>) -> RequestBuilder<'_> {
160 RequestBuilder {
161 client: self,
162 method,
163 path: path.into(),
164 params: HashMap::new(),
165 query: IndexMap::new(),
166 headers: self.config.default_headers.clone(),
167 body: HttpBody::Empty,
168 #[cfg(feature = "multipart")]
169 multipart: None,
170 timeout: self.config.timeout,
171 retry: self.config.retry.clone(),
172 auth: self.config.auth.clone(),
173 cancellation: None,
174 throw_on_error: false,
175 #[cfg(feature = "json")]
176 json_parser: None,
177 #[cfg(feature = "validate")]
178 validate_response: true,
179 }
180 }
181
182 pub(crate) async fn execute(&self, builder: RequestBuilder<'_>) -> Result<Response> {
183 #[cfg(feature = "json")]
184 let json_parser = builder
185 .json_parser
186 .clone()
187 .or_else(|| self.config.json_parser.clone());
188
189 let built = build_url(
190 &self.config.base_url,
191 &builder.path,
192 &builder.params,
193 &builder.query,
194 )?;
195
196 let mut method = builder.method;
197 if let Some(override_method) = built.method_override {
198 method = override_method;
199 }
200
201 #[cfg(feature = "schema")]
202 if let Some(registry) = &self.config.schema_registry {
203 registry.ensure_route(&builder.path, &method)?;
204 }
205
206 let mut url = built.url;
207 let mut headers = builder.headers;
208 let auth = builder.auth.or_else(|| self.config.auth.clone());
209 if let Some(auth) = auth {
210 auth.apply(&mut headers).await?;
211 }
212
213 let mut prepared = PreparedRequest {
214 url: url.clone(),
215 path: builder.path.clone(),
216 method: method.clone(),
217 headers: headers.clone(),
218 };
219 self.config.plugins.run_init_all(&mut prepared).await?;
220 url = prepared.url;
221 headers = prepared.headers;
222 method = prepared.method;
223
224 let mut req_ctx = RequestContext {
225 url: url.clone(),
226 method: method.clone(),
227 headers: headers.clone(),
228 body: body_for_context(&builder.body),
229 retry_attempt: 0,
230 };
231
232 let merged_hooks = &self.config.merged_hooks;
233 req_ctx = merged_hooks.run_on_request(req_ctx).await?;
234 url = req_ctx.url.clone();
235 headers = req_ctx.headers.clone();
236 method = req_ctx.method.clone();
237
238 let timeout = builder.timeout;
239 let retry_policy = builder.retry.or_else(|| self.config.retry.clone());
240 let throw_on_error = builder.throw_on_error;
241 let cancel = builder.cancellation;
242
243 let backend = self.backend.clone();
244
245 let _in_flight_permit = match &self.config.max_in_flight {
246 Some(sem) => Some(
247 sem.acquire()
248 .await
249 .map_err(|_| Error::Other("max_in_flight semaphore closed".into()))?,
250 ),
251 None => None,
252 };
253
254 let mut attempt = 0u32;
255 let max_attempts = retry_policy.as_ref().map(|p| p.max_attempts()).unwrap_or(0);
256
257 let request_body = builder.body;
258 #[cfg(feature = "multipart")]
259 let mut multipart_body = builder.multipart;
260 #[cfg(feature = "multipart")]
261 let had_multipart = multipart_body.is_some();
262
263 let cancel_ref = cancel.as_ref();
264
265 loop {
266 req_ctx.retry_attempt = attempt;
267
268 #[cfg(feature = "multipart")]
269 if attempt > 0 && had_multipart {
270 return Err(Error::Other(
271 "automatic retry is not supported with multipart request bodies".into(),
272 ));
273 }
274
275 let http_req = HttpRequest {
276 method: method.clone(),
277 url: url.clone(),
278 headers: headers.clone(),
279 body: request_body.clone(),
280 timeout,
281 cancellation: cancel.clone(),
282 #[cfg(feature = "multipart")]
283 multipart: multipart_body.take(),
284 };
285 let request_url = http_req.url.clone();
286
287 let result = execute_or_cancel(cancel_ref, backend.execute(http_req)).await;
288
289 match result {
290 Ok(http_res) => {
291 let response = Response::new(
292 http_res.status,
293 http_res.headers.clone(),
294 http_res.body,
295 Some(request_url.clone()),
296 #[cfg(feature = "json")]
297 json_parser.clone(),
298 );
299
300 let response = merged_hooks
301 .run_on_response(ResponseContext {
302 request: req_ctx.clone(),
303 response,
304 })
305 .await?;
306
307 let should_retry = retry_policy
308 .as_ref()
309 .map(|p| p.should_retry_response(&response, false))
310 .unwrap_or(false);
311
312 if should_retry && attempt < max_attempts {
313 merged_hooks
314 .run_on_retry(ResponseContext {
315 request: req_ctx.clone(),
316 response: response.clone(),
317 })
318 .await;
319 let delay = retry_policy
320 .as_ref()
321 .map(|p| p.delay_after_response(attempt, response.headers()))
322 .unwrap_or(Duration::from_secs(1));
323 attempt += 1;
324 sleep_or_cancel(delay, cancel_ref).await?;
325 continue;
326 }
327
328 if response.is_success() {
329 merged_hooks
330 .run_on_success(SuccessContext {
331 request: req_ctx.clone(),
332 response: response.clone(),
333 })
334 .await;
335 return Ok(response);
336 }
337
338 let status = response.status();
339 let http_err = Error::http_with_status_text(
340 status,
341 status.canonical_reason().unwrap_or("request failed"),
342 status.canonical_reason().unwrap_or("request failed"),
343 Some(response.bytes().clone()),
344 );
345 merged_hooks
346 .run_on_error(ErrorContext {
347 request: req_ctx.clone(),
348 response: Some(response.clone()),
349 error: http_err.clone(),
350 })
351 .await;
352
353 if throw_on_error {
354 return Err(http_err);
355 }
356 return Ok(response);
357 }
358 Err(err) => {
359 if err.is_cancelled() {
360 merged_hooks
361 .run_on_error(ErrorContext {
362 request: req_ctx.clone(),
363 response: None,
364 error: err.clone(),
365 })
366 .await;
367 return Err(err);
368 }
369
370 let retry_transport = matches!(&err, Error::Transport(_) | Error::Timeout);
371 if retry_transport && retry_policy.is_some() && attempt < max_attempts {
372 merged_hooks
373 .run_on_retry(ResponseContext {
374 request: req_ctx.clone(),
375 response: Response::new(
376 http::StatusCode::SERVICE_UNAVAILABLE,
377 http::HeaderMap::new(),
378 bytes::Bytes::new(),
379 Some(request_url.clone()),
380 #[cfg(feature = "json")]
381 None,
382 ),
383 })
384 .await;
385 let delay = retry_policy
386 .as_ref()
387 .map(|p| p.delay_after_response(attempt, &http::HeaderMap::new()))
388 .unwrap_or(Duration::from_secs(1));
389 attempt += 1;
390 sleep_or_cancel(delay, cancel_ref).await?;
391 continue;
392 }
393
394 merged_hooks
395 .run_on_error(ErrorContext {
396 request: req_ctx.clone(),
397 response: None,
398 error: err.clone(),
399 })
400 .await;
401
402 if retry_transport && retry_policy.is_some() {
403 return Err(Error::retry_exhausted(attempt + 1, err));
404 }
405
406 return Err(err);
407 }
408 }
409 }
410 }
411}
412
413pub struct ClientBuilder {
415 base_url: Option<Url>,
416 timeout: Option<Duration>,
417 retry: Option<RetryPolicy>,
418 auth: Option<Auth>,
419 default_headers: http::HeaderMap,
420 hooks: Hooks,
421 plugins: PluginRegistry,
422 reqwest_client: Option<ReqwestClient>,
423 custom_backend: Option<Arc<dyn HttpBackend>>,
424 max_in_flight: Option<usize>,
425 #[cfg(feature = "schema")]
426 schema_registry: Option<Arc<SchemaRegistry>>,
427 #[cfg(feature = "json")]
428 json_parser: Option<JsonParserFn>,
429}
430
431impl ClientBuilder {
432 pub fn new() -> Self {
434 Self {
435 base_url: None,
436 timeout: None,
437 retry: None,
438 auth: None,
439 default_headers: http::HeaderMap::new(),
440 hooks: Hooks::default(),
441 plugins: PluginRegistry::new(),
442 reqwest_client: None,
443 custom_backend: None,
444 max_in_flight: None,
445 #[cfg(feature = "schema")]
446 schema_registry: None,
447 #[cfg(feature = "json")]
448 json_parser: None,
449 }
450 }
451
452 pub fn base_url(mut self, base_url: impl AsRef<str>) -> Result<Self> {
454 self.base_url = Some(Url::parse(base_url.as_ref()).map_err(Error::InvalidBaseUrl)?);
455 Ok(self)
456 }
457
458 pub fn timeout(mut self, timeout: Duration) -> Self {
460 self.timeout = Some(timeout);
461 self
462 }
463
464 pub fn retry(mut self, policy: RetryPolicy) -> Self {
466 self.retry = Some(policy);
467 self
468 }
469
470 pub fn auth(mut self, auth: Auth) -> Self {
472 self.auth = Some(auth);
473 self
474 }
475
476 pub fn default_header(mut self, key: impl AsRef<str>, value: impl AsRef<str>) -> Result<Self> {
478 let name = http::HeaderName::from_bytes(key.as_ref().as_bytes())
479 .map_err(|e| Error::Other(format!("invalid header name: {e}")))?;
480 let value = http::HeaderValue::from_str(value.as_ref())
481 .map_err(|e| Error::Other(format!("invalid header value: {e}")))?;
482 self.default_headers.insert(name, value);
483 Ok(self)
484 }
485
486 pub fn hooks(mut self, hooks: Hooks) -> Self {
488 self.hooks = hooks;
489 self
490 }
491
492 pub fn plugin<P: crate::plugin::Plugin + 'static>(mut self, plugin: P) -> Self {
494 self.plugins.push(Box::new(plugin));
495 self
496 }
497
498 pub fn reqwest_client(mut self, client: ReqwestClient) -> Self {
500 self.reqwest_client = Some(client);
501 self
502 }
503
504 pub fn backend(mut self, backend: Arc<dyn HttpBackend>) -> Self {
534 self.custom_backend = Some(backend);
535 self
536 }
537
538 pub fn max_in_flight(mut self, limit: usize) -> Self {
545 self.max_in_flight = Some(limit);
546 self
547 }
548
549 #[cfg(feature = "schema")]
551 pub fn schema_registry(mut self, registry: Arc<SchemaRegistry>) -> Self {
552 self.schema_registry = Some(registry);
553 self
554 }
555
556 #[cfg(feature = "tower")]
558 pub fn http_service<S>(mut self, service: S) -> Self
559 where
560 S: tower::Service<HttpRequest, Response = HttpResponse, Error = Error>
561 + Clone
562 + Send
563 + 'static,
564 S::Future: Send + 'static,
565 {
566 use crate::tower::ServiceBackend;
567
568 self.custom_backend = Some(Arc::new(ServiceBackend::new(service)));
569 self
570 }
571
572 #[cfg(feature = "tower")]
574 pub fn http_service_boxed(mut self, service: crate::tower::BoxHttpService) -> Self {
575 use crate::tower::ServiceBackend;
576
577 self.custom_backend = Some(Arc::new(ServiceBackend::from_box(service)));
578 self
579 }
580
581 #[cfg(feature = "tower")]
603 pub fn transport_stack<F>(mut self, configure: F) -> Self
604 where
605 F: FnOnce(crate::tower::ReqwestHttpService) -> crate::tower::BoxHttpService,
606 {
607 use crate::tower::ServiceBackend;
608
609 let client = self.reqwest_client.clone().unwrap_or_default();
610 let stacked = configure(crate::tower::ReqwestHttpService::new(client));
611 self.custom_backend = Some(Arc::new(ServiceBackend::from_box(stacked)));
612 self
613 }
614
615 #[cfg(feature = "json")]
636 pub fn json_parser<F>(mut self, f: F) -> Self
637 where
638 F: Fn(&bytes::Bytes) -> std::result::Result<serde_json::Value, String>
639 + Send
640 + Sync
641 + 'static,
642 {
643 self.json_parser = Some(crate::json_parser::json_parser(f));
644 self
645 }
646
647 #[cfg(feature = "json")]
649 pub fn json_parser_fn(mut self, parser: JsonParserFn) -> Self {
650 self.json_parser = Some(parser);
651 self
652 }
653
654 pub fn build(self) -> Result<Client> {
666 let base_url = self
667 .base_url
668 .ok_or(Error::MissingBaseUrl)?;
669
670 let backend: Arc<dyn HttpBackend> = if let Some(b) = self.custom_backend {
671 b
672 } else {
673 let reqwest_client = self.reqwest_client.unwrap_or_default();
674 Arc::new(ReqwestBackend::new(reqwest_client))
675 };
676
677 let plugins = Arc::new(self.plugins);
678 let merged_hooks = self.hooks.clone().merge(plugins.merged_hooks());
679
680 Ok(Client {
681 config: Arc::new(ClientConfig {
682 base_url,
683 timeout: self.timeout,
684 retry: self.retry,
685 auth: self.auth,
686 default_headers: self.default_headers,
687 hooks: self.hooks,
688 merged_hooks,
689 plugins,
690 max_in_flight: self.max_in_flight.map(|n| Arc::new(Semaphore::new(n))),
691 #[cfg(feature = "schema")]
692 schema_registry: self.schema_registry,
693 #[cfg(feature = "json")]
694 json_parser: self.json_parser,
695 }),
696 backend,
697 })
698 }
699}
700
701impl Default for ClientBuilder {
702 fn default() -> Self {
703 Self::new()
704 }
705}