1use pest::Parser;
11use pest_derive::Parser;
12use std::sync::Arc;
13
14use super::query_field_router::{QueryFieldRouter, RoutingMode};
15use super::schema::{Field, Schema};
16use crate::query::{BooleanQuery, Query, TermQuery};
17use crate::tokenizer::{BoxedTokenizer, TokenizerRegistry};
18
19#[derive(Parser)]
20#[grammar = "dsl/ql/ql.pest"]
21struct QueryParser;
22
23#[derive(Debug, Clone)]
25pub enum ParsedQuery {
26 Term {
27 field: Option<String>,
28 term: String,
29 },
30 Phrase {
31 field: Option<String>,
32 phrase: String,
33 },
34 Ann {
36 field: String,
37 vector: Vec<f32>,
38 nprobe: usize,
39 rerank: f32,
40 },
41 Sparse {
43 field: String,
44 vector: Vec<(u32, f32)>,
45 },
46 And(Vec<ParsedQuery>),
47 Or(Vec<ParsedQuery>),
48 Not(Box<ParsedQuery>),
49}
50
51pub struct QueryLanguageParser {
53 schema: Arc<Schema>,
54 default_fields: Vec<Field>,
55 tokenizers: Arc<TokenizerRegistry>,
56 field_router: Option<QueryFieldRouter>,
58}
59
60impl QueryLanguageParser {
61 pub fn new(
62 schema: Arc<Schema>,
63 default_fields: Vec<Field>,
64 tokenizers: Arc<TokenizerRegistry>,
65 ) -> Self {
66 Self {
67 schema,
68 default_fields,
69 tokenizers,
70 field_router: None,
71 }
72 }
73
74 pub fn with_router(
76 schema: Arc<Schema>,
77 default_fields: Vec<Field>,
78 tokenizers: Arc<TokenizerRegistry>,
79 router: QueryFieldRouter,
80 ) -> Self {
81 Self {
82 schema,
83 default_fields,
84 tokenizers,
85 field_router: Some(router),
86 }
87 }
88
89 pub fn set_router(&mut self, router: QueryFieldRouter) {
91 self.field_router = Some(router);
92 }
93
94 pub fn router(&self) -> Option<&QueryFieldRouter> {
96 self.field_router.as_ref()
97 }
98
99 pub fn parse(&self, query_str: &str) -> Result<Box<dyn Query>, String> {
109 let query_str = query_str.trim();
110 if query_str.is_empty() {
111 return Err("Empty query".to_string());
112 }
113
114 if let Some(router) = &self.field_router
116 && let Some(routed) = router.route(query_str)
117 {
118 return self.build_routed_query(
119 &routed.query,
120 &routed.target_field,
121 routed.mode,
122 query_str,
123 );
124 }
125
126 self.parse_normal(query_str)
128 }
129
130 fn build_routed_query(
132 &self,
133 routed_query: &str,
134 target_field: &str,
135 mode: RoutingMode,
136 original_query: &str,
137 ) -> Result<Box<dyn Query>, String> {
138 let _field_id = self
140 .schema
141 .get_field(target_field)
142 .ok_or_else(|| format!("Unknown target field: {}", target_field))?;
143
144 let target_query = self.build_term_query(Some(target_field), routed_query)?;
146
147 match mode {
148 RoutingMode::Exclusive => {
149 Ok(target_query)
151 }
152 RoutingMode::Additional => {
153 let mut bool_query = BooleanQuery::new();
155 bool_query = bool_query.should(target_query);
156
157 if let Ok(default_query) = self.parse_normal(original_query) {
159 bool_query = bool_query.should(default_query);
160 }
161
162 Ok(Box::new(bool_query))
163 }
164 }
165 }
166
167 fn parse_normal(&self, query_str: &str) -> Result<Box<dyn Query>, String> {
169 match self.parse_query_string(query_str) {
171 Ok(parsed) => self.build_query(&parsed),
172 Err(_) => {
173 self.parse_plain_text(query_str)
176 }
177 }
178 }
179
180 fn parse_plain_text(&self, text: &str) -> Result<Box<dyn Query>, String> {
182 if self.default_fields.is_empty() {
183 return Err("No default fields configured".to_string());
184 }
185
186 let tokenizer = self.get_tokenizer(self.default_fields[0]);
187 let tokens: Vec<String> = tokenizer
188 .tokenize(text)
189 .into_iter()
190 .map(|t| t.text.to_lowercase())
191 .collect();
192
193 if tokens.is_empty() {
194 return Err("No tokens in query".to_string());
195 }
196
197 let mut bool_query = BooleanQuery::new();
198 for token in &tokens {
199 for &field_id in &self.default_fields {
200 bool_query = bool_query.should(TermQuery::text(field_id, token));
201 }
202 }
203 Ok(Box::new(bool_query))
204 }
205
206 fn parse_query_string(&self, query_str: &str) -> Result<ParsedQuery, String> {
207 let pairs = QueryParser::parse(Rule::query, query_str)
208 .map_err(|e| format!("Parse error: {}", e))?;
209
210 let query_pair = pairs.into_iter().next().ok_or("No query found")?;
211
212 self.parse_or_expr(query_pair.into_inner().next().unwrap())
214 }
215
216 fn parse_or_expr(&self, pair: pest::iterators::Pair<Rule>) -> Result<ParsedQuery, String> {
217 let mut inner = pair.into_inner();
218 let first = self.parse_and_expr(inner.next().unwrap())?;
219
220 let rest: Vec<ParsedQuery> = inner
221 .filter(|p| p.as_rule() == Rule::and_expr)
222 .map(|p| self.parse_and_expr(p))
223 .collect::<Result<Vec<_>, _>>()?;
224
225 if rest.is_empty() {
226 Ok(first)
227 } else {
228 let mut all = vec![first];
229 all.extend(rest);
230 Ok(ParsedQuery::Or(all))
231 }
232 }
233
234 fn parse_and_expr(&self, pair: pest::iterators::Pair<Rule>) -> Result<ParsedQuery, String> {
235 let mut inner = pair.into_inner();
236 let first = self.parse_primary(inner.next().unwrap())?;
237
238 let rest: Vec<ParsedQuery> = inner
239 .filter(|p| p.as_rule() == Rule::primary)
240 .map(|p| self.parse_primary(p))
241 .collect::<Result<Vec<_>, _>>()?;
242
243 if rest.is_empty() {
244 Ok(first)
245 } else {
246 let mut all = vec![first];
247 all.extend(rest);
248 Ok(ParsedQuery::And(all))
249 }
250 }
251
252 fn parse_primary(&self, pair: pest::iterators::Pair<Rule>) -> Result<ParsedQuery, String> {
253 let mut negated = false;
254 let mut inner_query = None;
255
256 for inner in pair.into_inner() {
257 match inner.as_rule() {
258 Rule::not_op => negated = true,
259 Rule::group => {
260 let or_expr = inner.into_inner().next().unwrap();
261 inner_query = Some(self.parse_or_expr(or_expr)?);
262 }
263 Rule::ann_query => {
264 inner_query = Some(self.parse_ann_query(inner)?);
265 }
266 Rule::sparse_query => {
267 inner_query = Some(self.parse_sparse_query(inner)?);
268 }
269 Rule::phrase_query => {
270 inner_query = Some(self.parse_phrase_query(inner)?);
271 }
272 Rule::term_query => {
273 inner_query = Some(self.parse_term_query(inner)?);
274 }
275 _ => {}
276 }
277 }
278
279 let query = inner_query.ok_or("No query in primary")?;
280
281 if negated {
282 Ok(ParsedQuery::Not(Box::new(query)))
283 } else {
284 Ok(query)
285 }
286 }
287
288 fn parse_term_query(&self, pair: pest::iterators::Pair<Rule>) -> Result<ParsedQuery, String> {
289 let mut field = None;
290 let mut term = String::new();
291
292 for inner in pair.into_inner() {
293 match inner.as_rule() {
294 Rule::field_spec => {
295 field = Some(inner.into_inner().next().unwrap().as_str().to_string());
296 }
297 Rule::term => {
298 term = inner.as_str().to_string();
299 }
300 _ => {}
301 }
302 }
303
304 Ok(ParsedQuery::Term { field, term })
305 }
306
307 fn parse_phrase_query(&self, pair: pest::iterators::Pair<Rule>) -> Result<ParsedQuery, String> {
308 let mut field = None;
309 let mut phrase = String::new();
310
311 for inner in pair.into_inner() {
312 match inner.as_rule() {
313 Rule::field_spec => {
314 field = Some(inner.into_inner().next().unwrap().as_str().to_string());
315 }
316 Rule::quoted_string => {
317 let s = inner.as_str();
318 phrase = s[1..s.len() - 1].to_string();
319 }
320 _ => {}
321 }
322 }
323
324 Ok(ParsedQuery::Phrase { field, phrase })
325 }
326
327 fn parse_ann_query(&self, pair: pest::iterators::Pair<Rule>) -> Result<ParsedQuery, String> {
329 let mut field = String::new();
330 let mut vector = Vec::new();
331 let mut nprobe = 32usize;
332 let mut rerank = 3.0f32;
333
334 for inner in pair.into_inner() {
335 match inner.as_rule() {
336 Rule::field_spec => {
337 field = inner.into_inner().next().unwrap().as_str().to_string();
338 }
339 Rule::vector_array => {
340 for num in inner.into_inner() {
341 if num.as_rule() == Rule::number
342 && let Ok(v) = num.as_str().parse::<f32>()
343 {
344 vector.push(v);
345 }
346 }
347 }
348 Rule::ann_params => {
349 for param in inner.into_inner() {
350 if param.as_rule() == Rule::ann_param {
351 let param_str = param.as_str();
353 if let Some(eq_pos) = param_str.find('=') {
354 let name = ¶m_str[..eq_pos];
355 let value = ¶m_str[eq_pos + 1..];
356 match name {
357 "nprobe" => nprobe = value.parse().unwrap_or(0),
358 "rerank" => rerank = value.parse().unwrap_or(0.0),
359 _ => {}
360 }
361 }
362 }
363 }
364 }
365 _ => {}
366 }
367 }
368
369 Ok(ParsedQuery::Ann {
370 field,
371 vector,
372 nprobe,
373 rerank,
374 })
375 }
376
377 fn parse_sparse_query(&self, pair: pest::iterators::Pair<Rule>) -> Result<ParsedQuery, String> {
379 let mut field = String::new();
380 let mut vector = Vec::new();
381
382 for inner in pair.into_inner() {
383 match inner.as_rule() {
384 Rule::field_spec => {
385 field = inner.into_inner().next().unwrap().as_str().to_string();
386 }
387 Rule::sparse_map => {
388 for entry in inner.into_inner() {
389 if entry.as_rule() == Rule::sparse_entry {
390 let mut entry_inner = entry.into_inner();
391 if let (Some(idx), Some(weight)) =
392 (entry_inner.next(), entry_inner.next())
393 && let (Ok(i), Ok(w)) =
394 (idx.as_str().parse::<u32>(), weight.as_str().parse::<f32>())
395 {
396 vector.push((i, w));
397 }
398 }
399 }
400 }
401 _ => {}
402 }
403 }
404
405 Ok(ParsedQuery::Sparse { field, vector })
406 }
407
408 fn build_query(&self, parsed: &ParsedQuery) -> Result<Box<dyn Query>, String> {
409 use crate::query::{DenseVectorQuery, SparseVectorQuery};
410
411 match parsed {
412 ParsedQuery::Term { field, term } => self.build_term_query(field.as_deref(), term),
413 ParsedQuery::Phrase { field, phrase } => {
414 self.build_phrase_query(field.as_deref(), phrase)
415 }
416 ParsedQuery::Ann {
417 field,
418 vector,
419 nprobe,
420 rerank,
421 } => {
422 let field_id = self
423 .schema
424 .get_field(field)
425 .ok_or_else(|| format!("Unknown field: {}", field))?;
426 let query = DenseVectorQuery::new(field_id, vector.clone())
427 .with_nprobe(*nprobe)
428 .with_rerank_factor(*rerank);
429 Ok(Box::new(query))
430 }
431 ParsedQuery::Sparse { field, vector } => {
432 let field_id = self
433 .schema
434 .get_field(field)
435 .ok_or_else(|| format!("Unknown field: {}", field))?;
436 let query = SparseVectorQuery::new(field_id, vector.clone());
437 Ok(Box::new(query))
438 }
439 ParsedQuery::And(queries) => {
440 let mut bool_query = BooleanQuery::new();
441 for q in queries {
442 bool_query = bool_query.must(self.build_query(q)?);
443 }
444 Ok(Box::new(bool_query))
445 }
446 ParsedQuery::Or(queries) => {
447 let mut bool_query = BooleanQuery::new();
448 for q in queries {
449 bool_query = bool_query.should(self.build_query(q)?);
450 }
451 Ok(Box::new(bool_query))
452 }
453 ParsedQuery::Not(inner) => {
454 let mut bool_query = BooleanQuery::new();
456 bool_query = bool_query.must_not(self.build_query(inner)?);
457 Ok(Box::new(bool_query))
458 }
459 }
460 }
461
462 fn build_term_query(&self, field: Option<&str>, term: &str) -> Result<Box<dyn Query>, String> {
463 if let Some(field_name) = field {
464 let field_id = self
466 .schema
467 .get_field(field_name)
468 .ok_or_else(|| format!("Unknown field: {}", field_name))?;
469 let tokenizer = self.get_tokenizer(field_id);
470 let tokens: Vec<String> = tokenizer
471 .tokenize(term)
472 .into_iter()
473 .map(|t| t.text.to_lowercase())
474 .collect();
475
476 if tokens.is_empty() {
477 return Err("No tokens in term".to_string());
478 }
479
480 if tokens.len() == 1 {
481 Ok(Box::new(TermQuery::text(field_id, &tokens[0])))
482 } else {
483 let mut bool_query = BooleanQuery::new();
485 for token in &tokens {
486 bool_query = bool_query.must(TermQuery::text(field_id, token));
487 }
488 Ok(Box::new(bool_query))
489 }
490 } else if !self.default_fields.is_empty() {
491 let tokenizer = self.get_tokenizer(self.default_fields[0]);
493 let tokens: Vec<String> = tokenizer
494 .tokenize(term)
495 .into_iter()
496 .map(|t| t.text.to_lowercase())
497 .collect();
498
499 if tokens.is_empty() {
500 return Err("No tokens in term".to_string());
501 }
502
503 let mut bool_query = BooleanQuery::new();
505 for token in &tokens {
506 for &field_id in &self.default_fields {
507 bool_query = bool_query.should(TermQuery::text(field_id, token));
508 }
509 }
510 Ok(Box::new(bool_query))
511 } else {
512 Err("No field specified and no default fields configured".to_string())
513 }
514 }
515
516 fn build_phrase_query(
517 &self,
518 field: Option<&str>,
519 phrase: &str,
520 ) -> Result<Box<dyn Query>, String> {
521 let field_id = if let Some(field_name) = field {
523 self.schema
524 .get_field(field_name)
525 .ok_or_else(|| format!("Unknown field: {}", field_name))?
526 } else if !self.default_fields.is_empty() {
527 self.default_fields[0]
528 } else {
529 return Err("No field specified and no default fields configured".to_string());
530 };
531
532 let tokenizer = self.get_tokenizer(field_id);
533 let tokens: Vec<String> = tokenizer
534 .tokenize(phrase)
535 .into_iter()
536 .map(|t| t.text.to_lowercase())
537 .collect();
538
539 if tokens.is_empty() {
540 return Err("No tokens in phrase".to_string());
541 }
542
543 if tokens.len() == 1 {
544 return Ok(Box::new(TermQuery::text(field_id, &tokens[0])));
545 }
546
547 let mut bool_query = BooleanQuery::new();
549 for token in &tokens {
550 bool_query = bool_query.must(TermQuery::text(field_id, token));
551 }
552
553 if field.is_none() && self.default_fields.len() > 1 {
555 let mut outer = BooleanQuery::new();
556 for &f in &self.default_fields {
557 let tokenizer = self.get_tokenizer(f);
558 let tokens: Vec<String> = tokenizer
559 .tokenize(phrase)
560 .into_iter()
561 .map(|t| t.text.to_lowercase())
562 .collect();
563
564 let mut field_query = BooleanQuery::new();
565 for token in &tokens {
566 field_query = field_query.must(TermQuery::text(f, token));
567 }
568 outer = outer.should(field_query);
569 }
570 return Ok(Box::new(outer));
571 }
572
573 Ok(Box::new(bool_query))
574 }
575
576 fn get_tokenizer(&self, field: Field) -> BoxedTokenizer {
577 let tokenizer_name = self
579 .schema
580 .get_field_entry(field)
581 .and_then(|entry| entry.tokenizer.as_deref())
582 .unwrap_or("default");
583
584 self.tokenizers
585 .get(tokenizer_name)
586 .unwrap_or_else(|| Box::new(crate::tokenizer::LowercaseTokenizer))
587 }
588}
589
590#[cfg(test)]
591mod tests {
592 use super::*;
593 use crate::dsl::SchemaBuilder;
594 use crate::tokenizer::TokenizerRegistry;
595
596 fn setup() -> (Arc<Schema>, Vec<Field>, Arc<TokenizerRegistry>) {
597 let mut builder = SchemaBuilder::default();
598 let title = builder.add_text_field("title", true, true);
599 let body = builder.add_text_field("body", true, true);
600 let schema = Arc::new(builder.build());
601 let tokenizers = Arc::new(TokenizerRegistry::default());
602 (schema, vec![title, body], tokenizers)
603 }
604
605 #[test]
606 fn test_simple_term() {
607 let (schema, default_fields, tokenizers) = setup();
608 let parser = QueryLanguageParser::new(schema, default_fields, tokenizers);
609
610 let _query = parser.parse("rust").unwrap();
612 }
613
614 #[test]
615 fn test_field_term() {
616 let (schema, default_fields, tokenizers) = setup();
617 let parser = QueryLanguageParser::new(schema, default_fields, tokenizers);
618
619 let _query = parser.parse("title:rust").unwrap();
621 }
622
623 #[test]
624 fn test_boolean_and() {
625 let (schema, default_fields, tokenizers) = setup();
626 let parser = QueryLanguageParser::new(schema, default_fields, tokenizers);
627
628 let _query = parser.parse("rust AND programming").unwrap();
630 }
631
632 #[test]
633 fn test_match_query() {
634 let (schema, default_fields, tokenizers) = setup();
635 let parser = QueryLanguageParser::new(schema, default_fields, tokenizers);
636
637 let _query = parser.parse("hello world").unwrap();
639 }
640
641 #[test]
642 fn test_phrase_query() {
643 let (schema, default_fields, tokenizers) = setup();
644 let parser = QueryLanguageParser::new(schema, default_fields, tokenizers);
645
646 let _query = parser.parse("\"hello world\"").unwrap();
648 }
649
650 #[test]
651 fn test_boolean_or() {
652 let (schema, default_fields, tokenizers) = setup();
653 let parser = QueryLanguageParser::new(schema, default_fields, tokenizers);
654
655 let _query = parser.parse("rust OR python").unwrap();
657 }
658
659 #[test]
660 fn test_complex_query() {
661 let (schema, default_fields, tokenizers) = setup();
662 let parser = QueryLanguageParser::new(schema, default_fields, tokenizers);
663
664 let _query = parser.parse("(rust OR python) AND programming").unwrap();
666 }
667
668 #[test]
669 fn test_router_exclusive_mode() {
670 use crate::dsl::query_field_router::{QueryFieldRouter, QueryRouterRule, RoutingMode};
671
672 let mut builder = SchemaBuilder::default();
673 let _title = builder.add_text_field("title", true, true);
674 let _uri = builder.add_text_field("uri", true, true);
675 let schema = Arc::new(builder.build());
676 let tokenizers = Arc::new(TokenizerRegistry::default());
677
678 let router = QueryFieldRouter::from_rules(&[QueryRouterRule {
679 pattern: r"^doi:(10\.\d{4,}/[^\s]+)$".to_string(),
680 substitution: "doi://{1}".to_string(),
681 target_field: "uri".to_string(),
682 mode: RoutingMode::Exclusive,
683 }])
684 .unwrap();
685
686 let parser = QueryLanguageParser::with_router(schema, vec![], tokenizers, router);
687
688 let _query = parser.parse("doi:10.1234/test.123").unwrap();
690 }
691
692 #[test]
693 fn test_router_additional_mode() {
694 use crate::dsl::query_field_router::{QueryFieldRouter, QueryRouterRule, RoutingMode};
695
696 let mut builder = SchemaBuilder::default();
697 let title = builder.add_text_field("title", true, true);
698 let _uri = builder.add_text_field("uri", true, true);
699 let schema = Arc::new(builder.build());
700 let tokenizers = Arc::new(TokenizerRegistry::default());
701
702 let router = QueryFieldRouter::from_rules(&[QueryRouterRule {
703 pattern: r"#(\d+)".to_string(),
704 substitution: "{1}".to_string(),
705 target_field: "uri".to_string(),
706 mode: RoutingMode::Additional,
707 }])
708 .unwrap();
709
710 let parser = QueryLanguageParser::with_router(schema, vec![title], tokenizers, router);
711
712 let _query = parser.parse("#42").unwrap();
714 }
715
716 #[test]
717 fn test_router_no_match_falls_through() {
718 use crate::dsl::query_field_router::{QueryFieldRouter, QueryRouterRule, RoutingMode};
719
720 let mut builder = SchemaBuilder::default();
721 let title = builder.add_text_field("title", true, true);
722 let _uri = builder.add_text_field("uri", true, true);
723 let schema = Arc::new(builder.build());
724 let tokenizers = Arc::new(TokenizerRegistry::default());
725
726 let router = QueryFieldRouter::from_rules(&[QueryRouterRule {
727 pattern: r"^doi:".to_string(),
728 substitution: "{0}".to_string(),
729 target_field: "uri".to_string(),
730 mode: RoutingMode::Exclusive,
731 }])
732 .unwrap();
733
734 let parser = QueryLanguageParser::with_router(schema, vec![title], tokenizers, router);
735
736 let _query = parser.parse("rust programming").unwrap();
738 }
739
740 #[test]
741 fn test_router_invalid_target_field() {
742 use crate::dsl::query_field_router::{QueryFieldRouter, QueryRouterRule, RoutingMode};
743
744 let mut builder = SchemaBuilder::default();
745 let _title = builder.add_text_field("title", true, true);
746 let schema = Arc::new(builder.build());
747 let tokenizers = Arc::new(TokenizerRegistry::default());
748
749 let router = QueryFieldRouter::from_rules(&[QueryRouterRule {
750 pattern: r"test".to_string(),
751 substitution: "{0}".to_string(),
752 target_field: "nonexistent".to_string(),
753 mode: RoutingMode::Exclusive,
754 }])
755 .unwrap();
756
757 let parser = QueryLanguageParser::with_router(schema, vec![], tokenizers, router);
758
759 let result = parser.parse("test");
761 assert!(result.is_err());
762 let err = result.err().unwrap();
763 assert!(err.contains("Unknown target field"));
764 }
765
766 #[test]
767 fn test_parse_ann_query() {
768 let mut builder = SchemaBuilder::default();
769 let embedding = builder.add_dense_vector_field("embedding", 128, true, true);
770 let schema = Arc::new(builder.build());
771 let tokenizers = Arc::new(TokenizerRegistry::default());
772
773 let parser = QueryLanguageParser::new(schema, vec![embedding], tokenizers);
774
775 let result = parser.parse_query_string("embedding:ann([1.0, 2.0, 3.0], nprobe=32)");
777 assert!(result.is_ok(), "Failed to parse ANN query: {:?}", result);
778
779 if let Ok(ParsedQuery::Ann {
780 field,
781 vector,
782 nprobe,
783 rerank,
784 }) = result
785 {
786 assert_eq!(field, "embedding");
787 assert_eq!(vector, vec![1.0, 2.0, 3.0]);
788 assert_eq!(nprobe, 32);
789 assert_eq!(rerank, 3.0); } else {
791 panic!("Expected Ann query, got: {:?}", result);
792 }
793 }
794
795 #[test]
796 fn test_parse_sparse_query() {
797 let mut builder = SchemaBuilder::default();
798 let sparse = builder.add_text_field("sparse", true, true);
799 let schema = Arc::new(builder.build());
800 let tokenizers = Arc::new(TokenizerRegistry::default());
801
802 let parser = QueryLanguageParser::new(schema, vec![sparse], tokenizers);
803
804 let result = parser.parse_query_string("sparse:sparse({1: 0.5, 5: 0.3})");
806 assert!(result.is_ok(), "Failed to parse sparse query: {:?}", result);
807
808 if let Ok(ParsedQuery::Sparse { field, vector }) = result {
809 assert_eq!(field, "sparse");
810 assert_eq!(vector, vec![(1, 0.5), (5, 0.3)]);
811 } else {
812 panic!("Expected Sparse query, got: {:?}", result);
813 }
814 }
815}