Skip to main content

gbz_base/
utils.rs

1//! Utility functions and structures.
2
3use std::collections::HashMap;
4use std::fs::File;
5use std::ops::{Range, RangeInclusive};
6use std::path::{Path, PathBuf};
7use std::io::{self, BufRead, BufReader};
8
9use flate2::read::MultiGzDecoder;
10
11use gbz::{GBWT, Pos, ENDMARKER};
12use pggname::GraphName;
13use simple_sds::binaries;
14
15//-----------------------------------------------------------------------------
16
17/// Returns the full file name for a specific test file.
18pub fn get_test_data(filename: &'static str) -> PathBuf {
19    let mut buf = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
20    buf.push("test-data");
21    buf.push(filename);
22    buf
23}
24
25//-----------------------------------------------------------------------------
26
27/// Returns a human-readable string representation of a size in bytes.
28///
29/// Reports the size using three decimal places.
30pub fn human_readable_size(size: usize) -> String {
31    let (size, unit) = binaries::human_readable_size(size);
32    format!("{:.3} {}", size, unit)
33}
34
35/// Returns a human-readable string representation of the file size in bytes.
36///
37/// Reports the size using three decimal places.
38/// Returns [`None`] if the file does not exist or the size cannot be determined.
39/// See [`simple_sds::binaries::file_size`] and [`simple_sds::binaries::human_readable_size`] for further information.
40pub fn file_size<P: AsRef<Path>>(filename: P) -> Option<String> {
41    let (size, unit) = binaries::file_size(filename)?;
42    Some(format!("{:.3} {}", size, unit))
43}
44
45/// Prints the peak resident set size to stderr if it can be determined.
46///
47/// Reports the size using three decimal places.
48pub fn report_peak_memory_usage() {
49    let peak_memory = binaries::peak_memory_usage();
50    if peak_memory.is_err() {
51        return;
52    }
53    let (size, unit) = binaries::human_readable_size(peak_memory.unwrap());
54    eprintln!("Peak memory usage: {:.3} {}", size, unit);
55}
56
57/// Returns `true` if the reader appears to be gzip-compressed.
58///
59/// # Errors
60///
61/// Passes through all I/O errors from the reader.
62pub fn is_gzipped<R: BufRead>(reader: &mut R) -> io::Result<bool> {
63    let buffer = reader.fill_buf()?;
64    let result = buffer.len() >= 2 && buffer[0..2] == [0x1F, 0x8B];
65    Ok(result)
66}
67
68/// Returns a buffered reader for the file, which may be gzip-compressed.
69///
70/// Use `-` as the file name to read from standard input.
71///
72/// # Errors
73///
74/// Passes through any I/O errors from trying to open and read the file.
75pub fn open_file<P: AsRef<Path>>(filename: P) -> Result<Box<dyn BufRead>, String> {
76    let mut inner = if filename.as_ref() == Path::new("-") {
77        Box::new(BufReader::new(io::stdin())) as Box<dyn BufRead>
78    } else {
79        let file = File::open(&filename).map_err(|x| format!("Failed to open file {}: {}", filename.as_ref().display(), x))?;
80        Box::new(BufReader::new(file)) as Box<dyn BufRead>
81    };
82    if is_gzipped(&mut inner).map_err(|x| format!("Failed to read file {}: {}", filename.as_ref().display(), x))? {
83        let gz_inner = MultiGzDecoder::new(inner);
84        Ok(Box::new(BufReader::new(gz_inner)))
85    } else {
86        Ok(inner)
87    }
88}
89
90//-----------------------------------------------------------------------------
91
92// Working with `Vec<u8>` buffers.
93
94/// Appends an unsigned integer a string represented as `Vec<u8>`.
95pub fn append_usize(buffer: &mut Vec<u8>, value: usize) {
96    buffer.extend_from_slice(value.to_string().as_bytes());
97}
98
99/// Appends a signed integer a string represented as `Vec<u8>`.
100pub fn append_isize(buffer: &mut Vec<u8>, value: isize) {
101    buffer.extend_from_slice(value.to_string().as_bytes());
102}
103
104//-----------------------------------------------------------------------------
105
106// Sequence encoding and decoding.
107
108// TODO: Precompute the decoding table for a byte.
109const DECODE: [u8; 6] = [0, b'A', b'C', b'G', b'T', b'N'];
110
111/// Decodes a single base encoded with [`encode_base`].
112///
113/// # Panics
114///
115/// Panics if `encoded > 5`.
116#[inline]
117pub fn decode_base(encoded: usize) -> u8 {
118    DECODE[encoded]
119}
120
121/// Decodes a sequence encoded with [`encode_sequence`].
122pub fn decode_sequence(encoded: &[u8]) -> Vec<u8> {
123    let capacity = if encoded.is_empty() { 0 } else { 3 * encoded.len() };
124    let mut result = Vec::with_capacity(capacity);
125
126    for byte in encoded {
127        let mut value = *byte as usize;
128        for _ in 0..3 {
129            let decoded = DECODE[value % DECODE.len()];
130            if decoded == 0 {
131                return result;
132            }
133            value /= DECODE.len();
134            result.push(decoded);
135        }
136    }
137
138    result
139}
140
141const fn generate_encoding() -> [u8; 256] {
142    let mut result = [5; 256];
143    result[b'a' as usize] = 1; result[b'A' as usize] = 1;
144    result[b'c' as usize] = 2; result[b'C' as usize] = 2;
145    result[b'g' as usize] = 3; result[b'G' as usize] = 3;
146    result[b't' as usize] = 4; result[b'T' as usize] = 4;
147    result
148}
149
150const ENCODE: [u8; 256] = generate_encoding();
151
152/// Encodes a single base.
153///
154/// Use [`decode_base`] to decode.
155#[inline]
156pub fn encode_base(base: u8) -> usize {
157    ENCODE[base as usize] as usize
158}
159
160/// Encodes a DNA sequence into a byte array, storing three bases in a byte.
161///
162/// Values outside `acgtACGT` are encoded as `N`.
163/// The last encoded symbol may be a special 0 character in order to preserve the length.
164/// This sentinel is not used when the length is a multiple of 3.
165/// Use [`decode_sequence`] to decode the sequence.
166pub fn encode_sequence(sequence: &[u8]) -> Vec<u8> {
167    let mut result: Vec<u8> = Vec::with_capacity(encoded_length(sequence.len()));
168
169    let mut offset = 0;
170    while offset + 3 <= sequence.len() {
171        let byte = ENCODE[sequence[offset] as usize] +
172            6 * ENCODE[sequence[offset + 1] as usize] +
173            36 * ENCODE[sequence[offset + 2] as usize];
174        result.push(byte);
175        offset += 3;
176    }
177    if sequence.len() - offset == 1 {
178        let byte = ENCODE[sequence[offset] as usize];
179        result.push(byte);
180    } else if sequence.len() - offset == 2 {
181        let byte = ENCODE[sequence[offset] as usize] + 6 * ENCODE[sequence[offset + 1] as usize];
182        result.push(byte);
183    }
184
185    result
186}
187
188/// Returns the length of the encoding for a sequence of the given length.
189pub fn encoded_length(sequence_length: usize) -> usize {
190    sequence_length.div_ceil(3)
191}
192
193//-----------------------------------------------------------------------------
194
195/// Returns an error if the given graph is not a valid reference for the given alignments.
196///
197/// The comparison is based on the provided [`GraphName`] objects.
198/// If either graph name is missing, no error is returned.
199/// Otherwise the graph name for the alignments must be a subgraph of the reference graph.
200pub fn require_valid_reference(alignments: &GraphName, reference: &GraphName) -> Result<(), String> {
201    if !alignments.has_name() || !reference.has_name() {
202        return Ok(());
203    }
204    if !alignments.is_subgraph_of(reference) {
205        let description = alignments.describe_relationship(reference, "alignments", "reference graph");
206        return Err(format!("The graph is not a valid reference for the alignments:\n{}", description));
207    }
208    Ok(())
209}
210
211//-----------------------------------------------------------------------------
212
213#[derive(Clone, Debug)]
214struct NodeIdCluster {
215    // Inclusive range of node ids in the cluster.
216    node_id_range: RangeInclusive<usize>,
217    // Range of indices in the original node id array.
218    array_range: Range<usize>,
219    // Array offset after the largest gap.
220    max_gap_offset: Option<usize>,
221}
222
223impl NodeIdCluster {
224    // Returns a new cluster covering the given range in the node id array.
225    // Assumes sorted and deduplicated node ids.
226    fn new(node_ids: &[usize], array_range: Range<usize>) -> Option<Self> {
227        if node_ids.is_empty() {
228            return None;
229        }
230        if array_range.is_empty() || array_range.end > node_ids.len() {
231            return None;
232        }
233
234        let first = node_ids[array_range.start];
235        let last = node_ids[array_range.end - 1];
236        let node_id_range = first..=last;
237
238        let mut max_gap_length = 0;
239        let mut max_gap_offset = None;
240        for i in (array_range.start + 1)..array_range.end {
241            let gap = node_ids[i] - node_ids[i - 1];
242            if gap > max_gap_length {
243                max_gap_length = gap;
244                max_gap_offset = Some(i);
245            }
246        }
247
248        Some(Self {
249            node_id_range,
250            array_range,
251            max_gap_offset,
252        })
253    }
254
255    fn max_gap_length(&self, node_ids: &[usize]) -> Option<usize> {
256        let offset = self.max_gap_offset?;
257        Some(node_ids[offset] - node_ids[offset - 1])
258    }
259
260    // Splits the cluster into two at the largest gap, if any.
261    // The return values are the cluster before the gap and the cluster after the gap.
262    fn split(self, node_ids: &[usize]) -> (Option<Self>, Option<Self>) {
263        if self.max_gap_offset.is_none() {
264            return (Some(self), None);
265        }
266
267        let offset = self.max_gap_offset.unwrap();
268        let left = NodeIdCluster::new(node_ids, self.array_range.start..offset);
269        let right = NodeIdCluster::new(node_ids, offset..self.array_range.end);
270
271        (left, right)
272    }
273}
274
275// TODO: If we stick to a constant threshold, we could determine the final clusters in a single pass.
276/// Returns a set of closed ranges that cover all node identifiers in the given set.
277///
278/// Initially there is a single cluster containing all node ids.
279/// Each cluster is recursively split at the longest gap between successive identifiers.
280/// The recursion stops when the length of the longest gap is at most `threshold`.
281/// This can be useful for partitioning a [`crate::Subgraph`] into multiple ranges before querying [`crate::GAFBase`].
282///
283/// # Examples
284///
285/// ```
286/// use gbz_base::utils;
287///
288/// let node_ids = vec![1, 2, 4, 6, 30, 31, 35];
289/// let threshold = 10;
290/// let clusters = utils::cluster_node_ids(node_ids, threshold);
291/// assert_eq!(clusters.len(), 2);
292/// assert_eq!(clusters[0], 1..=6);
293/// assert_eq!(clusters[1], 30..=35);
294/// ```
295pub fn cluster_node_ids(node_ids: Vec<usize>, threshold: usize) -> Vec<RangeInclusive<usize>> {
296    let mut node_ids = node_ids;
297    node_ids.sort_unstable();
298    node_ids.dedup();
299
300    let mut stack: Vec<NodeIdCluster> = Vec::new();
301    let mut result: Vec<RangeInclusive<usize>> = Vec::new();
302    let initial = NodeIdCluster::new(&node_ids, 0..node_ids.len());
303    if initial.is_none() {
304        return result;
305    }
306    stack.push(initial.unwrap());
307
308    while let Some(curr) = stack.pop() {
309        if let Some(len) = curr.max_gap_length(&node_ids) {
310            if len > threshold {
311                let (left, right) = curr.split(&node_ids);
312                if let Some(right) = right {
313                    stack.push(right);
314                }
315                if let Some(left) = left {
316                    stack.push(left);
317                }
318            } else {
319                result.push(curr.node_id_range);
320            }
321        } else {
322            result.push(curr.node_id_range);
323        }
324    }
325
326    result
327}
328
329//-----------------------------------------------------------------------------
330
331/// A structure that determines the GBWT starting positions for paths in a graph.
332///
333/// The starting positions are iterated in order.
334/// If a GBWT index is provided, the starting positions are determined from the paths in the index.
335/// Otherwise the starting positions for a unidirectional index are computed on the fly.
336pub enum PathStartSource<'a> {
337    /// A GBWT index of the paths and the next sequence id that has not been iterated.
338    Index(&'a GBWT, usize),
339    /// A map storing the number of paths starting from each node so far.
340    Map(HashMap<usize, usize>),
341}
342
343impl<'a> PathStartSource<'a> {
344    /// Returns a new source that computes the starting positions on the fly.
345    pub fn new() -> Self {
346        Self::default()
347    }
348
349    /// Returns the starting position of the next path.
350    ///
351    /// If the source is a GBWT index, the provided node identifier is ignored.
352    /// Returns [`None`] if the path is empty or all paths in the index have been iterated.
353    ///
354    /// If the source is a map, returns the starting position that would be assigned to the next path starting from the given node.
355    /// Returns [`None`] if the path is empty (if the node identifier is [`ENDMARKER`]).
356    pub fn next(&mut self, node_id: usize) -> Option<Pos> {
357        match self {
358            Self::Index(index, seq_id) => {
359                let result = index.start(*seq_id);
360                *seq_id += if index.is_bidirectional() { 2 } else { 1 };
361                result
362            }
363            Self::Map(map) => {
364                if node_id == ENDMARKER {
365                    return None;
366                }
367                let count = map.entry(node_id).or_insert(0);
368                let result = Pos::new(node_id, *count);
369                *count += 1;
370                Some(result)
371            }
372        }
373    }
374}
375
376impl<'a> Default for PathStartSource<'a> {
377    fn default() -> Self {
378        Self::Map(HashMap::new())
379    }
380}
381
382impl<'a> From<&'a GBWT> for PathStartSource<'a> {
383    fn from(index: &'a GBWT) -> Self {
384        Self::Index(index, 0)
385    }
386}
387
388//-----------------------------------------------------------------------------
389
390#[cfg(test)]
391mod tests {
392    use super::*;
393
394    use gbz::support::{self, Orientation};
395    use simple_sds::serialize;
396
397    #[test]
398    fn sequence_encoding() {
399        let full_sequence = b"GATTACACACCAGATNNNNNACATTGAACCTTACACAGTCTGAC";
400        for i in 0..full_sequence.len() {
401            let sequence = &full_sequence[0..i];
402            let encoded = encode_sequence(sequence);
403            let decoded = decode_sequence(&encoded);
404            assert_eq!(decoded, sequence, "Wrong sequence encoding for length {}", i);
405        }
406    }
407
408    fn test_cluster(node_ids: Vec<usize>, expected: Vec<RangeInclusive<usize>>, gap_threshold: usize, test_case: &str) {
409        let clusters = cluster_node_ids(node_ids, gap_threshold);
410        assert_eq!(clusters.len(), expected.len(), "Wrong number of clusters for {}", test_case);
411        for (i, cluster) in clusters.iter().enumerate() {
412            assert_eq!(cluster, &expected[i], "Wrong cluster {} for {}", i, test_case);
413        }
414    }
415
416    #[test]
417    fn cluster_node_ids_test() {
418        let node_ids = Vec::new();
419        let expected = Vec::new();
420        test_cluster(node_ids, expected, 10, "empty");
421
422        let node_ids = vec![5];
423        let expected = vec![5..=5];
424        test_cluster(node_ids, expected, 10, "single node");
425
426        let node_ids = vec![5, 6, 7, 8, 9];
427        let expected = vec![5..=9];
428        test_cluster(node_ids, expected, 10, "continuous nodes");
429
430        let node_ids = vec![6, 9, 7, 5, 8];
431        let expected = vec![5..=9];
432        test_cluster(node_ids, expected, 10, "continuous nodes unsorted");
433
434        let node_ids = vec![5, 7, 9];
435        let expected = vec![5..=9];
436        test_cluster(node_ids, expected, 10, "equal gaps");
437
438        let node_ids = vec![5, 6, 7, 20, 21, 22];
439        let expected = vec![5..=7, 20..=22];
440        test_cluster(node_ids, expected, 10, "two clusters");
441
442        let node_ids = vec![1, 50, 52, 53, 63, 64, 200];
443        let expected = vec![1..=1, 50..=64, 200..=200];
444        test_cluster(node_ids, expected, 10, "one cluster and outliers");
445
446        let node_ids = vec![1, 50, 52, 53, 73, 74, 200];
447        let expected = vec![1..=1, 50..=53, 73..=74, 200..=200];
448        test_cluster(node_ids, expected, 10, "two clusters and outliers");
449    }
450
451    #[test]
452    fn path_start_source() {
453        let gbwt_files = vec![
454            get_test_data("micb-kir3dl1_HG003.gbwt"),
455            get_test_data("bidirectional.gbwt"),
456            get_test_data("empty.gbwt"),
457            support::get_test_data("example.gbwt"),
458            support::get_test_data("translation.gbwt"),
459            support::get_test_data("with-empty.gbwt"),
460        ];
461
462        for gbwt_file in gbwt_files.iter() {
463            let index = serialize::load_from(gbwt_file);
464            assert!(index.is_ok(), "Failed to load GBWT index from {}: {}", gbwt_file.display(), index.unwrap_err());
465            let index: GBWT = index.unwrap();
466
467            let paths = if index.is_bidirectional() {
468                index.sequences() / 2
469            } else {
470                index.sequences()
471            };
472            let mut index_source = PathStartSource::from(&index);
473            let mut map_source = PathStartSource::new();
474            for path_id in 0..paths {
475                let seq_id = if index.is_bidirectional() { support::encode_path(path_id, Orientation::Forward) } else { path_id };
476                let truth = index.start(seq_id);
477                let node_id = truth.map(|pos| pos.node).unwrap_or(ENDMARKER);
478                let index_pos = index_source.next(0);
479                assert_eq!(index_pos, truth, "Wrong path start from index for path {} in {}", path_id, gbwt_file.display());
480                if !index.is_bidirectional() {
481                    let map_pos = map_source.next(node_id);
482                    assert_eq!(map_pos, truth, "Wrong path start from map for path {} in {}", path_id, gbwt_file.display());
483                }
484            }
485        }
486    }
487}
488
489//-----------------------------------------------------------------------------