use std::collections::HashMap;
use std::fs::File;
use std::io::{BufReader, BufWriter, Read, Write};
use std::path::Path;
use crate::error::KmeRustError;
use crate::kmer::{unpack_to_string, KmerLength};
const MAGIC: &[u8; 4] = b"KMIX";
const VERSION: u8 = 1;
#[derive(Debug, Clone)]
pub struct KmerIndex {
k: KmerLength,
counts: HashMap<u64, u64>,
}
impl KmerIndex {
#[must_use]
pub const fn new(k: KmerLength, counts: HashMap<u64, u64>) -> Self {
Self { k, counts }
}
#[must_use]
pub const fn k(&self) -> KmerLength {
self.k
}
#[must_use]
pub fn len(&self) -> usize {
self.counts.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.counts.is_empty()
}
#[must_use]
pub const fn counts(&self) -> &HashMap<u64, u64> {
&self.counts
}
#[must_use]
pub fn into_counts(self) -> HashMap<u64, u64> {
self.counts
}
#[must_use]
pub fn get(&self, packed_bits: u64) -> Option<u64> {
self.counts.get(&packed_bits).copied()
}
#[must_use]
pub fn to_string_counts(&self) -> HashMap<String, u64> {
self.counts
.iter()
.map(|(&packed, &count)| (unpack_to_string(packed, self.k), count))
.collect()
}
}
pub fn save_index<P: AsRef<Path>>(index: &KmerIndex, path: P) -> Result<(), KmeRustError> {
let path = path.as_ref();
#[cfg(feature = "gzip")]
if is_gzip_path(path) {
let file = File::create(path).map_err(|e| KmeRustError::IndexWrite {
source: e,
path: path.to_path_buf(),
})?;
let encoder = flate2::write::GzEncoder::new(file, flate2::Compression::default());
let writer = BufWriter::new(encoder);
return write_index(index, writer, path);
}
let file = File::create(path).map_err(|e| KmeRustError::IndexWrite {
source: e,
path: path.to_path_buf(),
})?;
let writer = BufWriter::new(file);
write_index(index, writer, path)
}
pub fn load_index<P: AsRef<Path>>(path: P) -> Result<KmerIndex, KmeRustError> {
let path = path.as_ref();
#[cfg(feature = "gzip")]
if is_gzip_path(path) {
let file = File::open(path).map_err(|e| KmeRustError::IndexRead {
source: e,
path: path.to_path_buf(),
})?;
let decoder = flate2::read::GzDecoder::new(file);
let reader = BufReader::new(decoder);
return read_index(reader, path);
}
let file = File::open(path).map_err(|e| KmeRustError::IndexRead {
source: e,
path: path.to_path_buf(),
})?;
let reader = BufReader::new(file);
read_index(reader, path)
}
fn write_index<W: Write, P: AsRef<Path>>(
index: &KmerIndex,
mut writer: W,
path: P,
) -> Result<(), KmeRustError> {
let mut crc = Crc32Writer::new(&mut writer);
crc.write_all(MAGIC).map_err(|e| KmeRustError::IndexWrite {
source: e,
path: path.as_ref().to_path_buf(),
})?;
crc.write_all(&[VERSION])
.map_err(|e| KmeRustError::IndexWrite {
source: e,
path: path.as_ref().to_path_buf(),
})?;
crc.write_all(&[index.k.as_u8()])
.map_err(|e| KmeRustError::IndexWrite {
source: e,
path: path.as_ref().to_path_buf(),
})?;
crc.write_all(&(index.counts.len() as u64).to_le_bytes())
.map_err(|e| KmeRustError::IndexWrite {
source: e,
path: path.as_ref().to_path_buf(),
})?;
for (&packed, &count) in &index.counts {
crc.write_all(&packed.to_le_bytes())
.map_err(|e| KmeRustError::IndexWrite {
source: e,
path: path.as_ref().to_path_buf(),
})?;
crc.write_all(&count.to_le_bytes())
.map_err(|e| KmeRustError::IndexWrite {
source: e,
path: path.as_ref().to_path_buf(),
})?;
}
let checksum = crc.finalize();
writer
.write_all(&checksum.to_le_bytes())
.map_err(|e| KmeRustError::IndexWrite {
source: e,
path: path.as_ref().to_path_buf(),
})?;
writer.flush().map_err(|e| KmeRustError::IndexWrite {
source: e,
path: path.as_ref().to_path_buf(),
})?;
Ok(())
}
fn read_index<R: Read, P: AsRef<Path>>(reader: R, path: P) -> Result<KmerIndex, KmeRustError> {
let path = path.as_ref();
let mut data = Vec::new();
let mut reader = BufReader::new(reader);
reader
.read_to_end(&mut data)
.map_err(|e| KmeRustError::IndexRead {
source: e,
path: path.to_path_buf(),
})?;
if data.len() < 18 {
return Err(KmeRustError::InvalidIndex {
details: "file too small".into(),
path: path.to_path_buf(),
});
}
if &data[..4] != MAGIC {
return Err(KmeRustError::InvalidIndex {
details: "invalid magic bytes (not a kmerust index file)".into(),
path: path.to_path_buf(),
});
}
let (content, checksum_bytes) = data.split_at(data.len() - 4);
let Ok(checksum_array) = checksum_bytes.try_into() else {
unreachable!("split_at guarantees exactly 4 bytes");
};
let stored_checksum = u32::from_le_bytes(checksum_array);
let computed_checksum = crc32(content);
if computed_checksum != stored_checksum {
return Err(KmeRustError::InvalidIndex {
details: format!(
"checksum mismatch (expected {stored_checksum:#x}, got {computed_checksum:#x})"
),
path: path.to_path_buf(),
});
}
let mut cursor = &content[4..];
if cursor.is_empty() || cursor[0] != VERSION {
return Err(KmeRustError::InvalidIndex {
details: format!("unsupported version {}", cursor.first().unwrap_or(&0)),
path: path.to_path_buf(),
});
}
cursor = &cursor[1..];
if cursor.is_empty() {
return Err(KmeRustError::InvalidIndex {
details: "missing k-mer length".into(),
path: path.to_path_buf(),
});
}
let k_val = cursor[0];
let k = KmerLength::new(k_val as usize).map_err(|e| KmeRustError::InvalidIndex {
details: format!("invalid k-mer length: {e}"),
path: path.to_path_buf(),
})?;
cursor = &cursor[1..];
if cursor.len() < 8 {
return Err(KmeRustError::InvalidIndex {
details: "missing k-mer count".into(),
path: path.to_path_buf(),
});
}
let Ok(count_array) = cursor[..8].try_into() else {
unreachable!("length check guarantees 8 bytes");
};
let count = u64::from_le_bytes(count_array);
cursor = &cursor[8..];
#[allow(clippy::cast_possible_truncation)]
let expected_data_size = count as usize * 16; if cursor.len() != expected_data_size {
return Err(KmeRustError::InvalidIndex {
details: format!(
"data size mismatch (expected {expected_data_size} bytes, got {} bytes)",
cursor.len()
),
path: path.to_path_buf(),
});
}
#[allow(clippy::cast_possible_truncation)]
let mut counts = HashMap::with_capacity(count as usize);
for _ in 0..count {
let Ok(packed_array) = cursor[..8].try_into() else {
unreachable!("size validation guarantees 16 bytes per entry");
};
let Ok(count_array) = cursor[8..16].try_into() else {
unreachable!("size validation guarantees 16 bytes per entry");
};
let packed = u64::from_le_bytes(packed_array);
let kmer_count = u64::from_le_bytes(count_array);
counts.insert(packed, kmer_count);
cursor = &cursor[16..];
}
Ok(KmerIndex { k, counts })
}
fn crc32(data: &[u8]) -> u32 {
const POLYNOMIAL: u32 = 0xEDB8_8320;
let table: [u32; 256] = {
let mut table = [0u32; 256];
for (i, entry) in table.iter_mut().enumerate() {
#[allow(clippy::cast_possible_truncation)]
let mut crc = i as u32;
for _ in 0..8 {
if crc & 1 != 0 {
crc = (crc >> 1) ^ POLYNOMIAL;
} else {
crc >>= 1;
}
}
*entry = crc;
}
table
};
let mut crc = !0u32;
for &byte in data {
crc = table[((crc ^ u32::from(byte)) & 0xFF) as usize] ^ (crc >> 8);
}
!crc
}
struct Crc32Writer<W> {
inner: W,
data: Vec<u8>,
}
impl<W: Write> Crc32Writer<W> {
const fn new(inner: W) -> Self {
Self {
inner,
data: Vec::new(),
}
}
fn finalize(self) -> u32 {
crc32(&self.data)
}
}
impl<W: Write> Write for Crc32Writer<W> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let n = self.inner.write(buf)?;
self.data.extend_from_slice(&buf[..n]);
Ok(n)
}
fn flush(&mut self) -> std::io::Result<()> {
self.inner.flush()
}
}
#[cfg(feature = "gzip")]
fn is_gzip_path(path: &Path) -> bool {
path.extension()
.is_some_and(|ext| ext.eq_ignore_ascii_case("gz"))
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use tempfile::NamedTempFile;
#[test]
fn roundtrip_empty_index() {
let index = KmerIndex::new(KmerLength::new(21).unwrap(), HashMap::new());
let tmp = NamedTempFile::with_suffix(".kmix").unwrap();
save_index(&index, tmp.path()).unwrap();
let loaded = load_index(tmp.path()).unwrap();
assert_eq!(loaded.k(), index.k());
assert!(loaded.is_empty());
}
#[test]
fn roundtrip_with_data() {
let mut counts = HashMap::new();
counts.insert(0b00_01_10_11u64, 42u64); counts.insert(0b11_10_01_00u64, 17u64); counts.insert(0u64, 1u64);
let index = KmerIndex::new(KmerLength::new(4).unwrap(), counts.clone());
let tmp = NamedTempFile::with_suffix(".kmix").unwrap();
save_index(&index, tmp.path()).unwrap();
let loaded = load_index(tmp.path()).unwrap();
assert_eq!(loaded.k(), index.k());
assert_eq!(loaded.len(), 3);
assert_eq!(loaded.get(0b00_01_10_11), Some(42));
assert_eq!(loaded.get(0b11_10_01_00), Some(17));
assert_eq!(loaded.get(0), Some(1));
}
#[test]
fn roundtrip_various_k_lengths() {
for k_val in [1, 5, 16, 21, 32] {
let mut counts = HashMap::new();
counts.insert(1u64, 100u64);
let k = KmerLength::new(k_val).unwrap();
let index = KmerIndex::new(k, counts);
let tmp = NamedTempFile::with_suffix(".kmix").unwrap();
save_index(&index, tmp.path()).unwrap();
let loaded = load_index(tmp.path()).unwrap();
assert_eq!(loaded.k().get(), k_val);
}
}
#[test]
fn invalid_magic_rejected() {
let tmp = NamedTempFile::with_suffix(".kmix").unwrap();
std::fs::write(tmp.path(), b"GARBAGE_DATA_HERE_").unwrap();
let result = load_index(tmp.path());
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string().contains("invalid magic"),
"expected 'invalid magic' error, got: {err}"
);
}
#[test]
fn corrupted_checksum_rejected() {
let mut counts = HashMap::new();
counts.insert(1u64, 1u64);
let index = KmerIndex::new(KmerLength::new(4).unwrap(), counts);
let tmp = NamedTempFile::with_suffix(".kmix").unwrap();
save_index(&index, tmp.path()).unwrap();
let mut data = std::fs::read(tmp.path()).unwrap();
if let Some(byte) = data.get_mut(10) {
*byte ^= 0xFF;
}
std::fs::write(tmp.path(), data).unwrap();
let result = load_index(tmp.path());
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("checksum"));
}
#[test]
fn file_too_small_rejected() {
let tmp = NamedTempFile::with_suffix(".kmix").unwrap();
std::fs::write(tmp.path(), b"KMIX").unwrap();
let result = load_index(tmp.path());
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("too small"));
}
#[test]
fn to_string_counts() {
let mut counts = HashMap::new();
counts.insert(0b00_01_10_11u64, 42u64);
let index = KmerIndex::new(KmerLength::new(4).unwrap(), counts);
let string_counts = index.to_string_counts();
assert_eq!(string_counts.len(), 1);
assert_eq!(string_counts.get("ACGT"), Some(&42));
}
#[test]
fn crc32_known_values() {
assert_eq!(crc32(b""), 0x0000_0000);
assert_eq!(crc32(b"123456789"), 0xCBF4_3926);
}
#[cfg(feature = "gzip")]
#[test]
fn roundtrip_gzip() {
let mut counts = HashMap::new();
counts.insert(0b00_01_10_11u64, 42u64);
let index = KmerIndex::new(KmerLength::new(4).unwrap(), counts);
let tmp = NamedTempFile::with_suffix(".kmix.gz").unwrap();
save_index(&index, tmp.path()).unwrap();
let loaded = load_index(tmp.path()).unwrap();
assert_eq!(loaded.k(), index.k());
assert_eq!(loaded.get(0b00_01_10_11), Some(42));
}
}