#[cfg(feature = "backend-zpaq")]
use std::f64::consts::LN_2;
#[cfg(feature = "backend-zpaq")]
const DEFAULT_MIN_PROB: f64 = 5.960_464_477_539_063e-8;
#[cfg(feature = "backend-zpaq")]
mod imp {
use super::{DEFAULT_MIN_PROB, LN_2};
use zpaq_rs::StreamingCompressor;
struct ZpaqStreaming {
compressor: StreamingCompressor,
last_bits: f64,
}
pub struct ZpaqRateModel {
stream: ZpaqStreaming,
history: Vec<u8>,
pending_symbol: Option<u8>,
pending_bits: f64,
min_prob: f64,
method: String,
}
impl ZpaqRateModel {
pub fn new(method: impl Into<String>, min_prob: f64) -> Self {
let method = method.into();
let min_prob = if min_prob.is_finite() && min_prob > 0.0 {
min_prob
} else {
DEFAULT_MIN_PROB
};
let compressor = StreamingCompressor::new(method.as_str()).unwrap_or_else(|e| {
panic!("ZPAQ rate backend requires a streamable method; got '{method}': {e}")
});
Self {
stream: ZpaqStreaming {
compressor,
last_bits: 0.0,
},
history: Vec::new(),
pending_symbol: None,
pending_bits: 0.0,
min_prob,
method,
}
}
pub fn reset(&mut self) {
let method = self.method.clone();
let compressor = StreamingCompressor::new(method.as_str()).unwrap_or_else(|e| {
panic!("ZPAQ rate backend requires a streamable method; got '{method}': {e}")
});
self.stream = ZpaqStreaming {
compressor,
last_bits: 0.0,
};
self.history.clear();
self.pending_symbol = None;
self.pending_bits = 0.0;
}
fn rebuild_stream_from_history(&mut self) {
let method = self.method.clone();
let compressor = StreamingCompressor::new(method.as_str()).unwrap_or_else(|e| {
panic!("ZPAQ rate backend requires a streamable method; got '{method}': {e}")
});
self.stream = ZpaqStreaming {
compressor,
last_bits: 0.0,
};
let history = self.history.clone();
for b in history {
let _ = self.encode_bits(b);
}
self.pending_symbol = None;
self.pending_bits = 0.0;
}
fn log_prob_from_history(&self, symbol: u8) -> f64 {
let mut compressor =
StreamingCompressor::new(self.method.as_str()).expect("zpaq streaming new failed");
for &b in &self.history {
compressor
.push(b)
.expect("zpaq streaming compression failed");
}
let before = compressor.bits();
compressor
.push(symbol)
.expect("zpaq streaming compression failed");
let bits = (compressor.bits() - before).max(0.0);
let logp = -(bits * LN_2);
logp.max(self.min_prob.ln())
}
fn encode_bits(&mut self, symbol: u8) -> f64 {
let before = self.stream.last_bits;
self.stream
.compressor
.push(symbol)
.expect("zpaq streaming compression failed");
let after = self.stream.compressor.bits();
self.stream.last_bits = after;
(after - before).max(0.0)
}
pub fn log_prob(&mut self, symbol: u8) -> f64 {
if let Some(pending) = self.pending_symbol {
if pending == symbol {
let logp = -(self.pending_bits * LN_2);
return logp.max(self.min_prob.ln());
}
self.rebuild_stream_from_history();
}
let bits = self.encode_bits(symbol);
self.pending_symbol = Some(symbol);
self.pending_bits = bits;
let logp = -(bits * LN_2);
logp.max(self.min_prob.ln())
}
pub fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
self.rebuild_stream_from_history();
for (sym, slot) in out.iter_mut().enumerate() {
*slot = self.log_prob_from_history(sym as u8);
}
}
pub fn update(&mut self, symbol: u8) {
if let Some(pending) = self.pending_symbol
&& pending == symbol
{
self.pending_symbol = None;
self.history.push(symbol);
return;
}
if self.pending_symbol.is_some() {
self.rebuild_stream_from_history();
}
let _ = self.encode_bits(symbol);
self.pending_symbol = None;
self.pending_bits = 0.0;
self.history.push(symbol);
}
pub fn update_and_score(&mut self, data: &[u8]) -> f64 {
if data.is_empty() {
return 0.0;
}
self.pending_symbol = None;
self.pending_bits = 0.0;
let mut bits = 0.0;
for &b in data {
bits += self.encode_bits(b);
self.history.push(b);
}
bits
}
}
impl Clone for ZpaqRateModel {
fn clone(&self) -> Self {
let mut cloned = Self::new(self.method.clone(), self.min_prob);
if !self.history.is_empty() {
let _ = cloned.update_and_score(&self.history);
}
if let Some(symbol) = self.pending_symbol {
let bits = cloned.encode_bits(symbol);
cloned.pending_symbol = Some(symbol);
cloned.pending_bits = bits;
} else {
cloned.pending_symbol = None;
cloned.pending_bits = 0.0;
}
cloned
}
}
pub fn validate_zpaq_rate_method(method: &str) -> Result<(), String> {
StreamingCompressor::new(method)
.map(|_| ())
.map_err(|e| e.to_string())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn zpaq_log_prob_update_matches_update_and_score() {
let data = b"the quick brown fox jumps over the lazy dog";
let mut model_a = ZpaqRateModel::new("1", 1e-9);
let mut bits_a = 0.0;
for &b in data {
let logp = model_a.log_prob(b);
bits_a += -logp / LN_2;
model_a.update(b);
}
let mut model_b = ZpaqRateModel::new("1", 1e-9);
let bits_b = model_b.update_and_score(data);
let diff = (bits_a - bits_b).abs();
assert!(diff < 1e-6, "bits mismatch: {bits_a} vs {bits_b}");
}
#[test]
fn zpaq_fill_log_probs_is_non_mutating() {
let history = b"zpaq fill non mutating";
let mut model_a = ZpaqRateModel::new("1", 1e-9);
let mut model_b = ZpaqRateModel::new("1", 1e-9);
for &b in history {
model_a.update(b);
model_b.update(b);
}
let mut row = [0.0f64; 256];
model_b.fill_log_probs(&mut row);
let sym = b'x';
let lp_a = model_a.log_prob(sym);
let lp_b = model_b.log_prob(sym);
assert!((lp_a - lp_b).abs() < 1e-9, "lp_a={lp_a} lp_b={lp_b}");
assert!((row[sym as usize] - lp_a).abs() < 1e-9);
model_a.update(sym);
model_b.update(sym);
let next_sym = b'y';
let lp_a2 = model_a.log_prob(next_sym);
let lp_b2 = model_b.log_prob(next_sym);
assert!((lp_a2 - lp_b2).abs() < 1e-9, "lp_a2={lp_a2} lp_b2={lp_b2}");
}
#[test]
fn zpaq_clone_preserves_pending_prediction_state() {
let mut model_a = ZpaqRateModel::new("1", 1e-9);
for &b in b"clone preserves pending state" {
model_a.update(b);
}
let probe = b'x';
let lp_a = model_a.log_prob(probe);
let mut model_b = model_a.clone();
let lp_b = model_b.log_prob(probe);
assert!((lp_a - lp_b).abs() < 1e-9, "lp_a={lp_a} lp_b={lp_b}");
model_a.update(probe);
model_b.update(probe);
let next = b'y';
let lp_a2 = model_a.log_prob(next);
let lp_b2 = model_b.log_prob(next);
assert!((lp_a2 - lp_b2).abs() < 1e-9, "lp_a2={lp_a2} lp_b2={lp_b2}");
}
}
}
#[cfg(not(feature = "backend-zpaq"))]
mod imp {
#[derive(Clone)]
pub struct ZpaqRateModel {
min_log_prob: f64,
}
impl ZpaqRateModel {
pub fn new(_method: impl Into<String>, min_prob: f64) -> Self {
let min_prob = if min_prob.is_finite() && min_prob > 0.0 {
min_prob
} else {
1e-12
};
Self {
min_log_prob: min_prob.ln(),
}
}
pub fn reset(&mut self) {}
pub fn log_prob(&mut self, _symbol: u8) -> f64 {
self.min_log_prob
}
pub fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
out.fill(self.min_log_prob);
}
pub fn update(&mut self, _symbol: u8) {}
pub fn update_and_score(&mut self, data: &[u8]) -> f64 {
let bits_per_symbol = -self.min_log_prob / std::f64::consts::LN_2;
bits_per_symbol * (data.len() as f64)
}
}
pub fn validate_zpaq_rate_method(_method: &str) -> Result<(), String> {
Err("zpaq backend disabled at compile time".to_string())
}
}
pub use imp::ZpaqRateModel;
pub use imp::validate_zpaq_rate_method;