Skip to main content

hermes_core/query/
collector.rs

1//! Search result collection and response types
2
3use std::cmp::Ordering;
4use std::collections::BinaryHeap;
5
6use crate::segment::SegmentReader;
7use crate::structures::TERMINATED;
8use crate::{DocId, Result, Score};
9
10use super::Query;
11
12/// Unique document address: segment_id (hex) + local doc_id within segment
13#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
14pub struct DocAddress {
15    /// Segment ID as hex string (32 chars)
16    pub segment_id: String,
17    /// Document ID within the segment
18    pub doc_id: DocId,
19}
20
21impl DocAddress {
22    pub fn new(segment_id: u128, doc_id: DocId) -> Self {
23        Self {
24            segment_id: format!("{:032x}", segment_id),
25            doc_id,
26        }
27    }
28
29    /// Parse segment_id from hex string
30    pub fn segment_id_u128(&self) -> Option<u128> {
31        u128::from_str_radix(&self.segment_id, 16).ok()
32    }
33}
34
35/// A scored position/ordinal within a field
36/// For text fields: position is the token position
37/// For vector fields: position is the ordinal (which vector in multi-value)
38#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
39pub struct ScoredPosition {
40    /// Position (text) or ordinal (vector)
41    pub position: u32,
42    /// Individual score contribution from this position/ordinal
43    pub score: f32,
44}
45
46impl ScoredPosition {
47    pub fn new(position: u32, score: f32) -> Self {
48        Self { position, score }
49    }
50}
51
52/// Search result with doc_id and score (internal use)
53#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
54pub struct SearchResult {
55    pub doc_id: DocId,
56    pub score: Score,
57    /// Matched positions per field: (field_id, scored_positions)
58    /// Each position includes its individual score contribution
59    #[serde(default, skip_serializing_if = "Vec::is_empty")]
60    pub positions: Vec<(u32, Vec<ScoredPosition>)>,
61}
62
63/// Matched field info with ordinals (for multi-valued fields)
64#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
65pub struct MatchedField {
66    /// Field ID
67    pub field_id: u32,
68    /// Matched element ordinals (for multi-valued fields with position tracking)
69    /// Empty if position tracking is not enabled for this field
70    pub ordinals: Vec<u32>,
71}
72
73impl SearchResult {
74    /// Extract unique ordinals from positions for each field
75    /// For text fields: ordinal = position >> 20 (from encoded position)
76    /// For vector fields: position IS the ordinal directly
77    pub fn extract_ordinals(&self) -> Vec<MatchedField> {
78        use rustc_hash::FxHashSet;
79
80        self.positions
81            .iter()
82            .map(|(field_id, scored_positions)| {
83                let mut ordinals: FxHashSet<u32> = FxHashSet::default();
84                for sp in scored_positions {
85                    // For text fields with encoded positions, extract ordinal
86                    // For vector fields, position IS the ordinal
87                    // We use a heuristic: if position > 0xFFFFF (20 bits), it's encoded
88                    let ordinal = if sp.position > 0xFFFFF {
89                        sp.position >> 20
90                    } else {
91                        sp.position
92                    };
93                    ordinals.insert(ordinal);
94                }
95                let mut ordinals: Vec<u32> = ordinals.into_iter().collect();
96                ordinals.sort_unstable();
97                MatchedField {
98                    field_id: *field_id,
99                    ordinals,
100                }
101            })
102            .collect()
103    }
104
105    /// Get all scored positions for a specific field
106    pub fn field_positions(&self, field_id: u32) -> Option<&[ScoredPosition]> {
107        self.positions
108            .iter()
109            .find(|(fid, _)| *fid == field_id)
110            .map(|(_, positions)| positions.as_slice())
111    }
112}
113
114/// Search hit with unique document address and score
115#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
116pub struct SearchHit {
117    /// Unique document address (segment_id + local doc_id)
118    pub address: DocAddress,
119    pub score: Score,
120    /// Matched fields with element ordinals (populated when position tracking is enabled)
121    #[serde(default, skip_serializing_if = "Vec::is_empty")]
122    pub matched_fields: Vec<MatchedField>,
123}
124
125/// Search response with hits (IDs only, no documents)
126#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
127pub struct SearchResponse {
128    pub hits: Vec<SearchHit>,
129    pub total_hits: u32,
130}
131
132impl PartialEq for SearchResult {
133    fn eq(&self, other: &Self) -> bool {
134        self.doc_id == other.doc_id
135    }
136}
137
138impl Eq for SearchResult {}
139
140impl PartialOrd for SearchResult {
141    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
142        Some(self.cmp(other))
143    }
144}
145
146impl Ord for SearchResult {
147    fn cmp(&self, other: &Self) -> Ordering {
148        other
149            .score
150            .partial_cmp(&self.score)
151            .unwrap_or(Ordering::Equal)
152            .then_with(|| self.doc_id.cmp(&other.doc_id))
153    }
154}
155
156/// Trait for search result collectors
157///
158/// Implement this trait to create custom collectors that can be
159/// combined and passed to query execution.
160pub trait Collector {
161    /// Called for each matching document
162    /// positions: Vec of (field_id, scored_positions)
163    fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]);
164
165    /// Whether this collector needs position information
166    fn needs_positions(&self) -> bool {
167        false
168    }
169}
170
171/// Collector for top-k results
172pub struct TopKCollector {
173    heap: BinaryHeap<SearchResult>,
174    k: usize,
175    collect_positions: bool,
176}
177
178impl TopKCollector {
179    pub fn new(k: usize) -> Self {
180        Self {
181            heap: BinaryHeap::with_capacity(k + 1),
182            k,
183            collect_positions: false,
184        }
185    }
186
187    /// Create a collector that also collects positions
188    pub fn with_positions(k: usize) -> Self {
189        Self {
190            heap: BinaryHeap::with_capacity(k + 1),
191            k,
192            collect_positions: true,
193        }
194    }
195
196    pub fn into_sorted_results(self) -> Vec<SearchResult> {
197        let mut results: Vec<_> = self.heap.into_vec();
198        results.sort_by(|a, b| {
199            b.score
200                .partial_cmp(&a.score)
201                .unwrap_or(Ordering::Equal)
202                .then_with(|| a.doc_id.cmp(&b.doc_id))
203        });
204        results
205    }
206}
207
208impl Collector for TopKCollector {
209    fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]) {
210        let positions = if self.collect_positions {
211            positions.to_vec()
212        } else {
213            Vec::new()
214        };
215
216        if self.heap.len() < self.k {
217            self.heap.push(SearchResult {
218                doc_id,
219                score,
220                positions,
221            });
222        } else if let Some(min) = self.heap.peek()
223            && score > min.score
224        {
225            self.heap.pop();
226            self.heap.push(SearchResult {
227                doc_id,
228                score,
229                positions,
230            });
231        }
232    }
233
234    fn needs_positions(&self) -> bool {
235        self.collect_positions
236    }
237}
238
239/// Collector that counts all matching documents
240#[derive(Default)]
241pub struct CountCollector {
242    count: u64,
243}
244
245impl CountCollector {
246    pub fn new() -> Self {
247        Self { count: 0 }
248    }
249
250    /// Get the total count
251    pub fn count(&self) -> u64 {
252        self.count
253    }
254}
255
256impl Collector for CountCollector {
257    #[inline]
258    fn collect(
259        &mut self,
260        _doc_id: DocId,
261        _score: Score,
262        _positions: &[(u32, Vec<ScoredPosition>)],
263    ) {
264        self.count += 1;
265    }
266}
267
268/// Execute a search query on a single segment (async)
269pub async fn search_segment(
270    reader: &SegmentReader,
271    query: &dyn Query,
272    limit: usize,
273) -> Result<Vec<SearchResult>> {
274    let mut collector = TopKCollector::new(limit);
275    collect_segment(reader, query, &mut collector).await?;
276    Ok(collector.into_sorted_results())
277}
278
279/// Execute a search query on a single segment with position collection (async)
280pub async fn search_segment_with_positions(
281    reader: &SegmentReader,
282    query: &dyn Query,
283    limit: usize,
284) -> Result<Vec<SearchResult>> {
285    let mut collector = TopKCollector::with_positions(limit);
286    collect_segment(reader, query, &mut collector).await?;
287    Ok(collector.into_sorted_results())
288}
289
290/// Count all documents matching a query on a single segment (async)
291pub async fn count_segment(reader: &SegmentReader, query: &dyn Query) -> Result<u64> {
292    let mut collector = CountCollector::new();
293    collect_segment(reader, query, &mut collector).await?;
294    Ok(collector.count())
295}
296
297// Implement Collector for tuple of 2 collectors
298impl<A: Collector, B: Collector> Collector for (&mut A, &mut B) {
299    fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]) {
300        self.0.collect(doc_id, score, positions);
301        self.1.collect(doc_id, score, positions);
302    }
303    fn needs_positions(&self) -> bool {
304        self.0.needs_positions() || self.1.needs_positions()
305    }
306}
307
308// Implement Collector for tuple of 3 collectors
309impl<A: Collector, B: Collector, C: Collector> Collector for (&mut A, &mut B, &mut C) {
310    fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]) {
311        self.0.collect(doc_id, score, positions);
312        self.1.collect(doc_id, score, positions);
313        self.2.collect(doc_id, score, positions);
314    }
315    fn needs_positions(&self) -> bool {
316        self.0.needs_positions() || self.1.needs_positions() || self.2.needs_positions()
317    }
318}
319
320/// Execute a query with one or more collectors (async)
321///
322/// # Examples
323/// ```ignore
324/// // Single collector
325/// let mut top_k = TopKCollector::new(10);
326/// collect_segment(reader, query, &mut top_k).await?;
327///
328/// // Multiple collectors (tuple)
329/// let mut top_k = TopKCollector::new(10);
330/// let mut count = CountCollector::new();
331/// collect_segment(reader, query, &mut (&mut top_k, &mut count)).await?;
332/// ```
333pub async fn collect_segment<C: Collector>(
334    reader: &SegmentReader,
335    query: &dyn Query,
336    collector: &mut C,
337) -> Result<()> {
338    let needs_positions = collector.needs_positions();
339    // Use large limit to disable WAND skipping, but not usize::MAX to avoid overflow
340    let mut scorer = query.scorer(reader, usize::MAX / 2).await?;
341
342    let mut doc = scorer.doc();
343    while doc != TERMINATED {
344        let positions = if needs_positions {
345            scorer.matched_positions().unwrap_or_default()
346        } else {
347            Vec::new()
348        };
349        collector.collect(doc, scorer.score(), &positions);
350        doc = scorer.advance();
351    }
352
353    Ok(())
354}
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359
360    #[test]
361    fn test_top_k_collector() {
362        let mut collector = TopKCollector::new(3);
363
364        collector.collect(0, 1.0, &[]);
365        collector.collect(1, 3.0, &[]);
366        collector.collect(2, 2.0, &[]);
367        collector.collect(3, 4.0, &[]);
368        collector.collect(4, 0.5, &[]);
369
370        let results = collector.into_sorted_results();
371
372        assert_eq!(results.len(), 3);
373        assert_eq!(results[0].doc_id, 3); // score 4.0
374        assert_eq!(results[1].doc_id, 1); // score 3.0
375        assert_eq!(results[2].doc_id, 2); // score 2.0
376    }
377
378    #[test]
379    fn test_count_collector() {
380        let mut collector = CountCollector::new();
381
382        collector.collect(0, 1.0, &[]);
383        collector.collect(1, 2.0, &[]);
384        collector.collect(2, 3.0, &[]);
385
386        assert_eq!(collector.count(), 3);
387    }
388
389    #[test]
390    fn test_multi_collector() {
391        let mut top_k = TopKCollector::new(2);
392        let mut count = CountCollector::new();
393
394        // Simulate what collect_segment_multi does
395        for (doc_id, score) in [(0, 1.0), (1, 3.0), (2, 2.0), (3, 4.0), (4, 0.5)] {
396            top_k.collect(doc_id, score, &[]);
397            count.collect(doc_id, score, &[]);
398        }
399
400        // Count should have all 5 documents
401        assert_eq!(count.count(), 5);
402
403        // TopK should only have top 2 results
404        let results = top_k.into_sorted_results();
405        assert_eq!(results.len(), 2);
406        assert_eq!(results[0].doc_id, 3); // score 4.0
407        assert_eq!(results[1].doc_id, 1); // score 3.0
408    }
409}