use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use crate::error::Result;
use crate::lexical::query::Query;
use crate::lexical::query::QueryResult;
use crate::lexical::query::boolean::{BooleanQuery, Occur};
use crate::lexical::query::matcher::Matcher;
use crate::lexical::query::scorer::Scorer;
use crate::lexical::reader::LexicalIndexReader;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AdvancedQueryConfig {
pub enable_optimization: bool,
pub max_clause_count: usize,
pub enable_caching: bool,
pub timeout_ms: u64,
pub enable_early_termination: bool,
pub min_score: f32,
}
impl Default for AdvancedQueryConfig {
fn default() -> Self {
AdvancedQueryConfig {
enable_optimization: true,
max_clause_count: 1024,
enable_caching: true,
timeout_ms: 30000, enable_early_termination: true,
min_score: 0.0,
}
}
}
#[derive(Debug)]
pub struct AdvancedQuery {
core_query: Box<dyn Query>,
field_boosts: HashMap<String, f32>,
boost: f32,
min_score: f32,
filters: Vec<Box<dyn Query>>,
negative_filters: Vec<Box<dyn Query>>,
post_filters: Vec<Box<dyn Query>>,
config: AdvancedQueryConfig,
}
impl AdvancedQuery {
pub fn new(core_query: Box<dyn Query>) -> Self {
AdvancedQuery {
core_query,
field_boosts: HashMap::new(),
boost: 1.0,
min_score: 0.0,
filters: Vec::new(),
negative_filters: Vec::new(),
post_filters: Vec::new(),
config: AdvancedQueryConfig::default(),
}
}
pub fn add_field_boost(mut self, field: String, boost: f32) -> Self {
self.field_boosts.insert(field, boost);
self
}
pub fn with_boost(mut self, boost: f32) -> Self {
self.boost = boost;
self
}
pub fn with_min_score(mut self, min_score: f32) -> Self {
self.min_score = min_score;
self
}
pub fn with_filter(mut self, filter: Box<dyn Query>) -> Self {
self.filters.push(filter);
self
}
pub fn with_negative_filter(mut self, filter: Box<dyn Query>) -> Self {
self.negative_filters.push(filter);
self
}
pub fn with_post_filter(mut self, filter: Box<dyn Query>) -> Self {
self.post_filters.push(filter);
self
}
pub fn with_config(mut self, config: AdvancedQueryConfig) -> Self {
self.config = config;
self
}
pub fn optimize(&mut self) -> Result<()> {
if !self.config.enable_optimization {
return Ok(());
}
if !self.filters.is_empty() || !self.negative_filters.is_empty() {
let mut boolean_builder = BooleanQueryBuilder::new();
boolean_builder = boolean_builder.add_clause(self.core_query.clone_box(), Occur::Must);
for filter in &self.filters {
boolean_builder = boolean_builder.add_clause(filter.clone_box(), Occur::Filter);
}
for neg_filter in &self.negative_filters {
boolean_builder =
boolean_builder.add_clause(neg_filter.clone_box(), Occur::MustNot);
}
self.core_query = Box::new(boolean_builder.build());
self.filters.clear();
self.negative_filters.clear();
}
Ok(())
}
pub fn execute(&mut self, reader: &dyn LexicalIndexReader) -> Result<Vec<QueryResult>> {
self.optimize()?;
let matcher = self.core_query.matcher(reader)?;
let scorer = self.core_query.scorer(reader)?;
let mut results = Vec::new();
let start_time = std::time::Instant::now();
let mut matcher = matcher;
while matcher.next()? {
if self.config.timeout_ms > 0
&& start_time.elapsed().as_millis() > self.config.timeout_ms as u128
{
break;
}
let doc_id = matcher.doc_id();
let mut score = scorer.score(doc_id, matcher.term_freq() as f32, None);
score *= self.boost;
if score < self.min_score.max(self.config.min_score) {
continue;
}
if !self.apply_post_filters(doc_id, reader)? {
continue;
}
results.push(QueryResult { doc_id, score });
if self.config.enable_early_termination && results.len() > 10000 {
break;
}
}
results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
Ok(results)
}
fn apply_post_filters(&self, doc_id: u64, reader: &dyn LexicalIndexReader) -> Result<bool> {
for filter in &self.post_filters {
let mut matcher = filter.matcher(reader)?;
if !matcher.skip_to(doc_id)? || matcher.doc_id() != doc_id {
return Ok(false);
}
}
Ok(true)
}
}
impl Query for AdvancedQuery {
fn matcher(&self, reader: &dyn LexicalIndexReader) -> Result<Box<dyn Matcher>> {
self.core_query.matcher(reader)
}
fn scorer(&self, reader: &dyn LexicalIndexReader) -> Result<Box<dyn Scorer>> {
self.core_query.scorer(reader)
}
fn boost(&self) -> f32 {
self.boost
}
fn set_boost(&mut self, boost: f32) {
self.boost = boost;
}
fn description(&self) -> String {
format!(
"AdvancedQuery(core: {}, boost: {})",
self.core_query.description(),
self.boost
)
}
fn is_empty(&self, reader: &dyn LexicalIndexReader) -> Result<bool> {
self.core_query.is_empty(reader)
}
fn cost(&self, reader: &dyn LexicalIndexReader) -> Result<u64> {
let base_cost = self.core_query.cost(reader)?;
let filter_cost = self
.filters
.iter()
.map(|f| f.cost(reader))
.collect::<Result<Vec<_>>>()?
.iter()
.sum::<u64>();
Ok(base_cost + filter_cost)
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn apply_field_boosts(&mut self, boosts: &HashMap<String, f32>) {
if !self.field_boosts.is_empty() {
self.core_query.apply_field_boosts(&self.field_boosts);
}
self.core_query.apply_field_boosts(boosts);
for filter in &mut self.filters {
filter.apply_field_boosts(boosts);
}
for filter in &mut self.negative_filters {
filter.apply_field_boosts(boosts);
}
for filter in &mut self.post_filters {
filter.apply_field_boosts(boosts);
}
}
fn clone_box(&self) -> Box<dyn Query> {
Box::new(self.clone())
}
}
impl Clone for AdvancedQuery {
fn clone(&self) -> Self {
AdvancedQuery {
core_query: self.core_query.clone_box(),
field_boosts: self.field_boosts.clone(),
boost: self.boost,
min_score: self.min_score,
filters: self.filters.iter().map(|f| f.clone_box()).collect(),
negative_filters: self
.negative_filters
.iter()
.map(|f| f.clone_box())
.collect(),
post_filters: self.post_filters.iter().map(|f| f.clone_box()).collect(),
config: self.config.clone(),
}
}
}
#[derive(Debug)]
pub struct BooleanQueryBuilder {
clauses: Vec<(Box<dyn Query>, Occur)>,
minimum_should_match: usize,
boost: f32,
config: AdvancedQueryConfig,
}
impl BooleanQueryBuilder {
pub fn new() -> Self {
BooleanQueryBuilder {
clauses: Vec::new(),
minimum_should_match: 0,
boost: 1.0,
config: AdvancedQueryConfig::default(),
}
}
pub fn add_clause(mut self, query: Box<dyn Query>, occur: Occur) -> Self {
self.clauses.push((query, occur));
self
}
pub fn minimum_should_match(mut self, count: usize) -> Self {
self.minimum_should_match = count;
self
}
pub fn boost(mut self, boost: f32) -> Self {
self.boost = boost;
self
}
pub fn config(mut self, config: AdvancedQueryConfig) -> Self {
self.config = config;
self
}
pub fn build(self) -> BooleanQuery {
let mut boolean_query = BooleanQuery::new();
for (query, occur) in self.clauses {
match occur {
Occur::Must => boolean_query.add_must(query),
Occur::Should => boolean_query.add_should(query),
Occur::MustNot => boolean_query.add_must_not(query),
Occur::Filter => boolean_query.add_filter(query),
}
}
if self.minimum_should_match > 0 {
boolean_query = boolean_query.with_minimum_should_match(self.minimum_should_match);
}
boolean_query.with_boost(self.boost)
}
}
impl Default for BooleanQueryBuilder {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct MultiFieldQuery {
query_text: String,
fields: HashMap<String, f32>,
query_type: MultiFieldQueryType,
tie_breaker: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum MultiFieldQueryType {
BestFields,
MostFields,
CrossFields,
Boolean,
}
impl MultiFieldQuery {
pub fn new(query_text: String) -> Self {
MultiFieldQuery {
query_text,
fields: HashMap::new(),
query_type: MultiFieldQueryType::BestFields,
tie_breaker: 0.0,
}
}
pub fn add_field(mut self, field: String, boost: f32) -> Self {
self.fields.insert(field, boost);
self
}
pub fn query_type(mut self, query_type: MultiFieldQueryType) -> Self {
self.query_type = query_type;
self
}
pub fn tie_breaker(mut self, tie_breaker: f32) -> Self {
self.tie_breaker = tie_breaker;
self
}
}
impl Query for MultiFieldQuery {
fn matcher(&self, reader: &dyn LexicalIndexReader) -> Result<Box<dyn Matcher>> {
let mut boolean_builder = BooleanQueryBuilder::new();
match self.query_type {
MultiFieldQueryType::BestFields | MultiFieldQueryType::Boolean => {
for field in self.fields.keys() {
let term_query = crate::lexical::query::term::TermQuery::new(
field.clone(),
self.query_text.clone(),
);
boolean_builder =
boolean_builder.add_clause(Box::new(term_query), Occur::Should);
}
}
MultiFieldQueryType::MostFields => {
for field in self.fields.keys() {
let term_query = crate::lexical::query::term::TermQuery::new(
field.clone(),
self.query_text.clone(),
);
boolean_builder = boolean_builder.add_clause(Box::new(term_query), Occur::Must);
}
}
MultiFieldQueryType::CrossFields => {
let mut combined_query = BooleanQuery::new();
for field in self.fields.keys() {
let term_query = crate::lexical::query::term::TermQuery::new(
field.clone(),
self.query_text.clone(),
);
combined_query.add_should(Box::new(term_query));
}
return combined_query.matcher(reader);
}
}
boolean_builder.build().matcher(reader)
}
fn scorer(&self, reader: &dyn LexicalIndexReader) -> Result<Box<dyn Scorer>> {
let mut boolean_builder = BooleanQueryBuilder::new();
match self.query_type {
MultiFieldQueryType::BestFields | MultiFieldQueryType::Boolean => {
for field in self.fields.keys() {
let term_query = crate::lexical::query::term::TermQuery::new(
field.clone(),
self.query_text.clone(),
);
boolean_builder =
boolean_builder.add_clause(Box::new(term_query), Occur::Should);
}
}
MultiFieldQueryType::MostFields => {
for field in self.fields.keys() {
let term_query = crate::lexical::query::term::TermQuery::new(
field.clone(),
self.query_text.clone(),
);
boolean_builder = boolean_builder.add_clause(Box::new(term_query), Occur::Must);
}
}
MultiFieldQueryType::CrossFields => {
let mut combined_query = BooleanQuery::new();
for field in self.fields.keys() {
let term_query = crate::lexical::query::term::TermQuery::new(
field.clone(),
self.query_text.clone(),
);
combined_query.add_should(Box::new(term_query));
}
return combined_query.scorer(reader);
}
}
boolean_builder.build().scorer(reader)
}
fn boost(&self) -> f32 {
1.0 }
fn set_boost(&mut self, _boost: f32) {
}
fn apply_field_boosts(&mut self, boosts: &HashMap<String, f32>) {
for (f, &b) in boosts {
if let Some(field_boost) = self.fields.get_mut(f) {
*field_boost *= b;
}
}
}
fn description(&self) -> String {
format!(
"MultiFieldQuery(text: {}, fields: {:?})",
self.query_text,
self.fields.keys().collect::<Vec<_>>()
)
}
fn is_empty(&self, _reader: &dyn LexicalIndexReader) -> Result<bool> {
Ok(self.query_text.is_empty() || self.fields.is_empty())
}
fn cost(&self, _reader: &dyn LexicalIndexReader) -> Result<u64> {
Ok(self.fields.len() as u64 * 100)
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn clone_box(&self) -> Box<dyn Query> {
Box::new(self.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::lexical::query::term::TermQuery;
#[allow(dead_code)]
#[test]
fn test_advanced_query_creation() {
let core_query = Box::new(TermQuery::new("title".to_string(), "test".to_string()));
let advanced_query = AdvancedQuery::new(core_query)
.with_boost(2.0)
.with_min_score(0.5)
.add_field_boost("title".to_string(), 1.5);
assert_eq!(advanced_query.boost, 2.0);
assert_eq!(advanced_query.min_score, 0.5);
assert_eq!(advanced_query.field_boosts.get("title"), Some(&1.5));
}
#[test]
fn test_boolean_query_builder() {
let builder = BooleanQueryBuilder::new()
.minimum_should_match(2)
.boost(1.5);
assert_eq!(builder.minimum_should_match, 2);
assert_eq!(builder.boost, 1.5);
}
#[test]
fn test_multi_field_query() {
let query = MultiFieldQuery::new("test query".to_string())
.add_field("title".to_string(), 2.0)
.add_field("content".to_string(), 1.0)
.query_type(MultiFieldQueryType::BestFields)
.tie_breaker(0.3);
assert_eq!(query.query_text, "test query");
assert_eq!(query.fields.len(), 2);
assert_eq!(query.tie_breaker, 0.3);
}
#[test]
fn test_advanced_query_config() {
let config = AdvancedQueryConfig {
enable_optimization: false,
max_clause_count: 500,
timeout_ms: 10000,
..Default::default()
};
assert!(!config.enable_optimization);
assert_eq!(config.max_clause_count, 500);
assert_eq!(config.timeout_ms, 10000);
}
}