mod args;
mod engine;
mod input;
mod workflow;
use anyhow::{anyhow, Result};
use crossbeam_channel::bounded;
use log::info;
use rayon::prelude::*;
use redicat_lib::bam2mtx::anndata_output::{AnnDataConfig, AnnDataConverter, BuilderMsg};
use redicat_lib::bam2mtx::barcode::BarcodeProcessor;
use redicat_lib::bam2mtx::processor::BamProcessorConfig;
use redicat_lib::utils;
use std::fs::File;
use std::io::Write;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use crate::commands::{common, is_standard_contig};
pub use args::Bam2MtxArgs;
use engine::{OptimizedChunkProcessor, SkippedSite};
use input::{chunk_positions, filter_positions_by, read_positions};
use workflow::prepare_positions_file;
pub fn run_bam2mtx(args: Bam2MtxArgs) -> Result<()> {
let start_time = Instant::now();
info!("Starting bam2mtx processing for {:?}", args.bam);
if args.two_pass {
info!("Two-pass mode enabled – running bulk first");
}
let positions_path = prepare_positions_file(&args)?;
info!("Using position list: {:?}", positions_path);
let active_threads = common::configure_global_thread_pool(args.threads)?;
utils::make_parent_dirs(&args.output)?;
info!("Loading {} threads into Rayon pool", active_threads);
info!("Loading cell barcodes from {:?}", args.barcodes);
let barcode_processor = Arc::new(BarcodeProcessor::from_file(&args.barcodes)?);
info!("Loaded {} valid barcodes", barcode_processor.len());
info!("Reading positions...");
let mut positions = read_positions(&positions_path)?;
if !args.all_contigs {
let before = positions.len();
positions = filter_positions_by(positions, is_standard_contig);
info!(
"Filtered positions to canonical contigs: {} → {} entries",
before,
positions.len()
);
}
if positions.is_empty() {
return Err(anyhow!(
"Position list {:?} contains no usable entries",
positions_path
));
}
let manifest_positions = positions.len();
let chunk_size = args.chunk_size();
let chunk_size_max_depth = args.chunk_size_max_depth();
let adata_config = AnnDataConfig {
stranded: args.stranded,
compression: Some("gzip".to_string()),
threads: active_threads,
chunk_size,
matrix_density: args.matrix_density,
batch_size: chunk_size,
total_positions: manifest_positions,
triplet_spill_nnz: (chunk_size as usize * 16).max(500_000),
};
let processor_config = BamProcessorConfig {
min_mapping_quality: args.min_mapq,
min_base_quality: args.min_baseq,
stranded: args.stranded,
max_depth: args.max_depth,
umi_tag: args.umi_tag.clone(),
cell_barcode_tag: args.cb_tag.clone(),
};
info!(
"Chunking positions with chunk_size={} and chunk_size_max_depth={}",
chunk_size, chunk_size_max_depth
);
let chunks = chunk_positions(positions, chunk_size, chunk_size_max_depth);
let total_near_max_positions: usize = chunks
.iter()
.map(|chunk| chunk.near_max_depth_count())
.sum();
info!(
"Chunked positions into {} batches ({} near-max-depth positions)",
chunks.len(),
total_near_max_positions
);
let processor = OptimizedChunkProcessor::new(
args.bam.clone(),
processor_config,
Arc::clone(&barcode_processor),
)?;
let contig_names = processor.contig_names();
info!("Processing chunks with streaming parallel pipeline...");
let processing_start = Instant::now();
let total_chunks = chunks.len();
let processed_chunks = AtomicUsize::new(0);
let total_positions = AtomicUsize::new(0);
let near_max_remaining = AtomicUsize::new(total_near_max_positions);
let long_report_interval = Duration::from_secs(1800);
let last_timed_report_secs = AtomicU64::new(0);
let last_percent_bucket = AtomicUsize::new(0);
let channel_bound = active_threads.max(4);
let (tx, rx) = bounded::<BuilderMsg>(channel_bound);
let skipped_records = parking_lot::Mutex::new(Vec::<SkippedSite>::new());
let output_path = args.output.clone();
let builder_config = adata_config.clone();
let builder_barcodes = Arc::clone(&barcode_processor);
let builder_contigs = Arc::clone(&contig_names);
let builder_output = output_path.clone();
let builder_handle = std::thread::Builder::new()
.name("anndata-builder".to_string())
.spawn(move || -> Result<()> {
let converter = AnnDataConverter::new(
builder_config,
builder_barcodes,
builder_contigs,
);
let adata = converter.convert_streaming_channel(rx, &builder_output)?;
converter.write_to_file(&adata, &builder_output)?;
Ok(())
})?;
let send_error: parking_lot::Mutex<Option<anyhow::Error>> = parking_lot::Mutex::new(None);
chunks
.into_par_iter()
.enumerate()
.for_each_init(
|| processor.open_reader().expect("failed to open BAM reader for Rayon thread"),
|reader, (_idx, chunk)| {
if send_error.lock().is_some() {
return;
}
let chunk_near = chunk.near_max_depth_count();
let result = processor.process_chunk(&chunk, reader);
let (data, skipped) = match result {
Ok(r) => r,
Err(e) => {
let mut err = send_error.lock();
if err.is_none() {
*err = Some(e);
}
return;
}
};
if !skipped.is_empty() {
skipped_records.lock().extend(skipped);
}
if !data.is_empty() {
total_positions.fetch_add(data.len(), Ordering::Relaxed);
}
let near_remaining = if chunk_near == 0 {
near_max_remaining.load(Ordering::Relaxed)
} else {
near_max_remaining
.fetch_sub(chunk_near, Ordering::Relaxed)
.saturating_sub(chunk_near)
};
let completed = processed_chunks.fetch_add(1, Ordering::Relaxed) + 1;
let elapsed = processing_start.elapsed();
let elapsed_secs = elapsed.as_secs();
if total_chunks > 0 {
let percent = (completed as f64 * 100.0)
/ total_chunks.max(1) as f64;
let bucket = ((percent / 5.0).floor() as usize).min(20);
let mut percent_log_due = false;
if completed < total_chunks {
let previous = last_percent_bucket.load(Ordering::Relaxed);
if bucket > previous
&& last_percent_bucket
.compare_exchange(
previous,
bucket,
Ordering::Relaxed,
Ordering::Relaxed,
)
.is_ok()
{
percent_log_due = true;
}
}
let mut timed_log_due = false;
if completed < total_chunks {
let last_marker = last_timed_report_secs.load(Ordering::Relaxed);
if elapsed_secs
.saturating_sub(last_marker)
>= long_report_interval.as_secs()
&& last_timed_report_secs
.compare_exchange(
last_marker,
elapsed_secs,
Ordering::Relaxed,
Ordering::Relaxed,
)
.is_ok()
{
timed_log_due = true;
}
}
if percent_log_due || timed_log_due {
info!(
"Processed {:.1}% ({} / {} chunks, {} near-max-depth positions remaining)",
percent,
completed,
total_chunks,
near_remaining
);
}
}
if !data.is_empty() {
if let Err(_e) = tx.send(BuilderMsg::Batch(data)) {
let mut err = send_error.lock();
if err.is_none() {
*err = Some(anyhow!("builder thread terminated early"));
}
}
}
});
let _ = tx.send(BuilderMsg::Done);
drop(tx);
info!("Processed 100.0%");
info!(
"Chunk traversal finished in {:?} ({} positions)",
processing_start.elapsed(),
total_positions.load(Ordering::Relaxed)
);
if let Some(err) = send_error.into_inner() {
return Err(err);
}
let parallel_assembly_start = Instant::now();
match builder_handle.join() {
Ok(Ok(())) => {
info!(
"Streaming sparse matrix assembly completed in {:?}",
parallel_assembly_start.elapsed()
);
}
Ok(Err(e)) => return Err(e),
Err(panic_payload) => {
return Err(anyhow!(
"builder thread panicked: {:?}",
panic_payload.downcast_ref::<String>()
));
}
}
info!("AnnData write finished -> {:?}", output_path);
let skipped_records = skipped_records.into_inner();
let skipped_count = skipped_records.len();
let mut skipped_path = output_path.clone();
let stem = skipped_path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("output");
skipped_path.set_file_name(format!("{}_skipped_sites.txt", stem));
let mut file = File::create(&skipped_path)?;
writeln!(file, "CHR\tPOS\tDEPTH")?;
for site in skipped_records {
writeln!(file, "{}\t{}\t{}", site.contig, site.pos, site.depth)?;
}
info!(
"Recorded {} depth-capped sites at {:?}",
skipped_count, skipped_path
);
info!(
"Chunk processing completed in {:?} ({} positions)",
processing_start.elapsed(),
total_positions.load(Ordering::Relaxed)
);
info!("bam2mtx workflow finished in {:?}", start_time.elapsed());
Ok(())
}