#![cfg(not(target_arch = "wasm32"))]
use crate::storage::compression::{self, Codec, CompressionStats};
use crate::{Document, RagError, Result};
use serde::{Deserialize, Serialize};
use std::fs::{self, File};
use std::io::{Read, Write};
use std::path::{Component, Path, PathBuf};
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{SystemTime, UNIX_EPOCH};
const STORAGE_VERSION: u32 = 2;
const DATA_EXTENSION: &str = "data";
const META_EXTENSION: &str = "meta";
const TMP_EXTENSION: &str = "tmp";
static TMP_FILE_COUNTER: AtomicU64 = AtomicU64::new(0);
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StorageMetadata {
pub version: u32,
pub created_at: u64,
pub updated_at: u64,
pub item_type: String,
pub compression: Codec,
pub original_size: usize,
pub compressed_size: usize,
}
impl StorageMetadata {
fn new(
item_type: String,
compression: Codec,
original_size: usize,
compressed_size: usize,
) -> Self {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
Self {
version: STORAGE_VERSION,
created_at: now,
updated_at: now,
item_type,
compression,
original_size,
compressed_size,
}
}
fn touch(&mut self) {
self.updated_at = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
}
}
#[derive(Debug)]
pub struct FileStorage {
base_path: PathBuf,
codec: Codec,
}
impl FileStorage {
pub fn new(base_path: impl AsRef<Path>) -> Result<Self> {
Self::with_codec(base_path, Codec::None)
}
pub fn with_codec(base_path: impl AsRef<Path>, codec: Codec) -> Result<Self> {
let base_path = base_path.as_ref().to_path_buf();
if !base_path.exists() {
fs::create_dir_all(&base_path).map_err(|e| {
RagError::StorageError(format!("Failed to create storage directory: {}", e))
})?;
}
if !base_path.is_dir() {
return Err(RagError::StorageError(format!(
"Storage path is not a directory: {}",
base_path.display()
)));
}
Ok(Self { base_path, codec })
}
pub fn save_document(&self, id: &str, document: &Document) -> Result<CompressionStats> {
Self::validate_item_name(id)?;
let serialized = serde_json::to_vec(document)
.map_err(|e| RagError::StorageError(format!("JSON serialization failed: {}", e)))?;
let (compressed, stats) = compression::compress_with(&serialized, self.codec)
.map_err(|e| RagError::StorageError(format!("Compression failed: {}", e)))?;
let metadata = if self.exists(id) {
let mut meta = self.get_metadata(id)?;
meta.touch();
meta.original_size = stats.original_size;
meta.compressed_size = stats.compressed_size;
meta.compression = stats.codec;
meta
} else {
StorageMetadata::new(
"document".to_string(),
stats.codec,
stats.original_size,
stats.compressed_size,
)
};
let data_path = self.item_path(id);
self.write_atomic(&data_path, &compressed)?;
let meta_path = self.metadata_path(id);
let meta_bytes = serde_json::to_vec(&metadata)
.map_err(|e| RagError::StorageError(format!("metadata serialize failed: {}", e)))?;
self.write_atomic(&meta_path, &meta_bytes)?;
Ok(stats)
}
pub fn load_document(&self, id: &str) -> Result<Document> {
Self::validate_item_name(id)?;
if !self.exists(id) {
return Err(RagError::StorageError(format!(
"Document not found: {}",
id
)));
}
let metadata = self.get_metadata(id)?;
if metadata.version != STORAGE_VERSION {
return Err(RagError::StorageError(format!(
"Incompatible storage version: expected {}, got {}",
STORAGE_VERSION, metadata.version
)));
}
let data_path = self.item_path(id);
let mut file = File::open(&data_path)?;
let mut compressed = Vec::new();
file.read_to_end(&mut compressed)?;
if compressed.len() != metadata.compressed_size {
return Err(RagError::StorageError(format!(
"Data corruption detected: size mismatch for {}",
id
)));
}
let decompressed = compression::decompress(&compressed)
.map_err(|e| RagError::StorageError(format!("Decompression failed: {}", e)))?;
let document: Document = serde_json::from_slice(&decompressed)
.map_err(|e| RagError::StorageError(format!("JSON deserialization failed: {}", e)))?;
Ok(document)
}
pub fn save_flat_index(
&self,
name: &str,
index: &FlatIndexWrapper,
) -> Result<CompressionStats> {
Self::validate_item_name(name)?;
self.save_with_metadata(name, index, "flat_index")
}
pub fn load_flat_index(&self, name: &str) -> Result<FlatIndexWrapper> {
Self::validate_item_name(name)?;
self.load_with_metadata(name)
}
pub fn save_hnsw_index(
&self,
name: &str,
index: &HNSWIndexWrapper,
) -> Result<CompressionStats> {
Self::validate_item_name(name)?;
self.save_with_metadata(name, index, "hnsw_index")
}
pub fn load_hnsw_index(&self, name: &str) -> Result<HNSWIndexWrapper> {
Self::validate_item_name(name)?;
self.load_with_metadata(name)
}
pub fn delete(&self, name: &str) -> Result<()> {
Self::validate_item_name(name)?;
let data_path = self.item_path(name);
let meta_path = self.metadata_path(name);
if data_path.exists() {
fs::remove_file(&data_path).map_err(|e| {
RagError::StorageError(format!("Failed to delete data file: {}", e))
})?;
}
if meta_path.exists() {
fs::remove_file(&meta_path).map_err(|e| {
RagError::StorageError(format!("Failed to delete metadata file: {}", e))
})?;
}
Ok(())
}
pub fn list(&self) -> Result<Vec<String>> {
let entries = fs::read_dir(&self.base_path).map_err(|e| {
RagError::StorageError(format!("Failed to read storage directory: {}", e))
})?;
let mut names = std::collections::HashSet::new();
for entry in entries {
let entry = entry.map_err(|e| {
RagError::StorageError(format!("Failed to read directory entry: {}", e))
})?;
let path = entry.path();
if path.is_file() {
if let Some(ext) = path.extension() {
if ext == DATA_EXTENSION || ext == META_EXTENSION {
if let Some(stem) = path.file_stem() {
if let Some(name) = stem.to_str() {
names.insert(name.to_string());
}
}
}
}
}
}
let mut result: Vec<String> = names.into_iter().collect();
result.sort();
Ok(result)
}
pub fn get_metadata(&self, name: &str) -> Result<StorageMetadata> {
Self::validate_item_name(name)?;
let meta_path = self.metadata_path(name);
if !meta_path.exists() {
return Err(RagError::StorageError(format!(
"Metadata not found for item: {}",
name
)));
}
let mut file = File::open(&meta_path)?;
let mut contents = Vec::new();
file.read_to_end(&mut contents)?;
let metadata: StorageMetadata = serde_json::from_slice(&contents)
.or_else(|_| bincode::deserialize::<StorageMetadata>(&contents))
.map_err(|e| RagError::StorageError(format!("metadata deserialize failed: {}", e)))?;
Ok(metadata)
}
pub fn total_size(&self) -> Result<u64> {
let entries = fs::read_dir(&self.base_path).map_err(|e| {
RagError::StorageError(format!("Failed to read storage directory: {}", e))
})?;
let mut total = 0u64;
for entry in entries {
let entry = entry.map_err(|e| {
RagError::StorageError(format!("Failed to read directory entry: {}", e))
})?;
let metadata = entry.metadata()?;
if metadata.is_file() {
total += metadata.len();
}
}
Ok(total)
}
pub fn clear(&self) -> Result<()> {
let entries = fs::read_dir(&self.base_path).map_err(|e| {
RagError::StorageError(format!("Failed to read storage directory: {}", e))
})?;
for entry in entries {
let entry = entry.map_err(|e| {
RagError::StorageError(format!("Failed to read directory entry: {}", e))
})?;
let path = entry.path();
if path.is_file() {
fs::remove_file(&path)
.map_err(|e| RagError::StorageError(format!("Failed to delete file: {}", e)))?;
}
}
Ok(())
}
pub fn exists(&self, name: &str) -> bool {
if Self::is_invalid_item_name(name) {
return false;
}
self.item_path(name).exists() && self.metadata_path(name).exists()
}
fn is_invalid_item_name(name: &str) -> bool {
if name.is_empty() {
return true;
}
if name.contains('\0') {
return true;
}
if name.contains('/') || name.contains('\\') {
return true;
}
let path = Path::new(name);
if path.is_absolute() {
return true;
}
let mut components = path.components();
match components.next() {
Some(Component::Normal(_)) => {
if components.next().is_some() {
return true;
}
}
_ => return true,
}
let base_name = name.split('.').next().unwrap_or(name);
let stem_upper = base_name.to_ascii_uppercase();
let is_reserved = matches!(
stem_upper.as_str(),
"CON"
| "PRN"
| "AUX"
| "NUL"
| "COM1"
| "COM2"
| "COM3"
| "COM4"
| "COM5"
| "COM6"
| "COM7"
| "COM8"
| "COM9"
| "LPT1"
| "LPT2"
| "LPT3"
| "LPT4"
| "LPT5"
| "LPT6"
| "LPT7"
| "LPT8"
| "LPT9"
);
if is_reserved {
return true;
}
false
}
fn validate_item_name(name: &str) -> Result<()> {
if Self::is_invalid_item_name(name) {
return Err(RagError::StorageError(format!(
"Invalid item name: '{}'. Names must be a single path segment",
name
)));
}
Ok(())
}
fn item_path(&self, name: &str) -> PathBuf {
self.base_path.join(format!("{}.{}", name, DATA_EXTENSION))
}
fn metadata_path(&self, name: &str) -> PathBuf {
self.base_path.join(format!("{}.{}", name, META_EXTENSION))
}
fn write_atomic(&self, path: &Path, data: &[u8]) -> Result<()> {
let filename = path.file_name().and_then(|f| f.to_str()).unwrap_or("item");
let counter = TMP_FILE_COUNTER.fetch_add(1, Ordering::Relaxed);
let tmp_path = path.with_file_name(format!(
"{}.{}.{}.{}",
filename,
std::process::id(),
counter,
TMP_EXTENSION
));
{
let mut file = File::create(&tmp_path)?;
file.write_all(data)?;
file.sync_all()?; }
fs::rename(&tmp_path, path).map_err(|e| {
let _ = fs::remove_file(&tmp_path);
RagError::IoError(e)
})?;
Ok(())
}
fn save_with_metadata<T: Serialize>(
&self,
name: &str,
item: &T,
item_type: &str,
) -> Result<CompressionStats> {
let serialized = serde_json::to_vec(item)
.map_err(|e| RagError::StorageError(format!("JSON serialization failed: {}", e)))?;
let (compressed, stats) = compression::compress_with(&serialized, self.codec)
.map_err(|e| RagError::StorageError(format!("Compression failed: {}", e)))?;
let metadata = if self.exists(name) {
let mut meta = self.get_metadata(name)?;
meta.touch();
meta.original_size = stats.original_size;
meta.compressed_size = stats.compressed_size;
meta.compression = stats.codec;
meta
} else {
StorageMetadata::new(
item_type.to_string(),
stats.codec,
stats.original_size,
stats.compressed_size,
)
};
let data_path = self.item_path(name);
self.write_atomic(&data_path, &compressed)?;
let meta_path = self.metadata_path(name);
let meta_bytes = serde_json::to_vec(&metadata)
.map_err(|e| RagError::StorageError(format!("metadata serialize failed: {}", e)))?;
self.write_atomic(&meta_path, &meta_bytes)?;
Ok(stats)
}
fn load_with_metadata<T: for<'de> Deserialize<'de>>(&self, name: &str) -> Result<T> {
if !self.exists(name) {
return Err(RagError::StorageError(format!("Item not found: {}", name)));
}
let metadata = self.get_metadata(name)?;
if metadata.version != STORAGE_VERSION {
return Err(RagError::StorageError(format!(
"Incompatible storage version: expected {}, got {}",
STORAGE_VERSION, metadata.version
)));
}
let data_path = self.item_path(name);
let mut file = File::open(&data_path)?;
let mut compressed = Vec::new();
file.read_to_end(&mut compressed)?;
if compressed.len() != metadata.compressed_size {
return Err(RagError::StorageError(format!(
"Data corruption detected: size mismatch for {}",
name
)));
}
let decompressed = compression::decompress(&compressed)
.map_err(|e| RagError::StorageError(format!("Decompression failed: {}", e)))?;
let item: T = serde_json::from_slice::<T>(&decompressed)
.map_err(|e| RagError::StorageError(format!("JSON deserialization failed: {}", e)))?;
Ok(item)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FlatIndexWrapper {
pub embedding_dim: usize,
pub documents: Vec<Document>,
}
impl FlatIndexWrapper {
pub fn from_index(index: &crate::index::FlatIndex) -> Self {
Self {
embedding_dim: index.embedding_dim(),
documents: index.get_all_documents(),
}
}
pub fn to_index(&self) -> Result<crate::index::FlatIndex> {
let mut index = crate::index::FlatIndex::new(self.embedding_dim);
index.add_batch(self.documents.clone())?;
Ok(index)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HNSWIndexWrapper {
pub embedding_dim: usize,
pub documents: Vec<Document>,
pub config: HNSWConfigWrapper,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HNSWConfigWrapper {
pub m: usize,
pub m0: usize,
pub ef_construction: usize,
pub ef_search: usize,
pub ml: f32,
#[serde(default = "default_use_heuristic")]
pub use_heuristic: bool,
#[serde(default)]
pub extend_candidates: bool,
#[serde(default = "default_keep_pruned")]
pub keep_pruned_connections: bool,
}
fn default_use_heuristic() -> bool {
true
}
fn default_keep_pruned() -> bool {
true
}
impl From<&crate::index::HNSWConfig> for HNSWConfigWrapper {
fn from(config: &crate::index::HNSWConfig) -> Self {
Self {
m: config.m,
m0: config.m0,
ef_construction: config.ef_construction,
ef_search: config.ef_search,
ml: config.ml,
use_heuristic: config.use_heuristic,
extend_candidates: config.extend_candidates,
keep_pruned_connections: config.keep_pruned_connections,
}
}
}
impl From<HNSWConfigWrapper> for crate::index::HNSWConfig {
fn from(wrapper: HNSWConfigWrapper) -> Self {
Self {
m: wrapper.m,
m0: wrapper.m0,
ef_construction: wrapper.ef_construction,
ef_search: wrapper.ef_search,
ml: wrapper.ml,
use_heuristic: wrapper.use_heuristic,
extend_candidates: wrapper.extend_candidates,
keep_pruned_connections: wrapper.keep_pruned_connections,
build_strategy: crate::index::BuildStrategy::default(),
seed: None,
}
}
}
impl HNSWIndexWrapper {
pub fn from_index(index: &crate::index::HNSWIndex) -> Self {
Self {
embedding_dim: index.embedding_dim(),
documents: index.get_all_documents(),
config: HNSWConfigWrapper::from(index.config()),
}
}
pub fn to_index(&self) -> Result<crate::index::HNSWIndex> {
let config: crate::index::HNSWConfig = self.config.clone().into();
let mut index = crate::index::HNSWIndex::new(self.embedding_dim, config);
for doc in &self.documents {
index.add(doc.clone())?;
}
Ok(index)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Arc, Barrier};
use std::thread;
use tempfile::tempdir;
fn create_test_document(id: &str) -> Document {
Document {
id: id.to_string(),
content: format!("Test content for {}", id),
embedding: vec![0.1, 0.2, 0.3, 0.4, 0.5],
metadata: Some(serde_json::json!({"test": true})),
}
}
fn create_test_flat_index() -> crate::index::FlatIndex {
let mut index = crate::index::FlatIndex::new(5);
index.add(create_test_document("doc1")).unwrap();
index.add(create_test_document("doc2")).unwrap();
index
}
fn create_test_hnsw_index() -> crate::index::HNSWIndex {
let mut index = crate::index::HNSWIndex::with_defaults(5);
index.add(create_test_document("doc1")).unwrap();
index.add(create_test_document("doc2")).unwrap();
index
}
#[test]
fn test_new_storage() {
let dir = tempdir().unwrap();
let _storage = FileStorage::new(dir.path()).unwrap();
assert!(dir.path().exists());
assert!(dir.path().is_dir());
}
#[test]
fn test_new_storage_with_codec() {
let dir = tempdir().unwrap();
let _storage = FileStorage::with_codec(dir.path(), Codec::Gzip).unwrap();
assert!(dir.path().exists());
}
#[test]
fn test_invalid_storage_path() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("file.txt");
std::fs::write(&file_path, b"test").unwrap();
let result = FileStorage::new(&file_path);
assert!(result.is_err());
}
#[test]
fn test_document_save_load() {
let dir = tempdir().unwrap();
let storage = FileStorage::new(dir.path()).unwrap();
let doc = create_test_document("doc1");
let stats = storage.save_document("doc1", &doc).unwrap();
assert!(stats.original_size > 0);
assert_eq!(stats.codec, Codec::None);
let loaded = storage.load_document("doc1").unwrap();
assert_eq!(loaded.id, doc.id);
assert_eq!(loaded.content, doc.content);
assert_eq!(loaded.embedding, doc.embedding);
}
#[test]
fn test_document_not_found() {
let dir = tempdir().unwrap();
let storage = FileStorage::new(dir.path()).unwrap();
let result = storage.load_document("nonexistent");
assert!(result.is_err());
}
#[test]
fn test_flat_index_persistence() {
let dir = tempdir().unwrap();
let storage = FileStorage::new(dir.path()).unwrap();
let index = create_test_flat_index();
let wrapper = FlatIndexWrapper::from_index(&index);
let stats = storage.save_flat_index("index1", &wrapper).unwrap();
assert!(stats.original_size > 0);
let loaded_wrapper = storage.load_flat_index("index1").unwrap();
let loaded_index = loaded_wrapper.to_index().unwrap();
assert_eq!(loaded_index.len(), index.len());
assert_eq!(loaded_index.embedding_dim(), index.embedding_dim());
}
#[test]
fn test_hnsw_index_persistence() {
let dir = tempdir().unwrap();
let storage = FileStorage::new(dir.path()).unwrap();
let index = create_test_hnsw_index();
let wrapper = HNSWIndexWrapper::from_index(&index);
let stats = storage.save_hnsw_index("index1", &wrapper).unwrap();
assert!(stats.original_size > 0);
let loaded_wrapper = storage.load_hnsw_index("index1").unwrap();
let loaded_index = loaded_wrapper.to_index().unwrap();
assert_eq!(loaded_index.len(), index.len());
assert_eq!(loaded_index.embedding_dim(), index.embedding_dim());
}
#[test]
fn test_atomic_write() {
let dir = tempdir().unwrap();
let storage = FileStorage::new(dir.path()).unwrap();
let path = dir.path().join("test.data");
let data = b"test data";
storage.write_atomic(&path, data).unwrap();
assert!(path.exists());
let read_data = std::fs::read(&path).unwrap();
assert_eq!(read_data, data);
let has_tmp = std::fs::read_dir(dir.path())
.unwrap()
.filter_map(|entry| entry.ok())
.map(|entry| entry.file_name().to_string_lossy().to_string())
.any(|name| name.ends_with(".tmp"));
assert!(!has_tmp);
}
#[test]
fn concurrent_atomic_writes_to_sibling_paths_do_not_cross_contaminate() {
let dir = tempdir().unwrap();
let storage = Arc::new(FileStorage::new(dir.path()).unwrap());
let data_path = dir.path().join("doc.data");
let meta_path = dir.path().join("doc.meta");
for _ in 0..128 {
let barrier = Arc::new(Barrier::new(3));
let s1 = Arc::clone(&storage);
let b1 = Arc::clone(&barrier);
let data_path_1 = data_path.clone();
let t1 = thread::spawn(move || {
b1.wait();
s1.write_atomic(&data_path_1, b"DATA").unwrap();
});
let s2 = Arc::clone(&storage);
let b2 = Arc::clone(&barrier);
let meta_path_1 = meta_path.clone();
let t2 = thread::spawn(move || {
b2.wait();
s2.write_atomic(&meta_path_1, b"META").unwrap();
});
barrier.wait();
t1.join().unwrap();
t2.join().unwrap();
assert_eq!(std::fs::read(&data_path).unwrap(), b"DATA");
assert_eq!(std::fs::read(&meta_path).unwrap(), b"META");
}
}
#[test]
fn test_metadata() {
let dir = tempdir().unwrap();
let storage = FileStorage::new(dir.path()).unwrap();
let doc = create_test_document("doc1");
storage.save_document("doc1", &doc).unwrap();
let metadata = storage.get_metadata("doc1").unwrap();
assert_eq!(metadata.version, STORAGE_VERSION);
assert_eq!(metadata.item_type, "document");
assert!(metadata.created_at > 0);
assert_eq!(metadata.created_at, metadata.updated_at);
assert_eq!(metadata.compression, Codec::None);
assert!(metadata.original_size > 0);
}
#[test]
fn test_metadata_update() {
let dir = tempdir().unwrap();
let storage = FileStorage::new(dir.path()).unwrap();
let doc = create_test_document("doc1");
storage.save_document("doc1", &doc).unwrap();
let meta1 = storage.get_metadata("doc1").unwrap();
std::thread::sleep(std::time::Duration::from_millis(10));
storage.save_document("doc1", &doc).unwrap();
let meta2 = storage.get_metadata("doc1").unwrap();
assert_eq!(meta2.created_at, meta1.created_at);
assert!(meta2.updated_at >= meta1.updated_at);
}
#[test]
fn test_list_storage() {
let dir = tempdir().unwrap();
let storage = FileStorage::new(dir.path()).unwrap();
assert_eq!(storage.list().unwrap().len(), 0);
storage
.save_document("doc1", &create_test_document("doc1"))
.unwrap();
storage
.save_document("doc2", &create_test_document("doc2"))
.unwrap();
storage
.save_document("doc3", &create_test_document("doc3"))
.unwrap();
let items = storage.list().unwrap();
assert_eq!(items.len(), 3);
assert!(items.contains(&"doc1".to_string()));
assert!(items.contains(&"doc2".to_string()));
assert!(items.contains(&"doc3".to_string()));
}
#[test]
fn test_delete() {
let dir = tempdir().unwrap();
let storage = FileStorage::new(dir.path()).unwrap();
let doc = create_test_document("doc1");
storage.save_document("doc1", &doc).unwrap();
assert!(storage.exists("doc1"));
assert_eq!(storage.list().unwrap().len(), 1);
storage.delete("doc1").unwrap();
assert!(!storage.exists("doc1"));
assert_eq!(storage.list().unwrap().len(), 0);
}
#[test]
fn test_delete_nonexistent() {
let dir = tempdir().unwrap();
let storage = FileStorage::new(dir.path()).unwrap();
let result = storage.delete("nonexistent");
assert!(result.is_ok());
}
#[test]
fn test_clear() {
let dir = tempdir().unwrap();
let storage = FileStorage::new(dir.path()).unwrap();
storage
.save_document("doc1", &create_test_document("doc1"))
.unwrap();
storage
.save_document("doc2", &create_test_document("doc2"))
.unwrap();
storage
.save_document("doc3", &create_test_document("doc3"))
.unwrap();
assert_eq!(storage.list().unwrap().len(), 3);
storage.clear().unwrap();
assert_eq!(storage.list().unwrap().len(), 0);
}
#[test]
fn test_storage_size() {
let dir = tempdir().unwrap();
let storage = FileStorage::new(dir.path()).unwrap();
assert_eq!(storage.total_size().unwrap(), 0);
storage
.save_document("doc1", &create_test_document("doc1"))
.unwrap();
let size = storage.total_size().unwrap();
assert!(size > 0);
storage
.save_document("doc2", &create_test_document("doc2"))
.unwrap();
let size2 = storage.total_size().unwrap();
assert!(size2 > size);
}
#[test]
fn test_compression_codecs() {
let dir = tempdir().unwrap();
#[allow(unused_mut)]
let mut codecs = vec![Codec::None, Codec::Gzip];
#[cfg(feature = "zstd")]
codecs.push(Codec::Zstd);
#[cfg(feature = "lz4")]
codecs.push(Codec::Lz4);
for codec in codecs {
let storage = FileStorage::with_codec(dir.path(), codec).unwrap();
let doc = create_test_document("doc1");
let stats = storage.save_document("test", &doc).unwrap();
assert!(stats.original_size > 0);
let loaded = storage.load_document("test").unwrap();
assert_eq!(loaded.id, doc.id);
assert_eq!(loaded.content, doc.content);
storage.delete("test").unwrap();
}
}
#[test]
fn test_exists() {
let dir = tempdir().unwrap();
let storage = FileStorage::new(dir.path()).unwrap();
assert!(!storage.exists("doc1"));
storage
.save_document("doc1", &create_test_document("doc1"))
.unwrap();
assert!(storage.exists("doc1"));
assert!(!storage.exists("doc2"));
}
#[test]
fn test_flat_index_wrapper_roundtrip() {
let index = create_test_flat_index();
let wrapper = FlatIndexWrapper::from_index(&index);
let restored = wrapper.to_index().unwrap();
assert_eq!(restored.len(), index.len());
assert_eq!(restored.embedding_dim(), index.embedding_dim());
let query = vec![0.1, 0.2, 0.3, 0.4, 0.5];
let results = restored.search(&query, 2).unwrap();
assert_eq!(results.len(), 2);
}
#[test]
fn test_hnsw_index_wrapper_roundtrip() {
let index = create_test_hnsw_index();
let wrapper = HNSWIndexWrapper::from_index(&index);
let restored = wrapper.to_index().unwrap();
assert_eq!(restored.len(), index.len());
assert_eq!(restored.embedding_dim(), index.embedding_dim());
let query = vec![0.1, 0.2, 0.3, 0.4, 0.5];
let results = restored.search(&query, 2).unwrap();
assert_eq!(results.len(), 2);
}
#[test]
fn test_concurrent_writes() {
let dir = tempdir().unwrap();
let storage = FileStorage::new(dir.path()).unwrap();
let doc = create_test_document("doc1");
for _ in 0..10 {
storage.save_document("doc1", &doc).unwrap();
let loaded = storage.load_document("doc1").unwrap();
assert_eq!(loaded.id, doc.id);
}
}
#[test]
fn test_large_document() {
let dir = tempdir().unwrap();
let storage = FileStorage::new(dir.path()).unwrap();
let mut large_doc = create_test_document("large");
large_doc.embedding = vec![0.5; 10000];
large_doc.content = "x".repeat(100000);
let stats = storage.save_document("large", &large_doc).unwrap();
assert!(stats.original_size > 100000);
let loaded = storage.load_document("large").unwrap();
assert_eq!(loaded.id, large_doc.id);
assert_eq!(loaded.embedding.len(), 10000);
assert_eq!(loaded.content.len(), 100000);
}
#[test]
fn test_rejects_path_traversal_item_names() {
let dir = tempdir().unwrap();
let storage = FileStorage::new(dir.path()).unwrap();
let doc = create_test_document("doc1");
let result = storage.save_document("../outside", &doc);
assert!(result.is_err(), "path traversal names should be rejected");
}
#[test]
fn test_is_invalid_item_name_comprehensive() {
assert!(!FileStorage::is_invalid_item_name("hello"));
assert!(!FileStorage::is_invalid_item_name("my_index"));
assert!(!FileStorage::is_invalid_item_name("data-2024"));
assert!(!FileStorage::is_invalid_item_name("file.txt"));
assert!(FileStorage::is_invalid_item_name(""));
assert!(FileStorage::is_invalid_item_name(".."));
assert!(FileStorage::is_invalid_item_name("."));
assert!(FileStorage::is_invalid_item_name("foo/bar"));
assert!(FileStorage::is_invalid_item_name("foo\\bar"));
assert!(FileStorage::is_invalid_item_name("../outside"));
assert!(FileStorage::is_invalid_item_name("/absolute"));
#[cfg(target_os = "windows")]
assert!(FileStorage::is_invalid_item_name("C:\\Windows\\System32"));
assert!(FileStorage::is_invalid_item_name("hello\0world"));
assert!(FileStorage::is_invalid_item_name("\0"));
assert!(FileStorage::is_invalid_item_name("CON"));
assert!(FileStorage::is_invalid_item_name("con"));
assert!(FileStorage::is_invalid_item_name("Con"));
assert!(FileStorage::is_invalid_item_name("PRN"));
assert!(FileStorage::is_invalid_item_name("AUX"));
assert!(FileStorage::is_invalid_item_name("NUL"));
assert!(FileStorage::is_invalid_item_name("nul"));
assert!(FileStorage::is_invalid_item_name("COM1"));
assert!(FileStorage::is_invalid_item_name("com1"));
assert!(FileStorage::is_invalid_item_name("COM9"));
assert!(FileStorage::is_invalid_item_name("LPT1"));
assert!(FileStorage::is_invalid_item_name("lpt1"));
assert!(FileStorage::is_invalid_item_name("LPT9"));
assert!(FileStorage::is_invalid_item_name("CON.txt"));
assert!(FileStorage::is_invalid_item_name("NUL.tar.gz"));
assert!(FileStorage::is_invalid_item_name("com1.data"));
assert!(FileStorage::is_invalid_item_name("lpt3.log"));
}
}