1use super::filter::Filter;
7use super::sort::QueryLimits;
8use crate::storage::engine::distance::DistanceMetric;
9use crate::storage::engine::vector_store::{SearchResult, VectorCollection, VectorId};
10use crate::storage::schema::Value;
11use std::collections::HashMap;
12
13#[derive(Debug, Clone)]
15pub struct DenseVector {
16 values: Vec<f32>,
17}
18
19impl DenseVector {
20 pub fn new(values: Vec<f32>) -> Self {
21 Self { values }
22 }
23
24 pub fn as_slice(&self) -> &[f32] {
25 &self.values
26 }
27}
28
29impl From<Vec<f32>> for DenseVector {
30 fn from(values: Vec<f32>) -> Self {
31 Self { values }
32 }
33}
34
35#[derive(Debug, Clone)]
37pub struct SimilarityQuery {
38 pub vector: DenseVector,
40 pub k: usize,
42 pub distance: DistanceMetric,
44 pub filter: Option<Filter>,
46 pub n_probes: Option<usize>,
48 pub distance_threshold: Option<f32>,
50}
51
52impl SimilarityQuery {
53 pub fn new(vector: DenseVector, k: usize) -> Self {
55 Self {
56 vector,
57 k,
58 distance: DistanceMetric::Cosine,
59 filter: None,
60 n_probes: None,
61 distance_threshold: None,
62 }
63 }
64
65 pub fn with_distance(mut self, distance: DistanceMetric) -> Self {
67 self.distance = distance;
68 self
69 }
70
71 pub fn with_filter(mut self, filter: Filter) -> Self {
73 self.filter = Some(filter);
74 self
75 }
76
77 pub fn with_probes(mut self, n_probes: usize) -> Self {
79 self.n_probes = Some(n_probes);
80 self
81 }
82
83 pub fn with_threshold(mut self, threshold: f32) -> Self {
85 self.distance_threshold = Some(threshold);
86 self
87 }
88}
89
90#[derive(Debug, Clone)]
92pub struct SimilarityResult {
93 pub id: VectorId,
95 pub distance: f32,
97 pub score: f32,
99 pub metadata: Option<HashMap<String, Value>>,
101}
102
103impl SimilarityResult {
104 pub fn new(id: VectorId, distance: f32) -> Self {
106 Self {
107 id,
108 distance,
109 score: 1.0 / (1.0 + distance), metadata: None,
111 }
112 }
113
114 pub fn with_metric(id: VectorId, distance: f32, metric: DistanceMetric) -> Self {
116 let score = match metric {
117 DistanceMetric::Cosine => 1.0 - distance, DistanceMetric::InnerProduct => -distance, DistanceMetric::L2 => 1.0 / (1.0 + distance),
120 };
121
122 Self {
123 id,
124 distance,
125 score: score.max(0.0),
126 metadata: None,
127 }
128 }
129
130 pub fn with_metadata(mut self, metadata: HashMap<String, Value>) -> Self {
132 self.metadata = Some(metadata);
133 self
134 }
135}
136
137#[derive(Debug, Clone)]
139pub struct SimilarityResultSet {
140 pub results: Vec<SimilarityResult>,
142 pub dimension: usize,
144 pub distance: DistanceMetric,
146 pub vectors_searched: Option<usize>,
148 pub search_time_us: u64,
150}
151
152impl SimilarityResultSet {
153 pub fn empty(dimension: usize, distance: DistanceMetric) -> Self {
155 Self {
156 results: Vec::new(),
157 dimension,
158 distance,
159 vectors_searched: None,
160 search_time_us: 0,
161 }
162 }
163
164 pub fn from_results(
166 results: Vec<SearchResult>,
167 dimension: usize,
168 distance: DistanceMetric,
169 ) -> Self {
170 let similarity_results = results
171 .into_iter()
172 .map(|r| SimilarityResult::with_metric(r.id, r.distance, distance))
173 .collect();
174
175 Self {
176 results: similarity_results,
177 dimension,
178 distance,
179 vectors_searched: None,
180 search_time_us: 0,
181 }
182 }
183
184 pub fn len(&self) -> usize {
186 self.results.len()
187 }
188
189 pub fn is_empty(&self) -> bool {
191 self.results.is_empty()
192 }
193
194 pub fn top_ids(&self, k: usize) -> Vec<VectorId> {
196 self.results.iter().take(k).map(|r| r.id).collect()
197 }
198
199 pub fn above_score(&self, threshold: f32) -> Vec<&SimilarityResult> {
201 self.results
202 .iter()
203 .filter(|r| r.score >= threshold)
204 .collect()
205 }
206
207 pub fn apply_limits(mut self, limits: QueryLimits) -> Self {
209 self.results = limits.apply(self.results);
210 self
211 }
212}
213
214pub trait VectorIndex: Send + Sync {
216 fn search(&self, query: &DenseVector, k: usize) -> Vec<SearchResult>;
218
219 fn search_with_params(
221 &self,
222 query: &DenseVector,
223 k: usize,
224 n_probes: Option<usize>,
225 ) -> Vec<SearchResult>;
226
227 fn get(&self, id: VectorId) -> Option<DenseVector>;
229
230 fn dimension(&self) -> usize;
232
233 fn distance_metric(&self) -> DistanceMetric;
235
236 fn len(&self) -> usize;
238
239 fn is_empty(&self) -> bool {
241 self.len() == 0
242 }
243}
244
245impl VectorIndex for VectorCollection {
246 fn search(&self, query: &DenseVector, k: usize) -> Vec<SearchResult> {
247 VectorCollection::search(self, query.as_slice(), k)
248 }
249
250 fn search_with_params(
251 &self,
252 query: &DenseVector,
253 k: usize,
254 _n_probes: Option<usize>,
255 ) -> Vec<SearchResult> {
256 VectorCollection::search(self, query.as_slice(), k)
257 }
258
259 fn get(&self, id: VectorId) -> Option<DenseVector> {
260 VectorCollection::get(self, id).map(|vec| DenseVector::new(vec.clone()))
261 }
262
263 fn dimension(&self) -> usize {
264 self.dimension
265 }
266
267 fn distance_metric(&self) -> DistanceMetric {
268 self.metric
269 }
270
271 fn len(&self) -> usize {
272 self.len()
273 }
274}
275
276pub fn execute_similarity_search(
278 index: &dyn VectorIndex,
279 query: &SimilarityQuery,
280) -> SimilarityResultSet {
281 let start = std::time::Instant::now();
282
283 let results = if let Some(threshold) = query.distance_threshold {
285 let candidates = index.search_with_params(&query.vector, query.k * 10, query.n_probes);
287 candidates
288 .into_iter()
289 .filter(|r| r.distance <= threshold)
290 .take(query.k)
291 .collect()
292 } else {
293 index.search_with_params(&query.vector, query.k, query.n_probes)
294 };
295
296 let search_time = start.elapsed().as_micros() as u64;
297
298 let mut result_set =
299 SimilarityResultSet::from_results(results, index.dimension(), index.distance_metric());
300 result_set.search_time_us = search_time;
301 result_set.vectors_searched = Some(index.len());
302
303 result_set
304}
305
306pub fn execute_hybrid_search<F>(
308 index: &dyn VectorIndex,
309 query: &SimilarityQuery,
310 get_metadata: F,
311 filter_matches: impl Fn(VectorId, &Filter) -> bool,
312) -> SimilarityResultSet
313where
314 F: Fn(VectorId) -> Option<HashMap<String, Value>>,
315{
316 let start = std::time::Instant::now();
317
318 let over_fetch = if query.filter.is_some() { 10 } else { 1 };
320 let candidates = index.search_with_params(&query.vector, query.k * over_fetch, query.n_probes);
321
322 let results: Vec<SimilarityResult> = candidates
324 .into_iter()
325 .filter(|r| {
326 if let Some(filter) = &query.filter {
327 filter_matches(r.id, filter)
328 } else {
329 true
330 }
331 })
332 .take(query.k)
333 .map(|r| {
334 let mut result =
335 SimilarityResult::with_metric(r.id, r.distance, index.distance_metric());
336 if let Some(meta) = get_metadata(r.id) {
337 result = result.with_metadata(meta);
338 }
339 result
340 })
341 .collect();
342
343 let search_time = start.elapsed().as_micros() as u64;
344
345 SimilarityResultSet {
346 results,
347 dimension: index.dimension(),
348 distance: index.distance_metric(),
349 vectors_searched: Some(index.len()),
350 search_time_us: search_time,
351 }
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357
358 fn create_test_index() -> VectorCollection {
359 let mut collection = VectorCollection::new("test", 3).with_metric(DistanceMetric::Cosine);
360
361 let _ = collection.insert(vec![1.0, 0.0, 0.0], None);
363 let _ = collection.insert(vec![0.0, 1.0, 0.0], None);
364 let _ = collection.insert(vec![0.0, 0.0, 1.0], None);
365 let _ = collection.insert(vec![0.7, 0.7, 0.0], None);
366 let _ = collection.insert(vec![0.5, 0.5, 0.7], None);
367
368 collection
369 }
370
371 #[test]
372 fn test_similarity_query_basic() {
373 let index = create_test_index();
374
375 let query = SimilarityQuery::new(DenseVector::new(vec![1.0, 0.0, 0.0]), 3);
376 let results = execute_similarity_search(&index, &query);
377
378 assert_eq!(results.len(), 3);
379 assert_eq!(results.results[0].id, 0); assert!(results.results[0].distance < 0.01);
381 }
382
383 #[test]
384 fn test_similarity_result_score() {
385 let result = SimilarityResult::with_metric(1, 0.0, DistanceMetric::Cosine);
387 assert!((result.score - 1.0).abs() < 0.01);
388
389 let result = SimilarityResult::with_metric(1, 1.0, DistanceMetric::Cosine);
391 assert!(result.score < 0.01);
392 }
393
394 #[test]
395 fn test_similarity_result_set_top_ids() {
396 let index = create_test_index();
397
398 let query = SimilarityQuery::new(DenseVector::new(vec![1.0, 0.0, 0.0]), 5);
399 let results = execute_similarity_search(&index, &query);
400
401 let top3 = results.top_ids(3);
402 assert_eq!(top3.len(), 3);
403 assert_eq!(top3[0], 0);
404 }
405
406 #[test]
407 fn test_similarity_threshold() {
408 let index = create_test_index();
409
410 let query =
412 SimilarityQuery::new(DenseVector::new(vec![1.0, 0.0, 0.0]), 10).with_threshold(0.5);
413
414 let results = execute_similarity_search(&index, &query);
415
416 for result in &results.results {
418 assert!(result.distance <= 0.5);
419 }
420 }
421
422 #[test]
423 fn test_vector_index_trait() {
424 let index = create_test_index();
425
426 let index_ref: &dyn VectorIndex = &index;
427
428 assert_eq!(index_ref.dimension(), 3);
429 assert_eq!(index_ref.len(), 5);
430 assert!(!index_ref.is_empty());
431
432 let vec = index_ref.get(0).unwrap();
433 assert_eq!(vec.as_slice(), &[1.0, 0.0, 0.0]);
434 }
435
436 #[test]
437 fn test_above_score_filter() {
438 let results = SimilarityResultSet {
439 results: vec![
440 SimilarityResult::new(1, 0.1), SimilarityResult::new(2, 0.5), SimilarityResult::new(3, 2.0), ],
444 dimension: 3,
445 distance: DistanceMetric::L2,
446 vectors_searched: Some(100),
447 search_time_us: 100,
448 };
449
450 let above_05 = results.above_score(0.5);
451 assert_eq!(above_05.len(), 2); }
453
454 #[test]
455 fn test_similarity_query_builder() {
456 let query = SimilarityQuery::new(DenseVector::new(vec![1.0, 0.0, 0.0]), 10)
457 .with_distance(DistanceMetric::L2)
458 .with_probes(5)
459 .with_threshold(1.0);
460
461 assert_eq!(query.k, 10);
462 assert_eq!(query.distance, DistanceMetric::L2);
463 assert_eq!(query.n_probes, Some(5));
464 assert_eq!(query.distance_threshold, Some(1.0));
465 }
466
467 #[test]
468 fn test_hybrid_search_with_filter() {
469 let index = create_test_index();
470
471 let metadata: HashMap<VectorId, HashMap<String, Value>> = [
473 (
474 1,
475 [("category".to_string(), Value::text("A".to_string()))]
476 .into_iter()
477 .collect(),
478 ),
479 (
480 2,
481 [("category".to_string(), Value::text("B".to_string()))]
482 .into_iter()
483 .collect(),
484 ),
485 (
486 3,
487 [("category".to_string(), Value::text("A".to_string()))]
488 .into_iter()
489 .collect(),
490 ),
491 (
492 4,
493 [("category".to_string(), Value::text("B".to_string()))]
494 .into_iter()
495 .collect(),
496 ),
497 (
498 5,
499 [("category".to_string(), Value::text("A".to_string()))]
500 .into_iter()
501 .collect(),
502 ),
503 ]
504 .into_iter()
505 .collect();
506
507 let filter = Filter::eq("category", Value::text("A".to_string()));
508 let query = SimilarityQuery::new(DenseVector::new(vec![1.0, 0.0, 0.0]), 5)
509 .with_filter(filter.clone());
510
511 let results = execute_hybrid_search(
512 &index,
513 &query,
514 |id| metadata.get(&id).cloned(),
515 |id, filter| {
516 if let Some(meta) = metadata.get(&id) {
517 filter.evaluate(&|col| meta.get(col).cloned())
518 } else {
519 false
520 }
521 },
522 );
523
524 assert!(results.len() <= 3); for result in &results.results {
527 if let Some(meta) = &result.metadata {
528 assert_eq!(meta.get("category"), Some(&Value::text("A".to_string())));
529 }
530 }
531 }
532
533 #[test]
534 fn test_apply_limits() {
535 let results = SimilarityResultSet {
536 results: (0..10)
537 .map(|i| SimilarityResult::new(i, i as f32 * 0.1))
538 .collect(),
539 dimension: 3,
540 distance: DistanceMetric::L2,
541 vectors_searched: Some(100),
542 search_time_us: 100,
543 };
544
545 let limited = results.apply_limits(QueryLimits::none().offset(2).limit(3));
546 assert_eq!(limited.len(), 3);
547 assert_eq!(limited.results[0].id, 2);
548 }
549}