1use std::cmp::Ordering;
4use std::collections::BinaryHeap;
5
6use crate::segment::SegmentReader;
7use crate::structures::TERMINATED;
8use crate::{DocId, Result, Score};
9
10use super::Query;
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash)]
16pub struct DocAddress {
17 segment_id_raw: u128,
19 pub doc_id: DocId,
21}
22
23impl DocAddress {
24 pub fn new(segment_id: u128, doc_id: DocId) -> Self {
25 Self {
26 segment_id_raw: segment_id,
27 doc_id,
28 }
29 }
30
31 pub fn segment_id(&self) -> String {
33 format!("{:032x}", self.segment_id_raw)
34 }
35
36 pub fn segment_id_u128(&self) -> Option<u128> {
38 Some(self.segment_id_raw)
39 }
40}
41
42impl serde::Serialize for DocAddress {
43 fn serialize<S: serde::Serializer>(
44 &self,
45 serializer: S,
46 ) -> std::result::Result<S::Ok, S::Error> {
47 use serde::ser::SerializeStruct;
48 let mut s = serializer.serialize_struct("DocAddress", 2)?;
49 s.serialize_field("segment_id", &format!("{:032x}", self.segment_id_raw))?;
50 s.serialize_field("doc_id", &self.doc_id)?;
51 s.end()
52 }
53}
54
55impl<'de> serde::Deserialize<'de> for DocAddress {
56 fn deserialize<D: serde::Deserializer<'de>>(
57 deserializer: D,
58 ) -> std::result::Result<Self, D::Error> {
59 #[derive(serde::Deserialize)]
60 struct Helper {
61 segment_id: String,
62 doc_id: DocId,
63 }
64 let h = Helper::deserialize(deserializer)?;
65 let raw = u128::from_str_radix(&h.segment_id, 16).map_err(serde::de::Error::custom)?;
66 Ok(DocAddress {
67 segment_id_raw: raw,
68 doc_id: h.doc_id,
69 })
70 }
71}
72
73#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
77pub struct ScoredPosition {
78 pub position: u32,
80 pub score: f32,
82}
83
84impl ScoredPosition {
85 pub fn new(position: u32, score: f32) -> Self {
86 Self { position, score }
87 }
88}
89
90#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
92pub struct SearchResult {
93 pub doc_id: DocId,
94 pub score: Score,
95 #[serde(default, skip_serializing_if = "is_zero_u128")]
97 pub segment_id: u128,
98 #[serde(default, skip_serializing_if = "Vec::is_empty")]
101 pub positions: Vec<(u32, Vec<ScoredPosition>)>,
102}
103
104fn is_zero_u128(v: &u128) -> bool {
105 *v == 0
106}
107
108#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
110pub struct MatchedField {
111 pub field_id: u32,
113 pub ordinals: Vec<u32>,
116}
117
118impl SearchResult {
119 pub fn extract_ordinals(&self) -> Vec<MatchedField> {
123 use rustc_hash::FxHashSet;
124
125 self.positions
126 .iter()
127 .map(|(field_id, scored_positions)| {
128 let mut ordinals: FxHashSet<u32> = FxHashSet::default();
129 for sp in scored_positions {
130 let ordinal = if sp.position > 0xFFFFF {
134 sp.position >> 20
135 } else {
136 sp.position
137 };
138 ordinals.insert(ordinal);
139 }
140 let mut ordinals: Vec<u32> = ordinals.into_iter().collect();
141 ordinals.sort_unstable();
142 MatchedField {
143 field_id: *field_id,
144 ordinals,
145 }
146 })
147 .collect()
148 }
149
150 pub fn field_positions(&self, field_id: u32) -> Option<&[ScoredPosition]> {
152 self.positions
153 .iter()
154 .find(|(fid, _)| *fid == field_id)
155 .map(|(_, positions)| positions.as_slice())
156 }
157}
158
159#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
161pub struct SearchHit {
162 pub address: DocAddress,
164 pub score: Score,
165 #[serde(default, skip_serializing_if = "Vec::is_empty")]
167 pub matched_fields: Vec<MatchedField>,
168}
169
170#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
172pub struct SearchResponse {
173 pub hits: Vec<SearchHit>,
174 pub total_hits: u32,
175}
176
177impl PartialEq for SearchResult {
178 fn eq(&self, other: &Self) -> bool {
179 self.segment_id == other.segment_id && self.doc_id == other.doc_id
180 }
181}
182
183impl Eq for SearchResult {}
184
185impl PartialOrd for SearchResult {
186 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
187 Some(self.cmp(other))
188 }
189}
190
191impl Ord for SearchResult {
192 fn cmp(&self, other: &Self) -> Ordering {
193 other
194 .score
195 .partial_cmp(&self.score)
196 .unwrap_or(Ordering::Equal)
197 .then_with(|| self.segment_id.cmp(&other.segment_id))
198 .then_with(|| self.doc_id.cmp(&other.doc_id))
199 }
200}
201
202pub trait Collector {
207 fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]);
210
211 fn needs_positions(&self) -> bool {
213 false
214 }
215}
216
217pub struct TopKCollector {
219 heap: BinaryHeap<SearchResult>,
220 k: usize,
221 collect_positions: bool,
222 total_seen: u32,
224}
225
226impl TopKCollector {
227 pub fn new(k: usize) -> Self {
228 Self {
229 heap: BinaryHeap::with_capacity(k + 1),
230 k,
231 collect_positions: false,
232 total_seen: 0,
233 }
234 }
235
236 pub fn with_positions(k: usize) -> Self {
238 Self {
239 heap: BinaryHeap::with_capacity(k + 1),
240 k,
241 collect_positions: true,
242 total_seen: 0,
243 }
244 }
245
246 pub fn total_seen(&self) -> u32 {
248 self.total_seen
249 }
250
251 pub fn into_sorted_results(self) -> Vec<SearchResult> {
252 let mut results: Vec<_> = self.heap.into_vec();
253 results.sort_by(|a, b| {
254 b.score
255 .partial_cmp(&a.score)
256 .unwrap_or(Ordering::Equal)
257 .then_with(|| a.doc_id.cmp(&b.doc_id))
258 });
259 results
260 }
261
262 pub fn into_results_with_count(self) -> (Vec<SearchResult>, u32) {
264 let total = self.total_seen;
265 (self.into_sorted_results(), total)
266 }
267}
268
269impl Collector for TopKCollector {
270 fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]) {
271 self.total_seen = self.total_seen.saturating_add(1);
272
273 let dominated =
277 self.heap.len() >= self.k && self.heap.peek().is_some_and(|min| score <= min.score);
278 if dominated {
279 return;
280 }
281
282 let positions = if self.collect_positions {
283 positions.to_vec()
284 } else {
285 Vec::new()
286 };
287
288 if self.heap.len() >= self.k {
289 self.heap.pop();
290 }
291 self.heap.push(SearchResult {
292 doc_id,
293 score,
294 segment_id: 0,
295 positions,
296 });
297 }
298
299 fn needs_positions(&self) -> bool {
300 self.collect_positions
301 }
302}
303
304#[derive(Default)]
306pub struct CountCollector {
307 count: u64,
308}
309
310impl CountCollector {
311 pub fn new() -> Self {
312 Self { count: 0 }
313 }
314
315 pub fn count(&self) -> u64 {
317 self.count
318 }
319}
320
321impl Collector for CountCollector {
322 #[inline]
323 fn collect(
324 &mut self,
325 _doc_id: DocId,
326 _score: Score,
327 _positions: &[(u32, Vec<ScoredPosition>)],
328 ) {
329 self.count += 1;
330 }
331}
332
333pub async fn search_segment_with_count(
335 reader: &SegmentReader,
336 query: &dyn Query,
337 limit: usize,
338) -> Result<(Vec<SearchResult>, u32)> {
339 let mut collector = TopKCollector::new(limit);
340 collect_segment_with_limit(reader, query, &mut collector, limit).await?;
341 Ok(collector.into_results_with_count())
342}
343
344pub async fn search_segment_with_positions_and_count(
346 reader: &SegmentReader,
347 query: &dyn Query,
348 limit: usize,
349) -> Result<(Vec<SearchResult>, u32)> {
350 let mut collector = TopKCollector::with_positions(limit);
351 collect_segment_with_limit(reader, query, &mut collector, limit).await?;
352 Ok(collector.into_results_with_count())
353}
354
355impl<A: Collector, B: Collector> Collector for (&mut A, &mut B) {
357 fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]) {
358 self.0.collect(doc_id, score, positions);
359 self.1.collect(doc_id, score, positions);
360 }
361 fn needs_positions(&self) -> bool {
362 self.0.needs_positions() || self.1.needs_positions()
363 }
364}
365
366impl<A: Collector, B: Collector, C: Collector> Collector for (&mut A, &mut B, &mut C) {
368 fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]) {
369 self.0.collect(doc_id, score, positions);
370 self.1.collect(doc_id, score, positions);
371 self.2.collect(doc_id, score, positions);
372 }
373 fn needs_positions(&self) -> bool {
374 self.0.needs_positions() || self.1.needs_positions() || self.2.needs_positions()
375 }
376}
377
378pub async fn collect_segment<C: Collector>(
396 reader: &SegmentReader,
397 query: &dyn Query,
398 collector: &mut C,
399) -> Result<()> {
400 collect_segment_with_limit(reader, query, collector, usize::MAX / 2).await
402}
403
404pub async fn collect_segment_with_limit<C: Collector>(
413 reader: &SegmentReader,
414 query: &dyn Query,
415 collector: &mut C,
416 limit: usize,
417) -> Result<()> {
418 let mut scorer = query.scorer(reader, limit).await?;
419 drive_scorer(scorer.as_mut(), collector);
420 Ok(())
421}
422
423fn drive_scorer<C: Collector>(scorer: &mut dyn super::Scorer, collector: &mut C) {
425 let needs_positions = collector.needs_positions();
426 let mut doc = scorer.doc();
427 while doc != TERMINATED {
428 if needs_positions {
429 let positions = scorer.matched_positions().unwrap_or_default();
430 collector.collect(doc, scorer.score(), &positions);
431 } else {
432 collector.collect(doc, scorer.score(), &[]);
433 }
434 doc = scorer.advance();
435 }
436}
437
438#[cfg(feature = "sync")]
442pub fn search_segment_with_count_sync(
443 reader: &SegmentReader,
444 query: &dyn Query,
445 limit: usize,
446) -> Result<(Vec<SearchResult>, u32)> {
447 let mut collector = TopKCollector::new(limit);
448 collect_segment_with_limit_sync(reader, query, &mut collector, limit)?;
449 Ok(collector.into_results_with_count())
450}
451
452#[cfg(feature = "sync")]
454pub fn search_segment_with_positions_and_count_sync(
455 reader: &SegmentReader,
456 query: &dyn Query,
457 limit: usize,
458) -> Result<(Vec<SearchResult>, u32)> {
459 let mut collector = TopKCollector::with_positions(limit);
460 collect_segment_with_limit_sync(reader, query, &mut collector, limit)?;
461 Ok(collector.into_results_with_count())
462}
463
464#[cfg(feature = "sync")]
466pub fn collect_segment_with_limit_sync<C: Collector>(
467 reader: &SegmentReader,
468 query: &dyn Query,
469 collector: &mut C,
470 limit: usize,
471) -> Result<()> {
472 let mut scorer = query.scorer_sync(reader, limit)?;
473 drive_scorer(scorer.as_mut(), collector);
474 Ok(())
475}
476
477#[cfg(test)]
478mod tests {
479 use super::*;
480
481 #[test]
482 fn test_top_k_collector() {
483 let mut collector = TopKCollector::new(3);
484
485 collector.collect(0, 1.0, &[]);
486 collector.collect(1, 3.0, &[]);
487 collector.collect(2, 2.0, &[]);
488 collector.collect(3, 4.0, &[]);
489 collector.collect(4, 0.5, &[]);
490
491 let results = collector.into_sorted_results();
492
493 assert_eq!(results.len(), 3);
494 assert_eq!(results[0].doc_id, 3); assert_eq!(results[1].doc_id, 1); assert_eq!(results[2].doc_id, 2); }
498
499 #[test]
500 fn test_count_collector() {
501 let mut collector = CountCollector::new();
502
503 collector.collect(0, 1.0, &[]);
504 collector.collect(1, 2.0, &[]);
505 collector.collect(2, 3.0, &[]);
506
507 assert_eq!(collector.count(), 3);
508 }
509
510 #[test]
511 fn test_multi_collector() {
512 let mut top_k = TopKCollector::new(2);
513 let mut count = CountCollector::new();
514
515 for (doc_id, score) in [(0, 1.0), (1, 3.0), (2, 2.0), (3, 4.0), (4, 0.5)] {
517 top_k.collect(doc_id, score, &[]);
518 count.collect(doc_id, score, &[]);
519 }
520
521 assert_eq!(count.count(), 5);
523
524 let results = top_k.into_sorted_results();
526 assert_eq!(results.len(), 2);
527 assert_eq!(results[0].doc_id, 3); assert_eq!(results[1].doc_id, 1); }
530}