use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use crate::types::NodeId;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub enum DistanceMetric {
#[default]
Cosine,
Euclidean,
DotProduct,
}
impl DistanceMetric {
pub fn distance_fn(&self) -> fn(&[f32], &[f32]) -> f32 {
match self {
DistanceMetric::Cosine => super::distance::cosine_distance,
DistanceMetric::Euclidean => super::distance::euclidean_distance,
DistanceMetric::DotProduct => |a, b| -super::distance::dot_product(a, b),
}
}
pub fn distance_to_similarity(&self, distance: f32) -> f32 {
match self {
DistanceMetric::Cosine => 1.0 - distance,
DistanceMetric::Euclidean => 1.0 / (1.0 + distance),
DistanceMetric::DotProduct => -distance, }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub enum MultiQueryAggregation {
#[default]
Min,
Max,
Avg,
Sum,
}
impl MultiQueryAggregation {
pub fn aggregate(&self, distances: &[f32]) -> f32 {
if distances.is_empty() {
return f32::INFINITY;
}
match self {
MultiQueryAggregation::Min => distances.iter().cloned().fold(f32::INFINITY, f32::min),
MultiQueryAggregation::Max => distances.iter().cloned().fold(f32::NEG_INFINITY, f32::max),
MultiQueryAggregation::Avg => distances.iter().sum::<f32>() / distances.len() as f32,
MultiQueryAggregation::Sum => distances.iter().sum(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorStoreConfig {
pub dimensions: usize,
pub metric: DistanceMetric,
pub row_group_size: usize,
pub fragment_target_size: usize,
pub normalize_on_insert: bool,
}
impl Default for VectorStoreConfig {
fn default() -> Self {
Self {
dimensions: 384,
metric: DistanceMetric::Cosine,
row_group_size: 1024,
fragment_target_size: 100_000,
normalize_on_insert: true,
}
}
}
impl VectorStoreConfig {
pub fn new(dimensions: usize) -> Self {
Self {
dimensions,
..Default::default()
}
}
pub fn with_metric(mut self, metric: DistanceMetric) -> Self {
self.metric = metric;
self
}
pub fn with_row_group_size(mut self, size: usize) -> Self {
self.row_group_size = size;
self
}
pub fn with_fragment_target_size(mut self, size: usize) -> Self {
self.fragment_target_size = size;
self
}
pub fn with_normalize(mut self, normalize: bool) -> Self {
self.normalize_on_insert = normalize;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RowGroup {
pub id: usize,
pub count: usize,
pub data: Vec<f32>,
}
impl RowGroup {
pub fn new(id: usize, capacity: usize, dimensions: usize) -> Self {
Self {
id,
count: 0,
data: Vec::with_capacity(capacity * dimensions),
}
}
pub fn is_full(&self, row_group_size: usize) -> bool {
self.count >= row_group_size
}
pub fn get(&self, index: usize, dimensions: usize) -> Option<&[f32]> {
if index >= self.count {
return None;
}
let offset = index * dimensions;
Some(&self.data[offset..offset + dimensions])
}
pub fn append(&mut self, vector: &[f32]) -> usize {
let index = self.count;
self.data.extend_from_slice(vector);
self.count += 1;
index
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub enum FragmentState {
#[default]
Active,
Sealed,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Fragment {
pub id: usize,
pub state: FragmentState,
pub row_groups: Vec<RowGroup>,
pub total_vectors: usize,
pub deletion_bitmap: Vec<u32>,
pub deleted_count: usize,
}
impl Fragment {
pub fn new(id: usize) -> Self {
Self {
id,
state: FragmentState::Active,
row_groups: Vec::new(),
total_vectors: 0,
deletion_bitmap: Vec::new(),
deleted_count: 0,
}
}
pub fn is_deleted(&self, index: usize) -> bool {
let word_idx = index / 32;
let bit_idx = index % 32;
if word_idx >= self.deletion_bitmap.len() {
return false;
}
(self.deletion_bitmap[word_idx] & (1 << bit_idx)) != 0
}
pub fn delete(&mut self, index: usize) -> bool {
if index >= self.total_vectors || self.is_deleted(index) {
return false;
}
let word_idx = index / 32;
let bit_idx = index % 32;
while self.deletion_bitmap.len() <= word_idx {
self.deletion_bitmap.push(0);
}
self.deletion_bitmap[word_idx] |= 1 << bit_idx;
self.deleted_count += 1;
true
}
pub fn live_count(&self) -> usize {
self.total_vectors - self.deleted_count
}
pub fn should_seal(&self, config: &VectorStoreConfig) -> bool {
self.total_vectors >= config.fragment_target_size
}
pub fn seal(&mut self) {
self.state = FragmentState::Sealed;
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct VectorLocation {
pub fragment_id: usize,
pub local_index: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorManifest {
pub config: VectorStoreConfig,
pub fragments: Vec<Fragment>,
pub active_fragment_id: usize,
pub total_vectors: usize,
pub total_deleted: usize,
pub next_vector_id: u64,
pub node_to_vector: HashMap<NodeId, u64>,
pub vector_to_node: HashMap<u64, NodeId>,
pub vector_locations: HashMap<u64, VectorLocation>,
}
impl VectorManifest {
pub fn new(config: VectorStoreConfig) -> Self {
let initial_fragment = Fragment::new(0);
Self {
config,
fragments: vec![initial_fragment],
active_fragment_id: 0,
total_vectors: 0,
total_deleted: 0,
next_vector_id: 0,
node_to_vector: HashMap::new(),
vector_to_node: HashMap::new(),
vector_locations: HashMap::new(),
}
}
pub fn active_fragment(&self) -> Option<&Fragment> {
self
.fragments
.iter()
.find(|f| f.id == self.active_fragment_id)
}
pub fn active_fragment_mut(&mut self) -> Option<&mut Fragment> {
let id = self.active_fragment_id;
self.fragments.iter_mut().find(|f| f.id == id)
}
pub fn live_count(&self) -> usize {
self.total_vectors - self.total_deleted
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IvfConfig {
pub n_clusters: usize,
pub n_probe: usize,
pub metric: DistanceMetric,
}
impl Default for IvfConfig {
fn default() -> Self {
Self {
n_clusters: 100,
n_probe: 10,
metric: DistanceMetric::Cosine,
}
}
}
impl IvfConfig {
pub fn new(n_clusters: usize) -> Self {
Self {
n_clusters,
..Default::default()
}
}
pub fn with_n_probe(mut self, n_probe: usize) -> Self {
self.n_probe = n_probe;
self
}
pub fn with_metric(mut self, metric: DistanceMetric) -> Self {
self.metric = metric;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PqConfig {
pub num_subspaces: usize,
pub num_centroids: usize,
pub max_iterations: usize,
}
impl Default for PqConfig {
fn default() -> Self {
Self {
num_subspaces: 48, num_centroids: 256, max_iterations: 20,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorSearchResult {
pub vector_id: u64,
pub node_id: NodeId,
pub distance: f32,
pub similarity: f32,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_distance_metric_default() {
assert_eq!(DistanceMetric::default(), DistanceMetric::Cosine);
}
#[test]
fn test_vector_store_config_default() {
let config = VectorStoreConfig::default();
assert_eq!(config.dimensions, 384);
assert_eq!(config.metric, DistanceMetric::Cosine);
assert!(config.normalize_on_insert);
}
#[test]
fn test_vector_store_config_builder() {
let config = VectorStoreConfig::new(768)
.with_metric(DistanceMetric::Euclidean)
.with_row_group_size(512)
.with_normalize(false);
assert_eq!(config.dimensions, 768);
assert_eq!(config.metric, DistanceMetric::Euclidean);
assert_eq!(config.row_group_size, 512);
assert!(!config.normalize_on_insert);
}
#[test]
fn test_row_group_new() {
let rg = RowGroup::new(0, 100, 128);
assert_eq!(rg.id, 0);
assert_eq!(rg.count, 0);
assert!(!rg.is_full(100));
}
#[test]
fn test_row_group_append() {
let mut rg = RowGroup::new(0, 10, 4);
let vec1 = [1.0, 2.0, 3.0, 4.0];
let idx = rg.append(&vec1);
assert_eq!(idx, 0);
assert_eq!(rg.count, 1);
let retrieved = rg.get(0, 4).expect("expected value");
assert_eq!(retrieved, &vec1);
}
#[test]
fn test_row_group_full() {
let mut rg = RowGroup::new(0, 10, 4);
for i in 0..10 {
rg.append(&[i as f32; 4]);
}
assert!(rg.is_full(10));
}
#[test]
fn test_fragment_new() {
let frag = Fragment::new(0);
assert_eq!(frag.id, 0);
assert_eq!(frag.state, FragmentState::Active);
assert_eq!(frag.total_vectors, 0);
assert_eq!(frag.live_count(), 0);
}
#[test]
fn test_fragment_deletion() {
let mut frag = Fragment::new(0);
frag.total_vectors = 100;
assert!(frag.delete(5));
assert!(frag.is_deleted(5));
assert!(!frag.is_deleted(4));
assert_eq!(frag.deleted_count, 1);
assert_eq!(frag.live_count(), 99);
assert!(!frag.delete(5));
assert_eq!(frag.deleted_count, 1);
}
#[test]
fn test_fragment_seal() {
let mut frag = Fragment::new(0);
assert_eq!(frag.state, FragmentState::Active);
frag.seal();
assert_eq!(frag.state, FragmentState::Sealed);
}
#[test]
fn test_vector_manifest_new() {
let config = VectorStoreConfig::new(128);
let manifest = VectorManifest::new(config);
assert_eq!(manifest.total_vectors, 0);
assert_eq!(manifest.fragments.len(), 1);
assert_eq!(manifest.active_fragment_id, 0);
}
#[test]
fn test_ivf_config_default() {
let config = IvfConfig::default();
assert_eq!(config.n_clusters, 100);
assert_eq!(config.n_probe, 10);
assert_eq!(config.metric, DistanceMetric::Cosine);
}
#[test]
fn test_pq_config_default() {
let config = PqConfig::default();
assert_eq!(config.num_subspaces, 48);
assert_eq!(config.num_centroids, 256);
assert_eq!(config.max_iterations, 20);
}
#[test]
fn test_distance_to_similarity() {
assert_eq!(DistanceMetric::Cosine.distance_to_similarity(0.2), 0.8);
assert_eq!(DistanceMetric::Euclidean.distance_to_similarity(1.0), 0.5);
}
}