assay_core/errors/
similarity.rs

1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Serialize, Deserialize)]
4pub struct ClosestMatch {
5    pub prompt: String,
6    pub similarity: f64,
7}
8
9pub fn closest_prompt<'a>(
10    needle: &str,
11    hay: impl Iterator<Item = &'a String>,
12) -> Option<ClosestMatch> {
13    let mut best: Option<ClosestMatch> = None;
14
15    // Threshold for suggestion. 0.55 is a reasonable heuristic.
16    const THRESHOLD: f64 = 0.55;
17
18    for candidate in hay {
19        let sim = strsim::normalized_levenshtein(needle, candidate);
20        if sim >= THRESHOLD && best.as_ref().is_none_or(|b| sim > b.similarity) {
21            best = Some(ClosestMatch {
22                prompt: candidate.clone(),
23                similarity: sim,
24            });
25        }
26    }
27    best
28}
29
30#[cfg(test)]
31mod tests {
32    use super::*;
33
34    #[test]
35    fn test_closest_prompt_exact() {
36        let candidates = ["foo".to_string(), "bar".to_string()];
37        let hit = closest_prompt("foo", candidates.iter()).unwrap();
38        assert_eq!(hit.prompt, "foo");
39        assert!((hit.similarity - 1.0).abs() < f64::EPSILON);
40    }
41
42    #[test]
43    fn test_closest_prompt_typo() {
44        let candidates = ["capitol".to_string(), "bar".to_string()];
45        let hit = closest_prompt("capital", candidates.iter()).unwrap();
46        assert_eq!(hit.prompt, "capitol");
47        assert!(hit.similarity > 0.8);
48    }
49
50    #[test]
51    fn test_closest_prompt_none() {
52        let candidates = ["zulu".to_string(), "bar".to_string()];
53        let hit = closest_prompt("alpha", candidates.iter());
54        assert!(hit.is_none());
55    }
56}