use crate::error::{Error, Result};
use crate::sharded::checkpoint_log::CheckpointEntry;
use crate::sharded::head;
use crate::sharded::index::{CheckpointIndex, CheckpointMetadata};
use parking_lot::Mutex;
use std::fs::{File, OpenOptions};
use std::io::{BufWriter, Seek, SeekFrom, Write};
use std::path::{Path, PathBuf};
use std::time::{SystemTime, UNIX_EPOCH};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CheckpointData {
pub offsets: Vec<u64>,
pub timestamp: u64,
pub shard_count: u16,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct PruneStats {
pub shards_pruned: u16,
pub segments_deleted: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct CompactionStats {
pub checkpoints_before: usize,
pub checkpoints_after: usize,
pub bytes_reclaimed: u64,
}
pub struct CheckpointManager {
root: PathBuf,
shard_count: u16,
index: CheckpointIndex,
log_file: Mutex<BufWriter<File>>,
}
impl CheckpointManager {
pub fn new(root: PathBuf, shard_count: u16) -> Result<Self> {
let log_path = root.join("checkpoints.log");
let index = CheckpointIndex::rebuild_from_log(&log_path)?;
let mut file = OpenOptions::new()
.create(true)
.append(true)
.open(&log_path)?;
file.seek(SeekFrom::End(0))?;
let log_file = Mutex::new(BufWriter::new(file));
Ok(Self {
root,
shard_count,
index,
log_file,
})
}
pub fn create(&self, user_id: &[u8], offsets: Vec<u64>) -> Result<()> {
if offsets.len() != self.shard_count as usize {
return Err(Error::ShardCountMismatch {
expected: self.shard_count,
found: offsets.len() as u16,
});
}
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|e| {
Error::Io(std::io::Error::new(
std::io::ErrorKind::Other,
format!("System time error: {}", e),
))
})?
.as_nanos() as u64;
let entry = CheckpointEntry::new(user_id.to_vec(), timestamp, offsets)?;
let serialized = entry.serialize()?;
let mut log = self.log_file.lock();
let file_offset = log.stream_position()?;
log.write_all(&serialized)?;
log.flush()?;
log.get_ref().sync_data()?;
drop(log);
self.index.insert(
user_id.to_vec(),
CheckpointMetadata {
file_offset,
timestamp,
shard_count: self.shard_count,
},
);
head::write_head(&self.root, user_id)?;
Ok(())
}
pub fn load(&self, user_id: &[u8]) -> Result<CheckpointData> {
let metadata = self
.index
.get(user_id)
.ok_or_else(|| Error::CheckpointNotFound(String::from_utf8_lossy(user_id).to_string()))?;
if metadata.shard_count != self.shard_count {
return Err(Error::ShardCountMismatch {
expected: self.shard_count,
found: metadata.shard_count,
});
}
let log_path = self.root.join("checkpoints.log");
let mut file = File::open(&log_path)?;
file.seek(SeekFrom::Start(metadata.file_offset))?;
let (entry, _) = CheckpointEntry::deserialize(&mut file, metadata.file_offset)?;
if entry.user_id != user_id {
return Err(Error::CheckpointCorrupted {
offset: metadata.file_offset,
reason: "user_id mismatch in entry".into(),
});
}
Ok(CheckpointData {
offsets: entry.offsets,
timestamp: entry.timestamp,
shard_count: entry.shard_count,
})
}
pub fn load_latest(&self) -> Result<(Vec<u8>, CheckpointData)> {
match head::read_head(&self.root)? {
Some(user_id) => {
let data = self.load(&user_id)?;
Ok((user_id, data))
}
None => {
let sorted = self.index.all_sorted_by_time();
if sorted.is_empty() {
return Err(Error::NoCheckpoints);
}
let (user_id, _) = sorted.last().unwrap();
let data = self.load(user_id)?;
Ok((user_id.clone(), data))
}
}
}
pub fn list_checkpoints(&self) -> Vec<(Vec<u8>, u64)> {
self.index
.all_sorted_by_time()
.into_iter()
.map(|(id, meta)| (id, meta.timestamp))
.collect()
}
pub fn compact(&self, keep_latest_n: usize) -> Result<CompactionStats> {
let log_path = self.root.join("checkpoints.log");
let temp_path = self.root.join("checkpoints.log.tmp");
let old_size = std::fs::metadata(&log_path).map(|m| m.len()).unwrap_or(0);
let mut all_checkpoints = self.index.all_sorted_by_time();
let checkpoints_before = all_checkpoints.len();
if all_checkpoints.len() > keep_latest_n {
all_checkpoints.drain(0..all_checkpoints.len() - keep_latest_n);
}
let checkpoints_after = all_checkpoints.len();
let mut temp_file = BufWriter::new(File::create(&temp_path)?);
let mut new_index = CheckpointIndex::new();
let mut new_offset = 0u64;
for (user_id, old_metadata) in &all_checkpoints {
let data = self.load(user_id)?;
let entry = CheckpointEntry::new(user_id.clone(), old_metadata.timestamp, data.offsets)?;
let serialized = entry.serialize()?;
temp_file.write_all(&serialized)?;
new_index.insert(
user_id.clone(),
CheckpointMetadata {
file_offset: new_offset,
timestamp: old_metadata.timestamp,
shard_count: old_metadata.shard_count,
},
);
new_offset += serialized.len() as u64;
}
temp_file.flush()?;
temp_file.get_ref().sync_all()?;
drop(temp_file);
drop(self.log_file.lock());
std::fs::rename(&temp_path, &log_path)?;
let dir = File::open(&self.root)?;
dir.sync_all()?;
let mut file = OpenOptions::new().append(true).open(&log_path)?;
file.seek(SeekFrom::End(0))?;
*self.log_file.lock() = BufWriter::new(file);
self.index.replace_with(new_index);
let new_size = std::fs::metadata(&log_path)?.len();
Ok(CompactionStats {
checkpoints_before,
checkpoints_after,
bytes_reclaimed: old_size.saturating_sub(new_size),
})
}
pub fn checkpoint_count(&self) -> usize {
self.index.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_create_and_load_checkpoint() {
let dir = TempDir::new().unwrap();
let manager = CheckpointManager::new(dir.path().to_path_buf(), 4).unwrap();
let offsets = vec![100, 200, 300, 400];
manager.create(b"ckpt_1", offsets.clone()).unwrap();
let data = manager.load(b"ckpt_1").unwrap();
assert_eq!(data.offsets, offsets);
assert_eq!(data.shard_count, 4);
}
#[test]
fn test_load_latest() {
let dir = TempDir::new().unwrap();
let manager = CheckpointManager::new(dir.path().to_path_buf(), 2).unwrap();
manager.create(b"ckpt_1", vec![10, 20]).unwrap();
std::thread::sleep(std::time::Duration::from_millis(10));
manager.create(b"ckpt_2", vec![30, 40]).unwrap();
let (user_id, data) = manager.load_latest().unwrap();
assert_eq!(user_id, b"ckpt_2");
assert_eq!(data.offsets, vec![30, 40]);
}
#[test]
fn test_list_checkpoints() {
let dir = TempDir::new().unwrap();
let manager = CheckpointManager::new(dir.path().to_path_buf(), 1).unwrap();
manager.create(b"ckpt_1", vec![10]).unwrap();
std::thread::sleep(std::time::Duration::from_millis(10));
manager.create(b"ckpt_2", vec![20]).unwrap();
let list = manager.list_checkpoints();
assert_eq!(list.len(), 2);
assert_eq!(list[0].0, b"ckpt_1");
assert_eq!(list[1].0, b"ckpt_2");
assert!(list[0].1 < list[1].1);
}
#[test]
fn test_load_nonexistent_returns_error() {
let dir = TempDir::new().unwrap();
let manager = CheckpointManager::new(dir.path().to_path_buf(), 2).unwrap();
let result = manager.load(b"nonexistent");
assert!(matches!(result, Err(Error::CheckpointNotFound(_))));
}
#[test]
fn test_shard_count_mismatch_on_create() {
let dir = TempDir::new().unwrap();
let manager = CheckpointManager::new(dir.path().to_path_buf(), 4).unwrap();
let result = manager.create(b"ckpt", vec![1, 2]);
assert!(matches!(result, Err(Error::ShardCountMismatch { .. })));
}
#[test]
fn test_compact_keeps_latest() {
let dir = TempDir::new().unwrap();
let manager = CheckpointManager::new(dir.path().to_path_buf(), 2).unwrap();
for i in 0..5 {
let user_id = format!("ckpt_{}", i);
manager
.create(user_id.as_bytes(), vec![i as u64 * 10, i as u64 * 20])
.unwrap();
std::thread::sleep(std::time::Duration::from_millis(1000)); }
assert_eq!(manager.checkpoint_count(), 5);
let stats = manager.compact(2).unwrap();
assert_eq!(stats.checkpoints_before, 5);
assert_eq!(stats.checkpoints_after, 2);
assert!(stats.bytes_reclaimed > 0);
assert_eq!(manager.checkpoint_count(), 2);
let can_load_0 = manager.load(b"ckpt_0").is_ok();
let can_load_4 = manager.load(b"ckpt_4").is_ok();
assert!(can_load_4, "Latest checkpoint should be kept");
assert!(!can_load_0, "Oldest checkpoint should be deleted");
}
}