use std::io::{self, Cursor, Read, Write};
use crate::dcx::{DcxHeader, FormatHint, Mode};
use crate::entropy::arithmetic::{ArithmeticDecoder, ArithmeticEncoder};
use crate::format::transform::TransformChain;
use crate::format::{detect_format, preprocess, reverse_preprocess};
use crate::mixer::MetaMixer;
use crate::model::gru_model::GruModel;
use crate::model::{CMConfig, CMEngine};
fn adaptive_fast_level(data_size: usize, level_override: Option<i32>) -> i32 {
if let Some(level) = level_override {
return level; }
match data_size {
0..=16_777_216 => 19, 16_777_217..=67_108_864 => 16, _ => 9, }
}
const DICT_MIN_DATA_SIZE: usize = 8192;
fn dict_chunk_size(data_len: usize) -> usize {
if data_len > 4_194_304 {
131_072 } else if data_len > 1_048_576 {
65_536 } else if data_len > 262_144 {
32_768 } else {
16_384 }
}
fn dict_max_size(data_len: usize) -> usize {
if data_len > 4_194_304 {
16_384 } else if data_len > 1_048_576 {
8_192 } else {
4_096 }
}
fn generate_training_samples(data: &[u8], chunk_size: usize) -> Vec<&[u8]> {
let col_chunks: Vec<&[u8]> = data.split(|&b| b == 0x00).collect();
if col_chunks.len() >= 5 {
let non_empty: Vec<&[u8]> = col_chunks.into_iter().filter(|c| !c.is_empty()).collect();
if !non_empty.is_empty() {
let avg_len = non_empty.iter().map(|c| c.len()).sum::<usize>() / non_empty.len();
if avg_len >= 8 {
return non_empty;
}
}
}
split_into_chunks(data, chunk_size)
}
fn split_into_chunks(data: &[u8], chunk_size: usize) -> Vec<&[u8]> {
let mut chunks = Vec::new();
let mut offset = 0;
while offset < data.len() {
let end = (offset + chunk_size).min(data.len());
chunks.push(&data[offset..end]);
offset = end;
}
chunks
}
fn try_dict_compress(data: &[u8], level: i32, plain_size: usize) -> Option<Vec<u8>> {
let chunk_size = dict_chunk_size(data.len());
let training_samples = generate_training_samples(data, chunk_size);
if training_samples.len() < 5 {
return None;
}
let max_dict = dict_max_size(data.len());
let dict = zstd::dict::from_samples(&training_samples, max_dict).ok()?;
if dict.is_empty() {
return None;
}
let chunks = split_into_chunks(data, chunk_size);
let mut compressor = zstd::bulk::Compressor::with_dictionary(level, &dict).ok()?;
let mut compressed_chunks: Vec<Vec<u8>> = Vec::with_capacity(chunks.len());
for chunk in &chunks {
let cc = compressor.compress(chunk).ok()?;
compressed_chunks.push(cc);
}
let total_compressed: usize = compressed_chunks.iter().map(|c| 4 + c.len()).sum();
let payload_size = 4 + dict.len() + 4 + total_compressed;
if payload_size >= plain_size {
return None;
}
let mut payload = Vec::with_capacity(payload_size);
payload.extend_from_slice(&(dict.len() as u32).to_le_bytes());
payload.extend_from_slice(&dict);
payload.extend_from_slice(&(compressed_chunks.len() as u32).to_le_bytes());
for cc in &compressed_chunks {
payload.extend_from_slice(&(cc.len() as u32).to_le_bytes());
payload.extend_from_slice(cc);
}
Some(payload)
}
fn decompress_with_dict(payload: &[u8], capacity: usize) -> std::io::Result<Vec<u8>> {
if payload.len() < 4 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"dict payload too short for dict_size",
));
}
let mut pos = 0;
let dict_size =
u32::from_le_bytes(payload[pos..pos + 4].try_into().expect("4-byte slice")) as usize;
pos += 4;
if payload.len() < pos + dict_size {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"dict payload truncated: dictionary bytes",
));
}
let dict_bytes = &payload[pos..pos + dict_size];
pos += dict_size;
if payload.len() < pos + 4 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"dict payload truncated: num_chunks",
));
}
let num_chunks =
u32::from_le_bytes(payload[pos..pos + 4].try_into().expect("4-byte slice")) as usize;
pos += 4;
let mut decompressor = zstd::bulk::Decompressor::with_dictionary(dict_bytes)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
let mut output = Vec::with_capacity(capacity);
for i in 0..num_chunks {
if payload.len() < pos + 4 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("dict payload truncated: chunk {i} size"),
));
}
let chunk_size =
u32::from_le_bytes(payload[pos..pos + 4].try_into().expect("4-byte slice")) as usize;
pos += 4;
if payload.len() < pos + chunk_size {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("dict payload truncated: chunk {i} data"),
));
}
let chunk_data = &payload[pos..pos + chunk_size];
pos += chunk_size;
let chunk_capacity = capacity.saturating_sub(output.len());
let decompressed = decompressor
.decompress(chunk_data, chunk_capacity)
.map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("chunk {i} decompress failed: {e}"),
)
})?;
output.extend_from_slice(&decompressed);
}
Ok(output)
}
const BROTLI_MODE_GENERIC: u32 = 0;
const BROTLI_MODE_TEXT: u32 = 1;
fn brotli_compress(data: &[u8], quality: u32, mode: u32) -> io::Result<Vec<u8>> {
use brotli::enc::backward_references::BrotliEncoderMode;
let mut output = Vec::new();
let brotli_mode = match mode {
1 => BrotliEncoderMode::BROTLI_MODE_TEXT,
_ => BrotliEncoderMode::BROTLI_MODE_GENERIC,
};
let params = brotli::enc::BrotliEncoderParams {
quality: quality as i32,
mode: brotli_mode,
..Default::default()
};
brotli::BrotliCompress(&mut io::Cursor::new(data), &mut output, ¶ms)?;
Ok(output)
}
fn brotli_decompress(data: &[u8]) -> io::Result<Vec<u8>> {
let mut output = Vec::new();
brotli::BrotliDecompress(&mut io::Cursor::new(data), &mut output)?;
Ok(output)
}
fn cm_compress(data: &[u8], config: CMConfig) -> Vec<u8> {
let mut engine = CMEngine::with_config(config);
let mut encoder = ArithmeticEncoder::new();
for &byte in data {
for bpos in 0..8 {
let bit = (byte >> (7 - bpos)) & 1;
let p = engine.predict();
encoder.encode(bit, p);
engine.update(bit);
}
}
encoder.finish()
}
fn cm_decompress(compressed: &[u8], original_size: usize, config: CMConfig) -> Vec<u8> {
let mut engine = CMEngine::with_config(config);
let mut decoder = ArithmeticDecoder::new(compressed);
let mut output = Vec::with_capacity(original_size);
for _ in 0..original_size {
let mut byte_val: u8 = 0;
for bpos in 0..8 {
let p = engine.predict();
let bit = decoder.decode(p);
engine.update(bit);
byte_val |= bit << (7 - bpos);
}
output.push(byte_val);
}
output
}
fn gru_compress(data: &[u8], config: CMConfig) -> Vec<u8> {
let mut engine = CMEngine::with_config(config);
let mut gru = GruModel::new();
let mut meta_mixer = MetaMixer::new(12); let mut encoder = ArithmeticEncoder::new();
let total_bytes = data.len();
let report_interval = if total_bytes > 100_000 {
total_bytes / 20
} else {
0
};
for (byte_idx, &byte) in data.iter().enumerate() {
for bpos in 0..8u8 {
let bit = (byte >> (7 - bpos)) & 1;
let p_cm = engine.predict();
let partial = if bpos == 0 {
1u32
} else {
let mut p = 1u32;
for prev_bpos in 0..bpos {
let prev_bit = (byte >> (7 - prev_bpos)) & 1;
p = (p << 1) | prev_bit as u32;
}
p
};
let p_gru = gru.predict_bit(bpos, partial);
let p_final = meta_mixer.blend(p_cm, p_gru);
encoder.encode(bit, p_final);
engine.update(bit);
meta_mixer.update(bit);
}
gru.train(byte);
gru.forward(byte);
if report_interval > 0 && (byte_idx + 1) % report_interval == 0 {
let pct = (byte_idx + 1) * 100 / total_bytes;
eprint!("\r[gru] compressing... {pct}%");
}
}
if total_bytes > 100_000 {
eprintln!("\r[gru] compressing... 100%");
}
encoder.finish()
}
fn gru_decompress(compressed: &[u8], original_size: usize, config: CMConfig) -> Vec<u8> {
let mut engine = CMEngine::with_config(config);
let mut gru = GruModel::new();
let mut meta_mixer = MetaMixer::new(12); let mut decoder = ArithmeticDecoder::new(compressed);
let mut output = Vec::with_capacity(original_size);
let report_interval = if original_size > 100_000 {
original_size / 20
} else {
0
};
for byte_idx in 0..original_size {
let mut byte_val: u8 = 0;
for bpos in 0..8u8 {
let p_cm = engine.predict();
let partial = if bpos == 0 {
1u32
} else {
let mut p = 1u32;
for prev_bpos in 0..bpos {
let prev_bit = (byte_val >> (7 - prev_bpos)) & 1;
p = (p << 1) | prev_bit as u32;
}
p
};
let p_gru = gru.predict_bit(bpos, partial);
let p_final = meta_mixer.blend(p_cm, p_gru);
let bit = decoder.decode(p_final);
engine.update(bit);
meta_mixer.update(bit);
byte_val |= bit << (7 - bpos);
}
output.push(byte_val);
gru.train(byte_val);
gru.forward(byte_val);
if report_interval > 0 && (byte_idx + 1) % report_interval == 0 {
let pct = (byte_idx + 1) * 100 / original_size;
eprint!("\r[gru] decompressing... {pct}%");
}
}
if original_size > 100_000 {
eprintln!("\r[gru] decompressing... 100%");
}
output
}
#[cfg(feature = "neural")]
fn neural_compress(
data: &[u8],
config: CMConfig,
llm: &mut datacortex_neural::LlmPredictor,
meta_mixer: &mut datacortex_neural::MetaMixer,
) -> Vec<u8> {
let mut engine = CMEngine::with_config(config);
let mut encoder = ArithmeticEncoder::new();
let total_bytes = data.len();
let mut bytes_processed = 0;
let report_interval = total_bytes / 20;
for (byte_idx, &byte) in data.iter().enumerate() {
for bpos in 0..8u8 {
let bit = (byte >> (7 - bpos)) & 1;
let p_cm = engine.predict();
let partial = if bpos == 0 {
1u32
} else {
let mut p = 1u32;
for prev_bpos in 0..bpos {
let prev_bit = (byte >> (7 - prev_bpos)) & 1;
p = (p << 1) | prev_bit as u32;
}
p
};
let p_llm = llm.predict_bit(bpos, partial);
let p_final = meta_mixer.blend(p_cm, p_llm);
encoder.encode(bit, p_final);
engine.update(bit);
meta_mixer.update(bit);
}
if let Err(e) = llm.predict_byte_probs(byte) {
if byte_idx < 5 {
eprintln!("[neural] LLM predict error at byte {byte_idx}: {e}");
}
}
bytes_processed += 1;
if report_interval > 0 && bytes_processed % report_interval == 0 {
let pct = bytes_processed * 100 / total_bytes;
eprint!("\r[neural] compressing... {pct}%");
}
}
if total_bytes > 1000 {
eprintln!("\r[neural] compressing... 100%");
}
encoder.finish()
}
#[cfg(feature = "neural")]
fn neural_decompress(
compressed: &[u8],
original_size: usize,
config: CMConfig,
llm: &mut datacortex_neural::LlmPredictor,
meta_mixer: &mut datacortex_neural::MetaMixer,
) -> Vec<u8> {
let mut engine = CMEngine::with_config(config);
let mut decoder = ArithmeticDecoder::new(compressed);
let mut output = Vec::with_capacity(original_size);
let report_interval = if original_size > 0 {
original_size / 20
} else {
1
};
for byte_idx in 0..original_size {
let mut byte_val: u8 = 0;
for bpos in 0..8u8 {
let p_cm = engine.predict();
let partial = if bpos == 0 {
1u32
} else {
let mut p = 1u32;
for prev_bpos in 0..bpos {
let prev_bit = (byte_val >> (7 - prev_bpos)) & 1;
p = (p << 1) | prev_bit as u32;
}
p
};
let p_llm = llm.predict_bit(bpos, partial);
let p_final = meta_mixer.blend(p_cm, p_llm);
let bit = decoder.decode(p_final);
engine.update(bit);
meta_mixer.update(bit);
byte_val |= bit << (7 - bpos);
}
output.push(byte_val);
if let Err(e) = llm.predict_byte_probs(byte_val) {
if byte_idx < 5 {
eprintln!("[neural] LLM predict error at byte {byte_idx}: {e}");
}
}
if report_interval > 0 && (byte_idx + 1) % report_interval == 0 {
let pct = (byte_idx + 1) * 100 / original_size;
eprint!("\r[neural] decompressing... {pct}%");
}
}
if original_size > 1000 {
eprintln!("\r[neural] decompressing... 100%");
}
output
}
fn cm_config_for_mode(mode: Mode) -> CMConfig {
match mode {
Mode::Max => CMConfig::max(),
Mode::Balanced => CMConfig::balanced(),
Mode::Fast => CMConfig::balanced(), }
}
#[cfg(feature = "neural")]
fn resolve_model_path(explicit: Option<&str>) -> Option<String> {
if let Some(p) = explicit {
if std::path::Path::new(p).exists() {
return Some(p.to_string());
}
eprintln!("[neural] explicit model path not found: {p}");
return None;
}
if let Ok(p) = std::env::var("DATACORTEX_MODEL") {
if p.is_empty() {
return None;
}
if std::path::Path::new(&p).exists() {
return Some(p);
}
eprintln!("[neural] DATACORTEX_MODEL path not found: {p}");
return None; }
if let Some(home) = std::env::var_os("HOME") {
let default = format!(
"{}/.datacortex/models/SmolLM2-135M-Instruct-Q8_0.gguf",
home.to_string_lossy()
);
if std::path::Path::new(&default).exists() {
return Some(default);
}
}
None
}
pub fn train_dict(samples: &[&[u8]], max_dict_size: usize) -> io::Result<Vec<u8>> {
if samples.is_empty() {
return Err(io::Error::other(
"no samples provided for dictionary training",
));
}
let mut fragments: Vec<&[u8]> = Vec::new();
for sample in samples {
if sample.is_empty() {
continue;
}
let lines: Vec<&[u8]> = sample
.split(|&b| b == b'\n')
.filter(|l| !l.is_empty())
.collect();
if lines.len() >= 5 {
fragments.extend(lines);
} else {
let chunk_size = 4096.min(sample.len());
let mut offset = 0;
while offset < sample.len() {
let end = (offset + chunk_size).min(sample.len());
fragments.push(&sample[offset..end]);
offset = end;
}
}
}
if fragments.len() < 5 {
return Err(io::Error::other(
"not enough training data (need at least 5 fragments)",
));
}
let dict = zstd::dict::from_samples(&fragments, max_dict_size)
.map_err(|e| io::Error::other(format!("dictionary training failed: {e}")))?;
if dict.is_empty() {
return Err(io::Error::other(
"dictionary training produced empty dictionary",
));
}
Ok(dict)
}
pub fn compress<W: Write>(
data: &[u8],
mode: Mode,
format_override: Option<FormatHint>,
output: &mut W,
) -> io::Result<()> {
compress_with_model(data, mode, format_override, None, output)
}
pub fn compress_with_model<W: Write>(
data: &[u8],
mode: Mode,
format_override: Option<FormatHint>,
model_path: Option<&str>,
output: &mut W,
) -> io::Result<()> {
compress_with_options(data, mode, format_override, model_path, None, output)
}
pub fn compress_with_options<W: Write>(
data: &[u8],
mode: Mode,
format_override: Option<FormatHint>,
model_path: Option<&str>,
zstd_level_override: Option<i32>,
output: &mut W,
) -> io::Result<()> {
compress_with_full_options(
data,
mode,
format_override,
model_path,
zstd_level_override,
None,
output,
)
}
pub fn compress_with_full_options<W: Write>(
data: &[u8],
mode: Mode,
format_override: Option<FormatHint>,
model_path: Option<&str>,
zstd_level_override: Option<i32>,
external_dict: Option<&[u8]>,
output: &mut W,
) -> io::Result<()> {
let format_hint = format_override.unwrap_or_else(|| detect_format(data));
let crc = crc32fast::hash(data);
let (preprocessed, chain) = preprocess(data, format_hint, mode);
let transform_metadata = if chain.is_empty() {
vec![]
} else {
chain.serialize()
};
let mut use_dict = false;
let mut use_brotli = false;
let mut use_raw_fallback = false;
let mut use_meta_embedded = false;
let compressed = match mode {
Mode::Fast => {
use std::sync::Mutex;
let level = adaptive_fast_level(preprocessed.len(), zstd_level_override);
let raw_level = adaptive_fast_level(data.len(), zstd_level_override);
let meta_size_for_comparison = if transform_metadata.len() > 64 {
let compressed_meta = zstd::bulk::compress(&transform_metadata, 19)
.unwrap_or_else(|_| transform_metadata.clone());
compressed_meta.len().min(transform_metadata.len())
} else {
transform_metadata.len()
};
let embedded_payload = if !transform_metadata.is_empty() {
let mut ep = Vec::with_capacity(4 + transform_metadata.len() + preprocessed.len());
ep.extend_from_slice(&(transform_metadata.len() as u32).to_le_bytes());
ep.extend_from_slice(&transform_metadata);
ep.extend_from_slice(&preprocessed);
Some(ep)
} else {
None
};
type PathResult = (Vec<u8>, usize, bool, bool, bool, bool);
let results = Mutex::new(Vec::<PathResult>::with_capacity(8));
rayon::scope(|s| {
s.spawn(|_| {
if let Ok(plain) = zstd::bulk::compress(&preprocessed, level) {
let (compressed, is_dict) = if let Some(ext_dict) = external_dict {
let chunk_size = dict_chunk_size(preprocessed.len());
let chunks = split_into_chunks(&preprocessed, chunk_size);
if let Ok(mut compressor) =
zstd::bulk::Compressor::with_dictionary(level, ext_dict)
{
let mut ok = true;
let mut cc_list = Vec::with_capacity(chunks.len());
for chunk in &chunks {
match compressor.compress(chunk) {
Ok(cc) => cc_list.push(cc),
Err(_) => {
ok = false;
break;
}
}
}
if ok {
let total_cc: usize = cc_list.iter().map(|c| 4 + c.len()).sum();
let payload_size = 4 + ext_dict.len() + 4 + total_cc;
if payload_size < plain.len() {
let mut payload = Vec::with_capacity(payload_size);
payload.extend_from_slice(
&(ext_dict.len() as u32).to_le_bytes(),
);
payload.extend_from_slice(ext_dict);
payload.extend_from_slice(
&(cc_list.len() as u32).to_le_bytes(),
);
for cc in &cc_list {
payload.extend_from_slice(
&(cc.len() as u32).to_le_bytes(),
);
payload.extend_from_slice(cc);
}
(payload, true)
} else {
(plain, false)
}
} else {
(plain, false)
}
} else {
(plain, false)
}
} else if preprocessed.len() >= DICT_MIN_DATA_SIZE {
if let Some(dict_payload) =
try_dict_compress(&preprocessed, level, plain.len())
{
(dict_payload, true)
} else {
(plain, false)
}
} else {
(plain, false)
};
let total = 32 + meta_size_for_comparison + compressed.len();
results
.lock()
.unwrap()
.push((compressed, total, is_dict, false, false, false));
}
});
s.spawn(|_| {
if let Ok(compressed) = zstd::bulk::compress(data, raw_level) {
let total = 32 + compressed.len();
results
.lock()
.unwrap()
.push((compressed, total, false, true, false, false));
}
});
s.spawn(|_| {
let q = if data.len() <= 1_048_576 { 11 } else { 9 };
if let Ok(compressed) = brotli_compress(data, q, BROTLI_MODE_TEXT) {
let total = 32 + compressed.len();
results
.lock()
.unwrap()
.push((compressed, total, false, true, true, false));
}
});
s.spawn(|_| {
let max_q = if preprocessed.len() <= 1_048_576 {
11
} else {
9
};
let qualities: &[u32] = if max_q == 11 {
&[11, 10]
} else {
&[max_q as u32]
};
let mut best: Option<PathResult> = None;
for &q in qualities {
if let Ok(compressed) =
brotli_compress(&preprocessed, q, BROTLI_MODE_GENERIC)
{
let total = 32 + meta_size_for_comparison + compressed.len();
if best.as_ref().is_none_or(|b| total < b.1) {
best = Some((compressed, total, false, false, true, false));
}
}
}
if let Some(r) = best {
results.lock().unwrap().push(r);
}
});
if let Some(ref ep) = embedded_payload {
s.spawn(|_| {
let max_q = if ep.len() <= 1_048_576 { 11 } else { 9 };
let qualities: &[u32] = if max_q == 11 {
&[11, 10]
} else {
&[max_q as u32]
};
let mut best: Option<PathResult> = None;
for &q in qualities {
if let Ok(compressed) = brotli_compress(ep, q, BROTLI_MODE_GENERIC) {
let total = 32 + compressed.len();
if best.as_ref().is_none_or(|b| total < b.1) {
best = Some((compressed, total, false, false, true, true));
}
}
}
if let Some(r) = best {
results.lock().unwrap().push(r);
}
});
}
if let Some(ref ep) = embedded_payload {
s.spawn(|_| {
let embed_level = adaptive_fast_level(ep.len(), zstd_level_override);
if let Ok(compressed) = zstd::bulk::compress(ep, embed_level) {
let total = 32 + compressed.len();
results
.lock()
.unwrap()
.push((compressed, total, false, false, false, true));
}
});
}
});
let results = results.into_inner().unwrap();
let best = results
.into_iter()
.min_by_key(|r| r.1)
.ok_or_else(|| io::Error::other("all compression paths failed"))?;
use_dict = best.2;
use_raw_fallback = best.3;
use_brotli = best.4;
use_meta_embedded = best.5;
best.0
}
Mode::Balanced => {
let config = cm_config_for_mode(mode);
let cm_data = gru_compress(&preprocessed, config);
let mut payload = Vec::with_capacity(8 + cm_data.len());
payload.extend_from_slice(&(preprocessed.len() as u64).to_le_bytes());
payload.extend_from_slice(&cm_data);
payload
}
Mode::Max => {
let config = cm_config_for_mode(mode);
#[cfg(feature = "neural")]
{
if let Some(mpath) = resolve_model_path(model_path) {
match datacortex_neural::LlmPredictor::new(&mpath) {
Ok(mut llm) => {
let mut meta_mixer = datacortex_neural::MetaMixer::new(5);
eprintln!(
"[neural] Max mode: dual-path CM+LLM ({} bytes mapped)",
llm.mapped_bytes()
);
let cm_data =
neural_compress(&preprocessed, config, &mut llm, &mut meta_mixer);
let mut payload = Vec::with_capacity(8 + cm_data.len());
let size_with_flag = preprocessed.len() as u64 | (1u64 << 63);
payload.extend_from_slice(&size_with_flag.to_le_bytes());
payload.extend_from_slice(&cm_data);
payload
}
Err(e) => {
eprintln!("[neural] LLM init failed, falling back to CM-only: {e}");
let cm_data = cm_compress(&preprocessed, config);
let mut payload = Vec::with_capacity(8 + cm_data.len());
payload.extend_from_slice(&(preprocessed.len() as u64).to_le_bytes());
payload.extend_from_slice(&cm_data);
payload
}
}
} else {
eprintln!(
"[neural] no model found, Max mode using CM-only. \
Set DATACORTEX_MODEL or use --model-path."
);
let cm_data = cm_compress(&preprocessed, config);
let mut payload = Vec::with_capacity(8 + cm_data.len());
payload.extend_from_slice(&(preprocessed.len() as u64).to_le_bytes());
payload.extend_from_slice(&cm_data);
payload
}
}
#[cfg(not(feature = "neural"))]
{
let _ = model_path; let cm_data = cm_compress(&preprocessed, config);
let mut payload = Vec::with_capacity(8 + cm_data.len());
payload.extend_from_slice(&(preprocessed.len() as u64).to_le_bytes());
payload.extend_from_slice(&cm_data);
payload
}
}
};
let final_metadata = if use_raw_fallback || use_meta_embedded {
vec![]
} else {
transform_metadata
};
let (header_metadata, meta_compressed) = if final_metadata.len() > 64 {
let compressed_meta =
zstd::bulk::compress(&final_metadata, 19).unwrap_or_else(|_| final_metadata.clone());
if compressed_meta.len() < final_metadata.len() {
(compressed_meta, true)
} else {
(final_metadata, false)
}
} else {
(final_metadata, false)
};
let header = DcxHeader {
mode,
format_hint,
original_size: data.len() as u64,
compressed_size: compressed.len() as u64,
crc32: crc,
transform_metadata: header_metadata,
has_dict: use_dict,
meta_compressed,
use_brotli,
meta_embedded: use_meta_embedded,
};
header.write_to(output)?;
output.write_all(&compressed)?;
Ok(())
}
pub fn decompress<R: Read>(input: &mut R) -> io::Result<Vec<u8>> {
decompress_with_model(input, None)
}
pub fn decompress_with_model<R: Read>(
input: &mut R,
model_path: Option<&str>,
) -> io::Result<Vec<u8>> {
let header = DcxHeader::read_from(input)?;
let mut compressed = vec![0u8; header.compressed_size as usize];
input.read_exact(&mut compressed)?;
let preprocessed = match header.mode {
Mode::Fast => {
if header.use_brotli {
brotli_decompress(&compressed)?
} else {
let capacity = header.original_size as usize * 2 + 65536;
if header.has_dict {
decompress_with_dict(&compressed, capacity)?
} else {
zstd::bulk::decompress(&compressed, capacity)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?
}
}
}
Mode::Balanced => {
if compressed.len() < 8 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"CM mode compressed data too short",
));
}
let size_raw = u64::from_le_bytes(compressed[..8].try_into().expect("8-byte slice"));
let preprocessed_size = (size_raw & !(1u64 << 63)) as usize;
let config = cm_config_for_mode(header.mode);
gru_decompress(&compressed[8..], preprocessed_size, config)
}
Mode::Max => {
if compressed.len() < 8 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"CM mode compressed data too short",
));
}
let size_raw = u64::from_le_bytes(compressed[..8].try_into().expect("8-byte slice"));
let neural_flag = size_raw & (1u64 << 63) != 0;
let preprocessed_size = (size_raw & !(1u64 << 63)) as usize;
let config = cm_config_for_mode(header.mode);
if neural_flag {
#[cfg(feature = "neural")]
{
if let Some(mpath) = resolve_model_path(model_path) {
match datacortex_neural::LlmPredictor::new(&mpath) {
Ok(mut llm) => {
let mut meta_mixer = datacortex_neural::MetaMixer::new(5);
eprintln!(
"[neural] decompressing with dual-path CM+LLM ({} bytes mapped)",
llm.mapped_bytes()
);
neural_decompress(
&compressed[8..],
preprocessed_size,
config,
&mut llm,
&mut meta_mixer,
)
}
Err(e) => {
return Err(io::Error::new(
io::ErrorKind::Other,
format!(
"file was compressed with neural mode but LLM failed to load: {e}"
),
));
}
}
} else {
return Err(io::Error::new(
io::ErrorKind::Other,
"file was compressed with neural mode but no model found. \
Set DATACORTEX_MODEL or use --model-path.",
));
}
}
#[cfg(not(feature = "neural"))]
{
let _ = model_path;
return Err(io::Error::other(
"file was compressed with neural mode but this build lacks the \
`neural` feature. Rebuild with --features neural.",
));
}
} else {
cm_decompress(&compressed[8..], preprocessed_size, config)
}
}
};
let (preprocessed, transform_metadata) = if header.meta_embedded {
if preprocessed.len() < 4 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"embedded metadata: decompressed stream too short for meta_len",
));
}
let meta_len =
u32::from_le_bytes(preprocessed[0..4].try_into().expect("4-byte slice")) as usize;
if preprocessed.len() < 4 + meta_len {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"embedded metadata: stream too short for metadata ({} bytes needed, {} available)",
4 + meta_len,
preprocessed.len()
),
));
}
let metadata = preprocessed[4..4 + meta_len].to_vec();
let actual_preprocessed = preprocessed[4 + meta_len..].to_vec();
(actual_preprocessed, metadata)
} else {
let tm = if header.meta_compressed && !header.transform_metadata.is_empty() {
let mut decoder =
zstd::Decoder::new(Cursor::new(&header.transform_metadata)).map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("failed to init metadata decompressor: {e}"),
)
})?;
let mut decompressed_meta = Vec::new();
decoder.read_to_end(&mut decompressed_meta).map_err(|e| {
io::Error::new(
io::ErrorKind::InvalidData,
format!("failed to decompress transform metadata: {e}"),
)
})?;
decompressed_meta
} else {
header.transform_metadata.clone()
};
(preprocessed, tm)
};
let data = if transform_metadata.is_empty() {
preprocessed
} else {
let chain = TransformChain::deserialize(&transform_metadata)?;
reverse_preprocess(&preprocessed, &chain)
};
let crc = crc32fast::hash(&data);
if crc != header.crc32 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"CRC-32 mismatch: expected {:#010X}, got {:#010X}",
header.crc32, crc
),
));
}
if data.len() as u64 != header.original_size {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"size mismatch: header says {} bytes, got {}",
header.original_size,
data.len()
),
));
}
Ok(data)
}
pub fn compress_to_vec(
data: &[u8],
mode: Mode,
format_override: Option<FormatHint>,
) -> io::Result<Vec<u8>> {
let mut buf = Vec::new();
compress(data, mode, format_override, &mut buf)?;
Ok(buf)
}
pub fn compress_to_vec_with_model(
data: &[u8],
mode: Mode,
format_override: Option<FormatHint>,
model_path: Option<&str>,
) -> io::Result<Vec<u8>> {
let mut buf = Vec::new();
compress_with_model(data, mode, format_override, model_path, &mut buf)?;
Ok(buf)
}
pub fn compress_to_vec_with_options(
data: &[u8],
mode: Mode,
format_override: Option<FormatHint>,
model_path: Option<&str>,
zstd_level_override: Option<i32>,
) -> io::Result<Vec<u8>> {
let mut buf = Vec::new();
compress_with_options(
data,
mode,
format_override,
model_path,
zstd_level_override,
&mut buf,
)?;
Ok(buf)
}
pub fn decompress_from_slice(dcx_data: &[u8]) -> io::Result<Vec<u8>> {
let mut cursor = Cursor::new(dcx_data);
decompress(&mut cursor)
}
pub fn read_header<R: Read>(input: &mut R) -> io::Result<DcxHeader> {
DcxHeader::read_from(input)
}
pub fn raw_zstd_compress(data: &[u8], level: i32) -> io::Result<Vec<u8>> {
zstd::bulk::compress(data, level).map_err(io::Error::other)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn fast_mode_roundtrip() {
let original = b"Hello, DataCortex! This is a test of Fast mode compression.";
let compressed = compress_to_vec(original, Mode::Fast, None).unwrap();
let decompressed = decompress_from_slice(&compressed).unwrap();
assert_eq!(decompressed, original);
}
#[test]
fn fast_mode_json_roundtrip() {
let data = br#"{"name":"Alice","age":30,"name":"Bob","age":25,"name":"Carol","age":35}"#;
let compressed = compress_to_vec(data, Mode::Fast, Some(FormatHint::Json)).unwrap();
let decompressed = decompress_from_slice(&compressed).unwrap();
assert_eq!(decompressed, data.to_vec());
}
#[test]
fn balanced_mode_roundtrip() {
let original = b"Balanced mode test data with some content.";
let compressed = compress_to_vec(original, Mode::Balanced, None).unwrap();
let decompressed = decompress_from_slice(&compressed).unwrap();
assert_eq!(decompressed, original);
}
#[test]
fn balanced_mode_longer_text() {
let original = b"The quick brown fox jumps over the lazy dog. This sentence contains every letter of the English alphabet at least once. We need enough data to properly exercise the arithmetic coder and order-0 model.";
let compressed = compress_to_vec(original, Mode::Balanced, None).unwrap();
let decompressed = decompress_from_slice(&compressed).unwrap();
assert_eq!(decompressed, original);
}
#[test]
fn balanced_mode_repetitive_data() {
let data = "hello world! ".repeat(100);
let compressed = compress_to_vec(data.as_bytes(), Mode::Balanced, None).unwrap();
let decompressed = decompress_from_slice(&compressed).unwrap();
assert_eq!(decompressed, data.as_bytes());
}
#[test]
fn balanced_mode_all_byte_values() {
let original: Vec<u8> = (0..=255).collect();
let compressed = compress_to_vec(&original, Mode::Balanced, None).unwrap();
let decompressed = decompress_from_slice(&compressed).unwrap();
assert_eq!(decompressed, original);
}
#[test]
fn balanced_mode_single_byte() {
let original = b"X";
let compressed = compress_to_vec(original, Mode::Balanced, None).unwrap();
let decompressed = decompress_from_slice(&compressed).unwrap();
assert_eq!(decompressed, original);
}
#[test]
fn balanced_mode_json_roundtrip() {
let data = br#"{"name":"Alice","age":30,"name":"Bob","age":25,"name":"Carol","age":35}"#;
let compressed = compress_to_vec(data, Mode::Balanced, Some(FormatHint::Json)).unwrap();
let decompressed = decompress_from_slice(&compressed).unwrap();
assert_eq!(decompressed, data.to_vec());
}
#[test]
fn empty_data_roundtrip() {
let original = b"";
for mode in [Mode::Fast, Mode::Balanced, Mode::Max] {
let compressed = compress_to_vec(original, mode, None).unwrap();
let decompressed = decompress_from_slice(&compressed).unwrap();
assert_eq!(decompressed, original, "failed for mode {mode}");
}
}
#[test]
fn crc_mismatch_detected() {
let original = b"test data for CRC check";
let mut compressed = compress_to_vec(original, Mode::Fast, None).unwrap();
let header_size = 32; if compressed.len() > header_size + 5 {
compressed[header_size + 3] ^= 0xFF;
}
assert!(decompress_from_slice(&compressed).is_err());
}
#[test]
fn fast_mode_actually_compresses() {
let data = "hello world. ".repeat(100);
let compressed = compress_to_vec(data.as_bytes(), Mode::Fast, None).unwrap();
assert!(
compressed.len() < data.len(),
"Fast mode should compress repetitive data: {} vs {}",
compressed.len(),
data.len()
);
}
#[test]
fn json_preprocessing_improves_fast_mode() {
let data = br#"[{"name":"Alice","score":95},{"name":"Bob","score":87},{"name":"Carol","score":92},{"name":"Dave","score":88},{"name":"Eve","score":91}]"#;
let with_preprocess = compress_to_vec(data, Mode::Fast, Some(FormatHint::Json)).unwrap();
let without_preprocess =
compress_to_vec(data, Mode::Fast, Some(FormatHint::Generic)).unwrap();
assert_eq!(
decompress_from_slice(&with_preprocess).unwrap(),
data.to_vec()
);
assert_eq!(
decompress_from_slice(&without_preprocess).unwrap(),
data.to_vec()
);
}
#[test]
fn all_modes_roundtrip() {
let data = b"test all modes with some more content to ensure decent compression";
for mode in [Mode::Max, Mode::Balanced, Mode::Fast] {
let compressed = compress_to_vec(data, mode, None).unwrap();
let decompressed = decompress_from_slice(&compressed).unwrap();
assert_eq!(decompressed, data, "failed for mode {mode}");
}
}
#[test]
fn cm_compress_decompress_direct() {
let data = b"Hello, World! This is a direct CM test.";
let compressed = cm_compress(data, CMConfig::balanced());
let decompressed = cm_decompress(&compressed, data.len(), CMConfig::balanced());
assert_eq!(decompressed, data.to_vec());
}
#[test]
fn cm_empty() {
let data: &[u8] = b"";
let compressed = cm_compress(data, CMConfig::balanced());
let decompressed = cm_decompress(&compressed, 0, CMConfig::balanced());
assert!(decompressed.is_empty());
}
#[test]
fn cm_single_byte() {
for byte in 0..=255u8 {
let data = [byte];
let compressed = cm_compress(&data, CMConfig::balanced());
let decompressed = cm_decompress(&compressed, 1, CMConfig::balanced());
assert_eq!(
decompressed, data,
"CM roundtrip failed for byte {byte:#04X}"
);
}
}
#[test]
fn cm_repetitive_compresses() {
let data = vec![b'A'; 1000];
let compressed = cm_compress(&data, CMConfig::balanced());
assert!(
compressed.len() < 200,
"CM should compress 1000 identical bytes well: {} bytes",
compressed.len()
);
let decompressed = cm_decompress(&compressed, data.len(), CMConfig::balanced());
assert_eq!(decompressed, data);
}
#[test]
fn max_mode_roundtrip() {
let original = b"Max mode test data with some content for compression.";
let compressed = compress_to_vec(original, Mode::Max, None).unwrap();
let decompressed = decompress_from_slice(&compressed).unwrap();
assert_eq!(decompressed, original);
}
#[test]
fn max_mode_longer_text() {
let original = b"The quick brown fox jumps over the lazy dog. Max mode uses 2x context maps for better predictions with fewer hash collisions. This should compress slightly better than balanced mode.";
let compressed = compress_to_vec(original, Mode::Max, None).unwrap();
let decompressed = decompress_from_slice(&compressed).unwrap();
assert_eq!(decompressed, original);
}
#[test]
fn test_dict_compress_roundtrip() {
let mut ndjson = String::new();
for i in 0..500 {
ndjson.push_str(&format!(
r#"{{"id":{},"name":"user_{}","status":"active","score":{}}}"#,
i,
i,
i * 17 % 100
));
ndjson.push('\n');
}
let data = ndjson.as_bytes();
assert!(
data.len() > DICT_MIN_DATA_SIZE,
"test data should exceed dict threshold: {} bytes",
data.len()
);
let compressed = compress_to_vec(data, Mode::Fast, Some(FormatHint::Ndjson)).unwrap();
let decompressed = decompress_from_slice(&compressed).unwrap();
assert_eq!(
decompressed, data,
"dict compress roundtrip: byte-exact mismatch"
);
}
#[test]
fn test_dict_falls_back_on_small() {
let data = b"small data that won't trigger dictionary training";
assert!(data.len() < DICT_MIN_DATA_SIZE);
let compressed = compress_to_vec(data, Mode::Fast, None).unwrap();
let decompressed = decompress_from_slice(&compressed).unwrap();
assert_eq!(decompressed, data.to_vec());
let mut cursor = Cursor::new(&compressed);
let header = crate::dcx::DcxHeader::read_from(&mut cursor).unwrap();
assert!(!header.has_dict, "small data should not have dict flag set");
}
#[test]
fn test_dict_backward_compat() {
let original = b"backward compatibility test data for decompression";
let compressed = compress_to_vec(original, Mode::Fast, None).unwrap();
let mut cursor = Cursor::new(&compressed);
let header = crate::dcx::DcxHeader::read_from(&mut cursor).unwrap();
assert!(!header.has_dict);
let decompressed = decompress_from_slice(&compressed).unwrap();
assert_eq!(decompressed, original.to_vec());
}
#[test]
fn test_dict_ndjson_large_roundtrip() {
let mut ndjson = String::new();
for i in 0..2000 {
ndjson.push_str(&format!(
r#"{{"timestamp":"2025-01-{:02}T{:02}:{:02}:00Z","level":"info","message":"Request processed","request_id":"req_{}","duration_ms":{}}}"#,
(i % 28) + 1,
i % 24,
i % 60,
i,
(i * 13) % 500
));
ndjson.push('\n');
}
let data = ndjson.as_bytes();
let compressed = compress_to_vec(data, Mode::Fast, Some(FormatHint::Ndjson)).unwrap();
let decompressed = decompress_from_slice(&compressed).unwrap();
assert_eq!(decompressed, data, "large NDJSON roundtrip mismatch");
}
#[test]
fn test_dict_generic_data_roundtrip() {
let mut data = Vec::new();
for i in 0..3000 {
data.extend_from_slice(
format!("line {i}: the quick brown fox jumps over the lazy dog\n").as_bytes(),
);
}
assert!(data.len() > DICT_MIN_DATA_SIZE);
let compressed = compress_to_vec(&data, Mode::Fast, Some(FormatHint::Generic)).unwrap();
let decompressed = decompress_from_slice(&compressed).unwrap();
assert_eq!(decompressed, data, "generic data dict roundtrip mismatch");
}
#[test]
fn test_dict_does_not_affect_other_modes() {
let mut ndjson = String::new();
for i in 0..200 {
ndjson.push_str(&format!(
r#"{{"id":{},"name":"user_{}","status":"active"}}"#,
i, i
));
ndjson.push('\n');
}
let data = ndjson.as_bytes();
for mode in [Mode::Balanced, Mode::Max] {
let compressed = compress_to_vec(data, mode, Some(FormatHint::Ndjson)).unwrap();
let mut cursor = Cursor::new(&compressed);
let header = crate::dcx::DcxHeader::read_from(&mut cursor).unwrap();
assert!(!header.has_dict, "mode {mode} should never have dict flag");
let decompressed = decompress_from_slice(&compressed).unwrap();
assert_eq!(decompressed, data, "roundtrip failed for mode {mode}");
}
}
#[test]
fn test_compress_with_level() {
let data = "hello world, compressing with custom zstd level. ".repeat(50);
let compressed =
compress_to_vec_with_options(data.as_bytes(), Mode::Fast, None, None, Some(19))
.unwrap();
let decompressed = decompress_from_slice(&compressed).unwrap();
assert_eq!(decompressed, data.as_bytes(), "level 19 roundtrip failed");
}
#[test]
fn test_compress_with_level_default() {
let data = "default level test data. ".repeat(50);
let compressed =
compress_to_vec_with_options(data.as_bytes(), Mode::Fast, None, None, None).unwrap();
let decompressed = decompress_from_slice(&compressed).unwrap();
assert_eq!(
decompressed,
data.as_bytes(),
"default level roundtrip failed"
);
}
#[test]
fn test_compress_with_level_higher_ratio() {
let data = r#"{"name":"Alice","score":95}"#.repeat(200);
let low =
compress_to_vec_with_options(data.as_bytes(), Mode::Fast, None, None, Some(1)).unwrap();
let high = compress_to_vec_with_options(data.as_bytes(), Mode::Fast, None, None, Some(19))
.unwrap();
assert_eq!(decompress_from_slice(&low).unwrap(), data.as_bytes());
assert_eq!(decompress_from_slice(&high).unwrap(), data.as_bytes());
assert!(
high.len() <= low.len(),
"level 19 ({}) should be <= level 1 ({})",
high.len(),
low.len()
);
}
#[test]
fn test_auto_fallback_picks_smaller() {
let data = std::fs::read(concat!(
env!("CARGO_MANIFEST_DIR"),
"/../../corpus/json-bench/citm_catalog.json"
))
.unwrap();
let compressed = compress_to_vec(&data, Mode::Fast, Some(FormatHint::Json)).unwrap();
let decompressed = decompress_from_slice(&compressed).unwrap();
assert_eq!(decompressed, data, "citm_catalog roundtrip failed");
let ratio = data.len() as f64 / compressed.len() as f64;
assert!(
ratio > 50.0,
"citm_catalog should achieve >50x, got {ratio:.1}x"
);
}
#[test]
fn test_auto_fallback_preprocessed_wins_on_ndjson() {
let data = std::fs::read(concat!(
env!("CARGO_MANIFEST_DIR"),
"/../../corpus/test-ndjson.ndjson"
))
.unwrap();
let compressed = compress_to_vec(&data, Mode::Fast, Some(FormatHint::Ndjson)).unwrap();
let decompressed = decompress_from_slice(&compressed).unwrap();
assert_eq!(decompressed, data, "test-ndjson roundtrip failed");
let mut cursor = Cursor::new(&compressed);
let header = crate::dcx::DcxHeader::read_from(&mut cursor).unwrap();
assert!(
!header.transform_metadata.is_empty() || header.meta_embedded,
"test-ndjson should prefer preprocessed path (non-empty transform metadata or embedded)"
);
}
#[test]
fn test_auto_fallback_roundtrip() {
let citm = std::fs::read(concat!(
env!("CARGO_MANIFEST_DIR"),
"/../../corpus/json-bench/citm_catalog.json"
))
.unwrap();
let ndjson = std::fs::read(concat!(
env!("CARGO_MANIFEST_DIR"),
"/../../corpus/test-ndjson.ndjson"
))
.unwrap();
let compressed_citm = compress_to_vec(&citm, Mode::Fast, Some(FormatHint::Json)).unwrap();
let decompressed_citm = decompress_from_slice(&compressed_citm).unwrap();
assert_eq!(
decompressed_citm, citm,
"citm_catalog roundtrip (raw path) failed"
);
let compressed_ndjson =
compress_to_vec(&ndjson, Mode::Fast, Some(FormatHint::Ndjson)).unwrap();
let decompressed_ndjson = decompress_from_slice(&compressed_ndjson).unwrap();
assert_eq!(
decompressed_ndjson, ndjson,
"test-ndjson roundtrip (preprocessed path) failed"
);
}
#[test]
fn test_adaptive_level_small_data() {
assert_eq!(adaptive_fast_level(100_000, None), 19);
assert_eq!(adaptive_fast_level(500_000, None), 19);
assert_eq!(adaptive_fast_level(1_048_576, None), 19);
assert_eq!(adaptive_fast_level(0, None), 19);
}
#[test]
fn test_adaptive_level_medium_data() {
assert_eq!(adaptive_fast_level(1_048_577, None), 19);
assert_eq!(adaptive_fast_level(5_000_000, None), 19);
assert_eq!(adaptive_fast_level(10_485_760, None), 19);
assert_eq!(adaptive_fast_level(16_777_216, None), 19);
}
#[test]
fn test_adaptive_level_large_data() {
assert_eq!(adaptive_fast_level(16_777_217, None), 16);
assert_eq!(adaptive_fast_level(33_554_432, None), 16);
assert_eq!(adaptive_fast_level(67_108_864, None), 16);
assert_eq!(adaptive_fast_level(67_108_865, None), 9);
assert_eq!(adaptive_fast_level(100_000_000, None), 9);
}
#[test]
fn test_adaptive_level_override() {
assert_eq!(adaptive_fast_level(100, Some(3)), 3);
assert_eq!(adaptive_fast_level(100_000_000, Some(22)), 22);
assert_eq!(adaptive_fast_level(0, Some(1)), 1);
}
#[test]
fn test_compressed_metadata_roundtrip() {
let mut ndjson = String::new();
for i in 0..500 {
ndjson.push_str(&format!(
r#"{{"id":{},"name":"user_{}","status":"active","score":{}}}"#,
i,
i,
i * 17 % 100
));
ndjson.push('\n');
}
let data = ndjson.as_bytes();
let compressed = compress_to_vec(data, Mode::Fast, Some(FormatHint::Ndjson)).unwrap();
let decompressed = decompress_from_slice(&compressed).unwrap();
assert_eq!(
decompressed, data,
"compressed metadata roundtrip: byte-exact mismatch"
);
let mut cursor = Cursor::new(&compressed);
let header = crate::dcx::DcxHeader::read_from(&mut cursor).unwrap();
if !header.transform_metadata.is_empty() && header.transform_metadata.len() > 10 {
}
}
#[test]
fn test_compressed_metadata_backward_compat() {
let original = b"backward compatibility test data for metadata decompression";
let compressed = compress_to_vec(original, Mode::Fast, None).unwrap();
let decompressed = decompress_from_slice(&compressed).unwrap();
assert_eq!(decompressed, original.to_vec());
let mut cursor = Cursor::new(&compressed);
let header = crate::dcx::DcxHeader::read_from(&mut cursor).unwrap();
assert!(!header.meta_compressed || !header.transform_metadata.is_empty());
}
#[test]
fn test_compressed_metadata_small_skipped() {
let data = br#"{"name":"Alice","age":30}"#;
let compressed = compress_to_vec(data, Mode::Fast, Some(FormatHint::Json)).unwrap();
let decompressed = decompress_from_slice(&compressed).unwrap();
assert_eq!(decompressed, data.to_vec());
let mut cursor = Cursor::new(&compressed);
let header = crate::dcx::DcxHeader::read_from(&mut cursor).unwrap();
if header.transform_metadata.len() <= 64 {
assert!(
!header.meta_compressed,
"metadata <= 64 bytes should not be compressed, but meta_compressed=true \
for {} bytes of metadata",
header.transform_metadata.len()
);
}
}
#[test]
fn test_twitter_json_brotli_wins() {
let data = std::fs::read(concat!(
env!("CARGO_MANIFEST_DIR"),
"/../../corpus/json-bench/twitter.json"
))
.unwrap();
let compressed = compress_to_vec(&data, Mode::Fast, Some(FormatHint::Json)).unwrap();
let decompressed = decompress_from_slice(&compressed).unwrap();
assert_eq!(decompressed, data, "twitter.json roundtrip failed");
let mut cursor = Cursor::new(&compressed);
let header = crate::dcx::DcxHeader::read_from(&mut cursor).unwrap();
assert!(
header.use_brotli,
"twitter.json should use brotli (FLAG_BROTLI set in header)"
);
}
#[test]
fn test_compressed_metadata_all_modes_roundtrip() {
let mut ndjson = String::new();
for i in 0..200 {
ndjson.push_str(&format!(
r#"{{"id":{},"name":"user_{}","status":"active"}}"#,
i, i
));
ndjson.push('\n');
}
let data = ndjson.as_bytes();
for mode in [Mode::Fast, Mode::Balanced, Mode::Max] {
let compressed = compress_to_vec(data, mode, Some(FormatHint::Ndjson)).unwrap();
let decompressed = decompress_from_slice(&compressed).unwrap();
assert_eq!(
decompressed, data,
"compressed metadata roundtrip failed for mode {mode}"
);
}
}
#[test]
fn test_brotli_compress_roundtrip() {
let data = b"Hello, brotli! This is a test of the brotli compression helpers.";
let compressed = brotli_compress(data, 11, BROTLI_MODE_GENERIC).unwrap();
let decompressed = brotli_decompress(&compressed).unwrap();
assert_eq!(decompressed, data.to_vec());
}
#[test]
fn test_brotli_auto_fallback_twitter() {
let data = std::fs::read(concat!(
env!("CARGO_MANIFEST_DIR"),
"/../../corpus/json-bench/twitter.json"
))
.unwrap();
let compressed = compress_to_vec(&data, Mode::Fast, Some(FormatHint::Json)).unwrap();
let decompressed = decompress_from_slice(&compressed).unwrap();
assert_eq!(decompressed, data, "twitter.json brotli roundtrip failed");
let mut cursor = Cursor::new(&compressed);
let header = crate::dcx::DcxHeader::read_from(&mut cursor).unwrap();
assert!(
header.use_brotli,
"twitter.json should use brotli in auto-fallback"
);
}
#[test]
fn test_brotli_ndjson_roundtrip() {
let data = std::fs::read(concat!(
env!("CARGO_MANIFEST_DIR"),
"/../../corpus/test-ndjson.ndjson"
))
.unwrap();
let compressed = compress_to_vec(&data, Mode::Fast, Some(FormatHint::Ndjson)).unwrap();
let decompressed = decompress_from_slice(&compressed).unwrap();
assert_eq!(decompressed, data, "ndjson roundtrip failed");
}
#[test]
fn test_brotli_backward_compat() {
let original = b"backward compatibility test: this data was compressed without brotli";
let crc = crc32fast::hash(original);
let zstd_compressed = zstd::bulk::compress(original, 19).unwrap();
let header = crate::dcx::DcxHeader {
mode: Mode::Fast,
format_hint: crate::dcx::FormatHint::Generic,
original_size: original.len() as u64,
compressed_size: zstd_compressed.len() as u64,
crc32: crc,
transform_metadata: vec![],
has_dict: false,
meta_compressed: false,
use_brotli: false,
meta_embedded: false,
};
let mut buf = Vec::new();
header.write_to(&mut buf).unwrap();
buf.extend_from_slice(&zstd_compressed);
assert_eq!(buf[7] & crate::dcx::FLAG_BROTLI, 0);
let decompressed = decompress_from_slice(&buf).unwrap();
assert_eq!(decompressed, original.to_vec());
}
#[test]
fn test_embedded_metadata_roundtrip() {
let data = std::fs::read(concat!(
env!("CARGO_MANIFEST_DIR"),
"/../../corpus/test-api.json"
))
.unwrap();
let compressed = compress_to_vec(&data, Mode::Fast, Some(FormatHint::Json)).unwrap();
let decompressed = decompress_from_slice(&compressed).unwrap();
assert_eq!(
decompressed, data,
"test-api.json embedded metadata roundtrip: byte-exact mismatch"
);
}
#[test]
fn test_embedded_metadata_backward_compat() {
let original = b"backward compat: no embedded metadata in this old file format";
let crc = crc32fast::hash(original);
let zstd_compressed = zstd::bulk::compress(original, 19).unwrap();
let header = crate::dcx::DcxHeader {
mode: Mode::Fast,
format_hint: crate::dcx::FormatHint::Generic,
original_size: original.len() as u64,
compressed_size: zstd_compressed.len() as u64,
crc32: crc,
transform_metadata: vec![],
has_dict: false,
meta_compressed: false,
use_brotli: false,
meta_embedded: false,
};
let mut buf = Vec::new();
header.write_to(&mut buf).unwrap();
buf.extend_from_slice(&zstd_compressed);
assert_eq!(buf[7] & crate::dcx::FLAG_META_EMBEDDED, 0);
let decompressed = decompress_from_slice(&buf).unwrap();
assert_eq!(decompressed, original.to_vec());
}
#[test]
fn test_embedded_metadata_small_file_improvement() {
let data = std::fs::read(concat!(
env!("CARGO_MANIFEST_DIR"),
"/../../corpus/test-api.json"
))
.unwrap();
let compressed = compress_to_vec(&data, Mode::Fast, Some(FormatHint::Json)).unwrap();
let decompressed = decompress_from_slice(&compressed).unwrap();
assert_eq!(decompressed, data, "roundtrip failed");
let ratio = data.len() as f64 / compressed.len() as f64;
assert!(
ratio > 5.0,
"test-api.json should achieve >5x compression, got {ratio:.1}x"
);
let mut cursor = Cursor::new(&compressed);
let header = crate::dcx::DcxHeader::read_from(&mut cursor).unwrap();
if header.meta_embedded {
assert!(
header.transform_metadata.is_empty(),
"meta_embedded header should have empty transform_metadata"
);
assert!(header.use_brotli, "meta_embedded should use brotli codec");
}
}
#[test]
fn test_embedded_metadata_ndjson_roundtrip() {
let data = std::fs::read(concat!(
env!("CARGO_MANIFEST_DIR"),
"/../../corpus/test-ndjson.ndjson"
))
.unwrap();
let compressed = compress_to_vec(&data, Mode::Fast, Some(FormatHint::Ndjson)).unwrap();
let decompressed = decompress_from_slice(&compressed).unwrap();
assert_eq!(
decompressed, data,
"NDJSON embedded metadata roundtrip: byte-exact mismatch"
);
}
#[test]
fn test_embedded_metadata_manual_roundtrip() {
let original = b"Hello, embedded metadata world! This is a test.";
let crc = crc32fast::hash(original);
let empty_chain = TransformChain::new();
let raw_metadata = empty_chain.serialize();
let mut embedded = Vec::new();
embedded.extend_from_slice(&(raw_metadata.len() as u32).to_le_bytes());
embedded.extend_from_slice(&raw_metadata);
embedded.extend_from_slice(original);
let brotli_data = brotli_compress(&embedded, 11, BROTLI_MODE_GENERIC).unwrap();
let header = crate::dcx::DcxHeader {
mode: Mode::Fast,
format_hint: crate::dcx::FormatHint::Generic,
original_size: original.len() as u64,
compressed_size: brotli_data.len() as u64,
crc32: crc,
transform_metadata: vec![], has_dict: false,
meta_compressed: false,
use_brotli: true,
meta_embedded: true,
};
let mut buf = Vec::new();
header.write_to(&mut buf).unwrap();
buf.extend_from_slice(&brotli_data);
assert_ne!(buf[7] & crate::dcx::FLAG_META_EMBEDDED, 0);
assert_ne!(buf[7] & crate::dcx::FLAG_BROTLI, 0);
let decompressed = decompress_from_slice(&buf).unwrap();
assert_eq!(decompressed, original.to_vec());
}
#[test]
fn test_brotli_text_mode_on_raw() {
let data = br#"{"name":"Alice","age":30,"city":"New York","active":true}"#;
let compressed_text = brotli_compress(data, 11, BROTLI_MODE_TEXT).unwrap();
let decompressed_text = brotli_decompress(&compressed_text).unwrap();
assert_eq!(
decompressed_text,
data.to_vec(),
"TEXT mode roundtrip failed"
);
let compressed_generic = brotli_compress(data, 11, BROTLI_MODE_GENERIC).unwrap();
let decompressed_generic = brotli_decompress(&compressed_generic).unwrap();
assert_eq!(
decompressed_generic,
data.to_vec(),
"GENERIC mode roundtrip failed"
);
assert!(
!compressed_text.is_empty(),
"TEXT mode should produce non-empty output"
);
}
#[test]
fn test_zstd_embedded_metadata_roundtrip() {
let original = b"Hello, zstd embedded metadata! This is a test of the zstd path.";
let crc = crc32fast::hash(original);
let empty_chain = TransformChain::new();
let raw_metadata = empty_chain.serialize();
let mut embedded = Vec::new();
embedded.extend_from_slice(&(raw_metadata.len() as u32).to_le_bytes());
embedded.extend_from_slice(&raw_metadata);
embedded.extend_from_slice(original);
let zstd_data = zstd::bulk::compress(&embedded, 19).unwrap();
let header = crate::dcx::DcxHeader {
mode: Mode::Fast,
format_hint: crate::dcx::FormatHint::Generic,
original_size: original.len() as u64,
compressed_size: zstd_data.len() as u64,
crc32: crc,
transform_metadata: vec![], has_dict: false,
meta_compressed: false,
use_brotli: false, meta_embedded: true,
};
let mut buf = Vec::new();
header.write_to(&mut buf).unwrap();
buf.extend_from_slice(&zstd_data);
assert_ne!(buf[7] & crate::dcx::FLAG_META_EMBEDDED, 0);
assert_eq!(buf[7] & crate::dcx::FLAG_BROTLI, 0);
let decompressed = decompress_from_slice(&buf).unwrap();
assert_eq!(decompressed, original.to_vec());
}
#[test]
fn test_multi_quality_brotli() {
let data = br#"{"items":[1,2,3,4,5],"nested":{"a":"hello","b":"world"}}"#;
let q10 = brotli_compress(data, 10, BROTLI_MODE_GENERIC).unwrap();
let q11 = brotli_compress(data, 11, BROTLI_MODE_GENERIC).unwrap();
let dec_q10 = brotli_decompress(&q10).unwrap();
let dec_q11 = brotli_decompress(&q11).unwrap();
assert_eq!(dec_q10, data.to_vec(), "quality 10 roundtrip failed");
assert_eq!(dec_q11, data.to_vec(), "quality 11 roundtrip failed");
assert!(!q10.is_empty());
assert!(!q11.is_empty());
let corpus_files = [
concat!(env!("CARGO_MANIFEST_DIR"), "/../../corpus/test-api.json"),
concat!(
env!("CARGO_MANIFEST_DIR"),
"/../../corpus/json-bench/twitter.json"
),
];
for path in corpus_files {
let file_data = std::fs::read(path).unwrap();
let compressed =
compress_to_vec(&file_data, Mode::Fast, Some(FormatHint::Json)).unwrap();
let decompressed = decompress_from_slice(&compressed).unwrap();
assert_eq!(
decompressed, file_data,
"multi-quality roundtrip failed for {path}"
);
}
}
#[test]
fn test_singleton_arrays_fast_roundtrip() {
let rows: Vec<String> = (0..500)
.map(|i| format!("{{\"items\":[{{\"x\":{}}}],\"id\":{}}}", i, i))
.collect();
let data = rows.join("\n") + "\n";
let compressed =
compress_to_vec(data.as_bytes(), Mode::Fast, Some(FormatHint::Ndjson)).unwrap();
let decompressed = decompress_from_slice(&compressed).unwrap();
assert_eq!(
decompressed,
data.as_bytes(),
"singleton_arrays fast mode roundtrip failed"
);
}
#[test]
fn test_very_long_lines_fast_roundtrip() {
let rows: Vec<String> = (0..50)
.map(|i| format!("{{\"data\":\"{}\",\"id\":{}}}", "X".repeat(100_000), i))
.collect();
let data = rows.join("\n") + "\n";
let compressed =
compress_to_vec(data.as_bytes(), Mode::Fast, Some(FormatHint::Ndjson)).unwrap();
let decompressed = decompress_from_slice(&compressed).unwrap();
assert_eq!(
decompressed,
data.as_bytes(),
"very_long_lines fast mode roundtrip failed"
);
}
#[test]
fn test_very_long_lines_balanced_roundtrip() {
let rows: Vec<String> = (0..10)
.map(|i| format!("{{\"data\":\"{}\",\"id\":{}}}", "X".repeat(100_000), i))
.collect();
let data = rows.join("\n") + "\n";
let compressed =
compress_to_vec(data.as_bytes(), Mode::Balanced, Some(FormatHint::Ndjson)).unwrap();
let decompressed = decompress_from_slice(&compressed).unwrap();
assert_eq!(
decompressed,
data.as_bytes(),
"very_long_lines balanced mode roundtrip failed"
);
}
#[test]
fn test_all_same_value_fast_roundtrip() {
let rows: Vec<String> = (0..10_000).map(|_| "{\"x\":1}".to_string()).collect();
let data = rows.join("\n") + "\n";
let compressed =
compress_to_vec(data.as_bytes(), Mode::Fast, Some(FormatHint::Ndjson)).unwrap();
let decompressed = decompress_from_slice(&compressed).unwrap();
assert_eq!(
decompressed,
data.as_bytes(),
"all_same_value fast mode roundtrip failed"
);
}
#[test]
fn test_generate_training_samples_degenerate() {
let mut data = vec![0x02u8]; data.extend_from_slice(&[0x00; 9999]); let samples = generate_training_samples(&data, 1024);
let avg_len = samples.iter().map(|s| s.len()).sum::<usize>() / samples.len();
assert!(
avg_len >= 8,
"training samples average size should be >= 8, got {avg_len}"
);
}
#[test]
fn null_heavy_codec_roundtrip_fast() {
let mut data = Vec::new();
for i in 0..30 {
data.extend_from_slice(format!("{{\"id\": {}, \"val\": null}}\n", i).as_bytes());
}
let mut compressed = Vec::new();
compress(&data, Mode::Fast, None, &mut compressed).unwrap();
let decompressed = decompress(&mut std::io::Cursor::new(&compressed)).unwrap();
assert_eq!(
decompressed, data,
"null-heavy 30-row fast mode roundtrip failed"
);
}
#[test]
fn null_heavy_codec_roundtrip_balanced() {
let mut data = Vec::new();
for i in 0..30 {
data.extend_from_slice(format!("{{\"id\": {}, \"val\": null}}\n", i).as_bytes());
}
let mut compressed = Vec::new();
compress(&data, Mode::Balanced, None, &mut compressed).unwrap();
let decompressed = decompress(&mut std::io::Cursor::new(&compressed)).unwrap();
assert_eq!(
decompressed, data,
"null-heavy 30-row balanced mode roundtrip failed"
);
}
#[test]
fn gharchive_selective_roundtrip() {
let path = concat!(
env!("CARGO_MANIFEST_DIR"),
"/../../corpus/json-bench/gharchive-10mb.ndjson"
);
let data = match std::fs::read(path) {
Ok(d) => d,
Err(_) => return, };
let mut compressed = Vec::new();
compress(
&data,
Mode::Fast,
Some(crate::dcx::FormatHint::Ndjson),
&mut compressed,
)
.unwrap();
let decompressed = decompress(&mut std::io::Cursor::new(&compressed)).unwrap();
assert_eq!(
decompressed, data,
"GH Archive selective columnar roundtrip failed"
);
}
}