1use crate::utils::bamranges::*;
2use crate::utils::bio_io::*;
3use bio::alphabets::dna::revcomp;
4use lazy_static::lazy_static;
5use regex::Regex;
6use rust_htslib::{
7 bam,
8 bam::record::{Aux, AuxArray},
9};
10use std::collections::HashMap;
11
12use std::convert::TryFrom;
13
14#[derive(Eq, PartialEq, Debug, PartialOrd, Ord, Clone)]
15pub struct BaseMod {
16 pub modified_base: u8,
17 pub strand: char,
18 pub modification_type: char,
19 pub ranges: Ranges,
20 pub record_is_reverse: bool,
21}
22
23impl BaseMod {
24 pub fn new(
25 record: &bam::Record,
26 modified_base: u8,
27 strand: char,
28 modification_type: char,
29 modified_bases_forward: Vec<i64>,
30 modified_probabilities_forward: Vec<u8>,
31 ) -> Self {
32 let tmp = modified_bases_forward.clone();
33 let mut ranges = Ranges::new(record, modified_bases_forward, None, None);
34 ranges.set_qual(modified_probabilities_forward);
35 let record_is_reverse = record.is_reverse();
36 assert_eq!(tmp, ranges.get_forward_starts(), "forward starts not equal");
38 Self {
39 modified_base,
40 strand,
41 modification_type,
42 ranges,
43 record_is_reverse,
44 }
45 }
46
47 pub fn is_m6a(&self) -> bool {
48 self.modification_type == 'a'
49 }
50
51 pub fn is_cpg(&self) -> bool {
52 self.modification_type == 'm'
53 }
54
55 pub fn filter_at_read_ends(&mut self, n_strip: i64) {
56 if n_strip <= 0 {
57 return;
58 }
59 self.ranges.filter_starts_at_read_ends(n_strip);
60 }
61}
62
63#[derive(Eq, PartialEq, Debug, Clone)]
64pub struct BaseMods {
65 pub base_mods: Vec<BaseMod>,
66}
67
68impl BaseMods {
69 pub fn new(record: &bam::Record, min_ml_score: u8) -> BaseMods {
70 BaseMods::my_mm_ml_parser(record, min_ml_score)
72 }
73
74 pub fn my_mm_ml_parser(record: &bam::Record, min_ml_score: u8) -> BaseMods {
75 lazy_static! {
77 static ref MM_RE: Regex =
79 Regex::new(r"((([ACGTUN])([-+])([A-Za-z]+|[0-9]+))[.?]?((,[0-9]+)*;)*)").unwrap();
80 }
81 let mut rtn = vec![];
83
84 let ml_tag = get_u8_tag(record, b"ML");
85
86 let mut num_mods_seen = 0;
87
88 if let Ok(Aux::String(mm_text)) = record.aux(b"MM") {
90 for cap in MM_RE.captures_iter(mm_text) {
91 let mod_base = cap.get(3).map(|m| m.as_str().as_bytes()[0]).unwrap();
92 let mod_strand = cap.get(4).map_or("", |m| m.as_str());
93 let modification_type = cap.get(5).map_or("", |m| m.as_str());
94 let mod_dists_str = cap.get(6).map_or("", |m| m.as_str());
95 let mod_dists: Vec<i64> = mod_dists_str
97 .trim_end_matches(';')
98 .split(',')
99 .map(|s| s.trim())
100 .filter(|s| !s.is_empty())
101 .map(|s| s.parse().unwrap())
102 .collect();
103
104 let forward_bases = if record.is_reverse() {
106 convert_seq_uppercase(revcomp(record.seq().as_bytes()))
107 } else {
108 convert_seq_uppercase(record.seq().as_bytes())
109 };
110 log::trace!(
111 "mod_base: {}, mod_strand: {}, modification_type: {}, mod_dists: {:?}",
112 mod_base as char,
113 mod_strand,
114 modification_type,
115 mod_dists
116 );
117 let mut cur_mod_idx = 0;
119 let mut cur_seq_idx = 0;
120 let mut dist_from_last_mod_base = 0;
121 let mut unfiltered_modified_positions: Vec<i64> = vec![0; mod_dists.len()];
122 while cur_seq_idx < forward_bases.len() && cur_mod_idx < mod_dists.len() {
123 let cur_base = forward_bases[cur_seq_idx];
124 if (cur_base == mod_base || mod_base == b'N')
125 && dist_from_last_mod_base == mod_dists[cur_mod_idx]
126 {
127 unfiltered_modified_positions[cur_mod_idx] =
128 i64::try_from(cur_seq_idx).unwrap();
129 dist_from_last_mod_base = 0;
130 cur_mod_idx += 1;
131 } else if cur_base == mod_base {
132 dist_from_last_mod_base += 1
133 }
134 cur_seq_idx += 1;
135 }
136 assert_eq!(
138 cur_mod_idx,
139 mod_dists.len(),
140 "{:?} {}",
141 String::from_utf8_lossy(record.qname()),
142 record.is_reverse()
143 );
144
145 let num_mods_cur_end = num_mods_seen + unfiltered_modified_positions.len();
147 let unfiltered_modified_probabilities = if num_mods_cur_end > ml_tag.len() {
148 let needed_num_of_zeros = num_mods_cur_end - ml_tag.len();
149 let mut to_add = vec![0; needed_num_of_zeros];
150 let mut has = ml_tag[num_mods_seen..ml_tag.len()].to_vec();
151 has.append(&mut to_add);
152 log::warn!(
153 "ML tag is too short for the number of modifications found in the MM tag. Assuming an ML value of 0 after the first {num_mods_cur_end} modifications."
154 );
155 has
156 } else {
157 ml_tag[num_mods_seen..num_mods_cur_end].to_vec()
158 };
159 num_mods_seen = num_mods_cur_end;
160
161 assert_eq!(
163 unfiltered_modified_positions.len(),
164 unfiltered_modified_probabilities.len()
165 );
166
167 let (modified_probabilities, modified_positions): (Vec<u8>, Vec<i64>) =
169 unfiltered_modified_probabilities
170 .iter()
171 .zip(unfiltered_modified_positions.iter())
172 .filter(|(&ml, &_mm)| ml >= min_ml_score)
173 .unzip();
174
175 if modified_positions.is_empty() {
177 continue;
178 }
179 let mods = BaseMod::new(
181 record,
182 mod_base,
183 mod_strand.chars().next().unwrap(),
184 modification_type.chars().next().unwrap(),
185 modified_positions,
186 modified_probabilities,
187 );
188 rtn.push(mods);
189 }
190 } else {
191 log::trace!("No MM tag found");
192 }
193
194 if ml_tag.len() != num_mods_seen {
195 log::warn!(
196 "ML tag ({}) different number than MM tag ({}).",
197 ml_tag.len(),
198 num_mods_seen
199 );
200 }
201 rtn.sort();
203 BaseMods { base_mods: rtn }
204 }
205
206 pub fn hashmap_to_basemods(
207 map: HashMap<(i32, i32, i32), Vec<(i64, u8)>>,
208 record: &bam::Record,
209 ) -> BaseMods {
210 let mut rtn = vec![];
211 for (mod_info, mods) in map {
212 let mod_base = mod_info.0 as u8;
213 let mod_type = mod_info.1 as u8 as char;
214 let mod_strand = if mod_info.2 == 0 { '+' } else { '-' };
215 let (mut positions, mut qualities): (Vec<i64>, Vec<u8>) = mods.into_iter().unzip();
216 if record.is_reverse() {
217 let length = record.seq_len() as i64;
218 positions = positions
219 .into_iter()
220 .rev()
221 .map(|p| length - p - 1)
222 .collect();
223 qualities.reverse();
224 }
225 let mods = BaseMod::new(record, mod_base, mod_strand, mod_type, positions, qualities);
226 rtn.push(mods);
227 }
228 rtn.sort();
230 BaseMods { base_mods: rtn }
231 }
232
233 pub fn drop_m6a(&mut self) {
235 self.base_mods.retain(|bm| !bm.is_m6a());
236 }
237
238 pub fn drop_cpg(&mut self) {
240 self.base_mods.retain(|bm| !bm.is_cpg());
241 }
242
243 pub fn drop_forward(&mut self) {
245 self.base_mods.retain(|bm| bm.strand == '-');
246 }
247
248 pub fn drop_reverse(&mut self) {
250 self.base_mods.retain(|bm| bm.strand == '+');
251 }
252
253 pub fn filter_m6a(&mut self, min_ml_score: u8) {
255 self.base_mods
256 .iter_mut()
257 .filter(|bm| bm.is_m6a())
258 .for_each(|bm| bm.ranges.filter_by_qual(min_ml_score));
259 }
260
261 pub fn filter_5mc(&mut self, min_ml_score: u8) {
263 self.base_mods
264 .iter_mut()
265 .filter(|bm| bm.is_cpg())
266 .for_each(|bm| bm.ranges.filter_by_qual(min_ml_score));
267 }
268
269 pub fn filter_at_read_ends(&mut self, n_strip: i64) {
271 if n_strip <= 0 {
272 return;
273 }
274 self.base_mods
275 .iter_mut()
276 .for_each(|bm| bm.filter_at_read_ends(n_strip));
277 }
278
279 pub fn m6a(&self) -> Ranges {
281 let ranges = self
282 .base_mods
283 .iter()
284 .filter(|x| x.is_m6a())
285 .map(|x| &x.ranges)
286 .collect();
287 Ranges::merge_ranges(ranges)
288 }
289
290 pub fn cpg(&self) -> Ranges {
292 let ranges = self
293 .base_mods
294 .iter()
295 .filter(|x| x.is_cpg())
296 .map(|x| &x.ranges)
297 .collect();
298 Ranges::merge_ranges(ranges)
299 }
300
301 pub fn add_mm_and_ml_tags(&self, record: &mut bam::Record) {
304 let mut ml_tag: Vec<u8> = vec![];
306 let mut mm_tag = "".to_string();
307 let mut seq = record.seq().as_bytes();
309 if record.is_reverse() {
310 seq = revcomp(seq);
311 }
312 for basemod in self.base_mods.iter() {
314 ml_tag.extend(basemod.ranges.get_forward_quals());
316 let mut cur_mm = vec![];
318 let positions = basemod.ranges.get_forward_starts();
319 let mut last_pos = 0;
320 for pos in positions {
321 let u_pos = pos as usize;
322 let mut in_between = 0;
323 if last_pos < u_pos {
324 for base in seq[last_pos..u_pos].iter() {
325 if *base == basemod.modified_base {
326 in_between += 1;
327 }
328 }
329 }
330 last_pos = u_pos + 1;
331 cur_mm.push(in_between);
332 }
333 mm_tag.push(basemod.modified_base as char);
335 mm_tag.push(basemod.strand);
336 mm_tag.push(basemod.modification_type);
337 for diff in cur_mm {
338 mm_tag.push_str(&format!(",{diff}"));
339 }
340 mm_tag.push(';')
341 }
343 log::trace!(
344 "{}\n{}\n{}\n",
345 record.is_reverse(),
346 mm_tag,
347 String::from_utf8_lossy(&seq)
348 );
349 record.remove_aux(b"MM").unwrap_or(());
351 record.remove_aux(b"ML").unwrap_or(());
352 let aux_integer_field = Aux::String(&mm_tag);
354 record.push_aux(b"MM", aux_integer_field).unwrap();
355 let aux_array: AuxArray<u8> = (&ml_tag).into();
357 let aux_array_field = Aux::ArrayU8(aux_array);
358 record.push_aux(b"ML", aux_array_field).unwrap();
359 }
360}