Skip to main content

engine/
search.rs

1use common::{DistanceMetric, PaginationCursor, QueryResponse, SearchResult, Vector};
2
3use crate::distance::calculate_distance;
4
5/// Perform brute-force search over vectors
6/// Returns top-k most similar vectors sorted by score descending
7/// Supports cursor-based pagination for efficient result traversal
8pub fn brute_force_search(
9    query: &[f32],
10    vectors: &[Vector],
11    top_k: usize,
12    metric: DistanceMetric,
13    include_metadata: bool,
14    include_vectors: bool,
15    cursor: Option<&PaginationCursor>,
16) -> QueryResponse {
17    // Validate and clamp top_k to reasonable bounds
18    let top_k = if top_k == 0 {
19        tracing::warn!("top_k of 0 is invalid, using 1");
20        1
21    } else if top_k > 10_000 {
22        tracing::warn!("top_k {} exceeds maximum, clamping to 10000", top_k);
23        10_000
24    } else {
25        top_k
26    };
27
28    if vectors.is_empty() {
29        return QueryResponse {
30            results: vec![],
31            next_cursor: None,
32            has_more: Some(false),
33            search_time_ms: 0,
34        };
35    }
36
37    // Calculate scores for all vectors
38    let mut scored: Vec<(f32, &Vector)> = vectors
39        .iter()
40        .map(|v| (calculate_distance(query, &v.values, metric), v))
41        .collect();
42
43    // Sort by score descending (higher similarity = better), then by id for tie-breaking
44    scored.sort_by(
45        |a, b| match b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal) {
46            std::cmp::Ordering::Equal => a.1.id.cmp(&b.1.id),
47            other => other,
48        },
49    );
50
51    // Apply cursor filter if provided
52    // For descending score order, "after cursor" means:
53    // score < last_score OR (score == last_score AND id > last_id)
54    let filtered: Vec<_> = if let Some(cursor) = cursor {
55        scored
56            .into_iter()
57            .filter(|(score, vector)| {
58                *score < cursor.last_score
59                    || (*score == cursor.last_score && vector.id > cursor.last_id)
60            })
61            .collect()
62    } else {
63        scored
64    };
65
66    // Fetch top_k + 1 to determine if there are more results
67    let fetch_count = top_k + 1;
68    let fetched: Vec<_> = filtered.into_iter().take(fetch_count).collect();
69    let has_more = fetched.len() > top_k;
70
71    // Take only top_k for the actual results
72    let results_slice = if has_more {
73        &fetched[..top_k]
74    } else {
75        &fetched[..]
76    };
77
78    // Build results
79    let results: Vec<SearchResult> = results_slice
80        .iter()
81        .map(|(score, vector)| SearchResult {
82            id: vector.id.clone(),
83            score: *score,
84            metadata: if include_metadata {
85                vector.metadata.clone()
86            } else {
87                None
88            },
89            vector: if include_vectors {
90                Some(vector.values.clone())
91            } else {
92                None
93            },
94        })
95        .collect();
96
97    // Generate next cursor from the last result
98    let next_cursor = if has_more {
99        results.last().map(|last_result| {
100            PaginationCursor::new(last_result.score, last_result.id.clone()).encode()
101        })
102    } else {
103        None
104    };
105
106    QueryResponse {
107        results,
108        next_cursor,
109        has_more: Some(has_more),
110        search_time_ms: 0,
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117    use serde_json::json;
118
119    fn make_vector(id: &str, values: Vec<f32>) -> Vector {
120        Vector {
121            id: id.to_string(),
122            values,
123            metadata: None,
124            ttl_seconds: None,
125            expires_at: None,
126        }
127    }
128
129    fn make_vector_with_metadata(
130        id: &str,
131        values: Vec<f32>,
132        metadata: serde_json::Value,
133    ) -> Vector {
134        Vector {
135            id: id.to_string(),
136            values,
137            metadata: Some(metadata),
138            ttl_seconds: None,
139            expires_at: None,
140        }
141    }
142
143    #[test]
144    fn test_brute_force_search_empty() {
145        let query = vec![1.0, 0.0, 0.0];
146        let vectors: Vec<Vector> = vec![];
147
148        let result = brute_force_search(
149            &query,
150            &vectors,
151            5,
152            DistanceMetric::Cosine,
153            true,
154            false,
155            None,
156        );
157        assert!(result.results.is_empty());
158        assert_eq!(result.has_more, Some(false));
159        assert!(result.next_cursor.is_none());
160    }
161
162    #[test]
163    fn test_brute_force_search_single_vector() {
164        let query = vec![1.0, 0.0, 0.0];
165        let vectors = vec![make_vector("v1", vec![1.0, 0.0, 0.0])];
166
167        let result = brute_force_search(
168            &query,
169            &vectors,
170            5,
171            DistanceMetric::Cosine,
172            true,
173            false,
174            None,
175        );
176        assert_eq!(result.results.len(), 1);
177        assert_eq!(result.results[0].id, "v1");
178        assert!((result.results[0].score - 1.0).abs() < 1e-6);
179        assert_eq!(result.has_more, Some(false));
180    }
181
182    #[test]
183    fn test_brute_force_search_ordering() {
184        let query = vec![1.0, 0.0, 0.0];
185        let vectors = vec![
186            make_vector("v1", vec![1.0, 0.0, 0.0]),     // Perfect match
187            make_vector("v2", vec![0.0, 1.0, 0.0]),     // Orthogonal
188            make_vector("v3", vec![0.707, 0.707, 0.0]), // 45 degrees
189        ];
190
191        let result = brute_force_search(
192            &query,
193            &vectors,
194            3,
195            DistanceMetric::Cosine,
196            true,
197            false,
198            None,
199        );
200
201        assert_eq!(result.results.len(), 3);
202        assert_eq!(result.results[0].id, "v1"); // Best match first
203        assert_eq!(result.results[1].id, "v3"); // Second best
204        assert_eq!(result.results[2].id, "v2"); // Worst match last
205    }
206
207    #[test]
208    fn test_brute_force_search_top_k() {
209        let query = vec![1.0, 0.0];
210        let vectors = vec![
211            make_vector("v1", vec![1.0, 0.0]),
212            make_vector("v2", vec![0.9, 0.1]),
213            make_vector("v3", vec![0.8, 0.2]),
214            make_vector("v4", vec![0.7, 0.3]),
215            make_vector("v5", vec![0.6, 0.4]),
216        ];
217
218        let result = brute_force_search(
219            &query,
220            &vectors,
221            3,
222            DistanceMetric::Cosine,
223            true,
224            false,
225            None,
226        );
227
228        assert_eq!(result.results.len(), 3);
229        assert_eq!(result.results[0].id, "v1");
230        assert_eq!(result.has_more, Some(true)); // More results available
231        assert!(result.next_cursor.is_some());
232    }
233
234    #[test]
235    fn test_brute_force_search_include_metadata() {
236        let query = vec![1.0, 0.0];
237        let vectors = vec![make_vector_with_metadata(
238            "v1",
239            vec![1.0, 0.0],
240            json!({"key": "value"}),
241        )];
242
243        // With metadata
244        let result = brute_force_search(
245            &query,
246            &vectors,
247            1,
248            DistanceMetric::Cosine,
249            true,
250            false,
251            None,
252        );
253        assert!(result.results[0].metadata.is_some());
254
255        // Without metadata
256        let result = brute_force_search(
257            &query,
258            &vectors,
259            1,
260            DistanceMetric::Cosine,
261            false,
262            false,
263            None,
264        );
265        assert!(result.results[0].metadata.is_none());
266    }
267
268    #[test]
269    fn test_brute_force_search_include_vectors() {
270        let query = vec![1.0, 0.0];
271        let vectors = vec![make_vector("v1", vec![1.0, 0.0])];
272
273        // With vectors
274        let result = brute_force_search(
275            &query,
276            &vectors,
277            1,
278            DistanceMetric::Cosine,
279            false,
280            true,
281            None,
282        );
283        assert!(result.results[0].vector.is_some());
284        assert_eq!(result.results[0].vector.as_ref().unwrap(), &vec![1.0, 0.0]);
285
286        // Without vectors
287        let result = brute_force_search(
288            &query,
289            &vectors,
290            1,
291            DistanceMetric::Cosine,
292            false,
293            false,
294            None,
295        );
296        assert!(result.results[0].vector.is_none());
297    }
298
299    #[test]
300    fn test_brute_force_search_euclidean() {
301        let query = vec![0.0, 0.0];
302        let vectors = vec![
303            make_vector("v1", vec![1.0, 0.0]), // Distance 1
304            make_vector("v2", vec![3.0, 4.0]), // Distance 5
305            make_vector("v3", vec![0.5, 0.0]), // Distance 0.5
306        ];
307
308        let result = brute_force_search(
309            &query,
310            &vectors,
311            3,
312            DistanceMetric::Euclidean,
313            false,
314            false,
315            None,
316        );
317
318        // Closest first (negative euclidean, so -0.5 > -1.0 > -5.0)
319        assert_eq!(result.results[0].id, "v3");
320        assert_eq!(result.results[1].id, "v1");
321        assert_eq!(result.results[2].id, "v2");
322    }
323
324    // Pagination tests
325    #[test]
326    fn test_pagination_basic() {
327        let query = vec![1.0, 0.0];
328        let vectors = vec![
329            make_vector("v1", vec![1.0, 0.0]),
330            make_vector("v2", vec![0.9, 0.1]),
331            make_vector("v3", vec![0.8, 0.2]),
332            make_vector("v4", vec![0.7, 0.3]),
333            make_vector("v5", vec![0.6, 0.4]),
334        ];
335
336        // First page
337        let result1 = brute_force_search(
338            &query,
339            &vectors,
340            2,
341            DistanceMetric::Cosine,
342            false,
343            false,
344            None,
345        );
346        assert_eq!(result1.results.len(), 2);
347        assert_eq!(result1.results[0].id, "v1");
348        assert_eq!(result1.results[1].id, "v2");
349        assert_eq!(result1.has_more, Some(true));
350        assert!(result1.next_cursor.is_some());
351
352        // Decode cursor and fetch second page
353        let cursor = PaginationCursor::decode(result1.next_cursor.as_ref().unwrap()).unwrap();
354        let result2 = brute_force_search(
355            &query,
356            &vectors,
357            2,
358            DistanceMetric::Cosine,
359            false,
360            false,
361            Some(&cursor),
362        );
363        assert_eq!(result2.results.len(), 2);
364        assert_eq!(result2.results[0].id, "v3");
365        assert_eq!(result2.results[1].id, "v4");
366        assert_eq!(result2.has_more, Some(true));
367
368        // Third page (last)
369        let cursor2 = PaginationCursor::decode(result2.next_cursor.as_ref().unwrap()).unwrap();
370        let result3 = brute_force_search(
371            &query,
372            &vectors,
373            2,
374            DistanceMetric::Cosine,
375            false,
376            false,
377            Some(&cursor2),
378        );
379        assert_eq!(result3.results.len(), 1);
380        assert_eq!(result3.results[0].id, "v5");
381        assert_eq!(result3.has_more, Some(false));
382        assert!(result3.next_cursor.is_none());
383    }
384
385    #[test]
386    fn test_pagination_cursor_encode_decode() {
387        let cursor = PaginationCursor::new(0.95, "test_id".to_string());
388        let encoded = cursor.encode();
389        let decoded = PaginationCursor::decode(&encoded).unwrap();
390
391        assert!((decoded.last_score - 0.95).abs() < 1e-6);
392        assert_eq!(decoded.last_id, "test_id");
393    }
394
395    #[test]
396    fn test_pagination_with_tie_scores() {
397        let query = vec![1.0, 0.0];
398        // All vectors have same distance from query (all normalized)
399        let vectors = vec![
400            make_vector("a", vec![1.0, 0.0]),
401            make_vector("b", vec![1.0, 0.0]),
402            make_vector("c", vec![1.0, 0.0]),
403            make_vector("d", vec![1.0, 0.0]),
404        ];
405
406        // First page
407        let result1 = brute_force_search(
408            &query,
409            &vectors,
410            2,
411            DistanceMetric::Cosine,
412            false,
413            false,
414            None,
415        );
416        assert_eq!(result1.results.len(), 2);
417        // Should be sorted alphabetically by id when scores are equal
418        assert_eq!(result1.results[0].id, "a");
419        assert_eq!(result1.results[1].id, "b");
420        assert_eq!(result1.has_more, Some(true));
421
422        // Second page
423        let cursor = PaginationCursor::decode(result1.next_cursor.as_ref().unwrap()).unwrap();
424        let result2 = brute_force_search(
425            &query,
426            &vectors,
427            2,
428            DistanceMetric::Cosine,
429            false,
430            false,
431            Some(&cursor),
432        );
433        assert_eq!(result2.results.len(), 2);
434        assert_eq!(result2.results[0].id, "c");
435        assert_eq!(result2.results[1].id, "d");
436        assert_eq!(result2.has_more, Some(false));
437    }
438
439    #[test]
440    fn test_pagination_no_more_results() {
441        let query = vec![1.0, 0.0];
442        let vectors = vec![
443            make_vector("v1", vec![1.0, 0.0]),
444            make_vector("v2", vec![0.9, 0.1]),
445        ];
446
447        let result = brute_force_search(
448            &query,
449            &vectors,
450            5,
451            DistanceMetric::Cosine,
452            false,
453            false,
454            None,
455        );
456        assert_eq!(result.results.len(), 2);
457        assert_eq!(result.has_more, Some(false));
458        assert!(result.next_cursor.is_none());
459    }
460}