Skip to main content

caps_sa/
ext_mem.rs

1//! External-memory suffix array construction.
2//!
3//! Implements upstream CaPS-SA's *sample-sort + LCP-enhanced merge*
4//! external-memory algorithm:
5//!
6//! 1. **Sort + sample + spill.** Split positions `0..n` into `p` subarrays
7//!    of size `n / p`. In parallel, sort each with the in-memory
8//!    merge-sort kernel (`sample_sort::merge_sort`), sample ~`c · ln n`
9//!    positions uniformly from each sorted subarray, and spill the sorted
10//!    `(position, lcp)` records to a per-subarray
11//!    [`ExtMemBucket`][crate::ext_bucket::ExtMemBucket].
12//! 2. **Select pivots.** Globally sort the pooled samples and pick
13//!    `p − 1` pivots at evenly-spaced ranks, splitting the suffix order
14//!    into `p` partitions.
15//! 3. **Distribute.** For each sorted subarray, binary-search the pivots
16//!    to find its `p` sub-subarray split points; append each sub-subarray
17//!    to the corresponding *partition* bucket, marking a boundary after
18//!    each contribution.
19//! 4. **Per-partition merge.** Load each partition's bucket into RAM
20//!    (≈ `n / p` records); cascade 2-way LCP-enhanced merges across the
21//!    `p` sub-subarrays to produce that partition's globally-sorted slice.
22//! 5. **Stream output.** Iterate partitions in order and emit each
23//!    position via the caller's closure.
24//!
25//! Peak RAM ≈ `text` + `O(n / p)` per active worker (one partition's
26//! merge working set). With `p = 4 × rayon::current_num_threads()` and
27//! `num_threads = 8`, that's a few hundred MB on a `n = 6e9` (human-scale)
28//! input — well below in-memory's `~4 × n × 8 = ~200 GB`. The full SA is
29//! never materialized in RAM.
30
31use std::cmp::Ordering;
32use std::io;
33use std::path::{Path, PathBuf};
34use std::sync::Mutex;
35use std::time::Instant;
36
37use rayon::prelude::*;
38
39use crate::Index;
40use crate::ext_bucket::{BucketPool, BucketRecord, BucketStore, InMemBucket, SaLcp};
41use crate::lcp::{LcpDispatch, Symbol};
42use crate::limits::{LimitProvider, PlainText};
43use crate::sample_sort;
44
45/// Emit a phase-timing line to stderr if `CAPS_SA_PROFILE` is set in
46/// the environment. Used to localise where the ext-mem path spends its
47/// time without paying the cost of always logging — see
48/// `bench/README.md` "Where AVX-512 helps and where it doesn't" for
49/// how this is used.
50fn profile_log(message: &str) {
51    if std::env::var_os("CAPS_SA_PROFILE").is_some() {
52        eprintln!("caps-sa profile  {message}");
53    }
54}
55
56/// Tunable options for [`build_ext_mem`].
57#[derive(Clone, Debug)]
58pub struct ExtMemOpts {
59    /// Bound on LCP-extension comparisons inside merges. `usize::MAX`
60    /// (default) is unbounded.
61    pub max_context: usize,
62    /// Number of subarrays (`p` in upstream CaPS-SA). `0` (default) picks
63    /// `4 × rayon::current_num_threads()`, clamped to `[1, n]`.
64    pub subproblem_count: usize,
65    /// Directory for temp files. Defaults to [`std::env::temp_dir`].
66    pub work_dir: PathBuf,
67    /// Number of physical files in the bucket pool (one pool for the
68    /// phase-1 subarray buckets and a second for the phase-3 partition
69    /// buckets). `0` (default) picks `rayon::current_num_threads()` —
70    /// the right answer in practice: one writable inode per worker
71    /// keeps kernel-level write contention bounded.
72    ///
73    /// The `2 × p` logical buckets (typically thousands at genome
74    /// scale) collapse onto this pool of anonymous tempfiles via
75    /// `bucket_id % physical_file_count`. Larger values lower kernel
76    /// write contention; smaller values are kinder to networked
77    /// filesystems with high metadata cost. The `CAPS_SA_N_PHYS` env
78    /// var overrides this for one-off benches.
79    pub physical_file_count: usize,
80}
81
82impl Default for ExtMemOpts {
83    fn default() -> Self {
84        Self {
85            max_context: usize::MAX,
86            subproblem_count: 0,
87            work_dir: std::env::temp_dir(),
88            physical_file_count: 0,
89        }
90    }
91}
92
93impl ExtMemOpts {
94    /// Convenience constructor with the supplied `work_dir` and defaults
95    /// for everything else.
96    pub fn with_work_dir(work_dir: impl AsRef<Path>) -> Self {
97        Self {
98            work_dir: work_dir.as_ref().to_path_buf(),
99            ..Self::default()
100        }
101    }
102}
103
104/// Build the suffix array of `text` with bounded RAM, streaming each
105/// output position to `emit` in lexicographic order.
106///
107/// Returns an [`io::Error`] if temp-file I/O fails. The callback may also
108/// return an error to abort construction; partial work is discarded and
109/// temp files are cleaned up when their bucket drops.
110///
111/// Equivalent in semantics to [`crate::build_in_memory`]: produces a
112/// standard lexicographic suffix array with the "shorter suffix is
113/// smaller when one runs off the end" tie-break.
114pub fn build_ext_mem<S, F>(text: &[S], opts: &ExtMemOpts, emit: F) -> io::Result<()>
115where
116    S: Symbol,
117    F: FnMut(u64) -> io::Result<()>,
118{
119    build_ext_mem_with(text, &PlainText::new(text.len()), opts, emit)
120}
121
122/// Variant of [`build_ext_mem`] that accepts a [`LimitProvider`]. With
123/// [`PlainText`] this matches [`build_ext_mem`] exactly (and
124/// monomorphizes to identical assembly); with
125/// [`SegmentedText`][crate::limits::SegmentedText] the LCP scans stop
126/// at segment boundaries.
127pub fn build_ext_mem_with<S, L, F>(text: &[S], lp: &L, opts: &ExtMemOpts, emit: F) -> io::Result<()>
128where
129    S: Symbol,
130    L: LimitProvider,
131    F: FnMut(u64) -> io::Result<()>,
132{
133    // Dispatch on text size: when every suffix position fits in `u32`
134    // (n ≤ 2^32), use the narrow record type to halve all the SaLcp
135    // bytes — bucket disk I/O, phase-1 records, and the per-partition
136    // load in phase 4.
137    if text.len() <= u32::MAX as usize + 1 {
138        build_ext_mem_inner::<S, u32, L, F>(
139            text,
140            PositionSource::Identity(text.len()),
141            lp,
142            opts,
143            emit,
144        )
145    } else {
146        build_ext_mem_inner::<S, u64, L, F>(
147            text,
148            PositionSource::Identity(text.len()),
149            lp,
150            opts,
151            emit,
152        )
153    }
154}
155
156/// Like [`build_ext_mem`] but sorts only the caller-supplied
157/// `positions` by the lexicographic order of their suffixes in `text`.
158/// Suffix content is always `text[position..]`; no positions are
159/// dropped from the input. To filter, the caller constructs
160/// `positions` with only the indices they want.
161///
162/// This lets STAR-style genome indexing skip the bin-padding
163/// pathology: pass only the ACGT-starting positions and the
164/// spacer-suffix work disappears.
165///
166/// `positions` does not need to be pre-sorted in any way.
167pub fn build_ext_mem_for_positions<S, F>(
168    text: &[S],
169    positions: Vec<u64>,
170    opts: &ExtMemOpts,
171    emit: F,
172) -> io::Result<()>
173where
174    S: Symbol,
175    F: FnMut(u64) -> io::Result<()>,
176{
177    build_ext_mem_for_positions_with(text, positions, &PlainText::new(text.len()), opts, emit)
178}
179
180/// Variant of [`build_ext_mem_for_positions`] that accepts a
181/// [`LimitProvider`]. See [`build_ext_mem_with`] for the semantics.
182pub fn build_ext_mem_for_positions_with<S, L, F>(
183    text: &[S],
184    positions: Vec<u64>,
185    lp: &L,
186    opts: &ExtMemOpts,
187    emit: F,
188) -> io::Result<()>
189where
190    S: Symbol,
191    L: LimitProvider,
192    F: FnMut(u64) -> io::Result<()>,
193{
194    // We hold a reference to `positions` for the duration of the build;
195    // `phase1_sort_sample_spill` copies each chunk out. The Vec is
196    // dropped once phase 1 returns.
197    if text.len() <= u32::MAX as usize + 1 {
198        build_ext_mem_inner::<S, u32, L, F>(
199            text,
200            PositionSource::Subset(&positions),
201            lp,
202            opts,
203            emit,
204        )
205    } else {
206        build_ext_mem_inner::<S, u64, L, F>(
207            text,
208            PositionSource::Subset(&positions),
209            lp,
210            opts,
211            emit,
212        )
213    }
214}
215
216/// Like [`build_ext_mem_for_positions`] but takes a **predicate** over
217/// text positions instead of a pre-materialised `Vec<u64>` of kept
218/// positions.
219///
220/// caps-sa walks the predicate **once** to build a bitmap of kept
221/// positions + a tiny per-block popcount prefix-sum (together ~`n / 8`
222/// bytes — ~770 MB on the human genome, vs the ~50 GB the equivalent
223/// `Vec<u64>` would take). Phase 1's per-subarray fill is then driven
224/// by popcount-walking the bitmap; the predicate is **never invoked
225/// again** after the initial build. See [`FilteredSource`] for the
226/// memory accounting and the inner loop.
227///
228/// Use this entry when the caller already has the text in RAM and
229/// the kept positions are described by a cheap per-position
230/// predicate (e.g. STAR's `text[p] < 4` for ACGT-only suffix
231/// sampling). It is **the right entry for genome-scale inputs** —
232/// the `Vec<u64>` path can dominate peak RSS otherwise.
233///
234/// `keep` is invoked from rayon worker threads in parallel during
235/// the bitmap build; it must be `Send + Sync` (typically a plain
236/// closure capturing only `&[u8]` references is fine).
237pub fn build_ext_mem_for_filter<S, F, Pred>(
238    text: &[S],
239    keep: Pred,
240    opts: &ExtMemOpts,
241    emit: F,
242) -> io::Result<()>
243where
244    S: Symbol,
245    F: FnMut(u64) -> io::Result<()>,
246    Pred: Fn(u64) -> bool + Send + Sync,
247{
248    build_ext_mem_for_filter_with(text, keep, &PlainText::new(text.len()), opts, emit)
249}
250
251/// Variant of [`build_ext_mem_for_filter`] that accepts a
252/// [`LimitProvider`]. See [`build_ext_mem_with`] for the semantics.
253pub fn build_ext_mem_for_filter_with<S, L, F, Pred>(
254    text: &[S],
255    keep: Pred,
256    lp: &L,
257    opts: &ExtMemOpts,
258    emit: F,
259) -> io::Result<()>
260where
261    S: Symbol,
262    L: LimitProvider,
263    F: FnMut(u64) -> io::Result<()>,
264    Pred: Fn(u64) -> bool + Send + Sync,
265{
266    let filtered = FilteredSource::new(text.len(), keep);
267    if text.len() <= u32::MAX as usize + 1 {
268        build_ext_mem_inner::<S, u32, L, F>(
269            text,
270            PositionSource::Filtered(filtered),
271            lp,
272            opts,
273            emit,
274        )
275    } else {
276        build_ext_mem_inner::<S, u64, L, F>(
277            text,
278            PositionSource::Filtered(filtered),
279            lp,
280            opts,
281            emit,
282        )
283    }
284}
285
286fn build_ext_mem_inner<S, I, L, F>(
287    text: &[S],
288    source: PositionSource<'_>,
289    lp: &L,
290    opts: &ExtMemOpts,
291    mut emit: F,
292) -> io::Result<()>
293where
294    S: Symbol,
295    I: Index,
296    L: LimitProvider,
297    SaLcp<I>: BucketRecord,
298    F: FnMut(u64) -> io::Result<()>,
299{
300    let n = source.len();
301    if n == 0 {
302        return Ok(());
303    }
304    let p = effective_subproblem_count(n, opts.subproblem_count);
305    let dispatch = LcpDispatch::detect();
306    let work_dir = opts.work_dir.clone();
307
308    // Pool the `2 × p` bucket files into one anonymous tempfile per
309    // worker thread. With `p` in the thousands and `num_threads` in
310    // the dozens this collapses the openat/close/unlink budget from
311    // ~3·p (per-bucket-file path) to N (per-worker-file pool),
312    // eliminating the metadata-syscall pain on networked filesystems
313    // and the open-file-handle limit headache on tiny inputs. Per-
314    // bucket in-memory buffers and write volumes are unchanged, so
315    // local-disk wall time is neutral or marginally improved. See
316    // `bench/README.md` for the empirical sizing.
317    let n_phys = effective_physical_file_count(opts.physical_file_count);
318    let phase1_pool = BucketPool::new(n_phys, &work_dir)?;
319    let phase3_pool = BucketPool::new(n_phys, &work_dir)?;
320
321    profile_log(&format!(
322        "build_ext_mem n={n} p={p} index_width={}b n_phys={n_phys}",
323        std::mem::size_of::<I>() * 8
324    ));
325
326    let sub_factory = |i: usize| phase1_pool.new_bucket::<SaLcp<I>>(i);
327    let part_factory = |j: usize| phase3_pool.new_bucket::<SaLcp<I>>(j);
328
329    let t = Instant::now();
330    let (mut subarray_buckets, samples) = phase1_sort_sample_spill::<S, I, L, _, _>(
331        text,
332        lp,
333        &source,
334        p,
335        opts,
336        dispatch,
337        sub_factory,
338    )?;
339    profile_log(&format!(
340        "phase1 (sort+sample+spill) {:.3}s",
341        t.elapsed().as_secs_f64()
342    ));
343
344    // Drop the position source as soon as phase 1 returns — phases
345    // 2/3/4 don't touch it. For `PositionSource::Subset` this frees
346    // the caller's `Vec<u64>` (e.g. ~47 GB on a human-scale
347    // _for_positions build); for `PositionSource::Filtered` it
348    // frees the bitmap + cumsum (~770 MB); for `Identity` it's a
349    // no-op. The text and the spilled `subarray_buckets` are all
350    // phase 2+ needs.
351    drop(source);
352
353    let t = Instant::now();
354    let pivots = phase2_select_pivots::<S, I, L>(text, lp, samples, p, opts.max_context, dispatch);
355    profile_log(&format!(
356        "phase2 (select pivots)      {:.3}s",
357        t.elapsed().as_secs_f64()
358    ));
359
360    let t = Instant::now();
361    let mut partition_buckets = phase3_distribute::<S, I, L, _, _>(
362        text,
363        lp,
364        &mut subarray_buckets,
365        &pivots,
366        p,
367        opts,
368        dispatch,
369        part_factory,
370    )?;
371    profile_log(&format!(
372        "phase3 (distribute)          {:.3}s",
373        t.elapsed().as_secs_f64()
374    ));
375
376    drop(subarray_buckets);
377
378    let t = Instant::now();
379    let result = phase4_merge_and_emit::<S, I, L, _, F>(
380        text,
381        lp,
382        &mut partition_buckets,
383        opts.max_context,
384        &mut emit,
385        dispatch,
386    );
387    profile_log(&format!(
388        "phase4 (merge+emit)          {:.3}s",
389        t.elapsed().as_secs_f64()
390    ));
391    result
392}
393
394/// Same algorithm as [`build_ext_mem_inner`] but with the disk-backed
395/// [`ExtMemBucket`] replaced by [`InMemBucket`] throughout — phase 1
396/// sorts each subarray and keeps the result in a `Vec<SaLcp<I>>`,
397/// phase 3 distributes into in-RAM partition Vecs, phase 4 cascade-
398/// merges the in-RAM partitions. No disk I/O.
399///
400/// Trades RAM for wall time: peak memory is ~`n × sizeof(SaLcp<I>)`
401/// (the post-phase-1 records sitting around until phase 3 consumes
402/// them), so ~25 GB on the human genome with `I = u32`. In exchange,
403/// the disk-spill / distribute-write / partition-load round-trip is
404/// gone — useful on machines with enough RAM to hold the working set.
405fn build_in_memory_ss_inner<S, I, L, F>(
406    text: &[S],
407    source: PositionSource<'_>,
408    lp: &L,
409    opts: &ExtMemOpts,
410    mut emit: F,
411) -> io::Result<()>
412where
413    S: Symbol,
414    I: Index,
415    L: LimitProvider,
416    SaLcp<I>: BucketRecord,
417    F: FnMut(u64) -> io::Result<()>,
418{
419    let n = source.len();
420    if n == 0 {
421        return Ok(());
422    }
423    let p = effective_subproblem_count(n, opts.subproblem_count);
424    let dispatch = LcpDispatch::detect();
425
426    let factory = |_i: usize| InMemBucket::<SaLcp<I>>::new();
427
428    let (mut subarray_buckets, samples) =
429        phase1_sort_sample_spill::<S, I, L, _, _>(text, lp, &source, p, opts, dispatch, factory)?;
430    // Same rationale as in `build_ext_mem_inner` — drop the source
431    // as soon as phase 1's `fill_chunk` calls have stopped.
432    drop(source);
433    let pivots = phase2_select_pivots::<S, I, L>(text, lp, samples, p, opts.max_context, dispatch);
434    let mut partition_buckets = phase3_distribute::<S, I, L, _, _>(
435        text,
436        lp,
437        &mut subarray_buckets,
438        &pivots,
439        p,
440        opts,
441        dispatch,
442        factory,
443    )?;
444    drop(subarray_buckets);
445    phase4_merge_and_emit::<S, I, L, _, F>(
446        text,
447        lp,
448        &mut partition_buckets,
449        opts.max_context,
450        &mut emit,
451        dispatch,
452    )
453}
454
455/// In-memory variant of the sample-sort algorithm used by
456/// [`build_ext_mem`]. Skips all disk I/O at the cost of holding the
457/// (`pos`, `lcp`) records in RAM throughout. Picks `u32` records when
458/// `n ≤ 2³²`, falls back to `u64` otherwise. The caller's `emit`
459/// closure is called once per output position in lex order, just like
460/// in the ext-mem path.
461pub fn build_in_memory_sample_sort<S, F>(text: &[S], opts: &ExtMemOpts, emit: F) -> io::Result<()>
462where
463    S: Symbol,
464    F: FnMut(u64) -> io::Result<()>,
465{
466    build_in_memory_sample_sort_with(text, &PlainText::new(text.len()), opts, emit)
467}
468
469/// Variant of [`build_in_memory_sample_sort`] that accepts a
470/// [`LimitProvider`].
471pub fn build_in_memory_sample_sort_with<S, L, F>(
472    text: &[S],
473    lp: &L,
474    opts: &ExtMemOpts,
475    emit: F,
476) -> io::Result<()>
477where
478    S: Symbol,
479    L: LimitProvider,
480    F: FnMut(u64) -> io::Result<()>,
481{
482    if text.len() <= u32::MAX as usize + 1 {
483        build_in_memory_ss_inner::<S, u32, L, F>(
484            text,
485            PositionSource::Identity(text.len()),
486            lp,
487            opts,
488            emit,
489        )
490    } else {
491        build_in_memory_ss_inner::<S, u64, L, F>(
492            text,
493            PositionSource::Identity(text.len()),
494            lp,
495            opts,
496            emit,
497        )
498    }
499}
500
501/// Subset-positions variant of [`build_in_memory_sample_sort`]. Same
502/// shape as [`build_ext_mem_for_positions`].
503pub fn build_in_memory_sample_sort_for_positions<S, F>(
504    text: &[S],
505    positions: Vec<u64>,
506    opts: &ExtMemOpts,
507    emit: F,
508) -> io::Result<()>
509where
510    S: Symbol,
511    F: FnMut(u64) -> io::Result<()>,
512{
513    build_in_memory_sample_sort_for_positions_with(
514        text,
515        positions,
516        &PlainText::new(text.len()),
517        opts,
518        emit,
519    )
520}
521
522/// Variant of [`build_in_memory_sample_sort_for_positions`] that
523/// accepts a [`LimitProvider`].
524pub fn build_in_memory_sample_sort_for_positions_with<S, L, F>(
525    text: &[S],
526    positions: Vec<u64>,
527    lp: &L,
528    opts: &ExtMemOpts,
529    emit: F,
530) -> io::Result<()>
531where
532    S: Symbol,
533    L: LimitProvider,
534    F: FnMut(u64) -> io::Result<()>,
535{
536    if text.len() <= u32::MAX as usize + 1 {
537        build_in_memory_ss_inner::<S, u32, L, F>(
538            text,
539            PositionSource::Subset(&positions),
540            lp,
541            opts,
542            emit,
543        )
544    } else {
545        build_in_memory_ss_inner::<S, u64, L, F>(
546            text,
547            PositionSource::Subset(&positions),
548            lp,
549            opts,
550            emit,
551        )
552    }
553}
554
555/// Source of the positions to sort.
556///
557/// - [`PositionSource::Identity`] is the all-suffixes case `0..n`;
558///   avoids materialising a `Vec<u64>` of length `n`, which on the
559///   human genome would itself be ~50 GB.
560/// - [`PositionSource::Subset`] holds a caller-supplied `&[u64]` of
561///   the positions to sort, in any order. Random-access by index.
562/// - [`PositionSource::Filtered`] is the streaming-predicate variant:
563///   the caller hands in a `Fn(u64) -> bool` over text positions and
564///   caps-sa walks the filter forward through the text on demand. A
565///   tiny prefix-sum (`text.len() / BLOCK_SIZE × 8` B ≈ 760 KB on
566///   the human genome at the 64 KB block size) lets `fill_chunk`
567///   locate the i-th kept position in `O(log n_blocks + BLOCK_SIZE +
568///   chunk_size)`. **No kept-positions list is materialised** — the
569///   memory saving over `Subset` is `≈ 8 × n_kept` bytes, dominant
570///   on genome-scale inputs.
571enum PositionSource<'a> {
572    Identity(usize),
573    Subset(&'a [u64]),
574    Filtered(FilteredSource),
575}
576
577/// Block size in **u64 words** for [`FilteredSource`]'s popcount
578/// prefix-sum. With 1024 words per block (64 K text bits) the
579/// prefix-sum stores one `u64` per block — ~760 KB for the human
580/// genome — and each `fill_chunk` walks at most one block (≈ 8 KB)
581/// in cache before emitting the first kept position.
582const FILTERED_WORDS_PER_BLOCK: usize = 1024;
583
584/// Streaming position source backed by a **bitmap** of kept positions
585/// + a per-block popcount prefix-sum. No `Vec<u64>` of kept positions
586/// is ever materialised.
587///
588/// Memory: `(n + 7) / 8` bytes for the bitmap (~770 MB on the human
589/// genome at `n ≈ 6.2 B`) + `8 × n_blocks` for the prefix sum
590/// (~760 KB). The 47 GB `Vec<u64>` the [`Subset`][PositionSource::Subset]
591/// variant requires goes away entirely.
592///
593/// Lookup path inside [`fill_chunk`]:
594/// 1. `partition_point` the prefix-sum to find the block containing
595///    the `start`-th kept position (`O(log n_blocks)` ≈ 17 ops).
596/// 2. Walk that block's bitmap words `popcount`-by-`popcount` until
597///    we've skipped the right number of set bits to reach `start`.
598/// 3. Walk forward through bitmap words using `trailing_zeros`-style
599///    iteration, emitting each set bit's position to `dst`. Inner
600///    loop is `O(chunk_size / 64)` u64 ops — no per-text-position
601///    closure calls, branch-light, cache-resident.
602///
603/// A truly `O(1)` select-1 structure (darray / Elias-Fano) on top
604/// of the bitmap would shrink step (1)+(2) further; with
605/// `chunk_size` ≈ 750 K and only `p` ≈ 8192 fill_chunk calls per
606/// build, the `O(log n_blocks)` + at-most-one-block-walk cost is
607/// already a rounding error. See `bench/README.md` for the
608/// follow-up note if that ever changes.
609struct FilteredSource {
610    text_len: usize,
611    total_kept: usize,
612    /// Bitmap: `(bitmap[w] >> b) & 1 == 1` iff position `64 * w + b`
613    /// is kept. Length = `text_len.div_ceil(64)`.
614    bitmap: Vec<u64>,
615    /// `cumsum[i] = sum of set bits in
616    /// `bitmap[0 .. i * FILTERED_WORDS_PER_BLOCK]`. Length =
617    /// `n_blocks + 1`; the last entry equals `total_kept`.
618    cumsum: Vec<u64>,
619}
620
621impl FilteredSource {
622    /// Build a [`FilteredSource`] by walking the predicate once over
623    /// `0..text_len` to fill the bitmap, then accumulating per-block
624    /// popcounts.
625    ///
626    /// The predicate is invoked exactly `text_len` times here; once
627    /// the bitmap is built, `fill_chunk` never calls it again. This
628    /// trades one full predicate pass for zero per-position calls
629    /// during all subsequent random-access `fill_chunk`s — a clear
630    /// win when (as in caps-sa's phase 1) every position is read at
631    /// least once.
632    fn new<Pred>(text_len: usize, keep: Pred) -> Self
633    where
634        Pred: Fn(u64) -> bool + Send + Sync,
635    {
636        let n_words = text_len.div_ceil(64);
637        // Parallel per-word bitmap build. Each word reads 64 text
638        // positions (clamped at `text_len`), packs them into a u64.
639        let bitmap: Vec<u64> = (0..n_words)
640            .into_par_iter()
641            .map(|w| {
642                let mut word: u64 = 0;
643                let base = (w as u64) * 64;
644                let limit = ((w + 1) * 64).min(text_len) - w * 64;
645                for b in 0..limit {
646                    if keep(base + b as u64) {
647                        word |= 1u64 << b;
648                    }
649                }
650                word
651            })
652            .collect();
653
654        // Per-block popcount cumsum. Each block covers
655        // `FILTERED_WORDS_PER_BLOCK` words = `FILTERED_BITS_PER_BLOCK`
656        // text positions.
657        let n_blocks = n_words.div_ceil(FILTERED_WORDS_PER_BLOCK);
658        let per_block: Vec<u64> = (0..n_blocks)
659            .into_par_iter()
660            .map(|i| {
661                let start = i * FILTERED_WORDS_PER_BLOCK;
662                let end = ((i + 1) * FILTERED_WORDS_PER_BLOCK).min(n_words);
663                let mut c: u64 = 0;
664                for &word in &bitmap[start..end] {
665                    c += word.count_ones() as u64;
666                }
667                c
668            })
669            .collect();
670        let mut cumsum = Vec::with_capacity(n_blocks + 1);
671        let mut s: u64 = 0;
672        cumsum.push(0);
673        for &k in &per_block {
674            s += k;
675            cumsum.push(s);
676        }
677        let total_kept = s as usize;
678        Self {
679            text_len,
680            total_kept,
681            bitmap,
682            cumsum,
683        }
684    }
685
686    /// Number of kept positions.
687    #[inline]
688    fn len(&self) -> usize {
689        self.total_kept
690    }
691
692    /// Fill `dst` with the next `dst.len()` kept positions starting
693    /// from the `start`-th (0-based) kept position. See type-level
694    /// doc for the algorithm; this is the hot path during phase 1
695    /// fill_chunk.
696    fn fill_chunk<I: Index>(&self, start: usize, dst: &mut [I]) {
697        debug_assert!(start + dst.len() <= self.total_kept);
698        if dst.is_empty() {
699            return;
700        }
701
702        // (1) Locate the block containing the `start`-th set bit.
703        // `partition_point(|c| c <= start)` gives the first cumsum
704        // entry strictly greater than `start`; previous index is the
705        // containing block.
706        let pp = self.cumsum.partition_point(|&c| c <= start as u64);
707        debug_assert!(pp > 0);
708        let block_idx = pp - 1;
709        let mut word_idx = block_idx * FILTERED_WORDS_PER_BLOCK;
710        let mut skip = start as u64 - self.cumsum[block_idx];
711
712        // (2) Skip the first `skip` set bits — possibly spanning
713        // several bitmap words. Whole words with `popcount ≤ skip`
714        // are consumed wholesale; the final partial word has its
715        // lowest `skip` set bits cleared so the emit loop sees only
716        // un-skipped 1s.
717        //
718        // Note: a naive `while skip >= 64 { … }` is wrong because a
719        // word's popcount can be far less than 64; we must subtract
720        // the actual popcount each iteration, not 64. This matters
721        // any time the bitmap is sparser than ~50%.
722        let n_words = self.bitmap.len();
723        let mut word: u64 = if word_idx < n_words { self.bitmap[word_idx] } else { 0 };
724        while skip > 0 {
725            let pc = word.count_ones() as u64;
726            if skip < pc {
727                // Consume `skip` lowest set bits inside the current
728                // word; emit loop continues from the remaining ones.
729                for _ in 0..skip {
730                    word &= word - 1;
731                }
732                break;
733            }
734            // Skip ≥ pc: consume the whole word and advance.
735            skip -= pc;
736            word_idx += 1;
737            word = if word_idx < n_words { self.bitmap[word_idx] } else { 0 };
738        }
739
740        // (3) Walk `word`+subsequent words, emitting one position per
741        // set bit. Uses `trailing_zeros` to jump straight to the next
742        // 1 inside a word, then clears it via `word &= word - 1`.
743        let mut written = 0usize;
744        let need = dst.len();
745        loop {
746            while word != 0 && written < need {
747                let bit = word.trailing_zeros() as u64;
748                let pos = (word_idx as u64) * 64 + bit;
749                debug_assert!((pos as usize) < self.text_len);
750                dst[written] = I::from_usize(pos as usize);
751                written += 1;
752                word &= word - 1;
753            }
754            if written == need {
755                break;
756            }
757            word_idx += 1;
758            debug_assert!(
759                word_idx < n_words,
760                "FilteredSource::fill_chunk: walked past bitmap end \
761                 ({written}/{need} emitted, word_idx={word_idx}, n_words={n_words})"
762            );
763            word = self.bitmap[word_idx];
764        }
765    }
766}
767
768impl<'a> PositionSource<'a> {
769    fn len(&self) -> usize {
770        match self {
771            Self::Identity(n) => *n,
772            Self::Subset(p) => p.len(),
773            Self::Filtered(f) => f.len(),
774        }
775    }
776
777    /// Fill `dst` with positions for the half-open subarray range
778    /// `[start, start + dst.len())`, narrowing the caller's `u64`
779    /// positions into `I` via [`Index::from_usize`]. For
780    /// [`PositionSource::Identity`] this generates the contiguous
781    /// integer range on the fly; for [`PositionSource::Subset`] it
782    /// reads from the caller's slice; for [`PositionSource::Filtered`]
783    /// it walks the predicate forward from the right text block.
784    fn fill_chunk<I: Index>(&self, start: usize, dst: &mut [I]) {
785        match self {
786            Self::Identity(_) => {
787                for (i, slot) in dst.iter_mut().enumerate() {
788                    *slot = I::from_usize(start + i);
789                }
790            }
791            Self::Subset(p) => {
792                let end = start + dst.len();
793                for (slot, &v) in dst.iter_mut().zip(p[start..end].iter()) {
794                    *slot = I::from_usize(v as usize);
795                }
796            }
797            Self::Filtered(f) => f.fill_chunk(start, dst),
798        }
799    }
800}
801
802/// Target subarray size used by [`effective_subproblem_count`] when
803/// auto-picking `p`. Smaller means more (smaller) subarrays — lower
804/// per-task phase-1 scratch, at the cost of more phase-3 distribute
805/// work (which scales as `O(p² · log(n/p))`, sequentially) and a
806/// higher temp-file count.
807const PHASE1_TARGET_CHUNK: usize = 65_536;
808/// Hard cap on the number of subarrays. Matches upstream CaPS-SA's
809/// default of 8192 — phase 3 is now parallelised across rayon
810/// workers (each subarray distributes independently into per-partition
811/// `Mutex<ExtMemBucket>` slots), so the `O(p²)` sequential distribute
812/// of the original design no longer constrains us. The cap is still
813/// finite to keep the temp-file count bounded.
814const PHASE1_MAX_PARTITIONS: usize = 8192;
815
816/// Resolve [`ExtMemOpts::physical_file_count`] for the current build.
817/// `0` (the default) means "let the runtime decide"; we pick
818/// `rayon::current_num_threads()` so the pool has one inode per
819/// concurrent writer, which empirically matches per-bucket-file wall
820/// time while collapsing thousands of small files into dozens of
821/// large ones. The `CAPS_SA_N_PHYS` env var overrides at the call
822/// site for benchmarks.
823fn effective_physical_file_count(requested: usize) -> usize {
824    if let Some(v) = std::env::var("CAPS_SA_N_PHYS")
825        .ok()
826        .and_then(|s| s.parse::<usize>().ok())
827        .filter(|&v| v >= 1)
828    {
829        return v;
830    }
831    if requested >= 1 {
832        return requested;
833    }
834    rayon::current_num_threads().max(1)
835}
836
837fn effective_subproblem_count(n: usize, requested: usize) -> usize {
838    if n == 0 {
839        return 0;
840    }
841    let raw = if requested == 0 {
842        let nthreads = rayon::current_num_threads().max(1);
843        let p_from_size = n.div_ceil(PHASE1_TARGET_CHUNK);
844        // At least one chunk per thread (otherwise we leave cores idle),
845        // at most `PHASE1_MAX_PARTITIONS` (so phase 3's sequential
846        // `O(p²)` sweep and the temp-file count stay manageable). For
847        // small inputs `p_from_size` is well below the cap, so the
848        // formula degrades gracefully to roughly "one chunk per thread";
849        // for human-scale inputs the cap binds and per-task scratch
850        // stays in the tens-of-MB range.
851        p_from_size.clamp(nthreads, PHASE1_MAX_PARTITIONS)
852    } else {
853        requested
854    };
855    raw.clamp(1, n)
856}
857
858/// Phase 1: sort each subarray in parallel, sample from it, and spill
859/// `(position, lcp)` records to its own [`ExtMemBucket`].
860///
861/// One rayon task per subarray; rayon's work-stealing scheduler keeps
862/// all worker threads busy and lets `merge_sort`'s inner
863/// [`rayon::join`] recursion steal idle slots. With the auto-picked
864/// `p` (target chunk ~ 64 K records), per-task scratch is ~18 MiB on
865/// human-scale inputs, so the `num_threads × per_task_scratch` peak
866/// stays bounded even though we don't reuse buffers across iterations.
867#[allow(clippy::too_many_arguments)]
868fn phase1_sort_sample_spill<S, I, L, B, MkB>(
869    text: &[S],
870    lp: &L,
871    source: &PositionSource<'_>,
872    p: usize,
873    opts: &ExtMemOpts,
874    dispatch: LcpDispatch,
875    mk_bucket: MkB,
876) -> io::Result<(Vec<B>, Vec<I>)>
877where
878    S: Symbol,
879    I: Index,
880    L: LimitProvider,
881    SaLcp<I>: BucketRecord,
882    B: BucketStore<SaLcp<I>> + Send,
883    MkB: Fn(usize) -> B + Send + Sync,
884{
885    let n = source.len();
886    let chunk_size = n.div_ceil(p);
887    let samples_target_total = sample_target_total(n, p);
888
889    let per_subarray: Vec<(B, Vec<I>)> = (0..p)
890        .into_par_iter()
891        .map(|i| {
892            let start = (i * chunk_size).min(n);
893            let end = ((i + 1) * chunk_size).min(n);
894            let len = end - start;
895
896            let mut bucket = mk_bucket(i);
897            if len == 0 {
898                return Ok::<_, io::Error>((bucket, Vec::new()));
899            }
900
901            // In-memory sort of this subarray with LCP maintenance.
902            let mut sa: Vec<I> = vec![I::zero(); len];
903            source.fill_chunk(start, &mut sa);
904            let mut sa_w = vec![I::zero(); len];
905            let mut lcp_arr = vec![I::zero(); len];
906            let mut lcp_w = vec![I::zero(); len];
907            sample_sort::merge_sort(
908                text,
909                lp,
910                &mut sa,
911                &mut sa_w,
912                &mut lcp_arr,
913                &mut lcp_w,
914                opts.max_context,
915                dispatch,
916            );
917
918            // Pull `samples_per_subarray` evenly-spaced positions out of
919            // the now-sorted subarray. Deterministic — no RNG needed for
920            // pivot selection to be globally well-distributed.
921            let samples_per_subarray = samples_target_total.div_ceil(p).min(len);
922            let samples = evenly_spaced(&sa, samples_per_subarray);
923
924            // Spill (position, lcp) records to the bucket. `lcp[0]`
925            // remains 0 (set by the merge-sort base case), making each
926            // subarray its own well-formed LCP-annotated sorted run.
927            let records: Vec<SaLcp<I>> = sa
928                .iter()
929                .zip(lcp_arr.iter())
930                .map(|(&pos, &lcp)| SaLcp { pos, lcp })
931                .collect();
932            bucket.add_slice(&records)?;
933
934            Ok((bucket, samples))
935        })
936        .collect::<Result<Vec<_>, _>>()?;
937
938    let mut buckets = Vec::with_capacity(p);
939    let mut all_samples = Vec::with_capacity(samples_target_total);
940    for (bucket, samples) in per_subarray {
941        buckets.push(bucket);
942        all_samples.extend(samples);
943    }
944    Ok((buckets, all_samples))
945}
946
947/// Target sample count *across all subarrays*. Matches upstream CaPS-SA's
948/// "`c · ln n`" rule per subarray with `c = 4`, so the global pool is
949/// `p · 4 · ln n` samples.
950fn sample_target_total(n: usize, p: usize) -> usize {
951    let ln_n = (n as f64).ln().max(1.0);
952    let per = (4.0 * ln_n).ceil() as usize;
953    // At least p (so we have enough to pick p-1 pivots) and at most n.
954    p.saturating_mul(per).clamp(p, n)
955}
956
957/// Pick `count` evenly-spaced elements from a slice. Deterministic, which
958/// keeps the algorithm reproducible without an RNG dependency.
959fn evenly_spaced<T: Copy>(xs: &[T], count: usize) -> Vec<T> {
960    let n = xs.len();
961    if count == 0 || n == 0 {
962        return Vec::new();
963    }
964    if count >= n {
965        return xs.to_vec();
966    }
967    // Pick indices at positions (i + 0.5) · n / count for i in 0..count,
968    // i.e. evenly-spaced midpoints. Avoids both endpoints — keeps pivots
969    // away from extreme corners of the order.
970    (0..count)
971        .map(|i| xs[(2 * i + 1) * n / (2 * count)])
972        .collect()
973}
974
975/// Phase 2: globally sort the pooled samples and pick `p − 1` pivots at
976/// evenly-spaced ranks.
977fn phase2_select_pivots<S, I, L>(
978    text: &[S],
979    lp: &L,
980    mut samples: Vec<I>,
981    p: usize,
982    max_ctx: usize,
983    dispatch: LcpDispatch,
984) -> Vec<I>
985where
986    S: Symbol,
987    I: Index,
988    L: LimitProvider,
989{
990    if p <= 1 || samples.is_empty() {
991        return Vec::new();
992    }
993    let n_samples = samples.len();
994    let mut sa_w = vec![I::zero(); n_samples];
995    let mut lcp = vec![I::zero(); n_samples];
996    let mut lcp_w = vec![I::zero(); n_samples];
997    sample_sort::merge_sort(
998        text,
999        lp,
1000        &mut samples,
1001        &mut sa_w,
1002        &mut lcp,
1003        &mut lcp_w,
1004        max_ctx,
1005        dispatch,
1006    );
1007
1008    // p-1 pivots at evenly-spaced ranks across the sorted sample pool.
1009    (1..p).map(|j| samples[(j * n_samples) / p]).collect()
1010}
1011
1012/// Phase 3: walk each subarray *in parallel*, load it into RAM,
1013/// binary-search the pivots to find its `p` sub-subarray boundaries,
1014/// and append each sub-subarray to the corresponding partition bucket.
1015///
1016/// Partition buckets are wrapped in a [`Mutex`] each so multiple
1017/// threads can write to different partitions concurrently without
1018/// shard-merging afterwards. With `p` in the thousands and `T` in the
1019/// tens, lock contention is negligible (probability that two threads
1020/// want the same partition at the same instant is `~T/p`); the lock
1021/// scope per acquisition is one `add_slice` + `mark_boundary` of a
1022/// few-KB sub-subarray.
1023///
1024/// Phase 4 doesn't care about the relative order of sub-subarrays
1025/// within a partition — only that each one between consecutive
1026/// boundaries is internally sorted. Both properties hold under
1027/// arbitrary thread interleaving.
1028#[allow(clippy::too_many_arguments)]
1029fn phase3_distribute<S, I, L, B, MkB>(
1030    text: &[S],
1031    lp: &L,
1032    subarray_buckets: &mut [B],
1033    pivots: &[I],
1034    p: usize,
1035    opts: &ExtMemOpts,
1036    dispatch: LcpDispatch,
1037    mk_bucket: MkB,
1038) -> io::Result<Vec<B>>
1039where
1040    S: Symbol,
1041    I: Index,
1042    L: LimitProvider,
1043    SaLcp<I>: BucketRecord,
1044    B: BucketStore<SaLcp<I>> + Send,
1045    MkB: Fn(usize) -> B + Send + Sync,
1046{
1047    let _ = opts; // work_dir is used only by the ext-mem factory closure now
1048    let partition_buckets: Vec<Mutex<B>> = (0..p).map(|j| Mutex::new(mk_bucket(j))).collect();
1049
1050    subarray_buckets
1051        .par_iter_mut()
1052        .try_for_each(|sub_bucket| -> io::Result<()> {
1053            if sub_bucket.total_records() == 0 {
1054                return Ok(());
1055            }
1056            let records = sub_bucket.load_all()?;
1057
1058            // Find p-1 split points by binary-searching each pivot's
1059            // *upper bound* in the sorted subarray.
1060            let mut splits = Vec::with_capacity(p + 1);
1061            splits.push(0usize);
1062            for &pivot in pivots {
1063                splits.push(upper_bound_by_pivot(
1064                    &records,
1065                    pivot,
1066                    text,
1067                    lp,
1068                    opts.max_context,
1069                    dispatch,
1070                ));
1071            }
1072            splits.push(records.len());
1073
1074            // Distribute each sub-subarray. Reset the first record's
1075            // `lcp` to 0 so the per-partition merge sees a well-formed
1076            // boundary.
1077            for j in 0..p {
1078                let lo = splits[j];
1079                let hi = splits[j + 1];
1080                if lo >= hi {
1081                    continue;
1082                }
1083                let mut sub: Vec<SaLcp<I>> = records[lo..hi].to_vec();
1084                sub[0].lcp = I::zero();
1085                let mut bucket = partition_buckets[j].lock().unwrap();
1086                bucket.add_slice(&sub)?;
1087                bucket.mark_boundary();
1088            }
1089            Ok(())
1090        })?;
1091
1092    // Unwrap the Mutexes — at this point only this thread holds
1093    // references, so the locks are uncontended.
1094    Ok(partition_buckets
1095        .into_iter()
1096        .map(|m| m.into_inner().expect("partition mutex poisoned"))
1097        .collect())
1098}
1099
1100/// Upper-bound binary search: returns the first index `i` such that the
1101/// suffix at `records[i].pos` is **strictly greater than** the suffix at
1102/// `pivot`.
1103fn upper_bound_by_pivot<S, I, L>(
1104    records: &[SaLcp<I>],
1105    pivot: I,
1106    text: &[S],
1107    lp: &L,
1108    max_ctx: usize,
1109    dispatch: LcpDispatch,
1110) -> usize
1111where
1112    S: Symbol,
1113    I: Index,
1114    L: LimitProvider,
1115{
1116    let mut lo = 0;
1117    let mut hi = records.len();
1118    while lo < hi {
1119        let mid = lo + (hi - lo) / 2;
1120        match dispatch.suffix_cmp_with(
1121            text,
1122            lp,
1123            records[mid].pos.to_usize(),
1124            pivot.to_usize(),
1125            max_ctx,
1126        ) {
1127            Ordering::Greater => hi = mid,
1128            Ordering::Equal | Ordering::Less => lo = mid + 1,
1129        }
1130    }
1131    lo
1132}
1133
1134/// Phase 4 + 5: parallel-merge partitions in chunks of `num_threads`,
1135/// emitting each chunk's results in lex order before starting the next.
1136///
1137/// Each worker thread holds its own [`CascadeWorkspace`] for the duration
1138/// of one partition merge. Within a chunk, all `T` workspaces live in
1139/// parallel; between chunks, they are dropped (so peak workspace memory
1140/// scales with `T`, not with the number of partitions). The merged result
1141/// for each partition is then drained sequentially via `emit` to preserve
1142/// streaming-output order without ever holding the full SA in RAM.
1143///
1144/// Peak transient RAM ≈ `T × max_partition_size × 16 bytes` for the
1145/// merged-result buffers, plus the workspaces themselves (~`2 × T ×
1146/// max_partition_size × 16` bytes). On a typical run with `p = 4 × T`
1147/// subarrays the per-partition size is `≈ n / p`, so this stays
1148/// proportional to `n / 4 = 0.25 n` even at the peak — well below the
1149/// in-memory path's `~4 n` working set.
1150fn phase4_merge_and_emit<S, I, L, B, F>(
1151    text: &[S],
1152    lp: &L,
1153    partition_buckets: &mut [B],
1154    max_ctx: usize,
1155    emit: &mut F,
1156    dispatch: LcpDispatch,
1157) -> io::Result<()>
1158where
1159    S: Symbol,
1160    I: Index,
1161    L: LimitProvider,
1162    SaLcp<I>: BucketRecord,
1163    B: BucketStore<SaLcp<I>> + Send,
1164    F: FnMut(u64) -> io::Result<()>,
1165{
1166    let n_partitions = partition_buckets.len();
1167    if n_partitions == 0 {
1168        return Ok(());
1169    }
1170    // `chunk_size = 4 × num_threads` (not `= num_threads`): with one
1171    // partition per thread per chunk, rayon's `par_iter_mut` assigns
1172    // 1-to-1 with no opportunity to steal, and the chunk's wall is
1173    // set by its slowest partition. Sample-sort partition sizes vary
1174    // ~2× from random sampling, so the slow tail leaves ~half the
1175    // cores idle waiting (observed: 52% parallel efficiency on
1176    // GRCh38 / 32 t).
1177    //
1178    // Bumping the chunk to `4 × num_threads` gives rayon four
1179    // partitions per thread to dispatch — fast threads can steal from
1180    // slow neighbours, smoothing out the size variance. Peak RAM
1181    // grows linearly: each in-flight merged partition holds its
1182    // result `Vec<I>` (~3 MB at human-genome scale with `u32`
1183    // indices), so the chunk's transient cost goes from `32 × 3 MB =
1184    // 96 MB` to `128 × 3 MB = 384 MB` — well within the budget we
1185    // already spend on phase 1.
1186    let chunk_size = rayon::current_num_threads().max(1) * 4;
1187
1188    // Per-thread CPU-µs accumulators for the two parallel sub-steps. They
1189    // add across threads, so the printed values are CPU-time (sum), not
1190    // wall-time; the ratio between them still tells us where the work is.
1191    use std::sync::atomic::{AtomicU64, Ordering as AtomicOrdering};
1192    let profile = std::env::var_os("CAPS_SA_PROFILE").is_some();
1193    let load_us = AtomicU64::new(0);
1194    let merge_us = AtomicU64::new(0);
1195    let mut emit_secs: f64 = 0.0;
1196
1197    let mut start = 0;
1198    while start < n_partitions {
1199        let end = (start + chunk_size).min(n_partitions);
1200        let chunk = &mut partition_buckets[start..end];
1201
1202        // Parallel-merge each non-empty bucket in this chunk. `par_iter_mut`
1203        // preserves index order in the collected `Vec`, so the subsequent
1204        // sequential emit yields positions in lex order.
1205        let merged: Vec<Vec<I>> = chunk
1206            .par_iter_mut()
1207            .map(|bucket| -> io::Result<Vec<I>> {
1208                if bucket.total_records() == 0 {
1209                    return Ok(Vec::new());
1210                }
1211                let t = Instant::now();
1212                let records = bucket.load_all()?;
1213                let boundaries: Vec<usize> = bucket.boundaries().to_vec();
1214                if profile {
1215                    load_us.fetch_add(t.elapsed().as_micros() as u64, AtomicOrdering::Relaxed);
1216                }
1217
1218                let t = Instant::now();
1219                let workspace = CascadeWorkspace::<I>::new();
1220                // `cascade_merge` consumes the workspace and returns
1221                // the result side directly — the other three buffers
1222                // drop along with `workspace` here, without an
1223                // intermediate `to_vec()` copy.
1224                let result =
1225                    workspace.cascade_merge(text, lp, &records, &boundaries, max_ctx, dispatch);
1226                if profile {
1227                    merge_us.fetch_add(t.elapsed().as_micros() as u64, AtomicOrdering::Relaxed);
1228                }
1229                Ok(result)
1230            })
1231            .collect::<Result<Vec<_>, io::Error>>()?;
1232
1233        let t = Instant::now();
1234        for positions in merged {
1235            for pos in positions {
1236                // Widen back to the public `u64` emit contract.
1237                emit(pos.to_usize() as u64)?;
1238            }
1239        }
1240        if profile {
1241            emit_secs += t.elapsed().as_secs_f64();
1242        }
1243
1244        start = end;
1245    }
1246    if profile {
1247        profile_log(&format!(
1248            "phase4 breakdown CPU: load {:.3}s merge {:.3}s; wall emit {:.3}s",
1249            load_us.load(AtomicOrdering::Relaxed) as f64 * 1e-6,
1250            merge_us.load(AtomicOrdering::Relaxed) as f64 * 1e-6,
1251            emit_secs,
1252        ));
1253    }
1254    Ok(())
1255}
1256
1257/// Reusable ping-pong scratch for the partition cascade merge.
1258///
1259/// Holds two `(sa, lcp)` buffers each sized to the largest partition seen.
1260/// The cascade alternates reads from one side and writes to the other,
1261/// flipping a `src_is_a` flag after each level. Avoids the
1262/// per-level allocations that the previous immutable-`Vec` cascade
1263/// performed for every pair of sub-subarrays.
1264struct CascadeWorkspace<I> {
1265    a_sa: Vec<I>,
1266    a_lcp: Vec<I>,
1267    b_sa: Vec<I>,
1268    b_lcp: Vec<I>,
1269}
1270
1271impl<I: Index> CascadeWorkspace<I> {
1272    fn new() -> Self {
1273        Self {
1274            a_sa: Vec::new(),
1275            a_lcp: Vec::new(),
1276            b_sa: Vec::new(),
1277            b_lcp: Vec::new(),
1278        }
1279    }
1280
1281    /// Grow all four buffers to at least `n` elements. The contents past
1282    /// the cascade's actual run lengths are don't-care.
1283    fn ensure_capacity(&mut self, n: usize) {
1284        if self.a_sa.len() < n {
1285            self.a_sa.resize(n, I::zero());
1286            self.a_lcp.resize(n, I::zero());
1287            self.b_sa.resize(n, I::zero());
1288            self.b_lcp.resize(n, I::zero());
1289        }
1290    }
1291
1292    /// Cascade 2-way LCP-enhanced merges across the sub-subarrays of one
1293    /// partition (delimited by `boundaries`) until a single sorted run
1294    /// remains. **Consumes the workspace** and returns the result side
1295    /// as a `Vec<u64>`; the other three buffers (`a_lcp`, the opposing
1296    /// `*_sa`, the opposing `*_lcp`) drop immediately. This shape lets
1297    /// the caller skip the per-partition `to_vec()` round-trip that
1298    /// would otherwise sit briefly alongside all four workspace buffers
1299    /// at peak.
1300    fn cascade_merge<S, L>(
1301        mut self,
1302        text: &[S],
1303        lp: &L,
1304        records: &[SaLcp<I>],
1305        boundaries: &[usize],
1306        max_ctx: usize,
1307        dispatch: LcpDispatch,
1308    ) -> Vec<I>
1309    where
1310        S: Symbol,
1311        L: LimitProvider,
1312    {
1313        let n = records.len();
1314        if n == 0 {
1315            return Vec::new();
1316        }
1317        self.ensure_capacity(n);
1318
1319        // Initialize side A in SOA form from the AOS `records`, and
1320        // collect the lengths of the non-empty sub-subarrays.
1321        let mut run_lens: Vec<usize> = boundaries
1322            .windows(2)
1323            .filter_map(|w| {
1324                let l = w[1] - w[0];
1325                if l > 0 { Some(l) } else { None }
1326            })
1327            .collect();
1328        for (i, r) in records.iter().enumerate() {
1329            self.a_sa[i] = r.pos;
1330            self.a_lcp[i] = r.lcp;
1331        }
1332
1333        let mut src_is_a = true;
1334        while run_lens.len() > 1 {
1335            run_lens = self.merge_one_level(src_is_a, &run_lens, text, lp, max_ctx, dispatch);
1336            src_is_a = !src_is_a;
1337        }
1338
1339        // Take ownership of the buffer holding the result, truncate to
1340        // the actual record count, drop the other three buffers with
1341        // `self` going out of scope.
1342        let mut result = if src_is_a { self.a_sa } else { self.b_sa };
1343        result.truncate(n);
1344        result
1345    }
1346
1347    /// Pair the runs in `run_lens` (last odd one passes through unchanged),
1348    /// running each pair through the LCP-enhanced 2-way merge from the
1349    /// `src_is_a`-selected buffer side into the other. Returns the new
1350    /// run-length list (each entry is the sum of the two it replaced, or
1351    /// the carry-over for an odd tail).
1352    fn merge_one_level<S, L>(
1353        &mut self,
1354        src_is_a: bool,
1355        run_lens: &[usize],
1356        text: &[S],
1357        lp: &L,
1358        max_ctx: usize,
1359        dispatch: LcpDispatch,
1360    ) -> Vec<usize>
1361    where
1362        S: Symbol,
1363        L: LimitProvider,
1364    {
1365        // Destructure self so the borrow checker can see the two sides as
1366        // disjoint locals — we borrow one immutably and the other mutably.
1367        let Self {
1368            a_sa,
1369            a_lcp,
1370            b_sa,
1371            b_lcp,
1372        } = self;
1373        let (src_sa, src_lcp, dst_sa, dst_lcp) = if src_is_a {
1374            (
1375                a_sa.as_slice(),
1376                a_lcp.as_slice(),
1377                b_sa.as_mut_slice(),
1378                b_lcp.as_mut_slice(),
1379            )
1380        } else {
1381            (
1382                b_sa.as_slice(),
1383                b_lcp.as_slice(),
1384                a_sa.as_mut_slice(),
1385                a_lcp.as_mut_slice(),
1386            )
1387        };
1388
1389        let mut new_lens = Vec::with_capacity(run_lens.len().div_ceil(2));
1390        let mut src_off = 0usize;
1391        let mut dst_off = 0usize;
1392        let mut i = 0;
1393        while i < run_lens.len() {
1394            let l1 = run_lens[i];
1395            if i + 1 < run_lens.len() {
1396                let l2 = run_lens[i + 1];
1397                let x_end = src_off + l1;
1398                let xy_end = x_end + l2;
1399                let dst_end = dst_off + l1 + l2;
1400                sample_sort::merge(
1401                    text,
1402                    lp,
1403                    &src_sa[src_off..x_end],
1404                    &src_sa[x_end..xy_end],
1405                    &src_lcp[src_off..x_end],
1406                    &src_lcp[x_end..xy_end],
1407                    &mut dst_sa[dst_off..dst_end],
1408                    &mut dst_lcp[dst_off..dst_end],
1409                    max_ctx,
1410                    dispatch,
1411                );
1412                new_lens.push(l1 + l2);
1413                src_off = xy_end;
1414                dst_off = dst_end;
1415                i += 2;
1416            } else {
1417                // Odd run carries over unchanged.
1418                let end = dst_off + l1;
1419                dst_sa[dst_off..end].copy_from_slice(&src_sa[src_off..src_off + l1]);
1420                dst_lcp[dst_off..end].copy_from_slice(&src_lcp[src_off..src_off + l1]);
1421                new_lens.push(l1);
1422                src_off += l1;
1423                dst_off = end;
1424                i += 1;
1425            }
1426        }
1427        new_lens
1428    }
1429}
1430
1431#[cfg(test)]
1432mod tests {
1433    use super::*;
1434    use crate::build_in_memory;
1435    use tempfile::tempdir;
1436
1437    fn ext_mem_sa(text: &[u8], p: usize) -> Vec<u64> {
1438        let dir = tempdir().unwrap();
1439        let opts = ExtMemOpts {
1440            subproblem_count: p,
1441            work_dir: dir.path().to_path_buf(),
1442            ..ExtMemOpts::default()
1443        };
1444        let mut out: Vec<u64> = Vec::with_capacity(text.len());
1445        build_ext_mem(text, &opts, |pos| {
1446            out.push(pos);
1447            Ok(())
1448        })
1449        .unwrap();
1450        out
1451    }
1452
1453    fn assert_matches_in_memory(text: &[u8], p: usize) {
1454        let want: Vec<u32> = build_in_memory(text);
1455        let want64: Vec<u64> = want.iter().map(|&x| x as u64).collect();
1456        let got = ext_mem_sa(text, p);
1457        assert_eq!(got, want64, "mismatch on text {text:?} with p={p}");
1458    }
1459
1460    #[test]
1461    fn ext_mem_empty() {
1462        let got = ext_mem_sa(b"", 4);
1463        assert!(got.is_empty());
1464    }
1465
1466    #[test]
1467    fn ext_mem_single_partition() {
1468        assert_matches_in_memory(b"banana", 1);
1469    }
1470
1471    #[test]
1472    fn ext_mem_p_greater_than_n() {
1473        assert_matches_in_memory(b"abc", 10);
1474    }
1475
1476    #[test]
1477    fn ext_mem_banana_p4() {
1478        assert_matches_in_memory(b"banana", 4);
1479    }
1480
1481    #[test]
1482    fn ext_mem_mississippi_p3() {
1483        assert_matches_in_memory(b"mississippi", 3);
1484    }
1485
1486    #[test]
1487    fn ext_mem_random_byte_texts() {
1488        use rand::{RngExt, SeedableRng};
1489        let mut rng = rand::rngs::StdRng::seed_from_u64(0xCAFE);
1490        for &n in &[16usize, 100, 1000, 5000] {
1491            for &p in &[1usize, 2, 4, 16] {
1492                let text: Vec<u8> = (0..n).map(|_| rng.random_range(0..6u8)).collect();
1493                assert_matches_in_memory(&text, p);
1494            }
1495        }
1496    }
1497
1498    #[test]
1499    fn ext_mem_with_unique_terminator() {
1500        use rand::{RngExt, SeedableRng};
1501        let mut rng = rand::rngs::StdRng::seed_from_u64(0xF00D);
1502        for &n in &[10usize, 200, 2000] {
1503            for &p in &[1usize, 3, 8] {
1504                let mut text: Vec<u8> = (0..n).map(|_| rng.random_range(0..5u8)).collect();
1505                text.push(200);
1506                assert_matches_in_memory(&text, p);
1507            }
1508        }
1509    }
1510
1511    fn ext_mem_for_positions(text: &[u8], positions: Vec<u64>, p: usize) -> Vec<u64> {
1512        let dir = tempdir().unwrap();
1513        let opts = ExtMemOpts {
1514            subproblem_count: p,
1515            work_dir: dir.path().to_path_buf(),
1516            ..ExtMemOpts::default()
1517        };
1518        let mut out: Vec<u64> = Vec::with_capacity(positions.len());
1519        build_ext_mem_for_positions(text, positions, &opts, |pos| {
1520            out.push(pos);
1521            Ok(())
1522        })
1523        .unwrap();
1524        out
1525    }
1526
1527    #[test]
1528    fn ext_mem_for_positions_full_set_matches_ext_mem() {
1529        let text = b"mississippi";
1530        let want = ext_mem_sa(text, 3);
1531        let positions: Vec<u64> = (0..text.len() as u64).collect();
1532        let got = ext_mem_for_positions(text, positions, 3);
1533        assert_eq!(got, want);
1534    }
1535
1536    #[test]
1537    fn ext_mem_for_positions_subset_matches_brute_force() {
1538        let text = b"mississippi";
1539        let positions: Vec<u64> = (0..text.len() as u64).step_by(2).collect();
1540        let mut want = positions.clone();
1541        want.sort_by(|&a, &b| text[a as usize..].cmp(&text[b as usize..]));
1542        let got = ext_mem_for_positions(text, positions, 4);
1543        assert_eq!(got, want);
1544    }
1545
1546    #[test]
1547    fn ext_mem_for_positions_random_subsets() {
1548        use rand::{RngExt, SeedableRng};
1549        let mut rng = rand::rngs::StdRng::seed_from_u64(0xC0DE);
1550        for &n in &[50usize, 500, 2000] {
1551            for &p in &[1usize, 3, 8] {
1552                let text: Vec<u8> = (0..n).map(|_| rng.random_range(0..5u8)).collect();
1553                let mut positions: Vec<u64> = (0..n as u64).collect();
1554                positions.retain(|_| rng.random_range(0..10) < 7);
1555                let mut want = positions.clone();
1556                want.sort_by(|&a, &b| text[a as usize..].cmp(&text[b as usize..]));
1557                let got = ext_mem_for_positions(&text, positions, p);
1558                assert_eq!(got, want, "subset ext-mem mismatch n={n} p={p}");
1559            }
1560        }
1561    }
1562
1563    fn in_memory_sample_sort(text: &[u8], p: usize) -> Vec<u64> {
1564        let dir = tempdir().unwrap();
1565        let opts = ExtMemOpts {
1566            subproblem_count: p,
1567            work_dir: dir.path().to_path_buf(),
1568            ..ExtMemOpts::default()
1569        };
1570        let mut out: Vec<u64> = Vec::with_capacity(text.len());
1571        build_in_memory_sample_sort(text, &opts, |pos| {
1572            out.push(pos);
1573            Ok(())
1574        })
1575        .unwrap();
1576        out
1577    }
1578
1579    #[test]
1580    fn in_memory_sample_sort_matches_in_memory() {
1581        for text in [b"banana" as &[u8], b"mississippi", b"abracadabra"] {
1582            let want: Vec<u32> = build_in_memory(text);
1583            let want64: Vec<u64> = want.iter().map(|&x| x as u64).collect();
1584            let got = in_memory_sample_sort(text, 0);
1585            assert_eq!(got, want64, "in-mem sample-sort mismatch on {text:?}");
1586        }
1587    }
1588
1589    #[test]
1590    fn in_memory_sample_sort_random_byte_texts() {
1591        use rand::{RngExt, SeedableRng};
1592        let mut rng = rand::rngs::StdRng::seed_from_u64(0xC0DE_C0DE);
1593        for &n in &[16usize, 200, 2000] {
1594            for &p in &[1usize, 4, 16] {
1595                let text: Vec<u8> = (0..n).map(|_| rng.random_range(0..6u8)).collect();
1596                let want: Vec<u32> = build_in_memory(&text);
1597                let want64: Vec<u64> = want.iter().map(|&x| x as u64).collect();
1598                let got = in_memory_sample_sort(&text, p);
1599                assert_eq!(got, want64, "in-mem ss mismatch n={n} p={p}");
1600            }
1601        }
1602    }
1603
1604    /// Helper that drives [`build_ext_mem_for_filter`] and collects the
1605    /// emitted positions.
1606    fn ext_mem_for_filter<Pred>(text: &[u8], keep: Pred, p: usize) -> Vec<u64>
1607    where
1608        Pred: Fn(u64) -> bool + Send + Sync,
1609    {
1610        let dir = tempdir().unwrap();
1611        let opts = ExtMemOpts {
1612            subproblem_count: p,
1613            work_dir: dir.path().to_path_buf(),
1614            ..ExtMemOpts::default()
1615        };
1616        let mut out: Vec<u64> = Vec::new();
1617        build_ext_mem_for_filter(text, keep, &opts, |pos| {
1618            out.push(pos);
1619            Ok(())
1620        })
1621        .unwrap();
1622        out
1623    }
1624
1625    #[test]
1626    fn ext_mem_for_filter_matches_for_positions_on_full_set() {
1627        // Filter that accepts every position → must equal the
1628        // identity-positions ext-mem build.
1629        let text = b"mississippi";
1630        let want = ext_mem_sa(text, 3);
1631        let got = ext_mem_for_filter(text, |_p| true, 3);
1632        assert_eq!(got, want);
1633    }
1634
1635    #[test]
1636    fn ext_mem_for_filter_matches_for_positions_on_dna_subset() {
1637        // STAR-style "keep ACGT (`< 4`), drop N (`4`)/spacer (`5`)"
1638        // filter. The filter API must produce exactly the same SA as
1639        // pre-materialising the kept positions and going through the
1640        // _for_positions path.
1641        use rand::{RngExt, SeedableRng};
1642        let mut rng = rand::rngs::StdRng::seed_from_u64(0xCA75_5A);
1643        for &n in &[50usize, 500, 2000] {
1644            for &p in &[1usize, 3, 8] {
1645                let text: Vec<u8> = (0..n).map(|_| rng.random_range(0..6u8)).collect();
1646                let positions: Vec<u64> = (0..n as u64).filter(|&i| text[i as usize] < 4).collect();
1647                let want = ext_mem_for_positions(&text, positions, p);
1648                let got = ext_mem_for_filter(&text, |i| text[i as usize] < 4, p);
1649                assert_eq!(got, want, "filter vs positions mismatch n={n} p={p}");
1650            }
1651        }
1652    }
1653
1654    #[test]
1655    fn ext_mem_for_filter_handles_block_aligned_boundaries() {
1656        // Exercise the bitmap word/block boundaries by using a text
1657        // longer than one popcount block (1024 × 64 bits = 64 K
1658        // positions) — but stay under that to keep the test fast.
1659        // 200 K positions touches the cumsum's second block too.
1660        use rand::{RngExt, SeedableRng};
1661        let mut rng = rand::rngs::StdRng::seed_from_u64(0xB10C_C0DE);
1662        let n = 200_000usize;
1663        let text: Vec<u8> = (0..n).map(|_| rng.random_range(0..6u8)).collect();
1664        let positions: Vec<u64> = (0..n as u64).filter(|&i| text[i as usize] < 4).collect();
1665        let want = ext_mem_for_positions(&text, positions, 8);
1666        let got = ext_mem_for_filter(&text, |i| text[i as usize] < 4, 8);
1667        assert_eq!(got, want, "filter API mismatch across block boundaries");
1668    }
1669
1670    #[test]
1671    fn ext_mem_for_filter_sparse_predicate() {
1672        // ~5% acceptance — exercises long runs of zero-bits in the
1673        // bitmap (skip-loop across whole words).
1674        use rand::{RngExt, SeedableRng};
1675        let mut rng = rand::rngs::StdRng::seed_from_u64(0x5_AA_55);
1676        let n = 50_000usize;
1677        let text: Vec<u8> = (0..n).map(|_| rng.random_range(0..20u8)).collect();
1678        let positions: Vec<u64> = (0..n as u64).filter(|&i| text[i as usize] < 1).collect();
1679        let want = ext_mem_for_positions(&text, positions, 4);
1680        let got = ext_mem_for_filter(&text, |i| text[i as usize] < 1, 4);
1681        assert_eq!(got, want, "filter API mismatch on sparse predicate");
1682    }
1683
1684    #[test]
1685    fn ext_mem_repetitive_does_not_blow_up() {
1686        // Many copies of a long repeat — what killed the Phase 2 v1
1687        // linear-scan merge. The sample-sort + LCP-enhanced cascade
1688        // should handle it in proportional time.
1689        use std::time::Instant;
1690        let unit = b"ACGTACGTACGTACGTACGTACGTACGT"; // 28 bases
1691        let mut text: Vec<u8> = Vec::new();
1692        for _ in 0..100 {
1693            text.extend_from_slice(unit);
1694        }
1695        text.push(200);
1696        let start = Instant::now();
1697        let got = ext_mem_sa(&text, 8);
1698        let elapsed = start.elapsed();
1699        let want: Vec<u32> = build_in_memory(&text);
1700        let want64: Vec<u64> = want.iter().map(|&x| x as u64).collect();
1701        assert_eq!(got, want64);
1702        // Sanity: should finish in well under a second on this input.
1703        assert!(
1704            elapsed.as_secs() < 2,
1705            "ext-mem build on a tiny repetitive text took {elapsed:?}"
1706        );
1707    }
1708}