#![allow(unsafe_op_in_unsafe_fn)]
pub mod aixi;
pub mod axioms;
pub mod backends;
pub mod coders;
pub mod compression;
pub mod datagen;
pub mod mixture;
pub(crate) mod neural_mix;
pub mod search;
pub(crate) mod simd_math;
pub use backends::ctw;
#[cfg(feature = "backend-mamba")]
pub use backends::mambazip;
pub use backends::match_model;
pub use backends::particle;
pub use backends::ppmd;
pub use backends::rosaplus;
#[cfg(feature = "backend-rwkv")]
pub use backends::rwkvzip;
pub use backends::sparse_match;
pub use backends::zpaq_rate;
use rayon::prelude::*;
use crate::coders::CoderType;
use std::cell::RefCell;
#[cfg(any(feature = "backend-rwkv", feature = "backend-mamba"))]
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::OnceLock;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum GenerationUpdateMode {
Adaptive,
Frozen,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum GenerationStrategy {
Greedy,
Sample,
}
#[derive(Clone, Copy, Debug)]
pub struct GenerationConfig {
pub strategy: GenerationStrategy,
pub update_mode: GenerationUpdateMode,
pub seed: u64,
pub temperature: f64,
pub top_k: usize,
pub top_p: f64,
}
impl Default for GenerationConfig {
fn default() -> Self {
Self::sampled_frozen(42)
}
}
impl GenerationConfig {
pub const fn greedy_frozen() -> Self {
Self {
strategy: GenerationStrategy::Greedy,
update_mode: GenerationUpdateMode::Frozen,
seed: 0xD00D_F00D_CAFE_BABEu64,
temperature: 1.0,
top_k: 0,
top_p: 1.0,
}
}
pub const fn sampled_frozen(seed: u64) -> Self {
Self {
strategy: GenerationStrategy::Sample,
update_mode: GenerationUpdateMode::Frozen,
seed,
temperature: 1.0,
top_k: 0,
top_p: 1.0,
}
}
}
struct GenerationRng {
state: u64,
}
impl GenerationRng {
fn new(seed: u64) -> Self {
Self {
state: if seed == 0 {
0xD00D_F00D_CAFE_BABEu64
} else {
seed
},
}
}
fn next_u64(&mut self) -> u64 {
let mut x = self.state;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
self.state = x;
x
}
fn next_f64(&mut self) -> f64 {
(self.next_u64() as f64) / (u64::MAX as f64)
}
}
static NUM_THREADS: OnceLock<usize> = OnceLock::new();
thread_local! {
#[cfg(feature = "backend-mamba")]
static MAMBA_TLS: RefCell<HashMap<usize, mambazip::Compressor>> = RefCell::new(HashMap::new());
#[cfg(feature = "backend-mamba")]
static MAMBA_RATE_TLS: RefCell<HashMap<usize, mambazip::Compressor>> = RefCell::new(HashMap::new());
#[cfg(feature = "backend-mamba")]
static MAMBA_METHOD_TLS: RefCell<HashMap<String, mambazip::Compressor>> = RefCell::new(HashMap::new());
#[cfg(feature = "backend-rwkv")]
static RWKV_TLS: RefCell<HashMap<usize, rwkvzip::Compressor>> = RefCell::new(HashMap::new());
#[cfg(feature = "backend-rwkv")]
static RWKV_RATE_TLS: RefCell<HashMap<usize, rwkvzip::Compressor>> = RefCell::new(HashMap::new());
#[cfg(feature = "backend-rwkv")]
static RWKV_METHOD_TLS: RefCell<HashMap<String, rwkvzip::Compressor>> = RefCell::new(HashMap::new());
}
#[cfg(feature = "backend-zpaq")]
impl Default for CompressionBackend {
fn default() -> Self {
CompressionBackend::Zpaq {
method: "5".to_string(),
}
}
}
#[cfg(not(feature = "backend-zpaq"))]
impl Default for CompressionBackend {
fn default() -> Self {
CompressionBackend::Rate {
rate_backend: RateBackend::default(),
coder: CoderType::AC,
framing: compression::FramingMode::Raw,
}
}
}
thread_local! {
static DEFAULT_CTX: RefCell<InfotheoryCtx> = RefCell::new(InfotheoryCtx::default());
}
pub fn get_default_ctx() -> InfotheoryCtx {
DEFAULT_CTX.with(|ctx| ctx.borrow().clone())
}
pub fn set_default_ctx(ctx: InfotheoryCtx) {
DEFAULT_CTX.with(|c| *c.borrow_mut() = ctx);
}
#[inline(always)]
fn with_default_ctx<R>(f: impl FnOnce(&InfotheoryCtx) -> R) -> R {
DEFAULT_CTX.with(|ctx| f(&ctx.borrow()))
}
pub fn mutual_information_rate_backend(
x: &[u8],
y: &[u8],
max_order: i64,
backend: &RateBackend,
) -> f64 {
let (x, y) = aligned_prefix(x, y);
if x.is_empty() {
return 0.0;
}
let h_x = entropy_rate_backend(x, max_order, backend);
let h_y = entropy_rate_backend(y, max_order, backend);
let h_xy = joint_entropy_rate_backend(x, y, max_order, backend);
(h_x + h_y - h_xy).max(0.0)
}
pub fn ned_rate_backend(x: &[u8], y: &[u8], max_order: i64, backend: &RateBackend) -> f64 {
let (x, y) = aligned_prefix(x, y);
if x.is_empty() {
return 0.0;
}
let h_x = entropy_rate_backend(x, max_order, backend);
let h_y = entropy_rate_backend(y, max_order, backend);
let h_xy = joint_entropy_rate_backend(x, y, max_order, backend);
let min_h = h_x.min(h_y);
let max_h = h_x.max(h_y);
if max_h == 0.0 {
0.0
} else {
((h_xy - min_h) / max_h).clamp(0.0, 1.0)
}
}
pub fn nte_rate_backend(x: &[u8], y: &[u8], max_order: i64, backend: &RateBackend) -> f64 {
let (x, y) = aligned_prefix(x, y);
if x.is_empty() {
return 0.0;
}
let h_x = entropy_rate_backend(x, max_order, backend);
let h_y = entropy_rate_backend(y, max_order, backend);
let h_xy = joint_entropy_rate_backend(x, y, max_order, backend);
let max_h = h_x.max(h_y);
if max_h == 0.0 {
0.0
} else {
let vi = (h_xy - h_x).max(0.0) + (h_xy - h_y).max(0.0);
(vi / max_h).clamp(0.0, 2.0)
}
}
#[derive(Clone)]
pub enum RateBackend {
RosaPlus,
Match {
hash_bits: usize,
min_len: usize,
max_len: usize,
base_mix: f64,
confidence_scale: f64,
},
SparseMatch {
hash_bits: usize,
min_len: usize,
max_len: usize,
gap_min: usize,
gap_max: usize,
base_mix: f64,
confidence_scale: f64,
},
Ppmd {
order: usize,
memory_mb: usize,
},
#[cfg(feature = "backend-mamba")]
Mamba {
model: Arc<mambazip::Model>,
},
#[cfg(feature = "backend-mamba")]
MambaMethod {
method: String,
},
#[cfg(feature = "backend-rwkv")]
Rwkv7 {
model: Arc<rwkvzip::Model>,
},
#[cfg(feature = "backend-rwkv")]
Rwkv7Method {
method: String,
},
Zpaq {
method: String,
},
Mixture {
spec: Arc<MixtureSpec>,
},
Particle {
spec: Arc<ParticleSpec>,
},
Calibrated {
spec: Arc<CalibratedSpec>,
},
Ctw {
depth: usize,
},
FacCtw {
base_depth: usize,
num_percept_bits: usize,
encoding_bits: usize,
},
}
#[allow(clippy::derivable_impls)]
impl Default for RateBackend {
fn default() -> Self {
#[cfg(feature = "backend-rosa")]
{
RateBackend::RosaPlus
}
#[cfg(all(not(feature = "backend-rosa"), feature = "backend-zpaq"))]
{
RateBackend::Zpaq {
method: "1".to_string(),
}
}
#[cfg(all(not(feature = "backend-rosa"), not(feature = "backend-zpaq")))]
{
RateBackend::Ctw { depth: 16 }
}
}
}
#[derive(Clone)]
pub enum CompressionBackend {
Zpaq {
method: String,
},
#[cfg(feature = "backend-rwkv")]
Rwkv7 {
model: Arc<rwkvzip::Model>,
coder: CoderType,
},
Rate {
rate_backend: RateBackend,
coder: CoderType,
framing: compression::FramingMode,
},
}
pub const MAX_MIXTURE_NESTING: usize = 8;
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum MixtureKind {
Bayes,
FadingBayes,
Switching,
Convex,
Mdl,
Neural,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum MixtureScheduleMode {
Default,
Theorem,
}
impl Default for MixtureScheduleMode {
fn default() -> Self {
Self::Default
}
}
pub fn parse_mixture_kind_name(kind: &str) -> Result<MixtureKind, String> {
match kind.trim().to_ascii_lowercase().as_str() {
"bayes" | "bayes-mix" | "bayes_mix" => Ok(MixtureKind::Bayes),
"fading" | "fading-bayes" | "fading_bayes" => Ok(MixtureKind::FadingBayes),
"switch" | "switching" | "switch-mix" | "switch_mix" => Ok(MixtureKind::Switching),
"convex" | "convex-mix" | "convex_mix" => Ok(MixtureKind::Convex),
"mdl" | "selector" | "mdr" => Ok(MixtureKind::Mdl),
"neural" | "neural-mix" | "neural_mix" | "mix" | "mixture" | "fx2" | "fx2-cmix"
| "fx2_cmix" => Ok(MixtureKind::Neural),
other => Err(format!("unknown mixture kind '{other}'")),
}
}
pub fn parse_mixture_schedule_name(schedule: &str) -> Result<MixtureScheduleMode, String> {
match schedule.trim().to_ascii_lowercase().as_str() {
"" | "default" | "constant" | "const" => Ok(MixtureScheduleMode::Default),
"theorem" | "paper" | "paper-theorem" | "paper_theorem" => Ok(MixtureScheduleMode::Theorem),
other => Err(format!("unknown mixture schedule '{other}'")),
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum CalibrationContextKind {
Global,
ByteClass,
Text,
Repeat,
TextRepeat,
}
#[derive(Clone)]
pub struct CalibratedSpec {
pub base: RateBackend,
pub context: CalibrationContextKind,
pub bins: usize,
pub learning_rate: f64,
pub bias_clip: f64,
}
#[derive(Clone)]
pub struct MixtureExpertSpec {
pub name: Option<String>,
pub log_prior: f64,
pub max_order: i64,
pub backend: RateBackend,
}
#[derive(Clone)]
pub struct MixtureSpec {
pub kind: MixtureKind,
pub schedule: MixtureScheduleMode,
pub alpha: f64,
pub decay: Option<f64>,
pub experts: Vec<MixtureExpertSpec>,
}
impl MixtureSpec {
pub fn new(kind: MixtureKind, experts: Vec<MixtureExpertSpec>) -> Self {
Self {
kind,
schedule: MixtureScheduleMode::Default,
alpha: 0.01,
decay: None,
experts,
}
}
pub fn with_schedule(mut self, schedule: MixtureScheduleMode) -> Self {
self.schedule = schedule;
self
}
pub fn with_alpha(mut self, alpha: f64) -> Self {
self.alpha = alpha;
self
}
pub fn with_decay(mut self, decay: f64) -> Self {
self.decay = Some(decay);
self
}
pub fn validate(&self) -> Result<(), String> {
validate_mixture_spec_with_depth(self, MAX_MIXTURE_NESTING)
}
pub fn build_experts(&self) -> Vec<crate::mixture::ExpertConfig> {
self.experts
.iter()
.map(|spec| {
crate::mixture::ExpertConfig::from_rate_backend(
spec.name.clone(),
spec.log_prior,
spec.backend.clone(),
spec.max_order,
)
})
.collect()
}
}
fn validate_mixture_spec_with_depth(spec: &MixtureSpec, depth: usize) -> Result<(), String> {
if depth == 0 {
return Err("mixture spec nesting too deep".to_string());
}
validate_mixture_spec_shallow(spec)?;
for (index, expert) in spec.experts.iter().enumerate() {
validate_rate_backend_with_depth(&expert.backend, depth - 1).map_err(|err| {
if let Some(name) = expert.name.as_deref() {
format!("mixture expert '{name}' invalid: {err}")
} else {
format!("mixture expert #{} invalid: {err}", index + 1)
}
})?;
}
Ok(())
}
fn validate_mixture_spec_shallow(spec: &MixtureSpec) -> Result<(), String> {
if spec.experts.is_empty() {
return Err("mixture spec must include at least one expert".to_string());
}
if !spec.alpha.is_finite() {
return Err("mixture alpha must be finite".to_string());
}
if spec
.experts
.iter()
.any(|expert| !expert.log_prior.is_finite())
{
return Err("mixture expert log_prior must be finite".to_string());
}
if let Some(decay) = spec.decay {
if !decay.is_finite() || !(0.0..1.0).contains(&decay) {
return Err("mixture decay must be in (0, 1)".to_string());
}
}
if matches!(spec.kind, MixtureKind::FadingBayes) && spec.decay.is_none() {
return Err("fading Bayes mixture requires decay".to_string());
}
if spec.schedule != MixtureScheduleMode::Default
&& !matches!(spec.kind, MixtureKind::Switching | MixtureKind::Convex)
{
return Err(
"mixture schedule is only supported for switching and convex mixtures".to_string(),
);
}
match (spec.kind, spec.schedule) {
(MixtureKind::Switching, MixtureScheduleMode::Default) => {
if !(0.0..=1.0).contains(&spec.alpha) {
return Err("switching mixture alpha must be in [0, 1]".to_string());
}
}
(MixtureKind::Convex, MixtureScheduleMode::Default)
| (MixtureKind::Neural, MixtureScheduleMode::Default) => {
if spec.alpha <= 0.0 {
return Err("mixture alpha must be > 0".to_string());
}
}
(MixtureKind::Neural, MixtureScheduleMode::Theorem) => unreachable!(),
_ => {}
}
Ok(())
}
fn validate_rate_backend_with_depth(backend: &RateBackend, depth: usize) -> Result<(), String> {
match backend {
RateBackend::Mixture { spec } => validate_mixture_spec_with_depth(spec.as_ref(), depth),
RateBackend::Particle { spec } => spec.validate(),
RateBackend::Calibrated { spec } => {
if depth == 0 {
return Err("calibrated spec nesting too deep".to_string());
}
validate_rate_backend_with_depth(&spec.base, depth - 1)
.map_err(|err| format!("calibrated base invalid: {err}"))
}
_ => Ok(()),
}
}
pub fn validate_rate_backend(backend: &RateBackend) -> Result<(), String> {
validate_rate_backend_with_depth(backend, MAX_MIXTURE_NESTING)
}
#[derive(Clone, Debug)]
pub struct ParticleSpec {
pub num_particles: usize,
pub context_window: usize,
pub unroll_steps: usize,
pub num_cells: usize,
pub cell_dim: usize,
pub num_rules: usize,
pub selector_hidden: usize,
pub rule_hidden: usize,
pub noise_dim: usize,
pub deterministic: bool,
pub enable_noise: bool,
pub noise_scale: f64,
pub noise_anneal_steps: usize,
pub learning_rate_readout: f64,
pub learning_rate_selector: f64,
pub learning_rate_rule: f64,
pub bptt_depth: usize,
pub optimizer_momentum: f64,
pub grad_clip: f64,
pub state_clip: f64,
pub forget_lambda: f64,
pub resample_threshold: f64,
pub mutate_fraction: f64,
pub mutate_scale: f64,
pub mutate_model_params: bool,
pub diagnostics_interval: usize,
pub min_prob: f64,
pub seed: u64,
}
impl Default for ParticleSpec {
fn default() -> Self {
Self {
num_particles: 16,
context_window: 32,
unroll_steps: 2,
num_cells: 8,
cell_dim: 32,
num_rules: 4,
selector_hidden: 64,
rule_hidden: 64,
noise_dim: 8,
deterministic: true,
enable_noise: false,
noise_scale: 0.10,
noise_anneal_steps: 8192,
learning_rate_readout: 0.01,
learning_rate_selector: 1e-4,
learning_rate_rule: 3e-4,
bptt_depth: 3,
optimizer_momentum: 0.05,
grad_clip: 1.0,
state_clip: 8.0,
forget_lambda: 0.0,
resample_threshold: 0.5,
mutate_fraction: 0.1,
mutate_scale: 0.01,
mutate_model_params: false,
diagnostics_interval: 0,
min_prob: 2f64.powi(-24),
seed: 42,
}
}
}
impl ParticleSpec {
pub fn validate(&self) -> Result<(), String> {
if self.num_particles == 0 {
return Err("num_particles must be > 0".into());
}
if self.context_window == 0 {
return Err("context_window must be > 0".into());
}
if self.unroll_steps == 0 {
return Err("unroll_steps must be > 0".into());
}
if self.num_cells == 0 {
return Err("num_cells must be > 0".into());
}
if self.cell_dim == 0 {
return Err("cell_dim must be > 0".into());
}
if self.num_rules == 0 {
return Err("num_rules must be > 0".into());
}
if self.selector_hidden == 0 {
return Err("selector_hidden must be > 0".into());
}
if self.rule_hidden == 0 {
return Err("rule_hidden must be > 0".into());
}
if !self.learning_rate_readout.is_finite() || self.learning_rate_readout < 0.0 {
return Err("learning_rate_readout must be finite and non-negative".into());
}
if !self.learning_rate_selector.is_finite() || self.learning_rate_selector < 0.0 {
return Err("learning_rate_selector must be finite and non-negative".into());
}
if !self.learning_rate_rule.is_finite() || self.learning_rate_rule < 0.0 {
return Err("learning_rate_rule must be finite and non-negative".into());
}
if !self.noise_scale.is_finite() || self.noise_scale < 0.0 {
return Err("noise_scale must be finite and non-negative".into());
}
if !self.optimizer_momentum.is_finite()
|| self.optimizer_momentum < 0.0
|| self.optimizer_momentum >= 1.0
{
return Err("optimizer_momentum must be finite and in [0, 1)".into());
}
if self.bptt_depth == 0 {
return Err("bptt_depth must be > 0".into());
}
if !(self.resample_threshold > 0.0 && self.resample_threshold <= 1.0) {
return Err("resample_threshold must be in (0, 1]".into());
}
if !(self.mutate_fraction >= 0.0 && self.mutate_fraction <= 1.0) {
return Err("mutate_fraction must be in [0, 1]".into());
}
if !(self.min_prob > 0.0 && self.min_prob < 0.5) {
return Err("min_prob must be in (0, 0.5)".into());
}
Ok(())
}
}
#[derive(Clone, Default)]
pub struct InfotheoryCtx {
pub rate_backend: RateBackend,
pub compression_backend: CompressionBackend,
}
pub struct RateBackendSession {
predictor: crate::mixture::RateBackendPredictor,
}
impl RateBackendSession {
pub fn from_backend(
backend: RateBackend,
max_order: i64,
total_symbols: Option<u64>,
) -> Result<Self, String> {
use crate::mixture::OnlineBytePredictor;
validate_rate_backend(&backend)?;
let mut predictor = crate::mixture::RateBackendPredictor::from_backend(
backend,
max_order,
crate::mixture::DEFAULT_MIN_PROB,
);
predictor.begin_stream(total_symbols)?;
Ok(Self { predictor })
}
pub fn observe(&mut self, data: &[u8]) {
use crate::mixture::OnlineBytePredictor;
for &byte in data {
self.predictor.update(byte);
}
}
pub fn condition(&mut self, data: &[u8]) {
use crate::mixture::OnlineBytePredictor;
for &byte in data {
self.predictor.update_frozen(byte);
}
}
pub fn reset_frozen(&mut self, total_symbols: Option<u64>) -> Result<(), String> {
use crate::mixture::OnlineBytePredictor;
self.predictor.reset_frozen(total_symbols)
}
pub fn fill_log_probs(&mut self, out: &mut [f64; 256]) {
use crate::mixture::OnlineBytePredictor;
self.predictor.fill_log_probs(out);
}
pub fn generate_bytes(&mut self, bytes: usize, config: GenerationConfig) -> Vec<u8> {
use crate::mixture::OnlineBytePredictor;
if bytes == 0 {
return Vec::new();
}
let mut out = Vec::with_capacity(bytes);
let mut logps = [0.0f64; 256];
let mut rng = GenerationRng::new(config.seed);
for _ in 0..bytes {
match &mut self.predictor {
crate::mixture::RateBackendPredictor::Rosa { .. } => {
for (sym, slot) in logps.iter_mut().enumerate() {
*slot = self.predictor.log_prob(sym as u8);
}
}
_ => self.predictor.fill_log_probs(&mut logps),
}
let byte = pick_generated_byte(&logps, config, &mut rng);
match config.update_mode {
GenerationUpdateMode::Adaptive => self.predictor.update(byte),
GenerationUpdateMode::Frozen => self.predictor.update_frozen(byte),
}
out.push(byte);
}
out
}
pub fn finish(&mut self) -> Result<(), String> {
use crate::mixture::OnlineBytePredictor;
self.predictor.finish_stream()
}
}
impl InfotheoryCtx {
pub fn new(rate_backend: RateBackend, compression_backend: CompressionBackend) -> Self {
Self {
rate_backend,
compression_backend,
}
}
pub fn with_zpaq(method: impl Into<String>) -> Self {
Self {
rate_backend: RateBackend::RosaPlus,
compression_backend: CompressionBackend::Zpaq {
method: method.into(),
},
}
}
pub fn compress_size(&self, data: &[u8]) -> u64 {
compress_size_backend(data, &self.compression_backend)
}
pub fn compress_size_chain(&self, parts: &[&[u8]]) -> u64 {
compress_size_chain_backend(parts, &self.compression_backend)
}
pub fn rate_backend_session(
&self,
max_order: i64,
total_symbols: Option<u64>,
) -> Result<RateBackendSession, String> {
RateBackendSession::from_backend(self.rate_backend.clone(), max_order, total_symbols)
}
pub fn entropy_rate_bytes(&self, data: &[u8], max_order: i64) -> f64 {
entropy_rate_backend(data, max_order, &self.rate_backend)
}
pub fn biased_entropy_rate_bytes(&self, data: &[u8], max_order: i64) -> f64 {
biased_entropy_rate_backend(data, max_order, &self.rate_backend)
}
pub fn cross_entropy_rate_bytes(
&self,
test_data: &[u8],
train_data: &[u8],
max_order: i64,
) -> f64 {
cross_entropy_rate_backend(test_data, train_data, max_order, &self.rate_backend)
}
pub fn cross_entropy_bytes(&self, test_data: &[u8], train_data: &[u8], max_order: i64) -> f64 {
if max_order == 0 {
if test_data.is_empty() {
return 0.0;
}
let p_x = byte_histogram(test_data);
let p_y = byte_histogram(train_data);
let mut h = 0.0f64;
for i in 0..256 {
if p_x[i] > 0.0 {
let q_y = p_y[i].max(1e-12);
h -= p_x[i] * q_y.log2();
}
}
h
} else {
self.cross_entropy_rate_bytes(test_data, train_data, max_order)
}
}
pub fn joint_entropy_rate_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
let (x, y) = aligned_prefix(x, y);
if x.is_empty() {
return 0.0;
}
joint_entropy_rate_backend(x, y, max_order, &self.rate_backend)
}
pub fn conditional_entropy_rate_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
let (x, y) = aligned_prefix(x, y);
if x.is_empty() {
return 0.0;
}
let h_xy = self.joint_entropy_rate_bytes(x, y, max_order);
let h_y = self.entropy_rate_bytes(y, max_order);
(h_xy - h_y).max(0.0)
}
pub fn cross_entropy_conditional_chain(&self, prefix_parts: &[&[u8]], data: &[u8]) -> f64 {
match &self.rate_backend {
RateBackend::RosaPlus => {
let mut prefix = Vec::new();
let total: usize = prefix_parts.iter().map(|p| p.len()).sum();
prefix.reserve(total);
for p in prefix_parts {
prefix.extend_from_slice(p);
}
cross_entropy_rate_backend(data, &prefix, -1, &RateBackend::RosaPlus)
}
RateBackend::Match { .. }
| RateBackend::SparseMatch { .. }
| RateBackend::Ppmd { .. }
| RateBackend::Calibrated { .. } => {
prequential_rate_backend(data, prefix_parts, -1, &self.rate_backend)
}
#[cfg(feature = "backend-rwkv")]
RateBackend::Rwkv7 { model } => with_rwkv_tls(model, |c| {
c.cross_entropy_conditional_chain(prefix_parts, data)
.unwrap_or_else(|e| panic!("rwkv conditional-chain scoring failed: {e:#}"))
}),
#[cfg(feature = "backend-rwkv")]
RateBackend::Rwkv7Method { method } => with_rwkv_method_tls(method, |c| {
c.cross_entropy_conditional_chain(prefix_parts, data)
.unwrap_or_else(|e| {
panic!("rwkv method conditional-chain scoring failed: {e:#}")
})
}),
#[cfg(feature = "backend-mamba")]
RateBackend::Mamba { model } => with_mamba_tls(model, |c| {
c.cross_entropy_conditional_chain(prefix_parts, data)
.unwrap_or_else(|e| panic!("mamba conditional-chain scoring failed: {e:#}"))
}),
#[cfg(feature = "backend-mamba")]
RateBackend::MambaMethod { method } => with_mamba_method_tls(method, |c| {
c.cross_entropy_conditional_chain(prefix_parts, data)
.unwrap_or_else(|e| {
panic!("mamba method conditional-chain scoring failed: {e:#}")
})
}),
RateBackend::Ctw { depth } => {
if data.is_empty() {
return 0.0;
}
let mut tree = crate::ctw::ContextTree::new(*depth);
for &part in prefix_parts {
for &b in part {
for i in (0..8).rev() {
tree.update(((b >> i) & 1) == 1);
}
}
}
let log_p_prefix = tree.get_log_block_probability();
for &b in data {
for i in (0..8).rev() {
tree.update(((b >> i) & 1) == 1);
}
}
let log_p_joint = tree.get_log_block_probability();
let log_p_cond = log_p_joint - log_p_prefix;
let bits = -log_p_cond / std::f64::consts::LN_2;
bits / (data.len() as f64)
}
RateBackend::Zpaq { method } => {
if data.is_empty() {
return 0.0;
}
let mut model =
crate::zpaq_rate::ZpaqRateModel::new(method.clone(), 2f64.powi(-24));
for &part in prefix_parts {
model.update_and_score(part);
}
let bits = model.update_and_score(data);
bits / (data.len() as f64)
}
RateBackend::Mixture { spec } => {
if data.is_empty() {
return 0.0;
}
let experts = spec.build_experts();
let mut mix = crate::mixture::build_mixture_runtime(spec.as_ref(), &experts)
.unwrap_or_else(|e| panic!("MixtureSpec invalid: {e}"));
let total = prefix_parts
.iter()
.map(|p| p.len() as u64)
.sum::<u64>()
.saturating_add(data.len() as u64);
mix.begin_stream(Some(total))
.unwrap_or_else(|e| panic!("Mixture stream init failed: {e}"));
for &part in prefix_parts {
for &b in part {
mix.step(b);
}
}
let mut bits = 0.0;
for &b in data {
bits -= mix.step(b) / std::f64::consts::LN_2;
}
bits / (data.len() as f64)
}
RateBackend::Particle { spec } => {
if data.is_empty() {
return 0.0;
}
let mut runtime = crate::particle::ParticleRuntime::new(spec.as_ref());
for &part in prefix_parts {
for &b in part {
runtime.step(b);
}
}
let mut bits = 0.0;
for &b in data {
bits -= runtime.step(b) / std::f64::consts::LN_2;
}
bits / (data.len() as f64)
}
RateBackend::FacCtw {
base_depth,
num_percept_bits: _,
encoding_bits,
} => {
if data.is_empty() {
return 0.0;
}
let bits_per_byte = (*encoding_bits).clamp(1, 8);
let mut fac = crate::ctw::FacContextTree::new(*base_depth, bits_per_byte);
for &part in prefix_parts {
for &b in part {
for i in 0..bits_per_byte {
let bit_idx = i;
fac.update(((b >> i) & 1) == 1, bit_idx);
}
}
}
let log_p_prefix = fac.get_log_block_probability();
for &b in data {
for i in 0..bits_per_byte {
let bit_idx = i;
fac.update(((b >> i) & 1) == 1, bit_idx);
}
}
let log_p_joint = fac.get_log_block_probability();
let log_p_cond = log_p_joint - log_p_prefix;
let bits = -log_p_cond / std::f64::consts::LN_2;
bits / (data.len() as f64)
}
}
}
pub fn generate_bytes(&self, prompt: &[u8], bytes: usize, max_order: i64) -> Vec<u8> {
self.generate_bytes_with_config(prompt, bytes, max_order, GenerationConfig::default())
}
pub fn generate_bytes_with_config(
&self,
prompt: &[u8],
bytes: usize,
max_order: i64,
config: GenerationConfig,
) -> Vec<u8> {
generate_rate_backend_chain(&[prompt], bytes, max_order, &self.rate_backend, config)
}
pub fn generate_bytes_conditional_chain(
&self,
prefix_parts: &[&[u8]],
bytes: usize,
max_order: i64,
) -> Vec<u8> {
self.generate_bytes_conditional_chain_with_config(
prefix_parts,
bytes,
max_order,
GenerationConfig::default(),
)
}
pub fn generate_bytes_conditional_chain_with_config(
&self,
prefix_parts: &[&[u8]],
bytes: usize,
max_order: i64,
config: GenerationConfig,
) -> Vec<u8> {
generate_rate_backend_chain(prefix_parts, bytes, max_order, &self.rate_backend, config)
}
pub fn ncd_bytes(&self, x: &[u8], y: &[u8], variant: NcdVariant) -> f64 {
ncd_bytes_backend(x, y, &self.compression_backend, variant)
}
pub fn mutual_information_rate_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
mutual_information_rate_backend(x, y, max_order, &self.rate_backend)
}
pub fn mutual_information_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
if max_order == 0 {
mutual_information_marg_bytes(x, y)
} else {
self.mutual_information_rate_bytes(x, y, max_order)
}
}
pub fn conditional_entropy_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
let (x, y) = aligned_prefix(x, y);
if max_order == 0 {
let h_xy = joint_marginal_entropy_bytes(x, y);
let h_y = marginal_entropy_bytes(y);
(h_xy - h_y).max(0.0)
} else {
let h_xy = self.joint_entropy_rate_bytes(x, y, max_order);
let h_y = self.entropy_rate_bytes(y, max_order);
(h_xy - h_y).max(0.0)
}
}
pub fn ned_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
if max_order == 0 {
ned_marg_bytes(x, y)
} else {
ned_rate_backend(x, y, max_order, &self.rate_backend)
}
}
pub fn ned_cons_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
let (x, y) = aligned_prefix(x, y);
let (h_x, h_y, h_xy) = if max_order == 0 {
(
marginal_entropy_bytes(x),
marginal_entropy_bytes(y),
joint_marginal_entropy_bytes(x, y),
)
} else {
(
self.entropy_rate_bytes(x, max_order),
self.entropy_rate_bytes(y, max_order),
self.joint_entropy_rate_bytes(x, y, max_order),
)
};
let min_h = h_x.min(h_y);
if h_xy == 0.0 {
0.0
} else {
((h_xy - min_h) / h_xy).clamp(0.0, 1.0)
}
}
pub fn nte_bytes(&self, x: &[u8], y: &[u8], max_order: i64) -> f64 {
if max_order == 0 {
nte_marg_bytes(x, y)
} else {
nte_rate_backend(x, y, max_order, &self.rate_backend)
}
}
pub fn intrinsic_dependence_bytes(&self, data: &[u8], max_order: i64) -> f64 {
let h_marginal = marginal_entropy_bytes(data);
if h_marginal < 1e-9 {
return 0.0;
}
let h_rate = self.entropy_rate_bytes(data, max_order);
((h_marginal - h_rate) / h_marginal).clamp(0.0, 1.0)
}
pub fn resistance_to_transformation_bytes(&self, x: &[u8], tx: &[u8], max_order: i64) -> f64 {
let (x, tx) = aligned_prefix(x, tx);
let h_x = if max_order == 0 {
marginal_entropy_bytes(x)
} else {
self.entropy_rate_bytes(x, max_order)
};
if h_x < 1e-9 {
return 0.0;
}
let mi = self.mutual_information_bytes(x, tx, max_order);
(mi / h_x).clamp(0.0, 1.0)
}
}
#[cfg(feature = "backend-rwkv")]
pub fn load_rwkv7_model_from_path(path: &str) -> Arc<rwkvzip::Model> {
rwkvzip::Compressor::load_model(path).expect("failed to load RWKV7 model")
}
#[cfg(feature = "backend-mamba")]
pub fn load_mamba_model_from_path(path: &str) -> Arc<mambazip::Model> {
mambazip::Compressor::load_model(path).expect("failed to load Mamba model")
}
#[inline(always)]
fn aligned_prefix<'a>(x: &'a [u8], y: &'a [u8]) -> (&'a [u8], &'a [u8]) {
let n = x.len().min(y.len());
(&x[..n], &y[..n])
}
#[cfg(feature = "backend-zpaq")]
#[inline(always)]
fn zpaq_compress_size_bytes(data: &[u8], method: &str) -> u64 {
zpaq_rs::compress_size(data, method).unwrap_or(0)
}
#[cfg(not(feature = "backend-zpaq"))]
#[inline(always)]
fn zpaq_compress_size_bytes(_data: &[u8], _method: &str) -> u64 {
panic!("CompressionBackend::Zpaq is unavailable: build with feature 'backend-zpaq'")
}
#[cfg(feature = "backend-zpaq")]
#[inline(always)]
fn zpaq_compress_size_parallel_bytes(data: &[u8], method: &str, threads: usize) -> u64 {
zpaq_rs::compress_size_parallel(data, method, threads).unwrap_or(0)
}
#[cfg(not(feature = "backend-zpaq"))]
#[inline(always)]
fn zpaq_compress_size_parallel_bytes(_data: &[u8], _method: &str, _threads: usize) -> u64 {
panic!("CompressionBackend::Zpaq is unavailable: build with feature 'backend-zpaq'")
}
#[cfg(feature = "backend-zpaq")]
#[inline(always)]
fn zpaq_compress_size_stream<R: std::io::Read + Send>(reader: R, method: &str) -> u64 {
zpaq_rs::compress_size_stream(reader, method, None, None).unwrap_or(0)
}
#[cfg(not(feature = "backend-zpaq"))]
#[inline(always)]
fn zpaq_compress_size_stream<R: std::io::Read + Send>(_reader: R, _method: &str) -> u64 {
panic!("CompressionBackend::Zpaq is unavailable: build with feature 'backend-zpaq'")
}
#[cfg(feature = "backend-zpaq")]
#[inline(always)]
fn zpaq_compress_to_vec(data: &[u8], method: &str) -> anyhow::Result<Vec<u8>> {
Ok(zpaq_rs::compress_to_vec(data, method)?)
}
#[cfg(not(feature = "backend-zpaq"))]
#[inline(always)]
fn zpaq_compress_to_vec(_data: &[u8], _method: &str) -> anyhow::Result<Vec<u8>> {
anyhow::bail!("zpaq backend disabled at compile time (enable feature 'backend-zpaq')")
}
#[cfg(feature = "backend-zpaq")]
#[inline(always)]
fn zpaq_decompress_to_vec(data: &[u8]) -> anyhow::Result<Vec<u8>> {
Ok(zpaq_rs::decompress_to_vec(data)?)
}
#[cfg(not(feature = "backend-zpaq"))]
#[inline(always)]
fn zpaq_decompress_to_vec(_data: &[u8]) -> anyhow::Result<Vec<u8>> {
anyhow::bail!("zpaq backend disabled at compile time (enable feature 'backend-zpaq')")
}
#[inline(always)]
pub fn get_compressed_size(path: &str, method: &str) -> u64 {
zpaq_compress_size_bytes(&std::fs::read(path).unwrap(), method)
}
pub fn validate_zpaq_rate_method(method: &str) -> Result<(), String> {
#[cfg(feature = "backend-zpaq")]
{
zpaq_rate::validate_zpaq_rate_method(method)
}
#[cfg(not(feature = "backend-zpaq"))]
{
let _ = method;
Err("zpaq backend disabled at compile time".to_string())
}
}
#[cfg(feature = "backend-rwkv")]
fn with_rwkv_tls<R>(
model: &Arc<rwkvzip::Model>,
f: impl FnOnce(&mut rwkvzip::Compressor) -> R,
) -> R {
let key = Arc::as_ptr(model) as usize;
RWKV_TLS.with(|cell| {
let mut map = cell.borrow_mut();
let comp = map
.entry(key)
.or_insert_with(|| rwkvzip::Compressor::new_from_model(model.clone()));
f(comp)
})
}
#[cfg(feature = "backend-rwkv")]
fn with_rwkv_method_tls<R>(method: &str, f: impl FnOnce(&mut rwkvzip::Compressor) -> R) -> R {
RWKV_METHOD_TLS.with(|cell| {
let mut map = cell.borrow_mut();
let mut comp = if let Some(template) = map.get(method) {
template.clone()
} else {
let template = rwkvzip::Compressor::new_from_method(method).unwrap_or_else(|e| {
panic!("invalid rwkv method '{method}': {e:#}");
});
map.insert(method.to_string(), template.clone());
template
};
drop(map);
f(&mut comp)
})
}
#[cfg(feature = "backend-rwkv")]
fn with_rwkv_rate_tls<R>(
model: &Arc<rwkvzip::Model>,
f: impl FnOnce(&mut rwkvzip::Compressor) -> R,
) -> R {
let key = Arc::as_ptr(model) as usize;
RWKV_RATE_TLS.with(|cell| {
let mut map = cell.borrow_mut();
let mut comp = if let Some(template) = map.get(&key) {
template.clone()
} else {
let template = rwkvzip::Compressor::new_from_model(model.clone());
map.insert(key, template.clone());
template
};
drop(map);
f(&mut comp)
})
}
#[cfg(feature = "backend-mamba")]
fn with_mamba_tls<R>(
model: &Arc<mambazip::Model>,
f: impl FnOnce(&mut mambazip::Compressor) -> R,
) -> R {
let key = Arc::as_ptr(model) as usize;
MAMBA_TLS.with(|cell| {
let mut map = cell.borrow_mut();
let comp = map
.entry(key)
.or_insert_with(|| mambazip::Compressor::new_from_model(model.clone()));
f(comp)
})
}
#[cfg(feature = "backend-mamba")]
fn with_mamba_rate_tls<R>(
model: &Arc<mambazip::Model>,
f: impl FnOnce(&mut mambazip::Compressor) -> R,
) -> R {
let key = Arc::as_ptr(model) as usize;
MAMBA_RATE_TLS.with(|cell| {
let mut map = cell.borrow_mut();
let mut comp = if let Some(template) = map.get(&key) {
template.clone()
} else {
let template = mambazip::Compressor::new_from_model(model.clone());
map.insert(key, template.clone());
template
};
drop(map);
f(&mut comp)
})
}
#[cfg(feature = "backend-mamba")]
fn with_mamba_method_tls<R>(method: &str, f: impl FnOnce(&mut mambazip::Compressor) -> R) -> R {
MAMBA_METHOD_TLS.with(|cell| {
let mut map = cell.borrow_mut();
let mut comp = if let Some(template) = map.get(method) {
template.clone()
} else {
let template = mambazip::Compressor::new_from_method(method).unwrap_or_else(|e| {
panic!("invalid mamba method '{method}': {e:#}");
});
map.insert(method.to_string(), template.clone());
template
};
drop(map);
f(&mut comp)
})
}
struct SliceChainReader<'a> {
parts: &'a [&'a [u8]],
i: usize,
off: usize,
}
impl<'a> SliceChainReader<'a> {
fn new(parts: &'a [&'a [u8]]) -> Self {
Self {
parts,
i: 0,
off: 0,
}
}
}
impl<'a> std::io::Read for SliceChainReader<'a> {
fn read(&mut self, mut buf: &mut [u8]) -> std::io::Result<usize> {
let mut total = 0;
if buf.is_empty() {
return Ok(0);
}
while self.i < self.parts.len() {
let p = self.parts[self.i];
if self.off >= p.len() {
self.i += 1;
self.off = 0;
continue;
}
let n = (p.len() - self.off).min(buf.len());
buf[..n].copy_from_slice(&p[self.off..self.off + n]);
self.off += n;
total += n;
let tmp = buf;
buf = &mut tmp[n..];
if buf.is_empty() {
break;
}
}
Ok(total)
}
}
pub fn compress_size_chain_backend(parts: &[&[u8]], backend: &CompressionBackend) -> u64 {
match backend {
CompressionBackend::Zpaq { method } => {
let r = SliceChainReader::new(parts);
zpaq_compress_size_stream(r, method.as_str())
}
#[cfg(feature = "backend-rwkv")]
CompressionBackend::Rwkv7 { model, coder } => {
with_rwkv_tls(model, |c| c.compress_size_chain(parts, *coder).unwrap_or(0))
}
CompressionBackend::Rate {
rate_backend,
coder,
framing,
} => {
crate::compression::compress_rate_size_chain(parts, rate_backend, -1, *coder, *framing)
.unwrap_or(0)
}
}
}
pub fn compress_size_backend(data: &[u8], backend: &CompressionBackend) -> u64 {
match backend {
CompressionBackend::Zpaq { method } => zpaq_compress_size_bytes(data, method.as_str()),
#[cfg(feature = "backend-rwkv")]
CompressionBackend::Rwkv7 { model, coder } => {
with_rwkv_tls(model, |c| c.compress_size(data, *coder).unwrap_or(0))
}
CompressionBackend::Rate {
rate_backend,
coder,
framing,
} => crate::compression::compress_rate_size(data, rate_backend, -1, *coder, *framing)
.unwrap_or(0),
}
}
pub fn compress_bytes_backend(
data: &[u8],
backend: &CompressionBackend,
) -> anyhow::Result<Vec<u8>> {
match backend {
CompressionBackend::Zpaq { method } => zpaq_compress_to_vec(data, method),
#[cfg(feature = "backend-rwkv")]
CompressionBackend::Rwkv7 { model, coder } => {
with_rwkv_tls(model, |c| c.compress(data, *coder))
}
CompressionBackend::Rate {
rate_backend,
coder,
framing,
} => crate::compression::compress_rate_bytes(data, rate_backend, -1, *coder, *framing),
}
}
pub fn decompress_bytes_backend(
input: &[u8],
backend: &CompressionBackend,
) -> anyhow::Result<Vec<u8>> {
match backend {
CompressionBackend::Zpaq { .. } => zpaq_decompress_to_vec(input),
#[cfg(feature = "backend-rwkv")]
CompressionBackend::Rwkv7 { model, .. } => with_rwkv_tls(model, |c| c.decompress(input)),
CompressionBackend::Rate {
rate_backend,
coder,
framing,
} => crate::compression::decompress_rate_bytes(input, rate_backend, -1, *coder, *framing),
}
}
fn prequential_rate_backend(
data: &[u8],
prefix_parts: &[&[u8]],
max_order: i64,
backend: &RateBackend,
) -> f64 {
use crate::mixture::OnlineBytePredictor;
if data.is_empty() {
return 0.0;
}
let total = prefix_parts
.iter()
.map(|p| p.len() as u64)
.sum::<u64>()
.saturating_add(data.len() as u64);
let mut predictor = crate::mixture::RateBackendPredictor::from_backend(
backend.clone(),
max_order,
crate::mixture::DEFAULT_MIN_PROB,
);
predictor
.begin_stream(Some(total))
.unwrap_or_else(|e| panic!("rate backend stream init failed: {e}"));
for prefix in prefix_parts {
for &b in *prefix {
predictor.update(b);
}
}
let mut bits = 0.0;
for &b in data {
bits -= predictor.log_prob(b) / std::f64::consts::LN_2;
predictor.update(b);
}
predictor
.finish_stream()
.unwrap_or_else(|e| panic!("rate backend stream finalize failed: {e}"));
bits / (data.len() as f64)
}
fn frozen_plugin_rate_backend(
score_data: &[u8],
fit_parts: &[&[u8]],
max_order: i64,
backend: &RateBackend,
) -> f64 {
if score_data.is_empty() {
return 0.0;
}
if matches!(backend, RateBackend::RosaPlus) {
let mut model = rosaplus::RosaPlus::new(max_order, false, 0, 42);
for part in fit_parts {
model.train_example(part);
}
model.build_lm();
return model.cross_entropy(score_data);
}
#[cfg(feature = "backend-rwkv")]
match backend {
RateBackend::Rwkv7 { model } => {
return with_rwkv_rate_tls(model, |c| {
c.cross_entropy_frozen_plugin_chain(fit_parts, score_data)
.unwrap_or_else(|e| panic!("rwkv frozen-plugin scoring failed: {e:#}"))
});
}
RateBackend::Rwkv7Method { method } => {
return with_rwkv_method_tls(method, |c| {
c.cross_entropy_frozen_plugin_chain(fit_parts, score_data)
.unwrap_or_else(|e| panic!("rwkv method frozen-plugin scoring failed: {e:#}"))
});
}
_ => {}
}
#[cfg(feature = "backend-mamba")]
match backend {
RateBackend::Mamba { model } => {
return with_mamba_rate_tls(model, |c| {
c.cross_entropy_frozen_plugin_chain(fit_parts, score_data)
.unwrap_or_else(|e| panic!("mamba frozen-plugin scoring failed: {e:#}"))
});
}
RateBackend::MambaMethod { method } => {
return with_mamba_method_tls(method, |c| {
c.cross_entropy_frozen_plugin_chain(fit_parts, score_data)
.unwrap_or_else(|e| panic!("mamba method frozen-plugin scoring failed: {e:#}"))
});
}
_ => {}
}
use crate::mixture::OnlineBytePredictor;
let fit_total = fit_parts.iter().map(|part| part.len() as u64).sum::<u64>();
let mut predictor = crate::mixture::RateBackendPredictor::from_backend(
backend.clone(),
max_order,
crate::mixture::DEFAULT_MIN_PROB,
);
predictor
.begin_stream(Some(fit_total))
.unwrap_or_else(|e| panic!("rate backend fit-pass init failed: {e}"));
for part in fit_parts {
for &byte in *part {
predictor.update(byte);
}
}
predictor
.finish_stream()
.unwrap_or_else(|e| panic!("rate backend fit-pass finalize failed: {e}"));
predictor
.reset_frozen(Some(score_data.len() as u64))
.unwrap_or_else(|e| panic!("rate backend frozen-score reset failed: {e}"));
let mut bits = 0.0;
for &byte in score_data {
bits -= predictor.log_prob(byte) / std::f64::consts::LN_2;
predictor.update_frozen(byte);
}
predictor
.finish_stream()
.unwrap_or_else(|e| panic!("rate backend frozen-score finalize failed: {e}"));
bits / (score_data.len() as f64)
}
#[inline(always)]
fn argmax_log_prob_byte(logps: &[f64; 256]) -> u8 {
let mut best_idx = 0usize;
let mut best = f64::NEG_INFINITY;
for (idx, &logp) in logps.iter().enumerate() {
let score = if logp.is_finite() {
logp
} else {
f64::NEG_INFINITY
};
if score > best {
best = score;
best_idx = idx;
}
}
best_idx as u8
}
fn pick_generated_byte(
logps: &[f64; 256],
config: GenerationConfig,
rng: &mut GenerationRng,
) -> u8 {
if matches!(config.strategy, GenerationStrategy::Greedy)
|| !config.temperature.is_finite()
|| config.temperature <= 0.0
{
return argmax_log_prob_byte(logps);
}
let mut entries = [(0u8, f64::NEG_INFINITY); 256];
for (idx, &logp) in logps.iter().enumerate() {
let scaled = if logp.is_finite() {
logp / config.temperature
} else {
f64::NEG_INFINITY
};
entries[idx] = (idx as u8, scaled);
}
entries.sort_by(|a, b| b.1.total_cmp(&a.1));
let keep_k = if config.top_k == 0 {
entries.len()
} else {
config.top_k.min(entries.len())
};
let top_p = if config.top_p.is_finite() {
config.top_p.clamp(0.0, 1.0)
} else {
1.0
};
let mut max_logp = f64::NEG_INFINITY;
for &(_, logp) in entries.iter().take(keep_k) {
if logp.is_finite() {
max_logp = max_logp.max(logp);
}
}
if !max_logp.is_finite() {
return argmax_log_prob_byte(logps);
}
let mut weights = [(0u8, 0.0f64); 256];
let mut total = 0.0;
for (idx, &(byte, logp)) in entries.iter().take(keep_k).enumerate() {
let w = if logp.is_finite() {
(logp - max_logp).exp()
} else {
0.0
};
weights[idx] = (byte, w);
total += w;
}
if !(total.is_finite()) || total <= 0.0 {
return argmax_log_prob_byte(logps);
}
let cutoff_count = if top_p >= 1.0 {
keep_k
} else {
let mut cumulative = 0.0;
let mut keep = 0usize;
for &(_, w) in weights.iter().take(keep_k) {
cumulative += w / total;
keep += 1;
if cumulative >= top_p {
break;
}
}
keep.max(1)
};
let mut truncated_total = 0.0;
for &(_, w) in weights.iter().take(cutoff_count) {
truncated_total += w;
}
if !(truncated_total.is_finite()) || truncated_total <= 0.0 {
return argmax_log_prob_byte(logps);
}
let target = rng.next_f64() * truncated_total;
let mut cumulative = 0.0;
let mut picked = weights[0].0;
for &(byte, weight) in weights.iter().take(cutoff_count) {
cumulative += weight;
if cumulative >= target {
picked = byte;
break;
}
}
picked
}
fn generate_rate_backend_chain(
prefix_parts: &[&[u8]],
bytes: usize,
max_order: i64,
backend: &RateBackend,
config: GenerationConfig,
) -> Vec<u8> {
if bytes == 0 {
return Vec::new();
}
let total = prefix_parts
.iter()
.map(|p| p.len() as u64)
.sum::<u64>()
.saturating_add(bytes as u64);
let mut session = RateBackendSession::from_backend(backend.clone(), max_order, Some(total))
.unwrap_or_else(|e| panic!("rate backend generation init failed: {e}"));
for &part in prefix_parts {
session.observe(part);
}
let out = session.generate_bytes(bytes, config);
session
.finish()
.unwrap_or_else(|e| panic!("rate backend generation finalize failed: {e}"));
out
}
pub fn entropy_rate_backend(data: &[u8], max_order: i64, backend: &RateBackend) -> f64 {
match backend {
RateBackend::RosaPlus => {
let mut m = rosaplus::RosaPlus::new(max_order, false, 0, 42);
m.predictive_entropy_rate(data)
}
RateBackend::Match { .. }
| RateBackend::SparseMatch { .. }
| RateBackend::Ppmd { .. }
| RateBackend::Calibrated { .. } => prequential_rate_backend(data, &[], max_order, backend),
#[cfg(feature = "backend-rwkv")]
RateBackend::Rwkv7 { model } => with_rwkv_tls(model, |c| {
c.cross_entropy(data)
.unwrap_or_else(|e| panic!("rwkv entropy scoring failed: {e:#}"))
}),
#[cfg(feature = "backend-rwkv")]
RateBackend::Rwkv7Method { method } => with_rwkv_method_tls(method, |c| {
c.cross_entropy(data)
.unwrap_or_else(|e| panic!("rwkv method entropy scoring failed: {e:#}"))
}),
#[cfg(feature = "backend-mamba")]
RateBackend::Mamba { model } => with_mamba_tls(model, |c| {
c.cross_entropy(data)
.unwrap_or_else(|e| panic!("mamba entropy scoring failed: {e:#}"))
}),
#[cfg(feature = "backend-mamba")]
RateBackend::MambaMethod { method } => with_mamba_method_tls(method, |c| {
c.cross_entropy(data)
.unwrap_or_else(|e| panic!("mamba method entropy scoring failed: {e:#}"))
}),
RateBackend::Zpaq { method } => {
if data.is_empty() {
return 0.0;
}
let mut model = crate::zpaq_rate::ZpaqRateModel::new(method.clone(), 2f64.powi(-24));
let bits = model.update_and_score(data);
bits / (data.len() as f64)
}
RateBackend::Mixture { spec } => {
if data.is_empty() {
return 0.0;
}
let experts = spec.build_experts();
let mut mix = crate::mixture::build_mixture_runtime(spec.as_ref(), &experts)
.unwrap_or_else(|e| panic!("MixtureSpec invalid: {e}"));
mix.begin_stream(Some(data.len() as u64))
.unwrap_or_else(|e| panic!("Mixture stream init failed: {e}"));
let mut bits = 0.0;
for &b in data {
bits -= mix.step(b) / std::f64::consts::LN_2;
}
mix.finish_stream()
.unwrap_or_else(|e| panic!("Mixture stream finalize failed: {e}"));
bits / (data.len() as f64)
}
RateBackend::Particle { spec } => {
if data.is_empty() {
return 0.0;
}
let mut runtime = crate::particle::ParticleRuntime::new(spec.as_ref());
let mut bits = 0.0;
for &b in data {
bits -= runtime.step(b) / std::f64::consts::LN_2;
}
bits / (data.len() as f64)
}
RateBackend::Ctw { depth } => {
if data.is_empty() {
return 0.0;
}
let mut fac = crate::ctw::FacContextTree::new(*depth, 8);
fac.reserve_for_symbols(data.len());
for &b in data {
fac.update_byte_msb(b);
}
let ln_p = fac.get_log_block_probability();
let bits = -ln_p / std::f64::consts::LN_2;
bits / (data.len() as f64)
}
RateBackend::FacCtw {
base_depth,
num_percept_bits: _,
encoding_bits,
} => {
if data.is_empty() {
return 0.0;
}
let bits_per_byte = (*encoding_bits).clamp(1, 8);
let mut fac = crate::ctw::FacContextTree::new(*base_depth, bits_per_byte);
fac.reserve_for_symbols(data.len());
for &b in data {
fac.update_byte_lsb(b);
}
let ln_p = fac.get_log_block_probability();
let bits = -ln_p / std::f64::consts::LN_2;
bits / (data.len() as f64)
}
}
}
pub fn biased_entropy_rate_backend(data: &[u8], max_order: i64, backend: &RateBackend) -> f64 {
match backend {
RateBackend::Zpaq { .. } => {
panic!("biased/plugin entropy is not supported for zpaq rate backends in 1.1.1")
}
_ => frozen_plugin_rate_backend(data, &[data], max_order, backend),
}
}
pub fn cross_entropy_rate_backend(
test_data: &[u8],
train_data: &[u8],
max_order: i64,
backend: &RateBackend,
) -> f64 {
match backend {
RateBackend::Zpaq { method } => {
if test_data.is_empty() {
return 0.0;
}
let mut model = crate::zpaq_rate::ZpaqRateModel::new(method.clone(), 2f64.powi(-24));
model.update_and_score(train_data);
let bits = model.update_and_score(test_data);
bits / (test_data.len() as f64)
}
_ => frozen_plugin_rate_backend(test_data, &[train_data], max_order, backend),
}
}
pub fn joint_entropy_rate_backend(
x: &[u8],
y: &[u8],
max_order: i64,
backend: &RateBackend,
) -> f64 {
let (x, y) = aligned_prefix(x, y);
if x.is_empty() {
return 0.0;
}
match backend {
RateBackend::RosaPlus => {
let joint_symbols: Vec<u32> = (0..x.len())
.map(|i| (x[i] as u32) * 256 + (y[i] as u32))
.collect();
let mut m = rosaplus::RosaPlus::new(max_order, false, 0, 42);
m.entropy_rate_cps(&joint_symbols)
}
RateBackend::Match { .. }
| RateBackend::SparseMatch { .. }
| RateBackend::Ppmd { .. }
| RateBackend::Calibrated { .. } => {
let mut joint = Vec::with_capacity(x.len() * 2);
for (&xb, &yb) in x.iter().zip(y.iter()) {
joint.push(xb);
joint.push(yb);
}
entropy_rate_backend(&joint, max_order, backend) * 2.0
}
#[cfg(feature = "backend-rwkv")]
RateBackend::Rwkv7 { model } => with_rwkv_tls(model, |c| {
c.joint_cross_entropy_aligned_min(x, y)
.unwrap_or_else(|e| panic!("rwkv joint-entropy scoring failed: {e:#}"))
}),
#[cfg(feature = "backend-rwkv")]
RateBackend::Rwkv7Method { method } => with_rwkv_method_tls(method, |c| {
c.joint_cross_entropy_aligned_min(x, y)
.unwrap_or_else(|e| panic!("rwkv method joint-entropy scoring failed: {e:#}"))
}),
#[cfg(feature = "backend-mamba")]
RateBackend::Mamba { model } => with_mamba_tls(model, |c| {
c.joint_cross_entropy_aligned_min(x, y)
.unwrap_or_else(|e| panic!("mamba joint-entropy scoring failed: {e:#}"))
}),
#[cfg(feature = "backend-mamba")]
RateBackend::MambaMethod { method } => with_mamba_method_tls(method, |c| {
c.joint_cross_entropy_aligned_min(x, y)
.unwrap_or_else(|e| panic!("mamba method joint-entropy scoring failed: {e:#}"))
}),
RateBackend::Zpaq { method } => {
let mut joint = Vec::with_capacity(x.len() * 2);
for (&xb, &yb) in x.iter().zip(y.iter()) {
joint.push(xb);
joint.push(yb);
}
let mut model = crate::zpaq_rate::ZpaqRateModel::new(method.clone(), 2f64.powi(-24));
let bits = model.update_and_score(&joint);
bits / (x.len() as f64)
}
RateBackend::Mixture { spec } => {
let mut joint = Vec::with_capacity(x.len() * 2);
for (&xb, &yb) in x.iter().zip(y.iter()) {
joint.push(xb);
joint.push(yb);
}
let experts = spec.build_experts();
let mut mix = crate::mixture::build_mixture_runtime(spec.as_ref(), &experts)
.unwrap_or_else(|e| panic!("MixtureSpec invalid: {e}"));
mix.begin_stream(Some(joint.len() as u64))
.unwrap_or_else(|e| panic!("Mixture stream init failed: {e}"));
let mut bits = 0.0;
for &b in &joint {
bits -= mix.step(b) / std::f64::consts::LN_2;
}
mix.finish_stream()
.unwrap_or_else(|e| panic!("Mixture stream finalize failed: {e}"));
bits / (x.len() as f64)
}
RateBackend::Particle { spec } => {
let mut joint = Vec::with_capacity(x.len() * 2);
for (&xb, &yb) in x.iter().zip(y.iter()) {
joint.push(xb);
joint.push(yb);
}
let mut runtime = crate::particle::ParticleRuntime::new(spec.as_ref());
let mut bits = 0.0;
for &b in &joint {
bits -= runtime.step(b) / std::f64::consts::LN_2;
}
bits / (x.len() as f64)
}
RateBackend::Ctw { depth } => {
let mut fac = crate::ctw::FacContextTree::new(*depth, 16);
for k in 0..x.len() {
let bx = x[k];
let by = y[k];
for bit_idx in 0..8 {
let bit_x = ((bx >> (7 - bit_idx)) & 1) == 1;
let bit_y = ((by >> (7 - bit_idx)) & 1) == 1;
fac.update(bit_x, bit_idx);
fac.update(bit_y, bit_idx + 8);
}
}
let ln_p = fac.get_log_block_probability();
let bits = -ln_p / std::f64::consts::LN_2;
bits / (x.len() as f64)
}
RateBackend::FacCtw {
base_depth,
num_percept_bits: _,
encoding_bits,
} => {
let bits_per_byte = (*encoding_bits).clamp(1, 8);
let mut fac = crate::ctw::FacContextTree::new(*base_depth, bits_per_byte * 2);
for k in 0..x.len() {
let bx = x[k];
let by = y[k];
for i in 0..bits_per_byte {
let bit_idx_x = i * 2;
let bit_idx_y = bit_idx_x + 1;
fac.update(((bx >> i) & 1) == 1, bit_idx_x);
fac.update(((by >> i) & 1) == 1, bit_idx_y);
}
}
let ln_p = fac.get_log_block_probability();
let bits = -ln_p / std::f64::consts::LN_2;
bits / (x.len() as f64)
}
}
}
#[inline(always)]
pub fn get_compressed_size_parallel(path: &str, method: &str, threads: usize) -> u64 {
zpaq_compress_size_parallel_bytes(&std::fs::read(path).unwrap(), method, threads)
}
#[inline(always)]
pub fn get_bytes_from_paths(paths: &[&str]) -> Vec<Vec<u8>> {
paths
.par_iter()
.map(|path| std::fs::read(*path).expect("failed to read file"))
.collect()
}
#[inline(always)]
pub fn get_sequential_compressed_sizes_from_sequential_paths(
paths: &[&str],
method: &str,
) -> Vec<u64> {
get_bytes_from_paths(paths)
.par_iter()
.map(|data| zpaq_compress_size_bytes(data, method))
.collect()
}
#[inline(always)]
pub fn get_parallel_compressed_sizes_from_sequential_paths(
paths: &[&str],
method: &str,
threads: usize,
) -> Vec<u64> {
get_bytes_from_paths(paths)
.par_iter()
.map(|data| zpaq_compress_size_parallel_bytes(data, method, threads))
.collect()
}
#[inline(always)]
pub fn get_sequential_compressed_sizes_from_parallel_paths(
paths: &[&str],
method: &str,
) -> Vec<u64> {
paths
.par_iter()
.map(|path| get_compressed_size(path, method))
.collect()
}
#[inline(always)]
pub fn get_parallel_compressed_sizes_from_parallel_paths(
paths: &[&str],
method: &str,
threads: usize,
) -> Vec<u64> {
paths
.par_iter()
.map(|path| get_compressed_size_parallel(path, method, threads))
.collect()
}
#[inline(always)]
pub fn get_compressed_sizes_from_paths(paths: &[&str], method: &str) -> Vec<u64> {
let n: usize = paths.len();
let num_threads: usize = *NUM_THREADS.get_or_init(num_cpus::get);
if n < num_threads {
get_parallel_compressed_sizes_from_parallel_paths(paths, method, num_threads.div_ceil(n))
} else {
get_sequential_compressed_sizes_from_parallel_paths(paths, method)
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum NcdVariant {
Vitanyi,
SymVitanyi,
Cons,
SymCons,
}
#[inline(always)]
fn compress_size_bytes(data: &[u8], method: &str) -> u64 {
zpaq_compress_size_bytes(data, method)
}
#[inline(always)]
fn ncd_from_sizes(cx: u64, cy: u64, cxy: u64, cyx: Option<u64>, variant: NcdVariant) -> f64 {
let min_c = cx.min(cy) as f64;
let max_c = cx.max(cy) as f64;
match variant {
NcdVariant::Vitanyi => {
if max_c == 0.0 {
0.0
} else {
(cxy as f64 - min_c) / max_c
}
}
NcdVariant::SymVitanyi => {
let m = cxy.min(cyx.expect("cyx required for SymVitanyi")) as f64;
if max_c == 0.0 {
0.0
} else {
(m - min_c) / max_c
}
}
NcdVariant::Cons => {
let denom = cxy as f64;
if denom == 0.0 {
0.0
} else {
(cxy as f64 - min_c) / denom
}
}
NcdVariant::SymCons => {
let m = cxy.min(cyx.expect("cyx required for SymCons")) as f64;
if m == 0.0 { 0.0 } else { (m - min_c) / m }
}
}
}
#[inline(always)]
pub fn ncd_bytes(x: &[u8], y: &[u8], method: &str, variant: NcdVariant) -> f64 {
let backend = CompressionBackend::Zpaq {
method: method.to_string(),
};
ncd_bytes_backend(x, y, &backend, variant)
}
#[inline(always)]
pub fn ncd_bytes_default(x: &[u8], y: &[u8], variant: NcdVariant) -> f64 {
with_default_ctx(|ctx| ctx.ncd_bytes(x, y, variant))
}
pub fn ncd_bytes_backend(
x: &[u8],
y: &[u8],
backend: &CompressionBackend,
variant: NcdVariant,
) -> f64 {
let (cx, cy) = rayon::join(
|| compress_size_backend(x, backend),
|| compress_size_backend(y, backend),
);
let cxy = compress_size_chain_backend(&[x, y], backend);
let cyx = match variant {
NcdVariant::SymVitanyi | NcdVariant::SymCons => {
Some(compress_size_chain_backend(&[y, x], backend))
}
_ => None,
};
ncd_from_sizes(cx, cy, cxy, cyx, variant)
}
#[inline(always)]
pub fn ncd_paths(x: &str, y: &str, method: &str, variant: NcdVariant) -> f64 {
let (bx, by) = rayon::join(
|| std::fs::read(x).expect("failed to read x"),
|| std::fs::read(y).expect("failed to read y"),
);
ncd_bytes(&bx, &by, method, variant)
}
pub fn ncd_paths_backend(
x: &str,
y: &str,
backend: &CompressionBackend,
variant: NcdVariant,
) -> f64 {
let (bx, by) = rayon::join(
|| std::fs::read(x).expect("failed to read x"),
|| std::fs::read(y).expect("failed to read y"),
);
ncd_bytes_backend(&bx, &by, backend, variant)
}
#[inline(always)]
pub fn ncd_vitanyi(x: &str, y: &str, method: &str) -> f64 {
ncd_paths(x, y, method, NcdVariant::Vitanyi)
}
#[inline(always)]
pub fn ncd_sym_vitanyi(x: &str, y: &str, method: &str) -> f64 {
ncd_paths(x, y, method, NcdVariant::SymVitanyi)
}
#[inline(always)]
pub fn ncd_cons(x: &str, y: &str, method: &str) -> f64 {
ncd_paths(x, y, method, NcdVariant::Cons)
}
#[inline(always)]
pub fn ncd_sym_cons(x: &str, y: &str, method: &str) -> f64 {
ncd_paths(x, y, method, NcdVariant::SymCons)
}
pub fn ncd_matrix_bytes(datas: &[Vec<u8>], method: &str, variant: NcdVariant) -> Vec<f64> {
let n = datas.len();
let cx: Vec<u64> = datas
.par_iter()
.map(|d| compress_size_bytes(d, method))
.collect();
let mut out = vec![0.0f64; n * n];
let out_ptr = std::sync::atomic::AtomicPtr::new(out.as_mut_ptr());
match variant {
NcdVariant::SymVitanyi | NcdVariant::SymCons => {
(0..n)
.into_par_iter()
.flat_map_iter(|i| (i + 1..n).map(move |j| (i, j)))
.for_each_init(Vec::<u8>::new, |buf, (i, j)| {
let x = &datas[i];
let y = &datas[j];
buf.clear();
buf.reserve(x.len() + y.len());
buf.extend_from_slice(x);
buf.extend_from_slice(y);
let cxy = compress_size_bytes(buf, method);
buf.clear();
buf.reserve(x.len() + y.len());
buf.extend_from_slice(y);
buf.extend_from_slice(x);
let cyx = compress_size_bytes(buf, method);
let d = ncd_from_sizes(cx[i], cx[j], cxy, Some(cyx), variant);
let p = out_ptr.load(std::sync::atomic::Ordering::Relaxed);
unsafe {
*p.add(i * n + j) = d;
*p.add(j * n + i) = d;
}
});
}
NcdVariant::Vitanyi | NcdVariant::Cons => {
(0..n)
.into_par_iter()
.for_each_init(Vec::<u8>::new, |buf, i| {
let x = &datas[i];
for j in 0..n {
let d = if i == j {
0.0
} else {
let y = &datas[j];
buf.clear();
buf.reserve(x.len() + y.len());
buf.extend_from_slice(x);
buf.extend_from_slice(y);
let cxy = compress_size_bytes(buf, method);
ncd_from_sizes(cx[i], cx[j], cxy, None, variant)
};
let p = out_ptr.load(std::sync::atomic::Ordering::Relaxed);
unsafe {
*p.add(i * n + j) = d;
}
}
});
}
}
out
}
pub fn ncd_matrix_paths(paths: &[&str], method: &str, variant: NcdVariant) -> Vec<f64> {
let datas = get_bytes_from_paths(paths);
ncd_matrix_bytes(&datas, method, variant)
}
#[inline(always)]
pub fn marginal_entropy_bytes(data: &[u8]) -> f64 {
if data.is_empty() {
return 0.0;
}
let mut counts = [0u64; 256];
for &b in data {
counts[b as usize] += 1;
}
let n = data.len() as f64;
let mut h = 0.0f64;
for &count in &counts {
if count > 0 {
let p = count as f64 / n;
h -= p * p.log2();
}
}
h
}
#[inline(always)]
pub fn entropy_rate_bytes(data: &[u8], max_order: i64) -> f64 {
with_default_ctx(|ctx| ctx.entropy_rate_bytes(data, max_order))
}
#[inline(always)]
pub fn biased_entropy_rate_bytes(data: &[u8], max_order: i64) -> f64 {
with_default_ctx(|ctx| ctx.biased_entropy_rate_bytes(data, max_order))
}
#[inline(always)]
pub fn joint_marginal_entropy_bytes(x: &[u8], y: &[u8]) -> f64 {
let (x, y) = aligned_prefix(x, y);
let n = x.len();
if n == 0 {
return 0.0;
}
let mut counts = vec![0u64; 256 * 256];
for i in 0..n {
let pair_idx = (x[i] as usize) * 256 + (y[i] as usize);
counts[pair_idx] += 1;
}
let n_f64 = n as f64;
let mut h = 0.0f64;
for &c in &counts {
if c > 0 {
let p = c as f64 / n_f64;
h -= p * p.log2();
}
}
h
}
#[inline(always)]
pub fn joint_entropy_rate_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
with_default_ctx(|ctx| ctx.joint_entropy_rate_bytes(x, y, max_order))
}
#[inline(always)]
pub fn conditional_entropy_rate_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
with_default_ctx(|ctx| ctx.conditional_entropy_rate_bytes(x, y, max_order))
}
#[inline(always)]
pub fn conditional_entropy_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
with_default_ctx(|ctx| ctx.conditional_entropy_bytes(x, y, max_order))
}
#[inline(always)]
pub fn mutual_information_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
with_default_ctx(|ctx| ctx.mutual_information_bytes(x, y, max_order))
}
pub fn mutual_information_marg_bytes(x: &[u8], y: &[u8]) -> f64 {
let (x, y) = aligned_prefix(x, y);
let h_x = marginal_entropy_bytes(x);
let h_y = marginal_entropy_bytes(y);
let h_xy = joint_marginal_entropy_bytes(x, y);
(h_x + h_y - h_xy).max(0.0)
}
#[inline(always)]
pub fn mutual_information_rate_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
with_default_ctx(|ctx| ctx.mutual_information_rate_bytes(x, y, max_order))
}
#[inline(always)]
pub fn ned_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
with_default_ctx(|ctx| ctx.ned_bytes(x, y, max_order))
}
pub fn ned_marg_bytes(x: &[u8], y: &[u8]) -> f64 {
let (x, y) = aligned_prefix(x, y);
let h_x = marginal_entropy_bytes(x);
let h_y = marginal_entropy_bytes(y);
let h_xy = joint_marginal_entropy_bytes(x, y);
let min_h = h_x.min(h_y);
let max_h = h_x.max(h_y);
if max_h == 0.0 {
0.0
} else {
((h_xy - min_h) / max_h).clamp(0.0, 1.0)
}
}
#[inline(always)]
pub fn ned_rate_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
with_default_ctx(|ctx| ctx.ned_bytes(x, y, max_order))
}
#[inline(always)]
pub fn ned_cons_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
with_default_ctx(|ctx| ctx.ned_cons_bytes(x, y, max_order))
}
pub fn ned_cons_marg_bytes(x: &[u8], y: &[u8]) -> f64 {
let h_x = marginal_entropy_bytes(x);
let h_y = marginal_entropy_bytes(y);
let h_xy = joint_marginal_entropy_bytes(x, y);
let min_h = h_x.min(h_y);
if h_xy == 0.0 {
0.0
} else {
((h_xy - min_h) / h_xy).clamp(0.0, 1.0)
}
}
#[inline(always)]
pub fn ned_cons_rate_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
with_default_ctx(|ctx| ctx.ned_cons_bytes(x, y, max_order))
}
#[inline(always)]
pub fn nte_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
with_default_ctx(|ctx| ctx.nte_bytes(x, y, max_order))
}
pub fn nte_marg_bytes(x: &[u8], y: &[u8]) -> f64 {
let (x, y) = aligned_prefix(x, y);
let h_x = marginal_entropy_bytes(x);
let h_y = marginal_entropy_bytes(y);
let h_xy = joint_marginal_entropy_bytes(x, y);
let vi = 2.0 * h_xy - h_x - h_y;
let max_h = h_x.max(h_y);
if max_h == 0.0 {
0.0
} else {
(vi / max_h).clamp(0.0, 2.0)
}
}
#[inline(always)]
pub fn nte_rate_bytes(x: &[u8], y: &[u8], max_order: i64) -> f64 {
with_default_ctx(|ctx| ctx.nte_bytes(x, y, max_order))
}
#[inline(always)]
fn byte_histogram(data: &[u8]) -> [f64; 256] {
let mut counts = [0u64; 256];
for &b in data {
counts[b as usize] += 1;
}
let n = data.len() as f64;
let mut probs = [0.0f64; 256];
if n == 0.0 {
return probs;
}
for i in 0..256 {
probs[i] = counts[i] as f64 / n;
}
probs
}
#[inline(always)]
pub fn tvd_bytes(x: &[u8], y: &[u8], _max_order: i64) -> f64 {
if x.is_empty() || y.is_empty() {
return 0.0;
}
let p_x = byte_histogram(x);
let p_y = byte_histogram(y);
let mut sum = 0.0f64;
for i in 0..256 {
sum += (p_x[i] - p_y[i]).abs();
}
(sum / 2.0).clamp(0.0, 1.0)
}
#[inline(always)]
pub fn nhd_bytes(x: &[u8], y: &[u8], _max_order: i64) -> f64 {
if x.is_empty() || y.is_empty() {
return 0.0;
}
let p_x = byte_histogram(x);
let p_y = byte_histogram(y);
let mut bc = 0.0f64;
for i in 0..256 {
bc += (p_x[i] * p_y[i]).sqrt();
}
(1.0 - bc).max(0.0).sqrt()
}
#[inline(always)]
pub fn cross_entropy_bytes(test_data: &[u8], train_data: &[u8], max_order: i64) -> f64 {
with_default_ctx(|ctx| ctx.cross_entropy_bytes(test_data, train_data, max_order))
}
#[inline(always)]
pub fn cross_entropy_rate_bytes(test_data: &[u8], train_data: &[u8], max_order: i64) -> f64 {
with_default_ctx(|ctx| ctx.cross_entropy_rate_bytes(test_data, train_data, max_order))
}
#[inline(always)]
pub fn generate_bytes(prompt: &[u8], bytes: usize, max_order: i64) -> Vec<u8> {
with_default_ctx(|ctx| ctx.generate_bytes(prompt, bytes, max_order))
}
#[inline(always)]
pub fn generate_bytes_with_config(
prompt: &[u8],
bytes: usize,
max_order: i64,
config: GenerationConfig,
) -> Vec<u8> {
with_default_ctx(|ctx| ctx.generate_bytes_with_config(prompt, bytes, max_order, config))
}
#[inline(always)]
pub fn generate_bytes_conditional_chain(
prefix_parts: &[&[u8]],
bytes: usize,
max_order: i64,
) -> Vec<u8> {
with_default_ctx(|ctx| ctx.generate_bytes_conditional_chain(prefix_parts, bytes, max_order))
}
#[inline(always)]
pub fn generate_bytes_conditional_chain_with_config(
prefix_parts: &[&[u8]],
bytes: usize,
max_order: i64,
config: GenerationConfig,
) -> Vec<u8> {
with_default_ctx(|ctx| {
ctx.generate_bytes_conditional_chain_with_config(prefix_parts, bytes, max_order, config)
})
}
pub fn d_kl_bytes(x: &[u8], y: &[u8]) -> f64 {
if x.is_empty() || y.is_empty() {
return 0.0;
}
let p_x = byte_histogram(x);
let p_y = byte_histogram(y);
let mut d_kl = 0.0f64;
for i in 0..256 {
if p_x[i] > 0.0 {
let q_y = p_y[i].max(1e-12);
d_kl += p_x[i] * (p_x[i] / q_y).log2();
}
}
d_kl.max(0.0)
}
pub fn js_div_bytes(x: &[u8], y: &[u8]) -> f64 {
if x.is_empty() || y.is_empty() {
return 0.0;
}
let p_x = byte_histogram(x);
let p_y = byte_histogram(y);
let mut m = [0.0f64; 256];
for i in 0..256 {
m[i] = 0.5 * (p_x[i] + p_y[i]);
}
let mut kl_pm = 0.0f64;
let mut kl_qm = 0.0f64;
for i in 0..256 {
if p_x[i] > 0.0 {
kl_pm += p_x[i] * (p_x[i] / m[i]).log2();
}
if p_y[i] > 0.0 {
kl_qm += p_y[i] * (p_y[i] / m[i]).log2();
}
}
(0.5 * kl_pm + 0.5 * kl_qm).max(0.0)
}
pub fn ned_paths(x: &str, y: &str, max_order: i64) -> f64 {
let (bx, by) = rayon::join(
|| std::fs::read(x).expect("failed to read x"),
|| std::fs::read(y).expect("failed to read y"),
);
ned_bytes(&bx, &by, max_order)
}
pub fn nte_paths(x: &str, y: &str, max_order: i64) -> f64 {
let (bx, by) = rayon::join(
|| std::fs::read(x).expect("failed to read x"),
|| std::fs::read(y).expect("failed to read y"),
);
nte_bytes(&bx, &by, max_order)
}
pub fn tvd_paths(x: &str, y: &str, max_order: i64) -> f64 {
let (bx, by) = rayon::join(
|| std::fs::read(x).expect("failed to read x"),
|| std::fs::read(y).expect("failed to read y"),
);
tvd_bytes(&bx, &by, max_order)
}
pub fn nhd_paths(x: &str, y: &str, max_order: i64) -> f64 {
let (bx, by) = rayon::join(
|| std::fs::read(x).expect("failed to read x"),
|| std::fs::read(y).expect("failed to read y"),
);
nhd_bytes(&bx, &by, max_order)
}
pub fn mutual_information_paths(x: &str, y: &str, max_order: i64) -> f64 {
let (bx, by) = rayon::join(
|| std::fs::read(x).expect("failed to read x"),
|| std::fs::read(y).expect("failed to read y"),
);
mutual_information_bytes(&bx, &by, max_order)
}
pub fn conditional_entropy_paths(x: &str, y: &str, max_order: i64) -> f64 {
let (bx, by) = rayon::join(
|| std::fs::read(x).expect("failed to read x"),
|| std::fs::read(y).expect("failed to read y"),
);
conditional_entropy_bytes(&bx, &by, max_order)
}
pub fn cross_entropy_paths(x: &str, y: &str, max_order: i64) -> f64 {
let (bx, by) = rayon::join(
|| std::fs::read(x).expect("failed to read x"),
|| std::fs::read(y).expect("failed to read y"),
);
cross_entropy_bytes(&bx, &by, max_order)
}
pub fn kl_divergence_paths(x: &str, y: &str) -> f64 {
let (bx, by) = rayon::join(
|| std::fs::read(x).expect("failed to read x"),
|| std::fs::read(y).expect("failed to read y"),
);
d_kl_bytes(&bx, &by)
}
pub fn js_divergence_paths(x: &str, y: &str) -> f64 {
let (bx, by) = rayon::join(
|| std::fs::read(x).expect("failed to read x"),
|| std::fs::read(y).expect("failed to read y"),
);
js_div_bytes(&bx, &by)
}
#[inline(always)]
pub fn intrinsic_dependence_bytes(data: &[u8], max_order: i64) -> f64 {
with_default_ctx(|ctx| ctx.intrinsic_dependence_bytes(data, max_order))
}
#[inline(always)]
pub fn resistance_to_transformation_bytes(x: &[u8], tx: &[u8], max_order: i64) -> f64 {
with_default_ctx(|ctx| ctx.resistance_to_transformation_bytes(x, tx, max_order))
}
#[cfg(test)]
mod tests {
use super::*;
fn test_match_backend() -> RateBackend {
RateBackend::Match {
hash_bits: 12,
min_len: 2,
max_len: 16,
base_mix: 0.01,
confidence_scale: 1.0,
}
}
fn test_ppmd_backend() -> RateBackend {
RateBackend::Ppmd {
order: 4,
memory_mb: 1,
}
}
fn test_calibrated_backend() -> RateBackend {
RateBackend::Calibrated {
spec: Arc::new(CalibratedSpec {
base: test_match_backend(),
context: CalibrationContextKind::Text,
bins: 16,
learning_rate: 0.05,
bias_clip: 4.0,
}),
}
}
fn test_mixture_backend() -> RateBackend {
RateBackend::Mixture {
spec: Arc::new(MixtureSpec::new(
MixtureKind::Bayes,
vec![
MixtureExpertSpec {
name: Some("match".to_string()),
log_prior: 0.0,
max_order: -1,
backend: test_match_backend(),
},
MixtureExpertSpec {
name: Some("ppmd".to_string()),
log_prior: 0.0,
max_order: -1,
backend: test_ppmd_backend(),
},
],
)),
}
}
fn test_particle_backend() -> RateBackend {
RateBackend::Particle {
spec: Arc::new(ParticleSpec {
num_particles: 4,
num_cells: 4,
cell_dim: 8,
num_rules: 2,
selector_hidden: 16,
rule_hidden: 16,
context_window: 8,
unroll_steps: 1,
..ParticleSpec::default()
}),
}
}
fn continuation_prompt() -> &'static [u8] {
b"If a frog is green, dogs are red.\nIf a toad is green, cats are red.\nIf a dog is green, frogs are red.\nIf a cat is green, toads are red.\nIf a frog is red, dogs are green.\nIf a toad is red, cats are green.\nIf a dog is red, frogs are green.\nIf a cat is red, toads are \n"
}
fn assert_deterministic_generate_for_backend(
backend: RateBackend,
max_order: i64,
bytes: usize,
label: &str,
) {
let prompt = continuation_prompt();
let a = generate_rate_backend_chain(
&[prompt],
bytes,
max_order,
&backend,
GenerationConfig::default(),
);
let b = generate_rate_backend_chain(
&[prompt],
bytes,
max_order,
&backend,
GenerationConfig::default(),
);
assert_eq!(
a, b,
"{label} generation should be deterministic for identical input"
);
assert_eq!(
a.len(),
bytes,
"{label} generation should emit requested byte count"
);
}
fn assert_sampled_generate_for_backend(
backend: RateBackend,
max_order: i64,
bytes: usize,
label: &str,
) {
let prompt = continuation_prompt();
let config = GenerationConfig::sampled_frozen(42);
let a = generate_rate_backend_chain(&[prompt], bytes, max_order, &backend, config);
let b = generate_rate_backend_chain(&[prompt], bytes, max_order, &backend, config);
assert_eq!(
a, b,
"{label} sampled generation should be deterministic for a fixed seed"
);
assert_eq!(
a.len(),
bytes,
"{label} sampled generation should emit requested byte count"
);
}
#[cfg(feature = "backend-zpaq")]
#[test]
fn ncd_basic_identity_nonnegative() {
let x = b"abcdabcdabcd";
let d = ncd_bytes(x, x, "5", NcdVariant::Vitanyi);
assert!(d >= -1e-9);
}
#[test]
fn shannon_identities_marginal_aligned() {
let x = b"abracadabra";
let y = b"abracadabra";
let h = marginal_entropy_bytes(x);
let mi = mutual_information_bytes(x, y, 0);
let h_xy = joint_marginal_entropy_bytes(x, y);
let h_x_given_y = conditional_entropy_bytes(x, y, 0);
let ned = ned_bytes(x, y, 0);
let nte = nte_bytes(x, y, 0);
assert!((h_xy - h).abs() < 1e-12);
assert!(h_x_given_y.abs() < 1e-12);
assert!((mi - h).abs() < 1e-12);
assert!(ned.abs() < 1e-12);
assert!(nte.abs() < 1e-12);
}
#[test]
fn shannon_identities_rate_aligned_reasonable() {
let x = b"the quick brown fox jumps over the lazy dog";
let y = b"the quick brown fox jumps over the lazy dog";
let max_order = 8;
let prev = get_default_ctx();
set_default_ctx(InfotheoryCtx::new(
RateBackend::RosaPlus,
CompressionBackend::default(),
));
let h_x = entropy_rate_bytes(x, max_order);
let h_xy = joint_entropy_rate_bytes(x, y, max_order);
let h_x_given_y = conditional_entropy_rate_bytes(x, y, max_order);
let mi = mutual_information_bytes(x, y, max_order);
let ned = ned_bytes(x, y, max_order);
let tol = 0.2;
assert!((h_xy - h_x).abs() < tol);
assert!(h_x_given_y < tol);
assert!((mi - h_x).abs() < tol);
assert!(ned < tol);
set_default_ctx(prev);
}
#[test]
fn resistance_identity_is_one() {
let x = b"some repeated repeated repeated text";
let prev = get_default_ctx();
set_default_ctx(InfotheoryCtx::new(
RateBackend::RosaPlus,
CompressionBackend::default(),
));
let r0 = resistance_to_transformation_bytes(x, x, 0);
let r8 = resistance_to_transformation_bytes(x, x, 8);
assert!((r0 - 1.0).abs() < 1e-12);
assert!((r8 - 1.0).abs() < 1e-6);
set_default_ctx(prev);
}
#[test]
fn marginal_metrics_empty_inputs_are_zero() {
let empty: &[u8] = &[];
let x = b"abc";
assert_eq!(tvd_bytes(empty, x, 0), 0.0);
assert_eq!(tvd_bytes(x, empty, 0), 0.0);
assert_eq!(nhd_bytes(empty, x, 0), 0.0);
assert_eq!(nhd_bytes(x, empty, 0), 0.0);
assert_eq!(d_kl_bytes(empty, x), 0.0);
assert_eq!(d_kl_bytes(x, empty), 0.0);
assert_eq!(js_div_bytes(empty, x), 0.0);
assert_eq!(js_div_bytes(x, empty), 0.0);
}
#[test]
fn marginal_cross_entropy_empty_test_is_zero() {
let empty: &[u8] = &[];
let y = b"abc";
let ctx = InfotheoryCtx::with_zpaq("5");
assert_eq!(ctx.cross_entropy_bytes(empty, y, 0), 0.0);
}
#[cfg(not(feature = "backend-zpaq"))]
#[test]
#[should_panic(expected = "CompressionBackend::Zpaq is unavailable")]
fn explicit_zpaq_backend_fails_loudly() {
let backend = CompressionBackend::Zpaq {
method: "5".to_string(),
};
let _ = compress_size_backend(b"abc", &backend);
}
#[cfg(not(feature = "backend-zpaq"))]
#[test]
fn default_compression_backend_falls_back_to_rate_coding() {
let backend = CompressionBackend::default();
assert!(matches!(
&backend,
CompressionBackend::Rate {
coder: crate::coders::CoderType::AC,
framing: crate::compression::FramingMode::Raw,
..
}
));
assert!(compress_size_backend(b"abc", &backend) > 0);
}
#[test]
fn backend_switching_test() {
let x = b"hello world context";
let h_rosa = entropy_rate_bytes(x, 8);
set_default_ctx(InfotheoryCtx::new(
RateBackend::Ctw { depth: 16 },
CompressionBackend::default(),
));
let h_ctw = entropy_rate_bytes(x, 8);
assert!(h_ctw > 0.0);
set_default_ctx(InfotheoryCtx::default());
let h_rosa_back = entropy_rate_bytes(x, 8);
assert!((h_rosa - h_rosa_back).abs() < 1e-12);
}
#[test]
fn ctw_early_updates_work() {
use crate::ctw::ContextTree;
let mut tree = ContextTree::new(16);
let p0 = tree.predict(false);
let p1 = tree.predict(true);
assert!((p0 - 0.5).abs() < 1e-10, "p0 should be ~0.5, got {}", p0);
assert!((p1 - 0.5).abs() < 1e-10, "p1 should be ~0.5, got {}", p1);
assert!((p0 + p1 - 1.0).abs() < 1e-10, "p0 + p1 should = 1.0");
for _ in 0..5 {
tree.update(true);
tree.update(false);
}
let log_prob = tree.get_log_block_probability();
assert!(
log_prob < 0.0,
"log_prob should be negative (< log 1), got {}",
log_prob
);
assert!(log_prob.is_finite(), "log_prob should be finite");
}
#[test]
fn nte_can_exceed_one() {
set_default_ctx(InfotheoryCtx::new(
RateBackend::Ctw { depth: 8 },
CompressionBackend::default(),
));
let x: Vec<u8> = (0..200).map(|i| (i % 2) as u8).collect(); let y: Vec<u8> = (0..200).map(|i| ((i + 1) % 2) as u8).collect();
let nte_rate = nte_rate_backend(&x, &y, -1, &RateBackend::Ctw { depth: 8 });
assert!(
(0.0..=2.0 + 1e-9).contains(&nte_rate),
"NTE should be in [0, 2], got {}",
nte_rate
);
set_default_ctx(InfotheoryCtx::default());
}
#[test]
fn ctw_empty_data_returns_zero() {
set_default_ctx(InfotheoryCtx::new(
RateBackend::Ctw { depth: 16 },
CompressionBackend::default(),
));
let empty: &[u8] = &[];
let h = entropy_rate_bytes(empty, -1);
assert_eq!(h, 0.0, "empty data should return 0.0 entropy");
set_default_ctx(InfotheoryCtx::default());
}
#[test]
fn joint_entropy_rate_aligns_inputs_and_handles_empty_cases() {
let cases = vec![
("ctw", RateBackend::Ctw { depth: 8 }),
(
"fac-ctw",
RateBackend::FacCtw {
base_depth: 8,
num_percept_bits: 8,
encoding_bits: 8,
},
),
("match", test_match_backend()),
];
for (name, backend) in cases {
assert_eq!(
joint_entropy_rate_backend(b"", b"nonempty", -1, &backend),
0.0,
"{name} should return 0.0 for empty aligned pairs"
);
assert_eq!(
joint_entropy_rate_backend(b"nonempty", b"", -1, &backend),
0.0,
"{name} should return 0.0 when alignment truncates to empty"
);
let aligned = joint_entropy_rate_backend(b"abcd", b"wxyz", -1, &backend);
let truncated = joint_entropy_rate_backend(b"abcdextra", b"wxyz", -1, &backend);
assert!(
(aligned - truncated).abs() < 1e-12,
"{name} should score only the aligned prefix: aligned={aligned} truncated={truncated}"
);
}
}
#[test]
fn biased_entropy_is_repeatable_across_backend_families() {
let data = b"ABABABAABBABABABAABB";
let cases = vec![
("match", test_match_backend()),
("ppmd", test_ppmd_backend()),
("calibrated", test_calibrated_backend()),
("ctw", RateBackend::Ctw { depth: 8 }),
("mixture", test_mixture_backend()),
("particle", test_particle_backend()),
];
for (name, backend) in cases {
let h1 = biased_entropy_rate_backend(data, -1, &backend);
let h2 = biased_entropy_rate_backend(data, -1, &backend);
assert!(h1.is_finite(), "{name} biased entropy should be finite");
assert!(
(h1 - h2).abs() < 1e-12,
"{name} biased entropy leaked mutable state across calls: h1={h1} h2={h2}"
);
}
}
#[test]
fn generate_bytes_chain_matches_flat_prompt() {
let prompt = continuation_prompt();
let split_at = prompt.len() / 2;
let front = &prompt[..split_at];
let back = &prompt[split_at..];
let backend = RateBackend::Ctw { depth: 32 };
let bytes = 8usize;
let max_order = -1;
let flat = generate_rate_backend_chain(
&[prompt],
bytes,
max_order,
&backend,
GenerationConfig::default(),
);
let chained = generate_rate_backend_chain(
&[front, back],
bytes,
max_order,
&backend,
GenerationConfig::default(),
);
assert_eq!(
flat, chained,
"chain conditioning should match flat prompt conditioning"
);
}
#[test]
fn generate_bytes_api_is_deterministic_for_ctw_rosa_match_ppmd() {
assert_deterministic_generate_for_backend(RateBackend::Ctw { depth: 32 }, -1, 8, "ctw");
assert_deterministic_generate_for_backend(RateBackend::RosaPlus, -1, 8, "rosaplus");
assert_deterministic_generate_for_backend(test_match_backend(), -1, 8, "match");
assert_deterministic_generate_for_backend(test_ppmd_backend(), -1, 8, "ppmd");
}
#[cfg(feature = "backend-rwkv")]
#[test]
fn generate_bytes_api_is_deterministic_for_rwkv_method() {
let backend = RateBackend::Rwkv7Method {
method: "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=31,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer".to_string(),
};
assert_deterministic_generate_for_backend(backend, -1, 8, "rwkv7");
}
#[test]
fn sampled_generation_is_deterministic_for_ctw_rosa_match_ppmd() {
assert_sampled_generate_for_backend(RateBackend::Ctw { depth: 32 }, -1, 8, "ctw");
assert_sampled_generate_for_backend(RateBackend::RosaPlus, -1, 8, "rosaplus");
assert_sampled_generate_for_backend(test_match_backend(), -1, 8, "match");
assert_sampled_generate_for_backend(test_ppmd_backend(), -1, 8, "ppmd");
}
#[cfg(feature = "backend-rwkv")]
#[test]
fn sampled_generation_is_deterministic_for_rwkv_method() {
let backend = RateBackend::Rwkv7Method {
method: "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=31,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer".to_string(),
};
assert_sampled_generate_for_backend(backend, -1, 8, "rwkv7");
}
#[test]
fn rosaplus_sampled_generation_predicts_green_continuation() {
let out = generate_rate_backend_chain(
&[continuation_prompt()],
8,
-1,
&RateBackend::RosaPlus,
GenerationConfig::sampled_frozen(42),
);
assert_eq!(out, b" green.\n");
}
#[test]
fn rate_backend_session_matches_ctx_generation() {
let prompt = continuation_prompt();
let backend = RateBackend::Ppmd {
order: 12,
memory_mb: 8,
};
let mut session =
RateBackendSession::from_backend(backend.clone(), -1, Some((prompt.len() + 8) as u64))
.expect("session init");
session.observe(prompt);
let from_session = session.generate_bytes(8, GenerationConfig::sampled_frozen(42));
session.finish().expect("session finish");
let ctx = InfotheoryCtx::new(backend, CompressionBackend::default());
let from_ctx =
ctx.generate_bytes_with_config(prompt, 8, -1, GenerationConfig::sampled_frozen(42));
assert_eq!(from_session, from_ctx);
}
#[test]
fn biased_entropy_ctw_uses_frozen_plugin_scoring() {
let backend = RateBackend::Ctw { depth: 8 };
let data = b"AAAAAAAA";
let plugin = biased_entropy_rate_backend(data, -1, &backend);
let prequential = entropy_rate_backend(data, -1, &backend);
assert!(
plugin + 1e-9 < prequential,
"expected plugin scoring to beat prequential scoring: plugin={plugin} prequential={prequential}"
);
}
#[test]
fn rosa_plugin_entropy_matches_direct_model_api() {
let data = b"abracadabra";
let backend = RateBackend::RosaPlus;
let plugin = biased_entropy_rate_backend(data, 3, &backend);
let mut direct = rosaplus::RosaPlus::new(3, false, 0, 42);
direct.train_example(data);
direct.build_lm();
let expected = direct.cross_entropy(data);
assert!(
(plugin - expected).abs() < 1e-12,
"rosa plugin entropy must match direct model API: plugin={plugin} expected={expected}"
);
}
#[test]
fn rosa_plugin_cross_entropy_matches_direct_model_api() {
let train = b"alakazam";
let test = b"abracadabra";
let backend = RateBackend::RosaPlus;
let plugin = cross_entropy_rate_backend(test, train, 3, &backend);
let mut direct = rosaplus::RosaPlus::new(3, false, 0, 42);
direct.train_example(train);
direct.build_lm();
let expected = direct.cross_entropy(test);
assert!(
(plugin - expected).abs() < 1e-12,
"rosa plugin cross entropy must match direct model API: plugin={plugin} expected={expected}"
);
}
#[test]
fn datagen_bernoulli_entropy_estimate() {
let p = 0.5;
let theoretical_h = crate::datagen::bernoulli_entropy(p);
assert!((theoretical_h - 1.0).abs() < 1e-10);
let data = crate::datagen::bernoulli(10000, p, 42);
let estimated_h = marginal_entropy_bytes(&data);
assert!(
(estimated_h - theoretical_h).abs() < 0.1,
"estimated H={} should be close to theoretical H={}",
estimated_h,
theoretical_h
);
}
#[cfg(feature = "backend-rwkv")]
#[test]
fn rwkv_method_entropy_is_stable_across_calls() {
let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=21,train=sgd,lr=0.01,stride=1;policy:schedule=0..100:infer";
let backend = RateBackend::Rwkv7Method {
method: method.to_string(),
};
let data = b"rwkv method entropy stability regression sample";
let h1 = entropy_rate_backend(data, -1, &backend);
let h2 = entropy_rate_backend(data, -1, &backend);
assert!(
(h1 - h2).abs() < 1e-12,
"rwkv method entropy leaked mutable state across calls: h1={h1}, h2={h2}"
);
}
#[cfg(feature = "backend-rwkv")]
#[test]
fn rwkv_method_without_policy_is_accepted_by_public_api() {
let backend = RateBackend::Rwkv7Method {
method: "cfg:hidden=64,layers=1,intermediate=64".to_string(),
};
let data = b"rwkv method without policy";
let h1 = entropy_rate_backend(data, -1, &backend);
let h2 = biased_entropy_rate_backend(data, -1, &backend);
assert!(h1.is_finite());
assert!(h2.is_finite());
}
#[cfg(feature = "backend-rwkv")]
#[test]
fn rwkv_infer_only_plugin_collapses_to_single_pass_entropy() {
let backend = RateBackend::Rwkv7Method {
method: "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=25,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer".to_string(),
};
let data = b"rwkv infer-only plugin equality sample";
let h = entropy_rate_backend(data, -1, &backend);
let plugin = biased_entropy_rate_backend(data, -1, &backend);
assert!(
(h - plugin).abs() < 1e-12,
"infer-only rwkv plugin should equal single-pass entropy: h={h}, plugin={plugin}"
);
}
#[cfg(feature = "backend-rwkv")]
#[test]
fn rwkv_method_biased_entropy_is_stable_across_calls_with_training_policy() {
let backend = RateBackend::Rwkv7Method {
method: "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=23,train=sgd,lr=0.01,stride=1;policy:schedule=0..100:train(scope=head+bias,opt=sgd,lr=0.01,stride=1,bptt=1,clip=0,momentum=0.0)".to_string(),
};
let data = b"rwkv plugin stability sample";
let h1 = biased_entropy_rate_backend(data, -1, &backend);
let h2 = biased_entropy_rate_backend(data, -1, &backend);
assert!(
(h1 - h2).abs() < 1e-12,
"rwkv method biased entropy leaked mutable state across calls: h1={h1}, h2={h2}"
);
}
#[cfg(feature = "backend-rwkv")]
#[test]
fn rwkv_method_conditional_chain_is_stable_across_calls() {
let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=22,train=sgd,lr=0.01,stride=1;policy:schedule=0..100:infer";
let ctx = InfotheoryCtx::new(
RateBackend::Rwkv7Method {
method: method.to_string(),
},
CompressionBackend::default(),
);
let prefix = b"universal prior slice";
let data = b"query payload";
let h1 = ctx.cross_entropy_conditional_chain(&[prefix.as_slice()], data);
let h2 = ctx.cross_entropy_conditional_chain(&[prefix.as_slice()], data);
assert!(
(h1 - h2).abs() < 1e-12,
"rwkv method conditional chain leaked mutable state across calls: h1={h1}, h2={h2}"
);
}
#[cfg(feature = "backend-mamba")]
#[test]
fn mamba_method_without_policy_is_accepted_by_public_api() {
let backend = RateBackend::MambaMethod {
method: "cfg:hidden=64,layers=1,intermediate=96".to_string(),
};
let data = b"mamba method without policy";
let h1 = entropy_rate_backend(data, -1, &backend);
let h2 = biased_entropy_rate_backend(data, -1, &backend);
assert!(h1.is_finite());
assert!(h2.is_finite());
}
#[cfg(feature = "backend-mamba")]
#[test]
fn mamba_infer_only_plugin_collapses_to_single_pass_entropy() {
let backend = RateBackend::MambaMethod {
method: "cfg:hidden=64,layers=1,intermediate=96,state=16,conv=4,dt_rank=16,seed=26,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer".to_string(),
};
let data = b"mamba infer-only plugin equality sample";
let h = entropy_rate_backend(data, -1, &backend);
let plugin = biased_entropy_rate_backend(data, -1, &backend);
assert!(
(h - plugin).abs() < 1e-12,
"infer-only mamba plugin should equal single-pass entropy: h={h}, plugin={plugin}"
);
}
#[cfg(feature = "backend-mamba")]
#[test]
fn mamba_method_biased_entropy_is_stable_across_calls_with_training_policy() {
let backend = RateBackend::MambaMethod {
method: "cfg:hidden=64,layers=1,intermediate=96,state=16,conv=4,dt_rank=16,seed=24,train=sgd,lr=0.01,stride=1;policy:schedule=0..100:train(scope=head+bias,opt=sgd,lr=0.01,stride=1,bptt=1,clip=0,momentum=0.0)".to_string(),
};
let data = b"mamba plugin stability sample";
let h1 = biased_entropy_rate_backend(data, -1, &backend);
let h2 = biased_entropy_rate_backend(data, -1, &backend);
assert!(
(h1 - h2).abs() < 1e-12,
"mamba method biased entropy leaked mutable state across calls: h1={h1}, h2={h2}"
);
}
#[test]
fn particle_entropy_rate_in_valid_range() {
let rb = test_particle_backend();
let data = b"hello world particle backend test";
let rate = entropy_rate_backend(data, -1, &rb);
assert!(
rate > 0.0 && rate < 8.0,
"particle entropy rate out of (0, 8) range: {rate}"
);
}
#[test]
fn particle_cross_entropy_stability() {
let rb = test_particle_backend();
let train = b"ABCABC";
let test = b"ABC";
let h1 = cross_entropy_rate_backend(test, train, -1, &rb);
let h2 = cross_entropy_rate_backend(test, train, -1, &rb);
assert!(
(h1 - h2).abs() < 1e-12,
"particle cross entropy not deterministic: h1={h1}, h2={h2}"
);
}
#[test]
fn particle_empty_input() {
let rb = RateBackend::Particle {
spec: Arc::new(ParticleSpec::default()),
};
let rate = entropy_rate_backend(b"", -1, &rb);
assert!(
rate == 0.0,
"particle entropy rate for empty input should be 0.0, got {rate}"
);
}
#[test]
fn particle_joint_entropy_rate() {
let rb = test_particle_backend();
let x = b"AAAA";
let y = b"BBBB";
let joint = joint_entropy_rate_backend(x, y, -1, &rb);
assert!(
joint > 0.0 && joint < 16.0,
"particle joint entropy rate out of range: {joint}"
);
}
}