use std::mem;
use http::uri::PathAndQuery;
use http::{header, HeaderMap, HeaderName, HeaderValue, Method, Request, Uri, Version};
use crate::body::BodyWriter;
use crate::ext::MethodExt;
use crate::util::compare_lowercase_ascii;
use crate::Error;
pub(crate) struct AmendedRequest<Body> {
request: Request<Option<Body>>,
uri: Option<Uri>,
headers: Vec<(HeaderName, HeaderValue)>,
unset: Vec<HeaderName>,
}
impl<Body> AmendedRequest<Body> {
pub fn new(request: Request<Body>) -> Self {
let (parts, body) = request.into_parts();
AmendedRequest {
request: Request::from_parts(parts, Some(body)),
uri: None,
headers: vec![],
unset: vec![],
}
}
pub fn take_request(&mut self) -> Request<Body> {
let request = mem::replace(&mut self.request, Request::new(None));
let (parts, body) = request.into_parts();
Request::from_parts(parts, body.unwrap())
}
pub fn set_uri(&mut self, uri: Uri) {
self.uri = Some(uri);
}
pub fn uri(&self) -> &Uri {
if let Some(uri) = &self.uri {
uri
} else {
self.request.uri()
}
}
pub fn prelude(&self) -> (&Method, &str, Version) {
let r = &self.request;
(
r.method(),
self.uri()
.path_and_query()
.map(|p| p.as_str())
.unwrap_or("/"),
r.version(),
)
}
pub fn set_header<K, V>(&mut self, name: K, value: V) -> Result<(), Error>
where
HeaderName: TryFrom<K>,
<HeaderName as TryFrom<K>>::Error: Into<http::Error>,
HeaderValue: TryFrom<V>,
<HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
{
let name = <HeaderName as TryFrom<K>>::try_from(name)
.map_err(Into::into)
.map_err(|e| Error::BadHeader(e.to_string()))?;
let value = <HeaderValue as TryFrom<V>>::try_from(value)
.map_err(Into::into)
.map_err(|e| Error::BadHeader(e.to_string()))?;
self.headers.push((name, value));
Ok(())
}
pub fn unset_header<K>(&mut self, name: K) -> Result<(), Error>
where
HeaderName: TryFrom<K>,
<HeaderName as TryFrom<K>>::Error: Into<http::Error>,
{
let name = <HeaderName as TryFrom<K>>::try_from(name)
.map_err(Into::into)
.map_err(|e| Error::BadHeader(e.to_string()))?;
self.unset.push(name);
Ok(())
}
pub fn original_request_headers(&self) -> &HeaderMap {
self.request.headers()
}
pub fn headers(&self) -> impl Iterator<Item = (&HeaderName, &HeaderValue)> {
self.headers
.iter()
.map(|v| (&v.0, &v.1))
.chain(self.request.headers().iter())
.filter(|v| !self.unset.iter().any(|x| x == v.0))
}
fn headers_get_all(&self, key: HeaderName) -> impl Iterator<Item = &HeaderValue> {
self.headers()
.filter(move |(k, _)| *k == key)
.map(|(_, v)| v)
}
fn headers_get(&self, key: HeaderName) -> Option<&HeaderValue> {
self.headers_get_all(key).next()
}
pub fn headers_len(&self) -> usize {
self.headers().count()
}
#[cfg(test)]
pub fn headers_vec(&self) -> Vec<(&str, &str)> {
self.headers()
.map(|(k, v)| (k.as_str(), v.to_str().unwrap()))
.collect()
}
pub fn method(&self) -> &Method {
self.request.method()
}
pub(crate) fn version(&self) -> Version {
self.request.version()
}
pub fn new_uri_from_location(&self, location: &str) -> Result<Uri, Error> {
let base = self.uri().clone();
join(base, location)
}
pub fn analyze(
&self,
wanted_mode: BodyWriter,
skip_method_body_check: bool,
allow_non_standard_methods: bool,
) -> Result<RequestInfo, Error> {
let v = self.request.version();
let m = self.method();
if !allow_non_standard_methods {
m.verify_version(v)?;
}
let count_host = self.headers_get_all(header::HOST).count();
if count_host > 1 {
return Err(Error::TooManyHostHeaders);
}
let count_len = self.headers_get_all(header::CONTENT_LENGTH).count();
if count_len > 1 {
return Err(Error::TooManyContentLengthHeaders);
}
let mut req_host_header = false;
if let Some(h) = self.headers_get(header::HOST) {
h.to_str().map_err(|_| Error::BadHostHeader)?;
req_host_header = true;
}
let mut req_auth_header = false;
if let Some(h) = self.headers_get(header::AUTHORIZATION) {
h.to_str().map_err(|_| Error::BadAuthorizationHeader)?;
req_auth_header = true;
}
let mut content_length: Option<u64> = None;
if let Some(h) = self.headers_get(header::CONTENT_LENGTH) {
let n = h
.to_str()
.ok()
.and_then(|s| s.parse::<u64>().ok())
.ok_or(Error::BadContentLengthHeader)?;
content_length = Some(n);
}
let has_chunked = self
.headers_get_all(header::TRANSFER_ENCODING)
.filter_map(|v| v.to_str().ok())
.any(|v| compare_lowercase_ascii(v, "chunked"));
let mut req_body_header = false;
let body_mode = if has_chunked {
req_body_header = true;
BodyWriter::new_chunked()
} else if let Some(n) = content_length {
req_body_header = true;
BodyWriter::new_sized(n)
} else {
wanted_mode
};
if !skip_method_body_check {
let need_body = self.method().need_request_body();
let has_body = body_mode.has_body();
if !need_body && has_body {
return Err(Error::MethodForbidsBody(self.method().clone()));
} else if need_body && !has_body {
return Err(Error::MethodRequiresBody(self.method().clone()));
}
}
Ok(RequestInfo {
body_mode,
req_host_header,
req_auth_header,
req_body_header,
})
}
}
fn join(base: Uri, location: &str) -> Result<Uri, Error> {
let mut parts = base.into_parts();
let maybe = location.parse::<Uri>();
let has_scheme = maybe
.as_ref()
.ok()
.map(|u| u.scheme().is_some())
.unwrap_or(false);
if has_scheme {
return Ok(maybe.unwrap());
}
if location.starts_with("/") {
let pq: PathAndQuery = location
.parse()
.map_err(|_| Error::BadLocationHeader(location.to_string()))?;
parts.path_and_query = Some(pq);
} else {
let base_path = parts
.path_and_query
.as_ref()
.map(|p| p.path())
.unwrap_or("/");
let total_path = join_relative(base_path, location)?;
let pq: PathAndQuery = total_path
.parse()
.map_err(|_| Error::BadLocationHeader(location.to_string()))?;
parts.path_and_query = Some(pq);
}
let uri = Uri::from_parts(parts).map_err(|_| Error::BadLocationHeader(location.to_string()))?;
Ok(uri)
}
fn join_relative(base_path: &str, location: &str) -> Result<String, Error> {
assert!(!base_path.is_empty());
assert!(!location.starts_with("/"));
let mut joiner: Vec<&str> = base_path.split('/').collect();
if joiner.len() > 1 {
joiner.pop();
}
for segment in location.split('/') {
if segment == "." {
} else if segment == ".." {
if joiner.len() == 1 {
trace!("Location is relative above root");
return Err(Error::BadLocationHeader(location.to_string()));
}
joiner.pop();
} else {
joiner.push(segment);
}
}
Ok(joiner.join("/"))
}
pub(crate) struct RequestInfo {
pub body_mode: BodyWriter,
pub req_host_header: bool,
pub req_auth_header: bool,
pub req_body_header: bool,
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn join_things() {
let uri: Uri = "foo.html".parse().unwrap();
println!("{:?}", uri.into_parts());
}
}