1use std::any::Any;
2
3use bytes::Bytes;
4use http::{HeaderName, HeaderValue, Method, Request, Uri, Version};
5use http::header::{CONNECTION, HOST, TE, TRANSFER_ENCODING, UPGRADE};
6use http::request::Builder;
7#[cfg(not(feature = "hyper-tls"))]
8use monoio_http::common::body::HttpBody;
9
10#[cfg(any(feature = "hyper", feature = "pool-hyper", feature = "hyper-tls"))]
11use crate::hyper::client::MonoioHyperClient;
12#[cfg(any(feature = "hyper", feature = "pool-hyper", feature = "hyper-tls"))]
13use crate::hyper::hyper_body::HyperBody;
14#[cfg(not(feature = "hyper-tls"))]
15use super::http::{client::MonoioClient, monoio_body::MonoioBody};
16use super::{response::HttpResponse, error::Error};
17
18const PROHIBITED_HEADERS: [HeaderName; 5] = [
19 CONNECTION,
20 HeaderName::from_static("keep-alive"),
21 TE,
22 TRANSFER_ENCODING,
23 UPGRADE,
24];
25
26pub trait RequestBody {
27 type Body;
28
29 fn create_body(bytes: Option<Bytes>) -> Self::Body;
30}
31
32pub struct HttpRequest<C> {
33 client: C,
34 builder: Builder,
35}
36
37impl<C> HttpRequest<C> {
38 pub(crate) fn new(client: C) -> HttpRequest<C> {
39 HttpRequest {
40 client,
41 builder: Builder::default(),
42 }
43 }
44
45 pub fn set_uri<T>(mut self, uri: T) -> Self
53 where
54 Uri: TryFrom<T>,
55 <Uri as TryFrom<T>>::Error: Into<http::Error>,
56 {
57 self.builder = self.builder.uri(uri);
58 self
59 }
60
61 pub fn set_method<T>(mut self, method: T) -> Self
69 where
70 Method: TryFrom<T>,
71 <Method as TryFrom<T>>::Error: Into<http::Error>,
72 {
73 self.builder = self.builder.method(method);
74 self
75 }
76
77 pub fn set_header<K, T>(mut self, key: K, value: T) -> Self
86 where
87 HeaderName: TryFrom<K>,
88 <HeaderName as TryFrom<K>>::Error: Into<http::Error>,
89 HeaderValue: TryFrom<T>,
90 <HeaderValue as TryFrom<T>>::Error: Into<http::Error>,
91 {
92 self.builder = self.builder.header(key, value);
93 self
94 }
95
96 pub fn set_version(mut self, version: Version) -> Self {
104 self.builder = self.builder.version(version);
105 self
106 }
107
108 pub fn set_extension<T>(mut self, extension: T) -> Self
112 where
113 T: Clone + Any + Send + Sync + 'static,
114 {
115 self.builder = self.builder.extension(extension);
116 self
117 }
118
119 fn build_request<B: RequestBody>(
120 builder: Builder,
121 body: Option<Bytes>,
122 ) -> Result<(Request<B::Body>, Uri), Error> {
123 let mut request = builder
124 .body(B::create_body(body))
125 .map_err(Error::HttpRequestBuilder)?;
126
127 let uri = request.uri().clone();
128
129 match request.version() {
133 Version::HTTP_2 | Version::HTTP_3 => {
134 let headers = request.headers_mut();
135 for header in PROHIBITED_HEADERS.iter() {
136 headers.remove(header);
137 }
138 }
139 _ => {
140 if let Some(host) = uri.host() {
141 let host = HeaderValue::try_from(host).map_err(Error::InvalidHeaderValue)?;
142 let headers = request.headers_mut();
143 if !headers.contains_key(HOST) {
144 headers.insert(HOST, host);
145 }
146 }
147 }
148 }
149
150 Ok((request, uri))
151 }
152}
153
154#[cfg(not(feature = "hyper-tls"))]
155impl HttpRequest<MonoioClient> {
156 pub async fn send(self) -> Result<HttpResponse<HttpBody>, Error> {
159 self.send_body(None).await
160 }
161
162 pub async fn send_body(self, body: impl Into<Option<Bytes>>) -> Result<HttpResponse<HttpBody>, Error> {
171 let (req, uri) = Self::build_request::<MonoioBody>(self.builder, body.into())?;
172 let response = self.client.send_request(req, uri).await?;
173 Ok(HttpResponse::new(response))
174 }
175}
176
177#[cfg(any(feature = "hyper", feature = "pool-hyper", feature = "hyper-tls"))]
178impl HttpRequest<MonoioHyperClient> {
179 pub async fn send(self) -> Result<HttpResponse<Bytes>, Error> { self.send_body(None).await }
182
183 pub async fn send_body(self, body: impl Into<Option<Bytes>>) -> Result<HttpResponse<Bytes>, Error> {
192 let (req, uri) = Self::build_request::<HyperBody>(self.builder, body.into())?;
193 let response = self.client.send_request(req, uri).await?;
194 HttpResponse::hyper_new(response).await
195 }
196}