use rustc_hash::FxHashMap;
use std::fs::File;
use std::io::{self, Write};
use std::path::Path;
pub(crate) const SNAPSHOT_MAGIC: &[u8; 4] = b"VSNP";
pub(crate) const SNAPSHOT_VERSION: u8 = 1;
pub(crate) const DEFAULT_SNAPSHOT_THRESHOLD: u64 = 10 * 1024 * 1024;
#[inline]
#[allow(clippy::cast_possible_truncation)] pub(crate) fn crc32_hash(data: &[u8]) -> u32 {
const CRC32_TABLE: [u32; 256] = {
let mut table = [0u32; 256];
let mut i = 0;
while i < 256 {
let mut crc = i as u32;
let mut j = 0;
while j < 8 {
if crc & 1 != 0 {
crc = (crc >> 1) ^ 0xEDB8_8320;
} else {
crc >>= 1;
}
j += 1;
}
table[i] = crc;
i += 1;
}
table
};
let mut crc = 0xFFFF_FFFF_u32;
for &byte in data {
let idx = ((crc ^ u32::from(byte)) & 0xFF) as usize;
crc = (crc >> 8) ^ CRC32_TABLE[idx];
}
!crc
}
fn validate_snapshot_format(data: &[u8]) -> io::Result<()> {
if data.len() < 25 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Snapshot too small",
));
}
if &data[0..4] != SNAPSHOT_MAGIC {
return Err(io::Error::new(io::ErrorKind::InvalidData, "Invalid magic"));
}
if data[4] != SNAPSHOT_VERSION {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Unsupported version",
));
}
Ok(())
}
fn validate_snapshot_header(data: &[u8]) -> io::Result<(u64, usize)> {
validate_snapshot_format(data)?;
let wal_pos = u64::from_le_bytes(
data[5..13]
.try_into()
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Invalid WAL position"))?,
);
let entry_count_u64 = u64::from_le_bytes(
data[13..21]
.try_into()
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Invalid entry count"))?,
);
let max_possible_entries = data.len().saturating_sub(25) / 16;
if entry_count_u64 > max_possible_entries as u64 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"Entry count exceeds data size",
));
}
#[allow(clippy::cast_possible_truncation)] let entry_count = entry_count_u64 as usize;
if data.len() != 21 + entry_count * 16 + 4 {
return Err(io::Error::new(io::ErrorKind::InvalidData, "Size mismatch"));
}
let stored_crc = u32::from_le_bytes(
data[data.len() - 4..]
.try_into()
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Invalid CRC"))?,
);
if stored_crc != crc32_hash(&data[..data.len() - 4]) {
return Err(io::Error::new(io::ErrorKind::InvalidData, "CRC mismatch"));
}
Ok((wal_pos, entry_count))
}
pub(crate) fn load_snapshot(snapshot_path: &Path) -> io::Result<(FxHashMap<u64, u64>, u64)> {
if !snapshot_path.exists() {
return Err(io::Error::new(io::ErrorKind::NotFound, "No snapshot"));
}
let data = std::fs::read(snapshot_path)?;
let (wal_pos, entry_count) = validate_snapshot_header(&data)?;
let mut index = FxHashMap::default();
index.reserve(entry_count);
let entries_start = 21;
for i in 0..entry_count {
let offset = entries_start + i * 16;
let id = u64::from_le_bytes(
data[offset..offset + 8]
.try_into()
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Invalid entry ID"))?,
);
let wal_offset = u64::from_le_bytes(
data[offset + 8..offset + 16]
.try_into()
.map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "Invalid entry offset"))?,
);
index.insert(id, wal_offset);
}
Ok((index, wal_pos))
}
pub(crate) fn create_snapshot_file(
dir: &Path,
index: &FxHashMap<u64, u64>,
wal_pos: u64,
) -> io::Result<()> {
let snapshot_path = dir.join("payloads.snapshot");
let entry_count = index.len();
let buf_size = 21 + entry_count * 16 + 4; let mut buf = Vec::with_capacity(buf_size);
buf.extend_from_slice(SNAPSHOT_MAGIC);
buf.push(SNAPSHOT_VERSION);
buf.extend_from_slice(&wal_pos.to_le_bytes());
buf.extend_from_slice(&(entry_count as u64).to_le_bytes());
for (&id, &offset) in index {
buf.extend_from_slice(&id.to_le_bytes());
buf.extend_from_slice(&offset.to_le_bytes());
}
let crc = crc32_hash(&buf);
buf.extend_from_slice(&crc.to_le_bytes());
let temp_path = dir.join("payloads.snapshot.tmp");
{
let file = File::create(&temp_path)?;
let mut writer = io::BufWriter::new(file);
writer.write_all(&buf)?;
writer.flush()?;
writer.get_ref().sync_all()?;
}
std::fs::rename(&temp_path, &snapshot_path)?;
Ok(())
}
#[must_use]
pub(crate) fn should_create_snapshot(last_snapshot_pos: u64, current_pos: u64) -> bool {
current_pos.saturating_sub(last_snapshot_pos) >= DEFAULT_SNAPSHOT_THRESHOLD
}