use crate::{ProofNode, ProofRule, ResolutionProof};
use std::collections::HashMap;
pub struct StructuralCompressor {
pattern_dict: HashMap<Pattern, PatternId>,
next_pattern_id: PatternId,
stats: CompressionStats,
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum Pattern {
SingleResolution {
left_size: usize,
right_size: usize,
result_size: usize,
},
ResolutionChain { length: usize },
BinaryTree { depth: usize },
Custom(Vec<u8>),
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct PatternId(pub usize);
#[derive(Clone, Debug, Default)]
pub struct CompressionStats {
pub original_bytes: usize,
pub compressed_bytes: usize,
pub patterns_found: usize,
pub unique_patterns: usize,
}
impl StructuralCompressor {
pub fn new() -> Self {
Self {
pattern_dict: HashMap::new(),
next_pattern_id: PatternId(0),
stats: CompressionStats::default(),
}
}
pub fn compress(&mut self, proof: &ResolutionProof) -> CompressedProof {
self.stats.original_bytes = self.estimate_size(proof);
let patterns = self.identify_patterns(proof);
self.stats.patterns_found = patterns.len();
self.stats.unique_patterns = self.pattern_dict.len();
let dict = self.build_dictionary(&patterns);
let encoded = self.encode_with_dictionary(proof, &dict);
self.stats.compressed_bytes = encoded.bytes.len();
encoded
}
fn identify_patterns(&mut self, proof: &ResolutionProof) -> Vec<(PatternId, Vec<usize>)> {
let mut patterns = Vec::new();
for i in 0..proof.nodes.len() {
if let Some((pattern, nodes)) = self.find_resolution_chain(proof, i) {
let pattern_id = self.get_or_create_pattern(pattern);
patterns.push((pattern_id, nodes));
}
}
for i in 0..proof.nodes.len() {
if let Some((pattern, nodes)) = self.find_binary_tree(proof, i) {
let pattern_id = self.get_or_create_pattern(pattern);
patterns.push((pattern_id, nodes));
}
}
patterns
}
fn find_resolution_chain(
&self,
proof: &ResolutionProof,
start: usize,
) -> Option<(Pattern, Vec<usize>)> {
let mut nodes = vec![start];
let mut current = start;
let mut length = 1;
loop {
let node = proof.nodes.get(current)?;
match &node.rule {
ProofRule::Resolution { left, .. } => {
if let Some(left_node) = proof.nodes.get(*left) {
if matches!(left_node.rule, ProofRule::Resolution { .. }) {
nodes.push(*left);
current = *left;
length += 1;
continue;
}
}
break;
}
_ => break,
}
}
if length >= 3 {
Some((Pattern::ResolutionChain { length }, nodes))
} else {
None
}
}
fn find_binary_tree(
&self,
proof: &ResolutionProof,
start: usize,
) -> Option<(Pattern, Vec<usize>)> {
let depth = self.compute_tree_depth(proof, start);
if depth >= 2 {
let nodes = self.collect_tree_nodes(proof, start);
Some((Pattern::BinaryTree { depth }, nodes))
} else {
None
}
}
fn compute_tree_depth(&self, proof: &ResolutionProof, node_idx: usize) -> usize {
let node = match proof.nodes.get(node_idx) {
Some(n) => n,
None => return 0,
};
match &node.rule {
ProofRule::Resolution { left, right, .. } => {
let left_depth = self.compute_tree_depth(proof, *left);
let right_depth = self.compute_tree_depth(proof, *right);
1 + left_depth.max(right_depth)
}
_ => 0,
}
}
fn collect_tree_nodes(&self, proof: &ResolutionProof, node_idx: usize) -> Vec<usize> {
let mut nodes = vec![node_idx];
if let Some(node) = proof.nodes.get(node_idx) {
if let ProofRule::Resolution { left, right, .. } = &node.rule {
nodes.extend(self.collect_tree_nodes(proof, *left));
nodes.extend(self.collect_tree_nodes(proof, *right));
}
}
nodes
}
fn get_or_create_pattern(&mut self, pattern: Pattern) -> PatternId {
if let Some(&id) = self.pattern_dict.get(&pattern) {
return id;
}
let id = self.next_pattern_id;
self.next_pattern_id = PatternId(id.0 + 1);
self.pattern_dict.insert(pattern, id);
id
}
fn build_dictionary(&self, patterns: &[(PatternId, Vec<usize>)]) -> CompressionDictionary {
let mut dict = CompressionDictionary::new();
let mut frequencies: HashMap<PatternId, usize> = HashMap::new();
for (pattern_id, _) in patterns {
*frequencies.entry(*pattern_id).or_insert(0) += 1;
}
let mut sorted_patterns: Vec<_> = frequencies.into_iter().collect();
sorted_patterns.sort_by_key(|(_, freq)| std::cmp::Reverse(*freq));
for (pattern_id, _freq) in sorted_patterns.iter().take(256) {
if let Some(pattern) = self.pattern_dict.iter().find(|(_, &id)| id == *pattern_id) {
dict.add_entry(*pattern_id, pattern.0.clone());
}
}
dict
}
fn encode_with_dictionary(
&self,
proof: &ResolutionProof,
dict: &CompressionDictionary,
) -> CompressedProof {
let mut bytes = Vec::new();
dict.encode(&mut bytes);
for node in &proof.nodes {
self.encode_node(node, &mut bytes, dict);
}
CompressedProof {
bytes,
dictionary: dict.clone(),
root: proof.root,
}
}
fn encode_node(&self, node: &ProofNode, bytes: &mut Vec<u8>, dict: &CompressionDictionary) {
Self::encode_varint(node.id as u64, bytes);
Self::encode_varint(node.clause.len() as u64, bytes);
for &lit in &node.clause {
Self::encode_signed_varint(lit as i64, bytes);
}
match &node.rule {
ProofRule::Resolution { left, right, pivot } => {
bytes.push(1); Self::encode_varint(*left as u64, bytes);
Self::encode_varint(*right as u64, bytes);
Self::encode_signed_varint(*pivot as i64, bytes);
}
ProofRule::Input => {
bytes.push(2); }
ProofRule::Axiom => {
bytes.push(3); }
}
}
fn encode_varint(mut value: u64, bytes: &mut Vec<u8>) {
loop {
let mut byte = (value & 0x7F) as u8;
value >>= 7;
if value != 0 {
byte |= 0x80;
}
bytes.push(byte);
if value == 0 {
break;
}
}
}
fn encode_signed_varint(value: i64, bytes: &mut Vec<u8>) {
let zigzag = if value >= 0 {
(value as u64) << 1
} else {
((-value - 1) as u64) << 1 | 1
};
Self::encode_varint(zigzag, bytes);
}
fn estimate_size(&self, proof: &ResolutionProof) -> usize {
let mut size = 0;
for node in &proof.nodes {
size += 8; size += 4; size += node.clause.len() * 4; size += 1;
if let ProofRule::Resolution { .. } = node.rule {
size += 8 + 8 + 4; }
}
size
}
pub fn stats(&self) -> &CompressionStats {
&self.stats
}
pub fn compression_ratio(&self) -> f64 {
if self.stats.original_bytes == 0 {
return 1.0;
}
self.stats.compressed_bytes as f64 / self.stats.original_bytes as f64
}
}
impl Default for StructuralCompressor {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone, Debug)]
pub struct CompressionDictionary {
entries: HashMap<PatternId, Pattern>,
}
impl CompressionDictionary {
pub fn new() -> Self {
Self {
entries: HashMap::new(),
}
}
pub fn add_entry(&mut self, id: PatternId, pattern: Pattern) {
self.entries.insert(id, pattern);
}
pub fn encode(&self, bytes: &mut Vec<u8>) {
StructuralCompressor::encode_varint(self.entries.len() as u64, bytes);
for (id, pattern) in &self.entries {
StructuralCompressor::encode_varint(id.0 as u64, bytes);
self.encode_pattern(pattern, bytes);
}
}
fn encode_pattern(&self, pattern: &Pattern, bytes: &mut Vec<u8>) {
match pattern {
Pattern::SingleResolution {
left_size,
right_size,
result_size,
} => {
bytes.push(1); StructuralCompressor::encode_varint(*left_size as u64, bytes);
StructuralCompressor::encode_varint(*right_size as u64, bytes);
StructuralCompressor::encode_varint(*result_size as u64, bytes);
}
Pattern::ResolutionChain { length } => {
bytes.push(2); StructuralCompressor::encode_varint(*length as u64, bytes);
}
Pattern::BinaryTree { depth } => {
bytes.push(3); StructuralCompressor::encode_varint(*depth as u64, bytes);
}
Pattern::Custom(data) => {
bytes.push(4); StructuralCompressor::encode_varint(data.len() as u64, bytes);
bytes.extend_from_slice(data);
}
}
}
}
impl Default for CompressionDictionary {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone, Debug)]
pub struct CompressedProof {
pub bytes: Vec<u8>,
pub dictionary: CompressionDictionary,
pub root: usize,
}
impl CompressedProof {
pub fn size_bytes(&self) -> usize {
self.bytes.len()
}
pub fn decompress(&self) -> Result<ResolutionProof, String> {
Err("Decompression not yet implemented".to_string())
}
}
pub struct LzProofCompressor {
window_size: usize,
min_match_length: usize,
}
impl LzProofCompressor {
pub fn new() -> Self {
Self {
window_size: 32768,
min_match_length: 3,
}
}
pub fn compress(&self, proof: &ResolutionProof) -> CompressedProof {
let mut bytes = Vec::new();
let proof_bytes = self.proof_to_bytes(proof);
self.lz_compress(&proof_bytes, &mut bytes);
CompressedProof {
bytes,
dictionary: CompressionDictionary::new(),
root: proof.root,
}
}
fn proof_to_bytes(&self, proof: &ResolutionProof) -> Vec<u8> {
let mut bytes = Vec::new();
for node in &proof.nodes {
bytes.extend_from_slice(&(node.id as u32).to_le_bytes());
bytes.extend_from_slice(&(node.clause.len() as u32).to_le_bytes());
for &lit in &node.clause {
bytes.extend_from_slice(&lit.to_le_bytes());
}
}
bytes
}
fn lz_compress(&self, input: &[u8], output: &mut Vec<u8>) {
let mut pos = 0;
while pos < input.len() {
let window_start = pos.saturating_sub(self.window_size);
let window = &input[window_start..pos];
if let Some((match_pos, match_len)) = self.find_longest_match(window, &input[pos..]) {
if match_len >= self.min_match_length {
self.encode_backreference(match_pos, match_len, output);
pos += match_len;
continue;
}
}
output.push(0); output.push(input[pos]);
pos += 1;
}
}
fn find_longest_match(&self, window: &[u8], lookahead: &[u8]) -> Option<(usize, usize)> {
let mut best_match: Option<(usize, usize)> = None;
for i in 0..window.len() {
let mut match_len = 0;
while match_len < lookahead.len()
&& i + match_len < window.len()
&& window[i + match_len] == lookahead[match_len]
{
match_len += 1;
}
if match_len >= self.min_match_length {
if best_match.map_or(true, |(_, len)| match_len > len) {
best_match = Some((i, match_len));
}
}
}
best_match
}
fn encode_backreference(&self, pos: usize, len: usize, output: &mut Vec<u8>) {
output.push(1); output.extend_from_slice(&(pos as u16).to_le_bytes());
output.extend_from_slice(&(len as u16).to_le_bytes());
}
}
impl Default for LzProofCompressor {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_structural_compressor_creation() {
let compressor = StructuralCompressor::new();
assert_eq!(compressor.pattern_dict.len(), 0);
}
#[test]
fn test_varint_encoding() {
let mut bytes = Vec::new();
StructuralCompressor::encode_varint(127, &mut bytes);
assert_eq!(bytes, vec![127]);
bytes.clear();
StructuralCompressor::encode_varint(300, &mut bytes);
assert_eq!(bytes, vec![0xAC, 0x02]);
}
#[test]
fn test_lz_compressor_creation() {
let compressor = LzProofCompressor::new();
assert_eq!(compressor.window_size, 32768);
assert_eq!(compressor.min_match_length, 3);
}
}