use alloc::collections::BTreeMap;
use core::fmt;
#[cfg(feature = "std")]
use core::fmt::Write;
use core::time::Duration;
#[cfg(feature = "std")]
use std::env;
#[cfg(feature = "std")]
use std::time::Instant;
#[cfg(feature = "async")]
use crate::connection::AsyncConnection;
#[cfg(feature = "std")]
use crate::connection::Connection;
#[cfg(feature = "proxy")]
use crate::proxy::Proxy;
#[cfg(feature = "std")]
use crate::url::Url;
#[cfg(feature = "std")]
use crate::{Error, Response, ResponseLazy};
pub type URL = String;
#[derive(Clone, PartialEq, Eq, Debug)]
pub enum Method {
Get,
Head,
Post,
Put,
Delete,
Connect,
Options,
Trace,
Patch,
Custom(String),
}
impl fmt::Display for Method {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Method::Get => write!(f, "GET"),
Method::Head => write!(f, "HEAD"),
Method::Post => write!(f, "POST"),
Method::Put => write!(f, "PUT"),
Method::Delete => write!(f, "DELETE"),
Method::Connect => write!(f, "CONNECT"),
Method::Options => write!(f, "OPTIONS"),
Method::Trace => write!(f, "TRACE"),
Method::Patch => write!(f, "PATCH"),
Method::Custom(ref s) => write!(f, "{}", s),
}
}
}
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct Request {
pub(crate) method: Method,
url: URL,
params: Vec<(String, String)>,
headers: BTreeMap<String, String>,
body: Option<Vec<u8>>,
timeout: Option<u64>,
pub(crate) pipelining: bool,
pub(crate) max_headers_size: Option<usize>,
pub(crate) max_status_line_len: Option<usize>,
pub(crate) max_body_size: Option<usize>,
max_redirects: usize,
#[cfg(feature = "proxy")]
pub(crate) proxy: Option<Proxy>,
}
impl Request {
pub fn new<T: Into<URL>>(method: Method, url: T) -> Request {
Request {
method,
url: url.into(),
params: Vec::new(),
headers: BTreeMap::new(),
body: None,
timeout: None,
pipelining: false,
max_headers_size: Some(256 * 1024),
max_status_line_len: Some(64 * 1024),
max_body_size: Some(1024 * 1024 * 1024),
max_redirects: 100,
#[cfg(feature = "proxy")]
proxy: None,
}
}
pub fn with_headers<T, K, V>(mut self, headers: T) -> Request
where
T: IntoIterator<Item = (K, V)>,
K: Into<String>,
V: Into<String>,
{
let headers = headers.into_iter().map(|(k, v)| (k.into(), v.into()));
self.headers.extend(headers);
self
}
pub fn with_header<T: Into<String>, U: Into<String>>(mut self, key: T, value: U) -> Request {
self.headers.insert(key.into(), value.into());
self
}
pub fn with_body<T: Into<Vec<u8>>>(mut self, body: T) -> Request {
let body = body.into();
let body_length = body.len();
self.body = Some(body);
self.with_header("Content-Length", format!("{}", body_length))
}
pub fn with_param<T: Into<String>, U: Into<String>>(mut self, key: T, value: U) -> Request {
self.params.push((key.into(), value.into()));
self
}
#[cfg(feature = "json-using-serde")]
pub fn with_json<T: serde::ser::Serialize>(mut self, body: &T) -> Result<Request, Error> {
self.headers
.insert("Content-Type".to_string(), "application/json; charset=UTF-8".to_string());
match serde_json::to_vec(&body) {
Ok(json) => Ok(self.with_body(json)),
Err(err) => Err(Error::SerdeJsonError(err)),
}
}
pub fn with_timeout(mut self, timeout: u64) -> Request {
self.timeout = Some(timeout);
self
}
pub fn with_max_redirects(mut self, max_redirects: usize) -> Request {
self.max_redirects = max_redirects;
self
}
pub fn with_max_headers_size<S: Into<Option<usize>>>(mut self, max_headers_size: S) -> Request {
self.max_headers_size = max_headers_size.into();
self
}
pub fn with_max_status_line_length<S: Into<Option<usize>>>(
mut self,
max_status_line_len: S,
) -> Request {
self.max_status_line_len = max_status_line_len.into();
self
}
pub fn with_max_body_size<S: Into<Option<usize>>>(mut self, max_body_size: S) -> Request {
self.max_body_size = max_body_size.into();
self
}
#[cfg(feature = "proxy")]
pub fn with_proxy(mut self, proxy: Proxy) -> Request {
self.proxy = Some(proxy);
self
}
#[cfg(feature = "async")]
pub fn with_pipelining(mut self) -> Request {
self.pipelining = true;
self
}
#[cfg(feature = "std")]
pub fn send(self) -> Result<Response, Error> {
let parsed_request = ParsedRequest::new(self)?;
let is_head = parsed_request.config.method == Method::Head;
let max_body_size = parsed_request.config.max_body_size;
let connection =
Connection::new(parsed_request.connection_params(), parsed_request.timeout_at)?;
let response = connection.send(parsed_request)?;
Response::create(response, is_head, max_body_size)
}
#[cfg(feature = "std")]
pub fn send_lazy(self) -> Result<ResponseLazy, Error> {
let parsed_request = ParsedRequest::new(self)?;
Connection::new(parsed_request.connection_params(), parsed_request.timeout_at)?
.send(parsed_request)
}
#[cfg(feature = "async")]
pub async fn send_async(self) -> Result<Response, Error> {
let parsed_request = ParsedRequest::new(self)?;
AsyncConnection::new(parsed_request.connection_params(), parsed_request.timeout_at)
.await?
.send(parsed_request)
.await
}
#[cfg(feature = "async")]
pub async fn send_lazy_async(self) -> Result<ResponseLazy, Error> {
let response = self.send_async().await?;
Ok(ResponseLazy::dummy_from_response(response))
}
}
#[cfg(feature = "std")]
pub(crate) struct ParsedRequest {
pub(crate) url: Url,
pub(crate) redirects: Vec<Url>,
pub(crate) config: Request,
pub(crate) timeout_at: Option<Instant>,
}
#[cfg(feature = "std")]
impl ParsedRequest {
#[allow(unused_mut)]
pub(crate) fn new(mut config: Request) -> Result<ParsedRequest, Error> {
let mut url = Url::parse(&config.url)?;
let params = config.params.iter().map(|(a, b)| (a.as_str(), b.as_str()));
url.append_query_params(params);
#[cfg(all(feature = "proxy", feature = "std"))]
if config.proxy.is_none() {
if url.is_https() {
if let Ok(proxy) =
std::env::var("https_proxy").map_err(|_| std::env::var("HTTPS_PROXY"))
{
if let Ok(proxy) = Proxy::new_http(proxy) {
config.proxy = Some(proxy);
}
}
}
else if let Ok(proxy) = std::env::var("http_proxy") {
if let Ok(proxy) = Proxy::new_http(proxy) {
config.proxy = Some(proxy);
}
}
else if let Ok(proxy) =
std::env::var("all_proxy").map_err(|_| std::env::var("ALL_PROXY"))
{
if let Ok(proxy) = Proxy::new_http(proxy) {
config.proxy = Some(proxy);
}
}
}
let timeout = config.timeout.or_else(|| match env::var("BITREQ_TIMEOUT") {
Ok(t) => t.parse::<u64>().ok(),
Err(_) => None,
});
let timeout_at = timeout.map(|t| Instant::now() + Duration::from_secs(t));
Ok(ParsedRequest { url, redirects: Vec::new(), config, timeout_at })
}
fn get_http_head(&self) -> String {
let mut http = String::with_capacity(32);
write!(
http,
"{} {} HTTP/1.1\r\nHost: {}",
self.config.method,
self.url.path_and_query(),
self.url.base_url()
)
.unwrap();
if self.url.has_explicit_non_default_port() {
write!(http, ":{}", self.url.port()).unwrap();
}
http += "\r\n";
for (k, v) in &self.config.headers {
write!(http, "{}: {}\r\n", k, v).unwrap();
}
if self.config.method == Method::Post
|| self.config.method == Method::Put
|| self.config.method == Method::Patch
{
let not_length = |key: &String| {
let key = key.to_lowercase();
key != "content-length" && key != "transfer-encoding"
};
if self.config.headers.keys().all(not_length) {
http += "Content-Length: 0\r\n";
}
}
http += "\r\n";
http
}
pub(crate) fn as_bytes(&self) -> Vec<u8> {
let mut head = self.get_http_head().into_bytes();
if let Some(body) = &self.config.body {
head.extend(body);
}
head
}
pub(crate) fn redirect_to(&mut self, url: &str) -> Result<(), Error> {
if url.contains("://") {
let mut new_url = Url::parse(url).map_err(|_| {
#[cfg(feature = "std")]
{
Error::IoError(std::io::Error::new(
std::io::ErrorKind::Other,
"was redirected to an absolute url with an invalid protocol",
))
}
#[cfg(not(feature = "std"))]
{
Error::Other("invalid protocol in redirect")
}
})?;
new_url.preserve_fragment_from(&self.url);
std::mem::swap(&mut new_url, &mut self.url);
self.redirects.push(new_url);
} else {
let mut absolute_url = String::new();
self.url.write_base_url_to(&mut absolute_url).unwrap();
absolute_url.push_str(url);
let mut new_url = Url::parse(&absolute_url)?;
new_url.preserve_fragment_from(&self.url);
std::mem::swap(&mut new_url, &mut self.url);
self.redirects.push(new_url);
}
if self.redirects.len() > self.config.max_redirects {
Err(Error::TooManyRedirections)
} else if self.redirects.iter().any(|redirect_url| redirect_url == &self.url) {
Err(Error::InfiniteRedirectionLoop)
} else {
Ok(())
}
}
pub(crate) fn connection_params(&self) -> ConnectionParams<'_> {
ConnectionParams::from_request(self)
}
}
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
#[cfg(feature = "std")]
pub(crate) struct ConnectionParams<'a> {
pub(crate) https: bool,
pub(crate) host: &'a str,
pub(crate) port: u16,
#[cfg(feature = "proxy")]
pub(crate) proxy: Option<&'a Proxy>,
}
#[cfg(feature = "std")]
impl<'a> ConnectionParams<'a> {
fn from_request(request: &'a ParsedRequest) -> Self {
Self {
https: request.url.is_https(),
host: request.url.base_url(),
port: request.url.port(),
#[cfg(feature = "proxy")]
proxy: request.config.proxy.as_ref(),
}
}
}
#[cfg(feature = "std")]
#[derive(Clone, Debug, Hash, PartialEq, Eq)]
pub(crate) struct OwnedConnectionParams {
pub(crate) https: bool,
pub(crate) host: String,
pub(crate) port: u16,
#[cfg(feature = "proxy")]
pub(crate) proxy: Option<Proxy>,
}
#[cfg(feature = "std")]
impl PartialEq<ConnectionParams<'_>> for OwnedConnectionParams {
fn eq(&self, other: &ConnectionParams<'_>) -> bool {
if self.https != other.https || self.host != other.host || self.port != other.port {
return false;
}
#[cfg(feature = "proxy")]
{
self.proxy.as_ref() == other.proxy
}
#[cfg(not(feature = "proxy"))]
{
true
}
}
}
#[cfg(feature = "std")]
impl From<ConnectionParams<'_>> for OwnedConnectionParams {
fn from(other: ConnectionParams<'_>) -> Self {
Self {
https: other.https,
host: other.host.to_owned(),
port: other.port,
#[cfg(feature = "proxy")]
proxy: other.proxy.cloned(),
}
}
}
pub fn get<T: Into<URL>>(url: T) -> Request { Request::new(Method::Get, url) }
pub fn head<T: Into<URL>>(url: T) -> Request { Request::new(Method::Head, url) }
pub fn post<T: Into<URL>>(url: T) -> Request { Request::new(Method::Post, url) }
pub fn put<T: Into<URL>>(url: T) -> Request { Request::new(Method::Put, url) }
pub fn delete<T: Into<URL>>(url: T) -> Request { Request::new(Method::Delete, url) }
pub fn connect<T: Into<URL>>(url: T) -> Request { Request::new(Method::Connect, url) }
pub fn options<T: Into<URL>>(url: T) -> Request { Request::new(Method::Options, url) }
pub fn trace<T: Into<URL>>(url: T) -> Request { Request::new(Method::Trace, url) }
pub fn patch<T: Into<URL>>(url: T) -> Request { Request::new(Method::Patch, url) }
#[cfg(test)]
#[cfg(feature = "std")]
mod parsing_tests {
use alloc::collections::BTreeMap;
use super::{get, ParsedRequest};
#[test]
fn test_headers() {
let mut headers = BTreeMap::new();
headers.insert("foo".to_string(), "bar".to_string());
headers.insert("foo".to_string(), "baz".to_string());
let req = get("http://www.example.org/test/res").with_headers(headers.clone());
assert_eq!(req.headers, headers);
}
#[test]
fn test_multiple_params() {
let req = get("http://www.example.org/test/res")
.with_param("foo", "bar")
.with_param("asd", "qwe");
let req = ParsedRequest::new(req).unwrap();
assert_eq!(req.url.path_and_query(), "/test/res?foo=bar&asd=qwe");
}
#[test]
fn test_domain() {
let req = get("http://www.example.org/test/res").with_param("foo", "bar");
let req = ParsedRequest::new(req).unwrap();
assert_eq!(req.url.base_url(), "www.example.org");
}
#[test]
fn test_protocol() {
let req =
ParsedRequest::new(get("http://www.example.org/").with_param("foo", "bar")).unwrap();
assert!(!req.url.is_https());
let req =
ParsedRequest::new(get("https://www.example.org/").with_param("foo", "bar")).unwrap();
assert!(req.url.is_https());
}
}
#[cfg(all(test, feature = "std"))]
mod encoding_tests {
use super::{get, ParsedRequest};
#[test]
fn test_with_param() {
let req = get("http://www.example.org").with_param("foo", "bar");
let req = ParsedRequest::new(req).unwrap();
assert_eq!(req.url.path_and_query(), "/?foo=bar");
let req = get("http://www.example.org").with_param("ówò", "what's this? 👀");
let req = ParsedRequest::new(req).unwrap();
assert_eq!(req.url.path_and_query(), "/?%C3%B3w%C3%B2=what%27s%20this%3F%20%F0%9F%91%80");
}
#[test]
fn test_on_creation() {
let req = ParsedRequest::new(get("http://www.example.org/?foo=bar#baz")).unwrap();
assert_eq!(req.url.path_and_query(), "/?foo=bar");
}
}