use anyhow::{anyhow, Result};
use dashmap::DashMap;
use scirs2_core::metrics::{Counter, Histogram, MetricsRegistry, Timer};
use scirs2_core::ndarray_ext::ArrayView1;
use scirs2_core::random::Random;
use scirs2_core::rngs::StdRng;
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap, HashSet};
use std::sync::Arc;
use tokio::sync::RwLock;
#[async_trait::async_trait]
pub trait VectorStore: Send + Sync {
async fn insert(
&self,
id: String,
vector: Vec<f32>,
metadata: Option<HashMap<String, String>>,
) -> Result<()>;
async fn insert_batch(&self, vectors: Vec<VectorData>) -> Result<()>;
async fn search(&self, query: &VectorQuery) -> Result<Vec<(String, f32)>>;
async fn get(&self, id: &str) -> Result<Option<VectorData>>;
async fn delete(&self, id: &str) -> Result<bool>;
async fn update(
&self,
id: String,
vector: Vec<f32>,
metadata: Option<HashMap<String, String>>,
) -> Result<()>;
fn size(&self) -> usize;
async fn build_index(&self) -> Result<()>;
async fn get_statistics(&self) -> Result<VectorStoreStats>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorData {
pub id: String,
pub vector: Vec<f32>,
pub metadata: Option<HashMap<String, String>>,
pub timestamp: std::time::SystemTime,
}
#[derive(Debug, Clone)]
pub struct VectorQuery {
pub vector: Vec<f32>,
pub k: usize,
pub metric: Option<SimilarityMetric>,
pub include_metadata: bool,
pub filters: Option<Vec<Filter>>,
pub min_similarity: Option<f32>,
}
#[derive(Debug, Clone)]
pub struct Filter {
pub field: String,
pub operation: FilterOperation,
pub value: String,
}
#[derive(Debug, Clone)]
pub enum FilterOperation {
Equals,
NotEquals,
Contains,
StartsWith,
EndsWith,
GreaterThan,
LessThan,
In(Vec<String>),
NotIn(Vec<String>),
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum SimilarityMetric {
Cosine,
Euclidean,
Manhattan,
DotProduct,
Jaccard,
Hamming,
}
impl std::fmt::Display for SimilarityMetric {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SimilarityMetric::Cosine => write!(f, "cosine"),
SimilarityMetric::Euclidean => write!(f, "euclidean"),
SimilarityMetric::Manhattan => write!(f, "manhattan"),
SimilarityMetric::DotProduct => write!(f, "dot_product"),
SimilarityMetric::Jaccard => write!(f, "jaccard"),
SimilarityMetric::Hamming => write!(f, "hamming"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorStoreStats {
pub total_vectors: usize,
pub dimension: usize,
pub index_type: String,
pub index_build_time: std::time::Duration,
pub memory_usage: usize,
pub avg_query_time: std::time::Duration,
pub cache_hit_rate: f32,
}
pub struct InMemoryVectorStore {
vectors: Arc<DashMap<String, VectorData>>,
index: Arc<RwLock<Option<Box<dyn VectorIndex>>>>,
config: VectorStoreConfig,
query_cache: Arc<DashMap<String, Vec<(String, f32)>>>,
stats: Arc<RwLock<VectorStoreStats>>,
cache_hits: Arc<std::sync::atomic::AtomicUsize>,
cache_misses: Arc<std::sync::atomic::AtomicUsize>,
insert_counter: Arc<Counter>,
search_counter: Arc<Counter>,
search_timer: Arc<Timer>,
index_build_timer: Arc<Timer>,
similarity_histogram: Arc<Histogram>,
metrics_registry: Arc<MetricsRegistry>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorStoreConfig {
pub dimension: usize,
pub default_metric: SimilarityMetric,
pub index_type: IndexType,
pub enable_cache: bool,
pub cache_size: usize,
pub cache_ttl: u64,
pub batch_size: usize,
}
impl Default for VectorStoreConfig {
fn default() -> Self {
Self {
dimension: 128,
default_metric: SimilarityMetric::Cosine,
index_type: IndexType::Flat,
enable_cache: true,
cache_size: 10000,
cache_ttl: 3600,
batch_size: 1000,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum IndexType {
Flat,
HNSW {
max_connections: usize,
ef_construction: usize,
ef_search: usize,
},
IVF {
num_clusters: usize,
num_probes: usize,
},
LSH {
num_tables: usize,
hash_length: usize,
},
PQ {
num_subquantizers: usize,
bits_per_subquantizer: usize,
},
}
#[async_trait::async_trait]
pub trait VectorIndex: Send + Sync {
async fn build(&mut self, vectors: &DashMap<String, VectorData>) -> Result<()>;
async fn search(
&self,
query: &[f32],
k: usize,
metric: SimilarityMetric,
) -> Result<Vec<(String, f32)>>;
async fn add(&mut self, id: String, vector: Vec<f32>) -> Result<()>;
async fn remove(&mut self, id: &str) -> Result<()>;
fn get_stats(&self) -> IndexStats;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IndexStats {
pub index_type: String,
pub num_vectors: usize,
pub build_time: std::time::Duration,
pub memory_usage: usize,
}
pub struct FlatIndex {
vectors: HashMap<String, Vec<f32>>,
stats: IndexStats,
}
impl Default for FlatIndex {
fn default() -> Self {
Self::new()
}
}
impl FlatIndex {
pub fn new() -> Self {
Self {
vectors: HashMap::new(),
stats: IndexStats {
index_type: "Flat".to_string(),
num_vectors: 0,
build_time: std::time::Duration::from_secs(0),
memory_usage: 0,
},
}
}
}
#[async_trait::async_trait]
impl VectorIndex for FlatIndex {
async fn build(&mut self, vectors: &DashMap<String, VectorData>) -> Result<()> {
let start = std::time::Instant::now();
self.vectors.clear();
for entry in vectors.iter() {
self.vectors
.insert(entry.key().clone(), entry.value().vector.clone());
}
self.stats.num_vectors = self.vectors.len();
self.stats.build_time = start.elapsed();
self.stats.memory_usage = self.vectors.len()
* self
.vectors
.values()
.next()
.map(|v| v.len() * 4)
.unwrap_or(0);
Ok(())
}
async fn search(
&self,
query: &[f32],
k: usize,
metric: SimilarityMetric,
) -> Result<Vec<(String, f32)>> {
let mut similarities = Vec::new();
for (id, vector) in &self.vectors {
let similarity = compute_similarity(query, vector, metric)?;
similarities.push((id.clone(), similarity));
}
similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
similarities.truncate(k);
Ok(similarities)
}
async fn add(&mut self, id: String, vector: Vec<f32>) -> Result<()> {
self.vectors.insert(id, vector);
self.stats.num_vectors = self.vectors.len();
Ok(())
}
async fn remove(&mut self, id: &str) -> Result<()> {
self.vectors.remove(id);
self.stats.num_vectors = self.vectors.len();
Ok(())
}
fn get_stats(&self) -> IndexStats {
self.stats.clone()
}
}
pub struct HNSWIndex {
#[allow(dead_code)]
max_connections: usize,
#[allow(dead_code)]
ef_construction: usize,
#[allow(dead_code)]
ef_search: usize,
layers: Vec<HashMap<String, Vec<String>>>,
vectors: HashMap<String, Vec<f32>>,
entry_point: Option<String>,
stats: IndexStats,
rng: Random<StdRng>,
}
impl HNSWIndex {
pub fn new(max_connections: usize, ef_construction: usize, ef_search: usize) -> Self {
Self {
max_connections,
ef_construction,
ef_search,
layers: Vec::new(),
vectors: HashMap::new(),
entry_point: None,
stats: IndexStats {
index_type: "HNSW".to_string(),
num_vectors: 0,
build_time: std::time::Duration::from_secs(0),
memory_usage: 0,
},
rng: Random::seed(42), }
}
}
#[async_trait::async_trait]
impl VectorIndex for HNSWIndex {
async fn build(&mut self, vectors: &DashMap<String, VectorData>) -> Result<()> {
let start = std::time::Instant::now();
self.layers.clear();
self.vectors.clear();
for entry in vectors.iter() {
let id = entry.key().clone();
let vector = entry.value().vector.clone();
self.vectors.insert(id.clone(), vector);
let layer = self.get_random_layer();
while self.layers.len() <= layer {
self.layers.push(HashMap::new());
}
for l in 0..=layer {
if l >= self.layers.len() {
self.layers.push(HashMap::new());
}
self.layers[l].insert(id.clone(), Vec::new());
}
if self.entry_point.is_none() || layer >= self.layers.len() - 1 {
self.entry_point = Some(id.clone());
}
}
self.build_connections().await?;
self.stats.num_vectors = self.vectors.len();
self.stats.build_time = start.elapsed();
Ok(())
}
async fn search(
&self,
query: &[f32],
k: usize,
metric: SimilarityMetric,
) -> Result<Vec<(String, f32)>> {
self.beam_search(query, k, metric)
}
async fn add(&mut self, id: String, vector: Vec<f32>) -> Result<()> {
self.vectors.insert(id.clone(), vector);
if self.layers.is_empty() {
self.layers.push(HashMap::new());
}
self.layers[0].insert(id.clone(), Vec::new());
if self.entry_point.is_none() {
self.entry_point = Some(id);
}
self.stats.num_vectors = self.vectors.len();
Ok(())
}
async fn remove(&mut self, id: &str) -> Result<()> {
self.vectors.remove(id);
for layer in &mut self.layers {
layer.remove(id);
}
self.stats.num_vectors = self.vectors.len();
Ok(())
}
fn get_stats(&self) -> IndexStats {
self.stats.clone()
}
}
impl HNSWIndex {
fn get_random_layer(&mut self) -> usize {
let mut layer = 0;
while (self.rng.random_f64() as f32) < 0.5 && layer < 16 {
layer += 1;
}
layer
}
async fn build_connections(&mut self) -> Result<()> {
let ids: Vec<String> = self.vectors.keys().cloned().collect();
if ids.is_empty() {
return Ok(());
}
for id in &ids {
let vector = match self.vectors.get(id) {
Some(v) => v.clone(),
None => continue,
};
for (layer_idx, layer) in self.layers.iter_mut().enumerate() {
if !layer.contains_key(id) {
continue;
}
let mut candidates: Vec<(String, f32)> = Vec::new();
for (other_id, _) in layer.iter() {
if other_id == id {
continue;
}
if let Some(other_vector) = self.vectors.get(other_id) {
let similarity =
compute_similarity(&vector, other_vector, SimilarityMetric::Cosine)
.unwrap_or(0.0);
candidates.push((other_id.clone(), similarity));
}
}
candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
let max_conn = if layer_idx == 0 {
self.max_connections * 2 } else {
self.max_connections
};
candidates.truncate(max_conn);
let connections: Vec<String> = candidates.into_iter().map(|(cid, _)| cid).collect();
layer.insert(id.clone(), connections.clone());
for neighbor_id in connections {
if let Some(neighbor_connections) = layer.get_mut(&neighbor_id) {
if !neighbor_connections.contains(id)
&& neighbor_connections.len() < max_conn
{
neighbor_connections.push(id.clone());
}
}
}
}
}
let mut memory = 0;
for (id, vec) in &self.vectors {
memory += id.len() + vec.len() * 4;
}
for layer in &self.layers {
for (id, connections) in layer {
memory += id.len() + connections.len() * 8; }
}
self.stats.memory_usage = memory;
Ok(())
}
fn beam_search(
&self,
query: &[f32],
k: usize,
metric: SimilarityMetric,
) -> Result<Vec<(String, f32)>> {
if self.vectors.is_empty() {
return Ok(Vec::new());
}
let entry = match &self.entry_point {
Some(e) => e.clone(),
None => return Ok(Vec::new()),
};
let mut current_best = entry.clone();
for layer_idx in (1..self.layers.len()).rev() {
let layer = &self.layers[layer_idx];
while let Some(current_vector) = self.vectors.get(¤t_best) {
let current_sim = compute_similarity(query, current_vector, metric)?;
let mut improved = false;
if let Some(neighbors) = layer.get(¤t_best) {
for neighbor in neighbors {
if let Some(neighbor_vector) = self.vectors.get(neighbor) {
let neighbor_sim = compute_similarity(query, neighbor_vector, metric)?;
if neighbor_sim > current_sim {
current_best = neighbor.clone();
improved = true;
break;
}
}
}
}
if !improved {
break;
}
}
}
if self.layers.is_empty() {
return Ok(Vec::new());
}
let bottom_layer = &self.layers[0];
let ef = std::cmp::max(k, self.ef_search);
let mut candidates = BinaryHeap::new();
let mut visited = HashSet::new();
if let Some(entry_vector) = self.vectors.get(¤t_best) {
let sim = compute_similarity(query, entry_vector, metric)?;
candidates.push(SimilarityItem {
id: current_best.clone(),
similarity: sim,
});
visited.insert(current_best);
}
let mut results: Vec<(String, f32)> = Vec::new();
while let Some(current) = candidates.pop() {
if results.len() < ef {
results.push((current.id.clone(), current.similarity));
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
} else if current.similarity > results.last().map(|r| r.1).unwrap_or(f32::NEG_INFINITY)
{
results.pop();
results.push((current.id.clone(), current.similarity));
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
}
if let Some(neighbors) = bottom_layer.get(¤t.id) {
for neighbor in neighbors {
if !visited.contains(neighbor) {
visited.insert(neighbor.clone());
if let Some(neighbor_vector) = self.vectors.get(neighbor) {
let sim = compute_similarity(query, neighbor_vector, metric)?;
let worst_result =
results.last().map(|r| r.1).unwrap_or(f32::NEG_INFINITY);
if results.len() < ef || sim > worst_result {
candidates.push(SimilarityItem {
id: neighbor.clone(),
similarity: sim,
});
}
}
}
}
}
}
results.truncate(k);
Ok(results)
}
}
#[derive(Debug, Clone)]
struct SimilarityItem {
id: String,
similarity: f32,
}
impl PartialEq for SimilarityItem {
fn eq(&self, other: &Self) -> bool {
self.similarity == other.similarity
}
}
impl Eq for SimilarityItem {}
impl PartialOrd for SimilarityItem {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for SimilarityItem {
fn cmp(&self, other: &Self) -> Ordering {
other
.similarity
.partial_cmp(&self.similarity)
.unwrap_or(Ordering::Equal)
}
}
impl InMemoryVectorStore {
pub fn new(config: VectorStoreConfig) -> Self {
let stats = VectorStoreStats {
total_vectors: 0,
dimension: config.dimension,
index_type: format!("{:?}", config.index_type),
index_build_time: std::time::Duration::from_secs(0),
memory_usage: 0,
avg_query_time: std::time::Duration::from_millis(0),
cache_hit_rate: 0.0,
};
let metrics_registry = Arc::new(MetricsRegistry::new());
let insert_counter = Arc::new(Counter::new("vector_inserts".to_string()));
let search_counter = Arc::new(Counter::new("vector_searches".to_string()));
let search_timer = Arc::new(Timer::new("search_latency".to_string()));
let index_build_timer = Arc::new(Timer::new("index_build_time".to_string()));
let similarity_histogram = Arc::new(Histogram::new("similarity_scores".to_string()));
Self {
vectors: Arc::new(DashMap::new()),
index: Arc::new(RwLock::new(None)),
config,
query_cache: Arc::new(DashMap::new()),
stats: Arc::new(RwLock::new(stats)),
cache_hits: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
cache_misses: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
insert_counter,
search_counter,
search_timer,
index_build_timer,
similarity_histogram,
metrics_registry,
}
}
fn apply_filters(&self, data: &VectorData, filters: &[Filter]) -> bool {
if let Some(metadata) = &data.metadata {
for filter in filters {
if let Some(value) = metadata.get(&filter.field) {
match &filter.operation {
FilterOperation::Equals => {
if value != &filter.value {
return false;
}
}
FilterOperation::NotEquals => {
if value == &filter.value {
return false;
}
}
FilterOperation::Contains => {
if !value.contains(&filter.value) {
return false;
}
}
FilterOperation::StartsWith => {
if !value.starts_with(&filter.value) {
return false;
}
}
FilterOperation::EndsWith => {
if !value.ends_with(&filter.value) {
return false;
}
}
FilterOperation::In(values) => {
if !values.contains(value) {
return false;
}
}
FilterOperation::NotIn(values) => {
if values.contains(value) {
return false;
}
}
FilterOperation::GreaterThan => {
if let (Ok(val_num), Ok(filter_num)) =
(value.parse::<f64>(), filter.value.parse::<f64>())
{
if val_num <= filter_num {
return false;
}
} else {
if value <= &filter.value {
return false;
}
}
}
FilterOperation::LessThan => {
if let (Ok(val_num), Ok(filter_num)) =
(value.parse::<f64>(), filter.value.parse::<f64>())
{
if val_num >= filter_num {
return false;
}
} else {
if value >= &filter.value {
return false;
}
}
}
}
} else {
return false; }
}
} else if !filters.is_empty() {
return false; }
true
}
}
#[async_trait::async_trait]
impl VectorStore for InMemoryVectorStore {
async fn insert(
&self,
id: String,
vector: Vec<f32>,
metadata: Option<HashMap<String, String>>,
) -> Result<()> {
if vector.len() != self.config.dimension {
return Err(anyhow!(
"Vector dimension mismatch: expected {}, got {}",
self.config.dimension,
vector.len()
));
}
let data = VectorData {
id: id.clone(),
vector,
metadata,
timestamp: std::time::SystemTime::now(),
};
let id_for_lookup = id.clone();
self.vectors.insert(id.clone(), data);
self.insert_counter.inc();
if let Some(index) = self.index.write().await.as_mut() {
index
.add(
id,
self.vectors
.get(&id_for_lookup)
.expect("vector should exist after insert")
.vector
.clone(),
)
.await?;
}
let mut stats = self.stats.write().await;
stats.total_vectors = self.vectors.len();
Ok(())
}
async fn insert_batch(&self, vectors: Vec<VectorData>) -> Result<()> {
for data in vectors {
if data.vector.len() != self.config.dimension {
return Err(anyhow!("Vector dimension mismatch"));
}
self.vectors.insert(data.id.clone(), data);
}
if self.index.read().await.is_some() {
self.build_index().await?;
}
Ok(())
}
async fn search(&self, query: &VectorQuery) -> Result<Vec<(String, f32)>> {
self.search_counter.inc();
let metric = query.metric.unwrap_or(self.config.default_metric);
if self.config.enable_cache {
let cache_key = format!(
"{:?}_{}_{}_{:?}",
query.vector, query.k, metric, query.filters
);
if let Some(cached) = self.query_cache.get(&cache_key) {
self.cache_hits
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
return Ok(cached.clone());
} else {
self.cache_misses
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
}
let start = std::time::Instant::now();
let results = if let Some(index) = self.index.read().await.as_ref() {
index.search(&query.vector, query.k, metric).await?
} else {
let mut similarities = Vec::new();
for entry in self.vectors.iter() {
let data = entry.value();
if let Some(filters) = &query.filters {
if !self.apply_filters(data, filters) {
continue;
}
}
let similarity = compute_similarity(&query.vector, &data.vector, metric)?;
self.similarity_histogram.observe(similarity as f64);
if let Some(min_sim) = query.min_similarity {
if similarity < min_sim {
continue;
}
}
similarities.push((entry.key().clone(), similarity));
}
similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
similarities.truncate(query.k);
similarities
};
let query_time = start.elapsed();
self.search_timer.observe(query_time);
let mut stats = self.stats.write().await;
stats.avg_query_time = (stats.avg_query_time + query_time) / 2;
if self.config.enable_cache {
let cache_key = format!(
"{:?}_{}_{}_{:?}",
query.vector, query.k, metric, query.filters
);
self.query_cache.insert(cache_key, results.clone());
}
Ok(results)
}
async fn get(&self, id: &str) -> Result<Option<VectorData>> {
Ok(self.vectors.get(id).map(|entry| entry.value().clone()))
}
async fn delete(&self, id: &str) -> Result<bool> {
let removed = self.vectors.remove(id).is_some();
if removed {
if let Some(index) = self.index.write().await.as_mut() {
index.remove(id).await?;
}
let mut stats = self.stats.write().await;
stats.total_vectors = self.vectors.len();
}
Ok(removed)
}
async fn update(
&self,
id: String,
vector: Vec<f32>,
metadata: Option<HashMap<String, String>>,
) -> Result<()> {
if vector.len() != self.config.dimension {
return Err(anyhow!("Vector dimension mismatch"));
}
let data = VectorData {
id: id.clone(),
vector: vector.clone(),
metadata,
timestamp: std::time::SystemTime::now(),
};
self.vectors.insert(id.clone(), data);
if let Some(index) = self.index.write().await.as_mut() {
index.remove(&id).await?;
index.add(id, vector).await?;
}
Ok(())
}
fn size(&self) -> usize {
self.vectors.len()
}
async fn build_index(&self) -> Result<()> {
let start = std::time::Instant::now();
let mut new_index: Box<dyn VectorIndex> = match &self.config.index_type {
IndexType::Flat => Box::new(FlatIndex::new()),
IndexType::HNSW {
max_connections,
ef_construction,
ef_search,
} => Box::new(HNSWIndex::new(
*max_connections,
*ef_construction,
*ef_search,
)),
_ => return Err(anyhow!("Index type not yet implemented")),
};
new_index.build(&self.vectors).await?;
*self.index.write().await = Some(new_index);
let mut stats = self.stats.write().await;
stats.index_build_time = start.elapsed();
Ok(())
}
async fn get_statistics(&self) -> Result<VectorStoreStats> {
let mut stats = self.stats.read().await.clone();
let hits = self.cache_hits.load(std::sync::atomic::Ordering::Relaxed);
let misses = self.cache_misses.load(std::sync::atomic::Ordering::Relaxed);
let total = hits + misses;
stats.cache_hit_rate = if total > 0 {
hits as f32 / total as f32
} else {
0.0
};
Ok(stats)
}
}
impl InMemoryVectorStore {
pub fn get_performance_metrics(&self) -> VectorStorePerformanceMetrics {
let insert_count = self.insert_counter.get();
let search_count = self.search_counter.get();
let search_timer_stats = self.search_timer.get_stats();
let index_timer_stats = self.index_build_timer.get_stats();
let similarity_hist_stats = self.similarity_histogram.get_stats();
VectorStorePerformanceMetrics {
total_inserts: insert_count,
total_searches: search_count,
avg_search_latency_ms: search_timer_stats.mean * 1000.0,
min_search_latency_ms: 0.0, max_search_latency_ms: 0.0, avg_index_build_time_ms: index_timer_stats.mean * 1000.0,
avg_similarity_score: similarity_hist_stats.mean,
similarity_count: similarity_hist_stats.count,
}
}
pub fn metrics_registry(&self) -> &Arc<MetricsRegistry> {
&self.metrics_registry
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorStorePerformanceMetrics {
pub total_inserts: u64,
pub total_searches: u64,
pub avg_search_latency_ms: f64,
pub min_search_latency_ms: f64,
pub max_search_latency_ms: f64,
pub avg_index_build_time_ms: f64,
pub avg_similarity_score: f64,
pub similarity_count: u64,
}
impl std::fmt::Display for VectorStorePerformanceMetrics {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"VectorPerf {{ inserts: {}, searches: {}, avg_latency: {:.2}ms, min: {:.2}ms, max: {:.2}ms, avg_similarity: {:.3}, computations: {} }}",
self.total_inserts,
self.total_searches,
self.avg_search_latency_ms,
self.min_search_latency_ms,
self.max_search_latency_ms,
self.avg_similarity_score,
self.similarity_count
)
}
}
pub fn compute_similarity(a: &[f32], b: &[f32], metric: SimilarityMetric) -> Result<f32> {
if a.len() != b.len() {
return Err(anyhow!("Vector dimension mismatch"));
}
let a_arr = ArrayView1::from(a);
let b_arr = ArrayView1::from(b);
match metric {
SimilarityMetric::Cosine => {
let dot_product = a_arr.dot(&b_arr);
let norm_a = a_arr.dot(&a_arr).sqrt();
let norm_b = b_arr.dot(&b_arr).sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
Ok(0.0)
} else {
Ok(dot_product / (norm_a * norm_b))
}
}
SimilarityMetric::Euclidean => {
let diff = &a_arr - &b_arr;
let distance = diff.dot(&diff).sqrt();
Ok(1.0 / (1.0 + distance)) }
SimilarityMetric::Manhattan => {
let diff = &a_arr - &b_arr;
let distance = diff.mapv(f32::abs).sum();
Ok(1.0 / (1.0 + distance))
}
SimilarityMetric::DotProduct => {
Ok(a_arr.dot(&b_arr))
}
SimilarityMetric::Jaccard => {
let a_binary = a_arr.mapv(|x| if x > 0.0 { 1u32 } else { 0 });
let b_binary = b_arr.mapv(|x| if x > 0.0 { 1u32 } else { 0 });
let intersection: u32 = (&a_binary * &b_binary).sum();
let union: u32 = a_binary
.iter()
.zip(b_binary.iter())
.map(|(x, y)| if *x > 0 || *y > 0 { 1 } else { 0 })
.sum();
if union == 0 {
Ok(0.0)
} else {
Ok(intersection as f32 / union as f32)
}
}
SimilarityMetric::Hamming => {
let a_binary = a_arr.mapv(|x| if x > 0.0 { 1u32 } else { 0 });
let b_binary = b_arr.mapv(|x| if x > 0.0 { 1u32 } else { 0 });
let differences: u32 = a_binary
.iter()
.zip(b_binary.iter())
.map(|(x, y)| if x != y { 1 } else { 0 })
.sum();
Ok(1.0 - (differences as f32 / a.len() as f32))
}
}
}
pub fn compute_similarities_batch(
query: &[f32],
candidates: &[&[f32]],
metric: SimilarityMetric,
) -> Result<Vec<f32>> {
if candidates.is_empty() {
return Ok(Vec::new());
}
for candidate in candidates {
if candidate.len() != query.len() {
return Err(anyhow!("Vector dimension mismatch in batch"));
}
}
let query_arr = ArrayView1::from(query);
let query_norm = match metric {
SimilarityMetric::Cosine => {
let norm = query_arr.dot(&query_arr).sqrt();
if norm == 0.0 {
return Ok(vec![0.0; candidates.len()]);
}
norm
}
_ => 1.0,
};
if candidates.len() > 100 {
use rayon::prelude::*;
let results: Vec<f32> = candidates
.par_iter()
.map(|candidate| {
let c_arr = ArrayView1::from(*candidate);
match metric {
SimilarityMetric::Cosine => {
let dot = query_arr.dot(&c_arr);
let c_norm = c_arr.dot(&c_arr).sqrt();
if c_norm == 0.0 {
0.0
} else {
dot / (query_norm * c_norm)
}
}
SimilarityMetric::Euclidean => {
let diff = &query_arr - &c_arr;
let dist = diff.dot(&diff).sqrt();
1.0 / (1.0 + dist)
}
SimilarityMetric::Manhattan => {
let diff = &query_arr - &c_arr;
let dist = diff.mapv(f32::abs).sum();
1.0 / (1.0 + dist)
}
SimilarityMetric::DotProduct => query_arr.dot(&c_arr),
_ => compute_similarity(query, candidate, metric).unwrap_or(0.0),
}
})
.collect();
Ok(results)
} else {
candidates
.iter()
.map(|candidate| compute_similarity(query, candidate, metric))
.collect()
}
}
pub fn create_vector_store(config: &VectorStoreConfig) -> Result<Arc<dyn VectorStore>> {
Ok(Arc::new(InMemoryVectorStore::new(config.clone())))
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_vector_store_creation() {
let config = VectorStoreConfig::default();
let store = InMemoryVectorStore::new(config);
assert_eq!(store.size(), 0);
}
#[tokio::test]
async fn test_vector_insertion_and_retrieval() {
let config = VectorStoreConfig {
dimension: 3,
..Default::default()
};
let store = InMemoryVectorStore::new(config);
let vector = vec![1.0, 2.0, 3.0];
let metadata = Some(
[("type".to_string(), "test".to_string())]
.iter()
.cloned()
.collect(),
);
store
.insert("test1".to_string(), vector.clone(), metadata.clone())
.await
.expect("operation should succeed");
let retrieved = store
.get("test1")
.await
.expect("async operation should succeed")
.expect("operation should succeed");
assert_eq!(retrieved.vector, vector);
assert_eq!(retrieved.metadata, metadata);
}
#[tokio::test]
async fn test_similarity_search() {
let config = VectorStoreConfig {
dimension: 3,
..Default::default()
};
let store = InMemoryVectorStore::new(config);
store
.insert("vec1".to_string(), vec![1.0, 0.0, 0.0], None)
.await
.expect("operation should succeed");
store
.insert("vec2".to_string(), vec![0.9, 0.1, 0.0], None)
.await
.expect("operation should succeed");
store
.insert("vec3".to_string(), vec![0.0, 1.0, 0.0], None)
.await
.expect("operation should succeed");
let query = VectorQuery {
vector: vec![1.0, 0.0, 0.0],
k: 2,
metric: Some(SimilarityMetric::Cosine),
include_metadata: false,
filters: None,
min_similarity: None,
};
let results = store
.search(&query)
.await
.expect("async operation should succeed");
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, "vec1"); }
#[test]
fn test_similarity_metrics() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![2.0, 4.0, 6.0];
let cosine = compute_similarity(&a, &b, SimilarityMetric::Cosine)
.expect("similarity computation should succeed");
assert!((cosine - 1.0).abs() < 1e-6);
let dot_product = compute_similarity(&a, &b, SimilarityMetric::DotProduct)
.expect("similarity computation should succeed");
assert_eq!(dot_product, 28.0); }
#[tokio::test]
async fn test_index_building() {
let config = VectorStoreConfig {
dimension: 3,
index_type: IndexType::Flat,
..Default::default()
};
let store = InMemoryVectorStore::new(config);
store
.insert("vec1".to_string(), vec![1.0, 0.0, 0.0], None)
.await
.expect("operation should succeed");
store
.insert("vec2".to_string(), vec![0.0, 1.0, 0.0], None)
.await
.expect("operation should succeed");
store
.build_index()
.await
.expect("async operation should succeed");
let stats = store
.get_statistics()
.await
.expect("async operation should succeed");
assert_eq!(stats.total_vectors, 2);
}
#[tokio::test]
async fn test_hnsw_index_building() {
let config = VectorStoreConfig {
dimension: 3,
index_type: IndexType::HNSW {
max_connections: 16,
ef_construction: 100,
ef_search: 50,
},
..Default::default()
};
let store = InMemoryVectorStore::new(config);
for i in 0..10 {
let angle = (i as f32) * std::f32::consts::PI / 5.0;
store
.insert(format!("vec{i}"), vec![angle.cos(), angle.sin(), 0.0], None)
.await
.expect("operation should succeed");
}
store
.build_index()
.await
.expect("async operation should succeed");
let stats = store
.get_statistics()
.await
.expect("async operation should succeed");
assert_eq!(stats.total_vectors, 10);
assert!(stats.index_type.contains("HNSW"));
}
#[tokio::test]
async fn test_hnsw_search() {
let config = VectorStoreConfig {
dimension: 3,
index_type: IndexType::HNSW {
max_connections: 16,
ef_construction: 100,
ef_search: 50,
},
..Default::default()
};
let store = InMemoryVectorStore::new(config);
for i in 0..20 {
let angle = (i as f32) * std::f32::consts::PI * 2.0 / 20.0;
store
.insert(format!("vec{i}"), vec![angle.cos(), angle.sin(), 0.0], None)
.await
.expect("operation should succeed");
}
store
.build_index()
.await
.expect("async operation should succeed");
let query = VectorQuery {
vector: vec![1.0, 0.0, 0.0],
k: 3,
metric: Some(SimilarityMetric::Cosine),
include_metadata: false,
filters: None,
min_similarity: None,
};
let results = store
.search(&query)
.await
.expect("async operation should succeed");
assert!(!results.is_empty());
assert_eq!(results[0].0, "vec0");
assert!((results[0].1 - 1.0).abs() < 0.01);
}
#[tokio::test]
async fn test_hnsw_large_dataset() {
let config = VectorStoreConfig {
dimension: 10,
index_type: IndexType::HNSW {
max_connections: 16,
ef_construction: 100,
ef_search: 50,
},
..Default::default()
};
let store = InMemoryVectorStore::new(config);
for i in 0..100 {
let vec: Vec<f32> = (0..10)
.map(|j| ((i * 7 + j * 13) % 100) as f32 / 100.0)
.collect();
store
.insert(format!("vec{i}"), vec, None)
.await
.expect("async operation should succeed");
}
store
.build_index()
.await
.expect("async operation should succeed");
let query_vec = vec![0.5f32; 10];
let query = VectorQuery {
vector: query_vec,
k: 10,
metric: Some(SimilarityMetric::Cosine),
include_metadata: false,
filters: None,
min_similarity: None,
};
let results = store
.search(&query)
.await
.expect("async operation should succeed");
assert!(!results.is_empty());
assert!(results.len() <= 10);
for i in 1..results.len() {
assert!(results[i - 1].1 >= results[i].1);
}
}
#[test]
fn test_batch_similarity_computation() {
let query = vec![1.0, 0.0, 0.0];
let candidates: Vec<&[f32]> =
vec![&[1.0, 0.0, 0.0], &[0.0, 1.0, 0.0], &[0.707, 0.707, 0.0]];
let similarities =
compute_similarities_batch(&query, &candidates, SimilarityMetric::Cosine)
.expect("batch similarity computation should succeed");
assert_eq!(similarities.len(), 3);
assert!((similarities[0] - 1.0).abs() < 0.01);
assert!(similarities[1].abs() < 0.01);
assert!((similarities[2] - 0.707).abs() < 0.01);
}
}