use std::collections::BinaryHeap;
use std::fs::File;
use std::io::{BufReader, BufWriter, Read, Write};
use std::path::Path;
use flate2::Compression;
use flate2::read::GzDecoder;
use flate2::write::GzEncoder;
use serde::{Deserialize, Serialize};
use crate::rat_enum::stream::runs::list_run_files;
pub const UNIQUE_FILENAME: &str = "unique.bin";
pub const CERTIFICATE_FILENAME: &str = "certificate.json";
const READER_BUFFER_BYTES: usize = 64 * 1024;
pub const CERTIFICATE_SCHEMA_VERSION: u32 = 1;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Certificate {
pub schema_version: u32,
pub ring: u8,
pub max_steps: usize,
pub step: i8,
pub free: bool,
pub run_files: usize,
pub total_input_records: u64,
pub unique_records: u64,
pub unique_bytes: u64,
pub unique_blake3: String,
}
struct HeapEntry {
key: Vec<u8>,
source_index: usize,
}
impl Ord for HeapEntry {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
other
.key
.cmp(&self.key)
.then_with(|| other.source_index.cmp(&self.source_index))
}
}
impl PartialOrd for HeapEntry {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl PartialEq for HeapEntry {
fn eq(&self, other: &Self) -> bool {
self.cmp(other) == std::cmp::Ordering::Equal
}
}
impl Eq for HeapEntry {}
struct RunReader {
inner: BufReader<GzDecoder<File>>,
buf: Vec<u8>,
}
impl RunReader {
fn open(path: &Path) -> std::io::Result<Self> {
let f = File::open(path)?;
Ok(RunReader {
inner: BufReader::with_capacity(READER_BUFFER_BYTES, GzDecoder::new(f)),
buf: Vec::with_capacity(64),
})
}
fn next_record(&mut self) -> std::io::Result<Option<Vec<u8>>> {
let mut len_buf = [0u8; 1];
match self.inner.read_exact(&mut len_buf) {
Ok(()) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(e),
}
let len = len_buf[0] as usize;
self.buf.clear();
self.buf.push(len_buf[0]);
self.buf.resize(1 + len, 0);
self.inner.read_exact(&mut self.buf[1..1 + len])?;
Ok(Some(self.buf.clone()))
}
}
#[allow(clippy::too_many_arguments)]
pub fn merge_runs(
out_dir: &Path,
ring: u8,
max_steps: usize,
step: i8,
free: bool,
) -> std::io::Result<Certificate> {
let runs_dir = out_dir.join(crate::rat_enum::stream::enumerate::RUNS_SUBDIR);
let run_files = list_run_files(&runs_dir)?;
if run_files.is_empty() {
return Err(std::io::Error::new(
std::io::ErrorKind::NotFound,
format!("no run files under {}", runs_dir.display()),
));
}
println!(
"merge: consuming {} run file(s) from {}",
run_files.len(),
runs_dir.display()
);
let mut readers: Vec<RunReader> = run_files
.iter()
.map(|p| RunReader::open(p))
.collect::<std::io::Result<_>>()?;
let mut heap = BinaryHeap::with_capacity(readers.len());
for (i, r) in readers.iter_mut().enumerate() {
if let Some(rec) = r.next_record()? {
heap.push(HeapEntry {
key: rec,
source_index: i,
});
}
}
let unique_path = out_dir.join(UNIQUE_FILENAME);
let unique_file = File::create(&unique_path)?;
let mut writer = GzEncoder::new(
BufWriter::with_capacity(1 << 20, unique_file),
Compression::fast(),
);
let mut hasher = blake3::Hasher::new();
let mut total_in: u64 = 0;
let mut unique_out: u64 = 0;
let mut unique_bytes: u64 = 0;
let mut last: Option<Vec<u8>> = None;
while let Some(top) = heap.pop() {
total_in += 1;
let key = top.key;
let src = top.source_index;
let is_new = match last.as_ref() {
Some(prev) => prev != &key,
None => true,
};
if is_new {
writer.write_all(&key)?;
hasher.update(&key);
unique_out += 1;
unique_bytes += key.len() as u64;
last = Some(key);
}
if let Some(next_rec) = readers[src].next_record()? {
heap.push(HeapEntry {
key: next_rec,
source_index: src,
});
}
}
writer.finish()?.flush()?;
let unique_blake3 = hasher.finalize().to_hex().to_string();
let cert = Certificate {
schema_version: CERTIFICATE_SCHEMA_VERSION,
ring,
max_steps,
step,
free,
run_files: run_files.len(),
total_input_records: total_in,
unique_records: unique_out,
unique_bytes,
unique_blake3,
};
let cert_path = out_dir.join(CERTIFICATE_FILENAME);
let cert_file = File::create(&cert_path)?;
serde_json::to_writer_pretty(BufWriter::new(cert_file), &cert)?;
println!(
"merge: {} input records -> {} unique ({} bytes); blake3={}",
cert.total_input_records, cert.unique_records, cert.unique_bytes, cert.unique_blake3
);
println!(
"merge: wrote {} + {}",
unique_path.display(),
cert_path.display()
);
Ok(cert)
}
pub fn read_unique_records(path: &Path) -> std::io::Result<UniqueRecordIter> {
let f = File::open(path)?;
Ok(UniqueRecordIter {
inner: BufReader::with_capacity(READER_BUFFER_BYTES, GzDecoder::new(f)),
})
}
#[derive(Debug)]
pub struct UniqueRecordIter {
inner: BufReader<GzDecoder<File>>,
}
impl Iterator for UniqueRecordIter {
type Item = std::io::Result<Vec<i8>>;
fn next(&mut self) -> Option<Self::Item> {
let mut len_buf = [0u8; 1];
match self.inner.read_exact(&mut len_buf) {
Ok(()) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return None,
Err(e) => return Some(Err(e)),
}
let len = len_buf[0] as usize;
let mut bytes = vec![0u8; len];
if let Err(e) = self.inner.read_exact(&mut bytes) {
return Some(Err(e));
}
Some(Ok(bytes
.into_iter()
.map(|b| (b as i16 - crate::rat_enum::stream::records::ANGLE_BIAS) as i8)
.collect()))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::rat_enum::stream::runs::RunWriter;
use std::path::PathBuf;
use std::sync::atomic::{AtomicUsize, Ordering};
fn tempdir() -> PathBuf {
static C: AtomicUsize = AtomicUsize::new(0);
let n = C.fetch_add(1, Ordering::Relaxed);
let pid = std::process::id();
let path = std::env::temp_dir().join(format!("rat_enum_merge_test_{pid}_{n}"));
std::fs::create_dir_all(&path).unwrap();
path
}
#[test]
fn merge_dedupes_across_runs_and_preserves_order() {
let dir = tempdir();
let runs_dir = dir.join(super::super::enumerate::RUNS_SUBDIR);
std::fs::create_dir_all(&runs_dir).unwrap();
{
let mut w = RunWriter::with_threshold(&runs_dir, 1, 1000);
w.record(&[]);
w.record(&[1, 2, 3]);
w.record(&[2]);
}
{
let mut w = RunWriter::with_threshold(&runs_dir, 2, 1000);
w.record(&[1, 2, 3]); w.record(&[-1, 0]);
w.record(&[1, 1]);
}
let cert = merge_runs(&dir, 12, 9, 1, false).unwrap();
assert_eq!(cert.run_files, 2);
assert_eq!(cert.total_input_records, 6);
assert_eq!(cert.unique_records, 5);
let recs: Vec<Vec<i8>> = read_unique_records(&dir.join(UNIQUE_FILENAME))
.unwrap()
.map(|r| r.unwrap())
.collect();
assert_eq!(
recs,
vec![vec![], vec![2], vec![-1, 0], vec![1, 1], vec![1, 2, 3],]
);
use crate::rat_enum::stream::records::encode_record;
let mut hasher = blake3::Hasher::new();
for rec in read_unique_records(&dir.join(UNIQUE_FILENAME)).unwrap() {
let mut enc = Vec::new();
encode_record(&rec.unwrap(), &mut enc);
hasher.update(&enc);
}
assert_eq!(hasher.finalize().to_hex().to_string(), cert.unique_blake3);
}
#[test]
fn merge_errors_on_empty_runs_dir() {
let dir = tempdir();
let runs_dir = dir.join(super::super::enumerate::RUNS_SUBDIR);
std::fs::create_dir_all(&runs_dir).unwrap();
let err = merge_runs(&dir, 12, 9, 1, false).unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::NotFound);
}
#[test]
fn merge_idempotent_on_rerun() {
let dir = tempdir();
let runs_dir = dir.join(super::super::enumerate::RUNS_SUBDIR);
std::fs::create_dir_all(&runs_dir).unwrap();
{
let mut w = RunWriter::with_threshold(&runs_dir, 1, 1000);
w.record(&[1, 2, 3]);
w.record(&[2]);
w.record(&[]);
}
{
let mut w = RunWriter::with_threshold(&runs_dir, 2, 1000);
w.record(&[1, 2, 3]); w.record(&[-1, 0]);
}
let cert1 = merge_runs(&dir, 12, 9, 1, false).unwrap();
let unique_1 = std::fs::read(dir.join(UNIQUE_FILENAME)).unwrap();
let cert_json_1 = std::fs::read(dir.join(CERTIFICATE_FILENAME)).unwrap();
let cert2 = merge_runs(&dir, 12, 9, 1, false).unwrap();
let unique_2 = std::fs::read(dir.join(UNIQUE_FILENAME)).unwrap();
let cert_json_2 = std::fs::read(dir.join(CERTIFICATE_FILENAME)).unwrap();
assert_eq!(cert1.unique_blake3, cert2.unique_blake3);
assert_eq!(cert1.unique_records, cert2.unique_records);
assert_eq!(cert1.total_input_records, cert2.total_input_records);
assert_eq!(
unique_1, unique_2,
"unique.bin not byte-identical across merges"
);
assert_eq!(
cert_json_1, cert_json_2,
"certificate.json not byte-identical across merges"
);
}
}