use std::borrow::Cow;
use unicode_normalization::UnicodeNormalization;
use unicode_normalization::char::is_combining_mark;
#[derive(Debug, Clone, Copy)]
pub enum Ranking {
CaseSensitiveEqual,
Equal,
StartsWith,
WordStartsWith,
Contains,
Acronym,
Matches(f64),
NoMatch,
}
impl Ranking {
fn tier_value(&self) -> u8 {
match self {
Ranking::CaseSensitiveEqual => 7,
Ranking::Equal => 6,
Ranking::StartsWith => 5,
Ranking::WordStartsWith => 4,
Ranking::Contains => 3,
Ranking::Acronym => 2,
Ranking::Matches(_) => 1,
Ranking::NoMatch => 0,
}
}
}
impl PartialEq for Ranking {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Ranking::Matches(a), Ranking::Matches(b)) => a == b,
_ => self.tier_value() == other.tier_value(),
}
}
}
impl PartialOrd for Ranking {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
match (self, other) {
(Ranking::Matches(a), Ranking::Matches(b)) => a.partial_cmp(b),
_ => self.tier_value().partial_cmp(&other.tier_value()),
}
}
}
pub fn get_closeness_ranking(candidate: &str, query: &str) -> Ranking {
let mut candidate_chars = candidate.chars().enumerate();
let mut first_match_index: Option<usize> = None;
let mut last_match_index: usize = 0;
for query_char in query.chars() {
let found = candidate_chars.find(|&(_, c)| c == query_char);
match found {
Some((pos, _)) => {
if first_match_index.is_none() {
first_match_index = Some(pos);
}
last_match_index = pos;
}
None => return Ranking::NoMatch,
}
}
let first = first_match_index.unwrap_or(0);
let spread = last_match_index - first;
if spread == 0 {
Ranking::Matches(2.0)
} else {
Ranking::Matches(1.0 + 1.0 / spread as f64)
}
}
fn is_acronym_delimiter(c: char) -> bool {
c == ' ' || c == '-'
}
pub fn get_acronym(s: &str) -> String {
let mut chars = s.chars();
let first = match chars.next() {
Some(c) => c,
None => return String::new(),
};
let word_count_estimate = 1 + memchr::memchr2_iter(b' ', b'-', s.as_bytes()).count();
let mut acronym = String::with_capacity(word_count_estimate);
acronym.push(first);
let mut prev = first;
for c in chars {
if is_acronym_delimiter(prev) && !is_acronym_delimiter(c) {
acronym.push(c);
}
prev = c;
}
acronym
}
const LATIN1_STRIP: [u8; 64] = [
b'A', b'A', b'A', b'A', b'A', b'A', 0, b'C', b'E', b'E', b'E', b'E', b'I', b'I', b'I', b'I', 0, b'N', b'O', b'O', b'O', b'O', b'O', 0, 0, b'U', b'U', b'U', b'U', b'Y', 0,
0, b'a', b'a', b'a', b'a', b'a', b'a', 0, b'c', b'e', b'e', b'e', b'e', b'i', b'i', b'i', b'i', 0, b'n', b'o', b'o', b'o', b'o', b'o', 0, 0, b'u', b'u', b'u', b'u', b'y', 0,
b'y', ];
fn strip_latin1_diacritics(s: &str) -> Option<Cow<'_, str>> {
let bytes = s.as_bytes();
let mut needs_strip = false;
let mut i = 0;
while i < bytes.len() {
let b = bytes[i];
if b < 0x80 {
i += 1;
} else if b == 0xC3 && i + 1 < bytes.len() {
let trail = bytes[i + 1];
let offset = trail.wrapping_sub(0x80);
if offset < 64 && LATIN1_STRIP[offset as usize] != 0 {
needs_strip = true;
}
i += 2;
} else if b == 0xC2 && i + 1 < bytes.len() {
i += 2;
} else {
return None;
}
}
if !needs_strip {
return Some(Cow::Borrowed(s));
}
let mut result = String::with_capacity(s.len());
let mut j = 0;
while j < bytes.len() {
let b = bytes[j];
if b < 0x80 {
result.push(b as char);
j += 1;
} else if b == 0xC3 && j + 1 < bytes.len() {
let trail = bytes[j + 1];
let offset = trail.wrapping_sub(0x80);
if offset < 64 {
let base = LATIN1_STRIP[offset as usize];
if base != 0 {
result.push(base as char);
} else {
result.push_str(
std::str::from_utf8(&bytes[j..j + 2]).expect("valid 2-byte UTF-8 sequence"),
);
}
}
j += 2;
} else {
result.push_str(
std::str::from_utf8(&bytes[j..j + 2]).expect("valid 2-byte UTF-8 sequence"),
);
j += 2;
}
}
Some(Cow::Owned(result))
}
pub fn prepare_value_for_comparison(s: &str, keep_diacritics: bool) -> Cow<'_, str> {
if keep_diacritics {
return Cow::Borrowed(s);
}
if s.is_ascii() {
return Cow::Borrowed(s);
}
if let Some(result) = strip_latin1_diacritics(s) {
return result;
}
let mut nfd = s.nfd();
let mut prefix_len: usize = 0;
loop {
match nfd.next() {
Some(c) if is_combining_mark(c) => break,
Some(_) => {
prefix_len += 1;
}
None => {
return Cow::Borrowed(s);
}
}
}
let mut result = String::with_capacity(s.len());
for (i, c) in s.nfd().enumerate() {
if i >= prefix_len {
break;
}
result.push(c);
}
for c in nfd {
if !is_combining_mark(c) {
result.push(c);
}
}
Cow::Owned(result)
}
pub(crate) struct PreparedQuery {
prepared: String,
pub(crate) lower: String,
char_count: usize,
}
impl PreparedQuery {
pub(crate) fn new(query: &str, keep_diacritics: bool) -> Self {
let prepared = prepare_value_for_comparison(query, keep_diacritics).into_owned();
let lower = prepared.to_lowercase();
let char_count = if lower.is_ascii() {
lower.len()
} else {
lower.chars().count()
};
Self {
prepared,
lower,
char_count,
}
}
}
fn lowercase_into(s: &str, buf: &mut String) {
buf.clear();
if s.is_ascii() {
buf.reserve(s.len());
if s.as_bytes().iter().all(|b| !b.is_ascii_uppercase()) {
buf.push_str(s);
} else {
buf.extend(s.as_bytes().iter().map(|&b| b.to_ascii_lowercase() as char));
}
} else {
buf.reserve(s.len());
if s.chars().all(|c| !c.is_uppercase()) {
buf.push_str(s);
} else {
for c in s.chars() {
for lc in c.to_lowercase() {
buf.push(lc);
}
}
}
}
}
pub(crate) fn get_match_ranking_prepared(
test_string: &str,
pq: &PreparedQuery,
keep_diacritics: bool,
candidate_buf: &mut String,
finder: Option<&memchr::memmem::Finder<'_>>,
) -> Ranking {
let candidate = prepare_value_for_comparison(test_string, keep_diacritics);
let candidate_char_count = if candidate.is_ascii() {
candidate.len()
} else {
candidate.chars().count()
};
if pq.char_count > candidate_char_count {
return Ranking::NoMatch;
}
if *candidate == *pq.prepared {
return Ranking::CaseSensitiveEqual;
}
lowercase_into(&candidate, candidate_buf);
if let Some(finder) = finder {
let candidate_bytes = candidate_buf.as_bytes();
let mut iter = finder.find_iter(candidate_bytes);
if let Some(first) = iter.next() {
if first == 0 {
if candidate_buf.len() == pq.lower.len() {
return Ranking::Equal;
}
return Ranking::StartsWith;
}
if candidate_bytes[first - 1] == b' ' {
return Ranking::WordStartsWith;
}
for pos in iter {
if pos > 0 && candidate_bytes[pos - 1] == b' ' {
return Ranking::WordStartsWith;
}
}
return Ranking::Contains;
}
} else {
if candidate_buf.is_empty() {
return Ranking::Equal;
}
return Ranking::StartsWith;
}
if pq.char_count == 1 {
return Ranking::NoMatch;
}
let acronym = get_acronym(candidate_buf);
if acronym.contains(&pq.lower) {
return Ranking::Acronym;
}
get_closeness_ranking(candidate_buf, &pq.lower)
}
pub fn get_match_ranking(
test_string: &str,
string_to_rank: &str,
keep_diacritics: bool,
) -> Ranking {
let pq = PreparedQuery::new(string_to_rank, keep_diacritics);
let finder = if pq.lower.is_empty() {
None
} else {
Some(memchr::memmem::Finder::new(pq.lower.as_bytes()))
};
let mut buf = String::new();
get_match_ranking_prepared(test_string, &pq, keep_diacritics, &mut buf, finder.as_ref())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn full_tier_ordering_descending() {
assert!(Ranking::CaseSensitiveEqual > Ranking::Equal);
assert!(Ranking::Equal > Ranking::StartsWith);
assert!(Ranking::StartsWith > Ranking::WordStartsWith);
assert!(Ranking::WordStartsWith > Ranking::Contains);
assert!(Ranking::Contains > Ranking::Acronym);
assert!(Ranking::Acronym > Ranking::Matches(1.5));
assert!(Ranking::Matches(1.5) > Ranking::NoMatch);
}
#[test]
fn matches_sub_score_ordering() {
assert!(Ranking::Matches(1.9) > Ranking::Matches(1.1));
assert!(Ranking::Matches(2.0) > Ranking::Matches(1.5));
assert!(Ranking::Matches(1.5) > Ranking::Matches(1.01));
}
#[test]
fn matches_below_acronym_above_no_match() {
assert!(Ranking::Acronym > Ranking::Matches(2.0));
assert!(Ranking::Matches(1.01) > Ranking::NoMatch);
}
#[test]
fn equality_same_fixed_tiers() {
assert_eq!(Ranking::CaseSensitiveEqual, Ranking::CaseSensitiveEqual);
assert_eq!(Ranking::Equal, Ranking::Equal);
assert_eq!(Ranking::NoMatch, Ranking::NoMatch);
}
#[test]
fn equality_same_matches_sub_score() {
assert_eq!(Ranking::Matches(1.5), Ranking::Matches(1.5));
}
#[test]
fn inequality_different_tiers() {
assert_ne!(Ranking::CaseSensitiveEqual, Ranking::Equal);
assert_ne!(Ranking::Matches(1.5), Ranking::NoMatch);
}
#[test]
fn inequality_different_sub_scores() {
assert_ne!(Ranking::Matches(1.2), Ranking::Matches(1.8));
}
#[test]
fn debug_formatting() {
let debug_str = format!("{:?}", Ranking::Matches(1.5));
assert!(debug_str.contains("Matches"));
assert!(debug_str.contains("1.5"));
let debug_str = format!("{:?}", Ranking::CaseSensitiveEqual);
assert!(debug_str.contains("CaseSensitiveEqual"));
}
#[test]
fn copy_produces_equal_value() {
let original = Ranking::Matches(1.75);
let copied = original;
assert_eq!(original, copied);
let original = Ranking::Contains;
let copied = original;
assert_eq!(original, copied);
}
#[test]
fn matches_at_boundary_values() {
assert!(Ranking::Acronym > Ranking::Matches(2.0));
assert!(Ranking::Matches(1.001) > Ranking::NoMatch);
}
#[test]
fn acronym_hyphen_and_space() {
assert_eq!(get_acronym("north-west airlines"), "nwa");
}
#[test]
fn acronym_space_only() {
assert_eq!(get_acronym("san francisco"), "sf");
}
#[test]
fn acronym_single_word() {
assert_eq!(get_acronym("single"), "s");
}
#[test]
fn acronym_empty_string() {
assert_eq!(get_acronym(""), "");
}
#[test]
fn acronym_underscores_not_delimiters() {
assert_eq!(get_acronym("snake_case_word"), "s");
}
#[test]
fn acronym_consecutive_spaces() {
assert_eq!(get_acronym("hello world"), "hw");
}
#[test]
fn acronym_consecutive_hyphens() {
assert_eq!(get_acronym("a--b"), "ab");
}
#[test]
fn acronym_mixed_delimiters() {
assert_eq!(get_acronym("one two-three four"), "ottf");
}
#[test]
fn acronym_single_char() {
assert_eq!(get_acronym("x"), "x");
}
#[test]
fn acronym_trailing_delimiter() {
assert_eq!(get_acronym("hello "), "h");
}
#[test]
fn strips_combining_acute_accent() {
let result = prepare_value_for_comparison("cafe\u{0301}", false);
assert_eq!(result, "cafe");
assert!(matches!(result, Cow::Owned(_)));
}
#[test]
fn returns_borrowed_for_plain_ascii() {
let result = prepare_value_for_comparison("cafe", false);
assert_eq!(result, "cafe");
assert!(matches!(result, Cow::Borrowed(_)));
}
#[test]
fn returns_borrowed_when_keep_diacritics_is_true() {
let input = "cafe\u{0301}";
let result = prepare_value_for_comparison(input, true);
assert_eq!(result, input);
assert!(matches!(result, Cow::Borrowed(_)));
}
#[test]
fn strips_precomposed_accent() {
let result = prepare_value_for_comparison("caf\u{00E9}", false);
assert_eq!(result, "cafe");
assert!(matches!(result, Cow::Owned(_)));
}
#[test]
fn strips_multiple_diacritics() {
let result = prepare_value_for_comparison("\u{00FC}ber-ma\u{00F1}ana", false);
assert_eq!(result, "uber-manana");
assert!(matches!(result, Cow::Owned(_)));
}
#[test]
fn returns_borrowed_for_empty_string() {
let result = prepare_value_for_comparison("", false);
assert_eq!(result, "");
assert!(matches!(result, Cow::Borrowed(_)));
}
#[test]
fn returns_borrowed_for_non_ascii_without_diacritics() {
let result = prepare_value_for_comparison("\u{4e16}\u{754c}", false);
assert_eq!(result, "\u{4e16}\u{754c}");
assert!(matches!(result, Cow::Borrowed(_)));
}
#[test]
fn keep_diacritics_true_with_plain_ascii() {
let result = prepare_value_for_comparison("hello", true);
assert_eq!(result, "hello");
assert!(matches!(result, Cow::Borrowed(_)));
}
#[test]
fn strips_combining_tilde() {
let result = prepare_value_for_comparison("n\u{0303}", false);
assert_eq!(result, "n");
assert!(matches!(result, Cow::Owned(_)));
}
#[test]
fn strips_multiple_combining_marks_on_single_base() {
let result = prepare_value_for_comparison("a\u{0300}\u{0301}", false);
assert_eq!(result, "a");
assert!(matches!(result, Cow::Owned(_)));
}
#[test]
fn preserves_o_stroke_unchanged() {
let result = prepare_value_for_comparison("\u{00F8}slo", false);
assert_eq!(result, "\u{00F8}slo");
assert!(
matches!(result, Cow::Borrowed(_)),
"o-stroke string should be returned borrowed (no allocation)"
);
}
#[test]
fn closeness_fuzzy_match_playground() {
let rank = get_closeness_ranking("playground", "plgnd");
match rank {
Ranking::Matches(s) => {
assert!(s > 1.0, "sub-score {s} should be > 1.0");
assert!(s < 2.0, "sub-score {s} should be < 2.0");
let expected = 1.0 + 1.0 / 9.0;
assert!(
(s - expected).abs() < f64::EPSILON,
"expected {expected}, got {s}"
);
}
other => panic!("expected Matches, got {other:?}"),
}
}
#[test]
fn closeness_no_match() {
assert_eq!(get_closeness_ranking("abc", "xyz"), Ranking::NoMatch);
}
#[test]
fn closeness_single_char_match() {
assert_eq!(get_closeness_ranking("ab", "a"), Ranking::Matches(2.0));
}
#[test]
fn closeness_single_char_not_found() {
assert_eq!(get_closeness_ranking("ab", "z"), Ranking::NoMatch);
}
#[test]
fn closeness_adjacent_chars() {
let rank = get_closeness_ranking("abcdef", "abc");
assert_eq!(rank, Ranking::Matches(1.5));
}
#[test]
fn closeness_two_char_query() {
let rank = get_closeness_ranking("abcdef", "ad");
match rank {
Ranking::Matches(s) => {
let expected = 1.0 + 1.0 / 3.0;
assert!(
(s - expected).abs() < f64::EPSILON,
"expected {expected}, got {s}"
);
}
other => panic!("expected Matches, got {other:?}"),
}
}
#[test]
fn closeness_partial_mismatch() {
assert_eq!(get_closeness_ranking("abcdef", "az"), Ranking::NoMatch);
}
#[test]
fn closeness_query_longer_than_candidate() {
assert_eq!(get_closeness_ranking("ab", "abcdef"), Ranking::NoMatch);
}
#[test]
fn closeness_result_always_in_range() {
let cases = [
("abcdefghijklmnop", "ap"), ("abcdefghijklmnop", "abop"), ("abcdef", "af"), ("ab", "ab"), ];
for (candidate, query) in cases {
let rank = get_closeness_ranking(candidate, query);
match rank {
Ranking::Matches(s) => {
assert!(
s > 1.0 && s <= 2.0,
"score {s} out of range for ({candidate}, {query})"
);
}
other => panic!("expected Matches for ({candidate}, {query}), got {other:?}"),
}
}
}
#[test]
fn closeness_case_sensitive() {
assert_eq!(get_closeness_ranking("abc", "A"), Ranking::NoMatch);
}
#[test]
fn closeness_empty_query() {
assert_eq!(get_closeness_ranking("anything", ""), Ranking::Matches(2.0));
}
#[test]
fn closeness_unicode_chars() {
let rank = get_closeness_ranking("a\u{00E9}c", "ac");
assert_eq!(rank, Ranking::Matches(1.5));
}
#[test]
fn ranking_equal() {
assert_eq!(get_match_ranking("Green", "green", false), Ranking::Equal);
}
#[test]
fn ranking_case_sensitive_equal() {
assert_eq!(
get_match_ranking("Green", "Green", false),
Ranking::CaseSensitiveEqual
);
}
#[test]
fn ranking_starts_with() {
assert_eq!(
get_match_ranking("Greenland", "green", false),
Ranking::StartsWith
);
}
#[test]
fn ranking_word_starts_with() {
assert_eq!(
get_match_ranking("San Francisco", "fran", false),
Ranking::WordStartsWith
);
}
#[test]
fn ranking_contains() {
assert_eq!(get_match_ranking("abcdef", "cde", false), Ranking::Contains);
}
#[test]
fn ranking_acronym() {
assert_eq!(
get_match_ranking("North-West Airlines", "nwa", false),
Ranking::Acronym
);
}
#[test]
fn ranking_fuzzy_matches() {
let rank = get_match_ranking("playground", "plgnd", false);
match rank {
Ranking::Matches(s) => {
assert!(s > 1.0, "sub-score {s} should be > 1.0");
assert!(s < 2.0, "sub-score {s} should be < 2.0");
}
other => panic!("expected Matches, got {other:?}"),
}
}
#[test]
fn ranking_no_match() {
assert_eq!(get_match_ranking("abc", "xyz", false), Ranking::NoMatch);
}
#[test]
fn ranking_query_longer_than_candidate() {
assert_eq!(get_match_ranking("ab", "abcdef", false), Ranking::NoMatch);
}
#[test]
fn ranking_single_char_not_substring() {
assert_eq!(get_match_ranking("abcdef", "z", false), Ranking::NoMatch);
}
#[test]
fn ranking_single_char_substring_found() {
assert_eq!(get_match_ranking("abcdef", "a", false), Ranking::StartsWith);
}
#[test]
fn ranking_single_char_equal() {
assert_eq!(
get_match_ranking("a", "a", false),
Ranking::CaseSensitiveEqual
);
}
#[test]
fn ranking_empty_query() {
assert_eq!(
get_match_ranking("anything", "", false),
Ranking::StartsWith
);
}
#[test]
fn ranking_both_empty() {
assert_eq!(
get_match_ranking("", "", false),
Ranking::CaseSensitiveEqual
);
}
#[test]
fn ranking_word_boundary_only_spaces() {
assert_eq!(
get_match_ranking("North-West", "west", false),
Ranking::Contains
);
}
#[test]
fn ranking_word_boundary_second_occurrence() {
assert_eq!(
get_match_ranking("xfoo bar foo", "foo", false),
Ranking::WordStartsWith
);
}
#[test]
fn ranking_diacritics_stripping() {
assert_eq!(
get_match_ranking("caf\u{00E9}", "cafe", false),
Ranking::CaseSensitiveEqual
);
}
#[test]
fn ranking_diacritics_kept() {
assert_eq!(
get_match_ranking("caf\u{00E9}", "cafe", true),
Ranking::NoMatch
);
}
#[test]
fn ranking_unicode_char_count_vs_byte_count() {
assert_eq!(get_match_ranking("\u{00E9}", "ab", true), Ranking::NoMatch);
}
#[test]
fn ranking_acronym_not_reached_for_single_char() {
assert_eq!(get_match_ranking("a b c", "x", false), Ranking::NoMatch);
}
#[test]
fn ranking_acronym_multi_word() {
assert_eq!(
get_match_ranking("as soon as possible", "asap", false),
Ranking::Acronym
);
}
#[test]
fn ranking_contains_mid_string() {
assert_eq!(
get_match_ranking("hello world", "lo w", false),
Ranking::Contains
);
}
#[test]
fn ranking_query_longer_than_candidate_unicode() {
assert_eq!(
get_match_ranking("\u{4e16}\u{754c}", "abc", false),
Ranking::NoMatch
);
}
#[test]
fn lowercase_into_already_lowercase_ascii() {
let mut buf = String::new();
lowercase_into("hello world", &mut buf);
assert_eq!(buf, "hello world");
}
#[test]
fn lowercase_into_already_lowercase_ascii_no_realloc() {
let mut buf = String::new();
lowercase_into("hello world", &mut buf);
assert_eq!(buf, "hello world");
let ptr_before = buf.as_ptr();
let cap_before = buf.capacity();
lowercase_into("hello world", &mut buf);
assert_eq!(buf, "hello world");
assert_eq!(buf.as_ptr(), ptr_before);
assert_eq!(buf.capacity(), cap_before);
}
#[test]
fn lowercase_into_mixed_case_ascii() {
let mut buf = String::new();
lowercase_into("Hello World", &mut buf);
assert_eq!(buf, "hello world");
}
#[test]
fn lowercase_into_all_uppercase_ascii() {
let mut buf = String::new();
lowercase_into("HELLO WORLD", &mut buf);
assert_eq!(buf, "hello world");
}
#[test]
fn lowercase_into_already_lowercase_non_ascii() {
let mut buf = String::new();
lowercase_into("cafe", &mut buf);
assert_eq!(buf, "cafe");
}
#[test]
fn lowercase_into_non_ascii_with_uppercase() {
let mut buf = String::new();
lowercase_into("Universitat", &mut buf);
assert_eq!(buf, "universitat");
}
#[test]
fn lowercase_into_empty_string() {
let mut buf = String::new();
lowercase_into("", &mut buf);
assert_eq!(buf, "");
}
#[test]
fn lowercase_into_clears_previous_contents() {
let mut buf = String::from("leftover data");
lowercase_into("new", &mut buf);
assert_eq!(buf, "new");
}
#[test]
fn lowercase_into_non_ascii_already_lowercase_cjk() {
let mut buf = String::new();
lowercase_into("\u{4e16}\u{754c}", &mut buf);
assert_eq!(buf, "\u{4e16}\u{754c}");
}
#[test]
fn lowercase_into_non_ascii_mixed_case_with_accent() {
let mut buf = String::new();
lowercase_into("Caf\u{00C9}", &mut buf);
assert_eq!(buf, "caf\u{00E9}");
}
}