use crate::RateBackend;
use crate::aixi::rate_backend::rate_backend_contains_zpaq;
use crate::ctw::{ContextTree, FacContextTree};
#[cfg(feature = "backend-mamba")]
use crate::mambazip::{Compressor as MambaCompressor, Model as MambaModel, State as MambaState};
use crate::mixture::{DEFAULT_MIN_PROB, OnlineBytePredictor, RateBackendPredictor};
use crate::rosaplus::{RosaPlus, RosaTx};
#[cfg(feature = "backend-rwkv")]
use crate::rwkvzip::{Compressor as RwkvCompressor, Model as RwkvModel, State as RwkvState};
use crate::zpaq_rate::ZpaqRateModel;
#[cfg(any(feature = "backend-mamba", feature = "backend-rwkv"))]
use std::sync::Arc;
pub trait Predictor: Send {
fn update(&mut self, sym: bool);
fn commit_update(&mut self, sym: bool) {
self.update(sym);
}
fn update_history(&mut self, sym: bool) {
self.update(sym);
}
fn commit_update_history(&mut self, sym: bool) {
self.update_history(sym);
}
fn revert(&mut self);
fn pop_history(&mut self) {
self.revert();
}
fn begin_rollback_scope(&mut self) {}
fn rollback_scope(&mut self) -> bool {
false
}
fn predict_prob(&mut self, sym: bool) -> f64;
fn predict_one(&mut self) -> f64 {
self.predict_prob(true)
}
fn model_name(&self) -> String;
fn boxed_clone(&self) -> Box<dyn Predictor>;
}
#[inline]
fn binary_prob_floor(min_prob: f64) -> f64 {
if min_prob.is_finite() {
min_prob.clamp(1e-12, 0.499_999_999_999)
} else {
1e-12
}
}
#[inline]
fn normalized_binary_prob_pair_from_probs(p0: f64, p1: f64, min_prob: f64) -> (f64, f64) {
let p0 = if p0.is_finite() && p0 > 0.0 { p0 } else { 0.0 };
let p1 = if p1.is_finite() && p1 > 0.0 { p1 } else { 0.0 };
let sum = p0 + p1;
if !sum.is_finite() || sum <= 0.0 {
return (0.5, 0.5);
}
let floor = binary_prob_floor(min_prob);
let q1 = (p1 / sum).clamp(floor, 1.0 - floor);
(1.0 - q1, q1)
}
#[inline]
fn normalized_binary_prob_pair_from_log_probs(logp0: f64, logp1: f64, min_prob: f64) -> (f64, f64) {
let max_log = logp0.max(logp1);
if !max_log.is_finite() {
return (0.5, 0.5);
}
let p0 = if logp0.is_finite() {
(logp0 - max_log).exp()
} else {
0.0
};
let p1 = if logp1.is_finite() {
(logp1 - max_log).exp()
} else {
0.0
};
normalized_binary_prob_pair_from_probs(p0, p1, min_prob)
}
pub struct CtwPredictor {
tree: ContextTree,
}
impl CtwPredictor {
pub fn new(depth: usize) -> Self {
Self {
tree: ContextTree::new(depth),
}
}
}
impl Predictor for CtwPredictor {
fn update(&mut self, sym: bool) {
self.tree.update(sym);
}
fn update_history(&mut self, sym: bool) {
self.tree.update_history(&[sym]);
}
fn revert(&mut self) {
self.tree.revert();
}
fn pop_history(&mut self) {
self.tree.revert_history();
}
fn predict_prob(&mut self, sym: bool) -> f64 {
self.tree.predict(sym)
}
fn model_name(&self) -> String {
format!("AC-CTW(d={})", self.tree.depth())
}
fn boxed_clone(&self) -> Box<dyn Predictor> {
Box::new(Self {
tree: self.tree.clone(),
})
}
}
pub struct FacCtwPredictor {
tree: FacContextTree,
current_bit: usize,
num_bits: usize,
}
impl FacCtwPredictor {
pub fn new(base_depth: usize, num_percept_bits: usize) -> Self {
Self {
tree: FacContextTree::new(base_depth, num_percept_bits),
current_bit: 0,
num_bits: num_percept_bits,
}
}
}
impl Predictor for FacCtwPredictor {
fn update(&mut self, sym: bool) {
self.tree.update(sym, self.current_bit);
self.current_bit = (self.current_bit + 1) % self.num_bits;
}
fn update_history(&mut self, sym: bool) {
self.tree.update_history(&[sym]);
}
fn revert(&mut self) {
self.current_bit = if self.current_bit == 0 {
self.num_bits - 1
} else {
self.current_bit - 1
};
self.tree.revert(self.current_bit);
}
fn pop_history(&mut self) {
self.tree.revert_history(1);
}
fn predict_prob(&mut self, sym: bool) -> f64 {
self.tree.predict(sym, self.current_bit)
}
fn model_name(&self) -> String {
format!("FAC-CTW(D={}, k={})", self.tree.base_depth(), self.num_bits)
}
fn boxed_clone(&self) -> Box<dyn Predictor> {
Box::new(Self {
tree: self.tree.clone(),
current_bit: self.current_bit,
num_bits: self.num_bits,
})
}
}
pub struct RosaPredictor {
model: RosaPlus,
history: Vec<RosaTx>,
}
impl RosaPredictor {
pub fn new(max_order: i64) -> Self {
let mut model = RosaPlus::new(max_order, false, 0, 42);
model.build_lm_full_bytes_no_finalize_endpos();
Self {
model,
history: Vec::new(),
}
}
}
impl Predictor for RosaPredictor {
fn update(&mut self, sym: bool) {
let mut tx = self.model.begin_tx();
let byte = if sym { 1u8 } else { 0u8 };
self.model.train_sequence_tx(&mut tx, &[byte]);
self.history.push(tx);
}
fn revert(&mut self) {
if let Some(tx) = self.history.pop() {
self.model.rollback_tx(tx);
}
}
fn predict_prob(&mut self, sym: bool) -> f64 {
let (p0, p1) = normalized_binary_prob_pair_from_probs(
self.model.prob_for_last(0),
self.model.prob_for_last(1),
DEFAULT_MIN_PROB,
);
if sym { p1 } else { p0 }
}
fn model_name(&self) -> String {
"ROSA".to_string()
}
fn boxed_clone(&self) -> Box<dyn Predictor> {
Box::new(Self {
model: self.model.clone(),
history: self.history.clone(),
})
}
}
pub struct ZpaqPredictor {
method: String,
min_prob: f64,
model: ZpaqRateModel,
history: Vec<u8>,
pending: Option<(u8, f64)>,
}
impl ZpaqPredictor {
pub fn new(method: String, min_prob: f64) -> Self {
let model = ZpaqRateModel::new(method.clone(), min_prob);
Self {
method,
min_prob,
model,
history: Vec::new(),
pending: None,
}
}
fn rebuild_from_history(&mut self) {
self.model.reset();
if !self.history.is_empty() {
self.model.update_and_score(&self.history);
}
}
fn log_prob_from_history(&self, symbol: u8) -> f64 {
let mut tmp = ZpaqRateModel::new(self.method.clone(), self.min_prob);
if !self.history.is_empty() {
tmp.update_and_score(&self.history);
}
tmp.log_prob(symbol)
}
fn binary_log_prob_pair(&mut self, preferred_symbol: u8) -> (f64, f64) {
let other_symbol = preferred_symbol ^ 1;
let preferred_logp = match self.pending {
Some((pending, logp)) if pending == preferred_symbol => logp,
Some(_) => self.log_prob_from_history(preferred_symbol),
None => {
let logp = self.model.log_prob(preferred_symbol);
self.pending = Some((preferred_symbol, logp));
logp
}
};
let other_logp = match self.pending {
Some((pending, logp)) if pending == other_symbol => logp,
_ => self.log_prob_from_history(other_symbol),
};
if preferred_symbol == 0 {
(preferred_logp, other_logp)
} else {
(other_logp, preferred_logp)
}
}
}
impl Predictor for ZpaqPredictor {
fn update(&mut self, sym: bool) {
let byte = if sym { 1u8 } else { 0u8 };
if let Some((pending, _)) = self.pending {
if pending == byte {
self.model.update(byte);
self.pending = None;
self.history.push(byte);
return;
}
self.pending = None;
self.rebuild_from_history();
}
self.model.update(byte);
self.history.push(byte);
}
fn revert(&mut self) {
if self.history.pop().is_some() {
self.pending = None;
self.rebuild_from_history();
}
}
fn predict_prob(&mut self, sym: bool) -> f64 {
let preferred_symbol = if sym { 1u8 } else { 0u8 };
let (logp0, logp1) = self.binary_log_prob_pair(preferred_symbol);
let (p0, p1) = normalized_binary_prob_pair_from_log_probs(logp0, logp1, self.min_prob);
if sym { p1 } else { p0 }
}
fn model_name(&self) -> String {
format!("ZPAQ({})", self.method)
}
fn boxed_clone(&self) -> Box<dyn Predictor> {
Box::new(Self {
method: self.method.clone(),
min_prob: self.min_prob,
model: self.model.clone(),
history: self.history.clone(),
pending: self.pending,
})
}
}
pub struct RateBackendBitPredictor {
backend: RateBackend,
max_order: i64,
min_prob: f64,
predictor: RateBackendPredictor,
journal: Vec<RateBackendJournalEntry>,
rollback_scopes: Vec<RateBackendRollbackScope>,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
enum RateBackendJournalKind {
Update,
FrozenUpdate,
}
#[derive(Clone)]
struct RateBackendJournalEntry {
kind: RateBackendJournalKind,
predictor: RateBackendPredictor,
}
#[derive(Clone)]
struct RateBackendRollbackScope {
predictor: RateBackendPredictor,
journal_len: usize,
}
impl RateBackendBitPredictor {
pub fn new(backend: RateBackend, max_order: i64) -> Result<Self, String> {
Self::new_with_min_prob(backend, max_order, DEFAULT_MIN_PROB)
}
pub fn new_with_min_prob(
backend: RateBackend,
max_order: i64,
min_prob: f64,
) -> Result<Self, String> {
if rate_backend_contains_zpaq(&backend) {
return Err(
"RateBackendBitPredictor does not support zpaq backends; use a non-zpaq rate_backend"
.to_string(),
);
}
let mut predictor =
RateBackendPredictor::from_backend(backend.clone(), max_order, min_prob);
predictor
.begin_stream(None)
.map_err(|err| format!("failed to start RateBackend predictor stream: {err}"))?;
Ok(Self {
backend,
max_order,
min_prob,
predictor,
journal: Vec::new(),
rollback_scopes: Vec::new(),
})
}
#[inline(always)]
fn bit_to_byte(sym: bool) -> u8 {
if sym { 1u8 } else { 0u8 }
}
fn clone_state(&self) -> Self {
Self {
backend: self.backend.clone(),
max_order: self.max_order,
min_prob: self.min_prob,
predictor: self.predictor.clone(),
journal: self.journal.clone(),
rollback_scopes: self.rollback_scopes.clone(),
}
}
fn checkpoint(&self, kind: RateBackendJournalKind) -> RateBackendJournalEntry {
RateBackendJournalEntry {
kind,
predictor: self.predictor.clone(),
}
}
fn restore_last(&mut self, expected_kind: RateBackendJournalKind) {
assert!(
self.rollback_scopes.is_empty(),
"RateBackendBitPredictor per-symbol rollback inside active scope is unsupported"
);
let entry = self
.journal
.pop()
.expect("RateBackendBitPredictor rollback underflow");
assert_eq!(
entry.kind, expected_kind,
"RateBackendBitPredictor rollback kind mismatch: expected {expected_kind:?}, got {:?}",
entry.kind
);
self.predictor = entry.predictor;
}
}
impl Predictor for RateBackendBitPredictor {
fn update(&mut self, sym: bool) {
if self.rollback_scopes.is_empty() {
self.journal
.push(self.checkpoint(RateBackendJournalKind::Update));
}
self.predictor.update(Self::bit_to_byte(sym));
}
fn commit_update(&mut self, sym: bool) {
self.predictor.update(Self::bit_to_byte(sym));
}
fn update_history(&mut self, sym: bool) {
if self.rollback_scopes.is_empty() {
self.journal
.push(self.checkpoint(RateBackendJournalKind::FrozenUpdate));
}
self.predictor.update_frozen(Self::bit_to_byte(sym));
}
fn commit_update_history(&mut self, sym: bool) {
self.predictor.update_frozen(Self::bit_to_byte(sym));
}
fn revert(&mut self) {
self.restore_last(RateBackendJournalKind::Update);
}
fn pop_history(&mut self) {
self.restore_last(RateBackendJournalKind::FrozenUpdate);
}
fn begin_rollback_scope(&mut self) {
self.rollback_scopes.push(RateBackendRollbackScope {
predictor: self.predictor.clone(),
journal_len: self.journal.len(),
});
}
fn rollback_scope(&mut self) -> bool {
let Some(scope) = self.rollback_scopes.pop() else {
return false;
};
self.predictor = scope.predictor;
self.journal.truncate(scope.journal_len);
true
}
fn predict_prob(&mut self, sym: bool) -> f64 {
let (p0, p1) = normalized_binary_prob_pair_from_log_probs(
self.predictor.log_prob(0),
self.predictor.log_prob(1),
self.min_prob,
);
if sym { p1 } else { p0 }
}
fn model_name(&self) -> String {
format!(
"RateBackendBits({})",
RateBackendPredictor::default_name(&self.backend, self.max_order)
)
}
fn boxed_clone(&self) -> Box<dyn Predictor> {
Box::new(self.clone_state())
}
}
#[cfg(feature = "backend-rwkv")]
use crate::coders::softmax_pdf_floor_inplace;
#[cfg(feature = "backend-rwkv")]
pub struct RwkvPredictor {
compressor: RwkvCompressor,
history: Vec<(RwkvState, Vec<f64>)>,
}
#[cfg(feature = "backend-rwkv")]
impl RwkvPredictor {
pub fn new(model: Arc<RwkvModel>) -> Self {
let mut compressor = RwkvCompressor::new_from_model(model);
let vocab_size = compressor.vocab_size();
let logits = compressor
.model
.forward(&mut compressor.scratch, 0, &mut compressor.state);
softmax_pdf_floor_inplace(logits, vocab_size, &mut compressor.pdf_buffer);
Self {
compressor,
history: Vec::new(),
}
}
pub fn from_method(method: &str) -> Result<Self, String> {
let mut compressor =
RwkvCompressor::new_from_method(method).map_err(|err| err.to_string())?;
compressor.forward_to_internal_pdf(0);
Ok(Self {
compressor,
history: Vec::new(),
})
}
}
#[cfg(feature = "backend-rwkv")]
impl Predictor for RwkvPredictor {
fn update(&mut self, sym: bool) {
self.history.push((
self.compressor.state.clone(),
self.compressor.pdf_buffer.clone(),
));
let byte = if sym { 1u32 } else { 0u32 };
let vocab_size = self.compressor.vocab_size();
let logits = self.compressor.model.forward(
&mut self.compressor.scratch,
byte,
&mut self.compressor.state,
);
softmax_pdf_floor_inplace(logits, vocab_size, &mut self.compressor.pdf_buffer);
}
fn revert(&mut self) {
if let Some((state, pdf)) = self.history.pop() {
self.compressor.state = state;
self.compressor.pdf_buffer = pdf;
}
}
fn predict_prob(&mut self, sym: bool) -> f64 {
let (p0, p1) = normalized_binary_prob_pair_from_probs(
self.compressor.pdf_buffer[0],
self.compressor.pdf_buffer[1],
DEFAULT_MIN_PROB,
);
if sym { p1 } else { p0 }
}
fn model_name(&self) -> String {
"RWKV".to_string()
}
fn boxed_clone(&self) -> Box<dyn Predictor> {
Box::new(Self {
compressor: self.compressor.clone(),
history: self.history.clone(),
})
}
}
#[cfg(feature = "backend-mamba")]
pub struct MambaPredictor {
compressor: MambaCompressor,
history: Vec<(MambaState, Vec<f64>)>,
}
#[cfg(feature = "backend-mamba")]
impl MambaPredictor {
pub fn new(model: Arc<MambaModel>) -> Self {
let mut compressor = MambaCompressor::new_from_model(model);
let logits = compressor
.model
.forward(&mut compressor.scratch, 0, &mut compressor.state)
.to_vec();
let bias = compressor.online_bias_snapshot();
MambaCompressor::logits_to_pdf(&logits, bias.as_deref(), &mut compressor.pdf_buffer);
Self {
compressor,
history: Vec::new(),
}
}
pub fn from_method(method: &str) -> Result<Self, String> {
let mut compressor =
MambaCompressor::new_from_method(method).map_err(|err| err.to_string())?;
let mut pdf = vec![0.0f64; compressor.vocab_size()];
compressor.forward_to_pdf(0, &mut pdf);
compressor.pdf_buffer.clone_from(&pdf);
Ok(Self {
compressor,
history: Vec::new(),
})
}
}
#[cfg(feature = "backend-mamba")]
impl Predictor for MambaPredictor {
fn update(&mut self, sym: bool) {
self.history.push((
self.compressor.state.clone(),
self.compressor.pdf_buffer.clone(),
));
let byte = if sym { 1u32 } else { 0u32 };
let logits = self
.compressor
.model
.forward(
&mut self.compressor.scratch,
byte,
&mut self.compressor.state,
)
.to_vec();
let bias = self.compressor.online_bias_snapshot();
MambaCompressor::logits_to_pdf(&logits, bias.as_deref(), &mut self.compressor.pdf_buffer);
}
fn revert(&mut self) {
if let Some((state, pdf)) = self.history.pop() {
self.compressor.state = state;
self.compressor.pdf_buffer = pdf;
}
}
fn predict_prob(&mut self, sym: bool) -> f64 {
let (p0, p1) = normalized_binary_prob_pair_from_probs(
self.compressor.pdf_buffer[0],
self.compressor.pdf_buffer[1],
DEFAULT_MIN_PROB,
);
if sym { p1 } else { p0 }
}
fn model_name(&self) -> String {
"Mamba".to_string()
}
fn boxed_clone(&self) -> Box<dyn Predictor> {
Box::new(Self {
compressor: self.compressor.clone(),
history: self.history.clone(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq(a: f64, b: f64) {
let diff = (a - b).abs();
assert!(
diff <= 1e-12,
"expected probabilities to match exactly enough: left={a} right={b} diff={diff}"
);
}
fn assert_binary_predictor_normalizes(mut predictor: Box<dyn Predictor>, label: &str) {
for (step, &bit) in [false, true, true, false, true, false].iter().enumerate() {
let p0 = predictor.predict_prob(false);
let p1 = predictor.predict_prob(true);
let sum = p0 + p1;
assert!(
(sum - 1.0).abs() < 1e-12,
"{label}: probabilities must sum to 1 at step {step}, got p0={p0}, p1={p1}, sum={sum}",
);
assert!(
(0.0..=1.0).contains(&p0) && (0.0..=1.0).contains(&p1),
"{label}: probabilities must stay in [0,1] at step {step}, got p0={p0}, p1={p1}",
);
predictor.commit_update(bit);
}
}
fn predictor_signature(
mut predictor: RateBackendBitPredictor,
probe: &[bool],
) -> Vec<(f64, f64)> {
let mut signature = Vec::with_capacity(probe.len());
for &bit in probe {
signature.push((predictor.predict_prob(false), predictor.predict_prob(true)));
predictor.commit_update(bit);
}
signature
}
#[test]
fn committed_rate_backend_updates_do_not_grow_journal() {
let mut predictor = RateBackendBitPredictor::new(RateBackend::RosaPlus, 8)
.expect("rate backend predictor should initialize");
for idx in 0..512usize {
predictor.commit_update((idx & 1) == 0);
predictor.commit_update_history((idx % 3) == 0);
}
assert!(
predictor.journal.is_empty(),
"committed history should not retain rollback snapshots"
);
}
#[test]
fn reversible_rate_backend_update_paths_round_trip_exactly() {
let mut predictor = RateBackendBitPredictor::new(RateBackend::RosaPlus, 8)
.expect("rate backend predictor should initialize");
for &bit in &[true, false, true, true, false, false, true] {
predictor.commit_update(bit);
}
let baseline_after_train = predictor.clone_state();
predictor.update(true);
predictor.update(false);
predictor.revert();
predictor.revert();
assert_eq!(predictor.journal.len(), baseline_after_train.journal.len());
let train_probe = [true, false, false, true, true, false];
let got = predictor_signature(predictor.clone_state(), &train_probe);
let want = predictor_signature(baseline_after_train.clone_state(), &train_probe);
for ((got0, got1), (want0, want1)) in got.into_iter().zip(want.into_iter()) {
approx_eq(got0, want0);
approx_eq(got1, want1);
}
let baseline_after_history = baseline_after_train.clone_state();
predictor.update_history(false);
predictor.update_history(true);
predictor.pop_history();
predictor.pop_history();
assert_eq!(
predictor.journal.len(),
baseline_after_history.journal.len()
);
let history_probe = [false, true, true, false, false, true];
let got = predictor_signature(predictor.clone_state(), &history_probe);
let want = predictor_signature(baseline_after_history, &history_probe);
for ((got0, got1), (want0, want1)) in got.into_iter().zip(want.into_iter()) {
approx_eq(got0, want0);
approx_eq(got1, want1);
}
}
#[test]
fn long_committed_history_does_not_contaminate_clone_rollback_state() {
let mut predictor = RateBackendBitPredictor::new(RateBackend::RosaPlus, 8)
.expect("rate backend predictor should initialize");
for idx in 0..2048usize {
predictor.commit_update((idx & 7) < 3);
predictor.commit_update_history((idx % 5) < 2);
}
assert!(predictor.journal.is_empty());
let mut cloned = predictor.clone_state();
assert!(
cloned.journal.is_empty(),
"clone state should only carry active reversible rollback depth"
);
let baseline = predictor_signature(predictor.clone_state(), &[true, false, true, false]);
cloned.update(true);
cloned.revert();
cloned.update_history(false);
cloned.pop_history();
assert!(cloned.journal.is_empty());
let after_round_trip = predictor_signature(cloned, &[true, false, true, false]);
for ((got0, got1), (want0, want1)) in after_round_trip.into_iter().zip(baseline.into_iter())
{
approx_eq(got0, want0);
approx_eq(got1, want1);
}
}
#[test]
fn rollback_scope_restores_simulation_state_without_growing_journal() {
let mut predictor = RateBackendBitPredictor::new(RateBackend::RosaPlus, 8)
.expect("rate backend predictor should initialize");
for &bit in &[true, false, true, false, true] {
predictor.commit_update(bit);
}
let baseline = predictor_signature(predictor.clone_state(), &[true, true, false, false]);
predictor.begin_rollback_scope();
for idx in 0..512usize {
predictor.update((idx & 1) == 0);
predictor.update_history((idx % 3) == 0);
}
assert!(
predictor.journal.is_empty(),
"scoped reversible updates should not retain per-bit snapshots"
);
assert!(predictor.rollback_scope(), "scope rollback should succeed");
assert!(predictor.journal.is_empty());
let after = predictor_signature(predictor, &[true, true, false, false]);
for ((got0, got1), (want0, want1)) in after.into_iter().zip(baseline.into_iter()) {
approx_eq(got0, want0);
approx_eq(got1, want1);
}
}
#[test]
fn cloned_predictor_carries_only_active_scope_snapshots() {
let mut predictor = RateBackendBitPredictor::new(RateBackend::RosaPlus, 8)
.expect("rate backend predictor should initialize");
for idx in 0..1024usize {
predictor.commit_update((idx & 3) == 0);
}
predictor.begin_rollback_scope();
for idx in 0..256usize {
predictor.update((idx & 1) == 0);
}
let cloned = predictor.clone_state();
assert!(
cloned.journal.is_empty(),
"scoped reversible updates should not leak per-bit journal state into clones"
);
assert_eq!(cloned.rollback_scopes.len(), 1);
}
#[test]
fn generic_rate_backend_bit_predictors_normalize_binary_mass() {
assert_binary_predictor_normalizes(
Box::new(
RateBackendBitPredictor::new(RateBackend::RosaPlus, 8)
.expect("generic rosa predictor"),
),
"generic-rosa",
);
assert_binary_predictor_normalizes(
Box::new(
RateBackendBitPredictor::new(
RateBackend::Ppmd {
order: 4,
memory_mb: 8,
},
8,
)
.expect("generic ppmd predictor"),
),
"generic-ppmd",
);
assert_binary_predictor_normalizes(
Box::new(
RateBackendBitPredictor::new(
RateBackend::Match {
hash_bits: 16,
min_len: 2,
max_len: 32,
base_mix: 0.05,
confidence_scale: 1.0,
},
8,
)
.expect("generic match predictor"),
),
"generic-match",
);
}
#[cfg(feature = "backend-zpaq")]
#[test]
fn zpaq_predictor_normalizes_binary_mass() {
assert_binary_predictor_normalizes(
Box::new(ZpaqPredictor::new("1".to_string(), DEFAULT_MIN_PROB)),
"zpaq",
);
}
#[cfg(feature = "backend-rwkv")]
#[test]
fn rwkv_predictor_normalizes_binary_mass() {
let method = "cfg:hidden=64,layers=1,intermediate=64,decay_rank=8,a_rank=8,v_rank=8,g_rank=8,seed=31,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer";
let predictor = RwkvPredictor::from_method(method).expect("rwkv predictor");
assert_binary_predictor_normalizes(Box::new(predictor), "rwkv");
}
#[cfg(feature = "backend-mamba")]
#[test]
fn mamba_predictor_normalizes_binary_mass() {
let method = "cfg:hidden=64,layers=1,intermediate=64,state=8,conv=3,dt_rank=4,seed=7,train=none,lr=0.0,stride=1;policy:schedule=0..100:infer";
let predictor = MambaPredictor::from_method(method).expect("mamba predictor");
assert_binary_predictor_normalizes(Box::new(predictor), "mamba");
}
}