use std::fmt;
use http::{HeaderName, HeaderValue, Response, StatusCode, Version, header};
use crate::Error;
use crate::body::BodyWriter;
use crate::util::compare_lowercase_ascii;
pub(crate) struct AmendedResponse {
response: Response<()>,
headers: Vec<(HeaderName, HeaderValue)>,
}
impl AmendedResponse {
pub fn new(response: Response<()>) -> Self {
AmendedResponse {
response,
headers: vec![],
}
}
pub fn prelude(&self) -> (Version, StatusCode) {
let r = &self.response;
(r.version(), r.status())
}
pub fn set_header<K, V>(&mut self, name: K, value: V) -> Result<(), Error>
where
HeaderName: TryFrom<K>,
<HeaderName as TryFrom<K>>::Error: Into<http::Error>,
HeaderValue: TryFrom<V>,
<HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
{
let name = <HeaderName as TryFrom<K>>::try_from(name)
.map_err(Into::into)
.map_err(|e| Error::BadHeader(e.to_string()))?;
let value = <HeaderValue as TryFrom<V>>::try_from(value)
.map_err(Into::into)
.map_err(|e| Error::BadHeader(e.to_string()))?;
self.headers.push((name, value));
Ok(())
}
pub fn headers(&self) -> impl Iterator<Item = (&HeaderName, &HeaderValue)> {
self.headers
.iter()
.map(|v| (&v.0, &v.1))
.chain(self.response.headers().iter())
}
fn headers_get_all(&self, key: HeaderName) -> impl Iterator<Item = &HeaderValue> {
self.headers()
.filter(move |(k, _)| *k == key)
.map(|(_, v)| v)
}
fn headers_get(&self, key: HeaderName) -> Option<&HeaderValue> {
self.headers_get_all(key).next()
}
pub fn headers_len(&self) -> usize {
self.headers().count()
}
pub fn analyze(&self, wanted_mode: BodyWriter) -> Result<ResponseInfo, Error> {
let count_len = self.headers_get_all(header::CONTENT_LENGTH).count();
if count_len > 1 {
return Err(Error::TooManyContentLengthHeaders);
}
let mut content_length: Option<u64> = None;
if let Some(h) = self.headers_get(header::CONTENT_LENGTH) {
let n = h
.to_str()
.ok()
.and_then(|s| s.parse::<u64>().ok())
.ok_or(Error::BadContentLengthHeader)?;
content_length = Some(n);
}
let has_chunked = self
.headers_get_all(header::TRANSFER_ENCODING)
.filter_map(|v| v.to_str().ok())
.any(|v| compare_lowercase_ascii(v, "chunked"));
let mut res_body_header = false;
let body_mode = if has_chunked {
res_body_header = true;
BodyWriter::new_chunked()
} else if let Some(n) = content_length {
res_body_header = true;
BodyWriter::new_sized(n)
} else {
wanted_mode
};
Ok(ResponseInfo {
body_mode,
res_body_header,
})
}
}
pub(crate) struct ResponseInfo {
pub body_mode: BodyWriter,
pub res_body_header: bool,
}
impl fmt::Debug for AmendedResponse {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AmendedResponse")
.field("status", &self.response.status())
.field("headers", &self.headers)
.finish()
}
}