use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DistanceMetric {
Cosine,
L2,
DotProduct,
}
#[derive(Debug, Clone)]
pub struct VectorEntry {
pub id: u64,
pub vector: Vec<f32>,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone)]
pub struct SearchResult {
pub id: u64,
pub score: f32,
pub metadata: Option<HashMap<String, String>>,
}
#[derive(Debug, Clone)]
pub struct VectorStoreConfig {
pub dimensions: u32,
pub metric: DistanceMetric,
pub capacity: usize,
pub store_metadata: bool,
}
impl Default for VectorStoreConfig {
fn default() -> Self {
Self {
dimensions: 768,
metric: DistanceMetric::Cosine,
capacity: 1_000_000,
store_metadata: true,
}
}
}
pub struct VectorStore {
config: VectorStoreConfig,
vectors: Vec<f32>,
ids: Vec<u64>,
metadata: Vec<Option<HashMap<String, String>>>,
count: usize,
}
impl VectorStore {
pub fn new(config: VectorStoreConfig) -> Self {
let cap = config.capacity;
let dim = config.dimensions as usize;
Self {
config,
vectors: Vec::with_capacity(cap * dim),
ids: Vec::with_capacity(cap),
metadata: Vec::with_capacity(cap),
count: 0,
}
}
pub fn insert(&mut self, entry: VectorEntry) -> Result<(), VectorStoreError> {
if entry.vector.len() != self.config.dimensions as usize {
return Err(VectorStoreError::DimensionMismatch {
expected: self.config.dimensions as usize,
got: entry.vector.len(),
});
}
if self.count >= self.config.capacity {
return Err(VectorStoreError::CapacityExceeded {
capacity: self.config.capacity,
});
}
self.vectors.extend_from_slice(&entry.vector);
self.ids.push(entry.id);
self.metadata.push(if self.config.store_metadata {
Some(entry.metadata)
} else {
None
});
self.count += 1;
Ok(())
}
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>, VectorStoreError> {
if query.len() != self.config.dimensions as usize {
return Err(VectorStoreError::DimensionMismatch {
expected: self.config.dimensions as usize,
got: query.len(),
});
}
let dim = self.config.dimensions as usize;
let mut scores: Vec<(usize, f32)> = (0..self.count)
.map(|i| {
let vec_start = i * dim;
let vec_slice = &self.vectors[vec_start..vec_start + dim];
let score = match self.config.metric {
DistanceMetric::Cosine => cosine_similarity(query, vec_slice),
DistanceMetric::L2 => -l2_distance(query, vec_slice), DistanceMetric::DotProduct => dot_product(query, vec_slice),
};
(i, score)
})
.collect();
scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scores.truncate(k);
Ok(scores
.into_iter()
.map(|(idx, score)| SearchResult {
id: self.ids[idx],
score,
metadata: self.metadata.get(idx).and_then(|m| m.clone()),
})
.collect())
}
pub fn delete(&mut self, id: u64) -> bool {
if let Some(idx) = self.ids.iter().position(|&i| i == id) {
let dim = self.config.dimensions as usize;
let vec_start = idx * dim;
self.vectors.drain(vec_start..vec_start + dim);
self.ids.remove(idx);
self.metadata.remove(idx);
self.count -= 1;
true
} else {
false
}
}
pub fn len(&self) -> usize {
self.count
}
pub fn is_empty(&self) -> bool {
self.count == 0
}
pub fn flat_vectors(&self) -> &[f32] {
&self.vectors
}
pub fn ids(&self) -> &[u64] {
&self.ids
}
pub fn dimensions(&self) -> u32 {
self.config.dimensions
}
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
dot / (norm_a * norm_b)
}
fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
a.iter()
.zip(b.iter())
.map(|(x, y)| (x - y) * (x - y))
.sum::<f32>()
.sqrt()
}
fn dot_product(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[derive(Debug, Clone)]
pub enum VectorStoreError {
DimensionMismatch {
expected: usize,
got: usize,
},
CapacityExceeded {
capacity: usize,
},
NotFound(u64),
}
impl std::fmt::Display for VectorStoreError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::DimensionMismatch { expected, got } => {
write!(f, "Dimension mismatch: expected {}, got {}", expected, got)
}
Self::CapacityExceeded { capacity } => write!(f, "Capacity exceeded: max {}", capacity),
Self::NotFound(id) => write!(f, "Vector {} not found", id),
}
}
}
impl std::error::Error for VectorStoreError {}
#[cfg(test)]
mod tests {
use super::*;
fn make_vec(dim: usize, base: f32) -> Vec<f32> {
(0..dim).map(|i| base + i as f32 * 0.1).collect()
}
#[test]
fn test_insert_and_search() {
let mut store = VectorStore::new(VectorStoreConfig {
dimensions: 4,
metric: DistanceMetric::Cosine,
capacity: 100,
store_metadata: false,
});
store
.insert(VectorEntry {
id: 1,
vector: vec![1.0, 0.0, 0.0, 0.0],
metadata: HashMap::new(),
})
.unwrap();
store
.insert(VectorEntry {
id: 2,
vector: vec![0.9, 0.1, 0.0, 0.0],
metadata: HashMap::new(),
})
.unwrap();
store
.insert(VectorEntry {
id: 3,
vector: vec![0.0, 1.0, 0.0, 0.0],
metadata: HashMap::new(),
})
.unwrap();
let results = store.search(&[1.0, 0.0, 0.0, 0.0], 2).unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].id, 1); assert_eq!(results[1].id, 2); }
#[test]
fn test_l2_distance_search() {
let mut store = VectorStore::new(VectorStoreConfig {
dimensions: 3,
metric: DistanceMetric::L2,
capacity: 100,
store_metadata: false,
});
store
.insert(VectorEntry {
id: 1,
vector: vec![0.0, 0.0, 0.0],
metadata: HashMap::new(),
})
.unwrap();
store
.insert(VectorEntry {
id: 2,
vector: vec![1.0, 0.0, 0.0],
metadata: HashMap::new(),
})
.unwrap();
store
.insert(VectorEntry {
id: 3,
vector: vec![10.0, 10.0, 10.0],
metadata: HashMap::new(),
})
.unwrap();
let results = store.search(&[0.0, 0.0, 0.0], 2).unwrap();
assert_eq!(results[0].id, 1); assert_eq!(results[1].id, 2); }
#[test]
fn test_dimension_mismatch() {
let mut store = VectorStore::new(VectorStoreConfig {
dimensions: 4,
..Default::default()
});
let result = store.insert(VectorEntry {
id: 1,
vector: vec![1.0, 2.0], metadata: HashMap::new(),
});
assert!(matches!(
result,
Err(VectorStoreError::DimensionMismatch { .. })
));
}
#[test]
fn test_capacity_exceeded() {
let mut store = VectorStore::new(VectorStoreConfig {
dimensions: 2,
capacity: 2,
..Default::default()
});
store
.insert(VectorEntry {
id: 1,
vector: vec![1.0, 0.0],
metadata: HashMap::new(),
})
.unwrap();
store
.insert(VectorEntry {
id: 2,
vector: vec![0.0, 1.0],
metadata: HashMap::new(),
})
.unwrap();
let result = store.insert(VectorEntry {
id: 3,
vector: vec![1.0, 1.0],
metadata: HashMap::new(),
});
assert!(matches!(
result,
Err(VectorStoreError::CapacityExceeded { .. })
));
}
#[test]
fn test_delete() {
let mut store = VectorStore::new(VectorStoreConfig {
dimensions: 2,
capacity: 10,
..Default::default()
});
store
.insert(VectorEntry {
id: 1,
vector: vec![1.0, 0.0],
metadata: HashMap::new(),
})
.unwrap();
store
.insert(VectorEntry {
id: 2,
vector: vec![0.0, 1.0],
metadata: HashMap::new(),
})
.unwrap();
assert_eq!(store.len(), 2);
assert!(store.delete(1));
assert_eq!(store.len(), 1);
assert!(!store.delete(1)); }
#[test]
fn test_metadata_storage() {
let mut store = VectorStore::new(VectorStoreConfig {
dimensions: 2,
store_metadata: true,
..Default::default()
});
let mut meta = HashMap::new();
meta.insert("node_type".to_string(), "isa_standard".to_string());
store
.insert(VectorEntry {
id: 42,
vector: vec![1.0, 0.0],
metadata: meta,
})
.unwrap();
let results = store.search(&[1.0, 0.0], 1).unwrap();
assert_eq!(results[0].id, 42);
let meta = results[0].metadata.as_ref().unwrap();
assert_eq!(meta.get("node_type").unwrap(), "isa_standard");
}
#[test]
fn test_flat_vectors_for_gpu() {
let mut store = VectorStore::new(VectorStoreConfig {
dimensions: 3,
capacity: 10,
..Default::default()
});
store
.insert(VectorEntry {
id: 1,
vector: vec![1.0, 2.0, 3.0],
metadata: HashMap::new(),
})
.unwrap();
store
.insert(VectorEntry {
id: 2,
vector: vec![4.0, 5.0, 6.0],
metadata: HashMap::new(),
})
.unwrap();
let flat = store.flat_vectors();
assert_eq!(flat.len(), 6); assert_eq!(flat, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
}
#[test]
fn test_cosine_similarity_identical() {
let a = vec![1.0, 2.0, 3.0];
assert!((cosine_similarity(&a, &a) - 1.0).abs() < 1e-6);
}
#[test]
fn test_cosine_similarity_orthogonal() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
assert!(cosine_similarity(&a, &b).abs() < 1e-6);
}
#[test]
fn test_dot_product_basic() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
assert!((dot_product(&a, &b) - 32.0).abs() < 1e-6);
}
}