use std::collections::HashMap;
use std::hash::Hash;
use std::ops::{Deref, DerefMut};
#[derive(Clone, Debug)]
pub struct TrieNode<K: Eq + Hash + Clone, V> {
children: HashMap<K, TrieNode<K, V>>,
terminal: Option<V>,
}
impl<K: Eq + Hash + Clone, V> Default for TrieNode<K, V> {
fn default() -> Self {
Self {
children: HashMap::new(),
terminal: None,
}
}
}
impl<K: Eq + Hash + Clone, V> TrieNode<K, V> {
pub fn new() -> Self {
Self::default()
}
pub fn is_empty(&self) -> bool {
self.children.is_empty()
}
pub fn num_children(&self) -> usize {
self.children.len()
}
pub fn terminal(&self) -> Option<&V> {
self.terminal.as_ref()
}
pub fn get_child(&self, key: &K) -> Option<&TrieNode<K, V>> {
self.children.get(key)
}
}
#[derive(Clone, Debug)]
pub struct Trie<K: Eq + Hash + Clone, V> {
root: TrieNode<K, V>,
len: usize,
}
impl<K: Eq + Hash + Clone, V> Default for Trie<K, V> {
fn default() -> Self {
Self::new()
}
}
impl<K: Eq + Hash + Clone, V> Trie<K, V> {
pub fn new() -> Self {
Self {
root: TrieNode::new(),
len: 0,
}
}
pub fn insert<I>(&mut self, sequence: I, value: V) -> Option<V>
where
I: IntoIterator<Item = K>,
{
let mut node = &mut self.root;
for element in sequence {
node = node.children.entry(element).or_default();
}
let old = node.terminal.take();
node.terminal = Some(value);
if old.is_none() {
self.len += 1;
}
old
}
pub fn get<I>(&self, sequence: I) -> Option<&V>
where
I: IntoIterator<Item = K>,
{
let mut node = &self.root;
for element in sequence {
match node.children.get(&element) {
Some(child) => node = child,
None => return None,
}
}
node.terminal.as_ref()
}
pub fn contains<I>(&self, sequence: I) -> bool
where
I: IntoIterator<Item = K>,
{
self.get(sequence).is_some()
}
pub fn has_prefix<I>(&self, prefix: I) -> bool
where
I: IntoIterator<Item = K>,
{
self.get_node(prefix).is_some()
}
pub fn get_node<I>(&self, sequence: I) -> Option<&TrieNode<K, V>>
where
I: IntoIterator<Item = K>,
{
let mut node = &self.root;
for element in sequence {
match node.children.get(&element) {
Some(child) => node = child,
None => return None,
}
}
Some(node)
}
pub fn children<I>(&self, context: I) -> Vec<(K, &V)>
where
I: IntoIterator<Item = K>,
{
match self.get_node(context) {
Some(node) => node
.children
.iter()
.filter_map(|(k, v)| v.terminal.as_ref().map(|t| (k.clone(), t)))
.collect(),
None => Vec::new(),
}
}
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn clear(&mut self) {
self.root = TrieNode::new();
self.len = 0;
}
}
impl<K: Eq + Hash + Clone> Trie<K, ()> {
pub fn insert_seq<I>(&mut self, sequence: I) -> bool
where
I: IntoIterator<Item = K>,
{
self.insert(sequence, ()).is_none()
}
pub fn longest_match(&self, elements: &[K], max_len: usize) -> usize {
let mut node = &self.root;
let mut longest = 0;
for (i, element) in elements.iter().take(max_len).enumerate() {
match node.children.get(element) {
Some(child) => {
node = child;
if node.terminal.is_some() {
longest = i + 1;
}
}
None => break,
}
}
longest
}
pub fn all_sequences(&self) -> Vec<Vec<K>> {
let mut result = Vec::with_capacity(self.len());
let mut prefix = Vec::new();
Self::collect_sequences(&self.root, &mut prefix, &mut result);
result
}
fn collect_sequences(node: &TrieNode<K, ()>, prefix: &mut Vec<K>, result: &mut Vec<Vec<K>>) {
if node.terminal.is_some() {
result.push(prefix.clone());
}
for (key, child) in &node.children {
prefix.push(key.clone());
Self::collect_sequences(child, prefix, result);
prefix.pop();
}
}
pub fn longest_match_iter<I>(&self, sequence: I, max_len: usize) -> usize
where
I: IntoIterator<Item = K>,
{
let elements: Vec<K> = sequence.into_iter().take(max_len).collect();
self.longest_match(&elements, max_len)
}
}
#[derive(Clone, Debug)]
pub struct CountTrie<K: Eq + Hash + Clone> {
inner: Trie<K, u64>,
}
impl<K: Eq + Hash + Clone> Default for CountTrie<K> {
fn default() -> Self {
Self::new()
}
}
impl<K: Eq + Hash + Clone> Deref for CountTrie<K> {
type Target = Trie<K, u64>;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl<K: Eq + Hash + Clone> DerefMut for CountTrie<K> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
impl<K: Eq + Hash + Clone> CountTrie<K> {
pub fn new() -> Self {
Self { inner: Trie::new() }
}
pub fn increment<I>(&mut self, sequence: I)
where
I: IntoIterator<Item = K>,
{
let inner = &mut self.inner;
let mut node = &mut inner.root;
for element in sequence {
node = node.children.entry(element).or_default();
}
match &mut node.terminal {
Some(count) => *count += 1,
None => {
node.terminal = Some(1);
inner.len += 1;
}
}
}
pub fn insert_count<I>(&mut self, sequence: I, count: u64)
where
I: IntoIterator<Item = K>,
{
let inner = &mut self.inner;
let mut node = &mut inner.root;
for element in sequence {
node = node.children.entry(element).or_default();
}
if node.terminal.is_none() {
inner.len += 1;
}
node.terminal = Some(count);
}
pub fn get_count<I>(&self, sequence: I) -> u64
where
I: IntoIterator<Item = K>,
{
self.inner.get(sequence).copied().unwrap_or(0)
}
pub fn children_count_sum<I>(&self, context: I) -> u64
where
I: IntoIterator<Item = K>,
{
match self.inner.get_node(context) {
Some(node) => node
.children
.values()
.filter_map(|child| child.terminal)
.sum(),
None => 0,
}
}
pub fn children_with_counts<I>(&self, context: I) -> Vec<(K, u64)>
where
I: IntoIterator<Item = K>,
{
match self.inner.get_node(context) {
Some(node) => node
.children
.iter()
.filter_map(|(k, v)| v.terminal.map(|count| (k.clone(), count)))
.collect(),
None => Vec::new(),
}
}
pub fn all_counts(&self) -> Vec<(Vec<K>, u64)> {
let mut result = Vec::with_capacity(self.len());
let mut prefix = Vec::new();
Self::collect_counts(&self.inner.root, &mut prefix, &mut result);
result
}
pub fn total_count(&self) -> u64 {
Self::sum_counts(&self.inner.root)
}
fn collect_counts(
node: &TrieNode<K, u64>,
prefix: &mut Vec<K>,
result: &mut Vec<(Vec<K>, u64)>,
) {
if let Some(count) = node.terminal {
result.push((prefix.clone(), count));
}
for (key, child) in &node.children {
prefix.push(key.clone());
Self::collect_counts(child, prefix, result);
prefix.pop();
}
}
fn sum_counts(node: &TrieNode<K, u64>) -> u64 {
let mut total = node.terminal.unwrap_or(0);
for child in node.children.values() {
total += Self::sum_counts(child);
}
total
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_trie_is_empty() {
let trie: Trie<char, ()> = Trie::new();
assert!(trie.is_empty());
assert_eq!(trie.len(), 0);
}
#[test]
fn test_insert_and_contains() {
let mut trie: Trie<char, ()> = Trie::new();
assert!(trie.insert_seq("hello".chars()));
assert!(trie.insert_seq("world".chars()));
assert!(!trie.insert_seq("hello".chars()));
assert!(trie.contains("hello".chars()));
assert!(trie.contains("world".chars()));
assert!(!trie.contains("hell".chars()));
assert!(!trie.contains("hello!".chars()));
assert_eq!(trie.len(), 2);
}
#[test]
fn test_has_prefix() {
let mut trie: Trie<char, ()> = Trie::new();
trie.insert_seq("hello".chars());
trie.insert_seq("help".chars());
assert!(trie.has_prefix("hel".chars()));
assert!(trie.has_prefix("hello".chars()));
assert!(trie.has_prefix("help".chars()));
assert!(!trie.has_prefix("hex".chars()));
assert!(!trie.has_prefix("world".chars()));
}
#[test]
fn test_longest_match() {
let mut trie: Trie<char, ()> = Trie::new();
trie.insert_seq("he".chars());
trie.insert_seq("hello".chars());
trie.insert_seq("help".chars());
let chars: Vec<char> = "helloworld".chars().collect();
assert_eq!(trie.longest_match(&chars, 10), 5);
let chars: Vec<char> = "helping".chars().collect();
assert_eq!(trie.longest_match(&chars, 10), 4);
let chars: Vec<char> = "hex".chars().collect();
assert_eq!(trie.longest_match(&chars, 10), 2);
let chars: Vec<char> = "world".chars().collect();
assert_eq!(trie.longest_match(&chars, 10), 0); }
#[test]
fn test_longest_match_with_max_len() {
let mut trie: Trie<char, ()> = Trie::new();
trie.insert_seq("hello".chars());
let chars: Vec<char> = "helloworld".chars().collect();
assert_eq!(trie.longest_match(&chars, 3), 0); assert_eq!(trie.longest_match(&chars, 5), 5); assert_eq!(trie.longest_match(&chars, 10), 5); }
#[test]
fn test_unicode() {
let mut trie: Trie<char, ()> = Trie::new();
trie.insert_seq("你好".chars());
trie.insert_seq("世界".chars());
trie.insert_seq("你好世界".chars());
assert!(trie.contains("你好".chars()));
assert!(trie.contains("世界".chars()));
assert!(!trie.contains("你".chars()));
let chars: Vec<char> = "你好世界".chars().collect();
assert_eq!(trie.longest_match(&chars, 10), 4); }
#[test]
fn test_clear() {
let mut trie: Trie<char, ()> = Trie::new();
trie.insert_seq("hello".chars());
trie.insert_seq("world".chars());
assert_eq!(trie.len(), 2);
trie.clear();
assert!(trie.is_empty());
assert!(!trie.contains("hello".chars()));
}
#[test]
fn test_longest_match_iter() {
let mut trie: Trie<char, ()> = Trie::new();
trie.insert_seq("hello".chars());
assert_eq!(trie.longest_match_iter("helloworld".chars(), 10), 5);
assert_eq!(trie.longest_match_iter("world".chars(), 10), 0);
}
#[test]
fn test_phoneme_trie() {
let mut trie: Trie<&str, ()> = Trie::new();
let hello_phonemes = ["h", "ə", "l", "oʊ"];
let help_phonemes = ["h", "ɛ", "l", "p"];
let world_phonemes = ["w", "ɜː", "l", "d"];
trie.insert_seq(hello_phonemes.iter().copied());
trie.insert_seq(help_phonemes.iter().copied());
trie.insert_seq(world_phonemes.iter().copied());
assert!(trie.contains(hello_phonemes.iter().copied()));
assert!(trie.contains(help_phonemes.iter().copied()));
assert!(!trie.contains(["h", "ə"].iter().copied()));
let test_phonemes = ["h", "ə", "l", "oʊ", "w", "ɜː", "l", "d"];
assert_eq!(trie.longest_match(&test_phonemes, 10), 4); }
#[test]
fn test_integer_trie() {
let mut trie: Trie<u32, ()> = Trie::new();
trie.insert_seq([1, 2, 3].iter().copied());
trie.insert_seq([1, 2, 4].iter().copied());
trie.insert_seq([5, 6, 7].iter().copied());
assert!(trie.contains([1, 2, 3].iter().copied()));
assert!(!trie.contains([1, 2].iter().copied()));
let sequence = [1, 2, 3, 5, 6, 7];
assert_eq!(trie.longest_match(&sequence, 10), 3);
}
#[test]
fn test_count_trie_new_is_empty() {
let trie: CountTrie<String> = CountTrie::new();
assert_eq!(trie.get_count(std::iter::empty::<String>()), 0);
}
#[test]
fn test_count_trie_increment_and_get_count() {
let mut trie: CountTrie<String> = CountTrie::new();
let seq = || ["the", "cat"].iter().map(|s| s.to_string());
trie.increment(seq());
assert_eq!(trie.get_count(seq()), 1);
trie.increment(seq());
assert_eq!(trie.get_count(seq()), 2);
trie.increment(seq());
assert_eq!(trie.get_count(seq()), 3);
}
#[test]
fn test_count_trie_get_count_missing() {
let trie: CountTrie<String> = CountTrie::new();
assert_eq!(
trie.get_count(["the", "cat"].iter().map(|s| s.to_string())),
0
);
}
#[test]
fn test_count_trie_children_count_sum() {
let mut trie: CountTrie<String> = CountTrie::new();
trie.increment(["the", "cat"].iter().map(|s| s.to_string()));
trie.increment(["the", "cat"].iter().map(|s| s.to_string()));
trie.increment(["the", "dog"].iter().map(|s| s.to_string()));
assert_eq!(
trie.children_count_sum(std::iter::once("the".to_string())),
3
);
assert_eq!(trie.children_count_sum(std::iter::empty::<String>()), 0);
}
#[test]
fn test_count_trie_children_with_counts() {
let mut trie: CountTrie<String> = CountTrie::new();
trie.increment(["the", "cat"].iter().map(|s| s.to_string()));
trie.increment(["the", "cat"].iter().map(|s| s.to_string()));
trie.increment(["the", "dog"].iter().map(|s| s.to_string()));
let mut children = trie.children_with_counts(std::iter::once("the".to_string()));
children.sort_by(|a, b| a.0.cmp(&b.0));
assert_eq!(children.len(), 2);
assert_eq!(children[0], ("cat".to_string(), 2));
assert_eq!(children[1], ("dog".to_string(), 1));
}
#[test]
fn test_count_trie_children_missing_context() {
let trie: CountTrie<String> = CountTrie::new();
let children = trie.children_with_counts(std::iter::once("nonexistent".to_string()));
assert!(children.is_empty());
}
#[test]
fn test_count_trie_overlapping_prefixes() {
let mut trie: CountTrie<String> = CountTrie::new();
trie.increment(std::iter::once("a".to_string()));
trie.increment(std::iter::once("a".to_string()));
trie.increment(["a", "b"].iter().map(|s| s.to_string()));
assert_eq!(trie.get_count(std::iter::once("a".to_string())), 2);
assert_eq!(trie.get_count(["a", "b"].iter().map(|s| s.to_string())), 1);
}
#[test]
fn test_count_trie_clear() {
let mut trie: CountTrie<String> = CountTrie::new();
trie.increment(std::iter::once("hello".to_string()));
assert_eq!(trie.get_count(std::iter::once("hello".to_string())), 1);
trie.clear();
assert_eq!(trie.get_count(std::iter::once("hello".to_string())), 0);
}
#[test]
fn test_count_trie_all_counts_empty() {
let trie: CountTrie<String> = CountTrie::new();
assert!(trie.all_counts().is_empty());
}
#[test]
fn test_count_trie_all_counts() {
let mut trie: CountTrie<String> = CountTrie::new();
trie.increment(["the", "cat"].iter().map(|s| s.to_string()));
trie.increment(["the", "cat"].iter().map(|s| s.to_string()));
trie.increment(["the", "dog"].iter().map(|s| s.to_string()));
trie.increment(std::iter::once("a".to_string()));
let mut counts = trie.all_counts();
counts.sort_by(|a, b| a.0.cmp(&b.0));
assert_eq!(counts.len(), 3);
assert_eq!(counts[0], (vec!["a".to_string()], 1));
assert_eq!(counts[1], (vec!["the".to_string(), "cat".to_string()], 2));
assert_eq!(counts[2], (vec!["the".to_string(), "dog".to_string()], 1));
}
#[test]
fn test_count_trie_all_counts_length_matches_len() {
let mut trie: CountTrie<String> = CountTrie::new();
trie.increment(["a", "b"].iter().map(|s| s.to_string()));
trie.increment(["a", "c"].iter().map(|s| s.to_string()));
trie.increment(std::iter::once("x".to_string()));
assert_eq!(trie.all_counts().len(), trie.len());
}
#[test]
fn test_count_trie_total_count_empty() {
let trie: CountTrie<String> = CountTrie::new();
assert_eq!(trie.total_count(), 0);
}
#[test]
fn test_count_trie_total_count() {
let mut trie: CountTrie<String> = CountTrie::new();
trie.increment(["the", "cat"].iter().map(|s| s.to_string()));
trie.increment(["the", "cat"].iter().map(|s| s.to_string()));
trie.increment(["the", "dog"].iter().map(|s| s.to_string()));
trie.increment(std::iter::once("a".to_string()));
assert_eq!(trie.total_count(), 4);
}
}