1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use indexmap::IndexMap;
6use tokio::sync::Semaphore;
7
8use http::Method;
9use reqwest::Client as ReqwestClient;
10use url::Url;
11
12use crate::auth::Auth;
13use crate::backend::{HttpBackend, HttpBody, HttpRequest, ReqwestBackend};
14use crate::cancel::execute_or_cancel;
15use crate::endpoint::{Endpoint, EndpointRequestBuilder};
16use crate::error::Error;
17use crate::hooks::{ErrorContext, Hooks, RequestContext, ResponseContext, SuccessContext};
18use crate::plugin::{PluginRegistry, PreparedRequest};
19use crate::request::RequestBuilder;
20use crate::response::Response;
21use crate::retry::{sleep_or_cancel, RetryPolicy};
22use crate::url_build::build_url;
23use crate::Result;
24
25#[cfg(feature = "tower")]
26use crate::backend::HttpResponse;
27
28#[cfg(feature = "json")]
29use crate::json_parser::JsonParserFn;
30
31#[cfg(feature = "schema")]
32use crate::schema::SchemaRegistry;
33
34fn body_for_context(body: &HttpBody) -> Option<bytes::Bytes> {
35 match body {
36 HttpBody::Empty => None,
37 HttpBody::Bytes(b) => Some(b.clone()),
38 }
39}
40
41#[derive(Clone)]
43pub struct ClientConfig {
44 pub base_url: Url,
45 pub timeout: Option<Duration>,
46 pub retry: Option<RetryPolicy>,
47 pub auth: Option<Auth>,
48 pub default_headers: http::HeaderMap,
49 pub hooks: Hooks,
50 pub(crate) merged_hooks: Hooks,
51 pub plugins: Arc<PluginRegistry>,
52 pub max_in_flight: Option<Arc<Semaphore>>,
58 #[cfg(feature = "schema")]
59 pub schema_registry: Option<Arc<SchemaRegistry>>,
60 #[cfg(feature = "json")]
61 pub json_parser: Option<JsonParserFn>,
62}
63
64#[derive(Clone)]
66pub struct Client {
67 config: Arc<ClientConfig>,
68 backend: Arc<dyn HttpBackend>,
69}
70
71impl Client {
72 pub fn new(base_url: impl AsRef<str>) -> Result<Self> {
73 ClientBuilder::new().base_url(base_url)?.build()
74 }
75
76 pub fn builder() -> ClientBuilder {
77 ClientBuilder::new()
78 }
79
80 pub fn with_http_client(reqwest_client: ReqwestClient, base_url: impl AsRef<str>) -> Result<Self> {
82 ClientBuilder::new()
83 .reqwest_client(reqwest_client)
84 .base_url(base_url)?
85 .build()
86 }
87
88 pub fn call<E: Endpoint>(&self) -> EndpointRequestBuilder<'_, E> {
90 EndpointRequestBuilder::new(self.request(E::METHOD, E::PATH))
91 }
92
93 pub fn config(&self) -> &ClientConfig {
94 &self.config
95 }
96
97 pub fn get(&self, path: impl Into<String>) -> RequestBuilder<'_> {
98 self.request(Method::GET, path)
99 }
100
101 pub fn post(&self, path: impl Into<String>) -> RequestBuilder<'_> {
102 self.request(Method::POST, path)
103 }
104
105 pub fn put(&self, path: impl Into<String>) -> RequestBuilder<'_> {
106 self.request(Method::PUT, path)
107 }
108
109 pub fn patch(&self, path: impl Into<String>) -> RequestBuilder<'_> {
110 self.request(Method::PATCH, path)
111 }
112
113 pub fn delete(&self, path: impl Into<String>) -> RequestBuilder<'_> {
114 self.request(Method::DELETE, path)
115 }
116
117 pub fn head(&self, path: impl Into<String>) -> RequestBuilder<'_> {
118 self.request(Method::HEAD, path)
119 }
120
121 pub fn request(&self, method: Method, path: impl Into<String>) -> RequestBuilder<'_> {
122 RequestBuilder {
123 client: self,
124 method,
125 path: path.into(),
126 params: HashMap::new(),
127 query: IndexMap::new(),
128 headers: self.config.default_headers.clone(),
129 body: HttpBody::Empty,
130 #[cfg(feature = "multipart")]
131 multipart: None,
132 timeout: self.config.timeout,
133 retry: self.config.retry.clone(),
134 auth: self.config.auth.clone(),
135 cancellation: None,
136 throw_on_error: false,
137 #[cfg(feature = "json")]
138 json_parser: None,
139 #[cfg(feature = "validate")]
140 validate_response: true,
141 }
142 }
143
144 pub(crate) async fn execute(&self, builder: RequestBuilder<'_>) -> Result<Response> {
145 #[cfg(feature = "json")]
146 let json_parser = builder
147 .json_parser
148 .clone()
149 .or_else(|| self.config.json_parser.clone());
150
151 let built = build_url(
152 &self.config.base_url,
153 &builder.path,
154 &builder.params,
155 &builder.query,
156 )?;
157
158 let mut method = builder.method;
159 if let Some(override_method) = built.method_override {
160 method = override_method;
161 }
162
163 #[cfg(feature = "schema")]
164 if let Some(registry) = &self.config.schema_registry {
165 registry.ensure_route(&builder.path, &method)?;
166 }
167
168 let mut url = built.url;
169 let mut headers = builder.headers;
170 let auth = builder.auth.or_else(|| self.config.auth.clone());
171 if let Some(auth) = auth {
172 auth.apply(&mut headers).await?;
173 }
174
175 let mut prepared = PreparedRequest {
176 url: url.clone(),
177 path: builder.path.clone(),
178 method: method.clone(),
179 headers: headers.clone(),
180 };
181 self.config.plugins.run_init_all(&mut prepared).await?;
182 url = prepared.url;
183 headers = prepared.headers;
184 method = prepared.method;
185
186 let mut req_ctx = RequestContext {
187 url: url.clone(),
188 method: method.clone(),
189 headers: headers.clone(),
190 body: body_for_context(&builder.body),
191 retry_attempt: 0,
192 };
193
194 let merged_hooks = &self.config.merged_hooks;
195 req_ctx = merged_hooks.run_on_request(req_ctx).await?;
196 url = req_ctx.url.clone();
197 headers = req_ctx.headers.clone();
198 method = req_ctx.method.clone();
199
200 let timeout = builder.timeout;
201 let retry_policy = builder.retry.or_else(|| self.config.retry.clone());
202 let throw_on_error = builder.throw_on_error;
203 let cancel = builder.cancellation;
204
205 let backend = self.backend.clone();
206
207 let _in_flight_permit = match &self.config.max_in_flight {
208 Some(sem) => Some(
209 sem.acquire()
210 .await
211 .map_err(|_| Error::Other("max_in_flight semaphore closed".into()))?,
212 ),
213 None => None,
214 };
215
216 let mut attempt = 0u32;
217 let max_attempts = retry_policy.as_ref().map(|p| p.max_attempts()).unwrap_or(0);
218
219 let request_body = builder.body;
220 #[cfg(feature = "multipart")]
221 let mut multipart_body = builder.multipart;
222 #[cfg(feature = "multipart")]
223 let had_multipart = multipart_body.is_some();
224
225 let cancel_ref = cancel.as_ref();
226
227 loop {
228 req_ctx.retry_attempt = attempt;
229
230 #[cfg(feature = "multipart")]
231 if attempt > 0 && had_multipart {
232 return Err(Error::Other(
233 "automatic retry is not supported with multipart request bodies".into(),
234 ));
235 }
236
237 let http_req = HttpRequest {
238 method: method.clone(),
239 url: url.clone(),
240 headers: headers.clone(),
241 body: request_body.clone(),
242 timeout,
243 cancellation: cancel.clone(),
244 #[cfg(feature = "multipart")]
245 multipart: multipart_body.take(),
246 };
247 let request_url = http_req.url.clone();
248
249 let result = execute_or_cancel(cancel_ref, backend.execute(http_req)).await;
250
251 match result {
252 Ok(http_res) => {
253 let response = Response::new(
254 http_res.status,
255 http_res.headers.clone(),
256 http_res.body,
257 Some(request_url.clone()),
258 #[cfg(feature = "json")]
259 json_parser.clone(),
260 );
261
262 let response = merged_hooks
263 .run_on_response(ResponseContext {
264 request: req_ctx.clone(),
265 response,
266 })
267 .await?;
268
269 let should_retry = retry_policy
270 .as_ref()
271 .map(|p| p.should_retry_response(&response, false))
272 .unwrap_or(false);
273
274 if should_retry && attempt < max_attempts {
275 merged_hooks
276 .run_on_retry(ResponseContext {
277 request: req_ctx.clone(),
278 response: response.clone(),
279 })
280 .await;
281 let delay = retry_policy
282 .as_ref()
283 .map(|p| p.delay_after_response(attempt, response.headers()))
284 .unwrap_or(Duration::from_secs(1));
285 attempt += 1;
286 sleep_or_cancel(delay, cancel_ref).await?;
287 continue;
288 }
289
290 if response.is_success() {
291 merged_hooks
292 .run_on_success(SuccessContext {
293 request: req_ctx.clone(),
294 response: response.clone(),
295 })
296 .await;
297 return Ok(response);
298 }
299
300 let status = response.status();
301 let http_err = Error::http_with_status_text(
302 status,
303 status.canonical_reason().unwrap_or("request failed"),
304 status.canonical_reason().unwrap_or("request failed"),
305 Some(response.bytes().clone()),
306 );
307 merged_hooks
308 .run_on_error(ErrorContext {
309 request: req_ctx.clone(),
310 response: Some(response.clone()),
311 error: http_err.clone(),
312 })
313 .await;
314
315 if throw_on_error {
316 return Err(http_err);
317 }
318 return Ok(response);
319 }
320 Err(err) => {
321 if err.is_cancelled() {
322 merged_hooks
323 .run_on_error(ErrorContext {
324 request: req_ctx.clone(),
325 response: None,
326 error: err.clone(),
327 })
328 .await;
329 return Err(err);
330 }
331
332 let retry_transport = matches!(&err, Error::Transport(_) | Error::Timeout);
333 if retry_transport && retry_policy.is_some() && attempt < max_attempts {
334 merged_hooks
335 .run_on_retry(ResponseContext {
336 request: req_ctx.clone(),
337 response: Response::new(
338 http::StatusCode::SERVICE_UNAVAILABLE,
339 http::HeaderMap::new(),
340 bytes::Bytes::new(),
341 Some(request_url.clone()),
342 #[cfg(feature = "json")]
343 None,
344 ),
345 })
346 .await;
347 let delay = retry_policy
348 .as_ref()
349 .map(|p| p.delay_after_response(attempt, &http::HeaderMap::new()))
350 .unwrap_or(Duration::from_secs(1));
351 attempt += 1;
352 sleep_or_cancel(delay, cancel_ref).await?;
353 continue;
354 }
355
356 merged_hooks
357 .run_on_error(ErrorContext {
358 request: req_ctx.clone(),
359 response: None,
360 error: err.clone(),
361 })
362 .await;
363
364 if retry_transport && retry_policy.is_some() {
365 return Err(Error::retry_exhausted(attempt + 1, err));
366 }
367
368 return Err(err);
369 }
370 }
371 }
372 }
373}
374
375pub struct ClientBuilder {
377 base_url: Option<Url>,
378 timeout: Option<Duration>,
379 retry: Option<RetryPolicy>,
380 auth: Option<Auth>,
381 default_headers: http::HeaderMap,
382 hooks: Hooks,
383 plugins: PluginRegistry,
384 reqwest_client: Option<ReqwestClient>,
385 custom_backend: Option<Arc<dyn HttpBackend>>,
386 max_in_flight: Option<usize>,
387 #[cfg(feature = "schema")]
388 schema_registry: Option<Arc<SchemaRegistry>>,
389 #[cfg(feature = "json")]
390 json_parser: Option<JsonParserFn>,
391}
392
393impl ClientBuilder {
394 pub fn new() -> Self {
395 Self {
396 base_url: None,
397 timeout: None,
398 retry: None,
399 auth: None,
400 default_headers: http::HeaderMap::new(),
401 hooks: Hooks::default(),
402 plugins: PluginRegistry::new(),
403 reqwest_client: None,
404 custom_backend: None,
405 max_in_flight: None,
406 #[cfg(feature = "schema")]
407 schema_registry: None,
408 #[cfg(feature = "json")]
409 json_parser: None,
410 }
411 }
412
413 pub fn base_url(mut self, base_url: impl AsRef<str>) -> Result<Self> {
414 self.base_url = Some(Url::parse(base_url.as_ref()).map_err(Error::InvalidBaseUrl)?);
415 Ok(self)
416 }
417
418 pub fn timeout(mut self, timeout: Duration) -> Self {
419 self.timeout = Some(timeout);
420 self
421 }
422
423 pub fn retry(mut self, policy: RetryPolicy) -> Self {
424 self.retry = Some(policy);
425 self
426 }
427
428 pub fn auth(mut self, auth: Auth) -> Self {
429 self.auth = Some(auth);
430 self
431 }
432
433 pub fn default_header(mut self, key: impl AsRef<str>, value: impl AsRef<str>) -> Result<Self> {
434 let name = http::HeaderName::from_bytes(key.as_ref().as_bytes())
435 .map_err(|e| Error::Other(format!("invalid header name: {e}")))?;
436 let value = http::HeaderValue::from_str(value.as_ref())
437 .map_err(|e| Error::Other(format!("invalid header value: {e}")))?;
438 self.default_headers.insert(name, value);
439 Ok(self)
440 }
441
442 pub fn hooks(mut self, hooks: Hooks) -> Self {
443 self.hooks = hooks;
444 self
445 }
446
447 pub fn plugin<P: crate::plugin::Plugin + 'static>(mut self, plugin: P) -> Self {
448 self.plugins.push(Box::new(plugin));
449 self
450 }
451
452 pub fn reqwest_client(mut self, client: ReqwestClient) -> Self {
453 self.reqwest_client = Some(client);
454 self
455 }
456
457 pub fn backend(mut self, backend: Arc<dyn HttpBackend>) -> Self {
459 self.custom_backend = Some(backend);
460 self
461 }
462
463 pub fn max_in_flight(mut self, limit: usize) -> Self {
470 self.max_in_flight = Some(limit);
471 self
472 }
473
474 #[cfg(feature = "schema")]
476 pub fn schema_registry(mut self, registry: Arc<SchemaRegistry>) -> Self {
477 self.schema_registry = Some(registry);
478 self
479 }
480
481 #[cfg(feature = "tower")]
483 pub fn http_service<S>(mut self, service: S) -> Self
484 where
485 S: tower::Service<HttpRequest, Response = HttpResponse, Error = Error>
486 + Clone
487 + Send
488 + 'static,
489 S::Future: Send + 'static,
490 {
491 use crate::tower::ServiceBackend;
492
493 self.custom_backend = Some(Arc::new(ServiceBackend::new(service)));
494 self
495 }
496
497 #[cfg(feature = "tower")]
499 pub fn http_service_boxed(mut self, service: crate::tower::BoxHttpService) -> Self {
500 use crate::tower::ServiceBackend;
501
502 self.custom_backend = Some(Arc::new(ServiceBackend::from_box(service)));
503 self
504 }
505
506 #[cfg(feature = "tower")]
511 pub fn transport_stack<F>(mut self, configure: F) -> Self
512 where
513 F: FnOnce(crate::tower::ReqwestHttpService) -> crate::tower::BoxHttpService,
514 {
515 use crate::tower::ServiceBackend;
516
517 let client = self.reqwest_client.clone().unwrap_or_default();
518 let stacked = configure(crate::tower::ReqwestHttpService::new(client));
519 self.custom_backend = Some(Arc::new(ServiceBackend::from_box(stacked)));
520 self
521 }
522
523 #[cfg(feature = "json")]
529 pub fn json_parser<F>(mut self, f: F) -> Self
530 where
531 F: Fn(&bytes::Bytes) -> std::result::Result<serde_json::Value, String>
532 + Send
533 + Sync
534 + 'static,
535 {
536 self.json_parser = Some(crate::json_parser::json_parser(f));
537 self
538 }
539
540 #[cfg(feature = "json")]
542 pub fn json_parser_fn(mut self, parser: JsonParserFn) -> Self {
543 self.json_parser = Some(parser);
544 self
545 }
546
547 pub fn build(self) -> Result<Client> {
548 let base_url = self
549 .base_url
550 .ok_or(Error::MissingBaseUrl)?;
551
552 let backend: Arc<dyn HttpBackend> = if let Some(b) = self.custom_backend {
553 b
554 } else {
555 let reqwest_client = self.reqwest_client.unwrap_or_default();
556 Arc::new(ReqwestBackend::new(reqwest_client))
557 };
558
559 let plugins = Arc::new(self.plugins);
560 let merged_hooks = self.hooks.clone().merge(plugins.merged_hooks());
561
562 Ok(Client {
563 config: Arc::new(ClientConfig {
564 base_url,
565 timeout: self.timeout,
566 retry: self.retry,
567 auth: self.auth,
568 default_headers: self.default_headers,
569 hooks: self.hooks,
570 merged_hooks,
571 plugins,
572 max_in_flight: self.max_in_flight.map(|n| Arc::new(Semaphore::new(n))),
573 #[cfg(feature = "schema")]
574 schema_registry: self.schema_registry,
575 #[cfg(feature = "json")]
576 json_parser: self.json_parser,
577 }),
578 backend,
579 })
580 }
581}
582
583impl Default for ClientBuilder {
584 fn default() -> Self {
585 Self::new()
586 }
587}