1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use tokio::sync::Semaphore;
6
7use http::Method;
8use reqwest::Client as ReqwestClient;
9use url::Url;
10
11use crate::auth::Auth;
12use crate::backend::{HttpBackend, HttpRequest, ReqwestBackend};
13#[cfg(feature = "tower")]
14use crate::backend::HttpResponse;
15use crate::error::Error;
16use crate::hooks::{ErrorContext, Hooks, RequestContext, ResponseContext, SuccessContext};
17use crate::plugin::{PluginRegistry, PreparedRequest};
18use crate::request::RequestBuilder;
19use crate::response::Response;
20use crate::retry::{sleep_before_retry, RetryPolicy};
21use crate::url_build::build_url;
22use crate::Result;
23
24#[cfg(feature = "json")]
25use crate::json_parser::JsonParserFn;
26
27#[cfg(feature = "schema")]
28use crate::schema::SchemaRegistry;
29
30#[derive(Clone)]
32pub struct ClientConfig {
33 pub base_url: Url,
34 pub timeout: Option<Duration>,
35 pub retry: Option<RetryPolicy>,
36 pub auth: Option<Auth>,
37 pub default_headers: http::HeaderMap,
38 pub hooks: Hooks,
39 pub plugins: Arc<PluginRegistry>,
40 pub max_in_flight: Option<Arc<Semaphore>>,
42 #[cfg(feature = "schema")]
43 pub schema_registry: Option<Arc<SchemaRegistry>>,
44 #[cfg(feature = "json")]
45 pub json_parser: Option<JsonParserFn>,
46}
47
48#[derive(Clone)]
50pub struct Client {
51 config: Arc<ClientConfig>,
52 backend: Arc<dyn HttpBackend>,
53}
54
55impl Client {
56 pub fn new(base_url: impl AsRef<str>) -> Result<Self> {
57 ClientBuilder::new().base_url(base_url)?.build()
58 }
59
60 pub fn builder() -> ClientBuilder {
61 ClientBuilder::new()
62 }
63
64 pub fn with_http_client(reqwest_client: ReqwestClient) -> Result<Self> {
65 ClientBuilder::new().reqwest_client(reqwest_client).build()
66 }
67
68 pub fn config(&self) -> &ClientConfig {
69 &self.config
70 }
71
72 pub fn get(&self, path: impl Into<String>) -> RequestBuilder<'_> {
73 self.request(Method::GET, path)
74 }
75
76 pub fn post(&self, path: impl Into<String>) -> RequestBuilder<'_> {
77 self.request(Method::POST, path)
78 }
79
80 pub fn put(&self, path: impl Into<String>) -> RequestBuilder<'_> {
81 self.request(Method::PUT, path)
82 }
83
84 pub fn patch(&self, path: impl Into<String>) -> RequestBuilder<'_> {
85 self.request(Method::PATCH, path)
86 }
87
88 pub fn delete(&self, path: impl Into<String>) -> RequestBuilder<'_> {
89 self.request(Method::DELETE, path)
90 }
91
92 pub fn head(&self, path: impl Into<String>) -> RequestBuilder<'_> {
93 self.request(Method::HEAD, path)
94 }
95
96 pub fn request(&self, method: Method, path: impl Into<String>) -> RequestBuilder<'_> {
97 RequestBuilder {
98 client: self,
99 method,
100 path: path.into(),
101 params: HashMap::new(),
102 query: HashMap::new(),
103 headers: self.config.default_headers.clone(),
104 body: None,
105 timeout: self.config.timeout,
106 retry: self.config.retry.clone(),
107 auth: self.config.auth.clone(),
108 #[cfg(feature = "json")]
109 json_parser: None,
110 #[cfg(feature = "validate")]
111 validate_response: true,
112 }
113 }
114
115 pub(crate) async fn execute(&self, builder: RequestBuilder<'_>) -> Result<Response> {
116 #[cfg(feature = "json")]
117 let json_parser = builder
118 .json_parser
119 .clone()
120 .or_else(|| self.config.json_parser.clone());
121 let built = build_url(
122 &self.config.base_url,
123 &builder.path,
124 &builder.params,
125 &builder.query,
126 )?;
127
128 let mut method = builder.method;
129 if let Some(override_method) = built.method_override {
130 method = override_method;
131 }
132
133 #[cfg(feature = "schema")]
134 if let Some(registry) = &self.config.schema_registry {
135 registry.ensure_route(&builder.path, &method)?;
136 }
137
138 let mut url = built.url;
139
140 let mut prepared = PreparedRequest {
141 url: url.clone(),
142 path: builder.path.clone(),
143 };
144 self.config.plugins.run_init_all(&mut prepared).await?;
145 url = prepared.url;
146
147 let mut headers = builder.headers;
148 let auth = builder.auth.or_else(|| self.config.auth.clone());
149 if let Some(auth) = auth {
150 auth.apply(&mut headers).await?;
151 }
152
153 let mut req_ctx = RequestContext {
154 url: url.clone(),
155 method: method.clone(),
156 headers: headers.clone(),
157 body: builder.body.clone(),
158 retry_attempt: 0,
159 };
160
161 let merged_hooks = self
162 .config
163 .hooks
164 .clone()
165 .merge(self.config.plugins.merged_hooks());
166
167 req_ctx = merged_hooks.run_on_request(req_ctx).await?;
168 url = req_ctx.url.clone();
169 headers = req_ctx.headers.clone();
170 method = req_ctx.method.clone();
171
172 let timeout = builder.timeout;
173 let retry_policy = builder.retry.or_else(|| self.config.retry.clone());
174
175 let backend = self.backend.clone();
176 let body = req_ctx.body.clone();
177
178 let _in_flight_permit = match &self.config.max_in_flight {
179 Some(sem) => Some(
180 sem.acquire()
181 .await
182 .map_err(|_| Error::Other("max_in_flight semaphore closed".into()))?,
183 ),
184 None => None,
185 };
186
187 let mut attempt = 0u32;
188 let max_attempts = retry_policy.as_ref().map(|p| p.max_attempts()).unwrap_or(0);
189
190 let http_req = HttpRequest {
191 method,
192 url,
193 headers,
194 body,
195 timeout,
196 };
197
198 loop {
199 req_ctx.retry_attempt = attempt;
200
201 let result = backend.execute(http_req.clone()).await;
202
203 match result {
204 Ok(http_res) => {
205 let response = Response::new(
206 http_res.status,
207 http_res.headers,
208 http_res.body,
209 Some(http_req.url.clone()),
210 #[cfg(feature = "json")]
211 json_parser.clone(),
212 );
213
214 let response = merged_hooks
215 .run_on_response(ResponseContext {
216 request: req_ctx.clone(),
217 response,
218 })
219 .await?;
220
221 let should_retry = retry_policy
222 .as_ref()
223 .map(|p| p.should_retry_response(&response, false))
224 .unwrap_or(false);
225
226 if should_retry && attempt < max_attempts {
227 merged_hooks
228 .run_on_retry(ResponseContext {
229 request: req_ctx.clone(),
230 response: response.clone(),
231 })
232 .await;
233 let delay = retry_policy
234 .as_ref()
235 .map(|p| p.delay_before_attempt(attempt))
236 .unwrap_or(Duration::from_secs(1));
237 attempt += 1;
238 sleep_before_retry(delay).await;
239 continue;
240 }
241
242 if response.is_success() {
243 merged_hooks
244 .run_on_success(SuccessContext {
245 request: req_ctx.clone(),
246 response: response.clone(),
247 })
248 .await;
249 } else {
250 let status = response.status();
251 merged_hooks
252 .run_on_error(ErrorContext {
253 request: req_ctx.clone(),
254 response: Some(response.clone()),
255 error: Error::http_with_status_text(
256 status,
257 status.canonical_reason().unwrap_or("request failed"),
258 status.canonical_reason().unwrap_or("request failed"),
259 Some(response.bytes().clone()),
260 ),
261 })
262 .await;
263 }
264
265 return Ok(response);
266 }
267 Err(err) => {
268 let retry_transport = matches!(&err, Error::Transport(_) | Error::Timeout);
269 if retry_transport && retry_policy.is_some() && attempt < max_attempts {
270 merged_hooks
271 .run_on_retry(ResponseContext {
272 request: req_ctx.clone(),
273 response: Response::new(
274 http::StatusCode::SERVICE_UNAVAILABLE,
275 http::HeaderMap::new(),
276 bytes::Bytes::new(),
277 Some(http_req.url.clone()),
278 #[cfg(feature = "json")]
279 None,
280 ),
281 })
282 .await;
283 let delay = retry_policy
284 .as_ref()
285 .map(|p| p.delay_before_attempt(attempt))
286 .unwrap_or(Duration::from_secs(1));
287 attempt += 1;
288 sleep_before_retry(delay).await;
289 continue;
290 }
291
292 merged_hooks
293 .run_on_error(ErrorContext {
294 request: req_ctx.clone(),
295 response: None,
296 error: err.clone(),
297 })
298 .await;
299
300 if retry_transport && retry_policy.is_some() {
301 return Err(Error::retry_exhausted(attempt + 1, err));
302 }
303
304 return Err(err);
305 }
306 }
307 }
308 }
309}
310
311pub struct ClientBuilder {
313 base_url: Option<Url>,
314 timeout: Option<Duration>,
315 retry: Option<RetryPolicy>,
316 auth: Option<Auth>,
317 default_headers: http::HeaderMap,
318 hooks: Hooks,
319 plugins: PluginRegistry,
320 reqwest_client: Option<ReqwestClient>,
321 custom_backend: Option<Arc<dyn HttpBackend>>,
322 max_in_flight: Option<usize>,
323 #[cfg(feature = "schema")]
324 schema_registry: Option<Arc<SchemaRegistry>>,
325 #[cfg(feature = "json")]
326 json_parser: Option<JsonParserFn>,
327}
328
329impl ClientBuilder {
330 pub fn new() -> Self {
331 Self {
332 base_url: None,
333 timeout: None,
334 retry: None,
335 auth: None,
336 default_headers: http::HeaderMap::new(),
337 hooks: Hooks::default(),
338 plugins: PluginRegistry::new(),
339 reqwest_client: None,
340 custom_backend: None,
341 max_in_flight: None,
342 #[cfg(feature = "schema")]
343 schema_registry: None,
344 #[cfg(feature = "json")]
345 json_parser: None,
346 }
347 }
348
349 pub fn base_url(mut self, base_url: impl AsRef<str>) -> Result<Self> {
350 self.base_url = Some(Url::parse(base_url.as_ref()).map_err(Error::InvalidBaseUrl)?);
351 Ok(self)
352 }
353
354 pub fn timeout(mut self, timeout: Duration) -> Self {
355 self.timeout = Some(timeout);
356 self
357 }
358
359 pub fn retry(mut self, policy: RetryPolicy) -> Self {
360 self.retry = Some(policy);
361 self
362 }
363
364 pub fn auth(mut self, auth: Auth) -> Self {
365 self.auth = Some(auth);
366 self
367 }
368
369 pub fn default_header(mut self, key: impl AsRef<str>, value: impl AsRef<str>) -> Result<Self> {
370 let name = http::HeaderName::from_bytes(key.as_ref().as_bytes())
371 .map_err(|e| Error::Other(format!("invalid header name: {e}")))?;
372 let value = http::HeaderValue::from_str(value.as_ref())
373 .map_err(|e| Error::Other(format!("invalid header value: {e}")))?;
374 self.default_headers.insert(name, value);
375 Ok(self)
376 }
377
378 pub fn hooks(mut self, hooks: Hooks) -> Self {
379 self.hooks = hooks;
380 self
381 }
382
383 pub fn plugin<P: crate::plugin::Plugin + 'static>(mut self, plugin: P) -> Self {
384 self.plugins.push(Box::new(plugin));
385 self
386 }
387
388 pub fn reqwest_client(mut self, client: ReqwestClient) -> Self {
389 self.reqwest_client = Some(client);
390 self
391 }
392
393 pub fn backend(mut self, backend: Arc<dyn HttpBackend>) -> Self {
395 self.custom_backend = Some(backend);
396 self
397 }
398
399 pub fn max_in_flight(mut self, limit: usize) -> Self {
405 self.max_in_flight = Some(limit);
406 self
407 }
408
409 #[cfg(feature = "schema")]
411 pub fn schema_registry(mut self, registry: Arc<SchemaRegistry>) -> Self {
412 self.schema_registry = Some(registry);
413 self
414 }
415
416 #[cfg(feature = "tower")]
418 pub fn http_service<S>(mut self, service: S) -> Self
419 where
420 S: tower::Service<HttpRequest, Response = HttpResponse, Error = Error>
421 + Clone
422 + Send
423 + 'static,
424 S::Future: Send + 'static,
425 {
426 use crate::tower::ServiceBackend;
427
428 self.custom_backend = Some(Arc::new(ServiceBackend::new(service)));
429 self
430 }
431
432 #[cfg(feature = "tower")]
434 pub fn http_service_boxed(mut self, service: crate::tower::BoxHttpService) -> Self {
435 use crate::tower::ServiceBackend;
436
437 self.custom_backend = Some(Arc::new(ServiceBackend::from_box(service)));
438 self
439 }
440
441 #[cfg(feature = "tower")]
446 pub fn transport_stack<F>(mut self, configure: F) -> Self
447 where
448 F: FnOnce(crate::tower::ReqwestHttpService) -> crate::tower::BoxHttpService,
449 {
450 use crate::tower::ServiceBackend;
451
452 let client = self.reqwest_client.clone().unwrap_or_default();
453 let stacked = configure(crate::tower::ReqwestHttpService::new(client));
454 self.custom_backend = Some(Arc::new(ServiceBackend::from_box(stacked)));
455 self
456 }
457
458 #[cfg(feature = "json")]
463 pub fn json_parser<F>(mut self, f: F) -> Self
464 where
465 F: Fn(&bytes::Bytes) -> std::result::Result<serde_json::Value, String>
466 + Send
467 + Sync
468 + 'static,
469 {
470 self.json_parser = Some(crate::json_parser::json_parser(f));
471 self
472 }
473
474 #[cfg(feature = "json")]
476 pub fn json_parser_fn(mut self, parser: JsonParserFn) -> Self {
477 self.json_parser = Some(parser);
478 self
479 }
480
481 pub fn build(self) -> Result<Client> {
482 let base_url = match self.base_url {
483 Some(url) => url,
484 None => Url::parse("http://localhost")
485 .map_err(|e| Error::Other(format!("invalid default base URL: {e}")))?,
486 };
487
488 let backend: Arc<dyn HttpBackend> = if let Some(b) = self.custom_backend {
489 b
490 } else {
491 let reqwest_client = self.reqwest_client.unwrap_or_default();
492 Arc::new(ReqwestBackend::new(reqwest_client))
493 };
494
495 Ok(Client {
496 config: Arc::new(ClientConfig {
497 base_url,
498 timeout: self.timeout,
499 retry: self.retry,
500 auth: self.auth,
501 default_headers: self.default_headers,
502 hooks: self.hooks,
503 plugins: Arc::new(self.plugins),
504 max_in_flight: self.max_in_flight.map(|n| Arc::new(Semaphore::new(n))),
505 #[cfg(feature = "schema")]
506 schema_registry: self.schema_registry,
507 #[cfg(feature = "json")]
508 json_parser: self.json_parser,
509 }),
510 backend,
511 })
512 }
513}
514
515impl Default for ClientBuilder {
516 fn default() -> Self {
517 Self::new()
518 }
519}