levenshtein_diff/
distance.rs

1use std::cmp::{max, min};
2
3use crate::util::*;
4
5/// Returns the Levenshtein distance between source and target using Naive Recursion
6///
7/// **It is ill-advised to use this function because of it's terrible performance
8/// characteristics.**
9///
10/// This implementation has a time complexity of O(3^n).
11///
12/// # Arguments
13///
14/// * `source` - The source sequence
15/// * `target` - The target sequence
16///
17/// # Examples
18///
19/// ```
20/// use levenshtein_diff as levenshtein;
21///
22/// let s1 = "SATURDAY";
23/// let s2 = "SUNDAY";
24/// let expected_leven = 3;
25
26/// let leven_naive = levenshtein::levenshtein_naive(s1.as_bytes(), s2.as_bytes());
27/// assert_eq!(leven_naive, expected_leven);
28/// ```
29pub fn levenshtein_naive<T: PartialEq>(source: &[T], target: &[T]) -> usize {
30    // base case
31    if source.is_empty() || target.is_empty() {
32        return max(source.len(), target.len());
33    }
34
35    if source.last() == target.last() {
36        // The item being looked at is the same, so it wouldn't contribute to the distance
37        return levenshtein_naive(up_to_last(source), up_to_last(target));
38    }
39
40    // The items being looked at are different, so we must consider all possibilities
41
42    let delete = levenshtein_naive(up_to_last(source), target) + 1;
43    let insert = levenshtein_naive(source, up_to_last(target)) + 1;
44    let substitute = levenshtein_naive(up_to_last(source), up_to_last(target)) + 1;
45
46    min(min(insert, delete), substitute)
47}
48
49/// Returns the Levenshtein distance and the distance matrix between source and target using
50/// dynamic programming with tabulation.
51///
52/// This implementation has a time complexity of O(n^2) and a space complexity of O(n^2).
53///
54/// # Arguments
55///
56/// * `source` - The source sequence
57/// * `target` - The target sequence
58///
59/// # Examples
60///
61/// ```
62/// use levenshtein_diff as levenshtein;
63///
64/// let s1 = "SATURDAY";
65/// let s2 = "SUNDAY";
66/// let expected_leven = 3;
67
68/// let (leven_naive, _) = levenshtein::levenshtein_tabulation(s1.as_bytes(), s2.as_bytes());
69/// assert_eq!(leven_naive, expected_leven);
70/// ```
71pub fn levenshtein_tabulation<T: PartialEq>(source: &[T], target: &[T]) -> (usize, DistanceMatrix) {
72    let m = source.len();
73    let n = target.len();
74
75    // table of distances
76    let mut distances = get_distance_table(m, n);
77
78    for i in 1..distances.len() {
79        for j in 1..distances[0].len() {
80            if source[i - 1] == target[j - 1] {
81                // The item being looked at is the same, so the distance won't increase
82                distances[i][j] = distances[i - 1][j - 1];
83                continue;
84            }
85
86            let delete = distances[i - 1][j] + 1;
87            let insert = distances[i][j - 1] + 1;
88            let substitute = distances[i - 1][j - 1] + 1;
89
90            distances[i][j] = min(min(delete, insert), substitute);
91        }
92    }
93
94    (distances[m][n], distances)
95}
96
97/// Returns the Levenshtein distance and the distance matrix between source and target using
98/// dynamic programming with memoization.
99///
100/// This implementation has a time complexity of O(n^2) and a space complexity of O(n^2).
101///
102/// # Arguments
103///
104/// * `source` - The source sequence
105/// * `target` - The target sequence
106///
107/// # Examples
108///
109/// ```
110/// use levenshtein_diff as levenshtein;
111///
112/// let s1 = "SATURDAY";
113/// let s2 = "SUNDAY";
114/// let expected_leven = 3;
115
116/// let (leven_naive, _) = levenshtein::levenshtein_memoization(s1.as_bytes(), s2.as_bytes());
117/// assert_eq!(leven_naive, expected_leven);
118/// ```
119pub fn levenshtein_memoization<T: PartialEq>(
120    source: &[T],
121    target: &[T],
122) -> (usize, DistanceMatrix) {
123    fn levenshtein_memoization_helper<T: PartialEq>(
124        source: &[T],
125        target: &[T],
126        distances: &mut DistanceMatrix,
127    ) -> usize {
128        // check the cache first
129        if distances[source.len()][target.len()] < usize::MAX {
130            return distances[source.len()][target.len()];
131        }
132
133        // base case
134        if source.is_empty() || target.is_empty() {
135            return max(source.len(), target.len());
136        }
137
138        // couldn't find the value, time to recursively calculate it
139
140        let k = if source.last() == target.last() { 0 } else { 1 };
141
142        let delete = levenshtein_memoization_helper(up_to_last(source), target, distances) + 1;
143        let insert = levenshtein_memoization_helper(source, up_to_last(target), distances) + 1;
144        let substitute =
145            levenshtein_memoization_helper(up_to_last(source), up_to_last(target), distances) + k;
146
147        let distance = min(min(delete, insert), substitute);
148
149        // update the cache
150        distances[source.len()][target.len()] = distance;
151
152        distance
153    }
154
155    let mut distances = get_distance_table(source.len(), target.len());
156
157    let distance = levenshtein_memoization_helper(source, target, &mut distances);
158
159    (distance, distances)
160}
161
162#[cfg(test)]
163mod tests {
164    use crate::distance::*;
165
166    #[test]
167    fn levenshtein_naive_test() {
168        let s1 = String::from("LAWN");
169        let s2 = String::from("FFLAWANN");
170        let expected_leven = 4;
171
172        let leven_naive = levenshtein_naive(s1.as_bytes(), s2.as_bytes());
173
174        assert_eq!(leven_naive, expected_leven);
175    }
176
177    #[test]
178    fn levenshtein_memoization_test() {
179        let s1 = String::from("LAWN");
180        let s2 = String::from("FFLAWANN");
181        let expected_leven = 4;
182
183        let (leven_memo, _) = levenshtein_memoization(s1.as_bytes(), s2.as_bytes());
184
185        assert_eq!(leven_memo, expected_leven);
186    }
187
188    #[test]
189    fn levenshtein_tabulation_test() {
190        let s1 = String::from("LAWN");
191        let s2 = String::from("FFLAWANN");
192        let expected_leven = 4;
193
194        let (leven_tab, _) = levenshtein_tabulation(s1.as_bytes(), s2.as_bytes());
195
196        assert_eq!(leven_tab, expected_leven);
197    }
198}