use crate::error::LexError;
use crate::node::Node;
use crate::utils::{read_lines_from_file, validate_expression};
pub struct Trie {
nodes: Vec<Node>,
num_of_words: usize,
}
impl Trie {
pub fn new() -> Self {
let root = Node::new(0, '\0');
Trie {
nodes: vec![root],
num_of_words: 0,
}
}
pub fn node_count(&self) -> usize {
self.nodes.len()
}
pub fn word_count(&self) -> usize {
self.num_of_words
}
pub fn contains(&self, word: &str) -> bool {
if word.is_empty() {
return true;
}
let mut node_idx = 0;
let chars: Vec<char> = word.chars().collect();
for (i, &ch) in chars.iter().enumerate() {
match self.nodes[node_idx].children.get(&ch) {
Some(&next_idx) => {
node_idx = next_idx;
if i == chars.len() - 1 && self.nodes[node_idx].eow {
return true;
}
}
None => return false,
}
}
false
}
pub fn contains_prefix(&self, prefix: &str) -> bool {
self.prefix_node(prefix).is_some()
}
fn prefix_node(&self, prefix: &str) -> Option<usize> {
if prefix.is_empty() {
return Some(0);
}
let mut node_idx = 0;
for ch in prefix.chars() {
match self.nodes[node_idx].children.get(&ch) {
Some(&next_idx) => node_idx = next_idx,
None => return None,
}
}
Some(node_idx)
}
pub fn add(&mut self, word: &str, count: usize) -> Result<(), LexError> {
let mut node_idx = 0;
let chars: Vec<char> = word.chars().collect();
for (i, &ch) in chars.iter().enumerate() {
if !self.nodes[node_idx].children.contains_key(&ch) {
let new_id = self.nodes.len();
let new_node = Node::new(new_id, ch);
self.nodes.push(new_node);
self.nodes[node_idx].children.insert(ch, new_id);
}
let next_idx = *self.nodes[node_idx].children.get(&ch).unwrap();
node_idx = next_idx;
if i == chars.len() - 1 {
self.nodes[node_idx].eow = true;
self.nodes[node_idx].count += count;
self.num_of_words += count;
}
}
Ok(())
}
pub fn add_all<I: IntoIterator<Item = String>>(&mut self, words: I) -> Result<(), LexError> {
for word in words {
self.add(&word, 1)?;
}
Ok(())
}
pub fn add_from_file(&mut self, path: &str) -> Result<(), LexError> {
let lines = read_lines_from_file(path)?;
for word in lines {
self.add(&word, 1)?;
}
Ok(())
}
pub fn search(&self, pattern: &str) -> Result<Vec<String>, LexError> {
if pattern.is_empty() {
return Ok(vec![]);
}
let pattern = validate_expression(pattern);
let pat_chars: Vec<char> = pattern.chars().collect();
let mut results = Vec::new();
let mut current = String::new();
words_with_wildcard(&self.nodes, 0, &pat_chars, 0, &mut current, &mut results);
Ok(results.into_iter().map(|(w, _)| w).collect())
}
pub fn search_with_count(&self, pattern: &str) -> Result<Vec<(String, usize)>, LexError> {
if pattern.is_empty() {
return Ok(vec![]);
}
let pattern = validate_expression(pattern);
let pat_chars: Vec<char> = pattern.chars().collect();
let mut results = Vec::new();
let mut current = String::new();
words_with_wildcard(&self.nodes, 0, &pat_chars, 0, &mut current, &mut results);
Ok(results)
}
pub fn search_with_prefix(&self, prefix: &str) -> Vec<String> {
if prefix.is_empty() {
return vec![];
}
match self.prefix_node(prefix) {
None => vec![],
Some(node_idx) => {
let pat = ['*'];
let mut results = Vec::new();
let mut current = prefix.to_string();
words_with_wildcard(&self.nodes, node_idx, &pat, 0, &mut current, &mut results);
results.into_iter().map(|(w, _)| w).collect()
}
}
}
pub fn search_with_prefix_count(&self, prefix: &str) -> Vec<(String, usize)> {
if prefix.is_empty() {
return vec![];
}
match self.prefix_node(prefix) {
None => vec![],
Some(node_idx) => {
let pat = ['*'];
let mut results = Vec::new();
let mut current = prefix.to_string();
words_with_wildcard(&self.nodes, node_idx, &pat, 0, &mut current, &mut results);
results
}
}
}
pub fn search_within_distance(&self, word: &str, dist: usize) -> Vec<String> {
let target: Vec<char> = word.chars().collect();
let row: Vec<usize> = (0..=target.len()).collect();
let mut results = Vec::new();
for (&ch, &child_idx) in &self.nodes[0].children {
let mut current_word = ch.to_string();
search_within_distance_inner(
&self.nodes,
child_idx,
&target,
ch,
&mut current_word,
&row,
dist,
&mut results,
);
}
results.into_iter().map(|(w, _)| w).collect()
}
pub fn search_within_distance_count(&self, word: &str, dist: usize) -> Vec<(String, usize)> {
let target: Vec<char> = word.chars().collect();
let row: Vec<usize> = (0..=target.len()).collect();
let mut results = Vec::new();
for (&ch, &child_idx) in &self.nodes[0].children {
let mut current_word = ch.to_string();
search_within_distance_inner(
&self.nodes,
child_idx,
&target,
ch,
&mut current_word,
&row,
dist,
&mut results,
);
}
results
}
}
impl Default for Trie {
fn default() -> Self {
Self::new()
}
}
pub(crate) fn words_with_wildcard(
nodes: &[Node],
node_idx: usize,
pattern: &[char],
index: usize,
current: &mut String,
results: &mut Vec<(String, usize)>,
) {
let node = &nodes[node_idx];
if node.eow && index >= pattern.len() && !current.is_empty() {
results.push((current.clone(), node.count));
}
if index >= pattern.len() {
return;
}
match pattern[index] {
'?' => {
for (&ch, &child_idx) in &node.children {
current.push(ch);
words_with_wildcard(nodes, child_idx, pattern, index + 1, current, results);
current.pop();
}
}
'*' => {
words_with_wildcard(nodes, node_idx, pattern, index + 1, current, results);
for (&ch, &child_idx) in &node.children {
current.push(ch);
words_with_wildcard(nodes, child_idx, pattern, index, current, results);
current.pop();
}
}
literal => {
if let Some(&child_idx) = node.children.get(&literal) {
current.push(literal);
words_with_wildcard(nodes, child_idx, pattern, index + 1, current, results);
current.pop();
}
}
}
}
pub(crate) fn search_within_distance_inner(
nodes: &[Node],
node_idx: usize,
target: &[char],
letter: char,
current_word: &mut String,
prev_row: &[usize],
dist: usize,
results: &mut Vec<(String, usize)>,
) {
let cols = target.len() + 1;
let mut curr_row = Vec::with_capacity(cols);
curr_row.push(prev_row[0] + 1);
for col in 1..cols {
let insert_cost = curr_row[col - 1] + 1;
let delete_cost = prev_row[col] + 1;
let replace_cost = if target[col - 1] == letter {
prev_row[col - 1]
} else {
prev_row[col - 1] + 1
};
curr_row.push(insert_cost.min(delete_cost).min(replace_cost));
}
let node = &nodes[node_idx];
if *curr_row.last().unwrap() <= dist && node.eow {
results.push((current_word.clone(), node.count));
}
if *curr_row.iter().min().unwrap() <= dist {
for (&ch, &child_idx) in &node.children {
current_word.push(ch);
search_within_distance_inner(
nodes,
child_idx,
target,
ch,
current_word,
&curr_row,
dist,
results,
);
current_word.pop();
}
}
}