use super::subcommands::center::CenterPosition;
use super::subcommands::center::CenteredFiberData;
use super::utils::input_bam::FiberFilters;
use super::*;
use crate::utils::bamranges::*;
use crate::utils::basemods::BaseMods;
use crate::utils::bio_io::*;
use crate::utils::ftexpression::apply_filter_fsd;
use rayon::prelude::*;
use rust_htslib::bam::Read;
use rust_htslib::{bam, bam::ext::BamRecordExtensions, bam::record::Aux, bam::HeaderView};
use std::collections::HashMap;
use std::fmt::Write;
#[derive(Debug, Clone, PartialEq)]
pub struct FiberseqData {
pub record: bam::Record,
pub msp: Ranges,
pub nuc: Ranges,
pub m6a: Ranges,
pub cpg: Ranges,
pub base_mods: BaseMods,
pub ec: f32,
pub target_name: String,
pub rg: String,
pub center_position: Option<CenterPosition>,
}
impl FiberseqData {
pub fn new(record: bam::Record, target_name: Option<&String>, filters: &FiberFilters) -> Self {
let rg = if let Ok(Aux::String(f)) = record.aux(b"RG") {
log::trace!("{f}");
f
} else {
"."
}
.to_string();
let nuc_starts = get_u32_tag(&record, b"ns");
let msp_starts = get_u32_tag(&record, b"as");
let nuc_length = get_u32_tag(&record, b"nl");
let msp_length = get_u32_tag(&record, b"al");
let nuc = Ranges::new(&record, nuc_starts, None, Some(nuc_length));
let mut msp = Ranges::new(&record, msp_starts, None, Some(msp_length));
let msp_qual = get_u8_tag(&record, b"aq");
if !msp_qual.is_empty() {
msp.set_qual(msp_qual);
}
let ec = if let Ok(Aux::Float(f)) = record.aux(b"ec") {
log::trace!("{f}");
f
} else {
0.0
};
let target_name = match target_name {
Some(t) => t.clone(),
None => ".".to_string(),
};
let base_mods = BaseMods::new(&record, filters.min_ml_score);
let m6a = base_mods.m6a();
let cpg = base_mods.cpg();
let mut fsd = FiberseqData {
record,
msp,
nuc,
m6a,
base_mods,
cpg,
ec,
target_name,
rg,
center_position: None,
};
apply_filter_fsd(&mut fsd, filters).expect("Failed to apply filter to FiberseqData");
fsd
}
pub fn dict_from_head_view(head_view: &HeaderView) -> HashMap<i32, String> {
if head_view.target_count() == 0 {
return HashMap::new();
}
let target_u8s = head_view.target_names();
let tids = target_u8s
.iter()
.map(|t| head_view.tid(t).expect("Unable to get tid"));
let target_names = target_u8s
.iter()
.map(|&a| String::from_utf8_lossy(a).to_string());
tids.zip(target_names)
.map(|(id, t)| (id as i32, t))
.collect()
}
pub fn target_name_from_tid(tid: i32, target_dict: &HashMap<i32, String>) -> Option<&String> {
target_dict.get(&tid)
}
pub fn from_records(
records: Vec<bam::Record>,
head_view: &HeaderView,
filters: &FiberFilters,
) -> Vec<Self> {
let target_dict = Self::dict_from_head_view(head_view);
records
.into_par_iter()
.map(|r| {
let tid = r.tid();
(r, Self::target_name_from_tid(tid, &target_dict))
})
.map(|(r, target_name)| Self::new(r, target_name, filters))
.collect::<Vec<_>>()
}
pub fn get_rq(&self) -> Option<f32> {
if let Ok(Aux::Float(f)) = self.record.aux(b"rq") {
Some(f)
} else {
None
}
}
pub fn get_hp(&self) -> String {
if let Ok(Aux::U8(f)) = self.record.aux(b"HP") {
format!("H{f}")
} else {
"UNK".to_string()
}
}
fn apply_offset(positions: &mut [Option<i64>], offset: i64, strand: char) {
for pos in positions.iter_mut().flatten() {
if *pos == -1 {
*pos = i64::MIN;
continue;
}
*pos -= offset;
if strand == '-' {
*pos = -*pos;
}
}
if strand == '-' {
positions.reverse();
}
}
fn offset_range(
starts: &mut [Option<i64>],
ends: &mut [Option<i64>],
offset: i64,
strand: char,
) {
FiberseqData::apply_offset(starts, offset, strand);
FiberseqData::apply_offset(ends, offset, strand);
for (start, end) in starts.iter_mut().zip(ends.iter_mut()) {
if start > end {
std::mem::swap(start, end);
}
}
}
pub fn center(&self, center_position: &CenterPosition) -> Option<Self> {
let mut new = self.clone();
let (ref_offset, mol_offset) =
CenteredFiberData::find_offsets(&self.record, center_position);
FiberseqData::apply_offset(&mut new.m6a.starts, mol_offset, center_position.strand);
FiberseqData::apply_offset(
&mut new.m6a.reference_starts,
ref_offset,
center_position.strand,
);
FiberseqData::apply_offset(&mut new.cpg.starts, mol_offset, center_position.strand);
FiberseqData::apply_offset(
&mut new.cpg.reference_starts,
ref_offset,
center_position.strand,
);
FiberseqData::offset_range(
&mut new.msp.starts,
&mut new.msp.ends,
mol_offset,
center_position.strand,
);
FiberseqData::offset_range(
&mut new.msp.reference_starts,
&mut new.msp.reference_ends,
ref_offset,
center_position.strand,
);
FiberseqData::offset_range(
&mut new.nuc.starts,
&mut new.nuc.ends,
mol_offset,
center_position.strand,
);
FiberseqData::offset_range(
&mut new.nuc.reference_starts,
&mut new.nuc.reference_ends,
ref_offset,
center_position.strand,
);
if center_position.strand == '-' {
new.m6a.qual.reverse();
new.cpg.qual.reverse();
new.msp.lengths.reverse();
new.msp.reference_lengths.reverse();
new.msp.qual.reverse();
new.nuc.lengths.reverse();
new.nuc.reference_lengths.reverse();
}
Some(new)
}
pub fn write_msp(&self, reference: bool) -> String {
let (starts, _ends, lengths) = if reference {
(
&self.msp.reference_starts,
&self.msp.reference_ends,
&self.msp.reference_lengths,
)
} else {
(&self.msp.starts, &self.msp.ends, &self.msp.lengths)
};
self.to_bed12(reference, starts, lengths, LINKER_COLOR)
}
pub fn write_nuc(&self, reference: bool) -> String {
let (starts, _ends, lengths) = if reference {
(
&self.nuc.reference_starts,
&self.nuc.reference_ends,
&self.nuc.reference_lengths,
)
} else {
(&self.nuc.starts, &self.nuc.ends, &self.nuc.lengths)
};
self.to_bed12(reference, starts, lengths, NUC_COLOR)
}
pub fn write_m6a(&self, reference: bool) -> String {
let starts = if reference {
&self.m6a.reference_starts
} else {
&self.m6a.starts
};
let lengths = vec![Some(1); starts.len()];
self.to_bed12(reference, starts, &lengths, M6A_COLOR)
}
pub fn write_cpg(&self, reference: bool) -> String {
let starts = if reference {
&self.cpg.reference_starts
} else {
&self.cpg.starts
};
let lengths = vec![Some(1); starts.len()];
self.to_bed12(reference, starts, &lengths, CPG_COLOR)
}
pub fn to_bed12(
&self,
reference: bool,
starts: &[Option<i64>],
lengths: &[Option<i64>],
color: &str,
) -> String {
if starts.is_empty() {
return "".to_string();
}
if self.record.is_unmapped() && reference {
return "".to_string();
}
let ct;
let start;
let end;
let name = String::from_utf8_lossy(self.record.qname()).to_string();
let mut rtn: String = String::with_capacity(0);
if reference {
ct = &self.target_name;
start = self.record.reference_start();
end = self.record.reference_end();
} else {
ct = &name;
start = 0;
end = self.record.seq_len() as i64;
}
let score = self.ec.round() as i64;
let strand = if self.record.is_reverse() { '-' } else { '+' };
let (filtered_starts, filtered_lengths): (Vec<i64>, Vec<i64>) = starts
.iter()
.flatten()
.zip(lengths.iter().flatten())
.unzip();
if filtered_lengths.is_empty() || filtered_starts.is_empty() {
return "".to_string();
}
let b_ct = filtered_starts.len() + 2;
let b_ln: String = filtered_lengths
.iter()
.map(|&ln| ln.to_string() + ",")
.collect();
let b_st: String = filtered_starts
.iter()
.map(|&st| (st - start).to_string() + ",")
.collect();
assert_eq!(filtered_lengths.len(), filtered_starts.len());
rtn.push_str(ct);
rtn.push('\t');
rtn.push_str(&start.to_string());
rtn.push('\t');
rtn.push_str(&end.to_string());
rtn.push('\t');
rtn.push_str(&name);
rtn.push('\t');
rtn.push_str(&score.to_string());
rtn.push('\t');
rtn.push(strand);
rtn.push('\t');
rtn.push_str(&start.to_string());
rtn.push('\t');
rtn.push_str(&end.to_string());
rtn.push('\t');
rtn.push_str(color);
rtn.push('\t');
rtn.push_str(&b_ct.to_string());
rtn.push_str("\t0,"); rtn.push_str(&b_ln);
rtn.push_str("1\t0,"); rtn.push_str(&b_st);
write!(&mut rtn, "{}", format_args!("{}\n", end - start - 1)).unwrap();
rtn
}
pub fn all_header(simplify: bool, quality: bool) -> String {
let mut x = format!(
"#{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t",
"ct", "st", "en", "fiber", "score", "strand", "sam_flag", "HP", "RG", "fiber_length",
);
if !simplify {
x.push_str("fiber_sequence\t")
}
if quality {
x.push_str("fiber_qual\t")
}
x.push_str(&format!(
"{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n",
"ec",
"rq",
"total_AT_bp",
"total_m6a_bp",
"total_nuc_bp",
"total_msp_bp",
"total_5mC_bp",
"nuc_starts",
"nuc_lengths",
"ref_nuc_starts",
"ref_nuc_lengths",
"msp_starts",
"msp_lengths",
"fire",
"ref_msp_starts",
"ref_msp_lengths",
"m6a",
"ref_m6a",
"m6a_qual",
"5mC",
"ref_5mC",
"5mC_qual"
));
x
}
pub fn write_all(&self, simplify: bool, quality: bool) -> String {
let name = std::str::from_utf8(self.record.qname()).unwrap();
let score = self.ec.round() as i64;
let q_len = self.record.seq_len() as i64;
let rq = match self.get_rq() {
Some(x) => format!("{}", x),
None => ".".to_string(),
};
let ct;
let start;
let end;
let strand;
if self.record.is_unmapped() {
ct = ".";
start = 0;
end = 0;
strand = '.';
} else {
ct = &self.target_name;
start = self.record.reference_start();
end = self.record.reference_end();
strand = if self.record.is_reverse() { '-' } else { '+' };
}
let sam_flag = self.record.flags();
let hp = self.get_hp();
let at_count = self
.record
.seq()
.as_bytes()
.iter()
.filter(|&x| *x == b'A' || *x == b'T')
.count() as i64;
let m6a_count = self.m6a.starts.len();
let m6a_qual = self.m6a.qual.iter().map(|a| Some(*a as i64)).collect();
let cpg_count = self.cpg.starts.len();
let cpg_qual = self.cpg.qual.iter().map(|a| Some(*a as i64)).collect();
let fire = self.msp.qual.iter().map(|a| Some(*a as i64)).collect();
let mut rtn = String::with_capacity(0);
rtn.write_fmt(format_args!(
"{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t",
ct, start, end, name, score, strand, sam_flag, hp, self.rg, q_len
))
.unwrap();
if !simplify {
rtn.write_fmt(format_args!(
"{}\t",
String::from_utf8_lossy(&self.record.seq().as_bytes()),
))
.unwrap();
}
if quality {
rtn.write_fmt(format_args!(
"{}\t",
String::from_utf8_lossy(
&self
.record
.qual()
.iter()
.map(|x| x + 33)
.collect::<Vec<u8>>()
),
))
.unwrap();
}
let total_nuc_bp = self.nuc.lengths.iter().flatten().sum::<i64>();
let total_msp_bp = self.msp.lengths.iter().flatten().sum::<i64>();
rtn.write_fmt(format_args!(
"{}\t{}\t{}\t{}\t{}\t{}\t{}\t",
self.ec, rq, at_count, m6a_count, total_nuc_bp, total_msp_bp, cpg_count
))
.unwrap();
for vec in &[
&self.nuc.starts,
&self.nuc.lengths,
&self.nuc.reference_starts,
&self.nuc.reference_lengths,
&self.msp.starts,
&self.msp.lengths,
&fire,
&self.msp.reference_starts,
&self.msp.reference_lengths,
&self.m6a.starts,
&self.m6a.reference_starts,
&m6a_qual,
&self.cpg.starts,
&self.cpg.reference_starts,
&cpg_qual,
] {
if vec.is_empty() {
rtn.push('.');
rtn.push('\t');
} else {
let z: String = vec
.iter()
.map(|x| match x {
Some(y) => *y,
None => -1,
})
.map(|x| x.to_string() + ",")
.collect();
rtn.write_fmt(format_args!("{}\t", z)).unwrap();
}
}
let len = rtn.len();
rtn.replace_range(len - 1..len, "\n");
rtn
}
}
pub struct FiberseqRecords<'a> {
bam_chunk: BamChunk<'a>,
header: HeaderView,
filters: FiberFilters,
cur_chunk: Vec<FiberseqData>,
}
impl<'a> FiberseqRecords<'a> {
pub fn new(bam: &'a mut bam::Reader, filters: FiberFilters) -> Self {
let header = bam.header().clone();
let bam_recs = bam.records();
let bam_chunk = BamChunk::new(bam_recs, None);
let cur_chunk: Vec<FiberseqData> = vec![];
FiberseqRecords {
bam_chunk,
header,
filters,
cur_chunk,
}
}
}
impl<'a> Iterator for FiberseqRecords<'a> {
type Item = FiberseqData;
fn next(&mut self) -> Option<Self::Item> {
if self.cur_chunk.is_empty() {
match self.bam_chunk.next() {
Some(recs) => {
self.cur_chunk = FiberseqData::from_records(recs, &self.header, &self.filters);
self.cur_chunk.reverse();
}
None => return None,
}
}
self.cur_chunk.pop()
}
}