use crate::Index;
use crate::lcp::{LcpDispatch, Symbol};
use crate::limits::{LimitProvider, PlainText};
use rayon::join;
#[derive(Clone, Debug)]
pub struct Opts {
pub max_context: usize,
}
impl Default for Opts {
fn default() -> Self {
Self {
max_context: usize::MAX,
}
}
}
pub fn build_in_memory<S, I>(text: &[S]) -> Vec<I>
where
S: Symbol,
I: Index,
{
build_in_memory_with_opts(text, &Opts::default())
}
pub fn build_in_memory_with_opts<S, I>(text: &[S], opts: &Opts) -> Vec<I>
where
S: Symbol,
I: Index,
{
build_in_memory_with(text, &PlainText::new(text.len()), opts)
}
pub fn build_in_memory_with<S, I, L>(text: &[S], lp: &L, opts: &Opts) -> Vec<I>
where
S: Symbol,
I: Index,
L: LimitProvider,
{
let n = text.len();
let positions: Vec<I> = (0..n).map(I::from_usize).collect();
build_in_memory_for_positions_with(text, positions, lp, opts)
}
pub fn build_in_memory_for_positions<S, I>(text: &[S], positions: Vec<I>) -> Vec<I>
where
S: Symbol,
I: Index,
{
build_in_memory_for_positions_with_opts(text, positions, &Opts::default())
}
pub fn build_in_memory_for_positions_with_opts<S, I>(
text: &[S],
positions: Vec<I>,
opts: &Opts,
) -> Vec<I>
where
S: Symbol,
I: Index,
{
build_in_memory_for_positions_with(text, positions, &PlainText::new(text.len()), opts)
}
pub fn build_in_memory_for_positions_with<S, I, L>(
text: &[S],
positions: Vec<I>,
lp: &L,
opts: &Opts,
) -> Vec<I>
where
S: Symbol,
I: Index,
L: LimitProvider,
{
let n = positions.len();
if n == 0 {
return Vec::new();
}
let mut sa: Vec<I> = positions;
let mut sa_w: Vec<I> = vec![I::zero(); n];
let mut lcp_arr: Vec<I> = vec![I::zero(); n];
let mut lcp_w: Vec<I> = vec![I::zero(); n];
let dispatch = LcpDispatch::detect();
merge_sort(
text,
lp,
&mut sa,
&mut sa_w,
&mut lcp_arr,
&mut lcp_w,
opts.max_context,
dispatch,
);
sa
}
#[allow(clippy::too_many_arguments)] pub(crate) fn merge_sort<S, I, L>(
text: &[S],
lp: &L,
sa: &mut [I],
sa_w: &mut [I],
lcp_arr: &mut [I],
lcp_w: &mut [I],
max_ctx: usize,
dispatch: LcpDispatch,
) where
S: Symbol,
I: Index,
L: LimitProvider,
{
let n = sa.len();
debug_assert_eq!(sa_w.len(), n);
debug_assert_eq!(lcp_arr.len(), n);
debug_assert_eq!(lcp_w.len(), n);
if n <= 1 {
if n == 1 {
lcp_arr[0] = I::zero();
}
return;
}
let mid = n / 2;
let (sa_l, sa_r) = sa.split_at_mut(mid);
let (sa_w_l, sa_w_r) = sa_w.split_at_mut(mid);
let (lcp_l, lcp_r) = lcp_arr.split_at_mut(mid);
let (lcp_w_l, lcp_w_r) = lcp_w.split_at_mut(mid);
join(
|| merge_sort(text, lp, sa_l, sa_w_l, lcp_l, lcp_w_l, max_ctx, dispatch),
|| merge_sort(text, lp, sa_r, sa_w_r, lcp_r, lcp_w_r, max_ctx, dispatch),
);
merge(
text, lp, sa_l, sa_r, lcp_l, lcp_r, sa_w, lcp_w, max_ctx, dispatch,
);
sa.copy_from_slice(sa_w);
lcp_arr.copy_from_slice(lcp_w);
}
#[allow(clippy::too_many_arguments)] pub(crate) fn merge<S, I, L>(
text: &[S],
lp: &L,
x: &[I],
y: &[I],
lcp_x: &[I],
lcp_y: &[I],
z: &mut [I],
lcp_z: &mut [I],
max_ctx: usize,
dispatch: LcpDispatch,
) where
S: Symbol,
I: Index,
L: LimitProvider,
{
let len_x = x.len();
let len_y = y.len();
debug_assert_eq!(z.len(), len_x + len_y);
debug_assert_eq!(lcp_z.len(), len_x + len_y);
if len_x == 0 {
z.copy_from_slice(y);
lcp_z.copy_from_slice(lcp_y);
return;
}
if len_y == 0 {
z.copy_from_slice(x);
lcp_z.copy_from_slice(lcp_x);
return;
}
let mut arr_a: &[I] = x;
let mut arr_b: &[I] = y;
let mut lcp_a: &[I] = lcp_x;
let mut lcp_b: &[I] = lcp_y;
let mut len_a = len_x;
let mut len_b = len_y;
let mut i_a: usize = 0;
let mut i_b: usize = 0;
let mut m: usize = 0;
let mut k: usize = 0;
while i_a < len_a && i_b < len_b {
let l_a = lcp_a[i_a].to_usize();
let (output_a, lcp_for_output, new_m) = if l_a > m {
(true, l_a, m)
} else if l_a < m {
(false, m, l_a)
} else {
let p_a = arr_a[i_a].to_usize();
let p_b = arr_b[i_b].to_usize();
let lim_a = lp.lim_at(p_a);
let lim_b = lp.lim_at(p_b);
let cap = lim_a.min(lim_b).min(max_ctx);
let remaining_ctx = cap.saturating_sub(m);
let ext = dispatch.lcp(text, p_a + m, p_b + m, remaining_ctx);
let total = m + ext;
let a_smaller = if total < lim_a && total < lim_b {
text[p_a + total] < text[p_b + total]
} else {
lp.boundary_order(p_a, lim_a, p_b, lim_b).is_lt()
};
(a_smaller, m, total)
};
if output_a {
z[k] = arr_a[i_a];
lcp_z[k] = I::from_usize(lcp_for_output);
i_a += 1;
} else {
z[k] = arr_b[i_b];
lcp_z[k] = I::from_usize(lcp_for_output);
i_b += 1;
std::mem::swap(&mut arr_a, &mut arr_b);
std::mem::swap(&mut lcp_a, &mut lcp_b);
std::mem::swap(&mut len_a, &mut len_b);
std::mem::swap(&mut i_a, &mut i_b);
}
m = new_m;
k += 1;
}
drain(arr_a, lcp_a, i_a, len_a, z, lcp_z, &mut k, m);
drain(arr_b, lcp_b, i_b, len_b, z, lcp_z, &mut k, m);
}
#[inline]
#[allow(clippy::too_many_arguments)] fn drain<I: Index>(
arr: &[I],
lcp_src: &[I],
mut i: usize,
len: usize,
z: &mut [I],
lcp_z: &mut [I],
k: &mut usize,
boundary_m: usize,
) {
let mut first = true;
while i < len {
z[*k] = arr[i];
lcp_z[*k] = if first {
I::from_usize(boundary_m)
} else {
lcp_src[i]
};
first = false;
i += 1;
*k += 1;
}
}
#[cfg(test)]
mod tests {
use super::*;
fn brute_force_sa(text: &[u8]) -> Vec<u32> {
let mut sa: Vec<u32> = (0..text.len() as u32).collect();
sa.sort_by(|&a, &b| text[a as usize..].cmp(&text[b as usize..]));
sa
}
fn assert_matches_brute(text: &[u8]) {
let got: Vec<u32> = build_in_memory(text);
let want = brute_force_sa(text);
assert_eq!(got, want, "mismatch on text {text:?}");
}
#[test]
fn empty_text() {
let sa: Vec<u32> = build_in_memory::<u8, u32>(&[]);
assert!(sa.is_empty());
}
#[test]
fn single_symbol() {
let sa: Vec<u32> = build_in_memory(&[7u8]);
assert_eq!(sa, vec![0]);
}
#[test]
fn banana() {
assert_matches_brute(b"banana");
}
#[test]
fn mississippi() {
assert_matches_brute(b"mississippi");
}
#[test]
fn small_distinct_sentinel() {
let text: Vec<u8> = vec![0, 1, 2, 0, 1, 5, 0, 2, 1, 6];
let got: Vec<u32> = build_in_memory(&text);
let want = brute_force_sa(&text);
assert_eq!(got, want);
}
#[test]
fn random_byte_texts() {
use rand::{RngExt, SeedableRng};
let mut rng = rand::rngs::StdRng::seed_from_u64(0xC0FFEE);
for &n in &[1usize, 2, 3, 7, 33, 200, 1000] {
let text: Vec<u8> = (0..n).map(|_| rng.random_range(0..6u8)).collect();
let got: Vec<u32> = build_in_memory(&text);
let want = brute_force_sa(&text);
assert_eq!(got, want, "mismatch on random text len={n}");
}
}
#[test]
fn for_positions_full_set_matches_build_in_memory() {
let text = b"banana";
let want: Vec<u32> = build_in_memory(text);
let positions: Vec<u32> = (0..text.len() as u32).collect();
let got = build_in_memory_for_positions(text, positions);
assert_eq!(got, want);
}
#[test]
fn for_positions_subset_matches_brute_force() {
let text = b"mississippi";
let positions: Vec<u32> = (0..text.len() as u32).step_by(2).collect();
let mut want = positions.clone();
want.sort_by(|&a, &b| text[a as usize..].cmp(&text[b as usize..]));
let got = build_in_memory_for_positions(text, positions);
assert_eq!(got, want);
}
#[test]
fn for_positions_random_subsets() {
use rand::{RngExt, SeedableRng};
let mut rng = rand::rngs::StdRng::seed_from_u64(0xFEED);
for &n in &[33usize, 200, 1000] {
let text: Vec<u8> = (0..n).map(|_| rng.random_range(0..6u8)).collect();
let mut positions: Vec<u32> = (0..n as u32).collect();
positions.retain(|_| rng.random_range(0..10) < 7);
let mut want = positions.clone();
want.sort_by(|&a, &b| text[a as usize..].cmp(&text[b as usize..]));
let got = build_in_memory_for_positions(&text, positions);
assert_eq!(got, want, "subset sort mismatch n={n}");
}
}
#[test]
fn random_with_unique_terminator() {
use rand::{RngExt, SeedableRng};
let mut rng = rand::rngs::StdRng::seed_from_u64(0xBEEF);
for &n in &[1usize, 50, 500] {
let mut text: Vec<u8> = (0..n).map(|_| rng.random_range(0..5u8)).collect();
text.push(250); let got: Vec<u32> = build_in_memory(&text);
let want = brute_force_sa(&text);
assert_eq!(got, want);
}
}
use crate::limits::SegmentedText;
fn segmented_cmp(text: &[u8], lp: &SegmentedText, a: usize, b: usize) -> std::cmp::Ordering {
use crate::limits::LimitProvider;
let lim_a = lp.lim_at(a);
let lim_b = lp.lim_at(b);
let lim = lim_a.min(lim_b);
for i in 0..lim {
if text[a + i] != text[b + i] {
return text[a + i].cmp(&text[b + i]);
}
}
lim_a.cmp(&lim_b)
}
fn assert_segmented_sa_valid(text: &[u8], lengths: &[usize], positions: &[u32], sa: &[u32]) {
let lp = SegmentedText::from_lengths(text.len(), lengths);
let mut expected = positions.to_vec();
expected.sort();
let mut got_sorted = sa.to_vec();
got_sorted.sort();
assert_eq!(got_sorted, expected, "sa is not a permutation of positions");
for w in sa.windows(2) {
let a = w[0] as usize;
let b = w[1] as usize;
let ord = segmented_cmp(text, &lp, a, b);
assert_ne!(
ord,
std::cmp::Ordering::Greater,
"out of order: pos {a} > pos {b} under segmented comparator",
);
}
}
#[test]
fn segmented_in_memory_matches_brute_force_small() {
let text: Vec<u8> = b"helloworldbananamississippi".to_vec();
let lengths = &[5usize, 5, 6, 11];
let lp = SegmentedText::from_lengths(text.len(), lengths);
let sa: Vec<u32> = build_in_memory_with(&text, &lp, &Opts::default());
let all_positions: Vec<u32> = (0..text.len() as u32).collect();
assert_segmented_sa_valid(&text, lengths, &all_positions, &sa);
}
#[test]
fn segmented_single_segment_equals_unsegmented() {
let text = b"mississippi";
let lp = SegmentedText::from_lengths(text.len(), &[text.len()]);
let got_segmented: Vec<u32> = build_in_memory_with(text, &lp, &Opts::default());
let got_plain: Vec<u32> = build_in_memory(text);
assert_eq!(got_segmented, got_plain);
}
#[test]
fn segmented_random_validity() {
use rand::{RngExt, SeedableRng};
let mut rng = rand::rngs::StdRng::seed_from_u64(0x5E6);
for _ in 0..20 {
let n_segments = rng.random_range(1..10usize);
let lengths: Vec<usize> = (0..n_segments)
.map(|_| rng.random_range(5..50usize))
.collect();
let n: usize = lengths.iter().sum();
let text: Vec<u8> = (0..n).map(|_| rng.random_range(0..3u8)).collect();
let lp = SegmentedText::from_lengths(n, &lengths);
let sa: Vec<u32> = build_in_memory_with(&text, &lp, &Opts::default());
let all_positions: Vec<u32> = (0..n as u32).collect();
assert_segmented_sa_valid(&text, &lengths, &all_positions, &sa);
}
}
#[test]
fn segmented_for_positions_subset_validity() {
let text: Vec<u8> = b"helloworldbananamississippi".to_vec();
let lengths = &[5usize, 5, 6, 11];
let positions: Vec<u32> = (0..text.len() as u32).step_by(2).collect();
let lp = SegmentedText::from_lengths(text.len(), lengths);
let sa =
build_in_memory_for_positions_with(&text, positions.clone(), &lp, &Opts::default());
assert_segmented_sa_valid(&text, lengths, &positions, &sa);
}
struct StarConvention {
inner: SegmentedText,
}
impl crate::limits::LimitProvider for StarConvention {
fn lim_at(&self, p: usize) -> usize {
self.inner.lim_at(p)
}
fn boundary_order(
&self,
p_a: usize,
lim_a: usize,
p_b: usize,
lim_b: usize,
) -> std::cmp::Ordering {
lim_b.cmp(&lim_a).then(p_a.cmp(&p_b))
}
}
fn star_brute_force_sa(text: &[u8], lengths: &[usize]) -> Vec<u32> {
use crate::limits::LimitProvider;
let lp = SegmentedText::from_lengths(text.len(), lengths);
let mut sa: Vec<u32> = (0..text.len() as u32).collect();
sa.sort_by(|&a, &b| {
let pa = a as usize;
let pb = b as usize;
let lim_a = lp.lim_at(pa);
let lim_b = lp.lim_at(pb);
let lim = lim_a.min(lim_b);
for i in 0..lim {
if text[pa + i] != text[pb + i] {
return text[pa + i].cmp(&text[pb + i]);
}
}
lim_b.cmp(&lim_a).then(pa.cmp(&pb))
});
sa
}
#[test]
fn star_convention_matches_brute_force_small() {
let text: Vec<u8> = b"helloworldbananamississippi".to_vec();
let lengths = &[5usize, 5, 6, 11];
let lp = StarConvention {
inner: SegmentedText::from_lengths(text.len(), lengths),
};
let got: Vec<u32> = build_in_memory_with(&text, &lp, &Opts::default());
let want = star_brute_force_sa(&text, lengths);
assert_eq!(got, want, "STAR-convention SA mismatch");
}
#[test]
fn star_convention_within_segment_longer_first() {
let text = b"AAAA";
let lp = StarConvention {
inner: SegmentedText::from_lengths(text.len(), &[text.len()]),
};
let got: Vec<u32> = build_in_memory_with(text, &lp, &Opts::default());
assert_eq!(got, vec![0u32, 1, 2, 3]);
}
#[test]
fn star_convention_random_matches_brute_force() {
use rand::{RngExt, SeedableRng};
let mut rng = rand::rngs::StdRng::seed_from_u64(0xCAFE);
for _ in 0..20 {
let n_segments = rng.random_range(1..10usize);
let lengths: Vec<usize> = (0..n_segments)
.map(|_| rng.random_range(5..50usize))
.collect();
let n: usize = lengths.iter().sum();
let text: Vec<u8> = (0..n).map(|_| rng.random_range(0..3u8)).collect();
let lp = StarConvention {
inner: SegmentedText::from_lengths(n, &lengths),
};
let got: Vec<u32> = build_in_memory_with(&text, &lp, &Opts::default());
let want = star_brute_force_sa(&text, &lengths);
assert_eq!(
got, want,
"STAR-convention SA mismatch (lengths={lengths:?}, text={text:?})",
);
}
}
}