use std::collections::HashMap;
#[derive(Default)]
struct TrieNode {
children: HashMap<char, TrieNode>,
is_end: bool,
}
pub struct Trie {
root: TrieNode,
}
impl Trie {
#[must_use]
pub fn new() -> Self {
Self {
root: TrieNode::default(),
}
}
pub fn insert(&mut self, word: &str) {
let mut node = &mut self.root;
for ch in word.chars() {
node = node.children.entry(ch).or_default();
}
node.is_end = true;
}
#[must_use]
pub fn search(&self, word: &str) -> bool {
let mut node = &self.root;
for ch in word.chars() {
match node.children.get(&ch) {
Some(n) => node = n,
None => return false,
}
}
node.is_end
}
#[must_use]
pub fn starts_with(&self, prefix: &str) -> Vec<String> {
let mut node = &self.root;
for ch in prefix.chars() {
match node.children.get(&ch) {
Some(n) => node = n,
None => return vec![],
}
}
let mut results = Vec::new();
let mut buf = prefix.to_string();
Self::collect_words(node, &mut buf, &mut results);
results
}
fn collect_words(node: &TrieNode, buf: &mut String, results: &mut Vec<String>) {
if node.is_end {
results.push(buf.clone());
}
for (&ch, child) in &node.children {
buf.push(ch);
Self::collect_words(child, buf, results);
buf.pop();
}
}
}
impl Default for Trie {
fn default() -> Self {
Self::new()
}
}
#[must_use]
pub fn rabin_karp(text: &str, pattern: &str) -> Vec<usize> {
const BASE: u64 = 31;
const MOD: u64 = 1_000_000_007;
let pat: Vec<char> = pattern.chars().collect();
let m = pat.len();
let mut matches = Vec::new();
if m == 0 {
return matches;
}
let mut window: Vec<char> = Vec::with_capacity(m);
let mut chars = text.chars();
for _ in 0..m {
if let Some(c) = chars.next() {
window.push(c);
} else {
return matches; }
}
let mut pw = vec![1u64; m];
for i in 1..m {
pw[i] = pw[i - 1].wrapping_mul(BASE) % MOD;
}
let cv = |c: char| -> u64 { (c as u64).wrapping_add(1) };
let (mut ph, mut wh) = (0u64, 0u64);
for i in 0..m {
ph = (ph + cv(pat[i]) * pw[m - 1 - i]) % MOD;
wh = (wh + cv(window[i]) * pw[m - 1 - i]) % MOD;
}
let mut window_idx = 0;
let check_match = |window: &[char], start_idx: usize, pat: &[char]| -> bool {
for i in 0..m {
if window[(start_idx + i) % m] != pat[i] {
return false;
}
}
true
};
if wh == ph && check_match(&window, window_idx, &pat) {
matches.push(0);
}
let mut i = 1;
for next_char in chars {
let old_char = window[window_idx];
wh = (wh + MOD - cv(old_char) * pw[m - 1] % MOD) % MOD;
wh = (wh * BASE) % MOD;
wh = (wh + cv(next_char)) % MOD;
window[window_idx] = next_char;
window_idx = (window_idx + 1) % m;
if wh == ph && check_match(&window, window_idx, &pat) {
matches.push(i);
}
i += 1;
}
matches
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn trie_insert_search() {
let mut t = Trie::new();
t.insert("drone");
t.insert("droning");
assert!(t.search("drone"));
assert!(!t.search("dron"));
assert!(t.search("droning"));
}
#[test]
fn trie_starts_with() {
let mut t = Trie::new();
for w in &["alert", "alerting", "alarm", "base"] {
t.insert(w);
}
let mut r = t.starts_with("al");
r.sort();
assert_eq!(r, vec!["alarm", "alert", "alerting"]);
}
#[test]
fn rabin_karp_multiple() {
assert_eq!(rabin_karp("ababab", "ab"), vec![0, 2, 4]);
}
#[test]
fn rabin_karp_single() {
assert_eq!(rabin_karp("hello world", "world"), vec![6]);
}
#[test]
fn rabin_karp_no_match() {
assert!(rabin_karp("hello", "xyz").is_empty());
}
}