1use std::cmp::Ordering;
4use std::collections::BinaryHeap;
5
6use crate::segment::SegmentReader;
7use crate::structures::TERMINATED;
8use crate::{DocId, Result, Score};
9
10use super::Query;
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
14pub struct DocAddress {
15 pub segment_id: String,
17 pub doc_id: DocId,
19}
20
21impl DocAddress {
22 pub fn new(segment_id: u128, doc_id: DocId) -> Self {
23 Self {
24 segment_id: format!("{:032x}", segment_id),
25 doc_id,
26 }
27 }
28
29 pub fn segment_id_u128(&self) -> Option<u128> {
31 u128::from_str_radix(&self.segment_id, 16).ok()
32 }
33}
34
35#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
39pub struct ScoredPosition {
40 pub position: u32,
42 pub score: f32,
44}
45
46impl ScoredPosition {
47 pub fn new(position: u32, score: f32) -> Self {
48 Self { position, score }
49 }
50}
51
52#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
54pub struct SearchResult {
55 pub doc_id: DocId,
56 pub score: Score,
57 #[serde(default, skip_serializing_if = "is_zero_u128")]
59 pub segment_id: u128,
60 #[serde(default, skip_serializing_if = "Vec::is_empty")]
63 pub positions: Vec<(u32, Vec<ScoredPosition>)>,
64}
65
66fn is_zero_u128(v: &u128) -> bool {
67 *v == 0
68}
69
70#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
72pub struct MatchedField {
73 pub field_id: u32,
75 pub ordinals: Vec<u32>,
78}
79
80impl SearchResult {
81 pub fn extract_ordinals(&self) -> Vec<MatchedField> {
85 use rustc_hash::FxHashSet;
86
87 self.positions
88 .iter()
89 .map(|(field_id, scored_positions)| {
90 let mut ordinals: FxHashSet<u32> = FxHashSet::default();
91 for sp in scored_positions {
92 let ordinal = if sp.position > 0xFFFFF {
96 sp.position >> 20
97 } else {
98 sp.position
99 };
100 ordinals.insert(ordinal);
101 }
102 let mut ordinals: Vec<u32> = ordinals.into_iter().collect();
103 ordinals.sort_unstable();
104 MatchedField {
105 field_id: *field_id,
106 ordinals,
107 }
108 })
109 .collect()
110 }
111
112 pub fn field_positions(&self, field_id: u32) -> Option<&[ScoredPosition]> {
114 self.positions
115 .iter()
116 .find(|(fid, _)| *fid == field_id)
117 .map(|(_, positions)| positions.as_slice())
118 }
119}
120
121#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
123pub struct SearchHit {
124 pub address: DocAddress,
126 pub score: Score,
127 #[serde(default, skip_serializing_if = "Vec::is_empty")]
129 pub matched_fields: Vec<MatchedField>,
130}
131
132#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
134pub struct SearchResponse {
135 pub hits: Vec<SearchHit>,
136 pub total_hits: u32,
137}
138
139impl PartialEq for SearchResult {
140 fn eq(&self, other: &Self) -> bool {
141 self.segment_id == other.segment_id && self.doc_id == other.doc_id
142 }
143}
144
145impl Eq for SearchResult {}
146
147impl PartialOrd for SearchResult {
148 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
149 Some(self.cmp(other))
150 }
151}
152
153impl Ord for SearchResult {
154 fn cmp(&self, other: &Self) -> Ordering {
155 other
156 .score
157 .partial_cmp(&self.score)
158 .unwrap_or(Ordering::Equal)
159 .then_with(|| self.segment_id.cmp(&other.segment_id))
160 .then_with(|| self.doc_id.cmp(&other.doc_id))
161 }
162}
163
164pub trait Collector {
169 fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]);
172
173 fn needs_positions(&self) -> bool {
175 false
176 }
177}
178
179pub struct TopKCollector {
181 heap: BinaryHeap<SearchResult>,
182 k: usize,
183 collect_positions: bool,
184 total_seen: u32,
186}
187
188impl TopKCollector {
189 pub fn new(k: usize) -> Self {
190 Self {
191 heap: BinaryHeap::with_capacity(k + 1),
192 k,
193 collect_positions: false,
194 total_seen: 0,
195 }
196 }
197
198 pub fn with_positions(k: usize) -> Self {
200 Self {
201 heap: BinaryHeap::with_capacity(k + 1),
202 k,
203 collect_positions: true,
204 total_seen: 0,
205 }
206 }
207
208 pub fn total_seen(&self) -> u32 {
210 self.total_seen
211 }
212
213 pub fn into_sorted_results(self) -> Vec<SearchResult> {
214 let mut results: Vec<_> = self.heap.into_vec();
215 results.sort_by(|a, b| {
216 b.score
217 .partial_cmp(&a.score)
218 .unwrap_or(Ordering::Equal)
219 .then_with(|| a.doc_id.cmp(&b.doc_id))
220 });
221 results
222 }
223
224 pub fn into_results_with_count(self) -> (Vec<SearchResult>, u32) {
226 let total = self.total_seen;
227 (self.into_sorted_results(), total)
228 }
229}
230
231impl Collector for TopKCollector {
232 fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]) {
233 self.total_seen += 1;
234
235 let dominated =
239 self.heap.len() >= self.k && self.heap.peek().is_some_and(|min| score <= min.score);
240 if dominated {
241 return;
242 }
243
244 let positions = if self.collect_positions {
245 positions.to_vec()
246 } else {
247 Vec::new()
248 };
249
250 if self.heap.len() >= self.k {
251 self.heap.pop();
252 }
253 self.heap.push(SearchResult {
254 doc_id,
255 score,
256 segment_id: 0,
257 positions,
258 });
259 }
260
261 fn needs_positions(&self) -> bool {
262 self.collect_positions
263 }
264}
265
266#[derive(Default)]
268pub struct CountCollector {
269 count: u64,
270}
271
272impl CountCollector {
273 pub fn new() -> Self {
274 Self { count: 0 }
275 }
276
277 pub fn count(&self) -> u64 {
279 self.count
280 }
281}
282
283impl Collector for CountCollector {
284 #[inline]
285 fn collect(
286 &mut self,
287 _doc_id: DocId,
288 _score: Score,
289 _positions: &[(u32, Vec<ScoredPosition>)],
290 ) {
291 self.count += 1;
292 }
293}
294
295pub async fn search_segment(
297 reader: &SegmentReader,
298 query: &dyn Query,
299 limit: usize,
300) -> Result<Vec<SearchResult>> {
301 let mut collector = TopKCollector::new(limit);
302 collect_segment_with_limit(reader, query, &mut collector, limit).await?;
303 Ok(collector.into_sorted_results())
304}
305
306pub async fn search_segment_with_count(
308 reader: &SegmentReader,
309 query: &dyn Query,
310 limit: usize,
311) -> Result<(Vec<SearchResult>, u32)> {
312 let mut collector = TopKCollector::new(limit);
313 collect_segment_with_limit(reader, query, &mut collector, limit).await?;
314 Ok(collector.into_results_with_count())
315}
316
317pub async fn search_segment_with_positions(
319 reader: &SegmentReader,
320 query: &dyn Query,
321 limit: usize,
322) -> Result<Vec<SearchResult>> {
323 let mut collector = TopKCollector::with_positions(limit);
324 collect_segment_with_limit(reader, query, &mut collector, limit).await?;
325 Ok(collector.into_sorted_results())
326}
327
328pub async fn search_segment_with_positions_and_count(
330 reader: &SegmentReader,
331 query: &dyn Query,
332 limit: usize,
333) -> Result<(Vec<SearchResult>, u32)> {
334 let mut collector = TopKCollector::with_positions(limit);
335 collect_segment_with_limit(reader, query, &mut collector, limit).await?;
336 Ok(collector.into_results_with_count())
337}
338
339pub async fn count_segment(reader: &SegmentReader, query: &dyn Query) -> Result<u64> {
341 let mut collector = CountCollector::new();
342 collect_segment(reader, query, &mut collector).await?;
343 Ok(collector.count())
344}
345
346impl<A: Collector, B: Collector> Collector for (&mut A, &mut B) {
348 fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]) {
349 self.0.collect(doc_id, score, positions);
350 self.1.collect(doc_id, score, positions);
351 }
352 fn needs_positions(&self) -> bool {
353 self.0.needs_positions() || self.1.needs_positions()
354 }
355}
356
357impl<A: Collector, B: Collector, C: Collector> Collector for (&mut A, &mut B, &mut C) {
359 fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]) {
360 self.0.collect(doc_id, score, positions);
361 self.1.collect(doc_id, score, positions);
362 self.2.collect(doc_id, score, positions);
363 }
364 fn needs_positions(&self) -> bool {
365 self.0.needs_positions() || self.1.needs_positions() || self.2.needs_positions()
366 }
367}
368
369pub async fn collect_segment<C: Collector>(
387 reader: &SegmentReader,
388 query: &dyn Query,
389 collector: &mut C,
390) -> Result<()> {
391 collect_segment_with_limit(reader, query, collector, usize::MAX / 2).await
393}
394
395pub async fn collect_segment_with_limit<C: Collector>(
404 reader: &SegmentReader,
405 query: &dyn Query,
406 collector: &mut C,
407 limit: usize,
408) -> Result<()> {
409 let needs_positions = collector.needs_positions();
410 let mut scorer = query.scorer(reader, limit).await?;
411
412 let mut doc = scorer.doc();
413 while doc != TERMINATED {
414 if needs_positions {
415 let positions = scorer.matched_positions().unwrap_or_default();
416 collector.collect(doc, scorer.score(), &positions);
417 } else {
418 collector.collect(doc, scorer.score(), &[]);
419 }
420 doc = scorer.advance();
421 }
422
423 Ok(())
424}
425
426#[cfg(test)]
427mod tests {
428 use super::*;
429
430 #[test]
431 fn test_top_k_collector() {
432 let mut collector = TopKCollector::new(3);
433
434 collector.collect(0, 1.0, &[]);
435 collector.collect(1, 3.0, &[]);
436 collector.collect(2, 2.0, &[]);
437 collector.collect(3, 4.0, &[]);
438 collector.collect(4, 0.5, &[]);
439
440 let results = collector.into_sorted_results();
441
442 assert_eq!(results.len(), 3);
443 assert_eq!(results[0].doc_id, 3); assert_eq!(results[1].doc_id, 1); assert_eq!(results[2].doc_id, 2); }
447
448 #[test]
449 fn test_count_collector() {
450 let mut collector = CountCollector::new();
451
452 collector.collect(0, 1.0, &[]);
453 collector.collect(1, 2.0, &[]);
454 collector.collect(2, 3.0, &[]);
455
456 assert_eq!(collector.count(), 3);
457 }
458
459 #[test]
460 fn test_multi_collector() {
461 let mut top_k = TopKCollector::new(2);
462 let mut count = CountCollector::new();
463
464 for (doc_id, score) in [(0, 1.0), (1, 3.0), (2, 2.0), (3, 4.0), (4, 0.5)] {
466 top_k.collect(doc_id, score, &[]);
467 count.collect(doc_id, score, &[]);
468 }
469
470 assert_eq!(count.count(), 5);
472
473 let results = top_k.into_sorted_results();
475 assert_eq!(results.len(), 2);
476 assert_eq!(results[0].doc_id, 3); assert_eq!(results[1].doc_id, 1); }
479}