1use std::cmp::Ordering;
4use std::collections::BinaryHeap;
5
6use crate::segment::SegmentReader;
7use crate::structures::TERMINATED;
8use crate::{DocId, Result, Score};
9
10use super::Query;
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
14pub struct DocAddress {
15 pub segment_id: String,
17 pub doc_id: DocId,
19}
20
21impl DocAddress {
22 pub fn new(segment_id: u128, doc_id: DocId) -> Self {
23 Self {
24 segment_id: format!("{:032x}", segment_id),
25 doc_id,
26 }
27 }
28
29 pub fn segment_id_u128(&self) -> Option<u128> {
31 u128::from_str_radix(&self.segment_id, 16).ok()
32 }
33}
34
35#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
39pub struct ScoredPosition {
40 pub position: u32,
42 pub score: f32,
44}
45
46impl ScoredPosition {
47 pub fn new(position: u32, score: f32) -> Self {
48 Self { position, score }
49 }
50}
51
52#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
54pub struct SearchResult {
55 pub doc_id: DocId,
56 pub score: Score,
57 #[serde(default, skip_serializing_if = "Vec::is_empty")]
60 pub positions: Vec<(u32, Vec<ScoredPosition>)>,
61}
62
63#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
65pub struct MatchedField {
66 pub field_id: u32,
68 pub ordinals: Vec<u32>,
71}
72
73impl SearchResult {
74 pub fn extract_ordinals(&self) -> Vec<MatchedField> {
78 use rustc_hash::FxHashSet;
79
80 self.positions
81 .iter()
82 .map(|(field_id, scored_positions)| {
83 let mut ordinals: FxHashSet<u32> = FxHashSet::default();
84 for sp in scored_positions {
85 let ordinal = if sp.position > 0xFFFFF {
89 sp.position >> 20
90 } else {
91 sp.position
92 };
93 ordinals.insert(ordinal);
94 }
95 let mut ordinals: Vec<u32> = ordinals.into_iter().collect();
96 ordinals.sort_unstable();
97 MatchedField {
98 field_id: *field_id,
99 ordinals,
100 }
101 })
102 .collect()
103 }
104
105 pub fn field_positions(&self, field_id: u32) -> Option<&[ScoredPosition]> {
107 self.positions
108 .iter()
109 .find(|(fid, _)| *fid == field_id)
110 .map(|(_, positions)| positions.as_slice())
111 }
112}
113
114#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
116pub struct SearchHit {
117 pub address: DocAddress,
119 pub score: Score,
120 #[serde(default, skip_serializing_if = "Vec::is_empty")]
122 pub matched_fields: Vec<MatchedField>,
123}
124
125#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
127pub struct SearchResponse {
128 pub hits: Vec<SearchHit>,
129 pub total_hits: u32,
130}
131
132impl PartialEq for SearchResult {
133 fn eq(&self, other: &Self) -> bool {
134 self.doc_id == other.doc_id
135 }
136}
137
138impl Eq for SearchResult {}
139
140impl PartialOrd for SearchResult {
141 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
142 Some(self.cmp(other))
143 }
144}
145
146impl Ord for SearchResult {
147 fn cmp(&self, other: &Self) -> Ordering {
148 other
149 .score
150 .partial_cmp(&self.score)
151 .unwrap_or(Ordering::Equal)
152 .then_with(|| self.doc_id.cmp(&other.doc_id))
153 }
154}
155
156pub trait Collector {
161 fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]);
164
165 fn needs_positions(&self) -> bool {
167 false
168 }
169}
170
171pub struct TopKCollector {
173 heap: BinaryHeap<SearchResult>,
174 k: usize,
175 collect_positions: bool,
176 total_seen: u32,
178}
179
180impl TopKCollector {
181 pub fn new(k: usize) -> Self {
182 Self {
183 heap: BinaryHeap::with_capacity(k + 1),
184 k,
185 collect_positions: false,
186 total_seen: 0,
187 }
188 }
189
190 pub fn with_positions(k: usize) -> Self {
192 Self {
193 heap: BinaryHeap::with_capacity(k + 1),
194 k,
195 collect_positions: true,
196 total_seen: 0,
197 }
198 }
199
200 pub fn total_seen(&self) -> u32 {
202 self.total_seen
203 }
204
205 pub fn into_sorted_results(self) -> Vec<SearchResult> {
206 let mut results: Vec<_> = self.heap.into_vec();
207 results.sort_by(|a, b| {
208 b.score
209 .partial_cmp(&a.score)
210 .unwrap_or(Ordering::Equal)
211 .then_with(|| a.doc_id.cmp(&b.doc_id))
212 });
213 results
214 }
215
216 pub fn into_results_with_count(self) -> (Vec<SearchResult>, u32) {
218 let total = self.total_seen;
219 (self.into_sorted_results(), total)
220 }
221}
222
223impl Collector for TopKCollector {
224 fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]) {
225 self.total_seen += 1;
226
227 let dominated =
231 self.heap.len() >= self.k && self.heap.peek().is_some_and(|min| score <= min.score);
232 if dominated {
233 return;
234 }
235
236 let positions = if self.collect_positions {
237 positions.to_vec()
238 } else {
239 Vec::new()
240 };
241
242 if self.heap.len() < self.k {
243 self.heap.push(SearchResult {
244 doc_id,
245 score,
246 positions,
247 });
248 } else {
249 self.heap.pop();
250 self.heap.push(SearchResult {
251 doc_id,
252 score,
253 positions,
254 });
255 }
256 }
257
258 fn needs_positions(&self) -> bool {
259 self.collect_positions
260 }
261}
262
263#[derive(Default)]
265pub struct CountCollector {
266 count: u64,
267}
268
269impl CountCollector {
270 pub fn new() -> Self {
271 Self { count: 0 }
272 }
273
274 pub fn count(&self) -> u64 {
276 self.count
277 }
278}
279
280impl Collector for CountCollector {
281 #[inline]
282 fn collect(
283 &mut self,
284 _doc_id: DocId,
285 _score: Score,
286 _positions: &[(u32, Vec<ScoredPosition>)],
287 ) {
288 self.count += 1;
289 }
290}
291
292pub async fn search_segment(
294 reader: &SegmentReader,
295 query: &dyn Query,
296 limit: usize,
297) -> Result<Vec<SearchResult>> {
298 let mut collector = TopKCollector::new(limit);
299 collect_segment_with_limit(reader, query, &mut collector, limit).await?;
300 Ok(collector.into_sorted_results())
301}
302
303pub async fn search_segment_with_count(
305 reader: &SegmentReader,
306 query: &dyn Query,
307 limit: usize,
308) -> Result<(Vec<SearchResult>, u32)> {
309 let mut collector = TopKCollector::new(limit);
310 collect_segment_with_limit(reader, query, &mut collector, limit).await?;
311 Ok(collector.into_results_with_count())
312}
313
314pub async fn search_segment_with_positions(
316 reader: &SegmentReader,
317 query: &dyn Query,
318 limit: usize,
319) -> Result<Vec<SearchResult>> {
320 let mut collector = TopKCollector::with_positions(limit);
321 collect_segment_with_limit(reader, query, &mut collector, limit).await?;
322 Ok(collector.into_sorted_results())
323}
324
325pub async fn search_segment_with_positions_and_count(
327 reader: &SegmentReader,
328 query: &dyn Query,
329 limit: usize,
330) -> Result<(Vec<SearchResult>, u32)> {
331 let mut collector = TopKCollector::with_positions(limit);
332 collect_segment_with_limit(reader, query, &mut collector, limit).await?;
333 Ok(collector.into_results_with_count())
334}
335
336pub async fn count_segment(reader: &SegmentReader, query: &dyn Query) -> Result<u64> {
338 let mut collector = CountCollector::new();
339 collect_segment(reader, query, &mut collector).await?;
340 Ok(collector.count())
341}
342
343impl<A: Collector, B: Collector> Collector for (&mut A, &mut B) {
345 fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]) {
346 self.0.collect(doc_id, score, positions);
347 self.1.collect(doc_id, score, positions);
348 }
349 fn needs_positions(&self) -> bool {
350 self.0.needs_positions() || self.1.needs_positions()
351 }
352}
353
354impl<A: Collector, B: Collector, C: Collector> Collector for (&mut A, &mut B, &mut C) {
356 fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]) {
357 self.0.collect(doc_id, score, positions);
358 self.1.collect(doc_id, score, positions);
359 self.2.collect(doc_id, score, positions);
360 }
361 fn needs_positions(&self) -> bool {
362 self.0.needs_positions() || self.1.needs_positions() || self.2.needs_positions()
363 }
364}
365
366pub async fn collect_segment<C: Collector>(
384 reader: &SegmentReader,
385 query: &dyn Query,
386 collector: &mut C,
387) -> Result<()> {
388 collect_segment_with_limit(reader, query, collector, usize::MAX / 2).await
390}
391
392pub async fn collect_segment_with_limit<C: Collector>(
401 reader: &SegmentReader,
402 query: &dyn Query,
403 collector: &mut C,
404 limit: usize,
405) -> Result<()> {
406 let needs_positions = collector.needs_positions();
407 let doc_id_offset = reader.doc_id_offset();
408 let mut scorer = query.scorer(reader, limit).await?;
409
410 let mut doc = scorer.doc();
411 while doc != TERMINATED {
412 if needs_positions {
414 let positions = scorer.matched_positions().unwrap_or_default();
415 collector.collect(doc + doc_id_offset, scorer.score(), &positions);
416 } else {
417 collector.collect(doc + doc_id_offset, scorer.score(), &[]);
418 }
419 doc = scorer.advance();
420 }
421
422 Ok(())
423}
424
425#[cfg(test)]
426mod tests {
427 use super::*;
428
429 #[test]
430 fn test_top_k_collector() {
431 let mut collector = TopKCollector::new(3);
432
433 collector.collect(0, 1.0, &[]);
434 collector.collect(1, 3.0, &[]);
435 collector.collect(2, 2.0, &[]);
436 collector.collect(3, 4.0, &[]);
437 collector.collect(4, 0.5, &[]);
438
439 let results = collector.into_sorted_results();
440
441 assert_eq!(results.len(), 3);
442 assert_eq!(results[0].doc_id, 3); assert_eq!(results[1].doc_id, 1); assert_eq!(results[2].doc_id, 2); }
446
447 #[test]
448 fn test_count_collector() {
449 let mut collector = CountCollector::new();
450
451 collector.collect(0, 1.0, &[]);
452 collector.collect(1, 2.0, &[]);
453 collector.collect(2, 3.0, &[]);
454
455 assert_eq!(collector.count(), 3);
456 }
457
458 #[test]
459 fn test_multi_collector() {
460 let mut top_k = TopKCollector::new(2);
461 let mut count = CountCollector::new();
462
463 for (doc_id, score) in [(0, 1.0), (1, 3.0), (2, 2.0), (3, 4.0), (4, 0.5)] {
465 top_k.collect(doc_id, score, &[]);
466 count.collect(doc_id, score, &[]);
467 }
468
469 assert_eq!(count.count(), 5);
471
472 let results = top_k.into_sorted_results();
474 assert_eq!(results.len(), 2);
475 assert_eq!(results[0].doc_id, 3); assert_eq!(results[1].doc_id, 1); }
478}