Skip to main content

caps_sa/
ext_mem.rs

1//! External-memory suffix array construction.
2//!
3//! Implements upstream CaPS-SA's *sample-sort + LCP-enhanced merge*
4//! external-memory algorithm:
5//!
6//! 1. **Sort + sample + spill.** Split positions `0..n` into `p` subarrays
7//!    of size `n / p`. In parallel, sort each with the in-memory
8//!    merge-sort kernel (`sample_sort::merge_sort`), sample ~`c · ln n`
9//!    positions uniformly from each sorted subarray, and spill the sorted
10//!    `(position, lcp)` records to a per-subarray
11//!    [`ExtMemBucket`][crate::ext_bucket::ExtMemBucket].
12//! 2. **Select pivots.** Globally sort the pooled samples and pick
13//!    `p − 1` pivots at evenly-spaced ranks, splitting the suffix order
14//!    into `p` partitions.
15//! 3. **Distribute.** For each sorted subarray, binary-search the pivots
16//!    to find its `p` sub-subarray split points; append each sub-subarray
17//!    to the corresponding *partition* bucket, marking a boundary after
18//!    each contribution.
19//! 4. **Per-partition merge.** Load each partition's bucket into RAM
20//!    (≈ `n / p` records); cascade 2-way LCP-enhanced merges across the
21//!    `p` sub-subarrays to produce that partition's globally-sorted slice.
22//! 5. **Stream output.** Iterate partitions in order and emit each
23//!    position via the caller's closure.
24//!
25//! Peak RAM ≈ `text` + `O(n / p)` per active worker (one partition's
26//! merge working set). With `p = 4 × rayon::current_num_threads()` and
27//! `num_threads = 8`, that's a few hundred MB on a `n = 6e9` (human-scale)
28//! input — well below in-memory's `~4 × n × 8 = ~200 GB`. The full SA is
29//! never materialized in RAM.
30
31use std::cmp::Ordering;
32use std::io;
33use std::path::{Path, PathBuf};
34use std::sync::Mutex;
35use std::time::Instant;
36
37use rayon::prelude::*;
38
39use crate::Index;
40use crate::ext_bucket::{BucketPool, BucketRecord, BucketStore, InMemBucket, SaLcp};
41use crate::lcp::LcpDispatch;
42use crate::sample_sort;
43
44/// Emit a phase-timing line to stderr if `CAPS_SA_PROFILE` is set in
45/// the environment. Used to localise where the ext-mem path spends its
46/// time without paying the cost of always logging — see
47/// `bench/README.md` "Where AVX-512 helps and where it doesn't" for
48/// how this is used.
49fn profile_log(message: &str) {
50    if std::env::var_os("CAPS_SA_PROFILE").is_some() {
51        eprintln!("caps-sa profile  {message}");
52    }
53}
54
55/// Tunable options for [`build_ext_mem`].
56#[derive(Clone, Debug)]
57pub struct ExtMemOpts {
58    /// Bound on LCP-extension comparisons inside merges. `usize::MAX`
59    /// (default) is unbounded.
60    pub max_context: usize,
61    /// Number of subarrays (`p` in upstream CaPS-SA). `0` (default) picks
62    /// `4 × rayon::current_num_threads()`, clamped to `[1, n]`.
63    pub subproblem_count: usize,
64    /// Directory for temp files. Defaults to [`std::env::temp_dir`].
65    pub work_dir: PathBuf,
66    /// Number of physical files in the bucket pool (one pool for the
67    /// phase-1 subarray buckets and a second for the phase-3 partition
68    /// buckets). `0` (default) picks `rayon::current_num_threads()` —
69    /// the right answer in practice: one writable inode per worker
70    /// keeps kernel-level write contention bounded.
71    ///
72    /// The `2 × p` logical buckets (typically thousands at genome
73    /// scale) collapse onto this pool of anonymous tempfiles via
74    /// `bucket_id % physical_file_count`. Larger values lower kernel
75    /// write contention; smaller values are kinder to networked
76    /// filesystems with high metadata cost. The `CAPS_SA_N_PHYS` env
77    /// var overrides this for one-off benches.
78    pub physical_file_count: usize,
79}
80
81impl Default for ExtMemOpts {
82    fn default() -> Self {
83        Self {
84            max_context: usize::MAX,
85            subproblem_count: 0,
86            work_dir: std::env::temp_dir(),
87            physical_file_count: 0,
88        }
89    }
90}
91
92impl ExtMemOpts {
93    /// Convenience constructor with the supplied `work_dir` and defaults
94    /// for everything else.
95    pub fn with_work_dir(work_dir: impl AsRef<Path>) -> Self {
96        Self {
97            work_dir: work_dir.as_ref().to_path_buf(),
98            ..Self::default()
99        }
100    }
101}
102
103/// Build the suffix array of `text` with bounded RAM, streaming each
104/// output position to `emit` in lexicographic order.
105///
106/// Returns an [`io::Error`] if temp-file I/O fails. The callback may also
107/// return an error to abort construction; partial work is discarded and
108/// temp files are cleaned up when their bucket drops.
109///
110/// Equivalent in semantics to [`crate::build_in_memory`]: produces a
111/// standard lexicographic suffix array with the "shorter suffix is
112/// smaller when one runs off the end" tie-break.
113pub fn build_ext_mem<S, F>(text: &[S], opts: &ExtMemOpts, emit: F) -> io::Result<()>
114where
115    S: Ord + Copy + Sync + 'static,
116    F: FnMut(u64) -> io::Result<()>,
117{
118    // Dispatch on text size: when every suffix position fits in `u32`
119    // (n ≤ 2^32), use the narrow record type to halve all the SaLcp
120    // bytes — bucket disk I/O, phase-1 records, and the per-partition
121    // load in phase 4.
122    if text.len() <= u32::MAX as usize + 1 {
123        build_ext_mem_inner::<S, u32, F>(text, PositionSource::Identity(text.len()), opts, emit)
124    } else {
125        build_ext_mem_inner::<S, u64, F>(text, PositionSource::Identity(text.len()), opts, emit)
126    }
127}
128
129/// Like [`build_ext_mem`] but sorts only the caller-supplied
130/// `positions` by the lexicographic order of their suffixes in `text`.
131/// Suffix content is always `text[position..]`; no positions are
132/// dropped from the input. To filter, the caller constructs
133/// `positions` with only the indices they want.
134///
135/// This lets STAR-style genome indexing skip the bin-padding
136/// pathology: pass only the ACGT-starting positions and the
137/// spacer-suffix work disappears.
138///
139/// `positions` does not need to be pre-sorted in any way.
140pub fn build_ext_mem_for_positions<S, F>(
141    text: &[S],
142    positions: Vec<u64>,
143    opts: &ExtMemOpts,
144    emit: F,
145) -> io::Result<()>
146where
147    S: Ord + Copy + Sync + 'static,
148    F: FnMut(u64) -> io::Result<()>,
149{
150    // We hold a reference to `positions` for the duration of the build;
151    // `phase1_sort_sample_spill` copies each chunk out. The Vec is
152    // dropped once phase 1 returns.
153    if text.len() <= u32::MAX as usize + 1 {
154        build_ext_mem_inner::<S, u32, F>(text, PositionSource::Subset(&positions), opts, emit)
155    } else {
156        build_ext_mem_inner::<S, u64, F>(text, PositionSource::Subset(&positions), opts, emit)
157    }
158}
159
160fn build_ext_mem_inner<S, I, F>(
161    text: &[S],
162    source: PositionSource<'_>,
163    opts: &ExtMemOpts,
164    mut emit: F,
165) -> io::Result<()>
166where
167    S: Ord + Copy + Sync + 'static,
168    I: Index,
169    SaLcp<I>: BucketRecord,
170    F: FnMut(u64) -> io::Result<()>,
171{
172    let n = source.len();
173    if n == 0 {
174        return Ok(());
175    }
176    let p = effective_subproblem_count(n, opts.subproblem_count);
177    let dispatch = LcpDispatch::detect();
178    let work_dir = opts.work_dir.clone();
179
180    // Pool the `2 × p` bucket files into one anonymous tempfile per
181    // worker thread. With `p` in the thousands and `num_threads` in
182    // the dozens this collapses the openat/close/unlink budget from
183    // ~3·p (per-bucket-file path) to N (per-worker-file pool),
184    // eliminating the metadata-syscall pain on networked filesystems
185    // and the open-file-handle limit headache on tiny inputs. Per-
186    // bucket in-memory buffers and write volumes are unchanged, so
187    // local-disk wall time is neutral or marginally improved. See
188    // `bench/README.md` for the empirical sizing.
189    let n_phys = effective_physical_file_count(opts.physical_file_count);
190    let phase1_pool = BucketPool::new(n_phys, &work_dir)?;
191    let phase3_pool = BucketPool::new(n_phys, &work_dir)?;
192
193    profile_log(&format!(
194        "build_ext_mem n={n} p={p} index_width={}b n_phys={n_phys}",
195        std::mem::size_of::<I>() * 8
196    ));
197
198    let sub_factory = |i: usize| phase1_pool.new_bucket::<SaLcp<I>>(i);
199    let part_factory = |j: usize| phase3_pool.new_bucket::<SaLcp<I>>(j);
200
201    let t = Instant::now();
202    let (mut subarray_buckets, samples) =
203        phase1_sort_sample_spill::<S, I, _, _>(text, &source, p, opts, dispatch, sub_factory)?;
204    profile_log(&format!(
205        "phase1 (sort+sample+spill) {:.3}s",
206        t.elapsed().as_secs_f64()
207    ));
208
209    let t = Instant::now();
210    let pivots = phase2_select_pivots::<S, I>(text, samples, p, opts.max_context, dispatch);
211    profile_log(&format!(
212        "phase2 (select pivots)      {:.3}s",
213        t.elapsed().as_secs_f64()
214    ));
215
216    let t = Instant::now();
217    let mut partition_buckets = phase3_distribute::<S, I, _, _>(
218        text,
219        &mut subarray_buckets,
220        &pivots,
221        p,
222        opts,
223        dispatch,
224        part_factory,
225    )?;
226    profile_log(&format!(
227        "phase3 (distribute)          {:.3}s",
228        t.elapsed().as_secs_f64()
229    ));
230
231    drop(subarray_buckets);
232
233    let t = Instant::now();
234    let result = phase4_merge_and_emit::<S, I, _, F>(
235        text,
236        &mut partition_buckets,
237        opts.max_context,
238        &mut emit,
239        dispatch,
240    );
241    profile_log(&format!(
242        "phase4 (merge+emit)          {:.3}s",
243        t.elapsed().as_secs_f64()
244    ));
245    result
246}
247
248/// Same algorithm as [`build_ext_mem_inner`] but with the disk-backed
249/// [`ExtMemBucket`] replaced by [`InMemBucket`] throughout — phase 1
250/// sorts each subarray and keeps the result in a `Vec<SaLcp<I>>`,
251/// phase 3 distributes into in-RAM partition Vecs, phase 4 cascade-
252/// merges the in-RAM partitions. No disk I/O.
253///
254/// Trades RAM for wall time: peak memory is ~`n × sizeof(SaLcp<I>)`
255/// (the post-phase-1 records sitting around until phase 3 consumes
256/// them), so ~25 GB on the human genome with `I = u32`. In exchange,
257/// the disk-spill / distribute-write / partition-load round-trip is
258/// gone — useful on machines with enough RAM to hold the working set.
259fn build_in_memory_ss_inner<S, I, F>(
260    text: &[S],
261    source: PositionSource<'_>,
262    opts: &ExtMemOpts,
263    mut emit: F,
264) -> io::Result<()>
265where
266    S: Ord + Copy + Sync + 'static,
267    I: Index,
268    SaLcp<I>: BucketRecord,
269    F: FnMut(u64) -> io::Result<()>,
270{
271    let n = source.len();
272    if n == 0 {
273        return Ok(());
274    }
275    let p = effective_subproblem_count(n, opts.subproblem_count);
276    let dispatch = LcpDispatch::detect();
277
278    let factory = |_i: usize| InMemBucket::<SaLcp<I>>::new();
279
280    let (mut subarray_buckets, samples) =
281        phase1_sort_sample_spill::<S, I, _, _>(text, &source, p, opts, dispatch, factory)?;
282    let pivots = phase2_select_pivots::<S, I>(text, samples, p, opts.max_context, dispatch);
283    let mut partition_buckets = phase3_distribute::<S, I, _, _>(
284        text,
285        &mut subarray_buckets,
286        &pivots,
287        p,
288        opts,
289        dispatch,
290        factory,
291    )?;
292    drop(subarray_buckets);
293    phase4_merge_and_emit::<S, I, _, F>(
294        text,
295        &mut partition_buckets,
296        opts.max_context,
297        &mut emit,
298        dispatch,
299    )
300}
301
302/// In-memory variant of the sample-sort algorithm used by
303/// [`build_ext_mem`]. Skips all disk I/O at the cost of holding the
304/// (`pos`, `lcp`) records in RAM throughout. Picks `u32` records when
305/// `n ≤ 2³²`, falls back to `u64` otherwise. The caller's `emit`
306/// closure is called once per output position in lex order, just like
307/// in the ext-mem path.
308pub fn build_in_memory_sample_sort<S, F>(text: &[S], opts: &ExtMemOpts, emit: F) -> io::Result<()>
309where
310    S: Ord + Copy + Sync + 'static,
311    F: FnMut(u64) -> io::Result<()>,
312{
313    if text.len() <= u32::MAX as usize + 1 {
314        build_in_memory_ss_inner::<S, u32, F>(
315            text,
316            PositionSource::Identity(text.len()),
317            opts,
318            emit,
319        )
320    } else {
321        build_in_memory_ss_inner::<S, u64, F>(
322            text,
323            PositionSource::Identity(text.len()),
324            opts,
325            emit,
326        )
327    }
328}
329
330/// Subset-positions variant of [`build_in_memory_sample_sort`]. Same
331/// shape as [`build_ext_mem_for_positions`].
332pub fn build_in_memory_sample_sort_for_positions<S, F>(
333    text: &[S],
334    positions: Vec<u64>,
335    opts: &ExtMemOpts,
336    emit: F,
337) -> io::Result<()>
338where
339    S: Ord + Copy + Sync + 'static,
340    F: FnMut(u64) -> io::Result<()>,
341{
342    if text.len() <= u32::MAX as usize + 1 {
343        build_in_memory_ss_inner::<S, u32, F>(text, PositionSource::Subset(&positions), opts, emit)
344    } else {
345        build_in_memory_ss_inner::<S, u64, F>(text, PositionSource::Subset(&positions), opts, emit)
346    }
347}
348
349/// Source of the positions to sort. The all-suffixes case
350/// ([`PositionSource::Identity`]) avoids materialising a `Vec<u64>` of
351/// length `n`, which on the human genome would itself be ~25 GB.
352enum PositionSource<'a> {
353    Identity(usize),
354    Subset(&'a [u64]),
355}
356
357impl<'a> PositionSource<'a> {
358    fn len(&self) -> usize {
359        match self {
360            Self::Identity(n) => *n,
361            Self::Subset(p) => p.len(),
362        }
363    }
364
365    /// Fill `dst` with positions for the half-open subarray range
366    /// `[start, start + dst.len())`, narrowing the caller's `u64`
367    /// positions into `I` via [`Index::from_usize`]. For
368    /// [`PositionSource::Identity`] this generates the contiguous
369    /// integer range on the fly; for [`PositionSource::Subset`] it
370    /// reads from the caller's slice.
371    fn fill_chunk<I: Index>(&self, start: usize, dst: &mut [I]) {
372        match self {
373            Self::Identity(_) => {
374                for (i, slot) in dst.iter_mut().enumerate() {
375                    *slot = I::from_usize(start + i);
376                }
377            }
378            Self::Subset(p) => {
379                let end = start + dst.len();
380                for (slot, &v) in dst.iter_mut().zip(p[start..end].iter()) {
381                    *slot = I::from_usize(v as usize);
382                }
383            }
384        }
385    }
386}
387
388/// Target subarray size used by [`effective_subproblem_count`] when
389/// auto-picking `p`. Smaller means more (smaller) subarrays — lower
390/// per-task phase-1 scratch, at the cost of more phase-3 distribute
391/// work (which scales as `O(p² · log(n/p))`, sequentially) and a
392/// higher temp-file count.
393const PHASE1_TARGET_CHUNK: usize = 65_536;
394/// Hard cap on the number of subarrays. Matches upstream CaPS-SA's
395/// default of 8192 — phase 3 is now parallelised across rayon
396/// workers (each subarray distributes independently into per-partition
397/// `Mutex<ExtMemBucket>` slots), so the `O(p²)` sequential distribute
398/// of the original design no longer constrains us. The cap is still
399/// finite to keep the temp-file count bounded.
400const PHASE1_MAX_PARTITIONS: usize = 8192;
401
402/// Resolve [`ExtMemOpts::physical_file_count`] for the current build.
403/// `0` (the default) means "let the runtime decide"; we pick
404/// `rayon::current_num_threads()` so the pool has one inode per
405/// concurrent writer, which empirically matches per-bucket-file wall
406/// time while collapsing thousands of small files into dozens of
407/// large ones. The `CAPS_SA_N_PHYS` env var overrides at the call
408/// site for benchmarks.
409fn effective_physical_file_count(requested: usize) -> usize {
410    if let Some(v) = std::env::var("CAPS_SA_N_PHYS")
411        .ok()
412        .and_then(|s| s.parse::<usize>().ok())
413        .filter(|&v| v >= 1)
414    {
415        return v;
416    }
417    if requested >= 1 {
418        return requested;
419    }
420    rayon::current_num_threads().max(1)
421}
422
423fn effective_subproblem_count(n: usize, requested: usize) -> usize {
424    if n == 0 {
425        return 0;
426    }
427    let raw = if requested == 0 {
428        let nthreads = rayon::current_num_threads().max(1);
429        let p_from_size = n.div_ceil(PHASE1_TARGET_CHUNK);
430        // At least one chunk per thread (otherwise we leave cores idle),
431        // at most `PHASE1_MAX_PARTITIONS` (so phase 3's sequential
432        // `O(p²)` sweep and the temp-file count stay manageable). For
433        // small inputs `p_from_size` is well below the cap, so the
434        // formula degrades gracefully to roughly "one chunk per thread";
435        // for human-scale inputs the cap binds and per-task scratch
436        // stays in the tens-of-MB range.
437        p_from_size.clamp(nthreads, PHASE1_MAX_PARTITIONS)
438    } else {
439        requested
440    };
441    raw.clamp(1, n)
442}
443
444/// Phase 1: sort each subarray in parallel, sample from it, and spill
445/// `(position, lcp)` records to its own [`ExtMemBucket`].
446///
447/// One rayon task per subarray; rayon's work-stealing scheduler keeps
448/// all worker threads busy and lets `merge_sort`'s inner
449/// [`rayon::join`] recursion steal idle slots. With the auto-picked
450/// `p` (target chunk ~ 64 K records), per-task scratch is ~18 MiB on
451/// human-scale inputs, so the `num_threads × per_task_scratch` peak
452/// stays bounded even though we don't reuse buffers across iterations.
453#[allow(clippy::too_many_arguments)]
454fn phase1_sort_sample_spill<S, I, B, MkB>(
455    text: &[S],
456    source: &PositionSource<'_>,
457    p: usize,
458    opts: &ExtMemOpts,
459    dispatch: LcpDispatch,
460    mk_bucket: MkB,
461) -> io::Result<(Vec<B>, Vec<I>)>
462where
463    S: Ord + Copy + Sync + 'static,
464    I: Index,
465    SaLcp<I>: BucketRecord,
466    B: BucketStore<SaLcp<I>> + Send,
467    MkB: Fn(usize) -> B + Send + Sync,
468{
469    let n = source.len();
470    let chunk_size = n.div_ceil(p);
471    let samples_target_total = sample_target_total(n, p);
472
473    let per_subarray: Vec<(B, Vec<I>)> = (0..p)
474        .into_par_iter()
475        .map(|i| {
476            let start = (i * chunk_size).min(n);
477            let end = ((i + 1) * chunk_size).min(n);
478            let len = end - start;
479
480            let mut bucket = mk_bucket(i);
481            if len == 0 {
482                return Ok::<_, io::Error>((bucket, Vec::new()));
483            }
484
485            // In-memory sort of this subarray with LCP maintenance.
486            let mut sa: Vec<I> = vec![I::zero(); len];
487            source.fill_chunk(start, &mut sa);
488            let mut sa_w = vec![I::zero(); len];
489            let mut lcp_arr = vec![I::zero(); len];
490            let mut lcp_w = vec![I::zero(); len];
491            sample_sort::merge_sort(
492                text,
493                &mut sa,
494                &mut sa_w,
495                &mut lcp_arr,
496                &mut lcp_w,
497                opts.max_context,
498                dispatch,
499            );
500
501            // Pull `samples_per_subarray` evenly-spaced positions out of
502            // the now-sorted subarray. Deterministic — no RNG needed for
503            // pivot selection to be globally well-distributed.
504            let samples_per_subarray = samples_target_total.div_ceil(p).min(len);
505            let samples = evenly_spaced(&sa, samples_per_subarray);
506
507            // Spill (position, lcp) records to the bucket. `lcp[0]`
508            // remains 0 (set by the merge-sort base case), making each
509            // subarray its own well-formed LCP-annotated sorted run.
510            let records: Vec<SaLcp<I>> = sa
511                .iter()
512                .zip(lcp_arr.iter())
513                .map(|(&pos, &lcp)| SaLcp { pos, lcp })
514                .collect();
515            bucket.add_slice(&records)?;
516
517            Ok((bucket, samples))
518        })
519        .collect::<Result<Vec<_>, _>>()?;
520
521    let mut buckets = Vec::with_capacity(p);
522    let mut all_samples = Vec::with_capacity(samples_target_total);
523    for (bucket, samples) in per_subarray {
524        buckets.push(bucket);
525        all_samples.extend(samples);
526    }
527    Ok((buckets, all_samples))
528}
529
530/// Target sample count *across all subarrays*. Matches upstream CaPS-SA's
531/// "`c · ln n`" rule per subarray with `c = 4`, so the global pool is
532/// `p · 4 · ln n` samples.
533fn sample_target_total(n: usize, p: usize) -> usize {
534    let ln_n = (n as f64).ln().max(1.0);
535    let per = (4.0 * ln_n).ceil() as usize;
536    // At least p (so we have enough to pick p-1 pivots) and at most n.
537    p.saturating_mul(per).clamp(p, n)
538}
539
540/// Pick `count` evenly-spaced elements from a slice. Deterministic, which
541/// keeps the algorithm reproducible without an RNG dependency.
542fn evenly_spaced<T: Copy>(xs: &[T], count: usize) -> Vec<T> {
543    let n = xs.len();
544    if count == 0 || n == 0 {
545        return Vec::new();
546    }
547    if count >= n {
548        return xs.to_vec();
549    }
550    // Pick indices at positions (i + 0.5) · n / count for i in 0..count,
551    // i.e. evenly-spaced midpoints. Avoids both endpoints — keeps pivots
552    // away from extreme corners of the order.
553    (0..count)
554        .map(|i| xs[(2 * i + 1) * n / (2 * count)])
555        .collect()
556}
557
558/// Phase 2: globally sort the pooled samples and pick `p − 1` pivots at
559/// evenly-spaced ranks.
560fn phase2_select_pivots<S, I>(
561    text: &[S],
562    mut samples: Vec<I>,
563    p: usize,
564    max_ctx: usize,
565    dispatch: LcpDispatch,
566) -> Vec<I>
567where
568    S: Ord + Copy + Sync + 'static,
569    I: Index,
570{
571    if p <= 1 || samples.is_empty() {
572        return Vec::new();
573    }
574    let n_samples = samples.len();
575    let mut sa_w = vec![I::zero(); n_samples];
576    let mut lcp = vec![I::zero(); n_samples];
577    let mut lcp_w = vec![I::zero(); n_samples];
578    sample_sort::merge_sort(
579        text,
580        &mut samples,
581        &mut sa_w,
582        &mut lcp,
583        &mut lcp_w,
584        max_ctx,
585        dispatch,
586    );
587
588    // p-1 pivots at evenly-spaced ranks across the sorted sample pool.
589    (1..p).map(|j| samples[(j * n_samples) / p]).collect()
590}
591
592/// Phase 3: walk each subarray *in parallel*, load it into RAM,
593/// binary-search the pivots to find its `p` sub-subarray boundaries,
594/// and append each sub-subarray to the corresponding partition bucket.
595///
596/// Partition buckets are wrapped in a [`Mutex`] each so multiple
597/// threads can write to different partitions concurrently without
598/// shard-merging afterwards. With `p` in the thousands and `T` in the
599/// tens, lock contention is negligible (probability that two threads
600/// want the same partition at the same instant is `~T/p`); the lock
601/// scope per acquisition is one `add_slice` + `mark_boundary` of a
602/// few-KB sub-subarray.
603///
604/// Phase 4 doesn't care about the relative order of sub-subarrays
605/// within a partition — only that each one between consecutive
606/// boundaries is internally sorted. Both properties hold under
607/// arbitrary thread interleaving.
608#[allow(clippy::too_many_arguments)]
609fn phase3_distribute<S, I, B, MkB>(
610    text: &[S],
611    subarray_buckets: &mut [B],
612    pivots: &[I],
613    p: usize,
614    opts: &ExtMemOpts,
615    dispatch: LcpDispatch,
616    mk_bucket: MkB,
617) -> io::Result<Vec<B>>
618where
619    S: Ord + Copy + Sync + 'static,
620    I: Index,
621    SaLcp<I>: BucketRecord,
622    B: BucketStore<SaLcp<I>> + Send,
623    MkB: Fn(usize) -> B + Send + Sync,
624{
625    let _ = opts; // work_dir is used only by the ext-mem factory closure now
626    let partition_buckets: Vec<Mutex<B>> = (0..p).map(|j| Mutex::new(mk_bucket(j))).collect();
627
628    subarray_buckets
629        .par_iter_mut()
630        .try_for_each(|sub_bucket| -> io::Result<()> {
631            if sub_bucket.total_records() == 0 {
632                return Ok(());
633            }
634            let records = sub_bucket.load_all()?;
635
636            // Find p-1 split points by binary-searching each pivot's
637            // *upper bound* in the sorted subarray.
638            let mut splits = Vec::with_capacity(p + 1);
639            splits.push(0usize);
640            for &pivot in pivots {
641                splits.push(upper_bound_by_pivot(
642                    &records,
643                    pivot,
644                    text,
645                    opts.max_context,
646                    dispatch,
647                ));
648            }
649            splits.push(records.len());
650
651            // Distribute each sub-subarray. Reset the first record's
652            // `lcp` to 0 so the per-partition merge sees a well-formed
653            // boundary.
654            for j in 0..p {
655                let lo = splits[j];
656                let hi = splits[j + 1];
657                if lo >= hi {
658                    continue;
659                }
660                let mut sub: Vec<SaLcp<I>> = records[lo..hi].to_vec();
661                sub[0].lcp = I::zero();
662                let mut bucket = partition_buckets[j].lock().unwrap();
663                bucket.add_slice(&sub)?;
664                bucket.mark_boundary();
665            }
666            Ok(())
667        })?;
668
669    // Unwrap the Mutexes — at this point only this thread holds
670    // references, so the locks are uncontended.
671    Ok(partition_buckets
672        .into_iter()
673        .map(|m| m.into_inner().expect("partition mutex poisoned"))
674        .collect())
675}
676
677/// Upper-bound binary search: returns the first index `i` such that the
678/// suffix at `records[i].pos` is **strictly greater than** the suffix at
679/// `pivot`.
680fn upper_bound_by_pivot<S, I>(
681    records: &[SaLcp<I>],
682    pivot: I,
683    text: &[S],
684    max_ctx: usize,
685    dispatch: LcpDispatch,
686) -> usize
687where
688    S: Ord + Copy + 'static,
689    I: Index,
690{
691    let mut lo = 0;
692    let mut hi = records.len();
693    while lo < hi {
694        let mid = lo + (hi - lo) / 2;
695        match dispatch.suffix_cmp(text, records[mid].pos.to_usize(), pivot.to_usize(), max_ctx) {
696            Ordering::Greater => hi = mid,
697            Ordering::Equal | Ordering::Less => lo = mid + 1,
698        }
699    }
700    lo
701}
702
703/// Phase 4 + 5: parallel-merge partitions in chunks of `num_threads`,
704/// emitting each chunk's results in lex order before starting the next.
705///
706/// Each worker thread holds its own [`CascadeWorkspace`] for the duration
707/// of one partition merge. Within a chunk, all `T` workspaces live in
708/// parallel; between chunks, they are dropped (so peak workspace memory
709/// scales with `T`, not with the number of partitions). The merged result
710/// for each partition is then drained sequentially via `emit` to preserve
711/// streaming-output order without ever holding the full SA in RAM.
712///
713/// Peak transient RAM ≈ `T × max_partition_size × 16 bytes` for the
714/// merged-result buffers, plus the workspaces themselves (~`2 × T ×
715/// max_partition_size × 16` bytes). On a typical run with `p = 4 × T`
716/// subarrays the per-partition size is `≈ n / p`, so this stays
717/// proportional to `n / 4 = 0.25 n` even at the peak — well below the
718/// in-memory path's `~4 n` working set.
719fn phase4_merge_and_emit<S, I, B, F>(
720    text: &[S],
721    partition_buckets: &mut [B],
722    max_ctx: usize,
723    emit: &mut F,
724    dispatch: LcpDispatch,
725) -> io::Result<()>
726where
727    S: Ord + Copy + Sync + 'static,
728    I: Index,
729    SaLcp<I>: BucketRecord,
730    B: BucketStore<SaLcp<I>> + Send,
731    F: FnMut(u64) -> io::Result<()>,
732{
733    let n_partitions = partition_buckets.len();
734    if n_partitions == 0 {
735        return Ok(());
736    }
737    let chunk_size = rayon::current_num_threads().max(1);
738
739    // Per-thread CPU-µs accumulators for the two parallel sub-steps. They
740    // add across threads, so the printed values are CPU-time (sum), not
741    // wall-time; the ratio between them still tells us where the work is.
742    use std::sync::atomic::{AtomicU64, Ordering as AtomicOrdering};
743    let profile = std::env::var_os("CAPS_SA_PROFILE").is_some();
744    let load_us = AtomicU64::new(0);
745    let merge_us = AtomicU64::new(0);
746    let mut emit_secs: f64 = 0.0;
747
748    let mut start = 0;
749    while start < n_partitions {
750        let end = (start + chunk_size).min(n_partitions);
751        let chunk = &mut partition_buckets[start..end];
752
753        // Parallel-merge each non-empty bucket in this chunk. `par_iter_mut`
754        // preserves index order in the collected `Vec`, so the subsequent
755        // sequential emit yields positions in lex order.
756        let merged: Vec<Vec<I>> = chunk
757            .par_iter_mut()
758            .map(|bucket| -> io::Result<Vec<I>> {
759                if bucket.total_records() == 0 {
760                    return Ok(Vec::new());
761                }
762                let t = Instant::now();
763                let records = bucket.load_all()?;
764                let boundaries: Vec<usize> = bucket.boundaries().to_vec();
765                if profile {
766                    load_us.fetch_add(t.elapsed().as_micros() as u64, AtomicOrdering::Relaxed);
767                }
768
769                let t = Instant::now();
770                let workspace = CascadeWorkspace::<I>::new();
771                // `cascade_merge` consumes the workspace and returns
772                // the result side directly — the other three buffers
773                // drop along with `workspace` here, without an
774                // intermediate `to_vec()` copy.
775                let result =
776                    workspace.cascade_merge(text, &records, &boundaries, max_ctx, dispatch);
777                if profile {
778                    merge_us.fetch_add(t.elapsed().as_micros() as u64, AtomicOrdering::Relaxed);
779                }
780                Ok(result)
781            })
782            .collect::<Result<Vec<_>, io::Error>>()?;
783
784        let t = Instant::now();
785        for positions in merged {
786            for pos in positions {
787                // Widen back to the public `u64` emit contract.
788                emit(pos.to_usize() as u64)?;
789            }
790        }
791        if profile {
792            emit_secs += t.elapsed().as_secs_f64();
793        }
794
795        start = end;
796    }
797    if profile {
798        profile_log(&format!(
799            "phase4 breakdown CPU: load {:.3}s merge {:.3}s; wall emit {:.3}s",
800            load_us.load(AtomicOrdering::Relaxed) as f64 * 1e-6,
801            merge_us.load(AtomicOrdering::Relaxed) as f64 * 1e-6,
802            emit_secs,
803        ));
804    }
805    Ok(())
806}
807
808/// Reusable ping-pong scratch for the partition cascade merge.
809///
810/// Holds two `(sa, lcp)` buffers each sized to the largest partition seen.
811/// The cascade alternates reads from one side and writes to the other,
812/// flipping a `src_is_a` flag after each level. Avoids the
813/// per-level allocations that the previous immutable-`Vec` cascade
814/// performed for every pair of sub-subarrays.
815struct CascadeWorkspace<I> {
816    a_sa: Vec<I>,
817    a_lcp: Vec<I>,
818    b_sa: Vec<I>,
819    b_lcp: Vec<I>,
820}
821
822impl<I: Index> CascadeWorkspace<I> {
823    fn new() -> Self {
824        Self {
825            a_sa: Vec::new(),
826            a_lcp: Vec::new(),
827            b_sa: Vec::new(),
828            b_lcp: Vec::new(),
829        }
830    }
831
832    /// Grow all four buffers to at least `n` elements. The contents past
833    /// the cascade's actual run lengths are don't-care.
834    fn ensure_capacity(&mut self, n: usize) {
835        if self.a_sa.len() < n {
836            self.a_sa.resize(n, I::zero());
837            self.a_lcp.resize(n, I::zero());
838            self.b_sa.resize(n, I::zero());
839            self.b_lcp.resize(n, I::zero());
840        }
841    }
842
843    /// Cascade 2-way LCP-enhanced merges across the sub-subarrays of one
844    /// partition (delimited by `boundaries`) until a single sorted run
845    /// remains. **Consumes the workspace** and returns the result side
846    /// as a `Vec<u64>`; the other three buffers (`a_lcp`, the opposing
847    /// `*_sa`, the opposing `*_lcp`) drop immediately. This shape lets
848    /// the caller skip the per-partition `to_vec()` round-trip that
849    /// would otherwise sit briefly alongside all four workspace buffers
850    /// at peak.
851    fn cascade_merge<S>(
852        mut self,
853        text: &[S],
854        records: &[SaLcp<I>],
855        boundaries: &[usize],
856        max_ctx: usize,
857        dispatch: LcpDispatch,
858    ) -> Vec<I>
859    where
860        S: Ord + Copy + 'static,
861    {
862        let n = records.len();
863        if n == 0 {
864            return Vec::new();
865        }
866        self.ensure_capacity(n);
867
868        // Initialize side A in SOA form from the AOS `records`, and
869        // collect the lengths of the non-empty sub-subarrays.
870        let mut run_lens: Vec<usize> = boundaries
871            .windows(2)
872            .filter_map(|w| {
873                let l = w[1] - w[0];
874                if l > 0 { Some(l) } else { None }
875            })
876            .collect();
877        for (i, r) in records.iter().enumerate() {
878            self.a_sa[i] = r.pos;
879            self.a_lcp[i] = r.lcp;
880        }
881
882        let mut src_is_a = true;
883        while run_lens.len() > 1 {
884            run_lens = self.merge_one_level(src_is_a, &run_lens, text, max_ctx, dispatch);
885            src_is_a = !src_is_a;
886        }
887
888        // Take ownership of the buffer holding the result, truncate to
889        // the actual record count, drop the other three buffers with
890        // `self` going out of scope.
891        let mut result = if src_is_a { self.a_sa } else { self.b_sa };
892        result.truncate(n);
893        result
894    }
895
896    /// Pair the runs in `run_lens` (last odd one passes through unchanged),
897    /// running each pair through the LCP-enhanced 2-way merge from the
898    /// `src_is_a`-selected buffer side into the other. Returns the new
899    /// run-length list (each entry is the sum of the two it replaced, or
900    /// the carry-over for an odd tail).
901    fn merge_one_level<S>(
902        &mut self,
903        src_is_a: bool,
904        run_lens: &[usize],
905        text: &[S],
906        max_ctx: usize,
907        dispatch: LcpDispatch,
908    ) -> Vec<usize>
909    where
910        S: Ord + Copy + 'static,
911    {
912        // Destructure self so the borrow checker can see the two sides as
913        // disjoint locals — we borrow one immutably and the other mutably.
914        let Self {
915            a_sa,
916            a_lcp,
917            b_sa,
918            b_lcp,
919        } = self;
920        let (src_sa, src_lcp, dst_sa, dst_lcp) = if src_is_a {
921            (
922                a_sa.as_slice(),
923                a_lcp.as_slice(),
924                b_sa.as_mut_slice(),
925                b_lcp.as_mut_slice(),
926            )
927        } else {
928            (
929                b_sa.as_slice(),
930                b_lcp.as_slice(),
931                a_sa.as_mut_slice(),
932                a_lcp.as_mut_slice(),
933            )
934        };
935
936        let mut new_lens = Vec::with_capacity(run_lens.len().div_ceil(2));
937        let mut src_off = 0usize;
938        let mut dst_off = 0usize;
939        let mut i = 0;
940        while i < run_lens.len() {
941            let l1 = run_lens[i];
942            if i + 1 < run_lens.len() {
943                let l2 = run_lens[i + 1];
944                let x_end = src_off + l1;
945                let xy_end = x_end + l2;
946                let dst_end = dst_off + l1 + l2;
947                sample_sort::merge(
948                    text,
949                    &src_sa[src_off..x_end],
950                    &src_sa[x_end..xy_end],
951                    &src_lcp[src_off..x_end],
952                    &src_lcp[x_end..xy_end],
953                    &mut dst_sa[dst_off..dst_end],
954                    &mut dst_lcp[dst_off..dst_end],
955                    max_ctx,
956                    dispatch,
957                );
958                new_lens.push(l1 + l2);
959                src_off = xy_end;
960                dst_off = dst_end;
961                i += 2;
962            } else {
963                // Odd run carries over unchanged.
964                let end = dst_off + l1;
965                dst_sa[dst_off..end].copy_from_slice(&src_sa[src_off..src_off + l1]);
966                dst_lcp[dst_off..end].copy_from_slice(&src_lcp[src_off..src_off + l1]);
967                new_lens.push(l1);
968                src_off += l1;
969                dst_off = end;
970                i += 1;
971            }
972        }
973        new_lens
974    }
975}
976
977#[cfg(test)]
978mod tests {
979    use super::*;
980    use crate::build_in_memory;
981    use tempfile::tempdir;
982
983    fn ext_mem_sa(text: &[u8], p: usize) -> Vec<u64> {
984        let dir = tempdir().unwrap();
985        let opts = ExtMemOpts {
986            subproblem_count: p,
987            work_dir: dir.path().to_path_buf(),
988            ..ExtMemOpts::default()
989        };
990        let mut out: Vec<u64> = Vec::with_capacity(text.len());
991        build_ext_mem(text, &opts, |pos| {
992            out.push(pos);
993            Ok(())
994        })
995        .unwrap();
996        out
997    }
998
999    fn assert_matches_in_memory(text: &[u8], p: usize) {
1000        let want: Vec<u32> = build_in_memory(text);
1001        let want64: Vec<u64> = want.iter().map(|&x| x as u64).collect();
1002        let got = ext_mem_sa(text, p);
1003        assert_eq!(got, want64, "mismatch on text {text:?} with p={p}");
1004    }
1005
1006    #[test]
1007    fn ext_mem_empty() {
1008        let got = ext_mem_sa(b"", 4);
1009        assert!(got.is_empty());
1010    }
1011
1012    #[test]
1013    fn ext_mem_single_partition() {
1014        assert_matches_in_memory(b"banana", 1);
1015    }
1016
1017    #[test]
1018    fn ext_mem_p_greater_than_n() {
1019        assert_matches_in_memory(b"abc", 10);
1020    }
1021
1022    #[test]
1023    fn ext_mem_banana_p4() {
1024        assert_matches_in_memory(b"banana", 4);
1025    }
1026
1027    #[test]
1028    fn ext_mem_mississippi_p3() {
1029        assert_matches_in_memory(b"mississippi", 3);
1030    }
1031
1032    #[test]
1033    fn ext_mem_random_byte_texts() {
1034        use rand::{RngExt, SeedableRng};
1035        let mut rng = rand::rngs::StdRng::seed_from_u64(0xCAFE);
1036        for &n in &[16usize, 100, 1000, 5000] {
1037            for &p in &[1usize, 2, 4, 16] {
1038                let text: Vec<u8> = (0..n).map(|_| rng.random_range(0..6u8)).collect();
1039                assert_matches_in_memory(&text, p);
1040            }
1041        }
1042    }
1043
1044    #[test]
1045    fn ext_mem_with_unique_terminator() {
1046        use rand::{RngExt, SeedableRng};
1047        let mut rng = rand::rngs::StdRng::seed_from_u64(0xF00D);
1048        for &n in &[10usize, 200, 2000] {
1049            for &p in &[1usize, 3, 8] {
1050                let mut text: Vec<u8> = (0..n).map(|_| rng.random_range(0..5u8)).collect();
1051                text.push(200);
1052                assert_matches_in_memory(&text, p);
1053            }
1054        }
1055    }
1056
1057    fn ext_mem_for_positions(text: &[u8], positions: Vec<u64>, p: usize) -> Vec<u64> {
1058        let dir = tempdir().unwrap();
1059        let opts = ExtMemOpts {
1060            subproblem_count: p,
1061            work_dir: dir.path().to_path_buf(),
1062            ..ExtMemOpts::default()
1063        };
1064        let mut out: Vec<u64> = Vec::with_capacity(positions.len());
1065        build_ext_mem_for_positions(text, positions, &opts, |pos| {
1066            out.push(pos);
1067            Ok(())
1068        })
1069        .unwrap();
1070        out
1071    }
1072
1073    #[test]
1074    fn ext_mem_for_positions_full_set_matches_ext_mem() {
1075        let text = b"mississippi";
1076        let want = ext_mem_sa(text, 3);
1077        let positions: Vec<u64> = (0..text.len() as u64).collect();
1078        let got = ext_mem_for_positions(text, positions, 3);
1079        assert_eq!(got, want);
1080    }
1081
1082    #[test]
1083    fn ext_mem_for_positions_subset_matches_brute_force() {
1084        let text = b"mississippi";
1085        let positions: Vec<u64> = (0..text.len() as u64).step_by(2).collect();
1086        let mut want = positions.clone();
1087        want.sort_by(|&a, &b| text[a as usize..].cmp(&text[b as usize..]));
1088        let got = ext_mem_for_positions(text, positions, 4);
1089        assert_eq!(got, want);
1090    }
1091
1092    #[test]
1093    fn ext_mem_for_positions_random_subsets() {
1094        use rand::{RngExt, SeedableRng};
1095        let mut rng = rand::rngs::StdRng::seed_from_u64(0xC0DE);
1096        for &n in &[50usize, 500, 2000] {
1097            for &p in &[1usize, 3, 8] {
1098                let text: Vec<u8> = (0..n).map(|_| rng.random_range(0..5u8)).collect();
1099                let mut positions: Vec<u64> = (0..n as u64).collect();
1100                positions.retain(|_| rng.random_range(0..10) < 7);
1101                let mut want = positions.clone();
1102                want.sort_by(|&a, &b| text[a as usize..].cmp(&text[b as usize..]));
1103                let got = ext_mem_for_positions(&text, positions, p);
1104                assert_eq!(got, want, "subset ext-mem mismatch n={n} p={p}");
1105            }
1106        }
1107    }
1108
1109    fn in_memory_sample_sort(text: &[u8], p: usize) -> Vec<u64> {
1110        let dir = tempdir().unwrap();
1111        let opts = ExtMemOpts {
1112            subproblem_count: p,
1113            work_dir: dir.path().to_path_buf(),
1114            ..ExtMemOpts::default()
1115        };
1116        let mut out: Vec<u64> = Vec::with_capacity(text.len());
1117        build_in_memory_sample_sort(text, &opts, |pos| {
1118            out.push(pos);
1119            Ok(())
1120        })
1121        .unwrap();
1122        out
1123    }
1124
1125    #[test]
1126    fn in_memory_sample_sort_matches_in_memory() {
1127        for text in [b"banana" as &[u8], b"mississippi", b"abracadabra"] {
1128            let want: Vec<u32> = build_in_memory(text);
1129            let want64: Vec<u64> = want.iter().map(|&x| x as u64).collect();
1130            let got = in_memory_sample_sort(text, 0);
1131            assert_eq!(got, want64, "in-mem sample-sort mismatch on {text:?}");
1132        }
1133    }
1134
1135    #[test]
1136    fn in_memory_sample_sort_random_byte_texts() {
1137        use rand::{RngExt, SeedableRng};
1138        let mut rng = rand::rngs::StdRng::seed_from_u64(0xC0DE_C0DE);
1139        for &n in &[16usize, 200, 2000] {
1140            for &p in &[1usize, 4, 16] {
1141                let text: Vec<u8> = (0..n).map(|_| rng.random_range(0..6u8)).collect();
1142                let want: Vec<u32> = build_in_memory(&text);
1143                let want64: Vec<u64> = want.iter().map(|&x| x as u64).collect();
1144                let got = in_memory_sample_sort(&text, p);
1145                assert_eq!(got, want64, "in-mem ss mismatch n={n} p={p}");
1146            }
1147        }
1148    }
1149
1150    #[test]
1151    fn ext_mem_repetitive_does_not_blow_up() {
1152        // Many copies of a long repeat — what killed the Phase 2 v1
1153        // linear-scan merge. The sample-sort + LCP-enhanced cascade
1154        // should handle it in proportional time.
1155        use std::time::Instant;
1156        let unit = b"ACGTACGTACGTACGTACGTACGTACGT"; // 28 bases
1157        let mut text: Vec<u8> = Vec::new();
1158        for _ in 0..100 {
1159            text.extend_from_slice(unit);
1160        }
1161        text.push(200);
1162        let start = Instant::now();
1163        let got = ext_mem_sa(&text, 8);
1164        let elapsed = start.elapsed();
1165        let want: Vec<u32> = build_in_memory(&text);
1166        let want64: Vec<u64> = want.iter().map(|&x| x as u64).collect();
1167        assert_eq!(got, want64);
1168        // Sanity: should finish in well under a second on this input.
1169        assert!(
1170            elapsed.as_secs() < 2,
1171            "ext-mem build on a tiny repetitive text took {elapsed:?}"
1172        );
1173    }
1174}