use std::collections::HashSet;
use std::sync::LazyLock;
use regex::Regex;
use crate::engine::search::{FusionAlgorithm, HybridMode, SearchQuery, SearchRequest};
use crate::error::{LaurusError, Result};
use crate::lexical::query::parser::LexicalQueryParser;
use crate::lexical::search::searcher::LexicalSearchQuery;
use crate::vector::query::parser::VectorQueryParser;
pub struct UnifiedQueryParser {
lexical_parser: LexicalQueryParser,
vector_parser: VectorQueryParser,
vector_fields: HashSet<String>,
default_fusion: FusionAlgorithm,
}
impl UnifiedQueryParser {
pub fn new(
lexical_parser: LexicalQueryParser,
vector_parser: VectorQueryParser,
vector_fields: HashSet<String>,
) -> Self {
Self {
lexical_parser,
vector_parser,
vector_fields,
default_fusion: FusionAlgorithm::RRF { k: 60.0 },
}
}
pub fn with_fusion(mut self, fusion: FusionAlgorithm) -> Self {
self.default_fusion = fusion;
self
}
pub async fn parse(&self, query_str: &str) -> Result<SearchRequest> {
let query_str = query_str.trim();
if query_str.is_empty() {
return Err(LaurusError::invalid_argument(
"Query string must not be empty",
));
}
let (lexical_str, vector_str, hybrid_mode) = self.split_query(query_str)?;
let lexical = if let Some(ref s) = lexical_str {
Some(self.lexical_parser.parse(s)?)
} else {
None
};
let vector = if let Some(ref s) = vector_str {
Some(self.vector_parser.parse(s).await?)
} else {
None
};
if lexical.is_none() && vector.is_none() {
return Err(LaurusError::invalid_argument(
"Query must contain at least one lexical or vector clause",
));
}
let fusion = if lexical.is_some() && vector.is_some() {
Some(self.default_fusion)
} else {
None
};
let query = match (lexical, vector) {
(Some(lex_query), Some(vec_req)) => SearchQuery::Hybrid {
lexical: LexicalSearchQuery::Obj(lex_query),
vector: vec_req.query,
mode: hybrid_mode,
},
(Some(lex_query), None) => SearchQuery::Lexical(LexicalSearchQuery::Obj(lex_query)),
(None, Some(vec_req)) => SearchQuery::Vector(vec_req.query),
(None, None) => unreachable!(),
};
Ok(SearchRequest {
query,
fusion_algorithm: fusion,
..Default::default()
})
}
fn split_query(&self, input: &str) -> Result<(Option<String>, Option<String>, HybridMode)> {
if self.vector_fields.is_empty() {
return Ok((Some(input.to_string()), None, HybridMode::Union));
}
let fields_pattern: String = self
.vector_fields
.iter()
.map(|f| regex::escape(f))
.collect::<Vec<_>>()
.join("|");
self.check_unsupported_vector_syntax(input, &fields_pattern)?;
let clause_pattern = format!(
r#"(\+)?({fields})(?::(?:"[^"]*"|[^\s"^~\[\{{]+)(?:\^[\d]+(?:\.[\d]+)?)?)"#,
fields = fields_pattern,
);
let vector_re = Regex::new(&clause_pattern).unwrap();
let mut vector_clauses: Vec<String> = Vec::new();
let mut has_required = false;
for caps in vector_re.captures_iter(input) {
if caps.get(1).is_some() {
has_required = true;
}
let full_match = caps.get(0).unwrap().as_str();
let clause = full_match.strip_prefix('+').unwrap_or(full_match);
vector_clauses.push(clause.to_string());
}
let lexical_raw = vector_re.replace_all(input, " ");
let lexical_cleaned = clean_lexical_string(&lexical_raw);
let lexical = if lexical_cleaned.is_empty() {
None
} else {
Some(lexical_cleaned)
};
let vector = if vector_clauses.is_empty() {
None
} else {
Some(vector_clauses.join(" "))
};
let mode = if has_required {
HybridMode::Intersection
} else {
HybridMode::Union
};
Ok((lexical, vector, mode))
}
fn check_unsupported_vector_syntax(&self, input: &str, fields_pattern: &str) -> Result<()> {
let tilde_pattern = format!(
r#"((?:{fields})(?::(?:"[^"]*"|[^\s"^~]+)(?:\^[\d]+(?:\.[\d]+)?)?))(~[\d]*)"#,
fields = fields_pattern,
);
let tilde_re = Regex::new(&tilde_pattern).unwrap();
if let Some(caps) = tilde_re.captures(input) {
let clause = caps.get(1).unwrap().as_str();
let modifier = caps.get(2).unwrap().as_str();
return Err(LaurusError::invalid_argument(format!(
"Proximity/fuzzy modifier '{modifier}' is not supported on vector field \
clause '{clause}'. The '~' modifier is only valid for lexical queries \
(e.g. \"term~2\" for fuzzy, '\"phrase\"~10' for proximity)."
)));
}
let range_pattern = format!(r#"({fields}):(\[|\{{)"#, fields = fields_pattern,);
let range_re = Regex::new(&range_pattern).unwrap();
if let Some(caps) = range_re.captures(input) {
let field = caps.get(1).unwrap().as_str();
let bracket = caps.get(2).unwrap().as_str();
let kind = if bracket == "[" {
"inclusive range ([...TO...])"
} else {
"exclusive range ({...TO...})"
};
return Err(LaurusError::invalid_argument(format!(
"Range query on vector field '{field}' is not supported. \
The {kind} syntax is only valid for lexical fields."
)));
}
Ok(())
}
}
fn clean_lexical_string(s: &str) -> String {
static WHITESPACE_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\s+").unwrap());
let s = WHITESPACE_RE.replace_all(s, " ");
let s = s.trim();
if s.is_empty() {
return String::new();
}
let tokens: Vec<&str> = s.split_whitespace().collect();
let mut result: Vec<&str> = Vec::new();
for token in &tokens {
if token.eq_ignore_ascii_case("AND") || token.eq_ignore_ascii_case("OR") {
if !result.is_empty() {
let last = result.last().unwrap();
if !last.eq_ignore_ascii_case("AND") && !last.eq_ignore_ascii_case("OR") {
result.push(token);
}
}
} else {
result.push(token);
}
}
if let Some(last) = result.last()
&& (last.eq_ignore_ascii_case("AND") || last.eq_ignore_ascii_case("OR"))
{
result.pop();
}
result.join(" ")
}
#[cfg(test)]
mod tests {
use std::any::Any;
use std::sync::Arc;
use async_trait::async_trait;
use super::*;
use crate::analysis::analyzer::standard::StandardAnalyzer;
use crate::embedding::embedder::{EmbedInput, EmbedInputType, Embedder};
use crate::engine::search::VectorSearchQuery;
use crate::error::Result as IrisResult;
use crate::vector::core::vector::Vector;
use crate::vector::store::request::QueryVector;
#[derive(Debug)]
struct MockEmbedder;
#[async_trait]
impl Embedder for MockEmbedder {
async fn embed(&self, _input: &EmbedInput<'_>) -> IrisResult<Vector> {
Ok(Vector::new(vec![0.0; 4]))
}
fn supported_input_types(&self) -> Vec<EmbedInputType> {
vec![EmbedInputType::Text]
}
fn name(&self) -> &str {
"mock"
}
fn as_any(&self) -> &dyn Any {
self
}
}
fn assert_lexical_only(request: &SearchRequest) {
assert!(matches!(request.query, SearchQuery::Lexical(_)));
}
fn assert_vector_only(request: &SearchRequest) -> &VectorSearchQuery {
match &request.query {
SearchQuery::Vector(v) => v,
_ => panic!("Expected SearchQuery::Vector"),
}
}
fn assert_hybrid(
request: &SearchRequest,
) -> (&LexicalSearchQuery, &VectorSearchQuery, &HybridMode) {
match &request.query {
SearchQuery::Hybrid {
lexical,
vector,
mode,
} => (lexical, vector, mode),
_ => panic!("Expected SearchQuery::Hybrid"),
}
}
fn get_vectors(vq: &VectorSearchQuery) -> &Vec<QueryVector> {
match vq {
VectorSearchQuery::Vectors(v) => v,
_ => panic!("Expected VectorSearchQuery::Vectors"),
}
}
fn make_parser() -> UnifiedQueryParser {
let analyzer = Arc::new(StandardAnalyzer::new().unwrap());
let lexical = LexicalQueryParser::new(analyzer).with_default_field("title");
let embedder: Arc<dyn Embedder> = Arc::new(MockEmbedder);
let vector = VectorQueryParser::new(embedder).with_default_field("content");
let vector_fields: HashSet<String> = ["content".to_string()].into_iter().collect();
UnifiedQueryParser::new(lexical, vector, vector_fields)
}
#[tokio::test]
async fn test_lexical_only() {
let parser = make_parser();
let request = parser.parse("title:hello").await.unwrap();
assert_lexical_only(&request);
assert!(request.fusion_algorithm.is_none());
}
#[tokio::test]
async fn test_vector_only_quoted() {
let parser = make_parser();
let request = parser.parse(r#"content:"cats""#).await.unwrap();
let vq = assert_vector_only(&request);
assert!(request.fusion_algorithm.is_none());
let vecs = get_vectors(vq);
assert_eq!(vecs.len(), 1);
assert_eq!(vecs[0].fields.as_ref().unwrap()[0], "content");
}
#[tokio::test]
async fn test_vector_only_unquoted() {
let parser = make_parser();
let request = parser.parse("content:cats").await.unwrap();
let vq = assert_vector_only(&request);
let vecs = get_vectors(vq);
assert_eq!(vecs.len(), 1);
assert_eq!(vecs[0].fields.as_ref().unwrap()[0], "content");
}
#[tokio::test]
async fn test_hybrid() {
let parser = make_parser();
let request = parser.parse(r#"title:hello content:"cats""#).await.unwrap();
assert_hybrid(&request);
assert!(request.fusion_algorithm.is_some());
if let Some(FusionAlgorithm::RRF { k }) = request.fusion_algorithm {
assert!((k - 60.0).abs() < f64::EPSILON);
} else {
panic!("Expected RRF fusion");
}
}
#[tokio::test]
async fn test_hybrid_unquoted_vector() {
let parser = make_parser();
let request = parser.parse("title:hello content:cats").await.unwrap();
assert_hybrid(&request);
}
#[tokio::test]
async fn test_vector_with_boost() {
let parser = make_parser();
let request = parser.parse(r#"content:"text"^0.8"#).await.unwrap();
let vq = assert_vector_only(&request);
let vecs = get_vectors(vq);
assert!((vecs[0].weight - 0.8).abs() < f32::EPSILON);
}
#[tokio::test]
async fn test_unquoted_vector_with_boost() {
let parser = make_parser();
let request = parser.parse("content:python^0.8").await.unwrap();
let vq = assert_vector_only(&request);
let vecs = get_vectors(vq);
assert!((vecs[0].weight - 0.8).abs() < f32::EPSILON);
}
#[tokio::test]
async fn test_multiple_vector_clauses() {
let analyzer = Arc::new(StandardAnalyzer::new().unwrap());
let lexical = LexicalQueryParser::new(analyzer).with_default_field("title");
let embedder: Arc<dyn Embedder> = Arc::new(MockEmbedder);
let vector = VectorQueryParser::new(embedder);
let vector_fields: HashSet<String> =
["a".to_string(), "b".to_string()].into_iter().collect();
let parser = UnifiedQueryParser::new(lexical, vector, vector_fields);
let request = parser.parse(r#"a:"x" b:"y"^0.5"#).await.unwrap();
let vq = assert_vector_only(&request);
let vecs = get_vectors(vq);
assert_eq!(vecs.len(), 2);
assert_eq!(vecs[0].fields.as_ref().unwrap()[0], "a");
assert_eq!(vecs[1].fields.as_ref().unwrap()[0], "b");
assert!((vecs[1].weight - 0.5).abs() < f32::EPSILON);
}
#[tokio::test]
async fn test_lexical_and_with_vector() {
let parser = make_parser();
let request = parser
.parse(r#"title:hello AND title:world content:"cats""#)
.await
.unwrap();
assert_hybrid(&request);
assert!(request.fusion_algorithm.is_some());
}
#[tokio::test]
async fn test_vector_between_and() {
let parser = make_parser();
let request = parser
.parse(r#"title:hello AND content:"cats" AND title:world"#)
.await
.unwrap();
assert_hybrid(&request);
}
#[tokio::test]
async fn test_empty_query_error() {
let parser = make_parser();
assert!(parser.parse("").await.is_err());
assert!(parser.parse(" ").await.is_err());
}
#[tokio::test]
async fn test_custom_fusion() {
let analyzer = Arc::new(StandardAnalyzer::new().unwrap());
let lexical = LexicalQueryParser::new(analyzer).with_default_field("title");
let embedder: Arc<dyn Embedder> = Arc::new(MockEmbedder);
let vector = VectorQueryParser::new(embedder).with_default_field("content");
let vector_fields: HashSet<String> = ["content".to_string()].into_iter().collect();
let parser = UnifiedQueryParser::new(lexical, vector, vector_fields).with_fusion(
FusionAlgorithm::WeightedSum {
lexical_weight: 0.7,
vector_weight: 0.3,
},
);
let request = parser.parse(r#"title:hello content:"cats""#).await.unwrap();
if let Some(FusionAlgorithm::WeightedSum {
lexical_weight,
vector_weight,
}) = request.fusion_algorithm
{
assert!((lexical_weight - 0.7).abs() < f32::EPSILON);
assert!((vector_weight - 0.3).abs() < f32::EPSILON);
} else {
panic!("Expected WeightedSum fusion");
}
}
#[tokio::test]
async fn test_unicode_vector_text() {
let parser = make_parser();
let request = parser.parse(r#"content:"日本語テスト""#).await.unwrap();
let vq = assert_vector_only(&request);
let vecs = get_vectors(vq);
assert_eq!(vecs.len(), 1);
assert_eq!(vecs[0].fields.as_ref().unwrap()[0], "content");
assert_eq!(vecs[0].vector.dimension(), 4);
}
#[tokio::test]
async fn test_no_vector_fields_all_lexical() {
let analyzer = Arc::new(StandardAnalyzer::new().unwrap());
let lexical = LexicalQueryParser::new(analyzer).with_default_field("title");
let embedder: Arc<dyn Embedder> = Arc::new(MockEmbedder);
let vector = VectorQueryParser::new(embedder);
let parser = UnifiedQueryParser::new(lexical, vector, HashSet::new());
let request = parser.parse("title:hello").await.unwrap();
assert_lexical_only(&request);
}
#[tokio::test]
async fn test_vector_field_with_proximity_error() {
let parser = make_parser();
let result = parser.parse(r#"content:"hello world"~2"#).await;
assert!(result.is_err());
let msg = result.err().unwrap().to_string();
assert!(msg.contains("Proximity/fuzzy modifier"), "got: {msg}");
assert!(msg.contains("content:"), "got: {msg}");
}
#[tokio::test]
async fn test_vector_field_with_fuzzy_error() {
let parser = make_parser();
let result = parser.parse("content:python~2").await;
assert!(result.is_err());
let msg = result.err().unwrap().to_string();
assert!(msg.contains("Proximity/fuzzy modifier"), "got: {msg}");
}
#[tokio::test]
async fn test_vector_field_with_tilde_only_error() {
let parser = make_parser();
let result = parser.parse("content:python~").await;
assert!(result.is_err());
let msg = result.err().unwrap().to_string();
assert!(msg.contains("Proximity/fuzzy modifier"), "got: {msg}");
}
#[tokio::test]
async fn test_vector_field_with_inclusive_range_error() {
let parser = make_parser();
let result = parser.parse("content:[A TO Z]").await;
assert!(result.is_err());
let msg = result.err().unwrap().to_string();
assert!(msg.contains("Range query"), "got: {msg}");
assert!(msg.contains("content"), "got: {msg}");
assert!(msg.contains("inclusive"), "got: {msg}");
}
#[tokio::test]
async fn test_vector_field_with_exclusive_range_error() {
let parser = make_parser();
let result = parser.parse("content:{100 TO 500}").await;
assert!(result.is_err());
let msg = result.err().unwrap().to_string();
assert!(msg.contains("Range query"), "got: {msg}");
assert!(msg.contains("exclusive"), "got: {msg}");
}
#[tokio::test]
async fn test_lexical_field_range_still_works() {
let parser = make_parser();
let request = parser.parse("title:[A TO Z]").await.unwrap();
assert_lexical_only(&request);
}
#[tokio::test]
async fn test_lexical_field_fuzzy_still_works() {
let parser = make_parser();
let request = parser.parse("title:hello~2").await.unwrap();
assert_lexical_only(&request);
}
#[tokio::test]
async fn test_tilde_inside_quotes_is_valid_vector_text() {
let parser = make_parser();
let request = parser.parse(r#"content:"python~2""#).await.unwrap();
let vq = assert_vector_only(&request);
let vecs = get_vectors(vq);
assert_eq!(vecs.len(), 1);
assert_eq!(vecs[0].fields.as_ref().unwrap()[0], "content");
}
#[tokio::test]
async fn test_hybrid_union_by_default() {
let parser = make_parser();
let request = parser.parse(r#"title:hello content:"cats""#).await.unwrap();
let (_, _, mode) = assert_hybrid(&request);
assert_eq!(*mode, HybridMode::Union);
}
#[tokio::test]
async fn test_hybrid_intersection_with_plus_vector() {
let parser = make_parser();
let request = parser
.parse(r#"title:hello +content:"cats""#)
.await
.unwrap();
let (_, _, mode) = assert_hybrid(&request);
assert_eq!(*mode, HybridMode::Intersection);
}
#[tokio::test]
async fn test_hybrid_intersection_plus_on_both() {
let parser = make_parser();
let request = parser
.parse(r#"+title:hello +content:"cats""#)
.await
.unwrap();
let (_, _, mode) = assert_hybrid(&request);
assert_eq!(*mode, HybridMode::Intersection);
}
#[tokio::test]
async fn test_hybrid_intersection_unquoted_vector() {
let parser = make_parser();
let request = parser.parse("title:hello +content:cats").await.unwrap();
let (_, _, mode) = assert_hybrid(&request);
assert_eq!(*mode, HybridMode::Intersection);
}
#[tokio::test]
async fn test_vector_only_with_plus_stays_vector() {
let parser = make_parser();
let request = parser.parse(r#"+content:"cats""#).await.unwrap();
assert_vector_only(&request);
}
#[test]
fn test_clean_leading_boolean() {
assert_eq!(clean_lexical_string("AND hello"), "hello");
assert_eq!(clean_lexical_string("OR hello"), "hello");
}
#[test]
fn test_clean_trailing_boolean() {
assert_eq!(clean_lexical_string("hello AND"), "hello");
assert_eq!(clean_lexical_string("hello OR"), "hello");
}
#[test]
fn test_clean_consecutive_boolean() {
assert_eq!(
clean_lexical_string("hello AND AND world"),
"hello AND world"
);
assert_eq!(clean_lexical_string("hello OR OR world"), "hello OR world");
assert_eq!(
clean_lexical_string("hello AND OR world"),
"hello AND world"
);
}
#[test]
fn test_clean_multiple_spaces() {
assert_eq!(clean_lexical_string("hello world"), "hello world");
}
#[test]
fn test_clean_empty() {
assert_eq!(clean_lexical_string(""), "");
assert_eq!(clean_lexical_string(" "), "");
assert_eq!(clean_lexical_string("AND"), "");
assert_eq!(clean_lexical_string("AND OR"), "");
}
}