use crate::{GraphRAGError, Result};
use std::collections::HashMap;
use voy::{Embeddings, Similarity};
#[cfg(target_arch = "wasm32")]
use wasm_bindgen::prelude::*;
pub struct VoyStore {
dimension: usize,
index: Option<Embeddings>,
id_to_index: HashMap<String, usize>,
index_to_id: Vec<String>,
pending_embeddings: Vec<Vec<f32>>,
index_built: bool,
}
impl VoyStore {
pub fn new(dimension: usize) -> Self {
Self {
dimension,
index: None,
id_to_index: HashMap::new(),
index_to_id: Vec::new(),
pending_embeddings: Vec::new(),
index_built: false,
}
}
pub fn add_vector(&mut self, id: String, embedding: Vec<f32>) -> Result<()> {
if embedding.len() != self.dimension {
return Err(GraphRAGError::VectorSearch {
message: format!(
"Embedding dimension mismatch: expected {}, got {}",
self.dimension,
embedding.len()
),
});
}
if self.id_to_index.contains_key(&id) {
return Err(GraphRAGError::VectorSearch {
message: format!("Vector ID '{}' already exists", id),
});
}
let index = self.pending_embeddings.len();
self.id_to_index.insert(id.clone(), index);
self.index_to_id.push(id);
self.pending_embeddings.push(embedding);
self.index_built = false;
Ok(())
}
pub fn build_index(&mut self) -> Result<()> {
if self.pending_embeddings.is_empty() {
return Err(GraphRAGError::VectorSearch {
message: "No embeddings to build index from".to_string(),
});
}
let flat_data: Vec<f32> = self
.pending_embeddings
.iter()
.flat_map(|v| v.iter().copied())
.collect();
let embeddings = Embeddings::builder(flat_data, self.pending_embeddings.len())
.with_dimension(self.dimension)
.with_similarity(Similarity::Cosine)
.build()
.map_err(|e| GraphRAGError::VectorSearch {
message: format!("Failed to build Voy index: {}", e),
})?;
self.index = Some(embeddings);
self.index_built = true;
Ok(())
}
pub fn search(&self, query_embedding: &[f32], top_k: usize) -> Result<Vec<(String, f32)>> {
if !self.index_built {
return Err(GraphRAGError::VectorSearch {
message: "Index not built. Call build_index() first.".to_string(),
});
}
if query_embedding.len() != self.dimension {
return Err(GraphRAGError::VectorSearch {
message: format!(
"Query dimension mismatch: expected {}, got {}",
self.dimension,
query_embedding.len()
),
});
}
let index = self.index.as_ref().ok_or_else(|| GraphRAGError::VectorSearch {
message: "Index not available".to_string(),
})?;
let results = index
.search(query_embedding, top_k)
.map_err(|e| GraphRAGError::VectorSearch {
message: format!("Voy search failed: {}", e),
})?;
let mut scored_results = Vec::new();
for (idx, similarity) in results.iter() {
if let Some(id) = self.index_to_id.get(*idx) {
scored_results.push((id.clone(), *similarity));
}
}
Ok(scored_results)
}
pub fn len(&self) -> usize {
self.pending_embeddings.len()
}
pub fn is_empty(&self) -> bool {
self.pending_embeddings.is_empty()
}
pub fn dimension(&self) -> usize {
self.dimension
}
pub fn is_index_built(&self) -> bool {
self.index_built
}
pub fn get_vector(&self, id: &str) -> Option<&Vec<f32>> {
self.id_to_index
.get(id)
.and_then(|&idx| self.pending_embeddings.get(idx))
}
pub fn contains(&self, id: &str) -> bool {
self.id_to_index.contains_key(id)
}
pub fn ids(&self) -> Vec<String> {
self.index_to_id.clone()
}
pub fn remove_vector(&mut self, id: &str) -> Result<()> {
let idx = self
.id_to_index
.remove(id)
.ok_or_else(|| GraphRAGError::VectorSearch {
message: format!("Vector ID '{}' not found", id),
})?;
if idx < self.index_to_id.len() {
self.index_to_id.remove(idx);
}
if idx < self.pending_embeddings.len() {
self.pending_embeddings.remove(idx);
}
for (_, index) in self.id_to_index.iter_mut() {
if *index > idx {
*index -= 1;
}
}
self.index_built = false;
self.index = None;
Ok(())
}
pub fn clear(&mut self) {
self.id_to_index.clear();
self.index_to_id.clear();
self.pending_embeddings.clear();
self.index = None;
self.index_built = false;
}
pub fn statistics(&self) -> VoyStoreStatistics {
let mut min_norm = f32::INFINITY;
let mut max_norm: f32 = 0.0;
let mut sum_norm: f32 = 0.0;
for embedding in &self.pending_embeddings {
let norm = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
min_norm = min_norm.min(norm);
max_norm = max_norm.max(norm);
sum_norm += norm;
}
let avg_norm = if !self.pending_embeddings.is_empty() {
sum_norm / self.pending_embeddings.len() as f32
} else {
0.0
};
VoyStoreStatistics {
vector_count: self.len(),
dimension: self.dimension,
index_built: self.index_built,
min_norm,
max_norm,
avg_norm,
}
}
}
impl Default for VoyStore {
fn default() -> Self {
Self::new(384) }
}
#[derive(Debug, Clone)]
pub struct VoyStoreStatistics {
pub vector_count: usize,
pub dimension: usize,
pub index_built: bool,
pub min_norm: f32,
pub max_norm: f32,
pub avg_norm: f32,
}
impl VoyStoreStatistics {
pub fn print(&self) {
println!("Voy Vector Store Statistics:");
println!(" Algorithm: k-d tree (Voy 0.6)");
println!(" Vector count: {}", self.vector_count);
println!(" Dimension: {}", self.dimension);
println!(" Index built: {}", self.index_built);
println!(" Bundle size: ~75KB (optimized for WASM)");
if self.vector_count > 0 {
println!(" Vector norms:");
println!(" Min: {:.4}", self.min_norm);
println!(" Max: {:.4}", self.max_norm);
println!(" Average: {:.4}", self.avg_norm);
}
}
}
#[cfg(target_arch = "wasm32")]
#[wasm_bindgen]
pub struct VoyStoreWasm {
inner: VoyStore,
}
#[cfg(target_arch = "wasm32")]
#[wasm_bindgen]
impl VoyStoreWasm {
#[wasm_bindgen(constructor)]
pub fn new(dimension: usize) -> Self {
Self {
inner: VoyStore::new(dimension),
}
}
#[wasm_bindgen(js_name = addVector)]
pub fn add_vector(&mut self, id: String, embedding: Vec<f32>) -> std::result::Result<(), JsValue> {
self.inner
.add_vector(id, embedding)
.map_err(|e| JsValue::from_str(&e.to_string()))
}
#[wasm_bindgen(js_name = buildIndex)]
pub fn build_index(&mut self) -> std::result::Result<(), JsValue> {
self.inner
.build_index()
.map_err(|e| JsValue::from_str(&e.to_string()))
}
#[wasm_bindgen(js_name = search)]
pub fn search(&self, query_embedding: Vec<f32>, top_k: usize) -> std::result::Result<JsValue, JsValue> {
let results = self
.inner
.search(&query_embedding, top_k)
.map_err(|e| JsValue::from_str(&e.to_string()))?;
serde_wasm_bindgen::to_value(&results)
.map_err(|e| JsValue::from_str(&e.to_string()))
}
#[wasm_bindgen(js_name = len)]
pub fn len(&self) -> usize {
self.inner.len()
}
#[wasm_bindgen(js_name = dimension)]
pub fn dimension(&self) -> usize {
self.inner.dimension()
}
#[wasm_bindgen(js_name = clear)]
pub fn clear(&mut self) {
self.inner.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_voy_store_creation() {
let store = VoyStore::new(384);
assert_eq!(store.dimension(), 384);
assert!(store.is_empty());
assert!(!store.is_index_built());
}
#[test]
fn test_add_vector() {
let mut store = VoyStore::new(3);
let embedding = vec![0.1, 0.2, 0.3];
assert!(store.add_vector("doc1".to_string(), embedding).is_ok());
assert_eq!(store.len(), 1);
assert!(store.contains("doc1"));
}
#[test]
fn test_dimension_mismatch() {
let mut store = VoyStore::new(3);
let wrong_embedding = vec![0.1, 0.2];
assert!(store
.add_vector("doc1".to_string(), wrong_embedding)
.is_err());
}
#[test]
fn test_duplicate_id() {
let mut store = VoyStore::new(3);
let embedding = vec![0.1, 0.2, 0.3];
store.add_vector("doc1".to_string(), embedding.clone()).unwrap();
assert!(store.add_vector("doc1".to_string(), embedding).is_err());
}
#[test]
fn test_build_and_search() {
let mut store = VoyStore::new(3);
store.add_vector("doc1".to_string(), vec![1.0, 0.0, 0.0]).unwrap();
store.add_vector("doc2".to_string(), vec![0.0, 1.0, 0.0]).unwrap();
store.add_vector("doc3".to_string(), vec![0.9, 0.1, 0.0]).unwrap();
assert!(store.build_index().is_ok());
assert!(store.is_index_built());
let query = vec![1.0, 0.0, 0.0];
let results = store.search(&query, 2).unwrap();
assert!(!results.is_empty());
assert!(results.len() <= 2);
let first_id = &results[0].0;
assert!(first_id == "doc1" || first_id == "doc3");
}
#[test]
fn test_search_without_index() {
let store = VoyStore::new(3);
let query = vec![1.0, 0.0, 0.0];
assert!(store.search(&query, 5).is_err());
}
#[test]
fn test_remove_vector() {
let mut store = VoyStore::new(3);
store.add_vector("doc1".to_string(), vec![1.0, 0.0, 0.0]).unwrap();
store.add_vector("doc2".to_string(), vec![0.0, 1.0, 0.0]).unwrap();
assert_eq!(store.len(), 2);
store.remove_vector("doc1").unwrap();
assert_eq!(store.len(), 1);
assert!(!store.contains("doc1"));
assert!(store.contains("doc2"));
}
#[test]
fn test_statistics() {
let mut store = VoyStore::new(3);
store.add_vector("doc1".to_string(), vec![1.0, 0.0, 0.0]).unwrap();
store.add_vector("doc2".to_string(), vec![0.0, 1.0, 0.0]).unwrap();
let stats = store.statistics();
assert_eq!(stats.vector_count, 2);
assert_eq!(stats.dimension, 3);
assert!(!stats.index_built);
}
}