use std::collections::HashMap;
#[derive(Debug)]
pub struct TrieNode<T> {
children: HashMap<char, TrieNode<T>>,
value: Option<T>,
is_end: bool,
}
impl<T> Default for TrieNode<T> {
fn default() -> Self {
Self {
children: HashMap::new(),
value: None,
is_end: false,
}
}
}
#[derive(Debug)]
pub struct Trie<T> {
root: TrieNode<T>,
}
impl<T> Default for Trie<T> {
fn default() -> Self {
Self {
root: TrieNode::default(),
}
}
}
impl<T: Clone> Trie<T> {
pub fn new() -> Self {
Self {
root: TrieNode::default(),
}
}
pub fn insert(&mut self, word: &str, value: T) {
let mut node = &mut self.root;
for ch in word.chars() {
node = node.children.entry(ch).or_default();
}
node.is_end = true;
node.value = Some(value);
}
#[allow(dead_code)]
pub fn get(&self, word: &str) -> Option<&T> {
let mut node = &self.root;
for ch in word.chars() {
match node.children.get(&ch) {
Some(n) => node = n,
None => return None,
}
}
if node.is_end {
node.value.as_ref()
} else {
None
}
}
#[allow(dead_code)]
pub fn contains(&self, word: &str) -> bool {
self.get(word).is_some()
}
pub fn find_longest_prefix<'a>(&self, text: &'a str) -> Option<(&'a str, &T, usize)> {
let mut node = &self.root;
let mut last_match: Option<(&'a str, &T, usize)> = None;
let mut current_len = 0;
for ch in text.chars() {
match node.children.get(&ch) {
Some(n) => {
node = n;
current_len += ch.len_utf8();
if node.is_end {
if let Some(ref value) = node.value {
last_match = Some((&text[..current_len], value, current_len));
}
}
}
None => break,
}
}
last_match
}
#[allow(dead_code)]
pub fn find_all<'a>(&self, text: &'a str) -> Vec<(usize, &'a str, &T)> {
let mut results = Vec::new();
let chars: Vec<char> = text.chars().collect();
let mut byte_pos = 0;
for (i, _) in chars.iter().enumerate() {
let remaining = &text[byte_pos..];
if let Some((matched, value, _)) = self.find_longest_prefix(remaining) {
results.push((byte_pos, matched, value));
}
byte_pos += chars[i].len_utf8();
}
results
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_trie_insert_and_get() {
let mut trie = Trie::new();
trie.insert("广东省", 1);
trie.insert("广东", 2);
trie.insert("广州市", 3);
assert_eq!(trie.get("广东省"), Some(&1));
assert_eq!(trie.get("广东"), Some(&2));
assert_eq!(trie.get("广州市"), Some(&3));
assert_eq!(trie.get("广"), None);
assert_eq!(trie.get("北京"), None);
}
#[test]
fn test_find_longest_prefix() {
let mut trie = Trie::new();
trie.insert("广东", 1);
trie.insert("广东省", 2);
let text = "广东省深圳市";
let result = trie.find_longest_prefix(text);
assert!(result.is_some());
let (matched, value, len) = result.unwrap();
assert_eq!(matched, "广东省");
assert_eq!(*value, 2);
assert_eq!(len, "广东省".len());
}
#[test]
fn test_find_all() {
let mut trie = Trie::new();
trie.insert("广东省", 1);
trie.insert("深圳市", 2);
trie.insert("南山区", 3);
let text = "广东省深圳市南山区科技园";
let results = trie.find_all(text);
assert_eq!(results.len(), 3);
assert_eq!(results[0].1, "广东省");
assert_eq!(results[1].1, "深圳市");
assert_eq!(results[2].1, "南山区");
}
}