use anyhow::{Context, Error};
use iterator::{ClusterOrRecords, find_clusters};
use lib_tsalign::a_star_aligner::alignment_geometry::{AlignmentCoordinates, AlignmentRange};
use rust_htslib::bcf;
use tokio::{pin, sync::mpsc};
use tokio_stream::StreamExt;
use tracing::{error, info_span, instrument, warn};
use crate::{
common::{
ImmutableSequence, SequencePair,
aligner::{AlignmentOrchestrator, AlignmentQuery, cli::CliAlignmentArgs},
contig::ContigName,
coords::{GenomePosition, GenomeRegion},
list_of_regions::Regions,
reference::{ReferenceQueryResult, ReferenceReader},
},
vcf::pipeline::{
message::{Cluster, MaskedRecords},
reader::VCFReader,
},
};
use super::Message;
pub mod cluster;
mod iterator;
pub struct Clusterizer {
input: VCFReader,
targets: Option<Regions>,
reference: ReferenceReader,
bam_f: Option<rust_htslib::bam::IndexedReader>,
output: mpsc::Sender<Message>,
settings: ClusterSettings,
}
#[derive(clap::Args, Debug, Default)]
pub struct ClusterSettings {
#[command(flatten)]
pub cluster_strategy: cluster::ClusteringSettings,
#[command(flatten)]
pub aligner: CliAlignmentArgs,
}
impl Clusterizer {
pub(super) fn new(
input: VCFReader,
reference: ReferenceReader,
targets: Option<Regions>,
bam_f: Option<rust_htslib::bam::IndexedReader>,
output: mpsc::Sender<Message>,
settings: ClusterSettings,
) -> Self {
Self {
input,
targets,
reference,
bam_f,
output,
settings,
}
}
pub async fn run(self) {
let aligner = match AlignmentOrchestrator::try_from(&self.settings.aligner) {
Ok(aligner) => aligner,
Err(e) => {
error!("Cannot initialize aligners: {e}");
return;
}
};
let strategy = self.settings.cluster_strategy;
let clusters = find_clusters(strategy, self.input, self.targets, self.bam_f);
pin!(clusters);
while let Some(e) = clusters.next().await {
let m = match e {
ClusterOrRecords::Cluster(records) => {
let cluster = Cluster::new(bitvec::bitvec![1; records.len()]);
let masked = cluster.apply_to_records(&records);
match prepare_cluster(&self.reference, masked, self.settings.aligner.padding)
.await
{
Ok(Some(prepared)) => {
match aligner.get_or_compute_alignment(
self.reference.get_name(),
&prepared.reference_region.clone(),
prepared.cluster_region,
prepared.query.clone(),
) {
Ok(pending) => Message::cluster(
records,
vec![(
cluster,
prepared.reference_region.start().clone(),
prepared.query.sequences,
pending,
)],
),
Err(e) => {
error!(
"Could not get or start the computation of an alignment: {e}"
);
Message::passthrough(records)
}
}
}
Ok(None) => Message::passthrough(records),
Err(err) => {
error!(
"Could not prepare cluster for realignment: {err}. Passing records through.",
);
Message::passthrough(records)
}
}
}
ClusterOrRecords::Records(records) => Message::passthrough(records),
ClusterOrRecords::SequenceDone => {
aligner.clear_cache();
continue;
}
};
self.output.send(m).await.expect("Channel closed !?");
}
}
}
#[derive(Clone, Debug)]
struct PreparedCluster {
query: AlignmentQuery,
reference_region: GenomeRegion,
cluster_region: GenomeRegion,
}
async fn prepare_cluster(
reference: &ReferenceReader,
cluster: MaskedRecords<'_, '_>,
padding: usize,
) -> Result<Option<PreparedCluster>, Error> {
let query_region = extract_region(cluster)?;
let Some(ReferenceQueryResult {
region: actual_region,
sequence: reference_sequence,
..
}) = reference
.get_seq(query_region.clone(), padding, padding)
.await?
else {
return Ok(None);
};
let _span = info_span!("prepare_cluster", pos = %query_region).entered();
let actual_padding_left =
query_region.start().position_0() - actual_region.start().position_0();
let actual_padding_right =
reference_sequence.len() - actual_padding_left - query_region.len().unwrap();
let query_sequence = apply_mutations(
&reference_sequence,
actual_region.start().position_0(),
cluster.masked_iter(),
)?;
let ranges = AlignmentRange::new_offset_limit(
AlignmentCoordinates::new(actual_padding_left, actual_padding_left),
AlignmentCoordinates::new(
reference_sequence.len() - actual_padding_right,
query_sequence.len() - actual_padding_right,
),
);
Ok(Some(PreparedCluster {
query: AlignmentQuery {
sequences: SequencePair {
reference: reference_sequence,
query: query_sequence,
},
ranges,
},
reference_region: actual_region,
cluster_region: query_region,
}))
}
fn extract_region(cluster: MaskedRecords<'_, '_>) -> anyhow::Result<GenomeRegion> {
let (start, end, contig) = {
let (mut start, mut end, mut contig) = (i64::MAX, 0, None);
for r in cluster.masked_iter() {
start = start.min(r.pos());
end = end.max(r.end());
if contig.is_none() {
if let Some(present) = r.rid() {
let name = r.header().rid2name(present)?;
contig = Some(ContigName::new(name));
}
}
}
(
usize::try_from(start)?,
usize::try_from(end)?,
contig.context("No rid present in records??")?,
)
};
let query_region = GenomeRegion::new_bounded(GenomePosition::new_0(contig, start), end - start);
Ok(query_region)
}
#[instrument(name = "build_query_sequence", skip_all)]
fn apply_mutations<'a>(
reference: &[u8],
reference_start: usize,
mutations: impl Iterator<Item = &'a bcf::Record>,
) -> anyhow::Result<ImmutableSequence> {
let mut result = Vec::new();
let mut last_pos = reference_start;
let reference_end = reference_start + reference.len();
for m in mutations {
let pos = usize::try_from(m.pos())?;
let end = usize::try_from(m.end())?;
if pos < last_pos || end > reference_end || pos > end {
warn!(
"Overlapping Mutations -- last mutation ends at {last_pos}, but this starts at {}",
pos
);
continue;
}
result.extend_from_slice(&reference[last_pos - reference_start..pos - reference_start]);
let alt = m.alleles()[1]; result.extend_from_slice(alt);
last_pos = end;
}
result.extend_from_slice(&reference[last_pos - reference_start..]);
Ok(result.into())
}