use super::{Encode, COMPRESSION_ERROR};
use bytes::Bytes;
use parking_lot::Mutex;
use pingora_error::{OrErr, Result};
use std::io::Write;
use std::time::{Duration, Instant};
use zstd::stream::write::Encoder;
pub struct Compressor {
compress: Mutex<Encoder<'static, Vec<u8>>>,
total_in: usize,
total_out: usize,
duration: Duration,
}
impl Compressor {
pub fn new(level: u32) -> Self {
Compressor {
compress: Mutex::new(Encoder::new(vec![], level as i32).unwrap()),
total_in: 0,
total_out: 0,
duration: Duration::new(0, 0),
}
}
}
impl Encode for Compressor {
fn encode(&mut self, input: &[u8], end: bool) -> Result<Bytes> {
const MAX_INIT_COMPRESSED_BUF_SIZE: usize = 16 * 1024;
let start = Instant::now();
self.total_in += input.len();
let mut compress = self.compress.lock();
compress
.get_mut()
.reserve(std::cmp::min(MAX_INIT_COMPRESSED_BUF_SIZE, input.len()));
compress
.write_all(input)
.or_err(COMPRESSION_ERROR, "while compress zstd")?;
if end {
compress
.do_finish()
.or_err(COMPRESSION_ERROR, "while compress zstd")?;
}
self.total_out += compress.get_ref().len();
self.duration += start.elapsed();
Ok(std::mem::take(compress.get_mut()).into()) }
fn stat(&self) -> (&'static str, usize, usize, Duration) {
("zstd", self.total_in, self.total_out, self.duration)
}
}
#[cfg(test)]
mod tests_stream {
use super::*;
#[test]
fn compress_zstd_data() {
let mut compressor = Compressor::new(11);
let input = b"adcdefgabcdefghadcdefgabcdefghadcdefgabcdefghadcdefgabcdefgh\n";
let compressed = compressor.encode(&input[..], false).unwrap();
assert!(compressed.is_empty());
let compressed = compressor.encode(&input[..], true).unwrap();
assert_eq!(&compressed[..4], &[0x28, 0xB5, 0x2F, 0xFD]);
assert!(compressed.len() < input.len());
}
}