use anyhow::{Context, bail};
use iterator::{ClusterOrRecords, find_clusters};
use lib_tsalign::a_star_aligner::{
alignment_geometry::{AlignmentCoordinates, AlignmentRange},
alignment_result::alignment::Alignment,
template_switch_distance::AlignmentType,
};
use rust_htslib::bcf;
use tokio::{pin, sync::mpsc};
use tokio_stream::StreamExt;
use tracing::{error, instrument, warn, warn_span};
use crate::{
common::{
ImmutableSequence, SequencePair,
aligner::{AlignmentOrchestrator, AlignmentQuery, cli::CliAlignmentArgs},
alignment::ForwardAlignment,
contig::ContigName,
coords::{GenomePosition, GenomeRegion},
csv::CSVAuxData,
list_of_regions::Regions,
reference::{ReferenceQueryResult, ReferenceReader},
},
vcf::pipeline::{
message::{Cluster, MaskedRecords},
reader::VCFReader,
},
};
use super::Message;
pub mod cluster;
mod iterator;
pub struct Clusterizer {
input: VCFReader,
targets: Option<Regions>,
reference: ReferenceReader,
bam_f: Option<rust_htslib::bam::IndexedReader>,
output: mpsc::Sender<Message>,
settings: ClusterSettings,
}
#[derive(clap::Args, Debug, Default)]
pub struct ClusterSettings {
#[command(flatten)]
pub cluster_strategy: cluster::ClusteringSettings,
#[command(flatten)]
pub aligner: CliAlignmentArgs,
}
impl Clusterizer {
pub(super) fn new(
input: VCFReader,
reference: ReferenceReader,
targets: Option<Regions>,
bam_f: Option<rust_htslib::bam::IndexedReader>,
output: mpsc::Sender<Message>,
settings: ClusterSettings,
) -> Self {
Self {
input,
targets,
reference,
bam_f,
output,
settings,
}
}
pub async fn run(self) {
let aligner = match AlignmentOrchestrator::try_from(&self.settings.aligner) {
Ok(aligner) => aligner,
Err(e) => {
error!("Cannot initialize aligners: {e}");
return;
}
};
let strategy = self.settings.cluster_strategy;
let vcf_name = self.input.name().to_string();
let clusters = find_clusters(strategy, self.input, self.targets, self.bam_f);
pin!(clusters);
let mut cluster_id = 0;
while let Some(e) = clusters.next().await {
let m = match e {
ClusterOrRecords::Cluster(records) => {
cluster_id += 1;
let cluster = Cluster::new(bitvec::bitvec![1; records.len()]);
let masked = cluster.apply_to_records(&records);
match prepare_cluster(&self.reference, masked, self.settings.aligner.padding)
.await
{
Ok(Some(prepared)) => {
match aligner.get_or_compute_alignment(
self.reference.get_name(),
&prepared.reference_region.clone(),
prepared.cluster_region.clone(),
prepared.query.clone(),
) {
Ok(pending) => Message::cluster(
records,
vec![(
cluster,
pending,
CSVAuxData {
cluster_id: cluster_id.to_string(),
sequences: prepared.query.sequences,
ref_context_region: prepared.reference_region.clone(),
alt_context_region: prepared.reference_region,
cluster_region: prepared.cluster_region,
alt_id: Some(vcf_name.clone()),
forward_alignment: ForwardAlignment(
prepared.fw_alignment,
),
cost: aligner.costs.clone(),
},
)],
),
Err(e) => {
error!(
"Could not get or start the computation of an alignment: {e}"
);
Message::passthrough(records)
}
}
}
Ok(None) => Message::passthrough(records),
Err((reg, err)) => {
let _s =
reg.map(|r| warn_span!("Failed to prepare", pos = %r).entered());
warn!("{err}. Passing records through.",);
Message::passthrough(records)
}
}
}
ClusterOrRecords::Records(records) => Message::passthrough(records),
ClusterOrRecords::SequenceDone => {
aligner.clear_cache();
continue;
}
};
self.output.send(m).await.expect("Channel closed !?");
}
}
}
#[derive(Clone, Debug)]
struct PreparedCluster {
query: AlignmentQuery,
reference_region: GenomeRegion,
cluster_region: GenomeRegion,
fw_alignment: Alignment<AlignmentType>,
}
async fn prepare_cluster(
reference: &ReferenceReader,
cluster: MaskedRecords<'_, '_>,
padding: usize,
) -> Result<Option<PreparedCluster>, (Option<GenomeRegion>, anyhow::Error)> {
let query_region = extract_region(cluster).map_err(|e| (None, e))?;
let Some(ReferenceQueryResult {
region: actual_region,
sequence: reference_sequence,
..
}) = reference
.get_seq(query_region.clone(), padding, padding)
.await
.map_err(|e| (Some(query_region.clone()), e))?
else {
return Ok(None);
};
let _span = warn_span!("prepare_cluster", pos = %query_region).entered();
let actual_padding_left =
query_region.start().position_0() - actual_region.start().position_0();
let actual_padding_right =
reference_sequence.len() - actual_padding_left - query_region.len().unwrap();
let (query_sequence, fw_alignment) = apply_mutations(
&reference_sequence,
actual_region.start().position_0(),
cluster.masked_iter(),
)
.map_err(|e| (Some(query_region.clone()), e))?;
let ranges = AlignmentRange::new_offset_limit(
AlignmentCoordinates::new(actual_padding_left, actual_padding_left),
AlignmentCoordinates::new(
reference_sequence.len() - actual_padding_right,
query_sequence.len() - actual_padding_right,
),
);
Ok(Some(PreparedCluster {
query: AlignmentQuery {
sequences: SequencePair {
reference: reference_sequence,
query: query_sequence,
},
ranges,
},
reference_region: actual_region,
cluster_region: query_region,
fw_alignment,
}))
}
fn extract_region(cluster: MaskedRecords<'_, '_>) -> anyhow::Result<GenomeRegion> {
let (start, end, contig) = {
let (mut start, mut end, mut contig) = (i64::MAX, 0, None);
for r in cluster.masked_iter() {
start = start.min(r.pos());
end = end.max(r.end());
if contig.is_none()
&& let Some(present) = r.rid() {
let name = r.header().rid2name(present)?;
contig = Some(ContigName::new(name));
}
}
(
usize::try_from(start)?,
usize::try_from(end)?,
contig.context("No rid present in records??")?,
)
};
let query_region = GenomeRegion::new_bounded(GenomePosition::new_0(contig, start), end - start);
Ok(query_region)
}
#[instrument(name = "build_query_sequence", skip_all)]
fn apply_mutations<'a>(
reference: &[u8],
reference_start: usize,
mutations: impl Iterator<Item = &'a bcf::Record>,
) -> anyhow::Result<(ImmutableSequence, Alignment<AlignmentType>)> {
let mut alt_sequence = Vec::new();
let mut cigar = Alignment::new();
let mut last_end = reference_start;
let reference_end = reference_start + reference.len();
let pos2off = |pos: usize| pos - reference_start;
for m in mutations {
let mut pos = usize::try_from(m.pos())?;
let end = usize::try_from(m.end())?;
let alleles = m.alleles();
if alleles.len() <= 1 {
warn!("Record at {pos} has no alt allele");
continue;
}
let mut ref_allele = alleles[0];
let mut alt_allele = alleles[1];
if !ref_allele.is_empty() && !alt_allele.is_empty() && ref_allele[0] == alt_allele[0] {
ref_allele = &ref_allele[1..];
alt_allele = &alt_allele[1..];
pos += 1;
}
if ref_allele.is_empty() && alt_allele.is_empty() {
warn!("Empty record; Skipping.");
continue;
}
if pos < last_end {
bail!(
"Overlapping mutation: last mutation ended at {last_end}, \
but this one starts at {pos}"
);
}
if end > reference_end {
warn!(
"Skipping out-of-bounds mutation at {pos}: \
record end {end} exceeds reference end {reference_end}"
);
continue;
}
if pos > end {
warn!("Skip negative-sized mutation: the position is {pos} but the end is {end}",);
continue;
}
if pos - last_end > 0 {
alt_sequence.extend_from_slice(&reference[pos2off(last_end)..pos2off(pos)]);
cigar.push_n(pos - last_end, AlignmentType::PrimaryMatch);
}
alt_sequence.extend_from_slice(alt_allele);
match (ref_allele.len(), alt_allele.len()) {
(0, 0) => {
bail!("Empty record; should have been caught earlier!");
}
(1, 1) => {
cigar.push(AlignmentType::PrimarySubstitution);
}
(0, m) => {
cigar.push_n(m, AlignmentType::PrimaryInsertion);
}
(n, 0) => {
cigar.push_n(n, AlignmentType::PrimaryDeletion);
}
(n, m) => {
let match_len = n.min(m);
for i in 0..match_len {
if ref_allele[i] == alt_allele[i] {
cigar.push(AlignmentType::PrimaryMatch);
} else {
cigar.push(AlignmentType::PrimarySubstitution);
}
}
let extra = n.max(m) - match_len;
if n > m {
cigar.push_n(extra, AlignmentType::PrimaryDeletion);
} else {
cigar.push_n(extra, AlignmentType::PrimaryInsertion);
}
}
}
last_end = end;
}
alt_sequence.extend_from_slice(&reference[pos2off(last_end)..]);
cigar.push_n(
reference.len() - pos2off(last_end),
AlignmentType::PrimaryMatch,
);
Ok((alt_sequence.into(), cigar))
}