1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
use std::cmp;
use std::hash::Hash;
use std::collections::HashSet;

#[derive(Debug)]
pub struct LcsTable<'a, T: 'a> {
    lengths: Vec<Vec<i64>>,

    a: &'a [T],
    b: &'a [T]
}

/// Finding longest common subsequences ("LCS") between two sequences requires constructing a *n x
/// m* table (where the two sequences are of lengths *n* and *m*). This is expensive to construct
/// and there's a lot of stuff you can calculate using it, so `LcsTable` holds onto this data.
impl<'a, T> LcsTable<'a, T> where T: Eq {
    /// Constructs a LcsTable for matching between two sequences `a` and `b`.
    pub fn new(a: &'a [T], b: &'a [T]) -> LcsTable<'a, T> {
        let mut lengths = vec![vec![0; b.len() + 1]; a.len() + 1];

        for i in 0..a.len() {
            for j in 0..b.len() {
                lengths[i + 1][j + 1] = if a[i] == b[j] {
                    1 + lengths[i][j]
                } else {
                    cmp::max(lengths[i + 1][j], lengths[i][j + 1])
                }
            }
        }

        LcsTable { lengths: lengths, a: a, b: b }
    }

    /// Gets the longest common subsequence between `a` and `b`.
    ///
    /// Example:
    ///
    /// ```
    /// use lcs::LcsTable;
    ///
    /// let a: Vec<_> = "a--b---c".chars().collect();
    /// let b: Vec<_> = "abc".chars().collect();
    ///
    /// let table = LcsTable::new(&a, &b);
    /// let lcs = table.longest_common_subsequence();
    ///
    /// assert_eq!(vec![&'a', &'b', &'c'], lcs);
    /// ```
    pub fn longest_common_subsequence(&self) -> Vec<&T> {
        self.find_lcs(self.a.len(), self.b.len())
    }

    fn find_lcs(&self, i: usize, j: usize) -> Vec<&T> {
        if i == 0 || j == 0 {
            return vec![];
        }

        if self.a[i - 1] == self.b[j - 1] {
            let mut prefix_lcs = self.find_lcs(i - 1, j - 1);
            prefix_lcs.push(&self.a[i - 1]);
            prefix_lcs
        } else {
            if self.lengths[i][j - 1] > self.lengths[i - 1][j] {
                self.find_lcs(i, j - 1)
            } else {
                self.find_lcs(i - 1, j)
            }
        }
    }

    /// Gets all longest common subsequences between `a` and `b`.
    ///
    /// Example:
    ///
    /// ```
    /// use lcs::LcsTable;
    ///
    /// let a: Vec<_> = "aaabbb-cccddd".chars().collect();
    /// let b: Vec<_> = "cdab".chars().collect();
    ///
    /// let table = LcsTable::new(&a, &b);
    /// let lcses = table.longest_common_subsequences();
    ///
    /// assert_eq!(2, lcses.len());
    /// assert!(lcses.contains(&vec![&'a', &'b']));
    /// assert!(lcses.contains(&vec![&'c', &'d']));
    /// ```
    pub fn longest_common_subsequences(&self) -> HashSet<Vec<&T>>
            where T: Hash {
        self.find_all_lcs(self.a.len(), self.b.len())
    }

    fn find_all_lcs(&self, i: usize, j: usize) -> HashSet<Vec<&T>>
            where T: Hash {
        if i == 0 || j == 0 {
            let mut ret = HashSet::new();
            ret.insert(vec![]);
            return ret;
        }

        if self.a[i - 1] == self.b[j - 1] {
            let mut sequences = HashSet::new();
            for mut lcs in self.find_all_lcs(i - 1, j - 1) {
                lcs.push(&self.a[i - 1]);
                sequences.insert(lcs);
            }
            sequences
        } else {
            let mut sequences = HashSet::new();

            if self.lengths[i][j - 1] >= self.lengths[i - 1][j] {
                for lsc in self.find_all_lcs(i, j - 1) {
                    sequences.insert(lsc);
                }
            }

            if self.lengths[i - 1][j] >= self.lengths[i][j - 1] {
                for lsc in self.find_all_lcs(i - 1, j) {
                    sequences.insert(lsc);
                }
            }

            sequences
        }
    }
}

#[test]
fn test_lcs_table() {
    // Example taken from:
    //
    // https://en.wikipedia.org/wiki/Longest_common_subsequence_problem#Worked_example

    let a: Vec<_> = "gac".chars().collect();
    let b: Vec<_> = "agcat".chars().collect();

    let actual_lengths = LcsTable::new(&a, &b).lengths;
    let expected_lengths = vec![
        vec![0, 0, 0, 0, 0, 0],
        vec![0, 0, 1, 1, 1, 1],
        vec![0, 1, 1, 1, 2, 2],
        vec![0, 1, 1, 2, 2, 2]
    ];

    assert_eq!(expected_lengths, actual_lengths);
}

#[test]
fn test_lcs_lcs() {
    let a: Vec<_> = "XXXaXXXbXXXc".chars().collect();
    let b: Vec<_> = "YYaYYbYYc".chars().collect();

    let table = LcsTable::new(&a, &b);
    let lcs = table.longest_common_subsequence();
    assert_eq!(vec![&'a', &'b', &'c'], lcs);
}

#[test]
fn test_longest_common_subsequences() {
    let a: Vec<_> = "gac".chars().collect();
    let b: Vec<_> = "agcat".chars().collect();

    let table = LcsTable::new(&a, &b);
    let subsequences = table.longest_common_subsequences();
    assert_eq!(3, subsequences.len());
    assert!(subsequences.contains(&vec![&'a', &'c']));
    assert!(subsequences.contains(&vec![&'g', &'a']));
    assert!(subsequences.contains(&vec![&'g', &'c']));
}