use std::collections::BTreeMap;
pub fn levenshtein(a: &str, b: &str) -> usize {
let a_bytes = a.as_bytes();
let b_bytes = b.as_bytes();
let m = a_bytes.len();
let n = b_bytes.len();
if m == 0 {
return n;
}
if n == 0 {
return m;
}
let (short, long, s_len, l_len) = if m <= n {
(a_bytes, b_bytes, m, n)
} else {
(b_bytes, a_bytes, n, m)
};
let mut prev_row: Vec<usize> = (0..=s_len).collect();
let mut curr_row: Vec<usize> = vec![0; s_len + 1];
for i in 1..=l_len {
curr_row[0] = i;
for j in 1..=s_len {
let cost = if long[i - 1] == short[j - 1] { 0 } else { 1 };
curr_row[j] = (prev_row[j] + 1) .min(curr_row[j - 1] + 1) .min(prev_row[j - 1] + cost); }
std::mem::swap(&mut prev_row, &mut curr_row);
}
prev_row[s_len]
}
pub fn levenshtein_similarity(a: &str, b: &str) -> f64 {
let max_len = a.len().max(b.len());
if max_len == 0 {
return 1.0;
}
let dist = levenshtein(a, b);
1.0 - (dist as f64) / (max_len as f64)
}
pub fn jaccard_ngram_similarity(a: &str, b: &str, n: usize) -> f64 {
if n == 0 || a.is_empty() || b.is_empty() {
return 0.0;
}
let set_a = char_ngram_set(a, n);
let set_b = char_ngram_set(b, n);
let intersection = set_a.iter().filter(|g| set_b.contains(*g)).count();
let union = {
let mut all = set_a.clone();
all.extend(set_b.iter().cloned());
all.len()
};
if union == 0 {
0.0
} else {
intersection as f64 / union as f64
}
}
fn char_ngram_set(s: &str, n: usize) -> std::collections::BTreeSet<String> {
let chars: Vec<char> = s.chars().collect();
let mut set = std::collections::BTreeSet::new();
if chars.len() >= n {
for window in chars.windows(n) {
set.insert(window.iter().collect());
}
}
set
}
pub fn char_ngrams(s: &str, n: usize) -> BTreeMap<String, usize> {
let mut counts = BTreeMap::new();
let chars: Vec<char> = s.chars().collect();
if chars.len() >= n {
for window in chars.windows(n) {
let gram: String = window.iter().collect();
*counts.entry(gram).or_insert(0) += 1;
}
}
counts
}
pub fn word_ngrams(s: &str, n: usize) -> BTreeMap<String, usize> {
let mut counts = BTreeMap::new();
let words: Vec<&str> = s.split_whitespace().collect();
if words.len() >= n {
for window in words.windows(n) {
let gram = window.join(" ");
*counts.entry(gram).or_insert(0) += 1;
}
}
counts
}
pub fn tokenize_whitespace(s: &str) -> Vec<(usize, usize)> {
let bytes = s.as_bytes();
let mut spans = Vec::new();
let mut i = 0;
while i < bytes.len() {
while i < bytes.len() && bytes[i].is_ascii_whitespace() {
i += 1;
}
if i >= bytes.len() {
break;
}
let start = i;
while i < bytes.len() && !bytes[i].is_ascii_whitespace() {
i += 1;
}
spans.push((start, i));
}
spans
}
pub fn tokenize_words(s: &str) -> Vec<String> {
let mut tokens = Vec::new();
for chunk in s.split_whitespace() {
let chars: Vec<char> = chunk.chars().collect();
let len = chars.len();
let mut lead = 0;
while lead < len && chars[lead].is_ascii_punctuation() {
lead += 1;
}
let mut trail = 0;
while trail < len - lead && chars[len - 1 - trail].is_ascii_punctuation() {
trail += 1;
}
for c in &chars[..lead] {
tokens.push(c.to_string());
}
let body_end = len - trail;
if body_end > lead {
let body: String = chars[lead..body_end].iter().collect();
tokens.push(body);
}
for c in &chars[body_end..] {
tokens.push(c.to_string());
}
}
tokens
}
pub fn ascii_lowercase(s: &str) -> String {
s.chars().map(|c| {
if c.is_ascii_uppercase() {
(c as u8 + 32) as char
} else {
c
}
}).collect()
}
pub fn strip_punctuation(s: &str) -> String {
s.chars().filter(|c| !c.is_ascii_punctuation()).collect()
}
pub fn term_frequency(s: &str) -> BTreeMap<String, f64> {
let words: Vec<String> = s.split_whitespace()
.map(|w| ascii_lowercase(w))
.collect();
let total = words.len() as f64;
if total == 0.0 {
return BTreeMap::new();
}
let mut counts: BTreeMap<String, usize> = BTreeMap::new();
for w in &words {
*counts.entry(w.clone()).or_insert(0) += 1;
}
counts
.into_iter()
.map(|(word, count)| (word, count as f64 / total))
.collect()
}
pub fn cosine_similarity(a: &BTreeMap<String, f64>, b: &BTreeMap<String, f64>) -> f64 {
let mut dot = 0.0;
let mut norm_a = 0.0;
let mut norm_b = 0.0;
for (key, va) in a {
norm_a += va * va;
if let Some(vb) = b.get(key) {
dot += va * vb;
}
}
for (_, vb) in b {
norm_b += vb * vb;
}
let denom = norm_a.sqrt() * norm_b.sqrt();
if denom == 0.0 {
0.0
} else {
dot / denom
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_levenshtein_identical() {
assert_eq!(levenshtein("hello", "hello"), 0);
}
#[test]
fn test_levenshtein_insert() {
assert_eq!(levenshtein("abc", "abcd"), 1);
}
#[test]
fn test_levenshtein_delete() {
assert_eq!(levenshtein("abcd", "abc"), 1);
}
#[test]
fn test_levenshtein_substitute() {
assert_eq!(levenshtein("abc", "axc"), 1);
}
#[test]
fn test_levenshtein_empty() {
assert_eq!(levenshtein("", "hello"), 5);
assert_eq!(levenshtein("hello", ""), 5);
assert_eq!(levenshtein("", ""), 0);
}
#[test]
fn test_levenshtein_kitten_sitting() {
assert_eq!(levenshtein("kitten", "sitting"), 3);
}
#[test]
fn test_levenshtein_similarity() {
let sim = levenshtein_similarity("hello", "hello");
assert!((sim - 1.0).abs() < 1e-10);
let sim2 = levenshtein_similarity("abc", "xyz");
assert!((sim2 - 0.0).abs() < 1e-10);
}
#[test]
fn test_jaccard_identical() {
let sim = jaccard_ngram_similarity("hello", "hello", 2);
assert!((sim - 1.0).abs() < 1e-10);
}
#[test]
fn test_jaccard_disjoint() {
let sim = jaccard_ngram_similarity("abc", "xyz", 2);
assert!((sim - 0.0).abs() < 1e-10);
}
#[test]
fn test_char_ngrams() {
let grams = char_ngrams("hello", 2);
assert_eq!(grams["he"], 1);
assert_eq!(grams["el"], 1);
assert_eq!(grams["ll"], 1);
assert_eq!(grams["lo"], 1);
assert_eq!(grams.len(), 4);
}
#[test]
fn test_word_ngrams() {
let grams = word_ngrams("the quick brown fox", 2);
assert_eq!(grams["the quick"], 1);
assert_eq!(grams["quick brown"], 1);
assert_eq!(grams["brown fox"], 1);
assert_eq!(grams.len(), 3);
}
#[test]
fn test_tokenize_whitespace() {
let spans = tokenize_whitespace(" hello world ");
assert_eq!(spans, vec![(2, 7), (10, 15)]);
}
#[test]
fn test_tokenize_words() {
let tokens = tokenize_words("Hello, world! (test)");
assert_eq!(tokens, vec!["Hello", ",", "world", "!", "(", "test", ")"]);
}
#[test]
fn test_ascii_lowercase() {
assert_eq!(ascii_lowercase("Hello WORLD"), "hello world");
}
#[test]
fn test_strip_punctuation() {
assert_eq!(strip_punctuation("hello, world!"), "hello world");
}
#[test]
fn test_term_frequency() {
let tf = term_frequency("the cat sat on the mat");
assert!((tf["the"] - 2.0 / 6.0).abs() < 1e-10);
assert!((tf["cat"] - 1.0 / 6.0).abs() < 1e-10);
}
#[test]
fn test_cosine_similarity_identical() {
let tf = term_frequency("hello world");
let sim = cosine_similarity(&tf, &tf);
assert!((sim - 1.0).abs() < 1e-10);
}
#[test]
fn test_cosine_similarity_orthogonal() {
let a = term_frequency("cat dog");
let b = term_frequency("fish bird");
let sim = cosine_similarity(&a, &b);
assert!((sim - 0.0).abs() < 1e-10);
}
#[test]
fn test_determinism() {
for _ in 0..10 {
assert_eq!(levenshtein("kitten", "sitting"), 3);
let grams = char_ngrams("deterministic", 3);
assert_eq!(grams.len(), 11);
}
}
}