1use std::cmp::Ordering;
4use std::collections::BinaryHeap;
5
6use crate::segment::SegmentReader;
7use crate::structures::TERMINATED;
8use crate::{DocId, Result, Score};
9
10use super::Query;
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
14pub struct DocAddress {
15 pub segment_id: String,
17 pub doc_id: DocId,
19}
20
21impl DocAddress {
22 pub fn new(segment_id: u128, doc_id: DocId) -> Self {
23 Self {
24 segment_id: format!("{:032x}", segment_id),
25 doc_id,
26 }
27 }
28
29 pub fn segment_id_u128(&self) -> Option<u128> {
31 u128::from_str_radix(&self.segment_id, 16).ok()
32 }
33}
34
35#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)]
39pub struct ScoredPosition {
40 pub position: u32,
42 pub score: f32,
44}
45
46impl ScoredPosition {
47 pub fn new(position: u32, score: f32) -> Self {
48 Self { position, score }
49 }
50}
51
52#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
54pub struct SearchResult {
55 pub doc_id: DocId,
56 pub score: Score,
57 #[serde(default, skip_serializing_if = "Vec::is_empty")]
60 pub positions: Vec<(u32, Vec<ScoredPosition>)>,
61}
62
63#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
65pub struct MatchedField {
66 pub field_id: u32,
68 pub ordinals: Vec<u32>,
71}
72
73impl SearchResult {
74 pub fn extract_ordinals(&self) -> Vec<MatchedField> {
78 use rustc_hash::FxHashSet;
79
80 self.positions
81 .iter()
82 .map(|(field_id, scored_positions)| {
83 let mut ordinals: FxHashSet<u32> = FxHashSet::default();
84 for sp in scored_positions {
85 let ordinal = if sp.position > 0xFFFFF {
89 sp.position >> 20
90 } else {
91 sp.position
92 };
93 ordinals.insert(ordinal);
94 }
95 let mut ordinals: Vec<u32> = ordinals.into_iter().collect();
96 ordinals.sort_unstable();
97 MatchedField {
98 field_id: *field_id,
99 ordinals,
100 }
101 })
102 .collect()
103 }
104
105 pub fn field_positions(&self, field_id: u32) -> Option<&[ScoredPosition]> {
107 self.positions
108 .iter()
109 .find(|(fid, _)| *fid == field_id)
110 .map(|(_, positions)| positions.as_slice())
111 }
112}
113
114#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
116pub struct SearchHit {
117 pub address: DocAddress,
119 pub score: Score,
120 #[serde(default, skip_serializing_if = "Vec::is_empty")]
122 pub matched_fields: Vec<MatchedField>,
123}
124
125#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
127pub struct SearchResponse {
128 pub hits: Vec<SearchHit>,
129 pub total_hits: u32,
130}
131
132impl PartialEq for SearchResult {
133 fn eq(&self, other: &Self) -> bool {
134 self.doc_id == other.doc_id
135 }
136}
137
138impl Eq for SearchResult {}
139
140impl PartialOrd for SearchResult {
141 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
142 Some(self.cmp(other))
143 }
144}
145
146impl Ord for SearchResult {
147 fn cmp(&self, other: &Self) -> Ordering {
148 other
149 .score
150 .partial_cmp(&self.score)
151 .unwrap_or(Ordering::Equal)
152 .then_with(|| self.doc_id.cmp(&other.doc_id))
153 }
154}
155
156pub trait Collector {
161 fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]);
164
165 fn needs_positions(&self) -> bool {
167 false
168 }
169}
170
171pub struct TopKCollector {
173 heap: BinaryHeap<SearchResult>,
174 k: usize,
175 collect_positions: bool,
176}
177
178impl TopKCollector {
179 pub fn new(k: usize) -> Self {
180 Self {
181 heap: BinaryHeap::with_capacity(k + 1),
182 k,
183 collect_positions: false,
184 }
185 }
186
187 pub fn with_positions(k: usize) -> Self {
189 Self {
190 heap: BinaryHeap::with_capacity(k + 1),
191 k,
192 collect_positions: true,
193 }
194 }
195
196 pub fn into_sorted_results(self) -> Vec<SearchResult> {
197 let mut results: Vec<_> = self.heap.into_vec();
198 results.sort_by(|a, b| {
199 b.score
200 .partial_cmp(&a.score)
201 .unwrap_or(Ordering::Equal)
202 .then_with(|| a.doc_id.cmp(&b.doc_id))
203 });
204 results
205 }
206}
207
208impl Collector for TopKCollector {
209 fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]) {
210 let positions = if self.collect_positions {
211 positions.to_vec()
212 } else {
213 Vec::new()
214 };
215
216 if self.heap.len() < self.k {
217 self.heap.push(SearchResult {
218 doc_id,
219 score,
220 positions,
221 });
222 } else if let Some(min) = self.heap.peek()
223 && score > min.score
224 {
225 self.heap.pop();
226 self.heap.push(SearchResult {
227 doc_id,
228 score,
229 positions,
230 });
231 }
232 }
233
234 fn needs_positions(&self) -> bool {
235 self.collect_positions
236 }
237}
238
239#[derive(Default)]
241pub struct CountCollector {
242 count: u64,
243}
244
245impl CountCollector {
246 pub fn new() -> Self {
247 Self { count: 0 }
248 }
249
250 pub fn count(&self) -> u64 {
252 self.count
253 }
254}
255
256impl Collector for CountCollector {
257 #[inline]
258 fn collect(
259 &mut self,
260 _doc_id: DocId,
261 _score: Score,
262 _positions: &[(u32, Vec<ScoredPosition>)],
263 ) {
264 self.count += 1;
265 }
266}
267
268pub async fn search_segment(
270 reader: &SegmentReader,
271 query: &dyn Query,
272 limit: usize,
273) -> Result<Vec<SearchResult>> {
274 let mut collector = TopKCollector::new(limit);
275 collect_segment(reader, query, &mut collector).await?;
276 Ok(collector.into_sorted_results())
277}
278
279pub async fn search_segment_with_positions(
281 reader: &SegmentReader,
282 query: &dyn Query,
283 limit: usize,
284) -> Result<Vec<SearchResult>> {
285 let mut collector = TopKCollector::with_positions(limit);
286 collect_segment(reader, query, &mut collector).await?;
287 Ok(collector.into_sorted_results())
288}
289
290pub async fn count_segment(reader: &SegmentReader, query: &dyn Query) -> Result<u64> {
292 let mut collector = CountCollector::new();
293 collect_segment(reader, query, &mut collector).await?;
294 Ok(collector.count())
295}
296
297impl<A: Collector, B: Collector> Collector for (&mut A, &mut B) {
299 fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]) {
300 self.0.collect(doc_id, score, positions);
301 self.1.collect(doc_id, score, positions);
302 }
303 fn needs_positions(&self) -> bool {
304 self.0.needs_positions() || self.1.needs_positions()
305 }
306}
307
308impl<A: Collector, B: Collector, C: Collector> Collector for (&mut A, &mut B, &mut C) {
310 fn collect(&mut self, doc_id: DocId, score: Score, positions: &[(u32, Vec<ScoredPosition>)]) {
311 self.0.collect(doc_id, score, positions);
312 self.1.collect(doc_id, score, positions);
313 self.2.collect(doc_id, score, positions);
314 }
315 fn needs_positions(&self) -> bool {
316 self.0.needs_positions() || self.1.needs_positions() || self.2.needs_positions()
317 }
318}
319
320pub async fn collect_segment<C: Collector>(
334 reader: &SegmentReader,
335 query: &dyn Query,
336 collector: &mut C,
337) -> Result<()> {
338 let needs_positions = collector.needs_positions();
339 let mut scorer = query.scorer(reader, usize::MAX / 2).await?;
341
342 let mut doc = scorer.doc();
343 while doc != TERMINATED {
344 let positions = if needs_positions {
345 scorer.matched_positions().unwrap_or_default()
346 } else {
347 Vec::new()
348 };
349 collector.collect(doc, scorer.score(), &positions);
350 doc = scorer.advance();
351 }
352
353 Ok(())
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359
360 #[test]
361 fn test_top_k_collector() {
362 let mut collector = TopKCollector::new(3);
363
364 collector.collect(0, 1.0, &[]);
365 collector.collect(1, 3.0, &[]);
366 collector.collect(2, 2.0, &[]);
367 collector.collect(3, 4.0, &[]);
368 collector.collect(4, 0.5, &[]);
369
370 let results = collector.into_sorted_results();
371
372 assert_eq!(results.len(), 3);
373 assert_eq!(results[0].doc_id, 3); assert_eq!(results[1].doc_id, 1); assert_eq!(results[2].doc_id, 2); }
377
378 #[test]
379 fn test_count_collector() {
380 let mut collector = CountCollector::new();
381
382 collector.collect(0, 1.0, &[]);
383 collector.collect(1, 2.0, &[]);
384 collector.collect(2, 3.0, &[]);
385
386 assert_eq!(collector.count(), 3);
387 }
388
389 #[test]
390 fn test_multi_collector() {
391 let mut top_k = TopKCollector::new(2);
392 let mut count = CountCollector::new();
393
394 for (doc_id, score) in [(0, 1.0), (1, 3.0), (2, 2.0), (3, 4.0), (4, 0.5)] {
396 top_k.collect(doc_id, score, &[]);
397 count.collect(doc_id, score, &[]);
398 }
399
400 assert_eq!(count.count(), 5);
402
403 let results = top_k.into_sorted_results();
405 assert_eq!(results.len(), 2);
406 assert_eq!(results[0].doc_id, 3); assert_eq!(results[1].doc_id, 1); }
409}