use std::io::{BufWriter, Write};
use std::path::{Path, PathBuf};
use crate::error::{Error, Result};
use crate::index::bm25::Bm25Index;
const WAL_OP_ADD: u8 = 0x01;
const WAL_OP_REMOVE: u8 = 0x02;
const BM25_WAL_FILENAME: &str = "bm25.wal";
const ADD_ENTRY_HEADER: usize = 1 + 8 + 4; const REMOVE_ENTRY_HEADER: usize = 1 + 8;
#[must_use]
pub(crate) fn wal_path_for_bm25(dir: &Path) -> PathBuf {
dir.join(BM25_WAL_FILENAME)
}
#[inline]
pub(crate) fn wal_append_add_document(wal_path: &Path, id: u64, text: &str) -> Result<()> {
let text_bytes = text.as_bytes();
let text_len = encode_text_len(text_bytes)?;
let body_len = add_entry_body_len(text_len)?;
let mut w = open_wal_writer(wal_path)?;
wal_write(&mut w, &body_len.to_le_bytes())?;
write_add_entry_body(&mut w, id, text_len, text_bytes)?;
flush_wal(&mut w)
}
#[inline]
fn encode_text_len(text_bytes: &[u8]) -> Result<u32> {
u32::try_from(text_bytes.len()).map_err(|_| {
Error::Index(format!(
"BM25 WAL: text too large ({} bytes) to encode",
text_bytes.len()
))
})
}
#[inline]
fn add_entry_body_len(text_len: u32) -> Result<u32> {
let header =
u32::try_from(ADD_ENTRY_HEADER).expect("ADD_ENTRY_HEADER fits in u32 (compile-time)");
header.checked_add(text_len).ok_or_else(|| {
Error::Index(format!(
"BM25 WAL: entry too large (text_len={text_len}) to fit in u32 prefix"
))
})
}
#[inline]
fn write_add_entry_body(
w: &mut std::io::BufWriter<std::fs::File>,
id: u64,
text_len: u32,
text_bytes: &[u8],
) -> Result<()> {
wal_write(w, &[WAL_OP_ADD])?;
wal_write(w, &id.to_le_bytes())?;
wal_write(w, &text_len.to_le_bytes())?;
wal_write(w, text_bytes)
}
#[inline]
pub(crate) fn wal_append_remove_document(wal_path: &Path, id: u64) -> Result<()> {
let body_len = u32::try_from(REMOVE_ENTRY_HEADER).expect("REMOVE_ENTRY_HEADER <= u32::MAX");
let mut w = open_wal_writer(wal_path)?;
wal_write(&mut w, &body_len.to_le_bytes())?;
wal_write(&mut w, &[WAL_OP_REMOVE])?;
wal_write(&mut w, &id.to_le_bytes())?;
flush_wal(&mut w)
}
pub(crate) fn wal_truncate(wal_path: &Path) -> Result<()> {
if !wal_path.exists() {
return Ok(());
}
let file = std::fs::OpenOptions::new()
.write(true)
.open(wal_path)
.map_err(|e| Error::Index(format!("BM25 WAL truncate open: {e}")))?;
file.set_len(0)
.map_err(|e| Error::Index(format!("BM25 WAL truncate: {e}")))
}
pub(crate) fn wal_replay(wal_path: &Path, index: &Bm25Index) -> Result<u64> {
let data = match std::fs::read(wal_path) {
Ok(d) => d,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(0),
Err(e) => return Err(Error::Index(format!("BM25 WAL read: {e}"))),
};
let mut pos = 0usize;
let mut count = 0u64;
while pos < data.len() {
let Some((body_start, body_len)) = read_entry_header(&data, pos) else {
break;
};
pos = body_start;
if pos + body_len > data.len() {
tracing::warn!(
"BM25 WAL truncated at offset {body_start}: declared {body_len} bytes but only {} remain",
data.len() - pos
);
break;
}
let op = data[pos];
pos += 1;
let applied = replay_single_entry(&data, op, &mut pos, body_start, body_len, index)?;
count += applied;
if pos < body_start + body_len {
pos = body_start + body_len;
}
}
Ok(count)
}
fn read_entry_header(data: &[u8], pos: usize) -> Option<(usize, usize)> {
if pos + 4 > data.len() {
tracing::warn!("BM25 WAL truncated at offset {pos}: not enough bytes for length prefix");
return None;
}
let bytes: [u8; 4] = data[pos..pos + 4].try_into().ok()?;
let body_len = u32::from_le_bytes(bytes) as usize;
Some((pos + 4, body_len))
}
fn replay_single_entry(
data: &[u8],
op: u8,
pos: &mut usize,
body_start: usize,
body_len: usize,
index: &Bm25Index,
) -> Result<u64> {
match op {
WAL_OP_ADD => replay_add_entry(data, pos, body_start, body_len, index),
WAL_OP_REMOVE => replay_remove_entry(data, pos, index),
unknown => {
tracing::warn!("BM25 WAL unknown op 0x{unknown:02x} at offset {body_start}");
*pos = body_start + body_len;
Ok(0)
}
}
}
fn replay_add_entry(
data: &[u8],
pos: &mut usize,
body_start: usize,
body_len: usize,
index: &Bm25Index,
) -> Result<u64> {
if body_len < ADD_ENTRY_HEADER {
tracing::warn!("BM25 WAL add entry too short at offset {body_start}");
*pos = body_start + body_len;
return Ok(0);
}
let id = read_le_u64(data, *pos)?;
*pos += 8;
let text_len = read_le_u32(data, *pos)? as usize;
*pos += 4;
let text_end = *pos + text_len;
if text_end > body_start + body_len || text_end > data.len() {
tracing::warn!("BM25 WAL add entry truncated at offset {body_start}");
*pos = body_start + body_len;
return Ok(0);
}
let text = std::str::from_utf8(&data[*pos..text_end])
.map_err(|e| Error::Index(format!("BM25 WAL add: invalid utf8 at {body_start}: {e}")))?;
index.add_document(id, text);
*pos = text_end;
Ok(1)
}
fn replay_remove_entry(data: &[u8], pos: &mut usize, index: &Bm25Index) -> Result<u64> {
let id = read_le_u64(data, *pos)?;
*pos += 8;
index.remove_document(id);
Ok(1)
}
#[inline]
fn read_le_u64(data: &[u8], pos: usize) -> Result<u64> {
data[pos..pos + 8]
.try_into()
.map(u64::from_le_bytes)
.map_err(|_| Error::Index(format!("BM25 WAL: corrupt u64 at offset {pos}")))
}
#[inline]
fn read_le_u32(data: &[u8], pos: usize) -> Result<u32> {
data[pos..pos + 4]
.try_into()
.map(u32::from_le_bytes)
.map_err(|_| Error::Index(format!("BM25 WAL: corrupt u32 at offset {pos}")))
}
fn open_wal_writer(wal_path: &Path) -> Result<BufWriter<std::fs::File>> {
let file = std::fs::OpenOptions::new()
.create(true)
.append(true)
.open(wal_path)
.map_err(|e| Error::Index(format!("BM25 WAL open: {e}")))?;
Ok(BufWriter::new(file))
}
fn wal_write(w: &mut BufWriter<std::fs::File>, bytes: &[u8]) -> Result<()> {
w.write_all(bytes)
.map_err(|e| Error::Index(format!("BM25 WAL write: {e}")))
}
fn flush_wal(w: &mut BufWriter<std::fs::File>) -> Result<()> {
w.flush()
.map_err(|e| Error::Index(format!("BM25 WAL flush: {e}")))?;
w.get_ref()
.sync_all()
.map_err(|e| Error::Index(format!("BM25 WAL fsync: {e}")))
}