use crate::store::types::{Distance, Id};
use crate::store::wasm_hnsw::WasmHnsw;
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WasmVectorBackend {
hnsw: WasmHnsw,
id_to_idx: HashMap<Id, usize>,
idx_to_id: HashMap<usize, Id>,
next_idx: usize,
}
impl WasmVectorBackend {
pub fn new(dimension: usize) -> Self {
Self::with_params(dimension, Distance::Cosine, 16, 200)
}
pub fn with_params(
dimension: usize,
metric: Distance,
m: usize,
ef_construction: usize,
) -> Self {
Self {
hnsw: WasmHnsw::with_params(dimension, metric, m, ef_construction),
id_to_idx: HashMap::new(),
idx_to_id: HashMap::new(),
next_idx: 0,
}
}
pub fn insert(&mut self, id: Id, vector: &[f32]) -> Result<()> {
self.hnsw.insert(id, vector.to_vec())
}
pub fn batch_insert(&mut self, items: Vec<(Id, Vec<f32>)>) -> Result<()> {
for (id, vector) in items {
self.insert(id, &vector)?;
}
Ok(())
}
pub fn optimize(&mut self, _vectors: &[(Id, Vec<f32>)]) -> Result<usize> {
Ok(self.hnsw.len())
}
pub fn remove(&mut self, id: &str) -> Result<()> {
self.hnsw.remove(id)
}
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(Id, f32)>> {
let ef_search = k.max(50);
self.search_with_ef(query, k, ef_search)
}
pub fn search_with_ef(
&self,
query: &[f32],
k: usize,
ef_search: usize,
) -> Result<Vec<(Id, f32)>> {
self.hnsw.search(query, k, ef_search)
}
pub fn save_index(&self, _path: &Path) -> Result<()> {
Ok(())
}
pub fn rebuild_from_vectors(&mut self, vectors: &[(Id, Vec<f32>)]) -> Result<()> {
self.hnsw.clear();
for (id, vector) in vectors {
self.insert(id.clone(), vector)?;
}
Ok(())
}
pub fn get_id_to_idx_map(&self) -> &HashMap<Id, usize> {
&self.id_to_idx
}
pub fn get_idx_to_id_map(&self) -> &HashMap<usize, Id> {
&self.idx_to_id
}
pub fn set_mappings(
&mut self,
id_to_idx: HashMap<Id, usize>,
idx_to_id: HashMap<usize, Id>,
next_idx: usize,
) {
self.id_to_idx = id_to_idx;
self.idx_to_id = idx_to_id;
self.next_idx = next_idx;
}
pub fn restore(
dimension: usize,
id_to_idx: HashMap<Id, usize>,
idx_to_id: HashMap<usize, Id>,
next_idx: usize,
) -> Result<Self> {
let mut backend = Self::new(dimension);
backend.set_mappings(id_to_idx, idx_to_id, next_idx);
Ok(backend)
}
pub fn len(&self) -> usize {
self.hnsw.len()
}
pub fn is_empty(&self) -> bool {
self.hnsw.is_empty()
}
pub fn dimension(&self) -> usize {
self.hnsw.stats().num_nodes
}
pub fn ids(&self) -> Vec<Id> {
self.hnsw.ids()
}
pub fn clear(&mut self) {
self.hnsw.clear();
}
pub fn stats(&self) -> String {
let stats = self.hnsw.stats();
format!(
"WASM HNSW Stats:\n\
- Nodes: {}\n\
- Edges: {}\n\
- Max Layer: {}\n\
- M: {}\n\
- ef_construction: {}\n\
- Layer distribution: {:?}",
stats.num_nodes,
stats.num_edges,
stats.max_layer,
stats.m,
stats.ef_construction,
stats.layer_distribution
)
}
pub fn to_visualizer(&self) -> Result<crate::graph_viz::HnswVisualizer> {
Ok(self.hnsw.to_visualizer())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_wasm_backend_hnsw() {
let mut backend = WasmVectorBackend::new(3);
backend
.insert("v1".to_string(), &vec![1.0, 0.0, 0.0])
.unwrap();
backend
.insert("v2".to_string(), &vec![0.0, 1.0, 0.0])
.unwrap();
backend
.insert("v3".to_string(), &vec![1.0, 1.0, 0.0])
.unwrap();
backend
.insert("v4".to_string(), &vec![0.5, 0.5, 0.0])
.unwrap();
assert_eq!(backend.len(), 4);
let results = backend.search(&[1.0, 0.1, 0.0], 2).unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, "v1");
println!("{}", backend.stats());
}
#[test]
fn test_wasm_backend_delete() {
let mut backend = WasmVectorBackend::new(2);
backend.insert("v1".to_string(), &vec![1.0, 2.0]).unwrap();
backend.insert("v2".to_string(), &vec![3.0, 4.0]).unwrap();
backend.insert("v3".to_string(), &vec![5.0, 6.0]).unwrap();
assert_eq!(backend.len(), 3);
backend.remove("v1").unwrap();
assert_eq!(backend.len(), 2);
let results = backend.search(&[1.0, 2.0], 3).unwrap();
assert_eq!(results.len(), 2); }
#[test]
fn test_wasm_backend_batch() {
let mut backend = WasmVectorBackend::new(4);
let batch = vec![
("v1".to_string(), vec![1.0, 0.0, 0.0, 0.0]),
("v2".to_string(), vec![0.0, 1.0, 0.0, 0.0]),
("v3".to_string(), vec![0.0, 0.0, 1.0, 0.0]),
("v4".to_string(), vec![0.0, 0.0, 0.0, 1.0]),
];
backend.batch_insert(batch).unwrap();
assert_eq!(backend.len(), 4);
let results = backend.search(&[1.0, 0.0, 0.0, 0.0], 2).unwrap();
assert_eq!(results[0].0, "v1");
}
#[test]
fn test_dimension_validation() {
let mut backend = WasmVectorBackend::new(3);
let result = backend.insert("v1".to_string(), &vec![1.0, 2.0]);
assert!(result.is_err());
let result = backend.insert("v1".to_string(), &vec![1.0, 2.0, 3.0]);
assert!(result.is_ok());
}
#[test]
fn test_large_scale() {
use rand::Rng;
let mut backend = WasmVectorBackend::with_params(128, Distance::Cosine, 16, 200);
let mut rng = rand::thread_rng();
for i in 0..1000 {
let vector: Vec<f32> = (0..128).map(|_| rng.gen()).collect();
backend.insert(format!("v{}", i), &vector).unwrap();
}
assert_eq!(backend.len(), 1000);
let query: Vec<f32> = (0..128).map(|_| rng.gen()).collect();
let results = backend.search(&query, 10).unwrap();
assert_eq!(results.len(), 10);
for i in 1..results.len() {
assert!(results[i].1 >= results[i - 1].1);
}
println!("{}", backend.stats());
}
}