use super::cli::FireOptions;
use super::fiber::FiberseqData;
use super::*;
use crate::fiber::FiberseqRecords;
use anyhow;
use bam::record::{Aux, AuxArray};
use decorator::get_fire_color;
use derive_builder::Builder;
use gbdt::decision_tree::{Data, DataVec};
use gbdt::gradient_boost::GBDT;
use itertools::Itertools;
use ordered_float::OrderedFloat;
use rayon::prelude::*;
use serde::Deserialize;
use std::collections::BTreeMap;
use std::fs;
use tempfile::NamedTempFile;
pub static FIRE_MODEL: &str = include_str!("../models/FIRE.gbdt.json");
pub static FIRE_CONF_JSON: &str = include_str!("../models/FIRE.conf.json");
fn get_model(fire_opts: &FireOptions) -> (GBDT, MapPrecisionValues) {
let mut remove_temp_file = false;
let (model_file, fdr_table_file) = match (&fire_opts.model, &fire_opts.fdr_table) {
(Some(model_file), Some(b)) => {
let fdr_table = fs::read_to_string(b).expect("Unable to read file");
(model_file.clone(), fdr_table)
}
_ => {
let temp_file = NamedTempFile::new().expect("Unable to make a temp file");
let (mut temp_file, path) = temp_file.keep().expect("Unable to keep temp file");
let temp_file_name = path
.as_os_str()
.to_str()
.expect("Unable to convert the path of the named temp file to an &str.");
temp_file
.write_all(FIRE_MODEL.as_bytes())
.expect("Unable to write file");
remove_temp_file = true;
(temp_file_name.to_string(), FIRE_CONF_JSON.to_string())
}
};
log::info!("Using model: {}", model_file);
let model =
GBDT::from_xgboost_dump(&model_file, "binary:logistic").expect("failed to load FIRE model");
if remove_temp_file {
fs::remove_file(model_file).expect("Unable to remove temp file");
}
let precision_table: PrecisionTable =
serde_json::from_str(&fdr_table_file).expect("Precision table JSON was not well-formatted");
let precision_converter = MapPrecisionValues::new(&precision_table);
(model, precision_converter)
}
fn get_mid_point(start: i64, end: i64) -> i64 {
(start + end) / 2
}
fn get_at_count(seq: &[u8], start: i64, end: i64) -> usize {
let subseq = &seq[start as usize..end as usize];
subseq
.iter()
.filter(|&&bp| bp == b'T' || bp == b'A')
.count()
}
pub fn get_bins(mid_point: i64, bin_num: i64, bin_width: i64, max_end: i64) -> Vec<(i64, i64)> {
let mut bins = Vec::new();
for i in 0..bin_num {
let mut bin_start = mid_point - (bin_num / 2 - i) * bin_width - bin_width / 2;
let mut bin_end = bin_start + bin_width;
if bin_start < 0 {
bin_start = 0;
}
if bin_end < 0 {
bin_end = 0;
}
if bin_start > max_end {
bin_start = max_end - 1;
}
if bin_end > max_end {
bin_end = max_end;
}
bins.push((bin_start, bin_end));
}
bins
}
fn get_5mc_count(rec: &FiberseqData, start: i64, end: i64) -> usize {
rec.cpg
.starts
.iter()
.flatten()
.filter(|&&pos| pos >= start && pos < end)
.count()
}
fn get_m6a_count(rec: &FiberseqData, start: i64, end: i64) -> usize {
rec.m6a
.starts
.iter()
.flatten()
.filter(|&&pos| pos >= start && pos < end)
.count()
}
fn get_m6a_rle_data(rec: &FiberseqData, start: i64, end: i64) -> (f32, f32) {
let mut m6a_rles = vec![];
let mut max = 0;
let mut _max_pos = 0;
let mut weighted_rle = 0.0;
for (m6a_1, m6a_2) in rec.m6a.starts.iter().flatten().tuple_windows() {
if *m6a_1 < start || *m6a_1 > end || *m6a_2 < start || *m6a_2 > end {
continue;
}
let rle = (m6a_2 - m6a_1).abs();
m6a_rles.push(rle);
if rle > max {
max = rle;
_max_pos = ((*m6a_1 + *m6a_2) / 2 - (end + start) / 2).abs();
}
weighted_rle += (rle * rle) as f32;
}
weighted_rle /= (end - start) as f32;
if m6a_rles.is_empty() {
return (-1.0, -1.0);
}
let mid_length = m6a_rles.len() / 2;
let (_, median, _) = m6a_rles.select_nth_unstable(mid_length);
(weighted_rle, *median as f32)
}
const FEATS_IN_USE: [&str; 3] = [
"m6a_count",
"frac_m6a",
"m6a_fc",
];
#[derive(Debug, Clone, Builder)]
pub struct FireFeatsInRange {
pub m6a_count: f32,
pub at_count: f32,
#[allow(unused)]
pub count_5mc: f32,
pub frac_m6a: f32,
pub m6a_fc: f32,
pub weighted_m6a_rle: f32,
pub median_m6a_rle: f32,
}
impl FireFeatsInRange {
pub fn header(tag: &str) -> String {
let mut out = "".to_string();
for col in FEATS_IN_USE.iter() {
out += &format!("\t{}_{}", tag, col);
}
out
}
}
#[derive(Debug)]
pub struct FireFeats<'a> {
rec: &'a FiberseqData,
#[allow(unused)]
at_count: usize,
m6a_count: usize,
frac_m6a: f32,
fire_opts: &'a FireOptions,
seq: Vec<u8>,
fire_feats: Vec<(i64, i64, Vec<f32>)>,
}
impl<'a> FireFeats<'a> {
pub fn new(rec: &'a FiberseqData, fire_opts: &'a FireOptions) -> Self {
let seq_len = rec.record.seq_len();
let seq = rec.record.seq().as_bytes();
log::trace!("new FireFeats {}", seq_len);
let at_count = get_at_count(&seq, 0, seq_len as i64);
let m6a_count = get_m6a_count(rec, 0, seq_len as i64);
log::trace!("new FireFeats");
let frac_m6a = if at_count > 0 {
m6a_count as f32 / at_count as f32
} else {
0.0
};
let mut rtn = Self {
rec,
at_count,
m6a_count,
frac_m6a,
fire_opts,
seq,
fire_feats: vec![],
};
rtn.get_fire_features();
rtn
}
fn m6a_fc_over_expected(&self, m6a_count: usize, at_count: usize) -> f32 {
let expected = self.frac_m6a * at_count as f32;
let observed = m6a_count as f32;
if expected == 0.0 || observed == 0.0 {
return 0.0;
}
let fc = observed / expected;
fc.log2()
}
fn feats_in_range(&self, start: i64, end: i64) -> FireFeatsInRange {
let m6a_count = get_m6a_count(self.rec, start, end);
let at_count = get_at_count(&self.seq, start, end);
let count_5mc = get_5mc_count(self.rec, start, end);
let frac_m6a = if at_count > 0 {
m6a_count as f32 / at_count as f32
} else {
0.0
};
let m6a_fc = self.m6a_fc_over_expected(m6a_count, at_count);
let (weighted_m6a_rle, median_m6a_rle) = get_m6a_rle_data(self.rec, start, end);
FireFeatsInRange {
m6a_count: m6a_count as f32,
at_count: at_count as f32,
count_5mc: count_5mc as f32,
frac_m6a,
m6a_fc,
weighted_m6a_rle,
median_m6a_rle,
}
}
pub fn fire_feats_header(fire_opts: &FireOptions) -> String {
let mut out = "#chrom\tstart\tend\tfiber".to_string();
out += "\tmsp_len\tmsp_len_times_m6a_fc\tccs_passes";
out += "\tfiber_m6a_count\tfiber_m6a_frac";
out += &FireFeatsInRange::header("msp");
out += &FireFeatsInRange::header("best");
out += &FireFeatsInRange::header("worst");
for bin_num in 0..fire_opts.bin_num {
out += &FireFeatsInRange::header(&format!("bin_{}", bin_num));
}
out += "\n";
out
}
fn msp_get_fire_features(&self, start: i64, end: i64) -> Vec<f32> {
let msp_len = end - start;
if msp_len < self.fire_opts.min_msp_length_for_positive_fire_call {
return vec![];
}
let ccs_passes = self.rec.ec;
let mut max_m6a_count = 0;
let mut max_m6a_start = 0;
let mut max_m6a_end = 0;
let mut min_m6a_count = usize::MAX;
let mut min_m6a_start = 0;
let mut min_m6a_end = 0;
let mut centering_pos = get_mid_point(start, end);
for st_idx in start..end {
let en_idx = (st_idx + self.fire_opts.best_window_size).min(end);
if (end - start) < (2 * self.fire_opts.best_window_size) {
log::trace!("MSP window is not large enough for best and worst window analysis");
break;
}
let m6a_count = get_m6a_count(self.rec, st_idx, en_idx);
if m6a_count > max_m6a_count {
max_m6a_count = m6a_count;
max_m6a_start = st_idx;
max_m6a_end = en_idx;
centering_pos = get_mid_point(st_idx, en_idx);
}
if m6a_count < min_m6a_count {
min_m6a_count = m6a_count;
min_m6a_start = st_idx;
min_m6a_end = en_idx;
}
if en_idx == end {
break;
}
}
let best_fire_feats = self.feats_in_range(max_m6a_start, max_m6a_end);
let worst_fire_feats = self.feats_in_range(min_m6a_start, min_m6a_end);
let msp_feats = self.feats_in_range(start, end);
let bins = get_bins(
centering_pos,
self.fire_opts.bin_num,
self.fire_opts.width_bin,
self.rec.record.seq_len() as i64,
);
let bin_feats = bins
.into_iter()
.map(|(start, end)| self.feats_in_range(start, end))
.collect::<Vec<FireFeatsInRange>>();
let msp_len_times_m6a_fc = msp_feats.m6a_fc * (msp_len as f32);
let mut rtn = vec![
msp_len as f32,
msp_len_times_m6a_fc,
ccs_passes,
self.m6a_count as f32,
self.frac_m6a,
];
let feat_sets = vec![&msp_feats, &best_fire_feats, &worst_fire_feats]
.into_iter()
.chain(bin_feats.iter());
for feat_set in feat_sets {
rtn.push(feat_set.m6a_count);
rtn.push(feat_set.frac_m6a);
rtn.push(feat_set.m6a_fc);
}
rtn
}
pub fn get_fire_features(&mut self) {
let msp_data = self.rec.msp.into_iter().collect_vec();
self.fire_feats = msp_data
.into_iter()
.map(|(s, e, _l, refs)| {
let (rs, re, _rl) = refs.unwrap_or((0, 0, 0));
(rs, re, self.msp_get_fire_features(s, e))
})
.collect();
}
pub fn dump_fire_feats(&self, out_buffer: &mut Box<dyn Write>) -> Result<(), anyhow::Error> {
for (s, e, row) in self.fire_feats.iter() {
if row.is_empty() {
continue;
}
let lead_feats = format!(
"{}\t{}\t{}\t{}\t",
self.rec.target_name,
s,
e,
String::from_utf8_lossy(self.rec.record.qname())
);
out_buffer.write_all(lead_feats.as_bytes())?;
out_buffer.write_all(row.iter().join("\t").as_bytes())?;
out_buffer.write_all(b"\n")?;
}
Ok(())
}
pub fn predict_with_xgb(
&self,
gbdt_model: &GBDT,
precision_converter: &MapPrecisionValues,
) -> Vec<u8> {
let count = self.fire_feats.len();
if count == 0 {
return vec![];
}
let mut gbdt_data: DataVec = Vec::new();
for (_st, _en, window) in self.fire_feats.iter() {
if window.is_empty() {
continue;
}
let d = Data::new_test_data(window.to_vec(), None);
gbdt_data.push(d);
}
let predictions_without_short_ones = gbdt_model.predict(&gbdt_data);
let mut precisions = Vec::with_capacity(count);
let mut cur_pos = 0;
for (_st, _en, window) in self.fire_feats.iter() {
if window.is_empty() {
precisions.push(0);
} else {
let precision = precision_converter
.precision_from_float(predictions_without_short_ones[cur_pos]);
precisions.push(precision);
cur_pos += 1;
}
}
assert_eq!(cur_pos, predictions_without_short_ones.len());
assert_eq!(precisions.len(), count);
precisions
}
}
#[derive(Debug, Deserialize)]
pub struct PrecisionTable {
pub columns: Vec<String>,
pub data: Vec<(f32, f32)>,
}
pub struct MapPrecisionValues {
pub map: BTreeMap<OrderedFloat<f32>, u8>,
}
impl MapPrecisionValues {
pub fn new(pt: &PrecisionTable) -> Self {
let mut map = BTreeMap::new();
for (mokapot_score, mokapot_q_value) in pt.data.iter() {
let precision = ((1.0 - mokapot_q_value) * 255.0).round() as u8;
map.insert(OrderedFloat(*mokapot_score), precision);
}
map.insert(
OrderedFloat(0.0),
*map.get(&OrderedFloat(0.0)).unwrap_or(&0),
);
Self { map }
}
pub fn precision_from_float(&self, value: f32) -> u8 {
let key = OrderedFloat(value);
let (less_key, less_val) = self
.map
.range(..key)
.next_back()
.unwrap_or((&OrderedFloat(0.0), &0));
let (more_key, more_val) = self
.map
.range(key..)
.next()
.unwrap_or((&OrderedFloat(1.0), &255));
if (more_key - key).abs() < (less_key - key).abs() {
*more_val
} else {
*less_val
}
}
}
pub fn add_fire_to_rec(
rec: &mut FiberseqData,
fire_opts: &FireOptions,
model: &GBDT,
precision_table: &MapPrecisionValues,
) {
let fire_feats = FireFeats::new(rec, fire_opts);
let mut precisions = fire_feats.predict_with_xgb(model, precision_table);
if rec.record.is_reverse() {
precisions.reverse();
}
let aux_array: AuxArray<u8> = (&precisions).into();
let aux_array_field = Aux::ArrayU8(aux_array);
rec.record.remove_aux(b"aq").unwrap_or(()); rec.record
.push_aux(b"aq", aux_array_field)
.expect("Cannot add FIRE precision to bam");
log::trace!("precisions: {:?}", precisions);
}
pub fn add_fire_to_bam(fire_opts: &FireOptions) -> Result<(), anyhow::Error> {
let (model, precision_table) = get_model(fire_opts);
let mut bam = bio_io::bam_reader(&fire_opts.bam, 8);
if fire_opts.feats_to_text {
let mut first = true;
let mut out_buffer = bio_io::writer(&fire_opts.out)?;
for chunk in &FiberseqRecords::new(&mut bam, 0).chunks(1_000) {
if first {
out_buffer.write_all(FireFeats::fire_feats_header(fire_opts).as_bytes())?;
first = false;
}
let chunk: Vec<FiberseqData> = chunk.collect();
let feats: Vec<FireFeats> =
chunk.iter().map(|r| FireFeats::new(r, fire_opts)).collect();
feats.iter().for_each(|f| {
f.dump_fire_feats(&mut out_buffer).unwrap();
});
}
}
else if fire_opts.extract {
fire_to_bed9(fire_opts, &mut bam)?;
}
else {
let mut out = bam_writer(&fire_opts.out, &bam, 8);
let mut skip_because_no_m6a = 0;
let mut skip_because_num_msp = 0;
let mut skip_because_ave_msp_length = 0;
for recs in &FiberseqRecords::new(&mut bam, 0).chunks(2_000) {
let mut recs: Vec<FiberseqData> = recs.collect();
recs.par_iter_mut().for_each(|r| {
add_fire_to_rec(r, fire_opts, &model, &precision_table);
});
for rec in recs {
let n_msps = rec.msp.starts.len();
if fire_opts.skip_no_m6a || fire_opts.min_msp > 0 || fire_opts.min_ave_msp_size > 0
{
if rec.m6a.starts.is_empty() || n_msps == 0 {
skip_because_no_m6a += 1;
continue;
}
if n_msps < fire_opts.min_msp {
skip_because_num_msp += 1;
continue;
}
let ave_msp_size =
rec.msp.lengths.iter().flatten().sum::<i64>() / n_msps as i64;
if ave_msp_size < fire_opts.min_ave_msp_size {
skip_because_ave_msp_length += 1;
continue;
}
}
out.write(&rec.record)?;
}
}
log::info!(
"Skipped {} records because they had an average MSP length less than {}; {} records because they had fewer than {} MSPs; and {} records because they had no m6A sites",
skip_because_ave_msp_length,
fire_opts.min_ave_msp_size,
skip_because_num_msp,
fire_opts.min_msp,
skip_because_no_m6a,
);
}
Ok(())
}
pub fn fire_to_bed9(fire_opts: &FireOptions, bam: &mut bam::Reader) -> Result<(), anyhow::Error> {
let mut out_buffer = bio_io::writer(&fire_opts.out)?;
for rec in FiberseqRecords::new(bam, 0) {
if rec.record.is_secondary() || rec.record.is_supplementary() || rec.record.is_unmapped() {
continue;
}
let start_iter = rec
.msp
.reference_starts
.iter()
.chain(rec.nuc.reference_starts.iter());
let end_iter = rec
.msp
.reference_ends
.iter()
.chain(rec.nuc.reference_ends.iter());
let qual_iter = rec.msp.qual.iter().chain(rec.nuc.qual.iter());
let n_msps = rec.msp.reference_starts.len();
for (count, ((start, end), qual)) in start_iter.zip(end_iter).zip(qual_iter).enumerate() {
if let (Some(start), Some(end)) = (start, end) {
let fdr = if count < n_msps {
100.0 - *qual as f32 / 255.0 * 100.0
} else {
101.0
};
let color = get_fire_color(fdr);
let bed9 = format!(
"{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n",
rec.target_name,
start,
end,
String::from_utf8_lossy(rec.record.qname()),
fdr.round(),
if rec.record.is_reverse() { "-" } else { "+" },
start,
end,
color,
fdr / 100.0,
rec.get_hp()
);
out_buffer.write_all(bed9.as_bytes())?;
}
}
}
Ok(())
}