1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use tokio::sync::Semaphore;
6
7use http::Method;
8use reqwest::Client as ReqwestClient;
9use url::Url;
10
11use crate::auth::Auth;
12use crate::backend::{HttpBackend, HttpRequest, HttpResponse, ReqwestBackend};
13use crate::error::Error;
14use crate::hooks::{ErrorContext, Hooks, RequestContext, ResponseContext, SuccessContext};
15use crate::plugin::{PluginRegistry, PreparedRequest};
16use crate::request::RequestBuilder;
17use crate::response::Response;
18use crate::retry::{sleep_before_retry, RetryPolicy};
19use crate::url_build::build_url;
20use crate::Result;
21
22#[cfg(feature = "json")]
23use crate::json_parser::JsonParserFn;
24
25#[cfg(feature = "schema")]
26use crate::schema::SchemaRegistry;
27
28#[derive(Clone)]
30pub struct ClientConfig {
31 pub base_url: Url,
32 pub timeout: Option<Duration>,
33 pub retry: Option<RetryPolicy>,
34 pub auth: Option<Auth>,
35 pub default_headers: http::HeaderMap,
36 pub hooks: Hooks,
37 pub plugins: Arc<PluginRegistry>,
38 pub max_in_flight: Option<Arc<Semaphore>>,
40 #[cfg(feature = "schema")]
41 pub schema_registry: Option<Arc<SchemaRegistry>>,
42 #[cfg(feature = "json")]
43 pub json_parser: Option<JsonParserFn>,
44}
45
46#[derive(Clone)]
48pub struct Client {
49 config: Arc<ClientConfig>,
50 backend: Arc<dyn HttpBackend>,
51}
52
53impl Client {
54 pub fn new(base_url: impl AsRef<str>) -> Result<Self> {
55 ClientBuilder::new().base_url(base_url)?.build()
56 }
57
58 pub fn builder() -> ClientBuilder {
59 ClientBuilder::new()
60 }
61
62 pub fn with_http_client(reqwest_client: ReqwestClient) -> Result<Self> {
63 ClientBuilder::new().reqwest_client(reqwest_client).build()
64 }
65
66 pub fn config(&self) -> &ClientConfig {
67 &self.config
68 }
69
70 pub fn get(&self, path: impl Into<String>) -> RequestBuilder<'_> {
71 self.request(Method::GET, path)
72 }
73
74 pub fn post(&self, path: impl Into<String>) -> RequestBuilder<'_> {
75 self.request(Method::POST, path)
76 }
77
78 pub fn put(&self, path: impl Into<String>) -> RequestBuilder<'_> {
79 self.request(Method::PUT, path)
80 }
81
82 pub fn patch(&self, path: impl Into<String>) -> RequestBuilder<'_> {
83 self.request(Method::PATCH, path)
84 }
85
86 pub fn delete(&self, path: impl Into<String>) -> RequestBuilder<'_> {
87 self.request(Method::DELETE, path)
88 }
89
90 pub fn head(&self, path: impl Into<String>) -> RequestBuilder<'_> {
91 self.request(Method::HEAD, path)
92 }
93
94 pub fn request(&self, method: Method, path: impl Into<String>) -> RequestBuilder<'_> {
95 RequestBuilder {
96 client: self,
97 method,
98 path: path.into(),
99 params: HashMap::new(),
100 query: HashMap::new(),
101 headers: self.config.default_headers.clone(),
102 body: None,
103 timeout: self.config.timeout,
104 retry: self.config.retry.clone(),
105 auth: self.config.auth.clone(),
106 #[cfg(feature = "json")]
107 json_parser: None,
108 #[cfg(feature = "validate")]
109 validate_response: true,
110 }
111 }
112
113 pub(crate) async fn execute(&self, builder: RequestBuilder<'_>) -> Result<Response> {
114 #[cfg(feature = "json")]
115 let json_parser = builder
116 .json_parser
117 .clone()
118 .or_else(|| self.config.json_parser.clone());
119 let built = build_url(
120 &self.config.base_url,
121 &builder.path,
122 &builder.params,
123 &builder.query,
124 )?;
125
126 let mut method = builder.method;
127 if let Some(override_method) = built.method_override {
128 method = override_method;
129 }
130
131 #[cfg(feature = "schema")]
132 if let Some(registry) = &self.config.schema_registry {
133 registry.ensure_route(&builder.path, &method)?;
134 }
135
136 let mut url = built.url;
137
138 let mut prepared = PreparedRequest {
139 url: url.clone(),
140 path: builder.path.clone(),
141 };
142 self.config.plugins.run_init_all(&mut prepared).await?;
143 url = prepared.url;
144
145 let mut headers = builder.headers;
146 let auth = builder.auth.or_else(|| self.config.auth.clone());
147 if let Some(auth) = auth {
148 auth.apply(&mut headers).await?;
149 }
150
151 let mut req_ctx = RequestContext {
152 url: url.clone(),
153 method: method.clone(),
154 headers: headers.clone(),
155 body: builder.body.clone(),
156 retry_attempt: 0,
157 };
158
159 let merged_hooks = self
160 .config
161 .hooks
162 .clone()
163 .merge(self.config.plugins.merged_hooks());
164
165 req_ctx = merged_hooks.run_on_request(req_ctx).await?;
166 url = req_ctx.url.clone();
167 headers = req_ctx.headers.clone();
168 method = req_ctx.method.clone();
169
170 let timeout = builder.timeout;
171 let retry_policy = builder.retry.or_else(|| self.config.retry.clone());
172
173 let backend = self.backend.clone();
174 let body = req_ctx.body.clone();
175
176 let _in_flight_permit = match &self.config.max_in_flight {
177 Some(sem) => Some(
178 sem.acquire()
179 .await
180 .map_err(|_| Error::Other("max_in_flight semaphore closed".into()))?,
181 ),
182 None => None,
183 };
184
185 let mut attempt = 0u32;
186 let max_attempts = retry_policy.as_ref().map(|p| p.max_attempts()).unwrap_or(0);
187
188 let http_req = HttpRequest {
189 method,
190 url,
191 headers,
192 body,
193 timeout,
194 };
195
196 loop {
197 req_ctx.retry_attempt = attempt;
198
199 let result = backend.execute(http_req.clone()).await;
200
201 match result {
202 Ok(http_res) => {
203 let response = Response::new(
204 http_res.status,
205 http_res.headers,
206 http_res.body,
207 Some(http_req.url.clone()),
208 #[cfg(feature = "json")]
209 json_parser.clone(),
210 );
211
212 let response = merged_hooks
213 .run_on_response(ResponseContext {
214 request: req_ctx.clone(),
215 response,
216 })
217 .await?;
218
219 let should_retry = retry_policy
220 .as_ref()
221 .map(|p| p.should_retry_response(&response, false))
222 .unwrap_or(false);
223
224 if should_retry && attempt < max_attempts {
225 merged_hooks
226 .run_on_retry(ResponseContext {
227 request: req_ctx.clone(),
228 response: response.clone(),
229 })
230 .await;
231 let delay = retry_policy
232 .as_ref()
233 .map(|p| p.delay_before_attempt(attempt))
234 .unwrap_or(Duration::from_secs(1));
235 attempt += 1;
236 sleep_before_retry(delay).await;
237 continue;
238 }
239
240 if response.is_success() {
241 merged_hooks
242 .run_on_success(SuccessContext {
243 request: req_ctx.clone(),
244 response: response.clone(),
245 })
246 .await;
247 } else {
248 let status = response.status();
249 merged_hooks
250 .run_on_error(ErrorContext {
251 request: req_ctx.clone(),
252 response: Some(response.clone()),
253 error: Error::http_with_status_text(
254 status,
255 status.canonical_reason().unwrap_or("request failed"),
256 status.canonical_reason().unwrap_or("request failed"),
257 Some(response.bytes().clone()),
258 ),
259 })
260 .await;
261 }
262
263 return Ok(response);
264 }
265 Err(err) => {
266 let retry_transport = matches!(&err, Error::Transport(_) | Error::Timeout);
267 if retry_transport && retry_policy.is_some() && attempt < max_attempts {
268 merged_hooks
269 .run_on_retry(ResponseContext {
270 request: req_ctx.clone(),
271 response: Response::new(
272 http::StatusCode::SERVICE_UNAVAILABLE,
273 http::HeaderMap::new(),
274 bytes::Bytes::new(),
275 Some(http_req.url.clone()),
276 #[cfg(feature = "json")]
277 None,
278 ),
279 })
280 .await;
281 let delay = retry_policy
282 .as_ref()
283 .map(|p| p.delay_before_attempt(attempt))
284 .unwrap_or(Duration::from_secs(1));
285 attempt += 1;
286 sleep_before_retry(delay).await;
287 continue;
288 }
289
290 merged_hooks
291 .run_on_error(ErrorContext {
292 request: req_ctx.clone(),
293 response: None,
294 error: err.clone(),
295 })
296 .await;
297
298 if retry_transport && retry_policy.is_some() {
299 return Err(Error::retry_exhausted(attempt + 1, err));
300 }
301
302 return Err(err);
303 }
304 }
305 }
306 }
307}
308
309pub struct ClientBuilder {
311 base_url: Option<Url>,
312 timeout: Option<Duration>,
313 retry: Option<RetryPolicy>,
314 auth: Option<Auth>,
315 default_headers: http::HeaderMap,
316 hooks: Hooks,
317 plugins: PluginRegistry,
318 reqwest_client: Option<ReqwestClient>,
319 custom_backend: Option<Arc<dyn HttpBackend>>,
320 max_in_flight: Option<usize>,
321 #[cfg(feature = "schema")]
322 schema_registry: Option<Arc<SchemaRegistry>>,
323 #[cfg(feature = "json")]
324 json_parser: Option<JsonParserFn>,
325}
326
327impl ClientBuilder {
328 pub fn new() -> Self {
329 Self {
330 base_url: None,
331 timeout: None,
332 retry: None,
333 auth: None,
334 default_headers: http::HeaderMap::new(),
335 hooks: Hooks::default(),
336 plugins: PluginRegistry::new(),
337 reqwest_client: None,
338 custom_backend: None,
339 max_in_flight: None,
340 #[cfg(feature = "schema")]
341 schema_registry: None,
342 #[cfg(feature = "json")]
343 json_parser: None,
344 }
345 }
346
347 pub fn base_url(mut self, base_url: impl AsRef<str>) -> Result<Self> {
348 self.base_url = Some(Url::parse(base_url.as_ref()).map_err(Error::InvalidBaseUrl)?);
349 Ok(self)
350 }
351
352 pub fn timeout(mut self, timeout: Duration) -> Self {
353 self.timeout = Some(timeout);
354 self
355 }
356
357 pub fn retry(mut self, policy: RetryPolicy) -> Self {
358 self.retry = Some(policy);
359 self
360 }
361
362 pub fn auth(mut self, auth: Auth) -> Self {
363 self.auth = Some(auth);
364 self
365 }
366
367 pub fn default_header(mut self, key: impl AsRef<str>, value: impl AsRef<str>) -> Result<Self> {
368 let name = http::HeaderName::from_bytes(key.as_ref().as_bytes())
369 .map_err(|e| Error::Other(format!("invalid header name: {e}")))?;
370 let value = http::HeaderValue::from_str(value.as_ref())
371 .map_err(|e| Error::Other(format!("invalid header value: {e}")))?;
372 self.default_headers.insert(name, value);
373 Ok(self)
374 }
375
376 pub fn hooks(mut self, hooks: Hooks) -> Self {
377 self.hooks = hooks;
378 self
379 }
380
381 pub fn plugin<P: crate::plugin::Plugin + 'static>(mut self, plugin: P) -> Self {
382 self.plugins.push(Box::new(plugin));
383 self
384 }
385
386 pub fn reqwest_client(mut self, client: ReqwestClient) -> Self {
387 self.reqwest_client = Some(client);
388 self
389 }
390
391 pub fn backend(mut self, backend: Arc<dyn HttpBackend>) -> Self {
393 self.custom_backend = Some(backend);
394 self
395 }
396
397 pub fn max_in_flight(mut self, limit: usize) -> Self {
403 self.max_in_flight = Some(limit);
404 self
405 }
406
407 #[cfg(feature = "schema")]
409 pub fn schema_registry(mut self, registry: Arc<SchemaRegistry>) -> Self {
410 self.schema_registry = Some(registry);
411 self
412 }
413
414 #[cfg(feature = "tower")]
416 pub fn http_service<S>(mut self, service: S) -> Self
417 where
418 S: tower::Service<HttpRequest, Response = HttpResponse, Error = Error>
419 + Clone
420 + Send
421 + 'static,
422 S::Future: Send + 'static,
423 {
424 use crate::tower::ServiceBackend;
425
426 self.custom_backend = Some(Arc::new(ServiceBackend::new(service)));
427 self
428 }
429
430 #[cfg(feature = "tower")]
432 pub fn http_service_boxed(mut self, service: crate::tower::BoxHttpService) -> Self {
433 use crate::tower::ServiceBackend;
434
435 self.custom_backend = Some(Arc::new(ServiceBackend::from_box(service)));
436 self
437 }
438
439 #[cfg(feature = "tower")]
444 pub fn transport_stack<F>(mut self, configure: F) -> Self
445 where
446 F: FnOnce(crate::tower::ReqwestHttpService) -> crate::tower::BoxHttpService,
447 {
448 use crate::tower::ServiceBackend;
449
450 let client = self.reqwest_client.clone().unwrap_or_default();
451 let stacked = configure(crate::tower::ReqwestHttpService::new(client));
452 self.custom_backend = Some(Arc::new(ServiceBackend::from_box(stacked)));
453 self
454 }
455
456 #[cfg(feature = "json")]
461 pub fn json_parser<F>(mut self, f: F) -> Self
462 where
463 F: Fn(&bytes::Bytes) -> std::result::Result<serde_json::Value, String>
464 + Send
465 + Sync
466 + 'static,
467 {
468 self.json_parser = Some(crate::json_parser::json_parser(f));
469 self
470 }
471
472 #[cfg(feature = "json")]
474 pub fn json_parser_fn(mut self, parser: JsonParserFn) -> Self {
475 self.json_parser = Some(parser);
476 self
477 }
478
479 pub fn build(self) -> Result<Client> {
480 let base_url = match self.base_url {
481 Some(url) => url,
482 None => Url::parse("http://localhost")
483 .map_err(|e| Error::Other(format!("invalid default base URL: {e}")))?,
484 };
485
486 let backend: Arc<dyn HttpBackend> = if let Some(b) = self.custom_backend {
487 b
488 } else {
489 let reqwest_client = self.reqwest_client.unwrap_or_default();
490 Arc::new(ReqwestBackend::new(reqwest_client))
491 };
492
493 Ok(Client {
494 config: Arc::new(ClientConfig {
495 base_url,
496 timeout: self.timeout,
497 retry: self.retry,
498 auth: self.auth,
499 default_headers: self.default_headers,
500 hooks: self.hooks,
501 plugins: Arc::new(self.plugins),
502 max_in_flight: self.max_in_flight.map(|n| Arc::new(Semaphore::new(n))),
503 #[cfg(feature = "schema")]
504 schema_registry: self.schema_registry,
505 #[cfg(feature = "json")]
506 json_parser: self.json_parser,
507 }),
508 backend,
509 })
510 }
511}
512
513impl Default for ClientBuilder {
514 fn default() -> Self {
515 Self::new()
516 }
517}