use std::borrow::Borrow;
use std::convert::{From, TryInto};
use std::fs;
use std::str;
use std::time::Duration;
use http::{
header::{
HeaderMap, HeaderValue, IntoHeaderName, ACCEPT, CONNECTION, CONTENT_LENGTH, CONTENT_TYPE, TRANSFER_ENCODING,
USER_AGENT,
},
Method,
};
use url::Url;
#[cfg(feature = "charsets")]
use crate::charsets::Charset;
use crate::error::{Error, ErrorKind, Result};
use crate::parsing::Response;
use crate::request::{
body::{self, Body, BodyKind},
header_append, header_insert, header_insert_if_missing,
proxy::ProxySettings,
BaseSettings, PreparedRequest,
};
use crate::tls::Certificate;
const DEFAULT_USER_AGENT: &str = concat!("attohttpc/", env!("CARGO_PKG_VERSION"));
#[derive(Debug)]
pub struct RequestBuilder<B = body::Empty> {
url: Url,
method: Method,
body: B,
base_settings: BaseSettings,
}
impl RequestBuilder {
pub fn new<U>(method: Method, base_url: U) -> Self
where
U: AsRef<str>,
{
Self::try_new(method, base_url).expect("invalid url or method")
}
pub fn try_new<U>(method: Method, base_url: U) -> Result<Self>
where
U: AsRef<str>,
{
Self::try_with_settings(method, base_url, BaseSettings::default())
}
pub(crate) fn with_settings<U>(method: Method, base_url: U, base_settings: BaseSettings) -> Self
where
U: AsRef<str>,
{
Self::try_with_settings(method, base_url, base_settings).expect("invalid url or method")
}
pub(crate) fn try_with_settings<U>(method: Method, base_url: U, base_settings: BaseSettings) -> Result<Self>
where
U: AsRef<str>,
{
let url = Url::parse(base_url.as_ref()).map_err(|_| ErrorKind::InvalidBaseUrl)?;
if method == Method::CONNECT {
return Err(ErrorKind::ConnectNotSupported.into());
}
Ok(Self {
url,
method,
body: body::Empty,
base_settings,
})
}
}
impl<B> RequestBuilder<B> {
pub fn param<K, V>(mut self, key: K, value: V) -> Self
where
K: AsRef<str>,
V: ToString,
{
self.url.query_pairs_mut().append_pair(key.as_ref(), &value.to_string());
self
}
pub fn params<P, K, V>(mut self, pairs: P) -> Self
where
P: IntoIterator,
P::Item: Borrow<(K, V)>,
K: AsRef<str>,
V: ToString,
{
for pair in pairs.into_iter() {
let (key, value) = pair.borrow();
self.url.query_pairs_mut().append_pair(key.as_ref(), &value.to_string());
}
self
}
#[cfg(all(
feature = "tls",
not(any(target_os = "windows", target_os = "macos", target_os = "ios"))
))]
pub fn basic_auth(self, username: impl std::fmt::Display, password: Option<impl std::fmt::Display>) -> Self {
let auth = match password {
Some(password) => format!("{}:{}", username, password),
None => format!("{}:", username),
};
self.header(
http::header::AUTHORIZATION,
format!("Basic {}", openssl::base64::encode_block(auth.as_bytes())),
)
}
pub fn bearer_auth(self, token: impl Into<String>) -> Self {
self.header(http::header::AUTHORIZATION, format!("Bearer {}", token.into()))
}
pub fn body<B1: Body>(self, body: B1) -> RequestBuilder<B1> {
RequestBuilder {
url: self.url,
method: self.method,
body,
base_settings: self.base_settings,
}
}
pub fn text<B1: AsRef<str>>(mut self, body: B1) -> RequestBuilder<body::Text<B1>> {
self.base_settings
.headers
.entry(http::header::CONTENT_TYPE)
.or_insert(HeaderValue::from_static("text/plain; charset=utf-8"));
self.body(body::Text(body))
}
pub fn bytes<B1: AsRef<[u8]>>(mut self, body: B1) -> RequestBuilder<body::Bytes<B1>> {
self.base_settings
.headers
.entry(http::header::CONTENT_TYPE)
.or_insert(HeaderValue::from_static("application/octet-stream"));
self.body(body::Bytes(body))
}
pub fn file(mut self, body: fs::File) -> RequestBuilder<body::File> {
self.base_settings
.headers
.entry(http::header::CONTENT_TYPE)
.or_insert(HeaderValue::from_static("application/octet-stream"));
self.body(body::File(body))
}
#[cfg(feature = "json")]
pub fn json<T: serde::Serialize>(mut self, value: &T) -> Result<RequestBuilder<body::Bytes<Vec<u8>>>> {
let body = serde_json::to_vec(value)?;
self.base_settings
.headers
.entry(http::header::CONTENT_TYPE)
.or_insert(HeaderValue::from_static("application/json; charset=utf-8"));
Ok(self.body(body::Bytes(body)))
}
#[cfg(feature = "json")]
pub fn json_streaming<T: serde::Serialize>(mut self, value: T) -> RequestBuilder<body::Json<T>> {
self.base_settings
.headers
.entry(http::header::CONTENT_TYPE)
.or_insert(HeaderValue::from_static("application/json; charset=utf-8"));
self.body(body::Json(value))
}
#[cfg(feature = "form")]
pub fn form<T: serde::Serialize>(mut self, value: &T) -> Result<RequestBuilder<body::Bytes<Vec<u8>>>> {
let body = serde_urlencoded::to_string(value)?.into_bytes();
self.base_settings
.headers
.entry(http::header::CONTENT_TYPE)
.or_insert(HeaderValue::from_static("application/x-www-form-urlencoded"));
Ok(self.body(body::Bytes(body)))
}
pub fn header<H, V>(self, header: H, value: V) -> Self
where
H: IntoHeaderName,
V: TryInto<HeaderValue>,
Error: From<V::Error>,
{
self.try_header(header, value).expect("invalid header value")
}
pub fn header_append<H, V>(self, header: H, value: V) -> Self
where
H: IntoHeaderName,
V: TryInto<HeaderValue>,
Error: From<V::Error>,
{
self.try_header_append(header, value).expect("invalid header value")
}
pub fn try_header<H, V>(mut self, header: H, value: V) -> Result<Self>
where
H: IntoHeaderName,
V: TryInto<HeaderValue>,
Error: From<V::Error>,
{
header_insert(&mut self.base_settings.headers, header, value)?;
Ok(self)
}
pub fn try_header_append<H, V>(mut self, header: H, value: V) -> Result<Self>
where
H: IntoHeaderName,
V: TryInto<HeaderValue>,
Error: From<V::Error>,
{
header_append(&mut self.base_settings.headers, header, value)?;
Ok(self)
}
pub fn max_headers(mut self, max_headers: usize) -> Self {
self.base_settings.max_headers = max_headers;
self
}
pub fn max_redirections(mut self, max_redirections: u32) -> Self {
self.base_settings.max_redirections = max_redirections;
self
}
pub fn follow_redirects(mut self, follow_redirects: bool) -> Self {
self.base_settings.follow_redirects = follow_redirects;
self
}
pub fn connect_timeout(mut self, duration: Duration) -> Self {
self.base_settings.connect_timeout = duration;
self
}
pub fn read_timeout(mut self, duration: Duration) -> Self {
self.base_settings.read_timeout = duration;
self
}
pub fn timeout(mut self, duration: Duration) -> Self {
self.base_settings.timeout = Some(duration);
self
}
pub fn proxy_settings(mut self, settings: ProxySettings) -> Self {
self.base_settings.proxy_settings = settings;
self
}
#[cfg(feature = "charsets")]
pub fn default_charset(mut self, default_charset: Option<Charset>) -> Self {
self.base_settings.default_charset = default_charset;
self
}
#[cfg(feature = "compress")]
pub fn allow_compression(mut self, allow_compression: bool) -> Self {
self.base_settings.allow_compression = allow_compression;
self
}
pub fn danger_accept_invalid_certs(mut self, accept_invalid_certs: bool) -> Self {
self.base_settings.accept_invalid_certs = accept_invalid_certs;
self
}
pub fn danger_accept_invalid_hostnames(mut self, accept_invalid_hostnames: bool) -> Self {
self.base_settings.accept_invalid_hostnames = accept_invalid_hostnames;
self
}
pub fn add_root_certificate(mut self, cert: Certificate) -> Self {
self.base_settings.root_certificates.0.push(cert);
self
}
}
impl<B: Body> RequestBuilder<B> {
pub fn prepare(self) -> PreparedRequest<B> {
self.try_prepare().expect("failed to prepare request")
}
pub fn try_prepare(self) -> Result<PreparedRequest<B>> {
let mut prepped = PreparedRequest {
url: self.url,
method: self.method,
body: self.body,
base_settings: self.base_settings,
};
header_insert(&mut prepped.base_settings.headers, CONNECTION, "close")?;
prepped.set_compression()?;
match prepped.body.kind()? {
BodyKind::Empty => (),
BodyKind::KnownLength(len) => {
header_insert(&mut prepped.base_settings.headers, CONTENT_LENGTH, len)?;
}
BodyKind::Chunked => {
header_insert(&mut prepped.base_settings.headers, TRANSFER_ENCODING, "chunked")?;
}
}
if let Some(typ) = prepped.body.content_type()? {
header_insert(&mut prepped.base_settings.headers, CONTENT_TYPE, typ)?;
}
header_insert_if_missing(&mut prepped.base_settings.headers, ACCEPT, "*/*")?;
header_insert_if_missing(&mut prepped.base_settings.headers, USER_AGENT, DEFAULT_USER_AGENT)?;
Ok(prepped)
}
pub fn send(self) -> Result<Response> {
self.try_prepare()?.send()
}
}
impl<B> RequestBuilder<B> {
pub fn inspect(&mut self) -> RequestInspector<'_, B> {
RequestInspector(self)
}
}
#[derive(Debug)]
pub struct RequestInspector<'a, B>(&'a mut RequestBuilder<B>);
impl<B> RequestInspector<'_, B> {
pub fn url(&self) -> &Url {
&self.0.url
}
pub fn method(&self) -> &Method {
&self.0.method
}
pub fn body(&mut self) -> &mut B {
&mut self.0.body
}
pub fn headers(&self) -> &HeaderMap {
&self.0.base_settings.headers
}
}
#[test]
#[cfg(feature = "tls")]
fn test_accept_invalid_certs_disabled_by_default() {
let builder = RequestBuilder::new(Method::GET, "https://localhost:7900");
assert_eq!(builder.base_settings.accept_invalid_certs, false);
assert_eq!(builder.base_settings.accept_invalid_hostnames, false);
let prepped = builder.prepare();
assert_eq!(prepped.base_settings.accept_invalid_certs, false);
assert_eq!(prepped.base_settings.accept_invalid_hostnames, false);
}
#[cfg(test)]
mod tests {
use super::*;
use http::header::HeaderMap;
#[test]
fn test_header_insert_exists() {
let mut headers = HeaderMap::new();
headers.insert(USER_AGENT, HeaderValue::from_static("hello"));
header_insert(&mut headers, USER_AGENT, "world").unwrap();
assert_eq!(headers[USER_AGENT], "world");
}
#[test]
fn test_header_insert_missing() {
let mut headers = HeaderMap::new();
header_insert(&mut headers, USER_AGENT, "world").unwrap();
assert_eq!(headers[USER_AGENT], "world");
}
#[test]
fn test_header_insert_if_missing_exists() {
let mut headers = HeaderMap::new();
headers.insert(USER_AGENT, HeaderValue::from_static("hello"));
header_insert_if_missing(&mut headers, USER_AGENT, "world").unwrap();
assert_eq!(headers[USER_AGENT], "hello");
}
#[test]
fn test_header_insert_if_missing_missing() {
let mut headers = HeaderMap::new();
header_insert_if_missing(&mut headers, USER_AGENT, "world").unwrap();
assert_eq!(headers[USER_AGENT], "world");
}
#[test]
fn test_header_append() {
let mut headers = HeaderMap::new();
header_append(&mut headers, USER_AGENT, "hello").unwrap();
header_append(&mut headers, USER_AGENT, "world").unwrap();
let vals: Vec<_> = headers.get_all(USER_AGENT).into_iter().collect();
assert_eq!(vals.len(), 2);
for val in vals {
assert!(val == "hello" || val == "world");
}
}
#[test]
fn test_request_builder_param() {
let prepped = RequestBuilder::new(Method::GET, "http://localhost:1337/foo")
.param("qux", "baz")
.prepare();
assert_eq!(prepped.url().as_str(), "http://localhost:1337/foo?qux=baz");
}
#[test]
fn test_request_builder_params() {
let prepped = RequestBuilder::new(Method::GET, "http://localhost:1337/foo")
.params(&[("qux", "baz"), ("foo", "bar")])
.prepare();
assert_eq!(prepped.url().as_str(), "http://localhost:1337/foo?qux=baz&foo=bar");
}
#[test]
fn test_request_builder_header_insert() {
let prepped = RequestBuilder::new(Method::GET, "http://localhost:1337/foo")
.header("hello", "world")
.prepare();
assert_eq!(prepped.headers()["hello"], "world");
}
#[test]
fn test_request_builder_header_append() {
let prepped = RequestBuilder::new(Method::GET, "http://localhost:1337/foo")
.header_append("hello", "world")
.header_append("hello", "!!!")
.prepare();
let vals: Vec<_> = prepped.headers().get_all("hello").into_iter().collect();
assert_eq!(vals.len(), 2);
for val in vals {
assert!(val == "world" || val == "!!!");
}
}
#[cfg(feature = "compress")]
fn assert_request_content(
builder: RequestBuilder,
status_line: &str,
mut header_lines: Vec<&str>,
body_lines: &[&str],
) {
let mut buf = Vec::new();
let mut prepped = builder.prepare();
prepped
.write_request(&mut buf, &prepped.url().clone(), None)
.expect("error writing request");
let text = std::str::from_utf8(&buf).expect("cannot decode request as utf-8");
let lines: Vec<_> = text.lines().collect();
let req_status_line = lines[0];
let empty_line_pos = lines
.iter()
.position(|l| l.is_empty())
.expect("no empty line in request");
let mut req_header_lines = lines[1..empty_line_pos].to_vec();
let req_body_lines = &lines[empty_line_pos + 1..];
req_header_lines.sort_unstable();
header_lines.sort_unstable();
assert_eq!(req_status_line, status_line);
assert_eq!(req_header_lines, header_lines);
assert_eq!(req_body_lines, body_lines);
}
#[test]
#[cfg(feature = "compress")]
fn test_request_builder_write_request_no_query() {
assert_request_content(
RequestBuilder::new(Method::GET, "http://localhost:1337/foo"),
"GET /foo HTTP/1.1",
vec![
"connection: close",
"accept-encoding: gzip, deflate",
"accept: */*",
&format!("user-agent: {}", DEFAULT_USER_AGENT),
],
&[],
);
}
#[test]
#[cfg(feature = "compress")]
fn test_request_builder_write_request_with_query() {
assert_request_content(
RequestBuilder::new(Method::GET, "http://localhost:1337/foo").param("hello", "world"),
"GET /foo?hello=world HTTP/1.1",
vec![
"connection: close",
"accept-encoding: gzip, deflate",
"accept: */*",
&format!("user-agent: {}", DEFAULT_USER_AGENT),
],
&[],
);
}
#[test]
fn test_prepare_default_headers() {
let prepped = RequestBuilder::new(Method::GET, "http://localhost:1337/foo/qux/baz").prepare();
assert_eq!(prepped.headers()[ACCEPT], "*/*");
assert_eq!(prepped.headers()[USER_AGENT], DEFAULT_USER_AGENT);
}
#[test]
fn test_prepare_custom_headers() {
let prepped = RequestBuilder::new(Method::GET, "http://localhost:1337/foo/qux/baz")
.header(USER_AGENT, "foobaz")
.header("Accept", "nothing")
.prepare();
assert_eq!(prepped.headers()[ACCEPT], "nothing");
assert_eq!(prepped.headers()[USER_AGENT], "foobaz");
}
}