use std::cmp::Ordering;
use std::collections::BTreeMap;
#[derive(Clone, Debug, PartialEq)]
pub struct SuggestionEntry {
pub score: f64,
pub payload: Option<Vec<u8>>,
}
#[derive(Clone, Debug, PartialEq)]
pub struct SuggestionHit {
pub value: Vec<u8>,
pub score: Option<f64>,
pub payload: Option<Vec<u8>>,
}
#[derive(Debug, Default)]
pub struct SuggestionDict {
entries: BTreeMap<Vec<u8>, SuggestionEntry>,
}
impl SuggestionDict {
#[must_use]
pub fn new() -> Self {
Self {
entries: BTreeMap::new(),
}
}
#[must_use]
pub fn len(&self) -> usize {
self.entries.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn add(
&mut self,
suggestion: Vec<u8>,
score: f64,
incr: bool,
payload: Option<Vec<u8>>,
) -> usize {
match self.entries.get_mut(&suggestion) {
Some(existing) if incr => {
existing.score += score;
if payload.is_some() {
existing.payload = payload;
}
}
Some(existing) => {
existing.score = score;
existing.payload = payload;
}
None => {
self.entries
.insert(suggestion, SuggestionEntry { score, payload });
}
}
self.entries.len()
}
pub fn del(&mut self, suggestion: &[u8]) -> bool {
self.entries.remove(suggestion).is_some()
}
#[must_use]
pub fn get(
&self,
prefix: &[u8],
max: usize,
fuzzy: bool,
with_scores: bool,
with_payloads: bool,
) -> Vec<SuggestionHit> {
if max == 0 {
return Vec::new();
}
let mut candidates: Vec<(&Vec<u8>, &SuggestionEntry)> = Vec::new();
if fuzzy {
for (key, entry) in &self.entries {
if fuzzy_prefix_match(prefix, key, 1) {
candidates.push((key, entry));
}
}
} else {
for (key, entry) in self.entries.range(prefix.to_vec()..) {
if !key.starts_with(prefix) {
break;
}
candidates.push((key, entry));
}
}
candidates.sort_by(|a, b| {
match b.1.score.partial_cmp(&a.1.score).unwrap_or(Ordering::Equal) {
Ordering::Equal => a.0.cmp(b.0),
ord => ord,
}
});
candidates.truncate(max);
candidates
.into_iter()
.map(|(key, entry)| SuggestionHit {
value: key.clone(),
score: if with_scores { Some(entry.score) } else { None },
payload: if with_payloads {
entry.payload.clone()
} else {
None
},
})
.collect()
}
}
fn fuzzy_prefix_match(prefix: &[u8], candidate: &[u8], max_errors: usize) -> bool {
if max_errors == 0 {
return candidate.starts_with(prefix);
}
let m = prefix.len();
if m == 0 {
return true;
}
let window_end = (m + max_errors).min(candidate.len());
let txt = &candidate[..window_end];
let n = txt.len();
let k = max_errors;
let mut prev = vec![usize::MAX; n + 1];
let mut curr = vec![usize::MAX; n + 1];
prev[0] = 0;
for cell in prev.iter_mut().take(n.min(k) + 1).skip(1) {
*cell = 0; }
for i in 1..=m {
curr[0] = i;
let lo = i.saturating_sub(k);
let hi = (i + k).min(n);
for j in 1..=n {
if j < lo || j > hi {
curr[j] = usize::MAX;
continue;
}
let cost = usize::from(prefix[i - 1] != txt[j - 1]);
let sub = prev[j - 1].saturating_add(cost);
let del = prev[j].saturating_add(1);
let ins = curr[j - 1].saturating_add(1);
curr[j] = sub.min(del).min(ins);
}
std::mem::swap(&mut prev, &mut curr);
}
let lo = m.saturating_sub(k);
let hi = (m + k).min(n);
let mut best = usize::MAX;
for cell in prev.iter().take(hi + 1).skip(lo) {
best = best.min(*cell);
}
best <= k
}
#[cfg(test)]
mod tests {
use super::*;
use hegel::generators as gs;
use hegel::TestCase;
#[test]
fn add_grows_dict() {
let mut d = SuggestionDict::new();
assert_eq!(d.add(b"alpha".to_vec(), 1.0, false, None), 1);
assert_eq!(d.add(b"beta".to_vec(), 1.0, false, None), 2);
assert_eq!(d.add(b"alpha".to_vec(), 5.0, false, None), 2);
}
#[test]
fn replace_overwrites_score() {
let mut d = SuggestionDict::new();
d.add(b"hello".to_vec(), 1.0, false, None);
d.add(b"hello".to_vec(), 7.0, false, None);
let hits = d.get(b"hello", 5, false, true, false);
assert_eq!(hits.len(), 1);
assert_eq!(hits[0].score, Some(7.0));
}
#[test]
fn incr_adds_to_score() {
let mut d = SuggestionDict::new();
d.add(b"hello".to_vec(), 1.5, false, None);
d.add(b"hello".to_vec(), 0.5, true, None);
let hits = d.get(b"hello", 5, false, true, false);
assert!((hits[0].score.unwrap() - 2.0).abs() < 1e-9);
}
#[test]
fn fuzzy_one_edit_substitution() {
assert!(fuzzy_prefix_match(b"helo", b"hello world", 1));
}
#[test]
fn fuzzy_one_edit_insertion() {
assert!(fuzzy_prefix_match(b"heallo", b"hello world", 1));
}
#[test]
fn fuzzy_one_edit_deletion() {
assert!(fuzzy_prefix_match(b"hllo", b"hello world", 1));
}
#[test]
fn fuzzy_two_edits_rejected_at_k1() {
assert!(!fuzzy_prefix_match(b"hxylo", b"hello world", 1));
}
#[test]
fn strict_prefix_rejects_substitution() {
assert!(!fuzzy_prefix_match(b"helo", b"hello world", 0));
assert!(fuzzy_prefix_match(b"hell", b"hello world", 0));
}
#[test]
fn get_orders_by_descending_score() {
let mut d = SuggestionDict::new();
d.add(b"apple".to_vec(), 1.0, false, None);
d.add(b"apricot".to_vec(), 5.0, false, None);
d.add(b"avocado".to_vec(), 3.0, false, None);
let hits = d.get(b"a", 5, false, true, false);
let names: Vec<&[u8]> = hits.iter().map(|h| h.value.as_slice()).collect();
assert_eq!(names, vec![&b"apricot"[..], &b"avocado"[..], &b"apple"[..]]);
}
#[test]
fn get_breaks_score_ties_lexicographically() {
let mut d = SuggestionDict::new();
d.add(b"banana".to_vec(), 1.0, false, None);
d.add(b"apple".to_vec(), 1.0, false, None);
d.add(b"cherry".to_vec(), 1.0, false, None);
let hits = d.get(b"", 5, false, false, false);
let names: Vec<&[u8]> = hits.iter().map(|h| h.value.as_slice()).collect();
assert_eq!(names, vec![&b"apple"[..], &b"banana"[..], &b"cherry"[..]]);
}
#[test]
fn del_returns_presence() {
let mut d = SuggestionDict::new();
d.add(b"alpha".to_vec(), 1.0, false, None);
assert!(d.del(b"alpha"));
assert!(!d.del(b"alpha"));
assert_eq!(d.len(), 0);
}
fn arb_suggestion(tc: &TestCase) -> Vec<u8> {
let len = tc.draw(gs::integers::<usize>().min_value(0).max_value(8));
let mut out = Vec::with_capacity(len);
for _ in 0..len {
let c = tc.draw(gs::integers::<u8>().min_value(b'a').max_value(b'd'));
out.push(c);
}
out
}
fn arb_score(tc: &TestCase) -> f64 {
let n = tc.draw(gs::integers::<i32>().min_value(0).max_value(1000));
f64::from(n) / 10.0
}
fn arb_corpus(tc: &TestCase) -> Vec<(Vec<u8>, f64)> {
let n = tc.draw(gs::integers::<usize>().min_value(0).max_value(8));
let mut out = Vec::with_capacity(n);
for _ in 0..n {
out.push((arb_suggestion(tc), arb_score(tc)));
}
out
}
#[hegel::test(test_cases = 256)]
fn len_equals_unique_suggestion_count(tc: TestCase) {
let corpus = arb_corpus(&tc);
let mut d = SuggestionDict::new();
for (s, score) in &corpus {
d.add(s.clone(), *score, false, None);
}
let mut unique: std::collections::BTreeSet<Vec<u8>> = std::collections::BTreeSet::new();
for (s, _) in &corpus {
unique.insert(s.clone());
}
assert_eq!(d.len(), unique.len());
}
#[hegel::test(test_cases = 256)]
fn strict_prefix_hits_carry_prefix(tc: TestCase) {
let corpus = arb_corpus(&tc);
let prefix = arb_suggestion(&tc);
let mut d = SuggestionDict::new();
for (s, score) in &corpus {
d.add(s.clone(), *score, false, None);
}
let hits = d.get(&prefix, 50, false, false, false);
for hit in &hits {
assert!(
hit.value.starts_with(&prefix),
"non-prefix hit {hit:?} for prefix {prefix:?}",
);
}
}
#[hegel::test(test_cases = 256)]
fn get_results_are_sorted(tc: TestCase) {
let corpus = arb_corpus(&tc);
let mut d = SuggestionDict::new();
for (s, score) in &corpus {
d.add(s.clone(), *score, false, None);
}
let hits = d.get(b"", 100, false, true, false);
for w in hits.windows(2) {
let (a, b) = (&w[0], &w[1]);
let a_score = a.score.unwrap();
let b_score = b.score.unwrap();
assert!(
b_score <= a_score + f64::EPSILON,
"score order broken: {a:?} then {b:?}",
);
if (a_score - b_score).abs() <= f64::EPSILON {
assert!(a.value <= b.value, "lex tie-break broken at {a:?} {b:?}");
}
}
}
#[hegel::test(test_cases = 256)]
fn del_is_idempotent(tc: TestCase) {
let corpus = arb_corpus(&tc);
let target = arb_suggestion(&tc);
let mut d = SuggestionDict::new();
for (s, score) in &corpus {
d.add(s.clone(), *score, false, None);
}
let _ = d.del(&target);
assert!(!d.del(&target));
let hits = d.get(&target, 50, false, false, false);
assert!(
hits.iter().all(|h| h.value != target),
"deleted entry leaked into hits: {hits:?}",
);
}
}