use std::sync::Arc;
use pest::Parser;
use pest_derive::Parser;
use crate::data::DataValue;
use crate::embedding::embedder::{EmbedInput, Embedder};
use crate::embedding::per_field::PerFieldEmbedder;
use crate::error::{LaurusError, Result};
use crate::vector::core::vector::Vector;
use crate::vector::store::request::{
QueryPayload, QueryVector, VectorSearchQuery, VectorSearchRequest,
};
#[derive(Parser)]
#[grammar = "vector/query/parser.pest"]
struct VectorQueryStringParser;
pub struct VectorQueryParser {
embedder: Arc<dyn Embedder>,
default_fields: Vec<String>,
}
impl VectorQueryParser {
pub fn new(embedder: Arc<dyn Embedder>) -> Self {
Self {
embedder,
default_fields: Vec::new(),
}
}
pub fn with_default_field(mut self, field: impl Into<String>) -> Self {
self.default_fields = vec![field.into()];
self
}
pub fn with_default_fields(mut self, fields: Vec<String>) -> Self {
self.default_fields = fields;
self
}
pub async fn parse(&self, query_str: &str) -> Result<VectorSearchRequest> {
let pairs = VectorQueryStringParser::parse(Rule::query, query_str).map_err(|e| {
LaurusError::invalid_argument(format!("Failed to parse vector query: {}", e))
})?;
let mut payloads = Vec::new();
for pair in pairs {
if pair.as_rule() == Rule::query {
for inner in pair.into_inner() {
if inner.as_rule() == Rule::vector_clause {
let payload = self.parse_vector_clause(inner)?;
payloads.push(payload);
}
}
}
}
if payloads.is_empty() {
return Err(LaurusError::invalid_argument(
"Vector query must contain at least one clause",
));
}
let mut query_vectors = Vec::new();
for payload in payloads {
let input = match &payload.payload {
DataValue::Text(t) => EmbedInput::Text(t),
DataValue::Bytes(b, m) => EmbedInput::Bytes(b, m.as_deref()),
_ => continue,
};
let vector = self.embed_for_field(&payload.field, &input).await?;
query_vectors.push(QueryVector {
vector,
weight: payload.weight,
fields: Some(vec![payload.field]),
});
}
Ok(VectorSearchRequest {
query: VectorSearchQuery::Vectors(query_vectors),
params: Default::default(),
})
}
async fn embed_for_field(&self, field: &str, input: &EmbedInput<'_>) -> Result<Vector> {
if let Some(pf) = self.embedder.as_any().downcast_ref::<PerFieldEmbedder>() {
pf.embed_field(field, input).await
} else {
self.embedder.embed(input).await
}
}
fn parse_vector_clause(&self, pair: pest::iterators::Pair<Rule>) -> Result<QueryPayload> {
let mut field_name: Option<String> = None;
let mut text: Option<String> = None;
let mut weight: f32 = 1.0;
for inner in pair.into_inner() {
match inner.as_rule() {
Rule::field_prefix => {
for fp_inner in inner.into_inner() {
if fp_inner.as_rule() == Rule::field_name {
field_name = Some(fp_inner.as_str().to_string());
}
}
}
Rule::quoted_text => {
for qt_inner in inner.into_inner() {
if qt_inner.as_rule() == Rule::inner_text {
text = Some(qt_inner.as_str().to_string());
}
}
}
Rule::boost => {
for b_inner in inner.into_inner() {
if b_inner.as_rule() == Rule::float_value {
weight = b_inner.as_str().parse::<f32>().map_err(|e| {
LaurusError::invalid_argument(format!("Invalid boost value: {}", e))
})?;
}
}
}
_ => {}
}
}
let field = match field_name {
Some(f) => f,
None => {
if self.default_fields.is_empty() {
return Err(LaurusError::invalid_argument(
"No field specified and no default field configured",
));
}
self.default_fields[0].clone()
}
};
let text = text
.ok_or_else(|| LaurusError::invalid_argument("Missing quoted text in vector clause"))?;
Ok(QueryPayload::with_weight(
field,
DataValue::Text(text),
weight,
))
}
}
#[cfg(test)]
mod tests {
use std::any::Any;
use async_trait::async_trait;
use super::*;
use crate::embedding::embedder::EmbedInputType;
#[derive(Debug)]
struct MockEmbedder {
dimension: usize,
}
#[async_trait]
impl Embedder for MockEmbedder {
async fn embed(&self, _input: &EmbedInput<'_>) -> Result<Vector> {
Ok(Vector::new(vec![0.0; self.dimension]))
}
fn supported_input_types(&self) -> Vec<EmbedInputType> {
vec![EmbedInputType::Text]
}
fn name(&self) -> &str {
"mock"
}
fn as_any(&self) -> &dyn Any {
self
}
}
fn mock_embedder() -> Arc<dyn Embedder> {
Arc::new(MockEmbedder { dimension: 4 })
}
fn get_vectors(req: &VectorSearchRequest) -> &[QueryVector] {
match &req.query {
VectorSearchQuery::Vectors(v) => v,
_ => panic!("Expected VectorSearchQuery::Vectors"),
}
}
#[tokio::test]
async fn test_basic_query() {
let parser = VectorQueryParser::new(mock_embedder());
let request = parser.parse(r#"content:~"cute kitten""#).await.unwrap();
let vecs = get_vectors(&request);
assert_eq!(vecs.len(), 1);
let qv = &vecs[0];
assert_eq!(qv.fields.as_ref().unwrap()[0], "content");
assert_eq!(qv.weight, 1.0);
assert_eq!(qv.vector.dimension(), 4);
}
#[tokio::test]
async fn test_boost() {
let parser = VectorQueryParser::new(mock_embedder());
let request = parser.parse(r#"content:~"text"^0.8"#).await.unwrap();
let vecs = get_vectors(&request);
assert_eq!(vecs.len(), 1);
let qv = &vecs[0];
assert_eq!(qv.fields.as_ref().unwrap()[0], "content");
assert!((qv.weight - 0.8).abs() < f32::EPSILON);
}
#[tokio::test]
async fn test_default_field() {
let parser = VectorQueryParser::new(mock_embedder()).with_default_field("embedding");
let request = parser.parse(r#"~"cute kitten""#).await.unwrap();
let vecs = get_vectors(&request);
assert_eq!(vecs.len(), 1);
assert_eq!(vecs[0].fields.as_ref().unwrap()[0], "embedding");
}
#[tokio::test]
async fn test_multiple_clauses() {
let parser = VectorQueryParser::new(mock_embedder());
let request = parser
.parse(r#"content:~"cats" image:~"dogs"^0.5"#)
.await
.unwrap();
let vecs = get_vectors(&request);
assert_eq!(vecs.len(), 2);
assert_eq!(vecs[0].fields.as_ref().unwrap()[0], "content");
assert_eq!(vecs[0].weight, 1.0);
assert_eq!(vecs[1].fields.as_ref().unwrap()[0], "image");
assert!((vecs[1].weight - 0.5).abs() < f32::EPSILON);
}
#[tokio::test]
async fn test_empty_query_error() {
let parser = VectorQueryParser::new(mock_embedder());
assert!(parser.parse("").await.is_err());
}
#[tokio::test]
async fn test_missing_tilde_error() {
let parser = VectorQueryParser::new(mock_embedder());
assert!(parser.parse(r#"content:"text""#).await.is_err());
}
#[tokio::test]
async fn test_no_field_no_default_error() {
let parser = VectorQueryParser::new(mock_embedder()); assert!(parser.parse(r#"~"text""#).await.is_err());
}
#[tokio::test]
async fn test_unicode_text() {
let parser = VectorQueryParser::new(mock_embedder());
let request = parser.parse(r#"content:~"日本語テスト""#).await.unwrap();
let vecs = get_vectors(&request);
assert_eq!(vecs.len(), 1);
assert_eq!(qv_field(&vecs[0]), "content");
assert_eq!(vecs[0].vector.dimension(), 4);
}
#[tokio::test]
async fn test_integer_boost() {
let parser = VectorQueryParser::new(mock_embedder());
let request = parser.parse(r#"content:~"text"^2"#).await.unwrap();
let vecs = get_vectors(&request);
assert!((vecs[0].weight - 2.0).abs() < f32::EPSILON);
}
#[tokio::test]
async fn test_field_with_underscore() {
let parser = VectorQueryParser::new(mock_embedder());
let request = parser.parse(r#"my_field:~"text""#).await.unwrap();
let vecs = get_vectors(&request);
assert_eq!(qv_field(&vecs[0]), "my_field");
}
#[tokio::test]
async fn test_field_with_dot() {
let parser = VectorQueryParser::new(mock_embedder());
let request = parser.parse(r#"nested.field:~"text""#).await.unwrap();
let vecs = get_vectors(&request);
assert_eq!(qv_field(&vecs[0]), "nested.field");
}
fn qv_field(qv: &QueryVector) -> &str {
&qv.fields.as_ref().unwrap()[0]
}
}