use crate::core::error::Result;
use crate::core::id::NodeId;
use crate::core::temporal::Timestamp;
use std::path::PathBuf;
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum Quantization {
#[default]
F32,
F16,
I8,
}
impl Quantization {
pub fn to_u8(self) -> u8 {
match self {
Quantization::F32 => 0,
Quantization::F16 => 1,
Quantization::I8 => 2,
}
}
pub fn from_u8(value: u8) -> Result<Self> {
match value {
0 => Ok(Quantization::F32),
1 => Ok(Quantization::F16),
2 => Ok(Quantization::I8),
_ => Err(crate::core::error::StorageError::CorruptedData(format!(
"Invalid quantization encoding: {}",
value
))
.into()),
}
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub enum StorageMode {
#[default]
InMemory,
MemoryMapped {
path: PathBuf,
},
}
pub struct CustomMetric {
pub name: String,
#[allow(clippy::type_complexity)]
pub distance_fn: Arc<dyn Fn(&[f32], &[f32]) -> f32 + Send + Sync>,
}
impl Clone for CustomMetric {
fn clone(&self) -> Self {
CustomMetric {
name: self.name.clone(),
distance_fn: Arc::clone(&self.distance_fn),
}
}
}
impl std::fmt::Debug for CustomMetric {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CustomMetric")
.field("name", &self.name)
.field("distance_fn", &"<function>")
.finish()
}
}
impl PartialEq for CustomMetric {
fn eq(&self, other: &Self) -> bool {
self.name == other.name
}
}
pub type TemporalSearchResults = Vec<(Timestamp, Vec<(NodeId, f32)>)>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DistanceMetric {
Cosine,
Euclidean,
DotProduct,
Haversine,
Hamming,
Tanimoto,
}
impl DistanceMetric {
pub fn to_u8(self) -> u8 {
match self {
DistanceMetric::Cosine => 0,
DistanceMetric::Euclidean => 1,
DistanceMetric::DotProduct => 2,
DistanceMetric::Haversine => 3,
DistanceMetric::Hamming => 4,
DistanceMetric::Tanimoto => 5,
}
}
pub fn from_u8(value: u8) -> Result<Self> {
match value {
0 => Ok(DistanceMetric::Cosine),
1 => Ok(DistanceMetric::Euclidean),
2 => Ok(DistanceMetric::DotProduct),
3 => Ok(DistanceMetric::Haversine),
4 => Ok(DistanceMetric::Hamming),
5 => Ok(DistanceMetric::Tanimoto),
_ => Err(crate::core::error::StorageError::CorruptedData(format!(
"Invalid distance metric encoding: {}",
value
))
.into()),
}
}
}
pub trait VectorIndex: Send + Sync {
fn add(&self, id: NodeId, vector: &[f32]) -> Result<()>;
fn remove(&self, id: NodeId) -> Result<()>;
fn search(&self, query: &[f32], k: usize) -> Result<Vec<(NodeId, f32)>>;
fn search_with_filter<F>(
&self,
query: &[f32],
k: usize,
predicate: F,
) -> Result<Vec<(NodeId, f32)>>
where
F: Fn(&NodeId) -> bool + Send + Sync;
#[must_use]
fn len(&self) -> usize;
#[must_use]
fn dimensions(&self) -> usize;
#[must_use]
fn distance_metric(&self) -> DistanceMetric;
#[must_use]
fn is_empty(&self) -> bool {
self.len() == 0
}
fn add_batch(&self, items: &[(NodeId, Vec<f32>)]) -> Result<()> {
for (id, vec) in items {
self.add(*id, vec)?;
}
Ok(())
}
fn add_batch_ref(&self, items: &[(NodeId, &[f32])]) -> Result<()> {
for (id, vec) in items {
self.add(*id, vec)?;
}
Ok(())
}
fn remove_batch(&self, ids: &[NodeId]) -> Result<()> {
for id in ids {
self.remove(*id)?;
}
Ok(())
}
fn save(&self, _path: &std::path::Path) -> Result<()> {
Err(crate::core::error::Error::Vector(
crate::core::error::VectorError::IndexError(
"save not supported by this index type".to_string(),
),
))
}
fn memory_usage(&self) -> usize {
0
}
fn quantization(&self) -> Quantization {
Quantization::F32
}
fn compact(&self) -> Result<()> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_distance_metric_debug() {
let metric = DistanceMetric::Cosine;
assert_eq!(format!("{:?}", metric), "Cosine");
assert_eq!(metric, DistanceMetric::Cosine);
assert_ne!(metric, DistanceMetric::Euclidean);
}
#[test]
fn test_quantization_default() {
assert_eq!(Quantization::default(), Quantization::F32);
}
#[test]
fn test_storage_mode_default() {
assert!(matches!(StorageMode::default(), StorageMode::InMemory));
}
#[test]
fn test_distance_metric_new_variants() {
assert_eq!(DistanceMetric::Haversine.to_u8(), 3);
assert_eq!(DistanceMetric::Hamming.to_u8(), 4);
assert_eq!(DistanceMetric::Tanimoto.to_u8(), 5);
assert_eq!(
DistanceMetric::from_u8(3).unwrap(),
DistanceMetric::Haversine
);
assert_eq!(DistanceMetric::from_u8(4).unwrap(), DistanceMetric::Hamming);
assert_eq!(
DistanceMetric::from_u8(5).unwrap(),
DistanceMetric::Tanimoto
);
}
#[test]
fn test_quantization_variants() {
assert_eq!(Quantization::F32.to_u8(), 0);
assert_eq!(Quantization::F16.to_u8(), 1);
assert_eq!(Quantization::I8.to_u8(), 2);
assert_eq!(Quantization::from_u8(0).unwrap(), Quantization::F32);
assert_eq!(Quantization::from_u8(1).unwrap(), Quantization::F16);
assert_eq!(Quantization::from_u8(2).unwrap(), Quantization::I8);
assert!(Quantization::from_u8(3).is_err());
}
}
pub mod hnsw;
pub mod temporal;
pub use hnsw::{HnswConfig, HnswIndex, HnswIndexBuilder};
pub use temporal::{
DriftMetric, RetentionPolicy, SnapshotInfo, SnapshotStrategy, TemporalVectorConfig,
TemporalVectorIndex, VectorIndexObserver,
};
pub mod sparse;
pub use sparse::{
ScoringMethod, SparseIndexConfig, SparseIndexStats, SparseVectorIndex, hybrid_fusion,
reciprocal_rank_fusion,
};
pub mod sharded;
pub use sharded::{
RebalanceConfig as ShardedRebalanceConfig, ShardStats, ShardedVectorConfig, ShardedVectorIndex,
ShardingStrategy,
};
#[derive(Debug, Clone, Copy, PartialEq)]
pub(crate) struct OrderedFloat(pub f32);
impl Eq for OrderedFloat {}
impl PartialOrd for OrderedFloat {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for OrderedFloat {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.0
.partial_cmp(&other.0)
.unwrap_or_else(|| match (self.0.is_nan(), other.0.is_nan()) {
(true, true) => std::cmp::Ordering::Equal,
(true, false) => std::cmp::Ordering::Less,
(false, true) => std::cmp::Ordering::Greater,
(false, false) => unreachable!(),
})
}
}
pub(crate) fn merge_top_k_results(
all_results: Vec<Vec<(crate::core::id::NodeId, f32)>>,
k: usize,
) -> Vec<(crate::core::id::NodeId, f32)> {
use std::cmp::Reverse;
use std::collections::BinaryHeap;
if k == 0 {
return Vec::new();
}
let mut heap: BinaryHeap<(Reverse<OrderedFloat>, crate::core::id::NodeId)> =
BinaryHeap::with_capacity(k + 1);
for results in all_results {
for (id, score) in results {
let ordered_score = OrderedFloat(score);
if heap.len() < k {
heap.push((Reverse(ordered_score), id));
} else if let Some(&(Reverse(min_score), _)) = heap.peek()
&& ordered_score > min_score
{
heap.pop();
heap.push((Reverse(ordered_score), id));
}
}
}
let mut results: Vec<(crate::core::id::NodeId, f32)> = heap
.into_iter()
.map(|(Reverse(score), id)| (id, score.0))
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results
}
pub mod distributed;
pub use distributed::{
CircuitBreakerConfig as DistributedCircuitBreakerConfig, CircuitState, DistributedError,
DistributedIndexStats, DistributedVectorConfig, DistributedVectorIndex, MockVectorNodeClient,
NodeCircuitBreaker, NodeConnection, NodeConnectionStats, RECOMMENDED_IMBALANCE_THRESHOLD,
RebalanceStats, RoutingStrategy, VectorNodeClient, VectorNodeConfig,
};