extern crate alloc;
use alloc::vec;
use alloc::vec::Vec;
pub(crate) fn bwt_forward(input: &[u8]) -> (Vec<u8>, u32) {
let n = input.len();
if n == 0 {
return (Vec::new(), 0);
}
let mut text: Vec<i32> = Vec::with_capacity(2 * n + 1);
for &b in input {
text.push(b as i32 + 1);
}
for &b in input {
text.push(b as i32 + 1);
}
text.push(0);
let sa = sa_is(&text, 257);
debug_assert_eq!(sa.len(), 2 * n + 1);
debug_assert_eq!(sa[0] as usize, 2 * n);
let mut l = Vec::with_capacity(n);
let mut origin: u32 = 0;
let mut out_i: u32 = 0;
for &s32 in sa.iter() {
let s = s32 as usize;
if s >= n {
continue;
}
let prev = if s == 0 { n - 1 } else { s - 1 };
l.push(input[prev]);
if s == 0 {
origin = out_i;
}
out_i += 1;
}
debug_assert_eq!(l.len(), n);
(l, origin)
}
fn sa_is(text: &[i32], alphabet_size: usize) -> Vec<i32> {
let n = text.len();
let mut sa = vec![-1i32; n];
sa_is_inner(text, &mut sa, alphabet_size);
sa
}
fn sa_is_inner(text: &[i32], sa: &mut [i32], alphabet_size: usize) {
let n = text.len();
debug_assert_eq!(sa.len(), n);
if n == 0 {
return;
}
if n == 1 {
sa[0] = 0;
return;
}
if n == 2 {
if text[0] < text[1] {
sa[0] = 0;
sa[1] = 1;
} else {
sa[0] = 1;
sa[1] = 0;
}
return;
}
let mut t = vec![false; n];
t[n - 1] = true;
for i in (0..n - 1).rev() {
t[i] = if text[i] < text[i + 1] {
true
} else if text[i] == text[i + 1] {
t[i + 1]
} else {
false
};
}
let mut counts = vec![0i32; alphabet_size];
for &c in text {
counts[c as usize] += 1;
}
sa.fill(-1);
let mut ends = bucket_ends(&counts);
for (i, &c_i) in text.iter().enumerate().take(n).skip(1) {
if is_lms(&t, i) {
let c = c_i as usize;
ends[c] -= 1;
sa[ends[c] as usize] = i as i32;
}
}
induce_sort_l(text, sa, &t, &counts);
induce_sort_s(text, sa, &t, &counts);
let mut n1 = 0usize;
for i in 0..n {
if sa[i] >= 0 && is_lms(&t, sa[i] as usize) {
sa[n1] = sa[i];
n1 += 1;
}
}
for slot in sa.iter_mut().take(n).skip(n1) {
*slot = -1;
}
let mut name: i32 = 0;
let mut prev: i32 = -1;
for i in 0..n1 {
let pos = sa[i] as usize;
let mut diff = false;
if prev == -1 {
diff = true;
} else {
let p = prev as usize;
let mut d = 0usize;
loop {
if pos + d >= n || p + d >= n {
diff = true;
break;
}
if text[pos + d] != text[p + d] || t[pos + d] != t[p + d] {
diff = true;
break;
}
if d > 0 && (is_lms(&t, pos + d) || is_lms(&t, p + d)) {
if is_lms(&t, pos + d) != is_lms(&t, p + d) {
diff = true;
}
break;
}
d += 1;
}
}
if diff {
name += 1;
prev = pos as i32;
}
sa[n1 + pos / 2] = name - 1;
}
let mut j = n - 1;
for i in (n1..n).rev() {
if sa[i] >= 0 {
sa[j] = sa[i];
j -= 1;
}
}
let new_alpha = (name as usize) + 1;
let (sa1_area, t1_area) = sa.split_at_mut(n - n1);
if (name as usize) == n1 {
for (i, &name_of_pos) in t1_area.iter().enumerate() {
sa1_area[name_of_pos as usize] = i as i32;
}
} else {
let sa1 = &mut sa1_area[..n1];
let mut reduced_text: Vec<i32> = Vec::with_capacity(n1);
reduced_text.extend_from_slice(&t1_area[..n1]);
sa_is_inner(&reduced_text, sa1, new_alpha);
}
let mut lms_positions: Vec<i32> = Vec::with_capacity(n1);
for (i, &is_s) in t.iter().enumerate().take(n).skip(1) {
if is_s && !t[i - 1] {
lms_positions.push(i as i32);
}
}
debug_assert_eq!(lms_positions.len(), n1);
for slot in sa.iter_mut().take(n1) {
let idx = *slot as usize; *slot = lms_positions[idx];
}
for slot in sa.iter_mut().take(n).skip(n1) {
*slot = -1;
}
let mut ends = bucket_ends(&counts);
let mut lms_sorted: Vec<i32> = Vec::with_capacity(n1);
lms_sorted.extend_from_slice(&sa[..n1]);
for slot in sa.iter_mut().take(n) {
*slot = -1;
}
for &pos in lms_sorted.iter().rev() {
let c = text[pos as usize] as usize;
ends[c] -= 1;
sa[ends[c] as usize] = pos;
}
induce_sort_l(text, sa, &t, &counts);
induce_sort_s(text, sa, &t, &counts);
}
fn is_lms(t: &[bool], i: usize) -> bool {
i > 0 && t[i] && !t[i - 1]
}
fn bucket_starts(counts: &[i32]) -> Vec<i32> {
let mut s = Vec::with_capacity(counts.len());
let mut acc = 0i32;
for &c in counts {
s.push(acc);
acc += c;
}
s
}
fn bucket_ends(counts: &[i32]) -> Vec<i32> {
let mut e = Vec::with_capacity(counts.len());
let mut acc = 0i32;
for &c in counts {
acc += c;
e.push(acc);
}
e
}
fn induce_sort_l(text: &[i32], sa: &mut [i32], t: &[bool], counts: &[i32]) {
let n = text.len();
let mut starts = bucket_starts(counts);
for i in 0..n {
if sa[i] <= 0 {
continue; }
let j = (sa[i] as usize) - 1;
if !t[j] {
let c = text[j] as usize;
let slot = starts[c] as usize;
sa[slot] = j as i32;
starts[c] += 1;
}
}
}
fn induce_sort_s(text: &[i32], sa: &mut [i32], t: &[bool], counts: &[i32]) {
let n = text.len();
let mut ends = bucket_ends(counts);
for i in (0..n).rev() {
if sa[i] <= 0 {
continue;
}
let j = (sa[i] as usize) - 1;
if t[j] {
let c = text[j] as usize;
ends[c] -= 1;
sa[ends[c] as usize] = j as i32;
}
}
}
pub(crate) fn bwt_inverse(l: &[u8], origin: u32) -> Vec<u8> {
let n = l.len();
if n == 0 {
return Vec::new();
}
debug_assert!((origin as usize) < n);
let mut count = [0u32; 256];
for &b in l {
count[b as usize] += 1;
}
let mut start = [0u32; 256];
let mut s: u32 = 0;
for c in 0..256 {
start[c] = s;
s += count[c];
}
let mut next = vec![0u32; n];
let mut cursor = start;
for (i, &b) in l.iter().enumerate() {
let c = b as usize;
next[i] = cursor[c];
cursor[c] += 1;
}
let mut out = vec![0u8; n];
let mut i = origin as usize;
for k in (0..n).rev() {
out[k] = l[i];
i = next[i] as usize;
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
#[test]
fn round_trip_short() {
let input = b"BANANA";
let (l, origin) = bwt_forward(input);
let back = bwt_inverse(&l, origin);
assert_eq!(back, input);
}
#[test]
fn round_trip_single() {
let input = b"a";
let (l, origin) = bwt_forward(input);
assert_eq!(l, b"a");
assert_eq!(origin, 0);
let back = bwt_inverse(&l, origin);
assert_eq!(back, input);
}
#[test]
fn round_trip_empty() {
let (l, origin) = bwt_forward(&[]);
assert!(l.is_empty());
assert_eq!(origin, 0);
let back = bwt_inverse(&l, origin);
assert!(back.is_empty());
}
#[test]
fn round_trip_longer() {
let input = b"the quick brown fox jumps over the lazy dog";
let (l, origin) = bwt_forward(input);
let back = bwt_inverse(&l, origin);
assert_eq!(back, input);
}
#[test]
fn round_trip_repeated_bytes() {
let input = vec![b'a'; 50];
let (l, origin) = bwt_forward(&input);
let back = bwt_inverse(&l, origin);
assert_eq!(back, input);
}
#[test]
fn round_trip_two_bytes() {
for (a, b) in [(0u8, 0u8), (0, 255), (255, 0), (1, 2), (2, 1), (5, 5)] {
let input = [a, b];
let (l, origin) = bwt_forward(&input);
let back = bwt_inverse(&l, origin);
assert_eq!(back, input);
}
}
#[test]
fn round_trip_three_bytes() {
for a in 0..3u8 {
for b in 0..3u8 {
for c in 0..3u8 {
let input = [a, b, c];
let (l, origin) = bwt_forward(&input);
let back = bwt_inverse(&l, origin);
assert_eq!(back, input);
}
}
}
}
#[test]
fn round_trip_with_zero_bytes() {
let input: Vec<u8> = (0u8..=255).collect();
let (l, origin) = bwt_forward(&input);
let back = bwt_inverse(&l, origin);
assert_eq!(back, input);
}
#[test]
fn round_trip_many_zeros() {
let input = vec![0u8; 500];
let (l, origin) = bwt_forward(&input);
let back = bwt_inverse(&l, origin);
assert_eq!(back, input);
}
#[test]
fn round_trip_pseudo_random_4k() {
let mut data = Vec::with_capacity(4096);
let mut state: u32 = 0xDEAD_BEEF;
for _ in 0..4096 {
state = state.wrapping_mul(1_664_525).wrapping_add(1_013_904_223);
data.push((state >> 16) as u8);
}
let (l, origin) = bwt_forward(&data);
let back = bwt_inverse(&l, origin);
assert_eq!(back, data);
}
#[test]
fn matches_naive_on_small_inputs() {
fn naive(input: &[u8]) -> (Vec<u8>, u32) {
let n = input.len();
if n == 0 {
return (Vec::new(), 0);
}
let mut sa: Vec<usize> = (0..n).collect();
sa.sort_by(|&a, &b| {
for k in 0..n {
let ai = input[(a + k) % n];
let bi = input[(b + k) % n];
if ai != bi {
return ai.cmp(&bi);
}
}
core::cmp::Ordering::Equal
});
let mut l = Vec::with_capacity(n);
let mut origin = 0u32;
for (i, &s) in sa.iter().enumerate() {
let prev = if s == 0 { n - 1 } else { s - 1 };
l.push(input[prev]);
if s == 0 {
origin = i as u32;
}
}
(l, origin)
}
let cases: &[&[u8]] = &[
b"",
b"a",
b"ab",
b"ba",
b"abc",
b"cba",
b"banana",
b"mississippi",
b"the quick brown fox jumps over the lazy dog",
b"\0\0\0",
b"\xff\xff\xff",
b"\x00\xff\x00\xff\x00",
];
for &case in cases {
let (sa_l, sa_o) = bwt_forward(case);
let back = bwt_inverse(&sa_l, sa_o);
assert_eq!(back.as_slice(), case);
let (nl, no) = naive(case);
let nback = bwt_inverse(&nl, no);
assert_eq!(nback.as_slice(), case);
}
}
}