extern crate alloc;
use alloc::vec;
use alloc::vec::Vec;
use crate::error::Error;
pub(super) fn forward(block: &[u8]) -> (Vec<u8>, usize) {
let n = block.len();
debug_assert!(n >= 1);
let sa = sort_rotations(block);
let mut last_col = vec![0u8; n];
let mut primary = 0usize;
for (r, &start) in sa.iter().enumerate() {
let start = start as usize;
let prev = if start == 0 { n - 1 } else { start - 1 };
last_col[r] = block[prev];
if start == 0 {
primary = r;
}
}
(last_col, primary)
}
fn sort_rotations(block: &[u8]) -> Vec<u32> {
let n = block.len();
debug_assert!(n >= 1);
if n == 1 {
return vec![0u32];
}
let mut text: Vec<i32> = Vec::with_capacity(2 * n + 1);
for &b in block {
text.push(b as i32 + 1);
}
for &b in block {
text.push(b as i32 + 1);
}
text.push(0);
let sa = sa_is(&text, 257);
let mut order: Vec<u32> = Vec::with_capacity(n);
for &s in sa.iter() {
let s = s as usize;
if s < n {
order.push(s as u32);
}
}
debug_assert_eq!(order.len(), n);
let period = smallest_period(block);
if period < n {
let mut i = 0usize;
while i < n {
let base = order[i] as usize % period;
let mut j = i + 1;
while j < n && order[j] as usize % period == base {
j += 1;
}
if j - i > 1 {
order[i..j].sort_unstable();
}
i = j;
}
}
order
}
fn smallest_period(block: &[u8]) -> usize {
let n = block.len();
if n <= 1 {
return n;
}
let mut fail = vec![0usize; n];
let mut k = 0usize;
for i in 1..n {
while k > 0 && block[i] != block[k] {
k = fail[k - 1];
}
if block[i] == block[k] {
k += 1;
}
fail[i] = k;
}
let p = n - fail[n - 1];
if n.is_multiple_of(p) { p } else { n }
}
fn sa_is(text: &[i32], alphabet_size: usize) -> Vec<i32> {
let n = text.len();
let mut sa = vec![-1i32; n];
sa_is_inner(text, &mut sa, alphabet_size);
sa
}
fn sa_is_inner(text: &[i32], sa: &mut [i32], alphabet_size: usize) {
let n = text.len();
debug_assert_eq!(sa.len(), n);
if n == 0 {
return;
}
if n == 1 {
sa[0] = 0;
return;
}
if n == 2 {
if text[0] < text[1] {
sa[0] = 0;
sa[1] = 1;
} else {
sa[0] = 1;
sa[1] = 0;
}
return;
}
let mut t = vec![false; n];
t[n - 1] = true;
let mut lms_positions: Vec<i32> = Vec::new();
for i in (0..n - 1).rev() {
let si = match text[i].cmp(&text[i + 1]) {
core::cmp::Ordering::Less => true,
core::cmp::Ordering::Equal => t[i + 1],
core::cmp::Ordering::Greater => false,
};
t[i] = si;
if t[i + 1] && !si {
lms_positions.push((i + 1) as i32);
}
}
lms_positions.reverse();
let n1 = lms_positions.len();
let mut counts = vec![0i32; alphabet_size];
for &c in text {
counts[c as usize] += 1;
}
let mut bucket = vec![0i32; alphabet_size];
sa.fill(-1);
fill_bucket_ends(&counts, &mut bucket);
for &p in &lms_positions {
let c = text[p as usize] as usize;
bucket[c] -= 1;
sa[bucket[c] as usize] = p;
}
induce_sort_l(text, sa, &t, &counts, &mut bucket);
induce_sort_s(text, sa, &t, &counts, &mut bucket);
let mut j1 = 0usize;
for i in 0..n {
if sa[i] >= 0 && is_lms(&t, sa[i] as usize) {
sa[j1] = sa[i];
j1 += 1;
}
}
debug_assert_eq!(j1, n1);
for slot in sa.iter_mut().take(n).skip(n1) {
*slot = -1;
}
let mut name: i32 = 0;
let mut prev: i32 = -1;
for i in 0..n1 {
let pos = sa[i] as usize;
let mut diff = false;
if prev == -1 {
diff = true;
} else {
let p = prev as usize;
let mut d = 0usize;
loop {
if pos + d >= n || p + d >= n {
diff = true;
break;
}
if text[pos + d] != text[p + d] || t[pos + d] != t[p + d] {
diff = true;
break;
}
if d > 0 && (is_lms(&t, pos + d) || is_lms(&t, p + d)) {
if is_lms(&t, pos + d) != is_lms(&t, p + d) {
diff = true;
}
break;
}
d += 1;
}
}
if diff {
name += 1;
prev = pos as i32;
}
sa[n1 + pos / 2] = name - 1;
}
let mut j = n - 1;
for i in (n1..n).rev() {
if sa[i] >= 0 {
sa[j] = sa[i];
j -= 1;
}
}
let new_alpha = (name as usize) + 1;
let (sa1_area, t1_area) = sa.split_at_mut(n - n1);
if (name as usize) == n1 {
for (i, &name_of_pos) in t1_area.iter().enumerate() {
sa1_area[name_of_pos as usize] = i as i32;
}
} else {
let reduced_text: &[i32] = &t1_area[..n1];
let sa1 = &mut sa1_area[..n1];
sa_is_inner(reduced_text, sa1, new_alpha);
}
for slot in sa.iter_mut().take(n1) {
let idx = *slot as usize;
*slot = lms_positions[idx];
}
for slot in sa.iter_mut().take(n).skip(n1) {
*slot = -1;
}
let mut lms_sorted: Vec<i32> = Vec::with_capacity(n1);
lms_sorted.extend_from_slice(&sa[..n1]);
for slot in sa.iter_mut().take(n) {
*slot = -1;
}
fill_bucket_ends(&counts, &mut bucket);
for &pos in lms_sorted.iter().rev() {
let c = text[pos as usize] as usize;
bucket[c] -= 1;
sa[bucket[c] as usize] = pos;
}
induce_sort_l(text, sa, &t, &counts, &mut bucket);
induce_sort_s(text, sa, &t, &counts, &mut bucket);
}
#[inline(always)]
fn is_lms(t: &[bool], i: usize) -> bool {
i > 0 && t[i] && !t[i - 1]
}
#[inline]
fn fill_bucket_starts(counts: &[i32], out: &mut [i32]) {
let mut acc = 0i32;
for (o, &c) in out.iter_mut().zip(counts.iter()) {
*o = acc;
acc += c;
}
}
#[inline]
fn fill_bucket_ends(counts: &[i32], out: &mut [i32]) {
let mut acc = 0i32;
for (o, &c) in out.iter_mut().zip(counts.iter()) {
acc += c;
*o = acc;
}
}
fn induce_sort_l(text: &[i32], sa: &mut [i32], t: &[bool], counts: &[i32], bucket: &mut [i32]) {
let n = text.len();
fill_bucket_starts(counts, bucket);
for i in 0..n {
let v = sa[i];
if v <= 0 {
continue;
}
let j = (v as usize) - 1;
if !t[j] {
let c = text[j] as usize;
let slot = bucket[c];
sa[slot as usize] = j as i32;
bucket[c] = slot + 1;
}
}
}
fn induce_sort_s(text: &[i32], sa: &mut [i32], t: &[bool], counts: &[i32], bucket: &mut [i32]) {
let n = text.len();
fill_bucket_ends(counts, bucket);
for i in (0..n).rev() {
let v = sa[i];
if v <= 0 {
continue;
}
let j = (v as usize) - 1;
if t[j] {
let c = text[j] as usize;
let slot = bucket[c] - 1;
bucket[c] = slot;
sa[slot as usize] = j as i32;
}
}
}
pub(super) fn inverse(last_col: &[u8], primary: usize, out: &mut Vec<u8>) -> Result<(), Error> {
let n = last_col.len();
if n == 0 || primary >= n {
return Err(Error::Corrupt);
}
let mut counts = [0usize; 256];
for &b in last_col {
counts[b as usize] += 1;
}
let mut start = [0usize; 256];
let mut acc = 0usize;
for c in 0..256 {
start[c] = acc;
acc += counts[c];
}
let mut next = vec![0u32; n];
let mut cursor = start; for (i, &b) in last_col.iter().enumerate() {
let c = b as usize;
next[cursor[c]] = i as u32;
cursor[c] += 1;
}
out.reserve(n);
let mut p = next[primary] as usize;
for _ in 0..n {
out.push(last_col[p]);
p = next[p] as usize;
}
Ok(())
}
#[cfg(test)]
mod transform_tests {
use super::*;
use alloc::vec::Vec;
fn reference_forward(block: &[u8]) -> (Vec<u8>, usize) {
let n = block.len();
let mut idx: Vec<usize> = (0..n).collect();
idx.sort_by(|&a, &b| {
for off in 0..n {
let ca = block[(a + off) % n];
let cb = block[(b + off) % n];
if ca != cb {
return ca.cmp(&cb);
}
}
core::cmp::Ordering::Equal
});
let mut last = Vec::with_capacity(n);
let mut primary = 0;
for (r, &start) in idx.iter().enumerate() {
let prev = if start == 0 { n - 1 } else { start - 1 };
last.push(block[prev]);
if start == 0 {
primary = r;
}
}
(last, primary)
}
fn roundtrip(block: &[u8]) {
let (l, p) = forward(block);
assert_eq!(l.len(), block.len());
let mut out = Vec::new();
inverse(&l, p, &mut out).unwrap();
assert_eq!(out, block, "roundtrip mismatch for {block:?}");
}
#[test]
fn forward_matches_reference() {
let cases: &[&[u8]] = &[
b"banana",
b"mississippi",
b"abracadabra",
b"aaaaaa",
b"a",
b"ab",
b"ba",
b"the quick brown fox jumps over the lazy dog",
&[0, 0, 0, 1, 0, 0],
&[255, 0, 255, 0, 255],
];
for &c in cases {
let fast = forward(c);
let reference = reference_forward(c);
assert_eq!(fast, reference, "forward mismatch for {c:?}");
roundtrip(c);
}
}
#[test]
fn single_byte() {
let (l, p) = forward(b"Z");
assert_eq!(l, b"Z");
assert_eq!(p, 0);
roundtrip(b"Z");
}
#[test]
fn all_same_byte() {
let block = [7u8; 64];
let (l, p) = forward(&block);
assert_eq!(l, block);
assert_eq!(p, 0);
roundtrip(&block);
}
#[test]
fn inverse_rejects_bad_primary() {
let mut out = Vec::new();
assert_eq!(inverse(b"abc", 3, &mut out), Err(Error::Corrupt));
assert_eq!(inverse(b"", 0, &mut out), Err(Error::Corrupt));
}
#[test]
fn banana_known_last_column() {
let (l, p) = forward(b"banana");
assert_eq!(l, b"nnbaaa");
assert_eq!(p, 3);
}
}