use core::hash::BuildHasher;
use core::slice;
use crate::alloc::{collections::BinaryHeap, string::String, vec::Vec};
use crate::aff::{CaseHandling, HIDDEN_HOMONYM_FLAG, MAX_SUGGESTIONS};
use crate::{FlagSet, FULL_WORD};
use super::Suggester;
macro_rules! has_flag {
( $flags:expr, $flag:expr ) => {{
match $flag {
Some(flag) => $flags.contains(&flag),
None => false,
}
}};
}
#[derive(Debug, PartialEq, Eq)]
struct MinScored<T: PartialEq + Eq> {
score: isize,
inner: T,
}
impl<T: PartialEq + Eq> Ord for MinScored<T> {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
self.score.cmp(&other.score).reverse()
}
}
impl<T: PartialEq + Eq> PartialOrd<Self> for MinScored<T> {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl<S: BuildHasher> Suggester<'_, S> {
pub(super) fn ngram_suggest(&self, word_str: &str, out: &mut Vec<String>) {
let mut word_buf = Vec::with_capacity(word_str.len());
let word = CharsStr::new(word_str, &mut word_buf);
let mut stem_buf = Vec::with_capacity(word.len_chars() * 2);
let mut lowercase_stem_buf = Vec::with_capacity(stem_buf.len());
let mut roots = BinaryHeap::with_capacity(100);
for entry @ (stem, flagset) in self.checker.words.iter() {
if flagset.contains(&self.checker.aff.options.forbidden_word_flag)
|| has_flag!(flagset, self.checker.aff.options.no_suggest_flag)
|| has_flag!(flagset, self.checker.aff.options.only_in_compound_flag)
|| flagset.contains(&HIDDEN_HOMONYM_FLAG)
{
continue;
}
let stem = CharsStr::new(stem.as_str(), &mut stem_buf);
let mut score =
left_common_substring_length(&self.checker.aff.options.case_handling, word, stem)
as isize;
let lowercase_stem = self
.checker
.aff
.options
.case_handling
.lowercase(stem.as_str());
let lowercase_stem = CharsStr::new(lowercase_stem.as_str(), &mut lowercase_stem_buf);
score += ngram_similarity_longer_worse(3, word, lowercase_stem);
let root = MinScored {
score,
inner: entry,
};
if roots.len() != 100 {
roots.push(root);
} else if roots.peek().is_some_and(|entry| score > entry.score) {
roots.pop();
roots.push(root);
}
}
let mut mangled_word = String::new();
let mut threshold = 0isize;
for k_byte_idx in word.char_indices().skip(1).take(3) {
let k_byte_idx = *k_byte_idx as usize;
mangled_word.clear();
mangled_word.push_str(&word_str[..k_byte_idx]);
mangled_word.extend(word_str[k_byte_idx..].chars().enumerate().map(|(i, ch)| {
if i % 4 == 0 {
'*'
} else {
ch
}
}));
threshold += ngram_similarity_any_mismatch(word.len_chars(), word, &mangled_word);
}
threshold /= 3;
let mut expanded_list = Vec::new();
let mut expanded_cross_affix = Vec::new();
let mut expanded_word_buf = Vec::with_capacity(word.len_chars() * 2);
let mut guess_words = BinaryHeap::new();
for MinScored {
inner: (stem, flags),
..
} in roots
{
expanded_cross_affix.clear();
self.expand_stem_for_ngram(
stem.as_str(),
flags,
word_str,
&mut expanded_list,
&mut expanded_cross_affix,
);
for expanded_word in expanded_list.drain(..) {
let mut score = left_common_substring_length(
&self.checker.aff.options.case_handling,
word,
CharsStr::new(&expanded_word, &mut expanded_word_buf),
) as isize;
let lower_expanded_word = self
.checker
.aff
.options
.case_handling
.lowercase(&expanded_word);
score +=
ngram_similarity_any_mismatch(word.len_chars(), word, &lower_expanded_word);
if score < threshold {
continue;
}
let guess_word = MinScored {
score,
inner: expanded_word,
};
if guess_words.len() != 200 {
guess_words.push(guess_word);
} else if guess_words.peek().is_some_and(|entry| score > entry.score) {
guess_words.pop();
guess_words.push(guess_word);
}
}
}
let mut lcs_state = Vec::new();
let mut guess_words = guess_words.into_sorted_vec();
let mut lower_guess_word_buf = Vec::with_capacity(word.len_chars());
for MinScored {
score,
inner: guess_word,
} in guess_words.iter_mut()
{
let lower_guess_word = self.checker.aff.options.case_handling.lowercase(guess_word);
let lower_guess_word = CharsStr::new(&lower_guess_word, &mut lower_guess_word_buf);
let lcs = longest_common_subsequence_length(word, lower_guess_word, &mut lcs_state);
if word.len_chars() == lower_guess_word.len_chars() && word.len_chars() == lcs {
*score += 2000;
break;
}
let mut ngram2 = ngram_similarity_any_mismatch_weighted(2, word, lower_guess_word);
ngram2 += ngram_similarity_any_mismatch_weighted(2, lower_guess_word, word);
let ngram4 = ngram_similarity_any_mismatch(4, word, lower_guess_word.as_str());
let left_common = left_common_substring_length(
&self.checker.aff.options.case_handling,
word,
lower_guess_word,
);
let (num_eq_chars_same_pos, eq_char_is_swapped) =
count_eq_at_same_pos(word, lower_guess_word);
*score = 2 * lcs as isize;
*score -= (word.len_chars() as isize - lower_guess_word.len_chars() as isize).abs();
*score += left_common as isize + ngram2 + ngram4;
if num_eq_chars_same_pos != 0 {
*score += 1;
}
if eq_char_is_swapped {
*score += 10;
}
if 5 * ngram2
< ((word.len_chars() + lower_guess_word.len_chars())
* (10 - self.checker.aff.options.max_diff_factor as usize))
as isize
{
*score -= 1000;
}
}
guess_words.sort_unstable();
let be_more_selective = guess_words.first().is_some_and(|guess| guess.score > 1000);
let old_num_suggestions = out.len();
let max_suggestions = MAX_SUGGESTIONS
.min(old_num_suggestions + self.checker.aff.options.max_ngram_suggestions as usize);
for MinScored {
score,
inner: guess_word,
} in guess_words.into_iter()
{
if out.len() == max_suggestions {
break;
}
if be_more_selective && score <= 1000 {
break;
}
if score < -100
&& (old_num_suggestions != out.len() || self.checker.aff.options.only_max_diff)
{
break;
}
if out.iter().any(|sug| guess_word.contains(sug)) {
if score < -100 {
break;
} else {
continue;
}
}
out.push(guess_word);
}
}
fn expand_stem_for_ngram(
&self,
stem: &str,
flags: &FlagSet,
word: &str,
expanded_list: &mut Vec<String>,
cross_affix: &mut Vec<bool>,
) {
expanded_list.clear();
cross_affix.clear();
if !has_flag!(flags, self.checker.aff.options.need_affix_flag) {
expanded_list.push(String::from(stem));
cross_affix.push(false);
}
if flags.is_empty() {
return;
}
for suffix in self.checker.aff.suffixes.iter() {
if !flags.contains(&suffix.flag) {
continue;
}
if !self.checker.is_outer_affix_valid::<_, FULL_WORD>(suffix) {
continue;
}
if self.checker.is_circumfix(suffix) {
continue;
}
if suffix
.strip
.as_ref()
.is_some_and(|suf| !stem.ends_with(&**suf))
{
continue;
}
if !suffix.condition_matches(stem) {
continue;
}
if !suffix.add.is_empty() && !word.ends_with(&*suffix.add) {
continue;
}
let expanded = suffix.to_derived(stem);
expanded_list.push(expanded);
cross_affix.push(suffix.crossproduct);
}
for i in 0..expanded_list.len() {
if !cross_affix[i] {
continue;
}
for prefix in self.checker.aff.prefixes.iter() {
let suffixed_stem = &expanded_list[i];
if !flags.contains(&prefix.flag) {
continue;
}
if !self.checker.is_outer_affix_valid::<_, FULL_WORD>(prefix) {
continue;
}
if self.checker.is_circumfix(prefix) {
continue;
}
if prefix
.strip
.as_ref()
.is_some_and(|pre| !suffixed_stem.starts_with(&**pre))
{
continue;
}
if !prefix.condition_matches(suffixed_stem) {
continue;
}
if !prefix.add.is_empty() && !word.starts_with(&*prefix.add) {
continue;
}
let expanded = prefix.to_derived(stem);
expanded_list.push(expanded);
}
}
for prefix in self.checker.aff.prefixes.iter() {
if !flags.contains(&prefix.flag) {
continue;
}
if !self.checker.is_outer_affix_valid::<_, FULL_WORD>(prefix) {
continue;
}
if self.checker.is_circumfix(prefix) {
continue;
}
if prefix
.strip
.as_ref()
.is_some_and(|pre| !stem.starts_with(&**pre))
{
continue;
}
if !prefix.condition_matches(stem) {
continue;
}
if !prefix.add.is_empty() && !word.starts_with(&*prefix.add) {
continue;
}
let expanded = prefix.to_derived(stem);
expanded_list.push(expanded);
}
}
}
#[derive(Clone, Copy)]
struct CharsStr<'s, 'i> {
inner: &'s str,
char_indices: &'i [u16],
}
impl<'s, 'i> CharsStr<'s, 'i> {
fn new(s: &'s str, slab: &'i mut Vec<u16>) -> Self {
let len_bytes = s.len();
assert!(len_bytes <= u16::MAX as usize);
slab.clear();
slab.extend(s.char_indices().map(|(i, _ch)| i as u16));
slab.push(len_bytes as u16);
Self {
inner: s,
char_indices: slab.as_slice(),
}
}
const fn len_chars(&self) -> usize {
self.char_indices.len() - 1
}
const fn is_empty(&self) -> bool {
self.char_indices.len() == 1
}
const fn as_str(&self) -> &str {
self.inner
}
fn char_slice(&self, char_range: core::ops::Range<usize>) -> &str {
let start_byte = self.char_indices[char_range.start] as usize;
let end_byte = self.char_indices[char_range.end] as usize;
unsafe { self.inner.get_unchecked(start_byte..end_byte) }
}
fn char_at(&self, char_idx: usize) -> &str {
let start_byte = self.char_indices[char_idx] as usize;
let end_byte = self.char_indices[char_idx + 1] as usize;
unsafe { self.inner.get_unchecked(start_byte..end_byte) }
}
fn char_iter(&self) -> impl Iterator<Item = &'s str> + '_ {
self.char_indices.windows(2).map(|idxs| unsafe {
let start = *idxs.get_unchecked(0) as usize;
let end = *idxs.get_unchecked(1) as usize;
self.inner.get_unchecked(start..end)
})
}
fn char_indices(&self) -> slice::Iter<'_, u16> {
self.char_indices.iter()
}
}
fn left_common_substring_length(
case_handling: &CaseHandling,
left: CharsStr,
right: CharsStr,
) -> usize {
let mut left_chars = left.as_str().chars();
let mut right_chars = right.as_str().chars();
let Some((l, r)) = left_chars.next().zip(right_chars.next()) else {
return 0;
};
if l != r && !case_handling.is_char_eq_lowercase(l, r) {
return 0;
}
index_of_mismatch(left_chars, right_chars)
.map(|idx| idx + 1)
.unwrap_or(left.len_chars())
}
fn index_of_mismatch<T: Eq, I: Iterator<Item = T>>(left: I, mut right: I) -> Option<usize> {
left.enumerate().find_map(|(idx, l)| match right.next() {
Some(r) if r == l => None,
_ => Some(idx),
})
}
fn ngram_similarity_longer_worse(n: usize, left: CharsStr, right: CharsStr) -> isize {
if right.is_empty() {
return 0;
}
let mut score = ngram_similarity(n, left, right.as_str());
let d = (right.len_chars() as isize - left.len_chars() as isize) - 2;
if d > 0 {
score -= d;
}
score
}
fn ngram_similarity(n: usize, left: CharsStr, right: &str) -> isize {
let n = n.min(left.len_chars());
let mut score = 0;
for k in 1..=n {
let mut k_score = 0;
for i in 0..=left.len_chars() - k {
let kgram = left.char_slice(i..i + k);
if right.contains(kgram) {
k_score += 1;
}
}
score += k_score;
if k_score < 2 {
break;
}
}
score
}
fn ngram_similarity_any_mismatch(n: usize, left: CharsStr, right: &str) -> isize {
if right.is_empty() {
return 0;
}
let mut score = ngram_similarity(n, left, right);
let d = (right.chars().count() as isize - left.len_chars() as isize).abs() - 2;
if d > 0 {
score -= d;
}
score
}
fn longest_common_subsequence_length(
left: CharsStr,
right: CharsStr,
state_buffer: &mut Vec<usize>,
) -> usize {
state_buffer.clear();
state_buffer.resize(right.len_chars(), 0);
let mut row1_prev = 0;
for l in left.char_iter() {
row1_prev = 0;
let mut row2_prev = 0;
for (j, row2_current) in state_buffer.iter_mut().enumerate().take(right.len_chars()) {
let row1_current = *row2_current;
*row2_current = if l == right.char_at(j) {
row1_prev + 1
} else {
row1_current.max(row2_prev)
};
row1_prev = row1_current;
row2_prev = *row2_current;
}
row1_prev = row2_prev;
}
row1_prev
}
fn ngram_similarity_any_mismatch_weighted(n: usize, left: CharsStr, right: CharsStr) -> isize {
if right.is_empty() {
return 0;
}
let mut score = ngram_similarity_weighted(n, left, right.as_str());
let d = (right.len_chars() as isize - left.len_chars() as isize).abs() - 2;
if d > 0 {
score -= d;
}
score
}
fn ngram_similarity_weighted(n: usize, left: CharsStr, right: &str) -> isize {
let n = n.min(left.len_chars());
let mut score = 0;
for k in 1..=n {
let mut k_score = 0;
for i in 0..=left.len_chars() - k {
let kgram = left.char_slice(i..i + k);
if right.contains(kgram) {
k_score += 1;
} else {
k_score -= 1;
if i == 0 || i == left.len_chars() - k {
k_score -= 1;
}
}
}
score += k_score;
}
score
}
fn count_eq_at_same_pos(left: CharsStr, right: CharsStr) -> (usize, bool) {
let n = left.len_chars().min(right.len_chars());
let count = left
.char_iter()
.zip(right.char_iter())
.filter(|(l, r)| l == r)
.count();
let mut is_swap = false;
if left.len_chars() == right.len_chars() && n - count == 2 {
let mut first_mismatch = None;
for (l, r) in left.char_iter().zip(right.char_iter()) {
if l != r {
if let Some((l1, r1)) = first_mismatch {
is_swap = l1 == r && r1 == l;
break;
}
first_mismatch = Some((l, r));
}
}
}
(count, is_swap)
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn index_of_mismatch_test() {
assert_eq!(index_of_mismatch(b"abcd".iter(), b"abcd".iter()), None);
assert_eq!(index_of_mismatch(b"abcd".iter(), b"abxy".iter()), Some(2));
assert_eq!(index_of_mismatch(b"abcd".iter(), b"abc".iter()), Some(3));
assert_eq!(index_of_mismatch(b"abc".iter(), b"abcd".iter()), None);
}
#[test]
fn nagrm_similarity_test() {
let mut left_buf = Vec::new();
let left = CharsStr::new("actually", &mut left_buf);
assert_eq!(ngram_similarity(3, left, "akchualy"), 11);
}
#[test]
fn longest_common_subsequence_length_test() {
let mut left_buffer = Vec::new();
let mut right_buffer = Vec::new();
let mut state_buffer = Vec::new();
assert_eq!(
longest_common_subsequence_length(
CharsStr::new("aaa", &mut left_buffer),
CharsStr::new("aaa", &mut right_buffer),
&mut state_buffer
),
3
);
assert_eq!(
longest_common_subsequence_length(
CharsStr::new("aaaaa", &mut left_buffer),
CharsStr::new("bbbaa", &mut right_buffer),
&mut state_buffer
),
2
);
}
#[test]
fn count_eq_at_same_pos_test() {
let mut left_buffer = Vec::new();
let mut right_buffer = Vec::new();
assert_eq!(
count_eq_at_same_pos(
CharsStr::new("abcd", &mut left_buffer),
CharsStr::new("abcd", &mut right_buffer),
),
(4, false)
);
assert_eq!(
count_eq_at_same_pos(
CharsStr::new("abcd", &mut left_buffer),
CharsStr::new("acbd", &mut right_buffer),
),
(2, true)
);
}
}