use std::ops::Bound;
use std::ops::Bound::Unbounded;
use std::str::FromStr;
#[cfg(feature = "metrics")]
use opentelemetry::metrics::Counter;
#[cfg(feature = "metrics")]
use opentelemetry::{global, KeyValue};
use summa_proto::proto;
use tantivy::query::{
AllQuery, BooleanQuery, BoostQuery, DisjunctionMaxQuery, EmptyQuery, MoreLikeThisQuery, Occur, PhraseQuery, Query, RangeQuery, RegexQuery, TermQuery,
};
use tantivy::schema::{Field, FieldEntry, FieldType, IndexRecordOption, Schema};
use tantivy::{Document, Index, Score, TantivyDocument, Term};
use tracing::info;
use crate::components::queries::ExistsQuery;
use crate::components::query_parser::morphology::MorphologyManager;
use crate::components::query_parser::utils::cast_field_to_typed_term;
use crate::components::query_parser::{QueryParser, QueryParserError};
use crate::configs::core::QueryParserConfig;
use crate::errors::{Error, SummaResult, ValidationError};
#[cfg(feature = "metrics")]
use crate::metrics::ToLabel;
#[derive(Clone)]
pub struct ProtoQueryParser {
index: Index,
cached_schema: Schema,
#[cfg(feature = "metrics")]
query_counter: Counter<u64>,
#[cfg(feature = "metrics")]
subquery_counter: Counter<u64>,
query_parser_config: QueryParserConfig,
morphology_manager: MorphologyManager,
}
pub enum QueryParserDefaultMode {
Boolean,
DisjuctionMax { tie_breaker: Score },
}
impl From<Option<proto::query_parser_config::DefaultMode>> for QueryParserDefaultMode {
fn from(value: Option<proto::query_parser_config::DefaultMode>) -> Self {
match value {
Some(proto::query_parser_config::DefaultMode::BooleanShouldMode(_)) | None => QueryParserDefaultMode::Boolean,
Some(proto::query_parser_config::DefaultMode::DisjuctionMaxMode(proto::MatchQueryDisjuctionMaxMode { tie_breaker })) => {
QueryParserDefaultMode::DisjuctionMax { tie_breaker }
}
}
}
}
fn cast_value_to_bound_term(field: &Field, full_path: &str, field_type: &FieldType, value: &str, including: bool) -> SummaResult<Bound<Term>> {
Ok(match value {
"*" => Unbounded,
value => {
let casted_value = cast_field_to_typed_term(field, full_path, field_type, value)?;
if including {
Bound::Included(casted_value)
} else {
Bound::Excluded(casted_value)
}
}
})
}
impl ProtoQueryParser {
pub fn for_index(index: &Index, query_parser_config: proto::QueryParserConfig) -> SummaResult<ProtoQueryParser> {
#[cfg(feature = "metrics")]
let query_counter = global::meter("summa").u64_counter("query_counter").with_description("Queries counter").init();
#[cfg(feature = "metrics")]
let subquery_counter = global::meter("summa")
.u64_counter("subquery_counter")
.with_description("Sub-queries counter")
.init();
Ok(ProtoQueryParser {
index: index.clone(),
cached_schema: index.schema(),
#[cfg(feature = "metrics")]
query_counter,
#[cfg(feature = "metrics")]
subquery_counter,
query_parser_config: QueryParserConfig(query_parser_config),
morphology_manager: MorphologyManager::default(),
})
}
pub fn resolve_field_name<'a>(&'a self, field_name: &'a str) -> &str {
self.query_parser_config
.0
.field_aliases
.get(field_name)
.map(|s| s.as_str())
.unwrap_or(field_name)
}
#[inline]
pub(crate) fn field_and_field_entry<'a>(&'a self, field_name: &'a str) -> SummaResult<(Field, &str, &FieldEntry)> {
match self.cached_schema.find_field(self.resolve_field_name(field_name)) {
Some((field, full_path)) => {
let field_entry = self.cached_schema.get_field_entry(field);
Ok((field, full_path, field_entry))
}
None => Err(ValidationError::MissingField(field_name.to_string()).into()),
}
}
fn parse_subquery(&self, query: proto::query::Query) -> SummaResult<Box<dyn Query>> {
#[cfg(feature = "metrics")]
self.subquery_counter.add(1, &[KeyValue::new("query", query.to_label())]);
Ok(match query {
proto::query::Query::All(_) => Box::new(AllQuery),
proto::query::Query::Empty(_) => Box::new(EmptyQuery),
proto::query::Query::Boolean(boolean_query_proto) => {
let mut subqueries = vec![];
for subquery in boolean_query_proto.subqueries {
subqueries.push((
match subquery.occur() {
proto::Occur::Should => Occur::Should,
proto::Occur::Must => Occur::Must,
proto::Occur::MustNot => Occur::MustNot,
},
self.parse_subquery(subquery.query.and_then(|query| query.query).ok_or(Error::EmptyQuery)?)?,
))
}
Box::new(BooleanQuery::new(subqueries))
}
proto::query::Query::DisjunctionMax(disjunction_max_proto) => Box::new(DisjunctionMaxQuery::with_tie_breaker(
disjunction_max_proto
.disjuncts
.into_iter()
.map(|disjunct| self.parse_subquery(disjunct.query.ok_or(Error::EmptyQuery)?))
.collect::<SummaResult<Vec<_>>>()?,
match disjunction_max_proto.tie_breaker.as_str() {
"" => 0.0,
s => f32::from_str(s).map_err(|_e| Error::InvalidSyntax(format!("cannot parse {} as f32", disjunction_max_proto.tie_breaker)))?,
},
)),
proto::query::Query::Match(match_query_proto) => {
let mut new_query_parser_config = self.query_parser_config.clone();
if let Some(query_parser_config) = match_query_proto.query_parser_config {
new_query_parser_config.merge(QueryParserConfig(query_parser_config));
}
let nested_query_parser = QueryParser::for_index(&self.index, new_query_parser_config.clone(), &self.morphology_manager)?;
match nested_query_parser.parse_query(&match_query_proto.value) {
Ok(parsed_query) => {
info!(query = ?match_query_proto.value, parsed_match_query = ?parsed_query, query_parser_config = ?new_query_parser_config);
Ok(parsed_query)
}
Err(QueryParserError::FieldDoesNotExist(field)) => Err(ValidationError::MissingField(field).into()),
Err(e) => Err(Error::InvalidQuerySyntax(Box::new(e), match_query_proto.value.to_owned())),
}?
}
proto::query::Query::Range(range_query_proto) => {
let (field, full_path, field_entry) = self.field_and_field_entry(&range_query_proto.field)?;
let value = range_query_proto.value.as_ref().ok_or(ValidationError::MissingRange)?;
let left = cast_value_to_bound_term(&field, full_path, field_entry.field_type(), &value.left, value.including_left)?;
let right = cast_value_to_bound_term(&field, full_path, field_entry.field_type(), &value.right, value.including_right)?;
Box::new(RangeQuery::new_term_bounds(
range_query_proto.field.clone(),
field_entry.field_type().value_type(),
&left,
&right,
))
}
proto::query::Query::Boost(boost_query_proto) => Box::new(BoostQuery::new(
self.parse_subquery(boost_query_proto.query.and_then(|query| query.query).ok_or(Error::EmptyQuery)?)?,
f32::from_str(&boost_query_proto.score).map_err(|_e| Error::InvalidSyntax(format!("cannot parse {} as f32", boost_query_proto.score)))?,
)),
proto::query::Query::Regex(regex_query_proto) => {
let (field, _, _) = self.field_and_field_entry(®ex_query_proto.field)?;
Box::new(RegexQuery::from_pattern(®ex_query_proto.value, field)?)
}
proto::query::Query::Phrase(phrase_query_proto) => {
let (field, full_path, field_entry) = self.field_and_field_entry(&phrase_query_proto.field)?;
let mut tokenizer = self.index.tokenizer_for_field(field)?;
let mut token_stream = tokenizer.token_stream(&phrase_query_proto.value);
let mut terms: Vec<(usize, Term)> = vec![];
while let Some(token) = token_stream.next() {
terms.push((
token.position,
cast_field_to_typed_term(&field, full_path, field_entry.field_type(), &token.text)?,
))
}
if terms.is_empty() {
Box::new(EmptyQuery)
} else if terms.len() == 1 {
Box::new(TermQuery::new(
terms[0].1.clone(),
field_entry.field_type().index_record_option().unwrap_or(IndexRecordOption::Basic),
))
} else {
Box::new(PhraseQuery::new_with_offset_and_slop(terms, phrase_query_proto.slop))
}
}
proto::query::Query::Term(term_query_proto) => {
let (field, full_path, field_entry) = self.field_and_field_entry(&term_query_proto.field)?;
let value = term_query_proto.value.to_lowercase();
Box::new(TermQuery::new(
cast_field_to_typed_term(&field, full_path, field_entry.field_type(), &value)?,
field_entry.field_type().index_record_option().unwrap_or(IndexRecordOption::Basic),
))
}
proto::query::Query::MoreLikeThis(more_like_this_query_proto) => {
let document = TantivyDocument::parse_json(&self.cached_schema, &more_like_this_query_proto.document)
.map_err(|_e| Error::InvalidSyntax("bad document".to_owned()))?;
let field_values = document
.get_sorted_field_values()
.into_iter()
.map(|(field, field_values)| (field, field_values.into_iter().cloned().collect()))
.collect();
let mut query_builder = MoreLikeThisQuery::builder();
if let Some(min_doc_frequency) = more_like_this_query_proto.min_doc_frequency {
query_builder = query_builder.with_min_doc_frequency(min_doc_frequency);
}
if let Some(max_doc_frequency) = more_like_this_query_proto.max_doc_frequency {
query_builder = query_builder.with_max_doc_frequency(max_doc_frequency);
}
if let Some(min_term_frequency) = more_like_this_query_proto.min_term_frequency {
query_builder = query_builder.with_min_term_frequency(min_term_frequency as usize);
}
if let Some(max_query_terms) = more_like_this_query_proto.max_query_terms {
query_builder = query_builder.with_max_query_terms(max_query_terms as usize);
}
if let Some(min_word_length) = more_like_this_query_proto.min_word_length {
query_builder = query_builder.with_min_word_length(min_word_length as usize);
}
if let Some(max_word_length) = more_like_this_query_proto.max_word_length {
query_builder = query_builder.with_max_word_length(max_word_length as usize);
}
if let Some(ref boost) = more_like_this_query_proto.boost {
query_builder =
query_builder.with_boost_factor(f32::from_str(boost).map_err(|_e| Error::InvalidSyntax(format!("cannot parse {boost} as f32")))?);
}
query_builder = query_builder.with_stop_words(more_like_this_query_proto.stop_words);
Box::new(query_builder.with_document_fields(field_values))
}
proto::query::Query::Exists(exists_query_proto) => {
let (field, full_path, field_entry) = self.field_and_field_entry(&exists_query_proto.field)?;
if !field_entry.field_type().is_indexed() {
let fni = QueryParserError::FieldNotIndexed(field_entry.name().to_string());
return Err(Error::InvalidQuerySyntax(Box::new(fni), exists_query_proto.field.to_string()));
}
Box::new(ExistsQuery::new(field, full_path))
}
})
}
pub fn parse_query(&self, query: proto::query::Query) -> SummaResult<Box<dyn Query>> {
#[cfg(feature = "metrics")]
self.query_counter.add(1, &[KeyValue::new("query", query.to_label())]);
self.parse_subquery(query)
}
}