1use std::convert::{From, TryInto};
2use std::io::{prelude::*, BufWriter};
3use std::str;
4use std::sync::Arc;
5use std::time::Instant;
6
7#[cfg(feature = "flate2")]
8use http::header::ACCEPT_ENCODING;
9use http::{
10 header::{HeaderValue, IntoHeaderName, HOST},
11 HeaderMap, Method, StatusCode, Version,
12};
13use url::Url;
14
15use crate::error::{Error, ErrorKind, InvalidResponseKind, Result};
16use crate::parsing::{parse_response, Response};
17use crate::streams::{BaseStream, ConnectInfo};
18
19pub mod body;
21mod builder;
22pub mod proxy;
23mod session;
24mod settings;
25
26use body::{Body, BodyKind};
27pub use builder::{RequestBuilder, RequestInspector};
28pub use session::Session;
29pub(crate) use settings::BaseSettings;
30
31fn header_insert<H, V>(headers: &mut HeaderMap, header: H, value: V) -> Result
32where
33 H: IntoHeaderName,
34 V: TryInto<HeaderValue>,
35 Error: From<V::Error>,
36{
37 let value = value.try_into()?;
38 headers.insert(header, value);
39 Ok(())
40}
41
42fn header_insert_if_missing<H, V>(headers: &mut HeaderMap, header: H, value: V) -> Result
43where
44 H: IntoHeaderName,
45 V: TryInto<HeaderValue>,
46 Error: From<V::Error>,
47{
48 let value = value.try_into()?;
49 headers.entry(header).or_insert(value);
50 Ok(())
51}
52
53fn header_append<H, V>(headers: &mut HeaderMap, header: H, value: V) -> Result
54where
55 H: IntoHeaderName,
56 V: TryInto<HeaderValue>,
57 Error: From<V::Error>,
58{
59 let value = value.try_into()?;
60 headers.append(header, value);
61 Ok(())
62}
63
64#[derive(Debug)]
66pub struct PreparedRequest<B> {
67 url: Url,
68 method: Method,
69 body: B,
70 headers: HeaderMap,
71 pub(crate) base_settings: Arc<BaseSettings>,
72}
73
74#[cfg(test)]
75impl PreparedRequest<body::Empty> {
76 pub(crate) fn new<U>(method: Method, base_url: U) -> Self
77 where
78 U: AsRef<str>,
79 {
80 PreparedRequest {
81 url: Url::parse(base_url.as_ref()).unwrap(),
82 method,
83 body: body::Empty,
84 headers: HeaderMap::new(),
85 base_settings: Arc::new(BaseSettings::default()),
86 }
87 }
88}
89
90impl<B> PreparedRequest<B> {
91 #[cfg(not(feature = "flate2"))]
92 fn set_compression(&mut self) -> Result {
93 Ok(())
94 }
95
96 #[cfg(feature = "flate2")]
97 fn set_compression(&mut self) -> Result {
98 if self.base_settings.allow_compression {
99 header_insert(&mut self.headers, ACCEPT_ENCODING, "gzip, deflate")?;
100 }
101 Ok(())
102 }
103
104 fn base_redirect_url(&self, location: &str, previous_url: &Url) -> Result<Url> {
105 match Url::parse(location) {
106 Ok(url) => Ok(url),
107 Err(url::ParseError::RelativeUrlWithoutBase) => {
108 let joined_url = previous_url
109 .join(location)
110 .map_err(|_| InvalidResponseKind::RedirectionUrl)?;
111
112 Ok(joined_url)
113 }
114 Err(_) => Err(InvalidResponseKind::RedirectionUrl.into()),
115 }
116 }
117
118 fn write_headers<W>(&self, writer: &mut W) -> Result
119 where
120 W: Write,
121 {
122 for (key, value) in self.headers.iter() {
123 write!(writer, "{}: ", key.as_str())?;
124 writer.write_all(value.as_bytes())?;
125 write!(writer, "\r\n")?;
126 }
127 write!(writer, "\r\n")?;
128 Ok(())
129 }
130
131 pub fn url(&self) -> &Url {
133 &self.url
134 }
135
136 pub fn method(&self) -> &Method {
138 &self.method
139 }
140
141 pub fn body(&self) -> &B {
143 &self.body
144 }
145
146 pub fn headers(&self) -> &HeaderMap {
148 &self.headers
149 }
150}
151
152impl<B: Body> PreparedRequest<B> {
153 fn write_request<W>(&mut self, writer: W, url: &Url, proxy: Option<&Url>) -> Result
154 where
155 W: Write,
156 {
157 let mut writer = BufWriter::new(writer);
158 let version = Version::HTTP_11;
159
160 if proxy.is_some() && url.scheme() == "http" {
161 debug!("{} {} {:?}", self.method.as_str(), url, version);
162
163 write!(writer, "{} {} {:?}\r\n", self.method.as_str(), url, version)?;
164 } else if let Some(query) = url.query() {
165 debug!("{} {}?{} {:?}", self.method.as_str(), url.path(), query, version);
166
167 write!(
168 writer,
169 "{} {}?{} {:?}\r\n",
170 self.method.as_str(),
171 url.path(),
172 query,
173 version,
174 )?;
175 } else {
176 debug!("{} {} {:?}", self.method.as_str(), url.path(), version);
177
178 write!(writer, "{} {} {:?}\r\n", self.method.as_str(), url.path(), version)?;
179 }
180
181 self.write_headers(&mut writer)?;
182
183 match self.body.kind()? {
184 BodyKind::Empty => (),
185 BodyKind::KnownLength(len) => {
186 debug!("writing out body of length {}", len);
187 self.body.write(&mut writer)?;
188 }
189 BodyKind::Chunked => {
190 debug!("writing out chunked body");
191 let mut writer = body::ChunkedWriter(&mut writer);
192 self.body.write(&mut writer)?;
193 writer.close()?;
194 }
195 }
196
197 writer.flush()?;
198
199 Ok(())
200 }
201
202 pub fn send(&mut self) -> Result<Response> {
204 let mut url = self.url.clone();
205
206 let deadline = self.base_settings.timeout.map(|timeout| Instant::now() + timeout);
207 let mut redirections = 0;
208
209 loop {
210 let proxy = self.base_settings.proxy_settings.for_url(&url).cloned();
218
219 match (url.scheme(), &proxy) {
221 ("http", Some(proxy)) => set_host(&mut self.headers, proxy)?,
222 _ => set_host(&mut self.headers, &url)?,
223 };
224
225 let info = ConnectInfo {
226 url: &url,
227 proxy: proxy.as_ref(),
228 base_settings: &self.base_settings,
229 deadline,
230 };
231 let mut stream = BaseStream::connect(&info)?;
232
233 self.write_request(&mut stream, &url, proxy.as_ref())?;
234 let resp = parse_response(stream, self, &url)?;
235
236 debug!("status code {}", resp.status().as_u16());
237
238 let is_redirect = matches!(
239 resp.status(),
240 StatusCode::MOVED_PERMANENTLY
241 | StatusCode::FOUND
242 | StatusCode::SEE_OTHER
243 | StatusCode::TEMPORARY_REDIRECT
244 | StatusCode::PERMANENT_REDIRECT
245 );
246 if !self.base_settings.follow_redirects || !is_redirect {
247 return Ok(resp);
248 }
249
250 redirections += 1;
251 if redirections > self.base_settings.max_redirections {
252 return Err(ErrorKind::TooManyRedirections.into());
253 }
254
255 let location = resp
257 .headers()
258 .get(http::header::LOCATION)
259 .ok_or(InvalidResponseKind::LocationHeader)?;
260
261 let location = String::from_utf8_lossy(location.as_bytes());
262
263 url = self.base_redirect_url(&location, &url)?;
264
265 debug!("redirected to {} giving url {}", location, url);
266 }
267 }
268}
269
270fn set_host(headers: &mut HeaderMap, url: &Url) -> Result {
271 let host = url.host_str().ok_or(ErrorKind::InvalidUrlHost)?;
272 if let Some(port) = url.port() {
273 header_insert(headers, HOST, format!("{host}:{port}"))?;
274 } else {
275 header_insert(headers, HOST, host)?;
276 }
277 Ok(())
278}
279
280#[cfg(test)]
281mod test {
282 use std::sync::Arc;
283
284 use http::header::{HeaderMap, HeaderValue, USER_AGENT};
285 use http::Method;
286 use url::Url;
287
288 use super::BaseSettings;
289 use super::{header_append, header_insert, header_insert_if_missing, PreparedRequest};
290 use crate::body::Empty;
291
292 #[test]
293 fn test_header_insert_exists() {
294 let mut headers = HeaderMap::new();
295 headers.insert(USER_AGENT, HeaderValue::from_static("hello"));
296 header_insert(&mut headers, USER_AGENT, "world").unwrap();
297 assert_eq!(headers[USER_AGENT], "world");
298 }
299
300 #[test]
301 fn test_header_insert_missing() {
302 let mut headers = HeaderMap::new();
303 header_insert(&mut headers, USER_AGENT, "world").unwrap();
304 assert_eq!(headers[USER_AGENT], "world");
305 }
306
307 #[test]
308 fn test_header_insert_if_missing_exists() {
309 let mut headers = HeaderMap::new();
310 headers.insert(USER_AGENT, HeaderValue::from_static("hello"));
311 header_insert_if_missing(&mut headers, USER_AGENT, "world").unwrap();
312 assert_eq!(headers[USER_AGENT], "hello");
313 }
314
315 #[test]
316 fn test_header_insert_if_missing_missing() {
317 let mut headers = HeaderMap::new();
318 header_insert_if_missing(&mut headers, USER_AGENT, "world").unwrap();
319 assert_eq!(headers[USER_AGENT], "world");
320 }
321
322 #[test]
323 fn test_header_append() {
324 let mut headers = HeaderMap::new();
325 header_append(&mut headers, USER_AGENT, "hello").unwrap();
326 header_append(&mut headers, USER_AGENT, "world").unwrap();
327
328 let vals: Vec<_> = headers.get_all(USER_AGENT).into_iter().collect();
329 assert_eq!(vals.len(), 2);
330 for val in vals {
331 assert!(val == "hello" || val == "world");
332 }
333 }
334
335 #[test]
336 fn test_http_url_with_http_proxy() {
337 let mut req = PreparedRequest {
338 method: Method::GET,
339 url: Url::parse("http://reddit.com/r/rust").unwrap(),
340 body: Empty,
341 headers: HeaderMap::new(),
342 base_settings: Arc::new(BaseSettings::default()),
343 };
344
345 let proxy = Url::parse("http://proxy:3128").unwrap();
346 let mut buf: Vec<u8> = vec![];
347 req.write_request(&mut buf, &req.url.clone(), Some(&proxy)).unwrap();
348
349 let text = std::str::from_utf8(&buf).unwrap();
350 let lines: Vec<_> = text.split("\r\n").collect();
351
352 assert_eq!(lines[0], "GET http://reddit.com/r/rust HTTP/1.1");
353 }
354
355 #[test]
356 fn test_http_url_with_https_proxy() {
357 let mut req = PreparedRequest {
358 method: Method::GET,
359 url: Url::parse("http://reddit.com/r/rust").unwrap(),
360 body: Empty,
361 headers: HeaderMap::new(),
362 base_settings: Arc::new(BaseSettings::default()),
363 };
364
365 let proxy = Url::parse("http://proxy:3128").unwrap();
366 let mut buf: Vec<u8> = vec![];
367 req.write_request(&mut buf, &req.url.clone(), Some(&proxy)).unwrap();
368
369 let text = std::str::from_utf8(&buf).unwrap();
370 let lines: Vec<_> = text.split("\r\n").collect();
371
372 assert_eq!(lines[0], "GET http://reddit.com/r/rust HTTP/1.1");
373 }
374}