use clap::Parser;
use arrow2::{array::Float64Array, chunk::Chunk, datatypes::Field};
use std::{
fs::File,
io::{self, BufReader},
path::PathBuf,
};
use num_format::{Locale, ToFormattedString};
use serde::Serialize;
use serde_json::json;
use tracing::info;
use tracing_subscriber::{filter::LevelFilter, fmt, prelude::*, EnvFilter};
use noodles_bam as bam;
mod alignment_parser;
mod bootstrap;
mod em;
mod util;
use crate::util::binomial_probability::binomial_continuous_prob;
use crate::util::normalize_probability::normalize_read_probs;
use crate::util::oarfish_types::{
AlignmentFilters, EMInfo, InMemoryAlignmentStore, TranscriptInfo,
};
use crate::util::read_function::read_short_quant_vec;
use crate::util::write_function::{write_infrep_file, write_output};
#[derive(Clone, Debug, clap::ValueEnum, Serialize)]
enum FilterGroup {
NoFilters,
NanocountFilters,
}
#[derive(Parser, Debug, Serialize)]
#[clap(author, version, about, long_about = None)]
struct Args {
#[arg(long, conflicts_with = "verbose")]
quiet: bool,
#[arg(long)]
verbose: bool,
#[arg(short, long, required = true)]
alignments: PathBuf,
#[arg(short, long, required = true)]
output: PathBuf,
#[arg(long, help_heading = "filters", value_enum)]
filter_group: Option<FilterGroup>,
#[arg(short, long, conflicts_with = "filter-group", help_heading="filters", default_value_t = u32::MAX as i64)]
three_prime_clip: i64,
#[arg(short, long, conflicts_with = "filter-group", help_heading="filters", default_value_t = u32::MAX)]
five_prime_clip: u32,
#[arg(
short,
long,
conflicts_with = "filter-group",
help_heading = "filters",
default_value_t = 0.95
)]
score_threshold: f32,
#[arg(
short,
long,
conflicts_with = "filter-group",
help_heading = "filters",
default_value_t = 0.5
)]
min_aligned_fraction: f32,
#[arg(
short = 'l',
long,
conflicts_with = "filter-group",
help_heading = "filters",
default_value_t = 50
)]
min_aligned_len: u32,
#[arg(
short = 'n',
long,
conflicts_with = "filter-group",
help_heading = "filters",
value_parser
)]
allow_negative_strand: bool,
#[arg(long, help_heading = "coverage model", value_parser)]
model_coverage: bool,
#[arg(long, help_heading = "EM", default_value_t = 1000)]
max_em_iter: u32,
#[arg(long, help_heading = "EM", default_value_t = 1e-3)]
convergence_thresh: f64,
#[arg(short = 'j', long, default_value_t = 1)]
threads: usize,
#[arg(short = 'q', long, help_heading = "EM")]
short_quant: Option<String>,
#[arg(long, default_value_t = 0)]
num_bootstraps: u32,
#[arg(short, long, help_heading = "coverage model", default_value_t = 10)]
bins: u32,
}
fn get_filter_opts(args: &Args) -> AlignmentFilters {
match args.filter_group {
Some(FilterGroup::NoFilters) => {
info!("disabling alignment filters.");
AlignmentFilters::builder()
.five_prime_clip(u32::MAX)
.three_prime_clip(i64::MAX)
.score_threshold(0_f32)
.min_aligned_fraction(0_f32)
.min_aligned_len(1_u32)
.allow_rc(true)
.model_coverage(args.model_coverage)
.build()
}
Some(FilterGroup::NanocountFilters) => {
info!("setting filters to nanocount defaults.");
AlignmentFilters::builder()
.five_prime_clip(u32::MAX)
.three_prime_clip(50_i64)
.score_threshold(0.95_f32)
.min_aligned_fraction(0.5_f32)
.min_aligned_len(50_u32)
.allow_rc(false)
.model_coverage(args.model_coverage)
.build()
}
None => {
info!("setting user-provided filter parameters.");
AlignmentFilters::builder()
.five_prime_clip(args.five_prime_clip)
.three_prime_clip(args.three_prime_clip)
.score_threshold(args.score_threshold)
.min_aligned_fraction(args.min_aligned_fraction)
.min_aligned_len(args.min_aligned_len)
.allow_rc(args.allow_negative_strand)
.model_coverage(args.model_coverage)
.build()
}
}
}
fn get_json_info(args: &Args, emi: &EMInfo) -> serde_json::Value {
let prob = if args.model_coverage {
"binomial"
} else {
"no_coverage"
};
json!({
"prob_model" : prob,
"num_bins" : args.bins,
"filter_options" : &emi.eq_map.filter_opts,
"discard_table" : &emi.eq_map.discard_table,
"alignments": &args.alignments,
"output": &args.output,
"verbose": &args.verbose,
"quiet": &args.quiet,
"em_max_iter": &args.max_em_iter,
"em_convergence_thresh": &args.convergence_thresh,
"threads": &args.threads,
"filter_group": &args.filter_group,
"short_quant": &args.short_quant,
"num_bootstraps": &args.num_bootstraps
})
}
fn main() -> anyhow::Result<()> {
let env_filter = EnvFilter::builder()
.with_default_directive(LevelFilter::INFO.into())
.from_env_lossy();
let (filtered_layer, reload_handle) = tracing_subscriber::reload::Layer::new(env_filter);
tracing_subscriber::registry()
.with(fmt::layer().with_writer(io::stderr))
.with(filtered_layer)
.init();
let args = Args::parse();
if args.quiet {
reload_handle.modify(|filter| *filter = EnvFilter::new("WARN"))?;
}
if args.verbose {
reload_handle.modify(|filter| *filter = EnvFilter::new("TRACE"))?;
}
let filter_opts = get_filter_opts(&args);
let mut reader = File::open(&args.alignments)
.map(BufReader::new)
.map(bam::io::Reader::new)?;
let header = alignment_parser::read_and_verify_header(&mut reader, &args.alignments)?;
let num_ref_seqs = header.reference_sequences().len();
let mut txps: Vec<TranscriptInfo> = Vec::with_capacity(num_ref_seqs);
let mut txps_name: Vec<String> = Vec::with_capacity(num_ref_seqs);
if args.model_coverage {
for (rseq, rmap) in header.reference_sequences().iter() {
txps.push(TranscriptInfo::with_len_and_bins(rmap.length(), args.bins));
txps_name.push(rseq.to_string());
}
} else {
for (rseq, rmap) in header.reference_sequences().iter() {
txps.push(TranscriptInfo::with_len(rmap.length()));
txps_name.push(rseq.to_string());
}
}
info!(
"parsed reference information for {} transcripts.",
txps.len()
);
let mut store = InMemoryAlignmentStore::new(filter_opts, &header);
alignment_parser::parse_alignments(&mut store, &header, &mut reader, &mut txps)?;
info!("\ndiscard_table: \n{}\n", store.discard_table.to_table());
if store.filter_opts.model_coverage {
binomial_continuous_prob(&mut txps, &args.bins, args.threads);
normalize_read_probs(&mut store, &txps, &args.bins);
}
info!(
"Total number of alignment records : {}",
store.total_len().to_formatted_string(&Locale::en)
);
info!(
"number of aligned reads : {}",
store.num_aligned_reads().to_formatted_string(&Locale::en)
);
let init_abundances = args.short_quant.as_ref().map(|sr_path| {
read_short_quant_vec(sr_path, &txps_name).unwrap_or_else(|e| panic!("{}", e))
});
let emi = EMInfo {
eq_map: &store,
txp_info: &mut txps,
max_iter: args.max_em_iter,
convergence_thresh: args.convergence_thresh,
init_abundances,
};
let counts = if args.threads > 4 {
em::em_par(&emi, args.threads)
} else {
em::em(&emi, args.threads)
};
let json_info = get_json_info(&args, &emi);
write_output(&args.output, json_info, &header, &counts)?;
if args.num_bootstraps > 0 {
let breps = em::bootstrap(&emi, args.num_bootstraps, args.threads);
let mut new_arrays = vec![];
let mut bs_fields = vec![];
for (i, b) in breps.into_iter().enumerate() {
let bs_array = Float64Array::from_vec(b);
bs_fields.push(Field::new(
format!("bootstrap.{}", i),
bs_array.data_type().clone(),
false,
));
new_arrays.push(bs_array.boxed());
}
let chunk = Chunk::new(new_arrays);
write_infrep_file(&args.output, bs_fields, chunk)?;
}
Ok(())
}