textdistance/algorithms/
damerau_levenshtein.rs

1//! Damerau-Levenshtein distance
2#![cfg(feature = "std")]
3use crate::{Algorithm, Result};
4use alloc::vec;
5use alloc::vec::Vec;
6use core::hash::Hash;
7use std::collections::HashMap;
8
9/// [Damerau-Levenshtein distance] is an edit distance between two sequences.
10///
11/// It is an improved version of [Levenshtein](crate::Levenshtein) that also includes
12/// transpositions.
13///
14/// It is the minimum number of operations (consisting of insertions, deletions or
15/// substitutions of a single character, or transposition of two adjacent characters)
16/// required to change one text into the other.
17///
18/// [Damerau-Levenshtein distance]: https://en.wikipedia.org/wiki/Damerau%E2%80%93Levenshtein_distance
19pub struct DamerauLevenshtein {
20    /// If false (default), allow adjacent transpositions.
21    pub restricted: bool,
22
23    /// The cost of removing a character.
24    pub del_cost: usize,
25
26    /// The cost of adding a new character.
27    pub ins_cost: usize,
28
29    /// The cost of replacing a character with another one.
30    pub sub_cost: usize,
31
32    /// The cost of swapping two adjacent characters.
33    pub trans_cost: usize,
34}
35
36impl Default for DamerauLevenshtein {
37    fn default() -> Self {
38        Self {
39            restricted: false,
40            del_cost: 1,
41            ins_cost: 1,
42            sub_cost: 1,
43            trans_cost: 1,
44        }
45    }
46}
47
48impl DamerauLevenshtein {
49    fn get_unrestricted<E: Eq + Hash>(&self, s1: &[E], s2: &[E]) -> Result<usize> {
50        let l1 = s1.len();
51        let l2 = s2.len();
52        let max_dist = l2 + l1;
53
54        let mut mat: Vec<Vec<usize>> = vec![vec![0; l2 + 2]; l1 + 2];
55        mat[0][0] = max_dist;
56        for i in 0..=l1 {
57            mat[i + 1][0] = max_dist;
58            mat[i + 1][1] = i;
59        }
60        for i in 0..=l2 {
61            mat[0][i + 1] = max_dist;
62            mat[1][i + 1] = i;
63        }
64
65        let mut char_map: HashMap<&E, usize> = HashMap::new();
66        for (i1, c1) in s1.iter().enumerate() {
67            let mut db = 0;
68            let i1 = i1 + 1;
69
70            for (i2, c2) in s2.iter().enumerate() {
71                let i2 = i2 + 1;
72                let last = *char_map.get(&c2).unwrap_or(&0);
73
74                let sub_cost = if c1 == c2 { 0 } else { self.sub_cost };
75                mat[i1 + 1][i2 + 1] = min4(
76                    mat[i1][i2] + sub_cost,                                    // substitution
77                    mat[i1 + 1][i2] + self.del_cost,                           // deletion
78                    mat[i1][i2 + 1] + self.ins_cost,                           // insertion
79                    mat[last][db] + i1 + i2 - 2 + self.trans_cost - last - db, // transposition
80                );
81
82                if c1 == c2 {
83                    db = i2;
84                }
85            }
86
87            char_map.insert(c1, i1);
88        }
89
90        Result {
91            is_distance: true,
92            abs: mat[l1 + 1][l2 + 1],
93            max: l1.max(l2),
94            len1: l1,
95            len2: l2,
96        }
97    }
98
99    #[allow(clippy::needless_range_loop)]
100    fn get_restricted<E: Eq + Hash>(&self, s1: &[E], s2: &[E]) -> Result<usize> {
101        let l1 = s1.len();
102        let l2 = s2.len();
103
104        let mut mat: Vec<Vec<usize>> = vec![vec![0; l2 + 2]; l1 + 2];
105        for i in 0..=l1 {
106            mat[i][0] = i;
107        }
108        for i in 0..=l2 {
109            mat[0][i] = i;
110        }
111
112        for (i1, c1) in s1.iter().enumerate() {
113            for (i2, c2) in s2.iter().enumerate() {
114                let sub_cost = if c1 == c2 { 0 } else { self.sub_cost };
115                mat[i1 + 1][i2 + 1] = min3(
116                    mat[i1][i2 + 1] + self.del_cost, // deletion
117                    mat[i1 + 1][i2] + self.ins_cost, // insertion
118                    mat[i1][i2] + sub_cost,          // substitution
119                );
120
121                // transposition
122                if i1 == 0 || i2 == 0 {
123                    continue;
124                };
125                if c1 != &s2[i2 - 1] {
126                    continue;
127                };
128                if &s1[i1 - 1] != c2 {
129                    continue;
130                };
131                let trans_cost = if c1 == c2 { 0 } else { self.trans_cost };
132                mat[i1 + 1][i2 + 1] = mat[i1 + 1][i2 + 1].min(mat[i1 - 1][i2 - 1] + trans_cost);
133            }
134        }
135
136        Result {
137            is_distance: true,
138            abs: mat[l1][l2],
139            max: l1.max(l2),
140            len1: l1,
141            len2: l2,
142        }
143    }
144}
145
146impl Algorithm<usize> for DamerauLevenshtein {
147    fn for_vec<E: Eq + Hash>(&self, s1: &[E], s2: &[E]) -> Result<usize> {
148        if self.restricted {
149            self.get_restricted(s1, s2)
150        } else {
151            self.get_unrestricted(s1, s2)
152        }
153    }
154}
155
156fn min4(a: usize, b: usize, c: usize, d: usize) -> usize {
157    a.min(b).min(c).min(d)
158}
159
160fn min3(a: usize, b: usize, c: usize) -> usize {
161    a.min(b).min(c)
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167    use crate::str::{damerau_levenshtein, damerau_levenshtein_restricted};
168    use assert2::assert;
169    use proptest::prelude::*;
170    use rstest::rstest;
171
172    #[rstest]
173    #[case("", "", 0)]
174    #[case("", "\0", 1)]
175    #[case("", "abc", 3)]
176    #[case("abc", "", 3)]
177    #[case("hannah", "hannha", 1)]
178    #[case("FOO", "BOR", 2)]
179    #[case("BAR", "BOR", 1)]
180    #[case("hansi", "hasni", 1)]
181    #[case("zzaabbio", "zzababoi", 2)]
182    #[case("zzaabb", "zzabab", 1)]
183    #[case("abcdef", "badcfe", 3)]
184    #[case("klmb", "klm", 1)]
185    #[case("klm", "klmb", 1)]
186    #[case("test", "text", 1)]
187    #[case("test", "tset", 1)]
188    #[case("test", "qwy", 4)]
189    #[case("test", "testit", 2)]
190    #[case("test", "tesst", 1)]
191    #[case("test", "tet", 1)]
192    #[case("cat", "hat", 1)]
193    #[case("Niall", "Neil", 3)]
194    #[case("aluminum", "Catalan", 7)]
195    #[case("ATCG", "TAGC", 2)]
196    #[case("ab", "ba", 1)]
197    #[case("ab", "cde", 3)]
198    #[case("ab", "ac", 1)]
199    #[case("ab", "bc", 2)]
200    fn function_str(#[case] s1: &str, #[case] s2: &str, #[case] exp: usize) {
201        let res1 = damerau_levenshtein(s1, s2);
202        let res2 = damerau_levenshtein_restricted(s1, s2);
203        assert!(res1 == res2);
204        assert!(res1 == exp);
205    }
206
207    #[test]
208    fn restricted() {
209        let a = DamerauLevenshtein {
210            restricted: true,
211            ..Default::default()
212        };
213        assert!(a.for_str("ab", "bca").val() == 3);
214        assert!(a.for_str("abcd", "bdac").val() == 4);
215    }
216
217    #[test]
218    fn unrestricted() {
219        let a = DamerauLevenshtein::default();
220        assert!(a.for_str("ab", "bca").val() == 2);
221        assert!(a.for_str("abcd", "bdac").val() == 3);
222    }
223
224    proptest! {
225        #[test]
226        fn prop_default(s1 in ".*", s2 in ".*") {
227            let res = damerau_levenshtein(&s1, &s2);
228            let res2 = damerau_levenshtein(&s2, &s1);
229            prop_assert_eq!(res, res2);
230            prop_assert!(res <= s1.len() || res <= s2.len());
231        }
232
233        #[test]
234        fn prop_restricted(s1 in ".*", s2 in ".*") {
235            let res = damerau_levenshtein_restricted(&s1, &s2);
236            let res2 = damerau_levenshtein_restricted(&s2, &s1);
237            prop_assert_eq!(res, res2);
238            prop_assert!(res <= s1.len() || res <= s2.len());
239        }
240    }
241}