Skip to main content

divsufsort_rs/
utils.rs

1use alloc::format;
2use alloc::string::String;
3use alloc::vec;
4use alloc::vec::Vec;
5
6use crate::DivSufSortError;
7use crate::constants::ALPHABET_SIZE;
8
9/// Error returned by [`sufcheck`] when the suffix array is found to be invalid.
10#[derive(Debug, PartialEq, Eq)]
11pub struct SufCheckError {
12    /// Human-readable description of the first inconsistency found.
13    pub message: String,
14}
15
16impl core::fmt::Display for SufCheckError {
17    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
18        write!(f, "sufcheck: {}", self.message)
19    }
20}
21
22#[cfg(feature = "std")]
23impl std::error::Error for SufCheckError {}
24
25/// Verifies that `sa` is the correct suffix array of `t`.
26///
27/// Checks that all SA entries are in range, that suffixes are lexicographically ordered,
28/// and that the inverse SA mapping is consistent.
29///
30/// If `verbose` is `true`, diagnostic messages are printed to stderr.
31///
32/// # Errors
33///
34/// Returns [`SufCheckError`] describing the first inconsistency found.
35pub fn sufcheck(t: &[u8], sa: &[i32], #[allow(unused)] verbose: bool) -> Result<(), SufCheckError> {
36    let n = t.len();
37
38    if n == 0 {
39        #[cfg(feature = "std")]
40        if verbose {
41            eprintln!("sufcheck: Done.");
42        }
43        return Ok(());
44    }
45
46    if sa.len() != n {
47        let msg = format!("SA length {} != T length {}", sa.len(), n);
48        #[cfg(feature = "std")]
49        if verbose {
50            eprintln!("sufcheck: {msg}");
51        }
52        return Err(SufCheckError { message: msg });
53    }
54
55    for (i, &s) in sa.iter().enumerate().take(n) {
56        if s < 0 || s as usize >= n {
57            let msg = format!("Out of the range [0,{}]. SA[{}]={}", n - 1, i, s);
58            #[cfg(feature = "std")]
59            if verbose {
60                eprintln!("sufcheck: {msg}");
61            }
62            return Err(SufCheckError { message: msg });
63        }
64    }
65
66    for i in 1..n {
67        if t[sa[i - 1] as usize] > t[sa[i] as usize] {
68            let msg = format!(
69                "Suffixes in wrong order. T[SA[{}]={}]={} > T[SA[{}]={}]={}",
70                i - 1,
71                sa[i - 1],
72                t[sa[i - 1] as usize],
73                i,
74                sa[i],
75                t[sa[i] as usize]
76            );
77            #[cfg(feature = "std")]
78            if verbose {
79                eprintln!("sufcheck: {msg}");
80            }
81            return Err(SufCheckError { message: msg });
82        }
83    }
84
85    let mut c_table = [0i32; ALPHABET_SIZE];
86    for &ch in t {
87        c_table[ch as usize] += 1;
88    }
89    let mut p = 0i32;
90    for entry in c_table.iter_mut() {
91        let t_val = *entry;
92        *entry = p;
93        p += t_val;
94    }
95
96    let q = c_table[t[n - 1] as usize];
97    c_table[t[n - 1] as usize] += 1;
98
99    for i in 0..n {
100        let sai = sa[i];
101        let (c, t_val) = if sai > 0 {
102            let c = t[(sai - 1) as usize] as usize;
103            (c, c_table[c])
104        } else {
105            let c = t[n - 1] as usize;
106            (c, q)
107        };
108
109        if t_val < 0 || sa[t_val as usize] != (if sai > 0 { sai - 1 } else { (n - 1) as i32 }) {
110            let msg = format!(
111                "Suffix in wrong position. SA[{}]={} or SA[{}]={}",
112                t_val,
113                if t_val >= 0 { sa[t_val as usize] } else { -1 },
114                i,
115                sa[i]
116            );
117            #[cfg(feature = "std")]
118            if verbose {
119                eprintln!("sufcheck: {msg}");
120            }
121            return Err(SufCheckError { message: msg });
122        }
123
124        if t_val != q {
125            c_table[c] += 1;
126            if c_table[c] as usize >= n || t[sa[c_table[c] as usize] as usize] as usize != c {
127                c_table[c] = -1;
128            }
129        }
130    }
131
132    #[cfg(feature = "std")]
133    if verbose {
134        eprintln!("sufcheck: Done.");
135    }
136    Ok(())
137}
138
139fn binarysearch_lower(a: &[i32], mut size: i32, value: i32) -> i32 {
140    let mut i = 0i32;
141    let mut half = size >> 1;
142    while 0 < size {
143        if a[(i + half) as usize] < value {
144            i += half + 1;
145            half -= (size & 1) ^ 1;
146        }
147        size = half;
148        half >>= 1;
149    }
150    i
151}
152
153/// Computes the Burrows-Wheeler Transform of `t`, writing the result into `u`.
154///
155/// If `sa` is `Some`, the provided suffix array is used directly. If `sa` is `None`,
156/// the transform is computed via [`crate::divbwt`].
157///
158/// On success, `*idx` is set to the primary index (1-based).
159///
160/// # Errors
161///
162/// Returns [`DivSufSortError::InvalidArgument`] if arguments are inconsistent.
163pub fn bw_transform(
164    t: &[u8],
165    u: &mut [u8],
166    sa: Option<&mut [i32]>,
167    idx: &mut i32,
168) -> Result<(), DivSufSortError> {
169    let n = t.len();
170
171    if n == 0 {
172        *idx = 0;
173        return Ok(());
174    }
175    if n == 1 {
176        u[0] = t[0];
177        *idx = 1;
178        return Ok(());
179    }
180
181    match sa {
182        None => {
183            // delegate to divbwt
184            let primary_idx = crate::divbwt(t, u, None)?;
185            *idx = primary_idx;
186            Ok(())
187        }
188        Some(a) => {
189            // T != U case (straightforward implementation)
190            u[0] = t[n - 1];
191            let mut i = 0usize;
192            while a[i] != 0 {
193                u[i + 1] = t[(a[i] - 1) as usize];
194                i += 1;
195            }
196            *idx = (i + 1) as i32;
197            i += 1;
198            while i < n {
199                u[i] = t[(a[i] - 1) as usize];
200                i += 1;
201            }
202            Ok(())
203        }
204    }
205}
206
207/// Inverts the Burrows-Wheeler Transform.
208///
209/// Given the BWT `t` and its primary index `idx`, reconstructs the original string
210/// into `u`. `a` is an optional scratch buffer of length ≥ `t.len()`; if `None`,
211/// an internal allocation is used.
212///
213/// # Errors
214///
215/// Returns [`DivSufSortError::InvalidArgument`] if `idx` is out of range or `a` is
216/// too short.
217pub fn inverse_bw_transform(
218    t: &[u8],
219    u: &mut [u8],
220    a: Option<&mut [i32]>,
221    idx: i32,
222) -> Result<(), DivSufSortError> {
223    let n = t.len();
224
225    if idx < 0 || idx as usize > n || (n > 0 && idx == 0) {
226        return Err(DivSufSortError::InvalidArgument);
227    }
228    if n <= 1 {
229        if n == 1 {
230            u[0] = t[0];
231        }
232        return Ok(());
233    }
234
235    let mut b_buf: Vec<i32>;
236    let b: &mut [i32] = match a {
237        Some(ref_a) => {
238            if ref_a.len() < n {
239                return Err(DivSufSortError::InvalidArgument);
240            }
241            ref_a
242        }
243        None => {
244            b_buf = vec![0i32; n];
245            &mut b_buf
246        }
247    };
248
249    let mut c_table = [0i32; ALPHABET_SIZE];
250    for &ch in t {
251        c_table[ch as usize] += 1;
252    }
253
254    // convert C table to cumulative sums and record seen characters in D
255    let mut d_buf = [0u8; ALPHABET_SIZE];
256    let mut d_len = 0usize;
257    let mut acc = 0i32;
258    for (c, cnt_ref) in c_table.iter_mut().enumerate() {
259        let cnt = *cnt_ref;
260        if cnt > 0 {
261            *cnt_ref = acc;
262            d_buf[d_len] = c as u8;
263            d_len += 1;
264            acc += cnt;
265        }
266    }
267
268    let idx_usize = idx as usize;
269    for i in 0..idx_usize {
270        b[c_table[t[i] as usize] as usize] = i as i32;
271        c_table[t[i] as usize] += 1;
272    }
273    for i in idx_usize..n {
274        b[c_table[t[i] as usize] as usize] = (i + 1) as i32;
275        c_table[t[i] as usize] += 1;
276    }
277
278    let d = &d_buf[..d_len];
279    let mut c_d = vec![0i32; d_len];
280    for (ci, &dc) in d.iter().enumerate() {
281        c_d[ci] = c_table[dc as usize];
282    }
283
284    let mut p = idx;
285    for u_elem in u.iter_mut().take(n) {
286        let pos = binarysearch_lower(&c_d, d_len as i32, p);
287        *u_elem = d[pos as usize];
288        p = b[(p - 1) as usize];
289    }
290
291    Ok(())
292}
293
294fn compare(t: &[u8], p: &[u8], suf: i32, match_len: &mut i32) -> i32 {
295    let tsize = t.len() as i32;
296    let psize = p.len() as i32;
297    let mut i = suf + *match_len;
298    let mut j = *match_len;
299    let mut r = 0i32;
300    while i < tsize && j < psize {
301        r = t[i as usize] as i32 - p[j as usize] as i32;
302        if r != 0 {
303            break;
304        }
305        i += 1;
306        j += 1;
307    }
308    *match_len = j;
309    if r == 0 {
310        -(if j != psize { 1 } else { 0 })
311    } else {
312        r
313    }
314}
315
316/// Searches for pattern `p` in text `t` using the suffix array `sa`.
317///
318/// Returns `(count, left)` where `count` is the number of occurrences and `left` is the
319/// leftmost index in `sa` at which a match starts. If `count == 0`, `left` is undefined.
320pub fn sa_search(t: &[u8], p: &[u8], sa: &[i32]) -> (i32, i32) {
321    let tsize = t.len() as i32;
322    let psize = p.len() as i32;
323    let sasize = sa.len() as i32;
324
325    if sasize == 0 || tsize == 0 {
326        return (0, -1);
327    }
328    if psize == 0 {
329        return (sasize, 0);
330    }
331
332    let mut i = 0i32;
333    let mut j = 0i32;
334    let mut k = 0i32;
335    let mut lmatch = 0i32;
336    let mut rmatch = 0i32;
337    let mut size = sasize;
338    let mut half = size >> 1;
339
340    while 0 < size {
341        let mut match_len = lmatch.min(rmatch);
342        let r = compare(t, p, sa[(i + half) as usize], &mut match_len);
343        if r < 0 {
344            i += half + 1;
345            half -= (size & 1) ^ 1;
346            lmatch = match_len;
347        } else if r > 0 {
348            rmatch = match_len;
349        } else {
350            let lsize = half;
351            j = i;
352            let rsize = size - half - 1;
353            k = i + half + 1;
354
355            let mut llmatch = lmatch;
356            let mut lrmatch = match_len;
357            let mut lsize2 = lsize;
358            let mut lhalf = lsize2 >> 1;
359            while 0 < lsize2 {
360                let mut lm = llmatch.min(lrmatch);
361                let lr = compare(t, p, sa[(j + lhalf) as usize], &mut lm);
362                if lr < 0 {
363                    j += lhalf + 1;
364                    lhalf -= (lsize2 & 1) ^ 1;
365                    llmatch = lm;
366                } else {
367                    lrmatch = lm;
368                }
369                lsize2 = lhalf;
370                lhalf >>= 1;
371            }
372
373            let mut rlmatch = match_len;
374            let mut rrmatch = rmatch;
375            let mut rsize2 = rsize;
376            let mut rhalf = rsize2 >> 1;
377            while 0 < rsize2 {
378                let mut rm = rlmatch.min(rrmatch);
379                let rr = compare(t, p, sa[(k + rhalf) as usize], &mut rm);
380                if rr <= 0 {
381                    k += rhalf + 1;
382                    rhalf -= (rsize2 & 1) ^ 1;
383                    rlmatch = rm;
384                } else {
385                    rrmatch = rm;
386                }
387                rsize2 = rhalf;
388                rhalf >>= 1;
389            }
390
391            break;
392        }
393        size = half;
394        half >>= 1;
395    }
396
397    let count = k - j;
398    let left = if count > 0 { j } else { i };
399    (count, left)
400}
401
402/// Searches for a single character `c` in text `t` using the suffix array `sa`.
403///
404/// Returns `(count, left)` where `count` is the number of occurrences and `left` is the
405/// leftmost index in `sa` at which a match starts. If `count == 0`, `left` is undefined.
406pub fn sa_simplesearch(t: &[u8], sa: &[i32], c: u8) -> (i32, i32) {
407    let tsize = t.len() as i32;
408    let sasize = sa.len() as i32;
409    let c = c as i32;
410
411    if sasize == 0 || tsize == 0 {
412        return (0, -1);
413    }
414
415    let mut i = 0i32;
416    let mut j = 0i32;
417    let mut k = 0i32;
418    let mut size = sasize;
419    let mut half = size >> 1;
420
421    while 0 < size {
422        let p = sa[(i + half) as usize];
423        let r = if p < tsize {
424            t[p as usize] as i32 - c
425        } else {
426            -1
427        };
428        if r < 0 {
429            i += half + 1;
430            half -= (size & 1) ^ 1;
431        } else if r == 0 {
432            let lsize = half;
433            j = i;
434            let rsize = size - half - 1;
435            k = i + half + 1;
436
437            let mut lsize2 = lsize;
438            let mut lhalf = lsize2 >> 1;
439            while 0 < lsize2 {
440                let lp = sa[(j + lhalf) as usize];
441                let lr = if lp < tsize {
442                    t[lp as usize] as i32 - c
443                } else {
444                    -1
445                };
446                if lr < 0 {
447                    j += lhalf + 1;
448                    lhalf -= (lsize2 & 1) ^ 1;
449                }
450                lsize2 = lhalf;
451                lhalf >>= 1;
452            }
453
454            let mut rsize2 = rsize;
455            let mut rhalf = rsize2 >> 1;
456            while 0 < rsize2 {
457                let rp = sa[(k + rhalf) as usize];
458                let rr = if rp < tsize {
459                    t[rp as usize] as i32 - c
460                } else {
461                    -1
462                };
463                if rr <= 0 {
464                    k += rhalf + 1;
465                    rhalf -= (rsize2 & 1) ^ 1;
466                }
467                rsize2 = rhalf;
468                rhalf >>= 1;
469            }
470
471            break;
472        }
473        size = half;
474        half >>= 1;
475    }
476
477    let count = k - j;
478    let left = if count > 0 { j } else { i };
479    (count, left)
480}
481
482#[cfg(test)]
483mod tests {
484    use super::*;
485
486    #[test]
487    fn test_sufcheck_empty() {
488        assert_eq!(sufcheck(b"", &[], false), Ok(()));
489    }
490
491    #[test]
492    fn test_sufcheck_single() {
493        assert_eq!(sufcheck(b"a", &[0], false), Ok(()));
494    }
495
496    #[test]
497    fn test_sufcheck_ba() {
498        // "ba": suffixes are "a"(1) < "ba"(0)
499        assert_eq!(sufcheck(b"ba", &[1, 0], false), Ok(()));
500    }
501
502    #[test]
503    fn test_sufcheck_wrong_order() {
504        // SA is in reverse order, so this should fail
505        assert!(sufcheck(b"ba", &[0, 1], false).is_err());
506    }
507
508    #[test]
509    fn test_sufcheck_out_of_range() {
510        assert!(sufcheck(b"ba", &[0, 2], false).is_err());
511    }
512
513    #[test]
514    fn test_sufcheck_banana() {
515        // known suffix array for "banana"
516        assert_eq!(sufcheck(b"banana", &[5, 3, 1, 0, 4, 2], false), Ok(()));
517    }
518
519    #[test]
520    fn test_sa_search_empty_sa() {
521        let (count, _) = sa_search(b"banana", b"an", &[]);
522        assert_eq!(count, 0);
523    }
524
525    #[test]
526    fn test_sa_search_empty_pattern() {
527        let sa = [5i32, 3, 1, 0, 4, 2];
528        let (count, left) = sa_search(b"banana", b"", &sa);
529        assert_eq!(count, 6);
530        assert_eq!(left, 0);
531    }
532
533    #[test]
534    fn test_sa_search_found() {
535        // SA of "banana" = [5,3,1,0,4,2]
536        // "an" occurs at positions 1, 3 → ranks 1,2 in SA
537        let sa = [5i32, 3, 1, 0, 4, 2];
538        let (count, left) = sa_search(b"banana", b"an", &sa);
539        assert_eq!(count, 2);
540        // count matches starting at left
541        for idx in left..left + count {
542            let suf_start = sa[idx as usize] as usize;
543            assert!(b"banana"[suf_start..].starts_with(b"an"));
544        }
545    }
546
547    #[test]
548    fn test_sa_search_not_found() {
549        let sa = [5i32, 3, 1, 0, 4, 2];
550        let (count, _) = sa_search(b"banana", b"xyz", &sa);
551        assert_eq!(count, 0);
552    }
553
554    #[test]
555    fn test_sa_simplesearch_found() {
556        // 'a' appears 3 times in "banana"
557        let sa = [5i32, 3, 1, 0, 4, 2];
558        let (count, left) = sa_simplesearch(b"banana", &sa, b'a');
559        assert_eq!(count, 3);
560        for idx in left..left + count {
561            let suf_start = sa[idx as usize] as usize;
562            assert_eq!(b"banana"[suf_start], b'a');
563        }
564    }
565
566    #[test]
567    fn test_sa_simplesearch_not_found() {
568        let sa = [5i32, 3, 1, 0, 4, 2];
569        let (count, _) = sa_simplesearch(b"banana", &sa, b'z');
570        assert_eq!(count, 0);
571    }
572
573    #[test]
574    fn test_bw_transform_banana() {
575        // "banana" SA=[5,3,1,0,4,2] → BWT="annbaa", idx=4
576        let t = b"banana";
577        let mut sa = vec![5i32, 3, 1, 0, 4, 2];
578        let mut u = vec![0u8; t.len()];
579        let mut idx = 0i32;
580        bw_transform(t, &mut u, Some(&mut sa), &mut idx).unwrap();
581        assert_eq!(&u, b"annbaa");
582        assert_eq!(idx, 4);
583    }
584
585    #[test]
586    fn test_inverse_bw_transform_banana() {
587        // BWT="annbaa", idx=4 → "banana"
588        let bwt = b"annbaa";
589        let mut u = vec![0u8; bwt.len()];
590        inverse_bw_transform(bwt, &mut u, None, 4).unwrap();
591        assert_eq!(&u, b"banana");
592    }
593
594    #[test]
595    fn test_bw_roundtrip() {
596        let t = b"mississippi";
597        // known suffix array for "mississippi"
598        let mut sa = vec![10i32, 7, 4, 1, 0, 9, 8, 6, 3, 5, 2];
599        let mut bwt = vec![0u8; t.len()];
600        let mut idx = 0i32;
601        bw_transform(t, &mut bwt, Some(&mut sa), &mut idx).unwrap();
602
603        let mut restored = vec![0u8; t.len()];
604        inverse_bw_transform(&bwt, &mut restored, None, idx).unwrap();
605        assert_eq!(restored, t);
606    }
607
608    #[test]
609    fn test_binarysearch_lower() {
610        let a = [3i32, 4, 6];
611        assert_eq!(binarysearch_lower(&a, 3, 3), 0);
612        assert_eq!(binarysearch_lower(&a, 3, 4), 1);
613        assert_eq!(binarysearch_lower(&a, 3, 5), 2);
614        assert_eq!(binarysearch_lower(&a, 3, 6), 2);
615        assert_eq!(binarysearch_lower(&a, 3, 7), 3);
616    }
617}