Skip to main content

engine/
advanced_search.rs

1//! Advanced Search Features for Dakera
2//!
3//! Provides enhanced search capabilities:
4//! - Multi-vector queries (query by multiple vectors)
5//! - Negative vectors (exclusion/avoidance)
6//! - MMR (Maximal Marginal Relevance) for diversity
7//! - Range queries (distance threshold)
8//! - Result aggregation and grouping
9
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, HashSet};
12
13use common::{DistanceMetric, VectorId};
14
15/// Configuration for advanced search operations
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct AdvancedSearchConfig {
18    /// Enable MMR diversity ranking
19    pub enable_mmr: bool,
20    /// Lambda parameter for MMR (0 = max diversity, 1 = max relevance)
21    pub mmr_lambda: f32,
22    /// Maximum candidates to consider for MMR reranking
23    pub mmr_candidates: usize,
24    /// Enable result grouping
25    pub enable_grouping: bool,
26    /// Field to group results by
27    pub group_by_field: Option<String>,
28    /// Maximum results per group
29    pub max_per_group: usize,
30}
31
32impl Default for AdvancedSearchConfig {
33    fn default() -> Self {
34        Self {
35            enable_mmr: false,
36            mmr_lambda: 0.5,
37            mmr_candidates: 100,
38            enable_grouping: false,
39            group_by_field: None,
40            max_per_group: 3,
41        }
42    }
43}
44
45/// Multi-vector query specification
46#[derive(Debug, Clone)]
47pub struct MultiVectorQuery {
48    /// Positive vectors (search towards these)
49    pub positive_vectors: Vec<Vec<f32>>,
50    /// Weights for positive vectors
51    pub positive_weights: Vec<f32>,
52    /// Negative vectors (search away from these)
53    pub negative_vectors: Vec<Vec<f32>>,
54    /// Weights for negative vectors
55    pub negative_weights: Vec<f32>,
56    /// Number of results to return
57    pub top_k: usize,
58    /// Optional distance threshold
59    pub distance_threshold: Option<f32>,
60}
61
62impl MultiVectorQuery {
63    /// Create a simple single-vector query
64    pub fn single(vector: Vec<f32>, top_k: usize) -> Self {
65        Self {
66            positive_vectors: vec![vector],
67            positive_weights: vec![1.0],
68            negative_vectors: Vec::new(),
69            negative_weights: Vec::new(),
70            top_k,
71            distance_threshold: None,
72        }
73    }
74
75    /// Create a multi-vector query
76    pub fn multi(vectors: Vec<Vec<f32>>, top_k: usize) -> Self {
77        let weights = vec![1.0 / vectors.len() as f32; vectors.len()];
78        Self {
79            positive_vectors: vectors,
80            positive_weights: weights,
81            negative_vectors: Vec::new(),
82            negative_weights: Vec::new(),
83            top_k,
84            distance_threshold: None,
85        }
86    }
87
88    /// Add a negative (avoidance) vector
89    pub fn with_negative(mut self, vector: Vec<f32>, weight: f32) -> Self {
90        self.negative_vectors.push(vector);
91        self.negative_weights.push(weight);
92        self
93    }
94
95    /// Set distance threshold for range query
96    pub fn with_threshold(mut self, threshold: f32) -> Self {
97        self.distance_threshold = Some(threshold);
98        self
99    }
100
101    /// Set custom weights for positive vectors
102    pub fn with_weights(mut self, weights: Vec<f32>) -> Self {
103        self.positive_weights = weights;
104        self
105    }
106
107    /// Compute the effective query vector (weighted combination)
108    pub fn compute_query_vector(&self, dimensions: usize) -> Vec<f32> {
109        let mut result = vec![0.0; dimensions];
110
111        // Add weighted positive vectors
112        for (vec, &weight) in self.positive_vectors.iter().zip(&self.positive_weights) {
113            for (i, &v) in vec.iter().enumerate() {
114                if i < dimensions {
115                    result[i] += v * weight;
116                }
117            }
118        }
119
120        // Subtract weighted negative vectors
121        for (vec, &weight) in self.negative_vectors.iter().zip(&self.negative_weights) {
122            for (i, &v) in vec.iter().enumerate() {
123                if i < dimensions {
124                    result[i] -= v * weight;
125                }
126            }
127        }
128
129        // Normalize the result
130        let norm: f32 = result.iter().map(|x| x * x).sum::<f32>().sqrt();
131        if norm > 0.0 {
132            for v in &mut result {
133                *v /= norm;
134            }
135        }
136
137        result
138    }
139}
140
141/// Result from advanced search
142#[derive(Debug, Clone, Serialize, Deserialize)]
143pub struct AdvancedSearchResult {
144    /// Vector ID
145    pub id: VectorId,
146    /// Similarity/distance score
147    pub score: f32,
148    /// Original rank before reranking
149    pub original_rank: usize,
150    /// Final rank after reranking
151    pub final_rank: usize,
152    /// MMR score (if enabled)
153    pub mmr_score: Option<f32>,
154    /// Group key (if grouping enabled)
155    pub group_key: Option<String>,
156}
157
158/// MMR (Maximal Marginal Relevance) reranker for diversity
159pub struct MmrReranker {
160    /// Lambda parameter (0 = max diversity, 1 = max relevance)
161    lambda: f32,
162}
163
164impl MmrReranker {
165    /// Create a new MMR reranker
166    pub fn new(lambda: f32) -> Self {
167        Self {
168            lambda: lambda.clamp(0.0, 1.0),
169        }
170    }
171
172    /// Rerank results using MMR for diversity
173    ///
174    /// MMR = λ * Sim(d, q) - (1 - λ) * max(Sim(d, d_j))
175    /// where d_j are already selected documents
176    pub fn rerank(
177        &self,
178        candidates: &[(VectorId, f32, Vec<f32>)], // (id, score, vector)
179        top_k: usize,
180    ) -> Vec<AdvancedSearchResult> {
181        if candidates.is_empty() {
182            return Vec::new();
183        }
184
185        let mut selected: Vec<usize> = Vec::with_capacity(top_k);
186        let mut remaining: HashSet<usize> = (0..candidates.len()).collect();
187        let mut results = Vec::with_capacity(top_k);
188
189        // Select first item (highest relevance)
190        let first_idx = candidates
191            .iter()
192            .enumerate()
193            .max_by(|a, b| {
194                a.1 .1
195                    .partial_cmp(&b.1 .1)
196                    .unwrap_or(std::cmp::Ordering::Equal)
197            })
198            .map(|(i, _)| i)
199            .unwrap_or(0);
200
201        selected.push(first_idx);
202        remaining.remove(&first_idx);
203        results.push(AdvancedSearchResult {
204            id: candidates[first_idx].0.clone(),
205            score: candidates[first_idx].1,
206            original_rank: first_idx,
207            final_rank: 0,
208            mmr_score: Some(candidates[first_idx].1),
209            group_key: None,
210        });
211
212        // Iteratively select remaining items
213        while results.len() < top_k && !remaining.is_empty() {
214            let mut best_idx = None;
215            let mut best_mmr = f32::NEG_INFINITY;
216
217            for &idx in &remaining {
218                let relevance = candidates[idx].1;
219
220                // Compute max similarity to already selected items
221                let max_sim = selected
222                    .iter()
223                    .map(|&sel_idx| {
224                        self.cosine_similarity(&candidates[idx].2, &candidates[sel_idx].2)
225                    })
226                    .fold(f32::NEG_INFINITY, f32::max);
227
228                // MMR score
229                let mmr = self.lambda * relevance - (1.0 - self.lambda) * max_sim;
230
231                if mmr > best_mmr {
232                    best_mmr = mmr;
233                    best_idx = Some(idx);
234                }
235            }
236
237            if let Some(idx) = best_idx {
238                selected.push(idx);
239                remaining.remove(&idx);
240                results.push(AdvancedSearchResult {
241                    id: candidates[idx].0.clone(),
242                    score: candidates[idx].1,
243                    original_rank: idx,
244                    final_rank: results.len(),
245                    mmr_score: Some(best_mmr),
246                    group_key: None,
247                });
248            } else {
249                break;
250            }
251        }
252
253        results
254    }
255
256    /// Compute cosine similarity between two vectors
257    fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
258        let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
259        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
260        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
261
262        if norm_a > 0.0 && norm_b > 0.0 {
263            dot / (norm_a * norm_b)
264        } else {
265            0.0
266        }
267    }
268}
269
270/// Range query executor for distance-based filtering
271pub struct RangeQuery {
272    /// Distance metric to use
273    metric: DistanceMetric,
274    /// Maximum distance threshold
275    threshold: f32,
276}
277
278impl RangeQuery {
279    /// Create a new range query
280    pub fn new(metric: DistanceMetric, threshold: f32) -> Self {
281        Self { metric, threshold }
282    }
283
284    /// Filter results by distance threshold
285    pub fn filter(&self, results: Vec<(VectorId, f32)>) -> Vec<(VectorId, f32)> {
286        results
287            .into_iter()
288            .filter(|(_, score)| self.passes_threshold(*score))
289            .collect()
290    }
291
292    /// Check if a score passes the threshold
293    fn passes_threshold(&self, score: f32) -> bool {
294        match self.metric {
295            // For cosine/dot product, higher is better
296            DistanceMetric::Cosine | DistanceMetric::DotProduct => score >= self.threshold,
297            // For euclidean, lower is better (negative euclidean, so higher is better)
298            DistanceMetric::Euclidean => score >= -self.threshold,
299        }
300    }
301}
302
303/// Result grouper for organizing results by field
304pub struct ResultGrouper {
305    /// Field to group by
306    group_field: String,
307    /// Maximum results per group
308    max_per_group: usize,
309}
310
311impl ResultGrouper {
312    /// Create a new result grouper
313    pub fn new(group_field: String, max_per_group: usize) -> Self {
314        Self {
315            group_field,
316            max_per_group,
317        }
318    }
319
320    /// Group results by field value
321    pub fn group(
322        &self,
323        results: Vec<(VectorId, f32, Option<serde_json::Value>)>,
324    ) -> HashMap<String, Vec<(VectorId, f32)>> {
325        let mut groups: HashMap<String, Vec<(VectorId, f32)>> = HashMap::new();
326
327        for (id, score, metadata) in results {
328            let group_key = metadata
329                .and_then(|m| m.get(&self.group_field).cloned())
330                .and_then(|v| match v {
331                    serde_json::Value::String(s) => Some(s),
332                    serde_json::Value::Number(n) => Some(n.to_string()),
333                    _ => None,
334                })
335                .unwrap_or_else(|| "_ungrouped".to_string());
336
337            let group = groups.entry(group_key).or_default();
338            if group.len() < self.max_per_group {
339                group.push((id, score));
340            }
341        }
342
343        groups
344    }
345}
346
347/// Advanced search executor combining all features
348pub struct AdvancedSearchExecutor {
349    config: AdvancedSearchConfig,
350}
351
352impl AdvancedSearchExecutor {
353    /// Create a new advanced search executor
354    pub fn new(config: AdvancedSearchConfig) -> Self {
355        Self { config }
356    }
357
358    /// Process search results with advanced features
359    pub fn process_results(
360        &self,
361        candidates: Vec<(VectorId, f32, Vec<f32>)>,
362        query: &MultiVectorQuery,
363    ) -> Vec<AdvancedSearchResult> {
364        let mut results = candidates;
365
366        // Apply distance threshold if specified
367        if let Some(threshold) = query.distance_threshold {
368            results.retain(|(_, score, _)| *score >= threshold);
369        }
370
371        // Apply MMR if enabled
372        if self.config.enable_mmr {
373            let reranker = MmrReranker::new(self.config.mmr_lambda);
374            return reranker.rerank(&results, query.top_k);
375        }
376
377        // Otherwise, return top-k by score
378        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
379        results.truncate(query.top_k);
380
381        results
382            .into_iter()
383            .enumerate()
384            .map(|(rank, (id, score, _))| AdvancedSearchResult {
385                id,
386                score,
387                original_rank: rank,
388                final_rank: rank,
389                mmr_score: None,
390                group_key: None,
391            })
392            .collect()
393    }
394
395    /// Apply negative vector penalty to scores
396    pub fn apply_negative_penalty(
397        &self,
398        results: &mut [(VectorId, f32, Vec<f32>)],
399        negative_vectors: &[Vec<f32>],
400        negative_weights: &[f32],
401    ) {
402        for (_, score, vec) in results.iter_mut() {
403            for (neg_vec, &weight) in negative_vectors.iter().zip(negative_weights) {
404                // Compute similarity to negative vector
405                let sim = self.cosine_similarity(vec, neg_vec);
406                // Penalize score based on similarity to negative vector
407                *score -= sim * weight;
408            }
409        }
410    }
411
412    fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
413        let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
414        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
415        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
416
417        if norm_a > 0.0 && norm_b > 0.0 {
418            dot / (norm_a * norm_b)
419        } else {
420            0.0
421        }
422    }
423}
424
425/// Search statistics for monitoring
426#[derive(Debug, Clone, Default, Serialize, Deserialize)]
427pub struct SearchStats {
428    /// Total candidates considered
429    pub candidates_considered: usize,
430    /// Results after threshold filtering
431    pub after_threshold: usize,
432    /// Results after MMR reranking
433    pub after_mmr: usize,
434    /// Number of groups (if grouping enabled)
435    pub num_groups: usize,
436    /// Search latency in milliseconds
437    pub latency_ms: u64,
438}
439
440#[cfg(test)]
441mod tests {
442    use super::*;
443
444    #[test]
445    fn test_multi_vector_query() {
446        let query = MultiVectorQuery::multi(vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]], 10);
447
448        assert_eq!(query.positive_vectors.len(), 2);
449        assert_eq!(query.positive_weights.len(), 2);
450        assert_eq!(query.positive_weights[0], 0.5);
451    }
452
453    #[test]
454    fn test_query_vector_computation() {
455        let query = MultiVectorQuery::single(vec![1.0, 0.0, 0.0], 10);
456        let computed = query.compute_query_vector(3);
457
458        assert_eq!(computed.len(), 3);
459        assert!((computed[0] - 1.0).abs() < 0.01);
460    }
461
462    #[test]
463    fn test_negative_vector() {
464        let query = MultiVectorQuery::single(vec![1.0, 0.0, 0.0], 10)
465            .with_negative(vec![0.0, 1.0, 0.0], 0.5);
466
467        assert_eq!(query.negative_vectors.len(), 1);
468        assert_eq!(query.negative_weights[0], 0.5);
469    }
470
471    #[test]
472    fn test_mmr_reranker() {
473        let reranker = MmrReranker::new(0.5);
474
475        // Create candidates with varying similarities
476        let candidates = vec![
477            ("a".to_string(), 0.9, vec![1.0, 0.0, 0.0]),
478            ("b".to_string(), 0.85, vec![0.95, 0.1, 0.0]), // Similar to a
479            ("c".to_string(), 0.8, vec![0.0, 1.0, 0.0]),   // Different from a
480            ("d".to_string(), 0.75, vec![0.0, 0.0, 1.0]),  // Different from both
481        ];
482
483        let results = reranker.rerank(&candidates, 3);
484
485        assert_eq!(results.len(), 3);
486        // First should be highest relevance
487        assert_eq!(results[0].id, "a");
488        // Due to MMR, "c" or "d" should rank higher than "b" (more diverse)
489    }
490
491    #[test]
492    fn test_range_query() {
493        let range = RangeQuery::new(DistanceMetric::Cosine, 0.8);
494
495        let results = vec![
496            ("a".to_string(), 0.95),
497            ("b".to_string(), 0.75), // Below threshold
498            ("c".to_string(), 0.85),
499        ];
500
501        let filtered = range.filter(results);
502        assert_eq!(filtered.len(), 2);
503        assert!(filtered.iter().all(|(_, s)| *s >= 0.8));
504    }
505
506    #[test]
507    fn test_result_grouper() {
508        let grouper = ResultGrouper::new("category".to_string(), 2);
509
510        let results = vec![
511            (
512                "a".to_string(),
513                0.9,
514                Some(serde_json::json!({"category": "tech"})),
515            ),
516            (
517                "b".to_string(),
518                0.85,
519                Some(serde_json::json!({"category": "tech"})),
520            ),
521            (
522                "c".to_string(),
523                0.8,
524                Some(serde_json::json!({"category": "tech"})),
525            ), // Should be excluded (max 2)
526            (
527                "d".to_string(),
528                0.75,
529                Some(serde_json::json!({"category": "science"})),
530            ),
531        ];
532
533        let groups = grouper.group(results);
534
535        assert_eq!(groups.len(), 2);
536        assert_eq!(groups["tech"].len(), 2);
537        assert_eq!(groups["science"].len(), 1);
538    }
539
540    #[test]
541    fn test_advanced_search_executor() {
542        let config = AdvancedSearchConfig {
543            enable_mmr: false,
544            ..Default::default()
545        };
546        let executor = AdvancedSearchExecutor::new(config);
547
548        let candidates = vec![
549            ("a".to_string(), 0.9, vec![1.0, 0.0]),
550            ("b".to_string(), 0.8, vec![0.0, 1.0]),
551            ("c".to_string(), 0.7, vec![0.5, 0.5]),
552        ];
553
554        let query = MultiVectorQuery::single(vec![1.0, 0.0], 2);
555        let results = executor.process_results(candidates, &query);
556
557        assert_eq!(results.len(), 2);
558        assert_eq!(results[0].id, "a");
559        assert_eq!(results[1].id, "b");
560    }
561
562    #[test]
563    fn test_threshold_filtering() {
564        let config = AdvancedSearchConfig::default();
565        let executor = AdvancedSearchExecutor::new(config);
566
567        let candidates = vec![
568            ("a".to_string(), 0.9, vec![1.0, 0.0]),
569            ("b".to_string(), 0.5, vec![0.0, 1.0]), // Below threshold
570            ("c".to_string(), 0.85, vec![0.5, 0.5]),
571        ];
572
573        let query = MultiVectorQuery::single(vec![1.0, 0.0], 10).with_threshold(0.7);
574        let results = executor.process_results(candidates, &query);
575
576        assert_eq!(results.len(), 2);
577        assert!(results.iter().all(|r| r.score >= 0.7));
578    }
579}