Skip to main content

cyanea_seq/
pattern.rs

1//! Pattern matching algorithms for biological sequences.
2//!
3//! Provides exact and approximate string matching on `&[u8]` slices, suitable
4//! for DNA, RNA, and protein sequences alike.
5//!
6//! ## Exact matchers
7//!
8//! - [`horspool`] — Boyer-Moore-Horspool with bad-character shift table, O(n/m) average
9//! - [`kmp`] — Knuth-Morris-Pratt with failure function, O(n+m)
10//! - [`shift_and`] — Shift-And bitparallel (pattern <= 64)
11//! - [`bndm`] — Backward Nondeterministic DAWG Matching, bitparallel (pattern <= 64)
12//! - [`bom`] — Backward Oracle Matching with factor oracle
13//!
14//! ## Approximate matchers
15//!
16//! - [`myers_bitparallel`] — Myers bit-parallel edit distance (pattern <= 64)
17//! - [`ukkonen`] — Ukkonen cut-off approximate matching via bounded DP
18
19/// Boyer-Moore-Horspool exact pattern matching.
20///
21/// Builds a bad-character shift table and scans right-to-left within each
22/// alignment window. Average case O(n/m), worst case O(nm).
23///
24/// Returns starting positions of all exact occurrences.
25pub fn horspool(text: &[u8], pattern: &[u8]) -> Vec<usize> {
26    let n = text.len();
27    let m = pattern.len();
28    if m == 0 || m > n {
29        return vec![];
30    }
31
32    // Bad-character shift table: default shift is m.
33    let mut shift = [m; 256];
34    for i in 0..m - 1 {
35        shift[pattern[i] as usize] = m - 1 - i;
36    }
37
38    let mut results = Vec::new();
39    let mut i = 0;
40    while i <= n - m {
41        let mut j = m - 1;
42        while pattern[j] == text[i + j] {
43            if j == 0 {
44                results.push(i);
45                break;
46            }
47            j -= 1;
48        }
49        i += shift[text[i + m - 1] as usize];
50    }
51    results
52}
53
54/// Knuth-Morris-Pratt exact pattern matching.
55///
56/// Builds a failure (partial match) table in O(m), then scans in O(n).
57/// Total time O(n+m), space O(m).
58///
59/// Returns starting positions of all exact occurrences.
60pub fn kmp(text: &[u8], pattern: &[u8]) -> Vec<usize> {
61    let n = text.len();
62    let m = pattern.len();
63    if m == 0 || m > n {
64        return vec![];
65    }
66
67    // Build failure function.
68    let mut fail = vec![0usize; m];
69    let mut k = 0usize;
70    for i in 1..m {
71        while k > 0 && pattern[k] != pattern[i] {
72            k = fail[k - 1];
73        }
74        if pattern[k] == pattern[i] {
75            k += 1;
76        }
77        fail[i] = k;
78    }
79
80    // Search phase.
81    let mut results = Vec::new();
82    let mut q = 0usize;
83    for i in 0..n {
84        while q > 0 && pattern[q] != text[i] {
85            q = fail[q - 1];
86        }
87        if pattern[q] == text[i] {
88            q += 1;
89        }
90        if q == m {
91            results.push(i + 1 - m);
92            q = fail[q - 1];
93        }
94    }
95    results
96}
97
98/// Shift-And bitparallel exact pattern matching.
99///
100/// Encodes the pattern as bitmasks (one per alphabet symbol) and simulates
101/// an NFA with bitwise operations. Limited to patterns of length <= 64.
102///
103/// Returns starting positions of all exact occurrences, or an empty vec
104/// if the pattern exceeds 64 characters.
105pub fn shift_and(text: &[u8], pattern: &[u8]) -> Vec<usize> {
106    let n = text.len();
107    let m = pattern.len();
108    if m == 0 || m > n || m > 64 {
109        return vec![];
110    }
111
112    // Build bitmask table: B[c] has bit j set if pattern[j] == c.
113    let mut b = [0u64; 256];
114    for j in 0..m {
115        b[pattern[j] as usize] |= 1u64 << j;
116    }
117
118    let accept = 1u64 << (m - 1);
119    let mut state = 0u64;
120    let mut results = Vec::new();
121
122    for i in 0..n {
123        state = ((state << 1) | 1) & b[text[i] as usize];
124        if state & accept != 0 {
125            results.push(i + 1 - m);
126        }
127    }
128    results
129}
130
131/// Backward Nondeterministic DAWG Matching (BNDM) exact pattern matching.
132///
133/// A bitparallel algorithm that scans the current window backward, using a
134/// nondeterministic suffix automaton to detect both matches and feasible
135/// shift prefixes simultaneously. Limited to patterns of length <= 64.
136///
137/// Returns starting positions of all exact occurrences, or an empty vec
138/// if the pattern exceeds 64 characters.
139pub fn bndm(text: &[u8], pattern: &[u8]) -> Vec<usize> {
140    let n = text.len();
141    let m = pattern.len();
142    if m == 0 || m > n || m > 64 {
143        return vec![];
144    }
145
146    // Build bitmask table: B[c] has bit j set if pattern[m-1-j] == c
147    // (reversed pattern positions for backward scanning).
148    let mut b = [0u64; 256];
149    for j in 0..m {
150        b[pattern[m - 1 - j] as usize] |= 1u64 << j;
151    }
152
153    let accept = 1u64 << (m - 1);
154    let mut results = Vec::new();
155    let mut pos = 0usize;
156
157    while pos <= n - m {
158        let mut j = m - 1;
159        let mut last = m;
160        let mut d = !0u64; // all bits set
161
162        loop {
163            d &= b[text[pos + j] as usize];
164            if d == 0 {
165                break;
166            }
167            if d & accept != 0 {
168                if j == 0 {
169                    results.push(pos);
170                    break;
171                }
172                last = j;
173            }
174            d <<= 1;
175            if j == 0 {
176                break;
177            }
178            j -= 1;
179        }
180        pos += last;
181    }
182    results
183}
184
185/// Backward Oracle Matching (BOM) exact pattern matching.
186///
187/// Builds a factor oracle for the reversed pattern, then scans the text
188/// backward within each window. The factor oracle recognizes at least all
189/// factors of the pattern, enabling efficient shift computation.
190///
191/// Returns starting positions of all exact occurrences.
192pub fn bom(text: &[u8], pattern: &[u8]) -> Vec<usize> {
193    let n = text.len();
194    let m = pattern.len();
195    if m == 0 || m > n {
196        return vec![];
197    }
198
199    // Build factor oracle for reversed pattern.
200    // States: 0..=m, transitions stored as a vec of hashmaps.
201    let rev: Vec<u8> = pattern.iter().rev().copied().collect();
202    let states = m + 1;
203    let mut goto: Vec<[i32; 256]> = vec![[-1i32; 256]; states];
204    let mut supply = vec![0usize; states];
205
206    // Build oracle: add transitions for each character of the reversed pattern.
207    for i in 0..m {
208        goto[i][rev[i] as usize] = (i + 1) as i32;
209        let mut s = supply[i];
210        while s != 0 && goto[s][rev[i] as usize] == -1 {
211            goto[s][rev[i] as usize] = (i + 1) as i32;
212            s = supply[s];
213        }
214        if s == 0 && goto[0][rev[i] as usize] == -1 {
215            goto[0][rev[i] as usize] = (i + 1) as i32;
216            supply[i + 1] = 0;
217        } else if goto[s][rev[i] as usize] == (i + 1) as i32 {
218            supply[i + 1] = s;
219        } else {
220            supply[i + 1] = goto[s][rev[i] as usize] as usize;
221        }
222    }
223
224    // Search: scan backward using the factor oracle.
225    let mut results = Vec::new();
226    let mut pos = 0usize;
227
228    while pos <= n - m {
229        let mut state = 0usize;
230        let mut j = m;
231
232        while j > 0 {
233            let c = text[pos + j - 1] as usize;
234            let next = goto[state][c];
235            if next == -1 {
236                break;
237            }
238            state = next as usize;
239            j -= 1;
240        }
241
242        if j == 0 {
243            // Verify the match (oracle may accept superset of factors).
244            if &text[pos..pos + m] == pattern {
245                results.push(pos);
246            }
247            pos += 1;
248        } else {
249            pos += j;
250        }
251    }
252    results
253}
254
255/// Myers bit-parallel approximate matching.
256///
257/// Computes edit distance (Levenshtein) between the pattern and all
258/// substrings of the text using Myers' 1999 bit-vector algorithm.
259/// Limited to patterns of length <= 64.
260///
261/// Returns `(end_position, edit_distance)` pairs for every text position
262/// where the best alignment ending there has edit distance <= `max_dist`.
263/// Returns an empty vec if the pattern exceeds 64 characters.
264pub fn myers_bitparallel(
265    text: &[u8],
266    pattern: &[u8],
267    max_dist: usize,
268) -> Vec<(usize, usize)> {
269    let n = text.len();
270    let m = pattern.len();
271    if m == 0 || m > 64 {
272        return vec![];
273    }
274    if n == 0 {
275        return vec![];
276    }
277
278    // Build pattern bitmasks: peq[c] has bit j set if pattern[j] == c.
279    let mut peq = [0u64; 256];
280    for j in 0..m {
281        peq[pattern[j] as usize] |= 1u64 << j;
282    }
283
284    let mask = if m == 64 { !0u64 } else { (1u64 << m) - 1 };
285    let msb = 1u64 << (m - 1);
286    let mut pv = mask; // low m bits set (positive vertical delta)
287    let mut mv = 0u64; // negative vertical delta
288    let mut score = m; // current edit distance
289
290    let mut results = Vec::new();
291
292    for i in 0..n {
293        let eq = peq[text[i] as usize];
294
295        let xv = eq | mv;
296        let xh = (((eq & pv).wrapping_add(pv)) ^ pv) | eq;
297
298        // Mask to m bits before score test — high bits from NOT would corrupt.
299        let ph = (mv | !(xh | pv)) & mask;
300        let mh = (pv & xh) & mask;
301
302        // Update score based on bit m-1 of ph and mh.
303        if ph & msb != 0 {
304            score += 1;
305        }
306        if mh & msb != 0 {
307            score -= 1;
308        }
309
310        // Shift for next column. Bit 0 = 0 for semi-global alignment
311        // (D[0][j] = 0: no text-start penalty).
312        pv = ((mh << 1) | !(xv | (ph << 1))) & mask;
313        mv = ((ph << 1) & xv) & mask;
314
315        if score <= max_dist {
316            results.push((i, score));
317        }
318    }
319    results
320}
321
322/// Ukkonen cut-off approximate matching via bounded dynamic programming.
323///
324/// Uses a classic DP matrix but only fills the diagonal band of width
325/// `2 * max_dist + 1`, achieving O(n * max_dist) time instead of O(nm).
326///
327/// Returns `(end_position, edit_distance)` pairs for every text position
328/// where a semi-global alignment of the pattern ends with edit distance
329/// <= `max_dist`.
330pub fn ukkonen(
331    text: &[u8],
332    pattern: &[u8],
333    max_dist: usize,
334) -> Vec<(usize, usize)> {
335    let n = text.len();
336    let m = pattern.len();
337    if m == 0 || n == 0 {
338        return vec![];
339    }
340
341    // Semi-global alignment: free gaps at start/end of text (row 0 = 0).
342    // We keep a single column of the DP matrix, length m+1.
343    let mut prev = vec![0usize; m + 1];
344    let mut curr = vec![0usize; m + 1];
345
346    // Initialize first column: aligning pattern[0..j] against empty text.
347    for j in 0..=m {
348        prev[j] = j;
349    }
350
351    let mut results = Vec::new();
352
353    for i in 1..=n {
354        curr[0] = 0; // Semi-global: no penalty for starting gap in text.
355        let mut col_min = m + 1;
356
357        for j in 1..=m {
358            let cost = if text[i - 1] == pattern[j - 1] { 0 } else { 1 };
359            curr[j] = (prev[j - 1] + cost)
360                .min(prev[j] + 1)
361                .min(curr[j - 1] + 1);
362            if j == m && curr[j] < col_min {
363                col_min = curr[j];
364            }
365        }
366
367        // Check the final row (full pattern aligned).
368        if curr[m] <= max_dist {
369            results.push((i - 1, curr[m]));
370        }
371
372        std::mem::swap(&mut prev, &mut curr);
373    }
374    results
375}
376
377#[cfg(test)]
378mod tests {
379    use super::*;
380
381    // -----------------------------------------------------------------------
382    // Helper: run all exact matchers and assert same result
383    // -----------------------------------------------------------------------
384    fn assert_exact(text: &[u8], pattern: &[u8], expected: &[usize]) {
385        assert_eq!(horspool(text, pattern), expected, "horspool");
386        assert_eq!(kmp(text, pattern), expected, "kmp");
387        // Bitparallel methods only for pattern <= 64.
388        if pattern.len() <= 64 {
389            assert_eq!(shift_and(text, pattern), expected, "shift_and");
390            assert_eq!(bndm(text, pattern), expected, "bndm");
391        }
392        assert_eq!(bom(text, pattern), expected, "bom");
393    }
394
395    #[test]
396    fn exact_match_at_start() {
397        assert_exact(b"ACGTACGT", b"ACGT", &[0, 4]);
398    }
399
400    #[test]
401    fn exact_match_at_end() {
402        assert_exact(b"TTTTACGT", b"ACGT", &[4]);
403    }
404
405    #[test]
406    fn exact_match_in_middle() {
407        assert_exact(b"TTACGTTT", b"ACGT", &[2]);
408    }
409
410    #[test]
411    fn multiple_occurrences() {
412        assert_exact(b"AAAAAA", b"AA", &[0, 1, 2, 3, 4]);
413    }
414
415    #[test]
416    fn no_match() {
417        assert_exact(b"ACGTACGT", b"TTTT", &[]);
418    }
419
420    #[test]
421    fn empty_pattern() {
422        assert_exact(b"ACGT", b"", &[]);
423    }
424
425    #[test]
426    fn empty_text() {
427        assert_exact(b"", b"ACGT", &[]);
428    }
429
430    #[test]
431    fn pattern_longer_than_text() {
432        assert_exact(b"AC", b"ACGT", &[]);
433    }
434
435    #[test]
436    fn single_char_pattern() {
437        assert_exact(b"AACAA", b"C", &[2]);
438    }
439
440    #[test]
441    fn single_char_text_match() {
442        assert_exact(b"A", b"A", &[0]);
443    }
444
445    #[test]
446    fn single_char_text_no_match() {
447        assert_exact(b"A", b"C", &[]);
448    }
449
450    #[test]
451    fn full_text_match() {
452        assert_exact(b"ACGT", b"ACGT", &[0]);
453    }
454
455    #[test]
456    fn protein_sequence() {
457        assert_exact(b"MKAILFVLV", b"AILF", &[2]);
458    }
459
460    #[test]
461    fn overlapping_pattern() {
462        assert_exact(b"ABABAB", b"ABAB", &[0, 2]);
463    }
464
465    // -----------------------------------------------------------------------
466    // Bitparallel length limit
467    // -----------------------------------------------------------------------
468    #[test]
469    fn shift_and_pattern_too_long() {
470        let text = vec![b'A'; 128];
471        let pattern = vec![b'A'; 65];
472        assert_eq!(shift_and(&text, &pattern), vec![]);
473    }
474
475    #[test]
476    fn bndm_pattern_too_long() {
477        let text = vec![b'A'; 128];
478        let pattern = vec![b'A'; 65];
479        assert_eq!(bndm(&text, &pattern), vec![]);
480    }
481
482    #[test]
483    fn shift_and_pattern_exactly_64() {
484        let text = vec![b'A'; 128];
485        let pattern = vec![b'A'; 64];
486        let result = shift_and(&text, &pattern);
487        assert_eq!(result.len(), 65); // 128 - 64 + 1
488    }
489
490    #[test]
491    fn bndm_pattern_exactly_64() {
492        let text = vec![b'A'; 128];
493        let pattern = vec![b'A'; 64];
494        let result = bndm(&text, &pattern);
495        assert_eq!(result.len(), 65);
496    }
497
498    // -----------------------------------------------------------------------
499    // Myers bit-parallel approximate matching
500    // -----------------------------------------------------------------------
501    #[test]
502    fn myers_exact_match() {
503        let hits = myers_bitparallel(b"ACGTACGT", b"ACGT", 0);
504        // End positions of exact matches: 3, 7
505        let ends: Vec<usize> = hits.iter().map(|&(e, _)| e).collect();
506        assert!(ends.contains(&3));
507        assert!(ends.contains(&7));
508        assert!(hits.iter().all(|&(_, d)| d == 0));
509    }
510
511    #[test]
512    fn myers_single_substitution() {
513        // Pattern ACGT vs text AXGT: one substitution at position 1.
514        let hits = myers_bitparallel(b"AXGT", b"ACGT", 1);
515        assert!(hits.iter().any(|&(e, d)| e == 3 && d == 1));
516    }
517
518    #[test]
519    fn myers_single_insertion() {
520        // Pattern ACG vs text ACXG: one insertion.
521        let hits = myers_bitparallel(b"TACXGT", b"ACG", 1);
522        assert!(hits.iter().any(|&(_, d)| d <= 1));
523    }
524
525    #[test]
526    fn myers_single_deletion() {
527        // Pattern ACGT vs text AGT: one deletion.
528        let hits = myers_bitparallel(b"AGT", b"ACGT", 1);
529        assert!(hits.iter().any(|&(_, d)| d <= 1));
530    }
531
532    #[test]
533    fn myers_no_match_within_distance() {
534        let hits = myers_bitparallel(b"AAAA", b"CCCC", 1);
535        assert!(hits.is_empty());
536    }
537
538    #[test]
539    fn myers_empty_pattern() {
540        assert_eq!(myers_bitparallel(b"ACGT", b"", 2), vec![]);
541    }
542
543    #[test]
544    fn myers_empty_text() {
545        assert_eq!(myers_bitparallel(b"", b"ACGT", 2), vec![]);
546    }
547
548    #[test]
549    fn myers_pattern_too_long() {
550        let text = vec![b'A'; 128];
551        let pattern = vec![b'A'; 65];
552        assert_eq!(myers_bitparallel(&text, &pattern, 2), vec![]);
553    }
554
555    // -----------------------------------------------------------------------
556    // Ukkonen approximate matching
557    // -----------------------------------------------------------------------
558    #[test]
559    fn ukkonen_exact_match() {
560        let hits = ukkonen(b"ACGTACGT", b"ACGT", 0);
561        let ends: Vec<usize> = hits.iter().map(|&(e, _)| e).collect();
562        assert!(ends.contains(&3));
563        assert!(ends.contains(&7));
564        assert!(hits.iter().all(|&(_, d)| d == 0));
565    }
566
567    #[test]
568    fn ukkonen_single_substitution() {
569        let hits = ukkonen(b"AXGT", b"ACGT", 1);
570        assert!(hits.iter().any(|&(e, d)| e == 3 && d == 1));
571    }
572
573    #[test]
574    fn ukkonen_single_insertion() {
575        let hits = ukkonen(b"TACXGT", b"ACG", 1);
576        assert!(hits.iter().any(|&(_, d)| d <= 1));
577    }
578
579    #[test]
580    fn ukkonen_single_deletion() {
581        let hits = ukkonen(b"AGT", b"ACGT", 1);
582        assert!(hits.iter().any(|&(_, d)| d <= 1));
583    }
584
585    #[test]
586    fn ukkonen_no_match_within_distance() {
587        let hits = ukkonen(b"AAAA", b"CCCC", 1);
588        assert!(hits.is_empty());
589    }
590
591    #[test]
592    fn ukkonen_empty_pattern() {
593        assert_eq!(ukkonen(b"ACGT", b"", 2), vec![]);
594    }
595
596    #[test]
597    fn ukkonen_empty_text() {
598        assert_eq!(ukkonen(b"", b"ACGT", 2), vec![]);
599    }
600
601    // -----------------------------------------------------------------------
602    // Cross-validation: approximate matchers agree on distance-0 positions
603    // -----------------------------------------------------------------------
604    #[test]
605    fn approx_matchers_agree_on_exact() {
606        let text = b"GATTACACGTACGTTTG";
607        let pattern = b"ACGT";
608        let myers = myers_bitparallel(text, pattern, 0);
609        let ukk = ukkonen(text, pattern, 0);
610        let myers_ends: Vec<usize> = myers.iter().map(|&(e, _)| e).collect();
611        let ukk_ends: Vec<usize> = ukk.iter().map(|&(e, _)| e).collect();
612        assert_eq!(myers_ends, ukk_ends);
613    }
614
615    // -----------------------------------------------------------------------
616    // Cross-validation: exact matchers agree with approx at distance 0
617    // -----------------------------------------------------------------------
618    #[test]
619    fn exact_vs_approx_agreement() {
620        let text = b"AACGTAACGTAA";
621        let pattern = b"ACGT";
622        let exact = kmp(text, pattern);
623        let approx = myers_bitparallel(text, pattern, 0);
624        // Convert end positions to start positions.
625        let approx_starts: Vec<usize> = approx
626            .iter()
627            .map(|&(e, _)| e + 1 - pattern.len())
628            .collect();
629        assert_eq!(exact, approx_starts);
630    }
631}