assay_core/errors/
similarity.rs1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Serialize, Deserialize)]
4pub struct ClosestMatch {
5 pub prompt: String,
6 pub similarity: f64,
7}
8
9pub fn closest_prompt<'a>(
10 needle: &str,
11 hay: impl Iterator<Item = &'a String>,
12) -> Option<ClosestMatch> {
13 let mut best: Option<ClosestMatch> = None;
14
15 const THRESHOLD: f64 = 0.55;
17
18 for candidate in hay {
19 let sim = strsim::normalized_levenshtein(needle, candidate);
20 if sim >= THRESHOLD && best.as_ref().is_none_or(|b| sim > b.similarity) {
21 best = Some(ClosestMatch {
22 prompt: candidate.clone(),
23 similarity: sim,
24 });
25 }
26 }
27 best
28}
29
30#[cfg(test)]
31mod tests {
32 use super::*;
33
34 #[test]
35 fn test_closest_prompt_exact() {
36 let candidates = ["foo".to_string(), "bar".to_string()];
37 let hit = closest_prompt("foo", candidates.iter()).unwrap();
38 assert_eq!(hit.prompt, "foo");
39 assert!((hit.similarity - 1.0).abs() < f64::EPSILON);
40 }
41
42 #[test]
43 fn test_closest_prompt_typo() {
44 let candidates = ["capitol".to_string(), "bar".to_string()];
45 let hit = closest_prompt("capital", candidates.iter()).unwrap();
46 assert_eq!(hit.prompt, "capitol");
47 assert!(hit.similarity > 0.8);
48 }
49
50 #[test]
51 fn test_closest_prompt_none() {
52 let candidates = ["zulu".to_string(), "bar".to_string()];
53 let hit = closest_prompt("alpha", candidates.iter());
54 assert!(hit.is_none());
55 }
56}