libsufr/
sufr_builder.rs

1//! Create on-disk suffix/LCP arrays
2//!
3//! Sufr builds the suffix and LCP (longest common prefix) arrays on disk.
4//! The suffixes are partitioned into temporary files, and these are sorted
5//! in parallel.
6//! Call the `write` method to serialized the data structures to
7//! a file (preferably with the _.sufr_ extension) that can be read
8//! by `sufr_file`.
9//!
10
11use crate::{
12    types::{
13        FromUsize, Int, SeedMask, SuffixSortType, SufrBuilderArgs, OUTFILE_VERSION,
14        SENTINEL_CHARACTER,
15    },
16    util::{find_lcp_full_offset, slice_u8_to_vec, usize_to_bytes, vec_to_slice_u8},
17};
18use anyhow::{anyhow, bail, Result};
19use log::info;
20use rand::{rngs::StdRng, Rng, RngCore, SeedableRng};
21use rayon::prelude::*;
22use std::{
23    cmp::{max, min, Ordering},
24    collections::HashSet,
25    fs::{self, File, OpenOptions},
26    io::{BufWriter, Seek, SeekFrom, Write},
27    mem,
28    ops::Range,
29    path::PathBuf,
30    sync::{Arc, Mutex},
31    time::Instant,
32};
33use tempfile::NamedTempFile;
34
35// --------------------------------------------------
36/// A struct for partitioning, sorting, and writing suffixes to disk
37#[derive(Debug)]
38pub struct SufrBuilder<T>
39where
40    T: Int + FromUsize<T> + Sized + Send + Sync + serde::ser::Serialize,
41{
42    /// The serialization version.
43    pub version: u8,
44
45    /// Whether or not the sequence is nucleotide.
46    pub is_dna: bool,
47
48    /// Whether or not the nucleotide sequence allows characters
49    /// other than A, C, G, or T.
50    pub allow_ambiguity: bool,
51
52    /// Whether or not the nucleotide sequence ignores
53    /// softmasked/lowercase bases.
54    pub ignore_softmask: bool,
55
56    /// The length of the given text.
57    pub text_len: T,
58
59    /// The number of suffixes that were indexed from the text, which
60    /// could be less than `text_len` when ambiguity/softmasked values
61    /// are ignored.
62    pub num_suffixes: T,
63
64    /// The number of sequences represented in the text.
65    pub num_sequences: T,
66
67    /// The positions in the text where each sequence starts.
68    pub sequence_starts: Vec<T>,
69
70    /// The names of the sequences in the text. Should be the same length
71    /// as `sequence_starts`.
72    pub sequence_names: Vec<String>,
73
74    /// The text that was indexed.
75    pub text: Vec<u8>,
76
77    /// Whether the text is sorted fully, using a maximum query length,
78    /// or a seed mask.
79    pub sort_type: SuffixSortType,
80
81    /// The number of partitions to use when building.
82    partitions: Vec<Partition>,
83
84    /// The locations of long runs of Ns in nucleotide text.
85    pub n_ranges: Vec<Range<usize>>,
86
87    /// The name of the output file
88    pub path: String,
89}
90
91// --------------------------------------------------
92impl<T> SufrBuilder<T>
93where
94    T: Int + FromUsize<T> + Sized + Send + Sync,
95{
96    /// Create a new suffix/LCP array.
97    /// The results will live in temporary files on-disk.
98    /// The integer values representing the positions of each suffix will
99    /// be `u32` when the length of the text is less than 2^32 and `u64`,
100    /// otherwise.
101    ///
102    /// ```
103    /// use anyhow::Result;
104    /// use std::{fs, path::Path};
105    /// use libsufr::{
106    ///     sufr_builder::SufrBuilder,
107    ///     types::SufrBuilderArgs,
108    ///     util::read_sequence_file,
109    /// };
110    ///
111    /// fn main() -> Result<()> {
112    ///     let path = Path::new("../data/inputs/1.fa");
113    ///     let sequence_delimiter = b'%';
114    ///     let seq_data = read_sequence_file(path, sequence_delimiter)?;
115    ///     let text_len = seq_data.seq.len() as u64;
116    ///     let outfile = "1.sufr";
117    ///     let builder_args = SufrBuilderArgs {
118    ///         text: seq_data.seq,
119    ///         low_memory: false,
120    ///         path: Some(outfile.to_string()),
121    ///         max_query_len: None,
122    ///         is_dna: true,
123    ///         allow_ambiguity: false,
124    ///         ignore_softmask: true,
125    ///         sequence_starts: seq_data.start_positions.into_iter().collect(),
126    ///         sequence_names: seq_data.sequence_names,
127    ///         num_partitions: 1024,
128    ///         seed_mask: None,
129    ///         random_seed: 42,
130    ///     };
131    ///
132    ///     if text_len < u32::MAX as u64 {
133    ///         let sufr_builder: SufrBuilder<u32> = SufrBuilder::new(builder_args)?;
134    ///     } else {
135    ///         let sufr_builder: SufrBuilder<u64> = SufrBuilder::new(builder_args)?;
136    ///     }
137    ///
138    ///     fs::remove_file(&outfile)?;
139    ///
140    ///     Ok(())
141    /// }
142    /// ```
143    pub fn new(args: SufrBuilderArgs) -> Result<SufrBuilder<T>> {
144        let text: Vec<_> = args
145            .text
146            .iter()
147            .map(|b| {
148                // Check for lowercase
149                if (97..=122).contains(b) {
150                    if args.ignore_softmask {
151                        b'N'
152                    } else {
153                        // only shift lowercase ASCII
154                        b & 0b1011111
155                    }
156                } else {
157                    *b
158                }
159            })
160            .collect();
161        let text_len = T::from_usize(text.len());
162
163        if args.seed_mask.is_some() && args.max_query_len.is_some() {
164            bail!("Cannot use max_query_len and seed_mask together");
165        }
166
167        let sort_type = if let Some(mask) = args.seed_mask {
168            let seed_mask = SeedMask::new(&mask)?;
169            SuffixSortType::Mask(seed_mask)
170        } else {
171            SuffixSortType::MaxQueryLen(args.max_query_len.unwrap_or(0))
172        };
173
174        // Check for long runs of Ns when ambiguous bases are allowed.
175        let mut n_ranges: Vec<Range<usize>> = vec![];
176        if args.allow_ambiguity {
177            let mut n_start: Option<usize> = None;
178            let min_n = 1000;
179            let now = Instant::now();
180            for (i, &byte) in text.iter().enumerate() {
181                if byte == b'N' {
182                    if n_start.is_none() {
183                        n_start = Some(i);
184                    }
185                } else {
186                    if let Some(prev) = n_start {
187                        if i - prev >= min_n {
188                            n_ranges.push(prev..i);
189                        }
190                    }
191                    n_start = None;
192                }
193            }
194            info!("Scanned for runs of Ns in {:?}", now.elapsed());
195        }
196
197        let mut sa = SufrBuilder {
198            version: OUTFILE_VERSION,
199            is_dna: args.is_dna,
200            allow_ambiguity: args.allow_ambiguity,
201            ignore_softmask: args.ignore_softmask,
202            sort_type,
203            text_len,
204            num_suffixes: T::default(),
205            text,
206            num_sequences: T::from_usize(args.sequence_starts.len()),
207            sequence_starts: args
208                .sequence_starts
209                .into_iter()
210                .map(T::from_usize)
211                .collect::<Vec<_>>(),
212            sequence_names: args.sequence_names,
213            partitions: vec![],
214            n_ranges,
215            path: args.path.unwrap_or("out.sufr".to_string()),
216        };
217        sa.sort(args.num_partitions, args.random_seed)?;
218        sa.write()?;
219        Ok(sa)
220    }
221
222    // --------------------------------------------------
223    // TODO: Remove? Only useful during debugging
224    // Return the string at a given suffix position
225    // Warning: Assumes pos is always found.
226    //
227    // Args:
228    // * `pos`: the suffix position
229    //pub(crate) fn string_at(&self, pos: usize) -> String {
230    //    self.text
231    //        .get(pos..)
232    //        .map(|v| String::from_utf8(v.to_vec()).unwrap())
233    //        .unwrap()
234    //}
235
236    // --------------------------------------------------
237    /// If a suffix is in a long run of Ns, return the position of the final N
238    ///
239    /// Args:
240    /// * `suffix`: suffix position
241    fn find_n_run(&self, suffix: usize) -> Option<usize> {
242        self.n_ranges
243            .binary_search_by(|range| {
244                if range.contains(&suffix) {
245                    Ordering::Equal
246                } else if range.start < suffix {
247                    Ordering::Less
248                } else {
249                    Ordering::Greater
250                }
251            })
252            .ok()
253            .map(|i| self.n_ranges[i].end)
254    }
255
256    // --------------------------------------------------
257    /// Find the longest common prefix between two suffixes.
258    ///
259    /// Args:
260    /// * `start1`: position of first suffix
261    /// * `start2`: position of second suffix
262    /// * `len`: the maximum length to check, e.g., the maximum query
263    ///   length or the weight of the seed mask (number of 1/"care" positions)
264    /// * `skip`: skip over the this many characters at the beginning.
265    ///   Because of the incremental way the LCPs are calculated, we may know
266    ///   that two suffixes share the `skip` number in common already.
267    #[inline(always)]
268    fn find_lcp(&self, start1: usize, start2: usize, len: T, skip: usize) -> T {
269        // TODO: Could we use traits for SortType, parameterize the builder
270        // on initialization and avoid conditionals here?
271        match &self.sort_type {
272            SuffixSortType::Mask(mask) => {
273                // Use the seed diff vector to select only the
274                // "care" positions up to the length of the text
275                let a_vals = mask
276                    .positions
277                    .iter()
278                    .skip(skip)
279                    .map(|&offset| start1 + offset)
280                    .filter(|&v| v < self.text_len.to_usize());
281
282                let b_vals = mask
283                    .positions
284                    .iter()
285                    .skip(skip)
286                    .map(|&offset| start2 + offset)
287                    .filter(|&v| v < self.text_len.to_usize());
288
289                unsafe {
290                    T::from_usize(
291                        skip + a_vals
292                            .zip(b_vals)
293                            .take_while(|(a, b)| {
294                                self.text.get_unchecked(*a)
295                                    == self.text.get_unchecked(*b)
296                            })
297                            .count(),
298                    )
299                }
300            }
301            SuffixSortType::MaxQueryLen(max_query_len) => {
302                match (&self.find_n_run(start1), &self.find_n_run(start2)) {
303                    // If the two suffixes start in long stretches of Ns
304                    // Then use the min of the end positions
305                    (Some(end1), Some(end2)) => {
306                        T::from_usize(min(end1 - start1, end2 - start2))
307                    }
308                    _ => {
309                        let text_len = self.text_len.to_usize();
310                        let len = if max_query_len > &0 {
311                            *max_query_len
312                        } else {
313                            len.to_usize()
314                        };
315                        let start1 = start1 + skip;
316                        let start2 = start2 + skip;
317                        let end1 = min(start1 + len, text_len);
318                        let end2 = min(start2 + len, text_len);
319                        unsafe {
320                            T::from_usize(
321                                skip + (start1..end1)
322                                    .zip(start2..end2)
323                                    .take_while(|(a, b)| {
324                                        self.text.get_unchecked(*a)
325                                            == self.text.get_unchecked(*b)
326                                    })
327                                    .count(),
328                            )
329                        }
330                    }
331                }
332            }
333        }
334    }
335
336    // --------------------------------------------------
337    /// Determine whether or not the first suffix position is lexicographically
338    /// less than the second.
339    /// This function is used to place suffixes into the highest partition
340    /// for sorting.
341    ///
342    /// Args:
343    /// * `start1`: the position of the first suffix
344    /// * `start2`: the position of the second suffix
345    #[inline(always)]
346    fn is_less(&self, start1: T, start2: T) -> bool {
347        if start1 == start2 {
348            false
349        } else {
350            let max_query_len = match &self.sort_type {
351                SuffixSortType::Mask(seed_mask) => T::from_usize(seed_mask.weight),
352                SuffixSortType::MaxQueryLen(max_query_len) => {
353                    if max_query_len > &0 {
354                        T::from_usize(*max_query_len)
355                    } else {
356                        self.text_len
357                    }
358                }
359            };
360
361            let len_lcp = find_lcp_full_offset(
362                self.find_lcp(start1.to_usize(), start2.to_usize(), max_query_len, 0)
363                    .to_usize(),
364                &self.sort_type,
365            );
366
367            if len_lcp >= max_query_len.to_usize() {
368                // The strings are equal(ish)
369                false
370            } else {
371                // Look at the next character
372                match (
373                    self.text.get(start1.to_usize() + len_lcp),
374                    self.text.get(start2.to_usize() + len_lcp),
375                ) {
376                    (Some(a), Some(b)) => a < b,
377                    (None, Some(_)) => true,
378                    _ => false,
379                }
380            }
381        }
382    }
383
384    // --------------------------------------------------
385    /// Find the highest partition to place a suffix for sorting.
386    ///
387    /// Args:
388    /// * `suffix`: a suffix position
389    /// * `pivots`: randomly selected suffix positions sorted lexicographically
390    #[inline(always)]
391    fn upper_bound(&self, suffix: T, pivots: &[T]) -> usize {
392        // Returns 0 when pivots is empty
393        pivots.partition_point(|&p| self.is_less(p, suffix))
394    }
395
396    // --------------------------------------------------
397    /// Write the suffixes into temporary files for sorting
398    ///
399    /// Args:
400    /// * `num_partitions`: how many partitions to create
401    /// * `random_seed`: a value for initializing the pseudo-random number
402    ///   generator for reproducibility when selecting the suffixes used
403    ///   for partitioning
404    fn partition(
405        &mut self,
406        num_partitions: usize,
407        random_seed: u64,
408    ) -> Result<PartitionBuildResult<T>> {
409        // Create more partitions than requested because
410        // we can't know how big they will end up being
411        let max_partitions = self.text_len.to_usize() / 4;
412        let num_partitions = if num_partitions * 10 < max_partitions {
413            num_partitions * 10
414        } else if num_partitions * 5 < max_partitions {
415            num_partitions * 5
416        } else if num_partitions * 2 < max_partitions {
417            num_partitions * 2
418        } else if num_partitions < max_partitions {
419            num_partitions
420        } else {
421            max_partitions
422        };
423
424        // Randomly select some pivots
425        let now = Instant::now();
426        let pivot_sa = self.select_pivots(self.text.len(), num_partitions, random_seed);
427        let num_pivots = pivot_sa.len();
428        info!(
429            "Selected {num_pivots} pivot{} in {:?}",
430            if num_pivots == 1 { "" } else { "s" },
431            now.elapsed()
432        );
433
434        let capacity = 4096;
435        let mut builders: Vec<_> = vec![];
436        for _ in 0..num_partitions {
437            let builder: PartitionBuilder<T> = PartitionBuilder::new(capacity)?;
438            builders.push(Arc::new(Mutex::new(builder)));
439        }
440
441        let now = Instant::now();
442        self.text
443            .par_iter()
444            .enumerate()
445            .try_for_each(|(i, &val)| -> Result<()> {
446                if val == SENTINEL_CHARACTER
447                    || !self.is_dna // Allow anything else if not DNA
448                    || (b"ACGT".contains(&val) || self.allow_ambiguity)
449                {
450                    let suffix = T::from_usize(i);
451                    let partition_num = self.upper_bound(suffix, &pivot_sa);
452                    match builders[partition_num].lock() {
453                        Ok(mut partition) => {
454                            if partition.add(suffix).is_err() {
455                                bail!("Unable to write data to disk")
456                            }
457                        }
458                        Err(e) => bail!("{e}"),
459                    }
460                }
461                Ok(())
462            })?;
463
464        // Flush out any remaining buffers
465        let mut num_suffixes = 0;
466        for builder in &builders {
467            match builder.lock() {
468                Ok(mut val) => {
469                    val.write()?;
470                    num_suffixes += val.total_len;
471                }
472                Err(e) => panic!("Failed to lock: {e}"),
473            }
474        }
475
476        info!(
477            "Wrote {num_suffixes} unsorted suffixes to partition{} in {:?}",
478            if num_pivots == 1 { "" } else { "s" },
479            now.elapsed()
480        );
481
482        //Ok((builders, num_suffixes))
483        Ok(PartitionBuildResult {
484            builders,
485            num_suffixes,
486        })
487    }
488
489    // --------------------------------------------------
490    /// Sort the suffixes.
491    ///
492    /// Args:
493    /// * `num_partitions`: the number of partitions to used
494    /// * `random_seed`: a value for initializing the RNG
495    fn sort(&mut self, num_partitions: usize, random_seed: u64) -> Result<()> {
496        let mut partition_build = self.partition(num_partitions, random_seed)?;
497
498        // Be sure to round up to get all the suffixes
499        let num_per_partition = (partition_build.num_suffixes as f64
500            / num_partitions as f64)
501            .ceil() as usize;
502        let total_sort_time = Instant::now();
503        let mut num_taken = 0;
504        let mut partition_inputs = vec![vec![]; num_partitions];
505
506        // We (probably) have many more partitions than we need,
507        // so here we accumulate the small partitions from the left
508        // stopping when we reach a boundary like 1M/partition.
509        // This evens out the workload to sort the partitions.
510        #[allow(clippy::needless_range_loop)]
511        for partition_num in 0..num_partitions {
512            let boundary = num_per_partition * (partition_num + 1);
513            while !partition_build.builders.is_empty() {
514                let part = partition_build.builders.remove(0);
515                match part.lock() {
516                    Ok(builder) => {
517                        if builder.total_len > 0 {
518                            partition_inputs[partition_num]
519                                .push((builder.path.clone(), builder.total_len));
520                            num_taken += builder.total_len;
521                        }
522                    }
523                    Err(e) => panic!("Can't get partition: {e}"),
524                }
525
526                // Let the last partition soak up the rest
527                if partition_num < num_partitions - 1 && num_taken > boundary {
528                    break;
529                }
530            }
531        }
532
533        // Ensure we got all the suffixes
534        if num_taken != partition_build.num_suffixes {
535            bail!(
536                "Took {num_taken} but needed to take {}",
537                partition_build.num_suffixes
538            );
539        }
540
541        let mut partitions: Vec<Option<Partition>> =
542            (0..num_partitions).map(|_| None).collect();
543
544        partitions.par_iter_mut().enumerate().try_for_each(
545            |(partition_num, partition)| -> Result<()> {
546                // Find the suffixes in this partition
547                let mut part_sa = vec![];
548                for (path, len) in &partition_inputs[partition_num] {
549                    let buffer = fs::read(path)?;
550                    let mut part: Vec<T> = slice_u8_to_vec(&buffer, *len);
551                    part_sa.append(&mut part);
552                    fs::remove_file(path)?;
553                }
554
555                let len = part_sa.len();
556                if len > 0 {
557                    let mut sa_w = part_sa.clone();
558                    let mut lcp = vec![T::default(); len];
559                    let mut lcp_w = vec![T::default(); len];
560                    self.merge_sort(&mut sa_w, &mut part_sa, len, &mut lcp, &mut lcp_w);
561
562                    // Write to disk
563                    let mut sa_file = NamedTempFile::new()?;
564                    let _ = sa_file.write(vec_to_slice_u8(&part_sa))?;
565                    let mut lcp_file = NamedTempFile::new()?;
566                    let _ = lcp_file.write(vec_to_slice_u8(&lcp))?;
567                    let (_, sa_path) = sa_file.keep()?;
568                    let (_, lcp_path) = lcp_file.keep()?;
569
570                    *partition = Some(Partition {
571                        order: partition_num,
572                        len,
573                        first_suffix: part_sa.first().unwrap().to_usize(),
574                        last_suffix: part_sa.last().unwrap().to_usize(),
575                        sa_path,
576                        lcp_path,
577                    });
578                }
579                Ok(())
580            },
581        )?;
582
583        // Get rid of None/unwrap Some, put in order
584        let mut partitions: Vec<_> = partitions.into_iter().flatten().collect();
585        partitions.sort_by_key(|p| p.order);
586
587        let sizes: Vec<_> = partitions.iter().map(|p| p.len).collect();
588        let total_size = sizes.iter().sum::<usize>();
589        info!(
590            "Sorted {total_size} suffixes in {num_partitions} partitions (avg {}) in {:?}",
591            total_size / num_partitions,
592            total_sort_time.elapsed()
593        );
594        self.num_suffixes = T::from_usize(sizes.iter().sum());
595        self.partitions = partitions;
596
597        Ok(())
598    }
599
600    // --------------------------------------------------
601    fn merge_sort(
602        &self,
603        x: &mut [T],
604        y: &mut [T],
605        n: usize,
606        lcp: &mut [T],
607        lcp_w: &mut [T],
608    ) {
609        if n == 1 {
610            lcp[0] = T::default();
611        } else {
612            let mid = n / 2;
613            self.merge_sort(
614                &mut y[..mid],
615                &mut x[..mid],
616                mid,
617                &mut lcp_w[..mid],
618                &mut lcp[..mid],
619            );
620
621            self.merge_sort(
622                &mut y[mid..],
623                &mut x[mid..],
624                n - mid,
625                &mut lcp_w[mid..],
626                &mut lcp[mid..],
627            );
628
629            self.merge(x, mid, lcp_w, y, lcp);
630        }
631    }
632
633    // --------------------------------------------------
634    fn merge(
635        &self,
636        suffix_array: &mut [T],
637        mid: usize,
638        lcp_w: &mut [T],
639        target_sa: &mut [T],
640        target_lcp: &mut [T],
641    ) {
642        let (mut x, mut y) = suffix_array.split_at_mut(mid);
643        let (mut lcp_x, mut lcp_y) = lcp_w.split_at_mut(mid);
644        let mut len_x = x.len();
645        let mut len_y = y.len();
646        let mut m = T::default(); // Last LCP from left side (x)
647        let mut idx_x = 0; // Index into x (left side)
648        let mut idx_y = 0; // Index into y (right side)
649        let mut idx_target = 0; // Index into target
650
651        while idx_x < len_x && idx_y < len_y {
652            let l_x = lcp_x[idx_x];
653
654            match l_x.cmp(&m) {
655                Ordering::Greater => {
656                    target_sa[idx_target] = x[idx_x];
657                    target_lcp[idx_target] = l_x;
658                }
659                Ordering::Less => {
660                    target_sa[idx_target] = y[idx_y];
661                    target_lcp[idx_target] = m;
662                    m = l_x;
663                }
664                Ordering::Equal => {
665                    let shorter_suffix = max(x[idx_x], y[idx_y]);
666                    let max_n = self.text_len - shorter_suffix;
667
668                    let context = match &self.sort_type {
669                        SuffixSortType::Mask(seed_mask) => T::from_usize(
670                            seed_mask
671                                .positions
672                                .iter()
673                                .filter(|&i| *i < max_n.to_usize())
674                                .count(),
675                        ),
676                        SuffixSortType::MaxQueryLen(max_query_len) => {
677                            if max_query_len > &0 {
678                                min(T::from_usize(*max_query_len), max_n)
679                            } else {
680                                max_n
681                            }
682                        }
683                    };
684
685                    // LCP(X_i, Y_j)
686                    let (len_lcp, full_len_lcp) = if m < context {
687                        let lcp = self.find_lcp(
688                            x[idx_x].to_usize(),
689                            y[idx_y].to_usize(),
690                            context - m,
691                            m.to_usize(), // skip
692                        );
693                        let full_lcp =
694                            find_lcp_full_offset(lcp.to_usize(), &self.sort_type);
695                        (lcp, T::from_usize(full_lcp))
696                    } else {
697                        (context, context)
698                    };
699
700                    // If full LCP equals context/MQL, take shorter suffix
701                    if len_lcp >= context {
702                        target_sa[idx_target] = shorter_suffix;
703                    }
704                    // Else, look at the next char after the LCP to determine order.
705                    else {
706                        let cmp = self.text[(x[idx_x] + full_len_lcp).to_usize()]
707                            .cmp(&self.text[(y[idx_y] + full_len_lcp).to_usize()]);
708
709                        match cmp {
710                            Ordering::Equal => {
711                                target_sa[idx_target] = shorter_suffix;
712                            }
713                            Ordering::Less => {
714                                target_sa[idx_target] = x[idx_x];
715                            }
716                            Ordering::Greater => {
717                                target_sa[idx_target] = y[idx_y];
718                            }
719                        }
720                    }
721
722                    // If we took from the right...
723                    if target_sa[idx_target] == x[idx_x] {
724                        target_lcp[idx_target] = l_x;
725                    } else {
726                        target_lcp[idx_target] = m
727                    }
728
729                    m = len_lcp;
730                }
731            }
732
733            if target_sa[idx_target] == x[idx_x] {
734                idx_x += 1;
735            } else {
736                idx_y += 1;
737                mem::swap(&mut x, &mut y);
738                mem::swap(&mut len_x, &mut len_y);
739                mem::swap(&mut lcp_x, &mut lcp_y);
740                mem::swap(&mut idx_x, &mut idx_y);
741            }
742            idx_target += 1;
743        }
744
745        // Copy rest of the data from X to Z.
746        while idx_x < len_x {
747            target_sa[idx_target] = x[idx_x];
748            target_lcp[idx_target] = lcp_x[idx_x];
749            idx_x += 1;
750            idx_target += 1;
751        }
752
753        // Copy rest of the data from Y to Z.
754        if idx_y < len_y {
755            target_sa[idx_target] = y[idx_y];
756            target_lcp[idx_target] = m;
757            idx_y += 1;
758            idx_target += 1;
759
760            while idx_y < len_y {
761                target_sa[idx_target] = y[idx_y];
762                target_lcp[idx_target] = lcp_y[idx_y];
763                idx_y += 1;
764                idx_target += 1;
765            }
766        }
767    }
768
769    // --------------------------------------------------
770    #[inline(always)]
771    fn select_pivots(
772        &self,
773        text_len: usize,
774        num_partitions: usize,
775        random_seed: u64,
776    ) -> Vec<T> {
777        if num_partitions > 1 {
778            // Use a HashMap because selecting pivots one-at-a-time
779            // can result in duplicates.
780            let num_pivots = num_partitions - 1;
781            let mut rng: Box<dyn RngCore> = if random_seed > 0 {
782                Box::new(StdRng::seed_from_u64(random_seed))
783            } else {
784                Box::new(rand::rng())
785            };
786            let mut pivot_sa = HashSet::<T>::new();
787            loop {
788                let pos = rng.random_range(0..text_len);
789                if self.is_dna && !b"ACGT$".contains(&self.text[pos]) {
790                    continue;
791                }
792                let _ = pivot_sa.insert(T::from_usize(pos));
793                if pivot_sa.len() == num_pivots {
794                    break;
795                }
796            }
797
798            // Sort the selected pivots
799            let mut pivot_sa: Vec<T> = pivot_sa.iter().cloned().collect();
800            let mut sa_w = pivot_sa.clone();
801            let len = pivot_sa.len();
802            let mut lcp = vec![T::default(); len];
803            let mut lcp_w = vec![T::default(); len];
804            self.merge_sort(&mut sa_w, &mut pivot_sa, len, &mut lcp, &mut lcp_w);
805            pivot_sa
806        } else {
807            vec![]
808        }
809    }
810
811    // --------------------------------------------------
812    /// Serialize contents of the sorted partitions to a _.sufr_ file.
813    /// Returns the number of bytes written to disk.
814    ///
815    /// Args:
816    /// * `filename`: the name of the output file.
817    fn write(&self) -> Result<()> {
818        let filename = &self.path;
819        let mut file = BufWriter::new(
820            File::create(filename).map_err(|e| anyhow!("{filename}: {e}"))?,
821        );
822
823        let mut bytes_out: usize = 0;
824
825        // Various metadata
826        let is_dna: u8 = if self.is_dna { 1 } else { 0 };
827        let allow_ambiguity: u8 = if self.allow_ambiguity { 1 } else { 0 };
828        let ignore_softmask: u8 = if self.ignore_softmask { 1 } else { 0 };
829        bytes_out +=
830            file.write(&[OUTFILE_VERSION, is_dna, allow_ambiguity, ignore_softmask])?;
831
832        // Text length
833        bytes_out += file.write(&usize_to_bytes(self.text_len.to_usize()))?;
834
835        // Locations of text, suffix array, and LCP
836        // Will be corrected at the end
837        let locs_pos = file.stream_position()?;
838        bytes_out += file.write(&usize_to_bytes(0usize))?;
839        bytes_out += file.write(&usize_to_bytes(0usize))?;
840        bytes_out += file.write(&usize_to_bytes(0usize))?;
841
842        // Number of suffixes
843        bytes_out += file.write(&usize_to_bytes(self.num_suffixes.to_usize()))?;
844
845        // Max query length
846        let max_query_len = if let SuffixSortType::MaxQueryLen(val) = &self.sort_type {
847            *val
848        } else {
849            0
850        };
851        bytes_out += file.write(&usize_to_bytes(max_query_len))?;
852
853        // Number of sequences
854        bytes_out += file.write(&usize_to_bytes(self.sequence_starts.len()))?;
855
856        // Sequence starts
857        bytes_out += file.write(vec_to_slice_u8(&self.sequence_starts))?;
858
859        // Seed mask
860        match &self.sort_type {
861            SuffixSortType::Mask(seed_mask) => {
862                bytes_out += file.write(&usize_to_bytes(seed_mask.bytes.len()))?;
863                file.write_all(&seed_mask.bytes)?;
864                bytes_out += seed_mask.bytes.len();
865            }
866            _ => bytes_out += file.write(&usize_to_bytes(0))?,
867        }
868
869        // Text
870        let text_pos = bytes_out;
871        file.write_all(&self.text)?;
872        bytes_out += self.text.len();
873
874        // Stitch partitioned suffix files together
875        let sa_pos = bytes_out;
876        for partition in &self.partitions {
877            let buffer = fs::read(&partition.sa_path)?;
878            bytes_out += &buffer.len();
879            file.write_all(&buffer)?;
880            fs::remove_file(&partition.sa_path)?;
881        }
882
883        let lcp_pos = bytes_out;
884
885        // Stitch partitioned LCP files together
886        for (i, partition) in self.partitions.iter().enumerate() {
887            let buffer = fs::read(&partition.lcp_path)?;
888            bytes_out += &buffer.len();
889
890            if i == 0 {
891                file.write_all(&buffer)?;
892            } else {
893                // Fix LCP boundary
894                let mut lcp: Vec<T> = slice_u8_to_vec(&buffer, partition.len);
895                if let Some(val) = lcp.first_mut() {
896                    *val = self.find_lcp(
897                        self.partitions[i - 1].last_suffix,
898                        partition.first_suffix,
899                        self.text_len,
900                        0, // start at beginning
901                    );
902                }
903                file.write_all(vec_to_slice_u8(&lcp))?;
904            }
905            fs::remove_file(&partition.lcp_path)?;
906        }
907
908        // Sequence names are variable in length so they are at the end
909        _ = file.write(&bincode::serialize(&self.sequence_names)?)?;
910
911        // Go back to header and record the locations
912        file.seek(SeekFrom::Start(locs_pos))?;
913        let _ = file.write(&usize_to_bytes(text_pos))?;
914        let _ = file.write(&usize_to_bytes(sa_pos))?;
915        let _ = file.write(&usize_to_bytes(lcp_pos))?;
916
917        Ok(())
918    }
919}
920
921// --------------------------------------------------
922/// Represents the partition values written to disk
923#[derive(Debug)]
924struct Partition {
925    /// The sorted position of this parition.
926    order: usize,
927
928    /// The number of suffixes/LCP values contained in this partition.
929    len: usize,
930
931    /// The value of the first suffix. Used in stitching together the LCPs.
932    first_suffix: usize,
933
934    /// The value of the last suffix. Used in stitching together the LCPs.
935    last_suffix: usize,
936
937    /// The path to the file containing the suffix array.
938    sa_path: PathBuf,
939
940    /// The path to the file containing the LCP array.
941    lcp_path: PathBuf,
942}
943
944// --------------------------------------------------
945/// This struct provides access to the on-disk partitions.
946#[derive(Debug)]
947struct PartitionBuildResult<T>
948where
949    T: Int + FromUsize<T> + Sized + Send + Sync + serde::ser::Serialize,
950{
951    /// A thread-safe vector of `PartitionBuilder` values
952    builders: Vec<Arc<Mutex<PartitionBuilder<T>>>>,
953
954    /// The total number of suffixes that were written to disk.
955    num_suffixes: usize,
956}
957
958// --------------------------------------------------
959/// A struct for writing suffixes to disk.
960#[derive(Debug)]
961struct PartitionBuilder<T>
962where
963    T: Int + FromUsize<T> + Sized + Send + Sync + serde::ser::Serialize,
964{
965    vals: Vec<T>,
966    capacity: usize,
967    len: usize,
968    total_len: usize,
969    path: PathBuf,
970}
971
972// --------------------------------------------------
973impl<T> PartitionBuilder<T>
974where
975    T: Int + FromUsize<T> + Sized + Send + Sync + serde::ser::Serialize,
976{
977    /// Create a new `PartitionBuilder`. This struct is used to write it's
978    /// suffix positions to a temporary file.
979    ///
980    /// Args:
981    /// * `capacity`: the number of suffixes to hold in memory until the
982    ///   writing to disk. This minimizes the number of times we access the disk
983    ///   while also limiting the amount of memory used. Currently set to 4096
984    ///   but it might be worth tuning this, perhaps use more memory to hit
985    ///   disk less? Or if memory use is too high, lower and take a performance
986    ///   hit for disk access?
987    fn new(capacity: usize) -> Result<Self> {
988        let tmp = NamedTempFile::new()?;
989        let (_, path) = tmp.keep()?;
990
991        Ok(PartitionBuilder {
992            // Re-use a static vector to avoid repeated allocations
993            vals: vec![T::default(); capacity],
994            len: 0,
995            total_len: 0,
996            capacity,
997            path,
998        })
999    }
1000
1001    /// Add a suffix to the partition. When the internal array of values hits
1002    /// `capacity`, then write all the values to disk.
1003    ///
1004    /// Args:
1005    /// * `val`: the suffix position to add
1006    pub fn add(&mut self, val: T) -> Result<()> {
1007        self.vals[self.len] = val;
1008        self.len += 1;
1009        if self.len == self.capacity {
1010            self.write()?;
1011            self.len = 0;
1012        }
1013
1014        Ok(())
1015    }
1016
1017    /// Write the suffixes to disk. This must be called at the end to flush
1018    /// any remaining values after the last call(s) from `add`.
1019    pub fn write(&mut self) -> Result<()> {
1020        if self.len > 0 {
1021            let mut file = OpenOptions::new()
1022                .create(true)
1023                .append(true)
1024                .open(&self.path)?;
1025            file.write_all(vec_to_slice_u8(&self.vals[0..self.len]))?;
1026            self.total_len += self.len;
1027        }
1028        Ok(())
1029    }
1030}
1031
1032// --------------------------------------------------
1033#[cfg(test)]
1034mod test {
1035    use super::{SufrBuilder, SufrBuilderArgs};
1036    use anyhow::Result;
1037    use std::fs;
1038    use tempfile::NamedTempFile;
1039    use pretty_assertions::assert_eq;
1040
1041    #[test]
1042    fn test_is_less() -> Result<()> {
1043        //           012345
1044        let text = b"TTTAGC".to_vec();
1045        let outfile = NamedTempFile::new()?;
1046        let args = SufrBuilderArgs {
1047            text,
1048            low_memory: true,
1049            path: Some(outfile.path().to_string_lossy().to_string()),
1050            max_query_len: None,
1051            is_dna: false,
1052            allow_ambiguity: false,
1053            ignore_softmask: false,
1054            sequence_starts: vec![0],
1055            sequence_names: vec!["1".to_string()],
1056            num_partitions: 2,
1057            seed_mask: None,
1058            random_seed: 0,
1059        };
1060        let sufr = SufrBuilder::<u32>::new(args)?;
1061
1062        // 1: TTAGC
1063        // 0: TTTAGC
1064        assert!(sufr.is_less(1, 0));
1065
1066        // 0: TTTAGC
1067        // 1: TTAGC
1068        assert!(!sufr.is_less(0, 1));
1069
1070        // 2: TAGC
1071        // 3: AGC
1072        assert!(!sufr.is_less(2, 3));
1073
1074        // 3: AGC
1075        // 0: TTTAGC
1076        assert!(sufr.is_less(3, 0));
1077
1078        fs::remove_file(outfile)?;
1079
1080        Ok(())
1081    }
1082
1083    #[test]
1084    fn test_is_less_max_query_len() -> Result<()> {
1085        //           012345
1086        let text = b"TTTAGC".to_vec();
1087        let outfile = NamedTempFile::new()?;
1088        let args = SufrBuilderArgs {
1089            text,
1090            low_memory: true,
1091            path: Some(outfile.path().to_string_lossy().to_string()),
1092            max_query_len: Some(2),
1093            is_dna: false,
1094            allow_ambiguity: false,
1095            ignore_softmask: false,
1096            sequence_starts: vec![0],
1097            sequence_names: vec!["1".to_string()],
1098            num_partitions: 2,
1099            seed_mask: None,
1100            random_seed: 0,
1101        };
1102        let sufr = SufrBuilder::<u32>::new(args)?;
1103
1104        // 1: TTAGC
1105        // 0: TTTAGC
1106        // This is true w/o MQL 2 but here they are equal
1107        // ("TT" == "TT")
1108        assert!(!sufr.is_less(1, 0));
1109
1110        // 0: TTTAGC
1111        // 1: TTAGC
1112        // ("TT" == "TT")
1113        assert!(!sufr.is_less(0, 1));
1114
1115        // 2: TAGC
1116        // 3: AGC
1117        assert!(!sufr.is_less(2, 3));
1118
1119        // 3: AGC
1120        // 0: TTTAGC
1121        assert!(sufr.is_less(3, 0));
1122
1123        fs::remove_file(outfile)?;
1124
1125        Ok(())
1126    }
1127
1128    #[test]
1129    fn test_is_less_seed_mask() -> Result<()> {
1130        //           012345
1131        let text = b"TTTTAT".to_vec();
1132        let outfile = NamedTempFile::new()?;
1133        let args = SufrBuilderArgs {
1134            text,
1135            low_memory: true,
1136            path: Some(outfile.path().to_string_lossy().to_string()),
1137            max_query_len: None,
1138            is_dna: false,
1139            allow_ambiguity: false,
1140            ignore_softmask: false,
1141            sequence_starts: vec![0],
1142            sequence_names: vec!["1".to_string()],
1143            num_partitions: 2,
1144            seed_mask: Some("101".to_string()),
1145            random_seed: 0,
1146        };
1147        let sufr: SufrBuilder<u32> = SufrBuilder::new(args)?;
1148
1149        // 0: TTTTAT
1150        // 1: TTTAT
1151        // "T-T" vs "T-T"
1152        assert!(!sufr.is_less(0, 1));
1153
1154        // 1: TTTAT
1155        // 0: TTTTAT
1156        // "T-T" vs "T-T"
1157        assert!(!sufr.is_less(1, 0));
1158
1159        // 0: TTTTAT
1160        // 3: TAT
1161        // "T-T" vs "T-T"
1162        assert!(!sufr.is_less(0, 3));
1163
1164        // 3: TAT
1165        // 0: TTTTAT
1166        // "T-T" vs "T-T"
1167        assert!(!sufr.is_less(3, 0));
1168
1169        fs::remove_file(outfile)?;
1170
1171        Ok(())
1172    }
1173
1174    #[test]
1175    fn test_find_lcp_no_seed_mask() -> Result<()> {
1176        //           012345
1177        let text = b"TTTAGC".to_vec();
1178        let outfile = NamedTempFile::new()?;
1179        let args = SufrBuilderArgs {
1180            text,
1181            low_memory: true,
1182            path: Some(outfile.path().to_string_lossy().to_string()),
1183            max_query_len: None,
1184            is_dna: false,
1185            allow_ambiguity: false,
1186            ignore_softmask: false,
1187            sequence_starts: vec![0],
1188            sequence_names: vec!["1".to_string()],
1189            num_partitions: 2,
1190            seed_mask: None,
1191            random_seed: 0,
1192        };
1193        let sufr: SufrBuilder<u32> = SufrBuilder::new(args)?;
1194
1195        // 0: TTTAGC
1196        // 1:  TTAGC
1197        // 6: len of text
1198        assert_eq!(sufr.find_lcp(0, 1, 6, 0), 2);
1199
1200        // 0: TTTAGC
1201        // 2:   TAGC
1202        // 6: len of text
1203        assert_eq!(sufr.find_lcp(0, 2, 6, 0), 1);
1204
1205        // 0: TTTAGC
1206        // 1:  TTAGC
1207        // 1: max query len = 1
1208        assert_eq!(sufr.find_lcp(0, 1, 1, 0), 1);
1209
1210        // 0: TTTAGC
1211        // 3:    AGC
1212        // 6: len of text
1213        assert_eq!(sufr.find_lcp(0, 3, 6, 0), 0);
1214
1215        // TODO: Add a test with skip
1216
1217        fs::remove_file(outfile)?;
1218
1219        Ok(())
1220    }
1221
1222    #[test]
1223    fn test_find_lcp_with_seed_mask() -> Result<()> {
1224        //           012345
1225        let text = b"TTTTTA".to_vec();
1226        let outfile = NamedTempFile::new()?;
1227        let args = SufrBuilderArgs {
1228            text,
1229            low_memory: true,
1230            path: Some(outfile.path().to_string_lossy().to_string()),
1231            max_query_len: None,
1232            is_dna: false,
1233            allow_ambiguity: false,
1234            ignore_softmask: false,
1235            sequence_starts: vec![0],
1236            sequence_names: vec!["1".to_string()],
1237            num_partitions: 2,
1238            seed_mask: Some("1101".to_string()),
1239            random_seed: 42,
1240        };
1241        let sufr: SufrBuilder<u32> = SufrBuilder::new(args)?;
1242
1243        // 0: TTTTTA
1244        // 1:  TTTTA
1245        assert_eq!(sufr.find_lcp(0, 1, 3, 0), 3);
1246
1247        // 0: TTTTTA
1248        // 2:   TTTA
1249        assert_eq!(sufr.find_lcp(0, 2, 3, 0), 2);
1250
1251        // 0: TTTTTA
1252        // 5:      A
1253        assert_eq!(sufr.find_lcp(0, 5, 3, 0), 0);
1254
1255        fs::remove_file(outfile)?;
1256
1257        Ok(())
1258    }
1259
1260    #[test]
1261    fn test_upper_bound_1() -> Result<()> {
1262        //          012345
1263        let text = b"TTTAGC".to_vec();
1264        let outfile = NamedTempFile::new()?;
1265        let args = SufrBuilderArgs {
1266            text,
1267            low_memory: true,
1268            path: Some(outfile.path().to_string_lossy().to_string()),
1269            max_query_len: None,
1270            is_dna: false,
1271            allow_ambiguity: false,
1272            ignore_softmask: false,
1273            sequence_starts: vec![0],
1274            sequence_names: vec!["1".to_string()],
1275            num_partitions: 2,
1276            seed_mask: None,
1277            random_seed: 42,
1278        };
1279        let sufr: SufrBuilder<u32> = SufrBuilder::new(args)?;
1280
1281        // The suffix "AGC$" is found before "GC$" and "C$
1282        assert_eq!(sufr.upper_bound(3, &[5, 4]), 0);
1283
1284        // The suffix "TAGC$" is beyond all the values
1285        assert_eq!(sufr.upper_bound(2, &[3, 4, 5]), 3);
1286
1287        // The "C$" is the last value
1288        assert_eq!(sufr.upper_bound(5, &[3, 4, 5]), 1);
1289
1290        fs::remove_file(outfile)?;
1291
1292        Ok(())
1293    }
1294
1295    #[test]
1296    fn test_upper_bound_2() -> Result<()> {
1297        //           0123456789
1298        let text = b"ACGTNNACGT".to_vec();
1299        let outfile = NamedTempFile::new()?;
1300        let args = SufrBuilderArgs {
1301            text,
1302            low_memory: true,
1303            path: Some(outfile.path().to_string_lossy().to_string()),
1304            max_query_len: None,
1305            is_dna: false,
1306            allow_ambiguity: false,
1307            ignore_softmask: false,
1308            sequence_starts: vec![0],
1309            sequence_names: vec!["1".to_string()],
1310            num_partitions: 2,
1311            seed_mask: None,
1312            random_seed: 42,
1313        };
1314
1315        let sufr: SufrBuilder<u64> = SufrBuilder::new(args)?;
1316
1317        // ACGTNNACGT$ == ACGTNNACGT$
1318        assert_eq!(sufr.upper_bound(0, &[0]), 0);
1319
1320        // ACGTNNACGT$ (0) > ACGT$ (6)
1321        assert_eq!(sufr.upper_bound(0, &[6]), 1);
1322
1323        // ACGT$ < ACGTNNACGT$
1324        assert_eq!(sufr.upper_bound(6, &[0]), 0);
1325
1326        // ACGT$ == ACGT$
1327        assert_eq!(sufr.upper_bound(6, &[6]), 0);
1328
1329        // Pivots = [CGT$, GT$]
1330        // ACGTNNACGT$ < CGT$ => p0
1331        assert_eq!(sufr.upper_bound(0, &[7, 8]), 0);
1332
1333        // CGTNNACGT$ > CGT$  => p1
1334        assert_eq!(sufr.upper_bound(1, &[7, 8]), 1);
1335
1336        // GT$ == GT$  => p1
1337        assert_eq!(sufr.upper_bound(1, &[7, 8]), 1);
1338
1339        // T$ > GT$  => p2
1340        assert_eq!(sufr.upper_bound(9, &[7, 8]), 2);
1341
1342        // T$ < TNNACGT$ => p0
1343        assert_eq!(sufr.upper_bound(9, &[3]), 0);
1344
1345        fs::remove_file(outfile)?;
1346
1347        Ok(())
1348    }
1349
1350    #[test]
1351    fn test_upper_bound_seed_mask() -> Result<()> {
1352        //           0123456789
1353        let text = b"ACGTNNACGT".to_vec();
1354        let outfile = NamedTempFile::new()?;
1355        let args = SufrBuilderArgs {
1356            text,
1357            low_memory: true,
1358            path: Some(outfile.path().to_string_lossy().to_string()),
1359            max_query_len: None,
1360            is_dna: false,
1361            allow_ambiguity: false,
1362            ignore_softmask: false,
1363            sequence_starts: vec![0],
1364            sequence_names: vec!["1".to_string()],
1365            num_partitions: 2,
1366            seed_mask: Some("101".to_string()),
1367            random_seed: 42,
1368        };
1369        let sufr: SufrBuilder<u32> = SufrBuilder::new(args)?;
1370
1371        // ACGTNNACGT$ == ACGTNNACGT$ (A-G)
1372        assert_eq!(sufr.upper_bound(0, &[0]), 0);
1373
1374        // ACGTNNACGT$ == ACGT$ (A-G)
1375        assert_eq!(sufr.upper_bound(0, &[6]), 0);
1376
1377        // ACGT$ == ACGTNNACGT$ (A-G)
1378        assert_eq!(sufr.upper_bound(6, &[0]), 0);
1379
1380        // ACGT$ == ACGT$ (A-G)
1381        assert_eq!(sufr.upper_bound(6, &[6]), 0);
1382
1383        // Pivots = [CGT$, GT$]
1384        // ACGTNNACGT$ < CGT$
1385        assert_eq!(sufr.upper_bound(0, &[7, 8]), 0);
1386
1387        // Pivots = [CGT$, GT$]
1388        // CGTNNACGT$ == CGT$ (C-T)
1389        assert_eq!(sufr.upper_bound(1, &[7, 8]), 0);
1390
1391        // Pivots = [CGT$, GT$]
1392        // GT$ == GT$
1393        assert_eq!(sufr.upper_bound(8, &[7, 8]), 1);
1394
1395        // Pivots = [CGT$, GT$]
1396        // T$ > GT$  => p2
1397        assert_eq!(sufr.upper_bound(9, &[7, 8]), 2);
1398
1399        // T$ == TNNACGT$ (only compare T)
1400        assert_eq!(sufr.upper_bound(9, &[3]), 0);
1401
1402        fs::remove_file(outfile)?;
1403
1404        Ok(())
1405    }
1406}