Skip to main content

hermes_core/query/
collector.rs

1//! Search result collection and response types
2
3use std::cmp::Ordering;
4use std::collections::BinaryHeap;
5
6use crate::segment::SegmentReader;
7use crate::structures::TERMINATED;
8use crate::{DocId, Result, Score};
9
10use super::Query;
11
12/// Unique document address: segment_id (hex) + local doc_id within segment
13#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
14pub struct DocAddress {
15    /// Segment ID as hex string (32 chars)
16    pub segment_id: String,
17    /// Document ID within the segment
18    pub doc_id: DocId,
19}
20
21impl DocAddress {
22    pub fn new(segment_id: u128, doc_id: DocId) -> Self {
23        Self {
24            segment_id: format!("{:032x}", segment_id),
25            doc_id,
26        }
27    }
28
29    /// Parse segment_id from hex string
30    pub fn segment_id_u128(&self) -> Option<u128> {
31        u128::from_str_radix(&self.segment_id, 16).ok()
32    }
33}
34
35/// Search result with doc_id and score (internal use)
36#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
37pub struct SearchResult {
38    pub doc_id: DocId,
39    pub score: Score,
40    /// Matched positions per field: (field_id, encoded_positions)
41    #[serde(default, skip_serializing_if = "Vec::is_empty")]
42    pub positions: Vec<(u32, Vec<u32>)>,
43}
44
45/// Matched field info with ordinals (for multi-valued fields)
46#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
47pub struct MatchedField {
48    /// Field ID
49    pub field_id: u32,
50    /// Matched element ordinals (for multi-valued fields with position tracking)
51    /// Empty if position tracking is not enabled for this field
52    pub ordinals: Vec<u32>,
53}
54
55impl SearchResult {
56    /// Extract unique ordinals from positions for each field
57    /// Uses the position encoding: ordinal = position >> 20
58    pub fn extract_ordinals(&self) -> Vec<MatchedField> {
59        use rustc_hash::FxHashSet;
60
61        self.positions
62            .iter()
63            .map(|(field_id, positions)| {
64                let mut ordinals: FxHashSet<u32> = FxHashSet::default();
65                for &pos in positions {
66                    let ordinal = pos >> 20; // Extract ordinal from encoded position
67                    ordinals.insert(ordinal);
68                }
69                let mut ordinals: Vec<u32> = ordinals.into_iter().collect();
70                ordinals.sort_unstable();
71                MatchedField {
72                    field_id: *field_id,
73                    ordinals,
74                }
75            })
76            .collect()
77    }
78}
79
80/// Search hit with unique document address and score
81#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
82pub struct SearchHit {
83    /// Unique document address (segment_id + local doc_id)
84    pub address: DocAddress,
85    pub score: Score,
86    /// Matched fields with element ordinals (populated when position tracking is enabled)
87    #[serde(default, skip_serializing_if = "Vec::is_empty")]
88    pub matched_fields: Vec<MatchedField>,
89}
90
91/// Search response with hits (IDs only, no documents)
92#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
93pub struct SearchResponse {
94    pub hits: Vec<SearchHit>,
95    pub total_hits: u32,
96}
97
98impl PartialEq for SearchResult {
99    fn eq(&self, other: &Self) -> bool {
100        self.doc_id == other.doc_id
101    }
102}
103
104impl Eq for SearchResult {}
105
106impl PartialOrd for SearchResult {
107    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
108        Some(self.cmp(other))
109    }
110}
111
112impl Ord for SearchResult {
113    fn cmp(&self, other: &Self) -> Ordering {
114        other
115            .score
116            .partial_cmp(&self.score)
117            .unwrap_or(Ordering::Equal)
118            .then_with(|| self.doc_id.cmp(&other.doc_id))
119    }
120}
121
122/// Trait for search result collectors
123///
124/// Implement this trait to create custom collectors that can be
125/// combined and passed to query execution.
126pub trait Collector {
127    /// Called for each matching document
128    fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<u32>)]);
129
130    /// Whether this collector needs position information
131    fn needs_positions(&self) -> bool {
132        false
133    }
134}
135
136/// Collector for top-k results
137pub struct TopKCollector {
138    heap: BinaryHeap<SearchResult>,
139    k: usize,
140    collect_positions: bool,
141}
142
143impl TopKCollector {
144    pub fn new(k: usize) -> Self {
145        Self {
146            heap: BinaryHeap::with_capacity(k + 1),
147            k,
148            collect_positions: false,
149        }
150    }
151
152    /// Create a collector that also collects positions
153    pub fn with_positions(k: usize) -> Self {
154        Self {
155            heap: BinaryHeap::with_capacity(k + 1),
156            k,
157            collect_positions: true,
158        }
159    }
160
161    pub fn into_sorted_results(self) -> Vec<SearchResult> {
162        let mut results: Vec<_> = self.heap.into_vec();
163        results.sort_by(|a, b| {
164            b.score
165                .partial_cmp(&a.score)
166                .unwrap_or(Ordering::Equal)
167                .then_with(|| a.doc_id.cmp(&b.doc_id))
168        });
169        results
170    }
171}
172
173impl Collector for TopKCollector {
174    fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<u32>)]) {
175        let positions = if self.collect_positions {
176            positions.to_vec()
177        } else {
178            Vec::new()
179        };
180
181        if self.heap.len() < self.k {
182            self.heap.push(SearchResult {
183                doc_id,
184                score,
185                positions,
186            });
187        } else if let Some(min) = self.heap.peek()
188            && score > min.score
189        {
190            self.heap.pop();
191            self.heap.push(SearchResult {
192                doc_id,
193                score,
194                positions,
195            });
196        }
197    }
198
199    fn needs_positions(&self) -> bool {
200        self.collect_positions
201    }
202}
203
204/// Collector that counts all matching documents
205#[derive(Default)]
206pub struct CountCollector {
207    count: u64,
208}
209
210impl CountCollector {
211    pub fn new() -> Self {
212        Self { count: 0 }
213    }
214
215    /// Get the total count
216    pub fn count(&self) -> u64 {
217        self.count
218    }
219}
220
221impl Collector for CountCollector {
222    #[inline]
223    fn collect(&mut self, _doc_id: DocId, _score: Score, _positions: &[(u32, Vec<u32>)]) {
224        self.count += 1;
225    }
226}
227
228/// Execute a search query on a single segment (async)
229pub async fn search_segment(
230    reader: &SegmentReader,
231    query: &dyn Query,
232    limit: usize,
233) -> Result<Vec<SearchResult>> {
234    let mut collector = TopKCollector::new(limit);
235    collect_segment(reader, query, &mut collector).await?;
236    Ok(collector.into_sorted_results())
237}
238
239/// Execute a search query on a single segment with position collection (async)
240pub async fn search_segment_with_positions(
241    reader: &SegmentReader,
242    query: &dyn Query,
243    limit: usize,
244) -> Result<Vec<SearchResult>> {
245    let mut collector = TopKCollector::with_positions(limit);
246    collect_segment(reader, query, &mut collector).await?;
247    Ok(collector.into_sorted_results())
248}
249
250/// Count all documents matching a query on a single segment (async)
251pub async fn count_segment(reader: &SegmentReader, query: &dyn Query) -> Result<u64> {
252    let mut collector = CountCollector::new();
253    collect_segment(reader, query, &mut collector).await?;
254    Ok(collector.count())
255}
256
257// Implement Collector for tuple of 2 collectors
258impl<A: Collector, B: Collector> Collector for (&mut A, &mut B) {
259    fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<u32>)]) {
260        self.0.collect(doc_id, score, positions);
261        self.1.collect(doc_id, score, positions);
262    }
263    fn needs_positions(&self) -> bool {
264        self.0.needs_positions() || self.1.needs_positions()
265    }
266}
267
268// Implement Collector for tuple of 3 collectors
269impl<A: Collector, B: Collector, C: Collector> Collector for (&mut A, &mut B, &mut C) {
270    fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<u32>)]) {
271        self.0.collect(doc_id, score, positions);
272        self.1.collect(doc_id, score, positions);
273        self.2.collect(doc_id, score, positions);
274    }
275    fn needs_positions(&self) -> bool {
276        self.0.needs_positions() || self.1.needs_positions() || self.2.needs_positions()
277    }
278}
279
280/// Execute a query with one or more collectors (async)
281///
282/// # Examples
283/// ```ignore
284/// // Single collector
285/// let mut top_k = TopKCollector::new(10);
286/// collect_segment(reader, query, &mut top_k).await?;
287///
288/// // Multiple collectors (tuple)
289/// let mut top_k = TopKCollector::new(10);
290/// let mut count = CountCollector::new();
291/// collect_segment(reader, query, &mut (&mut top_k, &mut count)).await?;
292/// ```
293pub async fn collect_segment<C: Collector>(
294    reader: &SegmentReader,
295    query: &dyn Query,
296    collector: &mut C,
297) -> Result<()> {
298    let needs_positions = collector.needs_positions();
299    // Use large limit to disable WAND skipping, but not usize::MAX to avoid overflow
300    let mut scorer = query.scorer(reader, usize::MAX / 2).await?;
301
302    let mut doc = scorer.doc();
303    while doc != TERMINATED {
304        let positions = if needs_positions {
305            scorer.matched_positions().unwrap_or_default()
306        } else {
307            Vec::new()
308        };
309        collector.collect(doc, scorer.score(), &positions);
310        doc = scorer.advance();
311    }
312
313    Ok(())
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319
320    #[test]
321    fn test_top_k_collector() {
322        let mut collector = TopKCollector::new(3);
323
324        collector.collect(0, 1.0, &[]);
325        collector.collect(1, 3.0, &[]);
326        collector.collect(2, 2.0, &[]);
327        collector.collect(3, 4.0, &[]);
328        collector.collect(4, 0.5, &[]);
329
330        let results = collector.into_sorted_results();
331
332        assert_eq!(results.len(), 3);
333        assert_eq!(results[0].doc_id, 3); // score 4.0
334        assert_eq!(results[1].doc_id, 1); // score 3.0
335        assert_eq!(results[2].doc_id, 2); // score 2.0
336    }
337
338    #[test]
339    fn test_count_collector() {
340        let mut collector = CountCollector::new();
341
342        collector.collect(0, 1.0, &[]);
343        collector.collect(1, 2.0, &[]);
344        collector.collect(2, 3.0, &[]);
345
346        assert_eq!(collector.count(), 3);
347    }
348
349    #[test]
350    fn test_multi_collector() {
351        let mut top_k = TopKCollector::new(2);
352        let mut count = CountCollector::new();
353
354        // Simulate what collect_segment_multi does
355        for (doc_id, score) in [(0, 1.0), (1, 3.0), (2, 2.0), (3, 4.0), (4, 0.5)] {
356            top_k.collect(doc_id, score, &[]);
357            count.collect(doc_id, score, &[]);
358        }
359
360        // Count should have all 5 documents
361        assert_eq!(count.count(), 5);
362
363        // TopK should only have top 2 results
364        let results = top_k.into_sorted_results();
365        assert_eq!(results.len(), 2);
366        assert_eq!(results[0].doc_id, 3); // score 4.0
367        assert_eq!(results[1].doc_id, 1); // score 3.0
368    }
369}