#![warn(missing_docs)]
use std::collections::HashMap;
use std::fs;
use std::io;
use std::path::Path;
use std::sync::Arc;
use std::time::{Duration, Instant};
use parking_lot::RwLock;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use tensor_store::{
fields, hnsw::simd, SparseVector, TensorData, TensorStore, TensorStoreError, TensorValue,
};
use tracing::instrument;
pub use tensor_store::{HNSWConfig, HNSWIndex, ScalarQuantizedVector};
pub use tensor_store::WalConfig;
pub use tensor_store::{DistanceMetric as ExtendedDistanceMetric, GeometricConfig};
pub use tensor_store::{
BinaryThreshold, BinaryVector, IVFConfig, IVFIndex, IVFIndexState, IVFStorage, PQCodebook,
PQConfig, PQVector,
};
type HnswCacheEntry = (Arc<HNSWIndex>, Vec<String>);
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum VectorError {
NotFound(String),
DimensionMismatch {
expected: usize,
got: usize,
},
EmptyVector,
InvalidTopK,
StorageError(String),
BatchValidationError {
index: usize,
cause: String,
},
BatchOperationError {
index: usize,
cause: String,
},
ConfigurationError(String),
CollectionExists(String),
CollectionNotFound(String),
IoError(String),
SerializationError(String),
SearchTimeout {
operation: String,
timeout_ms: u64,
},
}
impl std::fmt::Display for VectorError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NotFound(key) => write!(f, "Embedding not found: {key}"),
Self::DimensionMismatch { expected, got } => {
write!(f, "Dimension mismatch: expected {expected}, got {got}")
},
Self::EmptyVector => write!(f, "Empty vector provided"),
Self::InvalidTopK => write!(f, "Invalid top_k value (must be > 0)"),
Self::StorageError(e) => write!(f, "Storage error: {e}"),
Self::BatchValidationError { index, cause } => {
write!(f, "Batch validation error at index {index}: {cause}")
},
Self::BatchOperationError { index, cause } => {
write!(f, "Batch operation error at index {index}: {cause}")
},
Self::ConfigurationError(msg) => write!(f, "Configuration error: {msg}"),
Self::CollectionExists(name) => write!(f, "Collection already exists: {name}"),
Self::CollectionNotFound(name) => write!(f, "Collection not found: {name}"),
Self::IoError(msg) => write!(f, "IO error: {msg}"),
Self::SerializationError(msg) => write!(f, "Serialization error: {msg}"),
Self::SearchTimeout {
operation,
timeout_ms,
} => {
write!(f, "search timeout: {operation} exceeded {timeout_ms}ms")
},
}
}
}
impl std::error::Error for VectorError {}
impl From<TensorStoreError> for VectorError {
fn from(e: TensorStoreError) -> Self {
Self::StorageError(e.to_string())
}
}
impl From<io::Error> for VectorError {
fn from(e: io::Error) -> Self {
Self::IoError(e.to_string())
}
}
impl From<DistanceMetric> for ExtendedDistanceMetric {
#[allow(clippy::match_same_arms)] fn from(metric: DistanceMetric) -> Self {
match metric {
DistanceMetric::Euclidean => Self::Euclidean,
DistanceMetric::Cosine => Self::Cosine,
DistanceMetric::DotProduct => Self::Cosine,
}
}
}
pub type Result<T> = std::result::Result<T, VectorError>;
#[derive(Debug, Clone, Copy)]
struct Deadline {
deadline: Option<Instant>,
timeout_ms: u64,
}
impl Deadline {
fn from_duration(timeout: Option<Duration>) -> Self {
Self {
deadline: timeout.map(|d| Instant::now() + d),
#[allow(clippy::cast_possible_truncation)]
timeout_ms: timeout.map_or(0, |d| d.as_millis() as u64),
}
}
#[cfg(test)]
const fn never() -> Self {
Self {
deadline: None,
timeout_ms: 0,
}
}
#[inline]
fn is_expired(&self) -> bool {
self.deadline.is_some_and(|d| Instant::now() >= d)
}
const fn timeout_ms(&self) -> u64 {
self.timeout_ms
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SearchResult {
pub key: String,
pub score: f32,
}
impl SearchResult {
#[allow(missing_docs)]
#[must_use]
pub const fn new(key: String, score: f32) -> Self {
Self { key, score }
}
}
#[derive(
Debug,
Clone,
Copy,
PartialEq,
Eq,
Default,
Serialize,
Deserialize,
bitcode::Encode,
bitcode::Decode,
)]
pub enum DistanceMetric {
#[default]
Cosine,
Euclidean,
DotProduct,
}
#[derive(Debug, Clone, PartialEq)]
pub enum FilterCondition {
Eq(String, FilterValue),
Ne(String, FilterValue),
Lt(String, FilterValue),
Le(String, FilterValue),
Gt(String, FilterValue),
Ge(String, FilterValue),
And(Box<Self>, Box<Self>),
Or(Box<Self>, Box<Self>),
True,
Exists(String),
Contains(String, String),
StartsWith(String, String),
In(String, Vec<FilterValue>),
}
impl FilterCondition {
#[allow(missing_docs)]
#[must_use]
pub fn and(self, other: Self) -> Self {
Self::And(Box::new(self), Box::new(other))
}
#[allow(missing_docs)]
#[must_use]
pub fn or(self, other: Self) -> Self {
Self::Or(Box::new(self), Box::new(other))
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum FilterValue {
Int(i64),
Float(f64),
String(String),
Bool(bool),
Null,
}
impl From<i64> for FilterValue {
fn from(v: i64) -> Self {
Self::Int(v)
}
}
impl From<f64> for FilterValue {
fn from(v: f64) -> Self {
Self::Float(v)
}
}
impl From<String> for FilterValue {
fn from(v: String) -> Self {
Self::String(v)
}
}
impl From<&str> for FilterValue {
fn from(v: &str) -> Self {
Self::String(v.to_string())
}
}
impl From<bool> for FilterValue {
fn from(v: bool) -> Self {
Self::Bool(v)
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum FilterStrategy {
#[default]
Auto,
PreFilter,
PostFilter,
}
#[derive(Debug, Clone)]
pub struct FilteredSearchConfig {
pub strategy: FilterStrategy,
pub selectivity_threshold: f32,
pub oversample_factor: usize,
}
impl Default for FilteredSearchConfig {
fn default() -> Self {
Self {
strategy: FilterStrategy::Auto,
selectivity_threshold: 0.1,
oversample_factor: 3,
}
}
}
impl FilteredSearchConfig {
#[must_use]
pub const fn pre_filter() -> Self {
Self {
strategy: FilterStrategy::PreFilter,
selectivity_threshold: 0.1,
oversample_factor: 3,
}
}
#[must_use]
pub const fn post_filter() -> Self {
Self {
strategy: FilterStrategy::PostFilter,
selectivity_threshold: 0.1,
oversample_factor: 3,
}
}
#[must_use]
pub const fn with_oversample(mut self, factor: usize) -> Self {
self.oversample_factor = factor;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize, bitcode::Encode, bitcode::Decode)]
pub struct VectorCollectionConfig {
pub dimension: Option<usize>,
pub distance_metric: DistanceMetric,
pub auto_index: bool,
pub auto_index_threshold: usize,
}
impl Default for VectorCollectionConfig {
fn default() -> Self {
Self {
dimension: None,
distance_metric: DistanceMetric::Cosine,
auto_index: false,
auto_index_threshold: 1000,
}
}
}
impl VectorCollectionConfig {
#[must_use]
pub const fn with_dimension(mut self, dim: usize) -> Self {
self.dimension = Some(dim);
self
}
#[must_use]
pub const fn with_metric(mut self, metric: DistanceMetric) -> Self {
self.distance_metric = metric;
self
}
#[must_use]
pub const fn with_auto_index(mut self, threshold: usize) -> Self {
self.auto_index = true;
self.auto_index_threshold = threshold;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize, bitcode::Encode, bitcode::Decode)]
pub struct PersistentVectorIndex {
pub collection: String,
pub config: VectorCollectionConfig,
pub vectors: Vec<VectorEntry>,
pub created_at: u64,
pub version: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize, bitcode::Encode, bitcode::Decode)]
pub struct VectorEntry {
pub key: String,
pub vector: Vec<f32>,
pub metadata: Option<HashMap<String, MetadataValue>>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, bitcode::Encode, bitcode::Decode)]
pub enum MetadataValue {
Null,
Bool(bool),
Int(i64),
Float(f64),
String(String),
}
impl MetadataValue {
#[must_use]
pub fn from_tensor_value(tv: &TensorValue) -> Option<Self> {
use tensor_store::ScalarValue;
match tv {
TensorValue::Scalar(ScalarValue::Null) => Some(Self::Null),
TensorValue::Scalar(ScalarValue::Bool(b)) => Some(Self::Bool(*b)),
TensorValue::Scalar(ScalarValue::Int(i)) => Some(Self::Int(*i)),
TensorValue::Scalar(ScalarValue::Float(f)) => Some(Self::Float(*f)),
TensorValue::Scalar(ScalarValue::String(s)) => Some(Self::String(s.clone())),
_ => None, }
}
}
impl From<MetadataValue> for TensorValue {
fn from(mv: MetadataValue) -> Self {
use tensor_store::ScalarValue;
match mv {
MetadataValue::Null => Self::Scalar(ScalarValue::Null),
MetadataValue::Bool(b) => Self::Scalar(ScalarValue::Bool(b)),
MetadataValue::Int(i) => Self::Scalar(ScalarValue::Int(i)),
MetadataValue::Float(f) => Self::Scalar(ScalarValue::Float(f)),
MetadataValue::String(s) => Self::Scalar(ScalarValue::String(s)),
}
}
}
impl PersistentVectorIndex {
pub const CURRENT_VERSION: u32 = 1;
#[must_use]
pub fn new(collection: String, config: VectorCollectionConfig) -> Self {
Self {
collection,
config,
vectors: Vec::new(),
created_at: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_or(0, |d| d.as_secs()),
version: Self::CURRENT_VERSION,
}
}
pub fn push(
&mut self,
key: String,
vector: Vec<f32>,
metadata: Option<HashMap<String, MetadataValue>>,
) {
self.vectors.push(VectorEntry {
key,
vector,
metadata,
});
}
#[must_use]
pub const fn len(&self) -> usize {
self.vectors.len()
}
#[must_use]
pub const fn is_empty(&self) -> bool {
self.vectors.is_empty()
}
}
#[derive(Debug, Clone)]
pub struct VectorEngineConfig {
pub default_dimension: Option<usize>,
pub sparse_threshold: f32,
pub parallel_threshold: usize,
pub default_metric: DistanceMetric,
pub max_dimension: Option<usize>,
pub max_keys_per_scan: Option<usize>,
pub batch_parallel_threshold: usize,
pub search_timeout: Option<Duration>,
pub max_index_file_bytes: Option<usize>,
pub max_index_entries: Option<usize>,
}
impl Default for VectorEngineConfig {
fn default() -> Self {
Self {
default_dimension: None,
sparse_threshold: 0.5,
parallel_threshold: 5000,
default_metric: DistanceMetric::Cosine,
max_dimension: None,
max_keys_per_scan: None,
batch_parallel_threshold: 100,
search_timeout: None,
max_index_file_bytes: Some(100 * 1024 * 1024), max_index_entries: Some(1_000_000), }
}
}
impl VectorEngineConfig {
#[must_use]
pub const fn high_throughput() -> Self {
Self {
default_dimension: None,
sparse_threshold: 0.5,
parallel_threshold: 1000, default_metric: DistanceMetric::Cosine,
max_dimension: None,
max_keys_per_scan: None,
batch_parallel_threshold: 100,
search_timeout: None,
max_index_file_bytes: Some(100 * 1024 * 1024), max_index_entries: Some(1_000_000),
}
}
#[must_use]
pub const fn low_memory() -> Self {
Self {
default_dimension: None,
sparse_threshold: 0.3, parallel_threshold: 5000,
default_metric: DistanceMetric::Cosine,
max_dimension: Some(4096),
max_keys_per_scan: Some(10_000),
batch_parallel_threshold: 100,
search_timeout: Some(Duration::from_secs(30)),
max_index_file_bytes: Some(10 * 1024 * 1024), max_index_entries: Some(100_000), }
}
pub fn validate(&self) -> Result<()> {
if self.sparse_threshold < 0.0 || self.sparse_threshold > 1.0 {
return Err(VectorError::ConfigurationError(
"sparse_threshold must be between 0.0 and 1.0".to_string(),
));
}
if self.parallel_threshold == 0 {
return Err(VectorError::ConfigurationError(
"parallel_threshold must be greater than 0".to_string(),
));
}
if let Some(max_dim) = self.max_dimension {
if max_dim == 0 {
return Err(VectorError::ConfigurationError(
"max_dimension must be greater than 0".to_string(),
));
}
}
if let Some(max_keys) = self.max_keys_per_scan {
if max_keys == 0 {
return Err(VectorError::ConfigurationError(
"max_keys_per_scan must be greater than 0".to_string(),
));
}
}
if self.batch_parallel_threshold == 0 {
return Err(VectorError::ConfigurationError(
"batch_parallel_threshold must be greater than 0".to_string(),
));
}
if let Some(max_bytes) = self.max_index_file_bytes {
if max_bytes == 0 {
return Err(VectorError::ConfigurationError(
"max_index_file_bytes must be greater than 0".to_string(),
));
}
}
if let Some(max_entries) = self.max_index_entries {
if max_entries == 0 {
return Err(VectorError::ConfigurationError(
"max_index_entries must be greater than 0".to_string(),
));
}
}
Ok(())
}
#[must_use]
pub const fn with_default_dimension(mut self, dim: usize) -> Self {
self.default_dimension = Some(dim);
self
}
#[must_use]
pub const fn with_sparse_threshold(mut self, threshold: f32) -> Self {
self.sparse_threshold = threshold;
self
}
#[must_use]
pub const fn with_parallel_threshold(mut self, threshold: usize) -> Self {
self.parallel_threshold = threshold;
self
}
#[must_use]
pub const fn with_default_metric(mut self, metric: DistanceMetric) -> Self {
self.default_metric = metric;
self
}
#[must_use]
pub const fn with_max_dimension(mut self, max: usize) -> Self {
self.max_dimension = Some(max);
self
}
#[must_use]
pub const fn with_max_keys_per_scan(mut self, max: usize) -> Self {
self.max_keys_per_scan = Some(max);
self
}
#[must_use]
pub const fn with_batch_parallel_threshold(mut self, threshold: usize) -> Self {
self.batch_parallel_threshold = threshold;
self
}
#[must_use]
pub const fn with_search_timeout(mut self, timeout: Duration) -> Self {
self.search_timeout = Some(timeout);
self
}
#[must_use]
pub const fn with_max_index_file_bytes(mut self, max_bytes: usize) -> Self {
self.max_index_file_bytes = Some(max_bytes);
self
}
#[must_use]
pub const fn with_max_index_entries(mut self, max_entries: usize) -> Self {
self.max_index_entries = Some(max_entries);
self
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum HNSWStorageStrategy {
#[default]
Dense,
Auto,
Quantized,
}
#[derive(Debug, Clone)]
pub struct HNSWBuildOptions {
pub storage: HNSWStorageStrategy,
pub hnsw_config: HNSWConfig,
}
impl Default for HNSWBuildOptions {
fn default() -> Self {
Self {
storage: HNSWStorageStrategy::Dense,
hnsw_config: HNSWConfig::default(),
}
}
}
impl HNSWBuildOptions {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn memory_optimized() -> Self {
Self {
storage: HNSWStorageStrategy::Quantized,
hnsw_config: HNSWConfig::high_speed(),
}
}
#[must_use]
pub fn high_recall() -> Self {
Self {
storage: HNSWStorageStrategy::Dense,
hnsw_config: HNSWConfig::high_recall(),
}
}
#[must_use]
pub fn sparse_optimized() -> Self {
Self {
storage: HNSWStorageStrategy::Auto,
hnsw_config: HNSWConfig::default(),
}
}
#[must_use]
pub const fn with_storage(mut self, storage: HNSWStorageStrategy) -> Self {
self.storage = storage;
self
}
#[must_use]
pub const fn with_hnsw_config(mut self, config: HNSWConfig) -> Self {
self.hnsw_config = config;
self
}
#[must_use]
pub const fn with_sparsity_threshold(mut self, threshold: f32) -> Self {
self.hnsw_config.sparsity_threshold = threshold;
self
}
}
#[derive(Debug, Clone, Default)]
pub struct IVFBuildOptions {
pub config: IVFConfig,
}
impl IVFBuildOptions {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn flat(num_clusters: usize) -> Self {
Self {
config: IVFConfig::flat(num_clusters),
}
}
#[must_use]
pub fn pq(num_clusters: usize, pq_config: PQConfig) -> Self {
Self {
config: IVFConfig::pq(num_clusters, pq_config),
}
}
#[must_use]
pub fn binary(num_clusters: usize) -> Self {
Self {
config: IVFConfig::binary(num_clusters, BinaryThreshold::Sign),
}
}
#[must_use]
pub const fn with_nprobe(mut self, nprobe: usize) -> Self {
self.config.nprobe = nprobe;
self
}
#[must_use]
pub const fn with_num_clusters(mut self, num_clusters: usize) -> Self {
self.config.num_clusters = num_clusters;
self
}
#[must_use]
pub const fn with_storage(mut self, storage: IVFStorage) -> Self {
self.config.storage = storage;
self
}
}
#[derive(Debug, Clone)]
pub struct EmbeddingInput {
pub key: String,
pub vector: Vec<f32>,
}
impl EmbeddingInput {
#[must_use]
pub fn new(key: impl Into<String>, vector: Vec<f32>) -> Self {
Self {
key: key.into(),
vector,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BatchResult {
pub stored_keys: Vec<String>,
pub count: usize,
}
impl BatchResult {
#[must_use]
pub const fn new(stored_keys: Vec<String>) -> Self {
let count = stored_keys.len();
Self { stored_keys, count }
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct Pagination {
pub skip: usize,
pub limit: Option<usize>,
pub count_total: bool,
}
impl Pagination {
#[must_use]
pub const fn new(skip: usize, limit: usize) -> Self {
Self {
skip,
limit: Some(limit),
count_total: false,
}
}
#[must_use]
pub const fn with_total(mut self) -> Self {
self.count_total = true;
self
}
#[must_use]
pub const fn skip_only(skip: usize) -> Self {
Self {
skip,
limit: None,
count_total: false,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct PagedResult<T> {
pub items: Vec<T>,
pub total_count: Option<usize>,
pub has_more: bool,
}
impl<T> PagedResult<T> {
#[must_use]
pub const fn new(items: Vec<T>, total_count: Option<usize>, has_more: bool) -> Self {
Self {
items,
total_count,
has_more,
}
}
#[must_use]
pub const fn empty() -> Self {
Self {
items: Vec::new(),
total_count: Some(0),
has_more: false,
}
}
}
pub struct VectorEngine {
store: TensorStore,
config: VectorEngineConfig,
collections: Arc<RwLock<HashMap<String, VectorCollectionConfig>>>,
delete_lock: RwLock<()>,
hnsw_cache: Arc<RwLock<HashMap<String, HnswCacheEntry>>>,
}
impl VectorEngine {
#[must_use]
pub fn new() -> Self {
Self {
store: TensorStore::new(),
config: VectorEngineConfig::default(),
collections: Arc::new(RwLock::new(HashMap::new())),
delete_lock: RwLock::new(()),
hnsw_cache: Arc::new(RwLock::new(HashMap::new())),
}
}
#[must_use]
pub fn with_store(store: TensorStore) -> Self {
Self {
store,
config: VectorEngineConfig::default(),
collections: Arc::new(RwLock::new(HashMap::new())),
delete_lock: RwLock::new(()),
hnsw_cache: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn with_config(config: VectorEngineConfig) -> Result<Self> {
config.validate()?;
Ok(Self {
store: TensorStore::new(),
config,
collections: Arc::new(RwLock::new(HashMap::new())),
delete_lock: RwLock::new(()),
hnsw_cache: Arc::new(RwLock::new(HashMap::new())),
})
}
pub fn with_store_and_config(store: TensorStore, config: VectorEngineConfig) -> Result<Self> {
config.validate()?;
Ok(Self {
store,
config,
collections: Arc::new(RwLock::new(HashMap::new())),
delete_lock: RwLock::new(()),
hnsw_cache: Arc::new(RwLock::new(HashMap::new())),
})
}
pub fn open_durable<P: AsRef<Path>>(wal_path: P, wal_config: WalConfig) -> Result<Self> {
Self::open_durable_with_config(wal_path, wal_config, VectorEngineConfig::default())
}
pub fn open_durable_with_config<P: AsRef<Path>>(
wal_path: P,
wal_config: WalConfig,
config: VectorEngineConfig,
) -> Result<Self> {
config.validate()?;
let store = TensorStore::open_durable(wal_path, wal_config)
.map_err(|e| VectorError::StorageError(e.to_string()))?;
Ok(Self {
store,
config,
collections: Arc::new(RwLock::new(HashMap::new())),
delete_lock: RwLock::new(()),
hnsw_cache: Arc::new(RwLock::new(HashMap::new())),
})
}
pub fn recover<P: AsRef<Path>>(
wal_path: P,
wal_config: &WalConfig,
snapshot_path: Option<&Path>,
) -> Result<Self> {
Self::recover_with_config(
wal_path,
wal_config,
snapshot_path,
VectorEngineConfig::default(),
)
}
pub fn recover_with_config<P: AsRef<Path>>(
wal_path: P,
wal_config: &WalConfig,
snapshot_path: Option<&Path>,
config: VectorEngineConfig,
) -> Result<Self> {
config.validate()?;
let store = TensorStore::recover(wal_path, wal_config, snapshot_path)
.map_err(|e| VectorError::StorageError(e.to_string()))?;
Ok(Self {
store,
config,
collections: Arc::new(RwLock::new(HashMap::new())),
delete_lock: RwLock::new(()),
hnsw_cache: Arc::new(RwLock::new(HashMap::new())),
})
}
#[must_use]
pub fn is_durable(&self) -> bool {
self.store.has_wal()
}
#[must_use]
pub const fn store(&self) -> &TensorStore {
&self.store
}
pub const fn config(&self) -> &VectorEngineConfig {
&self.config
}
pub fn cache_hnsw_index(&self, collection: &str, index: Arc<HNSWIndex>, keys: Vec<String>) {
self.hnsw_cache
.write()
.insert(collection.to_string(), (index, keys));
}
pub fn invalidate_hnsw_cache(&self, collection: &str) {
self.hnsw_cache.write().remove(collection);
}
pub fn build_and_cache_index(&self, config: HNSWConfig) -> Result<()> {
let (index, keys) = self.build_hnsw_index(config)?;
self.cache_hnsw_index("_default", Arc::new(index), keys);
Ok(())
}
fn embedding_key(key: &str) -> String {
format!("emb:{key}")
}
const fn embedding_prefix() -> &'static str {
"emb:"
}
fn collection_embedding_key(collection: &str, key: &str) -> String {
format!("coll:{collection}:emb:{key}")
}
fn collection_embedding_prefix(collection: &str) -> String {
format!("coll:{collection}:emb:")
}
pub const DEFAULT_COLLECTION: &'static str = "default";
#[instrument(skip(self, config), fields(collection = %name))]
pub fn create_collection(&self, name: &str, config: VectorCollectionConfig) -> Result<()> {
let mut collections = self.collections.write();
if collections.contains_key(name) {
return Err(VectorError::CollectionExists(name.to_string()));
}
collections.insert(name.to_string(), config);
drop(collections);
Ok(())
}
#[instrument(skip(self), fields(collection = %name))]
pub fn delete_collection(&self, name: &str) -> Result<()> {
{
let mut collections = self.collections.write();
if !collections.contains_key(name) {
return Err(VectorError::CollectionNotFound(name.to_string()));
}
collections.remove(name);
}
let prefix = Self::collection_embedding_prefix(name);
let keys = self.store.scan(&prefix);
for key in keys {
let _ = self.store.delete(&key);
}
Ok(())
}
pub fn list_collections(&self) -> Vec<String> {
self.collections.read().keys().cloned().collect()
}
pub fn get_collection_config(&self, name: &str) -> Option<VectorCollectionConfig> {
self.collections.read().get(name).cloned()
}
pub fn collection_exists(&self, name: &str) -> bool {
self.collections.read().contains_key(name)
}
pub fn collection_count(&self, name: &str) -> usize {
let prefix = Self::collection_embedding_prefix(name);
self.store.scan(&prefix).len()
}
#[instrument(skip(self, vector), fields(collection = %collection, key = %key, vector_dim = vector.len()))]
pub fn store_in_collection(&self, collection: &str, key: &str, vector: Vec<f32>) -> Result<()> {
self.store_in_collection_with_metadata(collection, key, vector, HashMap::new())
}
#[instrument(skip(self, vector, metadata), fields(collection = %collection, key = %key, vector_dim = vector.len()))]
pub fn store_in_collection_with_metadata(
&self,
collection: &str,
key: &str,
vector: Vec<f32>,
metadata: HashMap<String, TensorValue>,
) -> Result<()> {
if vector.is_empty() {
return Err(VectorError::EmptyVector);
}
let collection_config = self.collections.read().get(collection).cloned();
if let Some(ref config) = collection_config {
if let Some(expected_dim) = config.dimension {
if vector.len() != expected_dim {
return Err(VectorError::DimensionMismatch {
expected: expected_dim,
got: vector.len(),
});
}
}
}
if let Some(max_dim) = self.config.max_dimension {
if vector.len() > max_dim {
return Err(VectorError::DimensionMismatch {
expected: max_dim,
got: vector.len(),
});
}
}
let storage_key = Self::collection_embedding_key(collection, key);
let mut tensor = TensorData::new();
let storage = if self.should_use_sparse(&vector) {
TensorValue::Sparse(SparseVector::from_dense(&vector))
} else {
TensorValue::Vector(vector)
};
tensor.set("vector", storage);
for (field, value) in metadata {
tensor.set(Self::metadata_field_key(&field), value);
}
self.store.put(storage_key, tensor)?;
self.invalidate_hnsw_cache(collection);
Ok(())
}
pub fn get_from_collection(&self, collection: &str, key: &str) -> Result<Vec<f32>> {
let storage_key = Self::collection_embedding_key(collection, key);
let tensor = self
.store
.get(&storage_key)
.map_err(|_| VectorError::NotFound(format!("{collection}:{key}")))?;
let vector_value = tensor
.get("vector")
.ok_or_else(|| VectorError::NotFound(format!("{collection}:{key}")))?;
Self::extract_vector(vector_value)
.ok_or_else(|| VectorError::NotFound(format!("{collection}:{key}")))
}
pub fn delete_from_collection(&self, collection: &str, key: &str) -> Result<()> {
let storage_key = Self::collection_embedding_key(collection, key);
if !self.store.exists(&storage_key) {
return Err(VectorError::NotFound(format!("{collection}:{key}")));
}
self.store.delete(&storage_key)?;
self.invalidate_hnsw_cache(collection);
Ok(())
}
pub fn exists_in_collection(&self, collection: &str, key: &str) -> bool {
let storage_key = Self::collection_embedding_key(collection, key);
self.store.exists(&storage_key)
}
pub fn list_collection_keys(&self, collection: &str) -> Vec<String> {
let prefix = Self::collection_embedding_prefix(collection);
self.store
.scan(&prefix)
.into_iter()
.filter_map(|k| k.strip_prefix(&prefix).map(ToString::to_string))
.collect()
}
pub fn get_collection_metadata(
&self,
collection: &str,
key: &str,
) -> Result<HashMap<String, TensorValue>> {
let storage_key = Self::collection_embedding_key(collection, key);
let tensor = self
.store
.get(&storage_key)
.map_err(|_| VectorError::NotFound(key.to_string()))?;
let mut metadata = HashMap::new();
for (field, value) in tensor.fields_iter() {
if let Some(meta_field) = field.strip_prefix(Self::METADATA_PREFIX) {
metadata.insert(meta_field.to_string(), value.clone());
}
}
Ok(metadata)
}
#[instrument(skip(self, query), fields(collection = %collection, query_dim = query.len(), top_k))]
pub fn search_in_collection(
&self,
collection: &str,
query: &[f32],
top_k: usize,
) -> Result<Vec<SearchResult>> {
let deadline = Deadline::from_duration(self.config.search_timeout);
if query.is_empty() {
return Err(VectorError::EmptyVector);
}
if top_k == 0 {
return Err(VectorError::InvalidTopK);
}
let collection_config = self.collections.read().get(collection).cloned();
if let Some(ref config) = collection_config {
if let Some(expected_dim) = config.dimension {
if query.len() != expected_dim {
return Err(VectorError::DimensionMismatch {
expected: expected_dim,
got: query.len(),
});
}
}
}
let prefix = Self::collection_embedding_prefix(collection);
let metric = collection_config
.as_ref()
.map_or(DistanceMetric::Cosine, |c| c.distance_metric);
let query_magnitude = Self::magnitude(query);
if query_magnitude == 0.0 && metric == DistanceMetric::Cosine {
return Ok(Vec::new());
}
{
let cache = self.hnsw_cache.read();
if let Some((index, mapping)) = cache.get(collection) {
if !mapping.is_empty() {
let neighbors = index.search(query, top_k);
let mut results: Vec<SearchResult> = neighbors
.into_iter()
.filter_map(|(idx, score)| {
mapping.get(idx).map(|key| SearchResult {
key: key.strip_prefix(&prefix).unwrap_or(key).to_string(),
score,
})
})
.collect();
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(top_k);
return Ok(results);
}
}
}
let keys: Vec<_> = self.store.scan(&prefix);
if deadline.is_expired() {
return Err(VectorError::SearchTimeout {
operation: "search_in_collection".to_string(),
timeout_ms: deadline.timeout_ms(),
});
}
let mut results: Vec<SearchResult> = keys
.into_iter()
.filter_map(|storage_key| {
let key = storage_key.strip_prefix(&prefix)?;
let tensor = self.store.get(&storage_key).ok()?;
let vector_value = tensor.get("vector")?;
let vector = Self::extract_vector(vector_value)?;
if vector.len() != query.len() {
return None;
}
let score = Self::compute_score(query, &vector, query_magnitude, metric);
Some(SearchResult::new(key.to_string(), score))
})
.collect();
if deadline.is_expired() {
return Err(VectorError::SearchTimeout {
operation: "search_in_collection".to_string(),
timeout_ms: deadline.timeout_ms(),
});
}
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(top_k);
Ok(results)
}
#[allow(clippy::too_many_lines)]
#[instrument(skip(self, query, filter, config), fields(collection = %collection, query_dim = query.len(), top_k))]
pub fn search_filtered_in_collection(
&self,
collection: &str,
query: &[f32],
top_k: usize,
filter: &FilterCondition,
config: Option<FilteredSearchConfig>,
) -> Result<Vec<SearchResult>> {
let deadline = Deadline::from_duration(self.config.search_timeout);
if query.is_empty() {
return Err(VectorError::EmptyVector);
}
if top_k == 0 {
return Err(VectorError::InvalidTopK);
}
let collection_config_opt = self.collections.read().get(collection).cloned();
if let Some(ref coll_config) = collection_config_opt {
if let Some(expected_dim) = coll_config.dimension {
if query.len() != expected_dim {
return Err(VectorError::DimensionMismatch {
expected: expected_dim,
got: query.len(),
});
}
}
}
let prefix = Self::collection_embedding_prefix(collection);
let query_magnitude = Self::magnitude(query);
if query_magnitude == 0.0 {
return Ok(Vec::new());
}
let filter_config = config.unwrap_or_default();
let strategy = match filter_config.strategy {
FilterStrategy::Auto => {
let keys = self.store.scan(&prefix);
let sample_size = 100.min(keys.len());
if sample_size == 0 {
FilterStrategy::PostFilter
} else {
let matches = keys
.iter()
.take(sample_size)
.filter(|k| {
self.store
.get(k)
.map(|t| Self::evaluate_filter(&t, filter))
.unwrap_or(false)
})
.count();
#[allow(clippy::cast_precision_loss)]
let selectivity = matches as f32 / sample_size as f32;
if selectivity < filter_config.selectivity_threshold {
FilterStrategy::PreFilter
} else {
FilterStrategy::PostFilter
}
}
},
other => other,
};
if deadline.is_expired() {
return Err(VectorError::SearchTimeout {
operation: "search_filtered_in_collection".to_string(),
timeout_ms: deadline.timeout_ms(),
});
}
let mut results: Vec<SearchResult> = match strategy {
FilterStrategy::PreFilter | FilterStrategy::Auto => {
self.store
.scan(&prefix)
.into_iter()
.filter_map(|storage_key| {
let tensor = self.store.get(&storage_key).ok()?;
if !Self::evaluate_filter(&tensor, filter) {
return None;
}
let key = storage_key.strip_prefix(&prefix)?;
let vector_value = tensor.get("vector")?;
let vector = Self::extract_vector(vector_value)?;
if vector.len() != query.len() {
return None;
}
let score = Self::cosine_similarity(query, &vector, query_magnitude);
Some(SearchResult::new(key.to_string(), score))
})
.collect()
},
FilterStrategy::PostFilter => {
let oversample_k = top_k
.saturating_mul(filter_config.oversample_factor)
.max(top_k);
let candidates = self.search_in_collection(collection, query, oversample_k)?;
candidates
.into_iter()
.filter(|r| {
let storage_key = Self::collection_embedding_key(collection, &r.key);
self.store
.get(&storage_key)
.map(|t| Self::evaluate_filter(&t, filter))
.unwrap_or(false)
})
.collect()
},
};
if deadline.is_expired() {
return Err(VectorError::SearchTimeout {
operation: "search_filtered_in_collection".to_string(),
timeout_ms: deadline.timeout_ms(),
});
}
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(top_k);
Ok(results)
}
#[instrument(skip(self, vector), fields(key = %key, vector_dim = vector.len()))]
pub fn store_embedding(&self, key: &str, vector: Vec<f32>) -> Result<()> {
if vector.is_empty() {
return Err(VectorError::EmptyVector);
}
if let Some(max_dim) = self.config.max_dimension {
if vector.len() > max_dim {
return Err(VectorError::DimensionMismatch {
expected: max_dim,
got: vector.len(),
});
}
}
let storage_key = Self::embedding_key(key);
let mut tensor = TensorData::new();
let storage = if self.should_use_sparse(&vector) {
TensorValue::Sparse(SparseVector::from_dense(&vector))
} else {
TensorValue::Vector(vector)
};
tensor.set("vector", storage);
self.store.put(storage_key, tensor)?;
self.invalidate_hnsw_cache("_default");
Ok(())
}
fn should_use_sparse(&self, vector: &[f32]) -> bool {
Self::should_use_sparse_with_threshold(vector, self.config.sparse_threshold)
}
fn should_use_sparse_with_threshold(vector: &[f32], threshold: f32) -> bool {
if vector.is_empty() {
return false;
}
let nnz = vector.iter().filter(|&&v| v.abs() > 1e-6).count();
#[allow(clippy::cast_precision_loss)]
let zero_ratio = 1.0 - (nnz as f32 / vector.len() as f32);
zero_ratio >= threshold
}
#[instrument(skip(self), fields(key = %key))]
pub fn get_embedding(&self, key: &str) -> Result<Vec<f32>> {
let storage_key = Self::embedding_key(key);
let tensor = self
.store
.get(&storage_key)
.map_err(|_| VectorError::NotFound(key.to_string()))?;
match tensor.get("vector") {
Some(TensorValue::Vector(v)) => Ok(v.clone()),
Some(TensorValue::Sparse(s)) => Ok(s.to_dense()),
_ => Err(VectorError::NotFound(key.to_string())),
}
}
#[instrument(skip(self), fields(key = %key))]
pub fn delete_embedding(&self, key: &str) -> Result<()> {
let _guard = self.delete_lock.write();
let storage_key = Self::embedding_key(key);
if !self.store.exists(&storage_key) {
return Err(VectorError::NotFound(key.to_string()));
}
self.store.delete(&storage_key)?;
self.invalidate_hnsw_cache("_default");
Ok(())
}
#[instrument(skip(self), fields(key = %key))]
pub fn exists(&self, key: &str) -> bool {
let storage_key = Self::embedding_key(key);
self.store.exists(&storage_key)
}
#[instrument(skip(self))]
pub fn count(&self) -> usize {
self.store.scan_count(Self::embedding_prefix())
}
#[instrument(skip(self, query), fields(query_dim = query.len(), top_k = top_k))]
pub fn search_similar(&self, query: &[f32], top_k: usize) -> Result<Vec<SearchResult>> {
let deadline = Deadline::from_duration(self.config.search_timeout);
if query.is_empty() {
return Err(VectorError::EmptyVector);
}
if top_k == 0 {
return Err(VectorError::InvalidTopK);
}
if let Some(max_dim) = self.config.max_dimension {
if query.len() > max_dim {
return Err(VectorError::DimensionMismatch {
expected: max_dim,
got: query.len(),
});
}
}
let query_magnitude = Self::magnitude(query);
if query_magnitude == 0.0 {
return Ok(Vec::new());
}
{
let cache = self.hnsw_cache.read();
if let Some((index, mapping)) = cache.get("_default") {
if !mapping.is_empty() {
let neighbors = index.search(query, top_k);
let prefix = Self::embedding_prefix();
let mut results: Vec<SearchResult> = neighbors
.into_iter()
.filter_map(|(idx, score)| {
mapping.get(idx).map(|key| SearchResult {
key: key.strip_prefix(prefix).unwrap_or(key).to_string(),
score,
})
})
.collect();
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(top_k);
return Ok(results);
}
}
}
let keys = self.store.scan(Self::embedding_prefix());
if deadline.is_expired() {
return Err(VectorError::SearchTimeout {
operation: "search_similar".to_string(),
timeout_ms: deadline.timeout_ms(),
});
}
let mut results: Vec<SearchResult> = if keys.len() >= self.config.parallel_threshold {
Self::search_parallel(&self.store, &keys, query, query_magnitude)
} else {
Self::search_sequential(&self.store, &keys, query, query_magnitude)
};
if deadline.is_expired() {
return Err(VectorError::SearchTimeout {
operation: "search_similar".to_string(),
timeout_ms: deadline.timeout_ms(),
});
}
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(top_k);
Ok(results)
}
#[instrument(skip(self, query), fields(query_dim = query.len(), top_k, metric = ?metric))]
pub fn search_similar_with_metric(
&self,
query: &[f32],
top_k: usize,
metric: DistanceMetric,
) -> Result<Vec<SearchResult>> {
let deadline = Deadline::from_duration(self.config.search_timeout);
if query.is_empty() {
return Err(VectorError::EmptyVector);
}
if top_k == 0 {
return Err(VectorError::InvalidTopK);
}
let query_magnitude = Self::magnitude(query);
if query_magnitude == 0.0 && !matches!(metric, DistanceMetric::Euclidean) {
return Ok(Vec::new());
}
let keys = self.store.scan(Self::embedding_prefix());
if deadline.is_expired() {
return Err(VectorError::SearchTimeout {
operation: "search_similar_with_metric".to_string(),
timeout_ms: deadline.timeout_ms(),
});
}
let mut results: Vec<SearchResult> = if keys.len() >= self.config.parallel_threshold {
Self::search_parallel_with_metric(&self.store, &keys, query, query_magnitude, metric)
} else {
Self::search_sequential_with_metric(&self.store, &keys, query, query_magnitude, metric)
};
if deadline.is_expired() {
return Err(VectorError::SearchTimeout {
operation: "search_similar_with_metric".to_string(),
timeout_ms: deadline.timeout_ms(),
});
}
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(top_k);
Ok(results)
}
fn extract_vector(value: &TensorValue) -> Option<Vec<f32>> {
match value {
TensorValue::Vector(v) => Some(v.clone()),
TensorValue::Sparse(s) => Some(s.to_dense()),
_ => None,
}
}
#[allow(clippy::similar_names)]
fn search_sequential(
store: &TensorStore,
keys: &[String],
query: &[f32],
query_magnitude: f32,
) -> Vec<SearchResult> {
keys.iter()
.filter_map(|storage_key| {
let tensor = store.get(storage_key).ok()?;
let stored_vec = Self::extract_vector(tensor.get("vector")?)?;
if stored_vec.len() != query.len() {
return None;
}
let score = Self::cosine_similarity(query, &stored_vec, query_magnitude);
let key = storage_key
.strip_prefix(Self::embedding_prefix())
.unwrap_or(storage_key)
.to_string();
Some(SearchResult::new(key, score))
})
.collect()
}
#[allow(clippy::similar_names)]
fn search_parallel(
store: &TensorStore,
keys: &[String],
query: &[f32],
query_magnitude: f32,
) -> Vec<SearchResult> {
keys.par_iter()
.filter_map(|storage_key| {
let tensor = store.get(storage_key).ok()?;
let stored_vec = Self::extract_vector(tensor.get("vector")?)?;
if stored_vec.len() != query.len() {
return None;
}
let score = Self::cosine_similarity(query, &stored_vec, query_magnitude);
let key = storage_key
.strip_prefix(Self::embedding_prefix())
.unwrap_or(storage_key)
.to_string();
Some(SearchResult::new(key, score))
})
.collect()
}
#[allow(clippy::similar_names)]
fn search_sequential_with_metric(
store: &TensorStore,
keys: &[String],
query: &[f32],
query_magnitude: f32,
metric: DistanceMetric,
) -> Vec<SearchResult> {
keys.iter()
.filter_map(|storage_key| {
let tensor = store.get(storage_key).ok()?;
let stored_vec = Self::extract_vector(tensor.get("vector")?)?;
if stored_vec.len() != query.len() {
return None;
}
let score = Self::compute_score(query, &stored_vec, query_magnitude, metric);
let key = storage_key
.strip_prefix(Self::embedding_prefix())
.unwrap_or(storage_key)
.to_string();
Some(SearchResult::new(key, score))
})
.collect()
}
#[allow(clippy::similar_names)]
fn search_parallel_with_metric(
store: &TensorStore,
keys: &[String],
query: &[f32],
query_magnitude: f32,
metric: DistanceMetric,
) -> Vec<SearchResult> {
keys.par_iter()
.filter_map(|storage_key| {
let tensor = store.get(storage_key).ok()?;
let stored_vec = Self::extract_vector(tensor.get("vector")?)?;
if stored_vec.len() != query.len() {
return None;
}
let score = Self::compute_score(query, &stored_vec, query_magnitude, metric);
let key = storage_key
.strip_prefix(Self::embedding_prefix())
.unwrap_or(storage_key)
.to_string();
Some(SearchResult::new(key, score))
})
.collect()
}
fn compute_score(
query: &[f32],
stored: &[f32],
query_magnitude: f32,
metric: DistanceMetric,
) -> f32 {
match metric {
DistanceMetric::Cosine => Self::cosine_similarity(query, stored, query_magnitude),
DistanceMetric::DotProduct => simd::dot_product(query, stored),
DistanceMetric::Euclidean => {
let dist = Self::euclidean_distance(query, stored);
1.0 / (1.0 + dist)
},
}
}
fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
let sum_sq: f32 = a.iter().zip(b.iter()).map(|(x, y)| (x - y) * (x - y)).sum();
sum_sq.sqrt()
}
fn cosine_similarity(a: &[f32], b: &[f32], a_magnitude: f32) -> f32 {
let dot_product = simd::dot_product(a, b);
let b_magnitude = simd::magnitude(b);
if a_magnitude == 0.0 || b_magnitude == 0.0 {
return 0.0;
}
dot_product / (a_magnitude * b_magnitude)
}
fn magnitude(v: &[f32]) -> f32 {
simd::magnitude(v)
}
pub fn compute_similarity(a: &[f32], b: &[f32]) -> Result<f32> {
if a.is_empty() || b.is_empty() {
return Err(VectorError::EmptyVector);
}
if a.len() != b.len() {
return Err(VectorError::DimensionMismatch {
expected: a.len(),
got: b.len(),
});
}
let a_magnitude = Self::magnitude(a);
if a_magnitude == 0.0 {
return Ok(0.0);
}
Ok(Self::cosine_similarity(a, b, a_magnitude))
}
pub fn dimension(&self) -> Option<usize> {
let keys = self.store.scan(Self::embedding_prefix());
for key in keys {
if let Ok(tensor) = self.store.get(&key) {
if let Some(vec) = tensor.get("vector").and_then(Self::extract_vector) {
return Some(vec.len());
}
}
}
None
}
#[instrument(skip(self))]
pub fn list_keys(&self) -> Vec<String> {
self.list_keys_bounded()
}
#[instrument(skip(self))]
pub fn list_keys_bounded(&self) -> Vec<String> {
let limit = self.config.max_keys_per_scan.unwrap_or(usize::MAX);
self.store
.scan(Self::embedding_prefix())
.into_iter()
.take(limit)
.filter_map(|k| k.strip_prefix(Self::embedding_prefix()).map(String::from))
.collect()
}
#[instrument(skip(self))]
pub fn clear(&self) -> Result<usize> {
let max_keys = self.config.max_keys_per_scan.unwrap_or(usize::MAX);
let keys: Vec<_> = self
.store
.scan(Self::embedding_prefix())
.into_iter()
.take(max_keys)
.collect();
let count = keys.len();
for key in keys {
self.store.delete(&key)?;
}
Ok(count)
}
#[instrument(skip(self, config))]
pub fn build_hnsw_index(&self, config: HNSWConfig) -> Result<(HNSWIndex, Vec<String>)> {
self.build_hnsw_index_with_options(HNSWBuildOptions {
storage: HNSWStorageStrategy::Dense,
hnsw_config: config,
})
}
#[instrument(skip(self, options))]
pub fn build_hnsw_index_with_options(
&self,
options: HNSWBuildOptions,
) -> Result<(HNSWIndex, Vec<String>)> {
let keys = self.list_keys();
if keys.is_empty() {
return Ok((HNSWIndex::with_config(options.hnsw_config), Vec::new()));
}
let first_key = &keys[0];
let first_vector = self.get_embedding(first_key)?;
let expected_dim = first_vector.len();
if let Some(max_dim) = self.config.max_dimension {
if expected_dim > max_dim {
return Err(VectorError::DimensionMismatch {
expected: max_dim,
got: expected_dim,
});
}
}
let index = HNSWIndex::with_config(options.hnsw_config);
let mut key_mapping = Vec::with_capacity(keys.len());
insert_with_strategy(&index, first_vector, options.storage);
key_mapping.push(first_key.clone());
for key in keys.into_iter().skip(1) {
let vector = self.get_embedding(&key)?;
if vector.len() != expected_dim {
return Err(VectorError::DimensionMismatch {
expected: expected_dim,
got: vector.len(),
});
}
insert_with_strategy(&index, vector, options.storage);
key_mapping.push(key);
}
Ok((index, key_mapping))
}
#[instrument(skip(self))]
pub fn build_hnsw_index_default(&self) -> Result<(HNSWIndex, Vec<String>)> {
self.build_hnsw_index(HNSWConfig::default())
}
pub fn estimate_hnsw_memory(&self) -> Result<usize> {
let count = self.count();
if count == 0 {
return Ok(0);
}
let keys = self.list_keys();
let first = self.get_embedding(&keys[0])?;
let dim = first.len();
let vector_bytes = count * dim * 4;
let graph_bytes = count * 16 * 2 * 8;
let key_bytes = count * 32;
Ok(vector_bytes + graph_bytes + key_bytes)
}
pub fn search_with_hnsw(
&self,
index: &HNSWIndex,
key_mapping: &[String],
query: &[f32],
top_k: usize,
) -> Result<Vec<SearchResult>> {
let deadline = Deadline::from_duration(self.config.search_timeout);
if query.is_empty() {
return Err(VectorError::EmptyVector);
}
if top_k == 0 {
return Err(VectorError::InvalidTopK);
}
let results = index.search(query, top_k);
if deadline.is_expired() {
return Err(VectorError::SearchTimeout {
operation: "search_with_hnsw".to_string(),
timeout_ms: deadline.timeout_ms(),
});
}
Ok(results
.into_iter()
.filter_map(|(node_id, score)| {
key_mapping.get(node_id).map(|key| SearchResult {
key: key.clone(),
score,
})
})
.collect())
}
pub fn search_with_hnsw_and_metric(
&self,
index: &HNSWIndex,
key_mapping: &[String],
query: &[f32],
top_k: usize,
metric: &ExtendedDistanceMetric,
) -> Result<Vec<SearchResult>> {
let deadline = Deadline::from_duration(self.config.search_timeout);
if query.is_empty() {
return Err(VectorError::EmptyVector);
}
if top_k == 0 {
return Err(VectorError::InvalidTopK);
}
let candidate_count = top_k.saturating_mul(2).max(10);
let candidates = index.search(query, candidate_count);
if deadline.is_expired() {
return Err(VectorError::SearchTimeout {
operation: "search_with_hnsw_and_metric".to_string(),
timeout_ms: deadline.timeout_ms(),
});
}
let query_sparse = SparseVector::from_dense(query);
let mut results: Vec<SearchResult> = candidates
.iter()
.filter_map(|(node_id, _)| {
let key = key_mapping.get(*node_id)?;
let vector = self.get_embedding(key).ok()?;
let stored_sparse = SparseVector::from_dense(&vector);
let raw_score = metric.compute(&query_sparse, &stored_sparse);
let score = metric.to_similarity(raw_score);
Some(SearchResult::new(key.clone(), score))
})
.collect();
if deadline.is_expired() {
return Err(VectorError::SearchTimeout {
operation: "search_with_hnsw_and_metric".to_string(),
timeout_ms: deadline.timeout_ms(),
});
}
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(top_k);
Ok(results)
}
#[instrument(skip(self, options))]
pub fn build_ivf_index(&self, options: IVFBuildOptions) -> Result<(IVFIndex, Vec<String>)> {
let keys = self.list_keys();
if keys.is_empty() {
return Ok((IVFIndex::new(options.config), Vec::new()));
}
let mut vectors: Vec<Vec<f32>> = Vec::with_capacity(keys.len());
let mut expected_dim = None;
for key in &keys {
let vector = self.get_embedding(key)?;
match expected_dim {
None => expected_dim = Some(vector.len()),
Some(dim) if dim != vector.len() => {
return Err(VectorError::DimensionMismatch {
expected: dim,
got: vector.len(),
});
},
_ => {},
}
vectors.push(vector);
}
if let Some(max_dim) = self.config.max_dimension {
if let Some(dim) = expected_dim {
if dim > max_dim {
return Err(VectorError::DimensionMismatch {
expected: max_dim,
got: dim,
});
}
}
}
let vector_refs: Vec<&[f32]> = vectors.iter().map(Vec::as_slice).collect();
let mut index = IVFIndex::new(options.config);
index.train(&vector_refs);
for vector in &vectors {
index.add(vector);
}
Ok((index, keys))
}
#[instrument(skip(self))]
pub fn build_ivf_index_default(&self) -> Result<(IVFIndex, Vec<String>)> {
self.build_ivf_index(IVFBuildOptions::default())
}
#[instrument(skip(self, index, key_mapping, query))]
pub fn search_with_ivf(
&self,
index: &IVFIndex,
key_mapping: &[String],
query: &[f32],
top_k: usize,
) -> Result<Vec<SearchResult>> {
let deadline = Deadline::from_duration(self.config.search_timeout);
if query.is_empty() {
return Err(VectorError::EmptyVector);
}
if top_k == 0 {
return Err(VectorError::InvalidTopK);
}
let results = index.search(query, top_k);
if deadline.is_expired() {
return Err(VectorError::SearchTimeout {
operation: "search_with_ivf".to_string(),
timeout_ms: deadline.timeout_ms(),
});
}
Ok(results
.into_iter()
.filter_map(|(vector_id, distance)| {
key_mapping.get(vector_id).map(|key| SearchResult {
key: key.clone(),
score: 1.0 / (1.0 + distance),
})
})
.collect())
}
#[instrument(skip(self, index, key_mapping, query))]
pub fn search_with_ivf_nprobe(
&self,
index: &IVFIndex,
key_mapping: &[String],
query: &[f32],
top_k: usize,
nprobe: usize,
) -> Result<Vec<SearchResult>> {
let deadline = Deadline::from_duration(self.config.search_timeout);
if query.is_empty() {
return Err(VectorError::EmptyVector);
}
if top_k == 0 {
return Err(VectorError::InvalidTopK);
}
let results = index.search_with_nprobe(query, top_k, nprobe);
if deadline.is_expired() {
return Err(VectorError::SearchTimeout {
operation: "search_with_ivf_nprobe".to_string(),
timeout_ms: deadline.timeout_ms(),
});
}
Ok(results
.into_iter()
.filter_map(|(vector_id, distance)| {
key_mapping.get(vector_id).map(|key| SearchResult {
key: key.clone(),
score: 1.0 / (1.0 + distance),
})
})
.collect())
}
pub fn estimate_ivf_memory(&self, options: &IVFBuildOptions) -> Result<usize> {
let count = self.count();
if count == 0 {
return Ok(0);
}
let keys = self.list_keys();
let first = self.get_embedding(&keys[0])?;
let dim = first.len();
let num_clusters = options.config.num_clusters;
let centroid_bytes = num_clusters * dim * 4;
let vector_bytes = match &options.config.storage {
IVFStorage::Flat => count * dim * 4, IVFStorage::PQ(pq_config) => {
count * pq_config.num_subspaces },
IVFStorage::Binary(_) => count * dim.div_ceil(64) * 8, };
let list_overhead = count * 8;
Ok(centroid_bytes + vector_bytes + list_overhead)
}
#[instrument(skip(self, inputs), fields(count = inputs.len()))]
#[allow(clippy::needless_pass_by_value)]
pub fn batch_store_embeddings(&self, inputs: Vec<EmbeddingInput>) -> Result<BatchResult> {
if inputs.is_empty() {
return Ok(BatchResult::new(Vec::new()));
}
for (i, input) in inputs.iter().enumerate() {
if input.vector.is_empty() {
return Err(VectorError::BatchValidationError {
index: i,
cause: "Empty vector provided".to_string(),
});
}
}
if inputs.len() >= self.config.batch_parallel_threshold {
let results: Vec<Result<String>> = inputs
.par_iter()
.enumerate()
.map(|(i, input)| {
self.store_embedding(&input.key, input.vector.clone())
.map(|()| input.key.clone())
.map_err(|e| VectorError::BatchOperationError {
index: i,
cause: e.to_string(),
})
})
.collect();
let mut stored_keys = Vec::with_capacity(inputs.len());
for result in results {
stored_keys.push(result?);
}
Ok(BatchResult::new(stored_keys))
} else {
let mut stored_keys = Vec::with_capacity(inputs.len());
for (i, input) in inputs.iter().enumerate() {
self.store_embedding(&input.key, input.vector.clone())
.map_err(|e| VectorError::BatchOperationError {
index: i,
cause: e.to_string(),
})?;
stored_keys.push(input.key.clone());
}
Ok(BatchResult::new(stored_keys))
}
}
#[instrument(skip(self, keys), fields(count = keys.len()))]
pub fn batch_delete_embeddings(&self, keys: Vec<String>) -> Result<usize> {
let _guard = self.delete_lock.write();
let deleted = keys
.into_iter()
.filter(|key| {
let storage_key = Self::embedding_key(key);
if self.store.exists(&storage_key) {
self.store.delete(&storage_key).is_ok()
} else {
false
}
})
.count();
Ok(deleted)
}
#[instrument(skip(self, pagination))]
pub fn list_keys_paginated(&self, pagination: Pagination) -> PagedResult<String> {
let max_scan = self.config.max_keys_per_scan.unwrap_or(usize::MAX);
let fetch_limit = pagination
.skip
.saturating_add(pagination.limit.unwrap_or(max_scan))
.min(max_scan);
let items: Vec<String> = self
.store
.scan(Self::embedding_prefix())
.into_iter()
.take(fetch_limit)
.filter_map(|k| k.strip_prefix(Self::embedding_prefix()).map(String::from))
.skip(pagination.skip)
.take(pagination.limit.unwrap_or(usize::MAX))
.collect();
let total_count = if pagination.count_total {
Some(self.count())
} else {
None
};
let has_more = total_count.map_or_else(
|| items.len() == pagination.limit.unwrap_or(0),
|total| pagination.skip.saturating_add(items.len()) < total,
);
PagedResult::new(items, total_count, has_more)
}
#[instrument(skip(self, query, pagination), fields(query_dim = query.len(), top_k))]
pub fn search_similar_paginated(
&self,
query: &[f32],
top_k: usize,
pagination: Pagination,
) -> Result<PagedResult<SearchResult>> {
let total_needed = pagination
.skip
.saturating_add(pagination.limit.unwrap_or(top_k));
let results = self.search_similar(query, total_needed.min(top_k))?;
let total_count = if pagination.count_total {
Some(results.len())
} else {
None
};
let skipped: Vec<SearchResult> = results.into_iter().skip(pagination.skip).collect();
let items: Vec<SearchResult> = match pagination.limit {
Some(limit) => skipped.into_iter().take(limit).collect(),
None => skipped,
};
let has_more = match (pagination.limit, total_count) {
(Some(_), Some(total)) => pagination.skip.saturating_add(items.len()) < total,
_ => false,
};
Ok(PagedResult::new(items, total_count, has_more))
}
#[instrument(skip(self, query, pagination), fields(query_dim = query.len(), top_k))]
pub fn search_entities_paginated(
&self,
query: &[f32],
top_k: usize,
pagination: Pagination,
) -> Result<PagedResult<SearchResult>> {
let total_needed = pagination
.skip
.saturating_add(pagination.limit.unwrap_or(top_k));
let results = self.search_entities(query, total_needed.min(top_k))?;
let total_count = if pagination.count_total {
Some(results.len())
} else {
None
};
let skipped: Vec<SearchResult> = results.into_iter().skip(pagination.skip).collect();
let items: Vec<SearchResult> = match pagination.limit {
Some(limit) => skipped.into_iter().take(limit).collect(),
None => skipped,
};
let has_more = match (pagination.limit, total_count) {
(Some(_), Some(total)) => pagination.skip.saturating_add(items.len()) < total,
_ => false,
};
Ok(PagedResult::new(items, total_count, has_more))
}
#[instrument(skip(self, vector), fields(key = %entity_key, vector_dim = vector.len()))]
pub fn set_entity_embedding(&self, entity_key: &str, vector: Vec<f32>) -> Result<()> {
if vector.is_empty() {
return Err(VectorError::EmptyVector);
}
if let Some(max_dim) = self.config.max_dimension {
if vector.len() > max_dim {
return Err(VectorError::DimensionMismatch {
expected: max_dim,
got: vector.len(),
});
}
}
let mut tensor = self
.store
.get(entity_key)
.unwrap_or_else(|_| TensorData::new());
let storage = if self.should_use_sparse(&vector) {
TensorValue::Sparse(SparseVector::from_dense(&vector))
} else {
TensorValue::Vector(vector)
};
tensor.set(fields::EMBEDDING, storage);
self.store.put(entity_key, tensor)?;
Ok(())
}
pub fn get_entity_embedding(&self, entity_key: &str) -> Result<Vec<f32>> {
let tensor = self
.store
.get(entity_key)
.map_err(|_| VectorError::NotFound(entity_key.to_string()))?;
match tensor.get(fields::EMBEDDING) {
Some(TensorValue::Vector(v)) => Ok(v.clone()),
Some(TensorValue::Sparse(s)) => Ok(s.to_dense()),
_ => Err(VectorError::NotFound(entity_key.to_string())),
}
}
pub fn entity_has_embedding(&self, entity_key: &str) -> bool {
self.store
.get(entity_key)
.map(|t| t.has(fields::EMBEDDING))
.unwrap_or(false)
}
pub fn remove_entity_embedding(&self, entity_key: &str) -> Result<()> {
let mut tensor = self
.store
.get(entity_key)
.map_err(|_| VectorError::NotFound(entity_key.to_string()))?;
if tensor.remove(fields::EMBEDDING).is_none() {
return Err(VectorError::NotFound(entity_key.to_string()));
}
self.store.put(entity_key, tensor)?;
Ok(())
}
#[instrument(skip(self, query), fields(query_dim = query.len(), top_k = top_k))]
pub fn search_entities(&self, query: &[f32], top_k: usize) -> Result<Vec<SearchResult>> {
let deadline = Deadline::from_duration(self.config.search_timeout);
if query.is_empty() {
return Err(VectorError::EmptyVector);
}
if top_k == 0 {
return Err(VectorError::InvalidTopK);
}
if let Some(max_dim) = self.config.max_dimension {
if query.len() > max_dim {
return Err(VectorError::DimensionMismatch {
expected: max_dim,
got: query.len(),
});
}
}
let query_magnitude = Self::magnitude(query);
if query_magnitude == 0.0 {
return Ok(Vec::new());
}
let max_scan = self.config.max_keys_per_scan.unwrap_or(usize::MAX);
let keys: Vec<_> = self.store.scan("").into_iter().take(max_scan).collect();
if deadline.is_expired() {
return Err(VectorError::SearchTimeout {
operation: "search_entities".to_string(),
timeout_ms: deadline.timeout_ms(),
});
}
let mut results: Vec<SearchResult> = keys
.iter()
.filter_map(|key| {
let tensor = self.store.get(key).ok()?;
let stored_vec = Self::extract_vector(tensor.get(fields::EMBEDDING)?)?;
if stored_vec.len() != query.len() {
return None;
}
let score = Self::cosine_similarity(query, &stored_vec, query_magnitude);
Some(SearchResult::new(key.clone(), score))
})
.collect();
if deadline.is_expired() {
return Err(VectorError::SearchTimeout {
operation: "search_entities".to_string(),
timeout_ms: deadline.timeout_ms(),
});
}
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(top_k);
Ok(results)
}
pub fn scan_entities_with_embeddings(&self) -> Vec<String> {
let max_scan = self.config.max_keys_per_scan.unwrap_or(usize::MAX);
self.store
.scan("")
.into_iter()
.take(max_scan)
.filter(|key| self.entity_has_embedding(key))
.collect()
}
pub fn count_entities_with_embeddings(&self) -> usize {
self.scan_entities_with_embeddings().len()
}
const METADATA_PREFIX: &'static str = "meta:";
fn metadata_field_key(field: &str) -> String {
format!("{}{}", Self::METADATA_PREFIX, field)
}
#[instrument(skip(self, vector, metadata), fields(key = %key, vector_dim = vector.len()))]
pub fn store_embedding_with_metadata(
&self,
key: &str,
vector: Vec<f32>,
metadata: HashMap<String, TensorValue>,
) -> Result<()> {
if vector.is_empty() {
return Err(VectorError::EmptyVector);
}
if let Some(max_dim) = self.config.max_dimension {
if vector.len() > max_dim {
return Err(VectorError::DimensionMismatch {
expected: max_dim,
got: vector.len(),
});
}
}
let storage_key = Self::embedding_key(key);
let mut tensor = TensorData::new();
let storage = if self.should_use_sparse(&vector) {
TensorValue::Sparse(SparseVector::from_dense(&vector))
} else {
TensorValue::Vector(vector)
};
tensor.set("vector", storage);
for (field, value) in metadata {
tensor.set(Self::metadata_field_key(&field), value);
}
self.store.put(storage_key, tensor)?;
Ok(())
}
#[instrument(skip(self), fields(key = %key))]
pub fn get_metadata(&self, key: &str) -> Result<HashMap<String, TensorValue>> {
let storage_key = Self::embedding_key(key);
let tensor = self
.store
.get(&storage_key)
.map_err(|_| VectorError::NotFound(key.to_string()))?;
let mut metadata = HashMap::new();
for (field, value) in tensor.fields_iter() {
if let Some(meta_field) = field.strip_prefix(Self::METADATA_PREFIX) {
metadata.insert(meta_field.to_string(), value.clone());
}
}
Ok(metadata)
}
#[instrument(skip(self, metadata), fields(key = %key))]
pub fn update_metadata(&self, key: &str, metadata: HashMap<String, TensorValue>) -> Result<()> {
let storage_key = Self::embedding_key(key);
let mut tensor = self
.store
.get(&storage_key)
.map_err(|_| VectorError::NotFound(key.to_string()))?;
if tensor.get("vector").is_none() {
return Err(VectorError::NotFound(key.to_string()));
}
for (field, value) in metadata {
tensor.set(Self::metadata_field_key(&field), value);
}
self.store.put(storage_key, tensor)?;
Ok(())
}
#[instrument(skip(self), fields(key = %key, field = %field))]
pub fn remove_metadata_field(&self, key: &str, field: &str) -> Result<()> {
let storage_key = Self::embedding_key(key);
let mut tensor = self
.store
.get(&storage_key)
.map_err(|_| VectorError::NotFound(key.to_string()))?;
tensor.remove(&Self::metadata_field_key(field));
self.store.put(storage_key, tensor)?;
Ok(())
}
pub fn has_metadata_field(&self, key: &str, field: &str) -> bool {
let storage_key = Self::embedding_key(key);
self.store
.get(&storage_key)
.map(|t| t.has(&Self::metadata_field_key(field)))
.unwrap_or(false)
}
pub fn get_metadata_field(&self, key: &str, field: &str) -> Result<Option<TensorValue>> {
let storage_key = Self::embedding_key(key);
let tensor = self
.store
.get(&storage_key)
.map_err(|_| VectorError::NotFound(key.to_string()))?;
Ok(tensor.get(&Self::metadata_field_key(field)).cloned())
}
#[instrument(skip(self, query, filter, config), fields(query_dim = query.len(), top_k))]
pub fn search_similar_filtered(
&self,
query: &[f32],
top_k: usize,
filter: &FilterCondition,
config: Option<FilteredSearchConfig>,
) -> Result<Vec<SearchResult>> {
let deadline = Deadline::from_duration(self.config.search_timeout);
if query.is_empty() {
return Err(VectorError::EmptyVector);
}
if top_k == 0 {
return Err(VectorError::InvalidTopK);
}
if let Some(max_dim) = self.config.max_dimension {
if query.len() > max_dim {
return Err(VectorError::DimensionMismatch {
expected: max_dim,
got: query.len(),
});
}
}
let config = config.unwrap_or_default();
let strategy = match config.strategy {
FilterStrategy::Auto => self.choose_filter_strategy(filter, &config),
other => other,
};
if deadline.is_expired() {
return Err(VectorError::SearchTimeout {
operation: "search_similar_filtered".to_string(),
timeout_ms: deadline.timeout_ms(),
});
}
match strategy {
FilterStrategy::PreFilter | FilterStrategy::Auto => {
Ok(self.search_with_pre_filter(query, top_k, filter))
},
FilterStrategy::PostFilter => {
self.search_with_post_filter(query, top_k, filter, &config)
},
}
}
fn choose_filter_strategy(
&self,
filter: &FilterCondition,
config: &FilteredSearchConfig,
) -> FilterStrategy {
if matches!(filter, FilterCondition::True) {
return FilterStrategy::PostFilter;
}
let sample_size = 100.min(self.count());
if sample_size == 0 {
return FilterStrategy::PostFilter;
}
let keys = self.list_keys();
let sample_keys: Vec<_> = keys.iter().take(sample_size).collect();
let matches = sample_keys
.iter()
.filter(|k| self.evaluate_filter_for_key(k, filter))
.count();
#[allow(clippy::cast_precision_loss)]
let selectivity = matches as f32 / sample_size as f32;
if selectivity < config.selectivity_threshold {
FilterStrategy::PreFilter
} else {
FilterStrategy::PostFilter
}
}
fn search_with_pre_filter(
&self,
query: &[f32],
top_k: usize,
filter: &FilterCondition,
) -> Vec<SearchResult> {
let query_magnitude = Self::magnitude(query);
if query_magnitude == 0.0 {
return Vec::new();
}
let matching_keys: Vec<String> = self
.list_keys()
.into_iter()
.filter(|key| self.evaluate_filter_for_key(key, filter))
.collect();
if matching_keys.is_empty() {
return Vec::new();
}
let mut results: Vec<SearchResult> = matching_keys
.iter()
.filter_map(|key| {
let vector = self.get_embedding(key).ok()?;
if vector.len() != query.len() {
return None;
}
let score = Self::cosine_similarity(query, &vector, query_magnitude);
Some(SearchResult::new(key.clone(), score))
})
.collect();
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(top_k);
results
}
fn search_with_post_filter(
&self,
query: &[f32],
top_k: usize,
filter: &FilterCondition,
config: &FilteredSearchConfig,
) -> Result<Vec<SearchResult>> {
let oversample_k = top_k.saturating_mul(config.oversample_factor).max(top_k);
let candidates = self.search_similar(query, oversample_k)?;
let filtered: Vec<SearchResult> = candidates
.into_iter()
.filter(|r| self.evaluate_filter_for_key(&r.key, filter))
.take(top_k)
.collect();
Ok(filtered)
}
fn evaluate_filter_for_key(&self, key: &str, filter: &FilterCondition) -> bool {
let storage_key = Self::embedding_key(key);
let Ok(tensor) = self.store.get(&storage_key) else {
return false;
};
Self::evaluate_filter(&tensor, filter)
}
fn evaluate_filter(tensor: &TensorData, filter: &FilterCondition) -> bool {
match filter {
FilterCondition::True => true,
FilterCondition::And(a, b) => {
Self::evaluate_filter(tensor, a) && Self::evaluate_filter(tensor, b)
},
FilterCondition::Or(a, b) => {
Self::evaluate_filter(tensor, a) || Self::evaluate_filter(tensor, b)
},
FilterCondition::Exists(field) => tensor.has(&Self::metadata_field_key(field)),
FilterCondition::Eq(field, val) => {
Self::compare_field(tensor, field, val, |ord| ord == std::cmp::Ordering::Equal)
},
FilterCondition::Ne(field, val) => {
Self::compare_field(tensor, field, val, |ord| ord != std::cmp::Ordering::Equal)
},
FilterCondition::Lt(field, val) => {
Self::compare_field(tensor, field, val, |ord| ord == std::cmp::Ordering::Less)
},
FilterCondition::Le(field, val) => {
Self::compare_field(tensor, field, val, |ord| ord != std::cmp::Ordering::Greater)
},
FilterCondition::Gt(field, val) => {
Self::compare_field(tensor, field, val, |ord| ord == std::cmp::Ordering::Greater)
},
FilterCondition::Ge(field, val) => {
Self::compare_field(tensor, field, val, |ord| ord != std::cmp::Ordering::Less)
},
FilterCondition::Contains(field, substr) => {
Self::string_contains(tensor, field, substr)
},
FilterCondition::StartsWith(field, prefix) => {
Self::string_starts_with(tensor, field, prefix)
},
FilterCondition::In(field, values) => values.iter().any(|v| {
Self::compare_field(tensor, field, v, |ord| ord == std::cmp::Ordering::Equal)
}),
}
}
fn compare_field<F>(tensor: &TensorData, field: &str, val: &FilterValue, cmp: F) -> bool
where
F: Fn(std::cmp::Ordering) -> bool,
{
let meta_key = Self::metadata_field_key(field);
let Some(stored) = tensor.get(&meta_key) else {
return false;
};
let ordering = Self::compare_tensor_value_to_filter(stored, val);
ordering.is_some_and(cmp)
}
#[allow(clippy::cast_precision_loss)]
fn compare_tensor_value_to_filter(
tensor_val: &TensorValue,
filter_val: &FilterValue,
) -> Option<std::cmp::Ordering> {
use tensor_store::ScalarValue;
match (tensor_val, filter_val) {
(TensorValue::Scalar(ScalarValue::Int(a)), FilterValue::Int(b)) => Some(a.cmp(b)),
(TensorValue::Scalar(ScalarValue::Float(a)), FilterValue::Float(b)) => a.partial_cmp(b),
(TensorValue::Scalar(ScalarValue::Float(a)), FilterValue::Int(b)) => {
a.partial_cmp(&(*b as f64))
},
(TensorValue::Scalar(ScalarValue::Int(a)), FilterValue::Float(b)) => {
(*a as f64).partial_cmp(b)
},
(TensorValue::Scalar(ScalarValue::String(a)), FilterValue::String(b)) => Some(a.cmp(b)),
(TensorValue::Scalar(ScalarValue::Bool(a)), FilterValue::Bool(b)) => Some(a.cmp(b)),
(TensorValue::Scalar(ScalarValue::Null), FilterValue::Null) => {
Some(std::cmp::Ordering::Equal)
},
_ => None, }
}
fn string_contains(tensor: &TensorData, field: &str, substr: &str) -> bool {
use tensor_store::ScalarValue;
let meta_key = Self::metadata_field_key(field);
match tensor.get(&meta_key) {
Some(TensorValue::Scalar(ScalarValue::String(s))) => s.contains(substr),
_ => false,
}
}
fn string_starts_with(tensor: &TensorData, field: &str, prefix: &str) -> bool {
use tensor_store::ScalarValue;
let meta_key = Self::metadata_field_key(field);
match tensor.get(&meta_key) {
Some(TensorValue::Scalar(ScalarValue::String(s))) => s.starts_with(prefix),
_ => false,
}
}
#[allow(clippy::cast_precision_loss)]
pub fn estimate_filter_selectivity(&self, filter: &FilterCondition) -> f32 {
let count = self.count();
if count == 0 {
return 0.0;
}
let sample_size = 100.min(count);
let keys = self.list_keys();
let sample_keys: Vec<_> = keys.iter().take(sample_size).collect();
let matches = sample_keys
.iter()
.filter(|k| self.evaluate_filter_for_key(k, filter))
.count();
matches as f32 / sample_size as f32
}
pub fn count_matching(&self, filter: &FilterCondition) -> usize {
self.list_keys()
.into_iter()
.filter(|key| self.evaluate_filter_for_key(key, filter))
.count()
}
pub fn list_keys_matching(&self, filter: &FilterCondition) -> Vec<String> {
self.list_keys()
.into_iter()
.filter(|key| self.evaluate_filter_for_key(key, filter))
.collect()
}
#[instrument(skip(self), fields(collection = %collection))]
pub fn snapshot_collection(&self, collection: &str) -> PersistentVectorIndex {
let config = self.get_collection_config(collection).unwrap_or_default();
let mut index = PersistentVectorIndex::new(collection.to_string(), config);
let prefix = if collection == Self::DEFAULT_COLLECTION {
Self::embedding_prefix().to_string()
} else {
Self::collection_embedding_prefix(collection)
};
for storage_key in self.store.scan(&prefix) {
let key = if collection == Self::DEFAULT_COLLECTION {
storage_key.strip_prefix(Self::embedding_prefix())
} else {
storage_key.strip_prefix(&prefix)
};
let Some(key) = key else { continue };
let Ok(tensor) = self.store.get(&storage_key) else {
continue;
};
let Some(vector) = tensor.get("vector").and_then(Self::extract_vector) else {
continue;
};
let mut metadata: HashMap<String, MetadataValue> = HashMap::new();
for (field, value) in tensor.fields_iter() {
if let Some(meta_key) = field.strip_prefix(Self::METADATA_PREFIX) {
if let Some(mv) = MetadataValue::from_tensor_value(value) {
metadata.insert(meta_key.to_string(), mv);
}
}
}
let metadata_opt = if metadata.is_empty() {
None
} else {
Some(metadata)
};
index.push(key.to_string(), vector, metadata_opt);
}
index
}
#[instrument(skip(self, path), fields(collection = %collection))]
pub fn save_index(&self, collection: &str, path: impl AsRef<Path>) -> Result<()> {
let index = self.snapshot_collection(collection);
let json = serde_json::to_string_pretty(&index)
.map_err(|e| VectorError::SerializationError(e.to_string()))?;
fs::write(path, json)?;
Ok(())
}
#[instrument(skip(self, path), fields(collection = %collection))]
pub fn save_index_binary(&self, collection: &str, path: impl AsRef<Path>) -> Result<()> {
let index = self.snapshot_collection(collection);
let bytes = bitcode::encode(&index);
fs::write(path, bytes)?;
Ok(())
}
#[instrument(skip(self, path))]
pub fn load_index(&self, path: impl AsRef<Path>) -> Result<String> {
let path = path.as_ref();
if let Some(max_bytes) = self.config.max_index_file_bytes {
let metadata = fs::metadata(path)?;
if metadata.len() > max_bytes as u64 {
return Err(VectorError::ConfigurationError(format!(
"index file size {} exceeds limit {}",
metadata.len(),
max_bytes
)));
}
}
let json = fs::read_to_string(path)?;
let index: PersistentVectorIndex = serde_json::from_str(&json)
.map_err(|e| VectorError::SerializationError(e.to_string()))?;
if let Some(max_entries) = self.config.max_index_entries {
if index.vectors.len() > max_entries {
return Err(VectorError::ConfigurationError(format!(
"index entry count {} exceeds limit {}",
index.vectors.len(),
max_entries
)));
}
}
self.restore_from_index(index)
}
#[instrument(skip(self, path))]
pub fn load_index_binary(&self, path: impl AsRef<Path>) -> Result<String> {
let path = path.as_ref();
if let Some(max_bytes) = self.config.max_index_file_bytes {
let metadata = fs::metadata(path)?;
if metadata.len() > max_bytes as u64 {
return Err(VectorError::ConfigurationError(format!(
"index file size {} exceeds limit {}",
metadata.len(),
max_bytes
)));
}
}
let bytes = fs::read(path)?;
let index: PersistentVectorIndex =
bitcode::decode(&bytes).map_err(|e| VectorError::SerializationError(e.to_string()))?;
if let Some(max_entries) = self.config.max_index_entries {
if index.vectors.len() > max_entries {
return Err(VectorError::ConfigurationError(format!(
"index entry count {} exceeds limit {}",
index.vectors.len(),
max_entries
)));
}
}
self.restore_from_index(index)
}
fn restore_from_index(&self, index: PersistentVectorIndex) -> Result<String> {
let collection = index.collection.clone();
if collection != Self::DEFAULT_COLLECTION {
let mut collections = self.collections.write();
collections.insert(collection.clone(), index.config.clone());
drop(collections);
}
for entry in index.vectors {
let metadata: HashMap<String, TensorValue> = entry
.metadata
.unwrap_or_default()
.into_iter()
.map(|(k, v)| (k, TensorValue::from(v)))
.collect();
if collection == Self::DEFAULT_COLLECTION {
self.store_embedding_with_metadata(&entry.key, entry.vector, metadata)?;
} else {
self.store_in_collection_with_metadata(
&collection,
&entry.key,
entry.vector,
metadata,
)?;
}
}
Ok(collection)
}
#[instrument(skip(self, dir))]
pub fn save_all_indices(&self, dir: impl AsRef<Path>) -> Result<Vec<String>> {
let dir = dir.as_ref();
fs::create_dir_all(dir)?;
let mut saved = Vec::new();
let default_count = self.count();
if default_count > 0 {
let path = dir.join("default.json");
self.save_index(Self::DEFAULT_COLLECTION, &path)?;
saved.push(Self::DEFAULT_COLLECTION.to_string());
}
for collection in self.list_collections() {
let count = self.collection_count(&collection);
if count > 0 {
let filename = format!("{collection}.json");
let path = dir.join(filename);
self.save_index(&collection, &path)?;
saved.push(collection);
}
}
Ok(saved)
}
#[instrument(skip(self, dir))]
pub fn load_all_indices(&self, dir: impl AsRef<Path>) -> Result<Vec<String>> {
let dir = dir.as_ref();
let mut loaded = Vec::new();
for entry in fs::read_dir(dir)? {
let entry = entry?;
let path = entry.path();
if path.extension().and_then(|e| e.to_str()) == Some("json") {
match self.load_index(&path) {
Ok(collection) => loaded.push(collection),
Err(e) => {
tracing::warn!("Failed to load index from {:?}: {}", path, e);
},
}
}
}
Ok(loaded)
}
}
impl Default for VectorEngine {
fn default() -> Self {
Self::new()
}
}
fn insert_with_strategy(index: &HNSWIndex, vector: Vec<f32>, strategy: HNSWStorageStrategy) {
match strategy {
HNSWStorageStrategy::Dense => {
index.insert(vector);
},
HNSWStorageStrategy::Auto => {
index.insert_auto(vector);
},
HNSWStorageStrategy::Quantized => {
index.insert_quantized(&vector);
},
}
}
#[cfg(test)]
mod tests {
use super::*;
use tensor_store::ScalarValue;
fn create_test_vector(dim: usize, seed: usize) -> Vec<f32> {
(0..dim)
.map(|i| {
let x = (seed * 31 + i * 17) as f32;
(x * 0.0001).sin() * ((seed + i) as f32 * 0.001)
})
.collect()
}
fn normalize(v: &[f32]) -> Vec<f32> {
let mag: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if mag == 0.0 {
v.to_vec()
} else {
v.iter().map(|x| x / mag).collect()
}
}
#[test]
fn store_and_retrieve_embedding() {
let engine = VectorEngine::new();
let vector = vec![1.0, 2.0, 3.0];
engine.store_embedding("test", vector.clone()).unwrap();
let retrieved = engine.get_embedding("test").unwrap();
assert_eq!(retrieved, vector);
}
#[test]
fn store_overwrites_existing() {
let engine = VectorEngine::new();
engine.store_embedding("key", vec![1.0, 2.0]).unwrap();
engine.store_embedding("key", vec![3.0, 4.0]).unwrap();
let retrieved = engine.get_embedding("key").unwrap();
assert_eq!(retrieved, vec![3.0, 4.0]);
}
#[test]
fn delete_embedding() {
let engine = VectorEngine::new();
engine.store_embedding("key", vec![1.0, 2.0]).unwrap();
assert!(engine.exists("key"));
engine.delete_embedding("key").unwrap();
assert!(!engine.exists("key"));
}
#[test]
fn delete_nonexistent_returns_error() {
let engine = VectorEngine::new();
let result = engine.delete_embedding("nonexistent");
assert!(matches!(result, Err(VectorError::NotFound(_))));
}
#[test]
fn get_nonexistent_returns_error() {
let engine = VectorEngine::new();
let result = engine.get_embedding("nonexistent");
assert!(matches!(result, Err(VectorError::NotFound(_))));
}
#[test]
fn empty_vector_returns_error() {
let engine = VectorEngine::new();
let result = engine.store_embedding("key", vec![]);
assert!(matches!(result, Err(VectorError::EmptyVector)));
}
#[test]
fn count_embeddings() {
let engine = VectorEngine::new();
assert_eq!(engine.count(), 0);
engine.store_embedding("a", vec![1.0]).unwrap();
engine.store_embedding("b", vec![2.0]).unwrap();
engine.store_embedding("c", vec![3.0]).unwrap();
assert_eq!(engine.count(), 3);
}
#[test]
fn search_similar_basic() {
let engine = VectorEngine::new();
engine.store_embedding("a", vec![1.0, 0.0, 0.0]).unwrap();
engine.store_embedding("b", vec![0.0, 1.0, 0.0]).unwrap();
engine.store_embedding("c", vec![1.0, 1.0, 0.0]).unwrap();
let results = engine.search_similar(&[1.0, 0.0, 0.0], 3).unwrap();
assert_eq!(results.len(), 3);
assert_eq!(results[0].key, "a");
assert!((results[0].score - 1.0).abs() < 1e-6);
}
#[test]
fn search_similar_top_k() {
let engine = VectorEngine::new();
for i in 0..10 {
engine
.store_embedding(&format!("v{}", i), vec![i as f32, 0.0])
.unwrap();
}
let results = engine.search_similar(&[5.0, 0.0], 3).unwrap();
assert_eq!(results.len(), 3);
}
#[test]
fn search_similar_fewer_than_k() {
let engine = VectorEngine::new();
engine.store_embedding("a", vec![1.0, 0.0]).unwrap();
engine.store_embedding("b", vec![0.0, 1.0]).unwrap();
let results = engine.search_similar(&[1.0, 0.0], 10).unwrap();
assert_eq!(results.len(), 2);
}
#[test]
fn search_similar_empty_query_error() {
let engine = VectorEngine::new();
let result = engine.search_similar(&[], 5);
assert!(matches!(result, Err(VectorError::EmptyVector)));
}
#[test]
fn search_similar_zero_top_k_error() {
let engine = VectorEngine::new();
let result = engine.search_similar(&[1.0, 0.0], 0);
assert!(matches!(result, Err(VectorError::InvalidTopK)));
}
#[test]
fn cosine_similarity_identical_vectors() {
let v = vec![1.0, 2.0, 3.0];
let score = VectorEngine::compute_similarity(&v, &v).unwrap();
assert!((score - 1.0).abs() < 1e-6);
}
#[test]
fn cosine_similarity_orthogonal_vectors() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
let score = VectorEngine::compute_similarity(&a, &b).unwrap();
assert!(score.abs() < 1e-6);
}
#[test]
fn cosine_similarity_opposite_vectors() {
let a = vec![1.0, 0.0];
let b = vec![-1.0, 0.0];
let score = VectorEngine::compute_similarity(&a, &b).unwrap();
assert!((score - (-1.0)).abs() < 1e-6);
}
#[test]
fn cosine_similarity_normalized_vectors() {
let a = normalize(&[1.0, 0.0]);
let b = normalize(&[1.0, 1.0]);
let score = VectorEngine::compute_similarity(&a, &b).unwrap();
let expected = (2.0_f32).sqrt() / 2.0;
assert!((score - expected).abs() < 1e-6);
}
#[test]
fn cosine_similarity_dimension_mismatch() {
let a = vec![1.0, 2.0];
let b = vec![1.0, 2.0, 3.0];
let result = VectorEngine::compute_similarity(&a, &b);
assert!(matches!(result, Err(VectorError::DimensionMismatch { .. })));
}
#[test]
fn cosine_similarity_zero_vector() {
let a = vec![0.0, 0.0];
let b = vec![1.0, 0.0];
let score = VectorEngine::compute_similarity(&a, &b).unwrap();
assert_eq!(score, 0.0);
}
#[test]
fn cosine_similarity_both_zero_vectors() {
let a = vec![0.0, 0.0];
let b = vec![0.0, 0.0];
let score = VectorEngine::compute_similarity(&a, &b).unwrap();
assert_eq!(score, 0.0);
assert!(!score.is_nan());
}
#[test]
fn search_skips_dimension_mismatch() {
let engine = VectorEngine::new();
engine.store_embedding("2d", vec![1.0, 0.0]).unwrap();
engine.store_embedding("3d", vec![1.0, 0.0, 0.0]).unwrap();
let results = engine.search_similar(&[1.0, 0.0], 10).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].key, "2d");
}
#[test]
fn store_10000_vectors_search() {
let engine = VectorEngine::new();
let dim = 128;
for i in 0..10000 {
let vector = create_test_vector(dim, i);
engine.store_embedding(&format!("v{}", i), vector).unwrap();
}
assert_eq!(engine.count(), 10000);
let query = create_test_vector(dim, 5000);
let results = engine.search_similar(&query, 5).unwrap();
assert_eq!(results.len(), 5);
assert_eq!(results[0].key, "v5000");
assert!((results[0].score - 1.0).abs() < 1e-5);
}
#[test]
fn high_dimensional_768() {
let engine = VectorEngine::new();
let dim = 768;
for i in 0..100 {
let vector = create_test_vector(dim, i);
engine.store_embedding(&format!("v{}", i), vector).unwrap();
}
let query = create_test_vector(dim, 50);
let results = engine.search_similar(&query, 3).unwrap();
assert_eq!(results.len(), 3);
assert_eq!(results[0].key, "v50");
}
#[test]
fn high_dimensional_1536() {
let engine = VectorEngine::new();
let dim = 1536;
for i in 0..100 {
let vector = create_test_vector(dim, i);
engine.store_embedding(&format!("v{}", i), vector).unwrap();
}
let query = create_test_vector(dim, 75);
let results = engine.search_similar(&query, 5).unwrap();
assert_eq!(results.len(), 5);
assert_eq!(results[0].key, "v75");
}
#[test]
fn similarity_scores_mathematically_correct() {
let engine = VectorEngine::new();
engine
.store_embedding("unit_x", normalize(&[1.0, 0.0, 0.0]))
.unwrap();
engine
.store_embedding("unit_y", normalize(&[0.0, 1.0, 0.0]))
.unwrap();
engine
.store_embedding("unit_z", normalize(&[0.0, 0.0, 1.0]))
.unwrap();
engine
.store_embedding("diag_xy", normalize(&[1.0, 1.0, 0.0]))
.unwrap();
engine
.store_embedding("neg_x", normalize(&[-1.0, 0.0, 0.0]))
.unwrap();
let query = normalize(&[1.0, 0.0, 0.0]);
let results = engine.search_similar(&query, 5).unwrap();
for result in &results {
match result.key.as_str() {
"unit_x" => assert!((result.score - 1.0).abs() < 1e-6), "unit_y" | "unit_z" => assert!(result.score.abs() < 1e-6), "diag_xy" => {
let expected = (2.0_f32).sqrt() / 2.0; assert!((result.score - expected).abs() < 1e-6);
},
"neg_x" => assert!((result.score - (-1.0)).abs() < 1e-6), _ => panic!("Unexpected key: {}", result.key),
}
}
}
#[test]
fn list_keys() {
let engine = VectorEngine::new();
engine.store_embedding("alpha", vec![1.0]).unwrap();
engine.store_embedding("beta", vec![2.0]).unwrap();
engine.store_embedding("gamma", vec![3.0]).unwrap();
let mut keys = engine.list_keys();
keys.sort();
assert_eq!(keys, vec!["alpha", "beta", "gamma"]);
}
#[test]
fn clear_all() {
let engine = VectorEngine::new();
engine.store_embedding("a", vec![1.0]).unwrap();
engine.store_embedding("b", vec![2.0]).unwrap();
assert_eq!(engine.count(), 2);
let cleared = engine.clear().unwrap();
assert_eq!(cleared, 2);
assert_eq!(engine.count(), 0);
}
#[test]
fn dimension() {
let engine = VectorEngine::new();
assert_eq!(engine.dimension(), None);
engine.store_embedding("test", vec![1.0, 2.0, 3.0]).unwrap();
assert_eq!(engine.dimension(), Some(3));
}
#[test]
fn default_trait() {
let engine = VectorEngine::default();
assert_eq!(engine.count(), 0);
}
#[test]
fn with_store_constructor() {
let store = TensorStore::new();
let engine = VectorEngine::with_store(store);
assert_eq!(engine.count(), 0);
}
#[test]
fn error_display() {
assert_eq!(
format!("{}", VectorError::NotFound("test".into())),
"Embedding not found: test"
);
assert_eq!(
format!(
"{}",
VectorError::DimensionMismatch {
expected: 3,
got: 5
}
),
"Dimension mismatch: expected 3, got 5"
);
assert_eq!(
format!("{}", VectorError::EmptyVector),
"Empty vector provided"
);
assert_eq!(
format!("{}", VectorError::InvalidTopK),
"Invalid top_k value (must be > 0)"
);
assert_eq!(
format!("{}", VectorError::StorageError("test".into())),
"Storage error: test"
);
}
#[test]
fn error_clone_and_eq() {
let e1 = VectorError::NotFound("test".into());
let e2 = e1.clone();
assert_eq!(e1, e2);
}
#[test]
fn error_is_error_trait() {
let error: Box<dyn std::error::Error> = Box::new(VectorError::EmptyVector);
assert!(error.to_string().contains("Empty"));
}
#[test]
fn search_result_clone_and_eq() {
let r1 = SearchResult::new("key".into(), 0.5);
let r2 = r1.clone();
assert_eq!(r1, r2);
}
#[test]
fn search_zero_query_vector() {
let engine = VectorEngine::new();
engine.store_embedding("a", vec![1.0, 0.0]).unwrap();
let results = engine.search_similar(&[0.0, 0.0], 5).unwrap();
assert!(results.is_empty());
}
#[test]
fn search_no_embeddings() {
let engine = VectorEngine::new();
let results = engine.search_similar(&[1.0, 0.0], 5).unwrap();
assert!(results.is_empty());
}
#[test]
fn exists_check() {
let engine = VectorEngine::new();
assert!(!engine.exists("key"));
engine.store_embedding("key", vec![1.0]).unwrap();
assert!(engine.exists("key"));
}
#[test]
fn hnsw_basic_insert_and_search() {
let index = HNSWIndex::new();
index.insert(vec![1.0, 0.0, 0.0]);
index.insert(vec![0.0, 1.0, 0.0]);
index.insert(vec![0.0, 0.0, 1.0]);
assert_eq!(index.len(), 3);
let results = index.search(&[1.0, 0.0, 0.0], 3);
assert_eq!(results.len(), 3);
assert_eq!(results[0].0, 0); assert!((results[0].1 - 1.0).abs() < 1e-6); }
#[test]
fn hnsw_empty_search() {
let index = HNSWIndex::new();
let results = index.search(&[1.0, 0.0], 5);
assert!(results.is_empty());
}
#[test]
fn hnsw_many_vectors() {
let index = HNSWIndex::new();
let dim = 64;
for i in 0..1000 {
let vector = create_test_vector(dim, i);
index.insert(vector);
}
assert_eq!(index.len(), 1000);
let query = create_test_vector(dim, 500);
let results = index.search(&query, 10);
assert_eq!(results.len(), 10);
let exact_match = results.iter().find(|(id, _)| *id == 500);
assert!(
exact_match.is_some(),
"Expected to find node 500 in top 10 results"
);
assert!(exact_match.unwrap().1 > 0.99);
}
#[test]
fn hnsw_high_recall() {
let config = HNSWConfig::high_recall();
let index = HNSWIndex::with_config(config);
let dim = 32;
for i in 0..100 {
let vector = create_test_vector(dim, i);
index.insert(vector);
}
let query = create_test_vector(dim, 42);
let results = index.search(&query, 1);
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, 42);
}
#[test]
fn hnsw_high_speed() {
let config = HNSWConfig::high_speed();
let index = HNSWIndex::with_config(config);
let dim = 32;
for i in 0..100 {
let vector = create_test_vector(dim, i);
index.insert(vector);
}
assert_eq!(index.len(), 100);
}
#[test]
fn hnsw_get_vector() {
let index = HNSWIndex::new();
let original = vec![1.0, 2.0, 3.0];
let id = index.insert(original.clone());
let retrieved = index.get_vector(id);
assert_eq!(retrieved, Some(original));
}
#[test]
fn hnsw_search_with_ef() {
let index = HNSWIndex::new();
let dim = 32;
for i in 0..100 {
index.insert(create_test_vector(dim, i));
}
let query = create_test_vector(dim, 50);
let results_low = index.search_with_ef(&query, 5, 10);
let results_high = index.search_with_ef(&query, 5, 200);
assert_eq!(results_low.len(), 5);
assert_eq!(results_high.len(), 5);
assert!(results_high.iter().any(|(id, _)| *id == 50));
}
#[test]
fn hnsw_default_trait() {
let index = HNSWIndex::default();
assert!(index.is_empty());
}
#[test]
fn hnsw_config_default() {
let config = HNSWConfig::default();
assert_eq!(config.m, 16);
assert_eq!(config.m0, 32);
assert_eq!(config.ef_construction, 200);
assert_eq!(config.ef_search, 50);
}
#[test]
fn engine_build_hnsw_index() {
let engine = VectorEngine::new();
let dim = 32;
for i in 0..50 {
let vector = create_test_vector(dim, i);
engine.store_embedding(&format!("v{}", i), vector).unwrap();
}
let (index, key_mapping) = engine.build_hnsw_index_default().unwrap();
assert_eq!(index.len(), 50);
assert_eq!(key_mapping.len(), 50);
}
#[test]
fn engine_search_with_hnsw() {
let engine = VectorEngine::new();
let dim = 32;
for i in 0..100 {
let vector = create_test_vector(dim, i);
engine
.store_embedding(&format!("vec{}", i), vector)
.unwrap();
}
let (index, key_mapping) = engine.build_hnsw_index_default().unwrap();
let query = create_test_vector(dim, 42);
let results = engine
.search_with_hnsw(&index, &key_mapping, &query, 5)
.unwrap();
assert_eq!(results.len(), 5);
assert!(results.iter().any(|r| r.key.contains("42")));
}
#[test]
fn engine_hnsw_empty_query_error() {
let engine = VectorEngine::new();
let index = HNSWIndex::new();
let key_mapping: Vec<String> = vec![];
let result = engine.search_with_hnsw(&index, &key_mapping, &[], 5);
assert!(matches!(result, Err(VectorError::EmptyVector)));
}
#[test]
fn engine_hnsw_zero_top_k_error() {
let engine = VectorEngine::new();
let index = HNSWIndex::new();
let key_mapping: Vec<String> = vec![];
let result = engine.search_with_hnsw(&index, &key_mapping, &[1.0], 0);
assert!(matches!(result, Err(VectorError::InvalidTopK)));
}
#[test]
fn entity_embedding_store_and_retrieve() {
let engine = VectorEngine::new();
let vector = vec![1.0, 2.0, 3.0];
engine
.set_entity_embedding("user:1", vector.clone())
.unwrap();
let retrieved = engine.get_entity_embedding("user:1").unwrap();
assert_eq!(retrieved, vector);
}
#[test]
fn entity_embedding_preserves_other_fields() {
let store = TensorStore::new();
let mut data = TensorData::new();
data.set(
"name",
TensorValue::Scalar(tensor_store::ScalarValue::String("Alice".into())),
);
store.put("user:1", data).unwrap();
let engine = VectorEngine::with_store(store);
engine
.set_entity_embedding("user:1", vec![0.1, 0.2])
.unwrap();
let tensor = engine.store.get("user:1").unwrap();
assert!(tensor.has("name"));
assert!(tensor.has(fields::EMBEDDING));
}
#[test]
fn entity_has_embedding_check() {
let engine = VectorEngine::new();
assert!(!engine.entity_has_embedding("user:1"));
engine
.set_entity_embedding("user:1", vec![1.0, 2.0])
.unwrap();
assert!(engine.entity_has_embedding("user:1"));
}
#[test]
fn entity_embedding_remove() {
let engine = VectorEngine::new();
engine
.set_entity_embedding("user:1", vec![1.0, 2.0])
.unwrap();
assert!(engine.entity_has_embedding("user:1"));
engine.remove_entity_embedding("user:1").unwrap();
assert!(!engine.entity_has_embedding("user:1"));
}
#[test]
fn entity_embedding_remove_nonexistent_error() {
let engine = VectorEngine::new();
let result = engine.remove_entity_embedding("user:999");
assert!(matches!(result, Err(VectorError::NotFound(_))));
}
#[test]
fn entity_embedding_get_nonexistent_error() {
let engine = VectorEngine::new();
let result = engine.get_entity_embedding("user:999");
assert!(matches!(result, Err(VectorError::NotFound(_))));
}
#[test]
fn entity_embedding_empty_vector_error() {
let engine = VectorEngine::new();
let result = engine.set_entity_embedding("user:1", vec![]);
assert!(matches!(result, Err(VectorError::EmptyVector)));
}
#[test]
fn search_entities_basic() {
let engine = VectorEngine::new();
engine
.set_entity_embedding("user:1", vec![1.0, 0.0, 0.0])
.unwrap();
engine
.set_entity_embedding("user:2", vec![0.0, 1.0, 0.0])
.unwrap();
engine
.set_entity_embedding("user:3", vec![1.0, 1.0, 0.0])
.unwrap();
let results = engine.search_entities(&[1.0, 0.0, 0.0], 3).unwrap();
assert_eq!(results.len(), 3);
assert_eq!(results[0].key, "user:1");
assert!((results[0].score - 1.0).abs() < 1e-6);
}
#[test]
fn search_entities_filters_non_embeddings() {
let store = TensorStore::new();
let mut user1 = TensorData::new();
user1.set(fields::EMBEDDING, TensorValue::Vector(vec![1.0, 0.0]));
store.put("user:1", user1).unwrap();
let mut user2 = TensorData::new();
user2.set(
"name",
TensorValue::Scalar(tensor_store::ScalarValue::String("Bob".into())),
);
store.put("user:2", user2).unwrap();
let engine = VectorEngine::with_store(store);
let results = engine.search_entities(&[1.0, 0.0], 10).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].key, "user:1");
}
#[test]
fn scan_entities_with_embeddings() {
let engine = VectorEngine::new();
engine
.set_entity_embedding("user:1", vec![1.0, 2.0])
.unwrap();
engine
.set_entity_embedding("user:2", vec![3.0, 4.0])
.unwrap();
let keys = engine.scan_entities_with_embeddings();
assert_eq!(keys.len(), 2);
}
#[test]
fn count_entities_with_embeddings() {
let engine = VectorEngine::new();
assert_eq!(engine.count_entities_with_embeddings(), 0);
engine.set_entity_embedding("user:1", vec![1.0]).unwrap();
engine.set_entity_embedding("user:2", vec![2.0]).unwrap();
assert_eq!(engine.count_entities_with_embeddings(), 2);
}
#[test]
fn search_entities_empty_query_error() {
let engine = VectorEngine::new();
let result = engine.search_entities(&[], 5);
assert!(matches!(result, Err(VectorError::EmptyVector)));
}
#[test]
fn search_entities_zero_top_k_error() {
let engine = VectorEngine::new();
let result = engine.search_entities(&[1.0], 0);
assert!(matches!(result, Err(VectorError::InvalidTopK)));
}
#[test]
fn search_entities_zero_query_returns_empty() {
let engine = VectorEngine::new();
engine
.set_entity_embedding("user:1", vec![1.0, 0.0])
.unwrap();
let results = engine.search_entities(&[0.0, 0.0], 5).unwrap();
assert!(results.is_empty());
}
#[test]
fn distance_metric_default() {
assert_eq!(DistanceMetric::default(), DistanceMetric::Cosine);
}
#[test]
fn search_with_metric_cosine() {
let engine = VectorEngine::new();
engine.store_embedding("a", vec![1.0, 0.0]).unwrap();
engine.store_embedding("b", vec![0.707, 0.707]).unwrap();
engine.store_embedding("c", vec![0.0, 1.0]).unwrap();
let results = engine
.search_similar_with_metric(&[1.0, 0.0], 3, DistanceMetric::Cosine)
.unwrap();
assert_eq!(results.len(), 3);
assert_eq!(results[0].key, "a");
assert!((results[0].score - 1.0).abs() < 0.01);
}
#[test]
fn search_with_metric_dot_product() {
let engine = VectorEngine::new();
engine.store_embedding("a", vec![1.0, 0.0]).unwrap();
engine.store_embedding("b", vec![2.0, 0.0]).unwrap();
engine.store_embedding("c", vec![0.5, 0.0]).unwrap();
let results = engine
.search_similar_with_metric(&[1.0, 0.0], 3, DistanceMetric::DotProduct)
.unwrap();
assert_eq!(results.len(), 3);
assert_eq!(results[0].key, "b");
assert!((results[0].score - 2.0).abs() < 0.01);
}
#[test]
fn search_with_metric_euclidean() {
let engine = VectorEngine::new();
engine.store_embedding("a", vec![1.0, 0.0]).unwrap();
engine.store_embedding("b", vec![2.0, 0.0]).unwrap();
engine.store_embedding("c", vec![10.0, 0.0]).unwrap();
let results = engine
.search_similar_with_metric(&[1.0, 0.0], 3, DistanceMetric::Euclidean)
.unwrap();
assert_eq!(results.len(), 3);
assert_eq!(results[0].key, "a");
assert!((results[0].score - 1.0).abs() < 0.01);
assert_eq!(results[1].key, "b");
assert!((results[1].score - 0.5).abs() < 0.01);
}
#[test]
fn search_with_metric_empty_query_error() {
let engine = VectorEngine::new();
let result = engine.search_similar_with_metric(&[], 5, DistanceMetric::Cosine);
assert!(matches!(result, Err(VectorError::EmptyVector)));
}
#[test]
fn search_with_metric_zero_top_k_error() {
let engine = VectorEngine::new();
let result = engine.search_similar_with_metric(&[1.0], 0, DistanceMetric::Cosine);
assert!(matches!(result, Err(VectorError::InvalidTopK)));
}
#[test]
fn search_with_metric_zero_query() {
let engine = VectorEngine::new();
engine.store_embedding("a", vec![1.0, 0.0]).unwrap();
let results = engine
.search_similar_with_metric(&[0.0, 0.0], 5, DistanceMetric::Cosine)
.unwrap();
assert!(results.is_empty());
}
#[test]
fn search_with_metric_zero_query_euclidean() {
let engine = VectorEngine::new();
engine.store_embedding("origin", vec![0.0, 0.0]).unwrap();
engine.store_embedding("unit", vec![1.0, 0.0]).unwrap();
engine.store_embedding("far", vec![10.0, 0.0]).unwrap();
let results = engine
.search_similar_with_metric(&[0.0, 0.0], 3, DistanceMetric::Euclidean)
.unwrap();
assert_eq!(results.len(), 3);
assert_eq!(results[0].key, "origin");
assert!((results[0].score - 1.0).abs() < 0.01);
assert_eq!(results[1].key, "unit");
assert!((results[1].score - 0.5).abs() < 0.01);
}
#[test]
fn euclidean_distance_identical() {
let a = vec![1.0, 2.0, 3.0];
let dist = VectorEngine::euclidean_distance(&a, &a);
assert!(dist.abs() < 1e-6);
}
#[test]
fn euclidean_distance_unit() {
let a = vec![0.0, 0.0];
let b = vec![1.0, 0.0];
let dist = VectorEngine::euclidean_distance(&a, &b);
assert!((dist - 1.0).abs() < 1e-6);
}
#[test]
fn euclidean_distance_pythagoras() {
let a = vec![0.0, 0.0];
let b = vec![3.0, 4.0];
let dist = VectorEngine::euclidean_distance(&a, &b);
assert!((dist - 5.0).abs() < 1e-6);
}
#[test]
fn sparse_vector_storage_and_retrieval() {
let engine = VectorEngine::new();
let mut sparse = vec![0.0f32; 100];
sparse[0] = 1.0;
sparse[50] = 2.0;
sparse[99] = 3.0;
engine.store_embedding("sparse", sparse.clone()).unwrap();
let retrieved = engine.get_embedding("sparse").unwrap();
assert_eq!(retrieved.len(), sparse.len());
assert_eq!(retrieved[0], 1.0);
assert_eq!(retrieved[50], 2.0);
assert_eq!(retrieved[99], 3.0);
}
#[test]
fn sparse_vector_search() {
let engine = VectorEngine::new();
let mut v1 = vec![0.0f32; 100];
v1[0] = 1.0;
v1[1] = 0.0;
v1[2] = 0.0;
let mut v2 = vec![0.0f32; 100];
v2[0] = 0.707;
v2[1] = 0.707;
let mut v3 = vec![0.0f32; 100];
v3[0] = 0.0;
v3[1] = 1.0;
engine.store_embedding("v1", v1).unwrap();
engine.store_embedding("v2", v2).unwrap();
engine.store_embedding("v3", v3).unwrap();
let mut query = vec![0.0f32; 100];
query[0] = 1.0;
let results = engine.search_similar(&query, 3).unwrap();
assert_eq!(results.len(), 3);
assert_eq!(results[0].key, "v1");
assert!((results[0].score - 1.0).abs() < 0.01);
}
#[test]
fn sparse_entity_embedding() {
let engine = VectorEngine::new();
let mut sparse = vec![0.0f32; 100];
sparse[10] = 5.0;
sparse[20] = -3.0;
engine
.set_entity_embedding("entity:1", sparse.clone())
.unwrap();
let retrieved = engine.get_entity_embedding("entity:1").unwrap();
assert_eq!(retrieved.len(), 100);
assert_eq!(retrieved[10], 5.0);
assert_eq!(retrieved[20], -3.0);
assert_eq!(retrieved[0], 0.0);
}
#[test]
fn sparse_detection_threshold() {
let half_sparse: Vec<f32> = (0..100).map(|i| if i < 50 { 0.0 } else { 1.0 }).collect();
assert!(VectorEngine::should_use_sparse_with_threshold(
&half_sparse,
0.5
));
let mostly_dense: Vec<f32> = (0..100).map(|i| if i < 40 { 0.0 } else { 1.0 }).collect();
assert!(!VectorEngine::should_use_sparse_with_threshold(
&mostly_dense,
0.5
));
let very_sparse: Vec<f32> = (0..100).map(|i| if i < 3 { 1.0 } else { 0.0 }).collect();
assert!(VectorEngine::should_use_sparse_with_threshold(
&very_sparse,
0.5
));
}
#[test]
fn sparse_search_with_metric() {
let engine = VectorEngine::new();
let mut v1 = vec![0.0f32; 100];
v1[0] = 1.0;
let mut v2 = vec![0.0f32; 100];
v2[0] = 2.0;
engine.store_embedding("v1", v1).unwrap();
engine.store_embedding("v2", v2).unwrap();
let mut query = vec![0.0f32; 100];
query[0] = 1.0;
let results = engine
.search_similar_with_metric(&query, 2, DistanceMetric::Euclidean)
.unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].key, "v1");
}
#[test]
fn search_entities_with_sparse() {
let engine = VectorEngine::new();
let mut e1 = vec![0.0f32; 100];
e1[0] = 1.0;
let mut e2 = vec![0.0f32; 100];
e2[1] = 1.0;
engine.set_entity_embedding("user:1", e1).unwrap();
engine.set_entity_embedding("user:2", e2).unwrap();
let mut query = vec![0.0f32; 100];
query[0] = 1.0;
let results = engine.search_entities(&query, 2).unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].key, "user:1");
assert!((results[0].score - 1.0).abs() < 0.01);
}
#[test]
fn config_default() {
let config = VectorEngineConfig::default();
assert_eq!(config.default_dimension, None);
assert!((config.sparse_threshold - 0.5).abs() < 1e-6);
assert_eq!(config.parallel_threshold, 5000);
assert_eq!(config.default_metric, DistanceMetric::Cosine);
}
#[test]
fn config_high_throughput() {
let config = VectorEngineConfig::high_throughput();
assert_eq!(config.parallel_threshold, 1000);
}
#[test]
fn config_low_memory() {
let config = VectorEngineConfig::low_memory();
assert!((config.sparse_threshold - 0.3).abs() < 1e-6);
}
#[test]
fn config_validate_valid() {
let config = VectorEngineConfig::default();
assert!(config.validate().is_ok());
}
#[test]
fn config_validate_invalid_sparse_threshold() {
let config = VectorEngineConfig {
sparse_threshold: 1.5,
..Default::default()
};
assert!(matches!(
config.validate(),
Err(VectorError::ConfigurationError(_))
));
}
#[test]
fn config_validate_invalid_parallel_threshold() {
let config = VectorEngineConfig {
parallel_threshold: 0,
..Default::default()
};
assert!(matches!(
config.validate(),
Err(VectorError::ConfigurationError(_))
));
}
#[test]
fn engine_with_config() {
let config = VectorEngineConfig::high_throughput();
let engine = VectorEngine::with_config(config).unwrap();
assert_eq!(engine.config().parallel_threshold, 1000);
}
#[test]
fn engine_with_store_and_config() {
let store = TensorStore::new();
let config = VectorEngineConfig::low_memory();
let engine = VectorEngine::with_store_and_config(store, config).unwrap();
assert!((engine.config().sparse_threshold - 0.3).abs() < 1e-6);
}
#[test]
fn batch_store_embeddings_basic() {
let engine = VectorEngine::new();
let inputs = vec![
EmbeddingInput::new("a", vec![1.0, 0.0]),
EmbeddingInput::new("b", vec![0.0, 1.0]),
EmbeddingInput::new("c", vec![1.0, 1.0]),
];
let result = engine.batch_store_embeddings(inputs).unwrap();
assert_eq!(result.count, 3);
assert_eq!(result.stored_keys, vec!["a", "b", "c"]);
assert_eq!(engine.count(), 3);
}
#[test]
fn batch_store_embeddings_empty() {
let engine = VectorEngine::new();
let result = engine.batch_store_embeddings(vec![]).unwrap();
assert_eq!(result.count, 0);
assert!(result.stored_keys.is_empty());
}
#[test]
fn batch_store_embeddings_validation_error() {
let engine = VectorEngine::new();
let inputs = vec![
EmbeddingInput::new("a", vec![1.0, 0.0]),
EmbeddingInput::new("b", vec![]), ];
let result = engine.batch_store_embeddings(inputs);
assert!(matches!(
result,
Err(VectorError::BatchValidationError { index: 1, .. })
));
}
#[test]
fn batch_delete_embeddings_basic() {
let engine = VectorEngine::new();
engine.store_embedding("a", vec![1.0]).unwrap();
engine.store_embedding("b", vec![2.0]).unwrap();
engine.store_embedding("c", vec![3.0]).unwrap();
let count = engine
.batch_delete_embeddings(vec!["a".to_string(), "b".to_string()])
.unwrap();
assert_eq!(count, 2);
assert_eq!(engine.count(), 1);
assert!(engine.exists("c"));
}
#[test]
fn batch_delete_embeddings_empty() {
let engine = VectorEngine::new();
let count = engine.batch_delete_embeddings(vec![]).unwrap();
assert_eq!(count, 0);
}
#[test]
fn batch_delete_embeddings_nonexistent() {
let engine = VectorEngine::new();
engine.store_embedding("a", vec![1.0]).unwrap();
let count = engine
.batch_delete_embeddings(vec!["a".to_string(), "nonexistent".to_string()])
.unwrap();
assert_eq!(count, 1);
}
#[test]
fn embedding_input_new() {
let input = EmbeddingInput::new("test", vec![1.0, 2.0]);
assert_eq!(input.key, "test");
assert_eq!(input.vector, vec![1.0, 2.0]);
}
#[test]
fn batch_result_new() {
let result = BatchResult::new(vec!["a".to_string(), "b".to_string()]);
assert_eq!(result.count, 2);
assert_eq!(result.stored_keys, vec!["a", "b"]);
}
#[test]
fn pagination_new() {
let p = Pagination::new(10, 20);
assert_eq!(p.skip, 10);
assert_eq!(p.limit, Some(20));
assert!(!p.count_total);
}
#[test]
fn pagination_with_total() {
let p = Pagination::new(0, 10).with_total();
assert!(p.count_total);
}
#[test]
fn pagination_skip_only() {
let p = Pagination::skip_only(5);
assert_eq!(p.skip, 5);
assert_eq!(p.limit, None);
}
#[test]
fn list_keys_paginated_basic() {
let engine = VectorEngine::new();
for i in 0..10 {
engine
.store_embedding(&format!("v{:02}", i), vec![i as f32])
.unwrap();
}
let result = engine.list_keys_paginated(Pagination::new(0, 3));
assert_eq!(result.items.len(), 3);
assert!(result.has_more);
assert_eq!(result.total_count, None);
}
#[test]
fn list_keys_paginated_with_total() {
let engine = VectorEngine::new();
for i in 0..10 {
engine
.store_embedding(&format!("v{:02}", i), vec![i as f32])
.unwrap();
}
let result = engine.list_keys_paginated(Pagination::new(0, 5).with_total());
assert_eq!(result.items.len(), 5);
assert_eq!(result.total_count, Some(10));
assert!(result.has_more);
}
#[test]
fn list_keys_paginated_skip() {
let engine = VectorEngine::new();
for i in 0..10 {
engine
.store_embedding(&format!("v{:02}", i), vec![i as f32])
.unwrap();
}
let result = engine.list_keys_paginated(Pagination::new(8, 5).with_total());
assert_eq!(result.items.len(), 2);
assert!(!result.has_more);
}
#[test]
fn search_similar_paginated_basic() {
let engine = VectorEngine::new();
for i in 0..10 {
engine
.store_embedding(&format!("v{}", i), vec![i as f32, 0.0])
.unwrap();
}
let result = engine
.search_similar_paginated(&[5.0, 0.0], 10, Pagination::new(0, 3).with_total())
.unwrap();
assert_eq!(result.items.len(), 3);
}
#[test]
fn search_entities_paginated_basic() {
let engine = VectorEngine::new();
for i in 0..5 {
engine
.set_entity_embedding(&format!("user:{}", i), vec![i as f32, 0.0])
.unwrap();
}
let result = engine
.search_entities_paginated(&[2.0, 0.0], 5, Pagination::new(0, 2).with_total())
.unwrap();
assert_eq!(result.items.len(), 2);
}
#[test]
fn paged_result_empty() {
let result: PagedResult<String> = PagedResult::empty();
assert!(result.items.is_empty());
assert_eq!(result.total_count, Some(0));
assert!(!result.has_more);
}
#[test]
fn error_batch_validation_display() {
let error = VectorError::BatchValidationError {
index: 5,
cause: "test error".to_string(),
};
assert_eq!(
format!("{}", error),
"Batch validation error at index 5: test error"
);
}
#[test]
fn error_batch_operation_display() {
let error = VectorError::BatchOperationError {
index: 3,
cause: "op failed".to_string(),
};
assert_eq!(
format!("{}", error),
"Batch operation error at index 3: op failed"
);
}
#[test]
fn error_configuration_display() {
let error = VectorError::ConfigurationError("bad config".to_string());
assert_eq!(format!("{}", error), "Configuration error: bad config");
}
#[test]
fn error_variants_clone_and_eq() {
let e1 = VectorError::BatchValidationError {
index: 1,
cause: "test".to_string(),
};
let e2 = e1.clone();
assert_eq!(e1, e2);
let e3 = VectorError::ConfigurationError("test".to_string());
let e4 = e3.clone();
assert_eq!(e3, e4);
}
#[test]
fn search_with_hnsw_and_metric_basic() {
let engine = VectorEngine::new();
let dim = 32;
for i in 0..50 {
let vector = create_test_vector(dim, i);
engine.store_embedding(&format!("v{}", i), vector).unwrap();
}
let (index, key_mapping) = engine.build_hnsw_index_default().unwrap();
let query = create_test_vector(dim, 25);
let results = engine
.search_with_hnsw_and_metric(
&index,
&key_mapping,
&query,
5,
&ExtendedDistanceMetric::Cosine,
)
.unwrap();
assert_eq!(results.len(), 5);
assert!(results.iter().any(|r| r.key == "v25"));
}
#[test]
fn search_with_hnsw_and_metric_euclidean() {
let engine = VectorEngine::new();
engine.store_embedding("a", vec![1.0, 0.0]).unwrap();
engine.store_embedding("b", vec![2.0, 0.0]).unwrap();
engine.store_embedding("c", vec![10.0, 0.0]).unwrap();
let (index, key_mapping) = engine.build_hnsw_index_default().unwrap();
let results = engine
.search_with_hnsw_and_metric(
&index,
&key_mapping,
&[1.0, 0.0],
3,
&ExtendedDistanceMetric::Euclidean,
)
.unwrap();
assert_eq!(results.len(), 3);
assert_eq!(results[0].key, "a"); }
#[test]
fn search_with_hnsw_and_metric_empty_query() {
let engine = VectorEngine::new();
let index = HNSWIndex::new();
let key_mapping: Vec<String> = vec![];
let result = engine.search_with_hnsw_and_metric(
&index,
&key_mapping,
&[],
5,
&ExtendedDistanceMetric::Cosine,
);
assert!(matches!(result, Err(VectorError::EmptyVector)));
}
#[test]
fn search_with_hnsw_and_metric_zero_top_k() {
let engine = VectorEngine::new();
let index = HNSWIndex::new();
let key_mapping: Vec<String> = vec![];
let result = engine.search_with_hnsw_and_metric(
&index,
&key_mapping,
&[1.0],
0,
&ExtendedDistanceMetric::Cosine,
);
assert!(matches!(result, Err(VectorError::InvalidTopK)));
}
#[test]
fn test_concurrent_store_embedding_same_key() {
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
use std::thread;
let engine = Arc::new(VectorEngine::new());
let success = Arc::new(AtomicUsize::new(0));
let handles: Vec<_> = (0..10)
.map(|i| {
let eng = Arc::clone(&engine);
let s = Arc::clone(&success);
thread::spawn(move || {
let vector = vec![i as f32, i as f32];
if eng.store_embedding("contested", vector).is_ok() {
s.fetch_add(1, Ordering::SeqCst);
}
})
})
.collect();
for handle in handles {
handle.join().unwrap();
}
assert_eq!(success.load(Ordering::SeqCst), 10);
assert!(engine.exists("contested"));
}
#[test]
fn test_concurrent_delete_embedding_same_key() {
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
use std::thread;
let engine = Arc::new(VectorEngine::new());
engine.store_embedding("to_delete", vec![1.0, 2.0]).unwrap();
let success = Arc::new(AtomicUsize::new(0));
let error = Arc::new(AtomicUsize::new(0));
let handles: Vec<_> = (0..10)
.map(|_| {
let eng = Arc::clone(&engine);
let s = Arc::clone(&success);
let e = Arc::clone(&error);
thread::spawn(move || match eng.delete_embedding("to_delete") {
Ok(()) => {
s.fetch_add(1, Ordering::SeqCst);
},
Err(VectorError::NotFound(_)) => {
e.fetch_add(1, Ordering::SeqCst);
},
Err(err) => panic!("unexpected error: {err:?}"),
})
})
.collect();
for handle in handles {
handle.join().unwrap();
}
assert_eq!(success.load(Ordering::SeqCst), 1);
assert_eq!(error.load(Ordering::SeqCst), 9);
}
#[test]
fn test_concurrent_search_similar() {
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
use std::thread;
let engine = Arc::new(VectorEngine::new());
for i in 0..100 {
engine
.store_embedding(&format!("v{}", i), vec![i as f32, 0.0])
.unwrap();
}
assert_eq!(engine.count(), 100);
let success_count = Arc::new(AtomicUsize::new(0));
let handles: Vec<_> = (0..20)
.map(|t| {
let eng = Arc::clone(&engine);
let counter = Arc::clone(&success_count);
thread::spawn(move || {
for _ in 0..5 {
let query = vec![t as f32, 0.0];
if let Ok(results) = eng.search_similar(&query, 5) {
if !results.is_empty() {
counter.fetch_add(1, Ordering::SeqCst);
}
}
}
})
})
.collect();
for handle in handles {
handle.join().unwrap();
}
let successes = success_count.load(Ordering::SeqCst);
assert!(
successes > 50,
"Expected at least 50 successful searches, got {}",
successes
);
}
#[test]
fn test_concurrent_store_and_search() {
use std::sync::Arc;
use std::thread;
let engine = Arc::new(VectorEngine::new());
for i in 0..50 {
engine
.store_embedding(&format!("init{}", i), vec![i as f32, 0.0])
.unwrap();
}
let handles: Vec<_> = (0..20)
.map(|t| {
let eng = Arc::clone(&engine);
thread::spawn(move || {
if t % 2 == 0 {
for i in 0..10 {
let key = format!("t{}_v{}", t, i);
eng.store_embedding(&key, vec![t as f32, i as f32]).unwrap();
}
} else {
for _ in 0..10 {
let query = vec![t as f32, 0.0];
let _ = eng.search_similar(&query, 5);
}
}
})
})
.collect();
for handle in handles {
handle.join().unwrap();
}
assert!(engine.count() >= 50);
}
#[test]
fn test_concurrent_batch_operations() {
use std::sync::Arc;
use std::thread;
let engine = Arc::new(VectorEngine::new());
let handles: Vec<_> = (0..5)
.map(|t| {
let eng = Arc::clone(&engine);
thread::spawn(move || {
let inputs: Vec<EmbeddingInput> = (0..20)
.map(|i| {
EmbeddingInput::new(format!("t{}_b{}", t, i), vec![t as f32, i as f32])
})
.collect();
eng.batch_store_embeddings(inputs).unwrap()
})
})
.collect();
let results: Vec<BatchResult> = handles.into_iter().map(|h| h.join().unwrap()).collect();
for result in &results {
assert_eq!(result.count, 20);
}
assert_eq!(engine.count(), 100);
}
#[test]
fn extended_metric_re_export() {
let metric = ExtendedDistanceMetric::Jaccard;
assert!(metric.higher_is_better());
let config = GeometricConfig::default();
assert!((config.cosine_weight - 0.5).abs() < 1e-6);
}
#[test]
fn sparse_with_custom_threshold() {
let config = VectorEngineConfig {
sparse_threshold: 0.8, ..Default::default()
};
let engine = VectorEngine::with_config(config).unwrap();
let mostly_zeros: Vec<f32> = (0..100).map(|i| if i < 30 { 1.0 } else { 0.0 }).collect();
assert!(!engine.should_use_sparse(&mostly_zeros));
let very_sparse: Vec<f32> = (0..100).map(|i| if i < 10 { 1.0 } else { 0.0 }).collect();
assert!(engine.should_use_sparse(&very_sparse));
}
#[test]
fn batch_store_large_batch() {
let engine = VectorEngine::new();
let inputs: Vec<EmbeddingInput> = (0..150)
.map(|i| EmbeddingInput::new(format!("v{}", i), vec![i as f32, 0.0]))
.collect();
let result = engine.batch_store_embeddings(inputs).unwrap();
assert_eq!(result.count, 150);
assert_eq!(engine.count(), 150);
}
#[test]
fn batch_delete_large_batch() {
let engine = VectorEngine::new();
for i in 0..150 {
engine
.store_embedding(&format!("v{}", i), vec![i as f32])
.unwrap();
}
let keys: Vec<String> = (0..150).map(|i| format!("v{}", i)).collect();
let count = engine.batch_delete_embeddings(keys).unwrap();
assert_eq!(count, 150);
assert_eq!(engine.count(), 0);
}
#[test]
fn pagination_empty_result() {
let engine = VectorEngine::new();
let result = engine.list_keys_paginated(Pagination::new(0, 10).with_total());
assert!(result.items.is_empty());
assert_eq!(result.total_count, Some(0));
assert!(!result.has_more);
}
#[test]
fn pagination_skip_past_end() {
let engine = VectorEngine::new();
for i in 0..5 {
engine
.store_embedding(&format!("v{}", i), vec![i as f32])
.unwrap();
}
let result = engine.list_keys_paginated(Pagination::new(10, 5).with_total());
assert!(result.items.is_empty());
assert_eq!(result.total_count, Some(5));
assert!(!result.has_more);
}
#[test]
fn distance_metric_conversion_cosine() {
let simple = DistanceMetric::Cosine;
let extended: ExtendedDistanceMetric = simple.into();
assert!(matches!(extended, ExtendedDistanceMetric::Cosine));
}
#[test]
fn distance_metric_conversion_euclidean() {
let simple = DistanceMetric::Euclidean;
let extended: ExtendedDistanceMetric = simple.into();
assert!(matches!(extended, ExtendedDistanceMetric::Euclidean));
}
#[test]
fn distance_metric_conversion_dot_product() {
let simple = DistanceMetric::DotProduct;
let extended: ExtendedDistanceMetric = simple.into();
assert!(matches!(extended, ExtendedDistanceMetric::Cosine));
}
#[test]
fn search_with_hnsw_and_metric_angular() {
let engine = VectorEngine::new();
engine.store_embedding("a", vec![1.0, 0.0]).unwrap();
engine.store_embedding("b", vec![0.707, 0.707]).unwrap();
engine.store_embedding("c", vec![0.0, 1.0]).unwrap();
let (index, key_mapping) = engine.build_hnsw_index_default().unwrap();
let results = engine
.search_with_hnsw_and_metric(
&index,
&key_mapping,
&[1.0, 0.0],
3,
&ExtendedDistanceMetric::Angular,
)
.unwrap();
assert_eq!(results.len(), 3);
assert_eq!(results[0].key, "a");
}
#[test]
fn search_with_hnsw_and_metric_jaccard() {
let engine = VectorEngine::new();
engine.store_embedding("a", vec![1.0, 1.0, 0.0]).unwrap();
engine.store_embedding("b", vec![1.0, 0.0, 0.0]).unwrap();
engine.store_embedding("c", vec![0.0, 0.0, 1.0]).unwrap();
let (index, key_mapping) = engine.build_hnsw_index_default().unwrap();
let results = engine
.search_with_hnsw_and_metric(
&index,
&key_mapping,
&[1.0, 1.0, 0.0],
3,
&ExtendedDistanceMetric::Jaccard,
)
.unwrap();
assert_eq!(results.len(), 3);
assert_eq!(results[0].key, "a");
}
#[test]
fn search_with_hnsw_and_metric_overlap() {
let engine = VectorEngine::new();
engine.store_embedding("a", vec![1.0, 1.0, 0.0]).unwrap();
engine.store_embedding("b", vec![1.0, 0.0, 0.0]).unwrap();
let (index, key_mapping) = engine.build_hnsw_index_default().unwrap();
let results = engine
.search_with_hnsw_and_metric(
&index,
&key_mapping,
&[1.0, 1.0, 0.0],
2,
&ExtendedDistanceMetric::Overlap,
)
.unwrap();
assert_eq!(results.len(), 2);
}
#[test]
fn search_with_hnsw_and_metric_manhattan() {
let engine = VectorEngine::new();
engine.store_embedding("origin", vec![0.0, 0.0]).unwrap();
engine.store_embedding("one", vec![1.0, 0.0]).unwrap();
engine.store_embedding("two", vec![2.0, 0.0]).unwrap();
let (index, key_mapping) = engine.build_hnsw_index_default().unwrap();
let results = engine
.search_with_hnsw_and_metric(
&index,
&key_mapping,
&[0.0, 0.0],
3,
&ExtendedDistanceMetric::Manhattan,
)
.unwrap();
assert_eq!(results.len(), 3);
assert_eq!(results[0].key, "origin");
}
#[test]
fn store_and_search_with_very_small_values() {
let engine = VectorEngine::new();
let tiny = vec![1e-18_f32, 1e-18, 1e-18];
engine.store_embedding("tiny", tiny.clone()).unwrap();
let results = engine.search_similar(&tiny, 1).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].key, "tiny");
}
#[test]
fn store_and_search_with_large_values() {
let engine = VectorEngine::new();
let large = vec![1e30_f32, 1e30, 1e30];
engine.store_embedding("large", large.clone()).unwrap();
let results = engine.search_similar(&large, 1).unwrap();
assert_eq!(results.len(), 1);
}
#[test]
fn search_handles_denormalized_floats() {
let engine = VectorEngine::new();
let denorm = vec![f32::MIN_POSITIVE / 2.0, 1.0, 0.0];
engine.store_embedding("denorm", denorm.clone()).unwrap();
let results = engine.search_similar(&denorm, 1).unwrap();
assert_eq!(results.len(), 1);
}
#[test]
fn zero_vector_with_euclidean_metric() {
let engine = VectorEngine::new();
engine.store_embedding("a", vec![1.0, 0.0]).unwrap();
engine.store_embedding("b", vec![0.0, 0.0]).unwrap();
let results = engine
.search_similar_with_metric(&[0.0, 0.0], 2, DistanceMetric::Euclidean)
.unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].key, "b");
}
#[test]
fn single_dimension_vector() {
let engine = VectorEngine::new();
engine.store_embedding("a", vec![1.0]).unwrap();
engine.store_embedding("b", vec![2.0]).unwrap();
engine.store_embedding("c", vec![-1.0]).unwrap();
let results = engine.search_similar(&[1.0], 3).unwrap();
assert_eq!(results.len(), 3);
assert!(
results[0].key == "a" || results[0].key == "b",
"Expected 'a' or 'b' first, got '{}'",
results[0].key
);
assert_eq!(results[2].key, "c");
}
#[test]
fn high_dimension_4096() {
let engine = VectorEngine::new();
let dim = 4096;
let v1: Vec<f32> = (0..dim).map(|i| (i as f32 * 0.001).sin()).collect();
let v2: Vec<f32> = (0..dim).map(|i| (i as f32 * 0.002).sin()).collect();
engine.store_embedding("v1", v1.clone()).unwrap();
engine.store_embedding("v2", v2).unwrap();
let results = engine.search_similar(&v1, 2).unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].key, "v1");
}
#[test]
fn mismatched_dimensions_silently_skipped() {
let engine = VectorEngine::new();
engine.store_embedding("dim2", vec![1.0, 0.0]).unwrap();
engine.store_embedding("dim3", vec![1.0, 0.0, 0.0]).unwrap();
engine
.store_embedding("dim4", vec![1.0, 0.0, 0.0, 0.0])
.unwrap();
let results = engine.search_similar(&[1.0, 0.0], 10).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].key, "dim2");
}
#[test]
fn config_validate_negative_sparse_threshold() {
let config = VectorEngineConfig {
sparse_threshold: -0.1,
..Default::default()
};
assert!(matches!(
config.validate(),
Err(VectorError::ConfigurationError(_))
));
}
#[test]
fn config_presets_are_valid() {
assert!(VectorEngineConfig::default().validate().is_ok());
assert!(VectorEngineConfig::high_throughput().validate().is_ok());
assert!(VectorEngineConfig::low_memory().validate().is_ok());
}
#[test]
#[ignore = "Run with: cargo test --release -- --ignored"]
fn scale_100k_vector_search() {
let engine = VectorEngine::new();
let dim = 128;
for i in 0..100_000 {
let vector = create_test_vector(dim, i);
engine.store_embedding(&format!("v{}", i), vector).unwrap();
}
assert_eq!(engine.count(), 100_000);
let query = create_test_vector(dim, 50_000);
let results = engine.search_similar(&query, 10).unwrap();
assert_eq!(results.len(), 10);
assert_eq!(results[0].key, "v50000");
}
#[test]
#[ignore] fn scale_10k_batch_store() {
let engine = VectorEngine::new();
let dim = 64;
let inputs: Vec<EmbeddingInput> = (0..10_000)
.map(|i| EmbeddingInput::new(format!("v{}", i), create_test_vector(dim, i)))
.collect();
let result = engine.batch_store_embeddings(inputs).unwrap();
assert_eq!(result.count, 10_000);
assert_eq!(engine.count(), 10_000);
}
#[test]
fn paged_result_new_with_data() {
let items = vec!["a".to_string(), "b".to_string()];
let result = PagedResult::new(items.clone(), Some(10), true);
assert_eq!(result.items, items);
assert_eq!(result.total_count, Some(10));
assert!(result.has_more);
}
#[test]
fn pagination_default() {
let p = Pagination::default();
assert_eq!(p.skip, 0);
assert_eq!(p.limit, None);
assert!(!p.count_total);
}
#[test]
fn sparse_detection_empty_vector() {
assert!(!VectorEngine::should_use_sparse_with_threshold(&[], 0.5));
}
#[test]
fn extract_vector_non_vector_type() {
let value = TensorValue::Scalar(tensor_store::ScalarValue::Int(42));
assert!(VectorEngine::extract_vector(&value).is_none());
}
#[test]
fn store_accessor() {
let engine = VectorEngine::new();
engine.store_embedding("test", vec![1.0, 2.0]).unwrap();
let store = engine.store();
assert!(store.exists("emb:test"));
}
#[test]
fn entity_embedding_no_embedding_field() {
let engine = VectorEngine::new();
let store = engine.store();
let mut data = TensorData::new();
data.set(
"name",
TensorValue::Scalar(tensor_store::ScalarValue::String("test".into())),
);
store.put("entity:1", data).unwrap();
let result = engine.get_entity_embedding("entity:1");
assert!(matches!(result, Err(VectorError::NotFound(_))));
}
#[test]
fn remove_entity_embedding_no_field() {
let engine = VectorEngine::new();
let store = engine.store();
let mut data = TensorData::new();
data.set(
"name",
TensorValue::Scalar(tensor_store::ScalarValue::String("test".into())),
);
store.put("entity:1", data).unwrap();
let result = engine.remove_entity_embedding("entity:1");
assert!(matches!(result, Err(VectorError::NotFound(_))));
}
#[test]
fn config_validate_invalid_max_dimension_zero() {
let config = VectorEngineConfig {
max_dimension: Some(0),
..Default::default()
};
let result = config.validate();
assert!(matches!(
result,
Err(VectorError::ConfigurationError(msg)) if msg.contains("max_dimension")
));
}
#[test]
fn config_validate_invalid_max_keys_per_scan_zero() {
let config = VectorEngineConfig {
max_keys_per_scan: Some(0),
..Default::default()
};
let result = config.validate();
assert!(matches!(
result,
Err(VectorError::ConfigurationError(msg)) if msg.contains("max_keys_per_scan")
));
}
#[test]
fn config_validate_invalid_batch_parallel_threshold_zero() {
let config = VectorEngineConfig {
batch_parallel_threshold: 0,
..Default::default()
};
let result = config.validate();
assert!(matches!(
result,
Err(VectorError::ConfigurationError(msg)) if msg.contains("batch_parallel_threshold")
));
}
#[test]
fn config_validate_invalid_max_index_file_bytes_zero() {
let config = VectorEngineConfig {
max_index_file_bytes: Some(0),
..Default::default()
};
let result = config.validate();
assert!(matches!(
result,
Err(VectorError::ConfigurationError(msg)) if msg.contains("max_index_file_bytes")
));
}
#[test]
fn config_validate_invalid_max_index_entries_zero() {
let config = VectorEngineConfig {
max_index_entries: Some(0),
..Default::default()
};
let result = config.validate();
assert!(matches!(
result,
Err(VectorError::ConfigurationError(msg)) if msg.contains("max_index_entries")
));
}
#[test]
fn with_config_validates_and_returns_error() {
let config = VectorEngineConfig {
max_dimension: Some(0),
..Default::default()
};
let result = VectorEngine::with_config(config);
assert!(matches!(result, Err(VectorError::ConfigurationError(_))));
}
#[test]
fn with_store_and_config_validates_and_returns_error() {
let store = TensorStore::new();
let config = VectorEngineConfig {
batch_parallel_threshold: 0,
..Default::default()
};
let result = VectorEngine::with_store_and_config(store, config);
assert!(matches!(result, Err(VectorError::ConfigurationError(_))));
}
#[test]
fn with_config_valid_succeeds() {
let config = VectorEngineConfig {
max_dimension: Some(1024),
max_keys_per_scan: Some(1000),
..Default::default()
};
let result = VectorEngine::with_config(config);
assert!(result.is_ok());
}
#[test]
fn list_keys_bounded_respects_limit() {
let config = VectorEngineConfig {
max_keys_per_scan: Some(3),
..Default::default()
};
let engine = VectorEngine::with_config(config).unwrap();
for i in 0..10 {
engine
.store_embedding(&format!("v{}", i), vec![i as f32])
.unwrap();
}
let keys = engine.list_keys_bounded();
assert_eq!(keys.len(), 3);
}
#[test]
fn list_keys_bounded_no_limit() {
let engine = VectorEngine::new();
for i in 0..10 {
engine
.store_embedding(&format!("v{}", i), vec![i as f32])
.unwrap();
}
let keys = engine.list_keys_bounded();
assert_eq!(keys.len(), 10);
}
#[test]
fn search_similar_rejects_oversized_dimension() {
let config = VectorEngineConfig {
max_dimension: Some(10),
..Default::default()
};
let engine = VectorEngine::with_config(config).unwrap();
let oversized_query = vec![0.0; 20];
let result = engine.search_similar(&oversized_query, 5);
assert!(matches!(
result,
Err(VectorError::DimensionMismatch {
expected: 10,
got: 20
})
));
}
#[test]
fn store_embedding_rejects_oversized_dimension() {
let config = VectorEngineConfig {
max_dimension: Some(5),
..Default::default()
};
let engine = VectorEngine::with_config(config).unwrap();
let oversized = vec![0.0; 10];
let result = engine.store_embedding("key", oversized);
assert!(matches!(
result,
Err(VectorError::DimensionMismatch {
expected: 5,
got: 10
})
));
}
#[test]
fn set_entity_embedding_rejects_oversized_dimension() {
let config = VectorEngineConfig {
max_dimension: Some(5),
..Default::default()
};
let engine = VectorEngine::with_config(config).unwrap();
let oversized = vec![0.0; 10];
let result = engine.set_entity_embedding("entity:1", oversized);
assert!(matches!(
result,
Err(VectorError::DimensionMismatch {
expected: 5,
got: 10
})
));
}
#[test]
fn search_entities_rejects_oversized_dimension() {
let config = VectorEngineConfig {
max_dimension: Some(5),
..Default::default()
};
let engine = VectorEngine::with_config(config).unwrap();
let oversized_query = vec![0.0; 10];
let result = engine.search_entities(&oversized_query, 5);
assert!(matches!(
result,
Err(VectorError::DimensionMismatch {
expected: 5,
got: 10
})
));
}
#[test]
fn build_hnsw_index_validates_dimension_consistency() {
let engine = VectorEngine::new();
engine.store_embedding("a", vec![1.0, 2.0, 3.0]).unwrap();
engine.store_embedding("b", vec![4.0, 5.0, 6.0]).unwrap();
let result = engine.build_hnsw_index_default();
assert!(result.is_ok());
let (index, keys) = result.unwrap();
assert_eq!(index.len(), 2);
assert_eq!(keys.len(), 2);
}
#[test]
fn build_hnsw_index_empty_store() {
let engine = VectorEngine::new();
let result = engine.build_hnsw_index_default();
assert!(result.is_ok());
let (index, keys) = result.unwrap();
assert!(index.is_empty());
assert!(keys.is_empty());
}
#[test]
fn build_hnsw_index_rejects_exceeding_max_dimension() {
let config = VectorEngineConfig {
max_dimension: Some(5),
..Default::default()
};
let engine = VectorEngine::with_config(config).unwrap();
engine.store_embedding("a", vec![1.0, 2.0]).unwrap();
let result = engine.build_hnsw_index_default();
assert!(result.is_ok());
}
#[test]
fn estimate_hnsw_memory_empty_store() {
let engine = VectorEngine::new();
let estimate = engine.estimate_hnsw_memory().unwrap();
assert_eq!(estimate, 0);
}
#[test]
fn estimate_hnsw_memory_calculation() {
let engine = VectorEngine::new();
for i in 0..10 {
engine
.store_embedding(&format!("v{}", i), vec![1.0; 128])
.unwrap();
}
let estimate = engine.estimate_hnsw_memory().unwrap();
let expected = 10 * 128 * 4 + 10 * 16 * 2 * 8 + 10 * 32;
assert_eq!(estimate, expected);
}
#[test]
fn batch_uses_config_parallel_threshold() {
let config = VectorEngineConfig {
batch_parallel_threshold: 5,
..Default::default()
};
let engine = VectorEngine::with_config(config).unwrap();
let inputs: Vec<EmbeddingInput> = (0..10)
.map(|i| EmbeddingInput::new(format!("v{}", i), vec![i as f32, 0.0]))
.collect();
let result = engine.batch_store_embeddings(inputs).unwrap();
assert_eq!(result.count, 10);
}
#[test]
fn config_low_memory_has_bounds() {
let config = VectorEngineConfig::low_memory();
assert_eq!(config.max_dimension, Some(4096));
assert_eq!(config.max_keys_per_scan, Some(10_000));
}
#[test]
fn config_default_has_new_fields() {
let config = VectorEngineConfig::default();
assert_eq!(config.max_dimension, None);
assert_eq!(config.max_keys_per_scan, None);
assert_eq!(config.batch_parallel_threshold, 100);
}
#[test]
fn config_high_throughput_has_new_fields() {
let config = VectorEngineConfig::high_throughput();
assert_eq!(config.max_dimension, None);
assert_eq!(config.max_keys_per_scan, None);
assert_eq!(config.batch_parallel_threshold, 100);
}
#[test]
fn config_builder_methods() {
let config = VectorEngineConfig::default()
.with_default_dimension(128)
.with_sparse_threshold(0.7)
.with_parallel_threshold(1000)
.with_default_metric(DistanceMetric::Euclidean)
.with_max_dimension(4096)
.with_max_keys_per_scan(50_000)
.with_batch_parallel_threshold(200);
assert_eq!(config.default_dimension, Some(128));
assert!((config.sparse_threshold - 0.7).abs() < f32::EPSILON);
assert_eq!(config.parallel_threshold, 1000);
assert_eq!(config.default_metric, DistanceMetric::Euclidean);
assert_eq!(config.max_dimension, Some(4096));
assert_eq!(config.max_keys_per_scan, Some(50_000));
assert_eq!(config.batch_parallel_threshold, 200);
}
#[test]
#[ignore] fn large_scale_million_vectors() {
use std::time::Instant;
let config = VectorEngineConfig::high_throughput();
let engine = VectorEngine::with_config(config).unwrap();
const VECTOR_COUNT: usize = 1_000_000;
const DIMENSION: usize = 128;
let start = Instant::now();
for i in 0..VECTOR_COUNT {
let key = format!("vec_{}", i);
let vector: Vec<f32> = (0..DIMENSION)
.map(|j| ((i * DIMENSION + j) % 1000) as f32 / 1000.0)
.collect();
engine.store_embedding(&key, vector).unwrap();
}
let insert_time = start.elapsed();
assert_eq!(engine.count(), VECTOR_COUNT);
let query: Vec<f32> = (0..DIMENSION)
.map(|i| i as f32 / DIMENSION as f32)
.collect();
let start = Instant::now();
for _ in 0..100 {
let results = engine.search_similar(&query, 10).unwrap();
assert_eq!(results.len(), 10);
}
let search_time = start.elapsed();
assert!(
insert_time.as_secs() < 120,
"Insert took too long: {:?}",
insert_time
);
assert!(
search_time.as_millis() < 5000,
"100 searches took too long: {:?}",
search_time
);
}
#[test]
fn error_from_tensor_store_error() {
let tensor_error = TensorStoreError::NotFound("test_key".to_string());
let vector_error: VectorError = tensor_error.into();
assert!(matches!(vector_error, VectorError::StorageError(_)));
assert!(vector_error.to_string().contains("test_key"));
}
#[test]
fn error_std_error_trait() {
let error = VectorError::NotFound("test".to_string());
let _: &dyn std::error::Error = &error;
}
#[test]
fn search_with_metric_parallel_path() {
let config = VectorEngineConfig {
parallel_threshold: 5,
..Default::default()
};
let engine = VectorEngine::with_config(config).unwrap();
for i in 0..10 {
engine
.store_embedding(&format!("vec_{}", i), vec![i as f32, 0.0, 0.0])
.unwrap();
}
let results = engine
.search_similar_with_metric(&[5.0, 0.0, 0.0], 3, DistanceMetric::Euclidean)
.unwrap();
assert_eq!(results.len(), 3);
}
#[test]
fn get_embedding_invalid_format() {
let engine = VectorEngine::new();
let mut tensor = TensorData::new();
tensor.set("not_vector", TensorValue::Scalar(ScalarValue::Int(42)));
engine
.store()
.put("emb:invalid".to_string(), tensor)
.unwrap();
let result = engine.get_embedding("invalid");
assert!(matches!(result, Err(VectorError::NotFound(_))));
}
#[test]
fn store_embedding_with_metadata_basic() {
let engine = VectorEngine::new();
let mut metadata = HashMap::new();
metadata.insert(
"category".to_string(),
TensorValue::Scalar(ScalarValue::String("electronics".to_string())),
);
metadata.insert(
"price".to_string(),
TensorValue::Scalar(ScalarValue::Float(299.99)),
);
engine
.store_embedding_with_metadata("product1", vec![0.1, 0.2, 0.3], metadata)
.unwrap();
let vector = engine.get_embedding("product1").unwrap();
assert_eq!(vector.len(), 3);
let retrieved_metadata = engine.get_metadata("product1").unwrap();
assert_eq!(retrieved_metadata.len(), 2);
assert!(retrieved_metadata.contains_key("category"));
assert!(retrieved_metadata.contains_key("price"));
}
#[test]
fn store_embedding_with_metadata_empty_metadata() {
let engine = VectorEngine::new();
let metadata = HashMap::new();
engine
.store_embedding_with_metadata("key", vec![1.0, 2.0], metadata)
.unwrap();
let vector = engine.get_embedding("key").unwrap();
assert_eq!(vector, vec![1.0, 2.0]);
let retrieved_metadata = engine.get_metadata("key").unwrap();
assert!(retrieved_metadata.is_empty());
}
#[test]
fn store_embedding_with_metadata_empty_vector_error() {
let engine = VectorEngine::new();
let metadata = HashMap::new();
let result = engine.store_embedding_with_metadata("key", vec![], metadata);
assert!(matches!(result, Err(VectorError::EmptyVector)));
}
#[test]
fn store_embedding_with_metadata_dimension_limit() {
let config = VectorEngineConfig {
max_dimension: Some(5),
..Default::default()
};
let engine = VectorEngine::with_config(config).unwrap();
let result = engine.store_embedding_with_metadata("key", vec![0.0; 10], HashMap::new());
assert!(matches!(
result,
Err(VectorError::DimensionMismatch {
expected: 5,
got: 10
})
));
}
#[test]
fn get_metadata_nonexistent_key() {
let engine = VectorEngine::new();
let result = engine.get_metadata("nonexistent");
assert!(matches!(result, Err(VectorError::NotFound(_))));
}
#[test]
fn update_metadata_basic() {
let engine = VectorEngine::new();
let mut initial_metadata = HashMap::new();
initial_metadata.insert(
"color".to_string(),
TensorValue::Scalar(ScalarValue::String("red".to_string())),
);
engine
.store_embedding_with_metadata("item", vec![1.0, 2.0], initial_metadata)
.unwrap();
let mut update = HashMap::new();
update.insert(
"size".to_string(),
TensorValue::Scalar(ScalarValue::String("large".to_string())),
);
update.insert(
"color".to_string(),
TensorValue::Scalar(ScalarValue::String("blue".to_string())),
);
engine.update_metadata("item", update).unwrap();
let metadata = engine.get_metadata("item").unwrap();
assert_eq!(metadata.len(), 2);
match metadata.get("color") {
Some(TensorValue::Scalar(ScalarValue::String(s))) => assert_eq!(s, "blue"),
_ => panic!("Expected color to be 'blue'"),
}
assert!(metadata.contains_key("size"));
}
#[test]
fn update_metadata_nonexistent_key() {
let engine = VectorEngine::new();
let metadata = HashMap::new();
let result = engine.update_metadata("nonexistent", metadata);
assert!(matches!(result, Err(VectorError::NotFound(_))));
}
#[test]
fn update_metadata_no_vector_error() {
let engine = VectorEngine::new();
let mut tensor = TensorData::new();
tensor.set(
"meta:test",
TensorValue::Scalar(ScalarValue::String("value".to_string())),
);
engine.store().put("emb:orphan", tensor).unwrap();
let result = engine.update_metadata("orphan", HashMap::new());
assert!(matches!(result, Err(VectorError::NotFound(_))));
}
#[test]
fn remove_metadata_field_basic() {
let engine = VectorEngine::new();
let mut metadata = HashMap::new();
metadata.insert("a".to_string(), TensorValue::Scalar(ScalarValue::Int(1)));
metadata.insert("b".to_string(), TensorValue::Scalar(ScalarValue::Int(2)));
engine
.store_embedding_with_metadata("key", vec![1.0], metadata)
.unwrap();
engine.remove_metadata_field("key", "a").unwrap();
let retrieved = engine.get_metadata("key").unwrap();
assert!(!retrieved.contains_key("a"));
assert!(retrieved.contains_key("b"));
}
#[test]
fn remove_metadata_field_nonexistent_key() {
let engine = VectorEngine::new();
let result = engine.remove_metadata_field("nonexistent", "field");
assert!(matches!(result, Err(VectorError::NotFound(_))));
}
#[test]
fn has_metadata_field_true() {
let engine = VectorEngine::new();
let mut metadata = HashMap::new();
metadata.insert(
"field1".to_string(),
TensorValue::Scalar(ScalarValue::Int(42)),
);
engine
.store_embedding_with_metadata("key", vec![1.0], metadata)
.unwrap();
assert!(engine.has_metadata_field("key", "field1"));
}
#[test]
fn has_metadata_field_false() {
let engine = VectorEngine::new();
engine.store_embedding("key", vec![1.0]).unwrap();
assert!(!engine.has_metadata_field("key", "nonexistent_field"));
}
#[test]
fn has_metadata_field_nonexistent_key() {
let engine = VectorEngine::new();
assert!(!engine.has_metadata_field("nonexistent", "field"));
}
#[test]
fn get_metadata_field_basic() {
let engine = VectorEngine::new();
let mut metadata = HashMap::new();
metadata.insert(
"score".to_string(),
TensorValue::Scalar(ScalarValue::Float(0.95)),
);
engine
.store_embedding_with_metadata("key", vec![1.0], metadata)
.unwrap();
let value = engine.get_metadata_field("key", "score").unwrap();
match value {
Some(TensorValue::Scalar(ScalarValue::Float(f))) => {
assert!((f - 0.95).abs() < f64::EPSILON);
},
_ => panic!("Expected float value"),
}
}
#[test]
fn get_metadata_field_not_present() {
let engine = VectorEngine::new();
engine.store_embedding("key", vec![1.0]).unwrap();
let value = engine.get_metadata_field("key", "nonexistent").unwrap();
assert!(value.is_none());
}
#[test]
fn get_metadata_field_nonexistent_key() {
let engine = VectorEngine::new();
let result = engine.get_metadata_field("nonexistent", "field");
assert!(matches!(result, Err(VectorError::NotFound(_))));
}
#[test]
fn metadata_with_sparse_vector() {
let engine = VectorEngine::new();
let mut sparse = vec![0.0f32; 100];
sparse[0] = 1.0;
sparse[50] = 2.0;
let mut metadata = HashMap::new();
metadata.insert(
"type".to_string(),
TensorValue::Scalar(ScalarValue::String("sparse".to_string())),
);
engine
.store_embedding_with_metadata("sparse_key", sparse.clone(), metadata)
.unwrap();
let retrieved = engine.get_embedding("sparse_key").unwrap();
assert_eq!(retrieved.len(), 100);
assert_eq!(retrieved[0], 1.0);
assert_eq!(retrieved[50], 2.0);
let meta = engine.get_metadata("sparse_key").unwrap();
assert!(meta.contains_key("type"));
}
#[test]
fn metadata_field_key_helper() {
let key = VectorEngine::metadata_field_key("test_field");
assert_eq!(key, "meta:test_field");
}
#[test]
fn metadata_multiple_types() {
let engine = VectorEngine::new();
let mut metadata = HashMap::new();
metadata.insert(
"int_field".to_string(),
TensorValue::Scalar(ScalarValue::Int(42)),
);
metadata.insert(
"float_field".to_string(),
TensorValue::Scalar(ScalarValue::Float(3.14)),
);
metadata.insert(
"string_field".to_string(),
TensorValue::Scalar(ScalarValue::String("hello".to_string())),
);
metadata.insert(
"bool_field".to_string(),
TensorValue::Scalar(ScalarValue::Bool(true)),
);
engine
.store_embedding_with_metadata("multi_type", vec![1.0, 2.0], metadata)
.unwrap();
let retrieved = engine.get_metadata("multi_type").unwrap();
assert_eq!(retrieved.len(), 4);
match retrieved.get("int_field") {
Some(TensorValue::Scalar(ScalarValue::Int(i))) => assert_eq!(*i, 42),
_ => panic!("Expected int"),
}
match retrieved.get("float_field") {
Some(TensorValue::Scalar(ScalarValue::Float(f))) => {
assert!((*f - 3.14).abs() < f64::EPSILON);
},
_ => panic!("Expected float"),
}
match retrieved.get("string_field") {
Some(TensorValue::Scalar(ScalarValue::String(s))) => assert_eq!(s, "hello"),
_ => panic!("Expected string"),
}
match retrieved.get("bool_field") {
Some(TensorValue::Scalar(ScalarValue::Bool(b))) => assert!(*b),
_ => panic!("Expected bool"),
}
}
#[test]
fn metadata_overwrites_on_store() {
let engine = VectorEngine::new();
let mut meta1 = HashMap::new();
meta1.insert("a".to_string(), TensorValue::Scalar(ScalarValue::Int(1)));
meta1.insert("b".to_string(), TensorValue::Scalar(ScalarValue::Int(2)));
engine
.store_embedding_with_metadata("key", vec![1.0], meta1)
.unwrap();
let mut meta2 = HashMap::new();
meta2.insert("c".to_string(), TensorValue::Scalar(ScalarValue::Int(3)));
engine
.store_embedding_with_metadata("key", vec![2.0], meta2)
.unwrap();
let retrieved = engine.get_metadata("key").unwrap();
assert_eq!(retrieved.len(), 1);
assert!(retrieved.contains_key("c"));
assert!(!retrieved.contains_key("a"));
assert!(!retrieved.contains_key("b"));
}
fn setup_filtered_search_engine() -> VectorEngine {
let engine = VectorEngine::new();
let categories = ["electronics", "clothing", "food"];
let prices = [100, 50, 25];
for (i, (cat, price)) in categories.iter().zip(prices.iter()).enumerate() {
let mut metadata = HashMap::new();
metadata.insert(
"category".to_string(),
TensorValue::Scalar(ScalarValue::String(cat.to_string())),
);
metadata.insert(
"price".to_string(),
TensorValue::Scalar(ScalarValue::Int(*price)),
);
metadata.insert(
"active".to_string(),
TensorValue::Scalar(ScalarValue::Bool(i % 2 == 0)),
);
engine
.store_embedding_with_metadata(
&format!("item{}", i),
vec![(i + 1) as f32, 1.0, 1.0],
metadata,
)
.unwrap();
}
engine
}
#[test]
fn search_filtered_eq_string() {
let engine = setup_filtered_search_engine();
let filter = FilterCondition::Eq(
"category".to_string(),
FilterValue::String("electronics".to_string()),
);
let results = engine
.search_similar_filtered(&[1.0, 1.0, 1.0], 10, &filter, None)
.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].key, "item0");
}
#[test]
fn search_filtered_eq_int() {
let engine = setup_filtered_search_engine();
let filter = FilterCondition::Eq("price".to_string(), FilterValue::Int(50));
let results = engine
.search_similar_filtered(&[1.0, 0.0, 0.0], 10, &filter, None)
.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].key, "item1");
}
#[test]
fn search_filtered_gt() {
let engine = setup_filtered_search_engine();
let filter = FilterCondition::Gt("price".to_string(), FilterValue::Int(30));
let results = engine
.search_similar_filtered(&[1.0, 1.0, 1.0], 10, &filter, None)
.unwrap();
assert_eq!(results.len(), 2); }
#[test]
fn search_filtered_lt() {
let engine = setup_filtered_search_engine();
let filter = FilterCondition::Lt("price".to_string(), FilterValue::Int(60));
let results = engine
.search_similar_filtered(&[1.0, 1.0, 1.0], 10, &filter, None)
.unwrap();
assert_eq!(results.len(), 2); }
#[test]
fn search_filtered_le() {
let engine = setup_filtered_search_engine();
let filter = FilterCondition::Le("price".to_string(), FilterValue::Int(50));
let results = engine
.search_similar_filtered(&[1.0, 1.0, 1.0], 10, &filter, None)
.unwrap();
assert_eq!(results.len(), 2); }
#[test]
fn search_filtered_ge() {
let engine = setup_filtered_search_engine();
let filter = FilterCondition::Ge("price".to_string(), FilterValue::Int(50));
let results = engine
.search_similar_filtered(&[1.0, 1.0, 1.0], 10, &filter, None)
.unwrap();
assert_eq!(results.len(), 2); }
#[test]
fn search_filtered_and() {
let engine = setup_filtered_search_engine();
let filter = FilterCondition::Gt("price".to_string(), FilterValue::Int(30)).and(
FilterCondition::Lt("price".to_string(), FilterValue::Int(80)),
);
let results = engine
.search_similar_filtered(&[1.0, 1.0, 1.0], 10, &filter, None)
.unwrap();
assert_eq!(results.len(), 1); assert_eq!(results[0].key, "item1");
}
#[test]
fn search_filtered_or() {
let engine = setup_filtered_search_engine();
let filter = FilterCondition::Eq(
"category".to_string(),
FilterValue::String("electronics".to_string()),
)
.or(FilterCondition::Eq(
"category".to_string(),
FilterValue::String("food".to_string()),
));
let results = engine
.search_similar_filtered(&[1.0, 1.0, 1.0], 10, &filter, None)
.unwrap();
assert_eq!(results.len(), 2);
}
#[test]
fn search_filtered_true() {
let engine = setup_filtered_search_engine();
let filter = FilterCondition::True;
let results = engine
.search_similar_filtered(&[1.0, 1.0, 1.0], 10, &filter, None)
.unwrap();
assert_eq!(results.len(), 3); }
#[test]
fn search_filtered_exists() {
let engine = VectorEngine::new();
let mut meta1 = HashMap::new();
meta1.insert(
"tag".to_string(),
TensorValue::Scalar(ScalarValue::String("a".to_string())),
);
engine
.store_embedding_with_metadata("with_tag", vec![1.0, 0.0], meta1)
.unwrap();
engine
.store_embedding("without_tag", vec![0.0, 1.0])
.unwrap();
let filter = FilterCondition::Exists("tag".to_string());
let results = engine
.search_similar_filtered(&[1.0, 0.0], 10, &filter, None)
.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].key, "with_tag");
}
#[test]
fn search_filtered_contains() {
let engine = VectorEngine::new();
let mut meta1 = HashMap::new();
meta1.insert(
"description".to_string(),
TensorValue::Scalar(ScalarValue::String("blue shirt".to_string())),
);
engine
.store_embedding_with_metadata("item1", vec![1.0, 0.0], meta1)
.unwrap();
let mut meta2 = HashMap::new();
meta2.insert(
"description".to_string(),
TensorValue::Scalar(ScalarValue::String("red pants".to_string())),
);
engine
.store_embedding_with_metadata("item2", vec![0.0, 1.0], meta2)
.unwrap();
let filter = FilterCondition::Contains("description".to_string(), "shirt".to_string());
let results = engine
.search_similar_filtered(&[1.0, 0.0], 10, &filter, None)
.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].key, "item1");
}
#[test]
fn search_filtered_contains_on_non_string() {
let engine = VectorEngine::new();
let mut meta = HashMap::new();
meta.insert(
"count".to_string(),
TensorValue::Scalar(ScalarValue::Int(42)),
);
engine
.store_embedding_with_metadata("item", vec![1.0, 0.0], meta)
.unwrap();
let filter = FilterCondition::Contains("count".to_string(), "4".to_string());
let results = engine
.search_similar_filtered(&[1.0, 0.0], 10, &filter, None)
.unwrap();
assert_eq!(results.len(), 0);
}
#[test]
fn search_filtered_starts_with() {
let engine = VectorEngine::new();
let mut meta1 = HashMap::new();
meta1.insert(
"sku".to_string(),
TensorValue::Scalar(ScalarValue::String("ABC123".to_string())),
);
engine
.store_embedding_with_metadata("item1", vec![1.0, 0.0], meta1)
.unwrap();
let mut meta2 = HashMap::new();
meta2.insert(
"sku".to_string(),
TensorValue::Scalar(ScalarValue::String("XYZ789".to_string())),
);
engine
.store_embedding_with_metadata("item2", vec![0.0, 1.0], meta2)
.unwrap();
let filter = FilterCondition::StartsWith("sku".to_string(), "ABC".to_string());
let results = engine
.search_similar_filtered(&[1.0, 0.0], 10, &filter, None)
.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].key, "item1");
}
#[test]
fn search_filtered_starts_with_on_non_string() {
let engine = VectorEngine::new();
let mut meta = HashMap::new();
meta.insert(
"count".to_string(),
TensorValue::Scalar(ScalarValue::Int(123)),
);
engine
.store_embedding_with_metadata("item", vec![1.0, 0.0], meta)
.unwrap();
let filter = FilterCondition::StartsWith("count".to_string(), "1".to_string());
let results = engine
.search_similar_filtered(&[1.0, 0.0], 10, &filter, None)
.unwrap();
assert_eq!(results.len(), 0);
}
#[test]
fn search_filtered_missing_field() {
let engine = VectorEngine::new();
engine.store_embedding("item", vec![1.0, 0.0]).unwrap();
let filter = FilterCondition::Eq("missing".to_string(), FilterValue::Int(42));
let results = engine
.search_similar_filtered(&[1.0, 0.0], 10, &filter, None)
.unwrap();
assert_eq!(results.len(), 0);
}
#[test]
fn search_filtered_in() {
let engine = setup_filtered_search_engine();
let filter = FilterCondition::In(
"category".to_string(),
vec![
FilterValue::String("electronics".to_string()),
FilterValue::String("food".to_string()),
],
);
let results = engine
.search_similar_filtered(&[1.0, 1.0, 1.0], 10, &filter, None)
.unwrap();
assert_eq!(results.len(), 2);
}
#[test]
fn search_filtered_ne() {
let engine = setup_filtered_search_engine();
let filter = FilterCondition::Ne(
"category".to_string(),
FilterValue::String("electronics".to_string()),
);
let results = engine
.search_similar_filtered(&[1.0, 1.0, 1.0], 10, &filter, None)
.unwrap();
assert_eq!(results.len(), 2); }
#[test]
fn search_filtered_bool() {
let engine = setup_filtered_search_engine();
let filter = FilterCondition::Eq("active".to_string(), FilterValue::Bool(true));
let results = engine
.search_similar_filtered(&[1.0, 1.0, 1.0], 10, &filter, None)
.unwrap();
assert_eq!(results.len(), 2); }
#[test]
fn search_filtered_empty_result() {
let engine = setup_filtered_search_engine();
let filter = FilterCondition::Eq(
"category".to_string(),
FilterValue::String("nonexistent".to_string()),
);
let results = engine
.search_similar_filtered(&[1.0, 1.0, 1.0], 10, &filter, None)
.unwrap();
assert!(results.is_empty());
}
#[test]
fn search_filtered_pre_filter_strategy() {
let engine = setup_filtered_search_engine();
let filter = FilterCondition::Eq(
"category".to_string(),
FilterValue::String("electronics".to_string()),
);
let config = FilteredSearchConfig::pre_filter();
let results = engine
.search_similar_filtered(&[1.0, 1.0, 1.0], 10, &filter, Some(config))
.unwrap();
assert_eq!(results.len(), 1);
}
#[test]
fn search_filtered_post_filter_strategy() {
let engine = setup_filtered_search_engine();
let filter = FilterCondition::Eq(
"category".to_string(),
FilterValue::String("electronics".to_string()),
);
let config = FilteredSearchConfig::post_filter();
let results = engine
.search_similar_filtered(&[1.0, 1.0, 1.0], 10, &filter, Some(config))
.unwrap();
assert_eq!(results.len(), 1);
}
#[test]
fn search_filtered_empty_vector_error() {
let engine = VectorEngine::new();
let filter = FilterCondition::True;
let result = engine.search_similar_filtered(&[], 5, &filter, None);
assert!(matches!(result, Err(VectorError::EmptyVector)));
}
#[test]
fn search_filtered_zero_top_k_error() {
let engine = VectorEngine::new();
let filter = FilterCondition::True;
let result = engine.search_similar_filtered(&[1.0], 0, &filter, None);
assert!(matches!(result, Err(VectorError::InvalidTopK)));
}
#[test]
fn search_filtered_dimension_limit() {
let config = VectorEngineConfig {
max_dimension: Some(5),
..Default::default()
};
let engine = VectorEngine::with_config(config).unwrap();
let filter = FilterCondition::True;
let result = engine.search_similar_filtered(&[0.0; 10], 5, &filter, None);
assert!(matches!(
result,
Err(VectorError::DimensionMismatch {
expected: 5,
got: 10
})
));
}
#[test]
fn count_matching_basic() {
let engine = setup_filtered_search_engine();
let filter = FilterCondition::Gt("price".to_string(), FilterValue::Int(30));
let count = engine.count_matching(&filter);
assert_eq!(count, 2);
}
#[test]
fn list_keys_matching_basic() {
let engine = setup_filtered_search_engine();
let filter = FilterCondition::Eq(
"category".to_string(),
FilterValue::String("electronics".to_string()),
);
let keys = engine.list_keys_matching(&filter);
assert_eq!(keys.len(), 1);
assert!(keys.contains(&"item0".to_string()));
}
#[test]
fn estimate_filter_selectivity_basic() {
let engine = setup_filtered_search_engine();
let filter = FilterCondition::True;
let selectivity = engine.estimate_filter_selectivity(&filter);
assert!((selectivity - 1.0).abs() < 0.01);
let filter_specific = FilterCondition::Eq(
"category".to_string(),
FilterValue::String("electronics".to_string()),
);
let selectivity_specific = engine.estimate_filter_selectivity(&filter_specific);
assert!(selectivity_specific > 0.0 && selectivity_specific < 1.0);
}
#[test]
fn filter_condition_and_or_builders() {
let a = FilterCondition::Eq("x".to_string(), FilterValue::Int(1));
let b = FilterCondition::Eq("y".to_string(), FilterValue::Int(2));
let and_cond = a.clone().and(b.clone());
assert!(matches!(and_cond, FilterCondition::And(_, _)));
let or_cond = a.or(b);
assert!(matches!(or_cond, FilterCondition::Or(_, _)));
}
#[test]
fn filter_value_from_traits() {
let v1: FilterValue = 42_i64.into();
assert!(matches!(v1, FilterValue::Int(42)));
let v2: FilterValue = 3.14_f64.into();
assert!(matches!(v2, FilterValue::Float(f) if (f - 3.14).abs() < f64::EPSILON));
let v3: FilterValue = "hello".into();
assert!(matches!(v3, FilterValue::String(s) if s == "hello"));
let v4: FilterValue = "world".to_string().into();
assert!(matches!(v4, FilterValue::String(s) if s == "world"));
let v5: FilterValue = true.into();
assert!(matches!(v5, FilterValue::Bool(true)));
}
#[test]
fn filter_strategy_default() {
assert_eq!(FilterStrategy::default(), FilterStrategy::Auto);
}
#[test]
fn filtered_search_config_builders() {
let pre = FilteredSearchConfig::pre_filter();
assert_eq!(pre.strategy, FilterStrategy::PreFilter);
let post = FilteredSearchConfig::post_filter();
assert_eq!(post.strategy, FilterStrategy::PostFilter);
let custom = FilteredSearchConfig::default().with_oversample(5);
assert_eq!(custom.oversample_factor, 5);
}
#[test]
fn search_filtered_float_comparison() {
let engine = VectorEngine::new();
let mut meta1 = HashMap::new();
meta1.insert(
"score".to_string(),
TensorValue::Scalar(ScalarValue::Float(0.95)),
);
engine
.store_embedding_with_metadata("high", vec![1.0, 0.0], meta1)
.unwrap();
let mut meta2 = HashMap::new();
meta2.insert(
"score".to_string(),
TensorValue::Scalar(ScalarValue::Float(0.5)),
);
engine
.store_embedding_with_metadata("low", vec![0.0, 1.0], meta2)
.unwrap();
let filter = FilterCondition::Gt("score".to_string(), FilterValue::Float(0.8));
let results = engine
.search_similar_filtered(&[1.0, 0.0], 10, &filter, None)
.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].key, "high");
}
#[test]
fn search_filtered_mixed_int_float_comparison() {
let engine = VectorEngine::new();
let mut meta = HashMap::new();
meta.insert(
"value".to_string(),
TensorValue::Scalar(ScalarValue::Float(50.5)),
);
engine
.store_embedding_with_metadata("item", vec![1.0, 0.0], meta)
.unwrap();
let filter = FilterCondition::Gt("value".to_string(), FilterValue::Int(50));
let results = engine
.search_similar_filtered(&[1.0, 0.0], 10, &filter, None)
.unwrap();
assert_eq!(results.len(), 1);
}
#[test]
fn search_filtered_int_vs_float_filter() {
let engine = VectorEngine::new();
let mut meta = HashMap::new();
meta.insert(
"count".to_string(),
TensorValue::Scalar(ScalarValue::Int(100)),
);
engine
.store_embedding_with_metadata("item", vec![1.0, 0.0], meta)
.unwrap();
let filter = FilterCondition::Gt("count".to_string(), FilterValue::Float(50.5));
let results = engine
.search_similar_filtered(&[1.0, 0.0], 10, &filter, None)
.unwrap();
assert_eq!(results.len(), 1);
let filter = FilterCondition::Gt("count".to_string(), FilterValue::Float(100.0));
let results = engine
.search_similar_filtered(&[1.0, 0.0], 10, &filter, None)
.unwrap();
assert_eq!(results.len(), 0);
}
#[test]
fn search_filtered_null_comparison() {
let engine = VectorEngine::new();
let mut meta = HashMap::new();
meta.insert(
"optional".to_string(),
TensorValue::Scalar(ScalarValue::Null),
);
engine
.store_embedding_with_metadata("with_null", vec![1.0, 0.0], meta)
.unwrap();
engine
.store_embedding("without_field", vec![0.0, 1.0])
.unwrap();
let filter = FilterCondition::Eq("optional".to_string(), FilterValue::Null);
let results = engine
.search_similar_filtered(&[1.0, 0.0], 10, &filter, None)
.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].key, "with_null");
}
#[test]
fn search_filtered_string_comparison() {
let engine = VectorEngine::new();
let mut meta1 = HashMap::new();
meta1.insert(
"name".to_string(),
TensorValue::Scalar(ScalarValue::String("apple".to_string())),
);
engine
.store_embedding_with_metadata("item1", vec![1.0, 0.0], meta1)
.unwrap();
let mut meta2 = HashMap::new();
meta2.insert(
"name".to_string(),
TensorValue::Scalar(ScalarValue::String("banana".to_string())),
);
engine
.store_embedding_with_metadata("item2", vec![0.0, 1.0], meta2)
.unwrap();
let filter =
FilterCondition::Gt("name".to_string(), FilterValue::String("app".to_string()));
let results = engine
.search_similar_filtered(&[1.0, 1.0], 10, &filter, None)
.unwrap();
assert_eq!(results.len(), 2);
let filter =
FilterCondition::Le("name".to_string(), FilterValue::String("apple".to_string()));
let results = engine
.search_similar_filtered(&[1.0, 1.0], 10, &filter, None)
.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].key, "item1");
}
#[test]
fn search_filtered_bool_false() {
let engine = VectorEngine::new();
let mut meta1 = HashMap::new();
meta1.insert(
"active".to_string(),
TensorValue::Scalar(ScalarValue::Bool(true)),
);
engine
.store_embedding_with_metadata("active_item", vec![1.0, 0.0], meta1)
.unwrap();
let mut meta2 = HashMap::new();
meta2.insert(
"active".to_string(),
TensorValue::Scalar(ScalarValue::Bool(false)),
);
engine
.store_embedding_with_metadata("inactive_item", vec![0.0, 1.0], meta2)
.unwrap();
let filter = FilterCondition::Eq("active".to_string(), FilterValue::Bool(false));
let results = engine
.search_similar_filtered(&[1.0, 1.0], 10, &filter, None)
.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].key, "inactive_item");
}
#[test]
fn search_filtered_incompatible_types() {
let engine = VectorEngine::new();
let mut meta = HashMap::new();
meta.insert(
"value".to_string(),
TensorValue::Scalar(ScalarValue::String("text".to_string())),
);
engine
.store_embedding_with_metadata("item", vec![1.0, 0.0], meta)
.unwrap();
let filter = FilterCondition::Eq("value".to_string(), FilterValue::Int(42));
let results = engine
.search_similar_filtered(&[1.0, 0.0], 10, &filter, None)
.unwrap();
assert_eq!(results.len(), 0);
}
#[test]
fn search_filtered_respects_top_k() {
let engine = VectorEngine::new();
for i in 0..10 {
let mut meta = HashMap::new();
meta.insert("idx".to_string(), TensorValue::Scalar(ScalarValue::Int(i)));
engine
.store_embedding_with_metadata(&format!("item{}", i), vec![i as f32, 0.0], meta)
.unwrap();
}
let filter = FilterCondition::True;
let results = engine
.search_similar_filtered(&[5.0, 0.0], 3, &filter, None)
.unwrap();
assert_eq!(results.len(), 3);
}
#[test]
fn create_collection_basic() {
let engine = VectorEngine::new();
let config = VectorCollectionConfig::default();
engine.create_collection("test", config).unwrap();
assert!(engine.collection_exists("test"));
assert!(engine.get_collection_config("test").is_some());
}
#[test]
fn create_collection_already_exists() {
let engine = VectorEngine::new();
let config = VectorCollectionConfig::default();
engine.create_collection("test", config.clone()).unwrap();
let result = engine.create_collection("test", config);
assert!(matches!(result, Err(VectorError::CollectionExists(_))));
}
#[test]
fn delete_collection_basic() {
let engine = VectorEngine::new();
engine
.create_collection("test", VectorCollectionConfig::default())
.unwrap();
engine
.store_in_collection("test", "key1", vec![1.0, 2.0])
.unwrap();
engine
.store_in_collection("test", "key2", vec![3.0, 4.0])
.unwrap();
assert_eq!(engine.collection_count("test"), 2);
engine.delete_collection("test").unwrap();
assert!(!engine.collection_exists("test"));
assert_eq!(engine.collection_count("test"), 0);
}
#[test]
fn delete_collection_not_found() {
let engine = VectorEngine::new();
let result = engine.delete_collection("nonexistent");
assert!(matches!(result, Err(VectorError::CollectionNotFound(_))));
}
#[test]
fn list_collections_basic() {
let engine = VectorEngine::new();
engine
.create_collection("alpha", VectorCollectionConfig::default())
.unwrap();
engine
.create_collection("beta", VectorCollectionConfig::default())
.unwrap();
let mut collections = engine.list_collections();
collections.sort();
assert_eq!(collections, vec!["alpha", "beta"]);
}
#[test]
fn store_in_collection_basic() {
let engine = VectorEngine::new();
engine
.create_collection("products", VectorCollectionConfig::default())
.unwrap();
engine
.store_in_collection("products", "item1", vec![1.0, 2.0, 3.0])
.unwrap();
let vector = engine.get_from_collection("products", "item1").unwrap();
assert_eq!(vector, vec![1.0, 2.0, 3.0]);
}
#[test]
fn store_in_collection_without_prior_create() {
let engine = VectorEngine::new();
engine
.store_in_collection("auto_created", "key", vec![1.0])
.unwrap();
let vector = engine.get_from_collection("auto_created", "key").unwrap();
assert_eq!(vector, vec![1.0]);
}
#[test]
fn store_in_collection_dimension_constraint() {
let engine = VectorEngine::new();
let config = VectorCollectionConfig::default().with_dimension(3);
engine.create_collection("fixed_dim", config).unwrap();
engine
.store_in_collection("fixed_dim", "good", vec![1.0, 2.0, 3.0])
.unwrap();
let result = engine.store_in_collection("fixed_dim", "bad", vec![1.0, 2.0]);
assert!(matches!(
result,
Err(VectorError::DimensionMismatch {
expected: 3,
got: 2
})
));
}
#[test]
fn get_from_collection_not_found() {
let engine = VectorEngine::new();
let result = engine.get_from_collection("coll", "nonexistent");
assert!(matches!(result, Err(VectorError::NotFound(_))));
}
#[test]
fn delete_from_collection_basic() {
let engine = VectorEngine::new();
engine
.store_in_collection("test", "key", vec![1.0])
.unwrap();
assert!(engine.exists_in_collection("test", "key"));
engine.delete_from_collection("test", "key").unwrap();
assert!(!engine.exists_in_collection("test", "key"));
}
#[test]
fn delete_from_collection_not_found() {
let engine = VectorEngine::new();
let result = engine.delete_from_collection("coll", "nonexistent");
assert!(matches!(result, Err(VectorError::NotFound(_))));
}
#[test]
fn list_collection_keys_basic() {
let engine = VectorEngine::new();
engine
.store_in_collection("test", "alpha", vec![1.0])
.unwrap();
engine
.store_in_collection("test", "beta", vec![2.0])
.unwrap();
let mut keys = engine.list_collection_keys("test");
keys.sort();
assert_eq!(keys, vec!["alpha", "beta"]);
}
#[test]
fn collection_count_basic() {
let engine = VectorEngine::new();
assert_eq!(engine.collection_count("empty"), 0);
engine.store_in_collection("test", "a", vec![1.0]).unwrap();
engine.store_in_collection("test", "b", vec![2.0]).unwrap();
assert_eq!(engine.collection_count("test"), 2);
}
#[test]
fn search_in_collection_basic() {
let engine = VectorEngine::new();
engine
.store_in_collection("products", "p1", vec![1.0, 0.0, 0.0])
.unwrap();
engine
.store_in_collection("products", "p2", vec![0.0, 1.0, 0.0])
.unwrap();
engine
.store_in_collection("products", "p3", vec![0.0, 0.0, 1.0])
.unwrap();
let results = engine
.search_in_collection("products", &[1.0, 0.0, 0.0], 2)
.unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].key, "p1");
}
#[test]
fn search_in_collection_empty() {
let engine = VectorEngine::new();
let results = engine
.search_in_collection("empty", &[1.0, 2.0], 5)
.unwrap();
assert!(results.is_empty());
}
#[test]
fn search_in_collection_dimension_constraint() {
let engine = VectorEngine::new();
let config = VectorCollectionConfig::default().with_dimension(3);
engine.create_collection("fixed", config).unwrap();
let result = engine.search_in_collection("fixed", &[1.0, 2.0], 5);
assert!(matches!(
result,
Err(VectorError::DimensionMismatch {
expected: 3,
got: 2
})
));
}
#[test]
fn search_filtered_in_collection_basic() {
let engine = VectorEngine::new();
let mut meta1 = HashMap::new();
meta1.insert(
"category".to_string(),
TensorValue::Scalar(ScalarValue::String("A".to_string())),
);
engine
.store_in_collection_with_metadata("test", "item1", vec![1.0, 0.0], meta1)
.unwrap();
let mut meta2 = HashMap::new();
meta2.insert(
"category".to_string(),
TensorValue::Scalar(ScalarValue::String("B".to_string())),
);
engine
.store_in_collection_with_metadata("test", "item2", vec![0.0, 1.0], meta2)
.unwrap();
let filter =
FilterCondition::Eq("category".to_string(), FilterValue::String("A".to_string()));
let results = engine
.search_filtered_in_collection("test", &[1.0, 0.0], 10, &filter, None)
.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].key, "item1");
}
#[test]
fn collection_isolation() {
let engine = VectorEngine::new();
engine
.store_in_collection("coll_a", "key1", vec![1.0])
.unwrap();
engine
.store_in_collection("coll_b", "key1", vec![2.0])
.unwrap();
let v_a = engine.get_from_collection("coll_a", "key1").unwrap();
let v_b = engine.get_from_collection("coll_b", "key1").unwrap();
assert_eq!(v_a, vec![1.0]);
assert_eq!(v_b, vec![2.0]);
let results_a = engine.search_in_collection("coll_a", &[1.0], 10).unwrap();
let results_b = engine.search_in_collection("coll_b", &[1.0], 10).unwrap();
assert_eq!(results_a.len(), 1);
assert_eq!(results_b.len(), 1);
}
#[test]
fn collection_and_default_isolation() {
let engine = VectorEngine::new();
engine.store_embedding("key1", vec![1.0]).unwrap();
engine
.store_in_collection("named", "key1", vec![2.0])
.unwrap();
let default_v = engine.get_embedding("key1").unwrap();
let named_v = engine.get_from_collection("named", "key1").unwrap();
assert_eq!(default_v, vec![1.0]);
assert_eq!(named_v, vec![2.0]);
}
#[test]
fn collection_config_with_dimension() {
let config = VectorCollectionConfig::default().with_dimension(128);
assert_eq!(config.dimension, Some(128));
}
#[test]
fn collection_config_with_metric() {
let config = VectorCollectionConfig::default().with_metric(DistanceMetric::Euclidean);
assert_eq!(config.distance_metric, DistanceMetric::Euclidean);
}
#[test]
fn collection_config_with_auto_index() {
let config = VectorCollectionConfig::default().with_auto_index(500);
assert!(config.auto_index);
assert_eq!(config.auto_index_threshold, 500);
}
#[test]
fn collection_config_default() {
let config = VectorCollectionConfig::default();
assert_eq!(config.dimension, None);
assert_eq!(config.distance_metric, DistanceMetric::Cosine);
assert!(!config.auto_index);
assert_eq!(config.auto_index_threshold, 1000);
}
#[test]
fn metadata_value_from_tensor_value() {
use tensor_store::ScalarValue;
let null_tv = TensorValue::Scalar(ScalarValue::Null);
assert!(matches!(
MetadataValue::from_tensor_value(&null_tv),
Some(MetadataValue::Null)
));
let bool_tv = TensorValue::Scalar(ScalarValue::Bool(true));
assert!(matches!(
MetadataValue::from_tensor_value(&bool_tv),
Some(MetadataValue::Bool(true))
));
let int_tv = TensorValue::Scalar(ScalarValue::Int(42));
assert!(matches!(
MetadataValue::from_tensor_value(&int_tv),
Some(MetadataValue::Int(42))
));
let float_tv = TensorValue::Scalar(ScalarValue::Float(3.14));
if let Some(MetadataValue::Float(f)) = MetadataValue::from_tensor_value(&float_tv) {
assert!((f - 3.14).abs() < 1e-10);
} else {
panic!("Expected Float");
}
let string_tv = TensorValue::Scalar(ScalarValue::String("hello".to_string()));
assert!(matches!(
MetadataValue::from_tensor_value(&string_tv),
Some(MetadataValue::String(s)) if s == "hello"
));
let vector_tv = TensorValue::Vector(vec![1.0, 2.0, 3.0]);
assert!(MetadataValue::from_tensor_value(&vector_tv).is_none());
}
#[test]
fn metadata_value_to_tensor_value() {
use tensor_store::ScalarValue;
let null_mv = MetadataValue::Null;
assert!(matches!(
TensorValue::from(null_mv),
TensorValue::Scalar(ScalarValue::Null)
));
let bool_mv = MetadataValue::Bool(true);
assert!(matches!(
TensorValue::from(bool_mv),
TensorValue::Scalar(ScalarValue::Bool(true))
));
let int_mv = MetadataValue::Int(42);
assert!(matches!(
TensorValue::from(int_mv),
TensorValue::Scalar(ScalarValue::Int(42))
));
let float_mv = MetadataValue::Float(3.14);
if let TensorValue::Scalar(ScalarValue::Float(f)) = TensorValue::from(float_mv) {
assert!((f - 3.14).abs() < 1e-10);
} else {
panic!("Expected Float");
}
let string_mv = MetadataValue::String("hello".to_string());
assert!(matches!(
TensorValue::from(string_mv),
TensorValue::Scalar(ScalarValue::String(s)) if s == "hello"
));
}
#[test]
fn persistent_vector_index_new() {
let config = VectorCollectionConfig::default();
let index = PersistentVectorIndex::new("test".to_string(), config);
assert_eq!(index.collection, "test");
assert!(index.vectors.is_empty());
assert_eq!(index.version, PersistentVectorIndex::CURRENT_VERSION);
}
#[test]
fn persistent_vector_index_push() {
let config = VectorCollectionConfig::default();
let mut index = PersistentVectorIndex::new("test".to_string(), config);
assert!(index.is_empty());
index.push("key1".to_string(), vec![1.0, 2.0], None);
assert_eq!(index.len(), 1);
let mut meta = HashMap::new();
meta.insert(
"tag".to_string(),
MetadataValue::String("value".to_string()),
);
index.push("key2".to_string(), vec![3.0, 4.0], Some(meta));
assert_eq!(index.len(), 2);
}
#[test]
fn snapshot_collection_default() {
let engine = VectorEngine::new();
engine.store_embedding("vec1", vec![1.0, 2.0]).unwrap();
engine.store_embedding("vec2", vec![3.0, 4.0]).unwrap();
let index = engine.snapshot_collection(VectorEngine::DEFAULT_COLLECTION);
assert_eq!(index.collection, VectorEngine::DEFAULT_COLLECTION);
assert_eq!(index.len(), 2);
}
#[test]
fn snapshot_collection_with_metadata() {
let engine = VectorEngine::new();
let mut meta = HashMap::new();
meta.insert(
"category".to_string(),
TensorValue::Scalar(ScalarValue::String("test".to_string())),
);
engine
.store_embedding_with_metadata("vec1", vec![1.0, 2.0], meta)
.unwrap();
let index = engine.snapshot_collection(VectorEngine::DEFAULT_COLLECTION);
assert_eq!(index.len(), 1);
let entry = &index.vectors[0];
assert_eq!(entry.key, "vec1");
assert!(entry.metadata.is_some());
let meta = entry.metadata.as_ref().unwrap();
assert!(matches!(
meta.get("category"),
Some(MetadataValue::String(s)) if s == "test"
));
}
#[test]
fn snapshot_collection_named() {
let engine = VectorEngine::new();
engine
.create_collection("mycoll", VectorCollectionConfig::default())
.unwrap();
engine
.store_in_collection("mycoll", "vec1", vec![1.0, 2.0])
.unwrap();
let index = engine.snapshot_collection("mycoll");
assert_eq!(index.collection, "mycoll");
assert_eq!(index.len(), 1);
}
#[test]
fn save_and_load_index_json() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("index.json");
let engine = VectorEngine::new();
engine.store_embedding("vec1", vec![1.0, 2.0, 3.0]).unwrap();
engine.store_embedding("vec2", vec![4.0, 5.0, 6.0]).unwrap();
engine
.save_index(VectorEngine::DEFAULT_COLLECTION, &path)
.unwrap();
let engine2 = VectorEngine::new();
let collection = engine2.load_index(&path).unwrap();
assert_eq!(collection, VectorEngine::DEFAULT_COLLECTION);
assert_eq!(engine2.count(), 2);
let v1 = engine2.get_embedding("vec1").unwrap();
assert_eq!(v1, vec![1.0, 2.0, 3.0]);
let v2 = engine2.get_embedding("vec2").unwrap();
assert_eq!(v2, vec![4.0, 5.0, 6.0]);
}
#[test]
fn save_and_load_index_binary() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("index.bin");
let engine = VectorEngine::new();
engine.store_embedding("vec1", vec![1.0, 2.0, 3.0]).unwrap();
engine
.save_index_binary(VectorEngine::DEFAULT_COLLECTION, &path)
.unwrap();
let engine2 = VectorEngine::new();
let collection = engine2.load_index_binary(&path).unwrap();
assert_eq!(collection, VectorEngine::DEFAULT_COLLECTION);
let v1 = engine2.get_embedding("vec1").unwrap();
assert_eq!(v1, vec![1.0, 2.0, 3.0]);
}
#[test]
fn save_and_load_index_with_metadata() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("index.json");
let engine = VectorEngine::new();
let mut meta = HashMap::new();
meta.insert(
"name".to_string(),
TensorValue::Scalar(ScalarValue::String("test".to_string())),
);
meta.insert(
"score".to_string(),
TensorValue::Scalar(ScalarValue::Int(42)),
);
engine
.store_embedding_with_metadata("vec1", vec![1.0, 2.0], meta)
.unwrap();
engine
.save_index(VectorEngine::DEFAULT_COLLECTION, &path)
.unwrap();
let engine2 = VectorEngine::new();
engine2.load_index(&path).unwrap();
let meta = engine2.get_metadata("vec1").unwrap();
assert!(matches!(
meta.get("name"),
Some(TensorValue::Scalar(ScalarValue::String(s))) if s == "test"
));
assert!(matches!(
meta.get("score"),
Some(TensorValue::Scalar(ScalarValue::Int(42)))
));
}
#[test]
fn save_and_load_named_collection() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("mycoll.json");
let engine = VectorEngine::new();
let config = VectorCollectionConfig::default().with_dimension(3);
engine.create_collection("mycoll", config).unwrap();
engine
.store_in_collection("mycoll", "vec1", vec![1.0, 2.0, 3.0])
.unwrap();
engine.save_index("mycoll", &path).unwrap();
let engine2 = VectorEngine::new();
let collection = engine2.load_index(&path).unwrap();
assert_eq!(collection, "mycoll");
assert!(engine2.collection_exists("mycoll"));
let v1 = engine2.get_from_collection("mycoll", "vec1").unwrap();
assert_eq!(v1, vec![1.0, 2.0, 3.0]);
}
#[test]
fn save_all_indices_basic() {
let dir = tempfile::tempdir().unwrap();
let engine = VectorEngine::new();
engine
.store_embedding("default_vec", vec![1.0, 2.0])
.unwrap();
engine
.create_collection("coll_a", VectorCollectionConfig::default())
.unwrap();
engine
.store_in_collection("coll_a", "vec_a", vec![3.0, 4.0])
.unwrap();
engine
.create_collection("coll_b", VectorCollectionConfig::default())
.unwrap();
engine
.store_in_collection("coll_b", "vec_b", vec![5.0, 6.0])
.unwrap();
let saved = engine.save_all_indices(dir.path()).unwrap();
assert_eq!(saved.len(), 3);
assert!(saved.contains(&VectorEngine::DEFAULT_COLLECTION.to_string()));
assert!(saved.contains(&"coll_a".to_string()));
assert!(saved.contains(&"coll_b".to_string()));
assert!(dir.path().join("default.json").exists());
assert!(dir.path().join("coll_a.json").exists());
assert!(dir.path().join("coll_b.json").exists());
}
#[test]
fn load_all_indices_basic() {
let dir = tempfile::tempdir().unwrap();
let engine = VectorEngine::new();
engine
.store_embedding("default_vec", vec![1.0, 2.0])
.unwrap();
engine
.create_collection("coll_a", VectorCollectionConfig::default())
.unwrap();
engine
.store_in_collection("coll_a", "vec_a", vec![3.0, 4.0])
.unwrap();
engine.save_all_indices(dir.path()).unwrap();
let engine2 = VectorEngine::new();
let loaded = engine2.load_all_indices(dir.path()).unwrap();
assert_eq!(loaded.len(), 2);
let default_vec = engine2.get_embedding("default_vec").unwrap();
assert_eq!(default_vec, vec![1.0, 2.0]);
let vec_a = engine2.get_from_collection("coll_a", "vec_a").unwrap();
assert_eq!(vec_a, vec![3.0, 4.0]);
}
#[test]
fn save_empty_collection_skipped() {
let dir = tempfile::tempdir().unwrap();
let engine = VectorEngine::new();
engine
.create_collection("empty", VectorCollectionConfig::default())
.unwrap();
let saved = engine.save_all_indices(dir.path()).unwrap();
assert!(saved.is_empty());
assert!(!dir.path().join("empty.json").exists());
}
#[test]
fn index_roundtrip_preserves_collection_config() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("coll.json");
let engine = VectorEngine::new();
let config = VectorCollectionConfig::default()
.with_dimension(128)
.with_metric(DistanceMetric::Euclidean)
.with_auto_index(500);
engine.create_collection("custom", config).unwrap();
engine
.store_in_collection("custom", "vec1", vec![1.0; 128])
.unwrap();
engine.save_index("custom", &path).unwrap();
let engine2 = VectorEngine::new();
engine2.load_index(&path).unwrap();
let loaded_config = engine2.get_collection_config("custom").unwrap();
assert_eq!(loaded_config.dimension, Some(128));
assert_eq!(loaded_config.distance_metric, DistanceMetric::Euclidean);
assert!(loaded_config.auto_index);
assert_eq!(loaded_config.auto_index_threshold, 500);
}
#[test]
fn load_index_io_error() {
let engine = VectorEngine::new();
let result = engine.load_index("/nonexistent/path/index.json");
assert!(matches!(result, Err(VectorError::IoError(_))));
}
#[test]
fn load_index_invalid_json() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("invalid.json");
fs::write(&path, "not valid json").unwrap();
let engine = VectorEngine::new();
let result = engine.load_index(&path);
assert!(matches!(result, Err(VectorError::SerializationError(_))));
}
#[test]
fn load_index_binary_invalid() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("invalid.bin");
fs::write(&path, [0xFF, 0xFF, 0xFF]).unwrap();
let engine = VectorEngine::new();
let result = engine.load_index_binary(&path);
assert!(matches!(result, Err(VectorError::SerializationError(_))));
}
#[test]
fn load_index_file_size_limit() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("large.json");
fs::write(&path, vec![b'x'; 200]).unwrap();
let config = VectorEngineConfig::default().with_max_index_file_bytes(100);
let engine = VectorEngine::with_config(config).unwrap();
let result = engine.load_index(&path);
assert!(matches!(
result,
Err(VectorError::ConfigurationError(msg)) if msg.contains("exceeds limit")
));
}
#[test]
fn load_index_binary_file_size_limit() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("large.bin");
fs::write(&path, vec![0u8; 200]).unwrap();
let config = VectorEngineConfig::default().with_max_index_file_bytes(100);
let engine = VectorEngine::with_config(config).unwrap();
let result = engine.load_index_binary(&path);
assert!(matches!(
result,
Err(VectorError::ConfigurationError(msg)) if msg.contains("exceeds limit")
));
}
#[test]
fn load_index_entry_count_limit() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("index.json");
let index = PersistentVectorIndex {
collection: "test".to_string(),
config: VectorCollectionConfig::default(),
vectors: (0..5)
.map(|i| VectorEntry {
key: format!("key{i}"),
vector: vec![1.0, 2.0, 3.0],
metadata: None,
})
.collect(),
created_at: 0,
version: 1,
};
let json = serde_json::to_string(&index).unwrap();
fs::write(&path, json).unwrap();
let config = VectorEngineConfig::default().with_max_index_entries(2);
let engine = VectorEngine::with_config(config).unwrap();
let result = engine.load_index(&path);
assert!(matches!(
result,
Err(VectorError::ConfigurationError(msg)) if msg.contains("entry count") && msg.contains("exceeds limit")
));
}
#[test]
fn vector_entry_serialization() {
let entry = VectorEntry {
key: "test".to_string(),
vector: vec![1.0, 2.0, 3.0],
metadata: None,
};
let json = serde_json::to_string(&entry).unwrap();
let deserialized: VectorEntry = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.key, "test");
assert_eq!(deserialized.vector, vec![1.0, 2.0, 3.0]);
assert!(deserialized.metadata.is_none());
}
#[test]
fn metadata_value_serialization() {
let values = vec![
MetadataValue::Null,
MetadataValue::Bool(true),
MetadataValue::Int(42),
MetadataValue::Float(3.14),
MetadataValue::String("hello".to_string()),
];
for value in values {
let json = serde_json::to_string(&value).unwrap();
let deserialized: MetadataValue = serde_json::from_str(&json).unwrap();
match (&value, &deserialized) {
(MetadataValue::Null, MetadataValue::Null) => {},
(MetadataValue::Bool(a), MetadataValue::Bool(b)) => assert_eq!(a, b),
(MetadataValue::Int(a), MetadataValue::Int(b)) => assert_eq!(a, b),
(MetadataValue::Float(a), MetadataValue::Float(b)) => {
assert!((a - b).abs() < 1e-10);
},
(MetadataValue::String(a), MetadataValue::String(b)) => assert_eq!(a, b),
_ => panic!("Type mismatch"),
}
}
}
#[test]
fn config_with_search_timeout() {
let config = VectorEngineConfig::default().with_search_timeout(Duration::from_secs(5));
assert_eq!(config.search_timeout, Some(Duration::from_secs(5)));
}
#[test]
fn search_timeout_error_display() {
let err = VectorError::SearchTimeout {
operation: "search_similar".to_string(),
timeout_ms: 5000,
};
let display = err.to_string();
assert!(display.contains("search_similar"));
assert!(display.contains("5000"));
}
#[test]
fn search_similar_respects_timeout() {
let config = VectorEngineConfig::default().with_search_timeout(Duration::from_nanos(1));
let engine = VectorEngine::with_config(config).unwrap();
for i in 0..1000 {
engine
.store_embedding(&format!("v{i}"), vec![i as f32; 128])
.unwrap();
}
let result = engine.search_similar(&[0.5f32; 128], 10);
assert!(matches!(result, Err(VectorError::SearchTimeout { .. })));
}
#[test]
fn search_similar_no_timeout_when_none() {
let engine = VectorEngine::new();
for i in 0..100 {
engine
.store_embedding(&format!("v{i}"), vec![i as f32, 0.0])
.unwrap();
}
let result = engine.search_similar(&[50.0, 0.0], 10);
assert!(result.is_ok());
}
#[test]
fn deadline_never_does_not_expire() {
let deadline = Deadline::never();
assert!(!deadline.is_expired());
assert_eq!(deadline.timeout_ms(), 0);
}
#[test]
fn deadline_from_duration_expires() {
let deadline = Deadline::from_duration(Some(Duration::from_nanos(1)));
std::thread::sleep(Duration::from_millis(1));
assert!(deadline.is_expired());
}
#[test]
fn deadline_none_duration_never_expires() {
let deadline = Deadline::from_duration(None);
assert!(!deadline.is_expired());
}
#[test]
fn low_memory_config_has_timeout() {
let config = VectorEngineConfig::low_memory();
assert_eq!(config.search_timeout, Some(Duration::from_secs(30)));
}
#[test]
fn high_throughput_config_has_no_timeout() {
let config = VectorEngineConfig::high_throughput();
assert!(config.search_timeout.is_none());
}
#[test]
fn search_with_metric_respects_timeout() {
let config = VectorEngineConfig::default().with_search_timeout(Duration::from_nanos(1));
let engine = VectorEngine::with_config(config).unwrap();
for i in 0..1000 {
engine
.store_embedding(&format!("v{i}"), vec![i as f32; 128])
.unwrap();
}
let result = engine.search_similar_with_metric(&[0.5f32; 128], 10, DistanceMetric::Cosine);
assert!(matches!(result, Err(VectorError::SearchTimeout { .. })));
}
#[test]
fn search_entities_respects_timeout() {
let config = VectorEngineConfig::default().with_search_timeout(Duration::from_nanos(1));
let engine = VectorEngine::with_config(config).unwrap();
for i in 0..1000 {
engine
.set_entity_embedding(&format!("entity:{i}"), vec![i as f32; 128])
.unwrap();
}
let result = engine.search_entities(&[0.5f32; 128], 10);
assert!(matches!(result, Err(VectorError::SearchTimeout { .. })));
}
#[test]
fn hnsw_build_options_default() {
let options = HNSWBuildOptions::default();
assert_eq!(options.storage, HNSWStorageStrategy::Dense);
}
#[test]
fn hnsw_build_options_new() {
let options = HNSWBuildOptions::new();
assert_eq!(options.storage, HNSWStorageStrategy::Dense);
}
#[test]
fn hnsw_build_options_memory_optimized() {
let options = HNSWBuildOptions::memory_optimized();
assert_eq!(options.storage, HNSWStorageStrategy::Quantized);
}
#[test]
fn hnsw_build_options_high_recall() {
let options = HNSWBuildOptions::high_recall();
assert_eq!(options.storage, HNSWStorageStrategy::Dense);
}
#[test]
fn hnsw_build_options_sparse_optimized() {
let options = HNSWBuildOptions::sparse_optimized();
assert_eq!(options.storage, HNSWStorageStrategy::Auto);
}
#[test]
fn hnsw_build_options_builder_methods() {
let options = HNSWBuildOptions::new()
.with_storage(HNSWStorageStrategy::Quantized)
.with_sparsity_threshold(0.7);
assert_eq!(options.storage, HNSWStorageStrategy::Quantized);
assert!((options.hnsw_config.sparsity_threshold - 0.7).abs() < f32::EPSILON);
}
#[test]
fn hnsw_build_options_with_hnsw_config() {
let config = HNSWConfig::high_recall();
let options = HNSWBuildOptions::new().with_hnsw_config(config.clone());
assert_eq!(options.hnsw_config.m, config.m);
assert_eq!(options.hnsw_config.ef_construction, config.ef_construction);
}
#[test]
fn build_hnsw_index_with_options_dense() {
let engine = VectorEngine::new();
engine.store_embedding("a", vec![1.0, 0.0, 0.0]).unwrap();
engine.store_embedding("b", vec![0.0, 1.0, 0.0]).unwrap();
engine.store_embedding("c", vec![0.0, 0.0, 1.0]).unwrap();
let options = HNSWBuildOptions::new().with_storage(HNSWStorageStrategy::Dense);
let (index, keys) = engine.build_hnsw_index_with_options(options).unwrap();
assert_eq!(keys.len(), 3);
assert_eq!(index.len(), 3);
let results = index.search(&[1.0, 0.0, 0.0], 1);
assert!(!results.is_empty());
}
#[test]
fn build_hnsw_index_with_options_auto() {
let engine = VectorEngine::new();
engine
.store_embedding(
"sparse1",
vec![1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
)
.unwrap();
engine
.store_embedding(
"sparse2",
vec![0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 0.0, 0.0, 0.0, 0.0],
)
.unwrap();
let options = HNSWBuildOptions::sparse_optimized();
let (index, keys) = engine.build_hnsw_index_with_options(options).unwrap();
assert_eq!(keys.len(), 2);
assert_eq!(index.len(), 2);
}
#[test]
fn build_hnsw_index_with_options_quantized() {
let engine = VectorEngine::new();
engine.store_embedding("a", vec![1.0, 0.0, 0.0]).unwrap();
engine.store_embedding("b", vec![0.0, 1.0, 0.0]).unwrap();
engine.store_embedding("c", vec![0.0, 0.0, 1.0]).unwrap();
let options = HNSWBuildOptions::memory_optimized();
let (index, keys) = engine.build_hnsw_index_with_options(options).unwrap();
assert_eq!(keys.len(), 3);
assert_eq!(index.len(), 3);
let results = index.search(&[1.0, 0.0, 0.0], 1);
assert!(!results.is_empty());
}
#[test]
fn build_hnsw_index_with_options_empty_store() {
let engine = VectorEngine::new();
let options = HNSWBuildOptions::new();
let (index, keys) = engine.build_hnsw_index_with_options(options).unwrap();
assert!(keys.is_empty());
assert_eq!(index.len(), 0);
}
#[test]
fn build_hnsw_index_with_options_dimension_mismatch() {
let engine = VectorEngine::new();
engine.store_embedding("a", vec![1.0, 2.0, 3.0]).unwrap();
engine.store_embedding("b", vec![1.0, 2.0]).unwrap();
let options = HNSWBuildOptions::new();
let result = engine.build_hnsw_index_with_options(options);
assert!(matches!(result, Err(VectorError::DimensionMismatch { .. })));
}
#[test]
fn quantized_search_recall() {
use tensor_store::ScalarQuantizedVector;
let original: Vec<f32> = (0..64).map(|i| (i as f32 * 0.1).sin() * 5.0).collect();
let quantized = ScalarQuantizedVector::from_dense(&original);
let dequantized = quantized.dequantize();
let max_error: f32 = original
.iter()
.zip(dequantized.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0_f32, f32::max);
let range = original.iter().copied().fold(f32::NEG_INFINITY, f32::max)
- original.iter().copied().fold(f32::INFINITY, f32::min);
assert!(
max_error <= range / 255.0 + 0.01,
"Quantization error ({max_error}) too large for range ({range})"
);
let a: Vec<f32> = (0..64).map(|i| (i as f32 * 0.2).sin() * 3.0).collect();
let b: Vec<f32> = (0..64).map(|i| (i as f32 * 0.3).cos() * 4.0).collect();
let qa = ScalarQuantizedVector::from_dense(&a);
let dense_dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let dense_mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let dense_mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
let dense_cosine = 1.0 - dense_dot / (dense_mag_a * dense_mag_b);
let quantized_cosine = qa.cosine_distance_dense(&b);
assert!(
(dense_cosine - quantized_cosine).abs() < 0.05,
"Cosine distance mismatch: dense={dense_cosine}, quantized={quantized_cosine}"
);
let engine = VectorEngine::new();
for i in 0..50 {
let v: Vec<f32> = (0..32)
.map(|d| ((i * 7 + d * 3) as f32 * 0.5).sin() * 4.0)
.collect();
engine.store_embedding(&format!("v{i}"), v).unwrap();
}
let (index, mapping) = engine
.build_hnsw_index_with_options(HNSWBuildOptions::memory_optimized())
.unwrap();
assert_eq!(mapping.len(), 50);
let query: Vec<f32> = (0..32)
.map(|d| ((999 * 7 + d * 3) as f32 * 0.5).sin() * 4.0)
.collect();
let results = index.search(&query, 5);
assert_eq!(results.len(), 5, "Should return requested k results");
for (_, sim) in &results {
assert!(*sim >= -1.0, "Similarity should be >= -1, got {sim}");
assert!(*sim <= 1.0, "Similarity should be <= 1, got {sim}");
}
for pair in results.windows(2) {
assert!(
pair[0].1 >= pair[1].1 - f32::EPSILON,
"Results not sorted: {} < {}",
pair[0].1,
pair[1].1
);
}
}
#[test]
fn hnsw_storage_strategy_default() {
let strategy = HNSWStorageStrategy::default();
assert_eq!(strategy, HNSWStorageStrategy::Dense);
}
#[test]
fn test_open_durable() {
let dir = tempfile::tempdir().unwrap();
let wal_path = dir.path().join("vector.wal");
let engine = VectorEngine::open_durable(&wal_path, WalConfig::default()).unwrap();
assert!(engine.is_durable());
engine
.store_embedding("test_key", vec![1.0, 2.0, 3.0])
.unwrap();
assert!(engine.get_embedding("test_key").is_ok());
}
#[test]
fn test_recover_durable() {
let dir = tempfile::tempdir().unwrap();
let wal_path = dir.path().join("vector.wal");
{
let _engine = VectorEngine::open_durable(&wal_path, WalConfig::default()).unwrap();
}
let recovered = VectorEngine::recover(&wal_path, &WalConfig::default(), None);
assert!(recovered.is_ok());
assert!(recovered.unwrap().is_durable());
}
#[test]
fn test_is_durable_false_for_in_memory() {
let engine = VectorEngine::new();
assert!(!engine.is_durable());
}
#[test]
fn ivf_build_options_default() {
let options = IVFBuildOptions::default();
assert_eq!(options.config.num_clusters, 100);
}
#[test]
fn ivf_build_options_new() {
let options = IVFBuildOptions::new();
assert_eq!(options.config.num_clusters, 100);
}
#[test]
fn ivf_build_options_flat() {
let options = IVFBuildOptions::flat(50);
assert_eq!(options.config.num_clusters, 50);
assert!(matches!(options.config.storage, IVFStorage::Flat));
}
#[test]
fn ivf_build_options_pq() {
let pq_config = PQConfig::default();
let options = IVFBuildOptions::pq(50, pq_config);
assert_eq!(options.config.num_clusters, 50);
assert!(matches!(options.config.storage, IVFStorage::PQ(_)));
}
#[test]
fn ivf_build_options_binary() {
let options = IVFBuildOptions::binary(50);
assert_eq!(options.config.num_clusters, 50);
assert!(matches!(options.config.storage, IVFStorage::Binary(_)));
}
#[test]
fn ivf_build_options_with_nprobe() {
let options = IVFBuildOptions::flat(50).with_nprobe(10);
assert_eq!(options.config.nprobe, 10);
}
#[test]
fn ivf_build_options_with_num_clusters() {
let options = IVFBuildOptions::new().with_num_clusters(200);
assert_eq!(options.config.num_clusters, 200);
}
#[test]
fn ivf_build_options_with_storage() {
let options = IVFBuildOptions::new().with_storage(IVFStorage::Flat);
assert!(matches!(options.config.storage, IVFStorage::Flat));
}
#[test]
fn vector_collection_config_with_dimension() {
let config = VectorCollectionConfig::default().with_dimension(128);
assert_eq!(config.dimension, Some(128));
}
#[test]
fn vector_collection_config_with_metric() {
let config = VectorCollectionConfig::default().with_metric(DistanceMetric::Euclidean);
assert_eq!(config.distance_metric, DistanceMetric::Euclidean);
}
#[test]
fn vector_collection_config_with_auto_index() {
let config = VectorCollectionConfig::default().with_auto_index(500);
assert!(config.auto_index);
assert_eq!(config.auto_index_threshold, 500);
}
#[test]
fn metadata_value_from_tensor_value_bytes() {
let bytes_val = TensorValue::Scalar(ScalarValue::Bytes(vec![1, 2, 3]));
let result = MetadataValue::from_tensor_value(&bytes_val);
assert!(result.is_none()); }
#[test]
fn metadata_value_from_tensor_value_vector() {
let vec_val = TensorValue::Vector(vec![1.0, 2.0, 3.0]);
let result = MetadataValue::from_tensor_value(&vec_val);
assert!(result.is_none()); }
#[test]
fn distance_metric_to_extended_euclidean() {
let metric = DistanceMetric::Euclidean;
let extended: ExtendedDistanceMetric = metric.into();
assert!(matches!(extended, ExtendedDistanceMetric::Euclidean));
}
#[test]
fn distance_metric_to_extended_cosine() {
let metric = DistanceMetric::Cosine;
let extended: ExtendedDistanceMetric = metric.into();
assert!(matches!(extended, ExtendedDistanceMetric::Cosine));
}
#[test]
fn distance_metric_to_extended_dot_product() {
let metric = DistanceMetric::DotProduct;
let extended: ExtendedDistanceMetric = metric.into();
assert!(matches!(extended, ExtendedDistanceMetric::Cosine));
}
#[test]
fn vector_error_display_collection_exists() {
let err = VectorError::CollectionExists("my_collection".to_string());
assert!(err.to_string().contains("my_collection"));
assert!(err.to_string().contains("already exists"));
}
#[test]
fn vector_error_display_collection_not_found() {
let err = VectorError::CollectionNotFound("missing_collection".to_string());
assert!(err.to_string().contains("missing_collection"));
assert!(err.to_string().contains("not found"));
}
#[test]
fn vector_error_display_io_error() {
let err = VectorError::IoError("disk full".to_string());
assert!(err.to_string().contains("disk full"));
assert!(err.to_string().contains("IO error"));
}
#[test]
fn vector_error_display_serialization_error() {
let err = VectorError::SerializationError("invalid format".to_string());
assert!(err.to_string().contains("invalid format"));
assert!(err.to_string().contains("Serialization"));
}
#[test]
fn filter_value_from_i64() {
let val: FilterValue = 42i64.into();
assert!(matches!(val, FilterValue::Int(42)));
}
#[test]
fn filter_value_from_f64() {
let val: FilterValue = 3.14f64.into();
assert!(matches!(val, FilterValue::Float(f) if (f - 3.14).abs() < f64::EPSILON));
}
#[test]
fn filter_value_from_string() {
let val: FilterValue = String::from("hello").into();
assert!(matches!(val, FilterValue::String(s) if s == "hello"));
}
#[test]
fn filter_value_from_str() {
let val: FilterValue = "world".into();
assert!(matches!(val, FilterValue::String(s) if s == "world"));
}
#[test]
fn filter_value_from_bool() {
let val: FilterValue = true.into();
assert!(matches!(val, FilterValue::Bool(true)));
}
#[test]
fn deadline_never_not_expired() {
let deadline = Deadline::never();
assert!(!deadline.is_expired());
assert_eq!(deadline.timeout_ms(), 0);
}
#[test]
fn deadline_from_duration_some() {
let deadline = Deadline::from_duration(Some(Duration::from_millis(100)));
assert!(!deadline.is_expired());
assert_eq!(deadline.timeout_ms(), 100);
}
#[test]
fn deadline_from_duration_none() {
let deadline = Deadline::from_duration(None);
assert!(!deadline.is_expired());
assert_eq!(deadline.timeout_ms(), 0);
}
#[test]
fn search_result_new() {
let result = SearchResult::new("test_key".to_string(), 0.95);
assert_eq!(result.key, "test_key");
assert!((result.score - 0.95).abs() < f32::EPSILON);
}
#[test]
fn embedding_input_constructor() {
let input = EmbeddingInput::new("my_key", vec![1.0, 2.0, 3.0]);
assert_eq!(input.key, "my_key");
assert_eq!(input.vector, vec![1.0, 2.0, 3.0]);
}
#[test]
fn batch_delete_multiple_keys() {
let engine = VectorEngine::new();
engine.store_embedding("a", vec![1.0, 0.0]).unwrap();
engine.store_embedding("b", vec![0.0, 1.0]).unwrap();
engine.store_embedding("c", vec![1.0, 1.0]).unwrap();
let deleted = engine
.batch_delete_embeddings(vec!["a".to_string(), "b".to_string()])
.unwrap();
assert_eq!(deleted, 2);
assert!(!engine.exists("a"));
assert!(!engine.exists("b"));
assert!(engine.exists("c"));
}
#[test]
fn batch_delete_partial_exists() {
let engine = VectorEngine::new();
engine.store_embedding("a", vec![1.0, 0.0]).unwrap();
let deleted = engine
.batch_delete_embeddings(vec!["a".to_string(), "nonexistent".to_string()])
.unwrap();
assert_eq!(deleted, 1);
}
#[test]
fn pagination_constructor_variants() {
let p1 = Pagination::new(10, 20);
assert_eq!(p1.skip, 10);
assert_eq!(p1.limit, Some(20));
assert!(!p1.count_total);
let p2 = p1.with_total();
assert!(p2.count_total);
let p3 = Pagination::skip_only(5);
assert_eq!(p3.skip, 5);
assert!(p3.limit.is_none());
}
#[test]
fn list_keys_paginated_no_limit_variant() {
let engine = VectorEngine::new();
for i in 0..5 {
engine
.store_embedding(&format!("key{i}"), vec![i as f32, 0.0])
.unwrap();
}
let pagination = Pagination::skip_only(2);
let result = engine.list_keys_paginated(pagination);
assert_eq!(result.items.len(), 3);
}
#[test]
fn search_entities_paginated_no_count_variant() {
let engine = VectorEngine::new();
for i in 0..5 {
engine
.set_entity_embedding(&format!("entity:{i}"), vec![i as f32, 0.0])
.unwrap();
}
let pagination = Pagination::new(0, 3);
let result = engine
.search_entities_paginated(&[2.0, 0.0], 5, pagination)
.unwrap();
assert_eq!(result.items.len(), 3);
assert!(result.total_count.is_none());
assert!(!result.has_more);
}
#[test]
fn scan_entities_with_embeddings_test() {
let engine = VectorEngine::new();
engine
.set_entity_embedding("user:1", vec![1.0, 0.0])
.unwrap();
engine
.set_entity_embedding("user:2", vec![0.0, 1.0])
.unwrap();
let entities = engine.scan_entities_with_embeddings();
assert_eq!(entities.len(), 2);
}
#[test]
fn count_entities_with_embeddings_test() {
let engine = VectorEngine::new();
engine
.set_entity_embedding("user:1", vec![1.0, 0.0])
.unwrap();
engine
.set_entity_embedding("user:2", vec![0.0, 1.0])
.unwrap();
let count = engine.count_entities_with_embeddings();
assert_eq!(count, 2);
}
#[test]
fn estimate_ivf_memory_populated() {
let engine = VectorEngine::new();
for i in 0..10 {
engine
.store_embedding(&format!("key{i}"), vec![1.0, 2.0, 3.0, 4.0])
.unwrap();
}
let options = IVFBuildOptions::flat(5);
let memory = engine.estimate_ivf_memory(&options).unwrap();
assert!(memory > 0);
}
#[test]
fn estimate_ivf_memory_no_vectors() {
let engine = VectorEngine::new();
let options = IVFBuildOptions::flat(5);
let memory = engine.estimate_ivf_memory(&options).unwrap();
assert_eq!(memory, 0);
}
#[test]
fn search_with_hnsw_angular_metric() {
let engine = VectorEngine::new();
engine.store_embedding("a", vec![1.0, 0.0, 0.0]).unwrap();
engine.store_embedding("b", vec![0.0, 1.0, 0.0]).unwrap();
engine.store_embedding("c", vec![0.0, 0.0, 1.0]).unwrap();
let (index, keys) = engine.build_hnsw_index_default().unwrap();
let results = engine
.search_with_hnsw_and_metric(
&index,
&keys,
&[1.0, 0.0, 0.0],
3,
&ExtendedDistanceMetric::Angular,
)
.unwrap();
assert_eq!(results.len(), 3);
}
#[test]
fn collection_get_metadata_test() {
let engine = VectorEngine::new();
let config = VectorCollectionConfig::default();
engine.create_collection("products", config).unwrap();
let mut meta = HashMap::new();
meta.insert(
"price".to_string(),
TensorValue::Scalar(ScalarValue::Int(100)),
);
engine
.store_in_collection_with_metadata("products", "item1", vec![1.0, 2.0], meta)
.unwrap();
let retrieved = engine.get_collection_metadata("products", "item1").unwrap();
assert!(retrieved.contains_key("price"));
}
#[test]
fn collection_search_filtered_test() {
let engine = VectorEngine::new();
let config = VectorCollectionConfig::default();
engine.create_collection("items", config).unwrap();
let mut meta1 = HashMap::new();
meta1.insert(
"category".to_string(),
TensorValue::Scalar(ScalarValue::String("A".to_string())),
);
engine
.store_in_collection_with_metadata("items", "item1", vec![1.0, 0.0], meta1)
.unwrap();
let mut meta2 = HashMap::new();
meta2.insert(
"category".to_string(),
TensorValue::Scalar(ScalarValue::String("B".to_string())),
);
engine
.store_in_collection_with_metadata("items", "item2", vec![0.0, 1.0], meta2)
.unwrap();
let filter =
FilterCondition::Eq("category".to_string(), FilterValue::String("A".to_string()));
let results = engine
.search_filtered_in_collection("items", &[1.0, 0.0], 10, &filter, None)
.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].key, "item1");
}
#[test]
fn collection_delete_from_basic() {
let engine = VectorEngine::new();
let config = VectorCollectionConfig::default();
engine.create_collection("test", config).unwrap();
engine
.store_in_collection("test", "key1", vec![1.0, 2.0])
.unwrap();
assert!(engine.exists_in_collection("test", "key1"));
engine.delete_from_collection("test", "key1").unwrap();
assert!(!engine.exists_in_collection("test", "key1"));
}
#[test]
fn collection_exists_in_false() {
let engine = VectorEngine::new();
let config = VectorCollectionConfig::default();
engine.create_collection("test", config).unwrap();
assert!(!engine.exists_in_collection("test", "nonexistent"));
}
#[test]
fn collection_list_keys_basic() {
let engine = VectorEngine::new();
let config = VectorCollectionConfig::default();
engine.create_collection("test", config).unwrap();
engine
.store_in_collection("test", "key1", vec![1.0, 2.0])
.unwrap();
engine
.store_in_collection("test", "key2", vec![2.0, 3.0])
.unwrap();
let keys = engine.list_collection_keys("test");
assert_eq!(keys.len(), 2);
}
#[test]
fn collection_list_keys_empty() {
let engine = VectorEngine::new();
let config = VectorCollectionConfig::default();
engine.create_collection("test", config).unwrap();
let keys = engine.list_collection_keys("test");
assert!(keys.is_empty());
}
#[test]
fn entity_has_embedding_false() {
let engine = VectorEngine::new();
assert!(!engine.entity_has_embedding("nonexistent:key"));
}
#[test]
fn remove_entity_embedding_success() {
let engine = VectorEngine::new();
engine
.set_entity_embedding("user:1", vec![1.0, 2.0])
.unwrap();
assert!(engine.entity_has_embedding("user:1"));
engine.remove_entity_embedding("user:1").unwrap();
assert!(!engine.entity_has_embedding("user:1"));
}
#[test]
fn remove_entity_embedding_not_found() {
let engine = VectorEngine::new();
let result = engine.remove_entity_embedding("nonexistent:key");
assert!(matches!(result, Err(VectorError::NotFound(_))));
}
#[test]
fn get_entity_embedding_not_found() {
let engine = VectorEngine::new();
let result = engine.get_entity_embedding("nonexistent:key");
assert!(matches!(result, Err(VectorError::NotFound(_))));
}
#[test]
fn search_similar_with_metric_euclidean() {
let engine = VectorEngine::new();
engine.store_embedding("a", vec![0.0, 0.0]).unwrap();
engine.store_embedding("b", vec![1.0, 0.0]).unwrap();
engine.store_embedding("c", vec![3.0, 4.0]).unwrap();
let results = engine
.search_similar_with_metric(&[0.0, 0.0], 3, DistanceMetric::Euclidean)
.unwrap();
assert_eq!(results.len(), 3);
assert_eq!(results[0].key, "a");
}
#[test]
fn search_similar_with_metric_dot_product() {
let engine = VectorEngine::new();
engine.store_embedding("a", vec![1.0, 0.0]).unwrap();
engine.store_embedding("b", vec![0.0, 1.0]).unwrap();
let results = engine
.search_similar_with_metric(&[1.0, 0.0], 2, DistanceMetric::DotProduct)
.unwrap();
assert_eq!(results.len(), 2);
}
#[test]
fn compute_similarity_identical_vectors() {
let score = VectorEngine::compute_similarity(&[1.0, 0.0], &[1.0, 0.0]).unwrap();
assert!((score - 1.0).abs() < 0.001);
}
#[test]
fn compute_similarity_orthogonal_vectors() {
let score = VectorEngine::compute_similarity(&[1.0, 0.0], &[0.0, 1.0]).unwrap();
assert!(score.abs() < 0.001);
}
#[test]
fn list_keys_bounded_with_limit() {
let config = VectorEngineConfig {
max_keys_per_scan: Some(100),
..Default::default()
};
let engine = VectorEngine::with_config(config).unwrap();
for i in 0..10 {
engine
.store_embedding(&format!("key{i}"), vec![i as f32])
.unwrap();
}
let keys = engine.list_keys_bounded();
assert_eq!(keys.len(), 10);
}
#[test]
fn clear_all_embeddings() {
let engine = VectorEngine::new();
for i in 0..5 {
engine
.store_embedding(&format!("key{i}"), vec![i as f32])
.unwrap();
}
assert_eq!(engine.count(), 5);
let removed = engine.clear().unwrap();
assert_eq!(removed, 5);
assert_eq!(engine.count(), 0);
}
#[test]
fn clear_empty_engine() {
let engine = VectorEngine::new();
let removed = engine.clear().unwrap();
assert_eq!(removed, 0);
}
#[test]
fn build_ivf_index_basic() {
let engine = VectorEngine::new();
for i in 0..20 {
engine
.store_embedding(&format!("key{i}"), vec![i as f32, (i * 2) as f32])
.unwrap();
}
let options = IVFBuildOptions::flat(5);
let (index, keys) = engine.build_ivf_index(options).unwrap();
assert_eq!(keys.len(), 20);
assert!(!index.is_empty());
}
#[test]
fn build_ivf_index_default_test() {
let engine = VectorEngine::new();
for i in 0..10 {
engine
.store_embedding(&format!("key{i}"), vec![i as f32, (i * 2) as f32])
.unwrap();
}
let (index, keys) = engine.build_ivf_index_default().unwrap();
assert_eq!(keys.len(), 10);
assert!(!index.is_empty());
}
#[test]
fn search_with_ivf_basic() {
let engine = VectorEngine::new();
for i in 0..20 {
engine
.store_embedding(&format!("key{i}"), vec![i as f32, (i * 2) as f32])
.unwrap();
}
let options = IVFBuildOptions::flat(5);
let (index, keys) = engine.build_ivf_index(options).unwrap();
let results = engine
.search_with_ivf(&index, &keys, &[5.0, 10.0], 3)
.unwrap();
assert_eq!(results.len(), 3);
}
#[test]
fn search_with_ivf_nprobe_basic() {
let engine = VectorEngine::new();
for i in 0..20 {
engine
.store_embedding(&format!("key{i}"), vec![i as f32, (i * 2) as f32])
.unwrap();
}
let options = IVFBuildOptions::flat(5);
let (index, keys) = engine.build_ivf_index(options).unwrap();
let results = engine
.search_with_ivf_nprobe(&index, &keys, &[5.0, 10.0], 3, 2)
.unwrap();
assert_eq!(results.len(), 3);
}
#[test]
fn dimension_with_vectors() {
let engine = VectorEngine::new();
engine.store_embedding("a", vec![1.0, 2.0, 3.0]).unwrap();
let dim = engine.dimension();
assert_eq!(dim, Some(3));
}
#[test]
fn dimension_empty_store() {
let engine = VectorEngine::new();
let dim = engine.dimension();
assert_eq!(dim, None);
}
#[test]
fn search_with_hnsw_simple() {
let engine = VectorEngine::new();
engine.store_embedding("a", vec![1.0, 0.0, 0.0]).unwrap();
engine.store_embedding("b", vec![0.0, 1.0, 0.0]).unwrap();
engine.store_embedding("c", vec![0.0, 0.0, 1.0]).unwrap();
let (index, keys) = engine.build_hnsw_index_default().unwrap();
let results = engine
.search_with_hnsw(&index, &keys, &[1.0, 0.0, 0.0], 2)
.unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].key, "a");
}
#[test]
fn cache_hnsw_index_accelerates_search_similar() {
let engine = VectorEngine::new();
engine.store_embedding("a", vec![1.0, 0.0, 0.0]).unwrap();
engine.store_embedding("b", vec![0.0, 1.0, 0.0]).unwrap();
engine.store_embedding("c", vec![0.0, 0.0, 1.0]).unwrap();
engine.build_and_cache_index(HNSWConfig::default()).unwrap();
let results = engine.search_similar(&[1.0, 0.0, 0.0], 2).unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].key, "a");
}
#[test]
fn cache_hnsw_index_returns_correct_top_k() {
let engine = VectorEngine::new();
engine.store_embedding("a", vec![1.0, 0.0, 0.0]).unwrap();
engine.store_embedding("b", vec![0.9, 0.1, 0.0]).unwrap();
engine.store_embedding("c", vec![0.0, 1.0, 0.0]).unwrap();
engine.store_embedding("d", vec![0.0, 0.0, 1.0]).unwrap();
engine.build_and_cache_index(HNSWConfig::default()).unwrap();
let results = engine.search_similar(&[1.0, 0.0, 0.0], 1).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].key, "a");
}
#[test]
fn cache_invalidated_on_store_embedding() {
let engine = VectorEngine::new();
engine.store_embedding("a", vec![1.0, 0.0, 0.0]).unwrap();
engine.store_embedding("b", vec![0.0, 1.0, 0.0]).unwrap();
engine.build_and_cache_index(HNSWConfig::default()).unwrap();
assert!(engine.hnsw_cache.read().contains_key("_default"));
engine.store_embedding("c", vec![0.0, 0.0, 1.0]).unwrap();
assert!(!engine.hnsw_cache.read().contains_key("_default"));
}
#[test]
fn cache_invalidated_on_delete_embedding() {
let engine = VectorEngine::new();
engine.store_embedding("a", vec![1.0, 0.0, 0.0]).unwrap();
engine.store_embedding("b", vec![0.0, 1.0, 0.0]).unwrap();
engine.build_and_cache_index(HNSWConfig::default()).unwrap();
assert!(engine.hnsw_cache.read().contains_key("_default"));
engine.delete_embedding("a").unwrap();
assert!(!engine.hnsw_cache.read().contains_key("_default"));
}
#[test]
fn cache_invalidated_on_store_in_collection() {
let engine = VectorEngine::new();
engine
.create_collection("test_coll", VectorCollectionConfig::default())
.unwrap();
engine
.store_in_collection("test_coll", "a", vec![1.0, 0.0, 0.0])
.unwrap();
let index = Arc::new(HNSWIndex::new());
engine.cache_hnsw_index("test_coll", index, vec!["a".to_string()]);
assert!(engine.hnsw_cache.read().contains_key("test_coll"));
engine
.store_in_collection("test_coll", "b", vec![0.0, 1.0, 0.0])
.unwrap();
assert!(!engine.hnsw_cache.read().contains_key("test_coll"));
}
#[test]
fn cache_invalidated_on_delete_from_collection() {
let engine = VectorEngine::new();
engine
.create_collection("test_coll", VectorCollectionConfig::default())
.unwrap();
engine
.store_in_collection("test_coll", "a", vec![1.0, 0.0, 0.0])
.unwrap();
let index = Arc::new(HNSWIndex::new());
engine.cache_hnsw_index("test_coll", index, vec!["a".to_string()]);
assert!(engine.hnsw_cache.read().contains_key("test_coll"));
engine.delete_from_collection("test_coll", "a").unwrap();
assert!(!engine.hnsw_cache.read().contains_key("test_coll"));
}
#[test]
fn invalidate_hnsw_cache_nonexistent_collection() {
let engine = VectorEngine::new();
engine.invalidate_hnsw_cache("nonexistent");
}
#[test]
fn cache_search_similar_empty_cache_falls_through() {
let engine = VectorEngine::new();
engine.store_embedding("a", vec![1.0, 0.0, 0.0]).unwrap();
engine.store_embedding("b", vec![0.0, 1.0, 0.0]).unwrap();
let results = engine.search_similar(&[1.0, 0.0, 0.0], 2).unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].key, "a");
}
#[test]
fn cache_search_in_collection_uses_cached_index() {
let engine = VectorEngine::new();
engine
.create_collection("docs", VectorCollectionConfig::default())
.unwrap();
engine
.store_in_collection("docs", "d1", vec![1.0, 0.0, 0.0])
.unwrap();
engine
.store_in_collection("docs", "d2", vec![0.0, 1.0, 0.0])
.unwrap();
engine
.store_in_collection("docs", "d3", vec![0.0, 0.0, 1.0])
.unwrap();
let prefix = VectorEngine::collection_embedding_prefix("docs");
let keys_in_store: Vec<String> = engine.store().scan(&prefix);
let index = HNSWIndex::new();
let mut key_mapping = Vec::new();
for storage_key in &keys_in_store {
let tensor = engine.store().get(storage_key).unwrap();
if let Some(TensorValue::Vector(v)) = tensor.get("vector") {
index.insert(v.clone());
key_mapping.push(storage_key.clone());
}
}
engine.cache_hnsw_index("docs", Arc::new(index), key_mapping);
let results = engine
.search_in_collection("docs", &[1.0, 0.0, 0.0], 2)
.unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].key, "d1");
}
#[test]
fn build_and_cache_index_empty_store() {
let engine = VectorEngine::new();
engine.build_and_cache_index(HNSWConfig::default()).unwrap();
let cache = engine.hnsw_cache.read();
let entry = cache.get("_default");
assert!(entry.is_some());
let (_, keys) = entry.unwrap();
assert!(keys.is_empty());
}
#[test]
fn build_and_cache_index_search_results_match_brute_force() {
let engine = VectorEngine::new();
engine.store_embedding("a", vec![1.0, 0.0, 0.0]).unwrap();
engine.store_embedding("b", vec![0.9, 0.1, 0.0]).unwrap();
engine.store_embedding("c", vec![0.0, 1.0, 0.0]).unwrap();
engine.store_embedding("d", vec![0.0, 0.0, 1.0]).unwrap();
let brute_results = engine.search_similar(&[1.0, 0.0, 0.0], 4).unwrap();
engine.build_and_cache_index(HNSWConfig::default()).unwrap();
let cached_results = engine.search_similar(&[1.0, 0.0, 0.0], 4).unwrap();
assert_eq!(brute_results.len(), cached_results.len());
assert_eq!(brute_results[0].key, "a");
assert_eq!(cached_results[0].key, "a");
}
#[test]
fn cache_manual_insert_and_invalidate() {
let engine = VectorEngine::new();
let index = Arc::new(HNSWIndex::new());
let keys = vec!["key1".to_string(), "key2".to_string()];
engine.cache_hnsw_index("my_coll", index, keys);
assert!(engine.hnsw_cache.read().contains_key("my_coll"));
engine.invalidate_hnsw_cache("my_coll");
assert!(!engine.hnsw_cache.read().contains_key("my_coll"));
}
#[test]
fn cache_does_not_cross_collections() {
let engine = VectorEngine::new();
engine.store_embedding("a", vec![1.0, 0.0, 0.0]).unwrap();
engine.build_and_cache_index(HNSWConfig::default()).unwrap();
assert!(engine.hnsw_cache.read().contains_key("_default"));
let index = Arc::new(HNSWIndex::new());
engine.cache_hnsw_index("other", index, vec![]);
assert!(engine.hnsw_cache.read().contains_key("other"));
engine.invalidate_hnsw_cache("_default");
assert!(!engine.hnsw_cache.read().contains_key("_default"));
assert!(engine.hnsw_cache.read().contains_key("other"));
}
#[test]
fn cache_empty_mapping_falls_through_to_brute_force() {
let engine = VectorEngine::new();
engine.store_embedding("a", vec![1.0, 0.0, 0.0]).unwrap();
engine.store_embedding("b", vec![0.0, 1.0, 0.0]).unwrap();
let index = Arc::new(HNSWIndex::new());
engine.cache_hnsw_index("_default", index, Vec::new());
let results = engine.search_similar(&[1.0, 0.0, 0.0], 2).unwrap();
assert_eq!(results.len(), 2);
assert_eq!(results[0].key, "a");
}
}