use std::fmt;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc;
use anyhow::{Result, anyhow};
use clap::{Args, ValueEnum};
use crossbeam_channel::{Receiver, Sender};
use noodles::sam::Header;
use crate::collector::Collector;
use crate::commands::alignment::{AlignmentCollector, MultiAlignmentOptions};
use crate::commands::basic::BasicCollector;
use crate::commands::command::Command;
use crate::commands::common::{InputOptions, OptionalReferenceOptions, OutputOptions};
use crate::commands::error::{ErrorCollector, MultiErrorOptions};
use crate::commands::gcbias::{GcBiasCollector, MultiGcBiasOptions};
use crate::commands::hybcap::{HybCapCollector, MultiHybCapOptions};
use crate::commands::isize::{InsertSizeCollector, MultiIsizeOptions};
use crate::commands::wgs::{MultiWgsOptions, WgsCollector};
use crate::fasta::Fasta;
use crate::progress::ProgressLogger;
use crate::sam::alignment_reader::AlignmentReader;
use crate::sam::record_utils::derive_sample;
use crate::sam::riker_record::{RikerRecord, RikerRecordRequirements};
const BATCH_SIZE: usize = 128;
const NUM_BATCHES_POOLED: usize = 16;
#[derive(Args, Debug, Clone)]
#[command(
long_about,
after_long_help = "\
Examples:
riker multi -i input.bam -o out_prefix -r ref.fa
riker multi -i input.bam -o out_prefix -r ref.fa --tools alignment basic isize
riker multi -i input.bam -o out_prefix -r ref.fa --threads 4
riker multi -i input.bam -o out_prefix --tools hybcap --hybcap::targets t.bed --hybcap::baits b.bed"
)]
pub struct Multi {
#[command(flatten)]
pub input: InputOptions,
#[command(flatten)]
pub output: OutputOptions,
#[command(flatten)]
pub reference: OptionalReferenceOptions,
#[arg(
long,
num_args(1..),
default_values_t = [CollectorKind::Alignment, CollectorKind::Basic, CollectorKind::Isize],
help_heading = "Multi Command Options"
)]
pub tools: Vec<CollectorKind>,
#[arg(long, default_value_t = 2, help_heading = "Multi Command Options")]
pub threads: usize,
#[command(flatten)]
pub alignment_opts: MultiAlignmentOptions,
#[command(flatten)]
pub error_opts: MultiErrorOptions,
#[command(flatten)]
pub gcbias_opts: MultiGcBiasOptions,
#[command(flatten)]
pub hybcap_opts: MultiHybCapOptions,
#[command(flatten)]
pub isize_opts: MultiIsizeOptions,
#[command(flatten)]
pub wgs_opts: MultiWgsOptions,
}
impl Multi {
fn build_collectors(
&self,
kinds: &[CollectorKind],
header: &Header,
) -> Result<Vec<Box<dyn Collector>>> {
let mut collectors: Vec<Box<dyn Collector>> = Vec::new();
for kind in kinds {
match kind {
CollectorKind::Alignment => {
let opts = self.alignment_opts.clone().validate()?;
collectors.push(Box::new(AlignmentCollector::new(
&self.input.input,
&self.output.output,
self.reference.reference.clone(),
&opts,
)));
}
CollectorKind::Basic => {
collectors.push(Box::new(BasicCollector::new(
&self.input.input,
&self.output.output,
)));
}
CollectorKind::Error => {
let ref_path = self.reference.reference.as_ref().unwrap();
let reference = Fasta::from_path(ref_path)?;
let mut error_opts = self.error_opts.clone();
if error_opts.error_reference.is_none() {
error_opts.error_reference = Some(ref_path.clone());
}
let opts = error_opts.validate()?;
collectors.push(Box::new(ErrorCollector::new(
&self.input.input,
&self.output.output,
reference,
&opts,
)?));
}
CollectorKind::GcBias => {
let ref_path = self.reference.reference.as_ref().unwrap();
let reference = Fasta::from_path(ref_path)?;
let opts = self.gcbias_opts.clone().validate()?;
collectors.push(Box::new(GcBiasCollector::new(
&self.input.input,
&self.output.output,
reference,
&opts,
)));
}
CollectorKind::HybCap => {
let opts = self.hybcap_opts.clone().validate()?;
let sample = derive_sample(&self.input.input, header);
let fasta = self
.reference
.reference
.as_ref()
.map(|p| Fasta::from_path(p))
.transpose()?;
collectors.push(Box::new(HybCapCollector::new(
&self.output.output,
fasta,
sample,
&opts,
)));
}
CollectorKind::Isize => {
let opts = self.isize_opts.clone().validate()?;
collectors.push(Box::new(InsertSizeCollector::new(
&self.input.input,
&self.output.output,
&opts,
)));
}
CollectorKind::Wgs => {
let ref_path = self.reference.reference.as_ref().unwrap();
let reference = Fasta::from_path(ref_path)?;
let opts = self.wgs_opts.clone().validate()?;
collectors.push(Box::new(WgsCollector::new(
&self.input.input,
&self.output.output,
reference,
&opts,
)?));
}
}
}
Ok(collectors)
}
}
impl Command for Multi {
fn execute(&self) -> Result<()> {
if self.threads == 0 {
return Err(anyhow!("--threads must be >= 1"));
}
let mut seen = Vec::new();
for kind in &self.tools {
if !seen.contains(kind) {
seen.push(*kind);
}
}
for kind in &seen {
match kind {
CollectorKind::Error if self.reference.reference.is_none() => {
return Err(anyhow!("Error collector requires --reference"));
}
CollectorKind::GcBias if self.reference.reference.is_none() => {
return Err(anyhow!("GC bias collector requires --reference"));
}
CollectorKind::Wgs if self.reference.reference.is_none() => {
return Err(anyhow!("WGS collector requires --reference"));
}
_ => {}
}
}
let reader = AlignmentReader::open(&self.input.input, self.reference.reference.as_deref())?;
let collectors = self.build_collectors(&seen, reader.header())?;
if self.threads > 1 {
run_parallel(reader, collectors, self.threads)?;
} else {
run_single_threaded(reader, collectors)?;
}
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
pub enum CollectorKind {
Alignment,
Basic,
#[value(name = "error")]
Error,
#[value(name = "gcbias")]
GcBias,
#[value(name = "hybcap")]
HybCap,
Isize,
Wgs,
}
impl fmt::Display for CollectorKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CollectorKind::Alignment => write!(f, "alignment"),
CollectorKind::Basic => write!(f, "basic"),
CollectorKind::Error => write!(f, "error"),
CollectorKind::GcBias => write!(f, "gcbias"),
CollectorKind::HybCap => write!(f, "hybcap"),
CollectorKind::Isize => write!(f, "isize"),
CollectorKind::Wgs => write!(f, "wgs"),
}
}
}
type Batch = Arc<RecyclableBatch>;
type WorkItem = (usize, Batch);
type WorkTx = Sender<WorkItem>;
type WorkRx = Receiver<WorkItem>;
struct RecyclableBatch {
records: Vec<RikerRecord>,
len: usize,
return_tx: mpsc::Sender<Vec<RikerRecord>>,
}
impl RecyclableBatch {
fn records(&self) -> &[RikerRecord] {
&self.records[..self.len]
}
}
impl Drop for RecyclableBatch {
fn drop(&mut self) {
let records = std::mem::take(&mut self.records);
let _ = self.return_tx.send(records);
}
}
fn run_single_threaded(
mut reader: AlignmentReader,
mut collectors: Vec<Box<dyn Collector>>,
) -> Result<()> {
let header = reader.header().clone();
for collector in &mut collectors {
collector.initialize(&header)?;
}
let requirements = combined_requirements(&collectors);
let mut progress = ProgressLogger::new("multi", "reads", 5_000_000);
let read_result = run_single_threaded_loop(
&mut reader,
&mut collectors,
&requirements,
&header,
&mut progress,
);
progress.finish();
read_result?;
for collector in &mut collectors {
collector.finish()?;
}
Ok(())
}
fn run_single_threaded_loop(
reader: &mut AlignmentReader,
collectors: &mut [Box<dyn Collector>],
requirements: &RikerRecordRequirements,
header: &Header,
progress: &mut ProgressLogger,
) -> Result<()> {
let mut record = reader.empty_record();
while reader.fill_record(requirements, &mut record)? {
progress.record_with(&record, header);
for collector in collectors.iter_mut() {
collector.accept(&record, header)?;
}
}
Ok(())
}
fn combined_requirements(collectors: &[Box<dyn Collector>]) -> RikerRecordRequirements {
collectors.iter().fold(RikerRecordRequirements::NONE, |acc, c| acc.union(c.field_needs()))
}
fn run_parallel(
mut reader: AlignmentReader,
mut collectors: Vec<Box<dyn Collector>>,
threads: usize,
) -> Result<()> {
debug_assert!(threads >= 2, "run_parallel requires at least 2 total threads");
let pool_workers = threads - 1;
let header = reader.header().clone();
for collector in &mut collectors {
collector.initialize(&header)?;
}
let requirements = combined_requirements(&collectors);
let slots: Vec<Mutex<Box<dyn Collector>>> = collectors.into_iter().map(Mutex::new).collect();
let work_queue_bound = (NUM_BATCHES_POOLED + 1) * slots.len().max(1);
let (work_tx, work_rx): (WorkTx, WorkRx) = crossbeam_channel::bounded(work_queue_bound);
let (pool_tx, pool_rx) = mpsc::channel::<Vec<RikerRecord>>();
for _ in 0..NUM_BATCHES_POOLED {
let mut vec: Vec<RikerRecord> = Vec::with_capacity(BATCH_SIZE);
vec.resize_with(BATCH_SIZE, || reader.empty_record());
pool_tx.send(vec).expect("pool send cannot fail: channel is unbounded and rx is alive");
}
let poison = AtomicBool::new(false);
let n_collectors = slots.len();
let header_ref = &header;
std::thread::scope(|scope| -> Result<()> {
let slots_ref: &[Mutex<Box<dyn Collector>>] = &slots;
let poison_ref: &AtomicBool = &poison;
let mut pool_handles = Vec::with_capacity(pool_workers);
for _ in 0..pool_workers {
let work_rx = work_rx.clone();
pool_handles.push(
scope.spawn(move || pool_worker_loop(work_rx, slots_ref, header_ref, poison_ref)),
);
}
drop(work_rx);
let requirements_ref = &requirements;
let reader_handle = scope.spawn(move || {
let reader_result = reader_thread_loop(
&mut reader,
header_ref,
&work_tx,
n_collectors,
pool_tx,
pool_rx,
requirements_ref,
poison_ref,
);
drop(work_tx);
reader_result
});
let reader_result = reader_handle.join().map_err(|_| anyhow!("reader thread panicked"))?;
if let Err(e) = reader_result {
poison.store(true, Ordering::Relaxed);
for handle in pool_handles {
let _ = handle.join();
}
return Err(e);
}
let mut first_error: Option<anyhow::Error> = None;
for handle in pool_handles {
match handle.join() {
Ok(Ok(())) => {}
Ok(Err(e)) => {
if first_error.is_none() {
first_error = Some(e);
}
}
Err(_) => {
if first_error.is_none() {
first_error = Some(anyhow!("pool thread panicked"));
}
}
}
}
if let Some(e) = first_error {
return Err(e);
}
Ok(())
})?;
for slot in slots {
let mut collector = slot.into_inner().unwrap();
collector.finish()?;
}
Ok(())
}
#[allow(
clippy::needless_pass_by_value,
clippy::too_many_arguments,
reason = "pool_tx and pool_rx move into this (scoped) thread: the reader \
is the only thread that receives from the pool, and we want the \
reader's handle on pool_tx to drop when the reader exits so \
in-flight RecyclableBatch Drops (which each carry a clone of \
pool_tx) become the last senders and the pool channel can \
close naturally on shutdown"
)]
fn reader_thread_loop(
reader: &mut AlignmentReader,
header: &Header,
work_tx: &WorkTx,
n_collectors: usize,
pool_tx: mpsc::Sender<Vec<RikerRecord>>,
pool_rx: mpsc::Receiver<Vec<RikerRecord>>,
requirements: &RikerRecordRequirements,
poison: &AtomicBool,
) -> Result<()> {
let mut progress = ProgressLogger::new("multi", "reads", 5_000_000);
let result = 'outer: loop {
if poison.load(Ordering::Relaxed) {
break Ok(());
}
let Ok(mut records) = pool_rx.recv() else {
break Ok(());
};
let mut len = 0;
while len < records.len() {
match reader.fill_record(requirements, &mut records[len]) {
Ok(false) => break,
Ok(true) => {
progress.record_with(&records[len], header);
len += 1;
}
Err(e) => break 'outer Err(e),
}
}
if len == 0 {
let _ = pool_tx.send(records);
break Ok(());
}
let batch = Arc::new(RecyclableBatch { records, len, return_tx: pool_tx.clone() });
if !dispatch_batch(work_tx, &batch, n_collectors, poison) {
break Ok(());
}
drop(batch);
};
progress.finish();
result
}
fn dispatch_batch(
work_tx: &WorkTx,
batch: &Batch,
n_collectors: usize,
poison: &AtomicBool,
) -> bool {
for idx in 0..n_collectors {
if poison.load(Ordering::Relaxed) {
return false;
}
if work_tx.send((idx, Arc::clone(batch))).is_err() {
return false;
}
}
true
}
#[allow(
clippy::needless_pass_by_value,
reason = "each worker owns its own clone of the MPMC receiver so we can \
drop the outer handle in run_parallel; passing by reference \
would leave that handle alive and keep the queue open"
)]
fn pool_worker_loop(
work_rx: WorkRx,
slots: &[Mutex<Box<dyn Collector>>],
header: &Header,
poison: &AtomicBool,
) -> Result<()> {
while let Ok((idx, batch)) = work_rx.recv() {
if poison.load(Ordering::Relaxed) {
return Ok(());
}
let mut collector = slots[idx].lock().unwrap();
if let Err(e) = collector.accept_multiple(batch.records(), header) {
poison.store(true, Ordering::Relaxed);
return Err(e);
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sam::riker_record::RikerRecord;
struct FailingCollector {
seen: u64,
fail_after: u64,
}
impl Collector for FailingCollector {
fn initialize(&mut self, _h: &Header) -> Result<()> {
Ok(())
}
fn accept(&mut self, _r: &RikerRecord, _h: &Header) -> Result<()> {
self.seen += 1;
if self.seen >= self.fail_after {
return Err(anyhow!("synthetic failure after {} records", self.seen));
}
Ok(())
}
fn finish(&mut self) -> Result<()> {
Ok(())
}
fn name(&self) -> &'static str {
"failing"
}
fn field_needs(&self) -> RikerRecordRequirements {
RikerRecordRequirements::NONE
}
}
#[test]
fn run_parallel_propagates_collector_error() -> Result<()> {
use std::path::Path;
use noodles::bam;
use noodles::sam::Header;
use noodles::sam::alignment::RecordBuf;
use noodles::sam::alignment::io::Write as _;
use noodles::sam::alignment::record::Flags;
use noodles::sam::alignment::record::cigar::Op;
use noodles::sam::alignment::record::cigar::op::Kind;
use noodles::sam::alignment::record_buf::{Cigar, QualityScores, Sequence};
use noodles::sam::header::record::value::{Map, map::ReferenceSequence};
let header = Header::builder()
.add_reference_sequence(
"chr1",
Map::<ReferenceSequence>::new(std::num::NonZeroUsize::new(10_000).unwrap()),
)
.build();
let tmp = tempfile::NamedTempFile::with_suffix(".bam")?;
{
let file = std::fs::File::create(tmp.path())?;
let mut writer = bam::io::Writer::new(std::io::BufWriter::new(file));
writer.write_header(&header)?;
let cigar: Cigar = [Op::new(Kind::Match, 50)].into_iter().collect();
for i in 0u32..2_000 {
let pos =
noodles::core::Position::new(usize::try_from(i).unwrap() % 9_000 + 1).unwrap();
let record = RecordBuf::builder()
.set_name(format!("r{i}").into_bytes())
.set_flags(Flags::empty())
.set_reference_sequence_id(0)
.set_alignment_start(pos)
.set_cigar(cigar.clone())
.set_sequence(Sequence::from(vec![b'A'; 50]))
.set_quality_scores(QualityScores::from(vec![30u8; 50]))
.build();
writer.write_alignment_record(&header, &record)?;
}
}
let reader = AlignmentReader::open(Path::new(tmp.path()), None)?;
let collectors: Vec<Box<dyn Collector>> =
vec![Box::new(FailingCollector { seen: 0, fail_after: 100 })];
let result = run_parallel(reader, collectors, 2);
let err = result.expect_err("run_parallel should propagate the collector error");
assert!(
err.to_string().contains("synthetic failure"),
"expected the failing collector's error, got: {err}"
);
Ok(())
}
}