use std::path::{Path, PathBuf};
use crate::error::{Error, Result};
use crate::index::bm25::Bm25Index;
use crate::index::wal_framing;
const CTX: &str = "BM25 WAL";
const WAL_OP_ADD: u8 = 0x01;
const WAL_OP_REMOVE: u8 = 0x02;
const BM25_WAL_FILENAME: &str = "bm25.wal";
const ADD_ENTRY_HEADER: usize = 1 + 8 + 4; const REMOVE_ENTRY_HEADER: usize = 1 + 8;
#[must_use]
pub(crate) fn wal_path_for_bm25(dir: &Path) -> PathBuf {
dir.join(BM25_WAL_FILENAME)
}
#[inline]
pub(crate) fn wal_append_add_document(wal_path: &Path, id: u64, text: &str) -> Result<()> {
let text_bytes = text.as_bytes();
let text_len = encode_text_len(text_bytes)?;
let body_len = add_entry_body_len(text_len)?;
let mut w = wal_framing::open_wal_writer(wal_path, CTX)?;
wal_framing::wal_write(&mut w, &body_len.to_le_bytes(), CTX)?;
write_add_entry_body(&mut w, id, text_len, text_bytes)?;
wal_framing::flush_wal(&mut w, CTX)
}
#[inline]
fn encode_text_len(text_bytes: &[u8]) -> Result<u32> {
u32::try_from(text_bytes.len()).map_err(|_| {
Error::Index(format!(
"BM25 WAL: text too large ({} bytes) to encode",
text_bytes.len()
))
})
}
#[inline]
fn add_entry_body_len(text_len: u32) -> Result<u32> {
let header = u32::try_from(ADD_ENTRY_HEADER)
.map_err(|_| Error::Index("BM25 WAL: add header too large".to_string()))?;
header.checked_add(text_len).ok_or_else(|| {
Error::Index(format!(
"BM25 WAL: entry too large (text_len={text_len}) to fit in u32 prefix"
))
})
}
#[inline]
fn write_add_entry_body(
w: &mut std::io::BufWriter<std::fs::File>,
id: u64,
text_len: u32,
text_bytes: &[u8],
) -> Result<()> {
wal_framing::wal_write(w, &[WAL_OP_ADD], CTX)?;
wal_framing::wal_write(w, &id.to_le_bytes(), CTX)?;
wal_framing::wal_write(w, &text_len.to_le_bytes(), CTX)?;
wal_framing::wal_write(w, text_bytes, CTX)
}
#[inline]
pub(crate) fn wal_append_remove_document(wal_path: &Path, id: u64) -> Result<()> {
let body_len = u32::try_from(REMOVE_ENTRY_HEADER)
.map_err(|_| Error::Index("BM25 WAL: remove header too large".to_string()))?;
let mut w = wal_framing::open_wal_writer(wal_path, CTX)?;
wal_framing::wal_write(&mut w, &body_len.to_le_bytes(), CTX)?;
wal_framing::wal_write(&mut w, &[WAL_OP_REMOVE], CTX)?;
wal_framing::wal_write(&mut w, &id.to_le_bytes(), CTX)?;
wal_framing::flush_wal(&mut w, CTX)
}
pub(crate) fn wal_truncate(wal_path: &Path) -> Result<()> {
wal_framing::wal_truncate(wal_path, CTX)
}
pub(crate) fn wal_replay(wal_path: &Path, index: &Bm25Index) -> Result<u64> {
let data = match std::fs::read(wal_path) {
Ok(d) => d,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(0),
Err(e) => return Err(Error::Index(format!("BM25 WAL read: {e}"))),
};
let mut pos = 0usize;
let mut count = 0u64;
while pos < data.len() {
let Some((body_start, body_len)) = wal_framing::read_entry_header(&data, pos, CTX) else {
break;
};
pos = body_start;
if pos + body_len > data.len() {
tracing::warn!(
"BM25 WAL truncated at offset {body_start}: declared {body_len} bytes but only {} remain",
data.len() - pos
);
break;
}
let op = data[pos];
pos += 1;
let applied = replay_single_entry(&data, op, &mut pos, body_start, body_len, index)?;
count += applied;
if pos < body_start + body_len {
pos = body_start + body_len;
}
}
Ok(count)
}
fn replay_single_entry(
data: &[u8],
op: u8,
pos: &mut usize,
body_start: usize,
body_len: usize,
index: &Bm25Index,
) -> Result<u64> {
match op {
WAL_OP_ADD => replay_add_entry(data, pos, body_start, body_len, index),
WAL_OP_REMOVE => replay_remove_entry(data, pos, index),
unknown => {
tracing::warn!("BM25 WAL unknown op 0x{unknown:02x} at offset {body_start}");
*pos = body_start + body_len;
Ok(0)
}
}
}
fn replay_add_entry(
data: &[u8],
pos: &mut usize,
body_start: usize,
body_len: usize,
index: &Bm25Index,
) -> Result<u64> {
if body_len < ADD_ENTRY_HEADER {
tracing::warn!("BM25 WAL add entry too short at offset {body_start}");
*pos = body_start + body_len;
return Ok(0);
}
let id = read_le_u64(data, *pos)?;
*pos += 8;
let text_len = read_le_u32(data, *pos)? as usize;
*pos += 4;
let text_end = *pos + text_len;
if text_end > body_start + body_len || text_end > data.len() {
tracing::warn!("BM25 WAL add entry truncated at offset {body_start}");
*pos = body_start + body_len;
return Ok(0);
}
let text = std::str::from_utf8(&data[*pos..text_end])
.map_err(|e| Error::Index(format!("BM25 WAL add: invalid utf8 at {body_start}: {e}")))?;
index.add_document(id, text);
*pos = text_end;
Ok(1)
}
fn replay_remove_entry(data: &[u8], pos: &mut usize, index: &Bm25Index) -> Result<u64> {
let id = read_le_u64(data, *pos)?;
*pos += 8;
index.remove_document(id);
Ok(1)
}
#[inline]
fn read_le_u64(data: &[u8], pos: usize) -> Result<u64> {
data[pos..pos + 8]
.try_into()
.map(u64::from_le_bytes)
.map_err(|_| Error::Index(format!("BM25 WAL: corrupt u64 at offset {pos}")))
}
#[inline]
fn read_le_u32(data: &[u8], pos: usize) -> Result<u32> {
data[pos..pos + 4]
.try_into()
.map(u32::from_le_bytes)
.map_err(|_| Error::Index(format!("BM25 WAL: corrupt u32 at offset {pos}")))
}