1use std::cmp::Ordering;
4use std::collections::BinaryHeap;
5
6use crate::segment::SegmentReader;
7use crate::structures::TERMINATED;
8use crate::{DocId, Result, Score};
9
10use super::Query;
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
14pub struct DocAddress {
15 pub segment_id: String,
17 pub doc_id: DocId,
19}
20
21impl DocAddress {
22 pub fn new(segment_id: u128, doc_id: DocId) -> Self {
23 Self {
24 segment_id: format!("{:032x}", segment_id),
25 doc_id,
26 }
27 }
28
29 pub fn segment_id_u128(&self) -> Option<u128> {
31 u128::from_str_radix(&self.segment_id, 16).ok()
32 }
33}
34
35#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
39pub struct ScoredPosition {
40 pub position: u32,
42 pub score: f32,
44}
45
46impl ScoredPosition {
47 pub fn new(position: u32, score: f32) -> Self {
48 Self { position, score }
49 }
50}
51
52#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
54pub struct SearchResult {
55 pub doc_id: DocId,
56 pub score: Score,
57 #[serde(default, skip_serializing_if = "Vec::is_empty")]
60 pub positions: Vec<(u32, Vec<ScoredPosition>)>,
61}
62
63#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
65pub struct MatchedField {
66 pub field_id: u32,
68 pub ordinals: Vec<u32>,
71}
72
73impl SearchResult {
74 pub fn extract_ordinals(&self) -> Vec<MatchedField> {
78 use rustc_hash::FxHashSet;
79
80 self.positions
81 .iter()
82 .map(|(field_id, scored_positions)| {
83 let mut ordinals: FxHashSet<u32> = FxHashSet::default();
84 for sp in scored_positions {
85 let ordinal = if sp.position > 0xFFFFF {
89 sp.position >> 20
90 } else {
91 sp.position
92 };
93 ordinals.insert(ordinal);
94 }
95 let mut ordinals: Vec<u32> = ordinals.into_iter().collect();
96 ordinals.sort_unstable();
97 MatchedField {
98 field_id: *field_id,
99 ordinals,
100 }
101 })
102 .collect()
103 }
104
105 pub fn field_positions(&self, field_id: u32) -> Option<&[ScoredPosition]> {
107 self.positions
108 .iter()
109 .find(|(fid, _)| *fid == field_id)
110 .map(|(_, positions)| positions.as_slice())
111 }
112}
113
114#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
116pub struct SearchHit {
117 pub address: DocAddress,
119 pub score: Score,
120 #[serde(default, skip_serializing_if = "Vec::is_empty")]
122 pub matched_fields: Vec<MatchedField>,
123}
124
125#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
127pub struct SearchResponse {
128 pub hits: Vec<SearchHit>,
129 pub total_hits: u32,
130}
131
132impl PartialEq for SearchResult {
133 fn eq(&self, other: &Self) -> bool {
134 self.doc_id == other.doc_id
135 }
136}
137
138impl Eq for SearchResult {}
139
140impl PartialOrd for SearchResult {
141 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
142 Some(self.cmp(other))
143 }
144}
145
146impl Ord for SearchResult {
147 fn cmp(&self, other: &Self) -> Ordering {
148 other
149 .score
150 .partial_cmp(&self.score)
151 .unwrap_or(Ordering::Equal)
152 .then_with(|| self.doc_id.cmp(&other.doc_id))
153 }
154}
155
156pub trait Collector {
161 fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]);
164
165 fn needs_positions(&self) -> bool {
167 false
168 }
169}
170
171pub struct TopKCollector {
173 heap: BinaryHeap<SearchResult>,
174 k: usize,
175 collect_positions: bool,
176 total_seen: u32,
178}
179
180impl TopKCollector {
181 pub fn new(k: usize) -> Self {
182 Self {
183 heap: BinaryHeap::with_capacity(k + 1),
184 k,
185 collect_positions: false,
186 total_seen: 0,
187 }
188 }
189
190 pub fn with_positions(k: usize) -> Self {
192 Self {
193 heap: BinaryHeap::with_capacity(k + 1),
194 k,
195 collect_positions: true,
196 total_seen: 0,
197 }
198 }
199
200 pub fn total_seen(&self) -> u32 {
202 self.total_seen
203 }
204
205 pub fn into_sorted_results(self) -> Vec<SearchResult> {
206 let mut results: Vec<_> = self.heap.into_vec();
207 results.sort_by(|a, b| {
208 b.score
209 .partial_cmp(&a.score)
210 .unwrap_or(Ordering::Equal)
211 .then_with(|| a.doc_id.cmp(&b.doc_id))
212 });
213 results
214 }
215
216 pub fn into_results_with_count(self) -> (Vec<SearchResult>, u32) {
218 let total = self.total_seen;
219 (self.into_sorted_results(), total)
220 }
221}
222
223impl Collector for TopKCollector {
224 fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]) {
225 self.total_seen += 1;
226
227 let positions = if self.collect_positions {
228 positions.to_vec()
229 } else {
230 Vec::new()
231 };
232
233 if self.heap.len() < self.k {
234 self.heap.push(SearchResult {
235 doc_id,
236 score,
237 positions,
238 });
239 } else if let Some(min) = self.heap.peek()
240 && score > min.score
241 {
242 self.heap.pop();
243 self.heap.push(SearchResult {
244 doc_id,
245 score,
246 positions,
247 });
248 }
249 }
250
251 fn needs_positions(&self) -> bool {
252 self.collect_positions
253 }
254}
255
256#[derive(Default)]
258pub struct CountCollector {
259 count: u64,
260}
261
262impl CountCollector {
263 pub fn new() -> Self {
264 Self { count: 0 }
265 }
266
267 pub fn count(&self) -> u64 {
269 self.count
270 }
271}
272
273impl Collector for CountCollector {
274 #[inline]
275 fn collect(
276 &mut self,
277 _doc_id: DocId,
278 _score: Score,
279 _positions: &[(u32, Vec<ScoredPosition>)],
280 ) {
281 self.count += 1;
282 }
283}
284
285pub async fn search_segment(
287 reader: &SegmentReader,
288 query: &dyn Query,
289 limit: usize,
290) -> Result<Vec<SearchResult>> {
291 let mut collector = TopKCollector::new(limit);
292 collect_segment_with_limit(reader, query, &mut collector, limit).await?;
293 Ok(collector.into_sorted_results())
294}
295
296pub async fn search_segment_with_count(
298 reader: &SegmentReader,
299 query: &dyn Query,
300 limit: usize,
301) -> Result<(Vec<SearchResult>, u32)> {
302 let mut collector = TopKCollector::new(limit);
303 collect_segment_with_limit(reader, query, &mut collector, limit).await?;
304 Ok(collector.into_results_with_count())
305}
306
307pub async fn search_segment_with_positions(
309 reader: &SegmentReader,
310 query: &dyn Query,
311 limit: usize,
312) -> Result<Vec<SearchResult>> {
313 let mut collector = TopKCollector::with_positions(limit);
314 collect_segment_with_limit(reader, query, &mut collector, limit).await?;
315 Ok(collector.into_sorted_results())
316}
317
318pub async fn search_segment_with_positions_and_count(
320 reader: &SegmentReader,
321 query: &dyn Query,
322 limit: usize,
323) -> Result<(Vec<SearchResult>, u32)> {
324 let mut collector = TopKCollector::with_positions(limit);
325 collect_segment_with_limit(reader, query, &mut collector, limit).await?;
326 Ok(collector.into_results_with_count())
327}
328
329pub async fn count_segment(reader: &SegmentReader, query: &dyn Query) -> Result<u64> {
331 let mut collector = CountCollector::new();
332 collect_segment(reader, query, &mut collector).await?;
333 Ok(collector.count())
334}
335
336impl<A: Collector, B: Collector> Collector for (&mut A, &mut B) {
338 fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]) {
339 self.0.collect(doc_id, score, positions);
340 self.1.collect(doc_id, score, positions);
341 }
342 fn needs_positions(&self) -> bool {
343 self.0.needs_positions() || self.1.needs_positions()
344 }
345}
346
347impl<A: Collector, B: Collector, C: Collector> Collector for (&mut A, &mut B, &mut C) {
349 fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]) {
350 self.0.collect(doc_id, score, positions);
351 self.1.collect(doc_id, score, positions);
352 self.2.collect(doc_id, score, positions);
353 }
354 fn needs_positions(&self) -> bool {
355 self.0.needs_positions() || self.1.needs_positions() || self.2.needs_positions()
356 }
357}
358
359pub async fn collect_segment<C: Collector>(
377 reader: &SegmentReader,
378 query: &dyn Query,
379 collector: &mut C,
380) -> Result<()> {
381 collect_segment_with_limit(reader, query, collector, usize::MAX / 2).await
383}
384
385pub async fn collect_segment_with_limit<C: Collector>(
391 reader: &SegmentReader,
392 query: &dyn Query,
393 collector: &mut C,
394 limit: usize,
395) -> Result<()> {
396 let needs_positions = collector.needs_positions();
397 let mut scorer = query.scorer(reader, limit).await?;
398
399 let mut doc = scorer.doc();
400 while doc != TERMINATED {
401 let positions = if needs_positions {
402 scorer.matched_positions().unwrap_or_default()
403 } else {
404 Vec::new()
405 };
406 collector.collect(doc, scorer.score(), &positions);
407 doc = scorer.advance();
408 }
409
410 Ok(())
411}
412
413#[cfg(test)]
414mod tests {
415 use super::*;
416
417 #[test]
418 fn test_top_k_collector() {
419 let mut collector = TopKCollector::new(3);
420
421 collector.collect(0, 1.0, &[]);
422 collector.collect(1, 3.0, &[]);
423 collector.collect(2, 2.0, &[]);
424 collector.collect(3, 4.0, &[]);
425 collector.collect(4, 0.5, &[]);
426
427 let results = collector.into_sorted_results();
428
429 assert_eq!(results.len(), 3);
430 assert_eq!(results[0].doc_id, 3); assert_eq!(results[1].doc_id, 1); assert_eq!(results[2].doc_id, 2); }
434
435 #[test]
436 fn test_count_collector() {
437 let mut collector = CountCollector::new();
438
439 collector.collect(0, 1.0, &[]);
440 collector.collect(1, 2.0, &[]);
441 collector.collect(2, 3.0, &[]);
442
443 assert_eq!(collector.count(), 3);
444 }
445
446 #[test]
447 fn test_multi_collector() {
448 let mut top_k = TopKCollector::new(2);
449 let mut count = CountCollector::new();
450
451 for (doc_id, score) in [(0, 1.0), (1, 3.0), (2, 2.0), (3, 4.0), (4, 0.5)] {
453 top_k.collect(doc_id, score, &[]);
454 count.collect(doc_id, score, &[]);
455 }
456
457 assert_eq!(count.count(), 5);
459
460 let results = top_k.into_sorted_results();
462 assert_eq!(results.len(), 2);
463 assert_eq!(results[0].doc_id, 3); assert_eq!(results[1].doc_id, 1); }
466}