fibertools_rs/utils/
basemods.rs

1use crate::utils::bamranges::*;
2use crate::utils::bio_io::*;
3use bio::alphabets::dna::revcomp;
4use lazy_static::lazy_static;
5use regex::Regex;
6use rust_htslib::{
7    bam,
8    bam::record::{Aux, AuxArray},
9};
10use std::collections::HashMap;
11
12use std::convert::TryFrom;
13
14#[derive(Eq, PartialEq, Debug, PartialOrd, Ord, Clone)]
15pub struct BaseMod {
16    pub modified_base: u8,
17    pub strand: char,
18    pub modification_type: char,
19    pub ranges: Ranges,
20    pub record_is_reverse: bool,
21}
22
23impl BaseMod {
24    pub fn new(
25        record: &bam::Record,
26        modified_base: u8,
27        strand: char,
28        modification_type: char,
29        modified_bases_forward: Vec<i64>,
30        modified_probabilities_forward: Vec<u8>,
31    ) -> Self {
32        let tmp = modified_bases_forward.clone();
33        let mut ranges = Ranges::new(record, modified_bases_forward, None, None);
34        ranges.set_qual(modified_probabilities_forward);
35        let record_is_reverse = record.is_reverse();
36        //assert_eq!(tmp, ranges.get_starts(), "starts not equal");
37        assert_eq!(tmp, ranges.get_forward_starts(), "forward starts not equal");
38        Self {
39            modified_base,
40            strand,
41            modification_type,
42            ranges,
43            record_is_reverse,
44        }
45    }
46
47    pub fn is_m6a(&self) -> bool {
48        self.modification_type == 'a'
49    }
50
51    pub fn is_cpg(&self) -> bool {
52        self.modification_type == 'm'
53    }
54
55    pub fn filter_at_read_ends(&mut self, n_strip: i64) {
56        if n_strip <= 0 {
57            return;
58        }
59        self.ranges.filter_starts_at_read_ends(n_strip);
60    }
61}
62
63#[derive(Eq, PartialEq, Debug, Clone)]
64pub struct BaseMods {
65    pub base_mods: Vec<BaseMod>,
66}
67
68impl BaseMods {
69    pub fn new(record: &bam::Record, min_ml_score: u8) -> BaseMods {
70        // my basemod parser is ~25% faster than rust_htslib's
71        BaseMods::my_mm_ml_parser(record, min_ml_score)
72    }
73
74    pub fn my_mm_ml_parser(record: &bam::Record, min_ml_score: u8) -> BaseMods {
75        // regex for matching the MM tag
76        lazy_static! {
77            // MM:Z:([ACGTUN][-+]([A-Za-z]+|[0-9]+)[.?]?(,[0-9]+)*;)*
78            static ref MM_RE: Regex =
79                Regex::new(r"((([ACGTUN])([-+])([A-Za-z]+|[0-9]+))[.?]?((,[0-9]+)*;)*)").unwrap();
80        }
81        // Array to store all the different modifications within the MM tag
82        let mut rtn = vec![];
83
84        let ml_tag = get_u8_tag(record, b"ML");
85
86        let mut num_mods_seen = 0;
87
88        // if there is an MM tag iterate over all the regex matches
89        if let Ok(Aux::String(mm_text)) = record.aux(b"MM") {
90            for cap in MM_RE.captures_iter(mm_text) {
91                let mod_base = cap.get(3).map(|m| m.as_str().as_bytes()[0]).unwrap();
92                let mod_strand = cap.get(4).map_or("", |m| m.as_str());
93                let modification_type = cap.get(5).map_or("", |m| m.as_str());
94                let mod_dists_str = cap.get(6).map_or("", |m| m.as_str());
95                // parse the string containing distances between modifications into a vector of i64
96                let mod_dists: Vec<i64> = mod_dists_str
97                    .trim_end_matches(';')
98                    .split(',')
99                    .map(|s| s.trim())
100                    .filter(|s| !s.is_empty())
101                    .map(|s| s.parse().unwrap())
102                    .collect();
103
104                // get forward sequence bases from the bam record
105                let forward_bases = if record.is_reverse() {
106                    convert_seq_uppercase(revcomp(record.seq().as_bytes()))
107                } else {
108                    convert_seq_uppercase(record.seq().as_bytes())
109                };
110                log::trace!(
111                    "mod_base: {}, mod_strand: {}, modification_type: {}, mod_dists: {:?}",
112                    mod_base as char,
113                    mod_strand,
114                    modification_type,
115                    mod_dists
116                );
117                // find real positions in the forward sequence
118                let mut cur_mod_idx = 0;
119                let mut cur_seq_idx = 0;
120                let mut dist_from_last_mod_base = 0;
121                let mut unfiltered_modified_positions: Vec<i64> = vec![0; mod_dists.len()];
122                while cur_seq_idx < forward_bases.len() && cur_mod_idx < mod_dists.len() {
123                    let cur_base = forward_bases[cur_seq_idx];
124                    if (cur_base == mod_base || mod_base == b'N')
125                        && dist_from_last_mod_base == mod_dists[cur_mod_idx]
126                    {
127                        unfiltered_modified_positions[cur_mod_idx] =
128                            i64::try_from(cur_seq_idx).unwrap();
129                        dist_from_last_mod_base = 0;
130                        cur_mod_idx += 1;
131                    } else if cur_base == mod_base {
132                        dist_from_last_mod_base += 1
133                    }
134                    cur_seq_idx += 1;
135                }
136                // assert that we extract the same number of modifications as we have distances
137                assert_eq!(
138                    cur_mod_idx,
139                    mod_dists.len(),
140                    "{:?} {}",
141                    String::from_utf8_lossy(record.qname()),
142                    record.is_reverse()
143                );
144
145                // check for the probability of modification.
146                let num_mods_cur_end = num_mods_seen + unfiltered_modified_positions.len();
147                let unfiltered_modified_probabilities = if num_mods_cur_end > ml_tag.len() {
148                    let needed_num_of_zeros = num_mods_cur_end - ml_tag.len();
149                    let mut to_add = vec![0; needed_num_of_zeros];
150                    let mut has = ml_tag[num_mods_seen..ml_tag.len()].to_vec();
151                    has.append(&mut to_add);
152                    log::warn!(
153                        "ML tag is too short for the number of modifications found in the MM tag. Assuming an ML value of 0 after the first {num_mods_cur_end} modifications."
154                    );
155                    has
156                } else {
157                    ml_tag[num_mods_seen..num_mods_cur_end].to_vec()
158                };
159                num_mods_seen = num_mods_cur_end;
160
161                // must be true for filtering, and at this point
162                assert_eq!(
163                    unfiltered_modified_positions.len(),
164                    unfiltered_modified_probabilities.len()
165                );
166
167                // Filter mods based on probabilities
168                let (modified_probabilities, modified_positions): (Vec<u8>, Vec<i64>) =
169                    unfiltered_modified_probabilities
170                        .iter()
171                        .zip(unfiltered_modified_positions.iter())
172                        .filter(|(&ml, &_mm)| ml >= min_ml_score)
173                        .unzip();
174
175                // don't add empty basemods
176                if modified_positions.is_empty() {
177                    continue;
178                }
179                // add to a struct
180                let mods = BaseMod::new(
181                    record,
182                    mod_base,
183                    mod_strand.chars().next().unwrap(),
184                    modification_type.chars().next().unwrap(),
185                    modified_positions,
186                    modified_probabilities,
187                );
188                rtn.push(mods);
189            }
190        } else {
191            log::trace!("No MM tag found");
192        }
193
194        if ml_tag.len() != num_mods_seen {
195            log::warn!(
196                "ML tag ({}) different number than MM tag ({}).",
197                ml_tag.len(),
198                num_mods_seen
199            );
200        }
201        // needed so I can compare methods
202        rtn.sort();
203        BaseMods { base_mods: rtn }
204    }
205
206    pub fn hashmap_to_basemods(
207        map: HashMap<(i32, i32, i32), Vec<(i64, u8)>>,
208        record: &bam::Record,
209    ) -> BaseMods {
210        let mut rtn = vec![];
211        for (mod_info, mods) in map {
212            let mod_base = mod_info.0 as u8;
213            let mod_type = mod_info.1 as u8 as char;
214            let mod_strand = if mod_info.2 == 0 { '+' } else { '-' };
215            let (mut positions, mut qualities): (Vec<i64>, Vec<u8>) = mods.into_iter().unzip();
216            if record.is_reverse() {
217                let length = record.seq_len() as i64;
218                positions = positions
219                    .into_iter()
220                    .rev()
221                    .map(|p| length - p - 1)
222                    .collect();
223                qualities.reverse();
224            }
225            let mods = BaseMod::new(record, mod_base, mod_strand, mod_type, positions, qualities);
226            rtn.push(mods);
227        }
228        // needed so I can compare methods
229        rtn.sort();
230        BaseMods { base_mods: rtn }
231    }
232
233    /// remove m6a base mods from the struct
234    pub fn drop_m6a(&mut self) {
235        self.base_mods.retain(|bm| !bm.is_m6a());
236    }
237
238    /// remove cpg/5mc base mods from the struct
239    pub fn drop_cpg(&mut self) {
240        self.base_mods.retain(|bm| !bm.is_cpg());
241    }
242
243    /// drop the forward stand of basemod calls
244    pub fn drop_forward(&mut self) {
245        self.base_mods.retain(|bm| bm.strand == '-');
246    }
247
248    /// drop the reverse strand of basemod calls   
249    pub fn drop_reverse(&mut self) {
250        self.base_mods.retain(|bm| bm.strand == '+');
251    }
252
253    /// drop m6A modifications with a qual less than the min_ml_score
254    pub fn filter_m6a(&mut self, min_ml_score: u8) {
255        self.base_mods
256            .iter_mut()
257            .filter(|bm| bm.is_m6a())
258            .for_each(|bm| bm.ranges.filter_by_qual(min_ml_score));
259    }
260
261    /// drop 5mC modifications with a qual less than the min_ml_score
262    pub fn filter_5mc(&mut self, min_ml_score: u8) {
263        self.base_mods
264            .iter_mut()
265            .filter(|bm| bm.is_cpg())
266            .for_each(|bm| bm.ranges.filter_by_qual(min_ml_score));
267    }
268
269    /// filter the basemods at the read ends
270    pub fn filter_at_read_ends(&mut self, n_strip: i64) {
271        if n_strip <= 0 {
272            return;
273        }
274        self.base_mods
275            .iter_mut()
276            .for_each(|bm| bm.filter_at_read_ends(n_strip));
277    }
278
279    /// combine the forward and reverse m6a data
280    pub fn m6a(&self) -> Ranges {
281        let ranges = self
282            .base_mods
283            .iter()
284            .filter(|x| x.is_m6a())
285            .map(|x| &x.ranges)
286            .collect();
287        Ranges::merge_ranges(ranges)
288    }
289
290    /// combine the forward and reverse cpd/5mc data
291    pub fn cpg(&self) -> Ranges {
292        let ranges = self
293            .base_mods
294            .iter()
295            .filter(|x| x.is_cpg())
296            .map(|x| &x.ranges)
297            .collect();
298        Ranges::merge_ranges(ranges)
299    }
300
301    /// Example MM tag: MM:Z:C+m,11,6,10;A+a,0,0,0;
302    /// Example ML tag: ML:B:C,157,30,2,164,118,255
303    pub fn add_mm_and_ml_tags(&self, record: &mut bam::Record) {
304        // init the mm and ml tag to be populated
305        let mut ml_tag: Vec<u8> = vec![];
306        let mut mm_tag = "".to_string();
307        // need the original sequence for distances between bases.
308        let mut seq = record.seq().as_bytes();
309        if record.is_reverse() {
310            seq = revcomp(seq);
311        }
312        // add to the ml and mm tag.
313        for basemod in self.base_mods.iter() {
314            // adding quality values (ML)
315            ml_tag.extend(basemod.ranges.get_forward_quals());
316            // get MM tag values
317            let mut cur_mm = vec![];
318            let positions = basemod.ranges.get_forward_starts();
319            let mut last_pos = 0;
320            for pos in positions {
321                let u_pos = pos as usize;
322                let mut in_between = 0;
323                if last_pos < u_pos {
324                    for base in seq[last_pos..u_pos].iter() {
325                        if *base == basemod.modified_base {
326                            in_between += 1;
327                        }
328                    }
329                }
330                last_pos = u_pos + 1;
331                cur_mm.push(in_between);
332            }
333            // Add to the MM string
334            mm_tag.push(basemod.modified_base as char);
335            mm_tag.push(basemod.strand);
336            mm_tag.push(basemod.modification_type);
337            for diff in cur_mm {
338                mm_tag.push_str(&format!(",{diff}"));
339            }
340            mm_tag.push(';')
341            // next basemod
342        }
343        log::trace!(
344            "{}\n{}\n{}\n",
345            record.is_reverse(),
346            mm_tag,
347            String::from_utf8_lossy(&seq)
348        );
349        // clear out the old base mods
350        record.remove_aux(b"MM").unwrap_or(());
351        record.remove_aux(b"ML").unwrap_or(());
352        // Add MM
353        let aux_integer_field = Aux::String(&mm_tag);
354        record.push_aux(b"MM", aux_integer_field).unwrap();
355        // Add ML
356        let aux_array: AuxArray<u8> = (&ml_tag).into();
357        let aux_array_field = Aux::ArrayU8(aux_array);
358        record.push_aux(b"ML", aux_array_field).unwrap();
359    }
360}