1use std::cmp::Ordering;
4use std::collections::BinaryHeap;
5
6use crate::segment::SegmentReader;
7use crate::structures::TERMINATED;
8use crate::{DocId, Result, Score};
9
10use super::Query;
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
14pub struct DocAddress {
15 pub segment_id: String,
17 pub doc_id: DocId,
19}
20
21impl DocAddress {
22 pub fn new(segment_id: u128, doc_id: DocId) -> Self {
23 Self {
24 segment_id: format!("{:032x}", segment_id),
25 doc_id,
26 }
27 }
28
29 pub fn segment_id_u128(&self) -> Option<u128> {
31 u128::from_str_radix(&self.segment_id, 16).ok()
32 }
33}
34
35#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
39pub struct ScoredPosition {
40 pub position: u32,
42 pub score: f32,
44}
45
46impl ScoredPosition {
47 pub fn new(position: u32, score: f32) -> Self {
48 Self { position, score }
49 }
50}
51
52#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
54pub struct SearchResult {
55 pub doc_id: DocId,
56 pub score: Score,
57 #[serde(default, skip_serializing_if = "is_zero_u128")]
59 pub segment_id: u128,
60 #[serde(default, skip_serializing_if = "Vec::is_empty")]
63 pub positions: Vec<(u32, Vec<ScoredPosition>)>,
64}
65
66fn is_zero_u128(v: &u128) -> bool {
67 *v == 0
68}
69
70#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
72pub struct MatchedField {
73 pub field_id: u32,
75 pub ordinals: Vec<u32>,
78}
79
80impl SearchResult {
81 pub fn extract_ordinals(&self) -> Vec<MatchedField> {
85 use rustc_hash::FxHashSet;
86
87 self.positions
88 .iter()
89 .map(|(field_id, scored_positions)| {
90 let mut ordinals: FxHashSet<u32> = FxHashSet::default();
91 for sp in scored_positions {
92 let ordinal = if sp.position > 0xFFFFF {
96 sp.position >> 20
97 } else {
98 sp.position
99 };
100 ordinals.insert(ordinal);
101 }
102 let mut ordinals: Vec<u32> = ordinals.into_iter().collect();
103 ordinals.sort_unstable();
104 MatchedField {
105 field_id: *field_id,
106 ordinals,
107 }
108 })
109 .collect()
110 }
111
112 pub fn field_positions(&self, field_id: u32) -> Option<&[ScoredPosition]> {
114 self.positions
115 .iter()
116 .find(|(fid, _)| *fid == field_id)
117 .map(|(_, positions)| positions.as_slice())
118 }
119}
120
121#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
123pub struct SearchHit {
124 pub address: DocAddress,
126 pub score: Score,
127 #[serde(default, skip_serializing_if = "Vec::is_empty")]
129 pub matched_fields: Vec<MatchedField>,
130}
131
132#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
134pub struct SearchResponse {
135 pub hits: Vec<SearchHit>,
136 pub total_hits: u32,
137}
138
139impl PartialEq for SearchResult {
140 fn eq(&self, other: &Self) -> bool {
141 self.doc_id == other.doc_id
142 }
143}
144
145impl Eq for SearchResult {}
146
147impl PartialOrd for SearchResult {
148 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
149 Some(self.cmp(other))
150 }
151}
152
153impl Ord for SearchResult {
154 fn cmp(&self, other: &Self) -> Ordering {
155 other
156 .score
157 .partial_cmp(&self.score)
158 .unwrap_or(Ordering::Equal)
159 .then_with(|| self.doc_id.cmp(&other.doc_id))
160 }
161}
162
163pub trait Collector {
168 fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]);
171
172 fn needs_positions(&self) -> bool {
174 false
175 }
176}
177
178pub struct TopKCollector {
180 heap: BinaryHeap<SearchResult>,
181 k: usize,
182 collect_positions: bool,
183 total_seen: u32,
185}
186
187impl TopKCollector {
188 pub fn new(k: usize) -> Self {
189 Self {
190 heap: BinaryHeap::with_capacity(k + 1),
191 k,
192 collect_positions: false,
193 total_seen: 0,
194 }
195 }
196
197 pub fn with_positions(k: usize) -> Self {
199 Self {
200 heap: BinaryHeap::with_capacity(k + 1),
201 k,
202 collect_positions: true,
203 total_seen: 0,
204 }
205 }
206
207 pub fn total_seen(&self) -> u32 {
209 self.total_seen
210 }
211
212 pub fn into_sorted_results(self) -> Vec<SearchResult> {
213 let mut results: Vec<_> = self.heap.into_vec();
214 results.sort_by(|a, b| {
215 b.score
216 .partial_cmp(&a.score)
217 .unwrap_or(Ordering::Equal)
218 .then_with(|| a.doc_id.cmp(&b.doc_id))
219 });
220 results
221 }
222
223 pub fn into_results_with_count(self) -> (Vec<SearchResult>, u32) {
225 let total = self.total_seen;
226 (self.into_sorted_results(), total)
227 }
228}
229
230impl Collector for TopKCollector {
231 fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]) {
232 self.total_seen += 1;
233
234 let dominated =
238 self.heap.len() >= self.k && self.heap.peek().is_some_and(|min| score <= min.score);
239 if dominated {
240 return;
241 }
242
243 let positions = if self.collect_positions {
244 positions.to_vec()
245 } else {
246 Vec::new()
247 };
248
249 if self.heap.len() >= self.k {
250 self.heap.pop();
251 }
252 self.heap.push(SearchResult {
253 doc_id,
254 score,
255 segment_id: 0,
256 positions,
257 });
258 }
259
260 fn needs_positions(&self) -> bool {
261 self.collect_positions
262 }
263}
264
265#[derive(Default)]
267pub struct CountCollector {
268 count: u64,
269}
270
271impl CountCollector {
272 pub fn new() -> Self {
273 Self { count: 0 }
274 }
275
276 pub fn count(&self) -> u64 {
278 self.count
279 }
280}
281
282impl Collector for CountCollector {
283 #[inline]
284 fn collect(
285 &mut self,
286 _doc_id: DocId,
287 _score: Score,
288 _positions: &[(u32, Vec<ScoredPosition>)],
289 ) {
290 self.count += 1;
291 }
292}
293
294pub async fn search_segment(
296 reader: &SegmentReader,
297 query: &dyn Query,
298 limit: usize,
299) -> Result<Vec<SearchResult>> {
300 let mut collector = TopKCollector::new(limit);
301 collect_segment_with_limit(reader, query, &mut collector, limit).await?;
302 Ok(collector.into_sorted_results())
303}
304
305pub async fn search_segment_with_count(
307 reader: &SegmentReader,
308 query: &dyn Query,
309 limit: usize,
310) -> Result<(Vec<SearchResult>, u32)> {
311 let mut collector = TopKCollector::new(limit);
312 collect_segment_with_limit(reader, query, &mut collector, limit).await?;
313 Ok(collector.into_results_with_count())
314}
315
316pub async fn search_segment_with_positions(
318 reader: &SegmentReader,
319 query: &dyn Query,
320 limit: usize,
321) -> Result<Vec<SearchResult>> {
322 let mut collector = TopKCollector::with_positions(limit);
323 collect_segment_with_limit(reader, query, &mut collector, limit).await?;
324 Ok(collector.into_sorted_results())
325}
326
327pub async fn search_segment_with_positions_and_count(
329 reader: &SegmentReader,
330 query: &dyn Query,
331 limit: usize,
332) -> Result<(Vec<SearchResult>, u32)> {
333 let mut collector = TopKCollector::with_positions(limit);
334 collect_segment_with_limit(reader, query, &mut collector, limit).await?;
335 Ok(collector.into_results_with_count())
336}
337
338pub async fn count_segment(reader: &SegmentReader, query: &dyn Query) -> Result<u64> {
340 let mut collector = CountCollector::new();
341 collect_segment(reader, query, &mut collector).await?;
342 Ok(collector.count())
343}
344
345impl<A: Collector, B: Collector> Collector for (&mut A, &mut B) {
347 fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]) {
348 self.0.collect(doc_id, score, positions);
349 self.1.collect(doc_id, score, positions);
350 }
351 fn needs_positions(&self) -> bool {
352 self.0.needs_positions() || self.1.needs_positions()
353 }
354}
355
356impl<A: Collector, B: Collector, C: Collector> Collector for (&mut A, &mut B, &mut C) {
358 fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]) {
359 self.0.collect(doc_id, score, positions);
360 self.1.collect(doc_id, score, positions);
361 self.2.collect(doc_id, score, positions);
362 }
363 fn needs_positions(&self) -> bool {
364 self.0.needs_positions() || self.1.needs_positions() || self.2.needs_positions()
365 }
366}
367
368pub async fn collect_segment<C: Collector>(
386 reader: &SegmentReader,
387 query: &dyn Query,
388 collector: &mut C,
389) -> Result<()> {
390 collect_segment_with_limit(reader, query, collector, usize::MAX / 2).await
392}
393
394pub async fn collect_segment_with_limit<C: Collector>(
403 reader: &SegmentReader,
404 query: &dyn Query,
405 collector: &mut C,
406 limit: usize,
407) -> Result<()> {
408 let needs_positions = collector.needs_positions();
409 let doc_id_offset = reader.doc_id_offset();
410 let mut scorer = query.scorer(reader, limit).await?;
411
412 let mut doc = scorer.doc();
413 while doc != TERMINATED {
414 if needs_positions {
416 let positions = scorer.matched_positions().unwrap_or_default();
417 collector.collect(doc + doc_id_offset, scorer.score(), &positions);
418 } else {
419 collector.collect(doc + doc_id_offset, scorer.score(), &[]);
420 }
421 doc = scorer.advance();
422 }
423
424 Ok(())
425}
426
427#[cfg(test)]
428mod tests {
429 use super::*;
430
431 #[test]
432 fn test_top_k_collector() {
433 let mut collector = TopKCollector::new(3);
434
435 collector.collect(0, 1.0, &[]);
436 collector.collect(1, 3.0, &[]);
437 collector.collect(2, 2.0, &[]);
438 collector.collect(3, 4.0, &[]);
439 collector.collect(4, 0.5, &[]);
440
441 let results = collector.into_sorted_results();
442
443 assert_eq!(results.len(), 3);
444 assert_eq!(results[0].doc_id, 3); assert_eq!(results[1].doc_id, 1); assert_eq!(results[2].doc_id, 2); }
448
449 #[test]
450 fn test_count_collector() {
451 let mut collector = CountCollector::new();
452
453 collector.collect(0, 1.0, &[]);
454 collector.collect(1, 2.0, &[]);
455 collector.collect(2, 3.0, &[]);
456
457 assert_eq!(collector.count(), 3);
458 }
459
460 #[test]
461 fn test_multi_collector() {
462 let mut top_k = TopKCollector::new(2);
463 let mut count = CountCollector::new();
464
465 for (doc_id, score) in [(0, 1.0), (1, 3.0), (2, 2.0), (3, 4.0), (4, 0.5)] {
467 top_k.collect(doc_id, score, &[]);
468 count.collect(doc_id, score, &[]);
469 }
470
471 assert_eq!(count.count(), 5);
473
474 let results = top_k.into_sorted_results();
476 assert_eq!(results.len(), 2);
477 assert_eq!(results[0].doc_id, 3); assert_eq!(results[1].doc_id, 1); }
480}