Skip to main content

hermes_core/query/
collector.rs

1//! Search result collection and response types
2
3use std::cmp::Ordering;
4use std::collections::BinaryHeap;
5
6use crate::segment::SegmentReader;
7use crate::structures::TERMINATED;
8use crate::{DocId, Result, Score};
9
10use super::Query;
11
12/// Unique document address: segment_id (hex) + local doc_id within segment
13#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
14pub struct DocAddress {
15    /// Segment ID as hex string (32 chars)
16    pub segment_id: String,
17    /// Document ID within the segment
18    pub doc_id: DocId,
19}
20
21impl DocAddress {
22    pub fn new(segment_id: u128, doc_id: DocId) -> Self {
23        Self {
24            segment_id: format!("{:032x}", segment_id),
25            doc_id,
26        }
27    }
28
29    /// Parse segment_id from hex string
30    pub fn segment_id_u128(&self) -> Option<u128> {
31        u128::from_str_radix(&self.segment_id, 16).ok()
32    }
33}
34
35/// A scored position/ordinal within a field
36/// For text fields: position is the token position
37/// For vector fields: position is the ordinal (which vector in multi-value)
38#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
39pub struct ScoredPosition {
40    /// Position (text) or ordinal (vector)
41    pub position: u32,
42    /// Individual score contribution from this position/ordinal
43    pub score: f32,
44}
45
46impl ScoredPosition {
47    pub fn new(position: u32, score: f32) -> Self {
48        Self { position, score }
49    }
50}
51
52/// Search result with doc_id and score (internal use)
53#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
54pub struct SearchResult {
55    pub doc_id: DocId,
56    pub score: Score,
57    /// Matched positions per field: (field_id, scored_positions)
58    /// Each position includes its individual score contribution
59    #[serde(default, skip_serializing_if = "Vec::is_empty")]
60    pub positions: Vec<(u32, Vec<ScoredPosition>)>,
61}
62
63/// Matched field info with ordinals (for multi-valued fields)
64#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
65pub struct MatchedField {
66    /// Field ID
67    pub field_id: u32,
68    /// Matched element ordinals (for multi-valued fields with position tracking)
69    /// Empty if position tracking is not enabled for this field
70    pub ordinals: Vec<u32>,
71}
72
73impl SearchResult {
74    /// Extract unique ordinals from positions for each field
75    /// For text fields: ordinal = position >> 20 (from encoded position)
76    /// For vector fields: position IS the ordinal directly
77    pub fn extract_ordinals(&self) -> Vec<MatchedField> {
78        use rustc_hash::FxHashSet;
79
80        self.positions
81            .iter()
82            .map(|(field_id, scored_positions)| {
83                let mut ordinals: FxHashSet<u32> = FxHashSet::default();
84                for sp in scored_positions {
85                    // For text fields with encoded positions, extract ordinal
86                    // For vector fields, position IS the ordinal
87                    // We use a heuristic: if position > 0xFFFFF (20 bits), it's encoded
88                    let ordinal = if sp.position > 0xFFFFF {
89                        sp.position >> 20
90                    } else {
91                        sp.position
92                    };
93                    ordinals.insert(ordinal);
94                }
95                let mut ordinals: Vec<u32> = ordinals.into_iter().collect();
96                ordinals.sort_unstable();
97                MatchedField {
98                    field_id: *field_id,
99                    ordinals,
100                }
101            })
102            .collect()
103    }
104
105    /// Get all scored positions for a specific field
106    pub fn field_positions(&self, field_id: u32) -> Option<&[ScoredPosition]> {
107        self.positions
108            .iter()
109            .find(|(fid, _)| *fid == field_id)
110            .map(|(_, positions)| positions.as_slice())
111    }
112}
113
114/// Search hit with unique document address and score
115#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
116pub struct SearchHit {
117    /// Unique document address (segment_id + local doc_id)
118    pub address: DocAddress,
119    pub score: Score,
120    /// Matched fields with element ordinals (populated when position tracking is enabled)
121    #[serde(default, skip_serializing_if = "Vec::is_empty")]
122    pub matched_fields: Vec<MatchedField>,
123}
124
125/// Search response with hits (IDs only, no documents)
126#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
127pub struct SearchResponse {
128    pub hits: Vec<SearchHit>,
129    pub total_hits: u32,
130}
131
132impl PartialEq for SearchResult {
133    fn eq(&self, other: &Self) -> bool {
134        self.doc_id == other.doc_id
135    }
136}
137
138impl Eq for SearchResult {}
139
140impl PartialOrd for SearchResult {
141    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
142        Some(self.cmp(other))
143    }
144}
145
146impl Ord for SearchResult {
147    fn cmp(&self, other: &Self) -> Ordering {
148        other
149            .score
150            .partial_cmp(&self.score)
151            .unwrap_or(Ordering::Equal)
152            .then_with(|| self.doc_id.cmp(&other.doc_id))
153    }
154}
155
156/// Trait for search result collectors
157///
158/// Implement this trait to create custom collectors that can be
159/// combined and passed to query execution.
160pub trait Collector {
161    /// Called for each matching document
162    /// positions: Vec of (field_id, scored_positions)
163    fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]);
164
165    /// Whether this collector needs position information
166    fn needs_positions(&self) -> bool {
167        false
168    }
169}
170
171/// Collector for top-k results
172pub struct TopKCollector {
173    heap: BinaryHeap<SearchResult>,
174    k: usize,
175    collect_positions: bool,
176    /// Total documents seen by this collector
177    total_seen: u32,
178}
179
180impl TopKCollector {
181    pub fn new(k: usize) -> Self {
182        Self {
183            heap: BinaryHeap::with_capacity(k + 1),
184            k,
185            collect_positions: false,
186            total_seen: 0,
187        }
188    }
189
190    /// Create a collector that also collects positions
191    pub fn with_positions(k: usize) -> Self {
192        Self {
193            heap: BinaryHeap::with_capacity(k + 1),
194            k,
195            collect_positions: true,
196            total_seen: 0,
197        }
198    }
199
200    /// Get the total number of documents seen (scored) by this collector
201    pub fn total_seen(&self) -> u32 {
202        self.total_seen
203    }
204
205    pub fn into_sorted_results(self) -> Vec<SearchResult> {
206        let mut results: Vec<_> = self.heap.into_vec();
207        results.sort_by(|a, b| {
208            b.score
209                .partial_cmp(&a.score)
210                .unwrap_or(Ordering::Equal)
211                .then_with(|| a.doc_id.cmp(&b.doc_id))
212        });
213        results
214    }
215
216    /// Consume collector and return (sorted_results, total_seen)
217    pub fn into_results_with_count(self) -> (Vec<SearchResult>, u32) {
218        let total = self.total_seen;
219        (self.into_sorted_results(), total)
220    }
221}
222
223impl Collector for TopKCollector {
224    fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]) {
225        self.total_seen += 1;
226
227        let positions = if self.collect_positions {
228            positions.to_vec()
229        } else {
230            Vec::new()
231        };
232
233        if self.heap.len() < self.k {
234            self.heap.push(SearchResult {
235                doc_id,
236                score,
237                positions,
238            });
239        } else if let Some(min) = self.heap.peek()
240            && score > min.score
241        {
242            self.heap.pop();
243            self.heap.push(SearchResult {
244                doc_id,
245                score,
246                positions,
247            });
248        }
249    }
250
251    fn needs_positions(&self) -> bool {
252        self.collect_positions
253    }
254}
255
256/// Collector that counts all matching documents
257#[derive(Default)]
258pub struct CountCollector {
259    count: u64,
260}
261
262impl CountCollector {
263    pub fn new() -> Self {
264        Self { count: 0 }
265    }
266
267    /// Get the total count
268    pub fn count(&self) -> u64 {
269        self.count
270    }
271}
272
273impl Collector for CountCollector {
274    #[inline]
275    fn collect(
276        &mut self,
277        _doc_id: DocId,
278        _score: Score,
279        _positions: &[(u32, Vec<ScoredPosition>)],
280    ) {
281        self.count += 1;
282    }
283}
284
285/// Execute a search query on a single segment (async)
286pub async fn search_segment(
287    reader: &SegmentReader,
288    query: &dyn Query,
289    limit: usize,
290) -> Result<Vec<SearchResult>> {
291    let mut collector = TopKCollector::new(limit);
292    collect_segment_with_limit(reader, query, &mut collector, limit).await?;
293    Ok(collector.into_sorted_results())
294}
295
296/// Execute a search query on a single segment and return (results, total_seen) (async)
297pub async fn search_segment_with_count(
298    reader: &SegmentReader,
299    query: &dyn Query,
300    limit: usize,
301) -> Result<(Vec<SearchResult>, u32)> {
302    let mut collector = TopKCollector::new(limit);
303    collect_segment_with_limit(reader, query, &mut collector, limit).await?;
304    Ok(collector.into_results_with_count())
305}
306
307/// Execute a search query on a single segment with position collection (async)
308pub async fn search_segment_with_positions(
309    reader: &SegmentReader,
310    query: &dyn Query,
311    limit: usize,
312) -> Result<Vec<SearchResult>> {
313    let mut collector = TopKCollector::with_positions(limit);
314    collect_segment_with_limit(reader, query, &mut collector, limit).await?;
315    Ok(collector.into_sorted_results())
316}
317
318/// Execute a search query on a single segment with positions and return (results, total_seen)
319pub async fn search_segment_with_positions_and_count(
320    reader: &SegmentReader,
321    query: &dyn Query,
322    limit: usize,
323) -> Result<(Vec<SearchResult>, u32)> {
324    let mut collector = TopKCollector::with_positions(limit);
325    collect_segment_with_limit(reader, query, &mut collector, limit).await?;
326    Ok(collector.into_results_with_count())
327}
328
329/// Count all documents matching a query on a single segment (async)
330pub async fn count_segment(reader: &SegmentReader, query: &dyn Query) -> Result<u64> {
331    let mut collector = CountCollector::new();
332    collect_segment(reader, query, &mut collector).await?;
333    Ok(collector.count())
334}
335
336// Implement Collector for tuple of 2 collectors
337impl<A: Collector, B: Collector> Collector for (&mut A, &mut B) {
338    fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]) {
339        self.0.collect(doc_id, score, positions);
340        self.1.collect(doc_id, score, positions);
341    }
342    fn needs_positions(&self) -> bool {
343        self.0.needs_positions() || self.1.needs_positions()
344    }
345}
346
347// Implement Collector for tuple of 3 collectors
348impl<A: Collector, B: Collector, C: Collector> Collector for (&mut A, &mut B, &mut C) {
349    fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]) {
350        self.0.collect(doc_id, score, positions);
351        self.1.collect(doc_id, score, positions);
352        self.2.collect(doc_id, score, positions);
353    }
354    fn needs_positions(&self) -> bool {
355        self.0.needs_positions() || self.1.needs_positions() || self.2.needs_positions()
356    }
357}
358
359/// Execute a query with one or more collectors (async)
360///
361/// Uses a large limit for the scorer to disable WAND pruning.
362/// For queries that benefit from WAND pruning (e.g., sparse vector search),
363/// use `collect_segment_with_limit` instead.
364///
365/// # Examples
366/// ```ignore
367/// // Single collector
368/// let mut top_k = TopKCollector::new(10);
369/// collect_segment(reader, query, &mut top_k).await?;
370///
371/// // Multiple collectors (tuple)
372/// let mut top_k = TopKCollector::new(10);
373/// let mut count = CountCollector::new();
374/// collect_segment(reader, query, &mut (&mut top_k, &mut count)).await?;
375/// ```
376pub async fn collect_segment<C: Collector>(
377    reader: &SegmentReader,
378    query: &dyn Query,
379    collector: &mut C,
380) -> Result<()> {
381    // Use large limit to disable WAND skipping for exhaustive collection
382    collect_segment_with_limit(reader, query, collector, usize::MAX / 2).await
383}
384
385/// Execute a query with one or more collectors and a specific limit (async)
386///
387/// The limit is passed to the scorer to enable WAND pruning for queries
388/// that support it (e.g., sparse vector search). This significantly improves
389/// performance when only the top-k results are needed.
390pub async fn collect_segment_with_limit<C: Collector>(
391    reader: &SegmentReader,
392    query: &dyn Query,
393    collector: &mut C,
394    limit: usize,
395) -> Result<()> {
396    let needs_positions = collector.needs_positions();
397    let mut scorer = query.scorer(reader, limit).await?;
398
399    let mut doc = scorer.doc();
400    while doc != TERMINATED {
401        let positions = if needs_positions {
402            scorer.matched_positions().unwrap_or_default()
403        } else {
404            Vec::new()
405        };
406        collector.collect(doc, scorer.score(), &positions);
407        doc = scorer.advance();
408    }
409
410    Ok(())
411}
412
413#[cfg(test)]
414mod tests {
415    use super::*;
416
417    #[test]
418    fn test_top_k_collector() {
419        let mut collector = TopKCollector::new(3);
420
421        collector.collect(0, 1.0, &[]);
422        collector.collect(1, 3.0, &[]);
423        collector.collect(2, 2.0, &[]);
424        collector.collect(3, 4.0, &[]);
425        collector.collect(4, 0.5, &[]);
426
427        let results = collector.into_sorted_results();
428
429        assert_eq!(results.len(), 3);
430        assert_eq!(results[0].doc_id, 3); // score 4.0
431        assert_eq!(results[1].doc_id, 1); // score 3.0
432        assert_eq!(results[2].doc_id, 2); // score 2.0
433    }
434
435    #[test]
436    fn test_count_collector() {
437        let mut collector = CountCollector::new();
438
439        collector.collect(0, 1.0, &[]);
440        collector.collect(1, 2.0, &[]);
441        collector.collect(2, 3.0, &[]);
442
443        assert_eq!(collector.count(), 3);
444    }
445
446    #[test]
447    fn test_multi_collector() {
448        let mut top_k = TopKCollector::new(2);
449        let mut count = CountCollector::new();
450
451        // Simulate what collect_segment_multi does
452        for (doc_id, score) in [(0, 1.0), (1, 3.0), (2, 2.0), (3, 4.0), (4, 0.5)] {
453            top_k.collect(doc_id, score, &[]);
454            count.collect(doc_id, score, &[]);
455        }
456
457        // Count should have all 5 documents
458        assert_eq!(count.count(), 5);
459
460        // TopK should only have top 2 results
461        let results = top_k.into_sorted_results();
462        assert_eq!(results.len(), 2);
463        assert_eq!(results[0].doc_id, 3); // score 4.0
464        assert_eq!(results[1].doc_id, 1); // score 3.0
465    }
466}