use std::{collections::VecDeque, time::Duration};
use anyhow::bail;
use create_records::create_records;
use rust_htslib::bcf::{self, Record};
use tokio::{sync::mpsc, time::timeout};
use tracing::{Span, error, instrument, warn};
use tracing_indicatif::span_ext::IndicatifSpanExt;
use crate::{
common::{
SequencePair,
aligner::{
self,
result::{
AlignmentFailure, SoftFailureReason, TwitcherAlignment, TwitcherAlignmentCase,
},
},
},
counter,
vcf::pipeline::writer::region_writer::RegionWriter,
};
use super::Message;
mod create_records;
pub mod region_writer;
pub struct Writer {
input: mpsc::Receiver<Message>,
output: SortingWriter,
region_output: Option<RegionWriter>,
}
impl Writer {
pub(super) fn new(
input: mpsc::Receiver<Message>,
output: OutputWriter,
region_output: Option<RegionWriter>,
) -> Self {
Self {
input,
output: SortingWriter::new(1000, output),
region_output,
}
}
#[instrument(skip_all, fields(indicatif.pb_show = true))]
pub async fn run(mut self) -> anyhow::Result<()> {
loop {
let r = timeout(Duration::from_millis(100), self.input.recv()).await;
match r {
Ok(Some(m)) => match self.handle_message(m).await {
Ok(()) => {}
Err(e) => {
error!("Error in writer: {e}");
}
},
Ok(None) => break, Err(_) => {} }
Self::tick_progress();
}
if let Some(mut rw) = self.region_output {
rw.flush().await?;
}
Ok(())
}
async fn handle_message(&mut self, m: Message) -> anyhow::Result<()> {
if m.is_passthrough() {
self.write_unchanged_records(m.orig_records)?;
} else {
let mut ts_alignments = Vec::new();
for (cluster, ref_start, sequences, mut rx) in m.clusters {
'retry: loop {
let result =
match tokio::time::timeout(Duration::from_millis(100), rx.recv()).await {
Ok(res) => res?,
Err(_elapsed) => {
Self::tick_progress();
continue 'retry;
}
};
match &*result {
Ok(realignment) => match &realignment.result {
TwitcherAlignmentCase::FoundTS { .. } => {
ts_alignments.push((
cluster,
ref_start,
sequences,
realignment.clone(),
));
}
TwitcherAlignmentCase::NoTS { .. } => {}
},
Err(AlignmentFailure::SoftFailure {
reason: SoftFailureReason::OutOfMemory | SoftFailureReason::Timeout(_),
}) => {}
Err(AlignmentFailure::SoftFailure {
reason: SoftFailureReason::Other(error),
}) => {
warn!("Soft error: {error}");
}
Err(AlignmentFailure::Error { error }) => {
error!("{error}");
}
}
break 'retry;
}
}
if ts_alignments.is_empty() {
self.write_unchanged_records(m.orig_records)?;
} else if ts_alignments.len() == 1 {
self.write_one_alignment(&m.orig_records, ts_alignments)
.await?;
} else {
warn!("Received complex cluster, not yet implemented!");
self.write_unchanged_records(m.orig_records)?;
}
}
Ok(())
}
async fn write_one_alignment(
&mut self,
orig_records: &[Record],
ts_alignments: Vec<(
super::message::Cluster,
crate::common::coords::GenomePosition,
SequencePair,
TwitcherAlignment,
)>,
) -> Result<(), anyhow::Error> {
let (cluster, ref_start, sequences, realignment) = &ts_alignments[0];
let unchanged_records: Vec<&bcf::Record> = cluster
.mask
.iter_zeros()
.map(|ix| &orig_records[ix])
.collect();
let TwitcherAlignment {
result:
TwitcherAlignmentCase::FoundTS {
alignment_with_ts, ..
},
stats,
} = realignment
else {
bail!("Implementation error")
};
let replaced_records = cluster.apply_to_records(orig_records);
let new_records = create_records(
alignment_with_ts.iter_compact_cloned(),
(stats.reference_offset(), &sequences.reference),
(stats.query_offset(), &sequences.query),
Some(replaced_records),
&self.output.inner.inner().empty_record(),
ref_start,
)?;
if let Some(rw) = &mut self.region_output {
rw.write(new_records.iter().map(|(r, _)| r)).await?;
}
let mut result = new_records;
result.extend(
unchanged_records
.into_iter()
.cloned()
.map(|rec| (rec, RecordProperties::OldRecord)),
);
result.sort_by_key(|(rec, _)| rec.pos());
for (rec, prop) in result {
self.output.write(rec, prop)?;
}
Ok(())
}
fn tick_progress() {
let span = Span::current();
let total = counter!("alignments").get();
let progress = counter!("alignments.finished").get();
span.pb_set_length(total as u64);
span.pb_set_position(progress as u64);
span.pb_set_message(&format!(
"({} alignments running)",
aligner::RUNNING.load(std::sync::atomic::Ordering::Relaxed)
));
}
fn write_unchanged_records(
&mut self,
records: Vec<Record>,
) -> Result<(), rust_htslib::errors::Error> {
for r in records {
self.output.write(r, RecordProperties::OldRecord)?;
}
Ok(())
}
}
pub struct SortingWriter {
buf: VecDeque<(Record, RecordProperties)>,
buf_coord_len: i64,
inner: OutputWriter,
}
impl SortingWriter {
fn new(buf_coord_len: i64, inner: OutputWriter) -> Self {
Self {
buf: VecDeque::new(),
buf_coord_len,
inner,
}
}
fn write(
&mut self,
record: Record,
properties: RecordProperties,
) -> Result<(), rust_htslib::errors::Error> {
self.buf.push_back((record, properties));
self.swim_last();
self.flush()
}
fn swim_last(&mut self) {
if self.buf.len() <= 1 {
return;
}
let mut ix = self.buf.len() - 1;
while ix >= 1
&& (self.buf[ix].0.rid(), self.buf[ix].0.pos())
< (self.buf[ix - 1].0.rid(), self.buf[ix - 1].0.pos())
{
self.buf.swap(ix, ix - 1);
ix -= 1;
}
}
fn flush(&mut self) -> Result<(), rust_htslib::tpool::Error> {
let last = self
.buf
.back()
.map(|(rec, _)| (rec.rid().unwrap(), rec.pos()));
let len = self.buf_coord_len;
self.flush_while(|(rec, _)| {
last.is_some_and(|(last_rid, last_pos)| {
(last_rid, last_pos - len) > (rec.rid().unwrap(), rec.pos())
})
})
}
fn flush_while(
&mut self,
condition: impl Fn(&(Record, RecordProperties)) -> bool,
) -> Result<(), rust_htslib::tpool::Error> {
while self.buf.front().is_some_and(&condition) {
let (record, properties) = self.buf.pop_front().unwrap();
tokio::task::block_in_place(|| self.inner.write(record, properties))?;
}
Ok(())
}
}
impl Drop for SortingWriter {
fn drop(&mut self) {
if let Err(err) = self.flush_while(|_| true) {
error!("Could not flush all the records: {err}");
}
}
}
enum RecordProperties {
OldRecord,
Realigned {
#[allow(unused)]
has_ts: bool,
},
}
pub enum OutputWriter {
Native { inner: bcf::Writer, only_ts: bool },
Buffered(FilteredOutputWriter),
}
impl OutputWriter {
pub fn new_native(inner: bcf::Writer, only_ts: bool) -> Self {
Self::Native { inner, only_ts }
}
pub fn new_buffered(inner: bcf::Writer, plus_minus: i64) -> Self {
Self::Buffered(FilteredOutputWriter {
buf: VecDeque::new(),
out: inner,
last_keep_pos: None,
current_pos: 0,
max_distance: plus_minus,
})
}
fn inner(&self) -> &bcf::Writer {
match self {
OutputWriter::Native { inner, .. } => inner,
OutputWriter::Buffered(filtered_output_writer) => &filtered_output_writer.out,
}
}
fn write(
&mut self,
record: Record,
properties: RecordProperties,
) -> Result<(), rust_htslib::errors::Error> {
match (self, properties) {
(OutputWriter::Native { only_ts: true, .. }, RecordProperties::OldRecord) => {}
(
OutputWriter::Native {
inner,
only_ts: false,
},
RecordProperties::OldRecord,
)
| (OutputWriter::Native { inner, .. }, RecordProperties::Realigned { .. }) => {
inner.write(&record)?;
}
(OutputWriter::Buffered(filtered_output_writer), RecordProperties::OldRecord) => {
filtered_output_writer.write(record, false)?;
}
(
OutputWriter::Buffered(filtered_output_writer),
RecordProperties::Realigned { .. },
) => {
filtered_output_writer.write(record, true)?;
}
}
Ok(())
}
}
pub struct FilteredOutputWriter {
buf: VecDeque<Record>,
out: bcf::Writer,
last_keep_pos: Option<i64>,
current_pos: i64,
max_distance: i64,
}
impl FilteredOutputWriter {
fn write(&mut self, record: Record, keep: bool) -> Result<(), rust_htslib::errors::Error> {
self.current_pos = record.pos();
if keep {
self.last_keep_pos = Some(self.current_pos);
self.flush_relevant_from_buf()?;
}
match self.last_keep_pos {
Some(ts_pos) if self.current_pos < ts_pos + self.max_distance => {
self.out.write(&record)?;
}
Some(_) => {
self.last_keep_pos = None;
self.buf.push_back(record);
}
None => {
self.discard_stale_records();
self.buf.push_back(record);
}
}
Ok(())
}
fn flush_relevant_from_buf(&mut self) -> Result<(), rust_htslib::errors::Error> {
let Some(ts_pos) = self.last_keep_pos else {
return Ok(());
};
while let Some(front) = self.buf.front() {
if front.pos() < ts_pos - self.max_distance {
self.buf.pop_front(); } else {
let record = self.buf.pop_front().unwrap();
self.out.write(&record)?;
}
}
Ok(())
}
fn discard_stale_records(&mut self) {
let threshold = self.current_pos.saturating_sub(self.max_distance);
while let Some(front) = self.buf.front() {
if front.pos() < threshold {
self.buf.pop_front();
} else {
break;
}
}
}
}