use super::schema::{Group, Schema};
use crate::results::analysis_results::compute_analysis_results;
use crate::results::analysis_results::AnalysisResults;
use crate::results::ComputeAnalysisResultsError;
use crate::schema::{BitOrder, Condition, FieldDefinition};
use crate::utils::analyze_utils::{
create_bit_reader, create_bit_writer, reverse_bits, size_estimate, BitReaderContainer,
BitWriterContainer,
};
use crate::utils::constants::CHILD_MARKER;
use ahash::{AHashMap, HashMapExt};
use bitstream_io::{BitRead, BitReader, BitWrite, Endianness};
use rustc_hash::FxHashMap;
use std::io::{Cursor, SeekFrom};
use thiserror::Error;
pub struct SchemaAnalyzer<'a> {
pub schema: &'a Schema,
pub entries: Vec<u8>,
pub field_states: AHashMap<String, AnalyzerFieldState>,
pub compression_options: CompressionOptions,
}
#[derive(Debug, Clone, Copy)]
pub struct SizeEstimationParameters<'a> {
pub name: &'a str,
pub data_len: usize,
pub data: Option<&'a [u8]>,
pub num_lz_matches: usize,
pub entropy: f64,
pub lz_match_multiplier: f64,
pub entropy_multiplier: f64,
}
pub type SizeEstimatorFn = fn(SizeEstimationParameters) -> usize;
#[derive(Debug, Clone, Copy)]
pub struct CompressionOptions {
pub zstd_compression_level: i32,
pub size_estimator_fn: SizeEstimatorFn,
pub lz_match_multiplier: f64,
pub entropy_multiplier: f64,
}
impl Default for CompressionOptions {
fn default() -> Self {
Self {
zstd_compression_level: 16,
size_estimator_fn: size_estimate,
lz_match_multiplier: 0.0,
entropy_multiplier: 0.0,
}
}
}
impl CompressionOptions {
pub fn with_zstd_compression_level(mut self, level: i32) -> Self {
self.zstd_compression_level = level;
self
}
pub fn with_size_estimator_fn(mut self, estimator_fn: SizeEstimatorFn) -> Self {
self.size_estimator_fn = estimator_fn;
self
}
}
pub struct AnalyzerFieldState {
pub name: String,
pub full_path: String,
pub depth: usize,
pub count: u64,
pub lenbits: u32,
pub writer: BitWriterContainer,
pub bit_counts: Vec<BitStats>,
pub bit_order: BitOrder,
pub value_counts: FxHashMap<u64, u64>,
}
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
pub struct BitStats {
pub zeros: u64,
pub ones: u64,
}
#[derive(Debug, Error)]
pub enum AnalysisError {
#[error("I/O error in add_entry reader during analysis. This is indicative of a bug in schema parsing or sanitization; and should normally not happen. Details: {0}")]
Io(#[from] std::io::Error),
#[error(
"Field '{0}' not found in Analyzer. This is indicative of a bug and should not happen."
)]
FieldNotFound(String),
#[error("Invalid entry length: expected {expected}, got {found}")]
InvalidEntryLength { expected: usize, found: usize },
}
impl<'a> SchemaAnalyzer<'a> {
pub fn new(schema: &'a Schema, options: CompressionOptions) -> Self {
Self {
schema,
entries: Vec::new(),
field_states: build_field_stats(&schema.root, "", 0, schema.bit_order),
compression_options: options,
}
}
pub fn add_entry(&mut self, entry: &[u8]) -> Result<(), AnalysisError> {
self.entries.extend_from_slice(entry);
if entry.len() * 8 < self.schema.root.bits as usize {
return Err(AnalysisError::InvalidEntryLength {
expected: self.schema.root.bits as usize,
found: self.entries.len() * 8,
});
}
let reader = create_bit_reader(entry, self.schema.bit_order);
match reader {
BitReaderContainer::Msb(mut bit_reader) => {
self.process_group(&self.schema.root, &mut bit_reader)
}
BitReaderContainer::Lsb(mut bit_reader) => {
self.process_group(&self.schema.root, &mut bit_reader)
}
}
}
fn process_group<TEndian: Endianness>(
&mut self,
group: &Group,
reader: &mut BitReader<Cursor<&[u8]>, TEndian>,
) -> Result<(), AnalysisError> {
if should_skip(reader, &group.skip_if_not)? {
return Ok(());
}
for (name, field_def) in &group.fields {
match field_def {
FieldDefinition::Field(field) => {
if should_skip(reader, &field.skip_if_not)? {
continue;
}
let bits_left = field.bits;
let field_stats = self
.field_states
.get_mut(name)
.ok_or_else(|| AnalysisError::FieldNotFound(name.clone()))?;
process_field_or_group(
reader,
bits_left,
field_stats,
field.skip_frequency_analysis,
)?;
}
FieldDefinition::Group(child_group) => {
let bits_left = child_group.bits;
let field_stats = self
.field_states
.get_mut(name)
.ok_or_else(|| AnalysisError::FieldNotFound(name.clone()))?;
let current_offset = reader.position_in_bits()?;
process_field_or_group(
reader,
bits_left,
field_stats,
child_group.skip_frequency_analysis,
)?;
reader.seek_bits(SeekFrom::Start(current_offset))?;
self.process_group(child_group, reader)?;
}
}
}
Ok(())
}
pub fn generate_results(&mut self) -> Result<AnalysisResults, ComputeAnalysisResultsError> {
compute_analysis_results(self)
}
}
fn process_field_or_group<TEndian: Endianness>(
reader: &mut BitReader<Cursor<&[u8]>, TEndian>,
mut bit_count: u32,
field_stats: &mut AnalyzerFieldState,
skip_frequency_analysis: bool,
) -> Result<(), AnalysisError> {
let writer = &mut field_stats.writer;
let can_bit_stats = bit_count <= 64;
let skip_count_values = bit_count > 16 || skip_frequency_analysis;
field_stats.count += 1;
while bit_count > 0 {
let max_bits = bit_count.min(64);
let bits = reader.read::<u64>(max_bits)?;
if !skip_count_values {
if field_stats.bit_order == BitOrder::Lsb {
let reversed_bits = reverse_bits(max_bits, bits);
*field_stats.value_counts.entry(reversed_bits).or_insert(0) += 1;
} else {
*field_stats.value_counts.entry(bits).or_insert(0) += 1;
}
}
match writer {
BitWriterContainer::Msb(w) => w.write(max_bits, bits)?,
BitWriterContainer::Lsb(w) => w.write(max_bits, bits)?,
}
if can_bit_stats {
for i in 0..max_bits {
let idx = i as usize;
let bit_value = (bits >> (max_bits - 1 - i)) & 1;
if bit_value == 0 {
field_stats.bit_counts[idx].zeros += 1;
} else {
field_stats.bit_counts[idx].ones += 1;
}
}
}
bit_count -= max_bits;
}
match writer {
BitWriterContainer::Msb(w) => w.flush()?,
BitWriterContainer::Lsb(w) => w.flush()?,
}
Ok(())
}
fn build_field_stats<'a>(
group: &'a Group,
parent_path: &'a str,
depth: usize,
file_bit_order: BitOrder,
) -> AHashMap<String, AnalyzerFieldState> {
let mut stats = AHashMap::new();
for (name, field) in &group.fields {
let path = if parent_path.is_empty() {
name.clone()
} else {
format!("{}{CHILD_MARKER}{}", parent_path, name)
};
match field {
FieldDefinition::Field(field) => {
let writer = create_bit_writer(file_bit_order);
stats.insert(
name.clone(),
AnalyzerFieldState {
full_path: path,
depth,
lenbits: field.bits,
count: 0,
writer,
bit_counts: vec![BitStats::default(); clamp_bits(field.bits as usize)],
name: name.clone(),
bit_order: field.bit_order.get_with_default_resolve(),
value_counts: FxHashMap::new(),
},
);
}
FieldDefinition::Group(group) => {
let writer = create_bit_writer(file_bit_order);
stats.insert(
name.clone(),
AnalyzerFieldState {
full_path: path.clone(),
depth,
lenbits: group.bits,
count: 0,
writer,
bit_counts: vec![BitStats::default(); clamp_bits(group.bits as usize)],
name: name.clone(),
bit_order: group.bit_order.get_with_default_resolve(),
value_counts: FxHashMap::new(),
},
);
stats.extend(build_field_stats(group, &path, depth + 1, file_bit_order));
}
}
}
stats
}
#[inline]
fn should_skip<TEndian: Endianness>(
reader: &mut BitReader<Cursor<&[u8]>, TEndian>,
conditions: &[Condition],
) -> Result<bool, AnalysisError> {
if conditions.is_empty() {
return Ok(false);
}
let original_pos_bits = reader.position_in_bits()?;
for condition in conditions {
let offset = (condition.byte_offset * 8) + condition.bit_offset as u64;
let target_pos = original_pos_bits.wrapping_add(offset);
reader.seek_bits(SeekFrom::Start(target_pos))?;
let mut value = reader.read::<u64>(condition.bits as u32)?;
if condition.bit_order == BitOrder::Lsb {
value = reverse_bits(condition.bits as u32, value);
}
if value != condition.value {
reader.seek_bits(SeekFrom::Start(original_pos_bits))?;
return Ok(true);
}
}
reader.seek_bits(SeekFrom::Start(original_pos_bits))?;
Ok(false)
}
fn clamp_bits(bits: usize) -> usize {
if bits > 64 {
0
} else {
bits
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::schema::Schema;
fn create_test_schema() -> Schema {
let yaml = r###"
version: '1.0'
root:
type: group
fields:
id:
type: field
bits: 32
description: "ID field"
nested:
type: group
bit_order: lsb
fields:
value:
type: field
bits: 8
description: "Nested value"
"###;
Schema::from_yaml(yaml).expect("Failed to parse test schema")
}
#[test]
fn test_analyzer_initialization() {
let schema = create_test_schema();
let options = CompressionOptions::default();
let analyzer = SchemaAnalyzer::new(&schema, options);
assert_eq!(
analyzer.field_states.len(),
3,
"Should have stats for root group + 2 fields"
);
}
#[test]
fn test_big_endian_bitorder() -> Result<(), AnalysisError> {
let yaml = r###"
version: '1.0'
root:
type: group
fields:
flags:
type: field
bits: 2
bit_order: msb
"###;
let schema = Schema::from_yaml(yaml).expect("Failed to parse test schema");
let options = CompressionOptions::default();
let mut analyzer = SchemaAnalyzer::new(&schema, options);
analyzer.add_entry(&[0b11000000])?; analyzer.add_entry(&[0b00000000])?; analyzer.add_entry(&[0b10000000])?; analyzer.add_entry(&[0b01000000])?;
{
let flags_field = analyzer
.field_states
.get_mut("flags")
.ok_or(AnalysisError::FieldNotFound("flags".to_string()))?;
assert_eq!(flags_field.count, 4, "Should process 4 entries");
assert_eq!(
flags_field.bit_counts.len(),
2,
"Should track 2 bits per field"
);
let writer = match &mut flags_field.writer {
BitWriterContainer::Msb(value) => value,
_ => panic!("Expected MSB variant"),
};
writer.byte_align()?;
writer.flush()?;
let inner_writer = writer.writer().unwrap();
let data = inner_writer.get_ref();
assert_eq!(data[0], 0xC9_u8, "Combined bits should form 0xC9");
let expected_counts =
FxHashMap::from_iter([(0b11, 1), (0b00, 1), (0b10, 1), (0b01, 1)]);
assert_eq!(
flags_field.value_counts, expected_counts,
"Value counts should match"
);
for (x, stats) in flags_field.bit_counts.iter().enumerate() {
assert_eq!(
stats.zeros, 2,
"Bit {} should have 2 zeros (actual: {})",
x, stats.zeros
);
assert_eq!(
stats.ones, 2,
"Bit {} should have 2 ones (actual: {})",
x, stats.ones
);
}
}
analyzer.add_entry(&[0b01000000])?; let flags_field = analyzer
.field_states
.get_mut("flags")
.ok_or(AnalysisError::FieldNotFound("flags".to_string()))?;
let expected_counts = FxHashMap::from_iter([(0b11, 1), (0b00, 1), (0b10, 1), (0b01, 2)]);
assert_eq!(
flags_field.value_counts, expected_counts,
"Value counts should match"
);
Ok(())
}
#[test]
fn test_little_endian_bitorder() {
let yaml = r###"
version: '1.0'
root:
type: group
fields:
flags:
type: field
bits: 2
bit_order: lsb
"###;
let schema = Schema::from_yaml(yaml).expect("Failed to parse test schema");
let options = CompressionOptions::default();
let mut analyzer = SchemaAnalyzer::new(&schema, options);
analyzer.add_entry(&[0b11000000]).unwrap(); analyzer.add_entry(&[0b00000000]).unwrap(); analyzer.add_entry(&[0b10000000]).unwrap(); analyzer.add_entry(&[0b01000000]).unwrap();
analyzer.add_entry(&[0b10000000]).unwrap(); let flags_field = analyzer.field_states.get_mut("flags").unwrap();
let expected_counts = FxHashMap::from_iter([(0b11, 1), (0b00, 1), (0b10, 1), (0b01, 2)]);
assert_eq!(
flags_field.value_counts, expected_counts,
"Value counts should match"
);
}
#[test]
fn test_field_stats_structure() {
let schema = create_test_schema();
let options = CompressionOptions::default();
let analyzer = SchemaAnalyzer::new(&schema, options);
let root_group = analyzer.field_states.get("id").unwrap();
assert_eq!(root_group.name, "id");
assert_eq!(root_group.full_path, "id");
assert_eq!(root_group.depth, 0);
assert_eq!(root_group.count, 0);
assert_eq!(root_group.lenbits, 32);
assert_eq!(root_group.bit_counts.len(), root_group.lenbits as usize);
assert_eq!(root_group.bit_order, BitOrder::Msb);
let id_field = analyzer.field_states.get("nested").unwrap();
assert_eq!(id_field.full_path, "nested");
assert_eq!(id_field.name, "nested");
assert_eq!(id_field.depth, 0);
assert_eq!(id_field.count, 0);
assert_eq!(id_field.lenbits, 8);
assert_eq!(id_field.bit_counts.len(), id_field.lenbits as usize);
assert_eq!(id_field.bit_order, BitOrder::Lsb);
let nested_value = analyzer.field_states.get("value").unwrap();
assert_eq!(nested_value.full_path, "nested.value");
assert_eq!(nested_value.name, "value");
assert_eq!(nested_value.depth, 1);
assert_eq!(nested_value.count, 0);
assert_eq!(nested_value.lenbits, 8);
assert_eq!(nested_value.bit_counts.len(), nested_value.lenbits as usize);
assert_eq!(nested_value.bit_order, BitOrder::Lsb); }
#[test]
fn skips_group_based_on_conditions() {
let yaml = r#"
version: '1.0'
root:
type: group
skip_if_not:
- byte_offset: 0
bit_offset: 0
bits: 8
value: 0x55
fields:
dummy: 8
"#;
let schema = Schema::from_yaml(yaml).unwrap();
let options = CompressionOptions::default();
let mut analyzer = SchemaAnalyzer::new(&schema, options);
analyzer.add_entry(&[0x55]).unwrap();
assert_eq!(analyzer.field_states.get("dummy").unwrap().count, 1);
analyzer.add_entry(&[0xAA]).unwrap();
assert_eq!(analyzer.field_states.get("dummy").unwrap().count, 1);
analyzer.add_entry(&[0x55]).unwrap();
assert_eq!(analyzer.field_states.get("dummy").unwrap().count, 2);
}
#[test]
fn skips_field_based_on_conditions() {
let yaml = r#"
version: '1.0'
root:
type: group
fields:
header:
type: field
bits: 7
skip_if_not:
- byte_offset: 0
bit_offset: 0
bits: 1
value: 1
"#;
let schema = Schema::from_yaml(yaml).unwrap();
let options = CompressionOptions::default();
let mut analyzer = SchemaAnalyzer::new(&schema, options);
analyzer.add_entry(&[0b10000000]).unwrap();
assert_eq!(analyzer.field_states.get("header").unwrap().count, 1);
analyzer.add_entry(&[0b00000000]).unwrap();
assert_eq!(analyzer.field_states.get("header").unwrap().count, 1);
}
#[test]
fn test_builder() {
let options = CompressionOptions::default().with_zstd_compression_level(7);
assert_eq!(options.zstd_compression_level, 7);
let options = CompressionOptions::default();
assert_eq!(options.zstd_compression_level, 16); }
}