1use std::cmp::Ordering;
4use std::collections::BinaryHeap;
5
6use crate::segment::SegmentReader;
7use crate::structures::TERMINATED;
8use crate::{DocId, Result, Score};
9
10use super::Query;
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
14pub struct DocAddress {
15 pub segment_id: String,
17 pub doc_id: DocId,
19}
20
21impl DocAddress {
22 pub fn new(segment_id: u128, doc_id: DocId) -> Self {
23 Self {
24 segment_id: format!("{:032x}", segment_id),
25 doc_id,
26 }
27 }
28
29 pub fn segment_id_u128(&self) -> Option<u128> {
31 u128::from_str_radix(&self.segment_id, 16).ok()
32 }
33}
34
35#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
37pub struct SearchResult {
38 pub doc_id: DocId,
39 pub score: Score,
40 #[serde(default, skip_serializing_if = "Vec::is_empty")]
42 pub positions: Vec<(u32, Vec<u32>)>,
43}
44
45#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
47pub struct MatchedField {
48 pub field_id: u32,
50 pub ordinals: Vec<u32>,
53}
54
55impl SearchResult {
56 pub fn extract_ordinals(&self) -> Vec<MatchedField> {
59 use rustc_hash::FxHashSet;
60
61 self.positions
62 .iter()
63 .map(|(field_id, positions)| {
64 let mut ordinals: FxHashSet<u32> = FxHashSet::default();
65 for &pos in positions {
66 let ordinal = pos >> 20; ordinals.insert(ordinal);
68 }
69 let mut ordinals: Vec<u32> = ordinals.into_iter().collect();
70 ordinals.sort_unstable();
71 MatchedField {
72 field_id: *field_id,
73 ordinals,
74 }
75 })
76 .collect()
77 }
78}
79
80#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
82pub struct SearchHit {
83 pub address: DocAddress,
85 pub score: Score,
86 #[serde(default, skip_serializing_if = "Vec::is_empty")]
88 pub matched_fields: Vec<MatchedField>,
89}
90
91#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
93pub struct SearchResponse {
94 pub hits: Vec<SearchHit>,
95 pub total_hits: u32,
96}
97
98impl PartialEq for SearchResult {
99 fn eq(&self, other: &Self) -> bool {
100 self.doc_id == other.doc_id
101 }
102}
103
104impl Eq for SearchResult {}
105
106impl PartialOrd for SearchResult {
107 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
108 Some(self.cmp(other))
109 }
110}
111
112impl Ord for SearchResult {
113 fn cmp(&self, other: &Self) -> Ordering {
114 other
115 .score
116 .partial_cmp(&self.score)
117 .unwrap_or(Ordering::Equal)
118 .then_with(|| self.doc_id.cmp(&other.doc_id))
119 }
120}
121
122pub trait Collector {
127 fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<u32>)]);
129
130 fn needs_positions(&self) -> bool {
132 false
133 }
134}
135
136pub struct TopKCollector {
138 heap: BinaryHeap<SearchResult>,
139 k: usize,
140 collect_positions: bool,
141}
142
143impl TopKCollector {
144 pub fn new(k: usize) -> Self {
145 Self {
146 heap: BinaryHeap::with_capacity(k + 1),
147 k,
148 collect_positions: false,
149 }
150 }
151
152 pub fn with_positions(k: usize) -> Self {
154 Self {
155 heap: BinaryHeap::with_capacity(k + 1),
156 k,
157 collect_positions: true,
158 }
159 }
160
161 pub fn into_sorted_results(self) -> Vec<SearchResult> {
162 let mut results: Vec<_> = self.heap.into_vec();
163 results.sort_by(|a, b| {
164 b.score
165 .partial_cmp(&a.score)
166 .unwrap_or(Ordering::Equal)
167 .then_with(|| a.doc_id.cmp(&b.doc_id))
168 });
169 results
170 }
171}
172
173impl Collector for TopKCollector {
174 fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<u32>)]) {
175 let positions = if self.collect_positions {
176 positions.to_vec()
177 } else {
178 Vec::new()
179 };
180
181 if self.heap.len() < self.k {
182 self.heap.push(SearchResult {
183 doc_id,
184 score,
185 positions,
186 });
187 } else if let Some(min) = self.heap.peek()
188 && score > min.score
189 {
190 self.heap.pop();
191 self.heap.push(SearchResult {
192 doc_id,
193 score,
194 positions,
195 });
196 }
197 }
198
199 fn needs_positions(&self) -> bool {
200 self.collect_positions
201 }
202}
203
204#[derive(Default)]
206pub struct CountCollector {
207 count: u64,
208}
209
210impl CountCollector {
211 pub fn new() -> Self {
212 Self { count: 0 }
213 }
214
215 pub fn count(&self) -> u64 {
217 self.count
218 }
219}
220
221impl Collector for CountCollector {
222 #[inline]
223 fn collect(&mut self, _doc_id: DocId, _score: Score, _positions: &[(u32, Vec<u32>)]) {
224 self.count += 1;
225 }
226}
227
228pub async fn search_segment(
230 reader: &SegmentReader,
231 query: &dyn Query,
232 limit: usize,
233) -> Result<Vec<SearchResult>> {
234 let mut collector = TopKCollector::new(limit);
235 collect_segment(reader, query, &mut collector).await?;
236 Ok(collector.into_sorted_results())
237}
238
239pub async fn search_segment_with_positions(
241 reader: &SegmentReader,
242 query: &dyn Query,
243 limit: usize,
244) -> Result<Vec<SearchResult>> {
245 let mut collector = TopKCollector::with_positions(limit);
246 collect_segment(reader, query, &mut collector).await?;
247 Ok(collector.into_sorted_results())
248}
249
250pub async fn count_segment(reader: &SegmentReader, query: &dyn Query) -> Result<u64> {
252 let mut collector = CountCollector::new();
253 collect_segment(reader, query, &mut collector).await?;
254 Ok(collector.count())
255}
256
257impl<A: Collector, B: Collector> Collector for (&mut A, &mut B) {
259 fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<u32>)]) {
260 self.0.collect(doc_id, score, positions);
261 self.1.collect(doc_id, score, positions);
262 }
263 fn needs_positions(&self) -> bool {
264 self.0.needs_positions() || self.1.needs_positions()
265 }
266}
267
268impl<A: Collector, B: Collector, C: Collector> Collector for (&mut A, &mut B, &mut C) {
270 fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<u32>)]) {
271 self.0.collect(doc_id, score, positions);
272 self.1.collect(doc_id, score, positions);
273 self.2.collect(doc_id, score, positions);
274 }
275 fn needs_positions(&self) -> bool {
276 self.0.needs_positions() || self.1.needs_positions() || self.2.needs_positions()
277 }
278}
279
280pub async fn collect_segment<C: Collector>(
294 reader: &SegmentReader,
295 query: &dyn Query,
296 collector: &mut C,
297) -> Result<()> {
298 let needs_positions = collector.needs_positions();
299 let mut scorer = query.scorer(reader, usize::MAX / 2).await?;
301
302 let mut doc = scorer.doc();
303 while doc != TERMINATED {
304 let positions = if needs_positions {
305 scorer.matched_positions().unwrap_or_default()
306 } else {
307 Vec::new()
308 };
309 collector.collect(doc, scorer.score(), &positions);
310 doc = scorer.advance();
311 }
312
313 Ok(())
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319
320 #[test]
321 fn test_top_k_collector() {
322 let mut collector = TopKCollector::new(3);
323
324 collector.collect(0, 1.0, &[]);
325 collector.collect(1, 3.0, &[]);
326 collector.collect(2, 2.0, &[]);
327 collector.collect(3, 4.0, &[]);
328 collector.collect(4, 0.5, &[]);
329
330 let results = collector.into_sorted_results();
331
332 assert_eq!(results.len(), 3);
333 assert_eq!(results[0].doc_id, 3); assert_eq!(results[1].doc_id, 1); assert_eq!(results[2].doc_id, 2); }
337
338 #[test]
339 fn test_count_collector() {
340 let mut collector = CountCollector::new();
341
342 collector.collect(0, 1.0, &[]);
343 collector.collect(1, 2.0, &[]);
344 collector.collect(2, 3.0, &[]);
345
346 assert_eq!(collector.count(), 3);
347 }
348
349 #[test]
350 fn test_multi_collector() {
351 let mut top_k = TopKCollector::new(2);
352 let mut count = CountCollector::new();
353
354 for (doc_id, score) in [(0, 1.0), (1, 3.0), (2, 2.0), (3, 4.0), (4, 0.5)] {
356 top_k.collect(doc_id, score, &[]);
357 count.collect(doc_id, score, &[]);
358 }
359
360 assert_eq!(count.count(), 5);
362
363 let results = top_k.into_sorted_results();
365 assert_eq!(results.len(), 2);
366 assert_eq!(results[0].doc_id, 3); assert_eq!(results[1].doc_id, 1); }
369}