use std::collections::{BTreeSet, HashMap, HashSet};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Suggestion {
pub term: String,
pub frequency: usize,
pub distance: u8,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Verbosity {
Top,
Closest,
All,
}
pub struct SymSpell {
max_distance: u8,
dictionary: HashMap<String, usize>,
deletes: HashMap<String, HashSet<String>>,
}
impl SymSpell {
pub fn new(max_distance: u8) -> Self {
Self {
max_distance,
dictionary: HashMap::new(),
deletes: HashMap::new(),
}
}
pub fn from_iter<I, S>(max_distance: u8, iter: I) -> Self
where
I: IntoIterator<Item = (S, usize)>,
S: Into<String>,
{
let mut sym = SymSpell::new(max_distance);
sym.load_iter(iter);
sym
}
pub fn load_iter<I, S>(&mut self, iter: I)
where
I: IntoIterator<Item = (S, usize)>,
S: Into<String>,
{
for (word_s, freq) in iter {
let word = word_s.into();
if word.is_empty() {
continue;
}
self.dictionary.insert(word.clone(), freq);
let dels = generate_deletes(&word, self.max_distance);
for d in dels {
self.deletes.entry(d).or_default().insert(word.clone());
}
}
}
pub fn lookup(&self, term: &str, max_distance: u8, verbosity: Verbosity) -> Vec<Suggestion> {
if term.is_empty() {
return Vec::new();
}
let max_distance = std::cmp::min(max_distance, self.max_distance);
let mut candidates: HashSet<String> = HashSet::new();
let mut considered: HashSet<String> = HashSet::new();
let mut queue: Vec<String> = vec![term.to_string()];
let queue_limit = 10000usize;
for idx in 0..queue.len() {
if idx >= queue_limit {
break;
}
let current = queue[idx].clone();
if let Some(set) = self.deletes.get(¤t) {
for w in set {
candidates.insert(w.clone());
}
}
if (current.len() > 1) && (max_distance as usize) > 0 {
for i in 0..current.len() {
let mut s = current.clone();
s.remove(i);
if !queue.contains(&s) {
queue.push(s);
}
}
}
}
let mut results: Vec<Suggestion> = Vec::new();
if let Some(&freq) = self.dictionary.get(term) {
results.push(Suggestion {
term: term.to_string(),
frequency: freq,
distance: 0,
});
}
for cand in candidates {
if considered.contains(&cand) {
continue;
}
considered.insert(cand.clone());
let distance = damerau_levenshtein(term, &cand);
if distance <= max_distance {
let freq = *self.dictionary.get(&cand).unwrap_or(&0);
results.push(Suggestion {
term: cand.clone(),
frequency: freq,
distance,
});
}
}
if results.is_empty() {
return Vec::new();
}
let min_distance = results.iter().map(|r| r.distance).min().unwrap_or(u8::MAX);
match verbosity {
Verbosity::Top => {
let mut best: Option<Suggestion> = None;
for r in results.into_iter().filter(|r| r.distance == min_distance) {
match &best {
None => best = Some(r),
Some(b) => {
if r.frequency > b.frequency {
best = Some(r);
}
}
}
}
best.into_iter().collect()
}
Verbosity::Closest => {
let mut filtered: Vec<Suggestion> = results
.into_iter()
.filter(|r| r.distance == min_distance)
.collect();
filtered.sort_by(|a, b| b.frequency.cmp(&a.frequency));
filtered
}
Verbosity::All => {
results.sort_by(|a, b| {
a.distance
.cmp(&b.distance)
.then_with(|| b.frequency.cmp(&a.frequency))
});
results
}
}
}
pub fn frequency(&self, word: &str) -> Option<usize> {
self.dictionary.get(word).copied()
}
}
pub struct EmbeddedSymSpell {
pub max_distance: u8,
pub dict: &'static ::phf::Map<&'static str, usize>,
pub deletes: &'static ::phf::Map<&'static str, &'static [&'static str]>,
}
impl EmbeddedSymSpell {
pub fn from_phf(
max_distance: u8,
dict: &'static ::phf::Map<&'static str, usize>,
deletes: &'static ::phf::Map<&'static str, &'static [&'static str]>,
) -> Self {
Self {
max_distance,
dict,
deletes,
}
}
pub fn frequency(&self, word: &str) -> Option<usize> {
self.dict.get(word).copied()
}
pub fn lookup(&self, term: &str, max_distance: u8, verbosity: Verbosity) -> Vec<Suggestion> {
if term.is_empty() {
return Vec::new();
}
let max_distance = std::cmp::min(max_distance, self.max_distance);
if let Some(&freq) = self.dict.get(term) {
return vec![Suggestion {
term: term.to_string(),
frequency: freq,
distance: 0,
}];
}
let mut candidates: HashSet<String> = HashSet::new();
let mut visited_deletions: HashSet<String> = HashSet::new();
let mut queue: Vec<String> = vec![term.to_string()];
let queue_limit = 10000usize;
for idx in 0..queue.len() {
if idx >= queue_limit {
break;
}
let current = queue[idx].clone();
if visited_deletions.insert(current.clone()) {
if let Some(slice) = self.deletes.get(¤t as &str) {
for &w in *slice {
candidates.insert(w.to_string());
}
}
}
if (current.len() > 1) && (max_distance as usize) > 0 {
for i in 0..current.len() {
let mut s = current.clone();
s.remove(i);
if !queue.contains(&s) {
queue.push(s);
}
}
}
}
let mut results: Vec<Suggestion> = Vec::new();
for cand in candidates {
let distance = damerau_levenshtein(term, &cand);
if distance <= max_distance {
let freq = *self.dict.get(&cand as &str).unwrap_or(&0);
results.push(Suggestion {
term: cand.clone(),
frequency: freq,
distance,
});
}
}
if results.is_empty() {
for (k, &v) in self.dict.entries() {
let distance = damerau_levenshtein(term, k);
if distance <= max_distance {
results.push(Suggestion {
term: k.to_string(),
frequency: v,
distance,
});
}
}
if results.is_empty() {
return Vec::new();
}
}
let min_distance = results.iter().map(|r| r.distance).min().unwrap_or(u8::MAX);
match verbosity {
Verbosity::Top => {
let mut best: Option<Suggestion> = None;
for r in results.into_iter().filter(|r| r.distance == min_distance) {
match &best {
None => best = Some(r),
Some(b) => {
if r.frequency > b.frequency {
best = Some(r);
}
}
}
}
best.into_iter().collect()
}
Verbosity::Closest => {
let mut filtered: Vec<Suggestion> = results
.into_iter()
.filter(|r| r.distance == min_distance)
.collect();
filtered.sort_by(|a, b| b.frequency.cmp(&a.frequency));
filtered
}
Verbosity::All => {
results.sort_by(|a, b| {
a.distance
.cmp(&b.distance)
.then_with(|| b.frequency.cmp(&a.frequency))
});
results
}
}
}
pub fn find_top(&self, term: &str) -> Option<Suggestion> {
self.lookup(term, self.max_distance, Verbosity::Top)
.into_iter()
.next()
}
pub fn find_closest(&self, term: &str) -> Vec<Suggestion> {
self.lookup(term, self.max_distance, Verbosity::Closest)
}
pub fn find_all(&self, term: &str) -> Vec<Suggestion> {
self.lookup(term, self.max_distance, Verbosity::All)
}
pub fn contains(&self, word: &str) -> bool {
self.dict.contains_key(word)
}
pub fn dict_map(&self) -> &'static ::phf::Map<&'static str, usize> {
self.dict
}
pub fn deletes_map(&self) -> &'static ::phf::Map<&'static str, &'static [&'static str]> {
self.deletes
}
pub fn candidates_for_deletion(&self, deletion: &str) -> Option<&'static [&'static str]> {
self.deletes.get(deletion).copied()
}
pub fn frequency_or_zero(&self, word: &str) -> usize {
*self.dict.get(word).unwrap_or(&0usize)
}
}
fn generate_deletes(word: &str, max_distance: u8) -> HashSet<String> {
let mut deletes: HashSet<String> = HashSet::new();
let mut queue: BTreeSet<String> = BTreeSet::new();
queue.insert(word.to_string());
for _d in 0..max_distance {
let mut next: BTreeSet<String> = BTreeSet::new();
for s in &queue {
if s.is_empty() {
continue;
}
for i in 0..s.len() {
let mut t = s.clone();
t.remove(i);
if deletes.insert(t.clone()) {
next.insert(t);
}
}
}
if next.is_empty() {
break;
}
queue = next;
}
deletes
}
fn damerau_levenshtein(a: &str, b: &str) -> u8 {
let a_chars: Vec<char> = a.chars().collect();
let b_chars: Vec<char> = b.chars().collect();
let (alen, blen) = (a_chars.len(), b_chars.len());
if alen == 0 {
return blen.min(255) as u8;
}
if blen == 0 {
return alen.min(255) as u8;
}
let mut dp: Vec<Vec<usize>> = vec![vec![0; blen + 1]; alen + 1];
for i in 0..=alen {
dp[i][0] = i;
}
for j in 0..=blen {
dp[0][j] = j;
}
for i in 1..=alen {
for j in 1..=blen {
let cost = if a_chars[i - 1] == b_chars[j - 1] {
0
} else {
1
};
dp[i][j] = std::cmp::min(
std::cmp::min(dp[i - 1][j] + 1, dp[i][j - 1] + 1),
dp[i - 1][j - 1] + cost,
);
if i > 1
&& j > 1
&& a_chars[i - 1] == b_chars[j - 2]
&& a_chars[i - 2] == b_chars[j - 1]
{
dp[i][j] = std::cmp::min(dp[i][j], dp[i - 2][j - 2] + 1);
}
}
}
dp[alen][blen].min(255) as u8
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_damerau_basic() {
assert_eq!(damerau_levenshtein("abc", "abc"), 0);
assert_eq!(damerau_levenshtein("abc", "ab"), 1);
assert_eq!(damerau_levenshtein("ab", "ba"), 1); }
#[test]
fn test_symspell_lookup() {
let entries = vec![
("hello".to_string(), 100usize),
("hell".to_string(), 50usize),
("help".to_string(), 10usize),
("world".to_string(), 200usize),
];
let sym = SymSpell::from_iter(2, entries);
let suggestions = sym.lookup("helo", 2, Verbosity::Closest);
assert!(suggestions.iter().any(|s| s.term == "hello"));
}
}