#![cfg(not(target_arch = "wasm32"))]
use crate::storage::compression::{self, Codec};
use crate::{Document, RagError, Result};
use serde::{Deserialize, Serialize};
use std::fs::{self, File, OpenOptions};
use std::io::{BufReader, BufWriter, Read, Write};
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{SystemTime, UNIX_EPOCH};
static ATOMIC_WRITE_COUNTER: AtomicU64 = AtomicU64::new(0);
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IncrementalConfig {
pub checkpoint_threshold: usize,
pub wal_sync_interval: usize,
pub max_wal_size: usize,
pub checkpoint_codec: Codec,
pub sync_on_write: bool,
pub keep_checkpoints: usize,
}
impl Default for IncrementalConfig {
fn default() -> Self {
Self {
checkpoint_threshold: 10_000,
wal_sync_interval: 100,
max_wal_size: 100 * 1024 * 1024, checkpoint_codec: Codec::Gzip,
sync_on_write: false,
keep_checkpoints: 2,
}
}
}
impl IncrementalConfig {
pub fn with_checkpoint_threshold(mut self, threshold: usize) -> Self {
self.checkpoint_threshold = threshold;
self
}
pub fn with_wal_sync_interval(mut self, interval: usize) -> Self {
self.wal_sync_interval = interval;
self
}
pub fn with_max_wal_size(mut self, size: usize) -> Self {
self.max_wal_size = size;
self
}
pub fn with_checkpoint_codec(mut self, codec: Codec) -> Self {
self.checkpoint_codec = codec;
self
}
pub fn with_sync_on_write(mut self, sync: bool) -> Self {
self.sync_on_write = sync;
self
}
pub fn with_keep_checkpoints(mut self, count: usize) -> Self {
self.keep_checkpoints = count;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum WalOperation {
Add(Document),
Remove(String),
Clear,
Checkpoint { checkpoint_id: u64 },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WalEntry {
pub seq: u64,
pub timestamp: u64,
pub operation: WalOperation,
pub checksum: u32,
}
impl WalEntry {
fn new(seq: u64, operation: WalOperation) -> Self {
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as u64;
let mut entry = Self {
seq,
timestamp,
operation,
checksum: 0,
};
entry.checksum = entry.compute_checksum();
entry
}
fn compute_checksum(&self) -> u32 {
let data = serde_json::to_vec(&(&self.seq, &self.timestamp, &self.operation)).unwrap();
crc32fast::hash(&data)
}
pub fn verify(&self) -> bool {
self.checksum == self.compute_checksum()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Manifest {
pub current_checkpoint: Option<u64>,
pub wal_seq: u64,
pub ops_since_checkpoint: usize,
pub total_documents: usize,
pub embedding_dim: usize,
pub index_type: String,
pub last_modified: u64,
}
impl Default for Manifest {
fn default() -> Self {
Self {
current_checkpoint: None,
wal_seq: 0,
ops_since_checkpoint: 0,
total_documents: 0,
embedding_dim: 0,
index_type: String::new(),
last_modified: 0,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CheckpointMeta {
pub id: u64,
pub wal_seq: u64,
pub document_count: usize,
pub embedding_dim: usize,
pub index_type: String,
pub created_at: u64,
pub original_size: usize,
pub compressed_size: usize,
pub codec: Codec,
}
struct WalWriter {
file: BufWriter<File>,
#[allow(dead_code)]
path: PathBuf, current_size: usize,
sync_on_write: bool,
}
impl WalWriter {
fn open(path: &Path, sync_on_write: bool) -> Result<Self> {
let file = OpenOptions::new()
.create(true)
.append(true)
.open(path)
.map_err(|e| RagError::StorageError(format!("Failed to open WAL: {}", e)))?;
let current_size = file.metadata().map(|m| m.len() as usize).unwrap_or(0);
Ok(Self {
file: BufWriter::new(file),
path: path.to_path_buf(),
current_size,
sync_on_write,
})
}
fn append(&mut self, entry: &WalEntry) -> Result<()> {
let data = serde_json::to_vec(entry)
.map_err(|e| RagError::StorageError(format!("WAL serialize failed: {}", e)))?;
let len = data.len() as u32;
self.file
.write_all(&len.to_le_bytes())
.map_err(|e| RagError::StorageError(format!("WAL write failed: {}", e)))?;
self.file
.write_all(&data)
.map_err(|e| RagError::StorageError(format!("WAL write failed: {}", e)))?;
self.current_size += 4 + data.len();
if self.sync_on_write {
self.sync()?;
}
Ok(())
}
fn sync(&mut self) -> Result<()> {
self.file
.flush()
.map_err(|e| RagError::StorageError(format!("WAL sync failed: {}", e)))?;
self.file
.get_ref()
.sync_all()
.map_err(|e| RagError::StorageError(format!("WAL sync failed: {}", e)))?;
Ok(())
}
fn size(&self) -> usize {
self.current_size
}
}
struct WalReader {
file: BufReader<File>,
}
impl WalReader {
fn open(path: &Path) -> Result<Self> {
let file = File::open(path)
.map_err(|e| RagError::StorageError(format!("Failed to open WAL: {}", e)))?;
Ok(Self {
file: BufReader::new(file),
})
}
fn read_from(&mut self, from_seq: u64) -> Result<Vec<WalEntry>> {
let mut entries = Vec::new();
let mut len_buf = [0u8; 4];
loop {
match self.file.read_exact(&mut len_buf) {
Ok(()) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
Err(e) => return Err(RagError::StorageError(format!("WAL read failed: {}", e))),
}
let len = u32::from_le_bytes(len_buf) as usize;
let mut data = vec![0u8; len];
match self.file.read_exact(&mut data) {
Ok(()) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => break,
Err(e) => return Err(RagError::StorageError(format!("WAL read failed: {}", e))),
}
let entry: WalEntry = serde_json::from_slice(&data)
.map_err(|e| RagError::StorageError(format!("WAL deserialize failed: {}", e)))?;
if !entry.verify() {
return Err(RagError::StorageError(format!(
"WAL entry {} failed integrity check",
entry.seq
)));
}
if entry.seq > from_seq {
entries.push(entry);
}
}
Ok(entries)
}
}
pub struct IncrementalStorage {
base_path: PathBuf,
config: IncrementalConfig,
manifest: Manifest,
wal_writer: Option<WalWriter>,
ops_since_sync: usize,
}
impl IncrementalStorage {
pub fn new<P: AsRef<Path>>(base_path: P, config: IncrementalConfig) -> Result<Self> {
let base_path = base_path.as_ref().to_path_buf();
fs::create_dir_all(&base_path)
.map_err(|e| RagError::StorageError(format!("Failed to create storage dir: {}", e)))?;
let manifest_path = base_path.join("manifest.json");
let manifest = if manifest_path.exists() {
let data = fs::read_to_string(&manifest_path)
.map_err(|e| RagError::StorageError(format!("Failed to read manifest: {}", e)))?;
serde_json::from_str(&data)
.map_err(|e| RagError::StorageError(format!("Failed to parse manifest: {}", e)))?
} else {
Manifest::default()
};
let wal_path = base_path.join(format!(
"wal_{:05}.log",
manifest.current_checkpoint.unwrap_or(0)
));
let wal_writer = WalWriter::open(&wal_path, config.sync_on_write)?;
Ok(Self {
base_path,
config,
manifest,
wal_writer: Some(wal_writer),
ops_since_sync: 0,
})
}
pub fn log_add(&mut self, doc: &Document) -> Result<()> {
self.log_operation(WalOperation::Add(doc.clone()))
}
pub fn log_remove(&mut self, id: &str) -> Result<()> {
self.log_operation(WalOperation::Remove(id.to_string()))
}
pub fn log_clear(&mut self) -> Result<()> {
self.log_operation(WalOperation::Clear)
}
fn log_operation(&mut self, operation: WalOperation) -> Result<()> {
self.manifest.wal_seq += 1;
self.manifest.ops_since_checkpoint += 1;
let entry = WalEntry::new(self.manifest.wal_seq, operation);
if let Some(ref mut writer) = self.wal_writer {
writer.append(&entry)?;
self.ops_since_sync += 1;
if self.ops_since_sync >= self.config.wal_sync_interval {
writer.sync()?;
self.ops_since_sync = 0;
}
}
self.manifest.last_modified = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
Ok(())
}
pub fn needs_checkpoint(&self) -> bool {
self.manifest.ops_since_checkpoint >= self.config.checkpoint_threshold
|| self.wal_writer.as_ref().map(|w| w.size()).unwrap_or(0) >= self.config.max_wal_size
}
pub fn checkpoint<T: Serialize>(
&mut self,
index: &T,
meta: IndexMetadata,
) -> Result<CheckpointMeta> {
if let Some(ref mut writer) = self.wal_writer {
writer.sync()?;
}
let checkpoint_id = self.manifest.current_checkpoint.map(|c| c + 1).unwrap_or(1);
let data = serde_json::to_vec(index)
.map_err(|e| RagError::StorageError(format!("Checkpoint serialize failed: {}", e)))?;
let original_size = data.len();
let (compressed, _stats) = compression::compress_with(&data, self.config.checkpoint_codec)?;
let compressed_size = compressed.len();
let checkpoint_path = self
.base_path
.join(format!("checkpoint_{:05}.bin", checkpoint_id));
Self::write_atomic_file(&checkpoint_path, &compressed)
.map_err(|e| RagError::StorageError(format!("Failed to write checkpoint: {}", e)))?;
let checkpoint_meta = CheckpointMeta {
id: checkpoint_id,
wal_seq: self.manifest.wal_seq,
document_count: meta.document_count,
embedding_dim: meta.embedding_dim,
index_type: meta.index_type.clone(),
created_at: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs(),
original_size,
compressed_size,
codec: self.config.checkpoint_codec,
};
let meta_path = self
.base_path
.join(format!("checkpoint_{:05}.meta", checkpoint_id));
let meta_json = serde_json::to_string_pretty(&checkpoint_meta)
.map_err(|e| RagError::StorageError(format!("Failed to serialize meta: {}", e)))?;
Self::write_atomic_file(&meta_path, meta_json.as_bytes()).map_err(|e| {
RagError::StorageError(format!("Failed to write checkpoint meta: {}", e))
})?;
self.manifest.wal_seq += 1;
let entry = WalEntry::new(
self.manifest.wal_seq,
WalOperation::Checkpoint { checkpoint_id },
);
if let Some(ref mut writer) = self.wal_writer {
writer.append(&entry)?;
writer.sync()?;
}
self.manifest.current_checkpoint = Some(checkpoint_id);
self.manifest.ops_since_checkpoint = 0;
self.manifest.total_documents = meta.document_count;
self.manifest.embedding_dim = meta.embedding_dim;
self.manifest.index_type = meta.index_type;
self.save_manifest()?;
self.rotate_wal(checkpoint_id)?;
self.cleanup_old_checkpoints(checkpoint_id)?;
Ok(checkpoint_meta)
}
pub fn load_checkpoint<T: for<'de> Deserialize<'de>>(
&self,
) -> Result<Option<(T, CheckpointMeta)>> {
let checkpoint_id = match self.manifest.current_checkpoint {
Some(id) => id,
None => return Ok(None),
};
let meta_path = self
.base_path
.join(format!("checkpoint_{:05}.meta", checkpoint_id));
let meta_json = fs::read_to_string(&meta_path).map_err(|e| {
RagError::StorageError(format!("Failed to read checkpoint meta: {}", e))
})?;
let meta: CheckpointMeta = serde_json::from_str(&meta_json).map_err(|e| {
RagError::StorageError(format!("Failed to parse checkpoint meta: {}", e))
})?;
let checkpoint_path = self
.base_path
.join(format!("checkpoint_{:05}.bin", checkpoint_id));
let compressed = fs::read(&checkpoint_path)
.map_err(|e| RagError::StorageError(format!("Failed to read checkpoint: {}", e)))?;
let data = compression::decompress(&compressed)?;
let index: T = serde_json::from_slice(&data)
.map_err(|e| RagError::StorageError(format!("Checkpoint deserialize failed: {}", e)))?;
Ok(Some((index, meta)))
}
pub fn get_wal_entries(&self) -> Result<Vec<WalEntry>> {
let checkpoint_seq = if let Some(cp_id) = self.manifest.current_checkpoint {
let meta_path = self.base_path.join(format!("checkpoint_{:05}.meta", cp_id));
if meta_path.exists() {
let meta_json = fs::read_to_string(&meta_path)
.map_err(|e| RagError::StorageError(format!("Failed to read meta: {}", e)))?;
let meta: CheckpointMeta = serde_json::from_str(&meta_json)
.map_err(|e| RagError::StorageError(format!("Failed to parse meta: {}", e)))?;
meta.wal_seq
} else {
0
}
} else {
0
};
let wal_path = self.base_path.join(format!(
"wal_{:05}.log",
self.manifest.current_checkpoint.unwrap_or(0)
));
if !wal_path.exists() {
return Ok(Vec::new());
}
let mut reader = WalReader::open(&wal_path)?;
reader.read_from(checkpoint_seq)
}
pub fn manifest(&self) -> &Manifest {
&self.manifest
}
pub fn stats(&self) -> StorageStats {
let wal_size = self.wal_writer.as_ref().map(|w| w.size()).unwrap_or(0);
let checkpoint_size = self
.manifest
.current_checkpoint
.map(|id| {
let path = self.base_path.join(format!("checkpoint_{:05}.bin", id));
fs::metadata(&path).map(|m| m.len() as usize).unwrap_or(0)
})
.unwrap_or(0);
StorageStats {
checkpoint_id: self.manifest.current_checkpoint,
wal_size,
checkpoint_size,
total_size: wal_size + checkpoint_size,
ops_since_checkpoint: self.manifest.ops_since_checkpoint,
total_documents: self.manifest.total_documents,
}
}
pub fn sync(&mut self) -> Result<()> {
if let Some(ref mut writer) = self.wal_writer {
writer.sync()?;
}
self.save_manifest()?;
Ok(())
}
fn save_manifest(&self) -> Result<()> {
let manifest_path = self.base_path.join("manifest.json");
let json = serde_json::to_string_pretty(&self.manifest)
.map_err(|e| RagError::StorageError(format!("Failed to serialize manifest: {}", e)))?;
Self::write_atomic_file(&manifest_path, json.as_bytes())
.map_err(|e| RagError::StorageError(format!("Failed to write manifest: {}", e)))?;
Ok(())
}
fn atomic_tmp_path(path: &Path) -> PathBuf {
let file_name = path.file_name().and_then(|f| f.to_str()).unwrap_or("file");
let counter = ATOMIC_WRITE_COUNTER.fetch_add(1, Ordering::Relaxed);
path.with_file_name(format!(
"{}.{}.{}.tmp",
file_name,
std::process::id(),
counter
))
}
fn write_atomic_file(path: &Path, data: &[u8]) -> std::io::Result<()> {
let tmp_path = Self::atomic_tmp_path(path);
{
let mut file = File::create(&tmp_path)?;
file.write_all(data)?;
file.sync_all()?;
}
fs::rename(&tmp_path, path).inspect_err(|_| {
let _ = fs::remove_file(&tmp_path);
})?;
Ok(())
}
fn rotate_wal(&mut self, checkpoint_id: u64) -> Result<()> {
if let Some(ref mut writer) = self.wal_writer {
writer.sync()?;
}
let new_wal_path = self.base_path.join(format!("wal_{:05}.log", checkpoint_id));
let new_writer = WalWriter::open(&new_wal_path, self.config.sync_on_write)?;
self.wal_writer = Some(new_writer);
let old_wal = self
.base_path
.join(format!("wal_{:05}.log", checkpoint_id.saturating_sub(1)));
if old_wal.exists() && old_wal != new_wal_path {
let _ = fs::remove_file(&old_wal);
}
Ok(())
}
fn cleanup_old_checkpoints(&self, current_id: u64) -> Result<()> {
let cutoff = if self.config.keep_checkpoints == 0 {
current_id.saturating_sub(1)
} else {
current_id.saturating_sub(self.config.keep_checkpoints as u64)
};
for entry in fs::read_dir(&self.base_path)
.map_err(|e| RagError::StorageError(format!("Failed to read dir: {}", e)))?
{
let entry =
entry.map_err(|e| RagError::StorageError(format!("Dir entry error: {}", e)))?;
let name = entry.file_name().to_string_lossy().to_string();
if name.starts_with("checkpoint_") {
if let Some(id_str) = name
.strip_prefix("checkpoint_")
.and_then(|s| s.split('.').next())
{
if let Ok(id) = id_str.parse::<u64>() {
if id <= cutoff {
let _ = fs::remove_file(entry.path());
}
}
}
}
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct IndexMetadata {
pub document_count: usize,
pub embedding_dim: usize,
pub index_type: String,
}
#[derive(Debug, Clone)]
pub struct StorageStats {
pub checkpoint_id: Option<u64>,
pub wal_size: usize,
pub checkpoint_size: usize,
pub total_size: usize,
pub ops_since_checkpoint: usize,
pub total_documents: usize,
}
pub struct RecoveryHelper<'a> {
storage: &'a IncrementalStorage,
}
impl<'a> RecoveryHelper<'a> {
pub fn new(storage: &'a IncrementalStorage) -> Self {
Self { storage }
}
pub fn replay_wal<F>(&self, mut apply_op: F) -> Result<usize>
where
F: FnMut(&WalOperation) -> Result<()>,
{
let entries = self.storage.get_wal_entries()?;
let count = entries.len();
for entry in entries {
match &entry.operation {
WalOperation::Checkpoint { .. } => {
continue;
}
op => apply_op(op)?,
}
}
Ok(count)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn create_test_document(id: &str, dim: usize) -> Document {
Document {
id: id.to_string(),
content: format!("Content for {}", id),
embedding: vec![0.1; dim],
metadata: None,
}
}
#[test]
fn test_config_builder() {
let config = IncrementalConfig::default()
.with_checkpoint_threshold(5000)
.with_wal_sync_interval(50)
.with_max_wal_size(50 * 1024 * 1024)
.with_sync_on_write(true)
.with_keep_checkpoints(3);
assert_eq!(config.checkpoint_threshold, 5000);
assert_eq!(config.wal_sync_interval, 50);
assert_eq!(config.max_wal_size, 50 * 1024 * 1024);
assert!(config.sync_on_write);
assert_eq!(config.keep_checkpoints, 3);
}
#[test]
fn test_wal_entry_integrity() {
let entry = WalEntry::new(1, WalOperation::Add(create_test_document("doc1", 128)));
assert!(entry.verify());
let mut tampered = entry.clone();
tampered.seq = 999;
assert!(!tampered.verify());
}
#[test]
fn test_storage_creation() {
let dir = TempDir::new().unwrap();
let storage = IncrementalStorage::new(dir.path(), IncrementalConfig::default()).unwrap();
assert!(storage.manifest().current_checkpoint.is_none());
assert_eq!(storage.manifest().wal_seq, 0);
}
#[test]
fn test_wal_logging() {
let dir = TempDir::new().unwrap();
let mut storage =
IncrementalStorage::new(dir.path(), IncrementalConfig::default()).unwrap();
storage.log_add(&create_test_document("doc1", 128)).unwrap();
storage.log_add(&create_test_document("doc2", 128)).unwrap();
storage.log_remove("doc1").unwrap();
assert_eq!(storage.manifest().wal_seq, 3);
assert_eq!(storage.manifest().ops_since_checkpoint, 3);
storage.sync().unwrap();
let entries = storage.get_wal_entries().unwrap();
assert_eq!(entries.len(), 3);
match &entries[0].operation {
WalOperation::Add(doc) => assert_eq!(doc.id, "doc1"),
_ => panic!("Expected Add operation"),
}
match &entries[2].operation {
WalOperation::Remove(id) => assert_eq!(id, "doc1"),
_ => panic!("Expected Remove operation"),
}
}
#[test]
fn test_checkpoint_and_recovery() {
let dir = TempDir::new().unwrap();
let mut storage = IncrementalStorage::new(
dir.path(),
IncrementalConfig::default().with_checkpoint_threshold(100),
)
.unwrap();
let test_data: Vec<String> =
vec!["doc1".to_string(), "doc2".to_string(), "doc3".to_string()];
for id in &test_data {
storage.log_add(&create_test_document(id, 128)).unwrap();
}
let meta = storage
.checkpoint(
&test_data,
IndexMetadata {
document_count: 3,
embedding_dim: 128,
index_type: "test".to_string(),
},
)
.unwrap();
assert_eq!(meta.id, 1);
assert_eq!(meta.document_count, 3);
storage.log_add(&create_test_document("doc4", 128)).unwrap();
storage.sync().unwrap();
let (loaded_data, loaded_meta): (Vec<String>, CheckpointMeta) =
storage.load_checkpoint().unwrap().unwrap();
assert_eq!(loaded_data, test_data);
assert_eq!(loaded_meta.id, 1);
let entries = storage.get_wal_entries().unwrap();
assert_eq!(entries.len(), 1);
match &entries[0].operation {
WalOperation::Add(doc) => assert_eq!(doc.id, "doc4"),
_ => panic!("Expected Add operation"),
}
}
#[test]
fn test_needs_checkpoint() {
let dir = TempDir::new().unwrap();
let mut storage = IncrementalStorage::new(
dir.path(),
IncrementalConfig::default().with_checkpoint_threshold(5),
)
.unwrap();
for i in 0..4 {
storage
.log_add(&create_test_document(&format!("doc{}", i), 128))
.unwrap();
}
assert!(!storage.needs_checkpoint());
storage.log_add(&create_test_document("doc5", 128)).unwrap();
assert!(storage.needs_checkpoint());
}
#[test]
fn test_storage_stats() {
let dir = TempDir::new().unwrap();
let mut storage =
IncrementalStorage::new(dir.path(), IncrementalConfig::default()).unwrap();
for i in 0..10 {
storage
.log_add(&create_test_document(&format!("doc{}", i), 128))
.unwrap();
}
storage.sync().unwrap();
let stats = storage.stats();
assert!(stats.wal_size > 0);
assert_eq!(stats.ops_since_checkpoint, 10);
assert!(stats.checkpoint_id.is_none());
}
#[test]
fn test_recovery_helper() {
let dir = TempDir::new().unwrap();
let mut storage =
IncrementalStorage::new(dir.path(), IncrementalConfig::default()).unwrap();
storage.log_add(&create_test_document("doc1", 128)).unwrap();
storage.log_add(&create_test_document("doc2", 128)).unwrap();
storage.log_remove("doc1").unwrap();
storage.sync().unwrap();
let helper = RecoveryHelper::new(&storage);
let mut adds = 0;
let mut removes = 0;
helper
.replay_wal(|op| {
match op {
WalOperation::Add(_) => adds += 1,
WalOperation::Remove(_) => removes += 1,
_ => {}
}
Ok(())
})
.unwrap();
assert_eq!(adds, 2);
assert_eq!(removes, 1);
}
#[test]
fn test_persistence_across_reopens() {
let dir = TempDir::new().unwrap();
{
let mut storage =
IncrementalStorage::new(dir.path(), IncrementalConfig::default()).unwrap();
storage.log_add(&create_test_document("doc1", 128)).unwrap();
storage.log_add(&create_test_document("doc2", 128)).unwrap();
storage.sync().unwrap();
}
{
let storage =
IncrementalStorage::new(dir.path(), IncrementalConfig::default()).unwrap();
assert_eq!(storage.manifest().wal_seq, 2);
let entries = storage.get_wal_entries().unwrap();
assert_eq!(entries.len(), 2);
}
}
#[test]
fn test_keep_checkpoints_zero_prunes_old_checkpoints() {
let dir = TempDir::new().unwrap();
let mut storage = IncrementalStorage::new(
dir.path(),
IncrementalConfig::default().with_keep_checkpoints(0),
)
.unwrap();
let data = vec!["doc".to_string()];
for checkpoint_no in 0..3 {
storage
.checkpoint(
&data,
IndexMetadata {
document_count: 1,
embedding_dim: 128,
index_type: format!("test_{}", checkpoint_no),
},
)
.unwrap();
}
let checkpoint_bins: Vec<_> = fs::read_dir(dir.path())
.unwrap()
.filter_map(|e| e.ok())
.map(|e| e.file_name().to_string_lossy().to_string())
.filter(|name| name.starts_with("checkpoint_") && name.ends_with(".bin"))
.collect();
assert_eq!(
checkpoint_bins.len(),
1,
"keep_checkpoints=0 should keep only current checkpoint"
);
}
#[test]
fn test_keep_checkpoints_exact_retention_count() {
let dir = TempDir::new().unwrap();
let mut storage = IncrementalStorage::new(
dir.path(),
IncrementalConfig::default().with_keep_checkpoints(2),
)
.unwrap();
let data = vec!["doc".to_string()];
for checkpoint_no in 0..5 {
storage
.checkpoint(
&data,
IndexMetadata {
document_count: 1,
embedding_dim: 128,
index_type: format!("test_{}", checkpoint_no),
},
)
.unwrap();
}
let checkpoint_bins: Vec<_> = fs::read_dir(dir.path())
.unwrap()
.filter_map(|e| e.ok())
.map(|e| e.file_name().to_string_lossy().to_string())
.filter(|name| name.starts_with("checkpoint_") && name.ends_with(".bin"))
.collect();
assert_eq!(
checkpoint_bins.len(),
2,
"retention should keep exactly keep_checkpoints checkpoint files"
);
}
#[test]
fn test_wal_roundtrip_with_metadata() {
let dir = TempDir::new().unwrap();
let mut storage =
IncrementalStorage::new(dir.path(), IncrementalConfig::default()).unwrap();
let doc = Document {
id: "meta-doc".to_string(),
content: "has metadata".to_string(),
embedding: vec![0.1; 4],
metadata: Some(serde_json::json!({
"scope": "workspace",
"tags": ["rust", "ai"],
"priority": 5
})),
};
storage.log_add(&doc).unwrap();
storage.sync().unwrap();
let entries = storage.get_wal_entries().unwrap();
assert_eq!(entries.len(), 1);
match &entries[0].operation {
WalOperation::Add(recovered) => {
assert_eq!(recovered.id, "meta-doc");
let meta = recovered.metadata.as_ref().unwrap();
assert_eq!(meta["scope"], "workspace");
assert_eq!(meta["priority"], 5);
assert_eq!(meta["tags"][0], "rust");
}
_ => panic!("Expected Add operation"),
}
}
#[test]
fn test_checkpoint_roundtrip_with_metadata() {
let dir = TempDir::new().unwrap();
let mut storage =
IncrementalStorage::new(dir.path(), IncrementalConfig::default()).unwrap();
let docs = vec![
Document {
id: "d1".to_string(),
content: "first".to_string(),
embedding: vec![1.0, 0.0],
metadata: Some(serde_json::json!({"lang": "rust"})),
},
Document {
id: "d2".to_string(),
content: "second".to_string(),
embedding: vec![0.0, 1.0],
metadata: None,
},
];
storage
.checkpoint(
&docs,
IndexMetadata {
document_count: 2,
embedding_dim: 2,
index_type: "hnsw".to_string(),
},
)
.unwrap();
let (loaded, meta): (Vec<Document>, CheckpointMeta) =
storage.load_checkpoint().unwrap().unwrap();
assert_eq!(meta.document_count, 2);
assert_eq!(loaded.len(), 2);
assert_eq!(loaded[0].id, "d1");
assert_eq!(loaded[0].metadata.as_ref().unwrap()["lang"], "rust");
assert_eq!(loaded[1].id, "d2");
assert!(loaded[1].metadata.is_none());
}
#[test]
fn test_recovery_ignores_truncated_tail_entry() {
let dir = TempDir::new().unwrap();
{
let mut storage =
IncrementalStorage::new(dir.path(), IncrementalConfig::default()).unwrap();
storage.log_add(&create_test_document("doc1", 4)).unwrap();
storage.sync().unwrap();
}
let wal_path = dir.path().join("wal_00000.log");
let torn_entry = WalEntry::new(2, WalOperation::Add(create_test_document("doc2", 4)));
let torn_payload = serde_json::to_vec(&torn_entry).unwrap();
let torn_len = torn_payload.len() as u32;
let mut file = OpenOptions::new().append(true).open(&wal_path).unwrap();
file.write_all(&torn_len.to_le_bytes()).unwrap();
file.write_all(&torn_payload[..torn_payload.len() / 2])
.unwrap();
file.sync_all().unwrap();
let storage = IncrementalStorage::new(dir.path(), IncrementalConfig::default()).unwrap();
let entries = storage.get_wal_entries().unwrap();
assert_eq!(entries.len(), 1);
match &entries[0].operation {
WalOperation::Add(doc) => assert_eq!(doc.id, "doc1"),
_ => panic!("expected Add operation"),
}
}
#[test]
fn test_atomic_writes_leave_no_tmp_files_after_repeated_checkpoints() {
let dir = TempDir::new().unwrap();
let mut storage =
IncrementalStorage::new(dir.path(), IncrementalConfig::default()).unwrap();
for i in 0..8 {
let payload = vec![format!("doc-{i}")];
storage
.checkpoint(
&payload,
IndexMetadata {
document_count: 1,
embedding_dim: 128,
index_type: "hnsw".to_string(),
},
)
.unwrap();
}
let has_tmp = fs::read_dir(dir.path())
.unwrap()
.filter_map(|e| e.ok())
.map(|e| e.file_name().to_string_lossy().to_string())
.any(|name| name.ends_with(".tmp"));
assert!(!has_tmp);
}
}