1use std::sync::OnceLock;
14
15use regex::Regex;
16
17use crate::types::SearchType;
18
19#[derive(Debug, Clone)]
21pub struct RouteResult {
22 pub search_type: SearchType,
24 pub confidence: f32,
26 pub runner_up: SearchType,
28 pub runner_up_score: f32,
30 pub all_scores: Vec<(SearchType, f32)>,
34}
35
36impl RouteResult {
37 pub fn is_confident(&self) -> bool {
41 self.confidence >= 2.0 * self.runner_up_score.max(1.0)
42 }
43}
44
45const DEFAULT_TYPE: SearchType = SearchType::GraphCompletion;
48const DEFAULT_BASE_SCORE: f32 = 2.0;
49const NEGATION_WINDOW: usize = 20;
51const NEGATION_WORDS: &[&str] = &["not", "n't", "no", "never", "without", "lack"];
53
54fn is_word_boundary(text: &str, idx: usize, len: usize) -> bool {
58 let before_ok = if idx == 0 {
59 true
60 } else {
61 text[..idx]
62 .chars()
63 .next_back()
64 .map(|c| !c.is_alphanumeric() && c != '_')
65 .unwrap_or(true)
66 };
67 let after_idx = idx + len;
68 let after_ok = if after_idx >= text.len() {
69 true
70 } else {
71 text[after_idx..]
72 .chars()
73 .next()
74 .map(|c| !c.is_alphanumeric() && c != '_')
75 .unwrap_or(true)
76 };
77 before_ok && after_ok
78}
79
80fn contains_word(text: &str, kw: &str) -> Option<usize> {
85 if kw.is_empty() {
86 return None;
87 }
88 let mut cursor = 0usize;
89 while let Some(rel) = text[cursor..].find(kw) {
90 let pos = cursor + rel;
91 if is_word_boundary(text, pos, kw.len()) {
92 return Some(pos);
93 }
94 let mut step = pos + 1;
97 while step < text.len() && !text.is_char_boundary(step) {
98 step += 1;
99 }
100 cursor = step;
101 }
102 None
103}
104
105fn is_negated(lower: &str, match_start: usize) -> bool {
109 let mut window_start = match_start.saturating_sub(NEGATION_WINDOW);
110 while window_start > 0 && !lower.is_char_boundary(window_start) {
111 window_start -= 1;
112 }
113 let prefix = &lower[window_start..match_start];
114 for neg in NEGATION_WORDS {
115 if contains_word(prefix, neg).is_some() {
117 return true;
118 }
119 }
120 false
121}
122
123enum Matcher {
128 Keywords(&'static [&'static str]),
131 Regex {
134 cell: &'static OnceLock<Regex>,
135 pattern: &'static str,
136 case_insensitive: bool,
137 },
138}
139
140struct Rule {
141 matcher: Matcher,
142 target: SearchType,
143 weight: f32,
144 respects_negation: bool,
146}
147
148static RE_CYPHER_PREFIX: OnceLock<Regex> = OnceLock::new();
152static RE_LEXICAL_QUOTED: OnceLock<Regex> = OnceLock::new();
153static RE_CODE_SYNTAX: OnceLock<Regex> = OnceLock::new();
154static RE_RELATIONSHIP_HOW: OnceLock<Regex> = OnceLock::new();
155static RE_RELATIONSHIP_WHAT: OnceLock<Regex> = OnceLock::new();
156static RE_YEAR: OnceLock<Regex> = OnceLock::new();
157static RE_YEAR_RANGE: OnceLock<Regex> = OnceLock::new();
158
159fn rules() -> &'static [Rule] {
160 static RULES: OnceLock<Vec<Rule>> = OnceLock::new();
161 RULES.get_or_init(|| {
162 vec![
163 Rule {
165 matcher: Matcher::Regex {
166 cell: &RE_CYPHER_PREFIX,
167 pattern: r"(^MATCH\s|^RETURN\s|^CREATE\s|^MERGE\s|--\(|\)--)",
169 case_insensitive: false,
170 },
171 target: SearchType::Cypher,
172 weight: 10.0,
173 respects_negation: true,
178 },
179 Rule {
181 matcher: Matcher::Keywords(&[
183 "coding rule",
184 "coding rules",
185 "code review",
186 "best practice",
187 "lint",
188 "linting",
189 "linter",
190 "refactor",
191 "refactoring",
192 ]),
193 target: SearchType::CodingRules,
194 weight: 5.0,
195 respects_negation: true,
196 },
197 Rule {
198 matcher: Matcher::Regex {
200 cell: &RE_CODE_SYNTAX,
201 pattern: r"\b(def |return |async |await |import |class \w+\(|\.py\b|function\s+\w+\()",
202 case_insensitive: true,
203 },
204 target: SearchType::CodingRules,
205 weight: 3.0,
206 respects_negation: true,
207 },
208 Rule {
210 matcher: Matcher::Regex {
212 cell: &RE_LEXICAL_QUOTED,
213 pattern: r#"^"[^"]+"$"#,
214 case_insensitive: false,
215 },
216 target: SearchType::ChunksLexical,
217 weight: 8.0,
218 respects_negation: true,
219 },
220 Rule {
221 matcher: Matcher::Keywords(&[
226 "exact",
227 "verbatim",
228 "literal",
229 "word for word",
230 "word-for-word",
231 "word.for.word",
232 "word_for_word",
233 ]),
234 target: SearchType::ChunksLexical,
235 weight: 4.0,
236 respects_negation: true,
237 },
238 Rule {
240 matcher: Matcher::Keywords(&[
242 "summarize",
243 "summary",
244 "overview",
245 "outline",
246 "tldr",
247 "tl;dr",
248 "gist",
249 "main point",
250 "main points",
251 "key takeaway",
252 "key takeaways",
253 "high level",
254 "high-level",
255 "highlevel",
256 ]),
257 target: SearchType::GraphSummaryCompletion,
258 weight: 5.0,
259 respects_negation: true,
260 },
261 Rule {
263 matcher: Matcher::Keywords(&[
265 "why",
266 "explain",
267 "reasoning",
268 "step by step",
269 "step-by-step",
270 "step.by.step",
271 "chain of thought",
272 ]),
273 target: SearchType::GraphCompletionCot,
274 weight: 4.0,
275 respects_negation: true,
276 },
277 Rule {
278 matcher: Matcher::Keywords(&["because", "therefore", "consequently"]),
280 target: SearchType::GraphCompletionCot,
281 weight: 2.0,
282 respects_negation: true,
283 },
284 Rule {
286 matcher: Matcher::Regex {
288 cell: &RE_RELATIONSHIP_HOW,
289 pattern: r"\b(how (is|are|does|do)\s+\w+\s+(related|connected|linked))\b",
290 case_insensitive: true,
291 },
292 target: SearchType::GraphCompletionContextExtension,
293 weight: 5.0,
294 respects_negation: true,
295 },
296 Rule {
297 matcher: Matcher::Regex {
299 cell: &RE_RELATIONSHIP_WHAT,
300 pattern: r"\b(what (connects|links|ties)|path between|degree of separation)\b",
301 case_insensitive: true,
302 },
303 target: SearchType::GraphCompletionContextExtension,
304 weight: 5.0,
305 respects_negation: true,
306 },
307 Rule {
308 matcher: Matcher::Keywords(&[
310 "connection",
311 "relationship",
312 "related to",
313 "linked to",
314 ]),
315 target: SearchType::GraphCompletionContextExtension,
316 weight: 3.0,
317 respects_negation: true,
318 },
319 Rule {
321 matcher: Matcher::Keywords(&[
323 "when", "before", "after", "during", "since", "until",
324 ]),
325 target: SearchType::Temporal,
326 weight: 3.0,
327 respects_negation: true,
328 },
329 Rule {
330 matcher: Matcher::Keywords(&[
332 "timeline",
333 "chronolog",
334 "chronology",
335 "chronological",
336 "era",
337 "decade",
338 "century",
339 ]),
340 target: SearchType::Temporal,
341 weight: 4.0,
342 respects_negation: true,
343 },
344 Rule {
345 matcher: Matcher::Regex {
347 cell: &RE_YEAR,
348 pattern: r"\b\d{4}s?\b",
349 case_insensitive: false,
350 },
351 target: SearchType::Temporal,
352 weight: 3.0,
353 respects_negation: true,
354 },
355 Rule {
356 matcher: Matcher::Regex {
358 cell: &RE_YEAR_RANGE,
359 pattern: r"\bbetween\s+\d{4}\s+and\s+\d{4}\b",
360 case_insensitive: true,
361 },
362 target: SearchType::Temporal,
363 weight: 6.0,
364 respects_negation: true,
365 },
366 ]
367 })
368}
369
370fn compile(
371 cell: &'static OnceLock<Regex>,
372 pattern: &str,
373 case_insensitive: bool,
374) -> &'static Regex {
375 cell.get_or_init(|| {
376 let mut builder = regex::RegexBuilder::new(pattern);
377 builder.case_insensitive(case_insensitive);
378 builder
379 .build()
380 .unwrap_or_else(|e| panic!("query_router: failed to compile regex {pattern:?}: {e}"))
381 })
382}
383
384fn rule_match(rule: &Rule, trimmed: &str, lower: &str) -> Option<usize> {
387 match &rule.matcher {
388 Matcher::Keywords(kws) => {
389 let mut earliest: Option<usize> = None;
394 for kw in *kws {
395 if let Some(pos) = contains_word(lower, kw) {
396 earliest = Some(earliest.map_or(pos, |e| e.min(pos)));
397 }
398 }
399 earliest
400 }
401 Matcher::Regex {
402 cell,
403 pattern,
404 case_insensitive,
405 } => {
406 let re = compile(cell, pattern, *case_insensitive);
407 re.find(trimmed).map(|m| m.start())
408 }
409 }
410}
411
412pub fn route_query(query: &str) -> RouteResult {
424 let trimmed = query.trim();
425 let lower = trimmed.to_lowercase();
426
427 let mut scores: Vec<(SearchType, f32)> = Vec::new();
430
431 for rule in rules() {
432 let Some(m_start) = rule_match(rule, trimmed, &lower) else {
433 continue;
434 };
435 if rule.respects_negation && is_negated(&lower, m_start) {
441 continue;
442 }
443 if let Some(entry) = scores.iter_mut().find(|(s, _)| *s == rule.target) {
444 entry.1 += rule.weight;
445 } else {
446 scores.push((rule.target, rule.weight));
447 }
448 }
449
450 if scores.is_empty() {
451 return RouteResult {
452 search_type: DEFAULT_TYPE,
453 confidence: DEFAULT_BASE_SCORE,
454 runner_up: DEFAULT_TYPE,
455 runner_up_score: 0.0,
456 all_scores: Vec::new(),
457 };
458 }
459
460 scores.sort_by(|a, b| b.1.total_cmp(&a.1));
462
463 let (best_type, best_score) = scores[0];
464 let (ru_type, ru_score) = scores.get(1).copied().unwrap_or((DEFAULT_TYPE, 0.0));
465
466 if best_score < DEFAULT_BASE_SCORE {
467 return RouteResult {
470 search_type: DEFAULT_TYPE,
471 confidence: best_score,
472 runner_up: best_type,
473 runner_up_score: best_score,
474 all_scores: scores,
475 };
476 }
477
478 RouteResult {
479 search_type: best_type,
480 confidence: best_score,
481 runner_up: ru_type,
482 runner_up_score: ru_score,
483 all_scores: scores,
484 }
485}
486
487#[cfg(test)]
488mod tests {
489 use super::*;
490
491 mod factual_queries {
494 use super::*;
495
496 #[test]
497 fn simple_who() {
498 assert_eq!(
499 route_query("Who won Nobel Prizes?").search_type,
500 SearchType::GraphCompletion
501 );
502 }
503
504 #[test]
505 fn simple_what() {
506 assert_eq!(
507 route_query("What did Einstein discover?").search_type,
508 SearchType::GraphCompletion
509 );
510 }
511
512 #[test]
513 fn short_list() {
514 assert_eq!(
515 route_query("List all scientists").search_type,
516 SearchType::GraphCompletion
517 );
518 }
519 }
520
521 mod cypher {
524 use super::*;
525
526 #[test]
527 fn match_statement() {
528 assert_eq!(
529 route_query("MATCH (n:Person) RETURN n.name").search_type,
530 SearchType::Cypher
531 );
532 }
533
534 #[test]
535 fn return_statement() {
536 assert_eq!(route_query("RETURN 1").search_type, SearchType::Cypher);
537 }
538 }
539
540 mod coding_rules {
543 use super::*;
544
545 #[test]
546 fn coding_rules_phrase() {
547 let r = route_query("What coding rules apply to error handling?");
548 assert_eq!(r.search_type, SearchType::CodingRules);
549 }
550
551 #[test]
552 fn code_review() {
553 assert_eq!(
554 route_query("Show me the code review guidelines").search_type,
555 SearchType::CodingRules
556 );
557 }
558
559 #[test]
560 fn bare_class_is_not_code() {
561 let result = route_query("What class of animal is a dolphin?");
562 assert_ne!(result.search_type, SearchType::CodingRules);
563 }
564
565 #[test]
566 fn bare_function_is_not_code() {
567 let result = route_query("What is the function of the liver?");
568 assert_ne!(result.search_type, SearchType::CodingRules);
569 }
570 }
571
572 mod lexical {
575 use super::*;
576
577 #[test]
578 fn quoted_phrase() {
579 assert_eq!(
580 route_query("\"polonium and radium\"").search_type,
581 SearchType::ChunksLexical
582 );
583 }
584
585 #[test]
586 fn exact_keyword() {
587 let r = route_query("Find the exact phrase in the documents");
588 assert_eq!(r.search_type, SearchType::ChunksLexical);
589 }
590 }
591
592 mod summary {
595 use super::*;
596
597 #[test]
598 fn summarize() {
599 let r = route_query("Summarize everything about Marie Curie");
600 assert_eq!(r.search_type, SearchType::GraphSummaryCompletion);
601 }
602
603 #[test]
604 fn overview() {
605 let r = route_query("Give me an overview of the project");
606 assert_eq!(r.search_type, SearchType::GraphSummaryCompletion);
607 }
608
609 #[test]
610 fn tldr() {
611 assert_eq!(
612 route_query("tldr of the report").search_type,
613 SearchType::GraphSummaryCompletion
614 );
615 }
616 }
617
618 mod reasoning {
621 use super::*;
622
623 #[test]
624 fn why_question() {
625 let r = route_query("Why did Curie win two Nobel Prizes?");
626 assert_eq!(r.search_type, SearchType::GraphCompletionCot);
627 }
628
629 #[test]
630 fn explain() {
631 let r = route_query("Explain the theory of relativity");
632 assert_eq!(r.search_type, SearchType::GraphCompletionCot);
633 }
634 }
635
636 mod relationship {
639 use super::*;
640
641 #[test]
642 fn connection_between() {
643 let r = route_query("How is Einstein connected to the Sorbonne?");
644 assert_eq!(r.search_type, SearchType::GraphCompletionContextExtension);
645 }
646
647 #[test]
648 fn related_to() {
649 let r = route_query("What entities are related to physics?");
650 assert_eq!(r.search_type, SearchType::GraphCompletionContextExtension);
651 }
652
653 #[test]
654 fn between_not_temporal() {
655 let r = route_query("What is the relationship between supply and demand?");
656 assert_eq!(r.search_type, SearchType::GraphCompletionContextExtension);
657 }
658 }
659
660 mod temporal {
663 use super::*;
664
665 #[test]
666 fn when_question() {
667 assert_eq!(
668 route_query("When did Einstein publish?").search_type,
669 SearchType::Temporal
670 );
671 }
672
673 #[test]
674 fn year_range() {
675 let r = route_query("What happened between 1910 and 1920?");
676 assert_eq!(r.search_type, SearchType::Temporal);
677 }
678
679 #[test]
680 fn timeline() {
681 assert_eq!(
682 route_query("Show the timeline of discoveries").search_type,
683 SearchType::Temporal
684 );
685 }
686
687 #[test]
688 fn specific_year() {
689 assert_eq!(
690 route_query("What was discovered in 1915?").search_type,
691 SearchType::Temporal
692 );
693 }
694 }
695
696 mod negation {
699 use super::*;
700
701 #[test]
702 fn not_related_suppresses_graph() {
703 let r = route_query("What is not related to physics?");
704 assert_ne!(r.search_type, SearchType::GraphCompletionContextExtension);
705 }
706
707 #[test]
708 fn no_connection_suppresses_graph() {
709 let r = route_query("There is no connection between these topics");
710 assert_ne!(r.search_type, SearchType::GraphCompletionContextExtension);
711 }
712
713 #[test]
714 fn negation_does_not_affect_distant_match() {
715 let r = route_query(
716 "This is not about food at all, however I want to know how is X connected to Y?",
717 );
718 assert_eq!(r.search_type, SearchType::GraphCompletionContextExtension);
719 }
720 }
721
722 mod confidence {
725 use super::*;
726
727 #[test]
728 fn high_confidence_for_cypher() {
729 let r = route_query("MATCH (n) RETURN n");
730 assert!(r.confidence >= 10.0);
731 assert!(r.is_confident());
732 }
733
734 #[test]
735 fn runner_up_populated() {
736 let r = route_query("Summarize the timeline of discoveries");
737 assert_eq!(r.search_type, SearchType::GraphSummaryCompletion);
739 assert!(!r.all_scores.is_empty());
740 }
741
742 #[test]
743 fn default_has_base_confidence() {
744 let r = route_query("Tell me something interesting");
745 assert_eq!(r.search_type, SearchType::GraphCompletion);
746 assert!(r.confidence >= 0.0);
747 }
748 }
749
750 mod ambiguous {
753 use super::*;
754
755 #[test]
756 fn temporal_beats_graph_for_years() {
757 let r = route_query("What happened between 1910 and 1920?");
758 assert_eq!(r.search_type, SearchType::Temporal);
759 }
760
761 #[test]
762 fn summary_with_temporal_word() {
763 let r = route_query("Summarize the timeline of Einstein's work");
764 assert_eq!(r.search_type, SearchType::GraphSummaryCompletion);
765 }
766
767 #[test]
768 fn default_for_vague_query() {
769 assert_eq!(
770 route_query("Tell me something").search_type,
771 SearchType::GraphCompletion
772 );
773 }
774 }
775}