use crate::error::Result;
use crate::lexical::query::Query;
use crate::lexical::query::matcher::{
AllMatcher, ConjunctionMatcher, ConjunctionNotMatcher, DisjunctionMatcher, EmptyMatcher,
Matcher, NotMatcher,
};
use crate::lexical::query::scorer::{BM25Scorer, Scorer};
use crate::lexical::reader::LexicalIndexReader;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Occur {
Must,
Should,
MustNot,
Filter,
}
#[derive(Debug)]
pub struct BooleanClause {
pub query: Box<dyn Query>,
pub occur: Occur,
}
impl Clone for BooleanClause {
fn clone(&self) -> Self {
BooleanClause {
query: self.query.clone_box(),
occur: self.occur,
}
}
}
impl BooleanClause {
pub fn new(query: Box<dyn Query>, occur: Occur) -> Self {
BooleanClause { query, occur }
}
pub fn must(query: Box<dyn Query>) -> Self {
BooleanClause::new(query, Occur::Must)
}
pub fn should(query: Box<dyn Query>) -> Self {
BooleanClause::new(query, Occur::Should)
}
pub fn must_not(query: Box<dyn Query>) -> Self {
BooleanClause::new(query, Occur::MustNot)
}
pub fn filter(query: Box<dyn Query>) -> Self {
BooleanClause::new(query, Occur::Filter)
}
}
#[derive(Debug)]
pub struct BooleanQuery {
clauses: Vec<BooleanClause>,
boost: f32,
minimum_should_match: usize,
}
impl BooleanQuery {
pub fn new() -> Self {
BooleanQuery {
clauses: Vec::new(),
boost: 1.0,
minimum_should_match: 0,
}
}
pub fn add_clause(&mut self, clause: BooleanClause) {
self.clauses.push(clause);
}
pub fn add_must(&mut self, query: Box<dyn Query>) {
self.add_clause(BooleanClause::must(query));
}
pub fn add_should(&mut self, query: Box<dyn Query>) {
self.add_clause(BooleanClause::should(query));
}
pub fn add_must_not(&mut self, query: Box<dyn Query>) {
self.add_clause(BooleanClause::must_not(query));
}
pub fn add_filter(&mut self, query: Box<dyn Query>) {
self.add_clause(BooleanClause::filter(query));
}
pub fn with_boost(mut self, boost: f32) -> Self {
self.boost = boost;
self
}
pub fn with_minimum_should_match(mut self, minimum: usize) -> Self {
self.minimum_should_match = minimum;
self
}
pub fn clauses(&self) -> &[BooleanClause] {
&self.clauses
}
pub fn minimum_should_match(&self) -> usize {
self.minimum_should_match
}
pub fn is_empty(&self) -> bool {
self.clauses.is_empty()
}
pub fn clauses_by_occur(&self, occur: Occur) -> Vec<&BooleanClause> {
self.clauses.iter().filter(|c| c.occur == occur).collect()
}
}
impl Default for BooleanQuery {
fn default() -> Self {
Self::new()
}
}
impl Clone for BooleanQuery {
fn clone(&self) -> Self {
BooleanQuery {
clauses: self
.clauses
.iter()
.map(|c| BooleanClause {
query: c.query.clone_box(),
occur: c.occur,
})
.collect(),
boost: self.boost,
minimum_should_match: self.minimum_should_match,
}
}
}
impl Query for BooleanQuery {
fn matcher(&self, reader: &dyn LexicalIndexReader) -> Result<Box<dyn Matcher>> {
if self.clauses.is_empty() {
return Ok(Box::new(EmptyMatcher::new()));
}
let must_clauses = self.clauses_by_occur(Occur::Must);
let filter_clauses = self.clauses_by_occur(Occur::Filter);
let should_clauses = self.clauses_by_occur(Occur::Should);
let must_not_clauses = self.clauses_by_occur(Occur::MustNot);
let mut required_clauses: Vec<&BooleanClause> = Vec::new();
required_clauses.extend(&must_clauses);
required_clauses.extend(&filter_clauses);
if !required_clauses.is_empty()
|| (!must_not_clauses.is_empty() && should_clauses.is_empty())
{
let mut positive_matcher = if !required_clauses.is_empty() {
if required_clauses.len() == 1 {
required_clauses[0].query.matcher(reader)?
} else {
let mut matchers = Vec::new();
for clause in &required_clauses {
let matcher = clause.query.matcher(reader)?;
if matcher.is_exhausted() {
return Ok(Box::new(EmptyMatcher::new()));
}
matchers.push(matcher);
}
Box::new(ConjunctionMatcher::new(matchers))
}
} else {
Box::new(AllMatcher::new(reader.max_doc()))
};
if self.minimum_should_match > 0 && !should_clauses.is_empty() {
let mut should_matchers = Vec::new();
for clause in &should_clauses {
let matcher = clause.query.matcher(reader)?;
if !matcher.is_exhausted() {
should_matchers.push(matcher);
}
}
if !should_matchers.is_empty() {
let should_matcher = if should_matchers.len() == 1 {
should_matchers.into_iter().next().unwrap()
} else {
Box::new(DisjunctionMatcher::new(should_matchers))
};
positive_matcher = Box::new(ConjunctionMatcher::new(vec![
positive_matcher,
should_matcher,
]));
} else if !required_clauses.is_empty() {
return Ok(Box::new(EmptyMatcher::new()));
}
}
if !must_not_clauses.is_empty() {
let mut negative_matchers = Vec::new();
for clause in &must_not_clauses {
let matcher = clause.query.matcher(reader)?;
if !matcher.is_exhausted() {
negative_matchers.push(matcher);
}
}
if !negative_matchers.is_empty() {
if required_clauses.is_empty() {
if negative_matchers.len() == 1 {
Ok(Box::new(NotMatcher::new(
negative_matchers.into_iter().next().unwrap(),
reader.max_doc(),
)))
} else {
let combined_negatives =
Box::new(DisjunctionMatcher::new(negative_matchers));
Ok(Box::new(NotMatcher::new(
combined_negatives,
reader.max_doc(),
)))
}
} else {
Ok(Box::new(ConjunctionNotMatcher::new(
positive_matcher,
negative_matchers,
)))
}
} else {
Ok(positive_matcher)
}
} else {
Ok(positive_matcher)
}
} else if !should_clauses.is_empty() {
let mut should_matchers = Vec::new();
for clause in &should_clauses {
let matcher = clause.query.matcher(reader)?;
if !matcher.is_exhausted() {
should_matchers.push(matcher);
}
}
if should_matchers.is_empty() {
return Ok(Box::new(EmptyMatcher::new()));
}
let positive_matcher = if should_matchers.len() == 1 {
should_matchers.into_iter().next().unwrap()
} else {
Box::new(DisjunctionMatcher::new(should_matchers))
};
if !must_not_clauses.is_empty() {
let mut negative_matchers = Vec::new();
for clause in &must_not_clauses {
let matcher = clause.query.matcher(reader)?;
if !matcher.is_exhausted() {
negative_matchers.push(matcher);
}
}
if !negative_matchers.is_empty() {
Ok(Box::new(ConjunctionNotMatcher::new(
positive_matcher,
negative_matchers,
)))
} else {
Ok(positive_matcher)
}
} else {
Ok(positive_matcher)
}
} else {
Ok(Box::new(EmptyMatcher::new()))
}
}
fn scorer(&self, reader: &dyn LexicalIndexReader) -> Result<Box<dyn Scorer>> {
use crate::lexical::query::scorer::BooleanScorer;
let mut sub_queries = Vec::new();
for clause in &self.clauses {
if clause.occur == Occur::Must || clause.occur == Occur::Should {
sub_queries.push(clause.query.clone_box());
}
}
if sub_queries.is_empty() {
let scorer = BM25Scorer::new(
1,
1,
reader.doc_count(),
10.0,
reader.doc_count(),
self.boost,
);
return Ok(Box::new(scorer));
}
let mut boolean_scorer = BooleanScorer::new(reader, sub_queries)?;
boolean_scorer.set_boost(self.boost);
Ok(Box::new(boolean_scorer))
}
fn boost(&self) -> f32 {
self.boost
}
fn set_boost(&mut self, boost: f32) {
self.boost = boost;
}
fn description(&self) -> String {
if self.clauses.is_empty() {
return "()".to_string();
}
let mut parts = Vec::new();
for clause in &self.clauses {
let clause_desc = match clause.occur {
Occur::Must => format!("+{}", clause.query.description()),
Occur::Should => clause.query.description(),
Occur::MustNot => format!("-{}", clause.query.description()),
Occur::Filter => format!("#{}", clause.query.description()),
};
parts.push(clause_desc);
}
let result = format!("({})", parts.join(" "));
if self.boost == 1.0 {
result
} else {
format!("{}^{}", result, self.boost)
}
}
fn clone_box(&self) -> Box<dyn Query> {
Box::new(self.clone())
}
fn is_empty(&self, reader: &dyn LexicalIndexReader) -> Result<bool> {
if self.clauses.is_empty() {
return Ok(true);
}
for clause in &self.clauses {
if !clause.query.is_empty(reader)? {
return Ok(false);
}
}
Ok(true)
}
fn cost(&self, reader: &dyn LexicalIndexReader) -> Result<u64> {
let mut total_cost = 0;
for clause in &self.clauses {
total_cost += clause.query.cost(reader)?;
}
Ok(total_cost)
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn apply_field_boosts(&mut self, boosts: &std::collections::HashMap<String, f32>) {
if let Some(f) = self.field()
&& let Some(&b) = boosts.get(f)
{
self.set_boost(self.boost() * b);
}
for clause in &mut self.clauses {
clause.query.apply_field_boosts(boosts);
}
}
}
#[derive(Debug)]
pub struct BooleanQueryBuilder {
query: BooleanQuery,
}
impl BooleanQueryBuilder {
pub fn new() -> Self {
BooleanQueryBuilder {
query: BooleanQuery::new(),
}
}
pub fn must(mut self, query: Box<dyn Query>) -> Self {
self.query.add_must(query);
self
}
pub fn should(mut self, query: Box<dyn Query>) -> Self {
self.query.add_should(query);
self
}
pub fn must_not(mut self, query: Box<dyn Query>) -> Self {
self.query.add_must_not(query);
self
}
pub fn filter(mut self, query: Box<dyn Query>) -> Self {
self.query.add_filter(query);
self
}
pub fn boost(mut self, boost: f32) -> Self {
self.query = self.query.with_boost(boost);
self
}
pub fn minimum_should_match(mut self, minimum: usize) -> Self {
self.query = self.query.with_minimum_should_match(minimum);
self
}
pub fn build(self) -> BooleanQuery {
self.query
}
}
impl Default for BooleanQueryBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::lexical::index::inverted::reader::{InvertedIndexReader, InvertedIndexReaderConfig};
use crate::lexical::query::term::TermQuery;
use crate::storage::memory::MemoryStorage;
use crate::storage::memory::MemoryStorageConfig;
use std::sync::Arc;
#[allow(dead_code)]
#[test]
fn test_boolean_query_creation() {
let query = BooleanQuery::new();
assert!(query.is_empty());
assert_eq!(query.clauses().len(), 0);
assert_eq!(query.boost(), 1.0);
assert_eq!(query.minimum_should_match(), 0);
}
#[test]
fn test_boolean_query_clauses() {
let mut query = BooleanQuery::new();
query.add_must(Box::new(TermQuery::new("title", "hello")));
query.add_should(Box::new(TermQuery::new("body", "world")));
query.add_must_not(Box::new(TermQuery::new("title", "spam")));
assert_eq!(query.clauses().len(), 3);
assert!(!query.is_empty());
let must_clauses = query.clauses_by_occur(Occur::Must);
let should_clauses = query.clauses_by_occur(Occur::Should);
let must_not_clauses = query.clauses_by_occur(Occur::MustNot);
assert_eq!(must_clauses.len(), 1);
assert_eq!(should_clauses.len(), 1);
assert_eq!(must_not_clauses.len(), 1);
}
#[test]
fn test_boolean_query_builder() {
let query = BooleanQueryBuilder::new()
.must(Box::new(TermQuery::new("title", "hello")))
.should(Box::new(TermQuery::new("body", "world")))
.must_not(Box::new(TermQuery::new("title", "spam")))
.boost(2.0)
.minimum_should_match(1)
.build();
assert_eq!(query.clauses().len(), 3);
assert_eq!(query.boost(), 2.0);
assert_eq!(query.minimum_should_match(), 1);
}
#[test]
fn test_boolean_query_description() {
let query = BooleanQueryBuilder::new()
.must(Box::new(TermQuery::new("title", "hello")))
.should(Box::new(TermQuery::new("body", "world")))
.must_not(Box::new(TermQuery::new("title", "spam")))
.build();
let desc = query.description();
assert!(desc.contains("+title:hello"));
assert!(desc.contains("body:world"));
assert!(desc.contains("-title:spam"));
}
#[test]
fn test_boolean_query_matcher() {
let storage = Arc::new(MemoryStorage::new(MemoryStorageConfig::default()));
let reader =
InvertedIndexReader::new(vec![], storage, InvertedIndexReaderConfig::default())
.unwrap();
let query = BooleanQueryBuilder::new()
.must(Box::new(TermQuery::new("title", "hello")))
.build();
let matcher = query.matcher(&reader).unwrap();
assert!(matcher.is_exhausted() || matcher.doc_id() != u64::MAX);
}
#[test]
fn test_boolean_query_scorer() {
let storage = Arc::new(MemoryStorage::new(MemoryStorageConfig::default()));
let reader =
InvertedIndexReader::new(vec![], storage, InvertedIndexReaderConfig::default())
.unwrap();
let query = BooleanQueryBuilder::new()
.must(Box::new(TermQuery::new("title", "hello")))
.build();
let scorer = query.scorer(&reader).unwrap();
assert!(scorer.score(0, 1.0, None) >= 0.0);
}
#[test]
fn test_boolean_clause_creation() {
let query = Box::new(TermQuery::new("title", "hello"));
let must_clause = BooleanClause::must(query.clone_box());
assert_eq!(must_clause.occur, Occur::Must);
let should_clause = BooleanClause::should(query.clone_box());
assert_eq!(should_clause.occur, Occur::Should);
let must_not_clause = BooleanClause::must_not(query.clone_box());
assert_eq!(must_not_clause.occur, Occur::MustNot);
}
#[test]
fn test_empty_boolean_query() {
let storage = Arc::new(MemoryStorage::new(MemoryStorageConfig::default()));
let reader =
InvertedIndexReader::new(vec![], storage, InvertedIndexReaderConfig::default())
.unwrap();
let query = BooleanQuery::new();
assert!(query.is_empty());
assert_eq!(query.cost(&reader).unwrap(), 0);
let matcher = query.matcher(&reader).unwrap();
assert!(matcher.is_exhausted());
}
}