1use common::{DistanceMetric, PaginationCursor, QueryResponse, SearchResult, Vector};
2
3use crate::distance::calculate_distance;
4
5pub fn brute_force_search(
9 query: &[f32],
10 vectors: &[Vector],
11 top_k: usize,
12 metric: DistanceMetric,
13 include_metadata: bool,
14 include_vectors: bool,
15 cursor: Option<&PaginationCursor>,
16) -> QueryResponse {
17 let top_k = if top_k == 0 {
19 tracing::warn!("top_k of 0 is invalid, using 1");
20 1
21 } else if top_k > 10_000 {
22 tracing::warn!("top_k {} exceeds maximum, clamping to 10000", top_k);
23 10_000
24 } else {
25 top_k
26 };
27
28 if vectors.is_empty() {
29 return QueryResponse {
30 results: vec![],
31 next_cursor: None,
32 has_more: Some(false),
33 search_time_ms: 0,
34 };
35 }
36
37 let mut scored: Vec<(f32, &Vector)> = vectors
39 .iter()
40 .map(|v| (calculate_distance(query, &v.values, metric), v))
41 .collect();
42
43 scored.sort_by(
45 |a, b| match b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal) {
46 std::cmp::Ordering::Equal => a.1.id.cmp(&b.1.id),
47 other => other,
48 },
49 );
50
51 let filtered: Vec<_> = if let Some(cursor) = cursor {
55 scored
56 .into_iter()
57 .filter(|(score, vector)| {
58 *score < cursor.last_score
59 || (*score == cursor.last_score && vector.id > cursor.last_id)
60 })
61 .collect()
62 } else {
63 scored
64 };
65
66 let fetch_count = top_k + 1;
68 let fetched: Vec<_> = filtered.into_iter().take(fetch_count).collect();
69 let has_more = fetched.len() > top_k;
70
71 let results_slice = if has_more {
73 &fetched[..top_k]
74 } else {
75 &fetched[..]
76 };
77
78 let results: Vec<SearchResult> = results_slice
80 .iter()
81 .map(|(score, vector)| SearchResult {
82 id: vector.id.clone(),
83 score: *score,
84 metadata: if include_metadata {
85 vector.metadata.clone()
86 } else {
87 None
88 },
89 vector: if include_vectors {
90 Some(vector.values.clone())
91 } else {
92 None
93 },
94 })
95 .collect();
96
97 let next_cursor = if has_more {
99 results.last().map(|last_result| {
100 PaginationCursor::new(last_result.score, last_result.id.clone()).encode()
101 })
102 } else {
103 None
104 };
105
106 QueryResponse {
107 results,
108 next_cursor,
109 has_more: Some(has_more),
110 search_time_ms: 0,
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117 use serde_json::json;
118
119 fn make_vector(id: &str, values: Vec<f32>) -> Vector {
120 Vector {
121 id: id.to_string(),
122 values,
123 metadata: None,
124 ttl_seconds: None,
125 expires_at: None,
126 }
127 }
128
129 fn make_vector_with_metadata(
130 id: &str,
131 values: Vec<f32>,
132 metadata: serde_json::Value,
133 ) -> Vector {
134 Vector {
135 id: id.to_string(),
136 values,
137 metadata: Some(metadata),
138 ttl_seconds: None,
139 expires_at: None,
140 }
141 }
142
143 #[test]
144 fn test_brute_force_search_empty() {
145 let query = vec![1.0, 0.0, 0.0];
146 let vectors: Vec<Vector> = vec![];
147
148 let result = brute_force_search(
149 &query,
150 &vectors,
151 5,
152 DistanceMetric::Cosine,
153 true,
154 false,
155 None,
156 );
157 assert!(result.results.is_empty());
158 assert_eq!(result.has_more, Some(false));
159 assert!(result.next_cursor.is_none());
160 }
161
162 #[test]
163 fn test_brute_force_search_single_vector() {
164 let query = vec![1.0, 0.0, 0.0];
165 let vectors = vec![make_vector("v1", vec![1.0, 0.0, 0.0])];
166
167 let result = brute_force_search(
168 &query,
169 &vectors,
170 5,
171 DistanceMetric::Cosine,
172 true,
173 false,
174 None,
175 );
176 assert_eq!(result.results.len(), 1);
177 assert_eq!(result.results[0].id, "v1");
178 assert!((result.results[0].score - 1.0).abs() < 1e-6);
179 assert_eq!(result.has_more, Some(false));
180 }
181
182 #[test]
183 fn test_brute_force_search_ordering() {
184 let query = vec![1.0, 0.0, 0.0];
185 let vectors = vec![
186 make_vector("v1", vec![1.0, 0.0, 0.0]), make_vector("v2", vec![0.0, 1.0, 0.0]), make_vector("v3", vec![0.707, 0.707, 0.0]), ];
190
191 let result = brute_force_search(
192 &query,
193 &vectors,
194 3,
195 DistanceMetric::Cosine,
196 true,
197 false,
198 None,
199 );
200
201 assert_eq!(result.results.len(), 3);
202 assert_eq!(result.results[0].id, "v1"); assert_eq!(result.results[1].id, "v3"); assert_eq!(result.results[2].id, "v2"); }
206
207 #[test]
208 fn test_brute_force_search_top_k() {
209 let query = vec![1.0, 0.0];
210 let vectors = vec![
211 make_vector("v1", vec![1.0, 0.0]),
212 make_vector("v2", vec![0.9, 0.1]),
213 make_vector("v3", vec![0.8, 0.2]),
214 make_vector("v4", vec![0.7, 0.3]),
215 make_vector("v5", vec![0.6, 0.4]),
216 ];
217
218 let result = brute_force_search(
219 &query,
220 &vectors,
221 3,
222 DistanceMetric::Cosine,
223 true,
224 false,
225 None,
226 );
227
228 assert_eq!(result.results.len(), 3);
229 assert_eq!(result.results[0].id, "v1");
230 assert_eq!(result.has_more, Some(true)); assert!(result.next_cursor.is_some());
232 }
233
234 #[test]
235 fn test_brute_force_search_include_metadata() {
236 let query = vec![1.0, 0.0];
237 let vectors = vec![make_vector_with_metadata(
238 "v1",
239 vec![1.0, 0.0],
240 json!({"key": "value"}),
241 )];
242
243 let result = brute_force_search(
245 &query,
246 &vectors,
247 1,
248 DistanceMetric::Cosine,
249 true,
250 false,
251 None,
252 );
253 assert!(result.results[0].metadata.is_some());
254
255 let result = brute_force_search(
257 &query,
258 &vectors,
259 1,
260 DistanceMetric::Cosine,
261 false,
262 false,
263 None,
264 );
265 assert!(result.results[0].metadata.is_none());
266 }
267
268 #[test]
269 fn test_brute_force_search_include_vectors() {
270 let query = vec![1.0, 0.0];
271 let vectors = vec![make_vector("v1", vec![1.0, 0.0])];
272
273 let result = brute_force_search(
275 &query,
276 &vectors,
277 1,
278 DistanceMetric::Cosine,
279 false,
280 true,
281 None,
282 );
283 assert!(result.results[0].vector.is_some());
284 assert_eq!(result.results[0].vector.as_ref().unwrap(), &vec![1.0, 0.0]);
285
286 let result = brute_force_search(
288 &query,
289 &vectors,
290 1,
291 DistanceMetric::Cosine,
292 false,
293 false,
294 None,
295 );
296 assert!(result.results[0].vector.is_none());
297 }
298
299 #[test]
300 fn test_brute_force_search_euclidean() {
301 let query = vec![0.0, 0.0];
302 let vectors = vec![
303 make_vector("v1", vec![1.0, 0.0]), make_vector("v2", vec![3.0, 4.0]), make_vector("v3", vec![0.5, 0.0]), ];
307
308 let result = brute_force_search(
309 &query,
310 &vectors,
311 3,
312 DistanceMetric::Euclidean,
313 false,
314 false,
315 None,
316 );
317
318 assert_eq!(result.results[0].id, "v3");
320 assert_eq!(result.results[1].id, "v1");
321 assert_eq!(result.results[2].id, "v2");
322 }
323
324 #[test]
326 fn test_pagination_basic() {
327 let query = vec![1.0, 0.0];
328 let vectors = vec![
329 make_vector("v1", vec![1.0, 0.0]),
330 make_vector("v2", vec![0.9, 0.1]),
331 make_vector("v3", vec![0.8, 0.2]),
332 make_vector("v4", vec![0.7, 0.3]),
333 make_vector("v5", vec![0.6, 0.4]),
334 ];
335
336 let result1 = brute_force_search(
338 &query,
339 &vectors,
340 2,
341 DistanceMetric::Cosine,
342 false,
343 false,
344 None,
345 );
346 assert_eq!(result1.results.len(), 2);
347 assert_eq!(result1.results[0].id, "v1");
348 assert_eq!(result1.results[1].id, "v2");
349 assert_eq!(result1.has_more, Some(true));
350 assert!(result1.next_cursor.is_some());
351
352 let cursor = PaginationCursor::decode(result1.next_cursor.as_ref().unwrap()).unwrap();
354 let result2 = brute_force_search(
355 &query,
356 &vectors,
357 2,
358 DistanceMetric::Cosine,
359 false,
360 false,
361 Some(&cursor),
362 );
363 assert_eq!(result2.results.len(), 2);
364 assert_eq!(result2.results[0].id, "v3");
365 assert_eq!(result2.results[1].id, "v4");
366 assert_eq!(result2.has_more, Some(true));
367
368 let cursor2 = PaginationCursor::decode(result2.next_cursor.as_ref().unwrap()).unwrap();
370 let result3 = brute_force_search(
371 &query,
372 &vectors,
373 2,
374 DistanceMetric::Cosine,
375 false,
376 false,
377 Some(&cursor2),
378 );
379 assert_eq!(result3.results.len(), 1);
380 assert_eq!(result3.results[0].id, "v5");
381 assert_eq!(result3.has_more, Some(false));
382 assert!(result3.next_cursor.is_none());
383 }
384
385 #[test]
386 fn test_pagination_cursor_encode_decode() {
387 let cursor = PaginationCursor::new(0.95, "test_id".to_string());
388 let encoded = cursor.encode();
389 let decoded = PaginationCursor::decode(&encoded).unwrap();
390
391 assert!((decoded.last_score - 0.95).abs() < 1e-6);
392 assert_eq!(decoded.last_id, "test_id");
393 }
394
395 #[test]
396 fn test_pagination_with_tie_scores() {
397 let query = vec![1.0, 0.0];
398 let vectors = vec![
400 make_vector("a", vec![1.0, 0.0]),
401 make_vector("b", vec![1.0, 0.0]),
402 make_vector("c", vec![1.0, 0.0]),
403 make_vector("d", vec![1.0, 0.0]),
404 ];
405
406 let result1 = brute_force_search(
408 &query,
409 &vectors,
410 2,
411 DistanceMetric::Cosine,
412 false,
413 false,
414 None,
415 );
416 assert_eq!(result1.results.len(), 2);
417 assert_eq!(result1.results[0].id, "a");
419 assert_eq!(result1.results[1].id, "b");
420 assert_eq!(result1.has_more, Some(true));
421
422 let cursor = PaginationCursor::decode(result1.next_cursor.as_ref().unwrap()).unwrap();
424 let result2 = brute_force_search(
425 &query,
426 &vectors,
427 2,
428 DistanceMetric::Cosine,
429 false,
430 false,
431 Some(&cursor),
432 );
433 assert_eq!(result2.results.len(), 2);
434 assert_eq!(result2.results[0].id, "c");
435 assert_eq!(result2.results[1].id, "d");
436 assert_eq!(result2.has_more, Some(false));
437 }
438
439 #[test]
440 fn test_pagination_no_more_results() {
441 let query = vec![1.0, 0.0];
442 let vectors = vec![
443 make_vector("v1", vec![1.0, 0.0]),
444 make_vector("v2", vec![0.9, 0.1]),
445 ];
446
447 let result = brute_force_search(
448 &query,
449 &vectors,
450 5,
451 DistanceMetric::Cosine,
452 false,
453 false,
454 None,
455 );
456 assert_eq!(result.results.len(), 2);
457 assert_eq!(result.has_more, Some(false));
458 assert!(result.next_cursor.is_none());
459 }
460}