Skip to main content

reddb_server/storage/query/
similarity.rs

1//! Similarity Search Integration for Query Engine
2//!
3//! Provides vector similarity search capabilities integrated with
4//! the query engine for semantic search and nearest neighbor queries.
5
6use super::filter::Filter;
7use super::sort::QueryLimits;
8use crate::storage::engine::distance::DistanceMetric;
9use crate::storage::engine::vector_store::{SearchResult, VectorCollection, VectorId};
10use crate::storage::schema::Value;
11use std::collections::HashMap;
12
13/// Dense vector wrapper for similarity queries
14#[derive(Debug, Clone)]
15pub struct DenseVector {
16    values: Vec<f32>,
17}
18
19impl DenseVector {
20    pub fn new(values: Vec<f32>) -> Self {
21        Self { values }
22    }
23
24    pub fn as_slice(&self) -> &[f32] {
25        &self.values
26    }
27}
28
29impl From<Vec<f32>> for DenseVector {
30    fn from(values: Vec<f32>) -> Self {
31        Self { values }
32    }
33}
34
35/// Similarity query parameters
36#[derive(Debug, Clone)]
37pub struct SimilarityQuery {
38    /// Query vector
39    pub vector: DenseVector,
40    /// Number of neighbors to find
41    pub k: usize,
42    /// Distance metric
43    pub distance: DistanceMetric,
44    /// Optional filter to apply before similarity search
45    pub filter: Option<Filter>,
46    /// Number of probes for IVF index (if applicable)
47    pub n_probes: Option<usize>,
48    /// Distance threshold (for range queries)
49    pub distance_threshold: Option<f32>,
50}
51
52impl SimilarityQuery {
53    /// Create a new similarity query
54    pub fn new(vector: DenseVector, k: usize) -> Self {
55        Self {
56            vector,
57            k,
58            distance: DistanceMetric::Cosine,
59            filter: None,
60            n_probes: None,
61            distance_threshold: None,
62        }
63    }
64
65    /// Set distance metric
66    pub fn with_distance(mut self, distance: DistanceMetric) -> Self {
67        self.distance = distance;
68        self
69    }
70
71    /// Set pre-filter
72    pub fn with_filter(mut self, filter: Filter) -> Self {
73        self.filter = Some(filter);
74        self
75    }
76
77    /// Set number of IVF probes
78    pub fn with_probes(mut self, n_probes: usize) -> Self {
79        self.n_probes = Some(n_probes);
80        self
81    }
82
83    /// Set distance threshold for range query
84    pub fn with_threshold(mut self, threshold: f32) -> Self {
85        self.distance_threshold = Some(threshold);
86        self
87    }
88}
89
90/// Similarity search result with metadata
91#[derive(Debug, Clone)]
92pub struct SimilarityResult {
93    /// Vector ID
94    pub id: VectorId,
95    /// Distance to query
96    pub distance: f32,
97    /// Similarity score (1 - normalized_distance for bounded metrics)
98    pub score: f32,
99    /// Associated metadata (optional)
100    pub metadata: Option<HashMap<String, Value>>,
101}
102
103impl SimilarityResult {
104    /// Create a new result
105    pub fn new(id: VectorId, distance: f32) -> Self {
106        Self {
107            id,
108            distance,
109            score: 1.0 / (1.0 + distance), // Simple similarity transform
110            metadata: None,
111        }
112    }
113
114    /// Create with score conversion based on distance metric
115    pub fn with_metric(id: VectorId, distance: f32, metric: DistanceMetric) -> Self {
116        let score = match metric {
117            DistanceMetric::Cosine => 1.0 - distance, // Cosine: 0 = identical, 2 = opposite
118            DistanceMetric::InnerProduct => -distance, // Negated dot product
119            DistanceMetric::L2 => 1.0 / (1.0 + distance),
120        };
121
122        Self {
123            id,
124            distance,
125            score: score.max(0.0),
126            metadata: None,
127        }
128    }
129
130    /// Add metadata
131    pub fn with_metadata(mut self, metadata: HashMap<String, Value>) -> Self {
132        self.metadata = Some(metadata);
133        self
134    }
135}
136
137/// Result set from similarity search
138#[derive(Debug, Clone)]
139pub struct SimilarityResultSet {
140    /// Results sorted by distance
141    pub results: Vec<SimilarityResult>,
142    /// Query vector dimension
143    pub dimension: usize,
144    /// Distance metric used
145    pub distance: DistanceMetric,
146    /// Total vectors searched (for approximate search)
147    pub vectors_searched: Option<usize>,
148    /// Search time in microseconds
149    pub search_time_us: u64,
150}
151
152impl SimilarityResultSet {
153    /// Create empty result set
154    pub fn empty(dimension: usize, distance: DistanceMetric) -> Self {
155        Self {
156            results: Vec::new(),
157            dimension,
158            distance,
159            vectors_searched: None,
160            search_time_us: 0,
161        }
162    }
163
164    /// Create from search results
165    pub fn from_results(
166        results: Vec<SearchResult>,
167        dimension: usize,
168        distance: DistanceMetric,
169    ) -> Self {
170        let similarity_results = results
171            .into_iter()
172            .map(|r| SimilarityResult::with_metric(r.id, r.distance, distance))
173            .collect();
174
175        Self {
176            results: similarity_results,
177            dimension,
178            distance,
179            vectors_searched: None,
180            search_time_us: 0,
181        }
182    }
183
184    /// Get number of results
185    pub fn len(&self) -> usize {
186        self.results.len()
187    }
188
189    /// Check if empty
190    pub fn is_empty(&self) -> bool {
191        self.results.is_empty()
192    }
193
194    /// Get top-k IDs
195    pub fn top_ids(&self, k: usize) -> Vec<VectorId> {
196        self.results.iter().take(k).map(|r| r.id).collect()
197    }
198
199    /// Get results above score threshold
200    pub fn above_score(&self, threshold: f32) -> Vec<&SimilarityResult> {
201        self.results
202            .iter()
203            .filter(|r| r.score >= threshold)
204            .collect()
205    }
206
207    /// Apply limits
208    pub fn apply_limits(mut self, limits: QueryLimits) -> Self {
209        self.results = limits.apply(self.results);
210        self
211    }
212}
213
214/// Trait for vector index that supports similarity search
215pub trait VectorIndex: Send + Sync {
216    /// Search for k nearest neighbors
217    fn search(&self, query: &DenseVector, k: usize) -> Vec<SearchResult>;
218
219    /// Search with optional parameters
220    fn search_with_params(
221        &self,
222        query: &DenseVector,
223        k: usize,
224        n_probes: Option<usize>,
225    ) -> Vec<SearchResult>;
226
227    /// Get vector by ID
228    fn get(&self, id: VectorId) -> Option<DenseVector>;
229
230    /// Get dimension
231    fn dimension(&self) -> usize;
232
233    /// Get distance metric
234    fn distance_metric(&self) -> DistanceMetric;
235
236    /// Get number of indexed vectors
237    fn len(&self) -> usize;
238
239    /// Check if empty
240    fn is_empty(&self) -> bool {
241        self.len() == 0
242    }
243}
244
245impl VectorIndex for VectorCollection {
246    fn search(&self, query: &DenseVector, k: usize) -> Vec<SearchResult> {
247        VectorCollection::search(self, query.as_slice(), k)
248    }
249
250    fn search_with_params(
251        &self,
252        query: &DenseVector,
253        k: usize,
254        _n_probes: Option<usize>,
255    ) -> Vec<SearchResult> {
256        VectorCollection::search(self, query.as_slice(), k)
257    }
258
259    fn get(&self, id: VectorId) -> Option<DenseVector> {
260        VectorCollection::get(self, id).map(|vec| DenseVector::new(vec.clone()))
261    }
262
263    fn dimension(&self) -> usize {
264        self.dimension
265    }
266
267    fn distance_metric(&self) -> DistanceMetric {
268        self.metric
269    }
270
271    fn len(&self) -> usize {
272        self.len()
273    }
274}
275
276/// Execute similarity search
277pub fn execute_similarity_search(
278    index: &dyn VectorIndex,
279    query: &SimilarityQuery,
280) -> SimilarityResultSet {
281    let start = std::time::Instant::now();
282
283    // Perform search
284    let results = if let Some(threshold) = query.distance_threshold {
285        // Range query: get more results then filter
286        let candidates = index.search_with_params(&query.vector, query.k * 10, query.n_probes);
287        candidates
288            .into_iter()
289            .filter(|r| r.distance <= threshold)
290            .take(query.k)
291            .collect()
292    } else {
293        index.search_with_params(&query.vector, query.k, query.n_probes)
294    };
295
296    let search_time = start.elapsed().as_micros() as u64;
297
298    let mut result_set =
299        SimilarityResultSet::from_results(results, index.dimension(), index.distance_metric());
300    result_set.search_time_us = search_time;
301    result_set.vectors_searched = Some(index.len());
302
303    result_set
304}
305
306/// Hybrid search combining filter and similarity
307pub fn execute_hybrid_search<F>(
308    index: &dyn VectorIndex,
309    query: &SimilarityQuery,
310    get_metadata: F,
311    filter_matches: impl Fn(VectorId, &Filter) -> bool,
312) -> SimilarityResultSet
313where
314    F: Fn(VectorId) -> Option<HashMap<String, Value>>,
315{
316    let start = std::time::Instant::now();
317
318    // Get more candidates than needed to account for filtering
319    let over_fetch = if query.filter.is_some() { 10 } else { 1 };
320    let candidates = index.search_with_params(&query.vector, query.k * over_fetch, query.n_probes);
321
322    // Apply filter and collect results
323    let results: Vec<SimilarityResult> = candidates
324        .into_iter()
325        .filter(|r| {
326            if let Some(filter) = &query.filter {
327                filter_matches(r.id, filter)
328            } else {
329                true
330            }
331        })
332        .take(query.k)
333        .map(|r| {
334            let mut result =
335                SimilarityResult::with_metric(r.id, r.distance, index.distance_metric());
336            if let Some(meta) = get_metadata(r.id) {
337                result = result.with_metadata(meta);
338            }
339            result
340        })
341        .collect();
342
343    let search_time = start.elapsed().as_micros() as u64;
344
345    SimilarityResultSet {
346        results,
347        dimension: index.dimension(),
348        distance: index.distance_metric(),
349        vectors_searched: Some(index.len()),
350        search_time_us: search_time,
351    }
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357
358    fn create_test_index() -> VectorCollection {
359        let mut collection = VectorCollection::new("test", 3).with_metric(DistanceMetric::Cosine);
360
361        // Add test vectors
362        let _ = collection.insert(vec![1.0, 0.0, 0.0], None);
363        let _ = collection.insert(vec![0.0, 1.0, 0.0], None);
364        let _ = collection.insert(vec![0.0, 0.0, 1.0], None);
365        let _ = collection.insert(vec![0.7, 0.7, 0.0], None);
366        let _ = collection.insert(vec![0.5, 0.5, 0.7], None);
367
368        collection
369    }
370
371    #[test]
372    fn test_similarity_query_basic() {
373        let index = create_test_index();
374
375        let query = SimilarityQuery::new(DenseVector::new(vec![1.0, 0.0, 0.0]), 3);
376        let results = execute_similarity_search(&index, &query);
377
378        assert_eq!(results.len(), 3);
379        assert_eq!(results.results[0].id, 0); // Exact match
380        assert!(results.results[0].distance < 0.01);
381    }
382
383    #[test]
384    fn test_similarity_result_score() {
385        // Cosine distance 0 = identical
386        let result = SimilarityResult::with_metric(1, 0.0, DistanceMetric::Cosine);
387        assert!((result.score - 1.0).abs() < 0.01);
388
389        // Cosine distance 1 = orthogonal
390        let result = SimilarityResult::with_metric(1, 1.0, DistanceMetric::Cosine);
391        assert!(result.score < 0.01);
392    }
393
394    #[test]
395    fn test_similarity_result_set_top_ids() {
396        let index = create_test_index();
397
398        let query = SimilarityQuery::new(DenseVector::new(vec![1.0, 0.0, 0.0]), 5);
399        let results = execute_similarity_search(&index, &query);
400
401        let top3 = results.top_ids(3);
402        assert_eq!(top3.len(), 3);
403        assert_eq!(top3[0], 0);
404    }
405
406    #[test]
407    fn test_similarity_threshold() {
408        let index = create_test_index();
409
410        // Query with distance threshold
411        let query =
412            SimilarityQuery::new(DenseVector::new(vec![1.0, 0.0, 0.0]), 10).with_threshold(0.5);
413
414        let results = execute_similarity_search(&index, &query);
415
416        // Only vectors within threshold should be returned
417        for result in &results.results {
418            assert!(result.distance <= 0.5);
419        }
420    }
421
422    #[test]
423    fn test_vector_index_trait() {
424        let index = create_test_index();
425
426        let index_ref: &dyn VectorIndex = &index;
427
428        assert_eq!(index_ref.dimension(), 3);
429        assert_eq!(index_ref.len(), 5);
430        assert!(!index_ref.is_empty());
431
432        let vec = index_ref.get(0).unwrap();
433        assert_eq!(vec.as_slice(), &[1.0, 0.0, 0.0]);
434    }
435
436    #[test]
437    fn test_above_score_filter() {
438        let results = SimilarityResultSet {
439            results: vec![
440                SimilarityResult::new(1, 0.1), // score ~0.91
441                SimilarityResult::new(2, 0.5), // score ~0.67
442                SimilarityResult::new(3, 2.0), // score ~0.33
443            ],
444            dimension: 3,
445            distance: DistanceMetric::L2,
446            vectors_searched: Some(100),
447            search_time_us: 100,
448        };
449
450        let above_05 = results.above_score(0.5);
451        assert_eq!(above_05.len(), 2); // 0.91 and 0.67 are >= 0.5
452    }
453
454    #[test]
455    fn test_similarity_query_builder() {
456        let query = SimilarityQuery::new(DenseVector::new(vec![1.0, 0.0, 0.0]), 10)
457            .with_distance(DistanceMetric::L2)
458            .with_probes(5)
459            .with_threshold(1.0);
460
461        assert_eq!(query.k, 10);
462        assert_eq!(query.distance, DistanceMetric::L2);
463        assert_eq!(query.n_probes, Some(5));
464        assert_eq!(query.distance_threshold, Some(1.0));
465    }
466
467    #[test]
468    fn test_hybrid_search_with_filter() {
469        let index = create_test_index();
470
471        // Mock metadata
472        let metadata: HashMap<VectorId, HashMap<String, Value>> = [
473            (
474                1,
475                [("category".to_string(), Value::text("A".to_string()))]
476                    .into_iter()
477                    .collect(),
478            ),
479            (
480                2,
481                [("category".to_string(), Value::text("B".to_string()))]
482                    .into_iter()
483                    .collect(),
484            ),
485            (
486                3,
487                [("category".to_string(), Value::text("A".to_string()))]
488                    .into_iter()
489                    .collect(),
490            ),
491            (
492                4,
493                [("category".to_string(), Value::text("B".to_string()))]
494                    .into_iter()
495                    .collect(),
496            ),
497            (
498                5,
499                [("category".to_string(), Value::text("A".to_string()))]
500                    .into_iter()
501                    .collect(),
502            ),
503        ]
504        .into_iter()
505        .collect();
506
507        let filter = Filter::eq("category", Value::text("A".to_string()));
508        let query = SimilarityQuery::new(DenseVector::new(vec![1.0, 0.0, 0.0]), 5)
509            .with_filter(filter.clone());
510
511        let results = execute_hybrid_search(
512            &index,
513            &query,
514            |id| metadata.get(&id).cloned(),
515            |id, filter| {
516                if let Some(meta) = metadata.get(&id) {
517                    filter.evaluate(&|col| meta.get(col).cloned())
518                } else {
519                    false
520                }
521            },
522        );
523
524        // Should only return vectors with category "A"
525        assert!(results.len() <= 3); // Only 3 vectors have category A
526        for result in &results.results {
527            if let Some(meta) = &result.metadata {
528                assert_eq!(meta.get("category"), Some(&Value::text("A".to_string())));
529            }
530        }
531    }
532
533    #[test]
534    fn test_apply_limits() {
535        let results = SimilarityResultSet {
536            results: (0..10)
537                .map(|i| SimilarityResult::new(i, i as f32 * 0.1))
538                .collect(),
539            dimension: 3,
540            distance: DistanceMetric::L2,
541            vectors_searched: Some(100),
542            search_time_us: 100,
543        };
544
545        let limited = results.apply_limits(QueryLimits::none().offset(2).limit(3));
546        assert_eq!(limited.len(), 3);
547        assert_eq!(limited.results[0].id, 2);
548    }
549}