use crate::vector::{MmapVectorStorage, SegmentOrdinal, VectorDimension, VectorId};
use crate::{SymbolId, semantic::SemanticSearchError};
use std::path::Path;
#[derive(Debug)]
pub struct SemanticVectorStorage {
storage: MmapVectorStorage,
dimension: VectorDimension,
}
impl SemanticVectorStorage {
pub fn new(path: &Path, dimension: VectorDimension) -> Result<Self, SemanticSearchError> {
let storage_path = path.join("segment_0.vec");
if storage_path.exists() {
std::fs::remove_file(&storage_path).map_err(|e| SemanticSearchError::StorageError {
message: format!("Failed to remove old storage: {e}"),
suggestion: "Check file permissions".to_string(),
})?;
}
let storage =
MmapVectorStorage::new(path, SegmentOrdinal::new(0), dimension).map_err(|e| {
SemanticSearchError::StorageError {
message: format!("Failed to create storage: {e}"),
suggestion: "Ensure the directory exists and you have write permissions"
.to_string(),
}
})?;
Ok(Self { storage, dimension })
}
pub fn open(path: &Path) -> Result<Self, SemanticSearchError> {
let storage = MmapVectorStorage::open(path, SegmentOrdinal::new(0)).map_err(|e| {
SemanticSearchError::StorageError {
message: format!("Failed to open storage: {e}"),
suggestion: "Check if semantic search data exists at the specified path"
.to_string(),
}
})?;
let dimension = storage.dimension();
Ok(Self { storage, dimension })
}
pub fn open_or_create(
path: &Path,
dimension: VectorDimension,
) -> Result<Self, SemanticSearchError> {
let storage = MmapVectorStorage::open_or_create(path, SegmentOrdinal::new(0), dimension)
.map_err(|e| SemanticSearchError::StorageError {
message: format!("Failed to open or create storage: {e}"),
suggestion: "Check path permissions and disk space".to_string(),
})?;
Ok(Self { storage, dimension })
}
pub fn save_embedding(
&mut self,
id: SymbolId,
embedding: &[f32],
) -> Result<(), SemanticSearchError> {
if embedding.len() != self.dimension.get() {
return Err(SemanticSearchError::DimensionMismatch {
expected: self.dimension.get(),
actual: embedding.len(),
suggestion: "Ensure all embeddings are generated with the same model".to_string(),
});
}
let vector_id =
VectorId::new(id.to_u32()).ok_or_else(|| SemanticSearchError::InvalidId {
id: id.to_u32(),
suggestion: "Symbol ID must be non-zero".to_string(),
})?;
self.storage
.write_batch(&[(vector_id, embedding)])
.map_err(|e| SemanticSearchError::StorageError {
message: format!("Failed to save embedding: {e}"),
suggestion: "Check disk space and file permissions".to_string(),
})
}
pub fn load_embedding(&mut self, id: SymbolId) -> Option<Vec<f32>> {
let vector_id = VectorId::new(id.to_u32())?;
self.storage.read_vector(vector_id)
}
pub fn load_all(&mut self) -> Result<Vec<(SymbolId, Vec<f32>)>, SemanticSearchError> {
let vectors =
self.storage
.read_all_vectors()
.map_err(|e| SemanticSearchError::StorageError {
message: format!("Failed to load embeddings: {e}"),
suggestion:
"The storage file may be corrupted. Try rebuilding the semantic index."
.to_string(),
})?;
let mut result = Vec::with_capacity(vectors.len());
for (vector_id, embedding) in vectors {
let symbol_id = SymbolId::new(vector_id.get()).unwrap();
result.push((symbol_id, embedding));
}
Ok(result)
}
pub fn save_batch(
&mut self,
embeddings: &[(SymbolId, Vec<f32>)],
) -> Result<(), SemanticSearchError> {
for (_, embedding) in embeddings {
if embedding.len() != self.dimension.get() {
return Err(SemanticSearchError::DimensionMismatch {
expected: self.dimension.get(),
actual: embedding.len(),
suggestion: "All embeddings must have the same dimension".to_string(),
});
}
}
let mut vector_batch = Vec::with_capacity(embeddings.len());
for (symbol_id, embedding) in embeddings {
let vector_id = VectorId::new(symbol_id.to_u32()).ok_or_else(|| {
SemanticSearchError::InvalidId {
id: symbol_id.to_u32(),
suggestion: "Symbol ID must be non-zero".to_string(),
}
})?;
vector_batch.push((vector_id, embedding.as_slice()));
}
self.storage
.write_batch(&vector_batch)
.map_err(|e| SemanticSearchError::StorageError {
message: format!("Failed to save batch: {e}"),
suggestion: "Check disk space and file permissions".to_string(),
})
}
pub fn embedding_count(&self) -> usize {
self.storage.vector_count()
}
pub fn dimension(&self) -> VectorDimension {
self.dimension
}
pub fn exists(&self) -> bool {
self.storage.exists()
}
pub fn file_size(&self) -> Result<u64, SemanticSearchError> {
self.storage
.file_size()
.map_err(|e| SemanticSearchError::StorageError {
message: format!("Failed to get file size: {e}"),
suggestion: "Check if the storage file exists".to_string(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_save_and_load_single_embedding() {
let temp_dir = TempDir::new().unwrap();
let dimension = VectorDimension::new(4).unwrap();
let mut storage = SemanticVectorStorage::new(temp_dir.path(), dimension).unwrap();
let symbol_id = SymbolId::new(42).unwrap();
let embedding = vec![1.0, 2.0, 3.0, 4.0];
storage.save_embedding(symbol_id, &embedding).unwrap();
let loaded = storage.load_embedding(symbol_id).unwrap();
assert_eq!(loaded, embedding);
assert!(
storage
.load_embedding(SymbolId::new(999).unwrap())
.is_none()
);
}
#[test]
fn test_load_all_embeddings() {
let temp_dir = TempDir::new().unwrap();
let dimension = VectorDimension::new(3).unwrap();
let mut storage =
SemanticVectorStorage::open_or_create(temp_dir.path(), dimension).unwrap();
let embeddings = vec![
(SymbolId::new(1).unwrap(), vec![1.0, 2.0, 3.0]),
(SymbolId::new(2).unwrap(), vec![4.0, 5.0, 6.0]),
(SymbolId::new(3).unwrap(), vec![7.0, 8.0, 9.0]),
];
for (id, embedding) in &embeddings {
storage.save_embedding(*id, embedding).unwrap();
}
let loaded = storage.load_all().unwrap();
assert_eq!(loaded.len(), 3);
for (original_id, original_embedding) in &embeddings {
let found = loaded
.iter()
.find(|(id, _)| id == original_id)
.map(|(_, embedding)| embedding);
assert_eq!(found, Some(original_embedding));
}
}
#[test]
fn test_dimension_validation() {
let temp_dir = TempDir::new().unwrap();
let dimension = VectorDimension::new(3).unwrap();
let mut storage = SemanticVectorStorage::new(temp_dir.path(), dimension).unwrap();
let result = storage.save_embedding(
SymbolId::new(1).unwrap(),
&[1.0, 2.0], );
assert!(result.is_err());
match result.unwrap_err() {
SemanticSearchError::DimensionMismatch {
expected, actual, ..
} => {
assert_eq!(expected, 3);
assert_eq!(actual, 2);
}
_ => panic!("Expected DimensionMismatch error"),
}
}
#[test]
fn test_batch_operations() {
let temp_dir = TempDir::new().unwrap();
let dimension = VectorDimension::new(2).unwrap();
let mut storage = SemanticVectorStorage::new(temp_dir.path(), dimension).unwrap();
let embeddings = vec![
(SymbolId::new(10).unwrap(), vec![1.0, 2.0]),
(SymbolId::new(20).unwrap(), vec![3.0, 4.0]),
(SymbolId::new(30).unwrap(), vec![5.0, 6.0]),
];
storage.save_batch(&embeddings).unwrap();
assert_eq!(storage.embedding_count(), 3);
assert_eq!(
storage.load_embedding(SymbolId::new(10).unwrap()),
Some(vec![1.0, 2.0])
);
assert_eq!(
storage.load_embedding(SymbolId::new(20).unwrap()),
Some(vec![3.0, 4.0])
);
assert_eq!(
storage.load_embedding(SymbolId::new(30).unwrap()),
Some(vec![5.0, 6.0])
);
}
#[test]
fn test_persistence_across_instances() {
let temp_dir = TempDir::new().unwrap();
let dimension = VectorDimension::new(2).unwrap();
{
let mut storage = SemanticVectorStorage::new(temp_dir.path(), dimension).unwrap();
storage
.save_embedding(SymbolId::new(42).unwrap(), &[1.5, 2.5])
.unwrap();
}
{
let mut storage = SemanticVectorStorage::open(temp_dir.path()).unwrap();
assert_eq!(storage.dimension(), dimension);
assert_eq!(storage.embedding_count(), 1);
let loaded = storage.load_embedding(SymbolId::new(42).unwrap()).unwrap();
assert_eq!(loaded, vec![1.5, 2.5]);
}
}
}