use noodles_bam as bam;
use noodles_core::{Position, Region};
use rayon::prelude::*;
use std::fs::File;
use std::io::BufReader;
use std::sync::Arc;
#[derive(Clone)]
pub struct ReadStats {
pub n50: u32,
pub mean_len: f64,
pub median_len: f64,
pub mean_qual: f64,
pub median_qual: f64,
pub num_reads: u64,
pub num_bases: u64,
pub lengths: Option<Vec<u32>>,
}
#[derive(Default)]
struct PartialStats {
lengths: Vec<u32>,
quals: Vec<f64>,
num_reads: u64,
num_bases: u64,
}
type StatResult<T> = Result<T, Box<dyn std::error::Error + Send + Sync>>;
#[hotpath::measure]
pub fn extract_read_stats(
bam_path: &std::path::Path,
) -> Result<ReadStats, Box<dyn std::error::Error>> {
let mut reader = bam::io::Reader::new(BufReader::new(File::open(bam_path)?));
let header = reader.read_header()?;
let reference_sequences: Vec<(String, u32)> = header
.reference_sequences()
.iter()
.map(|(name, ref_seq)| (name.to_string(), ref_seq.length().get() as u32))
.collect();
let header = Arc::new(header);
let index = Arc::new(bam::bai::fs::read(bam_path.with_extension("bam.bai"))?);
let worker_count = std::cmp::max(1, num_cpus::get() / 2);
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(worker_count)
.build()?;
let partials_result: StatResult<Vec<PartialStats>> =
pool.install(|| -> StatResult<Vec<PartialStats>> {
reference_sequences
.par_iter()
.map(|(chrom, chrom_length)| -> StatResult<PartialStats> {
let mut reader = bam::io::Reader::new(BufReader::new(File::open(bam_path)?));
let region = Region::new(
chrom.clone(),
Position::try_from(1usize)
.map_err(|e| format!("Invalid start position: {e}"))?
..=Position::try_from(*chrom_length as usize)
.map_err(|e| format!("Invalid end position: {e}"))?,
);
let query = reader.query(header.as_ref(), index.as_ref(), ®ion)?;
let mut stats = PartialStats::default();
for result in query {
let record = result?;
let len = record.sequence().len() as u32;
stats.lengths.push(len);
stats.num_reads += 1;
stats.num_bases += len as u64;
let qual = record.quality_scores();
let qual_slice = qual.as_ref();
if !qual_slice.is_empty() {
let q: f64 = qual_slice.iter().map(|&q| q as u32).sum::<u32>() as f64
/ qual_slice.len() as f64;
stats.quals.push(q);
}
}
Ok(stats)
})
.collect()
});
let mut partials = partials_result.map_err(|e| e as Box<dyn std::error::Error>)?;
{
let mut reader = bam::io::Reader::new(BufReader::new(File::open(bam_path)?));
let query = reader.query_unmapped(index.as_ref())?;
let mut unmapped = PartialStats::default();
for result in query {
let record = result?;
let len = record.sequence().len() as u32;
unmapped.lengths.push(len);
unmapped.num_reads += 1;
unmapped.num_bases += len as u64;
let qual = record.quality_scores();
let qual_slice = qual.as_ref();
if !qual_slice.is_empty() {
let q: f64 = qual_slice.iter().map(|&q| q as u32).sum::<u32>() as f64
/ qual_slice.len() as f64;
unmapped.quals.push(q);
}
}
partials.push(unmapped);
}
let mut lengths = Vec::new();
let mut quals = Vec::new();
let mut num_reads: u64 = 0;
let mut num_bases: u64 = 0;
for part in partials {
lengths.extend(part.lengths);
quals.extend(part.quals);
num_reads += part.num_reads;
num_bases += part.num_bases;
}
lengths.sort_unstable_by(|a, b| b.cmp(a));
let total: u64 = lengths.iter().map(|&l| l as u64).sum();
let mut acc = 0;
let mut n50 = 0;
for &l in &lengths {
acc += l as u64;
if acc >= total / 2 {
n50 = l;
break;
}
}
let mean_len = if lengths.is_empty() {
0.0
} else {
lengths.iter().sum::<u32>() as f64 / lengths.len() as f64
};
let median_len = if lengths.is_empty() {
0.0
} else {
let mid = lengths.len() / 2;
if lengths.len() % 2 == 0 {
(lengths[mid - 1] + lengths[mid]) as f64 / 2.0
} else {
lengths[mid] as f64
}
};
let mean_qual = if quals.is_empty() {
0.0
} else {
quals.iter().sum::<f64>() / quals.len() as f64
};
let median_qual = if quals.is_empty() {
0.0
} else {
let mut sorted = quals.clone();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
let mid = sorted.len() / 2;
if sorted.len() % 2 == 0 {
(sorted[mid - 1] + sorted[mid]) / 2.0
} else {
sorted[mid]
}
};
Ok(ReadStats {
n50,
mean_len,
median_len,
mean_qual,
median_qual,
num_reads,
num_bases,
lengths: Some(lengths),
})
}