1use std::cmp::Ordering;
4use std::collections::BinaryHeap;
5
6use crate::segment::SegmentReader;
7use crate::structures::TERMINATED;
8use crate::{DocId, Result, Score};
9
10use super::Query;
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
14pub struct DocAddress {
15 pub segment_id: String,
17 pub doc_id: DocId,
19}
20
21impl DocAddress {
22 pub fn new(segment_id: u128, doc_id: DocId) -> Self {
23 Self {
24 segment_id: format!("{:032x}", segment_id),
25 doc_id,
26 }
27 }
28
29 pub fn segment_id_u128(&self) -> Option<u128> {
31 u128::from_str_radix(&self.segment_id, 16).ok()
32 }
33}
34
35#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
37pub struct SearchResult {
38 pub doc_id: DocId,
39 pub score: Score,
40 #[serde(default, skip_serializing_if = "Vec::is_empty")]
42 pub positions: Vec<(u32, Vec<u32>)>,
43}
44
45#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
47pub struct MatchedField {
48 pub field_id: u32,
50 pub ordinals: Vec<u32>,
53}
54
55impl SearchResult {
56 pub fn extract_ordinals(&self) -> Vec<MatchedField> {
59 use rustc_hash::FxHashSet;
60
61 self.positions
62 .iter()
63 .map(|(field_id, positions)| {
64 let mut ordinals: FxHashSet<u32> = FxHashSet::default();
65 for &pos in positions {
66 let ordinal = pos >> 20; ordinals.insert(ordinal);
68 }
69 let mut ordinals: Vec<u32> = ordinals.into_iter().collect();
70 ordinals.sort_unstable();
71 MatchedField {
72 field_id: *field_id,
73 ordinals,
74 }
75 })
76 .collect()
77 }
78}
79
80#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
82pub struct SearchHit {
83 pub address: DocAddress,
85 pub score: Score,
86 #[serde(default, skip_serializing_if = "Vec::is_empty")]
88 pub matched_fields: Vec<MatchedField>,
89}
90
91#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
93pub struct SearchResponse {
94 pub hits: Vec<SearchHit>,
95 pub total_hits: u32,
96}
97
98impl PartialEq for SearchResult {
99 fn eq(&self, other: &Self) -> bool {
100 self.doc_id == other.doc_id
101 }
102}
103
104impl Eq for SearchResult {}
105
106impl PartialOrd for SearchResult {
107 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
108 Some(self.cmp(other))
109 }
110}
111
112impl Ord for SearchResult {
113 fn cmp(&self, other: &Self) -> Ordering {
114 other
115 .score
116 .partial_cmp(&self.score)
117 .unwrap_or(Ordering::Equal)
118 .then_with(|| self.doc_id.cmp(&other.doc_id))
119 }
120}
121
122pub struct TopKCollector {
124 heap: BinaryHeap<SearchResult>,
125 k: usize,
126}
127
128impl TopKCollector {
129 pub fn new(k: usize) -> Self {
130 Self {
131 heap: BinaryHeap::with_capacity(k + 1),
132 k,
133 }
134 }
135
136 pub fn collect(&mut self, doc_id: DocId, score: Score) {
137 self.collect_with_positions(doc_id, score, Vec::new());
138 }
139
140 pub fn collect_with_positions(
141 &mut self,
142 doc_id: DocId,
143 score: Score,
144 positions: Vec<(u32, Vec<u32>)>,
145 ) {
146 if self.heap.len() < self.k {
147 self.heap.push(SearchResult {
148 doc_id,
149 score,
150 positions,
151 });
152 } else if let Some(min) = self.heap.peek()
153 && score > min.score
154 {
155 self.heap.pop();
156 self.heap.push(SearchResult {
157 doc_id,
158 score,
159 positions,
160 });
161 }
162 }
163
164 pub fn into_sorted_results(self) -> Vec<SearchResult> {
165 let mut results: Vec<_> = self.heap.into_vec();
166 results.sort_by(|a, b| {
167 b.score
168 .partial_cmp(&a.score)
169 .unwrap_or(Ordering::Equal)
170 .then_with(|| a.doc_id.cmp(&b.doc_id))
171 });
172 results
173 }
174}
175
176pub async fn search_segment(
178 reader: &SegmentReader,
179 query: &dyn Query,
180 limit: usize,
181) -> Result<Vec<SearchResult>> {
182 search_segment_with_positions(reader, query, limit, false).await
183}
184
185pub async fn search_segment_with_positions(
187 reader: &SegmentReader,
188 query: &dyn Query,
189 limit: usize,
190 collect_positions: bool,
191) -> Result<Vec<SearchResult>> {
192 let mut scorer = query.scorer(reader, limit).await?;
193 let mut collector = TopKCollector::new(limit);
194
195 let mut doc = scorer.doc();
196
197 while doc != TERMINATED {
198 let positions = if collect_positions {
199 scorer.matched_positions().unwrap_or_default()
200 } else {
201 Vec::new()
202 };
203 collector.collect_with_positions(doc, scorer.score(), positions);
204 doc = scorer.advance();
205 }
206
207 Ok(collector.into_sorted_results())
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213
214 #[test]
215 fn test_top_k_collector() {
216 let mut collector = TopKCollector::new(3);
217
218 collector.collect(0, 1.0);
219 collector.collect(1, 3.0);
220 collector.collect(2, 2.0);
221 collector.collect(3, 4.0);
222 collector.collect(4, 0.5);
223
224 let results = collector.into_sorted_results();
225
226 assert_eq!(results.len(), 3);
227 assert_eq!(results[0].doc_id, 3); assert_eq!(results[1].doc_id, 1); assert_eq!(results[2].doc_id, 2); }
231}