a2httpc/request/
mod.rs

1use std::convert::{From, TryInto};
2use std::io::{prelude::*, BufWriter};
3use std::str;
4use std::sync::Arc;
5use std::time::Instant;
6
7#[cfg(feature = "flate2")]
8use http::header::ACCEPT_ENCODING;
9use http::{
10    header::{HeaderValue, IntoHeaderName, HOST},
11    HeaderMap, Method, StatusCode, Version,
12};
13use url::Url;
14
15use crate::error::{Error, ErrorKind, InvalidResponseKind, Result};
16use crate::parsing::{parse_response, Response};
17use crate::streams::{BaseStream, ConnectInfo};
18
19/// Contains types to describe request bodies
20pub mod body;
21mod builder;
22pub mod proxy;
23mod session;
24mod settings;
25
26use body::{Body, BodyKind};
27pub use builder::{RequestBuilder, RequestInspector};
28pub use session::Session;
29pub(crate) use settings::BaseSettings;
30
31fn header_insert<H, V>(headers: &mut HeaderMap, header: H, value: V) -> Result
32where
33    H: IntoHeaderName,
34    V: TryInto<HeaderValue>,
35    Error: From<V::Error>,
36{
37    let value = value.try_into()?;
38    headers.insert(header, value);
39    Ok(())
40}
41
42fn header_insert_if_missing<H, V>(headers: &mut HeaderMap, header: H, value: V) -> Result
43where
44    H: IntoHeaderName,
45    V: TryInto<HeaderValue>,
46    Error: From<V::Error>,
47{
48    let value = value.try_into()?;
49    headers.entry(header).or_insert(value);
50    Ok(())
51}
52
53fn header_append<H, V>(headers: &mut HeaderMap, header: H, value: V) -> Result
54where
55    H: IntoHeaderName,
56    V: TryInto<HeaderValue>,
57    Error: From<V::Error>,
58{
59    let value = value.try_into()?;
60    headers.append(header, value);
61    Ok(())
62}
63
64/// Represents a request that's ready to be sent. You can inspect this object for information about the request.
65#[derive(Debug)]
66pub struct PreparedRequest<B> {
67    url: Url,
68    method: Method,
69    body: B,
70    headers: HeaderMap,
71    pub(crate) base_settings: Arc<BaseSettings>,
72}
73
74#[cfg(test)]
75impl PreparedRequest<body::Empty> {
76    pub(crate) fn new<U>(method: Method, base_url: U) -> Self
77    where
78        U: AsRef<str>,
79    {
80        PreparedRequest {
81            url: Url::parse(base_url.as_ref()).unwrap(),
82            method,
83            body: body::Empty,
84            headers: HeaderMap::new(),
85            base_settings: Arc::new(BaseSettings::default()),
86        }
87    }
88}
89
90impl<B> PreparedRequest<B> {
91    #[cfg(not(feature = "flate2"))]
92    fn set_compression(&mut self) -> Result {
93        Ok(())
94    }
95
96    #[cfg(feature = "flate2")]
97    fn set_compression(&mut self) -> Result {
98        if self.base_settings.allow_compression {
99            header_insert(&mut self.headers, ACCEPT_ENCODING, "gzip, deflate")?;
100        }
101        Ok(())
102    }
103
104    fn base_redirect_url(&self, location: &str, previous_url: &Url) -> Result<Url> {
105        match Url::parse(location) {
106            Ok(url) => Ok(url),
107            Err(url::ParseError::RelativeUrlWithoutBase) => {
108                let joined_url = previous_url
109                    .join(location)
110                    .map_err(|_| InvalidResponseKind::RedirectionUrl)?;
111
112                Ok(joined_url)
113            }
114            Err(_) => Err(InvalidResponseKind::RedirectionUrl.into()),
115        }
116    }
117
118    fn write_headers<W>(&self, writer: &mut W) -> Result
119    where
120        W: Write,
121    {
122        for (key, value) in self.headers.iter() {
123            write!(writer, "{}: ", key.as_str())?;
124            writer.write_all(value.as_bytes())?;
125            write!(writer, "\r\n")?;
126        }
127        write!(writer, "\r\n")?;
128        Ok(())
129    }
130
131    /// Get the URL of this request.
132    pub fn url(&self) -> &Url {
133        &self.url
134    }
135
136    /// Get the method of this request.
137    pub fn method(&self) -> &Method {
138        &self.method
139    }
140
141    /// Get the body of the request.
142    pub fn body(&self) -> &B {
143        &self.body
144    }
145
146    /// Get the headers of this request.
147    pub fn headers(&self) -> &HeaderMap {
148        &self.headers
149    }
150}
151
152impl<B: Body> PreparedRequest<B> {
153    fn write_request<W>(&mut self, writer: W, url: &Url, proxy: Option<&Url>) -> Result
154    where
155        W: Write,
156    {
157        let mut writer = BufWriter::new(writer);
158        let version = Version::HTTP_11;
159
160        if proxy.is_some() && url.scheme() == "http" {
161            debug!("{} {} {:?}", self.method.as_str(), url, version);
162
163            write!(writer, "{} {} {:?}\r\n", self.method.as_str(), url, version)?;
164        } else if let Some(query) = url.query() {
165            debug!("{} {}?{} {:?}", self.method.as_str(), url.path(), query, version);
166
167            write!(
168                writer,
169                "{} {}?{} {:?}\r\n",
170                self.method.as_str(),
171                url.path(),
172                query,
173                version,
174            )?;
175        } else {
176            debug!("{} {} {:?}", self.method.as_str(), url.path(), version);
177
178            write!(writer, "{} {} {:?}\r\n", self.method.as_str(), url.path(), version)?;
179        }
180
181        self.write_headers(&mut writer)?;
182
183        match self.body.kind()? {
184            BodyKind::Empty => (),
185            BodyKind::KnownLength(len) => {
186                debug!("writing out body of length {}", len);
187                self.body.write(&mut writer)?;
188            }
189            BodyKind::Chunked => {
190                debug!("writing out chunked body");
191                let mut writer = body::ChunkedWriter(&mut writer);
192                self.body.write(&mut writer)?;
193                writer.close()?;
194            }
195        }
196
197        writer.flush()?;
198
199        Ok(())
200    }
201
202    /// Send this request and wait for the result.
203    pub fn send(&mut self) -> Result<Response> {
204        let mut url = self.url.clone();
205
206        let deadline = self.base_settings.timeout.map(|timeout| Instant::now() + timeout);
207        let mut redirections = 0;
208
209        loop {
210            // If a proxy is set and the url is using http, we must connect to the proxy and send
211            // a request with an authority instead of a path.
212            //
213            // If a proxy is set and the url is using https, we must connect to the proxy using
214            // the CONNECT method, and then send https traffic on the socket after the CONNECT
215            // handshake.
216
217            let proxy = self.base_settings.proxy_settings.for_url(&url).cloned();
218
219            // If there is a proxy and the protocol is HTTP, the Host header will be the proxy's host name.
220            match (url.scheme(), &proxy) {
221                ("http", Some(proxy)) => set_host(&mut self.headers, proxy)?,
222                _ => set_host(&mut self.headers, &url)?,
223            };
224
225            let info = ConnectInfo {
226                url: &url,
227                proxy: proxy.as_ref(),
228                base_settings: &self.base_settings,
229                deadline,
230            };
231            let mut stream = BaseStream::connect(&info)?;
232
233            self.write_request(&mut stream, &url, proxy.as_ref())?;
234            let resp = parse_response(stream, self, &url)?;
235
236            debug!("status code {}", resp.status().as_u16());
237
238            let is_redirect = matches!(
239                resp.status(),
240                StatusCode::MOVED_PERMANENTLY
241                    | StatusCode::FOUND
242                    | StatusCode::SEE_OTHER
243                    | StatusCode::TEMPORARY_REDIRECT
244                    | StatusCode::PERMANENT_REDIRECT
245            );
246            if !self.base_settings.follow_redirects || !is_redirect {
247                return Ok(resp);
248            }
249
250            redirections += 1;
251            if redirections > self.base_settings.max_redirections {
252                return Err(ErrorKind::TooManyRedirections.into());
253            }
254
255            // Handle redirect
256            let location = resp
257                .headers()
258                .get(http::header::LOCATION)
259                .ok_or(InvalidResponseKind::LocationHeader)?;
260
261            let location = String::from_utf8_lossy(location.as_bytes());
262
263            url = self.base_redirect_url(&location, &url)?;
264
265            debug!("redirected to {} giving url {}", location, url);
266        }
267    }
268}
269
270fn set_host(headers: &mut HeaderMap, url: &Url) -> Result {
271    let host = url.host_str().ok_or(ErrorKind::InvalidUrlHost)?;
272    if let Some(port) = url.port() {
273        header_insert(headers, HOST, format!("{host}:{port}"))?;
274    } else {
275        header_insert(headers, HOST, host)?;
276    }
277    Ok(())
278}
279
280#[cfg(test)]
281mod test {
282    use std::sync::Arc;
283
284    use http::header::{HeaderMap, HeaderValue, USER_AGENT};
285    use http::Method;
286    use url::Url;
287
288    use super::BaseSettings;
289    use super::{header_append, header_insert, header_insert_if_missing, PreparedRequest};
290    use crate::body::Empty;
291
292    #[test]
293    fn test_header_insert_exists() {
294        let mut headers = HeaderMap::new();
295        headers.insert(USER_AGENT, HeaderValue::from_static("hello"));
296        header_insert(&mut headers, USER_AGENT, "world").unwrap();
297        assert_eq!(headers[USER_AGENT], "world");
298    }
299
300    #[test]
301    fn test_header_insert_missing() {
302        let mut headers = HeaderMap::new();
303        header_insert(&mut headers, USER_AGENT, "world").unwrap();
304        assert_eq!(headers[USER_AGENT], "world");
305    }
306
307    #[test]
308    fn test_header_insert_if_missing_exists() {
309        let mut headers = HeaderMap::new();
310        headers.insert(USER_AGENT, HeaderValue::from_static("hello"));
311        header_insert_if_missing(&mut headers, USER_AGENT, "world").unwrap();
312        assert_eq!(headers[USER_AGENT], "hello");
313    }
314
315    #[test]
316    fn test_header_insert_if_missing_missing() {
317        let mut headers = HeaderMap::new();
318        header_insert_if_missing(&mut headers, USER_AGENT, "world").unwrap();
319        assert_eq!(headers[USER_AGENT], "world");
320    }
321
322    #[test]
323    fn test_header_append() {
324        let mut headers = HeaderMap::new();
325        header_append(&mut headers, USER_AGENT, "hello").unwrap();
326        header_append(&mut headers, USER_AGENT, "world").unwrap();
327
328        let vals: Vec<_> = headers.get_all(USER_AGENT).into_iter().collect();
329        assert_eq!(vals.len(), 2);
330        for val in vals {
331            assert!(val == "hello" || val == "world");
332        }
333    }
334
335    #[test]
336    fn test_http_url_with_http_proxy() {
337        let mut req = PreparedRequest {
338            method: Method::GET,
339            url: Url::parse("http://reddit.com/r/rust").unwrap(),
340            body: Empty,
341            headers: HeaderMap::new(),
342            base_settings: Arc::new(BaseSettings::default()),
343        };
344
345        let proxy = Url::parse("http://proxy:3128").unwrap();
346        let mut buf: Vec<u8> = vec![];
347        req.write_request(&mut buf, &req.url.clone(), Some(&proxy)).unwrap();
348
349        let text = std::str::from_utf8(&buf).unwrap();
350        let lines: Vec<_> = text.split("\r\n").collect();
351
352        assert_eq!(lines[0], "GET http://reddit.com/r/rust HTTP/1.1");
353    }
354
355    #[test]
356    fn test_http_url_with_https_proxy() {
357        let mut req = PreparedRequest {
358            method: Method::GET,
359            url: Url::parse("http://reddit.com/r/rust").unwrap(),
360            body: Empty,
361            headers: HeaderMap::new(),
362            base_settings: Arc::new(BaseSettings::default()),
363        };
364
365        let proxy = Url::parse("http://proxy:3128").unwrap();
366        let mut buf: Vec<u8> = vec![];
367        req.write_request(&mut buf, &req.url.clone(), Some(&proxy)).unwrap();
368
369        let text = std::str::from_utf8(&buf).unwrap();
370        let lines: Vec<_> = text.split("\r\n").collect();
371
372        assert_eq!(lines[0], "GET http://reddit.com/r/rust HTTP/1.1");
373    }
374}