pub mod bm25;
pub mod buffered_union;
pub mod bulk;
pub mod collector;
pub mod conjunction;
pub mod expression;
pub mod highlight;
pub mod reader;
pub mod results;
pub mod rrf;
pub mod searcher;
pub mod segment_store;
pub mod wand;
use crate::core::{NO_MORE_DOCS, Scorer, SegmentId};
use crate::search::collector::TopDocsCollector;
#[derive(Clone, Copy, Debug)]
pub enum RescoreScoreMode {
Total,
Multiply,
Avg,
Max,
Min,
}
impl RescoreScoreMode {
fn combine(&self, original: f32, rescore: f32, query_weight: f32, rescore_weight: f32) -> f32 {
let o = original * query_weight;
let r = rescore * rescore_weight;
match self {
Self::Total => o + r,
Self::Multiply => o * r,
Self::Avg => (o + r) / 2.0,
Self::Max => o.max(r),
Self::Min => o.min(r),
}
}
}
#[derive(Clone, Debug)]
pub struct TotalHits {
pub value: u64,
pub relation: TotalHitsRelation,
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum TotalHitsRelation {
EqualTo,
GreaterThanOrEqualTo,
}
#[derive(Clone, Copy, Debug)]
pub enum TrackTotalHits {
Exact,
Disabled,
UpTo(u64),
}
impl TotalHits {
pub fn exact(value: u64) -> Self {
Self {
value,
relation: TotalHitsRelation::EqualTo,
}
}
pub fn resolve(raw_total: u64, track: TrackTotalHits) -> Self {
match track {
TrackTotalHits::Exact => Self {
value: raw_total,
relation: TotalHitsRelation::EqualTo,
},
TrackTotalHits::Disabled => Self {
value: 0,
relation: TotalHitsRelation::GreaterThanOrEqualTo,
},
TrackTotalHits::UpTo(cap) => {
if raw_total <= cap {
Self {
value: raw_total,
relation: TotalHitsRelation::EqualTo,
}
} else {
Self {
value: cap,
relation: TotalHitsRelation::GreaterThanOrEqualTo,
}
}
}
}
}
pub fn to_json(&self) -> serde_json::Value {
serde_json::json!({
"value": self.value,
"relation": match self.relation {
TotalHitsRelation::EqualTo => "eq",
TotalHitsRelation::GreaterThanOrEqualTo => "gte",
}
})
}
}
#[derive(Clone, Debug)]
pub enum SourceFilter {
Enabled,
Disabled,
Fields(Vec<String>),
IncludeExclude {
includes: Vec<String>,
excludes: Vec<String>,
},
}
pub fn filter_source(
source: &serde_json::Value,
filter: &SourceFilter,
) -> Option<serde_json::Value> {
match filter {
SourceFilter::Enabled => Some(source.clone()),
SourceFilter::Disabled => None,
SourceFilter::Fields(fields) => {
let obj = source.as_object()?;
let filtered: serde_json::Map<String, serde_json::Value> = obj
.iter()
.filter(|(k, _)| fields.iter().any(|f| f == *k))
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
Some(serde_json::Value::Object(filtered))
}
SourceFilter::IncludeExclude { includes, excludes } => {
let obj = source.as_object()?;
let filtered: serde_json::Map<String, serde_json::Value> = obj
.iter()
.filter(|(k, _)| {
let included = includes.is_empty() || includes.iter().any(|f| f == *k);
let excluded = excludes.iter().any(|f| f == *k);
included && !excluded
})
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
Some(serde_json::Value::Object(filtered))
}
}
}
#[derive(Clone, Debug)]
pub struct Explanation {
pub value: f32,
pub description: String,
pub details: Vec<Explanation>,
}
impl Explanation {
pub fn matched(value: f32, description: String, details: Vec<Explanation>) -> Self {
Self {
value,
description,
details,
}
}
pub fn leaf(value: f32, description: String) -> Self {
Self {
value,
description,
details: Vec::new(),
}
}
pub fn no_match(description: String) -> Self {
Self {
value: 0.0,
description,
details: Vec::new(),
}
}
pub fn to_json(&self) -> serde_json::Value {
serde_json::json!({
"value": self.value,
"description": self.description,
"details": self.details.iter().map(|d| d.to_json()).collect::<Vec<_>>()
})
}
}
#[derive(Clone, Debug)]
pub struct SortField {
pub field: SortFieldType,
pub order: SortOrder,
pub missing: MissingValue,
}
#[derive(Clone, Debug)]
pub enum SortFieldType {
Score,
Doc,
Field(String),
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum SortOrder {
Asc,
Desc,
}
#[derive(Clone, Copy, Debug)]
pub enum MissingValue {
Last,
First,
}
#[derive(Clone, Debug)]
pub enum SortValue {
Score(f32),
Doc(u64),
F64(f64),
I64(i64),
Str(String),
Bool(bool),
Null,
}
impl SortValue {
pub fn compare(&self, other: &SortValue, field: &SortField) -> std::cmp::Ordering {
use std::cmp::Ordering;
let natural = match (self, other) {
(SortValue::Null, SortValue::Null) => Ordering::Equal,
(SortValue::Null, _) => match field.missing {
MissingValue::Last => match field.order {
SortOrder::Asc => Ordering::Greater,
SortOrder::Desc => Ordering::Less,
},
MissingValue::First => match field.order {
SortOrder::Asc => Ordering::Less,
SortOrder::Desc => Ordering::Greater,
},
},
(_, SortValue::Null) => match field.missing {
MissingValue::Last => match field.order {
SortOrder::Asc => Ordering::Less,
SortOrder::Desc => Ordering::Greater,
},
MissingValue::First => match field.order {
SortOrder::Asc => Ordering::Greater,
SortOrder::Desc => Ordering::Less,
},
},
(SortValue::F64(a), SortValue::F64(b)) => a.partial_cmp(b).unwrap_or(Ordering::Equal),
(SortValue::I64(a), SortValue::I64(b)) => a.cmp(b),
(SortValue::Str(a), SortValue::Str(b)) => a.cmp(b),
(SortValue::Bool(a), SortValue::Bool(b)) => a.cmp(b),
(SortValue::Score(a), SortValue::Score(b)) => {
a.partial_cmp(b).unwrap_or(Ordering::Equal)
}
(SortValue::Doc(a), SortValue::Doc(b)) => a.cmp(b),
_ => Ordering::Equal,
};
match field.order {
SortOrder::Asc => natural,
SortOrder::Desc => natural.reverse(),
}
}
pub fn to_json(&self) -> serde_json::Value {
match self {
SortValue::Score(s) => serde_json::json!(s),
SortValue::Doc(d) => serde_json::json!(d),
SortValue::F64(f) => serde_json::json!(f),
SortValue::I64(i) => serde_json::json!(i),
SortValue::Str(s) => serde_json::json!(s),
SortValue::Bool(b) => serde_json::json!(b),
SortValue::Null => serde_json::Value::Null,
}
}
}
pub fn compare_sort_values_cascade(
a: &[SortValue],
b: &[SortValue],
sort_fields: &[SortField],
) -> std::cmp::Ordering {
for (i, sf) in sort_fields.iter().enumerate() {
let cmp = a[i].compare(&b[i], sf);
if cmp != std::cmp::Ordering::Equal {
return cmp;
}
}
std::cmp::Ordering::Equal
}
#[inline]
pub fn score_loop<S: Scorer>(
scorer: &mut S,
collector: &mut TopDocsCollector,
seg_id: SegmentId,
) -> u64 {
let mut last_min: f32 = 0.0;
let mut total_hits: u64 = 0;
loop {
let doc = scorer.doc_id();
if doc == NO_MORE_DOCS {
break;
}
let score = scorer.score();
collector.collect(doc, seg_id, score);
total_hits += 1;
let min = collector.min_score();
if min > last_min {
scorer.set_min_competitive_score(min);
last_min = min;
}
scorer.next();
}
total_hits
}