use std::cmp::Ordering;
use std::io;
use std::path::{Path, PathBuf};
use std::sync::Mutex;
use std::time::Instant;
use rayon::prelude::*;
use crate::Index;
use crate::ext_bucket::{BucketPool, BucketRecord, BucketStore, InMemBucket, SaLcp};
use crate::lcp::{LcpDispatch, Symbol};
use crate::limits::{LimitProvider, PlainText};
use crate::sample_sort;
fn profile_log(message: &str) {
if std::env::var_os("CAPS_SA_PROFILE").is_some() {
eprintln!("caps-sa profile {message}");
}
}
#[derive(Clone, Debug)]
pub struct ExtMemOpts {
pub max_context: usize,
pub subproblem_count: usize,
pub work_dir: PathBuf,
pub physical_file_count: usize,
}
impl Default for ExtMemOpts {
fn default() -> Self {
Self {
max_context: usize::MAX,
subproblem_count: 0,
work_dir: std::env::temp_dir(),
physical_file_count: 0,
}
}
}
impl ExtMemOpts {
pub fn with_work_dir(work_dir: impl AsRef<Path>) -> Self {
Self {
work_dir: work_dir.as_ref().to_path_buf(),
..Self::default()
}
}
}
pub fn build_ext_mem<S, F>(text: &[S], opts: &ExtMemOpts, emit: F) -> io::Result<()>
where
S: Symbol,
F: FnMut(u64) -> io::Result<()>,
{
build_ext_mem_with(text, &PlainText::new(text.len()), opts, emit)
}
pub fn build_ext_mem_with<S, L, F>(text: &[S], lp: &L, opts: &ExtMemOpts, emit: F) -> io::Result<()>
where
S: Symbol,
L: LimitProvider,
F: FnMut(u64) -> io::Result<()>,
{
if text.len() <= u32::MAX as usize + 1 {
build_ext_mem_inner::<S, u32, L, F>(
text,
PositionSource::Identity(text.len()),
lp,
opts,
emit,
)
} else {
build_ext_mem_inner::<S, u64, L, F>(
text,
PositionSource::Identity(text.len()),
lp,
opts,
emit,
)
}
}
pub fn build_ext_mem_for_positions<S, F>(
text: &[S],
positions: Vec<u64>,
opts: &ExtMemOpts,
emit: F,
) -> io::Result<()>
where
S: Symbol,
F: FnMut(u64) -> io::Result<()>,
{
build_ext_mem_for_positions_with(text, positions, &PlainText::new(text.len()), opts, emit)
}
pub fn build_ext_mem_for_positions_with<S, L, F>(
text: &[S],
positions: Vec<u64>,
lp: &L,
opts: &ExtMemOpts,
emit: F,
) -> io::Result<()>
where
S: Symbol,
L: LimitProvider,
F: FnMut(u64) -> io::Result<()>,
{
if text.len() <= u32::MAX as usize + 1 {
build_ext_mem_inner::<S, u32, L, F>(
text,
PositionSource::Subset(&positions),
lp,
opts,
emit,
)
} else {
build_ext_mem_inner::<S, u64, L, F>(
text,
PositionSource::Subset(&positions),
lp,
opts,
emit,
)
}
}
pub fn build_ext_mem_for_filter<S, F, Pred>(
text: &[S],
keep: Pred,
opts: &ExtMemOpts,
emit: F,
) -> io::Result<()>
where
S: Symbol,
F: FnMut(u64) -> io::Result<()>,
Pred: Fn(u64) -> bool + Send + Sync,
{
build_ext_mem_for_filter_with(text, keep, &PlainText::new(text.len()), opts, emit)
}
pub fn build_ext_mem_for_filter_with<S, L, F, Pred>(
text: &[S],
keep: Pred,
lp: &L,
opts: &ExtMemOpts,
emit: F,
) -> io::Result<()>
where
S: Symbol,
L: LimitProvider,
F: FnMut(u64) -> io::Result<()>,
Pred: Fn(u64) -> bool + Send + Sync,
{
let filtered = FilteredSource::new(text.len(), keep);
if text.len() <= u32::MAX as usize + 1 {
build_ext_mem_inner::<S, u32, L, F>(
text,
PositionSource::Filtered(filtered),
lp,
opts,
emit,
)
} else {
build_ext_mem_inner::<S, u64, L, F>(
text,
PositionSource::Filtered(filtered),
lp,
opts,
emit,
)
}
}
fn build_ext_mem_inner<S, I, L, F>(
text: &[S],
source: PositionSource<'_>,
lp: &L,
opts: &ExtMemOpts,
mut emit: F,
) -> io::Result<()>
where
S: Symbol,
I: Index,
L: LimitProvider,
SaLcp<I>: BucketRecord,
F: FnMut(u64) -> io::Result<()>,
{
let n = source.len();
if n == 0 {
return Ok(());
}
let p = effective_subproblem_count(n, opts.subproblem_count);
let dispatch = LcpDispatch::detect();
let work_dir = opts.work_dir.clone();
let n_phys = effective_physical_file_count(opts.physical_file_count);
let phase1_pool = BucketPool::new(n_phys, &work_dir)?;
let phase3_pool = BucketPool::new(n_phys, &work_dir)?;
profile_log(&format!(
"build_ext_mem n={n} p={p} index_width={}b n_phys={n_phys}",
std::mem::size_of::<I>() * 8
));
let sub_factory = |i: usize| phase1_pool.new_bucket::<SaLcp<I>>(i);
let part_factory = |j: usize| phase3_pool.new_bucket::<SaLcp<I>>(j);
let t = Instant::now();
let (mut subarray_buckets, samples) = phase1_sort_sample_spill::<S, I, L, _, _>(
text,
lp,
&source,
p,
opts,
dispatch,
sub_factory,
)?;
profile_log(&format!(
"phase1 (sort+sample+spill) {:.3}s",
t.elapsed().as_secs_f64()
));
drop(source);
let t = Instant::now();
let pivots = phase2_select_pivots::<S, I, L>(text, lp, samples, p, opts.max_context, dispatch);
profile_log(&format!(
"phase2 (select pivots) {:.3}s",
t.elapsed().as_secs_f64()
));
let t = Instant::now();
let mut partition_buckets = phase3_distribute::<S, I, L, _, _>(
text,
lp,
&mut subarray_buckets,
&pivots,
p,
opts,
dispatch,
part_factory,
)?;
profile_log(&format!(
"phase3 (distribute) {:.3}s",
t.elapsed().as_secs_f64()
));
drop(subarray_buckets);
let t = Instant::now();
let result = phase4_merge_and_emit::<S, I, L, _, F>(
text,
lp,
&mut partition_buckets,
opts.max_context,
&mut emit,
dispatch,
);
profile_log(&format!(
"phase4 (merge+emit) {:.3}s",
t.elapsed().as_secs_f64()
));
result
}
fn build_in_memory_ss_inner<S, I, L, F>(
text: &[S],
source: PositionSource<'_>,
lp: &L,
opts: &ExtMemOpts,
mut emit: F,
) -> io::Result<()>
where
S: Symbol,
I: Index,
L: LimitProvider,
SaLcp<I>: BucketRecord,
F: FnMut(u64) -> io::Result<()>,
{
let n = source.len();
if n == 0 {
return Ok(());
}
let p = effective_subproblem_count(n, opts.subproblem_count);
let dispatch = LcpDispatch::detect();
let factory = |_i: usize| InMemBucket::<SaLcp<I>>::new();
let (mut subarray_buckets, samples) =
phase1_sort_sample_spill::<S, I, L, _, _>(text, lp, &source, p, opts, dispatch, factory)?;
drop(source);
let pivots = phase2_select_pivots::<S, I, L>(text, lp, samples, p, opts.max_context, dispatch);
let mut partition_buckets = phase3_distribute::<S, I, L, _, _>(
text,
lp,
&mut subarray_buckets,
&pivots,
p,
opts,
dispatch,
factory,
)?;
drop(subarray_buckets);
phase4_merge_and_emit::<S, I, L, _, F>(
text,
lp,
&mut partition_buckets,
opts.max_context,
&mut emit,
dispatch,
)
}
pub fn build_in_memory_sample_sort<S, F>(text: &[S], opts: &ExtMemOpts, emit: F) -> io::Result<()>
where
S: Symbol,
F: FnMut(u64) -> io::Result<()>,
{
build_in_memory_sample_sort_with(text, &PlainText::new(text.len()), opts, emit)
}
pub fn build_in_memory_sample_sort_with<S, L, F>(
text: &[S],
lp: &L,
opts: &ExtMemOpts,
emit: F,
) -> io::Result<()>
where
S: Symbol,
L: LimitProvider,
F: FnMut(u64) -> io::Result<()>,
{
if text.len() <= u32::MAX as usize + 1 {
build_in_memory_ss_inner::<S, u32, L, F>(
text,
PositionSource::Identity(text.len()),
lp,
opts,
emit,
)
} else {
build_in_memory_ss_inner::<S, u64, L, F>(
text,
PositionSource::Identity(text.len()),
lp,
opts,
emit,
)
}
}
pub fn build_in_memory_sample_sort_for_positions<S, F>(
text: &[S],
positions: Vec<u64>,
opts: &ExtMemOpts,
emit: F,
) -> io::Result<()>
where
S: Symbol,
F: FnMut(u64) -> io::Result<()>,
{
build_in_memory_sample_sort_for_positions_with(
text,
positions,
&PlainText::new(text.len()),
opts,
emit,
)
}
pub fn build_in_memory_sample_sort_for_positions_with<S, L, F>(
text: &[S],
positions: Vec<u64>,
lp: &L,
opts: &ExtMemOpts,
emit: F,
) -> io::Result<()>
where
S: Symbol,
L: LimitProvider,
F: FnMut(u64) -> io::Result<()>,
{
if text.len() <= u32::MAX as usize + 1 {
build_in_memory_ss_inner::<S, u32, L, F>(
text,
PositionSource::Subset(&positions),
lp,
opts,
emit,
)
} else {
build_in_memory_ss_inner::<S, u64, L, F>(
text,
PositionSource::Subset(&positions),
lp,
opts,
emit,
)
}
}
enum PositionSource<'a> {
Identity(usize),
Subset(&'a [u64]),
Filtered(FilteredSource),
}
const FILTERED_WORDS_PER_BLOCK: usize = 1024;
struct FilteredSource {
text_len: usize,
total_kept: usize,
bitmap: Vec<u64>,
cumsum: Vec<u64>,
}
impl FilteredSource {
fn new<Pred>(text_len: usize, keep: Pred) -> Self
where
Pred: Fn(u64) -> bool + Send + Sync,
{
let n_words = text_len.div_ceil(64);
let bitmap: Vec<u64> = (0..n_words)
.into_par_iter()
.map(|w| {
let mut word: u64 = 0;
let base = (w as u64) * 64;
let limit = ((w + 1) * 64).min(text_len) - w * 64;
for b in 0..limit {
if keep(base + b as u64) {
word |= 1u64 << b;
}
}
word
})
.collect();
let n_blocks = n_words.div_ceil(FILTERED_WORDS_PER_BLOCK);
let per_block: Vec<u64> = (0..n_blocks)
.into_par_iter()
.map(|i| {
let start = i * FILTERED_WORDS_PER_BLOCK;
let end = ((i + 1) * FILTERED_WORDS_PER_BLOCK).min(n_words);
let mut c: u64 = 0;
for &word in &bitmap[start..end] {
c += word.count_ones() as u64;
}
c
})
.collect();
let mut cumsum = Vec::with_capacity(n_blocks + 1);
let mut s: u64 = 0;
cumsum.push(0);
for &k in &per_block {
s += k;
cumsum.push(s);
}
let total_kept = s as usize;
Self {
text_len,
total_kept,
bitmap,
cumsum,
}
}
#[inline]
fn len(&self) -> usize {
self.total_kept
}
fn fill_chunk<I: Index>(&self, start: usize, dst: &mut [I]) {
debug_assert!(start + dst.len() <= self.total_kept);
if dst.is_empty() {
return;
}
let pp = self.cumsum.partition_point(|&c| c <= start as u64);
debug_assert!(pp > 0);
let block_idx = pp - 1;
let mut word_idx = block_idx * FILTERED_WORDS_PER_BLOCK;
let mut skip = start as u64 - self.cumsum[block_idx];
let n_words = self.bitmap.len();
let mut word: u64 = if word_idx < n_words { self.bitmap[word_idx] } else { 0 };
while skip > 0 {
let pc = word.count_ones() as u64;
if skip < pc {
for _ in 0..skip {
word &= word - 1;
}
break;
}
skip -= pc;
word_idx += 1;
word = if word_idx < n_words { self.bitmap[word_idx] } else { 0 };
}
let mut written = 0usize;
let need = dst.len();
loop {
while word != 0 && written < need {
let bit = word.trailing_zeros() as u64;
let pos = (word_idx as u64) * 64 + bit;
debug_assert!((pos as usize) < self.text_len);
dst[written] = I::from_usize(pos as usize);
written += 1;
word &= word - 1;
}
if written == need {
break;
}
word_idx += 1;
debug_assert!(
word_idx < n_words,
"FilteredSource::fill_chunk: walked past bitmap end \
({written}/{need} emitted, word_idx={word_idx}, n_words={n_words})"
);
word = self.bitmap[word_idx];
}
}
}
impl<'a> PositionSource<'a> {
fn len(&self) -> usize {
match self {
Self::Identity(n) => *n,
Self::Subset(p) => p.len(),
Self::Filtered(f) => f.len(),
}
}
fn fill_chunk<I: Index>(&self, start: usize, dst: &mut [I]) {
match self {
Self::Identity(_) => {
for (i, slot) in dst.iter_mut().enumerate() {
*slot = I::from_usize(start + i);
}
}
Self::Subset(p) => {
let end = start + dst.len();
for (slot, &v) in dst.iter_mut().zip(p[start..end].iter()) {
*slot = I::from_usize(v as usize);
}
}
Self::Filtered(f) => f.fill_chunk(start, dst),
}
}
}
const PHASE1_TARGET_CHUNK: usize = 65_536;
const PHASE1_MAX_PARTITIONS: usize = 8192;
fn effective_physical_file_count(requested: usize) -> usize {
if let Some(v) = std::env::var("CAPS_SA_N_PHYS")
.ok()
.and_then(|s| s.parse::<usize>().ok())
.filter(|&v| v >= 1)
{
return v;
}
if requested >= 1 {
return requested;
}
rayon::current_num_threads().max(1)
}
fn effective_subproblem_count(n: usize, requested: usize) -> usize {
if n == 0 {
return 0;
}
let raw = if requested == 0 {
let nthreads = rayon::current_num_threads().max(1);
let p_from_size = n.div_ceil(PHASE1_TARGET_CHUNK);
p_from_size.clamp(nthreads, PHASE1_MAX_PARTITIONS)
} else {
requested
};
raw.clamp(1, n)
}
#[allow(clippy::too_many_arguments)]
fn phase1_sort_sample_spill<S, I, L, B, MkB>(
text: &[S],
lp: &L,
source: &PositionSource<'_>,
p: usize,
opts: &ExtMemOpts,
dispatch: LcpDispatch,
mk_bucket: MkB,
) -> io::Result<(Vec<B>, Vec<I>)>
where
S: Symbol,
I: Index,
L: LimitProvider,
SaLcp<I>: BucketRecord,
B: BucketStore<SaLcp<I>> + Send,
MkB: Fn(usize) -> B + Send + Sync,
{
let n = source.len();
let chunk_size = n.div_ceil(p);
let samples_target_total = sample_target_total(n, p);
let per_subarray: Vec<(B, Vec<I>)> = (0..p)
.into_par_iter()
.map(|i| {
let start = (i * chunk_size).min(n);
let end = ((i + 1) * chunk_size).min(n);
let len = end - start;
let mut bucket = mk_bucket(i);
if len == 0 {
return Ok::<_, io::Error>((bucket, Vec::new()));
}
let mut sa: Vec<I> = vec![I::zero(); len];
source.fill_chunk(start, &mut sa);
let mut sa_w = vec![I::zero(); len];
let mut lcp_arr = vec![I::zero(); len];
let mut lcp_w = vec![I::zero(); len];
sample_sort::merge_sort(
text,
lp,
&mut sa,
&mut sa_w,
&mut lcp_arr,
&mut lcp_w,
opts.max_context,
dispatch,
);
let samples_per_subarray = samples_target_total.div_ceil(p).min(len);
let samples = evenly_spaced(&sa, samples_per_subarray);
let records: Vec<SaLcp<I>> = sa
.iter()
.zip(lcp_arr.iter())
.map(|(&pos, &lcp)| SaLcp { pos, lcp })
.collect();
bucket.add_slice(&records)?;
Ok((bucket, samples))
})
.collect::<Result<Vec<_>, _>>()?;
let mut buckets = Vec::with_capacity(p);
let mut all_samples = Vec::with_capacity(samples_target_total);
for (bucket, samples) in per_subarray {
buckets.push(bucket);
all_samples.extend(samples);
}
Ok((buckets, all_samples))
}
fn sample_target_total(n: usize, p: usize) -> usize {
let ln_n = (n as f64).ln().max(1.0);
let per = (4.0 * ln_n).ceil() as usize;
p.saturating_mul(per).clamp(p, n)
}
fn evenly_spaced<T: Copy>(xs: &[T], count: usize) -> Vec<T> {
let n = xs.len();
if count == 0 || n == 0 {
return Vec::new();
}
if count >= n {
return xs.to_vec();
}
(0..count)
.map(|i| xs[(2 * i + 1) * n / (2 * count)])
.collect()
}
fn phase2_select_pivots<S, I, L>(
text: &[S],
lp: &L,
mut samples: Vec<I>,
p: usize,
max_ctx: usize,
dispatch: LcpDispatch,
) -> Vec<I>
where
S: Symbol,
I: Index,
L: LimitProvider,
{
if p <= 1 || samples.is_empty() {
return Vec::new();
}
let n_samples = samples.len();
let mut sa_w = vec![I::zero(); n_samples];
let mut lcp = vec![I::zero(); n_samples];
let mut lcp_w = vec![I::zero(); n_samples];
sample_sort::merge_sort(
text,
lp,
&mut samples,
&mut sa_w,
&mut lcp,
&mut lcp_w,
max_ctx,
dispatch,
);
(1..p).map(|j| samples[(j * n_samples) / p]).collect()
}
#[allow(clippy::too_many_arguments)]
fn phase3_distribute<S, I, L, B, MkB>(
text: &[S],
lp: &L,
subarray_buckets: &mut [B],
pivots: &[I],
p: usize,
opts: &ExtMemOpts,
dispatch: LcpDispatch,
mk_bucket: MkB,
) -> io::Result<Vec<B>>
where
S: Symbol,
I: Index,
L: LimitProvider,
SaLcp<I>: BucketRecord,
B: BucketStore<SaLcp<I>> + Send,
MkB: Fn(usize) -> B + Send + Sync,
{
let _ = opts; let partition_buckets: Vec<Mutex<B>> = (0..p).map(|j| Mutex::new(mk_bucket(j))).collect();
subarray_buckets
.par_iter_mut()
.try_for_each(|sub_bucket| -> io::Result<()> {
if sub_bucket.total_records() == 0 {
return Ok(());
}
let records = sub_bucket.load_all()?;
let mut splits = Vec::with_capacity(p + 1);
splits.push(0usize);
for &pivot in pivots {
splits.push(upper_bound_by_pivot(
&records,
pivot,
text,
lp,
opts.max_context,
dispatch,
));
}
splits.push(records.len());
for j in 0..p {
let lo = splits[j];
let hi = splits[j + 1];
if lo >= hi {
continue;
}
let mut sub: Vec<SaLcp<I>> = records[lo..hi].to_vec();
sub[0].lcp = I::zero();
let mut bucket = partition_buckets[j].lock().unwrap();
bucket.add_slice(&sub)?;
bucket.mark_boundary();
}
Ok(())
})?;
Ok(partition_buckets
.into_iter()
.map(|m| m.into_inner().expect("partition mutex poisoned"))
.collect())
}
fn upper_bound_by_pivot<S, I, L>(
records: &[SaLcp<I>],
pivot: I,
text: &[S],
lp: &L,
max_ctx: usize,
dispatch: LcpDispatch,
) -> usize
where
S: Symbol,
I: Index,
L: LimitProvider,
{
let mut lo = 0;
let mut hi = records.len();
while lo < hi {
let mid = lo + (hi - lo) / 2;
match dispatch.suffix_cmp_with(
text,
lp,
records[mid].pos.to_usize(),
pivot.to_usize(),
max_ctx,
) {
Ordering::Greater => hi = mid,
Ordering::Equal | Ordering::Less => lo = mid + 1,
}
}
lo
}
fn phase4_merge_and_emit<S, I, L, B, F>(
text: &[S],
lp: &L,
partition_buckets: &mut [B],
max_ctx: usize,
emit: &mut F,
dispatch: LcpDispatch,
) -> io::Result<()>
where
S: Symbol,
I: Index,
L: LimitProvider,
SaLcp<I>: BucketRecord,
B: BucketStore<SaLcp<I>> + Send,
F: FnMut(u64) -> io::Result<()>,
{
let n_partitions = partition_buckets.len();
if n_partitions == 0 {
return Ok(());
}
let chunk_size = rayon::current_num_threads().max(1) * 4;
use std::sync::atomic::{AtomicU64, Ordering as AtomicOrdering};
let profile = std::env::var_os("CAPS_SA_PROFILE").is_some();
let load_us = AtomicU64::new(0);
let merge_us = AtomicU64::new(0);
let mut emit_secs: f64 = 0.0;
let mut start = 0;
while start < n_partitions {
let end = (start + chunk_size).min(n_partitions);
let chunk = &mut partition_buckets[start..end];
let merged: Vec<Vec<I>> = chunk
.par_iter_mut()
.map(|bucket| -> io::Result<Vec<I>> {
if bucket.total_records() == 0 {
return Ok(Vec::new());
}
let t = Instant::now();
let records = bucket.load_all()?;
let boundaries: Vec<usize> = bucket.boundaries().to_vec();
if profile {
load_us.fetch_add(t.elapsed().as_micros() as u64, AtomicOrdering::Relaxed);
}
let t = Instant::now();
let workspace = CascadeWorkspace::<I>::new();
let result =
workspace.cascade_merge(text, lp, &records, &boundaries, max_ctx, dispatch);
if profile {
merge_us.fetch_add(t.elapsed().as_micros() as u64, AtomicOrdering::Relaxed);
}
Ok(result)
})
.collect::<Result<Vec<_>, io::Error>>()?;
let t = Instant::now();
for positions in merged {
for pos in positions {
emit(pos.to_usize() as u64)?;
}
}
if profile {
emit_secs += t.elapsed().as_secs_f64();
}
start = end;
}
if profile {
profile_log(&format!(
"phase4 breakdown CPU: load {:.3}s merge {:.3}s; wall emit {:.3}s",
load_us.load(AtomicOrdering::Relaxed) as f64 * 1e-6,
merge_us.load(AtomicOrdering::Relaxed) as f64 * 1e-6,
emit_secs,
));
}
Ok(())
}
struct CascadeWorkspace<I> {
a_sa: Vec<I>,
a_lcp: Vec<I>,
b_sa: Vec<I>,
b_lcp: Vec<I>,
}
impl<I: Index> CascadeWorkspace<I> {
fn new() -> Self {
Self {
a_sa: Vec::new(),
a_lcp: Vec::new(),
b_sa: Vec::new(),
b_lcp: Vec::new(),
}
}
fn ensure_capacity(&mut self, n: usize) {
if self.a_sa.len() < n {
self.a_sa.resize(n, I::zero());
self.a_lcp.resize(n, I::zero());
self.b_sa.resize(n, I::zero());
self.b_lcp.resize(n, I::zero());
}
}
fn cascade_merge<S, L>(
mut self,
text: &[S],
lp: &L,
records: &[SaLcp<I>],
boundaries: &[usize],
max_ctx: usize,
dispatch: LcpDispatch,
) -> Vec<I>
where
S: Symbol,
L: LimitProvider,
{
let n = records.len();
if n == 0 {
return Vec::new();
}
self.ensure_capacity(n);
let mut run_lens: Vec<usize> = boundaries
.windows(2)
.filter_map(|w| {
let l = w[1] - w[0];
if l > 0 { Some(l) } else { None }
})
.collect();
for (i, r) in records.iter().enumerate() {
self.a_sa[i] = r.pos;
self.a_lcp[i] = r.lcp;
}
let mut src_is_a = true;
while run_lens.len() > 1 {
run_lens = self.merge_one_level(src_is_a, &run_lens, text, lp, max_ctx, dispatch);
src_is_a = !src_is_a;
}
let mut result = if src_is_a { self.a_sa } else { self.b_sa };
result.truncate(n);
result
}
fn merge_one_level<S, L>(
&mut self,
src_is_a: bool,
run_lens: &[usize],
text: &[S],
lp: &L,
max_ctx: usize,
dispatch: LcpDispatch,
) -> Vec<usize>
where
S: Symbol,
L: LimitProvider,
{
let Self {
a_sa,
a_lcp,
b_sa,
b_lcp,
} = self;
let (src_sa, src_lcp, dst_sa, dst_lcp) = if src_is_a {
(
a_sa.as_slice(),
a_lcp.as_slice(),
b_sa.as_mut_slice(),
b_lcp.as_mut_slice(),
)
} else {
(
b_sa.as_slice(),
b_lcp.as_slice(),
a_sa.as_mut_slice(),
a_lcp.as_mut_slice(),
)
};
let mut new_lens = Vec::with_capacity(run_lens.len().div_ceil(2));
let mut src_off = 0usize;
let mut dst_off = 0usize;
let mut i = 0;
while i < run_lens.len() {
let l1 = run_lens[i];
if i + 1 < run_lens.len() {
let l2 = run_lens[i + 1];
let x_end = src_off + l1;
let xy_end = x_end + l2;
let dst_end = dst_off + l1 + l2;
sample_sort::merge(
text,
lp,
&src_sa[src_off..x_end],
&src_sa[x_end..xy_end],
&src_lcp[src_off..x_end],
&src_lcp[x_end..xy_end],
&mut dst_sa[dst_off..dst_end],
&mut dst_lcp[dst_off..dst_end],
max_ctx,
dispatch,
);
new_lens.push(l1 + l2);
src_off = xy_end;
dst_off = dst_end;
i += 2;
} else {
let end = dst_off + l1;
dst_sa[dst_off..end].copy_from_slice(&src_sa[src_off..src_off + l1]);
dst_lcp[dst_off..end].copy_from_slice(&src_lcp[src_off..src_off + l1]);
new_lens.push(l1);
src_off += l1;
dst_off = end;
i += 1;
}
}
new_lens
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::build_in_memory;
use tempfile::tempdir;
fn ext_mem_sa(text: &[u8], p: usize) -> Vec<u64> {
let dir = tempdir().unwrap();
let opts = ExtMemOpts {
subproblem_count: p,
work_dir: dir.path().to_path_buf(),
..ExtMemOpts::default()
};
let mut out: Vec<u64> = Vec::with_capacity(text.len());
build_ext_mem(text, &opts, |pos| {
out.push(pos);
Ok(())
})
.unwrap();
out
}
fn assert_matches_in_memory(text: &[u8], p: usize) {
let want: Vec<u32> = build_in_memory(text);
let want64: Vec<u64> = want.iter().map(|&x| x as u64).collect();
let got = ext_mem_sa(text, p);
assert_eq!(got, want64, "mismatch on text {text:?} with p={p}");
}
#[test]
fn ext_mem_empty() {
let got = ext_mem_sa(b"", 4);
assert!(got.is_empty());
}
#[test]
fn ext_mem_single_partition() {
assert_matches_in_memory(b"banana", 1);
}
#[test]
fn ext_mem_p_greater_than_n() {
assert_matches_in_memory(b"abc", 10);
}
#[test]
fn ext_mem_banana_p4() {
assert_matches_in_memory(b"banana", 4);
}
#[test]
fn ext_mem_mississippi_p3() {
assert_matches_in_memory(b"mississippi", 3);
}
#[test]
fn ext_mem_random_byte_texts() {
use rand::{RngExt, SeedableRng};
let mut rng = rand::rngs::StdRng::seed_from_u64(0xCAFE);
for &n in &[16usize, 100, 1000, 5000] {
for &p in &[1usize, 2, 4, 16] {
let text: Vec<u8> = (0..n).map(|_| rng.random_range(0..6u8)).collect();
assert_matches_in_memory(&text, p);
}
}
}
#[test]
fn ext_mem_with_unique_terminator() {
use rand::{RngExt, SeedableRng};
let mut rng = rand::rngs::StdRng::seed_from_u64(0xF00D);
for &n in &[10usize, 200, 2000] {
for &p in &[1usize, 3, 8] {
let mut text: Vec<u8> = (0..n).map(|_| rng.random_range(0..5u8)).collect();
text.push(200);
assert_matches_in_memory(&text, p);
}
}
}
fn ext_mem_for_positions(text: &[u8], positions: Vec<u64>, p: usize) -> Vec<u64> {
let dir = tempdir().unwrap();
let opts = ExtMemOpts {
subproblem_count: p,
work_dir: dir.path().to_path_buf(),
..ExtMemOpts::default()
};
let mut out: Vec<u64> = Vec::with_capacity(positions.len());
build_ext_mem_for_positions(text, positions, &opts, |pos| {
out.push(pos);
Ok(())
})
.unwrap();
out
}
#[test]
fn ext_mem_for_positions_full_set_matches_ext_mem() {
let text = b"mississippi";
let want = ext_mem_sa(text, 3);
let positions: Vec<u64> = (0..text.len() as u64).collect();
let got = ext_mem_for_positions(text, positions, 3);
assert_eq!(got, want);
}
#[test]
fn ext_mem_for_positions_subset_matches_brute_force() {
let text = b"mississippi";
let positions: Vec<u64> = (0..text.len() as u64).step_by(2).collect();
let mut want = positions.clone();
want.sort_by(|&a, &b| text[a as usize..].cmp(&text[b as usize..]));
let got = ext_mem_for_positions(text, positions, 4);
assert_eq!(got, want);
}
#[test]
fn ext_mem_for_positions_random_subsets() {
use rand::{RngExt, SeedableRng};
let mut rng = rand::rngs::StdRng::seed_from_u64(0xC0DE);
for &n in &[50usize, 500, 2000] {
for &p in &[1usize, 3, 8] {
let text: Vec<u8> = (0..n).map(|_| rng.random_range(0..5u8)).collect();
let mut positions: Vec<u64> = (0..n as u64).collect();
positions.retain(|_| rng.random_range(0..10) < 7);
let mut want = positions.clone();
want.sort_by(|&a, &b| text[a as usize..].cmp(&text[b as usize..]));
let got = ext_mem_for_positions(&text, positions, p);
assert_eq!(got, want, "subset ext-mem mismatch n={n} p={p}");
}
}
}
fn in_memory_sample_sort(text: &[u8], p: usize) -> Vec<u64> {
let dir = tempdir().unwrap();
let opts = ExtMemOpts {
subproblem_count: p,
work_dir: dir.path().to_path_buf(),
..ExtMemOpts::default()
};
let mut out: Vec<u64> = Vec::with_capacity(text.len());
build_in_memory_sample_sort(text, &opts, |pos| {
out.push(pos);
Ok(())
})
.unwrap();
out
}
#[test]
fn in_memory_sample_sort_matches_in_memory() {
for text in [b"banana" as &[u8], b"mississippi", b"abracadabra"] {
let want: Vec<u32> = build_in_memory(text);
let want64: Vec<u64> = want.iter().map(|&x| x as u64).collect();
let got = in_memory_sample_sort(text, 0);
assert_eq!(got, want64, "in-mem sample-sort mismatch on {text:?}");
}
}
#[test]
fn in_memory_sample_sort_random_byte_texts() {
use rand::{RngExt, SeedableRng};
let mut rng = rand::rngs::StdRng::seed_from_u64(0xC0DE_C0DE);
for &n in &[16usize, 200, 2000] {
for &p in &[1usize, 4, 16] {
let text: Vec<u8> = (0..n).map(|_| rng.random_range(0..6u8)).collect();
let want: Vec<u32> = build_in_memory(&text);
let want64: Vec<u64> = want.iter().map(|&x| x as u64).collect();
let got = in_memory_sample_sort(&text, p);
assert_eq!(got, want64, "in-mem ss mismatch n={n} p={p}");
}
}
}
fn ext_mem_for_filter<Pred>(text: &[u8], keep: Pred, p: usize) -> Vec<u64>
where
Pred: Fn(u64) -> bool + Send + Sync,
{
let dir = tempdir().unwrap();
let opts = ExtMemOpts {
subproblem_count: p,
work_dir: dir.path().to_path_buf(),
..ExtMemOpts::default()
};
let mut out: Vec<u64> = Vec::new();
build_ext_mem_for_filter(text, keep, &opts, |pos| {
out.push(pos);
Ok(())
})
.unwrap();
out
}
#[test]
fn ext_mem_for_filter_matches_for_positions_on_full_set() {
let text = b"mississippi";
let want = ext_mem_sa(text, 3);
let got = ext_mem_for_filter(text, |_p| true, 3);
assert_eq!(got, want);
}
#[test]
fn ext_mem_for_filter_matches_for_positions_on_dna_subset() {
use rand::{RngExt, SeedableRng};
let mut rng = rand::rngs::StdRng::seed_from_u64(0xCA75_5A);
for &n in &[50usize, 500, 2000] {
for &p in &[1usize, 3, 8] {
let text: Vec<u8> = (0..n).map(|_| rng.random_range(0..6u8)).collect();
let positions: Vec<u64> = (0..n as u64).filter(|&i| text[i as usize] < 4).collect();
let want = ext_mem_for_positions(&text, positions, p);
let got = ext_mem_for_filter(&text, |i| text[i as usize] < 4, p);
assert_eq!(got, want, "filter vs positions mismatch n={n} p={p}");
}
}
}
#[test]
fn ext_mem_for_filter_handles_block_aligned_boundaries() {
use rand::{RngExt, SeedableRng};
let mut rng = rand::rngs::StdRng::seed_from_u64(0xB10C_C0DE);
let n = 200_000usize;
let text: Vec<u8> = (0..n).map(|_| rng.random_range(0..6u8)).collect();
let positions: Vec<u64> = (0..n as u64).filter(|&i| text[i as usize] < 4).collect();
let want = ext_mem_for_positions(&text, positions, 8);
let got = ext_mem_for_filter(&text, |i| text[i as usize] < 4, 8);
assert_eq!(got, want, "filter API mismatch across block boundaries");
}
#[test]
fn ext_mem_for_filter_sparse_predicate() {
use rand::{RngExt, SeedableRng};
let mut rng = rand::rngs::StdRng::seed_from_u64(0x5_AA_55);
let n = 50_000usize;
let text: Vec<u8> = (0..n).map(|_| rng.random_range(0..20u8)).collect();
let positions: Vec<u64> = (0..n as u64).filter(|&i| text[i as usize] < 1).collect();
let want = ext_mem_for_positions(&text, positions, 4);
let got = ext_mem_for_filter(&text, |i| text[i as usize] < 1, 4);
assert_eq!(got, want, "filter API mismatch on sparse predicate");
}
#[test]
fn ext_mem_repetitive_does_not_blow_up() {
use std::time::Instant;
let unit = b"ACGTACGTACGTACGTACGTACGTACGT"; let mut text: Vec<u8> = Vec::new();
for _ in 0..100 {
text.extend_from_slice(unit);
}
text.push(200);
let start = Instant::now();
let got = ext_mem_sa(&text, 8);
let elapsed = start.elapsed();
let want: Vec<u32> = build_in_memory(&text);
let want64: Vec<u64> = want.iter().map(|&x| x as u64).collect();
assert_eq!(got, want64);
assert!(
elapsed.as_secs() < 2,
"ext-mem build on a tiny repetitive text took {elapsed:?}"
);
}
}