use crate::core::{DocId, FieldId, NO_MORE_DOCS, Result, ScoreMode, Scorer, TwoPhaseIterator};
use crate::query::{BoundQuery, Query, ScorerSupplier};
use crate::search::bm25::{BlockMaxBm25Scorer, Bm25Scorer, Bm25Weight};
use crate::search::searcher::Searcher;
use crate::segment::reader::SegmentReader;
pub struct TermQuery {
pub field: String,
pub value: String,
}
impl Query for TermQuery {
fn bind(&self, searcher: &Searcher, score_mode: ScoreMode) -> Result<Box<dyn BoundQuery>> {
Ok(Box::new(BoundTermQuery {
field: self.field.clone(),
value: self.value.clone(),
score_mode,
total_docs: searcher.total_docs(),
doc_freq: searcher.doc_freq(&self.field, &self.value),
avg_field_length: searcher.avg_field_length(&self.field),
}))
}
}
pub(crate) struct BoundTermQuery {
pub(crate) field: String,
pub(crate) value: String,
pub(crate) score_mode: ScoreMode,
pub(crate) total_docs: u32,
#[allow(dead_code)]
pub(crate) doc_freq: u32,
pub(crate) avg_field_length: f32,
}
impl BoundTermQuery {
fn resolve_field(&self, reader: &SegmentReader) -> Option<FieldId> {
reader
.header()
.fields
.iter()
.find(|f| f.field_name == self.field)
.map(|f| f.field_id)
}
}
impl BoundQuery for BoundTermQuery {
fn bulk_score(
&self,
reader: &SegmentReader,
collector: &mut crate::search::collector::TopDocsCollector,
segment_id: crate::core::SegmentId,
) -> Result<Option<u64>> {
let field_id = match self.resolve_field(reader) {
Some(id) => id,
None => return Ok(Some(0)),
};
let doc_freq = reader.doc_freq(field_id, &self.value);
if doc_freq == 0 {
return Ok(Some(0));
}
if !self.score_mode.needs_scores() {
let postings = reader.postings(field_id, &self.value).unwrap();
let mut scorer = FilterScorer::new(postings);
return Ok(Some(crate::search::score_loop(
&mut scorer,
collector,
segment_id,
)));
}
let weight = Bm25Weight::new(self.total_docs, doc_freq, self.avg_field_length);
let norms = reader.norms(field_id).unwrap();
if let Some(dl) = norms.uniform_norm() {
let constant =
crate::search::bm25::bm25_score(weight.idf, 1.0, dl, weight.avg_field_length);
let postings = reader.postings(field_id, &self.value).unwrap();
let mut scorer = ConstantBm25Scorer::new(postings, constant);
return Ok(Some(crate::search::score_loop(
&mut scorer,
collector,
segment_id,
)));
}
if let Some(block_postings) = reader.postings_block_max(field_id, &self.value) {
let mut scorer = BlockMaxBm25Scorer::new(weight, block_postings, norms);
return Ok(Some(crate::search::score_loop(
&mut scorer,
collector,
segment_id,
)));
}
let postings = reader.postings(field_id, &self.value).unwrap();
let mut scorer = Bm25Scorer::new(weight, postings, norms);
Ok(Some(crate::search::score_loop(
&mut scorer,
collector,
segment_id,
)))
}
fn scorer_supplier(&self, reader: &SegmentReader) -> Result<Option<Box<dyn ScorerSupplier>>> {
let field_id = match self.resolve_field(reader) {
Some(id) => id,
None => return Ok(None),
};
let doc_freq = reader.doc_freq(field_id, &self.value);
if doc_freq == 0 {
return Ok(None);
}
Ok(Some(Box::new(TermScorerSupplier {
field_id,
value: self.value.clone(),
score_mode: self.score_mode,
doc_freq,
total_docs: self.total_docs,
avg_field_length: self.avg_field_length,
segment_data: reader as *const SegmentReader,
})))
}
fn explain(&self, reader: &SegmentReader, doc: DocId) -> Result<crate::search::Explanation> {
use crate::search::Explanation;
use crate::search::bm25::{bm25_idf, bm25_score};
let field_id = match self.resolve_field(reader) {
Some(id) => id,
None => {
return Ok(Explanation::no_match(format!(
"no field '{}' in segment",
self.field
)));
}
};
let doc_freq = reader.doc_freq(field_id, &self.value);
if doc_freq == 0 {
return Ok(Explanation::no_match(format!(
"term '{}' not found in field '{}'",
self.value, self.field
)));
}
let mut postings = match reader.postings(field_id, &self.value) {
Some(p) => p,
None => {
return Ok(Explanation::no_match(format!(
"term '{}' not found",
self.value
)));
}
};
let mut tf = 0u32;
while let Some((did, t)) = postings.next() {
if did == doc {
tf = t;
break;
}
if did > doc {
return Ok(Explanation::no_match(format!(
"doc {} does not contain term '{}'",
doc.as_u32(),
self.value
)));
}
}
if tf == 0 {
return Ok(Explanation::no_match(format!(
"doc {} does not contain term '{}'",
doc.as_u32(),
self.value
)));
}
let norms = reader.norms(field_id).unwrap();
let dl = crate::inverted::norms::decode_norm(norms.raw_byte(doc));
let avgdl = self.avg_field_length;
let idf = bm25_idf(self.total_docs, doc_freq);
let score = bm25_score(idf, tf as f32, dl, avgdl);
let idf_exp = Explanation::leaf(
idf,
format!("idf(docFreq={}, docCount={})", doc_freq, self.total_docs),
);
let tf_exp = Explanation::leaf(
tf as f32,
format!("tf(freq={} in doc {})", tf, doc.as_u32()),
);
let dl_exp = Explanation::leaf(dl, format!("dl(fieldLength={})", dl));
let avgdl_exp = Explanation::leaf(avgdl, format!("avgdl(avgFieldLength={:.1})", avgdl));
Ok(Explanation::matched(
score,
format!(
"score(freq={}) = idf * tf_norm, term={}, field={}",
tf, self.value, self.field
),
vec![idf_exp, tf_exp, dl_exp, avgdl_exp],
))
}
}
struct TermScorerSupplier {
field_id: FieldId,
value: String,
score_mode: ScoreMode,
doc_freq: u32,
total_docs: u32,
avg_field_length: f32,
segment_data: *const SegmentReader,
}
unsafe impl Send for TermScorerSupplier {}
impl ScorerSupplier for TermScorerSupplier {
fn cost(&self) -> u64 {
self.doc_freq as u64
}
fn scorer(self: Box<Self>) -> Result<Box<dyn Scorer>> {
let reader = unsafe { &*self.segment_data };
if !self.score_mode.needs_scores() {
let postings = reader.postings(self.field_id, &self.value).unwrap();
return Ok(Box::new(FilterScorer::new(postings)));
}
let weight = Bm25Weight::new(self.total_docs, self.doc_freq, self.avg_field_length);
let norms = reader.norms(self.field_id).unwrap();
if let Some(dl) = norms.uniform_norm() {
let constant =
crate::search::bm25::bm25_score(weight.idf, 1.0, dl, weight.avg_field_length);
let postings = reader.postings(self.field_id, &self.value).unwrap();
return Ok(Box::new(ConstantBm25Scorer::new(postings, constant)));
}
if let Some(block_postings) = reader.postings_block_max(self.field_id, &self.value) {
return Ok(Box::new(BlockMaxBm25Scorer::new(
weight,
block_postings,
norms,
)));
}
let postings = reader.postings(self.field_id, &self.value).unwrap();
Ok(Box::new(Bm25Scorer::new(weight, postings, norms)))
}
}
struct ConstantBm25Scorer<'a> {
postings: crate::inverted::postings::PostingListReader<'a>,
current: DocId,
constant_score: f32,
}
impl<'a> ConstantBm25Scorer<'a> {
fn new(
mut postings: crate::inverted::postings::PostingListReader<'a>,
constant_score: f32,
) -> Self {
let current = match postings.next() {
Some((id, _)) => id,
None => NO_MORE_DOCS,
};
Self {
postings,
current,
constant_score,
}
}
}
impl Scorer for ConstantBm25Scorer<'_> {
fn doc_id(&self) -> DocId {
self.current
}
fn next(&mut self) -> DocId {
self.current = match self.postings.next() {
Some((id, _)) => id,
None => NO_MORE_DOCS,
};
self.current
}
fn advance(&mut self, target: DocId) -> DocId {
while self.current < target && self.current != NO_MORE_DOCS {
self.next();
}
self.current
}
fn score(&mut self) -> f32 {
self.constant_score
}
fn two_phase(&mut self) -> Option<&mut dyn TwoPhaseIterator> {
None
}
fn max_score(&self) -> f32 {
self.constant_score
}
}
pub(crate) struct FilterScorer<'a> {
postings: crate::inverted::postings::PostingListReader<'a>,
current: DocId,
}
impl<'a> FilterScorer<'a> {
pub(crate) fn new(mut postings: crate::inverted::postings::PostingListReader<'a>) -> Self {
let current = match postings.next() {
Some((id, _)) => id,
None => NO_MORE_DOCS,
};
Self { postings, current }
}
}
impl Scorer for FilterScorer<'_> {
fn doc_id(&self) -> DocId {
self.current
}
fn next(&mut self) -> DocId {
self.current = match self.postings.next() {
Some((id, _)) => id,
None => NO_MORE_DOCS,
};
self.current
}
fn advance(&mut self, target: DocId) -> DocId {
while self.current < target && self.current != NO_MORE_DOCS {
self.next();
}
self.current
}
fn score(&mut self) -> f32 {
1.0 }
fn two_phase(&mut self) -> Option<&mut dyn TwoPhaseIterator> {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::analysis::Token;
use crate::core::SegmentId;
use crate::mapping::{FieldType, Mapping};
use crate::segment::builder::SegmentBuilder;
fn make_tokens(terms: &[&str]) -> Vec<Token> {
terms
.iter()
.enumerate()
.map(|(i, t)| Token::new(*t, 0, t.len(), i as u32))
.collect()
}
fn test_schema() -> Mapping {
Mapping::builder()
.field("body", FieldType::Text)
.field("tag", FieldType::Keyword)
.build()
}
fn build_test_segment() -> SegmentReader {
let schema = test_schema();
let mut builder = SegmentBuilder::new(SegmentId::new(1), &schema);
builder.add_document(
&[
(FieldId::new(0), make_tokens(&["hello", "world"])),
(FieldId::new(1), make_tokens(&["a"])),
],
br#"{"body":"hello world","tag":"a"}"#,
);
builder.add_document(
&[
(FieldId::new(0), make_tokens(&["hello", "luci"])),
(FieldId::new(1), make_tokens(&["b"])),
],
br#"{"body":"hello luci","tag":"b"}"#,
);
builder.add_document(
&[
(FieldId::new(0), make_tokens(&["goodbye"])),
(FieldId::new(1), make_tokens(&["a"])),
],
br#"{"body":"goodbye","tag":"a"}"#,
);
SegmentReader::open(builder.build()).unwrap()
}
#[test]
fn term_query_creates_weight() {
let reader = build_test_segment();
let store = crate::search::segment_store::SegmentStore::new(
vec![reader],
crate::analysis::AnalyzerRegistry::new(),
None,
None,
);
let searcher = Searcher::new(&store);
let query = TermQuery {
field: "tag".into(),
value: "a".into(),
};
let weight = query.bind(&searcher, ScoreMode::Complete).unwrap();
drop(weight);
}
#[test]
fn term_query_scorer_iterates() {
let reader = build_test_segment();
let store = crate::search::segment_store::SegmentStore::new(
vec![reader],
crate::analysis::AnalyzerRegistry::new(),
None,
None,
);
let searcher = Searcher::new(&store);
let query = TermQuery {
field: "tag".into(),
value: "a".into(),
};
let weight = query.bind(&searcher, ScoreMode::Complete).unwrap();
let reader = &searcher.segments()[0];
let supplier = weight.scorer_supplier(reader).unwrap().unwrap();
assert_eq!(supplier.cost(), 2);
let mut scorer = supplier.scorer().unwrap();
assert_eq!(scorer.doc_id(), DocId::new(0));
assert_eq!(scorer.next(), DocId::new(2));
assert_eq!(scorer.next(), NO_MORE_DOCS);
}
#[test]
fn term_query_missing_term() {
let reader = build_test_segment();
let store = crate::search::segment_store::SegmentStore::new(
vec![reader],
crate::analysis::AnalyzerRegistry::new(),
None,
None,
);
let searcher = Searcher::new(&store);
let query = TermQuery {
field: "tag".into(),
value: "nonexistent".into(),
};
let weight = query.bind(&searcher, ScoreMode::Complete).unwrap();
let reader = &searcher.segments()[0];
let supplier = weight.scorer_supplier(reader).unwrap();
assert!(supplier.is_none());
}
#[test]
fn term_query_missing_field() {
let reader = build_test_segment();
let store = crate::search::segment_store::SegmentStore::new(
vec![reader],
crate::analysis::AnalyzerRegistry::new(),
None,
None,
);
let searcher = Searcher::new(&store);
let query = TermQuery {
field: "nosuchfield".into(),
value: "x".into(),
};
let weight = query.bind(&searcher, ScoreMode::Complete).unwrap();
let reader = &searcher.segments()[0];
let supplier = weight.scorer_supplier(reader).unwrap();
assert!(supplier.is_none());
}
#[test]
fn term_query_filter_context() {
let reader = build_test_segment();
let store = crate::search::segment_store::SegmentStore::new(
vec![reader],
crate::analysis::AnalyzerRegistry::new(),
None,
None,
);
let searcher = Searcher::new(&store);
let query = TermQuery {
field: "tag".into(),
value: "a".into(),
};
let weight = query.bind(&searcher, ScoreMode::CompleteNoScores).unwrap();
let reader = &searcher.segments()[0];
let supplier = weight.scorer_supplier(reader).unwrap().unwrap();
let mut scorer = supplier.scorer().unwrap();
assert_eq!(scorer.doc_id(), DocId::new(0));
assert_eq!(scorer.score(), 1.0); assert_eq!(scorer.next(), DocId::new(2));
assert_eq!(scorer.next(), NO_MORE_DOCS);
}
}