rumi_lib/
lib.rs

1use basebits::{hamming_dist_none, BaseBits};
2//use rayon::iter::ParBridge;
3use rayon::prelude::*;
4use rust_htslib::bam::errors::Error;
5use rust_htslib::bam::record::{Aux, Cigar, CigarString};
6use rust_htslib::bam::{self, Read};
7use std::cmp::Ordering;
8use std::collections::hash_map::{Entry::Occupied, Entry::Vacant};
9use std::collections::{BTreeMap, HashMap, HashSet, VecDeque};
10use std::fmt::Display;
11use std::process;
12use std::sync::mpsc::channel;
13use std::sync::{Arc, Mutex};
14
15#[cfg(test)]
16mod test;
17
18#[derive(Debug)]
19pub struct Config {
20    pub allowed_read_dist: u32,
21    pub allowed_count_factor: u32,
22    pub allowed_network_depth: usize,
23    pub umi_tag: String,
24    pub input_bam: String,
25    pub output_bam: String,
26    pub umi_in_read_id: bool,
27    pub ignore_splice_pos: bool,
28    pub group_only: bool,
29    pub is_paired: bool,
30}
31
32#[derive(Debug, Clone, PartialEq)]
33pub struct Node {
34    umi: BaseBits,
35    freq: ReadFreq,
36    connections: Vec<usize>,
37}
38
39#[derive(Hash, PartialEq, Eq, Debug, Clone)]
40pub struct Position {
41    pos: i32,
42    is_spliced: Option<u32>,
43    is_rev: bool,
44    target: i32,
45    tlen: Option<i32>,
46}
47
48impl PartialOrd for Position {
49    fn partial_cmp(&self, other: &Position) -> Option<Ordering> {
50        Some(self.cmp(other))
51    }
52}
53
54impl Ord for Position {
55    fn cmp(&self, other: &Position) -> Ordering {
56        let comp = self.target.cmp(&other.target);
57        if comp != Ordering::Equal {
58            return comp;
59        }
60
61        let comp = self.pos.cmp(&other.pos);
62        if comp != Ordering::Equal {
63            return comp;
64        }
65
66        let comp = self.tlen.cmp(&other.tlen);
67        if comp != Ordering::Equal {
68            return comp;
69        }
70
71        let comp = self.is_spliced.cmp(&other.is_spliced);
72        if comp != Ordering::Equal {
73            return comp;
74        }
75
76        self.is_rev.cmp(&other.is_rev)
77    }
78}
79
80impl Position {
81    /// Takes a read and determins the position to use as a key in the returned group.
82    pub fn new(record: &bam::record::Record, ignore_splice_pos: bool, use_tlen: bool) -> Self {
83        let mut pos = record.pos();
84        let mut is_spliced: Option<u32>;
85        let tlen: Option<i32>;
86        let cigarview = record.cigar();
87        let cigar = &cigarview;
88
89        if record.is_reverse() {
90            pos = cigarview.end_pos();
91            // if the end of the read was soft clipped, add that amount back to its pos
92            if let Cigar::SoftClip(num) = cigar[cigar.len() - 1] {
93                pos = pos + num as i32;
94            }
95
96            is_spliced = Position::find_splice(&cigar, true);
97        } else {
98            if let Cigar::SoftClip(num) = cigar[0] {
99                pos = pos - num as i32;
100            }
101            is_spliced = Position::find_splice(&cigar, false);
102        }
103        if ignore_splice_pos && is_spliced.is_some() {
104            is_spliced = Some(0);
105        }
106
107        if use_tlen {
108            tlen = Some(record.insert_size());
109        } else {
110            tlen = None;
111        }
112        Self {
113            pos: pos,
114            is_rev: record.is_reverse(),
115            target: record.tid(),
116            is_spliced: is_spliced,
117            tlen: tlen,
118        }
119    }
120
121    /// Takes a cigar string and finds the first splice postion as an offset from the start.
122    /// Equivalent of `find_splice` in umi_tools
123    fn find_splice(cigar: &CigarString, is_reversed: bool) -> Option<u32> {
124        let mut range: Vec<usize> = (0..cigar.len()).collect();
125        let mut offset = 0;
126
127        if is_reversed {
128            range = (0..cigar.len()).rev().collect();
129        }
130        if let Cigar::SoftClip(num) = cigar[range[0]] {
131            offset = num;
132            range.remove(0);
133        }
134        for i in range {
135            match cigar[i] {
136                // Found splice
137                Cigar::RefSkip(_) | Cigar::SoftClip(_) => return Some(offset),
138                // Reference consumeing operations
139                Cigar::Match(num) | Cigar::Del(num) | Cigar::Equal(num) | Cigar::Diff(num) => {
140                    offset += num
141                }
142                // Non-reference consuming operations
143                Cigar::Ins(_) | Cigar::HardClip(_) | Cigar::Pad(_) => continue,
144            }
145        }
146        None
147    }
148}
149
150/// Abstraction so ReadFreq can hold a single best read for it's read signature or hold all reads
151/// for it's read signature (used for --group_only).
152#[derive(Debug, Clone, PartialEq)]
153pub enum ReadCollection {
154    SingleRead(bam::record::Record),
155    ManyReads(Vec<bam::record::Record>),
156}
157
158/// A Read or Reads and the number of times that read signature has been seen
159/// Read signature meaning Position + UMI
160#[derive(Debug, Clone, PartialEq)]
161pub struct ReadFreq {
162    read: ReadCollection,
163    freq: u32,
164}
165
166/// A group of reads that have been deduplicated
167#[derive(Debug)]
168pub struct Group<'a> {
169    nodes: Vec<&'a Node>,
170    umi: &'a BaseBits,
171    master_node: usize,
172}
173
174#[derive(Debug)]
175pub enum RecordEvent {
176    RecordMapped,
177    RecordUnmapped,
178    RecordUnpaired,
179    RecordMateUnmapped,
180    RecordChimeric,
181}
182
183#[derive(Debug)]
184pub struct Stats {
185    reads_in: u32,
186    reads_out: u32,
187    reads_unmapped: u32,
188    reads_unpaired: u32,
189    mate_unmapped: u32,
190    chimeric: u32,
191}
192
193impl Stats {
194    pub fn new() -> Self {
195        Stats {
196            reads_in: 0,
197            reads_out: 0,
198            reads_unmapped: 0,
199            reads_unpaired: 0,
200            mate_unmapped: 0,
201            chimeric: 0,
202        }
203    }
204    pub fn update(&mut self, other: &Self) {
205        self.reads_in += other.reads_in;
206        self.reads_out += other.reads_out;
207        self.reads_unmapped += other.reads_unmapped;
208        self.reads_unpaired += other.reads_unpaired;
209        self.mate_unmapped += other.mate_unmapped;
210        self.chimeric += other.chimeric;
211    }
212}
213
214impl Display for Stats {
215    fn fmt(&self, fmt: &mut std::fmt::Formatter) -> std::result::Result<(), std::fmt::Error> {
216        write!(fmt, "Reads In: {}\nReads Out: {}\nReads Unmapped: {}\nReads Unpaired: {}\nMates Unmapped: {}\nReads Chimeric: {}", self.reads_in, self.reads_out, self.reads_unmapped, self.reads_unpaired, self.mate_unmapped, self.chimeric)
217    }
218}
219
220pub type UmiMap = HashMap<BaseBits, ReadFreq>;
221pub type ReadMap = BTreeMap<Position, UmiMap>;
222
223/// The main function to coordinate the deduplication process
224pub fn run_dedup(config: &Config) -> Result<(), &'static str> {
225    let mut bam = bam::Reader::from_path(&config.input_bam).unwrap();
226    let header = bam::Header::from_template(bam.header());
227    let mut writer = bam::Writer::from_path(&config.output_bam, &header, bam::Format::BAM).unwrap();
228    let mut read_store: HashSet<Vec<u8>> = HashSet::new();
229    let (sender, reciever) = channel();
230    let global_stats = Arc::new(Mutex::new(Stats::new()));
231
232    let bundler = Bundler {
233        records: bam.records(),
234        last_chr: None,
235        next_bundle: vec![],
236    };
237
238    bundler
239        .par_bridge()
240        .flat_map(|bundle| {
241            let (x, stats) = group_reads(bundle, &config);
242            let y: Vec<(Position, UmiMap)> = x.into_iter().collect();
243            let mut g_stats = global_stats.lock().unwrap();
244            g_stats.update(&stats);
245            y
246        })
247        .flat_map(|(_, reads)| dedup(reads, config))
248        .for_each_with(sender, |s, x| s.send(x).unwrap());
249
250    let mut reads_out = 0;
251    reciever.iter().for_each(|read| {
252        reads_out += 1;
253        writer.write(&read).unwrap_or_else(|err| {
254            eprintln!("Problem writing: {}", err);
255            process::exit(1);
256        });
257        if config.is_paired {
258            read_store.insert(read.qname().to_vec());
259        }
260    });
261
262    if config.is_paired {
263        bam.records()
264            .map(|read| read.unwrap())
265            .filter(|read| read.is_last_in_template() && read_store.contains(read.qname()))
266            .for_each(|read| {
267                reads_out += 1;
268                writer.write(&read).unwrap_or_else(|err| {
269                    eprintln!("Problem writing: {}", err);
270                    process::exit(1);
271                });
272            });
273    }
274
275    let mut stats = global_stats.lock().unwrap();
276    stats.reads_out += reads_out;
277    println!("{}", stats);
278    Ok(())
279}
280
281pub fn run_group(config: &Config) -> Result<(), &'static str> {
282    let mut bam = bam::Reader::from_path(&config.input_bam).unwrap();
283    let header = bam::Header::from_template(bam.header());
284    let mut writer = bam::Writer::from_path(&config.output_bam, &header, bam::Format::BAM).unwrap();
285    let mut read_store: HashMap<Vec<u8>, (bam::record::Aux, Vec<u8>)> = HashMap::new();
286    let global_stats = Arc::new(Mutex::new(Stats::new()));
287    let (sender, reciever) = channel();
288
289    let bundler = Bundler {
290        records: bam.records(),
291        last_chr: None,
292        next_bundle: vec![],
293    };
294
295    bundler
296        .par_bridge()
297        .flat_map(|bundle| {
298            let (x, stats) = group_reads(bundle, &config);
299            let y: Vec<(Position, UmiMap)> = x.into_iter().collect();
300            let mut g_stats = global_stats.lock().unwrap();
301            g_stats.update(&stats);
302            y
303        })
304        .flat_map(|(_, reads)| label_groups(reads, config))
305        .for_each_with(sender, |s, x| s.send(x).unwrap());
306
307    let mut group_count: i64 = 0;
308    let mut reads_out = 0;
309    reciever.iter().for_each(|mut group| {
310        for read in group.iter_mut() {
311            reads_out += 1;
312            read.push_aux(b"UG", &bam::record::Aux::Integer(group_count));
313            writer.write(&read).unwrap_or_else(|err| {
314                eprintln!("Problem writing: {}", err);
315                process::exit(1);
316            });
317            if config.is_paired {
318                let umi = read.aux(b"BX").unwrap().string().to_vec();
319                read_store.insert(
320                    read.qname().to_vec(),
321                    (bam::record::Aux::Integer(group_count), umi),
322                );
323            }
324            group_count += 1;
325        }
326    });
327
328    if config.is_paired {
329        bam.records()
330            .map(|read| read.unwrap())
331            .filter(|read| read.is_last_in_template())
332            .for_each(|mut read| {
333                if let Some((ug, bx_val)) = read_store.get(read.qname()) {
334                    reads_out += 1;
335                    read.push_aux(b"UG", ug);
336                    read.push_aux(b"BX", &bam::record::Aux::String(&bx_val));
337                    writer.write(&read).unwrap_or_else(|err| {
338                        eprintln!("Problem writing: {}", err);
339                        process::exit(1);
340                    });
341                }
342            });
343    }
344    let mut stats = global_stats.lock().unwrap();
345    stats.reads_out = reads_out;
346    println!("{}", stats);
347    Ok(())
348}
349
350struct Bundler<I>
351where
352    I: Iterator<Item = Result<rust_htslib::bam::record::Record, Error>>,
353{
354    records: I,
355    last_chr: Option<i32>,
356    next_bundle: Vec<rust_htslib::bam::record::Record>,
357}
358
359impl<I> Iterator for Bundler<I>
360where
361    I: Iterator<Item = Result<rust_htslib::bam::record::Record, Error>>,
362{
363    type Item = Vec<rust_htslib::bam::record::Record>;
364
365    fn next(&mut self) -> Option<Self::Item> {
366        let mut bundle = vec![];
367        std::mem::swap(&mut self.next_bundle, &mut bundle);
368        while let Some(r) = self.records.next() {
369            let record = r.unwrap();
370            if let Some(tid) = self.last_chr {
371                if tid == record.tid() {
372                    bundle.push(record);
373                } else {
374                    self.last_chr = Some(record.tid());
375                    self.next_bundle.push(record);
376                    break;
377                }
378            } else {
379                self.last_chr = Some(record.tid());
380                bundle.push(record);
381            }
382        }
383        if bundle.len() > 0 {
384            Some(bundle)
385        } else {
386            None
387        }
388    }
389}
390
391fn get_tag<'a>(record: &'a bam::record::Record, config: &Config) -> &'a [u8] {
392    if config.umi_in_read_id {
393        match record.qname().split(|&c| c == b'_').last() {
394            Some(tag) => tag,
395            None => panic!("No tag in read id"),
396        }
397    } else {
398        match record.aux(config.umi_tag.as_bytes()) {
399            Some(tag) => tag.string(),
400            None => panic!("No tag on read"),
401        }
402    }
403}
404
405pub fn check_record(record: &bam::record::Record, paired_end: bool) -> RecordEvent {
406    if paired_end {
407        if record.is_unmapped() {
408            return RecordEvent::RecordUnmapped;
409        }
410        if !record.is_paired() {
411            return RecordEvent::RecordUnpaired;
412        }
413        if record.tid() != record.mtid() {
414            return RecordEvent::RecordChimeric;
415        }
416        if record.is_mate_unmapped() {
417            return RecordEvent::RecordMateUnmapped;
418        }
419    } else {
420        if record.is_unmapped() {
421            return RecordEvent::RecordUnmapped;
422        }
423    }
424    RecordEvent::RecordMapped
425}
426
427/// Group reads together based on their positions.
428pub fn group_reads(
429    records: Vec<rust_htslib::bam::record::Record>,
430    config: &Config,
431) -> (ReadMap, Stats) {
432    let mut read_map: ReadMap = BTreeMap::new();
433    let mut stats = Stats::new();
434
435    for record in records.into_iter() {
436        stats.reads_in += 1;
437
438        if config.is_paired && record.is_last_in_template() {
439            continue;
440        }
441
442        match check_record(&record, config.is_paired) {
443            RecordEvent::RecordMapped => {}
444            RecordEvent::RecordUnmapped => {
445                stats.reads_unmapped += 1;
446                continue;
447            }
448            RecordEvent::RecordUnpaired => {
449                stats.reads_unpaired += 1;
450                continue;
451            }
452            RecordEvent::RecordMateUnmapped => {
453                stats.mate_unmapped += 1;
454            }
455            RecordEvent::RecordChimeric => {
456                stats.chimeric += 1;
457                continue;
458            }
459        }
460
461        let tag = get_tag(&record, config);
462        let position = Position::new(&record, config.ignore_splice_pos, config.is_paired);
463
464        // Add to my reverse lookup
465        let bb = BaseBits::new(tag).unwrap();
466        let position_map = read_map.entry(position).or_insert(HashMap::new());
467        match position_map.entry(bb) {
468            Occupied(entry) => {
469                let rf = entry.into_mut();
470                match &rf.read {
471                    ReadCollection::SingleRead(read) => {
472                        if !read_a_ge_b(&read, &record) {
473                            rf.read = ReadCollection::SingleRead(record);
474                        }
475                    }
476                    ReadCollection::ManyReads(reads) => {
477                        // This is stupid but it's only for the group version so....
478                        let mut reads = reads.clone();
479                        reads.push(record);
480                        rf.read = ReadCollection::ManyReads(reads);
481                    }
482                }
483                rf.freq += 1;
484            }
485            Vacant(entry) => {
486                if !config.group_only {
487                    entry.insert(ReadFreq {
488                        read: ReadCollection::SingleRead(record),
489                        freq: 1,
490                    });
491                } else {
492                    entry.insert(ReadFreq {
493                        read: ReadCollection::ManyReads(vec![record]),
494                        freq: 1,
495                    });
496                }
497            }
498        };
499    }
500    (read_map, stats)
501}
502
503/// Create a graph from the UmiMap
504/// TODO: Inline?
505pub fn build_graph(reads: UmiMap) -> Vec<Node> {
506    reads
507        .into_iter()
508        .map(|(umi, freqs)| Node {
509            umi,
510            connections: vec![],
511            freq: freqs,
512        })
513        .collect()
514}
515
516/// Create the connections between the umis via an all vs all comparison.
517/// A Connection will only be formed from a larger node to a smaller node.
518/// Larger being defined as node_a >= 2x node_b - 1, the provides the directionality.
519/// TODO: Keep a seen list here instead of later? Some connections will be redundant.
520pub fn connect_graph(mut graph: Vec<Node>, dist: u32, counts_factor: u32) -> Vec<Node> {
521    for i in 0..graph.len() {
522        for j in 0..graph.len() {
523            if i == j {
524                continue;
525            }
526            if hamming_dist_none(&graph[i].umi, &graph[j].umi) <= dist
527                && graph[i].freq.freq >= (counts_factor * graph[j].freq.freq) - 1
528            {
529                graph[i].connections.push(j);
530            }
531        }
532    }
533    graph
534}
535
536// TODO: Use proper bk tree for faster lookups
537fn determine_umi<'a>(graph: &'a Vec<Node>, allowed_network_depth: usize) -> Vec<Group> {
538    // Group the umis by distance
539    let mut groups = vec![];
540    let mut seen: Vec<usize> = Vec::new();
541    // Create a vec of nodes indicies going from highest counts to lowest
542    let mut graph_indicies: Vec<usize> = (0..graph.len()).collect();
543    &graph_indicies.sort_by(|&a, &b| graph[b].freq.freq.cmp(&graph[a].freq.freq));
544
545    for &x in graph_indicies.iter() {
546        if seen.contains(&x) {
547            continue;
548        }
549        seen.push(x);
550        let node = &graph[x];
551        let mut group_holder: Vec<Vec<&Node>> = Vec::new();
552
553        // Get all the nodes within 1 hamming dist
554        let mut group: Vec<&Node> = vec![];
555        for &x in node.connections.iter() {
556            if !seen.contains(&x) {
557                seen.push(x);
558                group.push(&graph[x]);
559            }
560        }
561
562        // Get all the nodes within k hamming dist
563        // If two nodes lie equidistant away from a smaller node, it shouldn't matter which node
564        // gets the discrepent reads, there would be no real biological way to tell...
565        let mut queue: VecDeque<Vec<usize>> = VecDeque::new();
566        queue.push_back(group.iter().flat_map(|n| n.connections.clone()).collect());
567        for _ in 0..(allowed_network_depth - 1) {
568            if let Some(connections) = queue.pop_front() {
569                let mut new_group: Vec<&Node> = vec![];
570                queue.push_back(
571                    connections
572                        .iter()
573                        .flat_map(|&x| {
574                            if !seen.contains(&x) {
575                                seen.push(x);
576                                new_group.push(&graph[x]);
577                            }
578                            graph[x].connections.clone()
579                        })
580                        .collect(),
581                );
582                group_holder.push(new_group);
583            }
584        }
585        // Must add after, otherwise it will be searched again
586        group.push(node);
587
588        // Merge all groups and choose concensus umi
589        for g in group_holder {
590            group.extend(g.iter());
591        }
592
593        let master_node = group.iter().enumerate().fold(0, |max, (i, x)| {
594            if x.freq.freq > group[max].freq.freq {
595                i
596            } else {
597                max
598            }
599        });
600        let umi = &group[master_node].umi;
601        let group = Group {
602            nodes: group,
603            umi,
604            master_node,
605        };
606
607        groups.push(group);
608    }
609    groups
610}
611
612/// Deduplicate a group of reads that all positioned at the same position
613fn dedup(reads: UmiMap, config: &Config) -> Vec<bam::record::Record> {
614    let graph = build_graph(reads);
615    let graph = connect_graph(graph, config.allowed_read_dist, config.allowed_count_factor);
616    let groups = determine_umi(&graph, config.allowed_network_depth);
617    let mut final_reads = vec![];
618
619    for group in groups.into_iter() {
620        let node = group.nodes[group.master_node];
621        if let ReadCollection::SingleRead(read) = &node.freq.read {
622            let read = read.clone();
623            final_reads.push(read);
624        } else {
625            unreachable!();
626        }
627    }
628    final_reads
629}
630
631/// TODO: Don't clone the read :(
632fn label_groups(reads: UmiMap, config: &Config) -> Vec<Vec<bam::record::Record>> {
633    let graph = build_graph(reads);
634    let graph = connect_graph(graph, config.allowed_read_dist, config.allowed_count_factor);
635    let groups = determine_umi(&graph, config.allowed_network_depth);
636    let mut records = vec![];
637
638    for group in groups.into_iter() {
639        let mut group_list = vec![];
640        let master_umi = group.nodes[group.master_node].umi;
641        for node in group.nodes {
642            if let ReadCollection::ManyReads(reads) = &node.freq.read {
643                for read in reads.into_iter() {
644                    let mut read = read.clone();
645                    read.push_aux(b"BX", &bam::record::Aux::String(&master_umi.decode()));
646                    group_list.push(read);
647                }
648            } else {
649                unreachable!();
650            }
651        }
652        records.push(group_list);
653    }
654    records
655}
656
657/////////////////////// Helpers
658/// Decide wich read is better.
659/// For now this uses the simplistic approach of comparing mapq values.
660/// Returns true if alpha is better than beta, false otherwise
661fn read_a_ge_b(alpha: &bam::record::Record, beta: &bam::record::Record) -> bool {
662    // Take the read with the hightest mapq
663    match alpha.mapq().cmp(&beta.mapq()) {
664        Ordering::Less => false,
665        Ordering::Greater => true,
666        // Take the read with the lowest number of muli mappings
667        Ordering::Equal => match alpha
668            .aux(b"NH")
669            .unwrap_or(Aux::Integer(0))
670            .integer()
671            .cmp(&beta.aux(b"NH").unwrap_or(Aux::Integer(0)).integer())
672        {
673            Ordering::Less => true,
674            Ordering::Greater => false,
675            // Take the read with the smallest edit distance
676            Ordering::Equal => match alpha
677                .aux(b"NM")
678                .unwrap_or(Aux::Integer(0))
679                .integer()
680                .cmp(&beta.aux(b"NM").unwrap_or(Aux::Integer(0)).integer())
681            {
682                Ordering::Less => true,
683                Ordering::Greater => false,
684                // Take the read with the longer sequence
685                Ordering::Equal => match alpha.seq().len().cmp(&beta.seq().len()) {
686                    Ordering::Less => false,
687                    Ordering::Greater => true,
688                    // The incumbant wins if we get this far
689                    Ordering::Equal => true,
690                },
691            },
692        },
693    }
694}