Skip to main content

hermes_core/query/
collector.rs

1//! Search result collection and response types
2
3use std::cmp::Ordering;
4use std::collections::BinaryHeap;
5
6use crate::segment::SegmentReader;
7use crate::structures::TERMINATED;
8use crate::{DocId, Result, Score};
9
10use super::Query;
11
12/// Unique document address: segment_id (hex) + local doc_id within segment
13#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
14pub struct DocAddress {
15    /// Segment ID as hex string (32 chars)
16    pub segment_id: String,
17    /// Document ID within the segment
18    pub doc_id: DocId,
19}
20
21impl DocAddress {
22    pub fn new(segment_id: u128, doc_id: DocId) -> Self {
23        Self {
24            segment_id: format!("{:032x}", segment_id),
25            doc_id,
26        }
27    }
28
29    /// Parse segment_id from hex string
30    pub fn segment_id_u128(&self) -> Option<u128> {
31        u128::from_str_radix(&self.segment_id, 16).ok()
32    }
33}
34
35/// Search result with doc_id and score (internal use)
36#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
37pub struct SearchResult {
38    pub doc_id: DocId,
39    pub score: Score,
40    /// Matched positions per field: (field_id, encoded_positions)
41    #[serde(default, skip_serializing_if = "Vec::is_empty")]
42    pub positions: Vec<(u32, Vec<u32>)>,
43}
44
45/// Matched field info with ordinals (for multi-valued fields)
46#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
47pub struct MatchedField {
48    /// Field ID
49    pub field_id: u32,
50    /// Matched element ordinals (for multi-valued fields with position tracking)
51    /// Empty if position tracking is not enabled for this field
52    pub ordinals: Vec<u32>,
53}
54
55impl SearchResult {
56    /// Extract unique ordinals from positions for each field
57    /// Uses the position encoding: ordinal = position >> 20
58    pub fn extract_ordinals(&self) -> Vec<MatchedField> {
59        use rustc_hash::FxHashSet;
60
61        self.positions
62            .iter()
63            .map(|(field_id, positions)| {
64                let mut ordinals: FxHashSet<u32> = FxHashSet::default();
65                for &pos in positions {
66                    let ordinal = pos >> 20; // Extract ordinal from encoded position
67                    ordinals.insert(ordinal);
68                }
69                let mut ordinals: Vec<u32> = ordinals.into_iter().collect();
70                ordinals.sort_unstable();
71                MatchedField {
72                    field_id: *field_id,
73                    ordinals,
74                }
75            })
76            .collect()
77    }
78}
79
80/// Search hit with unique document address and score
81#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
82pub struct SearchHit {
83    /// Unique document address (segment_id + local doc_id)
84    pub address: DocAddress,
85    pub score: Score,
86    /// Matched fields with element ordinals (populated when position tracking is enabled)
87    #[serde(default, skip_serializing_if = "Vec::is_empty")]
88    pub matched_fields: Vec<MatchedField>,
89}
90
91/// Search response with hits (IDs only, no documents)
92#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
93pub struct SearchResponse {
94    pub hits: Vec<SearchHit>,
95    pub total_hits: u32,
96}
97
98impl PartialEq for SearchResult {
99    fn eq(&self, other: &Self) -> bool {
100        self.doc_id == other.doc_id
101    }
102}
103
104impl Eq for SearchResult {}
105
106impl PartialOrd for SearchResult {
107    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
108        Some(self.cmp(other))
109    }
110}
111
112impl Ord for SearchResult {
113    fn cmp(&self, other: &Self) -> Ordering {
114        other
115            .score
116            .partial_cmp(&self.score)
117            .unwrap_or(Ordering::Equal)
118            .then_with(|| self.doc_id.cmp(&other.doc_id))
119    }
120}
121
122/// Collector for top-k results
123pub struct TopKCollector {
124    heap: BinaryHeap<SearchResult>,
125    k: usize,
126}
127
128impl TopKCollector {
129    pub fn new(k: usize) -> Self {
130        Self {
131            heap: BinaryHeap::with_capacity(k + 1),
132            k,
133        }
134    }
135
136    pub fn collect(&mut self, doc_id: DocId, score: Score) {
137        self.collect_with_positions(doc_id, score, Vec::new());
138    }
139
140    pub fn collect_with_positions(
141        &mut self,
142        doc_id: DocId,
143        score: Score,
144        positions: Vec<(u32, Vec<u32>)>,
145    ) {
146        if self.heap.len() < self.k {
147            self.heap.push(SearchResult {
148                doc_id,
149                score,
150                positions,
151            });
152        } else if let Some(min) = self.heap.peek()
153            && score > min.score
154        {
155            self.heap.pop();
156            self.heap.push(SearchResult {
157                doc_id,
158                score,
159                positions,
160            });
161        }
162    }
163
164    pub fn into_sorted_results(self) -> Vec<SearchResult> {
165        let mut results: Vec<_> = self.heap.into_vec();
166        results.sort_by(|a, b| {
167            b.score
168                .partial_cmp(&a.score)
169                .unwrap_or(Ordering::Equal)
170                .then_with(|| a.doc_id.cmp(&b.doc_id))
171        });
172        results
173    }
174}
175
176/// Execute a search query on a single segment (async)
177pub async fn search_segment(
178    reader: &SegmentReader,
179    query: &dyn Query,
180    limit: usize,
181) -> Result<Vec<SearchResult>> {
182    search_segment_with_positions(reader, query, limit, false).await
183}
184
185/// Execute a search query on a single segment with optional position collection (async)
186pub async fn search_segment_with_positions(
187    reader: &SegmentReader,
188    query: &dyn Query,
189    limit: usize,
190    collect_positions: bool,
191) -> Result<Vec<SearchResult>> {
192    let mut scorer = query.scorer(reader, limit).await?;
193    let mut collector = TopKCollector::new(limit);
194
195    let mut doc = scorer.doc();
196
197    while doc != TERMINATED {
198        let positions = if collect_positions {
199            scorer.matched_positions().unwrap_or_default()
200        } else {
201            Vec::new()
202        };
203        collector.collect_with_positions(doc, scorer.score(), positions);
204        doc = scorer.advance();
205    }
206
207    Ok(collector.into_sorted_results())
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213
214    #[test]
215    fn test_top_k_collector() {
216        let mut collector = TopKCollector::new(3);
217
218        collector.collect(0, 1.0);
219        collector.collect(1, 3.0);
220        collector.collect(2, 2.0);
221        collector.collect(3, 4.0);
222        collector.collect(4, 0.5);
223
224        let results = collector.into_sorted_results();
225
226        assert_eq!(results.len(), 3);
227        assert_eq!(results[0].doc_id, 3); // score 4.0
228        assert_eq!(results[1].doc_id, 1); // score 3.0
229        assert_eq!(results[2].doc_id, 2); // score 2.0
230    }
231}