use std::{collections::HashMap, fs};
use crossbeam::channel::Receiver;
use gskits::{gsbam::bam_record_ext::BamRecord, pbar};
use mm2::AlignResult;
use rust_htslib::bam::ext::BamRecordExtensions;
pub mod hsc_csbq;
pub mod smc_csbq;
pub static ALL_BASES: [u8; 4] = ['A' as u8, 'C' as u8, 'G' as u8, 'T' as u8];
pub static INS: u8 = '+' as u8;
pub static DEL: u8 = '-' as u8;
pub struct Model {
refcalled2prob: HashMap<String, f32>,
}
impl Model {
pub fn new(fname: &str) -> Self {
let refcalled2prob = fs::read_to_string(fname)
.unwrap()
.split("\n")
.into_iter()
.map(|line| {
let (left, right) = line.trim().split_once(" ").unwrap();
(left.to_string(), right.parse::<f32>().unwrap())
})
.collect::<HashMap<_, _>>();
Self { refcalled2prob }
}
pub fn get_prob(&self, key: &str) -> f32 {
*self.refcalled2prob.get(key).unwrap()
}
pub fn get_prob_with_ref_called_base(&self, ref_base: u8, called_base: u8) -> f32 {
let key = unsafe { String::from_utf8_unchecked(vec![ref_base, called_base]) };
self.get_prob(&key)
}
pub fn get_prob_with_ref_base(&self, ref_base: u8) -> f32 {
let key = unsafe { String::from_utf8_unchecked(vec![ref_base]) };
self.get_prob(&key)
}
}
#[derive(Debug, Clone, Copy)]
pub enum PlpState {
Eq(u8),
Diff(u8), Del,
Ins(u8), }
pub struct LocusInfo {
pos: usize,
cur_base: u8,
plp_infos: Vec<PlpState>,
}
impl LocusInfo {
pub fn new(pos: usize, base: u8) -> Self {
Self {
pos,
cur_base: base,
plp_infos: vec![],
}
}
pub fn push_plp_state(&mut self, plp_state: PlpState) {
self.plp_infos.push(plp_state);
}
pub fn get_other_bases(&self) -> Vec<u8> {
ALL_BASES
.iter()
.copied()
.filter(|base| *base != self.cur_base)
.collect()
}
pub fn get_pos(&self) -> usize {
self.pos
}
}
pub fn cali_worker_for_hsc(
recv: Receiver<AlignResult>,
all_contig_locus_info: &mut HashMap<i32, Vec<LocusInfo>>,
model: &Model,
use_pbar: bool,
) -> HashMap<i32, Vec<u8>> {
let pb = if use_pbar {
Some(pbar::get_spin_pb(
"collect plp info".to_string(),
pbar::DEFAULT_INTERVAL,
))
} else {
None
};
for align_res in recv {
pb.as_ref().map(|pb_| pb_.inc(1));
for record in align_res.records {
let tid = record.tid();
let single_contig_locus_info = all_contig_locus_info.get_mut(&tid).unwrap();
collect_plp_info_from_record(&record, single_contig_locus_info);
}
}
pb.as_ref().map(|pb_| pb_.finish());
let pb = if use_pbar {
Some(pbar::get_spin_pb(
" do calibration".to_string(),
pbar::DEFAULT_INTERVAL,
))
} else {
None
};
let calibrated_qual: HashMap<i32, Vec<u8>> = all_contig_locus_info
.iter()
.map(|(tid, single_contig_locus_info)| {
pb.as_ref().map(|pb_| pb_.inc(1));
let qual = calibrate_single_contig_use_bayes(single_contig_locus_info, model);
(*tid, qual)
})
.collect::<HashMap<_, _>>();
pb.as_ref().map(|pb_| pb_.finish());
calibrated_qual
}
pub fn collect_plp_info_from_record(
record: &BamRecord,
single_contig_locus_info: &mut Vec<LocusInfo>,
) {
if record.is_secondary() || record.is_unmapped() || record.is_supplementary() {
return;
}
let ref_start = record.reference_start();
let ref_end = record.reference_end();
let mut rpos_cursor = None;
let mut qpos_cursor = None;
let query_seq = record.seq().as_bytes();
for [qpos, rpos] in record.aligned_pairs_full() {
if qpos.is_some() {
qpos_cursor = qpos;
}
if rpos.is_some() {
rpos_cursor = rpos;
}
if let Some(rpos_cursor_) = rpos_cursor {
if rpos_cursor_ < ref_start {
continue;
}
if rpos_cursor_ >= ref_end {
break;
}
} else {
continue;
}
if qpos_cursor.is_none() {
continue;
}
let ref_pos_cur_or_pre = rpos_cursor.unwrap() as usize;
let locus_info = unsafe { single_contig_locus_info.get_unchecked_mut(ref_pos_cur_or_pre) };
if qpos.is_none() {
locus_info.push_plp_state(PlpState::Del);
continue;
}
if rpos.is_none() {
locus_info.push_plp_state(PlpState::Ins(unsafe {
*query_seq.get_unchecked(qpos.unwrap() as usize)
}));
continue;
}
unsafe {
if locus_info.cur_base == *query_seq.get_unchecked(qpos.unwrap() as usize) {
locus_info.push_plp_state(PlpState::Eq(
*query_seq.get_unchecked(qpos.unwrap() as usize),
));
} else {
locus_info.push_plp_state(PlpState::Diff(
*query_seq.get_unchecked(qpos.unwrap() as usize),
));
}
}
if ref_pos_cur_or_pre == (ref_end as usize - 1) {
break;
}
}
}
pub fn calibrate_single_contig_use_bayes(
single_contig_locus_info: &Vec<LocusInfo>,
model: &Model,
) -> Vec<u8> {
let qual = single_contig_locus_info
.iter()
.map(|locus_info| single_locus_bayes(locus_info, model))
.collect::<Vec<_>>();
qual
}
pub fn join_prob(cur_base: u8, plp_infos: &Vec<PlpState>, model: &Model) -> f32 {
if plp_infos.len() == 0 {
return 1e-10;
}
let value = plp_infos
.iter()
.map(|plp_state| match *plp_state {
PlpState::Eq(called_base) => model
.get_prob_with_ref_called_base(cur_base, called_base)
.ln(),
PlpState::Diff(called_base) => model
.get_prob_with_ref_called_base(cur_base, called_base)
.ln(),
PlpState::Ins(_) => model.get_prob_with_ref_called_base(cur_base, INS).ln(),
PlpState::Del => model.get_prob_with_ref_called_base(cur_base, DEL).ln(),
})
.reduce(|acc, cur| acc + cur)
.unwrap()
+ model.get_prob_with_ref_base(cur_base).ln();
value.exp()
}
pub fn single_locus_bayes(locus_info: &LocusInfo, model: &Model) -> u8 {
let numerator = join_prob(locus_info.cur_base, &locus_info.plp_infos, model);
if numerator < 1e-9 {
return 0;
}
let denominator = ALL_BASES
.into_iter()
.map(|cur_base| join_prob(cur_base, &locus_info.plp_infos, model))
.reduce(|acc, v| {
acc + v
})
.unwrap();
let mut posterior = numerator / denominator;
posterior = if posterior > (1. - 1e-5) {
1. - 1e-5
} else {
posterior
};
let phreq = (-10. * (1. - posterior).log10()) as u8;
phreq
}
#[cfg(test)]
mod test {
use crate::{join_prob, single_locus_bayes, LocusInfo, Model, PlpState};
#[test]
fn test_join_prob() {
let cur_base = 'A' as u8;
let model = Model::new("model/default.txt");
let plp_infos = vec![PlpState::Eq('A' as u8), PlpState::Eq('A' as u8)];
let prob = join_prob(cur_base, &plp_infos, &model);
assert!((prob - 0.2025).abs() < 1e-4);
}
#[test]
fn test_single_locus_bayes() {
let cur_base = 'A' as u8;
let model = Model::new("model/default.txt");
let mut locus_info = LocusInfo::new(0, cur_base);
locus_info.push_plp_state(PlpState::Eq('A' as u8));
let q = single_locus_bayes(&locus_info, &model);
assert_eq!(q, 12);
locus_info.push_plp_state(PlpState::Diff('C' as u8));
let q = single_locus_bayes(&locus_info, &model);
assert_eq!(q, 2);
}
}