1use crate::utils::bamranges::*;
2use crate::utils::bio_io::*;
3use bio::alphabets::dna::revcomp;
4use lazy_static::lazy_static;
5use regex::Regex;
6use rust_htslib::{
7 bam,
8 bam::record::{Aux, AuxArray},
9};
10use std::collections::HashMap;
11
12use std::convert::TryFrom;
13
14#[derive(Eq, PartialEq, Debug, PartialOrd, Ord, Clone)]
15pub struct BaseMod {
16 pub modified_base: u8,
17 pub strand: char,
18 pub modification_type: char,
19 pub ranges: Ranges,
20 pub record_is_reverse: bool,
21}
22
23impl BaseMod {
24 pub fn new(
25 record: &bam::Record,
26 modified_base: u8,
27 strand: char,
28 modification_type: char,
29 modified_bases_forward: Vec<i64>,
30 modified_probabilities_forward: Vec<u8>,
31 ) -> Self {
32 let tmp = modified_bases_forward.clone();
33 let mut ranges = Ranges::new(record, modified_bases_forward, None, None);
34 ranges.set_qual(modified_probabilities_forward);
35 let record_is_reverse = record.is_reverse();
36 assert_eq!(tmp, ranges.get_forward_starts(), "forward starts not equal");
38 Self {
39 modified_base,
40 strand,
41 modification_type,
42 ranges,
43 record_is_reverse,
44 }
45 }
46
47 pub fn is_m6a(&self) -> bool {
48 self.modification_type == 'a'
49 }
50
51 pub fn is_cpg(&self) -> bool {
52 self.modification_type == 'm'
53 }
54
55 pub fn filter_at_read_ends(&mut self, n_strip: i64) {
56 if n_strip <= 0 {
57 return;
58 }
59 self.ranges.filter_starts_at_read_ends(n_strip);
60 }
61}
62
63#[derive(Eq, PartialEq, Debug, Clone)]
64pub struct BaseMods {
65 pub base_mods: Vec<BaseMod>,
66}
67
68impl BaseMods {
69 pub fn new(record: &bam::Record, min_ml_score: u8) -> BaseMods {
70 BaseMods::my_mm_ml_parser(record, min_ml_score)
72 }
73
74 pub fn my_mm_ml_parser(record: &bam::Record, min_ml_score: u8) -> BaseMods {
75 lazy_static! {
77 static ref MM_RE: Regex =
79 Regex::new(r"((([ACGTUN])([-+])([a-z]+|[0-9]+))[.?]?((,[0-9]+)*;)*)").unwrap();
80 }
81 let mut rtn = vec![];
83
84 let ml_tag = get_u8_tag(record, b"ML");
85
86 let mut num_mods_seen = 0;
87
88 if let Ok(Aux::String(mm_text)) = record.aux(b"MM") {
90 for cap in MM_RE.captures_iter(mm_text) {
91 let mod_base = cap.get(3).map(|m| m.as_str().as_bytes()[0]).unwrap();
92 let mod_strand = cap.get(4).map_or("", |m| m.as_str());
93 let modification_type = cap.get(5).map_or("", |m| m.as_str());
94 let mod_dists_str = cap.get(6).map_or("", |m| m.as_str());
95 let mod_dists: Vec<i64> = mod_dists_str
97 .trim_end_matches(';')
98 .split(',')
99 .map(|s| s.trim())
100 .filter(|s| !s.is_empty())
101 .map(|s| s.parse().unwrap())
102 .collect();
103
104 let forward_bases = if record.is_reverse() {
106 revcomp(record.seq().as_bytes())
107 } else {
108 record.seq().as_bytes()
109 };
110 log::trace!(
111 "mod_base: {}, mod_strand: {}, modification_type: {}, mod_dists: {:?}",
112 mod_base as char,
113 mod_strand,
114 modification_type,
115 mod_dists
116 );
117 let mut cur_mod_idx = 0;
119 let mut cur_seq_idx = 0;
120 let mut dist_from_last_mod_base = 0;
121 let mut unfiltered_modified_positions: Vec<i64> = vec![0; mod_dists.len()];
122 while cur_seq_idx < forward_bases.len() && cur_mod_idx < mod_dists.len() {
123 let cur_base = forward_bases[cur_seq_idx];
124 if cur_base == mod_base && dist_from_last_mod_base == mod_dists[cur_mod_idx] {
125 unfiltered_modified_positions[cur_mod_idx] =
126 i64::try_from(cur_seq_idx).unwrap();
127 dist_from_last_mod_base = 0;
128 cur_mod_idx += 1;
129 } else if cur_base == mod_base {
130 dist_from_last_mod_base += 1
131 }
132 cur_seq_idx += 1;
133 }
134 assert_eq!(
136 cur_mod_idx,
137 mod_dists.len(),
138 "{:?} {}",
139 String::from_utf8_lossy(record.qname()),
140 record.is_reverse()
141 );
142
143 let num_mods_cur_end = num_mods_seen + unfiltered_modified_positions.len();
145 let unfiltered_modified_probabilities = if num_mods_cur_end > ml_tag.len() {
146 let needed_num_of_zeros = num_mods_cur_end - ml_tag.len();
147 let mut to_add = vec![0; needed_num_of_zeros];
148 let mut has = ml_tag[num_mods_seen..ml_tag.len()].to_vec();
149 has.append(&mut to_add);
150 log::warn!(
151 "ML tag is too short for the number of modifications found in the MM tag. Assuming an ML value of 0 after the first {num_mods_cur_end} modifications."
152 );
153 has
154 } else {
155 ml_tag[num_mods_seen..num_mods_cur_end].to_vec()
156 };
157 num_mods_seen = num_mods_cur_end;
158
159 assert_eq!(
161 unfiltered_modified_positions.len(),
162 unfiltered_modified_probabilities.len()
163 );
164
165 let (modified_probabilities, modified_positions): (Vec<u8>, Vec<i64>) =
167 unfiltered_modified_probabilities
168 .iter()
169 .zip(unfiltered_modified_positions.iter())
170 .filter(|(&ml, &_mm)| ml >= min_ml_score)
171 .unzip();
172
173 if modified_positions.is_empty() {
175 continue;
176 }
177 let mods = BaseMod::new(
179 record,
180 mod_base,
181 mod_strand.chars().next().unwrap(),
182 modification_type.chars().next().unwrap(),
183 modified_positions,
184 modified_probabilities,
185 );
186 rtn.push(mods);
187 }
188 } else {
189 log::trace!("No MM tag found");
190 }
191
192 if ml_tag.len() != num_mods_seen {
193 log::warn!(
194 "ML tag ({}) different number than MM tag ({}).",
195 ml_tag.len(),
196 num_mods_seen
197 );
198 }
199 rtn.sort();
201 BaseMods { base_mods: rtn }
202 }
203
204 pub fn hashmap_to_basemods(
205 map: HashMap<(i32, i32, i32), Vec<(i64, u8)>>,
206 record: &bam::Record,
207 ) -> BaseMods {
208 let mut rtn = vec![];
209 for (mod_info, mods) in map {
210 let mod_base = mod_info.0 as u8;
211 let mod_type = mod_info.1 as u8 as char;
212 let mod_strand = if mod_info.2 == 0 { '+' } else { '-' };
213 let (mut positions, mut qualities): (Vec<i64>, Vec<u8>) = mods.into_iter().unzip();
214 if record.is_reverse() {
215 let length = record.seq_len() as i64;
216 positions = positions
217 .into_iter()
218 .rev()
219 .map(|p| length - p - 1)
220 .collect();
221 qualities.reverse();
222 }
223 let mods = BaseMod::new(record, mod_base, mod_strand, mod_type, positions, qualities);
224 rtn.push(mods);
225 }
226 rtn.sort();
228 BaseMods { base_mods: rtn }
229 }
230
231 pub fn drop_m6a(&mut self) {
233 self.base_mods.retain(|bm| !bm.is_m6a());
234 }
235
236 pub fn drop_cpg(&mut self) {
238 self.base_mods.retain(|bm| !bm.is_cpg());
239 }
240
241 pub fn drop_forward(&mut self) {
243 self.base_mods.retain(|bm| bm.strand == '-');
244 }
245
246 pub fn drop_reverse(&mut self) {
248 self.base_mods.retain(|bm| bm.strand == '+');
249 }
250
251 pub fn filter_m6a(&mut self, min_ml_score: u8) {
253 self.base_mods
254 .iter_mut()
255 .filter(|bm| bm.is_m6a())
256 .for_each(|bm| bm.ranges.filter_by_qual(min_ml_score));
257 }
258
259 pub fn filter_5mc(&mut self, min_ml_score: u8) {
261 self.base_mods
262 .iter_mut()
263 .filter(|bm| bm.is_cpg())
264 .for_each(|bm| bm.ranges.filter_by_qual(min_ml_score));
265 }
266
267 pub fn filter_at_read_ends(&mut self, n_strip: i64) {
269 if n_strip <= 0 {
270 return;
271 }
272 self.base_mods
273 .iter_mut()
274 .for_each(|bm| bm.filter_at_read_ends(n_strip));
275 }
276
277 pub fn m6a(&self) -> Ranges {
279 let ranges = self
280 .base_mods
281 .iter()
282 .filter(|x| x.is_m6a())
283 .map(|x| &x.ranges)
284 .collect();
285 Ranges::merge_ranges(ranges)
286 }
287
288 pub fn cpg(&self) -> Ranges {
290 let ranges = self
291 .base_mods
292 .iter()
293 .filter(|x| x.is_cpg())
294 .map(|x| &x.ranges)
295 .collect();
296 Ranges::merge_ranges(ranges)
297 }
298
299 pub fn add_mm_and_ml_tags(&self, record: &mut bam::Record) {
302 let mut ml_tag: Vec<u8> = vec![];
304 let mut mm_tag = "".to_string();
305 let mut seq = record.seq().as_bytes();
307 if record.is_reverse() {
308 seq = revcomp(seq);
309 }
310 for basemod in self.base_mods.iter() {
312 ml_tag.extend(basemod.ranges.get_forward_quals());
314 let mut cur_mm = vec![];
316 let positions = basemod.ranges.get_forward_starts();
317 let mut last_pos = 0;
318 for pos in positions {
319 let u_pos = pos as usize;
320 let mut in_between = 0;
321 if last_pos < u_pos {
322 for base in seq[last_pos..u_pos].iter() {
323 if *base == basemod.modified_base {
324 in_between += 1;
325 }
326 }
327 }
328 last_pos = u_pos + 1;
329 cur_mm.push(in_between);
330 }
331 mm_tag.push(basemod.modified_base as char);
333 mm_tag.push(basemod.strand);
334 mm_tag.push(basemod.modification_type);
335 for diff in cur_mm {
336 mm_tag.push_str(&format!(",{}", diff));
337 }
338 mm_tag.push(';')
339 }
341 log::trace!(
342 "{}\n{}\n{}\n",
343 record.is_reverse(),
344 mm_tag,
345 String::from_utf8_lossy(&seq)
346 );
347 record.remove_aux(b"MM").unwrap_or(());
349 record.remove_aux(b"ML").unwrap_or(());
350 let aux_integer_field = Aux::String(&mm_tag);
352 record.push_aux(b"MM", aux_integer_field).unwrap();
353 let aux_array: AuxArray<u8> = (&ml_tag).into();
355 let aux_array_field = Aux::ArrayU8(aux_array);
356 record.push_aux(b"ML", aux_array_field).unwrap();
357 }
358}