1#![warn(clippy::all)]
21#![allow(clippy::new_without_default)]
22#![allow(clippy::type_complexity)]
23
24pub mod dict;
25mod thread_zstd;
26
27use bytes::BufMut;
28use http::Version;
29use pingora_error::{Error, ErrorType, ImmutStr, Result};
30use pingora_http::ResponseHeader;
31use std::cell::RefCell;
32use std::ops::DerefMut;
33use thread_local::ThreadLocal;
34
35pub struct HeaderSerde {
40 compression: ZstdCompression,
41 buf: ThreadLocal<RefCell<Vec<u8>>>,
43}
44
45const MAX_HEADER_BUF_SIZE: usize = 128 * 1024; const COMPRESS_LEVEL: i32 = 3;
48
49impl HeaderSerde {
50 pub fn new(dict: Option<Vec<u8>>) -> Self {
55 if let Some(dict) = dict {
56 HeaderSerde {
57 compression: ZstdCompression::WithDict(thread_zstd::CompressionWithDict::new(
58 &dict,
59 COMPRESS_LEVEL,
60 )),
61 buf: ThreadLocal::new(),
62 }
63 } else {
64 HeaderSerde {
65 compression: ZstdCompression::Default(
66 thread_zstd::Compression::new(),
67 COMPRESS_LEVEL,
68 ),
69 buf: ThreadLocal::new(),
70 }
71 }
72 }
73
74 pub fn serialize(&self, header: &ResponseHeader) -> Result<Vec<u8>> {
76 let mut buf = self
79 .buf
80 .get_or(|| RefCell::new(Vec::with_capacity(MAX_HEADER_BUF_SIZE)))
81 .borrow_mut();
82 buf.clear(); resp_header_to_buf(header, &mut buf);
84 self.compression.compress(&buf)
85 }
86
87 pub fn deserialize(&self, data: &[u8]) -> Result<ResponseHeader> {
89 let mut buf = self
90 .buf
91 .get_or(|| RefCell::new(Vec::with_capacity(MAX_HEADER_BUF_SIZE)))
92 .borrow_mut();
93 buf.clear(); self.compression
95 .decompress_to_buffer(data, buf.deref_mut())?;
96 buf_to_http_header(&buf)
97 }
98}
99
100enum ZstdCompression {
103 Default(thread_zstd::Compression, i32),
104 WithDict(thread_zstd::CompressionWithDict),
105}
106
107#[inline]
108fn into_error<S: Into<ImmutStr>>(e: &'static str, context: S) -> Box<Error> {
109 Error::because(ErrorType::InternalError, context, e)
110}
111
112impl ZstdCompression {
113 fn compress(&self, data: &[u8]) -> Result<Vec<u8>> {
114 match &self {
115 ZstdCompression::Default(c, level) => c
116 .compress(data, *level)
117 .map_err(|e| into_error(e, "compress header")),
118 ZstdCompression::WithDict(c) => c
119 .compress(data)
120 .map_err(|e| into_error(e, "compress header")),
121 }
122 }
123
124 fn decompress_to_buffer(&self, source: &[u8], destination: &mut Vec<u8>) -> Result<usize> {
125 match &self {
126 ZstdCompression::Default(c, _) => {
127 c.decompress_to_buffer(source, destination).map_err(|e| {
128 into_error(
129 e,
130 format!(
131 "decompress header, frame_content_size: {}",
132 get_frame_content_size(source)
133 ),
134 )
135 })
136 }
137 ZstdCompression::WithDict(c) => {
138 c.decompress_to_buffer(source, destination).map_err(|e| {
139 into_error(
140 e,
141 format!(
142 "decompress header, frame_content_size: {}",
143 get_frame_content_size(source)
144 ),
145 )
146 })
147 }
148 }
149 }
150}
151
152#[inline]
153fn get_frame_content_size(source: &[u8]) -> ImmutStr {
154 match zstd_safe::get_frame_content_size(source) {
155 Ok(Some(size)) => match size {
156 zstd_safe::CONTENTSIZE_ERROR => ImmutStr::from("invalid"),
157 zstd_safe::CONTENTSIZE_UNKNOWN => ImmutStr::from("unknown"),
158 _ => ImmutStr::from(size.to_string()),
159 },
160 Ok(None) => ImmutStr::from("none"),
161 Err(_e) => ImmutStr::from("failed"),
162 }
163}
164
165const CRLF: &[u8; 2] = b"\r\n";
166
167#[inline]
169fn resp_header_to_buf(resp: &ResponseHeader, buf: &mut Vec<u8>) -> usize {
170 let version = match resp.version {
172 Version::HTTP_10 => "HTTP/1.0 ",
173 Version::HTTP_11 => "HTTP/1.1 ",
174 _ => "HTTP/1.1 ", };
176 buf.put_slice(version.as_bytes());
177 let status = resp.status;
178 buf.put_slice(status.as_str().as_bytes());
179 buf.put_u8(b' ');
180 let reason = status.canonical_reason();
181 if let Some(reason_buf) = reason {
182 buf.put_slice(reason_buf.as_bytes());
183 }
184 buf.put_slice(CRLF);
185
186 resp.header_to_h1_wire(buf);
188
189 buf.put_slice(CRLF);
190
191 buf.len()
192}
193
194const MAX_HEADERS: usize = 256;
196
197#[inline]
198fn buf_to_http_header(buf: &[u8]) -> Result<ResponseHeader> {
199 let mut headers = vec![httparse::EMPTY_HEADER; MAX_HEADERS];
200 let mut resp = httparse::Response::new(&mut headers);
201
202 match resp.parse(buf) {
203 Ok(s) => match s {
204 httparse::Status::Complete(_size) => parsed_to_header(&resp),
205 _ => Error::e_explain(ErrorType::InternalError, "incomplete uncompressed header"),
207 },
208 Err(e) => Error::e_because(
209 ErrorType::InternalError,
210 format!(
211 "parsing failed on uncompressed header, len={}, content={:?}",
212 buf.len(),
213 String::from_utf8_lossy(buf)
214 ),
215 e,
216 ),
217 }
218}
219
220#[inline]
221fn parsed_to_header(parsed: &httparse::Response) -> Result<ResponseHeader> {
222 let mut resp = ResponseHeader::build(parsed.code.unwrap(), Some(parsed.headers.len()))?;
225
226 for header in parsed.headers.iter() {
227 resp.append_header(header.name.to_string(), header.value)?;
228 }
229
230 Ok(resp)
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236
237 #[test]
238 fn test_ser_wo_dict() {
239 let serde = HeaderSerde::new(None);
240 let mut header = ResponseHeader::build(200, None).unwrap();
241 header.append_header("foo", "bar").unwrap();
242 header.append_header("foo", "barbar").unwrap();
243 header.append_header("foo", "barbarbar").unwrap();
244 header.append_header("Server", "Pingora").unwrap();
245
246 let compressed = serde.serialize(&header).unwrap();
247 let mut buf = vec![];
248 let uncompressed = resp_header_to_buf(&header, &mut buf);
249 assert!(compressed.len() < uncompressed);
250 }
251
252 #[test]
253 fn test_ser_de_no_dict() {
254 let serde = HeaderSerde::new(None);
255 let mut header = ResponseHeader::build(200, None).unwrap();
256 header.append_header("foo1", "bar1").unwrap();
257 header.append_header("foo2", "barbar2").unwrap();
258 header.append_header("foo3", "barbarbar3").unwrap();
259 header.append_header("Server", "Pingora").unwrap();
260
261 let compressed = serde.serialize(&header).unwrap();
262 let header2 = serde.deserialize(&compressed).unwrap();
263 assert_eq!(header.status, header2.status);
264 assert_eq!(header.headers, header2.headers);
265 }
266
267 #[test]
268 fn test_no_headers() {
269 let serde = HeaderSerde::new(None);
270 let header = ResponseHeader::build(200, None).unwrap(); let compressed = serde.serialize(&header).unwrap();
274 let header2 = serde.deserialize(&compressed).unwrap();
275
276 assert_eq!(header.status, header2.status);
277 assert_eq!(header.headers.len(), 0);
278 assert_eq!(header2.headers.len(), 0);
279 }
280
281 #[test]
282 fn test_empty_header_wire_format() {
283 let header = ResponseHeader::build(200, None).unwrap();
284 let mut buf = vec![];
285 resp_header_to_buf(&header, &mut buf);
286
287 assert_eq!(buf.len(), 19);
289 assert_eq!(buf, b"HTTP/1.1 200 OK\r\n\r\n");
290
291 let parsed = buf_to_http_header(&buf).unwrap();
293 assert_eq!(parsed.status.as_u16(), 200);
294 }
295}