use super::bamranges::*;
use super::bio_io::*;
use bio::alphabets::dna::revcomp;
use lazy_static::lazy_static;
use regex::Regex;
use rust_htslib::{
bam,
bam::record::{Aux, AuxArray},
};
use std::collections::HashMap;
use std::convert::TryFrom;
#[derive(Eq, PartialEq, Debug, PartialOrd, Ord, Clone)]
pub struct BaseMod {
pub modified_base: u8,
pub strand: char,
pub modification_type: char,
pub ranges: Ranges,
pub record_is_reverse: bool,
}
impl BaseMod {
pub fn new(
record: &bam::Record,
modified_base: u8,
strand: char,
modification_type: char,
modified_bases_forward: Vec<i64>,
modified_probabilities_forward: Vec<u8>,
) -> Self {
let tmp = modified_bases_forward.clone();
let mut ranges = Ranges::new(record, modified_bases_forward, None, None);
ranges.set_qual(modified_probabilities_forward);
let record_is_reverse = record.is_reverse();
assert_eq!(tmp, ranges.get_forward_starts(), "forward starts not equal");
Self {
modified_base,
strand,
modification_type,
ranges,
record_is_reverse,
}
}
pub fn is_m6a(&self) -> bool {
self.modification_type == 'a'
}
pub fn is_cpg(&self) -> bool {
self.modification_type == 'm'
}
}
#[derive(Eq, PartialEq, Debug, Clone)]
pub struct BaseMods {
pub base_mods: Vec<BaseMod>,
}
impl BaseMods {
pub fn new(record: &bam::Record, min_ml_score: u8) -> BaseMods {
BaseMods::my_mm_ml_parser(record, min_ml_score)
}
pub fn my_mm_ml_parser(record: &bam::Record, min_ml_score: u8) -> BaseMods {
lazy_static! {
static ref MM_RE: Regex =
Regex::new(r"((([ACGTUN])([-+])([a-z]+|[0-9]+))[.?]?((,[0-9]+)*;)*)").unwrap();
}
let mut rtn = vec![];
let ml_tag = get_u8_tag(record, b"ML");
let mut num_mods_seen = 0;
if let Ok(Aux::String(mm_text)) = record.aux(b"MM") {
for cap in MM_RE.captures_iter(mm_text) {
let mod_base = cap.get(3).map(|m| m.as_str().as_bytes()[0]).unwrap();
let mod_strand = cap.get(4).map_or("", |m| m.as_str());
let modification_type = cap.get(5).map_or("", |m| m.as_str());
let mod_dists_str = cap.get(6).map_or("", |m| m.as_str());
let mod_dists: Vec<i64> = mod_dists_str
.trim_end_matches(';')
.split(',')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.map(|s| s.parse().unwrap())
.collect();
let forward_bases = if record.is_reverse() {
revcomp(record.seq().as_bytes())
} else {
record.seq().as_bytes()
};
log::trace!(
"mod_base: {}, mod_strand: {}, modification_type: {}, mod_dists: {:?}",
mod_base as char,
mod_strand,
modification_type,
mod_dists
);
let mut cur_mod_idx = 0;
let mut cur_seq_idx = 0;
let mut dist_from_last_mod_base = 0;
let mut unfiltered_modified_positions: Vec<i64> = vec![0; mod_dists.len()];
while cur_seq_idx < forward_bases.len() && cur_mod_idx < mod_dists.len() {
let cur_base = forward_bases[cur_seq_idx];
if cur_base == mod_base && dist_from_last_mod_base == mod_dists[cur_mod_idx] {
unfiltered_modified_positions[cur_mod_idx] =
i64::try_from(cur_seq_idx).unwrap();
dist_from_last_mod_base = 0;
cur_mod_idx += 1;
} else if cur_base == mod_base {
dist_from_last_mod_base += 1
}
cur_seq_idx += 1;
}
assert_eq!(
cur_mod_idx,
mod_dists.len(),
"{:?} {}",
String::from_utf8_lossy(record.qname()),
record.is_reverse()
);
let num_mods_cur_end = num_mods_seen + unfiltered_modified_positions.len();
let unfiltered_modified_probabilities = if num_mods_cur_end > ml_tag.len() {
let needed_num_of_zeros = num_mods_cur_end - ml_tag.len();
let mut to_add = vec![0; needed_num_of_zeros];
let mut has = ml_tag[num_mods_seen..ml_tag.len()].to_vec();
has.append(&mut to_add);
log::warn!(
"ML tag is too short for the number of modifications found in the MM tag. Assuming an ML value of 0 after the first {num_mods_cur_end} modifications."
);
has
} else {
ml_tag[num_mods_seen..num_mods_cur_end].to_vec()
};
num_mods_seen = num_mods_cur_end;
assert_eq!(
unfiltered_modified_positions.len(),
unfiltered_modified_probabilities.len()
);
let (modified_probabilities, modified_positions): (Vec<u8>, Vec<i64>) =
unfiltered_modified_probabilities
.iter()
.zip(unfiltered_modified_positions.iter())
.filter(|(&ml, &_mm)| ml >= min_ml_score)
.unzip();
if modified_positions.is_empty() {
continue;
}
let mods = BaseMod::new(
record,
mod_base,
mod_strand.chars().next().unwrap(),
modification_type.chars().next().unwrap(),
modified_positions,
modified_probabilities,
);
rtn.push(mods);
}
} else {
log::debug!("No MM tag found");
}
if ml_tag.len() != num_mods_seen {
log::warn!(
"ML tag ({}) different number than MM tag ({}).",
ml_tag.len(),
num_mods_seen
);
}
rtn.sort();
BaseMods { base_mods: rtn }
}
pub fn hashmap_to_basemods(
map: HashMap<(i32, i32, i32), Vec<(i64, u8)>>,
record: &bam::Record,
) -> BaseMods {
let mut rtn = vec![];
for (mod_info, mods) in map {
let mod_base = mod_info.0 as u8;
let mod_type = mod_info.1 as u8 as char;
let mod_strand = if mod_info.2 == 0 { '+' } else { '-' };
let (mut positions, mut qualities): (Vec<i64>, Vec<u8>) = mods.into_iter().unzip();
if record.is_reverse() {
let length = record.seq_len() as i64;
positions = positions
.into_iter()
.rev()
.map(|p| length - p - 1)
.collect();
qualities.reverse();
}
let mods = BaseMod::new(record, mod_base, mod_strand, mod_type, positions, qualities);
rtn.push(mods);
}
rtn.sort();
BaseMods { base_mods: rtn }
}
pub fn drop_m6a(&mut self) {
self.base_mods.retain(|bm| !bm.is_m6a());
}
pub fn drop_cpg(&mut self) {
self.base_mods.retain(|bm| !bm.is_cpg());
}
pub fn m6a(&self) -> Ranges {
let ranges = self
.base_mods
.iter()
.filter(|x| x.is_m6a())
.map(|x| &x.ranges)
.collect();
Ranges::merge_ranges(ranges)
}
pub fn cpg(&self) -> Ranges {
let ranges = self
.base_mods
.iter()
.filter(|x| x.is_cpg())
.map(|x| &x.ranges)
.collect();
Ranges::merge_ranges(ranges)
}
pub fn add_mm_and_ml_tags(&self, record: &mut bam::Record) {
let mut ml_tag: Vec<u8> = vec![];
let mut mm_tag = "".to_string();
let mut seq = record.seq().as_bytes();
if record.is_reverse() {
seq = revcomp(seq);
}
for basemod in self.base_mods.iter() {
ml_tag.extend(basemod.ranges.get_forward_quals());
let mut cur_mm = vec![];
let positions = basemod.ranges.get_forward_starts();
let mut last_pos = 0;
for pos in positions {
let u_pos = pos as usize;
let mut in_between = 0;
if last_pos < u_pos {
for base in seq[last_pos..u_pos].iter() {
if *base == basemod.modified_base {
in_between += 1;
}
}
}
last_pos = u_pos + 1;
cur_mm.push(in_between);
}
mm_tag.push(basemod.modified_base as char);
mm_tag.push(basemod.strand);
mm_tag.push(basemod.modification_type);
for diff in cur_mm {
mm_tag.push_str(&format!(",{}", diff));
}
mm_tag.push(';')
}
log::trace!(
"{}\n{}\n{}\n",
record.is_reverse(),
mm_tag,
String::from_utf8_lossy(&seq)
);
record.remove_aux(b"MM").unwrap_or(());
record.remove_aux(b"ML").unwrap_or(());
let aux_integer_field = Aux::String(&mm_tag);
record.push_aux(b"MM", aux_integer_field).unwrap();
let aux_array: AuxArray<u8> = (&ml_tag).into();
let aux_array_field = Aux::ArrayU8(aux_array);
record.push_aux(b"ML", aux_array_field).unwrap();
}
}
#[cfg(test)]
mod tests {
use super::*;
use env_logger::{Builder, Target};
use log;
use rust_htslib::{bam, bam::Read};
#[test]
fn test_mods_do_not_change() {
Builder::new()
.target(Target::Stderr)
.filter(None, log::LevelFilter::Debug)
.init();
let mut bam = bam::Reader::from_path(&"tests/data/all.bam").unwrap();
for rec in bam.records() {
let mut rec = rec.unwrap();
let mods = BaseMods::new(&rec, 0);
mods.add_mm_and_ml_tags(&mut rec);
let mods_2 = BaseMods::new(&rec, 0);
assert_eq!(mods, mods_2);
}
}
}