1use std::borrow::Cow;
2
3use crate::{Method, Request, Response, Uri};
4use rama_core::{
5 Context, Service,
6 error::{BoxError, ErrorExt, OpaqueError},
7};
8
9pub trait HttpClientExt<State>:
13 private::HttpClientExtSealed<State> + Sized + Send + Sync + 'static
14{
15 type ExecuteResponse;
17 type ExecuteError;
19
20 fn get(&self, url: impl IntoUrl) -> RequestBuilder<Self, State, Self::ExecuteResponse>;
28
29 fn post(&self, url: impl IntoUrl) -> RequestBuilder<Self, State, Self::ExecuteResponse>;
37
38 fn put(&self, url: impl IntoUrl) -> RequestBuilder<Self, State, Self::ExecuteResponse>;
46
47 fn patch(&self, url: impl IntoUrl) -> RequestBuilder<Self, State, Self::ExecuteResponse>;
55
56 fn delete(&self, url: impl IntoUrl) -> RequestBuilder<Self, State, Self::ExecuteResponse>;
64
65 fn head(&self, url: impl IntoUrl) -> RequestBuilder<Self, State, Self::ExecuteResponse>;
71
72 fn request(
85 &self,
86 method: Method,
87 url: impl IntoUrl,
88 ) -> RequestBuilder<Self, State, Self::ExecuteResponse>;
89
90 fn execute(
96 &self,
97 ctx: Context<State>,
98 request: Request,
99 ) -> impl Future<Output = Result<Self::ExecuteResponse, Self::ExecuteError>>;
100}
101
102impl<State, S, Body> HttpClientExt<State> for S
103where
104 S: Service<State, Request, Response = Response<Body>, Error: Into<BoxError>>,
105{
106 type ExecuteResponse = Response<Body>;
107 type ExecuteError = S::Error;
108
109 fn get(&self, url: impl IntoUrl) -> RequestBuilder<Self, State, Self::ExecuteResponse> {
110 self.request(Method::GET, url)
111 }
112
113 fn post(&self, url: impl IntoUrl) -> RequestBuilder<Self, State, Self::ExecuteResponse> {
114 self.request(Method::POST, url)
115 }
116
117 fn put(&self, url: impl IntoUrl) -> RequestBuilder<Self, State, Self::ExecuteResponse> {
118 self.request(Method::PUT, url)
119 }
120
121 fn patch(&self, url: impl IntoUrl) -> RequestBuilder<Self, State, Self::ExecuteResponse> {
122 self.request(Method::PATCH, url)
123 }
124
125 fn delete(&self, url: impl IntoUrl) -> RequestBuilder<Self, State, Self::ExecuteResponse> {
126 self.request(Method::DELETE, url)
127 }
128
129 fn head(&self, url: impl IntoUrl) -> RequestBuilder<Self, State, Self::ExecuteResponse> {
130 self.request(Method::HEAD, url)
131 }
132
133 fn request(
134 &self,
135 method: Method,
136 url: impl IntoUrl,
137 ) -> RequestBuilder<Self, State, Self::ExecuteResponse> {
138 let uri = match url.into_url() {
139 Ok(uri) => uri,
140 Err(err) => {
141 return RequestBuilder {
142 http_client_service: self,
143 state: RequestBuilderState::Error(err),
144 _phantom: std::marker::PhantomData,
145 };
146 }
147 };
148
149 let builder = crate::dep::http::request::Builder::new()
150 .method(method)
151 .uri(uri);
152
153 RequestBuilder {
154 http_client_service: self,
155 state: RequestBuilderState::PreBody(builder),
156 _phantom: std::marker::PhantomData,
157 }
158 }
159
160 fn execute(
161 &self,
162 ctx: Context<State>,
163 request: Request,
164 ) -> impl Future<Output = Result<Self::ExecuteResponse, Self::ExecuteError>> {
165 Service::serve(self, ctx, request)
166 }
167}
168
169pub trait IntoUrl: private::IntoUrlSealed {}
175
176impl IntoUrl for Uri {}
177impl IntoUrl for &str {}
178impl IntoUrl for String {}
179impl IntoUrl for &String {}
180
181pub trait IntoHeaderName: private::IntoHeaderNameSealed {}
187
188impl IntoHeaderName for crate::HeaderName {}
189impl IntoHeaderName for Option<crate::HeaderName> {}
190impl IntoHeaderName for &str {}
191impl IntoHeaderName for String {}
192impl IntoHeaderName for &String {}
193impl IntoHeaderName for &[u8] {}
194
195pub trait IntoHeaderValue: private::IntoHeaderValueSealed {}
201
202impl IntoHeaderValue for crate::HeaderValue {}
203impl IntoHeaderValue for &str {}
204impl IntoHeaderValue for String {}
205impl IntoHeaderValue for &String {}
206impl IntoHeaderValue for &[u8] {}
207
208mod private {
209 use rama_http_types::HeaderName;
210 use rama_net::Protocol;
211
212 use super::*;
213
214 pub trait IntoUrlSealed {
215 fn into_url(self) -> Result<Uri, OpaqueError>;
216 }
217
218 impl IntoUrlSealed for Uri {
219 fn into_url(self) -> Result<Uri, OpaqueError> {
220 let protocol: Option<Protocol> = self.scheme().map(Into::into);
221 match protocol {
222 Some(protocol) => {
223 if protocol.is_http() {
224 Ok(self)
225 } else {
226 Err(OpaqueError::from_display(format!(
227 "Unsupported protocol: {protocol}"
228 )))
229 }
230 }
231 None => Err(OpaqueError::from_display("Missing scheme in URI")),
232 }
233 }
234 }
235
236 impl IntoUrlSealed for &str {
237 fn into_url(self) -> Result<Uri, OpaqueError> {
238 match self.parse::<Uri>() {
239 Ok(uri) => uri.into_url(),
240 Err(_) => Err(OpaqueError::from_display(format!("Invalid URL: {}", self))),
241 }
242 }
243 }
244
245 impl IntoUrlSealed for String {
246 fn into_url(self) -> Result<Uri, OpaqueError> {
247 self.as_str().into_url()
248 }
249 }
250
251 impl IntoUrlSealed for &String {
252 fn into_url(self) -> Result<Uri, OpaqueError> {
253 self.as_str().into_url()
254 }
255 }
256
257 pub trait IntoHeaderNameSealed {
258 fn into_header_name(self) -> Result<crate::HeaderName, OpaqueError>;
259 }
260
261 impl IntoHeaderNameSealed for HeaderName {
262 fn into_header_name(self) -> Result<crate::HeaderName, OpaqueError> {
263 Ok(self)
264 }
265 }
266
267 impl IntoHeaderNameSealed for Option<HeaderName> {
268 fn into_header_name(self) -> Result<crate::HeaderName, OpaqueError> {
269 match self {
270 Some(name) => Ok(name),
271 None => Err(OpaqueError::from_display("Header name is required")),
272 }
273 }
274 }
275
276 impl IntoHeaderNameSealed for &str {
277 fn into_header_name(self) -> Result<crate::HeaderName, OpaqueError> {
278 let name = self
279 .parse::<crate::HeaderName>()
280 .map_err(OpaqueError::from_std)?;
281 Ok(name)
282 }
283 }
284
285 impl IntoHeaderNameSealed for String {
286 fn into_header_name(self) -> Result<crate::HeaderName, OpaqueError> {
287 self.as_str().into_header_name()
288 }
289 }
290
291 impl IntoHeaderNameSealed for &String {
292 fn into_header_name(self) -> Result<crate::HeaderName, OpaqueError> {
293 self.as_str().into_header_name()
294 }
295 }
296
297 impl IntoHeaderNameSealed for &[u8] {
298 fn into_header_name(self) -> Result<crate::HeaderName, OpaqueError> {
299 let name = crate::HeaderName::from_bytes(self).map_err(OpaqueError::from_std)?;
300 Ok(name)
301 }
302 }
303
304 pub trait IntoHeaderValueSealed {
305 fn into_header_value(self) -> Result<crate::HeaderValue, OpaqueError>;
306 }
307
308 impl IntoHeaderValueSealed for crate::HeaderValue {
309 fn into_header_value(self) -> Result<crate::HeaderValue, OpaqueError> {
310 Ok(self)
311 }
312 }
313
314 impl IntoHeaderValueSealed for &str {
315 fn into_header_value(self) -> Result<crate::HeaderValue, OpaqueError> {
316 let value = self
317 .parse::<crate::HeaderValue>()
318 .map_err(OpaqueError::from_std)?;
319 Ok(value)
320 }
321 }
322
323 impl IntoHeaderValueSealed for String {
324 fn into_header_value(self) -> Result<crate::HeaderValue, OpaqueError> {
325 self.as_str().into_header_value()
326 }
327 }
328
329 impl IntoHeaderValueSealed for &String {
330 fn into_header_value(self) -> Result<crate::HeaderValue, OpaqueError> {
331 self.as_str().into_header_value()
332 }
333 }
334
335 impl IntoHeaderValueSealed for &[u8] {
336 fn into_header_value(self) -> Result<crate::HeaderValue, OpaqueError> {
337 let value = crate::HeaderValue::from_bytes(self).map_err(OpaqueError::from_std)?;
338 Ok(value)
339 }
340 }
341
342 pub trait HttpClientExtSealed<State> {}
343
344 impl<State, S, Body> HttpClientExtSealed<State> for S where
345 S: Service<State, Request, Response = Response<Body>, Error: Into<BoxError>>
346 {
347 }
348}
349
350pub struct RequestBuilder<'a, S, State, Response> {
354 http_client_service: &'a S,
355 state: RequestBuilderState,
356 _phantom: std::marker::PhantomData<fn(State, Response) -> ()>,
357}
358
359impl<S, State, Response> std::fmt::Debug for RequestBuilder<'_, S, State, Response>
360where
361 S: std::fmt::Debug,
362{
363 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
364 f.debug_struct("RequestBuilder")
365 .field("http_client_service", &self.http_client_service)
366 .field("state", &self.state)
367 .finish()
368 }
369}
370
371#[derive(Debug)]
372enum RequestBuilderState {
373 PreBody(crate::dep::http::request::Builder),
374 PostBody(crate::Request),
375 Error(OpaqueError),
376}
377
378impl<S, State, Body> RequestBuilder<'_, S, State, Response<Body>>
379where
380 S: Service<State, Request, Response = Response<Body>, Error: Into<BoxError>>,
381{
382 pub fn header<K, V>(mut self, key: K, value: V) -> Self
384 where
385 K: IntoHeaderName,
386 V: IntoHeaderValue,
387 {
388 match self.state {
389 RequestBuilderState::PreBody(builder) => {
390 let key = match key.into_header_name() {
391 Ok(key) => key,
392 Err(err) => {
393 self.state = RequestBuilderState::Error(err);
394 return self;
395 }
396 };
397 let value = match value.into_header_value() {
398 Ok(value) => value,
399 Err(err) => {
400 self.state = RequestBuilderState::Error(err);
401 return self;
402 }
403 };
404 self.state = RequestBuilderState::PreBody(builder.header(key, value));
405 self
406 }
407 RequestBuilderState::PostBody(mut request) => {
408 let key = match key.into_header_name() {
409 Ok(key) => key,
410 Err(err) => {
411 self.state = RequestBuilderState::Error(err);
412 return self;
413 }
414 };
415 let value = match value.into_header_value() {
416 Ok(value) => value,
417 Err(err) => {
418 self.state = RequestBuilderState::Error(err);
419 return self;
420 }
421 };
422 request.headers_mut().append(key, value);
423 self.state = RequestBuilderState::PostBody(request);
424 self
425 }
426 RequestBuilderState::Error(err) => {
427 self.state = RequestBuilderState::Error(err);
428 self
429 }
430 }
431 }
432
433 pub fn typed_header<H>(self, header: H) -> Self
437 where
438 H: crate::headers::Header,
439 {
440 self.header(H::name().clone(), header.encode_to_value())
441 }
442
443 pub fn headers(mut self, headers: crate::HeaderMap) -> Self {
447 for (key, value) in headers.into_iter() {
448 self = self.header(key, value);
449 }
450 self
451 }
452
453 pub fn basic_auth<U, P>(self, username: U, password: P) -> Self
455 where
456 U: Into<Cow<'static, str>>,
457 P: Into<Cow<'static, str>>,
458 {
459 let header = crate::headers::Authorization::basic(username, password);
460 self.typed_header(header)
461 }
462
463 pub fn bearer_auth<T>(mut self, token: T) -> Self
465 where
466 T: Into<Cow<'static, str>>,
467 {
468 let header = match crate::headers::Authorization::bearer(token) {
469 Ok(header) => header,
470 Err(err) => {
471 self.state = match self.state {
472 RequestBuilderState::Error(original_err) => {
473 RequestBuilderState::Error(original_err)
474 }
475 _ => RequestBuilderState::Error(OpaqueError::from_std(err)),
476 };
477 return self;
478 }
479 };
480
481 self.typed_header(header)
482 }
483
484 pub fn body<T>(mut self, body: T) -> Self
488 where
489 T: TryInto<crate::Body, Error: Into<BoxError>>,
490 {
491 self.state = match self.state {
492 RequestBuilderState::PreBody(builder) => match body.try_into() {
493 Ok(body) => match builder.body(body) {
494 Ok(req) => RequestBuilderState::PostBody(req),
495 Err(err) => RequestBuilderState::Error(OpaqueError::from_std(err)),
496 },
497 Err(err) => RequestBuilderState::Error(OpaqueError::from_boxed(err.into())),
498 },
499 RequestBuilderState::PostBody(mut req) => match body.try_into() {
500 Ok(body) => {
501 *req.body_mut() = body;
502 RequestBuilderState::PostBody(req)
503 }
504 Err(err) => RequestBuilderState::Error(OpaqueError::from_boxed(err.into())),
505 },
506 RequestBuilderState::Error(err) => RequestBuilderState::Error(err),
507 };
508 self
509 }
510
511 pub fn form<T: serde::Serialize + ?Sized>(mut self, form: &T) -> Self {
515 self.state = match self.state {
516 RequestBuilderState::PreBody(mut builder) => match serde_html_form::to_string(form) {
517 Ok(body) => {
518 let builder = match builder.headers_mut() {
519 Some(headers) => {
520 if !headers.contains_key(crate::header::CONTENT_TYPE) {
521 headers.insert(
522 crate::header::CONTENT_TYPE,
523 crate::HeaderValue::from_static(
524 "application/x-www-form-urlencoded",
525 ),
526 );
527 }
528 builder
529 }
530 None => builder.header(
531 crate::header::CONTENT_TYPE,
532 crate::HeaderValue::from_static("application/x-www-form-urlencoded"),
533 ),
534 };
535 match builder.body(body.into()) {
536 Ok(req) => RequestBuilderState::PostBody(req),
537 Err(err) => RequestBuilderState::Error(OpaqueError::from_std(err)),
538 }
539 }
540 Err(err) => RequestBuilderState::Error(OpaqueError::from_std(err)),
541 },
542 RequestBuilderState::PostBody(mut req) => match serde_html_form::to_string(form) {
543 Ok(body) => {
544 if !req.headers().contains_key(crate::header::CONTENT_TYPE) {
545 req.headers_mut().insert(
546 crate::header::CONTENT_TYPE,
547 crate::HeaderValue::from_static("application/x-www-form-urlencoded"),
548 );
549 }
550 *req.body_mut() = body.into();
551 RequestBuilderState::PostBody(req)
552 }
553 Err(err) => RequestBuilderState::Error(OpaqueError::from_std(err)),
554 },
555 RequestBuilderState::Error(err) => RequestBuilderState::Error(err),
556 };
557 self
558 }
559
560 pub fn json<T: serde::Serialize + ?Sized>(mut self, json: &T) -> Self {
564 self.state = match self.state {
565 RequestBuilderState::PreBody(mut builder) => match serde_json::to_vec(json) {
566 Ok(body) => {
567 let builder = match builder.headers_mut() {
568 Some(headers) => {
569 if !headers.contains_key(crate::header::CONTENT_TYPE) {
570 headers.insert(
571 crate::header::CONTENT_TYPE,
572 crate::HeaderValue::from_static("application/json"),
573 );
574 }
575 builder
576 }
577 None => builder.header(
578 crate::header::CONTENT_TYPE,
579 crate::HeaderValue::from_static("application/json"),
580 ),
581 };
582 match builder.body(body.into()) {
583 Ok(req) => RequestBuilderState::PostBody(req),
584 Err(err) => RequestBuilderState::Error(OpaqueError::from_std(err)),
585 }
586 }
587 Err(err) => RequestBuilderState::Error(OpaqueError::from_std(err)),
588 },
589 RequestBuilderState::PostBody(mut req) => match serde_json::to_vec(json) {
590 Ok(body) => {
591 if !req.headers().contains_key(crate::header::CONTENT_TYPE) {
592 req.headers_mut().insert(
593 crate::header::CONTENT_TYPE,
594 crate::HeaderValue::from_static("application/json"),
595 );
596 }
597 *req.body_mut() = body.into();
598 RequestBuilderState::PostBody(req)
599 }
600 Err(err) => RequestBuilderState::Error(OpaqueError::from_std(err)),
601 },
602 RequestBuilderState::Error(err) => RequestBuilderState::Error(err),
603 };
604 self
605 }
606
607 pub fn version(mut self, version: crate::Version) -> Self {
611 match self.state {
612 RequestBuilderState::PreBody(builder) => {
613 self.state = RequestBuilderState::PreBody(builder.version(version));
614 self
615 }
616 RequestBuilderState::PostBody(mut request) => {
617 *request.version_mut() = version;
618 self.state = RequestBuilderState::PostBody(request);
619 self
620 }
621 RequestBuilderState::Error(err) => {
622 self.state = RequestBuilderState::Error(err);
623 self
624 }
625 }
626 }
627
628 pub async fn send(self, ctx: Context<State>) -> Result<Response<Body>, OpaqueError> {
634 let request = match self.state {
635 RequestBuilderState::PreBody(builder) => builder
636 .body(crate::Body::empty())
637 .map_err(OpaqueError::from_std)?,
638 RequestBuilderState::PostBody(request) => request,
639 RequestBuilderState::Error(err) => return Err(err),
640 };
641
642 let uri = request.uri().clone();
643 match self.http_client_service.serve(ctx, request).await {
644 Ok(response) => Ok(response),
645 Err(err) => Err(OpaqueError::from_boxed(err.into()).context(uri.to_string())),
646 }
647 }
648}
649
650#[cfg(test)]
651mod test {
652 use rama_http_types::StatusCode;
653
654 use super::*;
655 use crate::{
656 layer::{
657 required_header::AddRequiredRequestHeadersLayer,
658 retry::{ManagedPolicy, RetryLayer},
659 trace::TraceLayer,
660 },
661 service::web::response::IntoResponse,
662 };
663 use rama_core::{
664 layer::{Layer, MapResultLayer},
665 service::{BoxService, service_fn},
666 };
667 use rama_utils::backoff::ExponentialBackoff;
668 use std::convert::Infallible;
669
670 async fn fake_client_fn<S, Body>(
671 _ctx: Context<S>,
672 request: Request<Body>,
673 ) -> Result<Response, Infallible>
674 where
675 S: Clone + Send + Sync + 'static,
676 Body: crate::dep::http_body::Body<Data: Send + 'static, Error: Send + 'static>
677 + Send
678 + 'static,
679 {
680 let ua = request.headers().get(crate::header::USER_AGENT).unwrap();
681 assert_eq!(
682 ua.to_str().unwrap(),
683 format!("{}/{}", rama_utils::info::NAME, rama_utils::info::VERSION)
684 );
685
686 Ok(StatusCode::OK.into_response())
687 }
688
689 fn map_internal_client_error<E, Body>(
690 result: Result<Response<Body>, E>,
691 ) -> Result<Response, rama_core::error::BoxError>
692 where
693 E: Into<rama_core::error::BoxError>,
694 Body: crate::dep::http_body::Body<Data = bytes::Bytes, Error: Into<BoxError>>
695 + Send
696 + Sync
697 + 'static,
698 {
699 match result {
700 Ok(response) => Ok(response.map(crate::Body::new)),
701 Err(err) => Err(err.into()),
702 }
703 }
704
705 type OpaqueError = rama_core::error::BoxError;
706 type HttpClient<S> = BoxService<S, Request, Response, OpaqueError>;
707
708 fn client<S: Clone + Send + Sync + 'static>() -> HttpClient<S> {
709 let builder = (
710 MapResultLayer::new(map_internal_client_error),
711 TraceLayer::new_for_http(),
712 );
713
714 #[cfg(feature = "compression")]
715 let builder = (
716 builder,
717 crate::layer::decompression::DecompressionLayer::new(),
718 );
719
720 (
721 builder,
722 RetryLayer::new(ManagedPolicy::default().with_backoff(ExponentialBackoff::default())),
723 AddRequiredRequestHeadersLayer::default(),
724 )
725 .into_layer(service_fn(fake_client_fn))
726 .boxed()
727 }
728
729 #[tokio::test]
730 async fn test_client_happy_path() {
731 let response = client()
732 .get("http://127.0.0.1:8080")
733 .send(Context::default())
734 .await
735 .unwrap();
736 assert_eq!(response.status(), StatusCode::OK);
737 }
738}