use std::sync::Arc;
use crate::dsl::Field;
use crate::segment::SegmentReader;
use crate::structures::{BlockPostingIterator, BlockPostingList, PositionPostingList, TERMINATED};
use crate::{DocId, Score};
use super::{CountFuture, EmptyScorer, GlobalStats, Query, Scorer, ScorerFuture};
#[derive(Clone)]
pub struct PhraseQuery {
pub field: Field,
pub terms: Vec<Vec<u8>>,
pub slop: u32,
global_stats: Option<Arc<GlobalStats>>,
}
impl std::fmt::Display for PhraseQuery {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let terms: Vec<_> = self
.terms
.iter()
.map(|t| String::from_utf8_lossy(t))
.collect();
write!(f, "Phrase({}:\"{}\"", self.field.0, terms.join(" "))?;
if self.slop > 0 {
write!(f, "~{}", self.slop)?;
}
write!(f, ")")
}
}
impl std::fmt::Debug for PhraseQuery {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let terms: Vec<_> = self
.terms
.iter()
.map(|t| String::from_utf8_lossy(t).to_string())
.collect();
f.debug_struct("PhraseQuery")
.field("field", &self.field)
.field("terms", &terms)
.field("slop", &self.slop)
.finish()
}
}
impl PhraseQuery {
pub fn new(field: Field, terms: Vec<Vec<u8>>) -> Self {
Self {
field,
terms,
slop: 0,
global_stats: None,
}
}
pub fn text(field: Field, phrase: &str) -> Self {
let terms: Vec<Vec<u8>> = phrase
.split_whitespace()
.map(|w| w.to_lowercase().into_bytes())
.collect();
Self {
field,
terms,
slop: 0,
global_stats: None,
}
}
pub fn with_slop(mut self, slop: u32) -> Self {
self.slop = slop;
self
}
pub fn with_global_stats(mut self, stats: Arc<GlobalStats>) -> Self {
self.global_stats = Some(stats);
self
}
}
fn build_phrase_scorer<'a>(
term_data: Vec<(BlockPostingList, PositionPostingList)>,
slop: u32,
reader: &SegmentReader,
field: Field,
) -> Box<dyn Scorer + 'a> {
let idf: f32 = term_data
.iter()
.map(|(p, _)| {
let num_docs = reader.num_docs() as f32;
let doc_freq = p.doc_count() as f32;
super::bm25_idf(doc_freq, num_docs)
})
.sum();
let avg_field_len = reader.avg_field_len(field);
let (postings, positions): (Vec<_>, Vec<_>) = term_data.into_iter().unzip();
Box::new(PhraseScorer::new(
postings,
positions,
slop,
idf,
avg_field_len,
))
}
macro_rules! phrase_early_returns {
($field:expr, $terms:expr, $reader:expr, $limit:expr,
$scorer_fn:ident $(, $aw:tt)*) => {
if $terms.is_empty() {
return Ok(Box::new(EmptyScorer) as Box<dyn Scorer + '_>);
}
if $terms.len() == 1 {
let tq = super::TermQuery::new($field, $terms[0].clone());
return tq.$scorer_fn($reader, $limit) $(. $aw)* ;
}
if !$reader.has_positions($field) {
let mut bq = super::BooleanQuery::new();
for t in $terms.iter() {
bq = bq.must(super::TermQuery::new($field, t.clone()));
}
return bq.$scorer_fn($reader, $limit) $(. $aw)* ;
}
};
}
impl Query for PhraseQuery {
fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
let field = self.field;
let terms = self.terms.clone();
let slop = self.slop;
Box::pin(async move {
phrase_early_returns!(field, terms, reader, limit, scorer, await);
let mut term_data = Vec::with_capacity(terms.len());
for term in &terms {
let (postings, positions) = futures::join!(
reader.get_postings(field, term),
reader.get_positions(field, term)
);
match (postings?, positions?) {
(Some(p), Some(pos)) => term_data.push((p, pos)),
_ => return Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>),
}
}
Ok(build_phrase_scorer(term_data, slop, reader, field))
})
}
#[cfg(feature = "sync")]
fn scorer_sync<'a>(
&self,
reader: &'a SegmentReader,
limit: usize,
) -> crate::Result<Box<dyn Scorer + 'a>> {
phrase_early_returns!(self.field, self.terms, reader, limit, scorer_sync);
use rayon::prelude::*;
let pairs: crate::Result<Vec<Option<(BlockPostingList, PositionPostingList)>>> = self
.terms
.par_iter()
.map(|term| {
let postings = reader.get_postings_sync(self.field, term)?;
let positions = reader.get_positions_sync(self.field, term)?;
Ok(match (postings, positions) {
(Some(p), Some(pos)) => Some((p, pos)),
_ => None,
})
})
.collect();
let mut term_data = Vec::with_capacity(self.terms.len());
for entry in pairs? {
match entry {
Some(pair) => term_data.push(pair),
None => return Ok(Box::new(EmptyScorer) as Box<dyn Scorer + 'a>),
}
}
Ok(build_phrase_scorer(
term_data, self.slop, reader, self.field,
))
}
fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
let field = self.field;
let terms = self.terms.clone();
Box::pin(async move {
if terms.is_empty() {
return Ok(0);
}
let mut min_count = u32::MAX;
for term in &terms {
match reader.get_postings(field, term).await? {
Some(list) => min_count = min_count.min(list.doc_count()),
None => return Ok(0),
}
}
Ok((min_count / 10).max(1))
})
}
}
struct PhraseScorer {
posting_iters: Vec<BlockPostingIterator<'static>>,
position_lists: Vec<PositionPostingList>,
slop: u32,
current_doc: DocId,
idf: f32,
avg_field_len: f32,
position_bufs: Vec<Vec<u32>>,
}
impl PhraseScorer {
fn new(
posting_lists: Vec<BlockPostingList>,
position_lists: Vec<PositionPostingList>,
slop: u32,
idf: f32,
avg_field_len: f32,
) -> Self {
let posting_iters: Vec<_> = posting_lists
.into_iter()
.map(|p| p.into_iterator())
.collect();
let num_terms = position_lists.len();
let mut scorer = Self {
posting_iters,
position_lists,
slop,
current_doc: 0,
idf,
avg_field_len,
position_bufs: (0..num_terms).map(|_| Vec::new()).collect(),
};
scorer.find_next_phrase_match();
scorer
}
fn find_next_phrase_match(&mut self) {
loop {
let doc = self.find_next_and_match();
if doc == TERMINATED {
self.current_doc = TERMINATED;
return;
}
if self.check_phrase_positions(doc) {
self.current_doc = doc;
return;
}
self.posting_iters[0].advance();
}
}
fn find_next_and_match(&mut self) -> DocId {
if self.posting_iters.is_empty() {
return TERMINATED;
}
loop {
let max_doc = self.posting_iters.iter().map(|it| it.doc()).max().unwrap();
if max_doc == TERMINATED {
return TERMINATED;
}
let mut all_match = true;
for it in &mut self.posting_iters {
let doc = it.seek(max_doc);
if doc != max_doc {
all_match = false;
if doc == TERMINATED {
return TERMINATED;
}
}
}
if all_match {
return max_doc;
}
}
}
fn check_phrase_positions(&mut self, doc_id: DocId) -> bool {
for (i, pos_list) in self.position_lists.iter().enumerate() {
if !pos_list.get_positions_into(doc_id, &mut self.position_bufs[i]) {
return false;
}
}
self.find_phrase_match_from_bufs()
}
fn find_phrase_match_from_bufs(&self) -> bool {
if self.position_bufs.is_empty() || self.position_bufs[0].is_empty() {
return false;
}
for &first_pos in &self.position_bufs[0] {
if self.check_phrase_from_position(first_pos, &self.position_bufs) {
return true;
}
}
false
}
fn check_phrase_from_position(&self, start_pos: u32, term_positions: &[Vec<u32>]) -> bool {
let mut expected_pos = start_pos;
for (i, positions) in term_positions.iter().enumerate() {
if i == 0 {
continue; }
expected_pos += 1;
let found = positions.iter().any(|&pos| {
if self.slop == 0 {
pos == expected_pos
} else {
let diff = pos.abs_diff(expected_pos);
diff <= self.slop
}
});
if !found {
return false;
}
}
true
}
}
impl super::docset::DocSet for PhraseScorer {
fn doc(&self) -> DocId {
self.current_doc
}
fn advance(&mut self) -> DocId {
if self.current_doc == TERMINATED {
return TERMINATED;
}
self.posting_iters[0].advance();
self.find_next_phrase_match();
self.current_doc
}
fn seek(&mut self, target: DocId) -> DocId {
if target == TERMINATED {
self.current_doc = TERMINATED;
return TERMINATED;
}
self.posting_iters[0].seek(target);
self.find_next_phrase_match();
self.current_doc
}
fn size_hint(&self) -> u32 {
0
}
}
impl Scorer for PhraseScorer {
fn score(&self) -> Score {
if self.current_doc == TERMINATED {
return 0.0;
}
let tf: f32 = self
.posting_iters
.iter()
.map(|it| it.term_freq() as f32)
.sum();
super::bm25_score(tf, self.idf, tf, self.avg_field_len) * 1.5
}
}