1use std::fmt;
2use std::collections::HashMap;
3use std::collections::hash_map::RandomState;
4use std::io::{Error, ErrorKind};
5use std::str::FromStr;
6use url::{Url, Position};
7use async_std::io::{Read, Write};
8use async_uninet::{SocketAddr, Stream};
9use async_httplib::{read_first_line, parse_version, parse_status, read_header_line,
10 write_slice, write_all, write_exact, write_chunks, flush_write};
11use crate::{Method, Version, Response, read_content_length};
12
13#[derive(Debug)]
14pub struct Request {
15 url: Url,
16 method: Method,
17 version: Version,
18 headers: HashMap<String, String>,
19 relay: Option<String>,
20 body_limit: Option<usize>,
21}
22
23impl Request {
24
25 pub fn default() -> Self {
26 Self {
27 url: Url::parse("http://localhost").unwrap(),
28 method: Method::Get,
29 version: Version::Http1_1,
30 headers: HashMap::with_hasher(RandomState::new()),
31 relay: None,
32 body_limit: None,
33 }
34 }
35
36 pub fn parse_url<U>(url: U) -> Result<Self, Error>
37 where
38 U: Into<String>,
39 {
40 let mut req = Request::default();
41 req.set_url_str(url.into())?;
42 Ok(req)
43 }
44
45 pub fn url(&self) -> &Url {
46 &self.url
47 }
48
49 fn scheme(&self) -> &str {
50 self.url.scheme()
51 }
52
53 fn host(&self) -> &str {
54 match self.url.host_str() {
55 Some(host) => host,
56 None => "localhost",
57 }
58 }
59
60 fn port(&self) -> u16 {
61 match self.url.port_or_known_default() {
62 Some(port) => port,
63 None => 80,
64 }
65 }
66
67 fn host_with_port(&self) -> String {
68 format!("{}:{}", self.host(), self.port())
69 }
70
71 fn socket_address(&self) -> String {
72 match &self.relay {
73 Some(relay) => relay.to_string(),
74 None => self.host_with_port(),
75 }
76 }
77
78 fn uri(&self) -> &str {
79 &self.url[Position::BeforePath..]
80 }
81
82 pub fn method(&self) -> &Method {
83 &self.method
84 }
85
86 pub fn version(&self) -> &Version {
87 &self.version
88 }
89
90 pub fn headers(&self) -> &HashMap<String, String> {
91 &self.headers
92 }
93
94 pub fn header<N: Into<String>>(&self, name: N) -> Option<&String> {
95 self.headers.get(&name.into())
96 }
97
98 pub fn relay(&self) -> &Option<String> {
99 &self.relay
100 }
101
102 pub fn body_limit(&self) -> &Option<usize> {
103 &self.body_limit
104 }
105
106 pub fn headers_mut(&mut self) -> &mut HashMap<String, String> {
107 &mut self.headers
108 }
109
110 pub fn has_method(&self, value: Method) -> bool {
111 self.method == value
112 }
113
114 pub fn has_version(&self, value: Version) -> bool {
115 self.version == value
116 }
117
118 pub fn has_header<N: Into<String>>(&self, name: N) -> bool {
119 self.headers.contains_key(&name.into())
120 }
121
122 pub fn has_body_limit(&self) -> bool {
123 self.body_limit.is_some()
124 }
125
126 pub fn set_url(&mut self, value: Url) {
127 self.url = value;
128 }
129
130 pub fn set_url_str<V: Into<String>>(&mut self, value: V) -> Result<(), Error> {
131 self.url = match Url::parse(&value.into()) {
132 Ok(url) => url,
133 Err(e) => return Err(Error::new(ErrorKind::InvalidInput, e.to_string())),
134 };
135 Ok(())
136 }
137
138 pub fn set_method(&mut self, value: Method) {
139 self.method = value;
140 }
141
142 pub fn set_method_str(&mut self, value: &str) -> Result<(), Error> {
143 self.method = Method::from_str(value)?;
144 Ok(())
145 }
146
147 pub fn set_version(&mut self, value: Version) {
148 self.version = value;
149 }
150
151 pub fn set_version_str(&mut self, value: &str) -> Result<(), Error> {
152 self.version = Version::from_str(value)?;
153 Ok(())
154 }
155
156 pub fn set_header<N: Into<String>, V: Into<String>>(&mut self, name: N, value: V) {
157 self.headers.insert(name.into(), value.into());
158 }
159
160 pub fn set_relay<V: Into<String>>(&mut self, value: V) {
161 self.relay = Some(value.into());
162 }
163
164 pub fn set_body_limit(&mut self, length: usize) {
165 self.body_limit = Some(length);
166 }
167
168 pub fn remove_header<N: Into<String>>(&mut self, name: N) {
169 self.headers.remove(&name.into());
170 }
171
172 pub fn remove_relay(&mut self) {
173 self.relay = None;
174 }
175
176 pub fn clear_headers(&mut self) {
177 self.headers.clear();
178 }
179
180 pub fn to_proto_string(&self) -> String {
181 let mut output = String::new();
182
183 match self.version {
184 Version::Http0_9 => {
185 output.push_str(&format!("GET {}\r\n", self.uri()));
186 },
187 _ => {
188 output.push_str(&format!("{} {} {}\r\n", self.method(), self.uri(), self.version()));
189 for (name, value) in self.headers.iter() {
190 output.push_str(&format!("{}: {}\r\n", name, value));
191 }
192 output.push_str("\r\n");
193 },
194 };
195
196 output
197 }
198
199 pub async fn send<'a>(&mut self) -> Result<Response<'a>, Error> {
200 self.update_host_header();
201
202 match self.scheme() {
203 "http" => self.send_http(&mut "".as_bytes()).await,
204 "https" => self.send_https(&mut "".as_bytes()).await,
205 s => Err(Error::new(ErrorKind::InvalidInput, format!("The URL scheme `{}` is invalid.", s))),
206 }
207 }
208
209 pub async fn send_stream<'a, R>(&mut self, body: &mut R) -> Result<Response<'a>, Error>
210 where
211 R: Read + Send + Unpin,
212 {
213 self.update_host_header();
214 self.update_body_headers();
215
216 match self.scheme() {
217 "http" => self.send_http(body).await,
218 "https" => self.send_https(body).await,
219 s => Err(Error::new(ErrorKind::InvalidInput, format!("The URL scheme `{}` is invalid.", s))),
220 }
221 }
222
223 pub async fn send_slice<'a>(&mut self, body: &[u8]) -> Result<Response<'a>, Error> {
224 self.set_header("Content-Length", body.len().to_string());
225 self.send_stream(&mut body.clone()).await
226 }
227
228 pub async fn send_str<'a>(&mut self, body: &str) -> Result<Response<'a>, Error> {
229 self.set_header("Content-Length", body.len().to_string());
230 self.send_stream(&mut body.as_bytes()).await
231 }
232
233 #[cfg(feature = "json")]
234 pub async fn send_json<'a>(&mut self, body: &serde_json::Value) -> Result<Response<'a>, Error> {
235 let body = body.to_string();
236 self.set_header("Content-Length", body.len().to_string());
237 self.send_stream(&mut body.as_bytes()).await
238 }
239
240 pub async fn send_http<'a, R>(&mut self, body: &mut R) -> Result<Response<'a>, Error>
241 where
242 R: Read + Send + Unpin,
243 {
244 let mut stream = self.build_conn().await?;
245 self.write_request(&mut stream, body).await?;
246 self.build_response(stream).await
247 }
248
249 pub async fn send_https<'a, R>(&mut self, body: &mut R) -> Result<Response<'a>, Error>
250 where
251 R: Read + Send + Unpin,
252 {
253 let stream = self.build_conn().await?;
254
255 let mut stream = match async_native_tls::connect(self.host(), stream).await {
256 Ok(stream) => stream,
257 Err(e) => return Err(Error::new(ErrorKind::Interrupted, e.to_string())),
258 };
259
260 self.write_request(&mut stream, body).await?;
261 self.build_response(stream).await
262 }
263
264 fn update_host_header(&mut self) {
265 if self.version >= Version::Http1_1 && !self.has_header("Host") {
266 self.set_header("Host", self.host_with_port());
267 }
268 }
269
270 fn update_body_headers(&mut self) {
271 if self.version >= Version::Http0_9 && self.method.has_body() && !self.has_header("Content-Length") {
272 self.set_header("Transfer-Encoding", "chunked");
273 }
274 }
275
276 async fn write_request<S, R>(&self, stream: &mut S, body: &mut R) -> Result<(), Error>
277 where
278 S: Write + Unpin,
279 R: Read + Send + Unpin,
280 {
281 self.write_proto(stream).await?;
282 self.write_body(stream, body).await
283 }
284
285 async fn write_proto<S>(&self, stream: &mut S) -> Result<(), Error>
286 where
287 S: Write + Unpin,
288 {
289 write_slice(stream, self.to_string().as_bytes()).await?;
290 flush_write(stream).await
291 }
292
293 async fn write_body<S, R>(&self, stream: &mut S, body: &mut R) -> Result<(), Error>
294 where
295 S: Write + Unpin,
296 R: Read + Send + Unpin,
297 {
298 if self.has_version(Version::Http0_9) {
299 write_all(stream, body, self.body_limit).await?;
300 } else if self.has_header("Content-Length") { write_exact(stream, body, read_content_length(&self.headers, self.body_limit)?).await?;
302 } else { write_chunks(stream, body, (Some(1024), self.body_limit)).await?;
304 }
305 flush_write(stream).await
306 }
307
308 async fn build_conn(&mut self) -> Result<Stream, Error> {
309 let addr = self.socket_address();
310
311 match SocketAddr::from_str(&addr).await {
312 Ok(addr) => Stream::connect(&addr).await,
313 Err(_) => Err(Error::new(ErrorKind::AddrNotAvailable, format!("The address `{}` is invalid.", addr))),
314 }
315 }
316
317 async fn build_response<'a, S>(&mut self, mut stream: S) -> Result<Response<'a>, Error>
318 where
319 S: Read + Send + Unpin + 'a,
320 {
321 let mut res: Response<'a> = Response::default();
322
323 let (mut version, mut status, mut message) = (vec![], vec![], vec![]);
324 read_first_line(&mut stream, (&mut version, &mut status, &mut message), None).await?;
325 res.set_version(parse_version(version)?);
326 res.set_status(parse_status(status)?);
327
328 loop {
329 let (mut name, mut value) = (vec![], vec![]);
330 read_header_line(&mut stream, (&mut name, &mut value), None).await?;
331
332 if name.is_empty() {
333 break;
334 }
335
336 res.set_header(
337 match String::from_utf8(name) {
338 Ok(name) => name,
339 Err(_) => return Err(Error::new(ErrorKind::InvalidData, format!("The response header `#{}` is invalid.", res.headers().len()))),
340 },
341 match String::from_utf8(value) {
342 Ok(value) => value,
343 Err(_) => return Err(Error::new(ErrorKind::InvalidData, format!("The response header `#{}` is invalid.", res.headers().len()))),
344 },
345 );
346 }
347
348 res.set_reader(stream);
349 Ok(res)
350 }
351}
352
353impl fmt::Display for Request {
354 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
355 write!(fmt, "{}", self.to_proto_string())
356 }
357}
358
359impl From<Request> for String {
360 fn from(item: Request) -> String {
361 item.to_string()
362 }
363}