#![allow(missing_debug_implementations)]
use super::NetworkScheme;
use crate::error::BoxError;
use http::{
Error, HeaderMap, HeaderName, HeaderValue, Method, Request, Uri, Version,
header::CONTENT_LENGTH, request::Builder,
};
use http_body::Body;
use std::{any::Any, marker::PhantomData};
pub struct InnerRequest<B>
where
B: Body + Send + Unpin + 'static,
B::Data: Send,
B::Error: Into<BoxError>,
{
request: Request<B>,
version: Option<Version>,
network_scheme: NetworkScheme,
}
impl<B> InnerRequest<B>
where
B: Body + Send + Unpin + 'static,
B::Data: Send,
B::Error: Into<BoxError>,
{
pub fn builder<'a>() -> InnerRequestBuilder<'a, B> {
InnerRequestBuilder {
builder: Request::builder(),
version: None,
network_scheme: Default::default(),
headers_order: None,
_body: PhantomData,
}
}
pub fn pieces(self) -> (Request<B>, Option<Version>, NetworkScheme) {
(self.request, self.version, self.network_scheme)
}
}
pub struct InnerRequestBuilder<'a, B>
where
B: Body + Send + Unpin + 'static,
B::Data: Send,
B::Error: Into<BoxError>,
{
builder: Builder,
version: Option<Version>,
network_scheme: NetworkScheme,
headers_order: Option<&'a [HeaderName]>,
_body: PhantomData<B>,
}
impl<'a, B> InnerRequestBuilder<'a, B>
where
B: Body + Send + Unpin + 'static,
B::Data: Send,
B::Error: Into<BoxError>,
{
#[inline]
pub fn method(mut self, method: Method) -> Self {
self.builder = self.builder.method(method);
self
}
#[inline]
pub fn uri(mut self, uri: Uri) -> Self {
self.builder = self.builder.uri(uri);
self
}
#[inline]
pub fn version(mut self, version: Option<Version>) -> Self {
if let Some(version) = version {
self.builder = self.builder.version(version);
self.version = Some(version);
}
self
}
#[inline]
pub fn headers(mut self, mut headers: HeaderMap) -> Self {
if let Some(h) = self.builder.headers_mut() {
std::mem::swap(h, &mut headers)
}
self
}
#[inline]
pub fn headers_order(mut self, order: Option<&'a [HeaderName]>) -> Self {
self.headers_order = order;
self
}
#[inline]
pub fn extension<T>(mut self, extension: Option<T>) -> Self
where
T: Clone + Any + Send + Sync + 'static,
{
if let Some(extension) = extension {
self.builder = self.builder.extension(extension);
}
self
}
#[inline]
pub fn network_scheme(mut self, network_scheme: NetworkScheme) -> Self {
self.network_scheme = network_scheme;
self
}
#[inline]
pub fn body(mut self, body: B) -> Result<InnerRequest<B>, Error> {
if let Some((method, (headers, headers_order))) = self
.builder
.method_ref()
.cloned()
.zip(self.builder.headers_mut().zip(self.headers_order))
{
add_content_length_header(method, headers, &body);
sort_headers(headers, headers_order);
}
self.builder.body(body).map(|request| InnerRequest {
request,
version: self.version,
network_scheme: self.network_scheme,
})
}
}
#[inline]
fn add_content_length_header<B>(method: Method, headers: &mut HeaderMap, body: &B)
where
B: Body,
{
if let Some(len) = Body::size_hint(body).exact() {
if len != 0 || method_has_defined_payload_semantics(method) {
headers
.entry(CONTENT_LENGTH)
.or_insert_with(|| HeaderValue::from(len));
}
}
}
#[inline]
fn method_has_defined_payload_semantics(method: Method) -> bool {
!matches!(
method,
Method::GET | Method::HEAD | Method::DELETE | Method::CONNECT
)
}
#[inline]
fn sort_headers(headers: &mut HeaderMap, headers_order: &[HeaderName]) {
if headers.len() <= 1 {
return;
}
let mut sorted_headers = HeaderMap::with_capacity(headers.keys_len());
for name in headers_order {
for value in headers.get_all(name) {
sorted_headers.append(name.clone(), value.clone());
}
headers.remove(name);
}
for (key, value) in headers.drain() {
if let Some(key) = key {
sorted_headers.append(key, value);
}
}
std::mem::swap(headers, &mut sorted_headers);
}
#[cfg(test)]
mod tests {
use super::*;
use http::header::{HeaderMap, HeaderName, HeaderValue};
#[test]
fn test_sort_headers() {
let mut headers = HeaderMap::new();
headers.insert("b-header", HeaderValue::from_static("b"));
headers.insert("a-header", HeaderValue::from_static("a"));
headers.insert("c-header", HeaderValue::from_static("c"));
headers.insert("extra-header", HeaderValue::from_static("extra"));
let headers_order = [
HeaderName::from_static("a-header"),
HeaderName::from_static("b-header"),
HeaderName::from_static("c-header"),
];
sort_headers(&mut headers, &headers_order);
let mut iter = headers.iter();
assert_eq!(
iter.next(),
Some((
&HeaderName::from_static("a-header"),
&HeaderValue::from_static("a")
))
);
assert_eq!(
iter.next(),
Some((
&HeaderName::from_static("b-header"),
&HeaderValue::from_static("b")
))
);
assert_eq!(
iter.next(),
Some((
&HeaderName::from_static("c-header"),
&HeaderValue::from_static("c")
))
);
assert_eq!(
iter.next(),
Some((
&HeaderName::from_static("extra-header"),
&HeaderValue::from_static("extra")
))
);
assert_eq!(iter.next(), None);
}
#[test]
fn test_sort_headers_partial_match() {
let mut headers = HeaderMap::new();
headers.insert("x-header", HeaderValue::from_static("x"));
headers.insert("y-header", HeaderValue::from_static("y"));
let headers_order = [
HeaderName::from_static("y-header"),
HeaderName::from_static("z-header"),
];
sort_headers(&mut headers, &headers_order);
let mut iter = headers.iter();
assert_eq!(
iter.next(),
Some((
&HeaderName::from_static("y-header"),
&HeaderValue::from_static("y")
))
);
assert_eq!(
iter.next(),
Some((
&HeaderName::from_static("x-header"),
&HeaderValue::from_static("x")
))
);
assert_eq!(iter.next(), None);
}
#[test]
fn test_sort_headers_empty() {
let mut headers = HeaderMap::new();
let headers_order: [HeaderName; 0] = [];
sort_headers(&mut headers, &headers_order);
assert!(headers.is_empty());
}
#[test]
fn test_sort_headers_no_ordering() {
let mut headers = HeaderMap::new();
headers.insert("random-header", HeaderValue::from_static("random"));
let headers_order: [HeaderName; 0] = [];
sort_headers(&mut headers, &headers_order);
let mut iter = headers.iter();
assert_eq!(
iter.next(),
Some((
&HeaderName::from_static("random-header"),
&HeaderValue::from_static("random")
))
);
assert_eq!(iter.next(), None);
}
}