Skip to main content

hermes_core/query/
vector.rs

1//! Vector query types for dense and sparse vector search
2
3use crate::dsl::Field;
4use crate::segment::{SegmentReader, VectorSearchResult};
5use crate::{DocId, Score, TERMINATED};
6
7use super::ScoredPosition;
8use super::traits::{CountFuture, MatchedPositions, Query, Scorer, ScorerFuture};
9
10/// Strategy for combining scores when a document has multiple values for the same field
11#[derive(Debug, Clone, Copy, PartialEq)]
12pub enum MultiValueCombiner {
13    /// Sum all scores (accumulates dot product contributions)
14    Sum,
15    /// Take the maximum score
16    Max,
17    /// Take the average score
18    Avg,
19    /// Log-Sum-Exp: smooth maximum approximation (default)
20    /// `score = (1/t) * log(Σ exp(t * sᵢ))`
21    /// Higher temperature → closer to max; lower → closer to mean
22    LogSumExp {
23        /// Temperature parameter (default: 1.5)
24        temperature: f32,
25    },
26    /// Weighted Top-K: weight top scores with exponential decay
27    /// `score = Σ wᵢ * sorted_scores[i]` where `wᵢ = decay^i`
28    WeightedTopK {
29        /// Number of top scores to consider (default: 5)
30        k: usize,
31        /// Decay factor per rank (default: 0.7)
32        decay: f32,
33    },
34}
35
36impl Default for MultiValueCombiner {
37    fn default() -> Self {
38        // LogSumExp with temperature 1.5 provides good balance between
39        // max (best relevance) and sum (saturation from multiple matches)
40        MultiValueCombiner::LogSumExp { temperature: 1.5 }
41    }
42}
43
44impl MultiValueCombiner {
45    /// Create LogSumExp combiner with default temperature (1.5)
46    pub fn log_sum_exp() -> Self {
47        Self::LogSumExp { temperature: 1.5 }
48    }
49
50    /// Create LogSumExp combiner with custom temperature
51    pub fn log_sum_exp_with_temperature(temperature: f32) -> Self {
52        Self::LogSumExp { temperature }
53    }
54
55    /// Create WeightedTopK combiner with defaults (k=5, decay=0.7)
56    pub fn weighted_top_k() -> Self {
57        Self::WeightedTopK { k: 5, decay: 0.7 }
58    }
59
60    /// Create WeightedTopK combiner with custom parameters
61    pub fn weighted_top_k_with_params(k: usize, decay: f32) -> Self {
62        Self::WeightedTopK { k, decay }
63    }
64
65    /// Combine multiple scores into a single score
66    pub fn combine(&self, scores: &[(u32, f32)]) -> f32 {
67        if scores.is_empty() {
68            return 0.0;
69        }
70
71        match self {
72            MultiValueCombiner::Sum => scores.iter().map(|(_, s)| s).sum(),
73            MultiValueCombiner::Max => scores
74                .iter()
75                .map(|(_, s)| *s)
76                .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
77                .unwrap_or(0.0),
78            MultiValueCombiner::Avg => {
79                let sum: f32 = scores.iter().map(|(_, s)| s).sum();
80                sum / scores.len() as f32
81            }
82            MultiValueCombiner::LogSumExp { temperature } => {
83                // Numerically stable log-sum-exp:
84                // LSE(x) = max(x) + log(Σ exp(xᵢ - max(x)))
85                let t = *temperature;
86                let max_score = scores
87                    .iter()
88                    .map(|(_, s)| *s)
89                    .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
90                    .unwrap_or(0.0);
91
92                let sum_exp: f32 = scores
93                    .iter()
94                    .map(|(_, s)| (t * (s - max_score)).exp())
95                    .sum();
96
97                max_score + sum_exp.ln() / t
98            }
99            MultiValueCombiner::WeightedTopK { k, decay } => {
100                // Sort scores descending and take top k
101                let mut sorted: Vec<f32> = scores.iter().map(|(_, s)| *s).collect();
102                sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
103                sorted.truncate(*k);
104
105                // Apply exponential decay weights
106                let mut weight = 1.0f32;
107                let mut weighted_sum = 0.0f32;
108                let mut weight_total = 0.0f32;
109
110                for score in sorted {
111                    weighted_sum += weight * score;
112                    weight_total += weight;
113                    weight *= decay;
114                }
115
116                if weight_total > 0.0 {
117                    weighted_sum / weight_total
118                } else {
119                    0.0
120                }
121            }
122        }
123    }
124}
125
126/// Dense vector query for similarity search
127#[derive(Debug, Clone)]
128pub struct DenseVectorQuery {
129    /// Field containing the dense vectors
130    pub field: Field,
131    /// Query vector
132    pub vector: Vec<f32>,
133    /// Number of clusters to probe (for IVF indexes)
134    pub nprobe: usize,
135    /// Re-ranking factor (multiplied by k for candidate selection, e.g. 3.0)
136    pub rerank_factor: f32,
137    /// How to combine scores for multi-valued documents
138    pub combiner: MultiValueCombiner,
139}
140
141impl std::fmt::Display for DenseVectorQuery {
142    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143        write!(
144            f,
145            "Dense({}, dim={}, nprobe={}, rerank={})",
146            self.field.0,
147            self.vector.len(),
148            self.nprobe,
149            self.rerank_factor
150        )
151    }
152}
153
154impl DenseVectorQuery {
155    /// Create a new dense vector query
156    pub fn new(field: Field, vector: Vec<f32>) -> Self {
157        Self {
158            field,
159            vector,
160            nprobe: 32,
161            rerank_factor: 3.0,
162            combiner: MultiValueCombiner::Max,
163        }
164    }
165
166    /// Set the number of clusters to probe (for IVF indexes)
167    pub fn with_nprobe(mut self, nprobe: usize) -> Self {
168        self.nprobe = nprobe;
169        self
170    }
171
172    /// Set the re-ranking factor (e.g. 3.0 = fetch 3x candidates for reranking)
173    pub fn with_rerank_factor(mut self, factor: f32) -> Self {
174        self.rerank_factor = factor;
175        self
176    }
177
178    /// Set the multi-value score combiner
179    pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
180        self.combiner = combiner;
181        self
182    }
183}
184
185impl Query for DenseVectorQuery {
186    fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
187        let field = self.field;
188        let vector = self.vector.clone();
189        let nprobe = self.nprobe;
190        let rerank_factor = self.rerank_factor;
191        let combiner = self.combiner;
192        Box::pin(async move {
193            let results = reader
194                .search_dense_vector(field, &vector, limit, nprobe, rerank_factor, combiner)
195                .await?;
196
197            Ok(Box::new(DenseVectorScorer::new(results, field.0)) as Box<dyn Scorer>)
198        })
199    }
200
201    #[cfg(feature = "sync")]
202    fn scorer_sync<'a>(
203        &self,
204        reader: &'a SegmentReader,
205        limit: usize,
206    ) -> crate::Result<Box<dyn Scorer + 'a>> {
207        let results = reader.search_dense_vector_sync(
208            self.field,
209            &self.vector,
210            limit,
211            self.nprobe,
212            self.rerank_factor,
213            self.combiner,
214        )?;
215        Ok(Box::new(DenseVectorScorer::new(results, self.field.0)) as Box<dyn Scorer>)
216    }
217
218    fn count_estimate<'a>(&self, _reader: &'a SegmentReader) -> CountFuture<'a> {
219        Box::pin(async move { Ok(u32::MAX) })
220    }
221}
222
223/// Scorer for dense vector search results with ordinal tracking
224struct DenseVectorScorer {
225    results: Vec<VectorSearchResult>,
226    position: usize,
227    field_id: u32,
228}
229
230impl DenseVectorScorer {
231    fn new(mut results: Vec<VectorSearchResult>, field_id: u32) -> Self {
232        // Sort by doc_id ascending — DocSet contract requires monotonic doc IDs
233        results.sort_unstable_by_key(|r| r.doc_id);
234        Self {
235            results,
236            position: 0,
237            field_id,
238        }
239    }
240}
241
242impl super::docset::DocSet for DenseVectorScorer {
243    fn doc(&self) -> DocId {
244        if self.position < self.results.len() {
245            self.results[self.position].doc_id
246        } else {
247            TERMINATED
248        }
249    }
250
251    fn advance(&mut self) -> DocId {
252        self.position += 1;
253        self.doc()
254    }
255
256    fn seek(&mut self, target: DocId) -> DocId {
257        // Binary search within remaining results for O(log k) seek
258        let remaining = &self.results[self.position..];
259        let offset = remaining.partition_point(|r| r.doc_id < target);
260        self.position += offset;
261        self.doc()
262    }
263
264    fn size_hint(&self) -> u32 {
265        (self.results.len() - self.position) as u32
266    }
267}
268
269impl Scorer for DenseVectorScorer {
270    fn score(&self) -> Score {
271        if self.position < self.results.len() {
272            self.results[self.position].score
273        } else {
274            0.0
275        }
276    }
277
278    fn matched_positions(&self) -> Option<MatchedPositions> {
279        if self.position >= self.results.len() {
280            return None;
281        }
282        let result = &self.results[self.position];
283        let scored_positions: Vec<ScoredPosition> = result
284            .ordinals
285            .iter()
286            .map(|(ordinal, score)| ScoredPosition::new(*ordinal, *score))
287            .collect();
288        Some(vec![(self.field_id, scored_positions)])
289    }
290}
291
292/// Sparse vector query for similarity search
293#[derive(Debug, Clone)]
294pub struct SparseVectorQuery {
295    /// Field containing the sparse vectors
296    pub field: Field,
297    /// Query vector as (dimension_id, weight) pairs
298    pub vector: Vec<(u32, f32)>,
299    /// How to combine scores for multi-valued documents
300    pub combiner: MultiValueCombiner,
301    /// Approximate search factor (1.0 = exact, lower values = faster but approximate)
302    /// Controls MaxScore pruning aggressiveness in block-max scoring
303    pub heap_factor: f32,
304    /// Minimum abs(weight) for query dimensions (0.0 = no filtering)
305    /// Dimensions below this threshold are dropped before search.
306    pub weight_threshold: f32,
307    /// Maximum number of query dimensions to process (None = all)
308    /// Keeps only the top-k dimensions by abs(weight).
309    pub max_query_dims: Option<usize>,
310    /// Fraction of query dimensions to keep (0.0-1.0), same semantics as
311    /// indexing-time `pruning`: sort by abs(weight) descending,
312    /// keep top fraction. None or 1.0 = no pruning.
313    pub pruning: Option<f32>,
314    /// Multiplier on executor limit for ordinal deduplication (1.0 = no over-fetch)
315    pub over_fetch_factor: f32,
316    /// Cached pruned vector; None = use `vector` as-is (no pruning applied)
317    pruned: Option<Vec<(u32, f32)>>,
318}
319
320impl std::fmt::Display for SparseVectorQuery {
321    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
322        let dims = self.pruned_dims();
323        write!(f, "Sparse({}, dims={}", self.field.0, dims.len())?;
324        if self.heap_factor < 1.0 {
325            write!(f, ", heap={}", self.heap_factor)?;
326        }
327        if self.vector.len() != dims.len() {
328            write!(f, ", orig={}", self.vector.len())?;
329        }
330        write!(f, ")")
331    }
332}
333
334impl SparseVectorQuery {
335    /// Create a new sparse vector query
336    ///
337    /// Default combiner is `LogSumExp { temperature: 0.7 }` which provides
338    /// saturation for documents with many sparse vectors (e.g., 100+ ordinals).
339    /// This prevents over-weighting from multiple matches while still allowing
340    /// additional matches to contribute to the score.
341    pub fn new(field: Field, vector: Vec<(u32, f32)>) -> Self {
342        Self {
343            field,
344            vector,
345            combiner: MultiValueCombiner::LogSumExp { temperature: 0.7 },
346            heap_factor: 1.0,
347            weight_threshold: 0.0,
348            max_query_dims: None,
349            pruning: None,
350            over_fetch_factor: 2.0,
351            pruned: None,
352        }
353    }
354
355    /// Effective query dimensions after pruning. Returns `vector` if no pruning is configured.
356    fn pruned_dims(&self) -> &[(u32, f32)] {
357        self.pruned.as_deref().unwrap_or(&self.vector)
358    }
359
360    /// Set the multi-value score combiner
361    pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
362        self.combiner = combiner;
363        self
364    }
365
366    /// Set executor over-fetch factor for multi-valued fields.
367    /// After MaxScore execution, ordinal combining may reduce result count;
368    /// this multiplier compensates by fetching more from the executor.
369    /// (1.0 = no over-fetch, 2.0 = fetch 2x then combine down)
370    pub fn with_over_fetch_factor(mut self, factor: f32) -> Self {
371        self.over_fetch_factor = factor.max(1.0);
372        self
373    }
374
375    /// Set the heap factor for approximate search
376    ///
377    /// Controls the trade-off between speed and recall:
378    /// - 1.0 = exact search (default)
379    /// - 0.8-0.9 = ~20-40% faster with minimal recall loss
380    /// - Lower values = more aggressive pruning, faster but lower recall
381    pub fn with_heap_factor(mut self, heap_factor: f32) -> Self {
382        self.heap_factor = heap_factor.clamp(0.0, 1.0);
383        self
384    }
385
386    /// Set minimum weight threshold for query dimensions
387    /// Dimensions with abs(weight) below this are dropped before search.
388    pub fn with_weight_threshold(mut self, threshold: f32) -> Self {
389        self.weight_threshold = threshold;
390        self.pruned = Some(self.compute_pruned_vector());
391        self
392    }
393
394    /// Set maximum number of query dimensions (top-k by weight)
395    pub fn with_max_query_dims(mut self, max_dims: usize) -> Self {
396        self.max_query_dims = Some(max_dims);
397        self.pruned = Some(self.compute_pruned_vector());
398        self
399    }
400
401    /// Set pruning fraction (0.0-1.0): keep top fraction of query dims by weight.
402    /// Same semantics as indexing-time `pruning`.
403    pub fn with_pruning(mut self, fraction: f32) -> Self {
404        self.pruning = Some(fraction.clamp(0.0, 1.0));
405        self.pruned = Some(self.compute_pruned_vector());
406        self
407    }
408
409    /// Apply weight_threshold, pruning, and max_query_dims, returning the pruned vector.
410    fn compute_pruned_vector(&self) -> Vec<(u32, f32)> {
411        let original_len = self.vector.len();
412
413        // Step 1: weight_threshold — drop dimensions below minimum weight
414        let mut v: Vec<(u32, f32)> = if self.weight_threshold > 0.0 {
415            self.vector
416                .iter()
417                .copied()
418                .filter(|(_, w)| w.abs() >= self.weight_threshold)
419                .collect()
420        } else {
421            self.vector.clone()
422        };
423        let after_threshold = v.len();
424
425        // Step 2: pruning — keep top fraction by abs(weight), same as indexing
426        let mut sorted_by_weight = false;
427        if let Some(fraction) = self.pruning
428            && fraction < 1.0
429            && v.len() > 1
430        {
431            v.sort_unstable_by(|a, b| {
432                b.1.abs()
433                    .partial_cmp(&a.1.abs())
434                    .unwrap_or(std::cmp::Ordering::Equal)
435            });
436            sorted_by_weight = true;
437            let keep = ((v.len() as f64 * fraction as f64).ceil() as usize).max(1);
438            v.truncate(keep);
439        }
440        let after_pruning = v.len();
441
442        // Step 3: max_query_dims — absolute cap on dimensions
443        if let Some(max_dims) = self.max_query_dims
444            && v.len() > max_dims
445        {
446            if !sorted_by_weight {
447                v.sort_unstable_by(|a, b| {
448                    b.1.abs()
449                        .partial_cmp(&a.1.abs())
450                        .unwrap_or(std::cmp::Ordering::Equal)
451                });
452            }
453            v.truncate(max_dims);
454        }
455
456        if v.len() < original_len {
457            let src: Vec<_> = self
458                .vector
459                .iter()
460                .map(|(d, w)| format!("({},{:.4})", d, w))
461                .collect();
462            let pruned_fmt: Vec<_> = v.iter().map(|(d, w)| format!("({},{:.4})", d, w)).collect();
463            log::debug!(
464                "[sparse query] field={}: pruned {}->{} dims \
465                 (threshold: {}->{}, pruning: {}->{}, max_dims: {}->{}), \
466                 source=[{}], pruned=[{}]",
467                self.field.0,
468                original_len,
469                v.len(),
470                original_len,
471                after_threshold,
472                after_threshold,
473                after_pruning,
474                after_pruning,
475                v.len(),
476                src.join(", "),
477                pruned_fmt.join(", "),
478            );
479        }
480
481        v
482    }
483
484    /// Create from separate indices and weights vectors
485    pub fn from_indices_weights(field: Field, indices: Vec<u32>, weights: Vec<f32>) -> Self {
486        let vector: Vec<(u32, f32)> = indices.into_iter().zip(weights).collect();
487        Self::new(field, vector)
488    }
489
490    /// Create from raw text using a HuggingFace tokenizer (single segment)
491    ///
492    /// This method tokenizes the text and creates a sparse vector query.
493    /// For multi-segment indexes, use `from_text_with_stats` instead.
494    ///
495    /// # Arguments
496    /// * `field` - The sparse vector field to search
497    /// * `text` - Raw text to tokenize
498    /// * `tokenizer_name` - HuggingFace tokenizer path (e.g., "bert-base-uncased")
499    /// * `weighting` - Weighting strategy for tokens
500    /// * `sparse_index` - Optional sparse index for IDF lookup (required for IDF weighting)
501    #[cfg(feature = "native")]
502    pub fn from_text(
503        field: Field,
504        text: &str,
505        tokenizer_name: &str,
506        weighting: crate::structures::QueryWeighting,
507        sparse_index: Option<&crate::segment::SparseIndex>,
508    ) -> crate::Result<Self> {
509        use crate::structures::QueryWeighting;
510        use crate::tokenizer::tokenizer_cache;
511
512        let tokenizer = tokenizer_cache().get_or_load(tokenizer_name)?;
513        let token_ids = tokenizer.tokenize_unique(text)?;
514
515        let weights: Vec<f32> = match weighting {
516            QueryWeighting::One => vec![1.0f32; token_ids.len()],
517            QueryWeighting::Idf => {
518                if let Some(index) = sparse_index {
519                    index.idf_weights(&token_ids)
520                } else {
521                    vec![1.0f32; token_ids.len()]
522                }
523            }
524            QueryWeighting::IdfFile => {
525                use crate::tokenizer::idf_weights_cache;
526                if let Some(idf) = idf_weights_cache().get_or_load(tokenizer_name) {
527                    token_ids.iter().map(|&id| idf.get(id)).collect()
528                } else {
529                    vec![1.0f32; token_ids.len()]
530                }
531            }
532        };
533
534        let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
535        Ok(Self::new(field, vector))
536    }
537
538    /// Create from raw text using global statistics (multi-segment)
539    ///
540    /// This is the recommended method for multi-segment indexes as it uses
541    /// aggregated IDF values across all segments for consistent ranking.
542    ///
543    /// # Arguments
544    /// * `field` - The sparse vector field to search
545    /// * `text` - Raw text to tokenize
546    /// * `tokenizer` - Pre-loaded HuggingFace tokenizer
547    /// * `weighting` - Weighting strategy for tokens
548    /// * `global_stats` - Global statistics for IDF computation
549    #[cfg(feature = "native")]
550    pub fn from_text_with_stats(
551        field: Field,
552        text: &str,
553        tokenizer: &crate::tokenizer::HfTokenizer,
554        weighting: crate::structures::QueryWeighting,
555        global_stats: Option<&super::GlobalStats>,
556    ) -> crate::Result<Self> {
557        use crate::structures::QueryWeighting;
558
559        let token_ids = tokenizer.tokenize_unique(text)?;
560
561        let weights: Vec<f32> = match weighting {
562            QueryWeighting::One => vec![1.0f32; token_ids.len()],
563            QueryWeighting::Idf => {
564                if let Some(stats) = global_stats {
565                    // Clamp to zero: negative weights don't make sense for IDF
566                    stats
567                        .sparse_idf_weights(field, &token_ids)
568                        .into_iter()
569                        .map(|w| w.max(0.0))
570                        .collect()
571                } else {
572                    vec![1.0f32; token_ids.len()]
573                }
574            }
575            QueryWeighting::IdfFile => {
576                // IdfFile requires a tokenizer name for HF model lookup;
577                // this code path doesn't have one, so fall back to 1.0
578                vec![1.0f32; token_ids.len()]
579            }
580        };
581
582        let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
583        Ok(Self::new(field, vector))
584    }
585
586    /// Create from raw text, loading tokenizer from index directory
587    ///
588    /// This method supports the `index://` prefix for tokenizer paths,
589    /// loading tokenizer.json from the index directory.
590    ///
591    /// # Arguments
592    /// * `field` - The sparse vector field to search
593    /// * `text` - Raw text to tokenize
594    /// * `tokenizer_bytes` - Tokenizer JSON bytes (pre-loaded from directory)
595    /// * `weighting` - Weighting strategy for tokens
596    /// * `global_stats` - Global statistics for IDF computation
597    #[cfg(feature = "native")]
598    pub fn from_text_with_tokenizer_bytes(
599        field: Field,
600        text: &str,
601        tokenizer_bytes: &[u8],
602        weighting: crate::structures::QueryWeighting,
603        global_stats: Option<&super::GlobalStats>,
604    ) -> crate::Result<Self> {
605        use crate::structures::QueryWeighting;
606        use crate::tokenizer::HfTokenizer;
607
608        let tokenizer = HfTokenizer::from_bytes(tokenizer_bytes)?;
609        let token_ids = tokenizer.tokenize_unique(text)?;
610
611        let weights: Vec<f32> = match weighting {
612            QueryWeighting::One => vec![1.0f32; token_ids.len()],
613            QueryWeighting::Idf => {
614                if let Some(stats) = global_stats {
615                    // Clamp to zero: negative weights don't make sense for IDF
616                    stats
617                        .sparse_idf_weights(field, &token_ids)
618                        .into_iter()
619                        .map(|w| w.max(0.0))
620                        .collect()
621                } else {
622                    vec![1.0f32; token_ids.len()]
623                }
624            }
625            QueryWeighting::IdfFile => {
626                // IdfFile requires a tokenizer name for HF model lookup;
627                // this code path doesn't have one, so fall back to 1.0
628                vec![1.0f32; token_ids.len()]
629            }
630        };
631
632        let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
633        Ok(Self::new(field, vector))
634    }
635}
636
637impl SparseVectorQuery {
638    /// Build the inner query for this sparse vector search against a segment.
639    /// Filters pruned dims to those present in the segment, then returns:
640    /// - None if no dims match
641    /// - A single SparseTermQuery if one dim matches
642    /// - A BooleanQuery of SHOULD SparseTermQuery clauses otherwise
643    fn build_inner_query(&self, reader: &SegmentReader) -> Option<Box<dyn Query>> {
644        let si = reader.sparse_index(self.field)?;
645        let matched: Vec<(u32, f32)> = self
646            .pruned_dims()
647            .iter()
648            .filter(|(d, _)| si.has_dimension(*d))
649            .copied()
650            .collect();
651        if matched.is_empty() {
652            return None;
653        }
654
655        let make_term = |(dim_id, weight)| {
656            SparseTermQuery::new(self.field, dim_id, weight)
657                .with_heap_factor(self.heap_factor)
658                .with_combiner(self.combiner)
659                .with_over_fetch_factor(self.over_fetch_factor)
660        };
661
662        if matched.len() == 1 {
663            return Some(Box::new(make_term(matched[0])));
664        }
665
666        let mut bool_q = super::BooleanQuery::new();
667        for dims in matched {
668            bool_q = bool_q.should(make_term(dims));
669        }
670        Some(Box::new(bool_q))
671    }
672}
673
674impl Query for SparseVectorQuery {
675    fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
676        match self.build_inner_query(reader) {
677            None => Box::pin(async { Ok(Box::new(super::EmptyScorer) as Box<dyn Scorer>) }),
678            Some(q) => q.scorer(reader, limit),
679        }
680    }
681
682    #[cfg(feature = "sync")]
683    fn scorer_sync<'a>(
684        &self,
685        reader: &'a SegmentReader,
686        limit: usize,
687    ) -> crate::Result<Box<dyn Scorer + 'a>> {
688        match self.build_inner_query(reader) {
689            None => Ok(Box::new(super::EmptyScorer) as Box<dyn Scorer + 'a>),
690            Some(q) => q.scorer_sync(reader, limit),
691        }
692    }
693
694    fn count_estimate<'a>(&self, _reader: &'a SegmentReader) -> CountFuture<'a> {
695        Box::pin(async move { Ok(u32::MAX) })
696    }
697}
698
699// ── SparseTermQuery: single sparse dimension query (like TermQuery for text) ──
700
701/// Query for a single sparse vector dimension.
702///
703/// Analogous to `TermQuery` for text: searches one dimension's posting list
704/// with a given weight. Multiple `SparseTermQuery` instances are combined as
705/// `BooleanQuery` SHOULD clauses to form a full sparse vector search.
706#[derive(Debug, Clone)]
707pub struct SparseTermQuery {
708    pub field: Field,
709    pub dim_id: u32,
710    pub weight: f32,
711    /// MaxScore heap factor (1.0 = exact, lower = approximate)
712    pub heap_factor: f32,
713    /// Multi-value combiner for ordinal deduplication
714    pub combiner: MultiValueCombiner,
715    /// Multiplier on executor limit to compensate for ordinal deduplication
716    pub over_fetch_factor: f32,
717}
718
719impl std::fmt::Display for SparseTermQuery {
720    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
721        write!(
722            f,
723            "SparseTerm({}, dim={}, w={:.3})",
724            self.field.0, self.dim_id, self.weight
725        )
726    }
727}
728
729impl SparseTermQuery {
730    pub fn new(field: Field, dim_id: u32, weight: f32) -> Self {
731        Self {
732            field,
733            dim_id,
734            weight,
735            heap_factor: 1.0,
736            combiner: MultiValueCombiner::default(),
737            over_fetch_factor: 2.0,
738        }
739    }
740
741    pub fn with_heap_factor(mut self, heap_factor: f32) -> Self {
742        self.heap_factor = heap_factor;
743        self
744    }
745
746    pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
747        self.combiner = combiner;
748        self
749    }
750
751    pub fn with_over_fetch_factor(mut self, factor: f32) -> Self {
752        self.over_fetch_factor = factor.max(1.0);
753        self
754    }
755
756    /// Create a SparseTermScorer from this query's config against a segment.
757    /// Returns EmptyScorer if the dimension doesn't exist.
758    fn make_scorer<'a>(
759        &self,
760        reader: &'a SegmentReader,
761    ) -> crate::Result<Option<SparseTermScorer<'a>>> {
762        let si = match reader.sparse_index(self.field) {
763            Some(si) => si,
764            None => return Ok(None),
765        };
766        let (skip_start, skip_count, global_max, block_data_offset) =
767            match si.get_skip_range_full(self.dim_id) {
768                Some(v) => v,
769                None => return Ok(None),
770            };
771        let cursor = super::TermCursor::sparse(
772            si,
773            self.weight,
774            skip_start,
775            skip_count,
776            global_max,
777            block_data_offset,
778        );
779        Ok(Some(SparseTermScorer {
780            cursor,
781            field_id: self.field.0,
782        }))
783    }
784}
785
786impl Query for SparseTermQuery {
787    fn scorer<'a>(&self, reader: &'a SegmentReader, _limit: usize) -> ScorerFuture<'a> {
788        let query = self.clone();
789        Box::pin(async move {
790            let mut scorer = match query.make_scorer(reader)? {
791                Some(s) => s,
792                None => return Ok(Box::new(super::EmptyScorer) as Box<dyn Scorer + 'a>),
793            };
794            scorer.cursor.ensure_block_loaded().await.ok();
795            Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
796        })
797    }
798
799    #[cfg(feature = "sync")]
800    fn scorer_sync<'a>(
801        &self,
802        reader: &'a SegmentReader,
803        _limit: usize,
804    ) -> crate::Result<Box<dyn Scorer + 'a>> {
805        let mut scorer = match self.make_scorer(reader)? {
806            Some(s) => s,
807            None => return Ok(Box::new(super::EmptyScorer) as Box<dyn Scorer + 'a>),
808        };
809        scorer.cursor.ensure_block_loaded_sync().ok();
810        Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
811    }
812
813    fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
814        let field = self.field;
815        let dim_id = self.dim_id;
816        Box::pin(async move {
817            let si = match reader.sparse_index(field) {
818                Some(si) => si,
819                None => return Ok(0),
820            };
821            match si.get_skip_range_full(dim_id) {
822                Some((_, skip_count, _, _)) => Ok((skip_count * 256) as u32),
823                None => Ok(0),
824            }
825        })
826    }
827
828    fn as_sparse_term_query_info(&self) -> Option<super::SparseTermQueryInfo> {
829        Some(super::SparseTermQueryInfo {
830            field: self.field,
831            dim_id: self.dim_id,
832            weight: self.weight,
833            heap_factor: self.heap_factor,
834            combiner: self.combiner,
835            over_fetch_factor: self.over_fetch_factor,
836        })
837    }
838}
839
840/// Lazy scorer for a single sparse dimension, backed by `TermCursor::Sparse`.
841///
842/// Iterates through the posting list block-by-block using sync I/O.
843/// Score for each doc = `query_weight * quantized_stored_weight`.
844struct SparseTermScorer<'a> {
845    cursor: super::TermCursor<'a>,
846    field_id: u32,
847}
848
849impl super::docset::DocSet for SparseTermScorer<'_> {
850    fn doc(&self) -> DocId {
851        let d = self.cursor.doc();
852        if d == u32::MAX { TERMINATED } else { d }
853    }
854
855    fn advance(&mut self) -> DocId {
856        match self.cursor.advance_sync() {
857            Ok(d) if d == u32::MAX => TERMINATED,
858            Ok(d) => d,
859            Err(_) => TERMINATED,
860        }
861    }
862
863    fn seek(&mut self, target: DocId) -> DocId {
864        match self.cursor.seek_sync(target) {
865            Ok(d) if d == u32::MAX => TERMINATED,
866            Ok(d) => d,
867            Err(_) => TERMINATED,
868        }
869    }
870
871    fn size_hint(&self) -> u32 {
872        0
873    }
874}
875
876impl Scorer for SparseTermScorer<'_> {
877    fn score(&self) -> Score {
878        self.cursor.score()
879    }
880
881    fn matched_positions(&self) -> Option<MatchedPositions> {
882        let ordinal = self.cursor.ordinal();
883        let score = self.cursor.score();
884        if score == 0.0 {
885            return None;
886        }
887        Some(vec![(
888            self.field_id,
889            vec![ScoredPosition::new(ordinal as u32, score)],
890        )])
891    }
892}
893
894#[cfg(test)]
895mod tests {
896    use super::*;
897    use crate::dsl::Field;
898
899    #[test]
900    fn test_dense_vector_query_builder() {
901        let query = DenseVectorQuery::new(Field(0), vec![1.0, 2.0, 3.0])
902            .with_nprobe(64)
903            .with_rerank_factor(5.0);
904
905        assert_eq!(query.field, Field(0));
906        assert_eq!(query.vector.len(), 3);
907        assert_eq!(query.nprobe, 64);
908        assert_eq!(query.rerank_factor, 5.0);
909    }
910
911    #[test]
912    fn test_sparse_vector_query_new() {
913        let sparse = vec![(1, 0.5), (5, 0.3), (10, 0.2)];
914        let query = SparseVectorQuery::new(Field(0), sparse.clone());
915
916        assert_eq!(query.field, Field(0));
917        assert_eq!(query.vector, sparse);
918    }
919
920    #[test]
921    fn test_sparse_vector_query_from_indices_weights() {
922        let query =
923            SparseVectorQuery::from_indices_weights(Field(0), vec![1, 5, 10], vec![0.5, 0.3, 0.2]);
924
925        assert_eq!(query.vector, vec![(1, 0.5), (5, 0.3), (10, 0.2)]);
926    }
927
928    #[test]
929    fn test_combiner_sum() {
930        let scores = vec![(0, 1.0), (1, 2.0), (2, 3.0)];
931        let combiner = MultiValueCombiner::Sum;
932        assert!((combiner.combine(&scores) - 6.0).abs() < 1e-6);
933    }
934
935    #[test]
936    fn test_combiner_max() {
937        let scores = vec![(0, 1.0), (1, 3.0), (2, 2.0)];
938        let combiner = MultiValueCombiner::Max;
939        assert!((combiner.combine(&scores) - 3.0).abs() < 1e-6);
940    }
941
942    #[test]
943    fn test_combiner_avg() {
944        let scores = vec![(0, 1.0), (1, 2.0), (2, 3.0)];
945        let combiner = MultiValueCombiner::Avg;
946        assert!((combiner.combine(&scores) - 2.0).abs() < 1e-6);
947    }
948
949    #[test]
950    fn test_combiner_log_sum_exp() {
951        let scores = vec![(0, 1.0), (1, 2.0), (2, 3.0)];
952        let combiner = MultiValueCombiner::log_sum_exp();
953        let result = combiner.combine(&scores);
954        // LogSumExp should be between max (3.0) and max + log(n)/t
955        assert!(result >= 3.0);
956        assert!(result <= 3.0 + (3.0_f32).ln() / 1.5);
957    }
958
959    #[test]
960    fn test_combiner_log_sum_exp_approaches_max_with_high_temp() {
961        let scores = vec![(0, 1.0), (1, 5.0), (2, 2.0)];
962        // High temperature should approach max
963        let combiner = MultiValueCombiner::log_sum_exp_with_temperature(10.0);
964        let result = combiner.combine(&scores);
965        // Should be very close to max (5.0)
966        assert!((result - 5.0).abs() < 0.5);
967    }
968
969    #[test]
970    fn test_combiner_weighted_top_k() {
971        let scores = vec![(0, 5.0), (1, 3.0), (2, 1.0), (3, 0.5)];
972        let combiner = MultiValueCombiner::weighted_top_k_with_params(3, 0.5);
973        let result = combiner.combine(&scores);
974        // Top 3: 5.0, 3.0, 1.0 with weights 1.0, 0.5, 0.25
975        // weighted_sum = 5*1 + 3*0.5 + 1*0.25 = 6.75
976        // weight_total = 1.75
977        // result = 6.75 / 1.75 ≈ 3.857
978        assert!((result - 3.857).abs() < 0.01);
979    }
980
981    #[test]
982    fn test_combiner_weighted_top_k_less_than_k() {
983        let scores = vec![(0, 2.0), (1, 1.0)];
984        let combiner = MultiValueCombiner::weighted_top_k_with_params(5, 0.7);
985        let result = combiner.combine(&scores);
986        // Only 2 scores, weights 1.0 and 0.7
987        // weighted_sum = 2*1 + 1*0.7 = 2.7
988        // weight_total = 1.7
989        // result = 2.7 / 1.7 ≈ 1.588
990        assert!((result - 1.588).abs() < 0.01);
991    }
992
993    #[test]
994    fn test_combiner_empty_scores() {
995        let scores: Vec<(u32, f32)> = vec![];
996        assert_eq!(MultiValueCombiner::Sum.combine(&scores), 0.0);
997        assert_eq!(MultiValueCombiner::Max.combine(&scores), 0.0);
998        assert_eq!(MultiValueCombiner::Avg.combine(&scores), 0.0);
999        assert_eq!(MultiValueCombiner::log_sum_exp().combine(&scores), 0.0);
1000        assert_eq!(MultiValueCombiner::weighted_top_k().combine(&scores), 0.0);
1001    }
1002
1003    #[test]
1004    fn test_combiner_single_score() {
1005        let scores = vec![(0, 5.0)];
1006        // All combiners should return 5.0 for a single score
1007        assert!((MultiValueCombiner::Sum.combine(&scores) - 5.0).abs() < 1e-6);
1008        assert!((MultiValueCombiner::Max.combine(&scores) - 5.0).abs() < 1e-6);
1009        assert!((MultiValueCombiner::Avg.combine(&scores) - 5.0).abs() < 1e-6);
1010        assert!((MultiValueCombiner::log_sum_exp().combine(&scores) - 5.0).abs() < 1e-6);
1011        assert!((MultiValueCombiner::weighted_top_k().combine(&scores) - 5.0).abs() < 1e-6);
1012    }
1013
1014    #[test]
1015    fn test_default_combiner_is_log_sum_exp() {
1016        let combiner = MultiValueCombiner::default();
1017        match combiner {
1018            MultiValueCombiner::LogSumExp { temperature } => {
1019                assert!((temperature - 1.5).abs() < 1e-6);
1020            }
1021            _ => panic!("Default combiner should be LogSumExp"),
1022        }
1023    }
1024}