use super::{
HTTP_HEADER_CONTENT_HTML, HTTP_HEADER_CONTENT_JSON,
HTTP_HEADER_CONTENT_TEXT, HTTP_HEADER_NO_CACHE, HTTP_HEADER_NO_STORE,
HTTP_HEADER_TRANSFER_CHUNKED, HttpHeader, LOG_TARGET, get_super_ts,
new_internal_error,
};
use bytes::{Bytes, BytesMut};
use http::StatusCode;
use http::header;
use http::{HeaderName, HeaderValue};
use pingora::http::ResponseHeader;
use pingora::proxy::Session;
use serde::Serialize;
use std::pin::Pin;
use tokio::io::AsyncReadExt;
use tracing::error;
fn new_cache_control_header(
max_age: Option<u32>,
cache_private: Option<bool>,
) -> HttpHeader {
let max_age = match max_age {
Some(0) | None => return HTTP_HEADER_NO_CACHE.clone(),
Some(age) => age,
};
let category: &[u8] = if cache_private.unwrap_or_default() {
b"private"
} else {
b"public"
};
let mut buf = BytesMut::with_capacity(category.len() + 9 + 10); buf.extend_from_slice(category);
buf.extend_from_slice(b", max-age=");
buf.extend_from_slice(itoa::Buffer::new().format(max_age).as_bytes());
if let Ok(value) = HeaderValue::from_bytes(&buf) {
return (header::CACHE_CONTROL, value);
}
HTTP_HEADER_NO_CACHE.clone()
}
#[derive(Default, Debug)]
pub struct HttpResponseBuilder {
response: HttpResponse,
}
impl HttpResponseBuilder {
pub fn new(status: StatusCode) -> Self {
Self {
response: HttpResponse {
status,
..Default::default()
},
}
}
pub fn body(mut self, body: impl Into<Bytes>) -> Self {
self.response.body = body.into();
self
}
pub fn header(mut self, header: HttpHeader) -> Self {
self.response
.headers
.get_or_insert_with(Vec::new)
.push(header);
self
}
pub fn headers(mut self, headers: Vec<HttpHeader>) -> Self {
self.response
.headers
.get_or_insert_with(Vec::new)
.extend(headers);
self
}
pub fn max_age(mut self, seconds: u32, is_private: bool) -> Self {
self.response.max_age = Some(seconds);
self.response.cache_private = Some(is_private);
self
}
pub fn no_store(self) -> Self {
self.header(HTTP_HEADER_NO_STORE.clone())
}
pub fn finish(self) -> HttpResponse {
self.response
}
}
#[derive(Default, Clone, Debug)]
pub struct HttpResponse {
pub status: StatusCode,
pub body: Bytes,
pub max_age: Option<u32>,
pub created_at: Option<u32>,
pub cache_private: Option<bool>,
pub headers: Option<Vec<HttpHeader>>,
}
impl HttpResponse {
pub fn builder(status: StatusCode) -> HttpResponseBuilder {
HttpResponseBuilder::new(status)
}
pub fn no_content() -> Self {
Self::builder(StatusCode::NO_CONTENT).no_store().finish()
}
pub fn bad_request(body: impl Into<Bytes>) -> Self {
Self::builder(StatusCode::BAD_REQUEST)
.body(body)
.header(HTTP_HEADER_CONTENT_TEXT.clone())
.no_store()
.finish()
}
pub fn not_found(body: impl Into<Bytes>) -> Self {
Self::builder(StatusCode::NOT_FOUND)
.body(body)
.header(HTTP_HEADER_CONTENT_TEXT.clone())
.no_store()
.finish()
}
pub fn unknown_error(body: impl Into<Bytes>) -> Self {
Self::builder(StatusCode::INTERNAL_SERVER_ERROR)
.body(body)
.header(HTTP_HEADER_CONTENT_TEXT.clone())
.no_store()
.finish()
}
pub fn html(body: impl Into<Bytes>) -> Self {
Self::builder(StatusCode::OK)
.body(body)
.header(HTTP_HEADER_CONTENT_HTML.clone())
.header(HTTP_HEADER_NO_CACHE.clone())
.finish()
}
pub fn redirect(location: &str) -> pingora::Result<Self> {
let value = HeaderValue::from_str(location).map_err(|e| {
error!(error = e.to_string(), "to header value fail");
new_internal_error(500, e)
})?;
Ok(Self::builder(StatusCode::TEMPORARY_REDIRECT)
.header((header::LOCATION, value))
.header(HTTP_HEADER_NO_CACHE.clone())
.finish())
}
pub fn text(body: impl Into<Bytes>) -> Self {
Self::builder(StatusCode::OK)
.body(body)
.header(HTTP_HEADER_CONTENT_TEXT.clone())
.header(HTTP_HEADER_NO_CACHE.clone())
.finish()
}
pub fn try_from_json<T>(value: &T) -> pingora::Result<Self>
where
T: ?Sized + Serialize,
{
let buf = serde_json::to_vec(value).map_err(|e| {
error!(target: LOG_TARGET, error = e.to_string(), "to json fail");
new_internal_error(400, e)
})?;
Ok(Self::builder(StatusCode::OK)
.body(buf)
.header(HTTP_HEADER_CONTENT_JSON.clone())
.finish())
}
pub fn try_from_json_status<T>(
value: &T,
status: StatusCode,
) -> pingora::Result<Self>
where
T: ?Sized + Serialize,
{
let mut resp = Self::try_from_json(value)?;
resp.status = status;
Ok(resp)
}
pub fn new_response_header(&self) -> pingora::Result<ResponseHeader> {
let mut resp = ResponseHeader::build(self.status, None)?;
let mut add_header =
|name: &HeaderName, value: &HeaderValue| -> pingora::Result<()> {
resp.insert_header(name, value)?;
Ok(())
};
add_header(
&header::CONTENT_LENGTH,
&HeaderValue::from(self.body.len()),
)?;
let (name, value) =
new_cache_control_header(self.max_age, self.cache_private);
add_header(&name, &value)?;
if let Some(created_at) = self.created_at {
let secs = get_super_ts().saturating_sub(created_at);
add_header(&header::AGE, &HeaderValue::from(secs))?;
}
if let Some(headers) = &self.headers {
for (name, value) in headers {
add_header(name, value)?;
}
}
Ok(resp)
}
pub async fn send(self, session: &mut Session) -> pingora::Result<usize> {
let header = self.new_response_header()?;
let size = self.body.len();
session
.write_response_header(Box::new(header), false)
.await?;
session.write_response_body(Some(self.body), true).await?;
session.finish_body().await?;
Ok(size)
}
}
pub struct HttpChunkResponse<'r, R> {
pub reader: Pin<&'r mut R>,
pub chunk_size: usize,
pub max_age: Option<u32>,
pub cache_private: Option<bool>,
pub headers: Option<Vec<HttpHeader>>,
}
const DEFAULT_BUF_SIZE: usize = 8 * 1024;
impl<'r, R> HttpChunkResponse<'r, R>
where
R: tokio::io::AsyncRead + std::marker::Unpin,
{
pub fn new(r: &'r mut R) -> Self {
Self {
reader: Pin::new(r),
chunk_size: DEFAULT_BUF_SIZE,
max_age: None,
headers: None,
cache_private: None,
}
}
pub fn get_response_header(&self) -> pingora::Result<ResponseHeader> {
let mut resp = ResponseHeader::build(StatusCode::OK, Some(4))?;
if let Some(headers) = &self.headers {
for (name, value) in headers {
resp.insert_header(name.to_owned(), value)?;
}
}
let chunked = HTTP_HEADER_TRANSFER_CHUNKED.clone();
resp.insert_header(chunked.0, chunked.1)?;
let cache_control =
new_cache_control_header(self.max_age, self.cache_private);
resp.insert_header(cache_control.0, cache_control.1)?;
Ok(resp)
}
pub async fn send(
mut self,
session: &mut Session,
) -> pingora::Result<usize> {
let header = self.get_response_header()?;
session
.write_response_header(Box::new(header), false)
.await?;
let mut sent = 0;
let chunk_size = self.chunk_size.max(512);
let mut buffer = vec![0; chunk_size];
loop {
let size = self.reader.read(&mut buffer).await.map_err(|e| {
error!(error = e.to_string(), "read data fail");
new_internal_error(400, e)
})?;
let end = size < chunk_size;
session
.write_response_body(
Some(Bytes::copy_from_slice(&buffer[..size])),
end,
)
.await?;
sent += size;
if end {
break;
}
}
session.finish_body().await?;
Ok(sent)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::convert_headers;
use bytes::Bytes;
use http::StatusCode;
use pretty_assertions::assert_eq;
use serde::Serialize;
use std::io::Write;
use tempfile::NamedTempFile;
use tokio::fs;
#[test]
fn test_new_cache_control_header() {
assert_eq!(
r###"("cache-control", "private, max-age=3600")"###,
format!("{:?}", new_cache_control_header(Some(3600), Some(true)))
);
assert_eq!(
r###"("cache-control", "public, max-age=3600")"###,
format!("{:?}", new_cache_control_header(Some(3600), None))
);
assert_eq!(
r###"("cache-control", "private, no-cache")"###,
format!("{:?}", new_cache_control_header(Some(0), Some(false)))
);
assert_eq!(
r###"("cache-control", "private, no-cache")"###,
format!("{:?}", new_cache_control_header(None, None))
);
}
#[test]
fn test_http_response() {
assert_eq!(
r###"HttpResponse { status: 204, body: b"", max_age: None, created_at: None, cache_private: None, headers: Some([("cache-control", "private, no-store")]) }"###,
format!("{:?}", HttpResponse::no_content())
);
assert_eq!(
r###"HttpResponse { status: 404, body: b"Not Found", max_age: None, created_at: None, cache_private: None, headers: Some([("content-type", "text/plain; charset=utf-8"), ("cache-control", "private, no-store")]) }"###,
format!("{:?}", HttpResponse::not_found("Not Found"))
);
assert_eq!(
r###"HttpResponse { status: 500, body: b"Unknown Error", max_age: None, created_at: None, cache_private: None, headers: Some([("content-type", "text/plain; charset=utf-8"), ("cache-control", "private, no-store")]) }"###,
format!("{:?}", HttpResponse::unknown_error("Unknown Error"))
);
assert_eq!(
r###"HttpResponse { status: 400, body: b"Bad Request", max_age: None, created_at: None, cache_private: None, headers: Some([("content-type", "text/plain; charset=utf-8"), ("cache-control", "private, no-store")]) }"###,
format!("{:?}", HttpResponse::bad_request("Bad Request"))
);
assert_eq!(
r###"HttpResponse { status: 200, body: b"<p>Pingap</p>", max_age: None, created_at: None, cache_private: None, headers: Some([("content-type", "text/html; charset=utf-8"), ("cache-control", "private, no-cache")]) }"###,
format!("{:?}", HttpResponse::html("<p>Pingap</p>"))
);
assert_eq!(
r###"HttpResponse { status: 307, body: b"", max_age: None, created_at: None, cache_private: None, headers: Some([("location", "http://example.com/"), ("cache-control", "private, no-cache")]) }"###,
format!(
"{:?}",
HttpResponse::redirect("http://example.com/").unwrap()
)
);
assert_eq!(
r###"HttpResponse { status: 200, body: b"Hello World!", max_age: None, created_at: None, cache_private: None, headers: Some([("content-type", "text/plain; charset=utf-8"), ("cache-control", "private, no-cache")]) }"###,
format!("{:?}", HttpResponse::text("Hello World!"))
);
#[derive(Serialize)]
struct Data {
message: String,
}
let resp = HttpResponse::try_from_json_status(
&Data {
message: "Hello World!".to_string(),
},
StatusCode::BAD_REQUEST,
)
.unwrap();
assert_eq!(
r###"HttpResponse { status: 400, body: b"{\"message\":\"Hello World!\"}", max_age: None, created_at: None, cache_private: None, headers: Some([("content-type", "application/json; charset=utf-8")]) }"###,
format!("{resp:?}")
);
let resp = HttpResponse::try_from_json(&Data {
message: "Hello World!".to_string(),
})
.unwrap();
assert_eq!(
r###"HttpResponse { status: 200, body: b"{\"message\":\"Hello World!\"}", max_age: None, created_at: None, cache_private: None, headers: Some([("content-type", "application/json; charset=utf-8")]) }"###,
format!("{resp:?}")
);
let resp = HttpResponse {
status: StatusCode::OK,
body: Bytes::from("Hello world!"),
max_age: Some(3600),
created_at: Some(0),
cache_private: Some(true),
headers: Some(
convert_headers(&[
"Contont-Type: application/json".to_string(),
"Content-Encoding: gzip".to_string(),
])
.unwrap(),
),
};
let mut header = resp.new_response_header().unwrap();
assert_eq!(true, !header.headers.get("Age").unwrap().is_empty());
header.remove_header("Age").unwrap();
assert_eq!(
r###"ResponseHeader { base: Parts { status: 200, version: HTTP/1.1, headers: {"content-length": "12", "cache-control": "private, max-age=3600", "content-encoding": "gzip", "contont-type": "application/json"} }, header_name_map: Some({"content-length": CaseHeaderName(b"Content-Length"), "cache-control": CaseHeaderName(b"Cache-Control"), "content-encoding": CaseHeaderName(b"Content-Encoding"), "contont-type": CaseHeaderName(b"contont-type")}), reason_phrase: None }"###,
format!("{header:?}")
);
}
#[tokio::test]
async fn test_http_chunk_response() {
let file = include_bytes!("http_response.rs");
let mut f = NamedTempFile::new().unwrap();
f.write_all(file).unwrap();
let mut f = fs::OpenOptions::new().read(true).open(f).await.unwrap();
let mut resp = HttpChunkResponse::new(&mut f);
resp.max_age = Some(3600);
resp.cache_private = Some(false);
resp.headers = Some(
convert_headers(&["Contont-Type: text/html".to_string()]).unwrap(),
);
let header = resp.get_response_header().unwrap();
assert_eq!(
r###"ResponseHeader { base: Parts { status: 200, version: HTTP/1.1, headers: {"contont-type": "text/html", "transfer-encoding": "chunked", "cache-control": "public, max-age=3600"} }, header_name_map: Some({"contont-type": CaseHeaderName(b"contont-type"), "transfer-encoding": CaseHeaderName(b"Transfer-Encoding"), "cache-control": CaseHeaderName(b"Cache-Control")}), reason_phrase: None }"###,
format!("{header:?}")
);
}
#[test]
fn test_new_cache_control_header_logic() {
let (name, value) = new_cache_control_header(Some(3600), Some(true));
assert_eq!(name, header::CACHE_CONTROL);
assert_eq!(value.to_str().unwrap(), "private, max-age=3600");
let (name, value) = new_cache_control_header(Some(3600), Some(false));
assert_eq!(name, header::CACHE_CONTROL);
assert_eq!(value.to_str().unwrap(), "public, max-age=3600");
let (name, value) = new_cache_control_header(Some(3600), None);
assert_eq!(name, header::CACHE_CONTROL);
assert_eq!(value.to_str().unwrap(), "public, max-age=3600");
let (name, value) = new_cache_control_header(Some(0), Some(true));
assert_eq!(name, header::CACHE_CONTROL);
assert_eq!(value, HTTP_HEADER_NO_CACHE.clone().1);
let (name, value) = new_cache_control_header(None, Some(false));
assert_eq!(name, header::CACHE_CONTROL);
assert_eq!(value, HTTP_HEADER_NO_CACHE.clone().1);
}
#[test]
fn test_http_response_builder_pattern() {
let etag_header = (header::ETAG, HeaderValue::from_static("\"12345\""));
let server_header =
(header::SERVER, HeaderValue::from_static("MyTestServer"));
let response = HttpResponse::builder(StatusCode::OK)
.body("Test Body")
.header(etag_header.clone())
.headers(vec![server_header.clone()])
.max_age(60, true) .finish();
assert_eq!(response.status, StatusCode::OK);
assert_eq!(response.body, Bytes::from("Test Body"));
assert_eq!(response.max_age, Some(60));
assert_eq!(response.cache_private, Some(true));
let headers = response.headers.unwrap();
assert_eq!(headers.len(), 2);
assert!(headers.contains(&etag_header));
assert!(headers.contains(&server_header));
let no_store_response = HttpResponse::builder(StatusCode::ACCEPTED)
.no_store()
.finish();
assert_eq!(no_store_response.status, StatusCode::ACCEPTED);
assert!(
no_store_response
.headers
.unwrap()
.contains(&HTTP_HEADER_NO_STORE.clone())
);
}
#[test]
fn test_http_response_error_cases() {
let invalid_location = "http://example.com/\0";
let result = HttpResponse::redirect(invalid_location);
assert!(result.is_err());
}
#[test]
fn test_new_response_header_generation() {
let resp = HttpResponse {
status: StatusCode::OK,
body: Bytes::from("Hello world!"),
max_age: Some(3600),
created_at: Some(get_super_ts().saturating_sub(10)), cache_private: Some(true),
headers: Some(vec![(
header::CONTENT_ENCODING,
HeaderValue::from_static("gzip"),
)]),
};
let header = resp.new_response_header().unwrap();
let headers_map: std::collections::HashMap<_, _> =
header.headers.iter().collect();
assert_eq!(header.status, StatusCode::OK);
assert_eq!(
headers_map
.get(&header::CONTENT_LENGTH)
.unwrap()
.to_str()
.unwrap(),
"12"
);
assert_eq!(
headers_map
.get(&header::CACHE_CONTROL)
.unwrap()
.to_str()
.unwrap(),
"private, max-age=3600"
);
assert_eq!(
headers_map
.get(&header::CONTENT_ENCODING)
.unwrap()
.to_str()
.unwrap(),
"gzip"
);
}
}