use std::io::{self, Read, Write, Seek, SeekFrom};
use std::path::{Path, PathBuf};
const MAGIC: &[u8; 8] = b"SSHIDX01";
const MPHF_MAGIC: &[u8; 8] = b"SSHIMH02";
const FORMAT_VERSION: (u32, u32) = (3, 0);
const MPHF_FORMAT_VERSION: (u32, u32) = (2, 0);
#[derive(Clone, Debug)]
pub struct DictionarySerializationHeader {
pub magic: [u8; 8],
pub version_major: u32,
pub version_minor: u32,
pub k: usize,
pub m: usize,
pub canonical: bool,
pub num_mphf_partitions: u32,
}
impl DictionarySerializationHeader {
pub fn new(k: usize, m: usize, canonical: bool, num_mphf_partitions: u32) -> Self {
Self {
magic: *MAGIC,
version_major: FORMAT_VERSION.0,
version_minor: FORMAT_VERSION.1,
k,
m,
canonical,
num_mphf_partitions,
}
}
pub fn write(&self, writer: &mut dyn Write) -> io::Result<()> {
writer.write_all(&self.magic)?;
writer.write_all(&self.version_major.to_le_bytes())?;
writer.write_all(&self.version_minor.to_le_bytes())?;
writer.write_all(&(self.k as u64).to_le_bytes())?;
writer.write_all(&(self.m as u64).to_le_bytes())?;
writer.write_all(&[self.canonical as u8])?;
writer.write_all(&self.num_mphf_partitions.to_le_bytes())?;
Ok(())
}
pub fn read(reader: &mut dyn Read) -> io::Result<Self> {
let mut magic = [0u8; 8];
reader.read_exact(&mut magic)?;
if &magic != MAGIC {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Invalid magic number for SSHash index file",
));
}
let mut version_major_bytes = [0u8; 4];
let mut version_minor_bytes = [0u8; 4];
let mut k_bytes = [0u8; 8];
let mut m_bytes = [0u8; 8];
let mut canonical_bytes = [0u8; 1];
let mut num_partitions_bytes = [0u8; 4];
reader.read_exact(&mut version_major_bytes)?;
reader.read_exact(&mut version_minor_bytes)?;
reader.read_exact(&mut k_bytes)?;
reader.read_exact(&mut m_bytes)?;
reader.read_exact(&mut canonical_bytes)?;
reader.read_exact(&mut num_partitions_bytes)?;
let version_major = u32::from_le_bytes(version_major_bytes);
let version_minor = u32::from_le_bytes(version_minor_bytes);
if version_major != FORMAT_VERSION.0 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"Incompatible format version: {}.{}, expected {}.{}",
version_major, version_minor, FORMAT_VERSION.0, FORMAT_VERSION.1
),
));
}
Ok(Self {
magic,
version_major,
version_minor,
k: u64::from_le_bytes(k_bytes) as usize,
m: u64::from_le_bytes(m_bytes) as usize,
canonical: canonical_bytes[0] != 0,
num_mphf_partitions: u32::from_le_bytes(num_partitions_bytes),
})
}
}
#[derive(Clone, Copy, Debug)]
pub struct MphfPartitionEntry {
pub partition_id: u32,
pub byte_offset: u64,
pub byte_size: u64,
}
impl MphfPartitionEntry {
fn write(&self, writer: &mut dyn Write) -> io::Result<()> {
writer.write_all(&self.partition_id.to_le_bytes())?;
writer.write_all(&self.byte_offset.to_le_bytes())?;
writer.write_all(&self.byte_size.to_le_bytes())?;
Ok(())
}
fn read(reader: &mut dyn Read) -> io::Result<Self> {
let mut id_bytes = [0u8; 4];
let mut offset_bytes = [0u8; 8];
let mut size_bytes = [0u8; 8];
reader.read_exact(&mut id_bytes)?;
reader.read_exact(&mut offset_bytes)?;
reader.read_exact(&mut size_bytes)?;
Ok(Self {
partition_id: u32::from_le_bytes(id_bytes),
byte_offset: u64::from_le_bytes(offset_bytes),
byte_size: u64::from_le_bytes(size_bytes),
})
}
}
#[derive(Clone, Debug)]
pub struct MphfContainerHeader {
pub magic: [u8; 8],
pub version_major: u32,
pub version_minor: u32,
pub num_partitions: u32,
}
impl MphfContainerHeader {
pub fn new(num_partitions: u32) -> Self {
Self {
magic: *MPHF_MAGIC,
version_major: MPHF_FORMAT_VERSION.0,
version_minor: MPHF_FORMAT_VERSION.1,
num_partitions,
}
}
pub fn write(&self, writer: &mut dyn Write) -> io::Result<()> {
writer.write_all(&self.magic)?;
writer.write_all(&self.version_major.to_le_bytes())?;
writer.write_all(&self.version_minor.to_le_bytes())?;
writer.write_all(&self.num_partitions.to_le_bytes())?;
Ok(())
}
pub fn read(reader: &mut dyn Read) -> io::Result<Self> {
let mut magic = [0u8; 8];
reader.read_exact(&mut magic)?;
if &magic != MPHF_MAGIC {
if &magic == b"SSHIMH01" {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"MPHF container is v1 format (SSHIMH01). Please rebuild the index — v2 (PartitionedMphf) is required.",
));
}
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Invalid magic number for SSHash MPHF container file",
));
}
let mut version_major_bytes = [0u8; 4];
let mut version_minor_bytes = [0u8; 4];
let mut num_partitions_bytes = [0u8; 4];
reader.read_exact(&mut version_major_bytes)?;
reader.read_exact(&mut version_minor_bytes)?;
reader.read_exact(&mut num_partitions_bytes)?;
let version_major = u32::from_le_bytes(version_major_bytes);
let version_minor = u32::from_le_bytes(version_minor_bytes);
if version_major != MPHF_FORMAT_VERSION.0 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"Incompatible MPHF format version: {}.{}, expected {}.{}",
version_major, version_minor, MPHF_FORMAT_VERSION.0, MPHF_FORMAT_VERSION.1
),
));
}
Ok(Self {
magic,
version_major,
version_minor,
num_partitions: u32::from_le_bytes(num_partitions_bytes),
})
}
}
pub fn index_file_path<P: AsRef<Path>>(base: P) -> PathBuf {
let mut path = base.as_ref().to_path_buf();
let ext = path.extension().map(|e| e.to_string_lossy().to_string()).unwrap_or_default();
if ext == "ssi" {
path
} else if ext.is_empty() {
path.set_extension("ssi");
path
} else {
path.set_extension(format!("{ext}.ssi"));
path
}
}
pub fn mphf_container_path<P: AsRef<Path>>(base: P) -> PathBuf {
let base_path = index_file_path(base);
let mut container_path = base_path.clone();
let filename = format!("{}.mphf", base_path.file_name().unwrap().to_string_lossy());
container_path.pop();
container_path.push(filename);
container_path
}
#[derive(Debug)]
pub enum SerializationError {
Io(io::Error),
Other(String),
}
impl From<io::Error> for SerializationError {
fn from(err: io::Error) -> Self {
SerializationError::Io(err)
}
}
impl std::fmt::Display for SerializationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SerializationError::Io(e) => write!(f, "IO error: {}", e),
SerializationError::Other(s) => write!(f, "{}", s),
}
}
}
impl std::error::Error for SerializationError {}
pub type SerializationResult<T> = Result<T, SerializationError>;
pub fn write_mphf_container<W: Write + Seek>(
writer: &mut W,
mphfs: &[Option<&crate::partitioned_mphf::PartitionedMphf>],
) -> io::Result<Vec<MphfPartitionEntry>> {
let num_partitions = mphfs.len() as u32;
let header = MphfContainerHeader::new(num_partitions);
header.write(writer)?;
let mut offset_table = Vec::new();
let offset_table_start = writer.stream_position()?;
for i in 0..num_partitions {
let entry = MphfPartitionEntry {
partition_id: i,
byte_offset: 0, byte_size: 0, };
entry.write(writer)?;
}
let _data_start = writer.stream_position()?;
for (partition_id, mphf_opt) in mphfs.iter().enumerate() {
let byte_offset = writer.stream_position()?;
if let Some(pmphf) = mphf_opt {
let mut mphf_buffer = Vec::new();
pmphf.write_to(&mut mphf_buffer)?;
let byte_size = mphf_buffer.len() as u64;
writer.write_all(&mphf_buffer)?;
offset_table.push(MphfPartitionEntry {
partition_id: partition_id as u32,
byte_offset,
byte_size,
});
} else {
offset_table.push(MphfPartitionEntry {
partition_id: partition_id as u32,
byte_offset,
byte_size: 0,
});
}
}
writer.seek(SeekFrom::Start(offset_table_start))?;
for entry in &offset_table {
entry.write(writer)?;
}
writer.seek(SeekFrom::End(0))?;
Ok(offset_table)
}
pub fn read_mphf_container<R: Read + Seek>(
reader: &mut R,
) -> io::Result<Vec<Option<crate::partitioned_mphf::PartitionedMphf>>> {
let header = MphfContainerHeader::read(reader)?;
let mut offset_table = Vec::with_capacity(header.num_partitions as usize);
for _ in 0..header.num_partitions {
offset_table.push(MphfPartitionEntry::read(reader)?);
}
let mut mphfs: Vec<Option<crate::partitioned_mphf::PartitionedMphf>> =
(0..header.num_partitions).map(|_| None).collect();
for entry in offset_table {
if entry.byte_size > 0 {
reader.seek(SeekFrom::Start(entry.byte_offset))?;
let pmphf = crate::partitioned_mphf::PartitionedMphf::read_from(reader)?;
mphfs[entry.partition_id as usize] = Some(pmphf);
}
}
Ok(mphfs)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_header_roundtrip() {
let header = DictionarySerializationHeader::new(31, 13, true, 2);
let mut buffer = Vec::new();
header.write(&mut buffer).unwrap();
let header2 = DictionarySerializationHeader::read(&mut buffer.as_slice()).unwrap();
assert_eq!(header.k, header2.k);
assert_eq!(header.m, header2.m);
assert_eq!(header.canonical, header2.canonical);
assert_eq!(header.num_mphf_partitions, header2.num_mphf_partitions);
}
#[test]
fn test_mphf_container_header_roundtrip() {
let header = MphfContainerHeader::new(5);
let mut buffer = Vec::new();
header.write(&mut buffer).unwrap();
let header2 = MphfContainerHeader::read(&mut buffer.as_slice()).unwrap();
assert_eq!(header.num_partitions, header2.num_partitions);
}
#[test]
fn test_mphf_partition_entry_roundtrip() {
let entry = MphfPartitionEntry {
partition_id: 3,
byte_offset: 1024,
byte_size: 512,
};
let mut buffer = Vec::new();
entry.write(&mut buffer).unwrap();
let entry2 = MphfPartitionEntry::read(&mut buffer.as_slice()).unwrap();
assert_eq!(entry.partition_id, entry2.partition_id);
assert_eq!(entry.byte_offset, entry2.byte_offset);
assert_eq!(entry.byte_size, entry2.byte_size);
}
#[test]
fn test_file_path_construction() {
let base = Path::new("/tmp/my_index");
let index = index_file_path(base);
assert!(index.to_string_lossy().ends_with("my_index.ssi"));
let mphf = mphf_container_path(base);
assert!(mphf.to_string_lossy().contains("my_index.ssi.mphf"));
assert!(!mphf.to_string_lossy().contains(".mphf.0")); }
}