Skip to main content

lora_store/memory/
vector_index.rs

1//! VECTOR index storage and query backend.
2//!
3//! Phase 1 of the vector-indexing extension: introduce a backend
4//! abstraction so the procedure layer no longer scans the property
5//! store directly. Today only the `Flat` brute-force backend is
6//! implemented; the [`VectorBackend`] enum exists so Phase 2 can drop
7//! in an HNSW arm without disturbing the registry shape or the
8//! maintenance hooks.
9//!
10//! Unlike the TEXT / POINT / SORTED registries — which key by
11//! `(label, property)` because the underlying structure is shared
12//! across catalog entries with the same scope — vector indexes are
13//! keyed by **index name**. Two vector indexes on the same
14//! `(label, property)` can coexist with different similarity functions
15//! (e.g. one cosine, one euclidean), each owning its own backend.
16
17use std::collections::{BTreeMap, BTreeSet};
18
19use serde::{Deserialize, Serialize};
20
21use crate::{
22    cosine_similarity_bounded, dot_product, euclidean_similarity, manhattan_distance, LoraVector,
23};
24
25use super::hnsw::{seed_from_name, HnswBackend, HnswParams, HnswSnapshot};
26use super::index_catalog::{IndexConfigValue, StoredIndexEntity};
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
29pub enum VectorSimilarity {
30    Cosine,
31    Euclidean,
32    /// Raw dot product. Higher is more similar; unbounded above and
33    /// below. The right choice for embeddings already L2-normalized
34    /// (cosine reduces to dot in that case, and dot skips one
35    /// reciprocal-sqrt per pair).
36    Dot,
37    /// L1-derived: `1 / (1 + d_L1)`. Same higher-is-better shape as
38    /// `Euclidean`; useful for quantized vectors where L1 is the
39    /// natural metric.
40    Manhattan,
41}
42
43impl VectorSimilarity {
44    pub fn parse(s: &str) -> Option<Self> {
45        if s.eq_ignore_ascii_case("cosine") {
46            Some(VectorSimilarity::Cosine)
47        } else if s.eq_ignore_ascii_case("euclidean") {
48            Some(VectorSimilarity::Euclidean)
49        } else if s.eq_ignore_ascii_case("dot") || s.eq_ignore_ascii_case("dot_product") {
50            Some(VectorSimilarity::Dot)
51        } else if s.eq_ignore_ascii_case("manhattan") {
52            Some(VectorSimilarity::Manhattan)
53        } else {
54            None
55        }
56    }
57
58    pub fn score(self, a: &LoraVector, b: &LoraVector) -> Option<f64> {
59        if a.dimension != b.dimension {
60            return None;
61        }
62        match self {
63            VectorSimilarity::Cosine => cosine_similarity_bounded(a, b),
64            VectorSimilarity::Euclidean => euclidean_similarity(a, b),
65            VectorSimilarity::Dot => dot_product(a, b),
66            VectorSimilarity::Manhattan => manhattan_distance(a, b).map(|d| 1.0 / (1.0 + d)),
67        }
68    }
69
70    /// Resolve `vector.similarity_function` from a catalog `OPTIONS`
71    /// map. Returns `None` when the key is missing or unrecognised;
72    /// DDL validation has already rejected invalid values, so a
73    /// `None` here only occurs on a malformed snapshot/WAL payload —
74    /// the caller picks a default in that case.
75    pub(super) fn from_options(options: &BTreeMap<String, IndexConfigValue>) -> Option<Self> {
76        match options.get("vector.similarity_function")? {
77            IndexConfigValue::String(s) => Self::parse(s),
78            _ => None,
79        }
80    }
81}
82
83/// Brute-force backend: store every vector, score them all per query.
84/// `BTreeMap` keying gives deterministic iteration order, which keeps
85/// score-tie ordering stable across runs (matches the legacy
86/// `score_entities` contract).
87#[derive(Debug, Default, Clone)]
88pub(super) struct FlatBackend {
89    items: BTreeMap<u64, LoraVector>,
90}
91
92impl FlatBackend {
93    fn insert(&mut self, id: u64, vector: LoraVector) {
94        self.items.insert(id, vector);
95    }
96
97    fn remove(&mut self, id: u64) {
98        self.items.remove(&id);
99    }
100
101    fn query(
102        &self,
103        query: &LoraVector,
104        similarity: VectorSimilarity,
105        restrict_to: Option<&BTreeSet<u64>>,
106    ) -> Vec<(u64, f64)> {
107        let mut out = Vec::with_capacity(self.items.len());
108        for (&id, v) in &self.items {
109            if let Some(set) = restrict_to {
110                if !set.contains(&id) {
111                    continue;
112                }
113            }
114            if let Some(score) = similarity.score(v, query) {
115                out.push((id, score));
116            }
117        }
118        out
119    }
120
121    #[cfg(test)]
122    fn len(&self) -> usize {
123        self.items.len()
124    }
125}
126
127/// Selector for which backend powers a given index. Surfaced via the
128/// `vector.indexProvider` index option; defaults to `Flat`.
129#[derive(Debug, Clone, Copy, PartialEq, Eq)]
130pub enum VectorIndexProvider {
131    Flat,
132    Hnsw,
133}
134
135impl VectorIndexProvider {
136    pub fn parse(s: &str) -> Option<Self> {
137        if s.eq_ignore_ascii_case("flat") {
138            Some(VectorIndexProvider::Flat)
139        } else if s.eq_ignore_ascii_case("hnsw") {
140            Some(VectorIndexProvider::Hnsw)
141        } else {
142            None
143        }
144    }
145
146    /// Resolve `vector.indexProvider` from a catalog `OPTIONS` map.
147    /// `'flat'` and `'hnsw'` are accepted; anything else returns
148    /// `None` and the caller falls back to the safe default.
149    pub(super) fn from_options(options: &BTreeMap<String, IndexConfigValue>) -> Option<Self> {
150        match options.get("vector.indexProvider")? {
151            IndexConfigValue::String(s) => Self::parse(s),
152            _ => None,
153        }
154    }
155}
156
157/// Backend dispatch. The Hnsw arm owns its own similarity (it
158/// internalizes scoring during graph construction); the Flat arm
159/// takes similarity per-query because it has no precomputed work to
160/// pin to a single metric.
161#[derive(Debug, Clone)]
162pub(super) enum VectorBackend {
163    Flat(FlatBackend),
164    Hnsw(HnswBackend),
165}
166
167impl VectorBackend {
168    fn insert(&mut self, id: u64, vector: LoraVector) {
169        match self {
170            VectorBackend::Flat(b) => b.insert(id, vector),
171            VectorBackend::Hnsw(b) => b.insert(id, vector),
172        }
173    }
174
175    fn remove(&mut self, id: u64) {
176        match self {
177            VectorBackend::Flat(b) => b.remove(id),
178            VectorBackend::Hnsw(b) => b.remove(id),
179        }
180    }
181
182    /// `similarity` and `k` are only honored by some arms:
183    /// - Flat: scores every point with `similarity`, returns all
184    ///   matching (id, score). The caller sorts + truncates to k.
185    /// - Hnsw: ignores `similarity` (configured at construction),
186    ///   uses `k` to size the result set inside the graph walk.
187    ///
188    /// `restrict_to` is a hard filter: only ids in the set may
189    /// appear in the result. HNSW still traverses through other
190    /// nodes for routing — recall against a very selective filter
191    /// degrades; callers facing tight filters should raise
192    /// `vector.hnsw.ef_search`.
193    fn query(
194        &self,
195        query: &LoraVector,
196        similarity: VectorSimilarity,
197        k: usize,
198        restrict_to: Option<&BTreeSet<u64>>,
199    ) -> Vec<(u64, f64)> {
200        match self {
201            VectorBackend::Flat(b) => b.query(query, similarity, restrict_to),
202            VectorBackend::Hnsw(b) => b.query(query, k, restrict_to),
203        }
204    }
205
206    #[cfg(test)]
207    fn len(&self) -> usize {
208        match self {
209            VectorBackend::Flat(b) => b.len(),
210            VectorBackend::Hnsw(b) => b.len(),
211        }
212    }
213}
214
215/// One installed VECTOR index, plus the resolved metadata the
216/// maintenance hook needs to decide whether a given property change
217/// applies (without having to re-read the catalog).
218#[derive(Debug, Clone)]
219pub(super) struct VectorIndexEntry {
220    pub label: String,
221    pub property: String,
222    pub similarity: VectorSimilarity,
223    pub backend: VectorBackend,
224}
225
226/// Per-entity-kind registry of vector indexes. Keyed by index name.
227#[derive(Debug, Default, Clone)]
228pub(super) struct VectorIndexRegistry {
229    by_name: BTreeMap<String, VectorIndexEntry>,
230}
231
232impl VectorIndexRegistry {
233    pub(super) fn register(
234        &mut self,
235        name: String,
236        label: String,
237        property: String,
238        similarity: VectorSimilarity,
239        provider: VectorIndexProvider,
240        hnsw: HnswParams,
241    ) {
242        let backend = match provider {
243            VectorIndexProvider::Flat => VectorBackend::Flat(FlatBackend::default()),
244            VectorIndexProvider::Hnsw => {
245                let seed = seed_from_name(&name);
246                VectorBackend::Hnsw(HnswBackend::new(similarity, hnsw, seed))
247            }
248        };
249        self.by_name.insert(
250            name,
251            VectorIndexEntry {
252                label,
253                property,
254                similarity,
255                backend,
256            },
257        );
258    }
259
260    pub(super) fn deregister(&mut self, name: &str) {
261        self.by_name.remove(name);
262    }
263
264    pub(super) fn is_empty(&self) -> bool {
265        self.by_name.is_empty()
266    }
267
268    /// Insert `vector` for `entity_id` into every index whose
269    /// `(label, property)` matches. Used by both the initial backfill
270    /// from `activate_vector_index` and per-mutation maintenance.
271    pub(super) fn insert_for(
272        &mut self,
273        label: &str,
274        property: &str,
275        entity_id: u64,
276        vector: &LoraVector,
277    ) {
278        for entry in self.by_name.values_mut() {
279            if entry.label == label && entry.property == property {
280                entry.backend.insert(entity_id, vector.clone());
281            }
282        }
283    }
284
285    /// Drop `entity_id` from every index whose `(label, property)`
286    /// matches.
287    pub(super) fn remove_for(&mut self, label: &str, property: &str, entity_id: u64) {
288        for entry in self.by_name.values_mut() {
289            if entry.label == label && entry.property == property {
290                entry.backend.remove(entity_id);
291            }
292        }
293    }
294
295    /// Run a top-k scan against a named index, optionally
296    /// restricting results to the given id set. Returns `(id,
297    /// score)` pairs from the backend; the caller applies the
298    /// canonical `sort_by_desc(score) then asc(id)` + truncate
299    /// post-step that matches the legacy `scored_rows` contract.
300    /// The flat arm returns all matching entities; the HNSW arm
301    /// caps at k internally.
302    pub(super) fn query(
303        &self,
304        name: &str,
305        query: &LoraVector,
306        k: usize,
307        restrict_to: Option<&BTreeSet<u64>>,
308    ) -> Option<Vec<(u64, f64)>> {
309        let entry = self.by_name.get(name)?;
310        Some(entry.backend.query(query, entry.similarity, k, restrict_to))
311    }
312
313    /// Capture a serializable snapshot of every HNSW backend in this
314    /// registry. Flat backends are skipped because their state is
315    /// reconstructible from the property store at zero cost — only
316    /// HNSW pays the O(n log n) rebuild penalty that justifies
317    /// shipping graph topology through the snapshot pipeline.
318    pub(super) fn to_snapshots(&self, entity: StoredIndexEntity) -> Vec<VectorIndexSnapshot> {
319        let mut out = Vec::new();
320        for (name, entry) in &self.by_name {
321            if let VectorBackend::Hnsw(b) = &entry.backend {
322                out.push(VectorIndexSnapshot {
323                    name: name.clone(),
324                    entity,
325                    label: entry.label.clone(),
326                    property: entry.property.clone(),
327                    data: VectorBackendSnapshot::Hnsw(b.to_snapshot(entry.similarity)),
328                });
329            }
330        }
331        out
332    }
333
334    /// Replace the backend for `snapshot.name` with one rebuilt from
335    /// the snapshot data. No-op if the index isn't registered, the
336    /// snapshot kind doesn't match, or the scope (label/property)
337    /// diverges — all signals that the catalog and the snapshot are
338    /// out of step, in which case we fall back to the property-store
339    /// backfill.
340    pub(super) fn restore_snapshot(&mut self, snapshot: VectorIndexSnapshot) -> bool {
341        let Some(entry) = self.by_name.get_mut(&snapshot.name) else {
342            return false;
343        };
344        if entry.label != snapshot.label || entry.property != snapshot.property {
345            return false;
346        }
347        match snapshot.data {
348            VectorBackendSnapshot::Hnsw(snap) => {
349                if !matches!(entry.backend, VectorBackend::Hnsw(_)) {
350                    return false;
351                }
352                entry.similarity = snap.similarity;
353                entry.backend = VectorBackend::Hnsw(HnswBackend::from_snapshot(snap));
354                true
355            }
356        }
357    }
358}
359
360/// Snapshot of one vector index, carried through the snapshot
361/// pipeline. Only HNSW backends are persisted today (see
362/// [`VectorIndexRegistry::to_snapshots`]); the enum is open for a
363/// future Flat arm if pre-built flat backends become expensive to
364/// rebuild for some workload.
365#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
366pub struct VectorIndexSnapshot {
367    pub name: String,
368    pub entity: StoredIndexEntity,
369    pub label: String,
370    pub property: String,
371    pub data: VectorBackendSnapshot,
372}
373
374#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
375pub enum VectorBackendSnapshot {
376    Hnsw(HnswSnapshot),
377}
378
379#[cfg(test)]
380mod tests {
381    use super::*;
382    use crate::{RawCoordinate, VectorCoordinateType};
383
384    fn vec(values: &[f32]) -> LoraVector {
385        let coords: Vec<RawCoordinate> = values
386            .iter()
387            .map(|v| RawCoordinate::Float(*v as f64))
388            .collect();
389        LoraVector::try_new(coords, values.len() as i64, VectorCoordinateType::Float32).unwrap()
390    }
391
392    fn register_flat(
393        reg: &mut VectorIndexRegistry,
394        name: &str,
395        label: &str,
396        prop: &str,
397        sim: VectorSimilarity,
398    ) {
399        reg.register(
400            name.into(),
401            label.into(),
402            prop.into(),
403            sim,
404            VectorIndexProvider::Flat,
405            HnswParams::default(),
406        );
407    }
408
409    #[test]
410    fn register_and_query_returns_scores() {
411        let mut reg = VectorIndexRegistry::default();
412        register_flat(&mut reg, "vidx", "V", "e", VectorSimilarity::Cosine);
413        reg.insert_for("V", "e", 1, &vec(&[1.0, 0.0, 0.0]));
414        reg.insert_for("V", "e", 2, &vec(&[0.0, 1.0, 0.0]));
415        let scored = reg.query("vidx", &vec(&[1.0, 0.0, 0.0]), 10, None).unwrap();
416        // Two entries; entity 1 (identical to query) scores 1.0.
417        assert_eq!(scored.len(), 2);
418        let by_id: BTreeMap<u64, f64> = scored.into_iter().collect();
419        assert!((by_id[&1] - 1.0).abs() < 1e-9);
420        assert!(by_id[&2] < by_id[&1]);
421    }
422
423    #[test]
424    fn remove_drops_from_backend() {
425        let mut reg = VectorIndexRegistry::default();
426        register_flat(&mut reg, "vidx", "V", "e", VectorSimilarity::Cosine);
427        reg.insert_for("V", "e", 1, &vec(&[1.0, 0.0]));
428        reg.insert_for("V", "e", 2, &vec(&[0.0, 1.0]));
429        reg.remove_for("V", "e", 1);
430        let scored = reg.query("vidx", &vec(&[1.0, 0.0]), 10, None).unwrap();
431        assert_eq!(scored.len(), 1);
432        assert_eq!(scored[0].0, 2);
433    }
434
435    #[test]
436    fn unrelated_scope_is_skipped() {
437        let mut reg = VectorIndexRegistry::default();
438        register_flat(
439            &mut reg,
440            "movie_emb",
441            "Movie",
442            "embedding",
443            VectorSimilarity::Cosine,
444        );
445        // Wrong label — must not be picked up.
446        reg.insert_for("Other", "embedding", 99, &vec(&[1.0, 0.0]));
447        let scored = reg.query("movie_emb", &vec(&[1.0, 0.0]), 10, None).unwrap();
448        assert!(scored.is_empty());
449    }
450
451    #[test]
452    fn two_indexes_on_same_scope_with_different_metrics() {
453        let mut reg = VectorIndexRegistry::default();
454        register_flat(&mut reg, "by_cos", "V", "e", VectorSimilarity::Cosine);
455        register_flat(&mut reg, "by_euc", "V", "e", VectorSimilarity::Euclidean);
456        reg.insert_for("V", "e", 1, &vec(&[1.0, 0.0]));
457        reg.insert_for("V", "e", 2, &vec(&[0.0, 1.0]));
458        let cos = reg.query("by_cos", &vec(&[1.0, 0.0]), 10, None).unwrap();
459        let euc = reg.query("by_euc", &vec(&[1.0, 0.0]), 10, None).unwrap();
460        assert_eq!(cos.len(), 2);
461        assert_eq!(euc.len(), 2);
462        // Distinct metrics → distinct second backends populated.
463        for entry in reg.by_name.values() {
464            assert_eq!(entry.backend.len(), 2);
465        }
466    }
467
468    #[test]
469    fn hnsw_provider_returns_top_k() {
470        let mut reg = VectorIndexRegistry::default();
471        reg.register(
472            "vh".into(),
473            "V".into(),
474            "e".into(),
475            VectorSimilarity::Cosine,
476            VectorIndexProvider::Hnsw,
477            HnswParams::default(),
478        );
479        for i in 0..50u64 {
480            let v = vec(&[(i as f32) / 50.0, 1.0 - (i as f32) / 50.0]);
481            reg.insert_for("V", "e", i, &v);
482        }
483        let hits = reg.query("vh", &vec(&[1.0, 0.0]), 5, None).unwrap();
484        assert_eq!(hits.len(), 5);
485        // Closest two should be the high-i (≈[1, 0]) end of the line.
486        let ids: Vec<u64> = hits.iter().map(|(id, _)| *id).collect();
487        assert!(ids.contains(&49) || ids.contains(&48), "got {ids:?}");
488    }
489}