use crate::error::{MemoryError, MemoryResult};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
use uuid::Uuid;
pub struct VectorStore {
path: Option<PathBuf>,
vectors: HashMap<Uuid, Vec<f32>>,
dimension: usize,
}
#[derive(Serialize, Deserialize)]
struct VectorStoreData {
dimension: usize,
vectors: Vec<(String, Vec<f32>)>,
}
impl VectorStore {
pub fn new(dimension: usize) -> Self {
Self {
path: None,
vectors: HashMap::new(),
dimension,
}
}
pub fn open(path: &Path, dimension: usize) -> MemoryResult<Self> {
let mut store = Self {
path: Some(path.to_path_buf()),
vectors: HashMap::new(),
dimension,
};
if path.exists() {
let data = fs::read_to_string(path)?;
let stored: VectorStoreData = serde_json::from_str(&data)?;
if stored.dimension != dimension {
return Err(MemoryError::InvalidConfig(format!(
"Dimension mismatch: stored={}, expected={}",
stored.dimension, dimension
)));
}
for (id_str, vec) in stored.vectors {
if let Ok(id) = Uuid::parse_str(&id_str) {
store.vectors.insert(id, vec);
}
}
}
Ok(store)
}
pub fn store(&mut self, episode_id: Uuid, vector: Vec<f32>) -> MemoryResult<()> {
if vector.len() != self.dimension {
return Err(MemoryError::InvalidConfig(format!(
"Vector dimension mismatch: got={}, expected={}",
vector.len(),
self.dimension
)));
}
self.vectors.insert(episode_id, vector);
self.persist()?;
Ok(())
}
pub fn get(&self, episode_id: Uuid) -> Option<&Vec<f32>> {
self.vectors.get(&episode_id)
}
pub fn remove(&mut self, episode_id: Uuid) -> MemoryResult<()> {
self.vectors.remove(&episode_id);
self.persist()?;
Ok(())
}
pub fn search(&self, query: &[f32], limit: usize) -> Vec<(Uuid, f64)> {
if query.len() != self.dimension {
return Vec::new();
}
let mut results: Vec<(Uuid, f64)> = self
.vectors
.iter()
.map(|(id, vec)| (*id, cosine_similarity(query, vec)))
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(limit);
results
}
pub fn len(&self) -> usize {
self.vectors.len()
}
pub fn is_empty(&self) -> bool {
self.vectors.is_empty()
}
pub fn episode_ids(&self) -> Vec<Uuid> {
self.vectors.keys().copied().collect()
}
fn persist(&self) -> MemoryResult<()> {
if let Some(ref path) = self.path {
let data = VectorStoreData {
dimension: self.dimension,
vectors: self
.vectors
.iter()
.map(|(id, vec)| (id.to_string(), vec.clone()))
.collect(),
};
let json = serde_json::to_string_pretty(&data)?;
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)?;
}
fs::write(path, json)?;
}
Ok(())
}
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f64 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
(dot_product / (norm_a * norm_b)) as f64
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
#[test]
fn test_store_and_get() {
let mut store = VectorStore::new(3);
let id = Uuid::new_v4();
let vec = vec![1.0, 2.0, 3.0];
store.store(id, vec.clone()).unwrap();
let retrieved = store.get(id).unwrap();
assert_eq!(retrieved, &vec);
}
#[test]
fn test_search() {
let mut store = VectorStore::new(3);
let id1 = Uuid::new_v4();
let id2 = Uuid::new_v4();
let id3 = Uuid::new_v4();
store.store(id1, vec![1.0, 0.0, 0.0]).unwrap();
store.store(id2, vec![0.9, 0.1, 0.0]).unwrap();
store.store(id3, vec![0.0, 1.0, 0.0]).unwrap();
let query = vec![1.0, 0.0, 0.0];
let results = store.search(&query, 2);
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, id1);
assert!((results[0].1 - 1.0).abs() < 0.001);
}
#[test]
fn test_persistence() {
let dir = tempdir().unwrap();
let path = dir.path().join("vectors.json");
let id = Uuid::new_v4();
let vec = vec![1.0, 2.0, 3.0];
{
let mut store = VectorStore::open(&path, 3).unwrap();
store.store(id, vec.clone()).unwrap();
}
{
let store = VectorStore::open(&path, 3).unwrap();
let retrieved = store.get(id).unwrap();
assert_eq!(retrieved, &vec);
}
}
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
let c = vec![0.0, 1.0, 0.0];
assert!(cosine_similarity(&a, &c).abs() < 0.001);
let d = vec![-1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &d) + 1.0).abs() < 0.001);
}
#[test]
fn test_dimension_mismatch() {
let mut store = VectorStore::new(3);
let id = Uuid::new_v4();
let result = store.store(id, vec![1.0, 2.0]); assert!(result.is_err());
}
}