use crate::lz77::{HASH_SIZE, WINDOW_SIZE};
use std::sync::{
Arc, Mutex,
atomic::{AtomicUsize, Ordering},
};
const WINDOW_BUF_LEN: usize = WINDOW_SIZE * 2; const HASH_HEAD_LEN: usize = HASH_SIZE; const HASH_PREV_LEN: usize = WINDOW_SIZE;
#[derive(Debug)]
struct PoolInner {
window: Mutex<Vec<Vec<u8>>>,
hash_head: Mutex<Vec<Vec<u16>>>,
hash_prev: Mutex<Vec<Vec<u16>>>,
cap: usize,
window_allocs: AtomicUsize,
window_hits: AtomicUsize,
}
#[derive(Clone, Debug)]
pub struct DeflatePool {
inner: Arc<PoolInner>,
}
impl DeflatePool {
pub fn new() -> Self {
Self::with_cap(4)
}
pub fn with_cap(cap: usize) -> Self {
Self {
inner: Arc::new(PoolInner {
window: Mutex::new(Vec::new()),
hash_head: Mutex::new(Vec::new()),
hash_prev: Mutex::new(Vec::new()),
cap,
window_allocs: AtomicUsize::new(0),
window_hits: AtomicUsize::new(0),
}),
}
}
pub(crate) fn get_window(&self) -> PooledBuf {
let mut guard = self.inner.window.lock().unwrap_or_else(|e| e.into_inner());
if let Some(mut buf) = guard.pop() {
self.inner.window_hits.fetch_add(1, Ordering::Relaxed);
buf.fill(0);
buf.resize(WINDOW_BUF_LEN, 0);
PooledBuf {
buf,
pool: Arc::clone(&self.inner),
kind: BufKind::Window,
}
} else {
self.inner.window_allocs.fetch_add(1, Ordering::Relaxed);
let buf = vec![0u8; WINDOW_BUF_LEN];
PooledBuf {
buf,
pool: Arc::clone(&self.inner),
kind: BufKind::Window,
}
}
}
pub(crate) fn get_hash_head(&self) -> PooledU16Buf {
let mut guard = self
.inner
.hash_head
.lock()
.unwrap_or_else(|e| e.into_inner());
if let Some(mut buf) = guard.pop() {
buf.fill(0);
buf.resize(HASH_HEAD_LEN, 0);
PooledU16Buf {
buf,
pool: Arc::clone(&self.inner),
kind: U16BufKind::HashHead,
}
} else {
let buf = vec![0u16; HASH_HEAD_LEN];
PooledU16Buf {
buf,
pool: Arc::clone(&self.inner),
kind: U16BufKind::HashHead,
}
}
}
pub(crate) fn get_hash_prev(&self) -> PooledU16Buf {
let mut guard = self
.inner
.hash_prev
.lock()
.unwrap_or_else(|e| e.into_inner());
if let Some(mut buf) = guard.pop() {
buf.fill(0);
buf.resize(HASH_PREV_LEN, 0);
PooledU16Buf {
buf,
pool: Arc::clone(&self.inner),
kind: U16BufKind::HashPrev,
}
} else {
let buf = vec![0u16; HASH_PREV_LEN];
PooledU16Buf {
buf,
pool: Arc::clone(&self.inner),
kind: U16BufKind::HashPrev,
}
}
}
pub(crate) fn return_window(&self, buf: Vec<u8>) {
let mut guard = self.inner.window.lock().unwrap_or_else(|e| e.into_inner());
if guard.len() < self.inner.cap {
guard.push(buf);
}
}
pub(crate) fn return_hash_head(&self, buf: Vec<u16>) {
let mut guard = self
.inner
.hash_head
.lock()
.unwrap_or_else(|e| e.into_inner());
if guard.len() < self.inner.cap {
guard.push(buf);
}
}
pub(crate) fn return_hash_prev(&self, buf: Vec<u16>) {
let mut guard = self
.inner
.hash_prev
.lock()
.unwrap_or_else(|e| e.into_inner());
if guard.len() < self.inner.cap {
guard.push(buf);
}
}
pub fn stats(&self) -> PoolStats {
PoolStats {
window_allocations: self.inner.window_allocs.load(Ordering::Relaxed),
window_hits: self.inner.window_hits.load(Ordering::Relaxed),
}
}
}
impl Default for DeflatePool {
fn default() -> Self {
Self::new()
}
}
unsafe impl Send for DeflatePool {}
unsafe impl Sync for DeflatePool {}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct PoolStats {
pub window_allocations: usize,
pub window_hits: usize,
}
#[derive(Clone, Copy, Debug)]
enum BufKind {
Window,
}
#[derive(Clone, Copy, Debug)]
enum U16BufKind {
HashHead,
HashPrev,
}
pub(crate) struct PooledBuf {
pub buf: Vec<u8>,
pool: Arc<PoolInner>,
kind: BufKind,
}
impl Drop for PooledBuf {
fn drop(&mut self) {
let buf = std::mem::take(&mut self.buf);
match self.kind {
BufKind::Window => {
let mut guard = self.pool.window.lock().unwrap_or_else(|e| e.into_inner());
if guard.len() < self.pool.cap {
guard.push(buf);
}
}
}
}
}
pub(crate) struct PooledU16Buf {
pub buf: Vec<u16>,
pool: Arc<PoolInner>,
kind: U16BufKind,
}
impl Drop for PooledU16Buf {
fn drop(&mut self) {
let buf = std::mem::take(&mut self.buf);
match self.kind {
U16BufKind::HashHead => {
let mut guard = self
.pool
.hash_head
.lock()
.unwrap_or_else(|e| e.into_inner());
if guard.len() < self.pool.cap {
guard.push(buf);
}
}
U16BufKind::HashPrev => {
let mut guard = self
.pool
.hash_prev
.lock()
.unwrap_or_else(|e| e.into_inner());
if guard.len() < self.pool.cap {
guard.push(buf);
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Deflater, deflate::deflate, inflate::inflate};
fn compress_with_pool(pool: &DeflatePool, input: &[u8], level: u8) -> Vec<u8> {
let mut d = Deflater::new(level).with_pool(pool);
d.compress_to_vec(input).expect("pool compress failed")
}
fn compress_without_pool(input: &[u8], level: u8) -> Vec<u8> {
deflate(input, level).expect("no-pool compress failed")
}
#[test]
fn test_pool_basic_window_reuse() {
let pool = DeflatePool::new();
let input: Vec<u8> = b"the quick brown fox jumps over the lazy dog "
.iter()
.cycle()
.take(8_192)
.copied()
.collect();
let out1 = compress_with_pool(&pool, &input, 6);
let out2 = compress_with_pool(&pool, &input, 6);
let out3 = compress_with_pool(&pool, &input, 6);
let stats = pool.stats();
assert!(
stats.window_hits >= 2,
"expected ≥ 2 window hits, got {} (allocs={})",
stats.window_hits,
stats.window_allocations,
);
for (i, compressed) in [&out1, &out2, &out3].iter().enumerate() {
let decompressed = inflate(compressed)
.unwrap_or_else(|e| panic!("inflate call {} failed: {}", i + 1, e));
assert_eq!(decompressed, input, "roundtrip failed for call {}", i + 1);
}
}
#[test]
fn test_pool_roundtrip_equality() {
let pool = DeflatePool::new();
let input: Vec<u8> = b"abcdefghijklmnopqrstuvwxyz0123456789"
.iter()
.cycle()
.take(65_536)
.copied()
.collect();
for level in [1u8, 6, 9] {
let pooled = compress_with_pool(&pool, &input, level);
let baseline = compress_without_pool(&input, level);
assert_eq!(
pooled, baseline,
"pooled and non-pooled output differ at level {}",
level
);
}
}
#[test]
fn test_pool_concurrent() {
use std::sync::Arc;
use std::thread;
let pool = Arc::new(DeflatePool::new());
let input: Arc<Vec<u8>> = Arc::new(
b"concurrent test data "
.iter()
.cycle()
.take(262_144) .copied()
.collect(),
);
let handles: Vec<_> = (0..4)
.map(|_| {
let p = Arc::clone(&pool);
let d = Arc::clone(&input);
thread::spawn(move || {
let mut deflater = Deflater::new(6).with_pool(&p);
let compressed = deflater
.compress_to_vec(&d)
.expect("thread compress failed");
let decompressed = inflate(&compressed).expect("thread inflate failed");
assert_eq!(&decompressed, d.as_ref(), "thread roundtrip failed");
})
})
.collect();
for h in handles {
h.join().expect("thread panicked");
}
}
#[test]
fn test_pool_boundary_cap_respected() {
let pool = DeflatePool::with_cap(2);
let input: Vec<u8> = b"hello boundary test"
.iter()
.cycle()
.take(16_384)
.copied()
.collect();
let mut outputs = Vec::new();
for _ in 0..3 {
outputs.push(compress_with_pool(&pool, &input, 6));
}
for (i, compressed) in outputs.iter().enumerate() {
let decompressed = inflate(compressed)
.unwrap_or_else(|e| panic!("inflate call {} failed: {}", i + 1, e));
assert_eq!(decompressed, input, "roundtrip failed at call {}", i + 1);
}
{
let guard = pool.inner.window.lock().expect("lock");
assert!(
guard.len() <= 2,
"pool window bucket should hold ≤ 2 buffers, holds {}",
guard.len()
);
}
}
#[test]
fn test_pool_cap_zero() {
let pool = DeflatePool::with_cap(0);
let input = b"cap zero test data".to_vec();
for _ in 0..3 {
let _ = compress_with_pool(&pool, &input, 6);
}
let stats = pool.stats();
assert_eq!(stats.window_hits, 0, "cap=0 must have no hits");
assert_eq!(
stats.window_allocations, 3,
"cap=0 must allocate for every call"
);
}
#[test]
fn test_pool_default() {
let pool = DeflatePool::default();
let out = compress_with_pool(&pool, b"default test", 6);
let dec = inflate(&out).expect("inflate default test");
assert_eq!(dec, b"default test");
}
}