use crate::{FxHashMap, FxHashSet};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum EdgeKind {
Morphological,
Abbreviation,
Synonym,
Semantic,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HebbianEdge {
pub target: String,
pub weight: f32,
pub kind: EdgeKind,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct LexicalGraph {
pub edges: HashMap<String, Vec<HebbianEdge>>,
}
#[derive(Debug, Clone)]
pub struct PreprocessResult {
pub original: String,
pub normalized: String,
pub expanded: String,
pub injected: Vec<String>,
pub semantic_hits: Vec<(String, String, f32)>, pub was_modified: bool,
}
impl LexicalGraph {
pub fn new() -> Self {
Self {
edges: HashMap::new(),
}
}
pub fn add(&mut self, from: &str, to: &str, weight: f32, kind: EdgeKind) {
self.edges
.entry(from.to_lowercase())
.or_default()
.push(HebbianEdge {
target: to.to_lowercase(),
weight,
kind,
});
}
pub fn reinforce(&mut self, from: &str, to: &str, delta: f32) {
let from = from.to_lowercase();
let to = to.to_lowercase();
if let Some(edges) = self.edges.get_mut(&from) {
for e in edges.iter_mut() {
if e.target == to {
e.weight = (e.weight + delta * (1.0 - e.weight)).min(1.0);
return;
}
}
}
self.add(&from, &to, 0.60 + delta, EdgeKind::Synonym);
}
pub fn l1_tokens_pub(query: &str) -> Vec<String> {
Self::l1_tokens(query)
}
fn l1_tokens(query: &str) -> Vec<String> {
let lower = query.to_lowercase();
let has_cjk = lower.chars().any(crate::tokenizer::is_cjk);
if !has_cjk {
let mut out = Vec::new();
for w in lower.split_whitespace() {
let has_boundary = w.ends_with('.') || w.ends_with('!') || w.ends_with('?');
let clean: String = w.trim_matches(|c: char| !c.is_alphanumeric()).to_string();
if !clean.is_empty() {
out.push(clean);
}
if has_boundary {
out.push(".".to_string()); }
}
out
} else {
crate::tokenizer::tokenize(query)
}
}
pub fn normalize_query(&self, query: &str) -> String {
let words = Self::l1_tokens(query);
let mut out: Vec<String> = Vec::with_capacity(words.len());
for word in &words {
let replacement = self.edges.get(word.as_str()).and_then(|edges| {
edges.iter().find(|e| {
matches!(e.kind, EdgeKind::Morphological | EdgeKind::Abbreviation)
&& e.weight >= 0.97
})
});
match replacement {
Some(e) => out.push(e.target.clone()),
None => out.push(word.clone()),
}
}
out.join(" ")
}
pub fn semantic_hits(&self, query: &str) -> Vec<(String, String, f32)> {
let words = Self::l1_tokens(query);
let mut hits = Vec::new();
for word in &words {
if let Some(edges) = self.edges.get(word.as_str()) {
for edge in edges {
if matches!(edge.kind, EdgeKind::Semantic) {
hits.push((word.clone(), edge.target.clone(), edge.weight));
}
}
}
}
hits
}
pub fn preprocess(&self, query: &str) -> PreprocessResult {
let normalized = self.normalize_query(query);
let semantic_hits = self.semantic_hits(&normalized);
let was_modified = normalized != query.to_lowercase();
PreprocessResult {
original: query.to_string(),
normalized: normalized.clone(),
expanded: normalized,
injected: vec![],
semantic_hits,
was_modified,
}
}
pub fn preprocess_grounded(
&self,
query: &str,
known_words: &std::collections::HashSet<&str>,
) -> PreprocessResult {
let words = Self::l1_tokens(query);
let mut out: Vec<String> = Vec::with_capacity(words.len());
let mut injected: Vec<String> = Vec::new();
for word in &words {
let edges = self.edges.get(word.as_str());
let is_oov = !known_words.contains(word.as_str());
if is_oov {
let canon = edges.and_then(|es| {
es.iter().find(|e| {
matches!(e.kind, EdgeKind::Morphological | EdgeKind::Abbreviation)
&& e.weight >= 0.97
&& known_words.contains(e.target.as_str())
})
});
if let Some(e) = canon {
out.push(e.target.clone());
continue;
}
}
if is_oov {
if let Some(syn) = edges.and_then(|es| {
es.iter().find(|e| {
matches!(e.kind, EdgeKind::Synonym)
&& e.weight >= 0.90
&& known_words.contains(e.target.as_str())
})
}) {
injected.push(format!("{} → {}", word, syn.target));
out.push(syn.target.clone());
continue;
}
}
out.push(word.clone());
}
let expanded = out.join(" ");
let was_modified = expanded != query.to_lowercase();
PreprocessResult {
original: query.to_string(),
normalized: expanded.clone(),
expanded,
injected,
semantic_hits: vec![],
was_modified,
}
}
pub fn save(&self, path: &str) -> std::io::Result<()> {
let json = serde_json::to_string_pretty(self)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
std::fs::write(path, json)
}
pub fn load(path: &str) -> std::io::Result<Self> {
let content = std::fs::read_to_string(path)?;
serde_json::from_str(&content)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
}
}
pub fn saas_test_graph() -> LexicalGraph {
let mut g = LexicalGraph::new();
for v in &["canceling", "cancelled", "cancellation", "cancels"] {
g.add(v, "cancel", 0.99, EdgeKind::Morphological);
}
for v in &["refunding", "refunded", "refunds"] {
g.add(v, "refund", 0.99, EdgeKind::Morphological);
}
for v in &["charging", "charged", "charges"] {
g.add(v, "charge", 0.99, EdgeKind::Morphological);
}
for v in &["shipped", "shipment", "shipments"] {
g.add(v, "ship", 0.99, EdgeKind::Morphological);
}
for v in &["merging", "merged", "merges"] {
g.add(v, "merge", 0.99, EdgeKind::Morphological);
}
for v in &["listing", "listed", "lists"] {
g.add(v, "list", 0.99, EdgeKind::Morphological);
}
for v in &["creating", "created", "creates", "creation"] {
g.add(v, "create", 0.99, EdgeKind::Morphological);
}
for v in &["scheduling", "scheduled", "schedules"] {
g.add(v, "schedule", 0.99, EdgeKind::Morphological);
}
for v in &["inviting", "invited", "invites"] {
g.add(v, "invite", 0.99, EdgeKind::Morphological);
}
for v in &["sending", "sent", "sends"] {
g.add(v, "send", 0.99, EdgeKind::Morphological);
}
for v in &["closing", "closed", "closes"] {
g.add(v, "close", 0.99, EdgeKind::Morphological);
}
g.add("pr", "pull request", 0.99, EdgeKind::Abbreviation);
g.add("prs", "pull requests", 0.99, EdgeKind::Abbreviation);
g.add("repo", "repository", 0.99, EdgeKind::Abbreviation);
g.add("repos", "repositories", 0.99, EdgeKind::Abbreviation);
g.add("sub", "subscription", 0.99, EdgeKind::Abbreviation);
g.add("subs", "subscriptions", 0.99, EdgeKind::Abbreviation);
g.add("msg", "message", 0.99, EdgeKind::Abbreviation);
g.add("msgs", "messages", 0.99, EdgeKind::Abbreviation);
g.add("chan", "channel", 0.99, EdgeKind::Abbreviation);
for (v, w) in &[
("terminate", 0.92f32),
("terminating", 0.92),
("terminated", 0.92),
] {
g.add(v, "cancel", *w, EdgeKind::Synonym);
}
for (v, w) in &[("kill", 0.85f32), ("killing", 0.85), ("killed", 0.85)] {
g.add(v, "cancel", *w, EdgeKind::Synonym);
}
for (v, w) in &[("axe", 0.83f32), ("axed", 0.83), ("axing", 0.83)] {
g.add(v, "cancel", *w, EdgeKind::Synonym);
}
g.add("ditch", "cancel", 0.80, EdgeKind::Synonym);
g.add("ping", "send", 0.92, EdgeKind::Synonym);
g.add("dm", "send", 0.90, EdgeKind::Synonym);
g.add("notify", "send", 0.85, EdgeKind::Synonym);
g.add("blast", "send", 0.80, EdgeKind::Synonym);
g.add("spin", "create", 0.82, EdgeKind::Synonym);
g.add("make", "create", 0.85, EdgeKind::Synonym);
g.add("build", "create", 0.82, EdgeKind::Synonym);
g.add("reimburse", "refund", 0.90, EdgeKind::Synonym);
g.add("reimbursement", "refund", 0.90, EdgeKind::Synonym);
g.add("compensate", "refund", 0.80, EdgeKind::Synonym);
g.add("run", "charge", 0.82, EdgeKind::Synonym); g.add("bill", "charge", 0.85, EdgeKind::Synonym);
g.add("show", "list", 0.85, EdgeKind::Synonym);
g.add("fetch", "list", 0.82, EdgeKind::Synonym);
g.add("integrate", "merge", 0.82, EdgeKind::Synonym);
g.add("squash", "merge", 0.80, EdgeKind::Synonym);
g.add("stop", "cancel", 0.65, EdgeKind::Semantic);
g.add("end", "cancel", 0.62, EdgeKind::Semantic);
g.add("drop", "cancel", 0.68, EdgeKind::Semantic);
g.add("fire", "send", 0.70, EdgeKind::Semantic); g.add("throw", "create", 0.65, EdgeKind::Semantic);
g
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn morph_canceling() {
let g = saas_test_graph();
assert_eq!(
g.normalize_query("canceling my subscription"),
"cancel my subscription"
);
}
#[test]
fn morph_cancelled() {
let g = saas_test_graph();
assert_eq!(
g.normalize_query("the order was cancelled"),
"the order was cancel"
);
}
#[test]
fn morph_multiple_in_one_query() {
let g = saas_test_graph();
assert_eq!(
g.normalize_query("merged the pr and closed the issue"),
"merge the pull request and close the issue"
);
}
#[test]
fn morph_shipped() {
let g = saas_test_graph();
assert_eq!(
g.normalize_query("get all shipped orders"),
"get all ship orders"
);
}
#[test]
fn abbrev_sub() {
let g = saas_test_graph();
assert_eq!(g.normalize_query("cancel my sub"), "cancel my subscription");
}
#[test]
fn abbrev_pr_and_repo() {
let g = saas_test_graph();
assert_eq!(
g.normalize_query("merge the pr in that repo"),
"merge the pull request in that repository"
);
}
#[test]
fn abbrev_msg_chan() {
let g = saas_test_graph();
assert_eq!(
g.normalize_query("send a msg to the chan"),
"send a message to the channel"
);
}
#[test]
fn combined_morph_and_abbrev() {
let g = saas_test_graph();
assert_eq!(
g.normalize_query("canceling my sub"),
"cancel my subscription"
);
}
#[test]
fn combined_morph_then_synonym() {
let g = saas_test_graph();
let r = g.preprocess("canceling my sub");
assert_eq!(r.normalized, "cancel my subscription");
assert!(
r.injected.is_empty(),
"should not inject anything when already canonical"
);
}
#[test]
fn semantic_stop_does_not_expand() {
let g = saas_test_graph();
let r = g.preprocess("stop sending me emails");
assert!(
!r.expanded.contains("cancel"),
"semantic word should not expand query"
);
let hit = r
.semantic_hits
.iter()
.any(|(src, tgt, _)| src == "stop" && tgt == "cancel");
assert!(hit, "stop → cancel should appear as semantic hit");
}
#[test]
fn semantic_end_does_not_expand() {
let g = saas_test_graph();
let r = g.preprocess("at the end of the month");
assert!(!r.expanded.contains("cancel"));
}
#[test]
fn no_modification_clean_query() {
let g = saas_test_graph();
let r = g.preprocess("cancel my subscription");
assert!(!r.was_modified);
assert_eq!(r.expanded, "cancel my subscription");
}
#[test]
fn reinforce_strengthens_existing_edge() {
let mut g = saas_test_graph();
let before = g.edges["terminate"][0].weight;
g.reinforce("terminate", "cancel", 0.05);
let after = g.edges["terminate"][0].weight;
assert!(after > before, "reinforcement should increase weight");
assert!(after <= 1.0, "should not exceed 1.0");
}
#[test]
fn reinforce_creates_new_edge() {
let mut g = saas_test_graph();
g.reinforce("nuke", "cancel", 0.05);
let has_edge = g
.edges
.get("nuke")
.map(|es| es.iter().any(|e| e.target == "cancel"))
.unwrap_or(false);
assert!(has_edge, "new word should get a learned edge");
}
#[test]
fn pipeline_merged_the_pr() {
let g = saas_test_graph();
let r = g.preprocess("merged the pr");
assert_eq!(r.normalized, "merge the pull request");
assert!(!r.was_modified || r.injected.is_empty()); }
}
#[derive(serde::Serialize, serde::Deserialize, Clone, Debug)]
pub struct ConjunctionRule {
pub words: Vec<String>,
pub intent: String,
pub bonus: f32,
}
#[derive(Debug, Clone)]
pub struct RouteResult {
pub confirmed: Vec<(String, f32)>,
pub ranked: Vec<(String, f32)>,
pub disposition: String,
pub has_negation: bool,
}
#[derive(serde::Serialize, Clone, Debug)]
pub struct RoundTrace {
pub tokens_in: Vec<String>,
pub scored: Vec<(String, f32)>,
pub confirmed: Vec<String>,
pub consumed: Vec<String>,
}
#[derive(serde::Serialize, Clone, Debug)]
pub struct MultiIntentTrace {
pub rounds: Vec<RoundTrace>,
pub stop_reason: String,
}
#[derive(serde::Serialize, serde::Deserialize, Clone, Debug, Default)]
pub struct IntentIndex {
#[serde(default)]
pub word_intent: HashMap<String, Vec<(String, f32)>>,
#[serde(default)]
pub conjunctions: Vec<ConjunctionRule>,
#[serde(default)]
pub char_ngrams: HashMap<String, std::collections::HashSet<String>>,
#[serde(default)]
pub intent_count: usize,
#[serde(skip)]
idf_cache: FxHashMap<String, f32>,
#[serde(skip)]
known_intents: FxHashSet<String>,
}
impl IntentIndex {
pub fn new() -> Self {
Self::default()
}
const PHRASE_RATE: f32 = 0.4;
const LEARN_RATE: f32 = 0.3;
pub fn rebuild_idf(&mut self) {
self.known_intents.clear();
for entries in self.word_intent.values() {
for (id, _) in entries {
self.known_intents.insert(id.clone());
}
}
let n = self.known_intents.len();
self.intent_count = n;
let n_f = n.max(1) as f32;
self.idf_cache.clear();
for (word, entries) in &self.word_intent {
let idf = (n_f / entries.len() as f32).ln().max(0.0);
self.idf_cache.insert(word.clone(), idf);
}
}
fn refresh_idf_for(&mut self, word: &str) {
if let Some(entries) = self.word_intent.get(word) {
let n_f = self.intent_count.max(1) as f32;
let idf = (n_f / entries.len() as f32).ln().max(0.0);
self.idf_cache.insert(word.to_string(), idf);
} else {
self.idf_cache.remove(word);
}
}
#[inline]
fn idf(&self, word: &str) -> f32 {
self.idf_cache.get(word).copied().unwrap_or_else(|| {
self.word_intent
.get(word)
.map(|e| {
(self.intent_count.max(1) as f32 / e.len() as f32)
.ln()
.max(0.0)
})
.unwrap_or(0.0)
})
}
pub fn learn_word(&mut self, word: &str, intent: &str, rate: f32) {
if word.is_empty() {
return;
}
let entries = self.word_intent.entry(word.to_string()).or_default();
if let Some(e) = entries.iter_mut().find(|(id, _)| id == intent) {
e.1 = (e.1 + rate * (1.0 - e.1)).min(1.0);
} else {
let new_intent = self.known_intents.insert(intent.to_string());
if new_intent {
self.intent_count += 1;
}
entries.push((intent.to_string(), rate));
if new_intent {
self.rebuild_idf();
} else {
self.refresh_idf_for(word);
}
}
}
pub fn learn_phrase(&mut self, words: &[&str], intent: &str) {
for word in words {
self.learn_word(word, intent, Self::PHRASE_RATE);
}
}
pub fn index_char_ngrams(&mut self, phrase: &str, intent: &str) {
let normalized: String = phrase.to_lowercase();
let s: String = format!(
" {} ",
normalized.split_whitespace().collect::<Vec<_>>().join(" ")
);
if s.chars().count() < 4 {
return;
}
let chars: Vec<char> = s.chars().collect();
let set = self.char_ngrams.entry(intent.to_string()).or_default();
for window in chars.windows(4) {
let ngram: String = window.iter().collect();
set.insert(ngram);
}
}
pub fn apply_char_ngram_tiebreaker(
&self,
query: &str,
ranked: Vec<(String, f32)>,
ratio_threshold: f32,
alpha: f32,
) -> Vec<(String, f32)> {
if ranked.len() < 2 {
return ranked;
}
let top1 = ranked[0].1;
let top2 = ranked[1].1;
if top1 + top2 <= 0.0 {
return ranked;
}
let ratio = top1 / (top1 + top2);
if ratio >= ratio_threshold {
return ranked;
}
let normalized: String = query.to_lowercase();
let s: String = format!(
" {} ",
normalized.split_whitespace().collect::<Vec<_>>().join(" ")
);
if s.chars().count() < 4 {
return ranked;
}
let chars: Vec<char> = s.chars().collect();
let mut q_ngrams: FxHashSet<String> = FxHashSet::default();
for window in chars.windows(4) {
let ngram: String = window.iter().collect();
q_ngrams.insert(ngram);
}
if q_ngrams.is_empty() {
return ranked;
}
let k = ranked.len().min(5);
let (head, tail) = ranked.split_at(k);
let mut rescored: Vec<(String, f32)> = head
.iter()
.map(|(id, score)| {
let intent_set = self.char_ngrams.get(id);
let jaccard = match intent_set {
Some(iset) if !iset.is_empty() => {
let inter = q_ngrams.iter().filter(|n| iset.contains(*n)).count();
let uni = q_ngrams.len() + iset.len() - inter;
if uni == 0 {
0.0
} else {
inter as f32 / uni as f32
}
}
_ => 0.0,
};
(id.clone(), score + alpha * jaccard)
})
.collect();
rescored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
rescored.extend_from_slice(tail);
rescored
}
pub fn learn_query_words(&mut self, words: &[&str], intent: &str) {
for word in words {
self.learn_word(word, intent, Self::LEARN_RATE);
}
}
pub fn default_threshold(&self) -> f32 {
0.3
}
pub fn default_gap(&self) -> f32 {
1.5
}
pub fn reinforce(&mut self, words: &[&str], intent: &str, delta: f32) {
for word in words {
let entries = self.word_intent.entry(word.to_string()).or_default();
if let Some(e) = entries.iter_mut().find(|(id, _)| id == intent) {
if delta >= 0.0 {
e.1 = (e.1 + delta * (1.0 - e.1)).min(1.0);
} else {
e.1 = (e.1 * (1.0 + delta)).max(0.0);
}
} else if delta > 0.0 {
entries.push((intent.to_string(), delta.min(1.0)));
self.refresh_idf_for(word);
}
}
}
pub fn fired_conjunction_indices(&self, words: &[&str]) -> Vec<usize> {
let word_set: FxHashSet<&str> = words.iter().copied().collect();
self.conjunctions
.iter()
.enumerate()
.filter(|(_, rule)| rule.words.iter().all(|w| word_set.contains(w.as_str())))
.map(|(i, _)| i)
.collect()
}
pub fn reinforce_conjunction(&mut self, idx: usize, delta: f32) {
if let Some(rule) = self.conjunctions.get_mut(idx) {
if delta >= 0.0 {
rule.bonus = (rule.bonus + delta * (1.0 - rule.bonus)).min(1.0);
} else {
rule.bonus = (rule.bonus * (1.0 + delta)).max(0.0);
}
}
}
pub fn score_normalized(&self, normalized: &str) -> (Vec<(String, f32)>, bool) {
const CJK_NEG: &[char] = &['不', '没', '别', '未'];
let cjk_negated = normalized.chars().any(|c| CJK_NEG.contains(&c));
let query_for_tokenize: std::borrow::Cow<str> = if cjk_negated {
std::borrow::Cow::Owned(
normalized
.chars()
.map(|c| if CJK_NEG.contains(&c) { ' ' } else { c })
.collect(),
)
} else {
std::borrow::Cow::Borrowed(normalized)
};
let tokens = crate::tokenizer::tokenize(&query_for_tokenize);
let mut scores: FxHashMap<String, f32> = FxHashMap::default();
let mut has_negation = cjk_negated;
let all_bases: FxHashSet<&str> = tokens
.iter()
.map(|t| t.strip_prefix("not_").unwrap_or(t.as_str()))
.collect();
for token in &tokens {
let is_negated = token.starts_with("not_");
let base = if is_negated {
&token["not_".len()..]
} else {
token.as_str()
};
if is_negated {
has_negation = true;
}
if let Some(activations) = self.word_intent.get(base) {
let idf = self.idf(base);
for (intent, weight) in activations {
let delta = weight * idf;
*scores.entry(intent.clone()).or_insert(0.0) +=
if is_negated { -delta } else { delta };
}
}
}
for rule in &self.conjunctions {
if rule.words.iter().all(|w| all_bases.contains(w.as_str())) {
*scores.entry(rule.intent.clone()).or_insert(0.0) += rule.bonus;
}
}
let mut result: Vec<(String, f32)> = scores.into_iter().filter(|(_, s)| *s > 0.0).collect();
result.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
(result, has_negation)
}
pub fn score(&self, layer1: &LexicalGraph, query: &str) -> (Vec<(String, f32)>, bool) {
let preprocessed = layer1.preprocess(query);
self.score_normalized(&preprocessed.expanded)
}
pub fn score_multi(
&self,
layer1: &LexicalGraph,
query: &str,
threshold: f32,
gap: f32,
) -> (Vec<(String, f32)>, bool) {
let preprocessed = layer1.preprocess(query);
self.score_multi_normalized(&preprocessed.expanded, threshold, gap)
}
pub fn score_multi_normalized(
&self,
normalized: &str,
threshold: f32,
gap: f32,
) -> (Vec<(String, f32)>, bool) {
let (results, neg, _trace) =
self.score_multi_normalized_traced(normalized, threshold, gap, false);
(results, neg)
}
pub fn score_multi_normalized_traced(
&self,
normalized: &str,
threshold: f32,
_gap: f32,
with_trace: bool,
) -> (Vec<(String, f32)>, bool, Option<MultiIntentTrace>) {
const GATE_RATIO: f32 = 0.55;
const MAX_ROUNDS: usize = 3;
const CJK_NEG: &[char] = &['不', '没', '别', '未'];
let cjk_negated = normalized.chars().any(|c| CJK_NEG.contains(&c));
let query_for_tokenize: std::borrow::Cow<str> = if cjk_negated {
std::borrow::Cow::Owned(
normalized
.chars()
.map(|c| if CJK_NEG.contains(&c) { ' ' } else { c })
.collect(),
)
} else {
std::borrow::Cow::Borrowed(normalized)
};
let all_tokens: Vec<String> = crate::tokenizer::tokenize(&query_for_tokenize);
let has_negation = cjk_negated || all_tokens.iter().any(|t| t.starts_with("not_"));
let mut remaining: Vec<String> = all_tokens;
let mut confirmed: Vec<(String, f32)> = Vec::new();
let mut confirmed_ids: FxHashSet<String> = FxHashSet::default();
let mut original_top: f32 = 0.0;
let mut trace_rounds: Vec<RoundTrace> = Vec::new();
let mut stop_reason: Option<String> = None;
for round in 0..MAX_ROUNDS {
let scored = self.score_tokens(&remaining, &confirmed_ids);
if scored.is_empty() {
if with_trace {
stop_reason = Some("no scores".into());
}
break;
}
let round_top = scored[0].1;
if round == 0 {
original_top = round_top;
}
if round_top < threshold {
if with_trace {
stop_reason =
Some(format!("top {:.2} < threshold {:.2}", round_top, threshold));
trace_rounds.push(RoundTrace {
tokens_in: remaining.clone(),
scored: scored.iter().take(5).cloned().collect(),
confirmed: vec![],
consumed: vec![],
});
}
break;
}
if round > 0 && round_top < original_top * GATE_RATIO {
if with_trace {
stop_reason = Some(format!(
"top {:.2} < gate {:.2}",
round_top,
original_top * GATE_RATIO
));
trace_rounds.push(RoundTrace {
tokens_in: remaining.clone(),
scored: scored.iter().take(5).cloned().collect(),
confirmed: vec![],
consumed: vec![],
});
}
break;
}
let mut round_confirmed: Vec<(String, f32)> = Vec::new();
for (intent, score) in &scored {
if *score >= round_top * 0.90 && *score >= threshold {
round_confirmed.push((intent.clone(), *score));
confirmed_ids.insert(intent.clone());
}
}
if round_confirmed.is_empty() {
if with_trace {
stop_reason = Some("no confirmed".into());
}
break;
}
confirmed.extend(round_confirmed.iter().cloned());
let tokens_before: Vec<String> = remaining.clone();
remaining.retain(|token| {
let base = token.strip_prefix("not_").unwrap_or(token.as_str());
if let Some(activations) = self.word_intent.get(base) {
let best_intent = activations
.iter()
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
match best_intent {
Some((id, _)) => !confirmed_ids.contains(id.as_str()),
None => true,
}
} else {
true }
});
if with_trace {
let remaining_set: FxHashSet<&String> = remaining.iter().collect();
let consumed: Vec<String> = tokens_before
.iter()
.filter(|t| !remaining_set.contains(t))
.cloned()
.collect();
trace_rounds.push(RoundTrace {
tokens_in: tokens_before,
scored: scored.iter().take(5).cloned().collect(),
confirmed: round_confirmed.iter().map(|(id, _)| id.clone()).collect(),
consumed,
});
}
if remaining.is_empty() {
if with_trace {
stop_reason = Some("all tokens consumed".into());
}
break;
}
}
let trace = if with_trace {
Some(MultiIntentTrace {
rounds: trace_rounds,
stop_reason: stop_reason.unwrap_or_else(|| "max rounds reached".into()),
})
} else {
None
};
(confirmed, has_negation, trace)
}
fn score_tokens(
&self,
tokens: &[String],
exclude_intents: &FxHashSet<String>,
) -> Vec<(String, f32)> {
let mut scores: FxHashMap<String, f32> = FxHashMap::default();
for token in tokens {
let is_negated = token.starts_with("not_");
let base = if is_negated {
&token["not_".len()..]
} else {
token.as_str()
};
if let Some(activations) = self.word_intent.get(base) {
let idf = self.idf(base);
for (intent, weight) in activations {
if exclude_intents.contains(intent) {
continue;
}
let delta = weight * idf;
*scores.entry(intent.clone()).or_insert(0.0) +=
if is_negated { -delta } else { delta };
}
}
}
let all_bases: FxHashSet<&str> = tokens
.iter()
.map(|t| t.strip_prefix("not_").unwrap_or(t.as_str()))
.collect();
for rule in &self.conjunctions {
if !exclude_intents.contains(&rule.intent)
&& rule.words.iter().all(|w| all_bases.contains(w.as_str()))
{
*scores.entry(rule.intent.clone()).or_insert(0.0) += rule.bonus;
}
}
let mut result: Vec<(String, f32)> = scores.into_iter().filter(|(_, s)| *s > 0.0).collect();
result.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
result
}
pub fn save(&self, path: &str) -> std::io::Result<()> {
let json = serde_json::to_string_pretty(self).map_err(std::io::Error::other)?;
std::fs::write(path, json)
}
pub fn load(path: &str) -> std::io::Result<Self> {
let json = std::fs::read_to_string(path)?;
serde_json::from_str(&json).map_err(std::io::Error::other)
}
pub fn resolve(
&self,
layer1: Option<&LexicalGraph>,
query: &str,
threshold: f32,
top_n: usize,
) -> RouteResult {
let processed = match layer1 {
Some(l1) => l1.preprocess(query).expanded,
None => query.to_string(),
};
let (raw, has_negation) = self.score_normalized(&processed);
let ranked: Vec<(String, f32)> = raw.into_iter().take(top_n).collect();
let (mut confirmed, _) = self.score_multi_normalized(&processed, threshold, 0.0);
if confirmed.is_empty() {
return RouteResult {
confirmed: vec![],
ranked,
disposition: "no_match".to_string(),
has_negation,
};
}
if confirmed.len() > 1 {
self.disambiguate_providers(&mut confirmed, &processed);
}
let top = confirmed[0].1;
let disposition = if confirmed.len() >= 3 && confirmed[2].1 / top >= 0.75 {
"escalate"
} else if top < threshold * 2.0 {
"low_confidence"
} else {
"confident"
};
RouteResult {
confirmed,
ranked,
disposition: disposition.to_string(),
has_negation,
}
}
fn disambiguate_providers(&self, confirmed: &mut Vec<(String, f32)>, query: &str) {
if confirmed.len() < 2 {
return;
}
let mut action_groups: FxHashMap<String, Vec<usize>> = FxHashMap::default();
for (i, (id, _)) in confirmed.iter().enumerate() {
let action = id.split(':').nth(1).unwrap_or(id.as_str());
action_groups.entry(action.to_string()).or_default().push(i);
}
let duplicate_groups: Vec<Vec<usize>> = action_groups
.values()
.filter(|g| g.len() > 1)
.cloned()
.collect();
if duplicate_groups.is_empty() {
return;
}
let tokens = crate::tokenizer::tokenize(query);
let confirmed_ids: FxHashSet<&str> = confirmed.iter().map(|(id, _)| id.as_str()).collect();
let mut unique_count: FxHashMap<&str, usize> = FxHashMap::default();
for token in &tokens {
let base = token.strip_prefix("not_").unwrap_or(token.as_str());
if let Some(activations) = self.word_intent.get(base) {
let matching: Vec<&str> = activations
.iter()
.filter(|(id, _)| confirmed_ids.contains(id.as_str()))
.map(|(id, _)| id.as_str())
.collect();
if matching.len() == 1 {
*unique_count.entry(matching[0]).or_insert(0) += 1;
}
}
}
let mut to_remove: FxHashSet<usize> = FxHashSet::default();
for group in &duplicate_groups {
let best = group.iter().max_by_key(|&&i| {
unique_count
.get(confirmed[i].0.as_str())
.copied()
.unwrap_or(0)
});
if let Some(&best_idx) = best {
if unique_count
.get(confirmed[best_idx].0.as_str())
.copied()
.unwrap_or(0)
> 0
{
for &i in group {
if i != best_idx {
to_remove.insert(i);
}
}
}
}
}
if !to_remove.is_empty() {
let mut i = 0;
confirmed.retain(|_| {
let keep = !to_remove.contains(&i);
i += 1;
keep
});
}
}
pub fn stats(&self) -> (usize, usize, usize) {
let activation_edges: usize = self.word_intent.values().map(|v| v.len()).sum();
(
self.word_intent.len(),
activation_edges,
self.conjunctions.len(),
)
}
}
pub fn english_morphology_base() -> LexicalGraph {
let mut g = LexicalGraph::new();
let morph = EdgeKind::Morphological;
for v in &[
"canceling",
"cancelling",
"cancelled",
"canceled",
"cancellation",
"cancels",
] {
g.add(v, "cancel", 0.99, morph.clone());
}
for v in &[
"refunding",
"refunded",
"refunds",
"reimbursing",
"reimbursed",
] {
g.add(v, "refund", 0.99, morph.clone());
}
for v in &["charging", "charged", "charges"] {
g.add(v, "charge", 0.99, morph.clone());
}
for v in &["updating", "updated", "updates"] {
g.add(v, "update", 0.99, morph.clone());
}
for v in &["creating", "created", "creates", "creation"] {
g.add(v, "create", 0.99, morph.clone());
}
for v in &["deleting", "deleted", "deletes", "deletion"] {
g.add(v, "delete", 0.99, morph.clone());
}
for v in &["sending", "sent", "sends"] {
g.add(v, "send", 0.99, morph.clone());
}
for v in &["receiving", "received", "receives"] {
g.add(v, "receive", 0.99, morph.clone());
}
for v in &["resetting", "resetted", "resets"] {
g.add(v, "reset", 0.99, morph.clone());
}
for v in &["changing", "changed", "changes"] {
g.add(v, "change", 0.99, morph.clone());
}
for v in &["upgrading", "upgraded", "upgrades"] {
g.add(v, "upgrade", 0.99, morph.clone());
}
for v in &["downgrading", "downgraded", "downgrades"] {
g.add(v, "downgrade", 0.99, morph.clone());
}
for v in &["connecting", "connected", "connects", "connection"] {
g.add(v, "connect", 0.99, morph.clone());
}
for v in &["disconnecting", "disconnected", "disconnects"] {
g.add(v, "disconnect", 0.99, morph.clone());
}
for v in &["installing", "installed", "installs", "installation"] {
g.add(v, "install", 0.99, morph.clone());
}
for v in &["removing", "removed", "removes", "removal"] {
g.add(v, "remove", 0.99, morph.clone());
}
for v in &["enabling", "enabled", "enables"] {
g.add(v, "enable", 0.99, morph.clone());
}
for v in &["disabling", "disabled", "disables"] {
g.add(v, "disable", 0.99, morph.clone());
}
for v in &["blocking", "blocked", "blocks"] {
g.add(v, "block", 0.99, morph.clone());
}
for v in &["reporting", "reported", "reports"] {
g.add(v, "report", 0.99, morph.clone());
}
for v in &["transferring", "transferred", "transfers"] {
g.add(v, "transfer", 0.99, morph.clone());
}
for v in &["scheduling", "scheduled", "schedules"] {
g.add(v, "schedule", 0.99, morph.clone());
}
for v in &["merging", "merged", "merges"] {
g.add(v, "merge", 0.99, morph.clone());
}
for v in &["shipping", "shipped", "shipment", "shipments"] {
g.add(v, "ship", 0.99, morph.clone());
}
for v in &["paying", "paid", "pays", "payment", "payments"] {
g.add(v, "pay", 0.99, morph.clone());
}
for v in &[
"subscribing",
"subscribed",
"subscribes",
"subscription",
"subscriptions",
] {
g.add(v, "subscribe", 0.99, morph.clone());
}
for v in &["listing", "listed", "lists"] {
g.add(v, "list", 0.99, morph.clone());
}
for v in &["inviting", "invited", "invites", "invitation"] {
g.add(v, "invite", 0.99, morph.clone());
}
for v in &["verifying", "verified", "verifies", "verification"] {
g.add(v, "verify", 0.99, morph.clone());
}
for v in &["accessing", "accessed", "accesses"] {
g.add(v, "access", 0.99, morph.clone());
}
for v in &["closing", "closed", "closes", "closure"] {
g.add(v, "close", 0.99, morph.clone());
}
for v in &["opening", "opened", "opens"] {
g.add(v, "open", 0.99, morph.clone());
}
for v in &["configuring", "configured", "configures", "configuration"] {
g.add(v, "configure", 0.99, morph.clone());
}
for v in &["deploying", "deployed", "deploys", "deployment"] {
g.add(v, "deploy", 0.99, morph.clone());
}
for v in &["detecting", "detected", "detects", "detection"] {
g.add(v, "detect", 0.99, morph.clone());
}
for v in &["failing", "failed", "fails", "failure"] {
g.add(v, "fail", 0.99, morph.clone());
}
for v in &["expiring", "expired", "expires", "expiration"] {
g.add(v, "expire", 0.99, morph.clone());
}
for v in &["renewing", "renewed", "renews", "renewal"] {
g.add(v, "renew", 0.99, morph.clone());
}
for v in &["approving", "approved", "approves", "approval"] {
g.add(v, "approve", 0.99, morph.clone());
}
for v in &["rejecting", "rejected", "rejects", "rejection"] {
g.add(v, "reject", 0.99, morph.clone());
}
g
}
#[cfg(test)]
mod intent_graph_tests {
use super::*;
fn mini_intent_graph() -> (LexicalGraph, IntentIndex) {
let layer1 = saas_test_graph();
let mut ig = IntentIndex::new();
ig.learn_phrase(&["cancel", "subscription"], "cancel_subscription");
ig.conjunctions.push(ConjunctionRule {
words: vec!["cancel".into(), "subscription".into()],
intent: "cancel_subscription".into(),
bonus: 0.50,
});
ig.learn_phrase(&["cancel", "order"], "cancel_order");
ig.learn_phrase(&["send", "message"], "send_message");
(layer1, ig)
}
#[test]
fn layer3_basic_activation() {
let (l1, ig) = mini_intent_graph();
let (scores, neg) = ig.score(&l1, "cancel my subscription");
let top = &scores[0];
assert_eq!(top.0, "cancel_subscription");
assert!(top.1 > 0.0, "cancel_subscription should score positively");
assert!(!neg, "no negation in this query");
}
#[test]
fn layer3_oov_via_layer1() {
let (l1, ig) = mini_intent_graph();
let (scores, _) = ig.score(&l1, "terminate my sub");
assert_eq!(scores[0].0, "cancel_subscription");
}
#[test]
fn layer3_idf_disambiguates() {
let (l1, ig) = mini_intent_graph();
let (scores, _) = ig.score(&l1, "cancel order");
assert_eq!(
scores[0].0, "cancel_order",
"IDF should push cancel_order above cancel_subscription (unique word 'order')"
);
}
#[test]
fn layer3_reinforcement() {
let (l1, mut ig) = mini_intent_graph();
let (before, _) = ig.score(&l1, "kill the subscription");
let kill_sub_before = before
.iter()
.find(|(id, _)| id == "cancel_subscription")
.map(|(_, s)| *s)
.unwrap_or(0.0);
ig.reinforce(&["kill"], "cancel_subscription", 0.80);
let (after, _) = ig.score(&l1, "kill the subscription");
let kill_sub_after = after
.iter()
.find(|(id, _)| id == "cancel_subscription")
.map(|(_, s)| *s)
.unwrap_or(0.0);
assert!(
kill_sub_after > kill_sub_before,
"reinforcement should improve score"
);
}
#[test]
fn layer3_multi_intent() {
let (l1, ig) = mini_intent_graph();
let (results, _) = ig.score_multi(&l1, "cancel subscription and send message", 0.4, 2.0);
let ids: Vec<&str> = results.iter().map(|(id, _)| id.as_str()).collect();
assert!(
ids.contains(&"cancel_subscription"),
"should detect cancel_subscription"
);
assert!(ids.contains(&"send_message"), "should detect send_message");
}
#[test]
fn layer3_negation_flags_not_suppresses() {
let (l1, ig) = mini_intent_graph();
let (with_neg, neg_flag) = ig.score(&l1, "don't cancel my subscription");
let (without_neg, _) = ig.score(&l1, "cancel my subscription");
let neg_score = with_neg
.iter()
.find(|(id, _)| id == "cancel_subscription")
.map(|(_, s)| *s)
.unwrap_or(0.0);
let pos_score = without_neg
.iter()
.find(|(id, _)| id == "cancel_subscription")
.map(|(_, s)| *s)
.unwrap_or(0.0);
assert!(
neg_score <= 0.0,
"cancel_subscription should be suppressed by negation (score={neg_score})"
);
assert!(
pos_score > 0.0,
"cancel_subscription should route without negation"
);
assert!(neg_flag, "has_negation flag should be true");
}
#[test]
fn layer3_cjk_negation() {
let (_, mut ig) = mini_intent_graph();
ig.learn_phrase(&["取消", "订阅"], "cancel_subscription");
let (pos_scores, pos_neg) = ig.score_normalized("取消订阅");
assert!(!pos_scores.is_empty(), "positive CJK should score");
assert_eq!(pos_scores[0].0, "cancel_subscription");
assert!(!pos_neg, "no negation in positive query");
let (neg_scores, neg_flag) = ig.score_normalized("不取消订阅");
assert!(neg_flag, "CJK negation marker 不 should set has_negation");
let found = neg_scores.iter().any(|(id, _)| id == "cancel_subscription");
assert!(
found,
"cancel_subscription should still appear (intent is about cancellation)"
);
}
#[test]
fn idf_stays_correct_when_intents_are_added_incrementally() {
let mut ig = IntentIndex::new();
ig.learn_phrase(&["foo"], "intent_a");
ig.learn_phrase(&["foo", "bar"], "intent_b");
ig.learn_phrase(&["baz"], "intent_c");
let lex = LexicalGraph::default();
let (scores_foo, _) = ig.score(&lex, "foo");
assert!(
!scores_foo.is_empty(),
"fresh-namespace IDF bug: 'foo' should score against intent_a and intent_b"
);
assert!(
scores_foo.iter().all(|(_, s)| *s > 0.0),
"scores should be positive after the IDF rebuild on new intent"
);
let (scores_bar, _) = ig.score(&lex, "bar");
assert!(
scores_bar.iter().any(|(id, _)| id == "intent_b"),
"'bar' should score against intent_b"
);
let (scores_baz, _) = ig.score(&lex, "baz");
assert!(
scores_baz.iter().any(|(id, _)| id == "intent_c"),
"'baz' should score against intent_c"
);
}
}