use std::path::Path;
use std::sync::Arc;
use crate::error::MemoryError;
#[derive(Debug, Clone)]
pub struct VectorHit {
pub key: String,
pub distance: f32,
}
impl VectorHit {
pub fn similarity(&self) -> f32 {
(1.0 - self.distance).max(0.0)
}
pub fn parse_key(&self) -> Result<(&str, &str), MemoryError> {
self.key
.split_once(':')
.ok_or_else(|| MemoryError::InvalidKey(self.key.clone()))
}
}
#[derive(Debug, Clone)]
pub struct VectorIndexConfig {
pub m: usize,
pub ef_construction: usize,
pub ef_search: usize,
pub dimensions: usize,
pub max_elements: usize,
pub compaction_threshold: f32,
pub flush_interval_secs: Option<u64>,
}
impl Default for VectorIndexConfig {
fn default() -> Self {
Self {
m: 16,
ef_construction: 200,
ef_search: 50,
dimensions: 768,
max_elements: 100_000,
compaction_threshold: 0.3,
flush_interval_secs: None,
}
}
}
pub trait VectorBackend: Send + Sync {
fn insert(&self, key: String, vector: &[f32]) -> Result<(), MemoryError>;
fn delete(&self, key: &str) -> Result<(), MemoryError>;
fn update(&self, key: String, vector: &[f32]) -> Result<(), MemoryError>;
fn search(&self, query: &[f32], top_k: usize) -> Result<Vec<VectorHit>, MemoryError>;
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
fn save(&self, dir: &Path, basename: &str) -> Result<(), MemoryError>;
fn backend_name(&self) -> &'static str;
}
#[derive(Clone)]
pub struct VectorIndex {
inner: Arc<dyn VectorBackend>,
}
impl VectorIndex {
pub fn new(config: VectorIndexConfig) -> Result<Self, MemoryError> {
let backend = build_active_backend(config)?;
Ok(Self { inner: backend })
}
pub fn load(dir: &Path, basename: &str, config: VectorIndexConfig) -> Result<Self, MemoryError> {
let backend = load_active_backend(dir, basename, config)?;
Ok(Self { inner: backend })
}
pub fn insert(&self, key: String, vector: &[f32]) -> Result<(), MemoryError> {
self.inner.insert(key, vector)
}
pub fn delete(&self, key: &str) -> Result<(), MemoryError> {
self.inner.delete(key)
}
pub fn update(&self, key: String, vector: &[f32]) -> Result<(), MemoryError> {
self.inner.update(key, vector)
}
pub fn search(&self, query: &[f32], top_k: usize) -> Result<Vec<VectorHit>, MemoryError> {
self.inner.search(query, top_k)
}
pub fn len(&self) -> usize {
self.inner.len()
}
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
pub fn save(&self, dir: &Path, basename: &str) -> Result<(), MemoryError> {
self.inner.save(dir, basename)
}
pub fn backend_name(&self) -> &'static str {
self.inner.backend_name()
}
pub fn _placeholder(&self) {}
}
fn build_active_backend(
config: VectorIndexConfig,
) -> Result<Arc<dyn VectorBackend>, MemoryError> {
#[cfg(feature = "hnsw")]
{
return Ok(Arc::new(super::hnsw_backend::HnswBackend::new(config)?));
}
#[cfg(feature = "usearch-backend")]
{
return Ok(Arc::new(super::usearch_backend::UsearchBackend::new(config)?));
}
#[allow(unreachable_code)]
{
let _ = config;
Err(MemoryError::NotImplemented(
"no vector backend feature enabled (need `hnsw` or `usearch-backend`)".to_string(),
))
}
}
fn load_active_backend(
dir: &Path,
basename: &str,
config: VectorIndexConfig,
) -> Result<Arc<dyn VectorBackend>, MemoryError> {
#[cfg(feature = "hnsw")]
{
return Ok(Arc::new(super::hnsw_backend::HnswBackend::load(
dir, basename, config,
)?));
}
#[cfg(feature = "usearch-backend")]
{
return Ok(Arc::new(super::usearch_backend::UsearchBackend::load(
dir, basename, config,
)?));
}
#[allow(unreachable_code)]
{
let _ = (dir, basename, config);
Err(MemoryError::NotImplemented(
"no vector backend feature enabled (need `hnsw` or `usearch-backend`)".to_string(),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn vector_hit_similarity_below_zero_clamps_to_zero() {
let h = VectorHit { key: "fact:1".to_string(), distance: 2.0 };
assert_eq!(h.similarity(), 0.0);
}
#[test]
fn vector_hit_similarity_normal() {
let h = VectorHit { key: "fact:1".to_string(), distance: 0.3 };
assert!((h.similarity() - 0.7).abs() < 1e-6);
}
#[test]
fn vector_hit_parse_key_valid() {
let h = VectorHit { key: "chunk:abc-123".to_string(), distance: 0.0 };
let (domain, id) = h.parse_key().unwrap();
assert_eq!(domain, "chunk");
assert_eq!(id, "abc-123");
}
#[test]
fn vector_hit_parse_key_invalid() {
let h = VectorHit { key: "no_colon".to_string(), distance: 0.0 };
assert!(h.parse_key().is_err());
}
#[test]
fn config_default_matches_hnsw_default() {
let c = VectorIndexConfig::default();
assert_eq!(c.m, 16);
assert_eq!(c.ef_construction, 200);
assert_eq!(c.ef_search, 50);
assert_eq!(c.dimensions, 768);
assert_eq!(c.max_elements, 100_000);
}
}