1use std::collections::{BTreeMap, BTreeSet};
18
19use serde::{Deserialize, Serialize};
20
21use crate::{
22 cosine_similarity_bounded, dot_product, euclidean_similarity, manhattan_distance, LoraVector,
23};
24
25use super::hnsw::{seed_from_name, HnswBackend, HnswParams, HnswSnapshot};
26use super::index_catalog::{IndexConfigValue, StoredIndexEntity};
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
29pub enum VectorSimilarity {
30 Cosine,
31 Euclidean,
32 Dot,
37 Manhattan,
41}
42
43impl VectorSimilarity {
44 pub fn parse(s: &str) -> Option<Self> {
45 if s.eq_ignore_ascii_case("cosine") {
46 Some(VectorSimilarity::Cosine)
47 } else if s.eq_ignore_ascii_case("euclidean") {
48 Some(VectorSimilarity::Euclidean)
49 } else if s.eq_ignore_ascii_case("dot") || s.eq_ignore_ascii_case("dot_product") {
50 Some(VectorSimilarity::Dot)
51 } else if s.eq_ignore_ascii_case("manhattan") {
52 Some(VectorSimilarity::Manhattan)
53 } else {
54 None
55 }
56 }
57
58 pub fn score(self, a: &LoraVector, b: &LoraVector) -> Option<f64> {
59 if a.dimension != b.dimension {
60 return None;
61 }
62 match self {
63 VectorSimilarity::Cosine => cosine_similarity_bounded(a, b),
64 VectorSimilarity::Euclidean => euclidean_similarity(a, b),
65 VectorSimilarity::Dot => dot_product(a, b),
66 VectorSimilarity::Manhattan => manhattan_distance(a, b).map(|d| 1.0 / (1.0 + d)),
67 }
68 }
69
70 pub(super) fn from_options(options: &BTreeMap<String, IndexConfigValue>) -> Option<Self> {
76 match options.get("vector.similarity_function")? {
77 IndexConfigValue::String(s) => Self::parse(s),
78 _ => None,
79 }
80 }
81}
82
83#[derive(Debug, Default, Clone)]
88pub(super) struct FlatBackend {
89 items: BTreeMap<u64, LoraVector>,
90}
91
92impl FlatBackend {
93 fn insert(&mut self, id: u64, vector: LoraVector) {
94 self.items.insert(id, vector);
95 }
96
97 fn remove(&mut self, id: u64) {
98 self.items.remove(&id);
99 }
100
101 fn query(
102 &self,
103 query: &LoraVector,
104 similarity: VectorSimilarity,
105 restrict_to: Option<&BTreeSet<u64>>,
106 ) -> Vec<(u64, f64)> {
107 let mut out = Vec::with_capacity(self.items.len());
108 for (&id, v) in &self.items {
109 if let Some(set) = restrict_to {
110 if !set.contains(&id) {
111 continue;
112 }
113 }
114 if let Some(score) = similarity.score(v, query) {
115 out.push((id, score));
116 }
117 }
118 out
119 }
120
121 #[cfg(test)]
122 fn len(&self) -> usize {
123 self.items.len()
124 }
125}
126
127#[derive(Debug, Clone, Copy, PartialEq, Eq)]
130pub enum VectorIndexProvider {
131 Flat,
132 Hnsw,
133}
134
135impl VectorIndexProvider {
136 pub fn parse(s: &str) -> Option<Self> {
137 if s.eq_ignore_ascii_case("flat") {
138 Some(VectorIndexProvider::Flat)
139 } else if s.eq_ignore_ascii_case("hnsw") {
140 Some(VectorIndexProvider::Hnsw)
141 } else {
142 None
143 }
144 }
145
146 pub(super) fn from_options(options: &BTreeMap<String, IndexConfigValue>) -> Option<Self> {
150 match options.get("vector.indexProvider")? {
151 IndexConfigValue::String(s) => Self::parse(s),
152 _ => None,
153 }
154 }
155}
156
157#[derive(Debug, Clone)]
162pub(super) enum VectorBackend {
163 Flat(FlatBackend),
164 Hnsw(HnswBackend),
165}
166
167impl VectorBackend {
168 fn insert(&mut self, id: u64, vector: LoraVector) {
169 match self {
170 VectorBackend::Flat(b) => b.insert(id, vector),
171 VectorBackend::Hnsw(b) => b.insert(id, vector),
172 }
173 }
174
175 fn remove(&mut self, id: u64) {
176 match self {
177 VectorBackend::Flat(b) => b.remove(id),
178 VectorBackend::Hnsw(b) => b.remove(id),
179 }
180 }
181
182 fn query(
194 &self,
195 query: &LoraVector,
196 similarity: VectorSimilarity,
197 k: usize,
198 restrict_to: Option<&BTreeSet<u64>>,
199 ) -> Vec<(u64, f64)> {
200 match self {
201 VectorBackend::Flat(b) => b.query(query, similarity, restrict_to),
202 VectorBackend::Hnsw(b) => b.query(query, k, restrict_to),
203 }
204 }
205
206 #[cfg(test)]
207 fn len(&self) -> usize {
208 match self {
209 VectorBackend::Flat(b) => b.len(),
210 VectorBackend::Hnsw(b) => b.len(),
211 }
212 }
213}
214
215#[derive(Debug, Clone)]
219pub(super) struct VectorIndexEntry {
220 pub label: String,
221 pub property: String,
222 pub similarity: VectorSimilarity,
223 pub backend: VectorBackend,
224}
225
226#[derive(Debug, Default, Clone)]
228pub(super) struct VectorIndexRegistry {
229 by_name: BTreeMap<String, VectorIndexEntry>,
230}
231
232impl VectorIndexRegistry {
233 pub(super) fn register(
234 &mut self,
235 name: String,
236 label: String,
237 property: String,
238 similarity: VectorSimilarity,
239 provider: VectorIndexProvider,
240 hnsw: HnswParams,
241 ) {
242 let backend = match provider {
243 VectorIndexProvider::Flat => VectorBackend::Flat(FlatBackend::default()),
244 VectorIndexProvider::Hnsw => {
245 let seed = seed_from_name(&name);
246 VectorBackend::Hnsw(HnswBackend::new(similarity, hnsw, seed))
247 }
248 };
249 self.by_name.insert(
250 name,
251 VectorIndexEntry {
252 label,
253 property,
254 similarity,
255 backend,
256 },
257 );
258 }
259
260 pub(super) fn deregister(&mut self, name: &str) {
261 self.by_name.remove(name);
262 }
263
264 pub(super) fn is_empty(&self) -> bool {
265 self.by_name.is_empty()
266 }
267
268 pub(super) fn insert_for(
272 &mut self,
273 label: &str,
274 property: &str,
275 entity_id: u64,
276 vector: &LoraVector,
277 ) {
278 for entry in self.by_name.values_mut() {
279 if entry.label == label && entry.property == property {
280 entry.backend.insert(entity_id, vector.clone());
281 }
282 }
283 }
284
285 pub(super) fn remove_for(&mut self, label: &str, property: &str, entity_id: u64) {
288 for entry in self.by_name.values_mut() {
289 if entry.label == label && entry.property == property {
290 entry.backend.remove(entity_id);
291 }
292 }
293 }
294
295 pub(super) fn query(
303 &self,
304 name: &str,
305 query: &LoraVector,
306 k: usize,
307 restrict_to: Option<&BTreeSet<u64>>,
308 ) -> Option<Vec<(u64, f64)>> {
309 let entry = self.by_name.get(name)?;
310 Some(entry.backend.query(query, entry.similarity, k, restrict_to))
311 }
312
313 pub(super) fn to_snapshots(&self, entity: StoredIndexEntity) -> Vec<VectorIndexSnapshot> {
319 let mut out = Vec::new();
320 for (name, entry) in &self.by_name {
321 if let VectorBackend::Hnsw(b) = &entry.backend {
322 out.push(VectorIndexSnapshot {
323 name: name.clone(),
324 entity,
325 label: entry.label.clone(),
326 property: entry.property.clone(),
327 data: VectorBackendSnapshot::Hnsw(b.to_snapshot(entry.similarity)),
328 });
329 }
330 }
331 out
332 }
333
334 pub(super) fn restore_snapshot(&mut self, snapshot: VectorIndexSnapshot) -> bool {
341 let Some(entry) = self.by_name.get_mut(&snapshot.name) else {
342 return false;
343 };
344 if entry.label != snapshot.label || entry.property != snapshot.property {
345 return false;
346 }
347 match snapshot.data {
348 VectorBackendSnapshot::Hnsw(snap) => {
349 if !matches!(entry.backend, VectorBackend::Hnsw(_)) {
350 return false;
351 }
352 entry.similarity = snap.similarity;
353 entry.backend = VectorBackend::Hnsw(HnswBackend::from_snapshot(snap));
354 true
355 }
356 }
357 }
358}
359
360#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
366pub struct VectorIndexSnapshot {
367 pub name: String,
368 pub entity: StoredIndexEntity,
369 pub label: String,
370 pub property: String,
371 pub data: VectorBackendSnapshot,
372}
373
374#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
375pub enum VectorBackendSnapshot {
376 Hnsw(HnswSnapshot),
377}
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382 use crate::{RawCoordinate, VectorCoordinateType};
383
384 fn vec(values: &[f32]) -> LoraVector {
385 let coords: Vec<RawCoordinate> = values
386 .iter()
387 .map(|v| RawCoordinate::Float(*v as f64))
388 .collect();
389 LoraVector::try_new(coords, values.len() as i64, VectorCoordinateType::Float32).unwrap()
390 }
391
392 fn register_flat(
393 reg: &mut VectorIndexRegistry,
394 name: &str,
395 label: &str,
396 prop: &str,
397 sim: VectorSimilarity,
398 ) {
399 reg.register(
400 name.into(),
401 label.into(),
402 prop.into(),
403 sim,
404 VectorIndexProvider::Flat,
405 HnswParams::default(),
406 );
407 }
408
409 #[test]
410 fn register_and_query_returns_scores() {
411 let mut reg = VectorIndexRegistry::default();
412 register_flat(&mut reg, "vidx", "V", "e", VectorSimilarity::Cosine);
413 reg.insert_for("V", "e", 1, &vec(&[1.0, 0.0, 0.0]));
414 reg.insert_for("V", "e", 2, &vec(&[0.0, 1.0, 0.0]));
415 let scored = reg.query("vidx", &vec(&[1.0, 0.0, 0.0]), 10, None).unwrap();
416 assert_eq!(scored.len(), 2);
418 let by_id: BTreeMap<u64, f64> = scored.into_iter().collect();
419 assert!((by_id[&1] - 1.0).abs() < 1e-9);
420 assert!(by_id[&2] < by_id[&1]);
421 }
422
423 #[test]
424 fn remove_drops_from_backend() {
425 let mut reg = VectorIndexRegistry::default();
426 register_flat(&mut reg, "vidx", "V", "e", VectorSimilarity::Cosine);
427 reg.insert_for("V", "e", 1, &vec(&[1.0, 0.0]));
428 reg.insert_for("V", "e", 2, &vec(&[0.0, 1.0]));
429 reg.remove_for("V", "e", 1);
430 let scored = reg.query("vidx", &vec(&[1.0, 0.0]), 10, None).unwrap();
431 assert_eq!(scored.len(), 1);
432 assert_eq!(scored[0].0, 2);
433 }
434
435 #[test]
436 fn unrelated_scope_is_skipped() {
437 let mut reg = VectorIndexRegistry::default();
438 register_flat(
439 &mut reg,
440 "movie_emb",
441 "Movie",
442 "embedding",
443 VectorSimilarity::Cosine,
444 );
445 reg.insert_for("Other", "embedding", 99, &vec(&[1.0, 0.0]));
447 let scored = reg.query("movie_emb", &vec(&[1.0, 0.0]), 10, None).unwrap();
448 assert!(scored.is_empty());
449 }
450
451 #[test]
452 fn two_indexes_on_same_scope_with_different_metrics() {
453 let mut reg = VectorIndexRegistry::default();
454 register_flat(&mut reg, "by_cos", "V", "e", VectorSimilarity::Cosine);
455 register_flat(&mut reg, "by_euc", "V", "e", VectorSimilarity::Euclidean);
456 reg.insert_for("V", "e", 1, &vec(&[1.0, 0.0]));
457 reg.insert_for("V", "e", 2, &vec(&[0.0, 1.0]));
458 let cos = reg.query("by_cos", &vec(&[1.0, 0.0]), 10, None).unwrap();
459 let euc = reg.query("by_euc", &vec(&[1.0, 0.0]), 10, None).unwrap();
460 assert_eq!(cos.len(), 2);
461 assert_eq!(euc.len(), 2);
462 for entry in reg.by_name.values() {
464 assert_eq!(entry.backend.len(), 2);
465 }
466 }
467
468 #[test]
469 fn hnsw_provider_returns_top_k() {
470 let mut reg = VectorIndexRegistry::default();
471 reg.register(
472 "vh".into(),
473 "V".into(),
474 "e".into(),
475 VectorSimilarity::Cosine,
476 VectorIndexProvider::Hnsw,
477 HnswParams::default(),
478 );
479 for i in 0..50u64 {
480 let v = vec(&[(i as f32) / 50.0, 1.0 - (i as f32) / 50.0]);
481 reg.insert_for("V", "e", i, &v);
482 }
483 let hits = reg.query("vh", &vec(&[1.0, 0.0]), 5, None).unwrap();
484 assert_eq!(hits.len(), 5);
485 let ids: Vec<u64> = hits.iter().map(|(id, _)| *id).collect();
487 assert!(ids.contains(&49) || ids.contains(&48), "got {ids:?}");
488 }
489}