fibertools_rs/utils/
basemods.rs

1use crate::utils::bamranges::*;
2use crate::utils::bio_io::*;
3use bio::alphabets::dna::revcomp;
4use lazy_static::lazy_static;
5use regex::Regex;
6use rust_htslib::{
7    bam,
8    bam::record::{Aux, AuxArray},
9};
10use std::collections::HashMap;
11
12use std::convert::TryFrom;
13
14#[derive(Eq, PartialEq, Debug, PartialOrd, Ord, Clone)]
15pub struct BaseMod {
16    pub modified_base: u8,
17    pub strand: char,
18    pub modification_type: char,
19    pub ranges: Ranges,
20    pub record_is_reverse: bool,
21}
22
23impl BaseMod {
24    pub fn new(
25        record: &bam::Record,
26        modified_base: u8,
27        strand: char,
28        modification_type: char,
29        modified_bases_forward: Vec<i64>,
30        modified_probabilities_forward: Vec<u8>,
31    ) -> Self {
32        let tmp = modified_bases_forward.clone();
33        let mut ranges = Ranges::new(record, modified_bases_forward, None, None);
34        ranges.set_qual(modified_probabilities_forward);
35        let record_is_reverse = record.is_reverse();
36        //assert_eq!(tmp, ranges.get_starts(), "starts not equal");
37        assert_eq!(tmp, ranges.get_forward_starts(), "forward starts not equal");
38        Self {
39            modified_base,
40            strand,
41            modification_type,
42            ranges,
43            record_is_reverse,
44        }
45    }
46
47    pub fn is_m6a(&self) -> bool {
48        self.modification_type == 'a'
49    }
50
51    pub fn is_cpg(&self) -> bool {
52        self.modification_type == 'm'
53    }
54
55    pub fn filter_at_read_ends(&mut self, n_strip: i64) {
56        if n_strip <= 0 {
57            return;
58        }
59        self.ranges.filter_starts_at_read_ends(n_strip);
60    }
61}
62
63#[derive(Eq, PartialEq, Debug, Clone)]
64pub struct BaseMods {
65    pub base_mods: Vec<BaseMod>,
66}
67
68impl BaseMods {
69    pub fn new(record: &bam::Record, min_ml_score: u8) -> BaseMods {
70        // my basemod parser is ~25% faster than rust_htslib's
71        BaseMods::my_mm_ml_parser(record, min_ml_score)
72    }
73
74    pub fn my_mm_ml_parser(record: &bam::Record, min_ml_score: u8) -> BaseMods {
75        // regex for matching the MM tag
76        lazy_static! {
77            // MM:Z:([ACGTUN][-+]([a-z]+|[0-9]+)[.?]?(,[0-9]+)*;)*
78            static ref MM_RE: Regex =
79                Regex::new(r"((([ACGTUN])([-+])([a-z]+|[0-9]+))[.?]?((,[0-9]+)*;)*)").unwrap();
80        }
81        // Array to store all the different modifications within the MM tag
82        let mut rtn = vec![];
83
84        let ml_tag = get_u8_tag(record, b"ML");
85
86        let mut num_mods_seen = 0;
87
88        // if there is an MM tag iterate over all the regex matches
89        if let Ok(Aux::String(mm_text)) = record.aux(b"MM") {
90            for cap in MM_RE.captures_iter(mm_text) {
91                let mod_base = cap.get(3).map(|m| m.as_str().as_bytes()[0]).unwrap();
92                let mod_strand = cap.get(4).map_or("", |m| m.as_str());
93                let modification_type = cap.get(5).map_or("", |m| m.as_str());
94                let mod_dists_str = cap.get(6).map_or("", |m| m.as_str());
95                // parse the string containing distances between modifications into a vector of i64
96                let mod_dists: Vec<i64> = mod_dists_str
97                    .trim_end_matches(';')
98                    .split(',')
99                    .map(|s| s.trim())
100                    .filter(|s| !s.is_empty())
101                    .map(|s| s.parse().unwrap())
102                    .collect();
103
104                // get forward sequence bases from the bam record
105                let forward_bases = if record.is_reverse() {
106                    revcomp(record.seq().as_bytes())
107                } else {
108                    record.seq().as_bytes()
109                };
110                log::trace!(
111                    "mod_base: {}, mod_strand: {}, modification_type: {}, mod_dists: {:?}",
112                    mod_base as char,
113                    mod_strand,
114                    modification_type,
115                    mod_dists
116                );
117                // find real positions in the forward sequence
118                let mut cur_mod_idx = 0;
119                let mut cur_seq_idx = 0;
120                let mut dist_from_last_mod_base = 0;
121                let mut unfiltered_modified_positions: Vec<i64> = vec![0; mod_dists.len()];
122                while cur_seq_idx < forward_bases.len() && cur_mod_idx < mod_dists.len() {
123                    let cur_base = forward_bases[cur_seq_idx];
124                    if cur_base == mod_base && dist_from_last_mod_base == mod_dists[cur_mod_idx] {
125                        unfiltered_modified_positions[cur_mod_idx] =
126                            i64::try_from(cur_seq_idx).unwrap();
127                        dist_from_last_mod_base = 0;
128                        cur_mod_idx += 1;
129                    } else if cur_base == mod_base {
130                        dist_from_last_mod_base += 1
131                    }
132                    cur_seq_idx += 1;
133                }
134                // assert that we extract the same number of modifications as we have distances
135                assert_eq!(
136                    cur_mod_idx,
137                    mod_dists.len(),
138                    "{:?} {}",
139                    String::from_utf8_lossy(record.qname()),
140                    record.is_reverse()
141                );
142
143                // check for the probability of modification.
144                let num_mods_cur_end = num_mods_seen + unfiltered_modified_positions.len();
145                let unfiltered_modified_probabilities = if num_mods_cur_end > ml_tag.len() {
146                    let needed_num_of_zeros = num_mods_cur_end - ml_tag.len();
147                    let mut to_add = vec![0; needed_num_of_zeros];
148                    let mut has = ml_tag[num_mods_seen..ml_tag.len()].to_vec();
149                    has.append(&mut to_add);
150                    log::warn!(
151                        "ML tag is too short for the number of modifications found in the MM tag. Assuming an ML value of 0 after the first {num_mods_cur_end} modifications."
152                    );
153                    has
154                } else {
155                    ml_tag[num_mods_seen..num_mods_cur_end].to_vec()
156                };
157                num_mods_seen = num_mods_cur_end;
158
159                // must be true for filtering, and at this point
160                assert_eq!(
161                    unfiltered_modified_positions.len(),
162                    unfiltered_modified_probabilities.len()
163                );
164
165                // Filter mods based on probabilities
166                let (modified_probabilities, modified_positions): (Vec<u8>, Vec<i64>) =
167                    unfiltered_modified_probabilities
168                        .iter()
169                        .zip(unfiltered_modified_positions.iter())
170                        .filter(|(&ml, &_mm)| ml >= min_ml_score)
171                        .unzip();
172
173                // don't add empty basemods
174                if modified_positions.is_empty() {
175                    continue;
176                }
177                // add to a struct
178                let mods = BaseMod::new(
179                    record,
180                    mod_base,
181                    mod_strand.chars().next().unwrap(),
182                    modification_type.chars().next().unwrap(),
183                    modified_positions,
184                    modified_probabilities,
185                );
186                rtn.push(mods);
187            }
188        } else {
189            log::trace!("No MM tag found");
190        }
191
192        if ml_tag.len() != num_mods_seen {
193            log::warn!(
194                "ML tag ({}) different number than MM tag ({}).",
195                ml_tag.len(),
196                num_mods_seen
197            );
198        }
199        // needed so I can compare methods
200        rtn.sort();
201        BaseMods { base_mods: rtn }
202    }
203
204    pub fn hashmap_to_basemods(
205        map: HashMap<(i32, i32, i32), Vec<(i64, u8)>>,
206        record: &bam::Record,
207    ) -> BaseMods {
208        let mut rtn = vec![];
209        for (mod_info, mods) in map {
210            let mod_base = mod_info.0 as u8;
211            let mod_type = mod_info.1 as u8 as char;
212            let mod_strand = if mod_info.2 == 0 { '+' } else { '-' };
213            let (mut positions, mut qualities): (Vec<i64>, Vec<u8>) = mods.into_iter().unzip();
214            if record.is_reverse() {
215                let length = record.seq_len() as i64;
216                positions = positions
217                    .into_iter()
218                    .rev()
219                    .map(|p| length - p - 1)
220                    .collect();
221                qualities.reverse();
222            }
223            let mods = BaseMod::new(record, mod_base, mod_strand, mod_type, positions, qualities);
224            rtn.push(mods);
225        }
226        // needed so I can compare methods
227        rtn.sort();
228        BaseMods { base_mods: rtn }
229    }
230
231    /// remove m6a base mods from the struct
232    pub fn drop_m6a(&mut self) {
233        self.base_mods.retain(|bm| !bm.is_m6a());
234    }
235
236    /// remove cpg/5mc base mods from the struct
237    pub fn drop_cpg(&mut self) {
238        self.base_mods.retain(|bm| !bm.is_cpg());
239    }
240
241    /// drop the forward stand of basemod calls
242    pub fn drop_forward(&mut self) {
243        self.base_mods.retain(|bm| bm.strand == '-');
244    }
245
246    /// drop the reverse strand of basemod calls   
247    pub fn drop_reverse(&mut self) {
248        self.base_mods.retain(|bm| bm.strand == '+');
249    }
250
251    /// drop m6A modifications with a qual less than the min_ml_score
252    pub fn filter_m6a(&mut self, min_ml_score: u8) {
253        self.base_mods
254            .iter_mut()
255            .filter(|bm| bm.is_m6a())
256            .for_each(|bm| bm.ranges.filter_by_qual(min_ml_score));
257    }
258
259    /// drop 5mC modifications with a qual less than the min_ml_score
260    pub fn filter_5mc(&mut self, min_ml_score: u8) {
261        self.base_mods
262            .iter_mut()
263            .filter(|bm| bm.is_cpg())
264            .for_each(|bm| bm.ranges.filter_by_qual(min_ml_score));
265    }
266
267    /// filter the basemods at the read ends
268    pub fn filter_at_read_ends(&mut self, n_strip: i64) {
269        if n_strip <= 0 {
270            return;
271        }
272        self.base_mods
273            .iter_mut()
274            .for_each(|bm| bm.filter_at_read_ends(n_strip));
275    }
276
277    /// combine the forward and reverse m6a data
278    pub fn m6a(&self) -> Ranges {
279        let ranges = self
280            .base_mods
281            .iter()
282            .filter(|x| x.is_m6a())
283            .map(|x| &x.ranges)
284            .collect();
285        Ranges::merge_ranges(ranges)
286    }
287
288    /// combine the forward and reverse cpd/5mc data
289    pub fn cpg(&self) -> Ranges {
290        let ranges = self
291            .base_mods
292            .iter()
293            .filter(|x| x.is_cpg())
294            .map(|x| &x.ranges)
295            .collect();
296        Ranges::merge_ranges(ranges)
297    }
298
299    /// Example MM tag: MM:Z:C+m,11,6,10;A+a,0,0,0;
300    /// Example ML tag: ML:B:C,157,30,2,164,118,255
301    pub fn add_mm_and_ml_tags(&self, record: &mut bam::Record) {
302        // init the mm and ml tag to be populated
303        let mut ml_tag: Vec<u8> = vec![];
304        let mut mm_tag = "".to_string();
305        // need the original sequence for distances between bases.
306        let mut seq = record.seq().as_bytes();
307        if record.is_reverse() {
308            seq = revcomp(seq);
309        }
310        // add to the ml and mm tag.
311        for basemod in self.base_mods.iter() {
312            // adding quality values (ML)
313            ml_tag.extend(basemod.ranges.get_forward_quals());
314            // get MM tag values
315            let mut cur_mm = vec![];
316            let positions = basemod.ranges.get_forward_starts();
317            let mut last_pos = 0;
318            for pos in positions {
319                let u_pos = pos as usize;
320                let mut in_between = 0;
321                if last_pos < u_pos {
322                    for base in seq[last_pos..u_pos].iter() {
323                        if *base == basemod.modified_base {
324                            in_between += 1;
325                        }
326                    }
327                }
328                last_pos = u_pos + 1;
329                cur_mm.push(in_between);
330            }
331            // Add to the MM string
332            mm_tag.push(basemod.modified_base as char);
333            mm_tag.push(basemod.strand);
334            mm_tag.push(basemod.modification_type);
335            for diff in cur_mm {
336                mm_tag.push_str(&format!(",{}", diff));
337            }
338            mm_tag.push(';')
339            // next basemod
340        }
341        log::trace!(
342            "{}\n{}\n{}\n",
343            record.is_reverse(),
344            mm_tag,
345            String::from_utf8_lossy(&seq)
346        );
347        // clear out the old base mods
348        record.remove_aux(b"MM").unwrap_or(());
349        record.remove_aux(b"ML").unwrap_or(());
350        // Add MM
351        let aux_integer_field = Aux::String(&mm_tag);
352        record.push_aux(b"MM", aux_integer_field).unwrap();
353        // Add ML
354        let aux_array: AuxArray<u8> = (&ml_tag).into();
355        let aux_array_field = Aux::ArrayU8(aux_array);
356        record.push_aux(b"ML", aux_array_field).unwrap();
357    }
358}