use std::collections::VecDeque;
use std::fs::{OpenOptions, remove_file, rename};
use std::io::Write;
use zstd::bulk::Compressor;
const ZSTD_LEVEL: i32 = 3;
const COMPRESS_MIN_BODY: usize = 24;
const DICT_MAX_SIZE: usize = 16 * 1024;
const DICT_EVAL_INTERVAL: u64 = 200;
const DICT_PLATEAU_EPS: f64 = 0.02;
const DICT_PLATEAU_CONFIRMATIONS: u32 = 2;
const DICT_MIN_SAMPLES: usize = 64;
const DICT_TRAIN_MIN_BYTES: usize = 96 * 1024;
const DICT_RESERVOIR_MAX_ITEMS: usize = 8192;
const DICT_RESERVOIR_MAX_BYTES: usize = 4 * 1024 * 1024;
const DICT_HOLDOUT_ITEMS: usize = 256;
const DICT_MAX_WARMUP_MSGS: u64 = 5_000_000;
enum DictState {
Disabled,
Warmup(Box<WarmupState>),
Ready(Box<Compressor<'static>>),
}
struct WarmupState {
samples: Vec<Vec<u8>>,
sample_bytes: usize,
seen: u64,
rng: u64,
last_eval_at: u64,
prev_ratio: Option<f64>,
plateau_hits: u32,
holdout: VecDeque<Vec<u8>>,
}
impl WarmupState {
fn new() -> Self {
WarmupState {
samples: Vec::new(),
sample_bytes: 0,
seen: 0,
rng: 0x9E37_79B9_7F4A_7C15,
last_eval_at: 0,
prev_ratio: None,
plateau_hits: 0,
holdout: VecDeque::new(),
}
}
fn next_rand(&mut self) -> u64 {
let mut x = self.rng;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
self.rng = x;
x
}
fn add_sample(&mut self, body: &[u8]) {
if self.samples.len() < DICT_RESERVOIR_MAX_ITEMS && self.sample_bytes + body.len() <= DICT_RESERVOIR_MAX_BYTES {
self.sample_bytes += body.len();
self.samples.push(body.to_vec());
} else if !self.samples.is_empty() {
let idx = (self.next_rand() as usize) % self.samples.len();
self.sample_bytes -= self.samples[idx].len();
self.sample_bytes += body.len();
self.samples[idx] = body.to_vec();
}
self.holdout.push_back(body.to_vec());
if self.holdout.len() > DICT_HOLDOUT_ITEMS {
self.holdout.pop_front();
}
}
}
enum DictDecision {
Continue,
Freeze(Vec<u8>),
Disable,
}
fn train_dict(samples: &[Vec<u8>]) -> Option<Vec<u8>> {
match zstd::dict::from_samples(samples, DICT_MAX_SIZE) {
Ok(dict) if !dict.is_empty() => Some(dict),
_ => None,
}
}
fn measure_ratio(dict: &[u8], holdout: &VecDeque<Vec<u8>>) -> f64 {
if holdout.is_empty() {
return 1.0;
}
let mut compressor = match Compressor::with_dictionary(ZSTD_LEVEL, dict) {
Ok(c) => c,
Err(_) => return 1.0,
};
let mut total_orig = 0usize;
let mut total_comp = 0usize;
for sample in holdout {
total_orig += sample.len();
match compressor.compress(sample) {
Ok(c) => total_comp += c.len() + 4,
Err(_) => total_comp += sample.len(),
}
}
if total_orig == 0 {
1.0
} else {
total_comp as f64 / total_orig as f64
}
}
pub struct DictPersistContext<'a> {
pub base_path: &'a str,
pub queue_name: &'a str,
pub part_id: u32,
}
pub struct EncodedBody {
pub compressed: bool,
pub msg_length: u32,
pub orig_len_bytes: [u8; 4],
pub stored: Vec<u8>,
}
pub struct DictTrainer {
state: DictState,
}
impl DictTrainer {
pub fn disabled() -> Self {
DictTrainer {
state: DictState::Disabled,
}
}
pub fn warmup() -> Self {
DictTrainer {
state: DictState::Warmup(Box::new(WarmupState::new())),
}
}
pub fn is_warmup(&self) -> bool {
matches!(self.state, DictState::Warmup(_))
}
pub fn encode_body(&mut self, data: &[u8]) -> EncodedBody {
let compressed_payload: Option<Vec<u8>> = match &mut self.state {
DictState::Ready(compressor) if data.len() >= COMPRESS_MIN_BODY => match compressor.compress(data) {
Ok(c) if c.len() + 4 < data.len() => Some(c),
_ => None,
},
_ => None,
};
if let Some(c) = compressed_payload {
let orig_len_bytes = u32::to_ne_bytes(data.len() as u32);
let mut stored = Vec::with_capacity(4 + c.len());
stored.extend_from_slice(&orig_len_bytes);
stored.extend_from_slice(&c);
EncodedBody {
compressed: true,
msg_length: stored.len() as u32,
orig_len_bytes,
stored,
}
} else {
EncodedBody {
compressed: false,
msg_length: data.len() as u32,
orig_len_bytes: [0; 4],
stored: data.to_vec(),
}
}
}
pub fn observe_stored_body(&mut self, ctx: &DictPersistContext<'_>, body: &[u8]) {
if !self.is_warmup() {
return;
}
let decision = {
let ws = match &mut self.state {
DictState::Warmup(ws) => ws,
_ => return,
};
ws.seen += 1;
ws.add_sample(body);
let interval_due = ws.seen.saturating_sub(ws.last_eval_at) >= DICT_EVAL_INTERVAL
&& ws.samples.len() >= DICT_MIN_SAMPLES
&& ws.sample_bytes >= DICT_TRAIN_MIN_BYTES;
let hard_cap = ws.seen >= DICT_MAX_WARMUP_MSGS;
if !interval_due && !hard_cap {
DictDecision::Continue
} else {
ws.last_eval_at = ws.seen;
match train_dict(&ws.samples) {
Some(candidate) => {
let ratio = measure_ratio(&candidate, &ws.holdout);
let plateau = match ws.prev_ratio {
Some(prev) if prev > 0.0 => (prev - ratio) / prev < DICT_PLATEAU_EPS,
Some(_) => true,
None => false,
};
if plateau {
ws.plateau_hits += 1;
} else {
ws.plateau_hits = 0;
}
ws.prev_ratio = Some(ratio);
if ws.plateau_hits >= DICT_PLATEAU_CONFIRMATIONS || hard_cap {
DictDecision::Freeze(candidate)
} else {
DictDecision::Continue
}
},
None => {
if hard_cap {
DictDecision::Disable
} else {
DictDecision::Continue
}
},
}
}
};
match decision {
DictDecision::Continue => {},
DictDecision::Disable => self.state = DictState::Disabled,
DictDecision::Freeze(dict) => self.freeze(ctx, dict),
}
}
fn freeze(&mut self, ctx: &DictPersistContext<'_>, dict: Vec<u8>) {
let dir = format!("{}/{}-{}", ctx.base_path, ctx.queue_name, ctx.part_id);
let path = format!("{}/{}_dict", dir, ctx.queue_name);
let tmp = format!("{}.tmp.{}", path, std::process::id());
let persisted = (|| -> std::io::Result<()> {
let mut f = OpenOptions::new().read(true).write(true).create(true).truncate(true).open(&tmp)?;
f.write_all(&dict)?;
f.sync_all()?;
drop(f);
rename(&tmp, &path)?;
Ok(())
})();
if let Err(e) = persisted {
warn!(
"queue:{}:{} freeze_dict, persist failed, err={}; keeping part uncompressed",
ctx.queue_name, ctx.part_id, e
);
let _ = remove_file(&tmp);
self.state = DictState::Disabled;
return;
}
match Compressor::with_dictionary(ZSTD_LEVEL, &dict) {
Ok(c) => {
debug!("queue:{}:{} dictionary frozen, {} bytes", ctx.queue_name, ctx.part_id, dict.len());
self.state = DictState::Ready(Box::new(c));
},
Err(e) => {
warn!(
"queue:{}:{} freeze_dict, build compressor failed, err={}; keeping part uncompressed",
ctx.queue_name, ctx.part_id, e
);
self.state = DictState::Disabled;
},
}
}
}
pub fn dict_path(base_path: &str, queue_name: &str, part_id: u32) -> String {
format!("{}/{}-{}/{}_dict", base_path, queue_name, part_id, queue_name)
}