use clap::Parser;
use std::num::NonZeroUsize;
use anyhow::Context;
use arrow2::{array::Float64Array, chunk::Chunk, datatypes::Field};
use std::{fs::File, io, path::PathBuf};
use num_format::{Locale, ToFormattedString};
use serde::Serialize;
use serde_json::json;
use tracing::info;
use tracing_subscriber::{filter::LevelFilter, fmt, prelude::*, EnvFilter};
use noodles_bam as bam;
use noodles_bgzf as bgzf;
mod alignment_parser;
mod bootstrap;
mod em;
mod util;
use crate::util::normalize_probability::normalize_read_probs;
use crate::util::oarfish_types::{
AlignmentFilters, EMInfo, InMemoryAlignmentStore, TranscriptInfo,
};
use crate::util::read_function::read_short_quant_vec;
use crate::util::write_function::{write_infrep_file, write_output};
use crate::util::{binomial_probability::binomial_continuous_prob, kde_utils};
#[derive(Clone, Debug, clap::ValueEnum, Serialize)]
enum FilterGroup {
NoFilters,
NanocountFilters,
}
fn parse_strand(arg: &str) -> anyhow::Result<bio_types::strand::Strand> {
match arg {
"+" | "fw" | "FW" | "f" | "F" => Ok(bio_types::strand::Strand::Forward),
"-" | "rc" | "RC" | "r" | "R" => Ok(bio_types::strand::Strand::Reverse),
"." | "both" | "either" => Ok(bio_types::strand::Strand::Unknown),
_ => anyhow::bail!("Cannot parse {} as a valid strand type", arg),
}
}
#[derive(Parser, Debug, Serialize)]
#[clap(author, version, about, long_about = None)]
struct Args {
#[arg(long, conflicts_with = "verbose")]
quiet: bool,
#[arg(long)]
verbose: bool,
#[arg(short, long, required = true)]
alignments: PathBuf,
#[arg(short, long, required = true)]
output: PathBuf,
#[arg(long, help_heading = "filters", value_enum)]
filter_group: Option<FilterGroup>,
#[arg(short, long, conflicts_with = "filter-group", help_heading="filters", default_value_t = u32::MAX as i64)]
three_prime_clip: i64,
#[arg(short, long, conflicts_with = "filter-group", help_heading="filters", default_value_t = u32::MAX)]
five_prime_clip: u32,
#[arg(
short,
long,
conflicts_with = "filter-group",
help_heading = "filters",
default_value_t = 0.95
)]
score_threshold: f32,
#[arg(
short,
long,
conflicts_with = "filter-group",
help_heading = "filters",
default_value_t = 0.5
)]
min_aligned_fraction: f32,
#[arg(
short = 'l',
long,
conflicts_with = "filter-group",
help_heading = "filters",
default_value_t = 50
)]
min_aligned_len: u32,
#[arg(
short = 'd',
long,
conflicts_with = "filter-group",
help_heading = "filters",
default_value_t = bio_types::strand::Strand::Unknown,
value_parser = parse_strand
)]
strand_filter: bio_types::strand::Strand,
#[arg(long, help_heading = "coverage model", value_parser)]
model_coverage: bool,
#[arg(long, help_heading = "EM", default_value_t = 1000)]
max_em_iter: u32,
#[arg(long, help_heading = "EM", default_value_t = 1e-3)]
convergence_thresh: f64,
#[arg(short = 'j', long, default_value_t = 1)]
threads: usize,
#[arg(short = 'q', long, help_heading = "EM")]
short_quant: Option<String>,
#[arg(long, default_value_t = 0)]
num_bootstraps: u32,
#[arg(short, long, help_heading = "coverage model", default_value_t = 100)]
bin_width: u32,
#[arg(short, long, hide = true)]
use_kde: bool,
}
fn get_filter_opts(args: &Args) -> AlignmentFilters {
match args.filter_group {
Some(FilterGroup::NoFilters) => {
info!("disabling alignment filters.");
AlignmentFilters::builder()
.five_prime_clip(u32::MAX)
.three_prime_clip(i64::MAX)
.score_threshold(0_f32)
.min_aligned_fraction(0_f32)
.min_aligned_len(1_u32)
.which_strand(args.strand_filter)
.model_coverage(args.model_coverage)
.build()
}
Some(FilterGroup::NanocountFilters) => {
info!("setting filters to nanocount defaults.");
AlignmentFilters::builder()
.five_prime_clip(u32::MAX)
.three_prime_clip(50_i64)
.score_threshold(0.95_f32)
.min_aligned_fraction(0.5_f32)
.min_aligned_len(50_u32)
.which_strand(bio_types::strand::Strand::Forward)
.model_coverage(args.model_coverage)
.build()
}
None => {
info!("setting user-provided filter parameters.");
AlignmentFilters::builder()
.five_prime_clip(args.five_prime_clip)
.three_prime_clip(args.three_prime_clip)
.score_threshold(args.score_threshold)
.min_aligned_fraction(args.min_aligned_fraction)
.min_aligned_len(args.min_aligned_len)
.which_strand(args.strand_filter)
.model_coverage(args.model_coverage)
.build()
}
}
}
fn get_json_info(args: &Args, emi: &EMInfo, seqcol_digest: &str) -> serde_json::Value {
let prob = if args.model_coverage {
"scaled_binomial"
} else {
"no_coverage"
};
json!({
"prob_model" : prob,
"bin_width" : args.bin_width,
"filter_options" : &emi.eq_map.filter_opts,
"discard_table" : &emi.eq_map.discard_table,
"alignments": &args.alignments,
"output": &args.output,
"verbose": &args.verbose,
"quiet": &args.quiet,
"em_max_iter": &args.max_em_iter,
"em_convergence_thresh": &args.convergence_thresh,
"threads": &args.threads,
"filter_group": &args.filter_group,
"short_quant": &args.short_quant,
"num_bootstraps": &args.num_bootstraps,
"seqcol_digest": seqcol_digest
})
}
fn main() -> anyhow::Result<()> {
let env_filter = EnvFilter::builder()
.with_default_directive(LevelFilter::INFO.into())
.from_env_lossy();
let (filtered_layer, reload_handle) = tracing_subscriber::reload::Layer::new(env_filter);
tracing_subscriber::registry()
.with(fmt::layer().with_writer(io::stderr))
.with(filtered_layer)
.init();
let args = Args::parse();
if args.quiet {
reload_handle.modify(|filter| *filter = EnvFilter::new("WARN"))?;
}
if args.verbose {
reload_handle.modify(|filter| *filter = EnvFilter::new("TRACE"))?;
}
let filter_opts = get_filter_opts(&args);
let afile = File::open(&args.alignments)?;
let worker_count = NonZeroUsize::new(1.max(args.threads.saturating_sub(1)))
.expect("decompression threads >= 1");
let decoder = bgzf::MultithreadedReader::with_worker_count(worker_count, afile);
let mut reader = bam::io::Reader::from(decoder);
let header = alignment_parser::read_and_verify_header(&mut reader, &args.alignments)?;
let seqcol_digest = {
info!("calculating seqcol digest");
let sc = seqcol_rs::SeqCol::from_sam_header(
header
.reference_sequences()
.iter()
.map(|(k, v)| (k.as_slice(), v.length().into())),
);
let d = sc.digest(seqcol_rs::DigestConfig::default()).context(
"failed to compute the seqcol digest for the information from the alignment header",
)?;
info!("done calculating seqcol digest");
d
};
let num_ref_seqs = header.reference_sequences().len();
let mut txps: Vec<TranscriptInfo> = Vec::with_capacity(num_ref_seqs);
let mut txps_name: Vec<String> = Vec::with_capacity(num_ref_seqs);
if args.model_coverage {
for (rseq, rmap) in header.reference_sequences().iter() {
txps.push(TranscriptInfo::with_len_and_bin_width(
rmap.length(),
args.bin_width,
));
txps_name.push(rseq.to_string());
}
} else {
for (rseq, rmap) in header.reference_sequences().iter() {
txps.push(TranscriptInfo::with_len(rmap.length()));
txps_name.push(rseq.to_string());
}
}
info!(
"parsed reference information for {} transcripts.",
txps.len()
);
let mut store = InMemoryAlignmentStore::new(filter_opts, &header);
alignment_parser::parse_alignments(&mut store, &header, &mut reader, &mut txps)?;
info!("\ndiscard_table: \n{}\n", store.discard_table.to_table());
drop(reader);
let kde_opt: Option<kders::kde::KDEModel> = if args.use_kde {
Some(kde_utils::get_kde_model(&txps, &store)?)
} else {
None
};
if store.filter_opts.model_coverage {
binomial_continuous_prob(&mut txps, &args.bin_width, args.threads);
normalize_read_probs(&mut store, &txps, &args.bin_width);
}
info!(
"Total number of alignment records : {}",
store.total_len().to_formatted_string(&Locale::en)
);
info!(
"number of aligned reads : {}",
store.num_aligned_reads().to_formatted_string(&Locale::en)
);
info!(
"number of unique alignments : {}",
store.unique_alignments().to_formatted_string(&Locale::en)
);
let init_abundances = args.short_quant.as_ref().map(|sr_path| {
read_short_quant_vec(sr_path, &txps_name).unwrap_or_else(|e| panic!("{}", e))
});
let emi = EMInfo {
eq_map: &store,
txp_info: &txps,
max_iter: args.max_em_iter,
convergence_thresh: args.convergence_thresh,
init_abundances,
kde_model: kde_opt,
};
if args.use_kde {
}
let counts = if args.threads > 4 {
em::em_par(&emi, args.threads)
} else {
em::em(&emi, args.threads)
};
let aux_txp_counts = crate::util::aux_counts::get_aux_counts(&store, &txps)?;
let json_info = get_json_info(&args, &emi, &seqcol_digest);
write_output(&args.output, json_info, &header, &counts, &aux_txp_counts)?;
if args.num_bootstraps > 0 {
let breps = em::bootstrap(&emi, args.num_bootstraps, args.threads);
let mut new_arrays = vec![];
let mut bs_fields = vec![];
for (i, b) in breps.into_iter().enumerate() {
let bs_array = Float64Array::from_vec(b);
bs_fields.push(Field::new(
format!("bootstrap.{}", i),
bs_array.data_type().clone(),
false,
));
new_arrays.push(bs_array.boxed());
}
let chunk = Chunk::new(new_arrays);
write_infrep_file(&args.output, bs_fields, chunk)?;
}
Ok(())
}