caps_sa/ext_mem.rs
1//! External-memory suffix array construction.
2//!
3//! Implements upstream CaPS-SA's *sample-sort + LCP-enhanced merge*
4//! external-memory algorithm:
5//!
6//! 1. **Sort + sample + spill.** Split positions `0..n` into `p` subarrays
7//! of size `n / p`. In parallel, sort each with the in-memory
8//! merge-sort kernel (`sample_sort::merge_sort`), sample ~`c · ln n`
9//! positions uniformly from each sorted subarray, and spill the sorted
10//! `(position, lcp)` records to a per-subarray
11//! [`ExtMemBucket`][crate::ext_bucket::ExtMemBucket].
12//! 2. **Select pivots.** Globally sort the pooled samples and pick
13//! `p − 1` pivots at evenly-spaced ranks, splitting the suffix order
14//! into `p` partitions.
15//! 3. **Distribute.** For each sorted subarray, binary-search the pivots
16//! to find its `p` sub-subarray split points; append each sub-subarray
17//! to the corresponding *partition* bucket, marking a boundary after
18//! each contribution.
19//! 4. **Per-partition merge.** Load each partition's bucket into RAM
20//! (≈ `n / p` records); cascade 2-way LCP-enhanced merges across the
21//! `p` sub-subarrays to produce that partition's globally-sorted slice.
22//! 5. **Stream output.** Iterate partitions in order and emit each
23//! position via the caller's closure.
24//!
25//! Peak RAM ≈ `text` + `O(n / p)` per active worker (one partition's
26//! merge working set). With `p = 4 × rayon::current_num_threads()` and
27//! `num_threads = 8`, that's a few hundred MB on a `n = 6e9` (human-scale)
28//! input — well below in-memory's `~4 × n × 8 = ~200 GB`. The full SA is
29//! never materialized in RAM.
30
31use std::cmp::Ordering;
32use std::io;
33use std::path::{Path, PathBuf};
34use std::sync::Mutex;
35use std::time::Instant;
36
37use rayon::prelude::*;
38
39use crate::Index;
40use crate::ext_bucket::{BucketPool, BucketRecord, BucketStore, InMemBucket, SaLcp};
41use crate::lcp::{LcpDispatch, Symbol};
42use crate::limits::{LimitProvider, PlainText};
43use crate::sample_sort;
44
45/// Emit a phase-timing line to stderr if `CAPS_SA_PROFILE` is set in
46/// the environment. Used to localise where the ext-mem path spends its
47/// time without paying the cost of always logging — see
48/// `bench/README.md` "Where AVX-512 helps and where it doesn't" for
49/// how this is used.
50fn profile_log(message: &str) {
51 if std::env::var_os("CAPS_SA_PROFILE").is_some() {
52 eprintln!("caps-sa profile {message}");
53 }
54}
55
56/// Tunable options for [`build_ext_mem`].
57#[derive(Clone, Debug)]
58pub struct ExtMemOpts {
59 /// Bound on LCP-extension comparisons inside merges. `usize::MAX`
60 /// (default) is unbounded.
61 pub max_context: usize,
62 /// Number of subarrays (`p` in upstream CaPS-SA). `0` (default) picks
63 /// `4 × rayon::current_num_threads()`, clamped to `[1, n]`.
64 pub subproblem_count: usize,
65 /// Directory for temp files. Defaults to [`std::env::temp_dir`].
66 pub work_dir: PathBuf,
67 /// Number of physical files in the bucket pool (one pool for the
68 /// phase-1 subarray buckets and a second for the phase-3 partition
69 /// buckets). `0` (default) picks `rayon::current_num_threads()` —
70 /// the right answer in practice: one writable inode per worker
71 /// keeps kernel-level write contention bounded.
72 ///
73 /// The `2 × p` logical buckets (typically thousands at genome
74 /// scale) collapse onto this pool of anonymous tempfiles via
75 /// `bucket_id % physical_file_count`. Larger values lower kernel
76 /// write contention; smaller values are kinder to networked
77 /// filesystems with high metadata cost. The `CAPS_SA_N_PHYS` env
78 /// var overrides this for one-off benches.
79 pub physical_file_count: usize,
80}
81
82impl Default for ExtMemOpts {
83 fn default() -> Self {
84 Self {
85 max_context: usize::MAX,
86 subproblem_count: 0,
87 work_dir: std::env::temp_dir(),
88 physical_file_count: 0,
89 }
90 }
91}
92
93impl ExtMemOpts {
94 /// Convenience constructor with the supplied `work_dir` and defaults
95 /// for everything else.
96 pub fn with_work_dir(work_dir: impl AsRef<Path>) -> Self {
97 Self {
98 work_dir: work_dir.as_ref().to_path_buf(),
99 ..Self::default()
100 }
101 }
102}
103
104/// Build the suffix array of `text` with bounded RAM, streaming each
105/// output position to `emit` in lexicographic order.
106///
107/// Returns an [`io::Error`] if temp-file I/O fails. The callback may also
108/// return an error to abort construction; partial work is discarded and
109/// temp files are cleaned up when their bucket drops.
110///
111/// Equivalent in semantics to [`crate::build_in_memory`]: produces a
112/// standard lexicographic suffix array with the "shorter suffix is
113/// smaller when one runs off the end" tie-break.
114pub fn build_ext_mem<S, F>(text: &[S], opts: &ExtMemOpts, emit: F) -> io::Result<()>
115where
116 S: Symbol,
117 F: FnMut(u64) -> io::Result<()>,
118{
119 build_ext_mem_with(text, &PlainText::new(text.len()), opts, emit)
120}
121
122/// Variant of [`build_ext_mem`] that accepts a [`LimitProvider`]. With
123/// [`PlainText`] this matches [`build_ext_mem`] exactly (and
124/// monomorphizes to identical assembly); with
125/// [`SegmentedText`][crate::limits::SegmentedText] the LCP scans stop
126/// at segment boundaries.
127pub fn build_ext_mem_with<S, L, F>(text: &[S], lp: &L, opts: &ExtMemOpts, emit: F) -> io::Result<()>
128where
129 S: Symbol,
130 L: LimitProvider,
131 F: FnMut(u64) -> io::Result<()>,
132{
133 // Dispatch on text size: when every suffix position fits in `u32`
134 // (n ≤ 2^32), use the narrow record type to halve all the SaLcp
135 // bytes — bucket disk I/O, phase-1 records, and the per-partition
136 // load in phase 4.
137 if text.len() <= u32::MAX as usize + 1 {
138 build_ext_mem_inner::<S, u32, L, F>(
139 text,
140 PositionSource::Identity(text.len()),
141 lp,
142 opts,
143 emit,
144 )
145 } else {
146 build_ext_mem_inner::<S, u64, L, F>(
147 text,
148 PositionSource::Identity(text.len()),
149 lp,
150 opts,
151 emit,
152 )
153 }
154}
155
156/// Like [`build_ext_mem`] but sorts only the caller-supplied
157/// `positions` by the lexicographic order of their suffixes in `text`.
158/// Suffix content is always `text[position..]`; no positions are
159/// dropped from the input. To filter, the caller constructs
160/// `positions` with only the indices they want.
161///
162/// This lets STAR-style genome indexing skip the bin-padding
163/// pathology: pass only the ACGT-starting positions and the
164/// spacer-suffix work disappears.
165///
166/// `positions` does not need to be pre-sorted in any way.
167pub fn build_ext_mem_for_positions<S, F>(
168 text: &[S],
169 positions: Vec<u64>,
170 opts: &ExtMemOpts,
171 emit: F,
172) -> io::Result<()>
173where
174 S: Symbol,
175 F: FnMut(u64) -> io::Result<()>,
176{
177 build_ext_mem_for_positions_with(text, positions, &PlainText::new(text.len()), opts, emit)
178}
179
180/// Variant of [`build_ext_mem_for_positions`] that accepts a
181/// [`LimitProvider`]. See [`build_ext_mem_with`] for the semantics.
182pub fn build_ext_mem_for_positions_with<S, L, F>(
183 text: &[S],
184 positions: Vec<u64>,
185 lp: &L,
186 opts: &ExtMemOpts,
187 emit: F,
188) -> io::Result<()>
189where
190 S: Symbol,
191 L: LimitProvider,
192 F: FnMut(u64) -> io::Result<()>,
193{
194 // We hold a reference to `positions` for the duration of the build;
195 // `phase1_sort_sample_spill` copies each chunk out. The Vec is
196 // dropped once phase 1 returns.
197 if text.len() <= u32::MAX as usize + 1 {
198 build_ext_mem_inner::<S, u32, L, F>(
199 text,
200 PositionSource::Subset(&positions),
201 lp,
202 opts,
203 emit,
204 )
205 } else {
206 build_ext_mem_inner::<S, u64, L, F>(
207 text,
208 PositionSource::Subset(&positions),
209 lp,
210 opts,
211 emit,
212 )
213 }
214}
215
216/// Like [`build_ext_mem_for_positions`] but takes a **predicate** over
217/// text positions instead of a pre-materialised `Vec<u64>` of kept
218/// positions.
219///
220/// caps-sa walks the predicate **once** to build a bitmap of kept
221/// positions + a tiny per-block popcount prefix-sum (together ~`n / 8`
222/// bytes — ~770 MB on the human genome, vs the ~50 GB the equivalent
223/// `Vec<u64>` would take). Phase 1's per-subarray fill is then driven
224/// by popcount-walking the bitmap; the predicate is **never invoked
225/// again** after the initial build. See [`FilteredSource`] for the
226/// memory accounting and the inner loop.
227///
228/// Use this entry when the caller already has the text in RAM and
229/// the kept positions are described by a cheap per-position
230/// predicate (e.g. STAR's `text[p] < 4` for ACGT-only suffix
231/// sampling). It is **the right entry for genome-scale inputs** —
232/// the `Vec<u64>` path can dominate peak RSS otherwise.
233///
234/// `keep` is invoked from rayon worker threads in parallel during
235/// the bitmap build; it must be `Send + Sync` (typically a plain
236/// closure capturing only `&[u8]` references is fine).
237pub fn build_ext_mem_for_filter<S, F, Pred>(
238 text: &[S],
239 keep: Pred,
240 opts: &ExtMemOpts,
241 emit: F,
242) -> io::Result<()>
243where
244 S: Symbol,
245 F: FnMut(u64) -> io::Result<()>,
246 Pred: Fn(u64) -> bool + Send + Sync,
247{
248 build_ext_mem_for_filter_with(text, keep, &PlainText::new(text.len()), opts, emit)
249}
250
251/// Variant of [`build_ext_mem_for_filter`] that accepts a
252/// [`LimitProvider`]. See [`build_ext_mem_with`] for the semantics.
253pub fn build_ext_mem_for_filter_with<S, L, F, Pred>(
254 text: &[S],
255 keep: Pred,
256 lp: &L,
257 opts: &ExtMemOpts,
258 emit: F,
259) -> io::Result<()>
260where
261 S: Symbol,
262 L: LimitProvider,
263 F: FnMut(u64) -> io::Result<()>,
264 Pred: Fn(u64) -> bool + Send + Sync,
265{
266 let filtered = FilteredSource::new(text.len(), keep);
267 if text.len() <= u32::MAX as usize + 1 {
268 build_ext_mem_inner::<S, u32, L, F>(
269 text,
270 PositionSource::Filtered(filtered),
271 lp,
272 opts,
273 emit,
274 )
275 } else {
276 build_ext_mem_inner::<S, u64, L, F>(
277 text,
278 PositionSource::Filtered(filtered),
279 lp,
280 opts,
281 emit,
282 )
283 }
284}
285
286fn build_ext_mem_inner<S, I, L, F>(
287 text: &[S],
288 source: PositionSource<'_>,
289 lp: &L,
290 opts: &ExtMemOpts,
291 mut emit: F,
292) -> io::Result<()>
293where
294 S: Symbol,
295 I: Index,
296 L: LimitProvider,
297 SaLcp<I>: BucketRecord,
298 F: FnMut(u64) -> io::Result<()>,
299{
300 let n = source.len();
301 if n == 0 {
302 return Ok(());
303 }
304 let p = effective_subproblem_count(n, opts.subproblem_count);
305 let dispatch = LcpDispatch::detect();
306 let work_dir = opts.work_dir.clone();
307
308 // Pool the `2 × p` bucket files into one anonymous tempfile per
309 // worker thread. With `p` in the thousands and `num_threads` in
310 // the dozens this collapses the openat/close/unlink budget from
311 // ~3·p (per-bucket-file path) to N (per-worker-file pool),
312 // eliminating the metadata-syscall pain on networked filesystems
313 // and the open-file-handle limit headache on tiny inputs. Per-
314 // bucket in-memory buffers and write volumes are unchanged, so
315 // local-disk wall time is neutral or marginally improved. See
316 // `bench/README.md` for the empirical sizing.
317 let n_phys = effective_physical_file_count(opts.physical_file_count);
318 let phase1_pool = BucketPool::new(n_phys, &work_dir)?;
319 let phase3_pool = BucketPool::new(n_phys, &work_dir)?;
320
321 profile_log(&format!(
322 "build_ext_mem n={n} p={p} index_width={}b n_phys={n_phys}",
323 std::mem::size_of::<I>() * 8
324 ));
325
326 let sub_factory = |i: usize| phase1_pool.new_bucket::<SaLcp<I>>(i);
327 let part_factory = |j: usize| phase3_pool.new_bucket::<SaLcp<I>>(j);
328
329 let t = Instant::now();
330 let (mut subarray_buckets, samples) = phase1_sort_sample_spill::<S, I, L, _, _>(
331 text,
332 lp,
333 &source,
334 p,
335 opts,
336 dispatch,
337 sub_factory,
338 )?;
339 profile_log(&format!(
340 "phase1 (sort+sample+spill) {:.3}s",
341 t.elapsed().as_secs_f64()
342 ));
343
344 // Drop the position source as soon as phase 1 returns — phases
345 // 2/3/4 don't touch it. For `PositionSource::Subset` this frees
346 // the caller's `Vec<u64>` (e.g. ~47 GB on a human-scale
347 // _for_positions build); for `PositionSource::Filtered` it
348 // frees the bitmap + cumsum (~770 MB); for `Identity` it's a
349 // no-op. The text and the spilled `subarray_buckets` are all
350 // phase 2+ needs.
351 drop(source);
352
353 let t = Instant::now();
354 let pivots = phase2_select_pivots::<S, I, L>(text, lp, samples, p, opts.max_context, dispatch);
355 profile_log(&format!(
356 "phase2 (select pivots) {:.3}s",
357 t.elapsed().as_secs_f64()
358 ));
359
360 let t = Instant::now();
361 let mut partition_buckets = phase3_distribute::<S, I, L, _, _>(
362 text,
363 lp,
364 &mut subarray_buckets,
365 &pivots,
366 p,
367 opts,
368 dispatch,
369 part_factory,
370 )?;
371 profile_log(&format!(
372 "phase3 (distribute) {:.3}s",
373 t.elapsed().as_secs_f64()
374 ));
375
376 drop(subarray_buckets);
377
378 let t = Instant::now();
379 let result = phase4_merge_and_emit::<S, I, L, _, F>(
380 text,
381 lp,
382 &mut partition_buckets,
383 opts.max_context,
384 &mut emit,
385 dispatch,
386 );
387 profile_log(&format!(
388 "phase4 (merge+emit) {:.3}s",
389 t.elapsed().as_secs_f64()
390 ));
391 result
392}
393
394/// Same algorithm as [`build_ext_mem_inner`] but with the disk-backed
395/// [`ExtMemBucket`] replaced by [`InMemBucket`] throughout — phase 1
396/// sorts each subarray and keeps the result in a `Vec<SaLcp<I>>`,
397/// phase 3 distributes into in-RAM partition Vecs, phase 4 cascade-
398/// merges the in-RAM partitions. No disk I/O.
399///
400/// Trades RAM for wall time: peak memory is ~`n × sizeof(SaLcp<I>)`
401/// (the post-phase-1 records sitting around until phase 3 consumes
402/// them), so ~25 GB on the human genome with `I = u32`. In exchange,
403/// the disk-spill / distribute-write / partition-load round-trip is
404/// gone — useful on machines with enough RAM to hold the working set.
405fn build_in_memory_ss_inner<S, I, L, F>(
406 text: &[S],
407 source: PositionSource<'_>,
408 lp: &L,
409 opts: &ExtMemOpts,
410 mut emit: F,
411) -> io::Result<()>
412where
413 S: Symbol,
414 I: Index,
415 L: LimitProvider,
416 SaLcp<I>: BucketRecord,
417 F: FnMut(u64) -> io::Result<()>,
418{
419 let n = source.len();
420 if n == 0 {
421 return Ok(());
422 }
423 let p = effective_subproblem_count(n, opts.subproblem_count);
424 let dispatch = LcpDispatch::detect();
425
426 let factory = |_i: usize| InMemBucket::<SaLcp<I>>::new();
427
428 let (mut subarray_buckets, samples) =
429 phase1_sort_sample_spill::<S, I, L, _, _>(text, lp, &source, p, opts, dispatch, factory)?;
430 // Same rationale as in `build_ext_mem_inner` — drop the source
431 // as soon as phase 1's `fill_chunk` calls have stopped.
432 drop(source);
433 let pivots = phase2_select_pivots::<S, I, L>(text, lp, samples, p, opts.max_context, dispatch);
434 let mut partition_buckets = phase3_distribute::<S, I, L, _, _>(
435 text,
436 lp,
437 &mut subarray_buckets,
438 &pivots,
439 p,
440 opts,
441 dispatch,
442 factory,
443 )?;
444 drop(subarray_buckets);
445 phase4_merge_and_emit::<S, I, L, _, F>(
446 text,
447 lp,
448 &mut partition_buckets,
449 opts.max_context,
450 &mut emit,
451 dispatch,
452 )
453}
454
455/// In-memory variant of the sample-sort algorithm used by
456/// [`build_ext_mem`]. Skips all disk I/O at the cost of holding the
457/// (`pos`, `lcp`) records in RAM throughout. Picks `u32` records when
458/// `n ≤ 2³²`, falls back to `u64` otherwise. The caller's `emit`
459/// closure is called once per output position in lex order, just like
460/// in the ext-mem path.
461pub fn build_in_memory_sample_sort<S, F>(text: &[S], opts: &ExtMemOpts, emit: F) -> io::Result<()>
462where
463 S: Symbol,
464 F: FnMut(u64) -> io::Result<()>,
465{
466 build_in_memory_sample_sort_with(text, &PlainText::new(text.len()), opts, emit)
467}
468
469/// Variant of [`build_in_memory_sample_sort`] that accepts a
470/// [`LimitProvider`].
471pub fn build_in_memory_sample_sort_with<S, L, F>(
472 text: &[S],
473 lp: &L,
474 opts: &ExtMemOpts,
475 emit: F,
476) -> io::Result<()>
477where
478 S: Symbol,
479 L: LimitProvider,
480 F: FnMut(u64) -> io::Result<()>,
481{
482 if text.len() <= u32::MAX as usize + 1 {
483 build_in_memory_ss_inner::<S, u32, L, F>(
484 text,
485 PositionSource::Identity(text.len()),
486 lp,
487 opts,
488 emit,
489 )
490 } else {
491 build_in_memory_ss_inner::<S, u64, L, F>(
492 text,
493 PositionSource::Identity(text.len()),
494 lp,
495 opts,
496 emit,
497 )
498 }
499}
500
501/// Subset-positions variant of [`build_in_memory_sample_sort`]. Same
502/// shape as [`build_ext_mem_for_positions`].
503pub fn build_in_memory_sample_sort_for_positions<S, F>(
504 text: &[S],
505 positions: Vec<u64>,
506 opts: &ExtMemOpts,
507 emit: F,
508) -> io::Result<()>
509where
510 S: Symbol,
511 F: FnMut(u64) -> io::Result<()>,
512{
513 build_in_memory_sample_sort_for_positions_with(
514 text,
515 positions,
516 &PlainText::new(text.len()),
517 opts,
518 emit,
519 )
520}
521
522/// Variant of [`build_in_memory_sample_sort_for_positions`] that
523/// accepts a [`LimitProvider`].
524pub fn build_in_memory_sample_sort_for_positions_with<S, L, F>(
525 text: &[S],
526 positions: Vec<u64>,
527 lp: &L,
528 opts: &ExtMemOpts,
529 emit: F,
530) -> io::Result<()>
531where
532 S: Symbol,
533 L: LimitProvider,
534 F: FnMut(u64) -> io::Result<()>,
535{
536 if text.len() <= u32::MAX as usize + 1 {
537 build_in_memory_ss_inner::<S, u32, L, F>(
538 text,
539 PositionSource::Subset(&positions),
540 lp,
541 opts,
542 emit,
543 )
544 } else {
545 build_in_memory_ss_inner::<S, u64, L, F>(
546 text,
547 PositionSource::Subset(&positions),
548 lp,
549 opts,
550 emit,
551 )
552 }
553}
554
555/// Source of the positions to sort.
556///
557/// - [`PositionSource::Identity`] is the all-suffixes case `0..n`;
558/// avoids materialising a `Vec<u64>` of length `n`, which on the
559/// human genome would itself be ~50 GB.
560/// - [`PositionSource::Subset`] holds a caller-supplied `&[u64]` of
561/// the positions to sort, in any order. Random-access by index.
562/// - [`PositionSource::Filtered`] is the streaming-predicate variant:
563/// the caller hands in a `Fn(u64) -> bool` over text positions and
564/// caps-sa walks the filter forward through the text on demand. A
565/// tiny prefix-sum (`text.len() / BLOCK_SIZE × 8` B ≈ 760 KB on
566/// the human genome at the 64 KB block size) lets `fill_chunk`
567/// locate the i-th kept position in `O(log n_blocks + BLOCK_SIZE +
568/// chunk_size)`. **No kept-positions list is materialised** — the
569/// memory saving over `Subset` is `≈ 8 × n_kept` bytes, dominant
570/// on genome-scale inputs.
571enum PositionSource<'a> {
572 Identity(usize),
573 Subset(&'a [u64]),
574 Filtered(FilteredSource),
575}
576
577/// Block size in **u64 words** for [`FilteredSource`]'s popcount
578/// prefix-sum. With 1024 words per block (64 K text bits) the
579/// prefix-sum stores one `u64` per block — ~760 KB for the human
580/// genome — and each `fill_chunk` walks at most one block (≈ 8 KB)
581/// in cache before emitting the first kept position.
582const FILTERED_WORDS_PER_BLOCK: usize = 1024;
583
584/// Streaming position source backed by a **bitmap** of kept positions
585/// + a per-block popcount prefix-sum. No `Vec<u64>` of kept positions
586/// is ever materialised.
587///
588/// Memory: `(n + 7) / 8` bytes for the bitmap (~770 MB on the human
589/// genome at `n ≈ 6.2 B`) + `8 × n_blocks` for the prefix sum
590/// (~760 KB). The 47 GB `Vec<u64>` the [`Subset`][PositionSource::Subset]
591/// variant requires goes away entirely.
592///
593/// Lookup path inside [`fill_chunk`]:
594/// 1. `partition_point` the prefix-sum to find the block containing
595/// the `start`-th kept position (`O(log n_blocks)` ≈ 17 ops).
596/// 2. Walk that block's bitmap words `popcount`-by-`popcount` until
597/// we've skipped the right number of set bits to reach `start`.
598/// 3. Walk forward through bitmap words using `trailing_zeros`-style
599/// iteration, emitting each set bit's position to `dst`. Inner
600/// loop is `O(chunk_size / 64)` u64 ops — no per-text-position
601/// closure calls, branch-light, cache-resident.
602///
603/// A truly `O(1)` select-1 structure (darray / Elias-Fano) on top
604/// of the bitmap would shrink step (1)+(2) further; with
605/// `chunk_size` ≈ 750 K and only `p` ≈ 8192 fill_chunk calls per
606/// build, the `O(log n_blocks)` + at-most-one-block-walk cost is
607/// already a rounding error. See `bench/README.md` for the
608/// follow-up note if that ever changes.
609struct FilteredSource {
610 text_len: usize,
611 total_kept: usize,
612 /// Bitmap: `(bitmap[w] >> b) & 1 == 1` iff position `64 * w + b`
613 /// is kept. Length = `text_len.div_ceil(64)`.
614 bitmap: Vec<u64>,
615 /// `cumsum[i] = sum of set bits in
616 /// `bitmap[0 .. i * FILTERED_WORDS_PER_BLOCK]`. Length =
617 /// `n_blocks + 1`; the last entry equals `total_kept`.
618 cumsum: Vec<u64>,
619}
620
621impl FilteredSource {
622 /// Build a [`FilteredSource`] by walking the predicate once over
623 /// `0..text_len` to fill the bitmap, then accumulating per-block
624 /// popcounts.
625 ///
626 /// The predicate is invoked exactly `text_len` times here; once
627 /// the bitmap is built, `fill_chunk` never calls it again. This
628 /// trades one full predicate pass for zero per-position calls
629 /// during all subsequent random-access `fill_chunk`s — a clear
630 /// win when (as in caps-sa's phase 1) every position is read at
631 /// least once.
632 fn new<Pred>(text_len: usize, keep: Pred) -> Self
633 where
634 Pred: Fn(u64) -> bool + Send + Sync,
635 {
636 let n_words = text_len.div_ceil(64);
637 // Parallel per-word bitmap build. Each word reads 64 text
638 // positions (clamped at `text_len`), packs them into a u64.
639 let bitmap: Vec<u64> = (0..n_words)
640 .into_par_iter()
641 .map(|w| {
642 let mut word: u64 = 0;
643 let base = (w as u64) * 64;
644 let limit = ((w + 1) * 64).min(text_len) - w * 64;
645 for b in 0..limit {
646 if keep(base + b as u64) {
647 word |= 1u64 << b;
648 }
649 }
650 word
651 })
652 .collect();
653
654 // Per-block popcount cumsum. Each block covers
655 // `FILTERED_WORDS_PER_BLOCK` words = `FILTERED_BITS_PER_BLOCK`
656 // text positions.
657 let n_blocks = n_words.div_ceil(FILTERED_WORDS_PER_BLOCK);
658 let per_block: Vec<u64> = (0..n_blocks)
659 .into_par_iter()
660 .map(|i| {
661 let start = i * FILTERED_WORDS_PER_BLOCK;
662 let end = ((i + 1) * FILTERED_WORDS_PER_BLOCK).min(n_words);
663 let mut c: u64 = 0;
664 for &word in &bitmap[start..end] {
665 c += word.count_ones() as u64;
666 }
667 c
668 })
669 .collect();
670 let mut cumsum = Vec::with_capacity(n_blocks + 1);
671 let mut s: u64 = 0;
672 cumsum.push(0);
673 for &k in &per_block {
674 s += k;
675 cumsum.push(s);
676 }
677 let total_kept = s as usize;
678 Self {
679 text_len,
680 total_kept,
681 bitmap,
682 cumsum,
683 }
684 }
685
686 /// Number of kept positions.
687 #[inline]
688 fn len(&self) -> usize {
689 self.total_kept
690 }
691
692 /// Fill `dst` with the next `dst.len()` kept positions starting
693 /// from the `start`-th (0-based) kept position. See type-level
694 /// doc for the algorithm; this is the hot path during phase 1
695 /// fill_chunk.
696 fn fill_chunk<I: Index>(&self, start: usize, dst: &mut [I]) {
697 debug_assert!(start + dst.len() <= self.total_kept);
698 if dst.is_empty() {
699 return;
700 }
701
702 // (1) Locate the block containing the `start`-th set bit.
703 // `partition_point(|c| c <= start)` gives the first cumsum
704 // entry strictly greater than `start`; previous index is the
705 // containing block.
706 let pp = self.cumsum.partition_point(|&c| c <= start as u64);
707 debug_assert!(pp > 0);
708 let block_idx = pp - 1;
709 let mut word_idx = block_idx * FILTERED_WORDS_PER_BLOCK;
710 let mut skip = start as u64 - self.cumsum[block_idx];
711
712 // (2) Skip the first `skip` set bits — possibly spanning
713 // several bitmap words. Whole words with `popcount ≤ skip`
714 // are consumed wholesale; the final partial word has its
715 // lowest `skip` set bits cleared so the emit loop sees only
716 // un-skipped 1s.
717 //
718 // Note: a naive `while skip >= 64 { … }` is wrong because a
719 // word's popcount can be far less than 64; we must subtract
720 // the actual popcount each iteration, not 64. This matters
721 // any time the bitmap is sparser than ~50%.
722 let n_words = self.bitmap.len();
723 let mut word: u64 = if word_idx < n_words { self.bitmap[word_idx] } else { 0 };
724 while skip > 0 {
725 let pc = word.count_ones() as u64;
726 if skip < pc {
727 // Consume `skip` lowest set bits inside the current
728 // word; emit loop continues from the remaining ones.
729 for _ in 0..skip {
730 word &= word - 1;
731 }
732 break;
733 }
734 // Skip ≥ pc: consume the whole word and advance.
735 skip -= pc;
736 word_idx += 1;
737 word = if word_idx < n_words { self.bitmap[word_idx] } else { 0 };
738 }
739
740 // (3) Walk `word`+subsequent words, emitting one position per
741 // set bit. Uses `trailing_zeros` to jump straight to the next
742 // 1 inside a word, then clears it via `word &= word - 1`.
743 let mut written = 0usize;
744 let need = dst.len();
745 loop {
746 while word != 0 && written < need {
747 let bit = word.trailing_zeros() as u64;
748 let pos = (word_idx as u64) * 64 + bit;
749 debug_assert!((pos as usize) < self.text_len);
750 dst[written] = I::from_usize(pos as usize);
751 written += 1;
752 word &= word - 1;
753 }
754 if written == need {
755 break;
756 }
757 word_idx += 1;
758 debug_assert!(
759 word_idx < n_words,
760 "FilteredSource::fill_chunk: walked past bitmap end \
761 ({written}/{need} emitted, word_idx={word_idx}, n_words={n_words})"
762 );
763 word = self.bitmap[word_idx];
764 }
765 }
766}
767
768impl<'a> PositionSource<'a> {
769 fn len(&self) -> usize {
770 match self {
771 Self::Identity(n) => *n,
772 Self::Subset(p) => p.len(),
773 Self::Filtered(f) => f.len(),
774 }
775 }
776
777 /// Fill `dst` with positions for the half-open subarray range
778 /// `[start, start + dst.len())`, narrowing the caller's `u64`
779 /// positions into `I` via [`Index::from_usize`]. For
780 /// [`PositionSource::Identity`] this generates the contiguous
781 /// integer range on the fly; for [`PositionSource::Subset`] it
782 /// reads from the caller's slice; for [`PositionSource::Filtered`]
783 /// it walks the predicate forward from the right text block.
784 fn fill_chunk<I: Index>(&self, start: usize, dst: &mut [I]) {
785 match self {
786 Self::Identity(_) => {
787 for (i, slot) in dst.iter_mut().enumerate() {
788 *slot = I::from_usize(start + i);
789 }
790 }
791 Self::Subset(p) => {
792 let end = start + dst.len();
793 for (slot, &v) in dst.iter_mut().zip(p[start..end].iter()) {
794 *slot = I::from_usize(v as usize);
795 }
796 }
797 Self::Filtered(f) => f.fill_chunk(start, dst),
798 }
799 }
800}
801
802/// Target subarray size used by [`effective_subproblem_count`] when
803/// auto-picking `p`. Smaller means more (smaller) subarrays — lower
804/// per-task phase-1 scratch, at the cost of more phase-3 distribute
805/// work (which scales as `O(p² · log(n/p))`, sequentially) and a
806/// higher temp-file count.
807const PHASE1_TARGET_CHUNK: usize = 65_536;
808/// Hard cap on the number of subarrays. Matches upstream CaPS-SA's
809/// default of 8192 — phase 3 is now parallelised across rayon
810/// workers (each subarray distributes independently into per-partition
811/// `Mutex<ExtMemBucket>` slots), so the `O(p²)` sequential distribute
812/// of the original design no longer constrains us. The cap is still
813/// finite to keep the temp-file count bounded.
814const PHASE1_MAX_PARTITIONS: usize = 8192;
815
816/// Resolve [`ExtMemOpts::physical_file_count`] for the current build.
817/// `0` (the default) means "let the runtime decide"; we pick
818/// `rayon::current_num_threads()` so the pool has one inode per
819/// concurrent writer, which empirically matches per-bucket-file wall
820/// time while collapsing thousands of small files into dozens of
821/// large ones. The `CAPS_SA_N_PHYS` env var overrides at the call
822/// site for benchmarks.
823fn effective_physical_file_count(requested: usize) -> usize {
824 if let Some(v) = std::env::var("CAPS_SA_N_PHYS")
825 .ok()
826 .and_then(|s| s.parse::<usize>().ok())
827 .filter(|&v| v >= 1)
828 {
829 return v;
830 }
831 if requested >= 1 {
832 return requested;
833 }
834 rayon::current_num_threads().max(1)
835}
836
837fn effective_subproblem_count(n: usize, requested: usize) -> usize {
838 if n == 0 {
839 return 0;
840 }
841 let raw = if requested == 0 {
842 let nthreads = rayon::current_num_threads().max(1);
843 let p_from_size = n.div_ceil(PHASE1_TARGET_CHUNK);
844 // At least one chunk per thread (otherwise we leave cores idle),
845 // at most `PHASE1_MAX_PARTITIONS` (so phase 3's sequential
846 // `O(p²)` sweep and the temp-file count stay manageable). For
847 // small inputs `p_from_size` is well below the cap, so the
848 // formula degrades gracefully to roughly "one chunk per thread";
849 // for human-scale inputs the cap binds and per-task scratch
850 // stays in the tens-of-MB range.
851 p_from_size.clamp(nthreads, PHASE1_MAX_PARTITIONS)
852 } else {
853 requested
854 };
855 raw.clamp(1, n)
856}
857
858/// Phase 1: sort each subarray in parallel, sample from it, and spill
859/// `(position, lcp)` records to its own [`ExtMemBucket`].
860///
861/// One rayon task per subarray; rayon's work-stealing scheduler keeps
862/// all worker threads busy and lets `merge_sort`'s inner
863/// [`rayon::join`] recursion steal idle slots. With the auto-picked
864/// `p` (target chunk ~ 64 K records), per-task scratch is ~18 MiB on
865/// human-scale inputs, so the `num_threads × per_task_scratch` peak
866/// stays bounded even though we don't reuse buffers across iterations.
867#[allow(clippy::too_many_arguments)]
868fn phase1_sort_sample_spill<S, I, L, B, MkB>(
869 text: &[S],
870 lp: &L,
871 source: &PositionSource<'_>,
872 p: usize,
873 opts: &ExtMemOpts,
874 dispatch: LcpDispatch,
875 mk_bucket: MkB,
876) -> io::Result<(Vec<B>, Vec<I>)>
877where
878 S: Symbol,
879 I: Index,
880 L: LimitProvider,
881 SaLcp<I>: BucketRecord,
882 B: BucketStore<SaLcp<I>> + Send,
883 MkB: Fn(usize) -> B + Send + Sync,
884{
885 let n = source.len();
886 let chunk_size = n.div_ceil(p);
887 let samples_target_total = sample_target_total(n, p);
888
889 let per_subarray: Vec<(B, Vec<I>)> = (0..p)
890 .into_par_iter()
891 .map(|i| {
892 let start = (i * chunk_size).min(n);
893 let end = ((i + 1) * chunk_size).min(n);
894 let len = end - start;
895
896 let mut bucket = mk_bucket(i);
897 if len == 0 {
898 return Ok::<_, io::Error>((bucket, Vec::new()));
899 }
900
901 // In-memory sort of this subarray with LCP maintenance.
902 let mut sa: Vec<I> = vec![I::zero(); len];
903 source.fill_chunk(start, &mut sa);
904 let mut sa_w = vec![I::zero(); len];
905 let mut lcp_arr = vec![I::zero(); len];
906 let mut lcp_w = vec![I::zero(); len];
907 sample_sort::merge_sort(
908 text,
909 lp,
910 &mut sa,
911 &mut sa_w,
912 &mut lcp_arr,
913 &mut lcp_w,
914 opts.max_context,
915 dispatch,
916 );
917
918 // Pull `samples_per_subarray` evenly-spaced positions out of
919 // the now-sorted subarray. Deterministic — no RNG needed for
920 // pivot selection to be globally well-distributed.
921 let samples_per_subarray = samples_target_total.div_ceil(p).min(len);
922 let samples = evenly_spaced(&sa, samples_per_subarray);
923
924 // Spill (position, lcp) records to the bucket. `lcp[0]`
925 // remains 0 (set by the merge-sort base case), making each
926 // subarray its own well-formed LCP-annotated sorted run.
927 let records: Vec<SaLcp<I>> = sa
928 .iter()
929 .zip(lcp_arr.iter())
930 .map(|(&pos, &lcp)| SaLcp { pos, lcp })
931 .collect();
932 bucket.add_slice(&records)?;
933
934 Ok((bucket, samples))
935 })
936 .collect::<Result<Vec<_>, _>>()?;
937
938 let mut buckets = Vec::with_capacity(p);
939 let mut all_samples = Vec::with_capacity(samples_target_total);
940 for (bucket, samples) in per_subarray {
941 buckets.push(bucket);
942 all_samples.extend(samples);
943 }
944 Ok((buckets, all_samples))
945}
946
947/// Target sample count *across all subarrays*. Matches upstream CaPS-SA's
948/// "`c · ln n`" rule per subarray with `c = 4`, so the global pool is
949/// `p · 4 · ln n` samples.
950fn sample_target_total(n: usize, p: usize) -> usize {
951 let ln_n = (n as f64).ln().max(1.0);
952 let per = (4.0 * ln_n).ceil() as usize;
953 // At least p (so we have enough to pick p-1 pivots) and at most n.
954 p.saturating_mul(per).clamp(p, n)
955}
956
957/// Pick `count` evenly-spaced elements from a slice. Deterministic, which
958/// keeps the algorithm reproducible without an RNG dependency.
959fn evenly_spaced<T: Copy>(xs: &[T], count: usize) -> Vec<T> {
960 let n = xs.len();
961 if count == 0 || n == 0 {
962 return Vec::new();
963 }
964 if count >= n {
965 return xs.to_vec();
966 }
967 // Pick indices at positions (i + 0.5) · n / count for i in 0..count,
968 // i.e. evenly-spaced midpoints. Avoids both endpoints — keeps pivots
969 // away from extreme corners of the order.
970 (0..count)
971 .map(|i| xs[(2 * i + 1) * n / (2 * count)])
972 .collect()
973}
974
975/// Phase 2: globally sort the pooled samples and pick `p − 1` pivots at
976/// evenly-spaced ranks.
977fn phase2_select_pivots<S, I, L>(
978 text: &[S],
979 lp: &L,
980 mut samples: Vec<I>,
981 p: usize,
982 max_ctx: usize,
983 dispatch: LcpDispatch,
984) -> Vec<I>
985where
986 S: Symbol,
987 I: Index,
988 L: LimitProvider,
989{
990 if p <= 1 || samples.is_empty() {
991 return Vec::new();
992 }
993 let n_samples = samples.len();
994 let mut sa_w = vec![I::zero(); n_samples];
995 let mut lcp = vec![I::zero(); n_samples];
996 let mut lcp_w = vec![I::zero(); n_samples];
997 sample_sort::merge_sort(
998 text,
999 lp,
1000 &mut samples,
1001 &mut sa_w,
1002 &mut lcp,
1003 &mut lcp_w,
1004 max_ctx,
1005 dispatch,
1006 );
1007
1008 // p-1 pivots at evenly-spaced ranks across the sorted sample pool.
1009 (1..p).map(|j| samples[(j * n_samples) / p]).collect()
1010}
1011
1012/// Phase 3: walk each subarray *in parallel*, load it into RAM,
1013/// binary-search the pivots to find its `p` sub-subarray boundaries,
1014/// and append each sub-subarray to the corresponding partition bucket.
1015///
1016/// Partition buckets are wrapped in a [`Mutex`] each so multiple
1017/// threads can write to different partitions concurrently without
1018/// shard-merging afterwards. With `p` in the thousands and `T` in the
1019/// tens, lock contention is negligible (probability that two threads
1020/// want the same partition at the same instant is `~T/p`); the lock
1021/// scope per acquisition is one `add_slice` + `mark_boundary` of a
1022/// few-KB sub-subarray.
1023///
1024/// Phase 4 doesn't care about the relative order of sub-subarrays
1025/// within a partition — only that each one between consecutive
1026/// boundaries is internally sorted. Both properties hold under
1027/// arbitrary thread interleaving.
1028#[allow(clippy::too_many_arguments)]
1029fn phase3_distribute<S, I, L, B, MkB>(
1030 text: &[S],
1031 lp: &L,
1032 subarray_buckets: &mut [B],
1033 pivots: &[I],
1034 p: usize,
1035 opts: &ExtMemOpts,
1036 dispatch: LcpDispatch,
1037 mk_bucket: MkB,
1038) -> io::Result<Vec<B>>
1039where
1040 S: Symbol,
1041 I: Index,
1042 L: LimitProvider,
1043 SaLcp<I>: BucketRecord,
1044 B: BucketStore<SaLcp<I>> + Send,
1045 MkB: Fn(usize) -> B + Send + Sync,
1046{
1047 let _ = opts; // work_dir is used only by the ext-mem factory closure now
1048 let partition_buckets: Vec<Mutex<B>> = (0..p).map(|j| Mutex::new(mk_bucket(j))).collect();
1049
1050 subarray_buckets
1051 .par_iter_mut()
1052 .try_for_each(|sub_bucket| -> io::Result<()> {
1053 if sub_bucket.total_records() == 0 {
1054 return Ok(());
1055 }
1056 let records = sub_bucket.load_all()?;
1057
1058 // Find p-1 split points by binary-searching each pivot's
1059 // *upper bound* in the sorted subarray.
1060 let mut splits = Vec::with_capacity(p + 1);
1061 splits.push(0usize);
1062 for &pivot in pivots {
1063 splits.push(upper_bound_by_pivot(
1064 &records,
1065 pivot,
1066 text,
1067 lp,
1068 opts.max_context,
1069 dispatch,
1070 ));
1071 }
1072 splits.push(records.len());
1073
1074 // Distribute each sub-subarray. Reset the first record's
1075 // `lcp` to 0 so the per-partition merge sees a well-formed
1076 // boundary.
1077 for j in 0..p {
1078 let lo = splits[j];
1079 let hi = splits[j + 1];
1080 if lo >= hi {
1081 continue;
1082 }
1083 let mut sub: Vec<SaLcp<I>> = records[lo..hi].to_vec();
1084 sub[0].lcp = I::zero();
1085 let mut bucket = partition_buckets[j].lock().unwrap();
1086 bucket.add_slice(&sub)?;
1087 bucket.mark_boundary();
1088 }
1089 Ok(())
1090 })?;
1091
1092 // Unwrap the Mutexes — at this point only this thread holds
1093 // references, so the locks are uncontended.
1094 Ok(partition_buckets
1095 .into_iter()
1096 .map(|m| m.into_inner().expect("partition mutex poisoned"))
1097 .collect())
1098}
1099
1100/// Upper-bound binary search: returns the first index `i` such that the
1101/// suffix at `records[i].pos` is **strictly greater than** the suffix at
1102/// `pivot`.
1103fn upper_bound_by_pivot<S, I, L>(
1104 records: &[SaLcp<I>],
1105 pivot: I,
1106 text: &[S],
1107 lp: &L,
1108 max_ctx: usize,
1109 dispatch: LcpDispatch,
1110) -> usize
1111where
1112 S: Symbol,
1113 I: Index,
1114 L: LimitProvider,
1115{
1116 let mut lo = 0;
1117 let mut hi = records.len();
1118 while lo < hi {
1119 let mid = lo + (hi - lo) / 2;
1120 match dispatch.suffix_cmp_with(
1121 text,
1122 lp,
1123 records[mid].pos.to_usize(),
1124 pivot.to_usize(),
1125 max_ctx,
1126 ) {
1127 Ordering::Greater => hi = mid,
1128 Ordering::Equal | Ordering::Less => lo = mid + 1,
1129 }
1130 }
1131 lo
1132}
1133
1134/// Phase 4 + 5: parallel-merge partitions in chunks of `num_threads`,
1135/// emitting each chunk's results in lex order before starting the next.
1136///
1137/// Each worker thread holds its own [`CascadeWorkspace`] for the duration
1138/// of one partition merge. Within a chunk, all `T` workspaces live in
1139/// parallel; between chunks, they are dropped (so peak workspace memory
1140/// scales with `T`, not with the number of partitions). The merged result
1141/// for each partition is then drained sequentially via `emit` to preserve
1142/// streaming-output order without ever holding the full SA in RAM.
1143///
1144/// Peak transient RAM ≈ `T × max_partition_size × 16 bytes` for the
1145/// merged-result buffers, plus the workspaces themselves (~`2 × T ×
1146/// max_partition_size × 16` bytes). On a typical run with `p = 4 × T`
1147/// subarrays the per-partition size is `≈ n / p`, so this stays
1148/// proportional to `n / 4 = 0.25 n` even at the peak — well below the
1149/// in-memory path's `~4 n` working set.
1150fn phase4_merge_and_emit<S, I, L, B, F>(
1151 text: &[S],
1152 lp: &L,
1153 partition_buckets: &mut [B],
1154 max_ctx: usize,
1155 emit: &mut F,
1156 dispatch: LcpDispatch,
1157) -> io::Result<()>
1158where
1159 S: Symbol,
1160 I: Index,
1161 L: LimitProvider,
1162 SaLcp<I>: BucketRecord,
1163 B: BucketStore<SaLcp<I>> + Send,
1164 F: FnMut(u64) -> io::Result<()>,
1165{
1166 let n_partitions = partition_buckets.len();
1167 if n_partitions == 0 {
1168 return Ok(());
1169 }
1170 // `chunk_size = 4 × num_threads` (not `= num_threads`): with one
1171 // partition per thread per chunk, rayon's `par_iter_mut` assigns
1172 // 1-to-1 with no opportunity to steal, and the chunk's wall is
1173 // set by its slowest partition. Sample-sort partition sizes vary
1174 // ~2× from random sampling, so the slow tail leaves ~half the
1175 // cores idle waiting (observed: 52% parallel efficiency on
1176 // GRCh38 / 32 t).
1177 //
1178 // Bumping the chunk to `4 × num_threads` gives rayon four
1179 // partitions per thread to dispatch — fast threads can steal from
1180 // slow neighbours, smoothing out the size variance. Peak RAM
1181 // grows linearly: each in-flight merged partition holds its
1182 // result `Vec<I>` (~3 MB at human-genome scale with `u32`
1183 // indices), so the chunk's transient cost goes from `32 × 3 MB =
1184 // 96 MB` to `128 × 3 MB = 384 MB` — well within the budget we
1185 // already spend on phase 1.
1186 let chunk_size = rayon::current_num_threads().max(1) * 4;
1187
1188 // Per-thread CPU-µs accumulators for the two parallel sub-steps. They
1189 // add across threads, so the printed values are CPU-time (sum), not
1190 // wall-time; the ratio between them still tells us where the work is.
1191 use std::sync::atomic::{AtomicU64, Ordering as AtomicOrdering};
1192 let profile = std::env::var_os("CAPS_SA_PROFILE").is_some();
1193 let load_us = AtomicU64::new(0);
1194 let merge_us = AtomicU64::new(0);
1195 let mut emit_secs: f64 = 0.0;
1196
1197 let mut start = 0;
1198 while start < n_partitions {
1199 let end = (start + chunk_size).min(n_partitions);
1200 let chunk = &mut partition_buckets[start..end];
1201
1202 // Parallel-merge each non-empty bucket in this chunk. `par_iter_mut`
1203 // preserves index order in the collected `Vec`, so the subsequent
1204 // sequential emit yields positions in lex order.
1205 let merged: Vec<Vec<I>> = chunk
1206 .par_iter_mut()
1207 .map(|bucket| -> io::Result<Vec<I>> {
1208 if bucket.total_records() == 0 {
1209 return Ok(Vec::new());
1210 }
1211 let t = Instant::now();
1212 let records = bucket.load_all()?;
1213 let boundaries: Vec<usize> = bucket.boundaries().to_vec();
1214 if profile {
1215 load_us.fetch_add(t.elapsed().as_micros() as u64, AtomicOrdering::Relaxed);
1216 }
1217
1218 let t = Instant::now();
1219 let workspace = CascadeWorkspace::<I>::new();
1220 // `cascade_merge` consumes the workspace and returns
1221 // the result side directly — the other three buffers
1222 // drop along with `workspace` here, without an
1223 // intermediate `to_vec()` copy.
1224 let result =
1225 workspace.cascade_merge(text, lp, &records, &boundaries, max_ctx, dispatch);
1226 if profile {
1227 merge_us.fetch_add(t.elapsed().as_micros() as u64, AtomicOrdering::Relaxed);
1228 }
1229 Ok(result)
1230 })
1231 .collect::<Result<Vec<_>, io::Error>>()?;
1232
1233 let t = Instant::now();
1234 for positions in merged {
1235 for pos in positions {
1236 // Widen back to the public `u64` emit contract.
1237 emit(pos.to_usize() as u64)?;
1238 }
1239 }
1240 if profile {
1241 emit_secs += t.elapsed().as_secs_f64();
1242 }
1243
1244 start = end;
1245 }
1246 if profile {
1247 profile_log(&format!(
1248 "phase4 breakdown CPU: load {:.3}s merge {:.3}s; wall emit {:.3}s",
1249 load_us.load(AtomicOrdering::Relaxed) as f64 * 1e-6,
1250 merge_us.load(AtomicOrdering::Relaxed) as f64 * 1e-6,
1251 emit_secs,
1252 ));
1253 }
1254 Ok(())
1255}
1256
1257/// Reusable ping-pong scratch for the partition cascade merge.
1258///
1259/// Holds two `(sa, lcp)` buffers each sized to the largest partition seen.
1260/// The cascade alternates reads from one side and writes to the other,
1261/// flipping a `src_is_a` flag after each level. Avoids the
1262/// per-level allocations that the previous immutable-`Vec` cascade
1263/// performed for every pair of sub-subarrays.
1264struct CascadeWorkspace<I> {
1265 a_sa: Vec<I>,
1266 a_lcp: Vec<I>,
1267 b_sa: Vec<I>,
1268 b_lcp: Vec<I>,
1269}
1270
1271impl<I: Index> CascadeWorkspace<I> {
1272 fn new() -> Self {
1273 Self {
1274 a_sa: Vec::new(),
1275 a_lcp: Vec::new(),
1276 b_sa: Vec::new(),
1277 b_lcp: Vec::new(),
1278 }
1279 }
1280
1281 /// Grow all four buffers to at least `n` elements. The contents past
1282 /// the cascade's actual run lengths are don't-care.
1283 fn ensure_capacity(&mut self, n: usize) {
1284 if self.a_sa.len() < n {
1285 self.a_sa.resize(n, I::zero());
1286 self.a_lcp.resize(n, I::zero());
1287 self.b_sa.resize(n, I::zero());
1288 self.b_lcp.resize(n, I::zero());
1289 }
1290 }
1291
1292 /// Cascade 2-way LCP-enhanced merges across the sub-subarrays of one
1293 /// partition (delimited by `boundaries`) until a single sorted run
1294 /// remains. **Consumes the workspace** and returns the result side
1295 /// as a `Vec<u64>`; the other three buffers (`a_lcp`, the opposing
1296 /// `*_sa`, the opposing `*_lcp`) drop immediately. This shape lets
1297 /// the caller skip the per-partition `to_vec()` round-trip that
1298 /// would otherwise sit briefly alongside all four workspace buffers
1299 /// at peak.
1300 fn cascade_merge<S, L>(
1301 mut self,
1302 text: &[S],
1303 lp: &L,
1304 records: &[SaLcp<I>],
1305 boundaries: &[usize],
1306 max_ctx: usize,
1307 dispatch: LcpDispatch,
1308 ) -> Vec<I>
1309 where
1310 S: Symbol,
1311 L: LimitProvider,
1312 {
1313 let n = records.len();
1314 if n == 0 {
1315 return Vec::new();
1316 }
1317 self.ensure_capacity(n);
1318
1319 // Initialize side A in SOA form from the AOS `records`, and
1320 // collect the lengths of the non-empty sub-subarrays.
1321 let mut run_lens: Vec<usize> = boundaries
1322 .windows(2)
1323 .filter_map(|w| {
1324 let l = w[1] - w[0];
1325 if l > 0 { Some(l) } else { None }
1326 })
1327 .collect();
1328 for (i, r) in records.iter().enumerate() {
1329 self.a_sa[i] = r.pos;
1330 self.a_lcp[i] = r.lcp;
1331 }
1332
1333 let mut src_is_a = true;
1334 while run_lens.len() > 1 {
1335 run_lens = self.merge_one_level(src_is_a, &run_lens, text, lp, max_ctx, dispatch);
1336 src_is_a = !src_is_a;
1337 }
1338
1339 // Take ownership of the buffer holding the result, truncate to
1340 // the actual record count, drop the other three buffers with
1341 // `self` going out of scope.
1342 let mut result = if src_is_a { self.a_sa } else { self.b_sa };
1343 result.truncate(n);
1344 result
1345 }
1346
1347 /// Pair the runs in `run_lens` (last odd one passes through unchanged),
1348 /// running each pair through the LCP-enhanced 2-way merge from the
1349 /// `src_is_a`-selected buffer side into the other. Returns the new
1350 /// run-length list (each entry is the sum of the two it replaced, or
1351 /// the carry-over for an odd tail).
1352 fn merge_one_level<S, L>(
1353 &mut self,
1354 src_is_a: bool,
1355 run_lens: &[usize],
1356 text: &[S],
1357 lp: &L,
1358 max_ctx: usize,
1359 dispatch: LcpDispatch,
1360 ) -> Vec<usize>
1361 where
1362 S: Symbol,
1363 L: LimitProvider,
1364 {
1365 // Destructure self so the borrow checker can see the two sides as
1366 // disjoint locals — we borrow one immutably and the other mutably.
1367 let Self {
1368 a_sa,
1369 a_lcp,
1370 b_sa,
1371 b_lcp,
1372 } = self;
1373 let (src_sa, src_lcp, dst_sa, dst_lcp) = if src_is_a {
1374 (
1375 a_sa.as_slice(),
1376 a_lcp.as_slice(),
1377 b_sa.as_mut_slice(),
1378 b_lcp.as_mut_slice(),
1379 )
1380 } else {
1381 (
1382 b_sa.as_slice(),
1383 b_lcp.as_slice(),
1384 a_sa.as_mut_slice(),
1385 a_lcp.as_mut_slice(),
1386 )
1387 };
1388
1389 let mut new_lens = Vec::with_capacity(run_lens.len().div_ceil(2));
1390 let mut src_off = 0usize;
1391 let mut dst_off = 0usize;
1392 let mut i = 0;
1393 while i < run_lens.len() {
1394 let l1 = run_lens[i];
1395 if i + 1 < run_lens.len() {
1396 let l2 = run_lens[i + 1];
1397 let x_end = src_off + l1;
1398 let xy_end = x_end + l2;
1399 let dst_end = dst_off + l1 + l2;
1400 sample_sort::merge(
1401 text,
1402 lp,
1403 &src_sa[src_off..x_end],
1404 &src_sa[x_end..xy_end],
1405 &src_lcp[src_off..x_end],
1406 &src_lcp[x_end..xy_end],
1407 &mut dst_sa[dst_off..dst_end],
1408 &mut dst_lcp[dst_off..dst_end],
1409 max_ctx,
1410 dispatch,
1411 );
1412 new_lens.push(l1 + l2);
1413 src_off = xy_end;
1414 dst_off = dst_end;
1415 i += 2;
1416 } else {
1417 // Odd run carries over unchanged.
1418 let end = dst_off + l1;
1419 dst_sa[dst_off..end].copy_from_slice(&src_sa[src_off..src_off + l1]);
1420 dst_lcp[dst_off..end].copy_from_slice(&src_lcp[src_off..src_off + l1]);
1421 new_lens.push(l1);
1422 src_off += l1;
1423 dst_off = end;
1424 i += 1;
1425 }
1426 }
1427 new_lens
1428 }
1429}
1430
1431#[cfg(test)]
1432mod tests {
1433 use super::*;
1434 use crate::build_in_memory;
1435 use tempfile::tempdir;
1436
1437 fn ext_mem_sa(text: &[u8], p: usize) -> Vec<u64> {
1438 let dir = tempdir().unwrap();
1439 let opts = ExtMemOpts {
1440 subproblem_count: p,
1441 work_dir: dir.path().to_path_buf(),
1442 ..ExtMemOpts::default()
1443 };
1444 let mut out: Vec<u64> = Vec::with_capacity(text.len());
1445 build_ext_mem(text, &opts, |pos| {
1446 out.push(pos);
1447 Ok(())
1448 })
1449 .unwrap();
1450 out
1451 }
1452
1453 fn assert_matches_in_memory(text: &[u8], p: usize) {
1454 let want: Vec<u32> = build_in_memory(text);
1455 let want64: Vec<u64> = want.iter().map(|&x| x as u64).collect();
1456 let got = ext_mem_sa(text, p);
1457 assert_eq!(got, want64, "mismatch on text {text:?} with p={p}");
1458 }
1459
1460 #[test]
1461 fn ext_mem_empty() {
1462 let got = ext_mem_sa(b"", 4);
1463 assert!(got.is_empty());
1464 }
1465
1466 #[test]
1467 fn ext_mem_single_partition() {
1468 assert_matches_in_memory(b"banana", 1);
1469 }
1470
1471 #[test]
1472 fn ext_mem_p_greater_than_n() {
1473 assert_matches_in_memory(b"abc", 10);
1474 }
1475
1476 #[test]
1477 fn ext_mem_banana_p4() {
1478 assert_matches_in_memory(b"banana", 4);
1479 }
1480
1481 #[test]
1482 fn ext_mem_mississippi_p3() {
1483 assert_matches_in_memory(b"mississippi", 3);
1484 }
1485
1486 #[test]
1487 fn ext_mem_random_byte_texts() {
1488 use rand::{RngExt, SeedableRng};
1489 let mut rng = rand::rngs::StdRng::seed_from_u64(0xCAFE);
1490 for &n in &[16usize, 100, 1000, 5000] {
1491 for &p in &[1usize, 2, 4, 16] {
1492 let text: Vec<u8> = (0..n).map(|_| rng.random_range(0..6u8)).collect();
1493 assert_matches_in_memory(&text, p);
1494 }
1495 }
1496 }
1497
1498 #[test]
1499 fn ext_mem_with_unique_terminator() {
1500 use rand::{RngExt, SeedableRng};
1501 let mut rng = rand::rngs::StdRng::seed_from_u64(0xF00D);
1502 for &n in &[10usize, 200, 2000] {
1503 for &p in &[1usize, 3, 8] {
1504 let mut text: Vec<u8> = (0..n).map(|_| rng.random_range(0..5u8)).collect();
1505 text.push(200);
1506 assert_matches_in_memory(&text, p);
1507 }
1508 }
1509 }
1510
1511 fn ext_mem_for_positions(text: &[u8], positions: Vec<u64>, p: usize) -> Vec<u64> {
1512 let dir = tempdir().unwrap();
1513 let opts = ExtMemOpts {
1514 subproblem_count: p,
1515 work_dir: dir.path().to_path_buf(),
1516 ..ExtMemOpts::default()
1517 };
1518 let mut out: Vec<u64> = Vec::with_capacity(positions.len());
1519 build_ext_mem_for_positions(text, positions, &opts, |pos| {
1520 out.push(pos);
1521 Ok(())
1522 })
1523 .unwrap();
1524 out
1525 }
1526
1527 #[test]
1528 fn ext_mem_for_positions_full_set_matches_ext_mem() {
1529 let text = b"mississippi";
1530 let want = ext_mem_sa(text, 3);
1531 let positions: Vec<u64> = (0..text.len() as u64).collect();
1532 let got = ext_mem_for_positions(text, positions, 3);
1533 assert_eq!(got, want);
1534 }
1535
1536 #[test]
1537 fn ext_mem_for_positions_subset_matches_brute_force() {
1538 let text = b"mississippi";
1539 let positions: Vec<u64> = (0..text.len() as u64).step_by(2).collect();
1540 let mut want = positions.clone();
1541 want.sort_by(|&a, &b| text[a as usize..].cmp(&text[b as usize..]));
1542 let got = ext_mem_for_positions(text, positions, 4);
1543 assert_eq!(got, want);
1544 }
1545
1546 #[test]
1547 fn ext_mem_for_positions_random_subsets() {
1548 use rand::{RngExt, SeedableRng};
1549 let mut rng = rand::rngs::StdRng::seed_from_u64(0xC0DE);
1550 for &n in &[50usize, 500, 2000] {
1551 for &p in &[1usize, 3, 8] {
1552 let text: Vec<u8> = (0..n).map(|_| rng.random_range(0..5u8)).collect();
1553 let mut positions: Vec<u64> = (0..n as u64).collect();
1554 positions.retain(|_| rng.random_range(0..10) < 7);
1555 let mut want = positions.clone();
1556 want.sort_by(|&a, &b| text[a as usize..].cmp(&text[b as usize..]));
1557 let got = ext_mem_for_positions(&text, positions, p);
1558 assert_eq!(got, want, "subset ext-mem mismatch n={n} p={p}");
1559 }
1560 }
1561 }
1562
1563 fn in_memory_sample_sort(text: &[u8], p: usize) -> Vec<u64> {
1564 let dir = tempdir().unwrap();
1565 let opts = ExtMemOpts {
1566 subproblem_count: p,
1567 work_dir: dir.path().to_path_buf(),
1568 ..ExtMemOpts::default()
1569 };
1570 let mut out: Vec<u64> = Vec::with_capacity(text.len());
1571 build_in_memory_sample_sort(text, &opts, |pos| {
1572 out.push(pos);
1573 Ok(())
1574 })
1575 .unwrap();
1576 out
1577 }
1578
1579 #[test]
1580 fn in_memory_sample_sort_matches_in_memory() {
1581 for text in [b"banana" as &[u8], b"mississippi", b"abracadabra"] {
1582 let want: Vec<u32> = build_in_memory(text);
1583 let want64: Vec<u64> = want.iter().map(|&x| x as u64).collect();
1584 let got = in_memory_sample_sort(text, 0);
1585 assert_eq!(got, want64, "in-mem sample-sort mismatch on {text:?}");
1586 }
1587 }
1588
1589 #[test]
1590 fn in_memory_sample_sort_random_byte_texts() {
1591 use rand::{RngExt, SeedableRng};
1592 let mut rng = rand::rngs::StdRng::seed_from_u64(0xC0DE_C0DE);
1593 for &n in &[16usize, 200, 2000] {
1594 for &p in &[1usize, 4, 16] {
1595 let text: Vec<u8> = (0..n).map(|_| rng.random_range(0..6u8)).collect();
1596 let want: Vec<u32> = build_in_memory(&text);
1597 let want64: Vec<u64> = want.iter().map(|&x| x as u64).collect();
1598 let got = in_memory_sample_sort(&text, p);
1599 assert_eq!(got, want64, "in-mem ss mismatch n={n} p={p}");
1600 }
1601 }
1602 }
1603
1604 /// Helper that drives [`build_ext_mem_for_filter`] and collects the
1605 /// emitted positions.
1606 fn ext_mem_for_filter<Pred>(text: &[u8], keep: Pred, p: usize) -> Vec<u64>
1607 where
1608 Pred: Fn(u64) -> bool + Send + Sync,
1609 {
1610 let dir = tempdir().unwrap();
1611 let opts = ExtMemOpts {
1612 subproblem_count: p,
1613 work_dir: dir.path().to_path_buf(),
1614 ..ExtMemOpts::default()
1615 };
1616 let mut out: Vec<u64> = Vec::new();
1617 build_ext_mem_for_filter(text, keep, &opts, |pos| {
1618 out.push(pos);
1619 Ok(())
1620 })
1621 .unwrap();
1622 out
1623 }
1624
1625 #[test]
1626 fn ext_mem_for_filter_matches_for_positions_on_full_set() {
1627 // Filter that accepts every position → must equal the
1628 // identity-positions ext-mem build.
1629 let text = b"mississippi";
1630 let want = ext_mem_sa(text, 3);
1631 let got = ext_mem_for_filter(text, |_p| true, 3);
1632 assert_eq!(got, want);
1633 }
1634
1635 #[test]
1636 fn ext_mem_for_filter_matches_for_positions_on_dna_subset() {
1637 // STAR-style "keep ACGT (`< 4`), drop N (`4`)/spacer (`5`)"
1638 // filter. The filter API must produce exactly the same SA as
1639 // pre-materialising the kept positions and going through the
1640 // _for_positions path.
1641 use rand::{RngExt, SeedableRng};
1642 let mut rng = rand::rngs::StdRng::seed_from_u64(0xCA75_5A);
1643 for &n in &[50usize, 500, 2000] {
1644 for &p in &[1usize, 3, 8] {
1645 let text: Vec<u8> = (0..n).map(|_| rng.random_range(0..6u8)).collect();
1646 let positions: Vec<u64> = (0..n as u64).filter(|&i| text[i as usize] < 4).collect();
1647 let want = ext_mem_for_positions(&text, positions, p);
1648 let got = ext_mem_for_filter(&text, |i| text[i as usize] < 4, p);
1649 assert_eq!(got, want, "filter vs positions mismatch n={n} p={p}");
1650 }
1651 }
1652 }
1653
1654 #[test]
1655 fn ext_mem_for_filter_handles_block_aligned_boundaries() {
1656 // Exercise the bitmap word/block boundaries by using a text
1657 // longer than one popcount block (1024 × 64 bits = 64 K
1658 // positions) — but stay under that to keep the test fast.
1659 // 200 K positions touches the cumsum's second block too.
1660 use rand::{RngExt, SeedableRng};
1661 let mut rng = rand::rngs::StdRng::seed_from_u64(0xB10C_C0DE);
1662 let n = 200_000usize;
1663 let text: Vec<u8> = (0..n).map(|_| rng.random_range(0..6u8)).collect();
1664 let positions: Vec<u64> = (0..n as u64).filter(|&i| text[i as usize] < 4).collect();
1665 let want = ext_mem_for_positions(&text, positions, 8);
1666 let got = ext_mem_for_filter(&text, |i| text[i as usize] < 4, 8);
1667 assert_eq!(got, want, "filter API mismatch across block boundaries");
1668 }
1669
1670 #[test]
1671 fn ext_mem_for_filter_sparse_predicate() {
1672 // ~5% acceptance — exercises long runs of zero-bits in the
1673 // bitmap (skip-loop across whole words).
1674 use rand::{RngExt, SeedableRng};
1675 let mut rng = rand::rngs::StdRng::seed_from_u64(0x5_AA_55);
1676 let n = 50_000usize;
1677 let text: Vec<u8> = (0..n).map(|_| rng.random_range(0..20u8)).collect();
1678 let positions: Vec<u64> = (0..n as u64).filter(|&i| text[i as usize] < 1).collect();
1679 let want = ext_mem_for_positions(&text, positions, 4);
1680 let got = ext_mem_for_filter(&text, |i| text[i as usize] < 1, 4);
1681 assert_eq!(got, want, "filter API mismatch on sparse predicate");
1682 }
1683
1684 #[test]
1685 fn ext_mem_repetitive_does_not_blow_up() {
1686 // Many copies of a long repeat — what killed the Phase 2 v1
1687 // linear-scan merge. The sample-sort + LCP-enhanced cascade
1688 // should handle it in proportional time.
1689 use std::time::Instant;
1690 let unit = b"ACGTACGTACGTACGTACGTACGTACGT"; // 28 bases
1691 let mut text: Vec<u8> = Vec::new();
1692 for _ in 0..100 {
1693 text.extend_from_slice(unit);
1694 }
1695 text.push(200);
1696 let start = Instant::now();
1697 let got = ext_mem_sa(&text, 8);
1698 let elapsed = start.elapsed();
1699 let want: Vec<u32> = build_in_memory(&text);
1700 let want64: Vec<u64> = want.iter().map(|&x| x as u64).collect();
1701 assert_eq!(got, want64);
1702 // Sanity: should finish in well under a second on this input.
1703 assert!(
1704 elapsed.as_secs() < 2,
1705 "ext-mem build on a tiny repetitive text took {elapsed:?}"
1706 );
1707 }
1708}