Skip to main content

oxibonsai_rag/
vector_store.rs

1//! In-memory flat vector store with configurable distance metric.
2//!
3//! The [`VectorStore`] holds a flat list of [`VectorEntry`] items.  Search
4//! is performed with a brute-force linear scan over all entries, evaluating
5//! the configured [`Distance`] metric against the query vector.  This is
6//! appropriate for corpora up to tens of thousands of chunks; larger
7//! corpora benefit from approximate-nearest-neighbour indices (out of
8//! scope for this crate).
9//!
10//! Scoring semantics are unified by [`Distance::to_score`]: similarity
11//! metrics (Cosine, DotProduct) use their raw value as the score, whereas
12//! true distances (Euclidean, Angular, Hamming) are negated so that
13//! "higher is better" sorting always yields the closest match first.
14//!
15//! NaN / Inf guards: any non-finite value in an inserted vector or the
16//! query vector is rejected with [`RagError::NonFinite`].
17
18use serde::{Deserialize, Serialize};
19
20use crate::chunker::Chunk;
21use crate::distance::Distance;
22use crate::error::RagError;
23use crate::metadata_filter::MetadataFilter;
24
25// ─────────────────────────────────────────────────────────────────────────────
26// Math primitives
27// ─────────────────────────────────────────────────────────────────────────────
28
29/// Compute the dot product of two equal-length slices.
30///
31/// Returns 0.0 if either slice is empty or they have different lengths.
32#[inline]
33pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
34    if a.len() != b.len() {
35        return 0.0;
36    }
37    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
38}
39
40/// L2-normalise `v` in place.
41///
42/// If the Euclidean norm is smaller than `1e-10` the vector is left
43/// unchanged to prevent NaN propagation.
44#[inline]
45pub fn l2_normalize(v: &mut [f32]) {
46    let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
47    if norm > 1e-10 {
48        for x in v.iter_mut() {
49            *x /= norm;
50        }
51    }
52}
53
54/// Cosine similarity between two equal-length vectors.
55///
56/// Both vectors are assumed to be *unit vectors* (L2-normalised).  Under
57/// that assumption, cosine similarity == dot product and the denominator
58/// can be skipped.
59///
60/// Returns a value in `[-1.0, 1.0]`.  Returns `0.0` for empty or mismatched
61/// inputs rather than panicking.
62#[inline]
63pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
64    if a.is_empty() || a.len() != b.len() {
65        return 0.0;
66    }
67    dot_product(a, b).clamp(-1.0, 1.0)
68}
69
70// ─────────────────────────────────────────────────────────────────────────────
71// VectorEntry & SearchResult
72// ─────────────────────────────────────────────────────────────────────────────
73
74/// A single indexed entry in the vector store.
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct VectorEntry {
77    /// Unique identifier assigned at insertion time.
78    pub id: usize,
79    /// Stored embedding vector.  For similarity metrics this is
80    /// L2-normalised; for distance metrics it is stored verbatim.
81    pub vector: Vec<f32>,
82    /// The chunk this entry was derived from.
83    pub chunk: Chunk,
84}
85
86/// A result returned by a similarity / distance search.
87#[derive(Debug, Clone)]
88pub struct SearchResult {
89    /// Unified "higher is better" score (see [`Distance::to_score`]).  For
90    /// similarity metrics this equals the raw similarity; for true
91    /// distances it equals the negative of the raw distance.
92    pub score: f32,
93    /// The chunk associated with this result.
94    pub chunk: Chunk,
95    /// The entry's unique identifier in the store.
96    pub id: usize,
97}
98
99// ─────────────────────────────────────────────────────────────────────────────
100// VectorStore
101// ─────────────────────────────────────────────────────────────────────────────
102
103/// In-memory flat vector store backed by a `Vec<VectorEntry>`.
104///
105/// The configured [`Distance`] controls both how vectors are *stored*
106/// (similarity metrics pre-normalise; distance metrics store verbatim) and
107/// how queries are scored.
108#[derive(Debug, Default, Serialize, Deserialize)]
109pub struct VectorStore {
110    entries: Vec<VectorEntry>,
111    dim: usize,
112    #[serde(default)]
113    distance: Distance,
114}
115
116impl VectorStore {
117    /// Create an empty cosine-similarity store for vectors of dim `dim`.
118    pub fn new(dim: usize) -> Self {
119        Self::new_with_distance(dim, Distance::default())
120    }
121
122    /// Create an empty store with a specific [`Distance`] metric.
123    pub fn new_with_distance(dim: usize, distance: Distance) -> Self {
124        Self {
125            entries: Vec::new(),
126            dim,
127            distance,
128        }
129    }
130
131    /// Insert a vector+chunk pair into the store.
132    ///
133    /// Behaviour depends on the store's [`Distance`]:
134    ///
135    /// - Similarity metrics (Cosine, DotProduct, Angular) L2-normalise the
136    ///   stored vector up-front so that scoring is cheap.
137    /// - True distance metrics (Euclidean, Hamming) preserve the vector
138    ///   verbatim.
139    ///
140    /// Returns the assigned entry id.  Errors:
141    ///
142    /// - [`RagError::DimensionMismatch`] for wrong-size vectors.
143    /// - [`RagError::NonFinite`] for `NaN` / `±∞` entries.
144    pub fn insert(&mut self, mut vector: Vec<f32>, chunk: Chunk) -> Result<usize, RagError> {
145        if vector.len() != self.dim {
146            return Err(RagError::DimensionMismatch {
147                expected: self.dim,
148                got: vector.len(),
149            });
150        }
151        if vector.iter().any(|x| !x.is_finite()) {
152            return Err(RagError::NonFinite);
153        }
154        if matches!(
155            self.distance,
156            Distance::Cosine | Distance::DotProduct | Distance::Angular
157        ) {
158            l2_normalize(&mut vector);
159        }
160        let id = self.entries.len();
161        self.entries.push(VectorEntry { id, vector, chunk });
162        Ok(id)
163    }
164
165    /// Return the top-`top_k` entries by score.
166    ///
167    /// The query vector is normalised internally when the metric is a
168    /// similarity; it is not mutated.  Results are returned in descending
169    /// score order (see [`SearchResult::score`] for polarity).
170    pub fn search(&self, query: &[f32], top_k: usize) -> Vec<SearchResult> {
171        self.search_with_threshold(query, top_k, f32::NEG_INFINITY)
172    }
173
174    /// Like [`Self::search`] but discards results whose score is below
175    /// `min_score`.
176    pub fn search_with_threshold(
177        &self,
178        query: &[f32],
179        top_k: usize,
180        min_score: f32,
181    ) -> Vec<SearchResult> {
182        self.scored(query, top_k, min_score, None)
183    }
184
185    /// Search filtered by a [`MetadataFilter`].
186    ///
187    /// Filter evaluation is post-scoring; the metric is evaluated against
188    /// every entry, then results that fail the filter are discarded.
189    pub fn search_filtered(
190        &self,
191        query: &[f32],
192        top_k: usize,
193        filter: &MetadataFilter,
194    ) -> Result<Vec<SearchResult>, RagError> {
195        filter.validate()?;
196        Ok(self.scored(query, top_k, f32::NEG_INFINITY, Some(filter)))
197    }
198
199    fn scored(
200        &self,
201        query: &[f32],
202        top_k: usize,
203        min_score: f32,
204        filter: Option<&MetadataFilter>,
205    ) -> Vec<SearchResult> {
206        if self.entries.is_empty() || top_k == 0 || query.len() != self.dim {
207            return Vec::new();
208        }
209        if query.iter().any(|x| !x.is_finite()) {
210            return Vec::new();
211        }
212
213        // Prepare the query according to metric semantics.
214        let prepared: Vec<f32> = if matches!(
215            self.distance,
216            Distance::Cosine | Distance::DotProduct | Distance::Angular
217        ) {
218            let mut q = query.to_vec();
219            l2_normalize(&mut q);
220            q
221        } else {
222            query.to_vec()
223        };
224
225        let mut scored: Vec<(f32, usize)> = Vec::with_capacity(self.entries.len());
226        for entry in &self.entries {
227            if let Some(f) = filter {
228                if !f.matches(&entry.chunk.metadata) {
229                    continue;
230                }
231            }
232            let raw = match self.distance.compute(&prepared, &entry.vector) {
233                Ok(v) => v,
234                Err(_) => continue,
235            };
236            let score = self.distance.to_score(raw);
237            if score >= min_score {
238                scored.push((score, entry.id));
239            }
240        }
241
242        scored.sort_unstable_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
243        scored.truncate(top_k);
244
245        scored
246            .into_iter()
247            .map(|(score, id)| SearchResult {
248                score,
249                chunk: self.entries[id].chunk.clone(),
250                id,
251            })
252            .collect()
253    }
254
255    /// Number of entries currently in the store.
256    pub fn len(&self) -> usize {
257        self.entries.len()
258    }
259
260    /// Returns `true` if the store contains no entries.
261    pub fn is_empty(&self) -> bool {
262        self.entries.is_empty()
263    }
264
265    /// Remove all entries from the store (preserves the configured
266    /// dimension and distance metric).
267    pub fn clear(&mut self) {
268        self.entries.clear();
269    }
270
271    /// Approximate heap memory used by the stored vectors and chunk
272    /// texts.  This is a lower-bound estimate: it counts vector bytes and
273    /// chunk-text bytes but ignores allocator overhead and struct
274    /// padding.
275    pub fn memory_usage_bytes(&self) -> usize {
276        self.entries.iter().fold(0usize, |acc, e| {
277            acc + e.vector.len() * std::mem::size_of::<f32>()
278                + e.chunk.text.len()
279                + std::mem::size_of::<VectorEntry>()
280        })
281    }
282
283    /// The embedding dimensionality this store was constructed with.
284    pub fn dim(&self) -> usize {
285        self.dim
286    }
287
288    /// The active distance metric.
289    pub fn distance(&self) -> Distance {
290        self.distance
291    }
292
293    /// Borrow the internal entries (used by the persistence layer).
294    pub(crate) fn entries(&self) -> &[VectorEntry] {
295        &self.entries
296    }
297
298    /// Replace the internal entries (used by the persistence layer).
299    pub(crate) fn set_entries(&mut self, entries: Vec<VectorEntry>) {
300        self.entries = entries;
301    }
302}