use std::collections::HashMap;
type NodeIdx = usize;
const NIL: NodeIdx = usize::MAX;
struct TrieNode<T> {
children: HashMap<u8, NodeIdx>,
output: Option<T>,
fail: NodeIdx,
dict: NodeIdx,
}
impl<T> TrieNode<T> {
fn new() -> Self {
Self {
children: HashMap::new(),
output: None,
fail: NIL,
dict: NIL,
}
}
}
pub struct KeywordClassifier<T> {
nodes: Vec<TrieNode<T>>,
}
impl<T: Clone> KeywordClassifier<T> {
pub fn new(rules: &[(&[&str], T)]) -> Self {
let mut classifier = Self {
nodes: vec![TrieNode::new()], };
for (keywords, cls) in rules {
for &kw in *keywords {
classifier.insert(kw, cls.clone());
}
}
classifier.build_failure_links();
classifier
}
pub fn classify(&self, text: &str) -> Option<&T> {
let mut node_idx: NodeIdx = 0;
for byte in text.as_bytes() {
let ch = if byte.is_ascii_uppercase() {
byte | 0x20
} else {
*byte
};
while node_idx != 0 && !self.nodes[node_idx].children.contains_key(&ch) {
node_idx = self.nodes[node_idx].fail;
}
node_idx = self.nodes[node_idx]
.children
.get(&ch)
.copied()
.unwrap_or(0);
if let Some(ref out) = self.nodes[node_idx].output {
return Some(out);
}
let dict_idx = self.nodes[node_idx].dict;
if dict_idx != NIL {
if let Some(ref out) = self.nodes[dict_idx].output {
return Some(out);
}
}
}
None
}
fn insert(&mut self, word: &str, cls: T) {
let mut node_idx: NodeIdx = 0;
for byte in word.as_bytes() {
let ch = if byte.is_ascii_uppercase() {
byte | 0x20
} else {
*byte
};
if let Some(&child_idx) = self.nodes[node_idx].children.get(&ch) {
node_idx = child_idx;
} else {
let child_idx = self.nodes.len();
self.nodes.push(TrieNode::new());
self.nodes[node_idx].children.insert(ch, child_idx);
node_idx = child_idx;
}
}
if self.nodes[node_idx].output.is_none() {
self.nodes[node_idx].output = Some(cls);
}
}
fn build_failure_links(&mut self) {
let mut queue: Vec<NodeIdx> = Vec::new();
let root_children: Vec<(u8, NodeIdx)> = self.nodes[0]
.children
.iter()
.map(|(&ch, &idx)| (ch, idx))
.collect();
for (_ch, child_idx) in &root_children {
self.nodes[*child_idx].fail = 0;
self.nodes[*child_idx].dict = NIL;
queue.push(*child_idx);
}
let mut head: usize = 0;
while head < queue.len() {
let node_idx = queue[head];
head += 1;
let children: Vec<(u8, NodeIdx)> = self.nodes[node_idx]
.children
.iter()
.map(|(&ch, &idx)| (ch, idx))
.collect();
for (ch, child_idx) in children {
let mut fail = self.nodes[node_idx].fail;
while fail != 0 && !self.nodes[fail].children.contains_key(&ch) {
fail = self.nodes[fail].fail;
}
let child_fail = self.nodes[fail]
.children
.get(&ch)
.copied()
.unwrap_or(0);
let child_fail = if child_fail == child_idx { 0 } else { child_fail };
self.nodes[child_idx].fail = child_fail;
self.nodes[child_idx].dict = if self.nodes[child_fail].output.is_some() {
child_fail
} else {
self.nodes[child_fail].dict
};
queue.push(child_idx);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn basic_classification() {
let classifier: KeywordClassifier<&str> = KeywordClassifier::new(&[
(&["blocked", "403", "captcha"], "blocked"),
(&["timeout"], "transient"),
]);
assert_eq!(classifier.classify("Error 403 Forbidden"), Some(&"blocked"));
assert_eq!(classifier.classify("Request timed out: timeout"), Some(&"transient"));
assert_eq!(classifier.classify("success"), None);
}
#[test]
fn case_insensitive() {
let classifier: KeywordClassifier<&str> = KeywordClassifier::new(&[
(&["captcha"], "blocked"),
]);
assert_eq!(classifier.classify("CAPTCHA detected"), Some(&"blocked"));
assert_eq!(classifier.classify("CaPtChA"), Some(&"blocked"));
assert_eq!(classifier.classify("captcha"), Some(&"blocked"));
}
#[test]
fn first_rule_wins() {
let classifier: KeywordClassifier<&str> = KeywordClassifier::new(&[
(&["timeout"], "blocked"),
(&["timeout"], "transient"),
]);
assert_eq!(classifier.classify("timeout error"), Some(&"blocked"));
}
#[test]
fn overlapping_patterns() {
let classifier: KeywordClassifier<&str> = KeywordClassifier::new(&[
(&["bot detect", "bot protection"], "blocked"),
(&["err_connection_reset", "err_connection_closed"], "transient"),
]);
assert_eq!(
classifier.classify("Detected bot detection script"),
Some(&"blocked")
);
assert_eq!(
classifier.classify("net::ERR_CONNECTION_RESET"),
Some(&"transient")
);
}
#[test]
fn no_match_returns_none() {
let classifier: KeywordClassifier<&str> = KeywordClassifier::new(&[
(&["foo"], "a"),
(&["bar"], "b"),
]);
assert_eq!(classifier.classify("baz qux"), None);
assert_eq!(classifier.classify(""), None);
}
#[test]
fn substring_matching() {
let classifier: KeywordClassifier<&str> = KeywordClassifier::new(&[
(&["403"], "blocked"),
]);
assert_eq!(classifier.classify("HTTP/1.1 403 Forbidden"), Some(&"blocked"));
}
#[test]
fn multiple_keywords_same_rule() {
let classifier: KeywordClassifier<i32> = KeywordClassifier::new(&[
(&["alpha", "beta", "gamma"], 1),
(&["delta", "epsilon"], 2),
]);
assert_eq!(classifier.classify("testing beta value"), Some(&1));
assert_eq!(classifier.classify("epsilon result"), Some(&2));
assert_eq!(classifier.classify("zeta"), None);
}
#[test]
fn aho_corasick_shared_prefix() {
let classifier: KeywordClassifier<&str> = KeywordClassifier::new(&[
(&["abcde"], "first"),
(&["bcd"], "second"),
]);
assert_eq!(classifier.classify("xxbcdxx"), Some(&"second"));
assert_eq!(classifier.classify("abcde"), Some(&"second"));
let classifier2: KeywordClassifier<&str> = KeywordClassifier::new(&[
(&["xyz"], "first"),
(&["abc"], "second"),
]);
assert_eq!(classifier2.classify("xxxyzxx"), Some(&"first"));
assert_eq!(classifier2.classify("xxabcxx"), Some(&"second"));
}
#[test]
fn real_world_error_messages() {
#[derive(Clone, Debug, PartialEq)]
enum ErrorClass {
Blocked,
Auth,
BackendDown,
Transient,
}
let classifier: KeywordClassifier<ErrorClass> = KeywordClassifier::new(&[
(
&[
"bot detect", "blocked", "403", "captcha",
"checking your browser", "access denied",
],
ErrorClass::Blocked,
),
(&["401", "unauthorized"], ErrorClass::Auth),
(
&["backend unavailable", "503", "service unavailable"],
ErrorClass::BackendDown,
),
(
&["err_connection_reset", "timeout", "websocket closed"],
ErrorClass::Transient,
),
]);
assert_eq!(
classifier.classify("Error: 403 Forbidden - Access Denied"),
Some(&ErrorClass::Blocked)
);
assert_eq!(
classifier.classify("HTTP 401 Unauthorized"),
Some(&ErrorClass::Auth)
);
assert_eq!(
classifier.classify("503 Service Temporarily Unavailable"),
Some(&ErrorClass::BackendDown)
);
assert_eq!(
classifier.classify("net::ERR_CONNECTION_RESET at navigation"),
Some(&ErrorClass::Transient)
);
assert_eq!(
classifier.classify("Page loaded successfully"),
None
);
}
}