Skip to main content

hermes_core/query/
collector.rs

1//! Search result collection and response types
2
3use std::cmp::Ordering;
4use std::collections::BinaryHeap;
5
6use crate::segment::SegmentReader;
7use crate::structures::TERMINATED;
8use crate::{DocId, Result, Score};
9
10use super::Query;
11
12/// Unique document address: segment_id (hex) + local doc_id within segment
13#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
14pub struct DocAddress {
15    /// Segment ID as hex string (32 chars)
16    pub segment_id: String,
17    /// Document ID within the segment
18    pub doc_id: DocId,
19}
20
21impl DocAddress {
22    pub fn new(segment_id: u128, doc_id: DocId) -> Self {
23        Self {
24            segment_id: format!("{:032x}", segment_id),
25            doc_id,
26        }
27    }
28
29    /// Parse segment_id from hex string
30    pub fn segment_id_u128(&self) -> Option<u128> {
31        u128::from_str_radix(&self.segment_id, 16).ok()
32    }
33}
34
35/// A scored position/ordinal within a field
36/// For text fields: position is the token position
37/// For vector fields: position is the ordinal (which vector in multi-value)
38#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
39pub struct ScoredPosition {
40    /// Position (text) or ordinal (vector)
41    pub position: u32,
42    /// Individual score contribution from this position/ordinal
43    pub score: f32,
44}
45
46impl ScoredPosition {
47    pub fn new(position: u32, score: f32) -> Self {
48        Self { position, score }
49    }
50}
51
52/// Search result with doc_id and score (internal use)
53#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
54pub struct SearchResult {
55    pub doc_id: DocId,
56    pub score: Score,
57    /// Segment ID (set by searcher after collection)
58    #[serde(default, skip_serializing_if = "is_zero_u128")]
59    pub segment_id: u128,
60    /// Matched positions per field: (field_id, scored_positions)
61    /// Each position includes its individual score contribution
62    #[serde(default, skip_serializing_if = "Vec::is_empty")]
63    pub positions: Vec<(u32, Vec<ScoredPosition>)>,
64}
65
66fn is_zero_u128(v: &u128) -> bool {
67    *v == 0
68}
69
70/// Matched field info with ordinals (for multi-valued fields)
71#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
72pub struct MatchedField {
73    /// Field ID
74    pub field_id: u32,
75    /// Matched element ordinals (for multi-valued fields with position tracking)
76    /// Empty if position tracking is not enabled for this field
77    pub ordinals: Vec<u32>,
78}
79
80impl SearchResult {
81    /// Extract unique ordinals from positions for each field
82    /// For text fields: ordinal = position >> 20 (from encoded position)
83    /// For vector fields: position IS the ordinal directly
84    pub fn extract_ordinals(&self) -> Vec<MatchedField> {
85        use rustc_hash::FxHashSet;
86
87        self.positions
88            .iter()
89            .map(|(field_id, scored_positions)| {
90                let mut ordinals: FxHashSet<u32> = FxHashSet::default();
91                for sp in scored_positions {
92                    // For text fields with encoded positions, extract ordinal
93                    // For vector fields, position IS the ordinal
94                    // We use a heuristic: if position > 0xFFFFF (20 bits), it's encoded
95                    let ordinal = if sp.position > 0xFFFFF {
96                        sp.position >> 20
97                    } else {
98                        sp.position
99                    };
100                    ordinals.insert(ordinal);
101                }
102                let mut ordinals: Vec<u32> = ordinals.into_iter().collect();
103                ordinals.sort_unstable();
104                MatchedField {
105                    field_id: *field_id,
106                    ordinals,
107                }
108            })
109            .collect()
110    }
111
112    /// Get all scored positions for a specific field
113    pub fn field_positions(&self, field_id: u32) -> Option<&[ScoredPosition]> {
114        self.positions
115            .iter()
116            .find(|(fid, _)| *fid == field_id)
117            .map(|(_, positions)| positions.as_slice())
118    }
119}
120
121/// Search hit with unique document address and score
122#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
123pub struct SearchHit {
124    /// Unique document address (segment_id + local doc_id)
125    pub address: DocAddress,
126    pub score: Score,
127    /// Matched fields with element ordinals (populated when position tracking is enabled)
128    #[serde(default, skip_serializing_if = "Vec::is_empty")]
129    pub matched_fields: Vec<MatchedField>,
130}
131
132/// Search response with hits (IDs only, no documents)
133#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
134pub struct SearchResponse {
135    pub hits: Vec<SearchHit>,
136    pub total_hits: u32,
137}
138
139impl PartialEq for SearchResult {
140    fn eq(&self, other: &Self) -> bool {
141        self.segment_id == other.segment_id && self.doc_id == other.doc_id
142    }
143}
144
145impl Eq for SearchResult {}
146
147impl PartialOrd for SearchResult {
148    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
149        Some(self.cmp(other))
150    }
151}
152
153impl Ord for SearchResult {
154    fn cmp(&self, other: &Self) -> Ordering {
155        other
156            .score
157            .partial_cmp(&self.score)
158            .unwrap_or(Ordering::Equal)
159            .then_with(|| self.segment_id.cmp(&other.segment_id))
160            .then_with(|| self.doc_id.cmp(&other.doc_id))
161    }
162}
163
164/// Trait for search result collectors
165///
166/// Implement this trait to create custom collectors that can be
167/// combined and passed to query execution.
168pub trait Collector {
169    /// Called for each matching document
170    /// positions: Vec of (field_id, scored_positions)
171    fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]);
172
173    /// Whether this collector needs position information
174    fn needs_positions(&self) -> bool {
175        false
176    }
177}
178
179/// Collector for top-k results
180pub struct TopKCollector {
181    heap: BinaryHeap<SearchResult>,
182    k: usize,
183    collect_positions: bool,
184    /// Total documents seen by this collector
185    total_seen: u32,
186}
187
188impl TopKCollector {
189    pub fn new(k: usize) -> Self {
190        Self {
191            heap: BinaryHeap::with_capacity(k + 1),
192            k,
193            collect_positions: false,
194            total_seen: 0,
195        }
196    }
197
198    /// Create a collector that also collects positions
199    pub fn with_positions(k: usize) -> Self {
200        Self {
201            heap: BinaryHeap::with_capacity(k + 1),
202            k,
203            collect_positions: true,
204            total_seen: 0,
205        }
206    }
207
208    /// Get the total number of documents seen (scored) by this collector
209    pub fn total_seen(&self) -> u32 {
210        self.total_seen
211    }
212
213    pub fn into_sorted_results(self) -> Vec<SearchResult> {
214        let mut results: Vec<_> = self.heap.into_vec();
215        results.sort_by(|a, b| {
216            b.score
217                .partial_cmp(&a.score)
218                .unwrap_or(Ordering::Equal)
219                .then_with(|| a.doc_id.cmp(&b.doc_id))
220        });
221        results
222    }
223
224    /// Consume collector and return (sorted_results, total_seen)
225    pub fn into_results_with_count(self) -> (Vec<SearchResult>, u32) {
226        let total = self.total_seen;
227        (self.into_sorted_results(), total)
228    }
229}
230
231impl Collector for TopKCollector {
232    fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]) {
233        self.total_seen += 1;
234
235        // Only clone positions when the document will actually be kept in the heap.
236        // This avoids deep-cloning Vec<ScoredPosition> for documents that are
237        // immediately discarded (the common case for large result sets).
238        let dominated =
239            self.heap.len() >= self.k && self.heap.peek().is_some_and(|min| score <= min.score);
240        if dominated {
241            return;
242        }
243
244        let positions = if self.collect_positions {
245            positions.to_vec()
246        } else {
247            Vec::new()
248        };
249
250        if self.heap.len() >= self.k {
251            self.heap.pop();
252        }
253        self.heap.push(SearchResult {
254            doc_id,
255            score,
256            segment_id: 0,
257            positions,
258        });
259    }
260
261    fn needs_positions(&self) -> bool {
262        self.collect_positions
263    }
264}
265
266/// Collector that counts all matching documents
267#[derive(Default)]
268pub struct CountCollector {
269    count: u64,
270}
271
272impl CountCollector {
273    pub fn new() -> Self {
274        Self { count: 0 }
275    }
276
277    /// Get the total count
278    pub fn count(&self) -> u64 {
279        self.count
280    }
281}
282
283impl Collector for CountCollector {
284    #[inline]
285    fn collect(
286        &mut self,
287        _doc_id: DocId,
288        _score: Score,
289        _positions: &[(u32, Vec<ScoredPosition>)],
290    ) {
291        self.count += 1;
292    }
293}
294
295/// Execute a search query on a single segment (async)
296pub async fn search_segment(
297    reader: &SegmentReader,
298    query: &dyn Query,
299    limit: usize,
300) -> Result<Vec<SearchResult>> {
301    let mut collector = TopKCollector::new(limit);
302    collect_segment_with_limit(reader, query, &mut collector, limit).await?;
303    Ok(collector.into_sorted_results())
304}
305
306/// Execute a search query on a single segment and return (results, total_seen) (async)
307pub async fn search_segment_with_count(
308    reader: &SegmentReader,
309    query: &dyn Query,
310    limit: usize,
311) -> Result<(Vec<SearchResult>, u32)> {
312    let mut collector = TopKCollector::new(limit);
313    collect_segment_with_limit(reader, query, &mut collector, limit).await?;
314    Ok(collector.into_results_with_count())
315}
316
317/// Execute a search query on a single segment with position collection (async)
318pub async fn search_segment_with_positions(
319    reader: &SegmentReader,
320    query: &dyn Query,
321    limit: usize,
322) -> Result<Vec<SearchResult>> {
323    let mut collector = TopKCollector::with_positions(limit);
324    collect_segment_with_limit(reader, query, &mut collector, limit).await?;
325    Ok(collector.into_sorted_results())
326}
327
328/// Execute a search query on a single segment with positions and return (results, total_seen)
329pub async fn search_segment_with_positions_and_count(
330    reader: &SegmentReader,
331    query: &dyn Query,
332    limit: usize,
333) -> Result<(Vec<SearchResult>, u32)> {
334    let mut collector = TopKCollector::with_positions(limit);
335    collect_segment_with_limit(reader, query, &mut collector, limit).await?;
336    Ok(collector.into_results_with_count())
337}
338
339/// Count all documents matching a query on a single segment (async)
340pub async fn count_segment(reader: &SegmentReader, query: &dyn Query) -> Result<u64> {
341    let mut collector = CountCollector::new();
342    collect_segment(reader, query, &mut collector).await?;
343    Ok(collector.count())
344}
345
346// Implement Collector for tuple of 2 collectors
347impl<A: Collector, B: Collector> Collector for (&mut A, &mut B) {
348    fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]) {
349        self.0.collect(doc_id, score, positions);
350        self.1.collect(doc_id, score, positions);
351    }
352    fn needs_positions(&self) -> bool {
353        self.0.needs_positions() || self.1.needs_positions()
354    }
355}
356
357// Implement Collector for tuple of 3 collectors
358impl<A: Collector, B: Collector, C: Collector> Collector for (&mut A, &mut B, &mut C) {
359    fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]) {
360        self.0.collect(doc_id, score, positions);
361        self.1.collect(doc_id, score, positions);
362        self.2.collect(doc_id, score, positions);
363    }
364    fn needs_positions(&self) -> bool {
365        self.0.needs_positions() || self.1.needs_positions() || self.2.needs_positions()
366    }
367}
368
369/// Execute a query with one or more collectors (async)
370///
371/// Uses a large limit for the scorer to disable MaxScore pruning.
372/// For queries that benefit from MaxScore pruning (e.g., sparse vector search),
373/// use `collect_segment_with_limit` instead.
374///
375/// # Examples
376/// ```ignore
377/// // Single collector
378/// let mut top_k = TopKCollector::new(10);
379/// collect_segment(reader, query, &mut top_k).await?;
380///
381/// // Multiple collectors (tuple)
382/// let mut top_k = TopKCollector::new(10);
383/// let mut count = CountCollector::new();
384/// collect_segment(reader, query, &mut (&mut top_k, &mut count)).await?;
385/// ```
386pub async fn collect_segment<C: Collector>(
387    reader: &SegmentReader,
388    query: &dyn Query,
389    collector: &mut C,
390) -> Result<()> {
391    // Use large limit to disable MaxScore skipping for exhaustive collection
392    collect_segment_with_limit(reader, query, collector, usize::MAX / 2).await
393}
394
395/// Execute a query with one or more collectors and a specific limit (async)
396///
397/// The limit is passed to the scorer to enable MaxScore pruning for queries
398/// that support it (e.g., sparse vector search). This significantly improves
399/// performance when only the top-k results are needed.
400///
401/// Doc IDs in the collector are segment-local. The searcher stamps each result
402/// with its segment_id, making (segment_id, doc_id) the unique document key.
403pub async fn collect_segment_with_limit<C: Collector>(
404    reader: &SegmentReader,
405    query: &dyn Query,
406    collector: &mut C,
407    limit: usize,
408) -> Result<()> {
409    let needs_positions = collector.needs_positions();
410    let mut scorer = query.scorer(reader, limit).await?;
411
412    let mut doc = scorer.doc();
413    while doc != TERMINATED {
414        if needs_positions {
415            let positions = scorer.matched_positions().unwrap_or_default();
416            collector.collect(doc, scorer.score(), &positions);
417        } else {
418            collector.collect(doc, scorer.score(), &[]);
419        }
420        doc = scorer.advance();
421    }
422
423    Ok(())
424}
425
426#[cfg(test)]
427mod tests {
428    use super::*;
429
430    #[test]
431    fn test_top_k_collector() {
432        let mut collector = TopKCollector::new(3);
433
434        collector.collect(0, 1.0, &[]);
435        collector.collect(1, 3.0, &[]);
436        collector.collect(2, 2.0, &[]);
437        collector.collect(3, 4.0, &[]);
438        collector.collect(4, 0.5, &[]);
439
440        let results = collector.into_sorted_results();
441
442        assert_eq!(results.len(), 3);
443        assert_eq!(results[0].doc_id, 3); // score 4.0
444        assert_eq!(results[1].doc_id, 1); // score 3.0
445        assert_eq!(results[2].doc_id, 2); // score 2.0
446    }
447
448    #[test]
449    fn test_count_collector() {
450        let mut collector = CountCollector::new();
451
452        collector.collect(0, 1.0, &[]);
453        collector.collect(1, 2.0, &[]);
454        collector.collect(2, 3.0, &[]);
455
456        assert_eq!(collector.count(), 3);
457    }
458
459    #[test]
460    fn test_multi_collector() {
461        let mut top_k = TopKCollector::new(2);
462        let mut count = CountCollector::new();
463
464        // Simulate what collect_segment_multi does
465        for (doc_id, score) in [(0, 1.0), (1, 3.0), (2, 2.0), (3, 4.0), (4, 0.5)] {
466            top_k.collect(doc_id, score, &[]);
467            count.collect(doc_id, score, &[]);
468        }
469
470        // Count should have all 5 documents
471        assert_eq!(count.count(), 5);
472
473        // TopK should only have top 2 results
474        let results = top_k.into_sorted_results();
475        assert_eq!(results.len(), 2);
476        assert_eq!(results[0].doc_id, 3); // score 4.0
477        assert_eq!(results[1].doc_id, 1); // score 3.0
478    }
479}