#![warn(clippy::all)]
#![allow(clippy::new_without_default)]
#![allow(clippy::type_complexity)]
pub mod dict;
mod thread_zstd;
use bytes::BufMut;
use http::Version;
use pingora_error::{Error, ErrorType, ImmutStr, Result};
use pingora_http::ResponseHeader;
use std::cell::RefCell;
use std::ops::DerefMut;
use thread_local::ThreadLocal;
pub struct HeaderSerde {
compression: ZstdCompression,
buf: ThreadLocal<RefCell<Vec<u8>>>,
}
const MAX_HEADER_BUF_SIZE: usize = 128 * 1024;
const COMPRESS_LEVEL: i32 = 3;
impl HeaderSerde {
pub fn new(dict: Option<Vec<u8>>) -> Self {
if let Some(dict) = dict {
HeaderSerde {
compression: ZstdCompression::WithDict(thread_zstd::CompressionWithDict::new(
&dict,
COMPRESS_LEVEL,
)),
buf: ThreadLocal::new(),
}
} else {
HeaderSerde {
compression: ZstdCompression::Default(
thread_zstd::Compression::new(),
COMPRESS_LEVEL,
),
buf: ThreadLocal::new(),
}
}
}
pub fn serialize(&self, header: &ResponseHeader) -> Result<Vec<u8>> {
let mut buf = self
.buf
.get_or(|| RefCell::new(Vec::with_capacity(MAX_HEADER_BUF_SIZE)))
.borrow_mut();
buf.clear(); resp_header_to_buf(header, &mut buf);
self.compression.compress(&buf)
}
pub fn deserialize(&self, data: &[u8]) -> Result<ResponseHeader> {
let mut buf = self
.buf
.get_or(|| RefCell::new(Vec::with_capacity(MAX_HEADER_BUF_SIZE)))
.borrow_mut();
buf.clear(); self.compression
.decompress_to_buffer(data, buf.deref_mut())?;
buf_to_http_header(&buf)
}
}
enum ZstdCompression {
Default(thread_zstd::Compression, i32),
WithDict(thread_zstd::CompressionWithDict),
}
#[inline]
fn into_error<S: Into<ImmutStr>>(e: &'static str, context: S) -> Box<Error> {
Error::because(ErrorType::InternalError, context, e)
}
impl ZstdCompression {
fn compress(&self, data: &[u8]) -> Result<Vec<u8>> {
match &self {
ZstdCompression::Default(c, level) => c
.compress(data, *level)
.map_err(|e| into_error(e, "compress header")),
ZstdCompression::WithDict(c) => c
.compress(data)
.map_err(|e| into_error(e, "compress header")),
}
}
fn decompress_to_buffer(&self, source: &[u8], destination: &mut Vec<u8>) -> Result<usize> {
match &self {
ZstdCompression::Default(c, _) => {
c.decompress_to_buffer(source, destination).map_err(|e| {
into_error(
e,
format!(
"decompress header, frame_content_size: {}",
get_frame_content_size(source)
),
)
})
}
ZstdCompression::WithDict(c) => {
c.decompress_to_buffer(source, destination).map_err(|e| {
into_error(
e,
format!(
"decompress header, frame_content_size: {}",
get_frame_content_size(source)
),
)
})
}
}
}
}
#[inline]
fn get_frame_content_size(source: &[u8]) -> ImmutStr {
match zstd_safe::get_frame_content_size(source) {
Ok(Some(size)) => match size {
zstd_safe::CONTENTSIZE_ERROR => ImmutStr::from("invalid"),
zstd_safe::CONTENTSIZE_UNKNOWN => ImmutStr::from("unknown"),
_ => ImmutStr::from(size.to_string()),
},
Ok(None) => ImmutStr::from("none"),
Err(_e) => ImmutStr::from("failed"),
}
}
const CRLF: &[u8; 2] = b"\r\n";
#[inline]
fn resp_header_to_buf(resp: &ResponseHeader, buf: &mut Vec<u8>) -> usize {
let version = match resp.version {
Version::HTTP_10 => "HTTP/1.0 ",
Version::HTTP_11 => "HTTP/1.1 ",
_ => "HTTP/1.1 ", };
buf.put_slice(version.as_bytes());
let status = resp.status;
buf.put_slice(status.as_str().as_bytes());
buf.put_u8(b' ');
let reason = status.canonical_reason();
if let Some(reason_buf) = reason {
buf.put_slice(reason_buf.as_bytes());
}
buf.put_slice(CRLF);
resp.header_to_h1_wire(buf);
buf.put_slice(CRLF);
buf.len()
}
const MAX_HEADERS: usize = 256;
#[inline]
fn buf_to_http_header(buf: &[u8]) -> Result<ResponseHeader> {
let mut headers = vec![httparse::EMPTY_HEADER; MAX_HEADERS];
let mut resp = httparse::Response::new(&mut headers);
match resp.parse(buf) {
Ok(s) => match s {
httparse::Status::Complete(_size) => parsed_to_header(&resp),
_ => Error::e_explain(ErrorType::InternalError, "incomplete uncompressed header"),
},
Err(e) => Error::e_because(
ErrorType::InternalError,
format!(
"parsing failed on uncompressed header, len={}, content={:?}",
buf.len(),
String::from_utf8_lossy(buf)
),
e,
),
}
}
#[inline]
fn parsed_to_header(parsed: &httparse::Response) -> Result<ResponseHeader> {
let mut resp = ResponseHeader::build(parsed.code.unwrap(), Some(parsed.headers.len()))?;
for header in parsed.headers.iter() {
resp.append_header(header.name.to_string(), header.value)?;
}
Ok(resp)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ser_wo_dict() {
let serde = HeaderSerde::new(None);
let mut header = ResponseHeader::build(200, None).unwrap();
header.append_header("foo", "bar").unwrap();
header.append_header("foo", "barbar").unwrap();
header.append_header("foo", "barbarbar").unwrap();
header.append_header("Server", "Pingora").unwrap();
let compressed = serde.serialize(&header).unwrap();
let mut buf = vec![];
let uncompressed = resp_header_to_buf(&header, &mut buf);
assert!(compressed.len() < uncompressed);
}
#[test]
fn test_ser_de_no_dict() {
let serde = HeaderSerde::new(None);
let mut header = ResponseHeader::build(200, None).unwrap();
header.append_header("foo1", "bar1").unwrap();
header.append_header("foo2", "barbar2").unwrap();
header.append_header("foo3", "barbarbar3").unwrap();
header.append_header("Server", "Pingora").unwrap();
let compressed = serde.serialize(&header).unwrap();
let header2 = serde.deserialize(&compressed).unwrap();
assert_eq!(header.status, header2.status);
assert_eq!(header.headers, header2.headers);
}
#[test]
fn test_no_headers() {
let serde = HeaderSerde::new(None);
let header = ResponseHeader::build(200, None).unwrap();
let compressed = serde.serialize(&header).unwrap();
let header2 = serde.deserialize(&compressed).unwrap();
assert_eq!(header.status, header2.status);
assert_eq!(header.headers.len(), 0);
assert_eq!(header2.headers.len(), 0);
}
#[test]
fn test_empty_header_wire_format() {
let header = ResponseHeader::build(200, None).unwrap();
let mut buf = vec![];
resp_header_to_buf(&header, &mut buf);
assert_eq!(buf.len(), 19);
assert_eq!(buf, b"HTTP/1.1 200 OK\r\n\r\n");
let parsed = buf_to_http_header(&buf).unwrap();
assert_eq!(parsed.status.as_u16(), 200);
}
}