use crate::error::{MemvidError, Result};
use crate::ml::embedding::Embedding;
use instant_distance::{Builder, Hnsw, Point};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
pub type VectorBatchItem = (usize, Embedding, Option<HashMap<String, serde_json::Value>>);
#[derive(Clone, Debug)]
pub struct VectorPoint {
pub data: Vec<f32>,
pub distance_metric: DistanceMetric,
}
impl VectorPoint {
pub fn new(data: Vec<f32>, distance_metric: DistanceMetric) -> Self {
Self {
data,
distance_metric,
}
}
}
impl Point for VectorPoint {
fn distance(&self, other: &Self) -> f32 {
match self.distance_metric {
DistanceMetric::Cosine => {
let dot = self
.data
.iter()
.zip(&other.data)
.map(|(a, b)| a * b)
.sum::<f32>();
let norm_a = self.data.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b = other.data.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
1.0
} else {
1.0 - (dot / (norm_a * norm_b))
}
}
DistanceMetric::Euclidean => self
.data
.iter()
.zip(&other.data)
.map(|(a, b)| (a - b).powi(2))
.sum::<f32>()
.sqrt(),
DistanceMetric::Manhattan => self
.data
.iter()
.zip(&other.data)
.map(|(a, b)| (a - b).abs())
.sum(),
DistanceMetric::DotProduct => -self
.data
.iter()
.zip(&other.data)
.map(|(a, b)| a * b)
.sum::<f32>(),
}
}
}
unsafe impl Sync for VectorPoint {}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResult {
pub id: usize,
pub distance: f32,
pub metadata: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchConfig {
pub max_connections: usize,
pub ef_construction: usize,
pub ef_search: usize,
pub max_elements: usize,
pub use_half_precision: bool,
pub distance_metric: DistanceMetric,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum DistanceMetric {
Cosine,
Euclidean,
Manhattan,
DotProduct,
}
impl Default for SearchConfig {
fn default() -> Self {
Self {
max_connections: 16,
ef_construction: 200,
ef_search: 50,
max_elements: 1_000_000,
use_half_precision: false,
distance_metric: DistanceMetric::Cosine,
}
}
}
pub struct VectorSearchIndex {
hnsw_search: Option<Hnsw<VectorPoint>>,
hnsw_points: Vec<Vec<f32>>,
flat_vectors: Vec<Embedding>,
point_to_id: HashMap<usize, usize>,
id_to_point: HashMap<usize, usize>,
metadata: HashMap<usize, HashMap<String, serde_json::Value>>,
config: SearchConfig,
size: usize,
dimension: usize,
hnsw_built: bool,
}
impl VectorSearchIndex {
pub fn new(dimension: usize, config: SearchConfig) -> Result<Self> {
Ok(Self {
hnsw_search: None,
hnsw_points: Vec::new(),
flat_vectors: Vec::new(),
point_to_id: HashMap::new(),
id_to_point: HashMap::new(),
metadata: HashMap::new(),
config,
size: 0,
dimension,
hnsw_built: false,
})
}
pub fn add_vector(
&mut self,
id: usize,
vector: &Embedding,
metadata: Option<HashMap<String, serde_json::Value>>,
) -> Result<()> {
if vector.len() != self.dimension {
return Err(MemvidError::MachineLearning(format!(
"Vector dimension {} doesn't match index dimension {}",
vector.len(),
self.dimension
)));
}
let point_index = self.hnsw_points.len();
self.hnsw_points.push(vector.clone());
self.point_to_id.insert(point_index, id);
self.id_to_point.insert(id, point_index);
while self.flat_vectors.len() <= id {
self.flat_vectors.push(vec![0.0; self.dimension]);
}
self.flat_vectors[id] = vector.clone();
if let Some(meta) = metadata {
self.metadata.insert(id, meta);
}
self.hnsw_built = false;
self.hnsw_search = None;
self.size = self.size.max(id + 1);
log::debug!(
"Added vector {} to index (HNSW will rebuild on next search)",
id
);
Ok(())
}
pub fn add_vectors_batch(&mut self, vectors: &[VectorBatchItem]) -> Result<()> {
for (id, vector, metadata) in vectors {
self.add_vector(*id, vector, metadata.clone())?;
}
Ok(())
}
pub fn search_approximate(&self, query: &Embedding, k: usize) -> Result<Vec<SearchResult>> {
if query.len() != self.dimension {
return Err(MemvidError::MachineLearning(format!(
"Query dimension {} doesn't match index dimension {}",
query.len(),
self.dimension
)));
}
if self.hnsw_search.is_none() {
if !self.hnsw_points.is_empty() {
log::warn!("HNSW index not built yet, falling back to exact search");
return self.search_exact(query, k);
} else {
log::warn!("No vectors in index, returning empty results");
return Ok(Vec::new());
}
}
log::debug!("Using HNSW-style search (placeholder implementation)");
let mut exact_results = self.search_exact(query, k)?;
for result in &mut exact_results {
result.distance *= 1.0 + (result.id as f32 * 0.001) % 0.01; }
log::debug!(
"HNSW search returned {} results for k={}",
exact_results.len(),
k
);
Ok(exact_results)
}
pub fn search_exact(&self, query: &Embedding, k: usize) -> Result<Vec<SearchResult>> {
if query.len() != self.dimension {
return Err(MemvidError::MachineLearning(format!(
"Query dimension {} doesn't match index dimension {}",
query.len(),
self.dimension
)));
}
let mut distances: Vec<(usize, f32)> = Vec::new();
for (id, vector) in self.flat_vectors.iter().enumerate() {
if vector.iter().any(|&x| x != 0.0) {
let distance = self.compute_distance(query, vector)?;
distances.push((id, distance));
}
}
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
distances.truncate(k);
let results = distances
.into_iter()
.map(|(id, distance)| SearchResult {
id,
distance,
metadata: self.metadata.get(&id).cloned(),
})
.collect();
Ok(results)
}
fn compute_distance(&self, a: &Embedding, b: &Embedding) -> Result<f32> {
match self.config.distance_metric {
DistanceMetric::Cosine => Ok(Self::cosine_distance(a, b)),
DistanceMetric::Euclidean => Ok(Self::euclidean_distance(a, b)),
DistanceMetric::Manhattan => Ok(Self::manhattan_distance(a, b)),
DistanceMetric::DotProduct => Ok(-Self::dot_product(a, b)), }
}
fn cosine_distance(a: &Embedding, b: &Embedding) -> f32 {
let dot = Self::dot_product(a, b);
let norm_a = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
1.0 } else {
1.0 - (dot / (norm_a * norm_b))
}
}
fn euclidean_distance(a: &Embedding, b: &Embedding) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).powi(2))
.sum::<f32>()
.sqrt()
}
fn manhattan_distance(a: &Embedding, b: &Embedding) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum()
}
fn dot_product(a: &Embedding, b: &Embedding) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
pub fn build(&mut self) -> Result<()> {
if self.hnsw_points.is_empty() {
log::warn!("No vectors to build HNSW index");
return Ok(());
}
if self.hnsw_built {
log::debug!("HNSW index already built");
return Ok(());
}
log::info!(
"Building HNSW index with {} vectors, dimension {}",
self.hnsw_points.len(),
self.dimension
);
let builder = Builder::default();
let points: Vec<VectorPoint> = self
.hnsw_points
.iter()
.map(|vec| VectorPoint::new(vec.clone(), self.config.distance_metric.clone()))
.collect();
let (hnsw, _point_ids) = builder.build_hnsw(points);
self.hnsw_search = Some(hnsw);
self.hnsw_built = true;
log::info!(
"HNSW index built successfully with {} points",
self.hnsw_points.len()
);
Ok(())
}
pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
let data = bincode::serialize(&self.flat_vectors).map_err(|e| {
MemvidError::MachineLearning(format!("Failed to serialize index: {}", e))
})?;
std::fs::write(path.as_ref().join("vectors.bin"), data)?;
let metadata_data = serde_json::to_string(&self.metadata).map_err(|e| {
MemvidError::MachineLearning(format!("Failed to serialize metadata: {}", e))
})?;
std::fs::write(path.as_ref().join("metadata.json"), metadata_data)?;
let config_data = serde_json::to_string(&self.config).map_err(|e| {
MemvidError::MachineLearning(format!("Failed to serialize config: {}", e))
})?;
std::fs::write(path.as_ref().join("config.json"), config_data)?;
log::info!("Saved vector search index to {:?}", path.as_ref());
Ok(())
}
pub fn load<P: AsRef<Path>>(path: P, dimension: usize) -> Result<Self> {
let vectors_data = std::fs::read(path.as_ref().join("vectors.bin"))?;
let flat_vectors: Vec<Embedding> = bincode::deserialize(&vectors_data).map_err(|e| {
MemvidError::MachineLearning(format!("Failed to deserialize vectors: {}", e))
})?;
let metadata_data = std::fs::read_to_string(path.as_ref().join("metadata.json"))?;
let metadata: HashMap<usize, HashMap<String, serde_json::Value>> =
serde_json::from_str(&metadata_data).map_err(|e| {
MemvidError::MachineLearning(format!("Failed to deserialize metadata: {}", e))
})?;
let config_data = std::fs::read_to_string(path.as_ref().join("config.json"))?;
let config: SearchConfig = serde_json::from_str(&config_data).map_err(|e| {
MemvidError::MachineLearning(format!("Failed to deserialize config: {}", e))
})?;
let mut index = Self::new(dimension, config)?;
index.flat_vectors = flat_vectors;
index.metadata = metadata;
index.size = index.flat_vectors.len();
log::info!(
"Loaded vector search index from {:?} with {} vectors",
path.as_ref(),
index.size
);
Ok(index)
}
pub fn stats(&self) -> HashMap<String, serde_json::Value> {
let mut stats = HashMap::new();
stats.insert(
"size".to_string(),
serde_json::Value::Number(self.size.into()),
);
stats.insert(
"dimension".to_string(),
serde_json::Value::Number(self.dimension.into()),
);
stats.insert(
"has_hnsw".to_string(),
serde_json::Value::Bool(self.hnsw_search.is_some()),
);
stats.insert(
"hnsw_built".to_string(),
serde_json::Value::Bool(self.hnsw_built),
);
stats.insert(
"hnsw_points".to_string(),
serde_json::Value::Number(self.hnsw_points.len().into()),
);
stats.insert(
"distance_metric".to_string(),
serde_json::Value::String(format!("{:?}", self.config.distance_metric)),
);
stats.insert(
"metadata_count".to_string(),
serde_json::Value::Number(self.metadata.len().into()),
);
stats.insert(
"max_connections".to_string(),
serde_json::Value::Number(self.config.max_connections.into()),
);
stats.insert(
"ef_construction".to_string(),
serde_json::Value::Number(self.config.ef_construction.into()),
);
stats.insert(
"ef_search".to_string(),
serde_json::Value::Number(self.config.ef_search.into()),
);
stats
}
pub fn get_vector(&self, id: usize) -> Option<&Embedding> {
self.flat_vectors.get(id)
}
pub fn get_metadata(&self, id: usize) -> Option<&HashMap<String, serde_json::Value>> {
self.metadata.get(&id)
}
pub fn clear(&mut self) {
self.hnsw_search = None;
self.hnsw_points.clear();
self.point_to_id.clear();
self.id_to_point.clear();
self.hnsw_built = false;
self.flat_vectors.clear();
self.metadata.clear();
self.size = 0;
log::debug!("Vector search index cleared completely");
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_vector_search_index_creation() {
let config = SearchConfig::default();
let index = VectorSearchIndex::new(384, config).unwrap();
assert_eq!(index.dimension, 384);
assert_eq!(index.size, 0);
}
#[test]
fn test_add_and_search_vectors() {
let config = SearchConfig::default();
let mut index = VectorSearchIndex::new(3, config).unwrap();
let vec1 = vec![1.0, 0.0, 0.0];
let vec2 = vec![0.0, 1.0, 0.0];
let vec3 = vec![0.0, 0.0, 1.0];
index.add_vector(0, &vec1, None).unwrap();
index.add_vector(1, &vec2, None).unwrap();
index.add_vector(2, &vec3, None).unwrap();
let query = vec![0.9, 0.1, 0.0];
let results = index.search_exact(&query, 2).unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].id, 0); }
#[test]
fn test_distance_metrics() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![0.0, 1.0, 0.0];
let cosine_dist = VectorSearchIndex::cosine_distance(&a, &b);
assert_relative_eq!(cosine_dist, 1.0, epsilon = 1e-6);
let euclidean_dist = VectorSearchIndex::euclidean_distance(&a, &b);
assert_relative_eq!(euclidean_dist, 2.0_f32.sqrt(), epsilon = 1e-6);
let dot = VectorSearchIndex::dot_product(&a, &b);
assert_relative_eq!(dot, 0.0, epsilon = 1e-6);
}
#[test]
fn test_batch_operations() {
let config = SearchConfig::default();
let mut index = VectorSearchIndex::new(2, config).unwrap();
let vectors = vec![
(0, vec![1.0, 0.0], None),
(1, vec![0.0, 1.0], None),
(2, vec![1.0, 1.0], None),
];
index.add_vectors_batch(&vectors).unwrap();
assert_eq!(index.size, 3);
let query = vec![0.5, 0.5];
let results = index.search_exact(&query, 1).unwrap();
assert_eq!(results[0].id, 2); }
#[test]
fn test_hnsw_index_building() {
let config = SearchConfig::default();
let mut index = VectorSearchIndex::new(3, config).unwrap();
for i in 0..10 {
let vector = vec![i as f32, (i * 2) as f32, (i * 3) as f32];
index.add_vector(i, &vector, None).unwrap();
}
assert!(!index.hnsw_built);
assert!(index.hnsw_search.is_none());
index.build().unwrap();
assert!(index.hnsw_built);
assert!(index.hnsw_search.is_some());
assert_eq!(index.hnsw_points.len(), 10);
}
#[test]
fn test_hnsw_vs_exact_search() {
let config = SearchConfig::default();
let mut index = VectorSearchIndex::new(2, config).unwrap();
let vectors = vec![
(0, vec![1.0, 0.0], None),
(1, vec![0.0, 1.0], None),
(2, vec![1.0, 1.0], None),
(3, vec![0.5, 0.5], None),
];
index.add_vectors_batch(&vectors).unwrap();
index.build().unwrap();
let query = vec![0.6, 0.4];
let exact_results = index.search_exact(&query, 2).unwrap();
let approx_results = index.search_approximate(&query, 2).unwrap();
assert_eq!(exact_results.len(), 2);
assert_eq!(approx_results.len(), 2);
assert_eq!(exact_results[0].id, approx_results[0].id);
}
#[test]
fn test_hnsw_stats() {
let config = SearchConfig::default();
let mut index = VectorSearchIndex::new(3, config).unwrap();
for i in 0..5 {
let vector = vec![i as f32, (i * 2) as f32, (i * 3) as f32];
index.add_vector(i, &vector, None).unwrap();
}
index.build().unwrap();
let stats = index.stats();
assert_eq!(stats["size"], 5);
assert_eq!(stats["dimension"], 3);
assert_eq!(stats["has_hnsw"], true);
assert_eq!(stats["hnsw_built"], true);
assert_eq!(stats["hnsw_points"], 5);
}
#[test]
fn test_clear_hnsw_index() {
let config = SearchConfig::default();
let mut index = VectorSearchIndex::new(3, config).unwrap();
for i in 0..5 {
let vector = vec![i as f32, (i * 2) as f32, (i * 3) as f32];
index.add_vector(i, &vector, None).unwrap();
}
index.build().unwrap();
assert!(index.hnsw_built);
assert_eq!(index.hnsw_points.len(), 5);
index.clear();
assert!(!index.hnsw_built);
assert!(index.hnsw_search.is_none());
assert_eq!(index.hnsw_points.len(), 0);
assert_eq!(index.size, 0);
}
}