1use crate::utils::bamannotations::*;
2use crate::utils::bio_io::*;
3use bio::alphabets::dna::revcomp;
4use lazy_static::lazy_static;
5use regex::Regex;
6use rust_htslib::{
7 bam,
8 bam::record::{Aux, AuxArray},
9};
10use std::collections::HashMap;
11
12use std::convert::TryFrom;
13
14#[derive(Eq, PartialEq, Debug, PartialOrd, Ord, Clone)]
15pub struct BaseMod {
16 pub modified_base: u8,
17 pub strand: char,
18 pub modification_type: char,
19 pub ranges: Ranges,
20 pub record_is_reverse: bool,
21}
22
23impl BaseMod {
24 pub fn new(
25 record: &bam::Record,
26 modified_base: u8,
27 strand: char,
28 modification_type: char,
29 modified_bases_forward: Vec<i64>,
30 modified_probabilities_forward: Vec<u8>,
31 ) -> Self {
32 let tmp = modified_bases_forward.clone();
33 let mut ranges = Ranges::new(record, modified_bases_forward, None, None);
34 ranges.set_qual(modified_probabilities_forward);
35 let record_is_reverse = record.is_reverse();
36 assert_eq!(tmp, ranges.forward_starts(), "forward starts not equal");
37 Self {
38 modified_base,
39 strand,
40 modification_type,
41 ranges,
42 record_is_reverse,
43 }
44 }
45
46 pub fn is_m6a(&self) -> bool {
47 self.modification_type == 'a'
48 }
49
50 pub fn is_cpg(&self) -> bool {
51 self.modification_type == 'm'
52 }
53
54 pub fn filter_at_read_ends(&mut self, n_strip: i64) {
55 if n_strip <= 0 {
56 return;
57 }
58 self.ranges.filter_starts_at_read_ends(n_strip);
59 }
60}
61
62#[derive(Eq, PartialEq, Debug, Clone)]
63pub struct BaseMods {
64 pub base_mods: Vec<BaseMod>,
65}
66
67impl BaseMods {
68 pub fn new(record: &bam::Record, min_ml_score: u8) -> BaseMods {
69 BaseMods::my_mm_ml_parser(record, min_ml_score)
71 }
72
73 pub fn my_mm_ml_parser(record: &bam::Record, min_ml_score: u8) -> BaseMods {
74 lazy_static! {
76 static ref MM_RE: Regex =
78 Regex::new(r"((([ACGTUN])([-+])([A-Za-z]+|[0-9]+))[.?]?((,[0-9]+)*;)*)").unwrap();
79 }
80 let mut rtn = vec![];
82
83 let ml_tag = get_u8_tag(record, b"ML");
84
85 let mut num_mods_seen = 0;
86
87 if let Ok(Aux::String(mm_text)) = record.aux(b"MM") {
89 for cap in MM_RE.captures_iter(mm_text) {
90 let mod_base = cap.get(3).map(|m| m.as_str().as_bytes()[0]).unwrap();
91 let mod_strand = cap.get(4).map_or("", |m| m.as_str());
92 let modification_type = cap.get(5).map_or("", |m| m.as_str());
93 let mod_dists_str = cap.get(6).map_or("", |m| m.as_str());
94 let mod_dists: Vec<i64> = mod_dists_str
96 .trim_end_matches(';')
97 .split(',')
98 .map(|s| s.trim())
99 .filter(|s| !s.is_empty())
100 .map(|s| s.parse().unwrap())
101 .collect();
102
103 let forward_bases = if record.is_reverse() {
105 convert_seq_uppercase(revcomp(record.seq().as_bytes()))
106 } else {
107 convert_seq_uppercase(record.seq().as_bytes())
108 };
109 log::trace!(
110 "mod_base: {}, mod_strand: {}, modification_type: {}, mod_dists: {:?}",
111 mod_base as char,
112 mod_strand,
113 modification_type,
114 mod_dists
115 );
116 let mut cur_mod_idx = 0;
118 let mut cur_seq_idx = 0;
119 let mut dist_from_last_mod_base = 0;
120 let mut unfiltered_modified_positions: Vec<i64> = vec![0; mod_dists.len()];
121 while cur_seq_idx < forward_bases.len() && cur_mod_idx < mod_dists.len() {
122 let cur_base = forward_bases[cur_seq_idx];
123 if (cur_base == mod_base || mod_base == b'N')
124 && dist_from_last_mod_base == mod_dists[cur_mod_idx]
125 {
126 unfiltered_modified_positions[cur_mod_idx] =
127 i64::try_from(cur_seq_idx).unwrap();
128 dist_from_last_mod_base = 0;
129 cur_mod_idx += 1;
130 } else if cur_base == mod_base {
131 dist_from_last_mod_base += 1
132 }
133 cur_seq_idx += 1;
134 }
135 assert_eq!(
137 cur_mod_idx,
138 mod_dists.len(),
139 "{:?} {}",
140 String::from_utf8_lossy(record.qname()),
141 record.is_reverse()
142 );
143
144 let num_mods_cur_end = num_mods_seen + unfiltered_modified_positions.len();
146 let unfiltered_modified_probabilities = if num_mods_cur_end > ml_tag.len() {
147 let needed_num_of_zeros = num_mods_cur_end - ml_tag.len();
148 let mut to_add = vec![0; needed_num_of_zeros];
149 let mut has = ml_tag[num_mods_seen..ml_tag.len()].to_vec();
150 has.append(&mut to_add);
151 log::warn!(
152 "ML tag is too short for the number of modifications found in the MM tag. Assuming an ML value of 0 after the first {num_mods_cur_end} modifications."
153 );
154 has
155 } else {
156 ml_tag[num_mods_seen..num_mods_cur_end].to_vec()
157 };
158 num_mods_seen = num_mods_cur_end;
159
160 assert_eq!(
162 unfiltered_modified_positions.len(),
163 unfiltered_modified_probabilities.len()
164 );
165
166 let (modified_probabilities, modified_positions): (Vec<u8>, Vec<i64>) =
168 unfiltered_modified_probabilities
169 .iter()
170 .zip(unfiltered_modified_positions.iter())
171 .filter(|(&ml, &_mm)| ml >= min_ml_score)
172 .unzip();
173
174 if modified_positions.is_empty() {
176 continue;
177 }
178 let mods = BaseMod::new(
180 record,
181 mod_base,
182 mod_strand.chars().next().unwrap(),
183 modification_type.chars().next().unwrap(),
184 modified_positions,
185 modified_probabilities,
186 );
187 rtn.push(mods);
188 }
189 } else {
190 log::trace!("No MM tag found");
191 }
192
193 if ml_tag.len() != num_mods_seen {
194 log::warn!(
195 "ML tag ({}) different number than MM tag ({}).",
196 ml_tag.len(),
197 num_mods_seen
198 );
199 }
200 rtn.sort();
202 BaseMods { base_mods: rtn }
203 }
204
205 pub fn hashmap_to_basemods(
206 map: HashMap<(i32, i32, i32), Vec<(i64, u8)>>,
207 record: &bam::Record,
208 ) -> BaseMods {
209 let mut rtn = vec![];
210 for (mod_info, mods) in map {
211 let mod_base = mod_info.0 as u8;
212 let mod_type = mod_info.1 as u8 as char;
213 let mod_strand = if mod_info.2 == 0 { '+' } else { '-' };
214 let (mut positions, mut qualities): (Vec<i64>, Vec<u8>) = mods.into_iter().unzip();
215 if record.is_reverse() {
216 let length = record.seq_len() as i64;
217 positions = positions
218 .into_iter()
219 .rev()
220 .map(|p| length - p - 1)
221 .collect();
222 qualities.reverse();
223 }
224 let mods = BaseMod::new(record, mod_base, mod_strand, mod_type, positions, qualities);
225 rtn.push(mods);
226 }
227 rtn.sort();
229 BaseMods { base_mods: rtn }
230 }
231
232 pub fn drop_m6a(&mut self) {
234 self.base_mods.retain(|bm| !bm.is_m6a());
235 }
236
237 pub fn drop_cpg(&mut self) {
239 self.base_mods.retain(|bm| !bm.is_cpg());
240 }
241
242 pub fn drop_forward(&mut self) {
244 self.base_mods.retain(|bm| bm.strand == '-');
245 }
246
247 pub fn drop_reverse(&mut self) {
249 self.base_mods.retain(|bm| bm.strand == '+');
250 }
251
252 pub fn filter_m6a(&mut self, min_ml_score: u8) {
254 self.base_mods
255 .iter_mut()
256 .filter(|bm| bm.is_m6a())
257 .for_each(|bm| bm.ranges.filter_by_qual(min_ml_score));
258 }
259
260 pub fn filter_5mc(&mut self, min_ml_score: u8) {
262 self.base_mods
263 .iter_mut()
264 .filter(|bm| bm.is_cpg())
265 .for_each(|bm| bm.ranges.filter_by_qual(min_ml_score));
266 }
267
268 pub fn filter_at_read_ends(&mut self, n_strip: i64) {
270 if n_strip <= 0 {
271 return;
272 }
273 self.base_mods
274 .iter_mut()
275 .for_each(|bm| bm.filter_at_read_ends(n_strip));
276 }
277
278 pub fn m6a(&self) -> Ranges {
280 let ranges = self
281 .base_mods
282 .iter()
283 .filter(|x| x.is_m6a())
284 .map(|x| &x.ranges)
285 .collect();
286 Ranges::merge_ranges(ranges)
287 }
288
289 pub fn cpg(&self) -> Ranges {
291 let ranges = self
292 .base_mods
293 .iter()
294 .filter(|x| x.is_cpg())
295 .map(|x| &x.ranges)
296 .collect();
297 Ranges::merge_ranges(ranges)
298 }
299
300 pub fn add_mm_and_ml_tags(&self, record: &mut bam::Record) {
303 let mut ml_tag: Vec<u8> = vec![];
305 let mut mm_tag = "".to_string();
306 let mut seq = record.seq().as_bytes();
308 if record.is_reverse() {
309 seq = revcomp(seq);
310 }
311 for basemod in self.base_mods.iter() {
313 ml_tag.extend(basemod.ranges.get_forward_quals());
315 let mut cur_mm = vec![];
317 let positions = basemod.ranges.forward_starts();
318 let mut last_pos = 0;
319 for pos in positions {
320 let u_pos = pos as usize;
321 let mut in_between = 0;
322 if last_pos < u_pos {
323 for base in seq[last_pos..u_pos].iter() {
324 if *base == basemod.modified_base {
325 in_between += 1;
326 }
327 }
328 }
329 last_pos = u_pos + 1;
330 cur_mm.push(in_between);
331 }
332 mm_tag.push(basemod.modified_base as char);
334 mm_tag.push(basemod.strand);
335 mm_tag.push(basemod.modification_type);
336 for diff in cur_mm {
337 mm_tag.push_str(&format!(",{diff}"));
338 }
339 mm_tag.push(';')
340 }
342 log::trace!(
343 "{}\n{}\n{}\n",
344 record.is_reverse(),
345 mm_tag,
346 String::from_utf8_lossy(&seq)
347 );
348 record.remove_aux(b"MM").unwrap_or(());
350 record.remove_aux(b"ML").unwrap_or(());
351 let aux_integer_field = Aux::String(&mm_tag);
353 record.push_aux(b"MM", aux_integer_field).unwrap();
354 let aux_array: AuxArray<u8> = (&ml_tag).into();
356 let aux_array_field = Aux::ArrayU8(aux_array);
357 record.push_aux(b"ML", aux_array_field).unwrap();
358 }
359}