use std::cmp::Ordering;
use async_stream::stream;
use bstr::ByteSlice;
use rust_htslib::{
bam::Read as BamRead,
bcf::{self, Read, record::GenotypeAllele},
};
use tokio_stream::Stream;
use tracing::{debug, error, trace};
use crate::{
common::{coords::GenomeRegion, list_of_regions::Regions},
counter,
vcf::pipeline::{clusterizer::cluster::ClusteringSettings, reader::VCFReader},
};
struct State {
strategy: ClusteringSettings,
record_buf: bcf::Record,
buffer: Vec<bcf::Record>,
last_rid: Option<u32>,
}
impl State {
fn new(strategy: ClusteringSettings, record_buf: bcf::Record) -> Self {
Self {
strategy,
record_buf,
buffer: vec![],
last_rid: None,
}
}
fn consume_record(&mut self, record: bcf::Record) -> Option<ClusterOrRecords> {
let rid = record.rid();
let record_belongs_to_current_buffer = self
.strategy
.belongs(&self.buffer, &record)
.unwrap_or(false);
if rid != self.last_rid {
self.last_rid = rid;
let ret = self.flush();
if record_belongs_to_current_buffer {
self.buffer.push(record);
}
ret
} else if record_belongs_to_current_buffer {
self.buffer.push(record);
None
} else {
let ret = self.flush();
self.buffer.push(record);
ret
}
}
fn flush(&mut self) -> Option<ClusterOrRecords> {
if self.strategy.is_cluster(&self.buffer) {
let buf = std::mem::take(&mut self.buffer);
Some(ClusterOrRecords::Cluster(buf))
} else if !self.buffer.is_empty() {
let buf = std::mem::take(&mut self.buffer);
Some(ClusterOrRecords::Records(buf))
} else {
None
}
}
}
pub enum ClusterOrRecords {
Cluster(Vec<bcf::Record>),
Records(Vec<bcf::Record>),
SequenceDone,
}
impl std::fmt::Debug for ClusterOrRecords {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ClusterOrRecords::Records(records) => write!(f, "{} records", records.len()),
ClusterOrRecords::Cluster(records) => write!(f, "cluster of {}", records.len()),
ClusterOrRecords::SequenceDone => write!(f, "sequence done"),
}
}
}
fn phase_cluster(
bam_reader: &mut rust_htslib::bam::IndexedReader,
cluster: Vec<bcf::Record>,
) -> ClusterOrRecords {
for rec in &cluster {
if rec.allele_count() > 2 {
error!("Record with too many alleles... discarding cluster");
return ClusterOrRecords::Records(cluster);
}
}
let a = cluster
.iter()
.map(rust_htslib::bcf::Record::pos)
.min()
.unwrap();
let chrom = cluster[0].rid().unwrap();
bam_reader.fetch((chrom, a, a + 1)).unwrap();
let mut hits = 0;
let magic_number = 4;
debug!("Phasing cluster {}, {}", a, chrom);
for rec in &cluster {
let alleles = rec.alleles();
trace!(
"{}, {}, {}",
alleles[0].as_bstr(),
alleles[1].as_bstr(),
rec.pos() - a
);
}
let u_a: u32 = a.try_into().unwrap();
for p in bam_reader.pileup() {
let Ok(p) = p else {
continue;
};
if p.pos() != u_a {
continue;
}
for aln in p.alignments() {
if hits >= magic_number {
break;
}
let q_pos = aln.qpos();
if q_pos.is_none() {
continue;
}
let off: u32 = q_pos.unwrap().try_into().unwrap();
let read = aln.record();
let cig = read.cigar();
let r_seq = read.seq();
let seq = r_seq.as_bytes();
let mut off: usize = off.try_into().unwrap();
let indel_match = cluster.iter().all(|rec| {
let ref_s = rec.alleles()[0];
let var_s = rec.alleles()[1];
let ref_l = ref_s.len() as u32;
let var_l = var_s.len() as u32;
let pos: u32 = rec.pos().try_into().unwrap();
match ref_l.cmp(&var_l) {
Ordering::Equal => true,
Ordering::Less | Ordering::Greater =>
{
match cig.read_pos(pos - 1, true, true) {
Ok(Some(a_pos)) => match cig.read_pos(pos + ref_l, true, true) {
Ok(Some(b_pos)) => a_pos + 1 + var_l == b_pos,
_ => false,
},
_ => false,
}
}
}
});
if !indel_match {
continue;
}
let c_match = cluster.iter().all(|rec| {
let s_off: usize = (rec.pos() - a).try_into().unwrap();
let ref_s = rec.alleles()[0];
let var_s = rec.alleles()[1];
let ref_l = ref_s.len();
let var_l = var_s.len();
match ref_l.cmp(&var_l) {
Ordering::Less => {
for i in 0..var_l {
if i + off + s_off >= seq.len() || var_s[i] != seq[i + off + s_off] {
return false;
}
}
off += var_l - ref_l;
}
Ordering::Equal => {
for i in 0..var_l {
if i + off + s_off >= seq.len() {
return false;
}
if var_s[i] != seq[i + off + s_off] {
return false;
}
}
}
Ordering::Greater => {
for i in 0..var_l {
if i + off + s_off >= seq.len() {
return false;
}
if seq[i + off + s_off] != var_s[i] {
return false;
}
}
off += var_l - ref_l;
}
}
true
});
hits += i32::from(c_match);
}
break;
}
if hits >= 4 {
debug!("Found consensus");
ClusterOrRecords::Cluster(cluster)
} else {
debug!("No consensus... discarding cluster");
ClusterOrRecords::Records(cluster)
}
}
pub(super) fn find_clusters(
strategy: ClusteringSettings,
mut reader: VCFReader,
targets: Option<Regions>,
mut bam_f: Option<rust_htslib::bam::IndexedReader>,
) -> impl Stream<Item = ClusterOrRecords>
where
{
let record_buf = reader.empty_record();
let mut state = State::new(strategy, record_buf);
stream! {
loop {
let read =
tokio::task::block_in_place(|| bcf::Read::read(&mut reader, &mut state.record_buf));
let event = match read {
Some(Ok(())) => {
let reg = GenomeRegion::try_from(&state.record_buf);
if targets
.as_ref()
.is_none_or(|t| reg.is_ok_and(|r| t.contains(&r)))
{
state.consume_record(state.record_buf.clone())
} else {
None }
}
None if state.buffer.is_empty() => {
if reader.done() {
break;
}
Some(ClusterOrRecords::SequenceDone)
}
None => state.flush(),
Some(Err(e)) => {
error!("Error reading record: {e}");
None }
};
let Some(event) = event else {
continue;
};
let event = match event {
ClusterOrRecords::Cluster(cluster) => {
counter!("clusters").inc(1);
let already_phase_resolved = cluster.windows(2).all(|elems| {
let gt1 = elems[0].genotypes().unwrap().get(0);
let gt2 = elems[1].genotypes().unwrap().get(0);
gt1[0] != gt1[1] && gt1 == gt2 && matches!(gt1[1], GenotypeAllele::Phased(_))
}) || cluster.windows(2).all(|elems| {
let gt1 = elems[0].genotypes().unwrap().get(0);
let gt2 = elems[1].genotypes().unwrap().get(0);
gt1[0] == gt1[1] && gt1 == gt2 && gt1[0].index() != Some(0)
}) || cluster.len() == 1;
if already_phase_resolved {
counter!("clusters.trivial").inc(1);
ClusterOrRecords::Cluster(cluster)
} else if let Some(bam_f_r) = &mut bam_f {
match phase_cluster(bam_f_r, cluster) {
c @ ClusterOrRecords::Cluster(_) => {
counter!("clusters.phased.successfully").inc(1);
c
}
r @ ClusterOrRecords::Records(_) => {
counter!("clusters.phased.unsuccessfully").inc(1);
r
}
d => d,
}
} else {
ClusterOrRecords::Records(cluster)
}
}
records_or_done => records_or_done,
};
yield event;
}
}
}