use std::sync::Arc;
use pest::Parser;
use pest_derive::Parser;
use crate::analysis::analyzer::analyzer::Analyzer;
use crate::analysis::analyzer::per_field::PerFieldAnalyzer;
use crate::analysis::analyzer::standard::StandardAnalyzer;
use crate::error::{LaurusError, Result};
use crate::lexical::core::field::NumericType;
use crate::lexical::query::Query;
use crate::lexical::query::boolean::{BooleanClause, BooleanQuery, Occur};
use crate::lexical::query::fuzzy::FuzzyQuery;
use crate::lexical::query::phrase::PhraseQuery;
use crate::lexical::query::range::NumericRangeQuery;
use crate::lexical::query::term::TermQuery;
use crate::lexical::query::wildcard::WildcardQuery;
#[derive(Parser)]
#[grammar = "lexical/query/parser.pest"]
struct QueryStringParser;
pub struct LexicalQueryParser {
analyzer: Arc<dyn Analyzer>,
default_fields: Vec<String>,
default_occur: Occur,
}
impl std::fmt::Debug for LexicalQueryParser {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("QueryParser")
.field("analyzer", &self.analyzer.name())
.field("default_fields", &self.default_fields)
.field("default_occur", &self.default_occur)
.finish()
}
}
impl LexicalQueryParser {
pub fn new(analyzer: Arc<dyn Analyzer>) -> Self {
Self {
analyzer,
default_fields: Vec::new(),
default_occur: Occur::Should,
}
}
pub fn with_standard_analyzer() -> Result<Self> {
Ok(LexicalQueryParser::new(Arc::new(StandardAnalyzer::new()?)))
}
pub fn with_default_field(mut self, field: impl Into<String>) -> Self {
self.default_fields = vec![field.into()];
self
}
pub fn with_default_fields(mut self, fields: Vec<String>) -> Self {
self.default_fields = fields;
self
}
pub fn with_default_occur(mut self, occur: Occur) -> Self {
self.default_occur = occur;
self
}
pub fn default_fields(&self) -> &[String] {
&self.default_fields
}
fn create_query_over_fields<F>(&self, field: Option<&str>, creator: F) -> Result<Box<dyn Query>>
where
F: Fn(&str) -> Result<Box<dyn Query>>,
{
if let Some(field_name) = field {
return creator(field_name);
}
if self.default_fields.is_empty() {
return Err(LaurusError::parse("No field specified".to_string()));
}
if self.default_fields.len() == 1 {
return creator(&self.default_fields[0]);
}
let mut bool_query = BooleanQuery::new();
for field_name in &self.default_fields {
let q = creator(field_name)?;
bool_query.add_clause(BooleanClause::new(q, Occur::Should));
}
Ok(Box::new(bool_query))
}
pub fn parse_field(&self, field: &str, query_str: &str) -> Result<Box<dyn Query>> {
let full_query = if query_str.contains(' ') && !query_str.starts_with('"') {
let escaped = query_str.replace('"', "\\\"");
format!("{field}:\"{escaped}\"")
} else {
format!("{field}:{query_str}")
};
self.parse(&full_query)
}
pub fn parse(&self, query_str: &str) -> Result<Box<dyn Query>> {
let pairs = QueryStringParser::parse(Rule::query, query_str)
.map_err(|e| LaurusError::parse(format!("Parse error: {e}")))?;
for pair in pairs {
if pair.as_rule() == Rule::query {
for inner_pair in pair.into_inner() {
if inner_pair.as_rule() == Rule::boolean_query {
return self.parse_boolean_query(inner_pair);
}
}
}
}
Err(LaurusError::parse("No valid query found".to_string()))
}
fn parse_boolean_query(&self, pair: pest::iterators::Pair<Rule>) -> Result<Box<dyn Query>> {
let mut current_occur = self.default_occur;
let mut terms: Vec<(Occur, Box<dyn Query>)> = Vec::new();
for inner_pair in pair.into_inner() {
match inner_pair.as_rule() {
Rule::boolean_op => {
let op = inner_pair.as_str();
current_occur = match op.to_uppercase().as_str() {
"AND" => Occur::Must,
"OR" => Occur::Should,
_ => Occur::Should,
};
}
Rule::clause => {
let (occur, query) = self.parse_clause(inner_pair, current_occur)?;
terms.push((occur, query));
current_occur = self.default_occur;
}
_ => {}
}
}
if terms.len() == 1 {
return Ok(terms.into_iter().next().unwrap().1);
}
let mut bool_query = BooleanQuery::new();
for (occur, query) in terms {
bool_query.add_clause(BooleanClause::new(query, occur));
}
Ok(Box::new(bool_query))
}
fn parse_clause(
&self,
pair: pest::iterators::Pair<Rule>,
default_occur: Occur,
) -> Result<(Occur, Box<dyn Query>)> {
for inner_pair in pair.into_inner() {
match inner_pair.as_rule() {
Rule::required_clause => {
for sub_pair in inner_pair.into_inner() {
if sub_pair.as_rule() == Rule::sub_clause {
let query = self.parse_sub_clause(sub_pair)?;
return Ok((Occur::Must, query));
}
}
}
Rule::prohibited_clause => {
for sub_pair in inner_pair.into_inner() {
if sub_pair.as_rule() == Rule::sub_clause {
let query = self.parse_sub_clause(sub_pair)?;
return Ok((Occur::MustNot, query));
}
}
}
Rule::sub_clause => {
let query = self.parse_sub_clause(inner_pair)?;
return Ok((default_occur, query));
}
_ => {}
}
}
Err(LaurusError::parse("Invalid clause".to_string()))
}
fn parse_sub_clause(&self, pair: pest::iterators::Pair<Rule>) -> Result<Box<dyn Query>> {
for inner_pair in pair.into_inner() {
match inner_pair.as_rule() {
Rule::grouped_query => return self.parse_grouped_query(inner_pair),
Rule::field_query => return self.parse_field_query(inner_pair),
Rule::term_query => return self.parse_term_query(inner_pair),
_ => {}
}
}
Err(LaurusError::parse("Invalid sub-clause".to_string()))
}
fn parse_grouped_query(&self, pair: pest::iterators::Pair<Rule>) -> Result<Box<dyn Query>> {
let mut boost = 1.0;
let mut query: Option<Box<dyn Query>> = None;
for inner_pair in pair.into_inner() {
match inner_pair.as_rule() {
Rule::boolean_query => {
query = Some(self.parse_boolean_query(inner_pair)?);
}
Rule::boost => {
boost = self.parse_boost(inner_pair)?;
}
_ => {}
}
}
if let Some(mut q) = query {
if boost != 1.0 {
q.set_boost(boost);
}
Ok(q)
} else {
Err(LaurusError::parse("Invalid grouped query".to_string()))
}
}
fn parse_field_query(&self, pair: pest::iterators::Pair<Rule>) -> Result<Box<dyn Query>> {
let mut field: Option<String> = None;
for inner_pair in pair.into_inner() {
match inner_pair.as_rule() {
Rule::field => {
field = Some(inner_pair.as_str().to_string());
}
Rule::field_value => {
let field_name = field
.ok_or_else(|| LaurusError::parse("Missing field name".to_string()))?;
return self.parse_field_value(inner_pair, Some(&field_name));
}
_ => {}
}
}
Err(LaurusError::parse("Invalid field query".to_string()))
}
fn parse_term_query(&self, pair: pest::iterators::Pair<Rule>) -> Result<Box<dyn Query>> {
for inner_pair in pair.into_inner() {
if inner_pair.as_rule() == Rule::field_value {
return self.parse_field_value(inner_pair, None);
}
}
Err(LaurusError::parse("Invalid term query".to_string()))
}
fn parse_field_value(
&self,
pair: pest::iterators::Pair<Rule>,
field: Option<&str>,
) -> Result<Box<dyn Query>> {
for inner_pair in pair.into_inner() {
match inner_pair.as_rule() {
Rule::range_query => return self.parse_range_query(inner_pair, field),
Rule::phrase_query => return self.parse_phrase_query(inner_pair, field),
Rule::fuzzy_term => return self.parse_fuzzy_term(inner_pair, field),
Rule::wildcard_term => return self.parse_wildcard_term(inner_pair, field),
Rule::simple_term => return self.parse_simple_term(inner_pair, field),
_ => {}
}
}
Err(LaurusError::parse("Invalid field value".to_string()))
}
fn parse_range_query(
&self,
pair: pest::iterators::Pair<Rule>,
field: Option<&str>,
) -> Result<Box<dyn Query>> {
let mut lower_inclusive = true;
let mut upper_inclusive = true;
let mut lower: Option<String> = None;
let mut upper: Option<String> = None;
for inner_pair in pair.into_inner() {
match inner_pair.as_rule() {
Rule::range_inclusive => {
lower_inclusive = true;
upper_inclusive = true;
for range_part in inner_pair.into_inner() {
if range_part.as_rule() == Rule::range_value {
if lower.is_none() {
lower = Some(self.parse_range_value(range_part)?);
} else {
upper = Some(self.parse_range_value(range_part)?);
}
}
}
}
Rule::range_exclusive => {
lower_inclusive = false;
upper_inclusive = false;
for range_part in inner_pair.into_inner() {
if range_part.as_rule() == Rule::range_value {
if lower.is_none() {
lower = Some(self.parse_range_value(range_part)?);
} else {
upper = Some(self.parse_range_value(range_part)?);
}
}
}
}
_ => {}
}
}
let lower_num = lower.as_ref().and_then(|s| s.parse::<f64>().ok());
let upper_num = upper.as_ref().and_then(|s| s.parse::<f64>().ok());
self.create_query_over_fields(field, |field_name| {
if lower_num.is_some() || upper_num.is_some() {
let query = NumericRangeQuery::new(
field_name,
NumericType::Float,
lower_num,
upper_num,
lower_inclusive,
upper_inclusive,
);
Ok(Box::new(query))
} else {
let term = format!(
"{}{} TO {}{}",
if lower_inclusive { "[" } else { "{" },
lower.as_deref().unwrap_or("*"),
upper.as_deref().unwrap_or("*"),
if upper_inclusive { "]" } else { "}" }
);
Ok(Box::new(TermQuery::new(field_name, &term)))
}
})
}
fn parse_range_value(&self, pair: pest::iterators::Pair<Rule>) -> Result<String> {
let value = pair.as_str();
if value == "*" {
Ok("*".to_string())
} else {
Ok(value.trim_matches('"').to_string())
}
}
fn parse_phrase_query(
&self,
pair: pest::iterators::Pair<Rule>,
field: Option<&str>,
) -> Result<Box<dyn Query>> {
let mut phrase_content = String::new();
let mut slop: Option<u32> = None;
let mut boost = 1.0;
for inner_pair in pair.into_inner() {
match inner_pair.as_rule() {
Rule::phrase_content => {
phrase_content = inner_pair.as_str().to_string();
}
Rule::proximity => {
for prox_pair in inner_pair.into_inner() {
if prox_pair.as_rule() == Rule::number {
slop = Some(prox_pair.as_str().parse().unwrap_or(0));
}
}
}
Rule::boost => {
boost = self.parse_boost(inner_pair)?;
}
_ => {}
}
}
self.create_query_over_fields(field, |field_name| {
let terms = self.analyze_term(Some(field_name), &phrase_content)?;
let mut phrase_query = PhraseQuery::new(field_name, terms);
if let Some(slop_value) = slop {
phrase_query = phrase_query.with_slop(slop_value);
}
if boost != 1.0 {
phrase_query = phrase_query.with_boost(boost);
}
Ok(Box::new(phrase_query))
})
}
fn parse_fuzzy_term(
&self,
pair: pest::iterators::Pair<Rule>,
field: Option<&str>,
) -> Result<Box<dyn Query>> {
let mut term = String::new();
let mut fuzziness: u8 = 2;
for inner_pair in pair.into_inner() {
match inner_pair.as_rule() {
Rule::term => {
term = inner_pair.as_str().to_string();
}
Rule::fuzziness => {
for fuzz_pair in inner_pair.into_inner() {
if fuzz_pair.as_rule() == Rule::number {
fuzziness = fuzz_pair.as_str().parse().unwrap_or(2);
}
}
}
_ => {}
}
}
self.create_query_over_fields(field, |field_name| {
let terms = self.analyze_term(Some(field_name), &term)?;
let normalized_term = if terms.is_empty() {
&term
} else {
&terms[0]
};
Ok(Box::new(
FuzzyQuery::new(field_name, normalized_term).max_edits(fuzziness as u32),
))
})
}
fn parse_wildcard_term(
&self,
pair: pest::iterators::Pair<Rule>,
field: Option<&str>,
) -> Result<Box<dyn Query>> {
let mut pattern = String::new();
for inner_pair in pair.into_inner() {
if inner_pair.as_rule() == Rule::wildcard_pattern {
pattern = inner_pair.as_str().to_string();
}
}
self.create_query_over_fields(field, |field_name| {
Ok(Box::new(WildcardQuery::new(field_name, &pattern)?))
})
}
fn parse_simple_term(
&self,
pair: pest::iterators::Pair<Rule>,
field: Option<&str>,
) -> Result<Box<dyn Query>> {
let mut term = String::new();
let mut boost = 1.0;
for inner_pair in pair.into_inner() {
match inner_pair.as_rule() {
Rule::term => {
term = inner_pair.as_str().to_string();
}
Rule::boost => {
boost = self.parse_boost(inner_pair)?;
}
_ => {}
}
}
self.create_query_over_fields(field, |field_name| {
let terms = self.analyze_term(Some(field_name), &term)?;
if terms.is_empty() {
return Err(LaurusError::parse("No terms after analysis".to_string()));
}
if terms.len() == 1 {
let query = TermQuery::new(field_name, &terms[0]);
if boost != 1.0 {
Ok(Box::new(query.with_boost(boost)))
} else {
Ok(Box::new(query))
}
} else {
let query = PhraseQuery::new(field_name, terms);
if boost != 1.0 {
Ok(Box::new(query.with_boost(boost)))
} else {
Ok(Box::new(query))
}
}
})
}
fn parse_boost(&self, pair: pest::iterators::Pair<Rule>) -> Result<f32> {
for inner_pair in pair.into_inner() {
if inner_pair.as_rule() == Rule::boost_value {
return Ok(inner_pair.as_str().parse().unwrap_or(1.0));
}
}
Ok(1.0)
}
fn analyze_term(&self, field: Option<&str>, term: &str) -> Result<Vec<String>> {
let token_stream = if let Some(field_name) = field {
if let Some(per_field) = self.analyzer.as_any().downcast_ref::<PerFieldAnalyzer>() {
per_field.analyze_field(field_name, term)?
} else {
self.analyzer.analyze(term)?
}
} else {
self.analyzer.analyze(term)?
};
let tokens: Vec<String> = token_stream.into_iter().map(|t| t.text).collect();
Ok(tokens)
}
}
pub struct QueryParserBuilder {
analyzer: Arc<dyn Analyzer>,
default_fields: Vec<String>,
default_occur: Occur,
}
impl QueryParserBuilder {
pub fn new(analyzer: Arc<dyn Analyzer>) -> Self {
Self {
analyzer,
default_fields: Vec::new(),
default_occur: Occur::Should,
}
}
pub fn default_field(mut self, field: impl Into<String>) -> Self {
self.default_fields = vec![field.into()];
self
}
pub fn default_fields(mut self, fields: Vec<String>) -> Self {
self.default_fields = fields;
self
}
pub fn default_occur(mut self, occur: Occur) -> Self {
self.default_occur = occur;
self
}
pub fn build(self) -> Result<LexicalQueryParser> {
Ok(LexicalQueryParser {
analyzer: self.analyzer,
default_fields: self.default_fields,
default_occur: self.default_occur,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::analysis::analyzer::standard::StandardAnalyzer;
fn create_test_parser() -> LexicalQueryParser {
let analyzer = Arc::new(StandardAnalyzer::new().unwrap());
LexicalQueryParser::new(analyzer)
}
#[test]
fn test_simple_term() {
let parser = create_test_parser().with_default_field("content");
let query = parser.parse("hello").unwrap();
assert!(format!("{query:?}").contains("TermQuery"));
}
#[test]
fn test_field_query() {
let parser = create_test_parser().with_default_field("content");
let query = parser.parse("title:hello").unwrap();
assert!(format!("{query:?}").contains("TermQuery"));
}
#[test]
fn test_boolean_query() {
let parser = create_test_parser().with_default_field("content");
let query = parser.parse("hello AND world").unwrap();
assert!(format!("{query:?}").contains("BooleanQuery"));
}
#[test]
fn test_phrase_query() {
let parser = create_test_parser().with_default_field("content");
let query = parser.parse("\"hello world\"").unwrap();
assert!(format!("{query:?}").contains("PhraseQuery"));
}
#[test]
fn test_fuzzy_query() {
let parser = create_test_parser().with_default_field("content");
let query = parser.parse("hello~2").unwrap();
assert!(format!("{query:?}").contains("FuzzyQuery"));
}
#[test]
fn test_wildcard_query() {
let parser = create_test_parser().with_default_field("content");
let query = parser.parse("hel*").unwrap();
assert!(format!("{query:?}").contains("WildcardQuery"));
}
#[test]
fn test_required_clause() {
let parser = create_test_parser().with_default_field("content");
let query = parser.parse("+hello world").unwrap();
assert!(format!("{query:?}").contains("BooleanQuery"));
}
#[test]
fn test_prohibited_clause() {
let parser = create_test_parser().with_default_field("content");
let query = parser.parse("hello -world").unwrap();
assert!(format!("{query:?}").contains("BooleanQuery"));
}
#[test]
fn test_grouped_query() {
let parser = create_test_parser().with_default_field("content");
let query = parser.parse("(hello OR world) AND test").unwrap();
assert!(format!("{query:?}").contains("BooleanQuery"));
}
#[test]
fn test_proximity_search() {
let parser = create_test_parser().with_default_field("content");
let query = parser.parse("\"hello world\"~10").unwrap();
assert!(format!("{query:?}").contains("PhraseQuery"));
}
#[test]
fn test_multiple_default_fields() {
let parser =
create_test_parser().with_default_fields(vec!["title".to_string(), "body".to_string()]);
let query = parser.parse("hello").unwrap();
let query_debug = format!("{:?}", query);
assert!(query_debug.contains("BooleanQuery"));
}
}