1use std::collections::BTreeMap;
7
8pub fn levenshtein(a: &str, b: &str) -> usize {
15 let a_bytes = a.as_bytes();
16 let b_bytes = b.as_bytes();
17 let m = a_bytes.len();
18 let n = b_bytes.len();
19
20 if m == 0 {
21 return n;
22 }
23 if n == 0 {
24 return m;
25 }
26
27 let (short, long, s_len, l_len) = if m <= n {
29 (a_bytes, b_bytes, m, n)
30 } else {
31 (b_bytes, a_bytes, n, m)
32 };
33
34 let mut prev_row: Vec<usize> = (0..=s_len).collect();
35 let mut curr_row: Vec<usize> = vec![0; s_len + 1];
36
37 for i in 1..=l_len {
38 curr_row[0] = i;
39 for j in 1..=s_len {
40 let cost = if long[i - 1] == short[j - 1] { 0 } else { 1 };
41 curr_row[j] = (prev_row[j] + 1) .min(curr_row[j - 1] + 1) .min(prev_row[j - 1] + cost); }
45 std::mem::swap(&mut prev_row, &mut curr_row);
46 }
47
48 prev_row[s_len]
49}
50
51pub fn levenshtein_similarity(a: &str, b: &str) -> f64 {
56 let max_len = a.len().max(b.len());
57 if max_len == 0 {
58 return 1.0;
59 }
60 let dist = levenshtein(a, b);
61 1.0 - (dist as f64) / (max_len as f64)
62}
63
64pub fn jaccard_ngram_similarity(a: &str, b: &str, n: usize) -> f64 {
71 if n == 0 || a.is_empty() || b.is_empty() {
72 return 0.0;
73 }
74
75 let set_a = char_ngram_set(a, n);
76 let set_b = char_ngram_set(b, n);
77
78 let intersection = set_a.iter().filter(|g| set_b.contains(*g)).count();
79 let union = {
80 let mut all = set_a.clone();
81 all.extend(set_b.iter().cloned());
82 all.len()
83 };
84
85 if union == 0 {
86 0.0
87 } else {
88 intersection as f64 / union as f64
89 }
90}
91
92fn char_ngram_set(s: &str, n: usize) -> std::collections::BTreeSet<String> {
93 let chars: Vec<char> = s.chars().collect();
94 let mut set = std::collections::BTreeSet::new();
95 if chars.len() >= n {
96 for window in chars.windows(n) {
97 set.insert(window.iter().collect());
98 }
99 }
100 set
101}
102
103pub fn char_ngrams(s: &str, n: usize) -> BTreeMap<String, usize> {
109 let mut counts = BTreeMap::new();
110 let chars: Vec<char> = s.chars().collect();
111 if chars.len() >= n {
112 for window in chars.windows(n) {
113 let gram: String = window.iter().collect();
114 *counts.entry(gram).or_insert(0) += 1;
115 }
116 }
117 counts
118}
119
120pub fn word_ngrams(s: &str, n: usize) -> BTreeMap<String, usize> {
124 let mut counts = BTreeMap::new();
125 let words: Vec<&str> = s.split_whitespace().collect();
126 if words.len() >= n {
127 for window in words.windows(n) {
128 let gram = window.join(" ");
129 *counts.entry(gram).or_insert(0) += 1;
130 }
131 }
132 counts
133}
134
135pub fn tokenize_whitespace(s: &str) -> Vec<(usize, usize)> {
139 let bytes = s.as_bytes();
140 let mut spans = Vec::new();
141 let mut i = 0;
142 while i < bytes.len() {
143 while i < bytes.len() && bytes[i].is_ascii_whitespace() {
145 i += 1;
146 }
147 if i >= bytes.len() {
148 break;
149 }
150 let start = i;
151 while i < bytes.len() && !bytes[i].is_ascii_whitespace() {
153 i += 1;
154 }
155 spans.push((start, i));
156 }
157 spans
158}
159
160pub fn tokenize_words(s: &str) -> Vec<String> {
163 let mut tokens = Vec::new();
164 for chunk in s.split_whitespace() {
165 let chars: Vec<char> = chunk.chars().collect();
166 let len = chars.len();
167
168 let mut lead = 0;
170 while lead < len && chars[lead].is_ascii_punctuation() {
171 lead += 1;
172 }
173
174 let mut trail = 0;
176 while trail < len - lead && chars[len - 1 - trail].is_ascii_punctuation() {
177 trail += 1;
178 }
179
180 for c in &chars[..lead] {
182 tokens.push(c.to_string());
183 }
184
185 let body_end = len - trail;
187 if body_end > lead {
188 let body: String = chars[lead..body_end].iter().collect();
189 tokens.push(body);
190 }
191
192 for c in &chars[body_end..] {
194 tokens.push(c.to_string());
195 }
196 }
197 tokens
198}
199
200pub fn ascii_lowercase(s: &str) -> String {
202 s.chars().map(|c| {
203 if c.is_ascii_uppercase() {
204 (c as u8 + 32) as char
205 } else {
206 c
207 }
208 }).collect()
209}
210
211pub fn strip_punctuation(s: &str) -> String {
213 s.chars().filter(|c| !c.is_ascii_punctuation()).collect()
214}
215
216pub fn term_frequency(s: &str) -> BTreeMap<String, f64> {
221 let words: Vec<String> = s.split_whitespace()
222 .map(|w| ascii_lowercase(w))
223 .collect();
224 let total = words.len() as f64;
225 if total == 0.0 {
226 return BTreeMap::new();
227 }
228 let mut counts: BTreeMap<String, usize> = BTreeMap::new();
229 for w in &words {
230 *counts.entry(w.clone()).or_insert(0) += 1;
231 }
232 counts
233 .into_iter()
234 .map(|(word, count)| (word, count as f64 / total))
235 .collect()
236}
237
238pub fn cosine_similarity(a: &BTreeMap<String, f64>, b: &BTreeMap<String, f64>) -> f64 {
245 let mut dot = 0.0;
246 let mut norm_a = 0.0;
247 let mut norm_b = 0.0;
248
249 for (key, va) in a {
250 norm_a += va * va;
251 if let Some(vb) = b.get(key) {
252 dot += va * vb;
253 }
254 }
255 for (_, vb) in b {
256 norm_b += vb * vb;
257 }
258
259 let denom = norm_a.sqrt() * norm_b.sqrt();
260 if denom == 0.0 {
261 0.0
262 } else {
263 dot / denom
264 }
265}
266
267#[cfg(test)]
270mod tests {
271 use super::*;
272
273 #[test]
274 fn test_levenshtein_identical() {
275 assert_eq!(levenshtein("hello", "hello"), 0);
276 }
277
278 #[test]
279 fn test_levenshtein_insert() {
280 assert_eq!(levenshtein("abc", "abcd"), 1);
281 }
282
283 #[test]
284 fn test_levenshtein_delete() {
285 assert_eq!(levenshtein("abcd", "abc"), 1);
286 }
287
288 #[test]
289 fn test_levenshtein_substitute() {
290 assert_eq!(levenshtein("abc", "axc"), 1);
291 }
292
293 #[test]
294 fn test_levenshtein_empty() {
295 assert_eq!(levenshtein("", "hello"), 5);
296 assert_eq!(levenshtein("hello", ""), 5);
297 assert_eq!(levenshtein("", ""), 0);
298 }
299
300 #[test]
301 fn test_levenshtein_kitten_sitting() {
302 assert_eq!(levenshtein("kitten", "sitting"), 3);
303 }
304
305 #[test]
306 fn test_levenshtein_similarity() {
307 let sim = levenshtein_similarity("hello", "hello");
308 assert!((sim - 1.0).abs() < 1e-10);
309 let sim2 = levenshtein_similarity("abc", "xyz");
310 assert!((sim2 - 0.0).abs() < 1e-10);
311 }
312
313 #[test]
314 fn test_jaccard_identical() {
315 let sim = jaccard_ngram_similarity("hello", "hello", 2);
316 assert!((sim - 1.0).abs() < 1e-10);
317 }
318
319 #[test]
320 fn test_jaccard_disjoint() {
321 let sim = jaccard_ngram_similarity("abc", "xyz", 2);
322 assert!((sim - 0.0).abs() < 1e-10);
323 }
324
325 #[test]
326 fn test_char_ngrams() {
327 let grams = char_ngrams("hello", 2);
328 assert_eq!(grams["he"], 1);
329 assert_eq!(grams["el"], 1);
330 assert_eq!(grams["ll"], 1);
331 assert_eq!(grams["lo"], 1);
332 assert_eq!(grams.len(), 4);
333 }
334
335 #[test]
336 fn test_word_ngrams() {
337 let grams = word_ngrams("the quick brown fox", 2);
338 assert_eq!(grams["the quick"], 1);
339 assert_eq!(grams["quick brown"], 1);
340 assert_eq!(grams["brown fox"], 1);
341 assert_eq!(grams.len(), 3);
342 }
343
344 #[test]
345 fn test_tokenize_whitespace() {
346 let spans = tokenize_whitespace(" hello world ");
347 assert_eq!(spans, vec![(2, 7), (10, 15)]);
348 }
349
350 #[test]
351 fn test_tokenize_words() {
352 let tokens = tokenize_words("Hello, world! (test)");
353 assert_eq!(tokens, vec!["Hello", ",", "world", "!", "(", "test", ")"]);
354 }
355
356 #[test]
357 fn test_ascii_lowercase() {
358 assert_eq!(ascii_lowercase("Hello WORLD"), "hello world");
359 }
360
361 #[test]
362 fn test_strip_punctuation() {
363 assert_eq!(strip_punctuation("hello, world!"), "hello world");
364 }
365
366 #[test]
367 fn test_term_frequency() {
368 let tf = term_frequency("the cat sat on the mat");
369 assert!((tf["the"] - 2.0 / 6.0).abs() < 1e-10);
370 assert!((tf["cat"] - 1.0 / 6.0).abs() < 1e-10);
371 }
372
373 #[test]
374 fn test_cosine_similarity_identical() {
375 let tf = term_frequency("hello world");
376 let sim = cosine_similarity(&tf, &tf);
377 assert!((sim - 1.0).abs() < 1e-10);
378 }
379
380 #[test]
381 fn test_cosine_similarity_orthogonal() {
382 let a = term_frequency("cat dog");
383 let b = term_frequency("fish bird");
384 let sim = cosine_similarity(&a, &b);
385 assert!((sim - 0.0).abs() < 1e-10);
386 }
387
388 #[test]
389 fn test_determinism() {
390 for _ in 0..10 {
391 assert_eq!(levenshtein("kitten", "sitting"), 3);
392 let grams = char_ngrams("deterministic", 3);
393 assert_eq!(grams.len(), 11);
394 }
395 }
396}