Skip to main content

hermes_core/query/
collector.rs

1//! Search result collection and response types
2
3use std::cmp::Ordering;
4use std::collections::BinaryHeap;
5
6use crate::segment::SegmentReader;
7use crate::structures::TERMINATED;
8use crate::{DocId, Result, Score};
9
10use super::Query;
11
12/// Unique document address: segment_id + local doc_id within segment.
13/// Stores segment_id as u128 internally (16 bytes) but serializes as hex string
14/// for backward compatibility with JSON/gRPC clients.
15#[derive(Debug, Clone, PartialEq, Eq, Hash)]
16pub struct DocAddress {
17    /// Segment ID as u128 (avoids heap allocation vs String)
18    segment_id_raw: u128,
19    /// Document ID within the segment
20    pub doc_id: DocId,
21}
22
23impl DocAddress {
24    pub fn new(segment_id: u128, doc_id: DocId) -> Self {
25        Self {
26            segment_id_raw: segment_id,
27            doc_id,
28        }
29    }
30
31    /// Get segment_id as hex string (for display/API)
32    pub fn segment_id(&self) -> String {
33        format!("{:032x}", self.segment_id_raw)
34    }
35
36    /// Get segment_id as u128 (zero-cost)
37    pub fn segment_id_u128(&self) -> Option<u128> {
38        Some(self.segment_id_raw)
39    }
40}
41
42impl serde::Serialize for DocAddress {
43    fn serialize<S: serde::Serializer>(
44        &self,
45        serializer: S,
46    ) -> std::result::Result<S::Ok, S::Error> {
47        use serde::ser::SerializeStruct;
48        let mut s = serializer.serialize_struct("DocAddress", 2)?;
49        s.serialize_field("segment_id", &format!("{:032x}", self.segment_id_raw))?;
50        s.serialize_field("doc_id", &self.doc_id)?;
51        s.end()
52    }
53}
54
55impl<'de> serde::Deserialize<'de> for DocAddress {
56    fn deserialize<D: serde::Deserializer<'de>>(
57        deserializer: D,
58    ) -> std::result::Result<Self, D::Error> {
59        #[derive(serde::Deserialize)]
60        struct Helper {
61            segment_id: String,
62            doc_id: DocId,
63        }
64        let h = Helper::deserialize(deserializer)?;
65        let raw = u128::from_str_radix(&h.segment_id, 16).map_err(serde::de::Error::custom)?;
66        Ok(DocAddress {
67            segment_id_raw: raw,
68            doc_id: h.doc_id,
69        })
70    }
71}
72
73/// A scored position/ordinal within a field
74/// For text fields: position is the token position
75/// For vector fields: position is the ordinal (which vector in multi-value)
76#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
77pub struct ScoredPosition {
78    /// Position (text) or ordinal (vector)
79    pub position: u32,
80    /// Individual score contribution from this position/ordinal
81    pub score: f32,
82}
83
84impl ScoredPosition {
85    pub fn new(position: u32, score: f32) -> Self {
86        Self { position, score }
87    }
88}
89
90/// Search result with doc_id and score (internal use)
91#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
92pub struct SearchResult {
93    pub doc_id: DocId,
94    pub score: Score,
95    /// Segment ID (set by searcher after collection)
96    #[serde(default, skip_serializing_if = "is_zero_u128")]
97    pub segment_id: u128,
98    /// Matched positions per field: (field_id, scored_positions)
99    /// Each position includes its individual score contribution
100    #[serde(default, skip_serializing_if = "Vec::is_empty")]
101    pub positions: Vec<(u32, Vec<ScoredPosition>)>,
102}
103
104fn is_zero_u128(v: &u128) -> bool {
105    *v == 0
106}
107
108/// Matched field info with ordinals (for multi-valued fields)
109#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
110pub struct MatchedField {
111    /// Field ID
112    pub field_id: u32,
113    /// Matched element ordinals (for multi-valued fields with position tracking)
114    /// Empty if position tracking is not enabled for this field
115    pub ordinals: Vec<u32>,
116}
117
118impl SearchResult {
119    /// Extract unique ordinals from positions for each field
120    /// For text fields: ordinal = position >> 20 (from encoded position)
121    /// For vector fields: position IS the ordinal directly
122    pub fn extract_ordinals(&self) -> Vec<MatchedField> {
123        use rustc_hash::FxHashSet;
124
125        self.positions
126            .iter()
127            .map(|(field_id, scored_positions)| {
128                let mut ordinals: FxHashSet<u32> = FxHashSet::default();
129                for sp in scored_positions {
130                    // For text fields with encoded positions, extract ordinal
131                    // For vector fields, position IS the ordinal
132                    // We use a heuristic: if position > 0xFFFFF (20 bits), it's encoded
133                    let ordinal = if sp.position > 0xFFFFF {
134                        sp.position >> 20
135                    } else {
136                        sp.position
137                    };
138                    ordinals.insert(ordinal);
139                }
140                let mut ordinals: Vec<u32> = ordinals.into_iter().collect();
141                ordinals.sort_unstable();
142                MatchedField {
143                    field_id: *field_id,
144                    ordinals,
145                }
146            })
147            .collect()
148    }
149
150    /// Get all scored positions for a specific field
151    pub fn field_positions(&self, field_id: u32) -> Option<&[ScoredPosition]> {
152        self.positions
153            .iter()
154            .find(|(fid, _)| *fid == field_id)
155            .map(|(_, positions)| positions.as_slice())
156    }
157}
158
159/// Search hit with unique document address and score
160#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
161pub struct SearchHit {
162    /// Unique document address (segment_id + local doc_id)
163    pub address: DocAddress,
164    pub score: Score,
165    /// Matched fields with element ordinals (populated when position tracking is enabled)
166    #[serde(default, skip_serializing_if = "Vec::is_empty")]
167    pub matched_fields: Vec<MatchedField>,
168}
169
170/// Search response with hits (IDs only, no documents)
171#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
172pub struct SearchResponse {
173    pub hits: Vec<SearchHit>,
174    pub total_hits: u32,
175}
176
177impl PartialEq for SearchResult {
178    fn eq(&self, other: &Self) -> bool {
179        self.segment_id == other.segment_id && self.doc_id == other.doc_id
180    }
181}
182
183impl Eq for SearchResult {}
184
185impl PartialOrd for SearchResult {
186    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
187        Some(self.cmp(other))
188    }
189}
190
191impl Ord for SearchResult {
192    fn cmp(&self, other: &Self) -> Ordering {
193        other
194            .score
195            .partial_cmp(&self.score)
196            .unwrap_or(Ordering::Equal)
197            .then_with(|| self.segment_id.cmp(&other.segment_id))
198            .then_with(|| self.doc_id.cmp(&other.doc_id))
199    }
200}
201
202/// Trait for search result collectors
203///
204/// Implement this trait to create custom collectors that can be
205/// combined and passed to query execution.
206pub trait Collector {
207    /// Called for each matching document
208    /// positions: Vec of (field_id, scored_positions)
209    fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]);
210
211    /// Whether this collector needs position information
212    fn needs_positions(&self) -> bool {
213        false
214    }
215}
216
217/// Collector for top-k results
218pub struct TopKCollector {
219    heap: BinaryHeap<SearchResult>,
220    k: usize,
221    collect_positions: bool,
222    /// Total documents seen by this collector
223    total_seen: u32,
224}
225
226impl TopKCollector {
227    pub fn new(k: usize) -> Self {
228        Self {
229            heap: BinaryHeap::with_capacity(k + 1),
230            k,
231            collect_positions: false,
232            total_seen: 0,
233        }
234    }
235
236    /// Create a collector that also collects positions
237    pub fn with_positions(k: usize) -> Self {
238        Self {
239            heap: BinaryHeap::with_capacity(k + 1),
240            k,
241            collect_positions: true,
242            total_seen: 0,
243        }
244    }
245
246    /// Get the total number of documents seen (scored) by this collector
247    pub fn total_seen(&self) -> u32 {
248        self.total_seen
249    }
250
251    pub fn into_sorted_results(self) -> Vec<SearchResult> {
252        let mut results: Vec<_> = self.heap.into_vec();
253        results.sort_by(|a, b| {
254            b.score
255                .partial_cmp(&a.score)
256                .unwrap_or(Ordering::Equal)
257                .then_with(|| a.doc_id.cmp(&b.doc_id))
258        });
259        results
260    }
261
262    /// Consume collector and return (sorted_results, total_seen)
263    pub fn into_results_with_count(self) -> (Vec<SearchResult>, u32) {
264        let total = self.total_seen;
265        (self.into_sorted_results(), total)
266    }
267}
268
269impl Collector for TopKCollector {
270    fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]) {
271        self.total_seen = self.total_seen.saturating_add(1);
272
273        // Only clone positions when the document will actually be kept in the heap.
274        // This avoids deep-cloning Vec<ScoredPosition> for documents that are
275        // immediately discarded (the common case for large result sets).
276        let dominated =
277            self.heap.len() >= self.k && self.heap.peek().is_some_and(|min| score <= min.score);
278        if dominated {
279            return;
280        }
281
282        let positions = if self.collect_positions {
283            positions.to_vec()
284        } else {
285            Vec::new()
286        };
287
288        if self.heap.len() >= self.k {
289            self.heap.pop();
290        }
291        self.heap.push(SearchResult {
292            doc_id,
293            score,
294            segment_id: 0,
295            positions,
296        });
297    }
298
299    fn needs_positions(&self) -> bool {
300        self.collect_positions
301    }
302}
303
304/// Collector that counts all matching documents
305#[derive(Default)]
306pub struct CountCollector {
307    count: u64,
308}
309
310impl CountCollector {
311    pub fn new() -> Self {
312        Self { count: 0 }
313    }
314
315    /// Get the total count
316    pub fn count(&self) -> u64 {
317        self.count
318    }
319}
320
321impl Collector for CountCollector {
322    #[inline]
323    fn collect(
324        &mut self,
325        _doc_id: DocId,
326        _score: Score,
327        _positions: &[(u32, Vec<ScoredPosition>)],
328    ) {
329        self.count += 1;
330    }
331}
332
333/// Execute a search query on a single segment and return (results, total_seen) (async)
334pub async fn search_segment_with_count(
335    reader: &SegmentReader,
336    query: &dyn Query,
337    limit: usize,
338) -> Result<(Vec<SearchResult>, u32)> {
339    let mut collector = TopKCollector::new(limit);
340    collect_segment_with_limit(reader, query, &mut collector, limit).await?;
341    Ok(collector.into_results_with_count())
342}
343
344/// Execute a search query on a single segment with positions and return (results, total_seen)
345pub async fn search_segment_with_positions_and_count(
346    reader: &SegmentReader,
347    query: &dyn Query,
348    limit: usize,
349) -> Result<(Vec<SearchResult>, u32)> {
350    let mut collector = TopKCollector::with_positions(limit);
351    collect_segment_with_limit(reader, query, &mut collector, limit).await?;
352    Ok(collector.into_results_with_count())
353}
354
355// Implement Collector for tuple of 2 collectors
356impl<A: Collector, B: Collector> Collector for (&mut A, &mut B) {
357    fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]) {
358        self.0.collect(doc_id, score, positions);
359        self.1.collect(doc_id, score, positions);
360    }
361    fn needs_positions(&self) -> bool {
362        self.0.needs_positions() || self.1.needs_positions()
363    }
364}
365
366// Implement Collector for tuple of 3 collectors
367impl<A: Collector, B: Collector, C: Collector> Collector for (&mut A, &mut B, &mut C) {
368    fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]) {
369        self.0.collect(doc_id, score, positions);
370        self.1.collect(doc_id, score, positions);
371        self.2.collect(doc_id, score, positions);
372    }
373    fn needs_positions(&self) -> bool {
374        self.0.needs_positions() || self.1.needs_positions() || self.2.needs_positions()
375    }
376}
377
378/// Execute a query with one or more collectors (async)
379///
380/// Uses a large limit for the scorer to disable MaxScore pruning.
381/// For queries that benefit from MaxScore pruning (e.g., sparse vector search),
382/// use `collect_segment_with_limit` instead.
383///
384/// # Examples
385/// ```ignore
386/// // Single collector
387/// let mut top_k = TopKCollector::new(10);
388/// collect_segment(reader, query, &mut top_k).await?;
389///
390/// // Multiple collectors (tuple)
391/// let mut top_k = TopKCollector::new(10);
392/// let mut count = CountCollector::new();
393/// collect_segment(reader, query, &mut (&mut top_k, &mut count)).await?;
394/// ```
395pub async fn collect_segment<C: Collector>(
396    reader: &SegmentReader,
397    query: &dyn Query,
398    collector: &mut C,
399) -> Result<()> {
400    // Use large limit to disable MaxScore skipping for exhaustive collection
401    collect_segment_with_limit(reader, query, collector, usize::MAX / 2).await
402}
403
404/// Execute a query with one or more collectors and a specific limit (async)
405///
406/// The limit is passed to the scorer to enable MaxScore pruning for queries
407/// that support it (e.g., sparse vector search). This significantly improves
408/// performance when only the top-k results are needed.
409///
410/// Doc IDs in the collector are segment-local. The searcher stamps each result
411/// with its segment_id, making (segment_id, doc_id) the unique document key.
412pub async fn collect_segment_with_limit<C: Collector>(
413    reader: &SegmentReader,
414    query: &dyn Query,
415    collector: &mut C,
416    limit: usize,
417) -> Result<()> {
418    let mut scorer = query.scorer(reader, limit).await?;
419    drive_scorer(scorer.as_mut(), collector);
420    Ok(())
421}
422
423/// Drive a scorer through a collector (shared by async and sync paths).
424fn drive_scorer<C: Collector>(scorer: &mut dyn super::Scorer, collector: &mut C) {
425    let needs_positions = collector.needs_positions();
426    let mut doc = scorer.doc();
427    while doc != TERMINATED {
428        if needs_positions {
429            let positions = scorer.matched_positions().unwrap_or_default();
430            collector.collect(doc, scorer.score(), &positions);
431        } else {
432            collector.collect(doc, scorer.score(), &[]);
433        }
434        doc = scorer.advance();
435    }
436}
437
438// ── Synchronous collector functions (mmap/RAM only) ─────────────────────────
439
440/// Synchronous segment search — returns (results, total_seen).
441#[cfg(feature = "sync")]
442pub fn search_segment_with_count_sync(
443    reader: &SegmentReader,
444    query: &dyn Query,
445    limit: usize,
446) -> Result<(Vec<SearchResult>, u32)> {
447    let mut collector = TopKCollector::new(limit);
448    collect_segment_with_limit_sync(reader, query, &mut collector, limit)?;
449    Ok(collector.into_results_with_count())
450}
451
452/// Synchronous segment search with positions — returns (results, total_seen).
453#[cfg(feature = "sync")]
454pub fn search_segment_with_positions_and_count_sync(
455    reader: &SegmentReader,
456    query: &dyn Query,
457    limit: usize,
458) -> Result<(Vec<SearchResult>, u32)> {
459    let mut collector = TopKCollector::with_positions(limit);
460    collect_segment_with_limit_sync(reader, query, &mut collector, limit)?;
461    Ok(collector.into_results_with_count())
462}
463
464/// Synchronous collect with limit — uses `scorer_sync`.
465#[cfg(feature = "sync")]
466pub fn collect_segment_with_limit_sync<C: Collector>(
467    reader: &SegmentReader,
468    query: &dyn Query,
469    collector: &mut C,
470    limit: usize,
471) -> Result<()> {
472    let mut scorer = query.scorer_sync(reader, limit)?;
473    drive_scorer(scorer.as_mut(), collector);
474    Ok(())
475}
476
477#[cfg(test)]
478mod tests {
479    use super::*;
480
481    #[test]
482    fn test_top_k_collector() {
483        let mut collector = TopKCollector::new(3);
484
485        collector.collect(0, 1.0, &[]);
486        collector.collect(1, 3.0, &[]);
487        collector.collect(2, 2.0, &[]);
488        collector.collect(3, 4.0, &[]);
489        collector.collect(4, 0.5, &[]);
490
491        let results = collector.into_sorted_results();
492
493        assert_eq!(results.len(), 3);
494        assert_eq!(results[0].doc_id, 3); // score 4.0
495        assert_eq!(results[1].doc_id, 1); // score 3.0
496        assert_eq!(results[2].doc_id, 2); // score 2.0
497    }
498
499    #[test]
500    fn test_count_collector() {
501        let mut collector = CountCollector::new();
502
503        collector.collect(0, 1.0, &[]);
504        collector.collect(1, 2.0, &[]);
505        collector.collect(2, 3.0, &[]);
506
507        assert_eq!(collector.count(), 3);
508    }
509
510    #[test]
511    fn test_multi_collector() {
512        let mut top_k = TopKCollector::new(2);
513        let mut count = CountCollector::new();
514
515        // Simulate what collect_segment_multi does
516        for (doc_id, score) in [(0, 1.0), (1, 3.0), (2, 2.0), (3, 4.0), (4, 0.5)] {
517            top_k.collect(doc_id, score, &[]);
518            count.collect(doc_id, score, &[]);
519        }
520
521        // Count should have all 5 documents
522        assert_eq!(count.count(), 5);
523
524        // TopK should only have top 2 results
525        let results = top_k.into_sorted_results();
526        assert_eq!(results.len(), 2);
527        assert_eq!(results[0].doc_id, 3); // score 4.0
528        assert_eq!(results[1].doc_id, 1); // score 3.0
529    }
530}