Skip to main content

fibertools_rs/utils/
basemods.rs

1use crate::utils::bamannotations::*;
2use crate::utils::bio_io::*;
3use bio::alphabets::dna::revcomp;
4use lazy_static::lazy_static;
5use regex::Regex;
6use rust_htslib::{
7    bam,
8    bam::record::{Aux, AuxArray},
9};
10use std::collections::HashMap;
11
12use std::convert::TryFrom;
13
14#[derive(Eq, PartialEq, Debug, PartialOrd, Ord, Clone)]
15pub struct BaseMod {
16    pub modified_base: u8,
17    pub strand: char,
18    pub modification_type: char,
19    pub ranges: Ranges,
20    pub record_is_reverse: bool,
21}
22
23impl BaseMod {
24    pub fn new(
25        record: &bam::Record,
26        modified_base: u8,
27        strand: char,
28        modification_type: char,
29        modified_bases_forward: Vec<i64>,
30        modified_probabilities_forward: Vec<u8>,
31    ) -> Self {
32        let tmp = modified_bases_forward.clone();
33        let mut ranges = Ranges::new(record, modified_bases_forward, None, None);
34        ranges.set_qual(modified_probabilities_forward);
35        let record_is_reverse = record.is_reverse();
36        assert_eq!(tmp, ranges.forward_starts(), "forward starts not equal");
37        Self {
38            modified_base,
39            strand,
40            modification_type,
41            ranges,
42            record_is_reverse,
43        }
44    }
45
46    pub fn is_m6a(&self) -> bool {
47        self.modification_type == 'a'
48    }
49
50    pub fn is_cpg(&self) -> bool {
51        self.modification_type == 'm'
52    }
53
54    pub fn filter_at_read_ends(&mut self, n_strip: i64) {
55        if n_strip <= 0 {
56            return;
57        }
58        self.ranges.filter_starts_at_read_ends(n_strip);
59    }
60}
61
62#[derive(Eq, PartialEq, Debug, Clone)]
63pub struct BaseMods {
64    pub base_mods: Vec<BaseMod>,
65}
66
67impl BaseMods {
68    pub fn new(record: &bam::Record, min_ml_score: u8) -> BaseMods {
69        // my basemod parser is ~25% faster than rust_htslib's
70        BaseMods::my_mm_ml_parser(record, min_ml_score)
71    }
72
73    pub fn my_mm_ml_parser(record: &bam::Record, min_ml_score: u8) -> BaseMods {
74        // regex for matching the MM tag
75        lazy_static! {
76            // MM:Z:([ACGTUN][-+]([A-Za-z]+|[0-9]+)[.?]?(,[0-9]+)*;)*
77            static ref MM_RE: Regex =
78                Regex::new(r"((([ACGTUN])([-+])([A-Za-z]+|[0-9]+))[.?]?((,[0-9]+)*;)*)").unwrap();
79        }
80        // Array to store all the different modifications within the MM tag
81        let mut rtn = vec![];
82
83        let ml_tag = get_u8_tag(record, b"ML");
84
85        let mut num_mods_seen = 0;
86
87        // if there is an MM tag iterate over all the regex matches
88        if let Ok(Aux::String(mm_text)) = record.aux(b"MM") {
89            for cap in MM_RE.captures_iter(mm_text) {
90                let mod_base = cap.get(3).map(|m| m.as_str().as_bytes()[0]).unwrap();
91                let mod_strand = cap.get(4).map_or("", |m| m.as_str());
92                let modification_type = cap.get(5).map_or("", |m| m.as_str());
93                let mod_dists_str = cap.get(6).map_or("", |m| m.as_str());
94                // parse the string containing distances between modifications into a vector of i64
95                let mod_dists: Vec<i64> = mod_dists_str
96                    .trim_end_matches(';')
97                    .split(',')
98                    .map(|s| s.trim())
99                    .filter(|s| !s.is_empty())
100                    .map(|s| s.parse().unwrap())
101                    .collect();
102
103                // get forward sequence bases from the bam record
104                let forward_bases = if record.is_reverse() {
105                    convert_seq_uppercase(revcomp(record.seq().as_bytes()))
106                } else {
107                    convert_seq_uppercase(record.seq().as_bytes())
108                };
109                log::trace!(
110                    "mod_base: {}, mod_strand: {}, modification_type: {}, mod_dists: {:?}",
111                    mod_base as char,
112                    mod_strand,
113                    modification_type,
114                    mod_dists
115                );
116                // find real positions in the forward sequence
117                let mut cur_mod_idx = 0;
118                let mut cur_seq_idx = 0;
119                let mut dist_from_last_mod_base = 0;
120                let mut unfiltered_modified_positions: Vec<i64> = vec![0; mod_dists.len()];
121                while cur_seq_idx < forward_bases.len() && cur_mod_idx < mod_dists.len() {
122                    let cur_base = forward_bases[cur_seq_idx];
123                    if (cur_base == mod_base || mod_base == b'N')
124                        && dist_from_last_mod_base == mod_dists[cur_mod_idx]
125                    {
126                        unfiltered_modified_positions[cur_mod_idx] =
127                            i64::try_from(cur_seq_idx).unwrap();
128                        dist_from_last_mod_base = 0;
129                        cur_mod_idx += 1;
130                    } else if cur_base == mod_base {
131                        dist_from_last_mod_base += 1
132                    }
133                    cur_seq_idx += 1;
134                }
135                // assert that we extract the same number of modifications as we have distances
136                assert_eq!(
137                    cur_mod_idx,
138                    mod_dists.len(),
139                    "{:?} {}",
140                    String::from_utf8_lossy(record.qname()),
141                    record.is_reverse()
142                );
143
144                // check for the probability of modification.
145                let num_mods_cur_end = num_mods_seen + unfiltered_modified_positions.len();
146                let unfiltered_modified_probabilities = if num_mods_cur_end > ml_tag.len() {
147                    let needed_num_of_zeros = num_mods_cur_end - ml_tag.len();
148                    let mut to_add = vec![0; needed_num_of_zeros];
149                    let mut has = ml_tag[num_mods_seen..ml_tag.len()].to_vec();
150                    has.append(&mut to_add);
151                    log::warn!(
152                        "ML tag is too short for the number of modifications found in the MM tag. Assuming an ML value of 0 after the first {num_mods_cur_end} modifications."
153                    );
154                    has
155                } else {
156                    ml_tag[num_mods_seen..num_mods_cur_end].to_vec()
157                };
158                num_mods_seen = num_mods_cur_end;
159
160                // must be true for filtering, and at this point
161                assert_eq!(
162                    unfiltered_modified_positions.len(),
163                    unfiltered_modified_probabilities.len()
164                );
165
166                // Filter mods based on probabilities
167                let (modified_probabilities, modified_positions): (Vec<u8>, Vec<i64>) =
168                    unfiltered_modified_probabilities
169                        .iter()
170                        .zip(unfiltered_modified_positions.iter())
171                        .filter(|(&ml, &_mm)| ml >= min_ml_score)
172                        .unzip();
173
174                // don't add empty basemods
175                if modified_positions.is_empty() {
176                    continue;
177                }
178                // add to a struct
179                let mods = BaseMod::new(
180                    record,
181                    mod_base,
182                    mod_strand.chars().next().unwrap(),
183                    modification_type.chars().next().unwrap(),
184                    modified_positions,
185                    modified_probabilities,
186                );
187                rtn.push(mods);
188            }
189        } else {
190            log::trace!("No MM tag found");
191        }
192
193        if ml_tag.len() != num_mods_seen {
194            log::warn!(
195                "ML tag ({}) different number than MM tag ({}).",
196                ml_tag.len(),
197                num_mods_seen
198            );
199        }
200        // needed so I can compare methods
201        rtn.sort();
202        BaseMods { base_mods: rtn }
203    }
204
205    pub fn hashmap_to_basemods(
206        map: HashMap<(i32, i32, i32), Vec<(i64, u8)>>,
207        record: &bam::Record,
208    ) -> BaseMods {
209        let mut rtn = vec![];
210        for (mod_info, mods) in map {
211            let mod_base = mod_info.0 as u8;
212            let mod_type = mod_info.1 as u8 as char;
213            let mod_strand = if mod_info.2 == 0 { '+' } else { '-' };
214            let (mut positions, mut qualities): (Vec<i64>, Vec<u8>) = mods.into_iter().unzip();
215            if record.is_reverse() {
216                let length = record.seq_len() as i64;
217                positions = positions
218                    .into_iter()
219                    .rev()
220                    .map(|p| length - p - 1)
221                    .collect();
222                qualities.reverse();
223            }
224            let mods = BaseMod::new(record, mod_base, mod_strand, mod_type, positions, qualities);
225            rtn.push(mods);
226        }
227        // needed so I can compare methods
228        rtn.sort();
229        BaseMods { base_mods: rtn }
230    }
231
232    /// remove m6a base mods from the struct
233    pub fn drop_m6a(&mut self) {
234        self.base_mods.retain(|bm| !bm.is_m6a());
235    }
236
237    /// remove cpg/5mc base mods from the struct
238    pub fn drop_cpg(&mut self) {
239        self.base_mods.retain(|bm| !bm.is_cpg());
240    }
241
242    /// drop the forward stand of basemod calls
243    pub fn drop_forward(&mut self) {
244        self.base_mods.retain(|bm| bm.strand == '-');
245    }
246
247    /// drop the reverse strand of basemod calls   
248    pub fn drop_reverse(&mut self) {
249        self.base_mods.retain(|bm| bm.strand == '+');
250    }
251
252    /// drop m6A modifications with a qual less than the min_ml_score
253    pub fn filter_m6a(&mut self, min_ml_score: u8) {
254        self.base_mods
255            .iter_mut()
256            .filter(|bm| bm.is_m6a())
257            .for_each(|bm| bm.ranges.filter_by_qual(min_ml_score));
258    }
259
260    /// drop 5mC modifications with a qual less than the min_ml_score
261    pub fn filter_5mc(&mut self, min_ml_score: u8) {
262        self.base_mods
263            .iter_mut()
264            .filter(|bm| bm.is_cpg())
265            .for_each(|bm| bm.ranges.filter_by_qual(min_ml_score));
266    }
267
268    /// filter the basemods at the read ends
269    pub fn filter_at_read_ends(&mut self, n_strip: i64) {
270        if n_strip <= 0 {
271            return;
272        }
273        self.base_mods
274            .iter_mut()
275            .for_each(|bm| bm.filter_at_read_ends(n_strip));
276    }
277
278    /// combine the forward and reverse m6a data
279    pub fn m6a(&self) -> Ranges {
280        let ranges = self
281            .base_mods
282            .iter()
283            .filter(|x| x.is_m6a())
284            .map(|x| &x.ranges)
285            .collect();
286        Ranges::merge_ranges(ranges)
287    }
288
289    /// combine the forward and reverse cpd/5mc data
290    pub fn cpg(&self) -> Ranges {
291        let ranges = self
292            .base_mods
293            .iter()
294            .filter(|x| x.is_cpg())
295            .map(|x| &x.ranges)
296            .collect();
297        Ranges::merge_ranges(ranges)
298    }
299
300    /// Example MM tag: MM:Z:C+m,11,6,10;A+a,0,0,0;
301    /// Example ML tag: ML:B:C,157,30,2,164,118,255
302    pub fn add_mm_and_ml_tags(&self, record: &mut bam::Record) {
303        // init the mm and ml tag to be populated
304        let mut ml_tag: Vec<u8> = vec![];
305        let mut mm_tag = "".to_string();
306        // need the original sequence for distances between bases.
307        let mut seq = record.seq().as_bytes();
308        if record.is_reverse() {
309            seq = revcomp(seq);
310        }
311        // add to the ml and mm tag.
312        for basemod in self.base_mods.iter() {
313            // adding quality values (ML)
314            ml_tag.extend(basemod.ranges.get_forward_quals());
315            // get MM tag values
316            let mut cur_mm = vec![];
317            let positions = basemod.ranges.forward_starts();
318            let mut last_pos = 0;
319            for pos in positions {
320                let u_pos = pos as usize;
321                let mut in_between = 0;
322                if last_pos < u_pos {
323                    for base in seq[last_pos..u_pos].iter() {
324                        if *base == basemod.modified_base {
325                            in_between += 1;
326                        }
327                    }
328                }
329                last_pos = u_pos + 1;
330                cur_mm.push(in_between);
331            }
332            // Add to the MM string
333            mm_tag.push(basemod.modified_base as char);
334            mm_tag.push(basemod.strand);
335            mm_tag.push(basemod.modification_type);
336            for diff in cur_mm {
337                mm_tag.push_str(&format!(",{diff}"));
338            }
339            mm_tag.push(';')
340            // next basemod
341        }
342        log::trace!(
343            "{}\n{}\n{}\n",
344            record.is_reverse(),
345            mm_tag,
346            String::from_utf8_lossy(&seq)
347        );
348        // clear out the old base mods
349        record.remove_aux(b"MM").unwrap_or(());
350        record.remove_aux(b"ML").unwrap_or(());
351        // Add MM
352        let aux_integer_field = Aux::String(&mm_tag);
353        record.push_aux(b"MM", aux_integer_field).unwrap();
354        // Add ML
355        let aux_array: AuxArray<u8> = (&ml_tag).into();
356        let aux_array_field = Aux::ArrayU8(aux_array);
357        record.push_aux(b"ML", aux_array_field).unwrap();
358    }
359}