use std::cell::RefCell;
use merc_aterm::ATerm;
use merc_aterm::ATermRead;
use merc_collections::IndexedSet;
use merc_io::BitStreamRead;
use merc_io::BitStreamWrite;
use merc_number::bits_for_value;
use merc_utilities::MercError;
use crate::Data;
use crate::Ldd;
use crate::Storage;
use crate::iterators::iter_nodes;
const BLF_MAGIC: u64 = 0x8baf;
const BLF_VERSION: u64 = 0x8306;
pub struct BinaryLddWriter<W: BitStreamWrite> {
writer: W,
nodes: RefCell<IndexedSet<Ldd>>,
}
impl<W: BitStreamWrite> BinaryLddWriter<W> {
pub fn new(storage: &mut Storage, mut writer: W) -> Result<Self, MercError> {
writer.write_bits(BLF_MAGIC, 16)?;
writer.write_bits(BLF_VERSION, 16)?;
let mut nodes = IndexedSet::new();
nodes.insert(storage.empty_set().clone());
nodes.insert(storage.empty_vector().clone());
Ok(Self {
writer,
nodes: RefCell::new(nodes),
})
}
pub fn write_ldd(&mut self, ldd: &Ldd, storage: &Storage) -> Result<(), MercError> {
for (node, Data(value, down, right)) in iter_nodes(storage, ldd, |node| {
!self.nodes.borrow().contains(node)
}) {
let mut nodes = self.nodes.borrow_mut();
let (index, inserted) = nodes.insert(node.clone());
if inserted {
self.writer.write_bits(0, 1)?;
self.writer.write_integer(value as u64)?;
self.writer.write_bits(
*nodes
.index(&down)
.expect("The down node must have already been written") as u64,
Self::ldd_index_width(&nodes),
)?;
self.writer.write_bits(
*nodes
.index(&right)
.expect("The right node must have already been written") as u64,
Self::ldd_index_width(&nodes),
)?;
}
if node == *ldd {
self.writer.write_bits(1, 1)?;
self.writer.write_bits(*index as u64, Self::ldd_index_width(&nodes))?;
}
}
Ok(())
}
fn ldd_index_width(nodes: &IndexedSet<Ldd>) -> u8 {
bits_for_value(nodes.len())
}
}
pub struct BinaryLddReader<R: BitStreamRead> {
reader: R,
nodes: Vec<Ldd>,
}
impl<R: BitStreamRead> BinaryLddReader<R> {
pub fn new(storage: &mut Storage, mut reader: R) -> Result<Self, MercError> {
let magic = reader.read_bits(16)?;
if magic != BLF_MAGIC {
return Err("Invalid magic number in binary LDD stream".into());
}
let version = reader.read_bits(16)?;
if version != BLF_VERSION {
return Err(format!("The BLF version ({version}) of the input file is incompatible with the version ({BLF_VERSION}) of this tool. The input file must be regenerated.").into());
}
let nodes = vec![storage.empty_set().clone(), storage.empty_vector().clone()];
Ok(Self { reader, nodes })
}
pub fn read_ldd(&mut self, storage: &mut Storage) -> Result<Ldd, MercError> {
loop {
let is_output = self.reader.read_bits(1)? == 1;
if is_output {
let index = self.reader.read_bits(self.ldd_index_width(false))? as usize;
return Ok(self
.nodes
.get(index)
.ok_or(format!("Read invalid ldd index {index}, length {}", self.nodes.len()))?
.clone());
}
let value = self.reader.read_integer()?;
let down_index = self.reader.read_bits(self.ldd_index_width(true))? as usize;
let right_index = self.reader.read_bits(self.ldd_index_width(true))? as usize;
let ldd = storage.insert(
value as u32,
self.nodes.get(down_index).ok_or(format!(
"Read invalid down ldd index {down_index}, length {}",
self.nodes.len()
))?,
self.nodes.get(right_index).ok_or(format!(
"Read invalid right ldd index {right_index}, length {}",
self.nodes.len()
))?,
);
self.nodes.push(ldd);
}
}
fn ldd_index_width(&self, input: bool) -> u8 {
bits_for_value(self.nodes.len() + input as usize) }
}
impl<R: BitStreamRead + ATermRead> ATermRead for BinaryLddReader<R> {
delegate::delegate! {
to self.reader {
fn read_aterm(&mut self) -> Result<Option<ATerm>, MercError>;
fn read_aterm_iter(&mut self) -> Result<impl ExactSizeIterator<Item = Result<ATerm, MercError>>, MercError>;
}
}
}
impl<R: BitStreamRead> BitStreamRead for BinaryLddReader<R> {
delegate::delegate! {
to self.reader {
fn read_bits(&mut self, num_bits: u8) -> Result<u64, MercError>;
fn read_integer(&mut self) -> Result<u64, MercError>;
fn read_string(&mut self) -> Result<String, MercError>;
}
}
}
#[cfg(test)]
mod tests {
use merc_io::BitStreamReader;
use merc_io::BitStreamWriter;
use merc_utilities::random_test;
use crate::test_utility::from_iter;
use crate::test_utility::random_vector_set;
use super::*;
#[test]
#[cfg_attr(miri, ignore)]
fn test_binary_ldd_stream() {
random_test(100, |rng| {
let mut storage = Storage::new();
let input: Vec<_> = (0..20)
.map(|_| {
let input = random_vector_set(rng, 32, 10, 10);
from_iter(&mut storage, input.iter())
})
.collect();
let mut vector: Vec<u8> = Vec::new();
let stream = BitStreamWriter::new(&mut vector);
let mut output_stream = BinaryLddWriter::new(&mut storage, stream).unwrap();
for term in &input {
output_stream.write_ldd(term, &storage).unwrap();
}
drop(output_stream);
let mut input_stream = BinaryLddReader::new(&mut storage, BitStreamReader::new(&vector[..])).unwrap();
for term in &input {
debug_assert_eq!(
*term,
input_stream.read_ldd(&mut storage).unwrap(),
"The read LDD must match the LDD that we have written"
);
}
});
}
}