async_fetch/
request.rs

1use std::fmt;
2use std::collections::HashMap;
3use std::collections::hash_map::RandomState;
4use std::io::{Error, ErrorKind};
5use std::str::FromStr;
6use url::{Url, Position};
7use async_std::io::{Read, Write};
8use async_uninet::{SocketAddr, Stream};
9use async_httplib::{read_first_line, parse_version, parse_status, read_header_line,
10    write_slice, write_all, write_exact, write_chunks, flush_write};
11use crate::{Method, Version, Response, read_content_length};
12
13#[derive(Debug)]
14pub struct Request {
15    url: Url,
16    method: Method,
17    version: Version,
18    headers: HashMap<String, String>,
19    relay: Option<String>,
20    body_limit: Option<usize>,
21}
22
23impl Request {
24
25    pub fn default() -> Self {
26        Self {
27            url: Url::parse("http://localhost").unwrap(),
28            method: Method::Get,
29            version: Version::Http1_1,
30            headers: HashMap::with_hasher(RandomState::new()),
31            relay: None,
32            body_limit: None,
33        }
34    }
35
36    pub fn parse_url<U>(url: U) -> Result<Self, Error>
37        where
38        U: Into<String>,
39    {
40        let mut req = Request::default();
41        req.set_url_str(url.into())?;
42        Ok(req)
43    }
44
45    pub fn url(&self) -> &Url {
46        &self.url
47    }
48
49    fn scheme(&self) -> &str {
50        self.url.scheme()
51    }
52
53    fn host(&self) -> &str {
54        match self.url.host_str() {
55            Some(host) => host,
56            None => "localhost",
57        }
58    }
59
60    fn port(&self) -> u16 {
61        match self.url.port_or_known_default() {
62            Some(port) => port,
63            None => 80,
64        }
65    }
66
67    fn host_with_port(&self) -> String {
68        format!("{}:{}", self.host(), self.port())
69    }
70
71    fn socket_address(&self) -> String {
72        match &self.relay {
73            Some(relay) => relay.to_string(),
74            None => self.host_with_port(),
75        }
76    }
77
78    fn uri(&self) -> &str {
79        &self.url[Position::BeforePath..]
80    }
81
82    pub fn method(&self) -> &Method {
83        &self.method
84    }
85
86    pub fn version(&self) -> &Version {
87        &self.version
88    }
89
90    pub fn headers(&self) -> &HashMap<String, String> {
91        &self.headers
92    }
93
94    pub fn header<N: Into<String>>(&self, name: N) -> Option<&String> {
95        self.headers.get(&name.into())
96    }
97
98    pub fn relay(&self) -> &Option<String> {
99        &self.relay
100    }
101
102    pub fn body_limit(&self) -> &Option<usize> {
103        &self.body_limit
104    }
105
106    pub fn headers_mut(&mut self) -> &mut HashMap<String, String> {
107        &mut self.headers
108    }
109
110    pub fn has_method(&self, value: Method) -> bool {
111        self.method == value
112    }
113
114    pub fn has_version(&self, value: Version) -> bool {
115        self.version == value
116    }
117
118    pub fn has_header<N: Into<String>>(&self, name: N) -> bool {
119        self.headers.contains_key(&name.into())
120    }
121
122    pub fn has_body_limit(&self) -> bool {
123        self.body_limit.is_some()
124    }
125
126    pub fn set_url(&mut self, value: Url) {
127        self.url = value;
128    }
129
130    pub fn set_url_str<V: Into<String>>(&mut self, value: V) -> Result<(), Error> {
131        self.url = match Url::parse(&value.into()) {
132            Ok(url) => url,
133            Err(e) => return Err(Error::new(ErrorKind::InvalidInput, e.to_string())),
134        };
135        Ok(())
136    }
137
138    pub fn set_method(&mut self, value: Method) {
139        self.method = value;
140    }
141
142    pub fn set_method_str(&mut self, value: &str) -> Result<(), Error> {
143        self.method = Method::from_str(value)?;
144        Ok(())
145    }
146
147    pub fn set_version(&mut self, value: Version) {
148        self.version = value;
149    }
150
151    pub fn set_version_str(&mut self, value: &str) -> Result<(), Error> {
152        self.version = Version::from_str(value)?;
153        Ok(())
154    }
155
156    pub fn set_header<N: Into<String>, V: Into<String>>(&mut self, name: N, value: V) {
157        self.headers.insert(name.into(), value.into());
158    }
159
160    pub fn set_relay<V: Into<String>>(&mut self, value: V) {
161        self.relay = Some(value.into());
162    }
163
164    pub fn set_body_limit(&mut self, length: usize) {
165        self.body_limit = Some(length);
166    }
167
168    pub fn remove_header<N: Into<String>>(&mut self, name: N) {
169        self.headers.remove(&name.into());
170    }
171
172    pub fn remove_relay(&mut self) {
173        self.relay = None;
174    }
175
176    pub fn clear_headers(&mut self) {
177        self.headers.clear();
178    }
179
180    pub fn to_proto_string(&self) -> String {
181        let mut output = String::new();
182
183        match self.version {
184            Version::Http0_9 => {
185                output.push_str(&format!("GET {}\r\n", self.uri()));
186            },
187            _ => {
188                output.push_str(&format!("{} {} {}\r\n", self.method(), self.uri(), self.version()));
189                for (name, value) in self.headers.iter() {
190                    output.push_str(&format!("{}: {}\r\n", name, value));
191                }
192                output.push_str("\r\n");
193            },
194        };
195
196        output
197    }
198
199    pub async fn send<'a>(&mut self) -> Result<Response<'a>, Error> {
200        self.update_host_header();
201
202        match self.scheme() {
203            "http" => self.send_http(&mut "".as_bytes()).await,
204            "https" => self.send_https(&mut "".as_bytes()).await,
205            s => Err(Error::new(ErrorKind::InvalidInput, format!("The URL scheme `{}` is invalid.", s))),
206        }
207    }
208
209    pub async fn send_stream<'a, R>(&mut self, body: &mut R) -> Result<Response<'a>, Error>
210        where
211        R: Read + Send + Unpin,
212    {
213        self.update_host_header();
214        self.update_body_headers();
215        
216        match self.scheme() {
217            "http" => self.send_http(body).await,
218            "https" => self.send_https(body).await,
219            s => Err(Error::new(ErrorKind::InvalidInput, format!("The URL scheme `{}` is invalid.", s))),
220        }
221    }
222
223    pub async fn send_slice<'a>(&mut self, body: &[u8]) -> Result<Response<'a>, Error> {
224        self.set_header("Content-Length", body.len().to_string());
225        self.send_stream(&mut body.clone()).await
226    }
227
228    pub async fn send_str<'a>(&mut self, body: &str) -> Result<Response<'a>, Error> {
229        self.set_header("Content-Length", body.len().to_string());
230        self.send_stream(&mut body.as_bytes()).await
231    }
232
233    #[cfg(feature = "json")]
234    pub async fn send_json<'a>(&mut self, body: &serde_json::Value) -> Result<Response<'a>, Error> {
235        let body = body.to_string();
236        self.set_header("Content-Length", body.len().to_string());
237        self.send_stream(&mut body.as_bytes()).await
238    }
239
240    pub async fn send_http<'a, R>(&mut self, body: &mut R) -> Result<Response<'a>, Error>
241        where
242        R: Read + Send + Unpin,
243    {
244        let mut stream = self.build_conn().await?;
245        self.write_request(&mut stream, body).await?;
246        self.build_response(stream).await
247    }
248
249    pub async fn send_https<'a, R>(&mut self, body: &mut R) -> Result<Response<'a>, Error>
250        where
251        R: Read + Send + Unpin,
252    {
253        let stream = self.build_conn().await?;
254
255        let mut stream = match async_native_tls::connect(self.host(), stream).await {
256            Ok(stream) => stream,
257            Err(e) => return Err(Error::new(ErrorKind::Interrupted, e.to_string())),
258        };
259
260        self.write_request(&mut stream, body).await?;
261        self.build_response(stream).await
262    }
263
264    fn update_host_header(&mut self) {
265        if self.version >= Version::Http1_1 && !self.has_header("Host") {
266            self.set_header("Host", self.host_with_port());
267        }
268    }
269
270    fn update_body_headers(&mut self) {
271        if self.version >= Version::Http0_9 && self.method.has_body() && !self.has_header("Content-Length") {
272            self.set_header("Transfer-Encoding", "chunked");
273        }
274    }
275
276    async fn write_request<S, R>(&self, stream: &mut S, body: &mut R) -> Result<(), Error>
277        where
278        S: Write + Unpin,
279        R: Read + Send + Unpin,
280    {
281        self.write_proto(stream).await?;
282        self.write_body(stream, body).await
283    }
284
285    async fn write_proto<S>(&self, stream: &mut S) -> Result<(), Error>
286        where
287        S: Write + Unpin,
288    {
289        write_slice(stream, self.to_string().as_bytes()).await?;
290        flush_write(stream).await
291    }
292
293    async fn write_body<S, R>(&self, stream: &mut S, body: &mut R) -> Result<(), Error>
294        where
295        S: Write + Unpin,
296        R: Read + Send + Unpin,
297    {
298        if self.has_version(Version::Http0_9) {
299            write_all(stream, body, self.body_limit).await?;
300        } else if self.has_header("Content-Length") { // exact
301            write_exact(stream, body, read_content_length(&self.headers, self.body_limit)?).await?;
302        } else { // chunked
303            write_chunks(stream, body, (Some(1024), self.body_limit)).await?;
304        }
305        flush_write(stream).await
306    }
307
308    async fn build_conn(&mut self) -> Result<Stream, Error> {
309        let addr = self.socket_address();
310
311        match SocketAddr::from_str(&addr).await {
312            Ok(addr) => Stream::connect(&addr).await,
313            Err(_) => Err(Error::new(ErrorKind::AddrNotAvailable, format!("The address `{}` is invalid.", addr))),
314        }
315    }
316
317    async fn build_response<'a, S>(&mut self, mut stream: S) -> Result<Response<'a>, Error>
318        where
319        S: Read + Send + Unpin + 'a,
320    {
321        let mut res: Response<'a> = Response::default();
322
323        let (mut version, mut status, mut message) = (vec![], vec![], vec![]);
324        read_first_line(&mut stream, (&mut version, &mut status, &mut message), None).await?;
325        res.set_version(parse_version(version)?);
326        res.set_status(parse_status(status)?);
327    
328        loop {
329            let (mut name, mut value) = (vec![], vec![]);
330            read_header_line(&mut stream, (&mut name, &mut value), None).await?;
331            
332            if name.is_empty() {
333                break;
334            }
335
336            res.set_header(
337                match String::from_utf8(name) {
338                    Ok(name) => name,
339                    Err(_) => return Err(Error::new(ErrorKind::InvalidData, format!("The response header `#{}` is invalid.", res.headers().len()))),
340                },
341                match String::from_utf8(value) {
342                    Ok(value) => value,
343                    Err(_) => return Err(Error::new(ErrorKind::InvalidData, format!("The response header `#{}` is invalid.", res.headers().len()))),
344                },
345            );
346        }
347
348        res.set_reader(stream);
349        Ok(res)
350    }
351}
352
353impl fmt::Display for Request {
354    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
355        write!(fmt, "{}", self.to_proto_string())
356    }
357}
358
359impl From<Request> for String {
360    fn from(item: Request) -> String {
361        item.to_string()
362    }
363}