use std::collections::{BTreeMap, BTreeSet};
use serde::{Deserialize, Serialize};
use crate::{
cosine_similarity_bounded, dot_product, euclidean_similarity, manhattan_distance, LoraVector,
};
use super::hnsw::{seed_from_name, HnswBackend, HnswParams, HnswSnapshot};
use super::index_catalog::{IndexConfigValue, StoredIndexEntity};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum VectorSimilarity {
Cosine,
Euclidean,
Dot,
Manhattan,
}
impl VectorSimilarity {
pub fn parse(s: &str) -> Option<Self> {
if s.eq_ignore_ascii_case("cosine") {
Some(VectorSimilarity::Cosine)
} else if s.eq_ignore_ascii_case("euclidean") {
Some(VectorSimilarity::Euclidean)
} else if s.eq_ignore_ascii_case("dot") || s.eq_ignore_ascii_case("dot_product") {
Some(VectorSimilarity::Dot)
} else if s.eq_ignore_ascii_case("manhattan") {
Some(VectorSimilarity::Manhattan)
} else {
None
}
}
pub fn score(self, a: &LoraVector, b: &LoraVector) -> Option<f64> {
if a.dimension != b.dimension {
return None;
}
match self {
VectorSimilarity::Cosine => cosine_similarity_bounded(a, b),
VectorSimilarity::Euclidean => euclidean_similarity(a, b),
VectorSimilarity::Dot => dot_product(a, b),
VectorSimilarity::Manhattan => manhattan_distance(a, b).map(|d| 1.0 / (1.0 + d)),
}
}
pub(super) fn from_options(options: &BTreeMap<String, IndexConfigValue>) -> Option<Self> {
match options.get("vector.similarity_function")? {
IndexConfigValue::String(s) => Self::parse(s),
_ => None,
}
}
}
#[derive(Debug, Default, Clone)]
pub(super) struct FlatBackend {
items: BTreeMap<u64, LoraVector>,
}
impl FlatBackend {
fn insert(&mut self, id: u64, vector: LoraVector) {
self.items.insert(id, vector);
}
fn remove(&mut self, id: u64) {
self.items.remove(&id);
}
fn query(
&self,
query: &LoraVector,
similarity: VectorSimilarity,
restrict_to: Option<&BTreeSet<u64>>,
) -> Vec<(u64, f64)> {
let mut out = Vec::with_capacity(self.items.len());
for (&id, v) in &self.items {
if let Some(set) = restrict_to {
if !set.contains(&id) {
continue;
}
}
if let Some(score) = similarity.score(v, query) {
out.push((id, score));
}
}
out
}
#[cfg(test)]
fn len(&self) -> usize {
self.items.len()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum VectorIndexProvider {
Flat,
Hnsw,
}
impl VectorIndexProvider {
pub fn parse(s: &str) -> Option<Self> {
if s.eq_ignore_ascii_case("flat") {
Some(VectorIndexProvider::Flat)
} else if s.eq_ignore_ascii_case("hnsw") {
Some(VectorIndexProvider::Hnsw)
} else {
None
}
}
pub(super) fn from_options(options: &BTreeMap<String, IndexConfigValue>) -> Option<Self> {
match options.get("vector.indexProvider")? {
IndexConfigValue::String(s) => Self::parse(s),
_ => None,
}
}
}
#[derive(Debug, Clone)]
pub(super) enum VectorBackend {
Flat(FlatBackend),
Hnsw(HnswBackend),
}
impl VectorBackend {
fn insert(&mut self, id: u64, vector: LoraVector) {
match self {
VectorBackend::Flat(b) => b.insert(id, vector),
VectorBackend::Hnsw(b) => b.insert(id, vector),
}
}
fn remove(&mut self, id: u64) {
match self {
VectorBackend::Flat(b) => b.remove(id),
VectorBackend::Hnsw(b) => b.remove(id),
}
}
fn query(
&self,
query: &LoraVector,
similarity: VectorSimilarity,
k: usize,
restrict_to: Option<&BTreeSet<u64>>,
) -> Vec<(u64, f64)> {
match self {
VectorBackend::Flat(b) => b.query(query, similarity, restrict_to),
VectorBackend::Hnsw(b) => b.query(query, k, restrict_to),
}
}
#[cfg(test)]
fn len(&self) -> usize {
match self {
VectorBackend::Flat(b) => b.len(),
VectorBackend::Hnsw(b) => b.len(),
}
}
}
#[derive(Debug, Clone)]
pub(super) struct VectorIndexEntry {
pub label: String,
pub property: String,
pub similarity: VectorSimilarity,
pub backend: VectorBackend,
}
#[derive(Debug, Default, Clone)]
pub(super) struct VectorIndexRegistry {
by_name: BTreeMap<String, VectorIndexEntry>,
}
impl VectorIndexRegistry {
pub(super) fn register(
&mut self,
name: String,
label: String,
property: String,
similarity: VectorSimilarity,
provider: VectorIndexProvider,
hnsw: HnswParams,
) {
let backend = match provider {
VectorIndexProvider::Flat => VectorBackend::Flat(FlatBackend::default()),
VectorIndexProvider::Hnsw => {
let seed = seed_from_name(&name);
VectorBackend::Hnsw(HnswBackend::new(similarity, hnsw, seed))
}
};
self.by_name.insert(
name,
VectorIndexEntry {
label,
property,
similarity,
backend,
},
);
}
pub(super) fn deregister(&mut self, name: &str) {
self.by_name.remove(name);
}
pub(super) fn is_empty(&self) -> bool {
self.by_name.is_empty()
}
pub(super) fn insert_for(
&mut self,
label: &str,
property: &str,
entity_id: u64,
vector: &LoraVector,
) {
for entry in self.by_name.values_mut() {
if entry.label == label && entry.property == property {
entry.backend.insert(entity_id, vector.clone());
}
}
}
pub(super) fn remove_for(&mut self, label: &str, property: &str, entity_id: u64) {
for entry in self.by_name.values_mut() {
if entry.label == label && entry.property == property {
entry.backend.remove(entity_id);
}
}
}
pub(super) fn query(
&self,
name: &str,
query: &LoraVector,
k: usize,
restrict_to: Option<&BTreeSet<u64>>,
) -> Option<Vec<(u64, f64)>> {
let entry = self.by_name.get(name)?;
Some(entry.backend.query(query, entry.similarity, k, restrict_to))
}
pub(super) fn to_snapshots(&self, entity: StoredIndexEntity) -> Vec<VectorIndexSnapshot> {
let mut out = Vec::new();
for (name, entry) in &self.by_name {
if let VectorBackend::Hnsw(b) = &entry.backend {
out.push(VectorIndexSnapshot {
name: name.clone(),
entity,
label: entry.label.clone(),
property: entry.property.clone(),
data: VectorBackendSnapshot::Hnsw(b.to_snapshot(entry.similarity)),
});
}
}
out
}
pub(super) fn restore_snapshot(&mut self, snapshot: VectorIndexSnapshot) -> bool {
let Some(entry) = self.by_name.get_mut(&snapshot.name) else {
return false;
};
if entry.label != snapshot.label || entry.property != snapshot.property {
return false;
}
match snapshot.data {
VectorBackendSnapshot::Hnsw(snap) => {
if !matches!(entry.backend, VectorBackend::Hnsw(_)) {
return false;
}
entry.similarity = snap.similarity;
entry.backend = VectorBackend::Hnsw(HnswBackend::from_snapshot(snap));
true
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct VectorIndexSnapshot {
pub name: String,
pub entity: StoredIndexEntity,
pub label: String,
pub property: String,
pub data: VectorBackendSnapshot,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum VectorBackendSnapshot {
Hnsw(HnswSnapshot),
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{RawCoordinate, VectorCoordinateType};
fn vec(values: &[f32]) -> LoraVector {
let coords: Vec<RawCoordinate> = values
.iter()
.map(|v| RawCoordinate::Float(*v as f64))
.collect();
LoraVector::try_new(coords, values.len() as i64, VectorCoordinateType::Float32).unwrap()
}
fn register_flat(
reg: &mut VectorIndexRegistry,
name: &str,
label: &str,
prop: &str,
sim: VectorSimilarity,
) {
reg.register(
name.into(),
label.into(),
prop.into(),
sim,
VectorIndexProvider::Flat,
HnswParams::default(),
);
}
#[test]
fn register_and_query_returns_scores() {
let mut reg = VectorIndexRegistry::default();
register_flat(&mut reg, "vidx", "V", "e", VectorSimilarity::Cosine);
reg.insert_for("V", "e", 1, &vec(&[1.0, 0.0, 0.0]));
reg.insert_for("V", "e", 2, &vec(&[0.0, 1.0, 0.0]));
let scored = reg.query("vidx", &vec(&[1.0, 0.0, 0.0]), 10, None).unwrap();
assert_eq!(scored.len(), 2);
let by_id: BTreeMap<u64, f64> = scored.into_iter().collect();
assert!((by_id[&1] - 1.0).abs() < 1e-9);
assert!(by_id[&2] < by_id[&1]);
}
#[test]
fn remove_drops_from_backend() {
let mut reg = VectorIndexRegistry::default();
register_flat(&mut reg, "vidx", "V", "e", VectorSimilarity::Cosine);
reg.insert_for("V", "e", 1, &vec(&[1.0, 0.0]));
reg.insert_for("V", "e", 2, &vec(&[0.0, 1.0]));
reg.remove_for("V", "e", 1);
let scored = reg.query("vidx", &vec(&[1.0, 0.0]), 10, None).unwrap();
assert_eq!(scored.len(), 1);
assert_eq!(scored[0].0, 2);
}
#[test]
fn unrelated_scope_is_skipped() {
let mut reg = VectorIndexRegistry::default();
register_flat(
&mut reg,
"movie_emb",
"Movie",
"embedding",
VectorSimilarity::Cosine,
);
reg.insert_for("Other", "embedding", 99, &vec(&[1.0, 0.0]));
let scored = reg.query("movie_emb", &vec(&[1.0, 0.0]), 10, None).unwrap();
assert!(scored.is_empty());
}
#[test]
fn two_indexes_on_same_scope_with_different_metrics() {
let mut reg = VectorIndexRegistry::default();
register_flat(&mut reg, "by_cos", "V", "e", VectorSimilarity::Cosine);
register_flat(&mut reg, "by_euc", "V", "e", VectorSimilarity::Euclidean);
reg.insert_for("V", "e", 1, &vec(&[1.0, 0.0]));
reg.insert_for("V", "e", 2, &vec(&[0.0, 1.0]));
let cos = reg.query("by_cos", &vec(&[1.0, 0.0]), 10, None).unwrap();
let euc = reg.query("by_euc", &vec(&[1.0, 0.0]), 10, None).unwrap();
assert_eq!(cos.len(), 2);
assert_eq!(euc.len(), 2);
for entry in reg.by_name.values() {
assert_eq!(entry.backend.len(), 2);
}
}
#[test]
fn hnsw_provider_returns_top_k() {
let mut reg = VectorIndexRegistry::default();
reg.register(
"vh".into(),
"V".into(),
"e".into(),
VectorSimilarity::Cosine,
VectorIndexProvider::Hnsw,
HnswParams::default(),
);
for i in 0..50u64 {
let v = vec(&[(i as f32) / 50.0, 1.0 - (i as f32) / 50.0]);
reg.insert_for("V", "e", i, &v);
}
let hits = reg.query("vh", &vec(&[1.0, 0.0]), 5, None).unwrap();
assert_eq!(hits.len(), 5);
let ids: Vec<u64> = hits.iter().map(|(id, _)| *id).collect();
assert!(ids.contains(&49) || ids.contains(&48), "got {ids:?}");
}
}