Skip to main content

hermes_core/query/
collector.rs

1//! Search result collection and response types
2
3use std::cmp::Ordering;
4use std::collections::BinaryHeap;
5
6use crate::segment::SegmentReader;
7use crate::structures::TERMINATED;
8use crate::{DocId, Result, Score};
9
10use super::Query;
11
12/// Unique document address: segment_id (hex) + local doc_id within segment
13#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
14pub struct DocAddress {
15    /// Segment ID as hex string (32 chars)
16    pub segment_id: String,
17    /// Document ID within the segment
18    pub doc_id: DocId,
19}
20
21impl DocAddress {
22    pub fn new(segment_id: u128, doc_id: DocId) -> Self {
23        Self {
24            segment_id: format!("{:032x}", segment_id),
25            doc_id,
26        }
27    }
28
29    /// Parse segment_id from hex string
30    pub fn segment_id_u128(&self) -> Option<u128> {
31        u128::from_str_radix(&self.segment_id, 16).ok()
32    }
33}
34
35/// A scored position/ordinal within a field
36/// For text fields: position is the token position
37/// For vector fields: position is the ordinal (which vector in multi-value)
38#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
39pub struct ScoredPosition {
40    /// Position (text) or ordinal (vector)
41    pub position: u32,
42    /// Individual score contribution from this position/ordinal
43    pub score: f32,
44}
45
46impl ScoredPosition {
47    pub fn new(position: u32, score: f32) -> Self {
48        Self { position, score }
49    }
50}
51
52/// Search result with doc_id and score (internal use)
53#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
54pub struct SearchResult {
55    pub doc_id: DocId,
56    pub score: Score,
57    /// Segment ID (set by searcher after collection)
58    #[serde(default, skip_serializing_if = "is_zero_u128")]
59    pub segment_id: u128,
60    /// Matched positions per field: (field_id, scored_positions)
61    /// Each position includes its individual score contribution
62    #[serde(default, skip_serializing_if = "Vec::is_empty")]
63    pub positions: Vec<(u32, Vec<ScoredPosition>)>,
64}
65
66fn is_zero_u128(v: &u128) -> bool {
67    *v == 0
68}
69
70/// Matched field info with ordinals (for multi-valued fields)
71#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
72pub struct MatchedField {
73    /// Field ID
74    pub field_id: u32,
75    /// Matched element ordinals (for multi-valued fields with position tracking)
76    /// Empty if position tracking is not enabled for this field
77    pub ordinals: Vec<u32>,
78}
79
80impl SearchResult {
81    /// Extract unique ordinals from positions for each field
82    /// For text fields: ordinal = position >> 20 (from encoded position)
83    /// For vector fields: position IS the ordinal directly
84    pub fn extract_ordinals(&self) -> Vec<MatchedField> {
85        use rustc_hash::FxHashSet;
86
87        self.positions
88            .iter()
89            .map(|(field_id, scored_positions)| {
90                let mut ordinals: FxHashSet<u32> = FxHashSet::default();
91                for sp in scored_positions {
92                    // For text fields with encoded positions, extract ordinal
93                    // For vector fields, position IS the ordinal
94                    // We use a heuristic: if position > 0xFFFFF (20 bits), it's encoded
95                    let ordinal = if sp.position > 0xFFFFF {
96                        sp.position >> 20
97                    } else {
98                        sp.position
99                    };
100                    ordinals.insert(ordinal);
101                }
102                let mut ordinals: Vec<u32> = ordinals.into_iter().collect();
103                ordinals.sort_unstable();
104                MatchedField {
105                    field_id: *field_id,
106                    ordinals,
107                }
108            })
109            .collect()
110    }
111
112    /// Get all scored positions for a specific field
113    pub fn field_positions(&self, field_id: u32) -> Option<&[ScoredPosition]> {
114        self.positions
115            .iter()
116            .find(|(fid, _)| *fid == field_id)
117            .map(|(_, positions)| positions.as_slice())
118    }
119}
120
121/// Search hit with unique document address and score
122#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
123pub struct SearchHit {
124    /// Unique document address (segment_id + local doc_id)
125    pub address: DocAddress,
126    pub score: Score,
127    /// Matched fields with element ordinals (populated when position tracking is enabled)
128    #[serde(default, skip_serializing_if = "Vec::is_empty")]
129    pub matched_fields: Vec<MatchedField>,
130}
131
132/// Search response with hits (IDs only, no documents)
133#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
134pub struct SearchResponse {
135    pub hits: Vec<SearchHit>,
136    pub total_hits: u32,
137}
138
139impl PartialEq for SearchResult {
140    fn eq(&self, other: &Self) -> bool {
141        self.doc_id == other.doc_id
142    }
143}
144
145impl Eq for SearchResult {}
146
147impl PartialOrd for SearchResult {
148    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
149        Some(self.cmp(other))
150    }
151}
152
153impl Ord for SearchResult {
154    fn cmp(&self, other: &Self) -> Ordering {
155        other
156            .score
157            .partial_cmp(&self.score)
158            .unwrap_or(Ordering::Equal)
159            .then_with(|| self.doc_id.cmp(&other.doc_id))
160    }
161}
162
163/// Trait for search result collectors
164///
165/// Implement this trait to create custom collectors that can be
166/// combined and passed to query execution.
167pub trait Collector {
168    /// Called for each matching document
169    /// positions: Vec of (field_id, scored_positions)
170    fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]);
171
172    /// Whether this collector needs position information
173    fn needs_positions(&self) -> bool {
174        false
175    }
176}
177
178/// Collector for top-k results
179pub struct TopKCollector {
180    heap: BinaryHeap<SearchResult>,
181    k: usize,
182    collect_positions: bool,
183    /// Total documents seen by this collector
184    total_seen: u32,
185}
186
187impl TopKCollector {
188    pub fn new(k: usize) -> Self {
189        Self {
190            heap: BinaryHeap::with_capacity(k + 1),
191            k,
192            collect_positions: false,
193            total_seen: 0,
194        }
195    }
196
197    /// Create a collector that also collects positions
198    pub fn with_positions(k: usize) -> Self {
199        Self {
200            heap: BinaryHeap::with_capacity(k + 1),
201            k,
202            collect_positions: true,
203            total_seen: 0,
204        }
205    }
206
207    /// Get the total number of documents seen (scored) by this collector
208    pub fn total_seen(&self) -> u32 {
209        self.total_seen
210    }
211
212    pub fn into_sorted_results(self) -> Vec<SearchResult> {
213        let mut results: Vec<_> = self.heap.into_vec();
214        results.sort_by(|a, b| {
215            b.score
216                .partial_cmp(&a.score)
217                .unwrap_or(Ordering::Equal)
218                .then_with(|| a.doc_id.cmp(&b.doc_id))
219        });
220        results
221    }
222
223    /// Consume collector and return (sorted_results, total_seen)
224    pub fn into_results_with_count(self) -> (Vec<SearchResult>, u32) {
225        let total = self.total_seen;
226        (self.into_sorted_results(), total)
227    }
228}
229
230impl Collector for TopKCollector {
231    fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]) {
232        self.total_seen += 1;
233
234        // Only clone positions when the document will actually be kept in the heap.
235        // This avoids deep-cloning Vec<ScoredPosition> for documents that are
236        // immediately discarded (the common case for large result sets).
237        let dominated =
238            self.heap.len() >= self.k && self.heap.peek().is_some_and(|min| score <= min.score);
239        if dominated {
240            return;
241        }
242
243        let positions = if self.collect_positions {
244            positions.to_vec()
245        } else {
246            Vec::new()
247        };
248
249        if self.heap.len() >= self.k {
250            self.heap.pop();
251        }
252        self.heap.push(SearchResult {
253            doc_id,
254            score,
255            segment_id: 0,
256            positions,
257        });
258    }
259
260    fn needs_positions(&self) -> bool {
261        self.collect_positions
262    }
263}
264
265/// Collector that counts all matching documents
266#[derive(Default)]
267pub struct CountCollector {
268    count: u64,
269}
270
271impl CountCollector {
272    pub fn new() -> Self {
273        Self { count: 0 }
274    }
275
276    /// Get the total count
277    pub fn count(&self) -> u64 {
278        self.count
279    }
280}
281
282impl Collector for CountCollector {
283    #[inline]
284    fn collect(
285        &mut self,
286        _doc_id: DocId,
287        _score: Score,
288        _positions: &[(u32, Vec<ScoredPosition>)],
289    ) {
290        self.count += 1;
291    }
292}
293
294/// Execute a search query on a single segment (async)
295pub async fn search_segment(
296    reader: &SegmentReader,
297    query: &dyn Query,
298    limit: usize,
299) -> Result<Vec<SearchResult>> {
300    let mut collector = TopKCollector::new(limit);
301    collect_segment_with_limit(reader, query, &mut collector, limit).await?;
302    Ok(collector.into_sorted_results())
303}
304
305/// Execute a search query on a single segment and return (results, total_seen) (async)
306pub async fn search_segment_with_count(
307    reader: &SegmentReader,
308    query: &dyn Query,
309    limit: usize,
310) -> Result<(Vec<SearchResult>, u32)> {
311    let mut collector = TopKCollector::new(limit);
312    collect_segment_with_limit(reader, query, &mut collector, limit).await?;
313    Ok(collector.into_results_with_count())
314}
315
316/// Execute a search query on a single segment with position collection (async)
317pub async fn search_segment_with_positions(
318    reader: &SegmentReader,
319    query: &dyn Query,
320    limit: usize,
321) -> Result<Vec<SearchResult>> {
322    let mut collector = TopKCollector::with_positions(limit);
323    collect_segment_with_limit(reader, query, &mut collector, limit).await?;
324    Ok(collector.into_sorted_results())
325}
326
327/// Execute a search query on a single segment with positions and return (results, total_seen)
328pub async fn search_segment_with_positions_and_count(
329    reader: &SegmentReader,
330    query: &dyn Query,
331    limit: usize,
332) -> Result<(Vec<SearchResult>, u32)> {
333    let mut collector = TopKCollector::with_positions(limit);
334    collect_segment_with_limit(reader, query, &mut collector, limit).await?;
335    Ok(collector.into_results_with_count())
336}
337
338/// Count all documents matching a query on a single segment (async)
339pub async fn count_segment(reader: &SegmentReader, query: &dyn Query) -> Result<u64> {
340    let mut collector = CountCollector::new();
341    collect_segment(reader, query, &mut collector).await?;
342    Ok(collector.count())
343}
344
345// Implement Collector for tuple of 2 collectors
346impl<A: Collector, B: Collector> Collector for (&mut A, &mut B) {
347    fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]) {
348        self.0.collect(doc_id, score, positions);
349        self.1.collect(doc_id, score, positions);
350    }
351    fn needs_positions(&self) -> bool {
352        self.0.needs_positions() || self.1.needs_positions()
353    }
354}
355
356// Implement Collector for tuple of 3 collectors
357impl<A: Collector, B: Collector, C: Collector> Collector for (&mut A, &mut B, &mut C) {
358    fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]) {
359        self.0.collect(doc_id, score, positions);
360        self.1.collect(doc_id, score, positions);
361        self.2.collect(doc_id, score, positions);
362    }
363    fn needs_positions(&self) -> bool {
364        self.0.needs_positions() || self.1.needs_positions() || self.2.needs_positions()
365    }
366}
367
368/// Execute a query with one or more collectors (async)
369///
370/// Uses a large limit for the scorer to disable WAND pruning.
371/// For queries that benefit from WAND pruning (e.g., sparse vector search),
372/// use `collect_segment_with_limit` instead.
373///
374/// # Examples
375/// ```ignore
376/// // Single collector
377/// let mut top_k = TopKCollector::new(10);
378/// collect_segment(reader, query, &mut top_k).await?;
379///
380/// // Multiple collectors (tuple)
381/// let mut top_k = TopKCollector::new(10);
382/// let mut count = CountCollector::new();
383/// collect_segment(reader, query, &mut (&mut top_k, &mut count)).await?;
384/// ```
385pub async fn collect_segment<C: Collector>(
386    reader: &SegmentReader,
387    query: &dyn Query,
388    collector: &mut C,
389) -> Result<()> {
390    // Use large limit to disable WAND skipping for exhaustive collection
391    collect_segment_with_limit(reader, query, collector, usize::MAX / 2).await
392}
393
394/// Execute a query with one or more collectors and a specific limit (async)
395///
396/// The limit is passed to the scorer to enable WAND pruning for queries
397/// that support it (e.g., sparse vector search). This significantly improves
398/// performance when only the top-k results are needed.
399///
400/// Doc IDs are automatically adjusted by the segment's doc_id_offset to produce
401/// global doc IDs that can be used across all segments.
402pub async fn collect_segment_with_limit<C: Collector>(
403    reader: &SegmentReader,
404    query: &dyn Query,
405    collector: &mut C,
406    limit: usize,
407) -> Result<()> {
408    let needs_positions = collector.needs_positions();
409    let doc_id_offset = reader.doc_id_offset();
410    let mut scorer = query.scorer(reader, limit).await?;
411
412    let mut doc = scorer.doc();
413    while doc != TERMINATED {
414        // Add doc_id_offset to convert segment-local ID to global ID
415        if needs_positions {
416            let positions = scorer.matched_positions().unwrap_or_default();
417            collector.collect(doc + doc_id_offset, scorer.score(), &positions);
418        } else {
419            collector.collect(doc + doc_id_offset, scorer.score(), &[]);
420        }
421        doc = scorer.advance();
422    }
423
424    Ok(())
425}
426
427#[cfg(test)]
428mod tests {
429    use super::*;
430
431    #[test]
432    fn test_top_k_collector() {
433        let mut collector = TopKCollector::new(3);
434
435        collector.collect(0, 1.0, &[]);
436        collector.collect(1, 3.0, &[]);
437        collector.collect(2, 2.0, &[]);
438        collector.collect(3, 4.0, &[]);
439        collector.collect(4, 0.5, &[]);
440
441        let results = collector.into_sorted_results();
442
443        assert_eq!(results.len(), 3);
444        assert_eq!(results[0].doc_id, 3); // score 4.0
445        assert_eq!(results[1].doc_id, 1); // score 3.0
446        assert_eq!(results[2].doc_id, 2); // score 2.0
447    }
448
449    #[test]
450    fn test_count_collector() {
451        let mut collector = CountCollector::new();
452
453        collector.collect(0, 1.0, &[]);
454        collector.collect(1, 2.0, &[]);
455        collector.collect(2, 3.0, &[]);
456
457        assert_eq!(collector.count(), 3);
458    }
459
460    #[test]
461    fn test_multi_collector() {
462        let mut top_k = TopKCollector::new(2);
463        let mut count = CountCollector::new();
464
465        // Simulate what collect_segment_multi does
466        for (doc_id, score) in [(0, 1.0), (1, 3.0), (2, 2.0), (3, 4.0), (4, 0.5)] {
467            top_k.collect(doc_id, score, &[]);
468            count.collect(doc_id, score, &[]);
469        }
470
471        // Count should have all 5 documents
472        assert_eq!(count.count(), 5);
473
474        // TopK should only have top 2 results
475        let results = top_k.into_sorted_results();
476        assert_eq!(results.len(), 2);
477        assert_eq!(results[0].doc_id, 3); // score 4.0
478        assert_eq!(results[1].doc_id, 1); // score 3.0
479    }
480}