use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::{DomainInfo, PredicateInfo, StringInterner, SymbolTable};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct CompactSchema {
strings: Vec<String>,
domains: Vec<CompactDomain>,
predicates: Vec<CompactPredicate>,
variables: Vec<(usize, usize)>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
struct CompactDomain {
name_id: usize,
cardinality: usize,
description_id: Option<usize>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
struct CompactPredicate {
name_id: usize,
arg_domain_ids: Vec<usize>,
description_id: Option<usize>,
}
impl CompactSchema {
pub fn from_symbol_table(table: &SymbolTable) -> Self {
let mut interner = StringInterner::new();
let mut string_to_id = HashMap::new();
let mut intern = |s: &str| -> usize {
if let Some(&id) = string_to_id.get(s) {
id
} else {
let id = interner.intern(s);
string_to_id.insert(s.to_string(), id);
id
}
};
let domains: Vec<_> = table
.domains
.values()
.map(|domain| {
let name_id = intern(&domain.name);
let description_id = domain.description.as_ref().map(|d| intern(d));
CompactDomain {
name_id,
cardinality: domain.cardinality,
description_id,
}
})
.collect();
let predicates: Vec<_> = table
.predicates
.values()
.map(|pred| {
let name_id = intern(&pred.name);
let arg_domain_ids: Vec<_> = pred.arg_domains.iter().map(|d| intern(d)).collect();
let description_id = pred.description.as_ref().map(|d| intern(d));
CompactPredicate {
name_id,
arg_domain_ids,
description_id,
}
})
.collect();
let variables: Vec<_> = table
.variables
.iter()
.map(|(var, domain)| {
let var_id = intern(var);
let domain_id = intern(domain);
(var_id, domain_id)
})
.collect();
let strings: Vec<_> = (0..interner.len())
.filter_map(|id| interner.resolve(id).map(|s| s.to_string()))
.collect();
CompactSchema {
strings,
domains,
predicates,
variables,
}
}
pub fn to_symbol_table(&self) -> Result<SymbolTable> {
let mut table = SymbolTable::new();
for compact in &self.domains {
let name = self.strings.get(compact.name_id).ok_or_else(|| {
anyhow::anyhow!("Invalid string ID {} for domain name", compact.name_id)
})?;
let mut domain = DomainInfo::new(name.clone(), compact.cardinality);
if let Some(desc_id) = compact.description_id {
let description = self.strings.get(desc_id).ok_or_else(|| {
anyhow::anyhow!("Invalid string ID {} for description", desc_id)
})?;
domain.description = Some(description.clone());
}
table.add_domain(domain)?;
}
for compact in &self.predicates {
let name = self.strings.get(compact.name_id).ok_or_else(|| {
anyhow::anyhow!("Invalid string ID {} for predicate name", compact.name_id)
})?;
let arg_domains: Result<Vec<_>> = compact
.arg_domain_ids
.iter()
.map(|&id| {
self.strings
.get(id)
.cloned()
.ok_or_else(|| anyhow::anyhow!("Invalid string ID {} for arg domain", id))
})
.collect();
let mut pred = PredicateInfo::new(name.clone(), arg_domains?);
if let Some(desc_id) = compact.description_id {
let description = self.strings.get(desc_id).ok_or_else(|| {
anyhow::anyhow!("Invalid string ID {} for description", desc_id)
})?;
pred.description = Some(description.clone());
}
table.add_predicate(pred)?;
}
for &(var_id, domain_id) in &self.variables {
let var = self
.strings
.get(var_id)
.ok_or_else(|| anyhow::anyhow!("Invalid string ID {} for variable", var_id))?;
let domain = self.strings.get(domain_id).ok_or_else(|| {
anyhow::anyhow!("Invalid string ID {} for variable domain", domain_id)
})?;
table.bind_variable(var, domain)?;
}
Ok(table)
}
pub fn to_binary(&self) -> Result<Vec<u8>> {
oxicode::serde::encode_to_vec(self, oxicode::config::standard())
.map_err(|e| anyhow::anyhow!("Bincode encode error: {}", e))
}
pub fn from_binary(data: &[u8]) -> Result<Self> {
let (result, _): (Self, usize) =
oxicode::serde::decode_from_slice(data, oxicode::config::standard())
.map_err(|e| anyhow::anyhow!("Bincode decode error: {}", e))?;
Ok(result)
}
pub fn string_count(&self) -> usize {
self.strings.len()
}
pub fn compression_stats(&self) -> CompressionStats {
let string_bytes: usize = self.strings.iter().map(|s| s.len()).sum();
let domain_count = self.domains.len();
let predicate_count = self.predicates.len();
let variable_count = self.variables.len();
let avg_string_len = if !self.strings.is_empty() {
string_bytes / self.strings.len()
} else {
0
};
let estimated_original_size = domain_count * (avg_string_len + 16) + predicate_count * (avg_string_len + 16) + variable_count * (avg_string_len * 2);
CompressionStats {
unique_strings: self.strings.len(),
total_string_bytes: string_bytes,
domain_count,
predicate_count,
variable_count,
estimated_original_size,
compact_size: string_bytes
+ domain_count * 24
+ predicate_count * 24
+ variable_count * 16,
}
}
}
#[derive(Clone, Debug)]
pub struct CompressionStats {
pub unique_strings: usize,
pub total_string_bytes: usize,
pub domain_count: usize,
pub predicate_count: usize,
pub variable_count: usize,
pub estimated_original_size: usize,
pub compact_size: usize,
}
impl CompressionStats {
pub fn compression_ratio(&self) -> f64 {
if self.estimated_original_size > 0 {
self.compact_size as f64 / self.estimated_original_size as f64
} else {
1.0
}
}
pub fn space_savings(&self) -> f64 {
(1.0 - self.compression_ratio()) * 100.0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_compact_round_trip() {
let mut table = SymbolTable::new();
table
.add_domain(DomainInfo::new("Person", 100))
.expect("unwrap");
table
.add_domain(DomainInfo::new("Location", 50))
.expect("unwrap");
table
.add_predicate(PredicateInfo::new(
"at",
vec!["Person".to_string(), "Location".to_string()],
))
.expect("unwrap");
table.bind_variable("x", "Person").expect("unwrap");
let compact = CompactSchema::from_symbol_table(&table);
let recovered = compact.to_symbol_table().expect("unwrap");
assert_eq!(table.domains.len(), recovered.domains.len());
assert_eq!(table.predicates.len(), recovered.predicates.len());
assert_eq!(table.variables.len(), recovered.variables.len());
}
#[test]
fn test_string_deduplication() {
let mut table = SymbolTable::new();
table
.add_domain(DomainInfo::new("Person", 100))
.expect("unwrap");
table
.add_predicate(PredicateInfo::new("knows", vec!["Person".to_string()]))
.expect("unwrap");
table
.add_predicate(PredicateInfo::new("likes", vec!["Person".to_string()]))
.expect("unwrap");
let compact = CompactSchema::from_symbol_table(&table);
assert_eq!(compact.string_count(), 3);
}
#[test]
fn test_binary_serialization() {
let mut table = SymbolTable::new();
table
.add_domain(DomainInfo::new("Person", 100))
.expect("unwrap");
let compact = CompactSchema::from_symbol_table(&table);
let binary = compact.to_binary().expect("unwrap");
let recovered = CompactSchema::from_binary(&binary).expect("unwrap");
let table2 = recovered.to_symbol_table().expect("unwrap");
assert_eq!(table.domains.len(), table2.domains.len());
}
#[test]
fn test_compression_stats() {
let mut table = SymbolTable::new();
table
.add_domain(DomainInfo::new("Person", 100))
.expect("unwrap");
table
.add_domain(DomainInfo::new("Location", 50))
.expect("unwrap");
let compact = CompactSchema::from_symbol_table(&table);
let stats = compact.compression_stats();
assert_eq!(stats.domain_count, 2);
assert!(stats.compression_ratio() > 0.0);
assert!(stats.space_savings() > -200.0);
}
#[test]
fn test_empty_table() {
let table = SymbolTable::new();
let compact = CompactSchema::from_symbol_table(&table);
let recovered = compact.to_symbol_table().expect("unwrap");
assert_eq!(recovered.domains.len(), 0);
assert_eq!(recovered.predicates.len(), 0);
}
}