use crate::Confidence;
use serde::{Deserialize, Serialize};
use std::borrow::Cow;
pub use super::script::Script;
fn strip_bom_and_bidi_controls(s: &str) -> Cow<'_, str> {
fn should_strip(c: char) -> bool {
c == '\u{FEFF}'
|| matches!(c, '\u{061C}' | '\u{200E}' | '\u{200F}')
|| matches!(c, '\u{202A}'..='\u{202E}' | '\u{2066}'..='\u{2069}')
}
if !s.chars().any(should_strip) {
return Cow::Borrowed(s);
}
Cow::Owned(s.chars().filter(|&c| !should_strip(c)).collect())
}
pub fn normalize(s: &str) -> String {
let s = strip_bom_and_bidi_controls(s);
let cfg = textprep::ScrubConfig {
collapse_whitespace: true,
normalization: textprep::ScrubNormalization::Nfc,
case: textprep::ScrubCase::Lower,
strip_diacritics: false,
..textprep::ScrubConfig::default()
};
textprep::scrub_with(s.as_ref(), &cfg)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SimilarityConfig {
pub normalize: bool,
pub min_ngram_length: usize,
pub ngram_size: usize,
}
impl Default for SimilarityConfig {
fn default() -> Self {
Self {
normalize: true,
min_ngram_length: 2,
ngram_size: 2, }
}
}
#[derive(Debug, Clone, Default)]
pub struct Similarity {
config: SimilarityConfig,
}
impl Similarity {
pub fn new() -> Self {
Self::default()
}
pub fn with_config(config: SimilarityConfig) -> Self {
Self { config }
}
pub fn compute(&self, a: &str, b: &str) -> f32 {
if a == b {
return 1.0;
}
if a.is_empty() || b.is_empty() {
return if a.is_empty() && b.is_empty() {
1.0
} else {
0.0
};
}
let (a_norm, b_norm) = if self.config.normalize {
(normalize(a), normalize(b))
} else {
(a.to_string(), b.to_string())
};
if a_norm == b_norm {
return 1.0;
}
let script_a = Script::detect(&a_norm);
let script_b = Script::detect(&b_norm);
if script_a.has_word_boundaries() && script_b.has_word_boundaries() {
self.word_jaccard(&a_norm, &b_norm)
} else {
self.ngram_jaccard(&a_norm, &b_norm)
}
}
fn word_jaccard(&self, a: &str, b: &str) -> f32 {
textprep::similarity::word_jaccard(a, b) as f32
}
fn ngram_jaccard(&self, a: &str, b: &str) -> f32 {
textprep::similarity::char_ngram_jaccard(a, b, self.config.ngram_size) as f32
}
}
pub fn levenshtein_distance(a: &str, b: &str) -> usize {
let a_chars: Vec<char> = a.chars().collect();
let b_chars: Vec<char> = b.chars().collect();
let m = a_chars.len();
let n = b_chars.len();
if m == 0 {
return n;
}
if n == 0 {
return m;
}
let mut prev: Vec<usize> = (0..=n).collect();
let mut curr = vec![0; n + 1];
for i in 1..=m {
curr[0] = i;
for j in 1..=n {
let cost = if a_chars[i - 1] == b_chars[j - 1] {
0
} else {
1
};
curr[j] = (prev[j] + 1) .min(curr[j - 1] + 1) .min(prev[j - 1] + cost); }
std::mem::swap(&mut prev, &mut curr);
}
prev[n]
}
pub fn levenshtein_similarity(a: &str, b: &str) -> f32 {
if a == b {
return 1.0;
}
if a.is_empty() && b.is_empty() {
return 1.0;
}
let max_len = a.chars().count().max(b.chars().count());
if max_len == 0 {
return 1.0;
}
let distance = levenshtein_distance(a, b);
1.0 - (distance as f32 / max_len as f32)
}
pub fn jaro_similarity(a: &str, b: &str) -> f32 {
if a == b {
return 1.0;
}
let a_chars: Vec<char> = a.chars().collect();
let b_chars: Vec<char> = b.chars().collect();
let a_len = a_chars.len();
let b_len = b_chars.len();
if a_len == 0 || b_len == 0 {
return 0.0;
}
let match_distance = (a_len.max(b_len) / 2).saturating_sub(1);
let mut a_matches = vec![false; a_len];
let mut b_matches = vec![false; b_len];
let mut matches = 0;
let mut transpositions = 0;
for i in 0..a_len {
let start = i.saturating_sub(match_distance);
let end = (i + match_distance + 1).min(b_len);
for j in start..end {
if b_matches[j] || a_chars[i] != b_chars[j] {
continue;
}
a_matches[i] = true;
b_matches[j] = true;
matches += 1;
break;
}
}
if matches == 0 {
return 0.0;
}
let mut k = 0;
for i in 0..a_len {
if !a_matches[i] {
continue;
}
while !b_matches[k] {
k += 1;
}
if a_chars[i] != b_chars[k] {
transpositions += 1;
}
k += 1;
}
let m = matches as f32;
let t = (transpositions / 2) as f32;
(m / a_len as f32 + m / b_len as f32 + (m - t) / m) / 3.0
}
pub fn jaro_winkler_similarity(a: &str, b: &str) -> f32 {
let jaro = jaro_similarity(a, b);
let a_chars: Vec<char> = a.chars().collect();
let b_chars: Vec<char> = b.chars().collect();
let prefix_len = a_chars
.iter()
.zip(b_chars.iter())
.take(4)
.take_while(|(a, b)| a == b)
.count();
let p = 0.1; jaro + (prefix_len as f32 * p * (1.0 - jaro))
}
pub fn multilingual_similarity(a: &str, b: &str) -> f32 {
Similarity::new().compute(a, b)
}
pub fn cross_lingual_similarity(a: &str, b: &str) -> f32 {
let script_a = Script::detect(a);
let script_b = Script::detect(b);
if script_a == script_b && script_a != Script::Mixed {
return multilingual_similarity(a, b);
}
let known_pairs: &[(&str, &str)] = &[
("moscow", "москва"),
("saint petersburg", "санкт-петербург"),
("kiev", "киев"),
("beijing", "北京"),
("shanghai", "上海"),
("guangzhou", "广州"),
("shenzhen", "深圳"),
("tokyo", "東京"),
("東京", "tokyo"), ("osaka", "大阪"),
("kyoto", "京都"),
("cairo", "القاهرة"),
("riyadh", "الرياض"),
("dubai", "دبي"),
("putin", "путин"),
("xi jinping", "习近平"),
("abe", "安倍"),
];
let a_norm = normalize(a);
let b_norm = normalize(b);
let base_a = a
.split('(')
.next()
.and_then(|s| s.split('(').next())
.map(|s| s.trim());
let base_b = b
.split('(')
.next()
.and_then(|s| s.split('(').next())
.map(|s| s.trim());
if let Some(ba) = base_a {
for (pair_a, pair_b) in known_pairs {
if (ba == *pair_b && b_norm == *pair_a) || (ba == *pair_a && b_norm == *pair_b) {
return 0.85;
}
}
}
if let Some(bb) = base_b {
let bb_norm = normalize(bb);
for (pair_a, pair_b) in known_pairs {
let a_matches_a =
a_norm == *pair_a || a == *pair_a || a_norm.contains(pair_a) || a.contains(pair_a);
let a_matches_b =
a_norm == *pair_b || a == *pair_b || a_norm.contains(pair_b) || a.contains(pair_b);
let bb_matches_a = bb_norm == *pair_a
|| bb == *pair_a
|| bb_norm.contains(pair_a)
|| bb.contains(pair_a);
let bb_matches_b = bb_norm == *pair_b
|| bb == *pair_b
|| bb_norm.contains(pair_b)
|| bb.contains(pair_b);
if (a_matches_a && bb_matches_b) || (a_matches_b && bb_matches_a) {
return 0.85;
}
}
}
for (pair_a, pair_b) in known_pairs {
let mut a_has_a = a_norm.contains(pair_a) || a.contains(pair_a);
let mut a_has_b = a_norm.contains(pair_b) || a.contains(pair_b);
let mut b_has_a = b_norm.contains(pair_a) || b.contains(pair_a);
let mut b_has_b = b_norm.contains(pair_b) || b.contains(pair_b);
if let Some(ba) = base_a {
let ba_norm = normalize(ba);
a_has_a = a_has_a || ba_norm.contains(pair_a) || ba.contains(pair_a);
a_has_b = a_has_b || ba_norm.contains(pair_b) || ba.contains(pair_b);
}
if let Some(bb) = base_b {
let bb_norm = normalize(bb);
b_has_a = b_has_a || bb_norm.contains(pair_a) || bb.contains(pair_a);
b_has_b = b_has_b || bb_norm.contains(pair_b) || bb.contains(pair_b);
}
if (a_has_a && b_has_b) || (a_has_b && b_has_a) {
return 0.85; }
if let (Some(ba), Some(bb)) = (base_a, base_b) {
let ba_norm = normalize(ba);
let bb_norm = normalize(bb);
if (ba_norm == *pair_a && bb_norm == *pair_b)
|| (ba_norm == *pair_b && bb_norm == *pair_a)
{
return 0.9;
}
if (ba == *pair_a && bb == *pair_b) || (ba == *pair_b && bb == *pair_a) {
return 0.9;
}
}
}
let has_translit_pattern = |text: &str| -> bool {
text.contains('(') && text.contains(')')
|| text.contains('/')
|| text.contains('(') && text.contains(')') };
if has_translit_pattern(a) || has_translit_pattern(b) {
let extract_variants = |text: &str| -> Vec<String> {
let mut variants = Vec::new();
variants.push(text.to_string());
let base_text = text
.split('(')
.next()
.and_then(|s| s.split('(').next())
.and_then(|s| s.split('/').next())
.map(|s| s.trim().to_string());
if let Some(base) = base_text {
if !base.is_empty() && base != text {
variants.push(base);
}
}
if let Some(start) = text.find('(') {
let after_start = &text[start..];
if let Some(end_offset) = after_start.find(')') {
let content = &text[start + 1..start + end_offset];
let content = content.trim();
if !content.is_empty() {
variants.push(content.to_string());
}
}
}
if let Some(start) = text.find('(') {
if let Some(end) = text[start..].find(')') {
let content = text[start + 1..start + end].trim();
if !content.is_empty() {
variants.push(content.to_string());
}
}
}
if let Some(slash_pos) = text.find('/') {
let after_slash = text[slash_pos + 1..].trim();
if !after_slash.is_empty() {
variants.push(after_slash.to_string());
}
}
variants
};
let variants_a = extract_variants(a);
let variants_b = extract_variants(b);
let b_norm = normalize(b);
let a_norm = normalize(a);
for va in &variants_a {
if va == b {
return 0.85;
}
let va_norm = normalize(va);
if va_norm == b_norm {
return 0.85;
}
if va.eq_ignore_ascii_case(b) {
return 0.85;
}
}
for vb in &variants_b {
if vb == a {
return 0.85;
}
let vb_norm = normalize(vb);
if vb_norm == a_norm || vb.eq_ignore_ascii_case(a) {
return 0.85;
}
}
for va in &variants_a {
let va_norm = normalize(va);
for vb in &variants_b {
if va == vb {
return 0.9;
}
let vb_norm = normalize(vb);
if va_norm == vb_norm || va.eq_ignore_ascii_case(vb) {
return 0.9;
}
}
}
for va in &variants_a {
for vb in &variants_b {
let va_norm = normalize(va);
let vb_norm = normalize(vb);
if va_norm == vb_norm || va == vb {
return 0.9;
}
}
}
}
let base_sim = multilingual_similarity(a, b);
if base_sim > 0.3 {
base_sim * 1.2
} else {
base_sim
}
.min(1.0)
}
pub fn is_acronym_match(a: &str, b: &str) -> bool {
let (short, long) = if a.chars().count() < b.chars().count() {
(a, b)
} else {
(b, a)
};
let short_len = short.chars().count();
if !(2..=10).contains(&short_len) {
return false;
}
let upper_count = short.chars().filter(|c| c.is_uppercase()).count();
let alpha_count = short.chars().filter(|c| c.is_alphabetic()).count();
if upper_count < short_len / 2 || alpha_count < short_len / 2 {
return false;
}
let initials: String = long
.split(|c: char| c.is_whitespace() || c == '-')
.filter(|w| !w.is_empty())
.filter_map(|w| w.chars().next())
.filter(|c| c.is_alphabetic())
.collect();
initials.eq_ignore_ascii_case(short)
}
pub trait SynonymSource: Send + Sync {
fn lookup(&self, term: &str) -> Option<SynonymMatch>;
fn are_synonyms(&self, a: &str, b: &str) -> Option<SynonymMatch> {
let match_a = self.lookup(a)?;
let match_b = self.lookup(b)?;
if match_a.canonical_id == match_b.canonical_id {
Some(SynonymMatch {
canonical_id: match_a.canonical_id,
confidence: Confidence::new((match_a.confidence + match_b.confidence) / 2.0),
source: match_a.source,
})
} else {
None
}
}
fn source_name(&self) -> &str;
}
#[derive(Debug, Clone)]
pub struct SynonymMatch {
pub canonical_id: String,
pub confidence: Confidence,
pub source: String,
}
#[derive(Debug, Clone, Copy, Default)]
pub struct NoSynonyms;
impl SynonymSource for NoSynonyms {
fn lookup(&self, _term: &str) -> Option<SynonymMatch> {
None
}
fn source_name(&self) -> &str {
"none"
}
}
#[derive(Default)]
pub struct ChainedSynonyms {
sources: Vec<Box<dyn SynonymSource>>,
}
impl ChainedSynonyms {
pub fn new() -> Self {
Self::default()
}
pub fn with_source<S: SynonymSource + 'static>(mut self, source: S) -> Self {
self.sources.push(Box::new(source));
self
}
}
impl SynonymSource for ChainedSynonyms {
fn lookup(&self, term: &str) -> Option<SynonymMatch> {
for source in &self.sources {
if let Some(m) = source.lookup(term) {
return Some(m);
}
}
None
}
fn source_name(&self) -> &str {
"chained"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_normalize() {
assert_eq!(normalize(" Hello World "), "hello world");
assert_eq!(normalize("UPPERCASE"), "uppercase");
}
#[test]
fn test_similarity_identical() {
let sim = Similarity::new();
assert_eq!(sim.compute("test", "test"), 1.0);
assert_eq!(sim.compute("北京", "北京"), 1.0);
}
#[test]
fn test_similarity_cjk_partial() {
let sim = Similarity::new();
let score = sim.compute("中华人民共和国", "中华民国");
assert!(score > 0.0 && score < 1.0, "CJK partial: {}", score);
}
#[test]
fn test_similarity_english_words() {
let sim = Similarity::new();
let score = sim.compute("Marie Curie", "Curie");
assert!(score > 0.0 && score < 1.0, "English partial: {}", score);
}
#[test]
fn test_levenshtein() {
assert_eq!(levenshtein_distance("kitten", "sitting"), 3);
assert_eq!(levenshtein_distance("", "test"), 4);
assert_eq!(levenshtein_distance("same", "same"), 0);
}
#[test]
fn test_levenshtein_similarity() {
assert_eq!(levenshtein_similarity("same", "same"), 1.0);
let sim = levenshtein_similarity("kitten", "sitting");
assert!(sim > 0.5 && sim < 1.0);
}
#[test]
fn test_jaro_winkler() {
let sim = jaro_winkler_similarity("MARTHA", "MARHTA");
assert!(sim > 0.9, "Jaro-Winkler for similar strings: {}", sim);
let sim = jaro_winkler_similarity("DWAYNE", "DUANE");
assert!(sim > 0.8, "Jaro-Winkler: {}", sim);
}
#[test]
fn test_multilingual_api() {
let sim = multilingual_similarity("test", "test");
assert_eq!(sim, 1.0);
}
}
#[cfg(test)]
mod proptests {
use super::*;
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn similarity_symmetric(a in "\\PC{1,30}", b in "\\PC{1,30}") {
let sim = Similarity::new();
let ab = sim.compute(&a, &b);
let ba = sim.compute(&b, &a);
prop_assert!((ab - ba).abs() < 0.001,
"Symmetry: {} vs {}", ab, ba);
}
#[test]
fn similarity_bounded(a in "\\PC{0,50}", b in "\\PC{0,50}") {
let sim = Similarity::new();
let score = sim.compute(&a, &b);
prop_assert!((0.0..=1.0).contains(&score),
"Bounds: {}", score);
}
#[test]
fn similarity_identity(s in "\\PC{0,50}") {
let sim = Similarity::new();
let score = sim.compute(&s, &s);
prop_assert!((score - 1.0).abs() < 0.001,
"Identity: {}", score);
}
#[test]
fn levenshtein_non_negative(a in "\\PC{0,30}", b in "\\PC{0,30}") {
let dist = levenshtein_distance(&a, &b);
prop_assert!(dist <= a.chars().count() + b.chars().count());
}
#[test]
fn jaro_winkler_bounded(a in "\\PC{1,30}", b in "\\PC{1,30}") {
let sim = jaro_winkler_similarity(&a, &b);
prop_assert!((0.0..=1.0).contains(&sim),
"Jaro-Winkler bounds: {}", sim);
}
}
#[test]
fn test_acronym_who() {
assert!(is_acronym_match("WHO", "World Health Organization"));
assert!(is_acronym_match("World Health Organization", "WHO"));
}
#[test]
fn test_acronym_mrsa() {
assert!(is_acronym_match(
"MRSA",
"Methicillin-resistant Staphylococcus aureus"
));
}
#[test]
fn test_acronym_ibm() {
assert!(is_acronym_match("IBM", "International Business Machines"));
}
#[test]
fn test_acronym_german() {
assert!(is_acronym_match("DDR", "Deutsche Demokratische Republik"));
assert!(is_acronym_match("EU", "Europäische Union"));
}
#[test]
fn test_acronym_negative() {
assert!(!is_acronym_match("IBM", "Apple"));
assert!(!is_acronym_match("WHO", "United Nations"));
assert!(!is_acronym_match("USA", "Canada"));
}
#[test]
fn test_acronym_too_short() {
assert!(!is_acronym_match("A", "Apple"));
}
#[test]
fn test_acronym_not_mostly_uppercase() {
assert!(!is_acronym_match("who", "World Health Organization"));
}
#[test]
fn test_no_synonyms_returns_none() {
let source = NoSynonyms;
assert!(source.lookup("test").is_none());
assert!(source.are_synonyms("a", "b").is_none());
}
#[test]
fn test_chained_synonyms_empty() {
let chain = ChainedSynonyms::new();
assert!(chain.lookup("test").is_none());
}
#[test]
fn test_cross_lingual_same_script() {
let sim = cross_lingual_similarity("Moscow", "Moskva");
assert!((0.0..=1.0).contains(&sim), "Similarity should be in [0, 1]");
let sim_cjk = cross_lingual_similarity("北京", "北京");
assert!(
(sim_cjk - 1.0).abs() < 0.01,
"Identical strings should have similarity 1.0"
);
}
#[test]
fn test_cross_lingual_known_pairs() {
let sim = cross_lingual_similarity("Moscow", "Москва");
assert!(sim > 0.8, "Moscow ↔ Москва should match");
let sim = cross_lingual_similarity("Tokyo", "東京");
assert!(sim > 0.8, "Tokyo ↔ 東京 should match");
let sim = cross_lingual_similarity("Beijing", "北京");
assert!(sim > 0.8, "Beijing ↔ 北京 should match");
}
#[test]
fn test_cross_lingual_with_parentheses() {
let sim = cross_lingual_similarity("Moscow (Москва)", "Москва");
assert!(
sim > 0.6,
"Should extract variant from parentheses, got {}",
sim
);
let sim = cross_lingual_similarity("東京 (Tokyo)", "Tokyo");
assert!(
sim > 0.5,
"Should handle CJK with Latin transliteration, got {}",
sim
);
let sim = cross_lingual_similarity("Tokyo", "東京 (Tokyo)");
assert!(sim > 0.5, "Should work in reverse direction, got {}", sim);
}
#[test]
fn test_cross_lingual_different_scripts() {
let sim = cross_lingual_similarity("Paris", "パリ"); assert!((0.0..=1.0).contains(&sim));
}
}