use std::path::PathBuf;
use anyhow::Result;
use clap::Args;
use kuva::plot::legend::LegendPosition;
use kuva::plot::{Histogram, LinePlot};
use kuva::render::annotations::{Orientation, ShadedRegion};
use kuva::render::layout::{Layout, TickFormat};
use kuva::render::plots::Plot;
use noodles::sam::Header;
use riker_derive::MetricDocs;
use serde::{Deserialize, Serialize};
use crate::collector::{Collector, drive_collector_single_threaded};
use crate::commands::command::Command;
use crate::commands::common::{InputOptions, OptionalReferenceOptions, OutputOptions};
use crate::metrics::write_tsv;
use crate::plotting::{
FG_BLUE, FG_GREEN, FG_PACIFIC, FG_RED, FG_TEAL, PLOT_HEIGHT, PLOT_WIDTH, write_plot_pdf,
};
use crate::progress::ProgressLogger;
use crate::sam::alignment_reader::AlignmentReader;
use crate::sam::record_utils::derive_sample;
use crate::sam::riker_record::{RikerRecord, RikerRecordRequirements};
pub const BASE_DIST_SUFFIX: &str = ".base-distribution-by-cycle.txt";
pub const MEAN_QUAL_SUFFIX: &str = ".mean-quality-by-cycle.txt";
pub const QUAL_DIST_SUFFIX: &str = ".quality-score-distribution.txt";
pub const BASE_DIST_PLOT_SUFFIX: &str = ".base-distribution-by-cycle.pdf";
pub const MEAN_QUAL_PLOT_SUFFIX: &str = ".mean-quality-by-cycle.pdf";
pub const QUAL_DIST_PLOT_SUFFIX: &str = ".quality-score-distribution.pdf";
#[derive(Args, Debug, Clone)]
#[command(
long_about,
after_long_help = "\
Examples:
riker basic -i input.bam -o out_prefix"
)]
pub struct Basic {
#[command(flatten)]
pub input: InputOptions,
#[command(flatten)]
pub output: OutputOptions,
#[command(flatten)]
pub reference: OptionalReferenceOptions,
}
impl Command for Basic {
fn execute(&self) -> Result<()> {
let mut reader =
AlignmentReader::open(&self.input.input, self.reference.reference.as_deref())?;
let mut collector = BasicCollector::new(&self.input.input, &self.output.output);
let mut progress = ProgressLogger::new("basic", "reads", 5_000_000);
drive_collector_single_threaded(&mut reader, &mut collector, &mut progress)
}
}
const BASE_BITS: u8 = 0x1F;
const N_BASE_SLOTS: usize = 32;
const IDX_A: usize = (b'A' & BASE_BITS) as usize;
const IDX_C: usize = (b'C' & BASE_BITS) as usize;
const IDX_G: usize = (b'G' & BASE_BITS) as usize;
const IDX_T: usize = (b'T' & BASE_BITS) as usize;
const ACGT_BITMASK: u32 = (1 << IDX_A) | (1 << IDX_C) | (1 << IDX_G) | (1 << IDX_T);
#[derive(Clone, Default, Debug)]
struct CycleStats {
qual_sum: u64,
base_counts: [u64; N_BASE_SLOTS],
}
impl CycleStats {
fn total(&self) -> u64 {
self.base_counts.iter().sum()
}
}
pub struct BasicCollector {
input: PathBuf,
base_dist_path: PathBuf,
mean_qual_path: PathBuf,
qual_dist_path: PathBuf,
base_dist_plot_path: PathBuf,
mean_qual_plot_path: PathBuf,
qual_dist_plot_path: PathBuf,
sample: String,
plot_title_prefix: String,
r1_cycles: Vec<CycleStats>,
r2_cycles: Vec<CycleStats>,
qual_counts: [[u64; 128]; 4],
}
impl BasicCollector {
#[must_use]
pub fn new(input: &std::path::Path, prefix: &std::path::Path) -> Self {
Self {
input: input.to_path_buf(),
base_dist_path: super::command::output_path(prefix, BASE_DIST_SUFFIX),
mean_qual_path: super::command::output_path(prefix, MEAN_QUAL_SUFFIX),
qual_dist_path: super::command::output_path(prefix, QUAL_DIST_SUFFIX),
base_dist_plot_path: super::command::output_path(prefix, BASE_DIST_PLOT_SUFFIX),
mean_qual_plot_path: super::command::output_path(prefix, MEAN_QUAL_PLOT_SUFFIX),
qual_dist_plot_path: super::command::output_path(prefix, QUAL_DIST_PLOT_SUFFIX),
sample: String::new(),
plot_title_prefix: String::new(),
r1_cycles: Vec::new(),
r2_cycles: Vec::new(),
qual_counts: [[0u64; 128]; 4],
}
}
fn ensure_capacity(cycles: &mut Vec<CycleStats>, len: usize) {
if len > cycles.len() {
cycles.resize(len, CycleStats::default());
}
}
#[inline]
fn process_record<const REVERSE: bool>(
seq: &[u8],
quals: &[u8],
cycles: &mut [CycleStats],
qual_counts: &mut [[u64; 128]; 4],
) {
let n = seq.len();
let cycles = &mut cycles[..n];
for i in 0..n {
let base = seq[i];
let q = quals[i];
let cycle_idx = if REVERSE { n - 1 - i } else { i };
let bi = (base & BASE_BITS) as usize;
let stats = &mut cycles[cycle_idx];
stats.base_counts[bi] += 1;
stats.qual_sum += u64::from(q);
if (ACGT_BITMASK >> bi) & 1 != 0 {
let qi = (q & 0x7F) as usize;
qual_counts[i & 3][qi] += 1;
}
}
}
fn build_base_dist_metrics(&self) -> Vec<BaseDistributionByCycleMetric> {
let mut metrics = Vec::new();
for (read_end, cycles) in [(1u8, &self.r1_cycles), (2u8, &self.r2_cycles)] {
for (i, c) in cycles.iter().enumerate() {
let total = c.total();
if total == 0 {
continue;
}
let t = total as f64;
let a = c.base_counts[IDX_A];
let cnt_c = c.base_counts[IDX_C];
let g = c.base_counts[IDX_G];
let cnt_t = c.base_counts[IDX_T];
debug_assert!(a + cnt_c + g + cnt_t <= total);
let n = total - (a + cnt_c + g + cnt_t);
metrics.push(BaseDistributionByCycleMetric {
sample: self.sample.clone(),
read_end,
cycle: i + 1,
frac_a: a as f64 / t,
frac_c: cnt_c as f64 / t,
frac_g: g as f64 / t,
frac_t: cnt_t as f64 / t,
frac_n: n as f64 / t,
});
}
}
metrics
}
fn build_mean_qual_metrics(&self) -> Vec<MeanQualityByCycleMetric> {
let mut metrics = Vec::new();
let r1_max = self.r1_cycles.len();
for (i, c) in self.r1_cycles.iter().enumerate() {
let total = c.total();
if total == 0 {
continue;
}
metrics.push(MeanQualityByCycleMetric {
sample: self.sample.clone(),
cycle: i + 1,
mean_quality: c.qual_sum as f64 / total as f64,
});
}
for (i, c) in self.r2_cycles.iter().enumerate() {
let total = c.total();
if total == 0 {
continue;
}
metrics.push(MeanQualityByCycleMetric {
sample: self.sample.clone(),
cycle: r1_max + i + 1,
mean_quality: c.qual_sum as f64 / total as f64,
});
}
metrics
}
#[allow(clippy::cast_possible_truncation)] fn build_qual_dist_metrics(&self) -> Vec<QualityScoreDistributionMetric> {
let mut combined = [0u64; 128];
for bank in &self.qual_counts {
for (slot, &c) in bank.iter().enumerate() {
combined[slot] += c;
}
}
let total: u64 = combined.iter().sum();
let total_f = total as f64;
combined
.iter()
.enumerate()
.filter(|&(_, count)| *count > 0)
.map(|(q, count)| QualityScoreDistributionMetric {
sample: self.sample.clone(),
quality: q as u8,
count: *count,
frac_bases: if total > 0 { *count as f64 / total_f } else { 0.0 },
})
.collect()
}
fn plot_base_distribution(&self, metrics: &[BaseDistributionByCycleMetric]) -> Result<()> {
if metrics.is_empty() {
return Ok(());
}
let r1_max = self.r1_cycles.len();
let max_cycle = (r1_max + self.r2_cycles.len()) as f64;
let base_names = ["A", "C", "G", "T", "N"];
let base_colors = [FG_BLUE, FG_GREEN, FG_TEAL, FG_PACIFIC, FG_RED];
let plots: Vec<Plot> = (0..5)
.map(|base_idx| {
let xy: Vec<(f64, f64)> = metrics
.iter()
.map(|m| {
let x = if m.read_end == 1 {
m.cycle as f64
} else {
(r1_max + m.cycle) as f64
};
let y = match base_idx {
0 => m.frac_a,
1 => m.frac_c,
2 => m.frac_g,
3 => m.frac_t,
_ => m.frac_n,
};
(x, y)
})
.collect();
Plot::Line(
LinePlot::new()
.with_data(xy)
.with_color(base_colors[base_idx])
.with_legend(base_names[base_idx]),
)
})
.collect();
let mut layout = Layout::auto_from_plots(&plots)
.with_width(PLOT_WIDTH)
.with_height(PLOT_HEIGHT)
.with_title(format!("{} Base Distribution by Cycle", self.plot_title_prefix))
.with_x_label("Cycle")
.with_y_label("Fraction")
.with_x_axis_max(max_cycle + 1.0)
.with_legend_position(LegendPosition::OutsideRightMiddle);
if !self.r2_cycles.is_empty() {
layout.shaded_regions.push(ShadedRegion {
orientation: Orientation::Vertical,
min_val: 1.0,
max_val: r1_max as f64 + 0.5,
color: FG_BLUE.to_owned(),
opacity: 0.1,
});
layout.shaded_regions.push(ShadedRegion {
orientation: Orientation::Vertical,
min_val: r1_max as f64 + 0.5,
max_val: max_cycle + 1.0,
color: FG_TEAL.to_owned(),
opacity: 0.1,
});
}
write_plot_pdf(plots, layout, &self.base_dist_plot_path)
}
fn plot_mean_quality(&self, metrics: &[MeanQualityByCycleMetric]) -> Result<()> {
if metrics.is_empty() {
return Ok(());
}
let r1_max = self.r1_cycles.len();
let has_r2 = !self.r2_cycles.is_empty();
let r1_xy: Vec<(f64, f64)> = metrics
.iter()
.filter(|m| m.cycle <= r1_max)
.map(|m| (m.cycle as f64, m.mean_quality))
.collect();
let mut plots = vec![Plot::Line(
LinePlot::new().with_data(r1_xy).with_color(FG_BLUE).with_fill().with_fill_opacity(0.3),
)];
if has_r2 {
let r2_xy: Vec<(f64, f64)> = metrics
.iter()
.filter(|m| m.cycle > r1_max)
.map(|m| (m.cycle as f64, m.mean_quality))
.collect();
plots.push(Plot::Line(
LinePlot::new()
.with_data(r2_xy)
.with_color(FG_TEAL)
.with_fill()
.with_fill_opacity(0.3),
));
}
let max_cycle = (r1_max + self.r2_cycles.len()) as f64;
let layout = Layout::auto_from_plots(&plots)
.with_width(PLOT_WIDTH)
.with_height(PLOT_HEIGHT)
.with_title(format!("{} Mean Quality by Cycle", self.plot_title_prefix))
.with_x_label("Cycle")
.with_y_label("Mean Quality")
.with_x_axis_max(max_cycle + 1.0);
write_plot_pdf(plots, layout, &self.mean_qual_plot_path)
}
fn plot_qual_distribution(&self, metrics: &[QualityScoreDistributionMetric]) -> Result<()> {
if metrics.is_empty() {
return Ok(());
}
let (edges, counts) = Self::qual_dist_histogram_bins(metrics);
let x_axis_max = *edges.last().expect("histogram bins always produces non-empty edges");
let plots = vec![Plot::Histogram(Histogram::from_bins(edges, counts).with_color(FG_BLUE))];
let layout = Layout::auto_from_plots(&plots)
.with_width(PLOT_WIDTH)
.with_height(PLOT_HEIGHT)
.with_title(format!("{} Quality Score Distribution", self.plot_title_prefix))
.with_x_label("Quality Score")
.with_y_label("Fraction of Bases")
.with_x_axis_min(0.0)
.with_x_axis_max(x_axis_max)
.with_x_tick_step(5.0)
.with_minor_ticks(5)
.with_x_tick_format(TickFormat::Integer);
write_plot_pdf(plots, layout, &self.qual_dist_plot_path)
}
fn qual_dist_histogram_bins(
metrics: &[QualityScoreDistributionMetric],
) -> (Vec<f64>, Vec<f64>) {
let max_quality = metrics.iter().map(|m| usize::from(m.quality)).max().unwrap_or(45);
let upper = max_quality.max(45);
let mut frac_by_q = vec![0.0_f64; upper + 1];
for m in metrics {
let q = usize::from(m.quality);
if q < frac_by_q.len() {
frac_by_q[q] = m.frac_bases;
}
}
let edges: Vec<f64> = (0..=upper + 1).map(|q| q as f64).collect();
(edges, frac_by_q)
}
}
impl Collector for BasicCollector {
fn initialize(&mut self, header: &Header) -> Result<()> {
let label = derive_sample(&self.input, header);
self.plot_title_prefix.clone_from(&label);
self.sample = label;
Ok(())
}
fn accept(&mut self, record: &RikerRecord, _header: &Header) -> Result<()> {
let flags = record.flags();
if flags.is_secondary() || flags.is_supplementary() || flags.is_qc_fail() {
return Ok(());
}
let seq: &[u8] = record.sequence();
let quals: &[u8] = record.quality_scores();
let n = seq.len().min(quals.len());
if n == 0 {
return Ok(());
}
let is_reverse = flags.is_reverse_complemented();
let is_r2 = flags.is_segmented() && flags.is_last_segment();
let cycles = if is_r2 { &mut self.r2_cycles } else { &mut self.r1_cycles };
Self::ensure_capacity(cycles, n);
let seq = &seq[..n];
let quals = &quals[..n];
if is_reverse {
Self::process_record::<true>(seq, quals, cycles, &mut self.qual_counts);
} else {
Self::process_record::<false>(seq, quals, cycles, &mut self.qual_counts);
}
Ok(())
}
fn finish(&mut self) -> Result<()> {
let base_dist = self.build_base_dist_metrics();
let mean_qual = self.build_mean_qual_metrics();
let qual_dist = self.build_qual_dist_metrics();
write_tsv(&self.base_dist_path, &base_dist)?;
write_tsv(&self.mean_qual_path, &mean_qual)?;
write_tsv(&self.qual_dist_path, &qual_dist)?;
self.plot_base_distribution(&base_dist)?;
self.plot_mean_quality(&mean_qual)?;
self.plot_qual_distribution(&qual_dist)?;
Ok(())
}
fn name(&self) -> &'static str {
"basic"
}
fn field_needs(&self) -> RikerRecordRequirements {
RikerRecordRequirements::NONE.with_sequence()
}
}
#[derive(Debug, Serialize, Deserialize, MetricDocs)]
pub struct BaseDistributionByCycleMetric {
pub sample: String,
pub read_end: u8,
pub cycle: usize,
#[serde(serialize_with = "crate::metrics::serialize_f64_5dp")]
pub frac_a: f64,
#[serde(serialize_with = "crate::metrics::serialize_f64_5dp")]
pub frac_c: f64,
#[serde(serialize_with = "crate::metrics::serialize_f64_5dp")]
pub frac_g: f64,
#[serde(serialize_with = "crate::metrics::serialize_f64_5dp")]
pub frac_t: f64,
#[serde(serialize_with = "crate::metrics::serialize_f64_5dp")]
pub frac_n: f64,
}
#[derive(Debug, Serialize, Deserialize, MetricDocs)]
pub struct MeanQualityByCycleMetric {
pub sample: String,
pub cycle: usize,
#[serde(serialize_with = "crate::metrics::serialize_f64_2dp")]
pub mean_quality: f64,
}
#[derive(Debug, Serialize, Deserialize, MetricDocs)]
pub struct QualityScoreDistributionMetric {
pub sample: String,
pub quality: u8,
pub count: u64,
#[serde(serialize_with = "crate::metrics::serialize_f64_5dp")]
pub frac_bases: f64,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_base_indices_distinct() {
let idxs = [IDX_A, IDX_C, IDX_G, IDX_T];
for (i, a) in idxs.iter().enumerate() {
for b in &idxs[i + 1..] {
assert_ne!(a, b, "base index collision: {a} == {b}");
}
assert!(*a < N_BASE_SLOTS);
}
}
#[test]
fn test_base_mask_is_case_insensitive() {
for (upper, lower) in [(b'A', b'a'), (b'C', b'c'), (b'G', b'g'), (b'T', b't')] {
assert_eq!(upper & BASE_BITS, lower & BASE_BITS);
}
}
#[test]
fn test_base_mask_no_collisions_with_iupac_or_n() {
let canonical = [(b'A', IDX_A), (b'C', IDX_C), (b'G', IDX_G), (b'T', IDX_T)];
let non_acgt = b"=MRSVWYHKDBN";
for &b in non_acgt {
let slot = (b & BASE_BITS) as usize;
for &(_, canon) in &canonical {
assert_ne!(
slot, canon,
"non-ACGT byte 0x{b:02x} ({}) collides with canonical slot {canon}",
b as char
);
}
}
assert_eq!(b'N' & BASE_BITS, b'n' & BASE_BITS);
}
#[test]
fn test_qual_dist_histogram_bins_includes_top_quality_bin() {
let metrics = vec![
QualityScoreDistributionMetric {
sample: "s".to_string(),
quality: 30,
count: 100,
frac_bases: 0.4,
},
QualityScoreDistributionMetric {
sample: "s".to_string(),
quality: 60,
count: 150,
frac_bases: 0.6,
},
];
let (edges, counts) = BasicCollector::qual_dist_histogram_bins(&metrics);
assert_eq!(edges.len(), counts.len() + 1);
assert_eq!(counts.len(), 61);
assert!((counts[60] - 0.6).abs() < 1e-9);
assert!((counts[30] - 0.4).abs() < 1e-9);
assert!((edges[0] - 0.0).abs() < 1e-9);
assert!((edges[edges.len() - 1] - 61.0).abs() < 1e-9);
}
#[test]
fn test_qual_dist_histogram_bins_extends_to_q45_floor() {
let metrics = vec![QualityScoreDistributionMetric {
sample: "s".to_string(),
quality: 10,
count: 100,
frac_bases: 1.0,
}];
let (edges, counts) = BasicCollector::qual_dist_histogram_bins(&metrics);
assert_eq!(counts.len(), 46);
assert_eq!(edges.len(), 47);
assert!((edges[edges.len() - 1] - 46.0).abs() < 1e-9);
}
#[test]
fn test_acgt_bitmask_membership() {
for &b in b"ACGTacgt" {
let bi = (b & BASE_BITS) as usize;
assert_ne!(
(ACGT_BITMASK >> bi) & 1,
0,
"byte 0x{b:02x} ({}) should be in ACGT_BITMASK",
b as char
);
}
for &b in b"NnWSMKRYBDHV=" {
let bi = (b & BASE_BITS) as usize;
assert_eq!(
(ACGT_BITMASK >> bi) & 1,
0,
"non-ACGT byte 0x{b:02x} ({}) should not be in ACGT_BITMASK",
b as char
);
}
}
}