use std::io::{BufWriter, Write};
use std::path::Path;
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use super::inverted_index::{FrozenSegment, SparseInvertedIndex};
use super::types::{PostingEntry, SparseVector};
use crate::error::{Error, Result};
const WAL_OP_UPSERT: u8 = 0x01;
const WAL_OP_DELETE: u8 = 0x02;
const COMPACTION_REPLAY_THRESHOLD: u64 = 10_000;
#[derive(Debug, Serialize, Deserialize)]
pub(super) struct SparseMeta {
pub(super) version: u32,
pub(super) doc_count: u64,
pub(super) term_count: u32,
}
#[derive(Debug, Serialize, Deserialize)]
struct TermEntry {
term_id: u32,
offset: u64,
len: u32,
max_weight: f32,
}
const POSTING_DISK_SIZE: usize = 12;
const _: () = assert!(
std::mem::size_of::<u64>() + std::mem::size_of::<f32>() == POSTING_DISK_SIZE,
"POSTING_DISK_SIZE must match u64 + f32 packed size"
);
#[inline]
fn read_le_u64(data: &[u8], pos: usize, context: &str) -> Result<u64> {
data[pos..pos + 8]
.try_into()
.map(u64::from_le_bytes)
.map_err(|_| Error::SparseIndexError(format!("{context} at offset {pos}")))
}
#[inline]
fn read_le_u32(data: &[u8], pos: usize, context: &str) -> Result<u32> {
data[pos..pos + 4]
.try_into()
.map(u32::from_le_bytes)
.map_err(|_| Error::SparseIndexError(format!("{context} at offset {pos}")))
}
#[inline]
fn read_le_f32(data: &[u8], pos: usize, context: &str) -> Result<f32> {
data[pos..pos + 4]
.try_into()
.map(f32::from_le_bytes)
.map_err(|_| Error::SparseIndexError(format!("{context} at offset {pos}")))
}
pub fn wal_append_upsert(wal_path: &Path, point_id: u64, vector: &SparseVector) -> Result<()> {
#[allow(clippy::cast_possible_truncation)] let nnz = vector.nnz() as u32;
let total_len = compute_upsert_entry_len(nnz)?;
let mut w = open_wal_writer(wal_path)?;
wal_write(&mut w, &total_len.to_le_bytes())?;
wal_write(&mut w, &[WAL_OP_UPSERT])?;
wal_write(&mut w, &point_id.to_le_bytes())?;
wal_write(&mut w, &nnz.to_le_bytes())?;
for (&idx, &val) in vector.indices.iter().zip(vector.values.iter()) {
wal_write(&mut w, &idx.to_le_bytes())?;
wal_write(&mut w, &val.to_le_bytes())?;
}
w.flush()
.map_err(|e| Error::SparseIndexError(format!("WAL flush failed: {e}")))?;
Ok(())
}
fn compute_upsert_entry_len(nnz: u32) -> Result<u32> {
nnz.checked_mul(8)
.and_then(|pairs_len| {
1u32.checked_add(8)
.and_then(|h| h.checked_add(4))
.and_then(|h| h.checked_add(pairs_len))
})
.ok_or_else(|| {
Error::SparseIndexError(format!(
"WAL entry too large: nnz={nnz} would overflow u32 length prefix"
))
})
}
fn open_wal_writer(wal_path: &Path) -> Result<BufWriter<std::fs::File>> {
let file = std::fs::OpenOptions::new()
.create(true)
.append(true)
.open(wal_path)
.map_err(|e| Error::SparseIndexError(format!("WAL open failed: {e}")))?;
Ok(BufWriter::new(file))
}
fn wal_write(w: &mut BufWriter<std::fs::File>, bytes: &[u8]) -> Result<()> {
w.write_all(bytes)
.map_err(|e| Error::SparseIndexError(format!("WAL write failed: {e}")))
}
pub fn wal_append_delete(wal_path: &Path, point_id: u64) -> Result<()> {
let total_len: u32 = 1 + 8;
let file = std::fs::OpenOptions::new()
.create(true)
.append(true)
.open(wal_path)
.map_err(|e| Error::SparseIndexError(format!("WAL open failed: {e}")))?;
let mut w = BufWriter::new(file);
w.write_all(&total_len.to_le_bytes())
.map_err(|e| Error::SparseIndexError(format!("WAL write failed: {e}")))?;
w.write_all(&[WAL_OP_DELETE])
.map_err(|e| Error::SparseIndexError(format!("WAL write failed: {e}")))?;
w.write_all(&point_id.to_le_bytes())
.map_err(|e| Error::SparseIndexError(format!("WAL write failed: {e}")))?;
w.flush()
.map_err(|e| Error::SparseIndexError(format!("WAL flush failed: {e}")))?;
Ok(())
}
pub fn wal_replay(wal_path: &Path, index: &SparseInvertedIndex) -> Result<u64> {
let data = match std::fs::read(wal_path) {
Ok(d) => d,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(0),
Err(e) => return Err(Error::SparseIndexError(format!("WAL read failed: {e}"))),
};
let mut pos = 0usize;
let mut count = 0u64;
while pos < data.len() {
let Some((entry_start, total_len)) = read_wal_entry_header(&data, pos) else {
break;
};
pos += 4;
if pos + total_len > data.len() {
tracing::warn!(
"Sparse WAL truncated at offset {entry_start}: declared {total_len} bytes but only {} remain",
data.len() - pos
);
break;
}
let op = data[pos];
pos += 1;
match op {
WAL_OP_UPSERT => {
let Some(new_pos) = replay_upsert_entry(&data, pos, entry_start, total_len, index)?
else {
break;
};
pos = new_pos;
count += 1;
}
WAL_OP_DELETE => {
let point_id = read_le_u64(&data, pos, "WAL entry corrupted: bad point_id bytes")?;
pos += 8;
index.delete(point_id);
count += 1;
}
unknown => {
tracing::warn!("Sparse WAL unknown op 0x{unknown:02x} at offset {entry_start}");
pos = entry_start + total_len;
}
}
let expected_end = entry_start + total_len;
if pos < expected_end {
pos = expected_end;
}
}
Ok(count)
}
fn read_wal_entry_header(data: &[u8], pos: usize) -> Option<(usize, usize)> {
if pos + 4 > data.len() {
tracing::warn!("Sparse WAL truncated at offset {pos}: not enough bytes for length prefix");
return None;
}
let total_len =
read_le_u32(data, pos, "WAL entry corrupted: bad length-prefix bytes").ok()? as usize;
Some((pos + 4, total_len))
}
fn replay_upsert_entry(
data: &[u8],
mut pos: usize,
entry_start: usize,
total_len: usize,
index: &SparseInvertedIndex,
) -> Result<Option<usize>> {
if total_len < 1 + 8 + 4 {
tracing::warn!("Sparse WAL upsert entry too short at offset {entry_start}");
return Ok(None);
}
let point_id = read_le_u64(data, pos, "WAL entry corrupted: bad point_id bytes")?;
pos += 8;
let nnz = read_le_u32(data, pos, "WAL entry corrupted: bad nnz bytes")? as usize;
pos += 4;
if entry_start + total_len < pos + nnz * 8 {
tracing::warn!("Sparse WAL upsert entry truncated at offset {entry_start}");
return Ok(None);
}
let pairs = read_term_weight_pairs(data, &mut pos, nnz)?;
let vector = SparseVector::new(pairs);
index.insert(point_id, &vector);
Ok(Some(pos))
}
fn read_term_weight_pairs(data: &[u8], pos: &mut usize, nnz: usize) -> Result<Vec<(u32, f32)>> {
let mut pairs = Vec::with_capacity(nnz);
for _ in 0..nnz {
let idx = read_le_u32(data, *pos, "WAL entry corrupted: bad term-index bytes")?;
*pos += 4;
let val = read_le_f32(data, *pos, "WAL entry corrupted: bad weight bytes")?;
*pos += 4;
pairs.push((idx, val));
}
Ok(pairs)
}
fn sparse_file_prefix(name: &str) -> String {
if name.is_empty() {
"sparse".to_string()
} else {
format!("sparse-{name}")
}
}
pub fn compact_named(dir: &Path, name: &str, index: &SparseInvertedIndex) -> Result<()> {
let prefix = sparse_file_prefix(name);
compact_with_prefix(dir, &prefix, index)
}
pub fn load_named_from_disk(dir: &Path, name: &str) -> Result<Option<SparseInvertedIndex>> {
let prefix = sparse_file_prefix(name);
load_from_disk_with_prefix(dir, &prefix)
}
#[must_use]
pub fn wal_path_for_name(dir: &Path, name: &str) -> std::path::PathBuf {
let prefix = sparse_file_prefix(name);
dir.join(format!("{prefix}.wal"))
}
pub fn compact(dir: &Path, index: &SparseInvertedIndex) -> Result<()> {
compact_with_prefix(dir, "sparse", index)
}
fn compact_with_prefix(dir: &Path, prefix: &str, index: &SparseInvertedIndex) -> Result<()> {
let merged = index.get_merged_postings_for_compaction();
let mut term_ids: Vec<u32> = merged.keys().copied().collect();
term_ids.sort_unstable();
let term_entries = write_idx_tmp(dir, prefix, &term_ids, &merged)?;
write_terms_tmp(dir, prefix, &term_entries)?;
write_meta_tmp(dir, prefix, index.doc_count(), &term_ids)?;
atomic_rename_compacted_files(dir, prefix)?;
truncate_wal(dir, prefix)?;
Ok(())
}
fn write_idx_tmp(
dir: &Path,
prefix: &str,
term_ids: &[u32],
merged: &FxHashMap<u32, (Vec<PostingEntry>, f32)>,
) -> Result<Vec<TermEntry>> {
let idx_tmp = dir.join(format!("{prefix}.idx.tmp"));
let mut idx_file = BufWriter::new(
std::fs::File::create(&idx_tmp)
.map_err(|e| Error::SparseIndexError(format!("compact idx create: {e}")))?,
);
let mut term_entries: Vec<TermEntry> = Vec::with_capacity(term_ids.len());
let mut current_offset: u64 = 0;
for &term_id in term_ids {
let (postings, max_weight) = lookup_term(term_id, merged)?;
write_postings(&mut idx_file, postings)?;
let byte_len = (postings.len() * POSTING_DISK_SIZE) as u64;
term_entries.push(TermEntry {
term_id,
offset: current_offset,
#[allow(clippy::cast_possible_truncation)]
len: postings.len() as u32,
max_weight: *max_weight,
});
current_offset += byte_len;
}
idx_file
.flush()
.map_err(|e| Error::SparseIndexError(format!("compact idx flush: {e}")))?;
Ok(term_entries)
}
fn lookup_term(
term_id: u32,
merged: &FxHashMap<u32, (Vec<PostingEntry>, f32)>,
) -> Result<&(Vec<PostingEntry>, f32)> {
merged.get(&term_id).ok_or_else(|| {
Error::SparseIndexError(format!(
"compact: term_id {term_id} absent from merged postings map"
))
})
}
fn write_postings(w: &mut BufWriter<std::fs::File>, postings: &[PostingEntry]) -> Result<()> {
for entry in postings {
w.write_all(&entry.doc_id.to_le_bytes())
.map_err(|e| Error::SparseIndexError(format!("compact idx write: {e}")))?;
w.write_all(&entry.weight.to_le_bytes())
.map_err(|e| Error::SparseIndexError(format!("compact idx write: {e}")))?;
}
Ok(())
}
fn write_terms_tmp(dir: &Path, prefix: &str, term_entries: &[TermEntry]) -> Result<()> {
let terms_tmp = dir.join(format!("{prefix}.terms.tmp"));
let terms_data = postcard::to_allocvec(term_entries)
.map_err(|e| Error::SparseIndexError(format!("compact terms serialize: {e}")))?;
std::fs::write(&terms_tmp, &terms_data)
.map_err(|e| Error::SparseIndexError(format!("compact terms write: {e}")))
}
fn write_meta_tmp(dir: &Path, prefix: &str, doc_count: u64, term_ids: &[u32]) -> Result<()> {
let meta_tmp = dir.join(format!("{prefix}.meta.tmp"));
let meta = SparseMeta {
version: 1,
doc_count,
#[allow(clippy::cast_possible_truncation)]
term_count: term_ids.len() as u32,
};
let meta_data = postcard::to_allocvec(&meta)
.map_err(|e| Error::SparseIndexError(format!("compact meta serialize: {e}")))?;
std::fs::write(&meta_tmp, &meta_data)
.map_err(|e| Error::SparseIndexError(format!("compact meta write: {e}")))
}
fn atomic_rename_compacted_files(dir: &Path, prefix: &str) -> Result<()> {
for ext in &["idx", "terms", "meta"] {
std::fs::rename(
dir.join(format!("{prefix}.{ext}.tmp")),
dir.join(format!("{prefix}.{ext}")),
)
.map_err(|e| Error::SparseIndexError(format!("compact {ext} rename: {e}")))?;
}
Ok(())
}
fn truncate_wal(dir: &Path, prefix: &str) -> Result<()> {
let wal_path = dir.join(format!("{prefix}.wal"));
if wal_path.exists() {
let file = std::fs::OpenOptions::new()
.write(true)
.open(&wal_path)
.map_err(|e| Error::SparseIndexError(format!("compact wal truncate: {e}")))?;
file.set_len(0)
.map_err(|e| Error::SparseIndexError(format!("compact wal truncate: {e}")))?;
}
Ok(())
}
pub fn load_from_disk(dir: &Path) -> Result<Option<SparseInvertedIndex>> {
load_from_disk_with_prefix(dir, "sparse")
}
fn load_from_disk_with_prefix(dir: &Path, prefix: &str) -> Result<Option<SparseInvertedIndex>> {
let meta_path = dir.join(format!("{prefix}.meta"));
if !meta_path.exists() {
return load_wal_only(dir, prefix);
}
let meta = load_and_validate_meta(&meta_path)?;
let index = load_compacted_index(dir, prefix, &meta)?;
let wal_path = dir.join(format!("{prefix}.wal"));
let replayed = wal_replay(&wal_path, &index)?;
if replayed >= COMPACTION_REPLAY_THRESHOLD {
compact_with_prefix(dir, prefix, &index)?;
}
Ok(Some(index))
}
fn load_wal_only(dir: &Path, prefix: &str) -> Result<Option<SparseInvertedIndex>> {
let wal_path = dir.join(format!("{prefix}.wal"));
if !wal_path.exists() {
return Ok(None);
}
let index = SparseInvertedIndex::new();
let replayed = wal_replay(&wal_path, &index)?;
if replayed == 0 {
return Ok(None);
}
if replayed >= COMPACTION_REPLAY_THRESHOLD {
compact_with_prefix(dir, prefix, &index)?;
}
Ok(Some(index))
}
fn load_and_validate_meta(meta_path: &Path) -> Result<SparseMeta> {
let meta_data = std::fs::read(meta_path)
.map_err(|e| Error::SparseIndexError(format!("load meta read: {e}")))?;
let meta: SparseMeta = postcard::from_bytes(&meta_data)
.map_err(|e| Error::SparseIndexError(format!("load meta deserialize: {e}")))?;
if meta.version != 1 {
return Err(Error::SparseIndexError(format!(
"unsupported sparse meta version: {}",
meta.version
)));
}
Ok(meta)
}
fn load_compacted_index(
dir: &Path,
prefix: &str,
meta: &SparseMeta,
) -> Result<SparseInvertedIndex> {
let terms_path = dir.join(format!("{prefix}.terms"));
let terms_data = std::fs::read(&terms_path)
.map_err(|e| Error::SparseIndexError(format!("load terms read: {e}")))?;
let term_entries: Vec<TermEntry> = postcard::from_bytes(&terms_data)
.map_err(|e| Error::SparseIndexError(format!("load terms deserialize: {e}")))?;
let idx_path = dir.join(format!("{prefix}.idx"));
let idx_data = std::fs::read(&idx_path)
.map_err(|e| Error::SparseIndexError(format!("load idx read: {e}")))?;
let postings = build_postings_from_idx(&idx_data, &term_entries)?;
#[allow(clippy::cast_possible_truncation)]
let frozen = FrozenSegment::new(postings, meta.doc_count as usize);
Ok(SparseInvertedIndex::from_frozen_segment(frozen))
}
fn build_postings_from_idx(
idx_data: &[u8],
term_entries: &[TermEntry],
) -> Result<FxHashMap<u32, (Vec<PostingEntry>, f32)>> {
let mut postings: FxHashMap<u32, (Vec<PostingEntry>, f32)> = FxHashMap::default();
for te in term_entries {
#[allow(clippy::cast_possible_truncation)] let start = te.offset as usize;
let byte_count = (te.len as usize) * POSTING_DISK_SIZE;
let end = start + byte_count;
if end > idx_data.len() {
return Err(Error::SparseIndexError(format!(
"load idx: term {} offset {start}+{byte_count} exceeds file size {}",
te.term_id,
idx_data.len()
)));
}
let mut entries = Vec::with_capacity(te.len as usize);
let mut pos = start;
for _ in 0..te.len {
let doc_id = read_le_u64(idx_data, pos, "load idx: corrupt doc_id bytes")?;
pos += 8;
let weight = read_le_f32(idx_data, pos, "load idx: corrupt weight bytes")?;
pos += 4;
entries.push(PostingEntry { doc_id, weight });
}
postings.insert(te.term_id, (entries, te.max_weight));
}
Ok(postings)
}