use std::sync::OnceLock;
use regex::Regex;
use crate::types::SearchType;
#[derive(Debug, Clone)]
pub struct RouteResult {
pub search_type: SearchType,
pub confidence: f32,
pub runner_up: SearchType,
pub runner_up_score: f32,
pub all_scores: Vec<(SearchType, f32)>,
}
impl RouteResult {
pub fn is_confident(&self) -> bool {
self.confidence >= 2.0 * self.runner_up_score.max(1.0)
}
}
const DEFAULT_TYPE: SearchType = SearchType::GraphCompletion;
const DEFAULT_BASE_SCORE: f32 = 2.0;
const NEGATION_WINDOW: usize = 20;
const NEGATION_WORDS: &[&str] = &["not", "n't", "no", "never", "without", "lack"];
fn is_word_boundary(text: &str, idx: usize, len: usize) -> bool {
let before_ok = if idx == 0 {
true
} else {
text[..idx]
.chars()
.next_back()
.map(|c| !c.is_alphanumeric() && c != '_')
.unwrap_or(true)
};
let after_idx = idx + len;
let after_ok = if after_idx >= text.len() {
true
} else {
text[after_idx..]
.chars()
.next()
.map(|c| !c.is_alphanumeric() && c != '_')
.unwrap_or(true)
};
before_ok && after_ok
}
fn contains_word(text: &str, kw: &str) -> Option<usize> {
if kw.is_empty() {
return None;
}
let mut cursor = 0usize;
while let Some(rel) = text[cursor..].find(kw) {
let pos = cursor + rel;
if is_word_boundary(text, pos, kw.len()) {
return Some(pos);
}
let mut step = pos + 1;
while step < text.len() && !text.is_char_boundary(step) {
step += 1;
}
cursor = step;
}
None
}
fn is_negated(lower: &str, match_start: usize) -> bool {
let mut window_start = match_start.saturating_sub(NEGATION_WINDOW);
while window_start > 0 && !lower.is_char_boundary(window_start) {
window_start -= 1;
}
let prefix = &lower[window_start..match_start];
for neg in NEGATION_WORDS {
if contains_word(prefix, neg).is_some() {
return true;
}
}
false
}
enum Matcher {
Keywords(&'static [&'static str]),
Regex {
cell: &'static OnceLock<Regex>,
pattern: &'static str,
case_insensitive: bool,
},
}
struct Rule {
matcher: Matcher,
target: SearchType,
weight: f32,
respects_negation: bool,
}
static RE_CYPHER_PREFIX: OnceLock<Regex> = OnceLock::new();
static RE_LEXICAL_QUOTED: OnceLock<Regex> = OnceLock::new();
static RE_CODE_SYNTAX: OnceLock<Regex> = OnceLock::new();
static RE_RELATIONSHIP_HOW: OnceLock<Regex> = OnceLock::new();
static RE_RELATIONSHIP_WHAT: OnceLock<Regex> = OnceLock::new();
static RE_YEAR: OnceLock<Regex> = OnceLock::new();
static RE_YEAR_RANGE: OnceLock<Regex> = OnceLock::new();
fn rules() -> &'static [Rule] {
static RULES: OnceLock<Vec<Rule>> = OnceLock::new();
RULES.get_or_init(|| {
vec![
Rule {
matcher: Matcher::Regex {
cell: &RE_CYPHER_PREFIX,
pattern: r"(^MATCH\s|^RETURN\s|^CREATE\s|^MERGE\s|--\(|\)--)",
case_insensitive: false,
},
target: SearchType::Cypher,
weight: 10.0,
respects_negation: true,
},
Rule {
matcher: Matcher::Keywords(&[
"coding rule",
"coding rules",
"code review",
"best practice",
"lint",
"linting",
"linter",
"refactor",
"refactoring",
]),
target: SearchType::CodingRules,
weight: 5.0,
respects_negation: true,
},
Rule {
matcher: Matcher::Regex {
cell: &RE_CODE_SYNTAX,
pattern: r"\b(def |return |async |await |import |class \w+\(|\.py\b|function\s+\w+\()",
case_insensitive: true,
},
target: SearchType::CodingRules,
weight: 3.0,
respects_negation: true,
},
Rule {
matcher: Matcher::Regex {
cell: &RE_LEXICAL_QUOTED,
pattern: r#"^"[^"]+"$"#,
case_insensitive: false,
},
target: SearchType::ChunksLexical,
weight: 8.0,
respects_negation: true,
},
Rule {
matcher: Matcher::Keywords(&[
"exact",
"verbatim",
"literal",
"word for word",
"word-for-word",
"word.for.word",
"word_for_word",
]),
target: SearchType::ChunksLexical,
weight: 4.0,
respects_negation: true,
},
Rule {
matcher: Matcher::Keywords(&[
"summarize",
"summary",
"overview",
"outline",
"tldr",
"tl;dr",
"gist",
"main point",
"main points",
"key takeaway",
"key takeaways",
"high level",
"high-level",
"highlevel",
]),
target: SearchType::GraphSummaryCompletion,
weight: 5.0,
respects_negation: true,
},
Rule {
matcher: Matcher::Keywords(&[
"why",
"explain",
"reasoning",
"step by step",
"step-by-step",
"step.by.step",
"chain of thought",
]),
target: SearchType::GraphCompletionCot,
weight: 4.0,
respects_negation: true,
},
Rule {
matcher: Matcher::Keywords(&["because", "therefore", "consequently"]),
target: SearchType::GraphCompletionCot,
weight: 2.0,
respects_negation: true,
},
Rule {
matcher: Matcher::Regex {
cell: &RE_RELATIONSHIP_HOW,
pattern: r"\b(how (is|are|does|do)\s+\w+\s+(related|connected|linked))\b",
case_insensitive: true,
},
target: SearchType::GraphCompletionContextExtension,
weight: 5.0,
respects_negation: true,
},
Rule {
matcher: Matcher::Regex {
cell: &RE_RELATIONSHIP_WHAT,
pattern: r"\b(what (connects|links|ties)|path between|degree of separation)\b",
case_insensitive: true,
},
target: SearchType::GraphCompletionContextExtension,
weight: 5.0,
respects_negation: true,
},
Rule {
matcher: Matcher::Keywords(&[
"connection",
"relationship",
"related to",
"linked to",
]),
target: SearchType::GraphCompletionContextExtension,
weight: 3.0,
respects_negation: true,
},
Rule {
matcher: Matcher::Keywords(&[
"when", "before", "after", "during", "since", "until",
]),
target: SearchType::Temporal,
weight: 3.0,
respects_negation: true,
},
Rule {
matcher: Matcher::Keywords(&[
"timeline",
"chronolog",
"chronology",
"chronological",
"era",
"decade",
"century",
]),
target: SearchType::Temporal,
weight: 4.0,
respects_negation: true,
},
Rule {
matcher: Matcher::Regex {
cell: &RE_YEAR,
pattern: r"\b\d{4}s?\b",
case_insensitive: false,
},
target: SearchType::Temporal,
weight: 3.0,
respects_negation: true,
},
Rule {
matcher: Matcher::Regex {
cell: &RE_YEAR_RANGE,
pattern: r"\bbetween\s+\d{4}\s+and\s+\d{4}\b",
case_insensitive: true,
},
target: SearchType::Temporal,
weight: 6.0,
respects_negation: true,
},
]
})
}
fn compile(
cell: &'static OnceLock<Regex>,
pattern: &str,
case_insensitive: bool,
) -> &'static Regex {
cell.get_or_init(|| {
let mut builder = regex::RegexBuilder::new(pattern);
builder.case_insensitive(case_insensitive);
builder
.build()
.unwrap_or_else(|e| panic!("query_router: failed to compile regex {pattern:?}: {e}"))
})
}
fn rule_match(rule: &Rule, trimmed: &str, lower: &str) -> Option<usize> {
match &rule.matcher {
Matcher::Keywords(kws) => {
let mut earliest: Option<usize> = None;
for kw in *kws {
if let Some(pos) = contains_word(lower, kw) {
earliest = Some(earliest.map_or(pos, |e| e.min(pos)));
}
}
earliest
}
Matcher::Regex {
cell,
pattern,
case_insensitive,
} => {
let re = compile(cell, pattern, *case_insensitive);
re.find(trimmed).map(|m| m.start())
}
}
}
pub fn route_query(query: &str) -> RouteResult {
let trimmed = query.trim();
let lower = trimmed.to_lowercase();
let mut scores: Vec<(SearchType, f32)> = Vec::new();
for rule in rules() {
let Some(m_start) = rule_match(rule, trimmed, &lower) else {
continue;
};
if rule.respects_negation && is_negated(&lower, m_start) {
continue;
}
if let Some(entry) = scores.iter_mut().find(|(s, _)| *s == rule.target) {
entry.1 += rule.weight;
} else {
scores.push((rule.target, rule.weight));
}
}
if scores.is_empty() {
return RouteResult {
search_type: DEFAULT_TYPE,
confidence: DEFAULT_BASE_SCORE,
runner_up: DEFAULT_TYPE,
runner_up_score: 0.0,
all_scores: Vec::new(),
};
}
scores.sort_by(|a, b| b.1.total_cmp(&a.1));
let (best_type, best_score) = scores[0];
let (ru_type, ru_score) = scores.get(1).copied().unwrap_or((DEFAULT_TYPE, 0.0));
if best_score < DEFAULT_BASE_SCORE {
return RouteResult {
search_type: DEFAULT_TYPE,
confidence: best_score,
runner_up: best_type,
runner_up_score: best_score,
all_scores: scores,
};
}
RouteResult {
search_type: best_type,
confidence: best_score,
runner_up: ru_type,
runner_up_score: ru_score,
all_scores: scores,
}
}
#[cfg(test)]
mod tests {
use super::*;
mod factual_queries {
use super::*;
#[test]
fn simple_who() {
assert_eq!(
route_query("Who won Nobel Prizes?").search_type,
SearchType::GraphCompletion
);
}
#[test]
fn simple_what() {
assert_eq!(
route_query("What did Einstein discover?").search_type,
SearchType::GraphCompletion
);
}
#[test]
fn short_list() {
assert_eq!(
route_query("List all scientists").search_type,
SearchType::GraphCompletion
);
}
}
mod cypher {
use super::*;
#[test]
fn match_statement() {
assert_eq!(
route_query("MATCH (n:Person) RETURN n.name").search_type,
SearchType::Cypher
);
}
#[test]
fn return_statement() {
assert_eq!(route_query("RETURN 1").search_type, SearchType::Cypher);
}
}
mod coding_rules {
use super::*;
#[test]
fn coding_rules_phrase() {
let r = route_query("What coding rules apply to error handling?");
assert_eq!(r.search_type, SearchType::CodingRules);
}
#[test]
fn code_review() {
assert_eq!(
route_query("Show me the code review guidelines").search_type,
SearchType::CodingRules
);
}
#[test]
fn bare_class_is_not_code() {
let result = route_query("What class of animal is a dolphin?");
assert_ne!(result.search_type, SearchType::CodingRules);
}
#[test]
fn bare_function_is_not_code() {
let result = route_query("What is the function of the liver?");
assert_ne!(result.search_type, SearchType::CodingRules);
}
}
mod lexical {
use super::*;
#[test]
fn quoted_phrase() {
assert_eq!(
route_query("\"polonium and radium\"").search_type,
SearchType::ChunksLexical
);
}
#[test]
fn exact_keyword() {
let r = route_query("Find the exact phrase in the documents");
assert_eq!(r.search_type, SearchType::ChunksLexical);
}
}
mod summary {
use super::*;
#[test]
fn summarize() {
let r = route_query("Summarize everything about Marie Curie");
assert_eq!(r.search_type, SearchType::GraphSummaryCompletion);
}
#[test]
fn overview() {
let r = route_query("Give me an overview of the project");
assert_eq!(r.search_type, SearchType::GraphSummaryCompletion);
}
#[test]
fn tldr() {
assert_eq!(
route_query("tldr of the report").search_type,
SearchType::GraphSummaryCompletion
);
}
}
mod reasoning {
use super::*;
#[test]
fn why_question() {
let r = route_query("Why did Curie win two Nobel Prizes?");
assert_eq!(r.search_type, SearchType::GraphCompletionCot);
}
#[test]
fn explain() {
let r = route_query("Explain the theory of relativity");
assert_eq!(r.search_type, SearchType::GraphCompletionCot);
}
}
mod relationship {
use super::*;
#[test]
fn connection_between() {
let r = route_query("How is Einstein connected to the Sorbonne?");
assert_eq!(r.search_type, SearchType::GraphCompletionContextExtension);
}
#[test]
fn related_to() {
let r = route_query("What entities are related to physics?");
assert_eq!(r.search_type, SearchType::GraphCompletionContextExtension);
}
#[test]
fn between_not_temporal() {
let r = route_query("What is the relationship between supply and demand?");
assert_eq!(r.search_type, SearchType::GraphCompletionContextExtension);
}
}
mod temporal {
use super::*;
#[test]
fn when_question() {
assert_eq!(
route_query("When did Einstein publish?").search_type,
SearchType::Temporal
);
}
#[test]
fn year_range() {
let r = route_query("What happened between 1910 and 1920?");
assert_eq!(r.search_type, SearchType::Temporal);
}
#[test]
fn timeline() {
assert_eq!(
route_query("Show the timeline of discoveries").search_type,
SearchType::Temporal
);
}
#[test]
fn specific_year() {
assert_eq!(
route_query("What was discovered in 1915?").search_type,
SearchType::Temporal
);
}
}
mod negation {
use super::*;
#[test]
fn not_related_suppresses_graph() {
let r = route_query("What is not related to physics?");
assert_ne!(r.search_type, SearchType::GraphCompletionContextExtension);
}
#[test]
fn no_connection_suppresses_graph() {
let r = route_query("There is no connection between these topics");
assert_ne!(r.search_type, SearchType::GraphCompletionContextExtension);
}
#[test]
fn negation_does_not_affect_distant_match() {
let r = route_query(
"This is not about food at all, however I want to know how is X connected to Y?",
);
assert_eq!(r.search_type, SearchType::GraphCompletionContextExtension);
}
}
mod confidence {
use super::*;
#[test]
fn high_confidence_for_cypher() {
let r = route_query("MATCH (n) RETURN n");
assert!(r.confidence >= 10.0);
assert!(r.is_confident());
}
#[test]
fn runner_up_populated() {
let r = route_query("Summarize the timeline of discoveries");
assert_eq!(r.search_type, SearchType::GraphSummaryCompletion);
assert!(!r.all_scores.is_empty());
}
#[test]
fn default_has_base_confidence() {
let r = route_query("Tell me something interesting");
assert_eq!(r.search_type, SearchType::GraphCompletion);
assert!(r.confidence >= 0.0);
}
}
mod ambiguous {
use super::*;
#[test]
fn temporal_beats_graph_for_years() {
let r = route_query("What happened between 1910 and 1920?");
assert_eq!(r.search_type, SearchType::Temporal);
}
#[test]
fn summary_with_temporal_word() {
let r = route_query("Summarize the timeline of Einstein's work");
assert_eq!(r.search_type, SearchType::GraphSummaryCompletion);
}
#[test]
fn default_for_vague_query() {
assert_eq!(
route_query("Tell me something").search_type,
SearchType::GraphCompletion
);
}
}
}