use super::snapshot::crc32_hash;
use std::fs::{File, OpenOptions};
use std::io::{self, BufReader, BufWriter, Read, Write};
use std::path::Path;
const CRC_SIZE: usize = 4;
const EDGE_PAYLOAD_SIZE: usize = 1 + 4 + 4 + 1;
const ENTRY_PAYLOAD_SIZE: usize = 1 + 4 + 1;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum DeltaOp {
AddEdge = 1,
RemoveEdge = 2,
SetEntry = 3,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum HnswDelta {
AddEdge {
from: u32,
to: u32,
layer: u8,
},
RemoveEdge {
from: u32,
to: u32,
layer: u8,
},
SetEntry {
node: u32,
max_layer: u8,
},
}
fn serialize_delta(delta: &HnswDelta, buf: &mut Vec<u8>) {
match *delta {
HnswDelta::AddEdge { from, to, layer } => {
buf.push(DeltaOp::AddEdge as u8);
buf.extend_from_slice(&from.to_le_bytes());
buf.extend_from_slice(&to.to_le_bytes());
buf.push(layer);
}
HnswDelta::RemoveEdge { from, to, layer } => {
buf.push(DeltaOp::RemoveEdge as u8);
buf.extend_from_slice(&from.to_le_bytes());
buf.extend_from_slice(&to.to_le_bytes());
buf.push(layer);
}
HnswDelta::SetEntry { node, max_layer } => {
buf.push(DeltaOp::SetEntry as u8);
buf.extend_from_slice(&node.to_le_bytes());
buf.push(max_layer);
}
}
}
fn deserialize_delta(payload: &[u8]) -> Option<HnswDelta> {
let op = *payload.first()?;
match op {
1 if payload.len() == EDGE_PAYLOAD_SIZE => {
let from = u32::from_le_bytes([payload[1], payload[2], payload[3], payload[4]]);
let to = u32::from_le_bytes([payload[5], payload[6], payload[7], payload[8]]);
let layer = payload[9];
Some(HnswDelta::AddEdge { from, to, layer })
}
2 if payload.len() == EDGE_PAYLOAD_SIZE => {
let from = u32::from_le_bytes([payload[1], payload[2], payload[3], payload[4]]);
let to = u32::from_le_bytes([payload[5], payload[6], payload[7], payload[8]]);
let layer = payload[9];
Some(HnswDelta::RemoveEdge { from, to, layer })
}
3 if payload.len() == ENTRY_PAYLOAD_SIZE => {
let node = u32::from_le_bytes([payload[1], payload[2], payload[3], payload[4]]);
let max_layer = payload[5];
Some(HnswDelta::SetEntry { node, max_layer })
}
_ => None,
}
}
pub struct HnswDeltaWriter {
writer: BufWriter<File>,
entry_count: u64,
}
impl HnswDeltaWriter {
pub fn open(path: &Path) -> io::Result<Self> {
let file = OpenOptions::new().create(true).append(true).open(path)?;
Ok(Self {
writer: BufWriter::new(file),
entry_count: 0,
})
}
pub fn append(&mut self, delta: &HnswDelta) -> io::Result<()> {
let mut buf = Vec::with_capacity(EDGE_PAYLOAD_SIZE + CRC_SIZE);
serialize_delta(delta, &mut buf);
let crc = crc32_hash(&buf);
buf.extend_from_slice(&crc.to_le_bytes());
self.writer.write_all(&buf)?;
self.entry_count += 1;
Ok(())
}
pub fn sync(&mut self) -> io::Result<()> {
self.writer.flush()?;
self.writer.get_ref().sync_all()
}
#[must_use]
pub fn entry_count(&self) -> u64 {
self.entry_count
}
}
pub struct HnswDeltaReader {
reader: BufReader<File>,
}
impl HnswDeltaReader {
pub fn open(path: &Path) -> io::Result<Self> {
let file = File::open(path)?;
Ok(Self {
reader: BufReader::new(file),
})
}
pub fn read_all(&mut self) -> io::Result<Vec<HnswDelta>> {
let mut entries = Vec::new();
loop {
match self.read_one() {
Ok(Some(delta)) => entries.push(delta),
Ok(None) => break,
Err(e)
if e.kind() == io::ErrorKind::UnexpectedEof
|| e.kind() == io::ErrorKind::InvalidData =>
{
break;
}
Err(e) => return Err(e),
}
}
Ok(entries)
}
}
impl HnswDeltaReader {
fn read_one(&mut self) -> io::Result<Option<HnswDelta>> {
let mut op_buf = [0u8; 1];
match self.reader.read_exact(&mut op_buf) {
Ok(()) => {}
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(e),
}
let payload_size = payload_size_for_op(op_buf[0])?;
let mut payload = vec![0u8; payload_size];
payload[0] = op_buf[0];
self.reader.read_exact(&mut payload[1..])?;
let mut crc_buf = [0u8; CRC_SIZE];
self.reader.read_exact(&mut crc_buf)?;
let stored_crc = u32::from_le_bytes(crc_buf);
let computed_crc = crc32_hash(&payload);
if stored_crc != computed_crc {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"CRC32 mismatch in HNSW delta WAL entry",
));
}
deserialize_delta(&payload).map_or_else(
|| {
Err(io::Error::new(
io::ErrorKind::InvalidData,
"unrecognized HNSW delta op",
))
},
|d| Ok(Some(d)),
)
}
}
fn payload_size_for_op(op: u8) -> io::Result<usize> {
match op {
1 | 2 => Ok(EDGE_PAYLOAD_SIZE),
3 => Ok(ENTRY_PAYLOAD_SIZE),
_ => Err(io::Error::new(
io::ErrorKind::InvalidData,
"unknown HNSW delta WAL op code",
)),
}
}