1use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, HashSet};
12
13use common::{DistanceMetric, VectorId};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct AdvancedSearchConfig {
18 pub enable_mmr: bool,
20 pub mmr_lambda: f32,
22 pub mmr_candidates: usize,
24 pub enable_grouping: bool,
26 pub group_by_field: Option<String>,
28 pub max_per_group: usize,
30}
31
32impl Default for AdvancedSearchConfig {
33 fn default() -> Self {
34 Self {
35 enable_mmr: false,
36 mmr_lambda: 0.5,
37 mmr_candidates: 100,
38 enable_grouping: false,
39 group_by_field: None,
40 max_per_group: 3,
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
47pub struct MultiVectorQuery {
48 pub positive_vectors: Vec<Vec<f32>>,
50 pub positive_weights: Vec<f32>,
52 pub negative_vectors: Vec<Vec<f32>>,
54 pub negative_weights: Vec<f32>,
56 pub top_k: usize,
58 pub distance_threshold: Option<f32>,
60}
61
62impl MultiVectorQuery {
63 pub fn single(vector: Vec<f32>, top_k: usize) -> Self {
65 Self {
66 positive_vectors: vec![vector],
67 positive_weights: vec![1.0],
68 negative_vectors: Vec::new(),
69 negative_weights: Vec::new(),
70 top_k,
71 distance_threshold: None,
72 }
73 }
74
75 pub fn multi(vectors: Vec<Vec<f32>>, top_k: usize) -> Self {
77 let weights = vec![1.0 / vectors.len() as f32; vectors.len()];
78 Self {
79 positive_vectors: vectors,
80 positive_weights: weights,
81 negative_vectors: Vec::new(),
82 negative_weights: Vec::new(),
83 top_k,
84 distance_threshold: None,
85 }
86 }
87
88 pub fn with_negative(mut self, vector: Vec<f32>, weight: f32) -> Self {
90 self.negative_vectors.push(vector);
91 self.negative_weights.push(weight);
92 self
93 }
94
95 pub fn with_threshold(mut self, threshold: f32) -> Self {
97 self.distance_threshold = Some(threshold);
98 self
99 }
100
101 pub fn with_weights(mut self, weights: Vec<f32>) -> Self {
103 self.positive_weights = weights;
104 self
105 }
106
107 pub fn compute_query_vector(&self, dimensions: usize) -> Vec<f32> {
109 let mut result = vec![0.0; dimensions];
110
111 for (vec, &weight) in self.positive_vectors.iter().zip(&self.positive_weights) {
113 for (i, &v) in vec.iter().enumerate() {
114 if i < dimensions {
115 result[i] += v * weight;
116 }
117 }
118 }
119
120 for (vec, &weight) in self.negative_vectors.iter().zip(&self.negative_weights) {
122 for (i, &v) in vec.iter().enumerate() {
123 if i < dimensions {
124 result[i] -= v * weight;
125 }
126 }
127 }
128
129 let norm: f32 = result.iter().map(|x| x * x).sum::<f32>().sqrt();
131 if norm > 0.0 {
132 for v in &mut result {
133 *v /= norm;
134 }
135 }
136
137 result
138 }
139}
140
141#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct AdvancedSearchResult {
144 pub id: VectorId,
146 pub score: f32,
148 pub original_rank: usize,
150 pub final_rank: usize,
152 pub mmr_score: Option<f32>,
154 pub group_key: Option<String>,
156}
157
158pub struct MmrReranker {
160 lambda: f32,
162}
163
164impl MmrReranker {
165 pub fn new(lambda: f32) -> Self {
167 Self {
168 lambda: lambda.clamp(0.0, 1.0),
169 }
170 }
171
172 pub fn rerank(
177 &self,
178 candidates: &[(VectorId, f32, Vec<f32>)], top_k: usize,
180 ) -> Vec<AdvancedSearchResult> {
181 if candidates.is_empty() {
182 return Vec::new();
183 }
184
185 let mut selected: Vec<usize> = Vec::with_capacity(top_k);
186 let mut remaining: HashSet<usize> = (0..candidates.len()).collect();
187 let mut results = Vec::with_capacity(top_k);
188
189 let first_idx = candidates
191 .iter()
192 .enumerate()
193 .max_by(|a, b| {
194 a.1 .1
195 .partial_cmp(&b.1 .1)
196 .unwrap_or(std::cmp::Ordering::Equal)
197 })
198 .map(|(i, _)| i)
199 .unwrap_or(0);
200
201 selected.push(first_idx);
202 remaining.remove(&first_idx);
203 results.push(AdvancedSearchResult {
204 id: candidates[first_idx].0.clone(),
205 score: candidates[first_idx].1,
206 original_rank: first_idx,
207 final_rank: 0,
208 mmr_score: Some(candidates[first_idx].1),
209 group_key: None,
210 });
211
212 while results.len() < top_k && !remaining.is_empty() {
214 let mut best_idx = None;
215 let mut best_mmr = f32::NEG_INFINITY;
216
217 for &idx in &remaining {
218 let relevance = candidates[idx].1;
219
220 let max_sim = selected
222 .iter()
223 .map(|&sel_idx| {
224 self.cosine_similarity(&candidates[idx].2, &candidates[sel_idx].2)
225 })
226 .fold(f32::NEG_INFINITY, f32::max);
227
228 let mmr = self.lambda * relevance - (1.0 - self.lambda) * max_sim;
230
231 if mmr > best_mmr {
232 best_mmr = mmr;
233 best_idx = Some(idx);
234 }
235 }
236
237 if let Some(idx) = best_idx {
238 selected.push(idx);
239 remaining.remove(&idx);
240 results.push(AdvancedSearchResult {
241 id: candidates[idx].0.clone(),
242 score: candidates[idx].1,
243 original_rank: idx,
244 final_rank: results.len(),
245 mmr_score: Some(best_mmr),
246 group_key: None,
247 });
248 } else {
249 break;
250 }
251 }
252
253 results
254 }
255
256 fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
258 let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
259 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
260 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
261
262 if norm_a > 0.0 && norm_b > 0.0 {
263 dot / (norm_a * norm_b)
264 } else {
265 0.0
266 }
267 }
268}
269
270pub struct RangeQuery {
272 metric: DistanceMetric,
274 threshold: f32,
276}
277
278impl RangeQuery {
279 pub fn new(metric: DistanceMetric, threshold: f32) -> Self {
281 Self { metric, threshold }
282 }
283
284 pub fn filter(&self, results: Vec<(VectorId, f32)>) -> Vec<(VectorId, f32)> {
286 results
287 .into_iter()
288 .filter(|(_, score)| self.passes_threshold(*score))
289 .collect()
290 }
291
292 fn passes_threshold(&self, score: f32) -> bool {
294 match self.metric {
295 DistanceMetric::Cosine | DistanceMetric::DotProduct => score >= self.threshold,
297 DistanceMetric::Euclidean => score >= -self.threshold,
299 }
300 }
301}
302
303pub struct ResultGrouper {
305 group_field: String,
307 max_per_group: usize,
309}
310
311impl ResultGrouper {
312 pub fn new(group_field: String, max_per_group: usize) -> Self {
314 Self {
315 group_field,
316 max_per_group,
317 }
318 }
319
320 pub fn group(
322 &self,
323 results: Vec<(VectorId, f32, Option<serde_json::Value>)>,
324 ) -> HashMap<String, Vec<(VectorId, f32)>> {
325 let mut groups: HashMap<String, Vec<(VectorId, f32)>> = HashMap::new();
326
327 for (id, score, metadata) in results {
328 let group_key = metadata
329 .and_then(|m| m.get(&self.group_field).cloned())
330 .and_then(|v| match v {
331 serde_json::Value::String(s) => Some(s),
332 serde_json::Value::Number(n) => Some(n.to_string()),
333 _ => None,
334 })
335 .unwrap_or_else(|| "_ungrouped".to_string());
336
337 let group = groups.entry(group_key).or_default();
338 if group.len() < self.max_per_group {
339 group.push((id, score));
340 }
341 }
342
343 groups
344 }
345}
346
347pub struct AdvancedSearchExecutor {
349 config: AdvancedSearchConfig,
350}
351
352impl AdvancedSearchExecutor {
353 pub fn new(config: AdvancedSearchConfig) -> Self {
355 Self { config }
356 }
357
358 pub fn process_results(
360 &self,
361 candidates: Vec<(VectorId, f32, Vec<f32>)>,
362 query: &MultiVectorQuery,
363 ) -> Vec<AdvancedSearchResult> {
364 let mut results = candidates;
365
366 if let Some(threshold) = query.distance_threshold {
368 results.retain(|(_, score, _)| *score >= threshold);
369 }
370
371 if self.config.enable_mmr {
373 let reranker = MmrReranker::new(self.config.mmr_lambda);
374 return reranker.rerank(&results, query.top_k);
375 }
376
377 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
379 results.truncate(query.top_k);
380
381 results
382 .into_iter()
383 .enumerate()
384 .map(|(rank, (id, score, _))| AdvancedSearchResult {
385 id,
386 score,
387 original_rank: rank,
388 final_rank: rank,
389 mmr_score: None,
390 group_key: None,
391 })
392 .collect()
393 }
394
395 pub fn apply_negative_penalty(
397 &self,
398 results: &mut [(VectorId, f32, Vec<f32>)],
399 negative_vectors: &[Vec<f32>],
400 negative_weights: &[f32],
401 ) {
402 for (_, score, vec) in results.iter_mut() {
403 for (neg_vec, &weight) in negative_vectors.iter().zip(negative_weights) {
404 let sim = self.cosine_similarity(vec, neg_vec);
406 *score -= sim * weight;
408 }
409 }
410 }
411
412 fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
413 let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
414 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
415 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
416
417 if norm_a > 0.0 && norm_b > 0.0 {
418 dot / (norm_a * norm_b)
419 } else {
420 0.0
421 }
422 }
423}
424
425#[derive(Debug, Clone, Default, Serialize, Deserialize)]
427pub struct SearchStats {
428 pub candidates_considered: usize,
430 pub after_threshold: usize,
432 pub after_mmr: usize,
434 pub num_groups: usize,
436 pub latency_ms: u64,
438}
439
440#[cfg(test)]
441mod tests {
442 use super::*;
443
444 #[test]
445 fn test_multi_vector_query() {
446 let query = MultiVectorQuery::multi(vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]], 10);
447
448 assert_eq!(query.positive_vectors.len(), 2);
449 assert_eq!(query.positive_weights.len(), 2);
450 assert_eq!(query.positive_weights[0], 0.5);
451 }
452
453 #[test]
454 fn test_query_vector_computation() {
455 let query = MultiVectorQuery::single(vec![1.0, 0.0, 0.0], 10);
456 let computed = query.compute_query_vector(3);
457
458 assert_eq!(computed.len(), 3);
459 assert!((computed[0] - 1.0).abs() < 0.01);
460 }
461
462 #[test]
463 fn test_negative_vector() {
464 let query = MultiVectorQuery::single(vec![1.0, 0.0, 0.0], 10)
465 .with_negative(vec![0.0, 1.0, 0.0], 0.5);
466
467 assert_eq!(query.negative_vectors.len(), 1);
468 assert_eq!(query.negative_weights[0], 0.5);
469 }
470
471 #[test]
472 fn test_mmr_reranker() {
473 let reranker = MmrReranker::new(0.5);
474
475 let candidates = vec![
477 ("a".to_string(), 0.9, vec![1.0, 0.0, 0.0]),
478 ("b".to_string(), 0.85, vec![0.95, 0.1, 0.0]), ("c".to_string(), 0.8, vec![0.0, 1.0, 0.0]), ("d".to_string(), 0.75, vec![0.0, 0.0, 1.0]), ];
482
483 let results = reranker.rerank(&candidates, 3);
484
485 assert_eq!(results.len(), 3);
486 assert_eq!(results[0].id, "a");
488 }
490
491 #[test]
492 fn test_range_query() {
493 let range = RangeQuery::new(DistanceMetric::Cosine, 0.8);
494
495 let results = vec![
496 ("a".to_string(), 0.95),
497 ("b".to_string(), 0.75), ("c".to_string(), 0.85),
499 ];
500
501 let filtered = range.filter(results);
502 assert_eq!(filtered.len(), 2);
503 assert!(filtered.iter().all(|(_, s)| *s >= 0.8));
504 }
505
506 #[test]
507 fn test_result_grouper() {
508 let grouper = ResultGrouper::new("category".to_string(), 2);
509
510 let results = vec![
511 (
512 "a".to_string(),
513 0.9,
514 Some(serde_json::json!({"category": "tech"})),
515 ),
516 (
517 "b".to_string(),
518 0.85,
519 Some(serde_json::json!({"category": "tech"})),
520 ),
521 (
522 "c".to_string(),
523 0.8,
524 Some(serde_json::json!({"category": "tech"})),
525 ), (
527 "d".to_string(),
528 0.75,
529 Some(serde_json::json!({"category": "science"})),
530 ),
531 ];
532
533 let groups = grouper.group(results);
534
535 assert_eq!(groups.len(), 2);
536 assert_eq!(groups["tech"].len(), 2);
537 assert_eq!(groups["science"].len(), 1);
538 }
539
540 #[test]
541 fn test_advanced_search_executor() {
542 let config = AdvancedSearchConfig {
543 enable_mmr: false,
544 ..Default::default()
545 };
546 let executor = AdvancedSearchExecutor::new(config);
547
548 let candidates = vec![
549 ("a".to_string(), 0.9, vec![1.0, 0.0]),
550 ("b".to_string(), 0.8, vec![0.0, 1.0]),
551 ("c".to_string(), 0.7, vec![0.5, 0.5]),
552 ];
553
554 let query = MultiVectorQuery::single(vec![1.0, 0.0], 2);
555 let results = executor.process_results(candidates, &query);
556
557 assert_eq!(results.len(), 2);
558 assert_eq!(results[0].id, "a");
559 assert_eq!(results[1].id, "b");
560 }
561
562 #[test]
563 fn test_threshold_filtering() {
564 let config = AdvancedSearchConfig::default();
565 let executor = AdvancedSearchExecutor::new(config);
566
567 let candidates = vec![
568 ("a".to_string(), 0.9, vec![1.0, 0.0]),
569 ("b".to_string(), 0.5, vec![0.0, 1.0]), ("c".to_string(), 0.85, vec![0.5, 0.5]),
571 ];
572
573 let query = MultiVectorQuery::single(vec![1.0, 0.0], 10).with_threshold(0.7);
574 let results = executor.process_results(candidates, &query);
575
576 assert_eq!(results.len(), 2);
577 assert!(results.iter().all(|r| r.score >= 0.7));
578 }
579}