1pub mod bm25;
6pub mod buffered_union;
7pub mod bulk;
8pub mod collector;
9pub mod conjunction;
10pub mod expression;
11pub mod highlight;
12pub mod reader;
13pub mod results;
14pub mod rrf;
15pub mod searcher;
16pub mod segment_store;
17pub mod wand;
18
19use crate::core::{NO_MORE_DOCS, Scorer, SegmentId};
20
21use crate::search::collector::TopDocsCollector;
22
23#[derive(Clone, Copy, Debug)]
28pub enum RescoreScoreMode {
29 Total,
30 Multiply,
31 Avg,
32 Max,
33 Min,
34}
35
36impl RescoreScoreMode {
37 fn combine(&self, original: f32, rescore: f32, query_weight: f32, rescore_weight: f32) -> f32 {
38 let o = original * query_weight;
39 let r = rescore * rescore_weight;
40 match self {
41 Self::Total => o + r,
42 Self::Multiply => o * r,
43 Self::Avg => (o + r) / 2.0,
44 Self::Max => o.max(r),
45 Self::Min => o.min(r),
46 }
47 }
48}
49
50#[derive(Clone, Debug)]
55pub struct TotalHits {
56 pub value: u64,
57 pub relation: TotalHitsRelation,
58}
59
60#[derive(Clone, Copy, Debug, PartialEq)]
61pub enum TotalHitsRelation {
62 EqualTo,
63 GreaterThanOrEqualTo,
64}
65
66#[derive(Clone, Copy, Debug)]
68pub enum TrackTotalHits {
69 Exact,
70 Disabled,
71 UpTo(u64),
72}
73
74impl TotalHits {
75 pub fn exact(value: u64) -> Self {
76 Self {
77 value,
78 relation: TotalHitsRelation::EqualTo,
79 }
80 }
81
82 pub fn resolve(raw_total: u64, track: TrackTotalHits) -> Self {
83 match track {
84 TrackTotalHits::Exact => Self {
85 value: raw_total,
86 relation: TotalHitsRelation::EqualTo,
87 },
88 TrackTotalHits::Disabled => Self {
89 value: 0,
90 relation: TotalHitsRelation::GreaterThanOrEqualTo,
91 },
92 TrackTotalHits::UpTo(cap) => {
93 if raw_total <= cap {
94 Self {
95 value: raw_total,
96 relation: TotalHitsRelation::EqualTo,
97 }
98 } else {
99 Self {
100 value: cap,
101 relation: TotalHitsRelation::GreaterThanOrEqualTo,
102 }
103 }
104 }
105 }
106 }
107
108 pub fn to_json(&self) -> serde_json::Value {
109 serde_json::json!({
110 "value": self.value,
111 "relation": match self.relation {
112 TotalHitsRelation::EqualTo => "eq",
113 TotalHitsRelation::GreaterThanOrEqualTo => "gte",
114 }
115 })
116 }
117}
118
119#[derive(Clone, Debug)]
122pub enum SourceFilter {
123 Enabled,
125 Disabled,
127 Fields(Vec<String>),
129 IncludeExclude {
131 includes: Vec<String>,
132 excludes: Vec<String>,
133 },
134}
135
136pub fn filter_source(
138 source: &serde_json::Value,
139 filter: &SourceFilter,
140) -> Option<serde_json::Value> {
141 match filter {
142 SourceFilter::Enabled => Some(source.clone()),
143 SourceFilter::Disabled => None,
144 SourceFilter::Fields(fields) => {
145 let obj = source.as_object()?;
146 let filtered: serde_json::Map<String, serde_json::Value> = obj
147 .iter()
148 .filter(|(k, _)| fields.iter().any(|f| f == *k))
149 .map(|(k, v)| (k.clone(), v.clone()))
150 .collect();
151 Some(serde_json::Value::Object(filtered))
152 }
153 SourceFilter::IncludeExclude { includes, excludes } => {
154 let obj = source.as_object()?;
155 let filtered: serde_json::Map<String, serde_json::Value> = obj
156 .iter()
157 .filter(|(k, _)| {
158 let included = includes.is_empty() || includes.iter().any(|f| f == *k);
159 let excluded = excludes.iter().any(|f| f == *k);
160 included && !excluded
161 })
162 .map(|(k, v)| (k.clone(), v.clone()))
163 .collect();
164 Some(serde_json::Value::Object(filtered))
165 }
166 }
167}
168
169#[derive(Clone, Debug)]
174pub struct Explanation {
175 pub value: f32,
176 pub description: String,
177 pub details: Vec<Explanation>,
178}
179
180impl Explanation {
181 pub fn matched(value: f32, description: String, details: Vec<Explanation>) -> Self {
182 Self {
183 value,
184 description,
185 details,
186 }
187 }
188
189 pub fn leaf(value: f32, description: String) -> Self {
190 Self {
191 value,
192 description,
193 details: Vec::new(),
194 }
195 }
196
197 pub fn no_match(description: String) -> Self {
198 Self {
199 value: 0.0,
200 description,
201 details: Vec::new(),
202 }
203 }
204
205 pub fn to_json(&self) -> serde_json::Value {
206 serde_json::json!({
207 "value": self.value,
208 "description": self.description,
209 "details": self.details.iter().map(|d| d.to_json()).collect::<Vec<_>>()
210 })
211 }
212}
213
214#[derive(Clone, Debug)]
219pub struct SortField {
220 pub field: SortFieldType,
221 pub order: SortOrder,
222 pub missing: MissingValue,
223}
224
225#[derive(Clone, Debug)]
226pub enum SortFieldType {
227 Score,
229 Doc,
231 Field(String),
233}
234
235#[derive(Clone, Copy, Debug, PartialEq)]
236pub enum SortOrder {
237 Asc,
238 Desc,
239}
240
241#[derive(Clone, Copy, Debug)]
242pub enum MissingValue {
243 Last,
244 First,
245}
246
247#[derive(Clone, Debug)]
249pub enum SortValue {
250 Score(f32),
251 Doc(u64),
252 F64(f64),
253 I64(i64),
254 Str(String),
255 Bool(bool),
256 Null,
257}
258
259impl SortValue {
260 pub fn compare(&self, other: &SortValue, field: &SortField) -> std::cmp::Ordering {
262 use std::cmp::Ordering;
263
264 let natural = match (self, other) {
265 (SortValue::Null, SortValue::Null) => Ordering::Equal,
266 (SortValue::Null, _) => match field.missing {
267 MissingValue::Last => match field.order {
268 SortOrder::Asc => Ordering::Greater,
269 SortOrder::Desc => Ordering::Less,
270 },
271 MissingValue::First => match field.order {
272 SortOrder::Asc => Ordering::Less,
273 SortOrder::Desc => Ordering::Greater,
274 },
275 },
276 (_, SortValue::Null) => match field.missing {
277 MissingValue::Last => match field.order {
278 SortOrder::Asc => Ordering::Less,
279 SortOrder::Desc => Ordering::Greater,
280 },
281 MissingValue::First => match field.order {
282 SortOrder::Asc => Ordering::Greater,
283 SortOrder::Desc => Ordering::Less,
284 },
285 },
286 (SortValue::F64(a), SortValue::F64(b)) => a.partial_cmp(b).unwrap_or(Ordering::Equal),
287 (SortValue::I64(a), SortValue::I64(b)) => a.cmp(b),
288 (SortValue::Str(a), SortValue::Str(b)) => a.cmp(b),
289 (SortValue::Bool(a), SortValue::Bool(b)) => a.cmp(b),
290 (SortValue::Score(a), SortValue::Score(b)) => {
291 a.partial_cmp(b).unwrap_or(Ordering::Equal)
292 }
293 (SortValue::Doc(a), SortValue::Doc(b)) => a.cmp(b),
294 _ => Ordering::Equal,
295 };
296
297 match field.order {
298 SortOrder::Asc => natural,
299 SortOrder::Desc => natural.reverse(),
300 }
301 }
302
303 pub fn to_json(&self) -> serde_json::Value {
305 match self {
306 SortValue::Score(s) => serde_json::json!(s),
307 SortValue::Doc(d) => serde_json::json!(d),
308 SortValue::F64(f) => serde_json::json!(f),
309 SortValue::I64(i) => serde_json::json!(i),
310 SortValue::Str(s) => serde_json::json!(s),
311 SortValue::Bool(b) => serde_json::json!(b),
312 SortValue::Null => serde_json::Value::Null,
313 }
314 }
315}
316
317pub fn compare_sort_values_cascade(
319 a: &[SortValue],
320 b: &[SortValue],
321 sort_fields: &[SortField],
322) -> std::cmp::Ordering {
323 for (i, sf) in sort_fields.iter().enumerate() {
324 let cmp = a[i].compare(&b[i], sf);
325 if cmp != std::cmp::Ordering::Equal {
326 return cmp;
327 }
328 }
329 std::cmp::Ordering::Equal
330}
331
332#[inline]
338pub fn score_loop<S: Scorer>(
339 scorer: &mut S,
340 collector: &mut TopDocsCollector,
341 seg_id: SegmentId,
342) -> u64 {
343 let mut last_min: f32 = 0.0;
344 let mut total_hits: u64 = 0;
345 loop {
346 let doc = scorer.doc_id();
347 if doc == NO_MORE_DOCS {
348 break;
349 }
350 let score = scorer.score();
351 collector.collect(doc, seg_id, score);
352 total_hits += 1;
353 let min = collector.min_score();
354 if min > last_min {
355 scorer.set_min_competitive_score(min);
356 last_min = min;
357 }
358 scorer.next();
359 }
360 total_hits
361}