use crate::persistence::directory::Directory;
use crate::persistence::error::{PersistenceError, PersistenceResult};
use crate::persistence::format::CHECKPOINT_MAGIC;
use crc32fast::Hasher;
use std::io::{Read, Write};
use std::sync::Arc;
#[cfg(feature = "persistence")]
use postcard;
#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct CheckpointHeader {
pub magic: [u8; 4],
pub format_version: u32,
pub entry_id: u64,
pub segment_count: u32,
pub segment_list_offset: u64,
pub doc_count: u64,
pub created_at: u64,
pub checksum: u32,
}
impl CheckpointHeader {
pub const SIZE: usize = 4 + 4 + 8 + 4 + 8 + 8 + 8 + 4;
pub fn write<W: Write>(&self, writer: &mut W) -> PersistenceResult<()> {
use byteorder::{LittleEndian, WriteBytesExt};
writer.write_all(&self.magic)?;
writer.write_u32::<LittleEndian>(self.format_version)?;
writer.write_u64::<LittleEndian>(self.entry_id)?;
writer.write_u32::<LittleEndian>(self.segment_count)?;
writer.write_u64::<LittleEndian>(self.segment_list_offset)?;
writer.write_u64::<LittleEndian>(self.doc_count)?;
writer.write_u64::<LittleEndian>(self.created_at)?;
writer.write_u32::<LittleEndian>(self.checksum)?;
Ok(())
}
pub fn read<R: Read>(reader: &mut R) -> PersistenceResult<Self> {
use byteorder::{LittleEndian, ReadBytesExt};
let mut magic = [0u8; 4];
reader.read_exact(&mut magic)?;
if magic != CHECKPOINT_MAGIC {
return Err(PersistenceError::Format(format!(
"Invalid checkpoint magic bytes (expected: {:?}, actual: {:?})",
CHECKPOINT_MAGIC, magic
)));
}
let format_version = reader.read_u32::<LittleEndian>()?;
let entry_id = reader.read_u64::<LittleEndian>()?;
let segment_count = reader.read_u32::<LittleEndian>()?;
let segment_list_offset = reader.read_u64::<LittleEndian>()?;
let doc_count = reader.read_u64::<LittleEndian>()?;
let created_at = reader.read_u64::<LittleEndian>()?;
let checksum = reader.read_u32::<LittleEndian>()?;
Ok(Self {
magic,
format_version,
entry_id,
segment_count,
segment_list_offset,
doc_count,
created_at,
checksum,
})
}
pub fn validate_checksum<R: Read>(&self, reader: &mut R) -> PersistenceResult<()> {
let mut all_data = Vec::new();
reader.read_to_end(&mut all_data)?;
let segment_list_start = self.segment_list_offset as usize;
if segment_list_start > all_data.len() {
return Err(PersistenceError::Format(format!(
"Segment list offset beyond file size (expected < {}, actual: {})",
all_data.len(),
segment_list_start
)));
}
let segment_list_bytes = &all_data[segment_list_start..];
let mut hasher = Hasher::new();
hasher.update(&self.magic);
hasher.update(&self.format_version.to_le_bytes());
hasher.update(&self.entry_id.to_le_bytes());
hasher.update(&self.segment_count.to_le_bytes());
hasher.update(&self.segment_list_offset.to_le_bytes());
hasher.update(&self.doc_count.to_le_bytes());
hasher.update(&self.created_at.to_le_bytes());
hasher.update(segment_list_bytes);
let expected_checksum = hasher.finalize();
if expected_checksum != self.checksum {
return Err(PersistenceError::ChecksumMismatch {
expected: self.checksum,
actual: expected_checksum,
});
}
Ok(())
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct SegmentMetadata {
pub segment_id: u64,
pub path: String,
pub doc_count: u32,
pub max_doc_id: u32,
pub size_bytes: u64,
}
pub struct CheckpointWriter {
directory: Arc<dyn Directory>,
}
impl CheckpointWriter {
pub fn new(directory: Box<dyn Directory>) -> Self {
Self {
directory: Arc::<dyn Directory>::from(directory),
}
}
pub fn new_arc(directory: impl Into<Arc<dyn Directory>>) -> Self {
Self {
directory: directory.into(),
}
}
pub fn create_checkpoint(
&self,
entry_id: u64,
segments: &[SegmentMetadata],
) -> PersistenceResult<String> {
use std::time::{SystemTime, UNIX_EPOCH};
let checkpoint_id = entry_id;
let checkpoint_path = format!("checkpoints/checkpoint_{}.bin", checkpoint_id);
self.directory.create_dir_all("checkpoints")?;
let created_at = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|e| PersistenceError::InvalidState(format!("system clock error: {e}")))?
.as_secs();
let segment_list_bytes = postcard::to_allocvec(segments).map_err(|e| {
PersistenceError::Serialization(format!(
"Failed to serialize segment list with postcard: {}",
e
))
})?;
let mut hasher = Hasher::new();
hasher.update(&CHECKPOINT_MAGIC);
hasher.update(&1u32.to_le_bytes()); hasher.update(&entry_id.to_le_bytes());
hasher.update(&(segments.len() as u32).to_le_bytes());
hasher.update(&(CheckpointHeader::SIZE as u64).to_le_bytes()); hasher.update(
&segments
.iter()
.map(|s| s.doc_count as u64)
.sum::<u64>()
.to_le_bytes(),
); hasher.update(&created_at.to_le_bytes()); hasher.update(&segment_list_bytes);
let checksum = hasher.finalize();
let header = CheckpointHeader {
magic: CHECKPOINT_MAGIC,
format_version: 1,
entry_id,
segment_count: segments.len() as u32,
segment_list_offset: CheckpointHeader::SIZE as u64, doc_count: segments.iter().map(|s| s.doc_count as u64).sum(),
created_at,
checksum,
};
let mut checkpoint_data = Vec::new();
header.write(&mut checkpoint_data)?;
checkpoint_data.extend_from_slice(&segment_list_bytes);
self.directory
.atomic_write(&checkpoint_path, &checkpoint_data)?;
let checkpoint_segments_dir = format!("checkpoints/checkpoint_{}/segments", checkpoint_id);
self.directory.create_dir_all(&checkpoint_segments_dir)?;
for segment in segments {
let source_segment_dir = &segment.path;
let dest_segment_dir =
format!("{}/segment_{}", checkpoint_segments_dir, segment.segment_id);
if self.directory.exists(source_segment_dir) {
self.directory.create_dir_all(&dest_segment_dir)?;
let files = self.directory.list_dir(source_segment_dir)?;
for file_name in files {
let source_file = format!("{}/{}", source_segment_dir, file_name);
let dest_file = format!("{}/{}", dest_segment_dir, file_name);
let mut source_reader = self.directory.open_file(&source_file)?;
let mut file_data = Vec::new();
source_reader.read_to_end(&mut file_data)?;
self.directory.atomic_write(&dest_file, &file_data)?;
}
}
}
Ok(checkpoint_path)
}
}
pub struct CheckpointReader {
directory: Arc<dyn Directory>,
}
impl CheckpointReader {
pub fn new(directory: Box<dyn Directory>) -> Self {
Self {
directory: Arc::<dyn Directory>::from(directory),
}
}
pub fn new_arc(directory: impl Into<Arc<dyn Directory>>) -> Self {
Self {
directory: directory.into(),
}
}
pub fn load_checkpoint(&self, checkpoint_path: &str) -> PersistenceResult<CheckpointHeader> {
let mut file = self.directory.open_file(checkpoint_path)?;
let header = CheckpointHeader::read(&mut file)?;
let mut file_for_checksum = self.directory.open_file(checkpoint_path)?;
header.validate_checksum(&mut file_for_checksum)?;
Ok(header)
}
pub fn load_checkpoint_with_segments(
&self,
checkpoint_path: &str,
) -> PersistenceResult<(CheckpointHeader, Vec<SegmentMetadata>)> {
let mut file = self.directory.open_file(checkpoint_path)?;
let mut all_data = Vec::new();
file.read_to_end(&mut all_data)?;
let mut header_reader = std::io::Cursor::new(&all_data);
let header = CheckpointHeader::read(&mut header_reader)?;
let segment_list_start = header.segment_list_offset as usize;
if segment_list_start > all_data.len() {
return Err(PersistenceError::Format(format!(
"Segment list offset beyond file size (expected < {}, actual: {})",
all_data.len(),
segment_list_start
)));
}
let segment_list_bytes = &all_data[segment_list_start..];
let segments: Vec<SegmentMetadata> =
postcard::from_bytes(segment_list_bytes).map_err(|e| {
PersistenceError::Deserialization(format!(
"Failed to deserialize segment list: {}",
e
))
})?;
let mut file_for_checksum = self.directory.open_file(checkpoint_path)?;
header.validate_checksum(&mut file_for_checksum)?;
Ok((header, segments))
}
pub fn list_checkpoints(&self) -> PersistenceResult<Vec<String>> {
if !self.directory.exists("checkpoints") {
return Ok(Vec::new());
}
let mut checkpoints = self.directory.list_dir("checkpoints")?;
checkpoints.retain(|f| f.ends_with(".bin"));
checkpoints.sort();
Ok(checkpoints)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_checkpoint_header_roundtrip() {
let mut buffer = Vec::new();
let header = CheckpointHeader {
magic: CHECKPOINT_MAGIC,
format_version: 1,
entry_id: 100,
segment_count: 5,
segment_list_offset: 64,
doc_count: 1000,
created_at: 1234567890,
checksum: 42,
};
header.write(&mut buffer).unwrap();
assert_eq!(buffer.len(), CheckpointHeader::SIZE);
let mut reader = std::io::Cursor::new(&buffer);
let read_header = CheckpointHeader::read(&mut reader).unwrap();
assert_eq!(read_header.magic, header.magic);
assert_eq!(read_header.entry_id, header.entry_id);
assert_eq!(read_header.segment_count, header.segment_count);
}
}