use serde::{Deserialize, Serialize};
use crate::error::Result;
use crate::lexical::index::inverted::core::terms::TermsEnum;
use crate::lexical::query::Query;
use crate::lexical::reader::LexicalIndexReader;
pub trait MultiTermQuery: Query {
fn field(&self) -> &str;
fn rewrite_method(&self) -> RewriteMethod;
fn enumerate_terms(&self, reader: &dyn LexicalIndexReader) -> Result<Vec<(String, u64, f32)>>;
fn max_expansions(&self) -> usize {
50
}
fn get_terms_enum(
&self,
_reader: &dyn LexicalIndexReader,
) -> Result<Option<Box<dyn TermsEnum>>> {
Ok(None)
}
fn rewrite(&self, reader: &dyn LexicalIndexReader) -> Result<Box<dyn Query>> {
let rewrite_method = self.rewrite_method();
let terms_enum_opt = self.get_terms_enum(reader)?;
let matching_terms = if let Some(mut terms_enum) = terms_enum_opt {
match rewrite_method {
RewriteMethod::TopTermsScoring { max_expansions }
| RewriteMethod::TopTermsBlended { max_expansions } => {
collect_top_terms(&mut *terms_enum, max_expansions)?
}
_ => {
let mut terms = Vec::new();
while let Some(term_stats) = terms_enum.next()? {
terms.push((term_stats.term.clone(), term_stats.doc_freq, 1.0));
}
terms
}
}
} else {
self.enumerate_terms(reader)?
};
if matching_terms.is_empty() {
use crate::lexical::query::boolean::BooleanQuery;
return Ok(Box::new(BooleanQuery::new()));
}
use crate::lexical::query::boolean::{BooleanClause, BooleanQuery, Occur};
use crate::lexical::query::term::TermQuery;
let mut boolean_query = BooleanQuery::new();
boolean_query.set_boost(self.boost());
match rewrite_method {
RewriteMethod::TopTermsScoring { .. } => {
for (term, _, _) in matching_terms {
let term_query = TermQuery::new(MultiTermQuery::field(self).to_string(), term);
boolean_query
.add_clause(BooleanClause::new(Box::new(term_query), Occur::Should));
}
}
RewriteMethod::TopTermsBlended { .. } => {
for (term, _, _) in matching_terms {
let term_query = TermQuery::new(MultiTermQuery::field(self).to_string(), term);
boolean_query
.add_clause(BooleanClause::new(Box::new(term_query), Occur::Should));
}
}
RewriteMethod::ConstantScore => {
for (term, _, _) in matching_terms {
let term_query = TermQuery::new(MultiTermQuery::field(self).to_string(), term);
boolean_query
.add_clause(BooleanClause::new(Box::new(term_query), Occur::Should));
}
}
RewriteMethod::BooleanQuery => {
for (term, _, _) in matching_terms {
let term_query = TermQuery::new(MultiTermQuery::field(self).to_string(), term);
boolean_query
.add_clause(BooleanClause::new(Box::new(term_query), Occur::Should));
}
}
}
Ok(Box::new(boolean_query))
}
}
#[derive(PartialEq)]
struct ScoredTerm {
term: String,
doc_freq: u64,
boost: f32,
}
impl Eq for ScoredTerm {}
impl PartialOrd for ScoredTerm {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for ScoredTerm {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
other
.doc_freq
.cmp(&self.doc_freq)
.then_with(|| self.term.cmp(&other.term)) }
}
fn collect_top_terms(
terms_enum: &mut dyn TermsEnum,
max_expansions: usize,
) -> Result<Vec<(String, u64, f32)>> {
use std::cmp::Reverse;
use std::collections::BinaryHeap;
let mut heap = BinaryHeap::with_capacity(max_expansions + 1);
while let Some(term_stats) = terms_enum.next()? {
let scored_term = ScoredTerm {
term: term_stats.term.clone(),
doc_freq: term_stats.doc_freq,
boost: 1.0,
};
heap.push(Reverse(scored_term));
if heap.len() > max_expansions {
heap.pop(); }
}
let mut results = Vec::with_capacity(heap.len());
while let Some(Reverse(scored_term)) = heap.pop() {
results.push((scored_term.term, scored_term.doc_freq, scored_term.boost));
}
Ok(results)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RewriteMethod {
TopTermsScoring { max_expansions: usize },
TopTermsBlended { max_expansions: usize },
ConstantScore,
BooleanQuery,
}
impl Default for RewriteMethod {
fn default() -> Self {
RewriteMethod::TopTermsBlended { max_expansions: 50 }
}
}
impl RewriteMethod {
pub fn max_expansions(&self) -> Option<usize> {
match self {
RewriteMethod::TopTermsScoring { max_expansions } => Some(*max_expansions),
RewriteMethod::TopTermsBlended { max_expansions } => Some(*max_expansions),
RewriteMethod::ConstantScore => None,
RewriteMethod::BooleanQuery => None,
}
}
pub fn is_constant_score(&self) -> bool {
matches!(self, RewriteMethod::ConstantScore)
}
pub fn is_top_terms(&self) -> bool {
matches!(
self,
RewriteMethod::TopTermsScoring { .. } | RewriteMethod::TopTermsBlended { .. }
)
}
}