use ahash::{HashMap, HashSet};
use anyhow::{Context, Result};
use comfy_table::{
Cell, ContentArrangement, Table, modifiers::UTF8_ROUND_CORNERS, presets::UTF8_FULL,
};
use noodles::sam;
use noodles::sam::alignment::record::data::field::Value;
use noodles::sam::alignment::record::data::field::tag::Tag;
use rust_lapper::Lapper;
use std::fmt::Display;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use crate::bam_utils::CB;
#[derive(Debug)]
pub struct BamReadFilterStats {
n_total: AtomicU64,
n_failed_proper_pair: AtomicU64,
n_failed_mapq: AtomicU64,
n_failed_length: AtomicU64,
n_failed_blacklist: AtomicU64,
n_failed_barcode: AtomicU64,
n_not_in_read_group: AtomicU64,
n_incorrect_strand: AtomicU64,
n_failed_tag_filter: AtomicU64,
n_failed_fragment_length: AtomicU64,
}
impl Default for BamReadFilterStats {
fn default() -> Self {
Self::new()
}
}
impl BamReadFilterStats {
pub fn new() -> Self {
Self {
n_total: AtomicU64::new(0),
n_failed_proper_pair: AtomicU64::new(0),
n_failed_mapq: AtomicU64::new(0),
n_failed_length: AtomicU64::new(0),
n_failed_blacklist: AtomicU64::new(0),
n_failed_barcode: AtomicU64::new(0),
n_not_in_read_group: AtomicU64::new(0),
n_incorrect_strand: AtomicU64::new(0),
n_failed_tag_filter: AtomicU64::new(0),
n_failed_fragment_length: AtomicU64::new(0),
}
}
pub fn snapshot(&self) -> BamReadFilterStatsSnapshot {
BamReadFilterStatsSnapshot {
n_total: self.n_total.load(Ordering::Relaxed),
n_failed_proper_pair: self.n_failed_proper_pair.load(Ordering::Relaxed),
n_failed_mapq: self.n_failed_mapq.load(Ordering::Relaxed),
n_failed_length: self.n_failed_length.load(Ordering::Relaxed),
n_failed_blacklist: self.n_failed_blacklist.load(Ordering::Relaxed),
n_failed_barcode: self.n_failed_barcode.load(Ordering::Relaxed),
n_not_in_read_group: self.n_not_in_read_group.load(Ordering::Relaxed),
n_incorrect_strand: self.n_incorrect_strand.load(Ordering::Relaxed),
n_failed_tag_filter: self.n_failed_tag_filter.load(Ordering::Relaxed),
n_failed_fragment_length: self.n_failed_fragment_length.load(Ordering::Relaxed),
}
}
}
#[derive(Clone, Copy, Debug)]
pub struct BamReadFilterStatsSnapshot {
n_total: u64,
n_failed_proper_pair: u64,
n_failed_mapq: u64,
n_failed_length: u64,
n_failed_blacklist: u64,
n_failed_barcode: u64,
n_not_in_read_group: u64,
n_incorrect_strand: u64,
n_failed_tag_filter: u64,
n_failed_fragment_length: u64,
}
impl Display for BamReadFilterStatsSnapshot {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let rows = self.stage_rows();
let mut table = Table::new();
table
.load_preset(UTF8_FULL)
.apply_modifier(UTF8_ROUND_CORNERS)
.set_content_arrangement(ContentArrangement::Dynamic)
.set_header(vec!["stage", "remain", "drop", "% total"]);
for (label, remain, dropped, pct) in rows {
let dropped = if dropped == 0 {
"-".to_string()
} else {
format!("-{}", Self::format_count(dropped))
};
table.add_row(vec![
Cell::new(label),
Cell::new(Self::format_count(remain))
.set_alignment(comfy_table::CellAlignment::Right),
Cell::new(dropped).set_alignment(comfy_table::CellAlignment::Right),
Cell::new(format!("{pct:.1}%")).set_alignment(comfy_table::CellAlignment::Right),
]);
}
writeln!(f, "Read Filtering Funnel")?;
write!(f, "{table}")?;
Ok(())
}
}
impl BamReadFilterStatsSnapshot {
fn format_count(count: u64) -> String {
let digits = count.to_string();
let mut out = String::with_capacity(digits.len() + digits.len() / 3);
for (i, ch) in digits.chars().rev().enumerate() {
if i > 0 && i % 3 == 0 {
out.push(',');
}
out.push(ch);
}
out.chars().rev().collect()
}
pub fn n_filtered(&self) -> u64 {
self.n_total - self.n_reads_after_filtering()
}
pub fn pass_rate(&self) -> f64 {
if self.n_total == 0 {
0.0
} else {
(self.n_reads_after_filtering() as f64 / self.n_total as f64) * 100.0
}
}
fn pct_of_total(&self, count: u64) -> f64 {
if self.n_total == 0 {
0.0
} else {
(count as f64 / self.n_total as f64) * 100.0
}
}
pub fn stage_rows(&self) -> Vec<(&'static str, u64, u64, f64)> {
let mut remaining = self.n_total;
let mut rows = vec![("Initial reads", remaining, 0, 100.0)];
let failures = [
("Pass strand filter", self.n_incorrect_strand),
("Pass proper-pair filter", self.n_failed_proper_pair),
("Pass minimum MAPQ", self.n_failed_mapq),
("Pass read-length filter", self.n_failed_length),
("Pass fragment-length filter", self.n_failed_fragment_length),
("Outside blacklist", self.n_failed_blacklist),
("Pass barcode allowlist", self.n_failed_barcode),
("Pass read-group filter", self.n_not_in_read_group),
("Pass tag filter", self.n_failed_tag_filter),
];
for (label, dropped) in failures {
if dropped > 0 {
remaining -= dropped;
rows.push((label, remaining, dropped, self.pct_of_total(remaining)));
}
}
rows.push((
"Reads retained",
self.n_reads_after_filtering(),
0,
self.pass_rate(),
));
rows
}
pub fn n_reads_after_filtering(&self) -> u64 {
self.n_total
- (self.n_failed_proper_pair
+ self.n_failed_mapq
+ self.n_failed_length
+ self.n_failed_blacklist
+ self.n_failed_barcode
+ self.n_not_in_read_group
+ self.n_incorrect_strand
+ self.n_failed_tag_filter
+ self.n_failed_fragment_length)
}
pub fn n_total(&self) -> u64 {
self.n_total
}
pub fn n_failed_mapq(&self) -> u64 {
self.n_failed_mapq
}
pub fn n_failed_length(&self) -> u64 {
self.n_failed_length
}
pub fn n_incorrect_strand(&self) -> u64 {
self.n_incorrect_strand
}
pub fn n_failed_proper_pair(&self) -> u64 {
self.n_failed_proper_pair
}
pub fn n_failed_blacklist(&self) -> u64 {
self.n_failed_blacklist
}
pub fn n_failed_barcode(&self) -> u64 {
self.n_failed_barcode
}
pub fn n_not_in_read_group(&self) -> u64 {
self.n_not_in_read_group
}
pub fn n_failed_tag_filter(&self) -> u64 {
self.n_failed_tag_filter
}
pub fn n_failed_fragment_length(&self) -> u64 {
self.n_failed_fragment_length
}
}
#[derive(Debug, Clone)]
pub struct BamReadFilter {
strand: bio_types::strand::Strand,
proper_pair: bool,
min_mapq: u8,
min_length: u32,
max_length: u32,
blacklisted_locations: Option<HashMap<usize, Lapper<usize, u32>>>,
whitelisted_barcodes: Option<HashSet<String>>,
read_group: Option<String>,
filter_tag: Option<String>,
filter_tag_value: Option<String>,
min_fragment_length: Option<u32>,
max_fragment_length: Option<u32>,
stats: Arc<BamReadFilterStats>,
}
impl Display for BamReadFilter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "\tOnly allowing reads with strand: {}", self.strand)?;
writeln!(f, "\tProper pair: {}", self.proper_pair)?;
writeln!(f, "\tMinimum mapping quality: {}", self.min_mapq)?;
writeln!(f, "\tMinimum read length: {}", self.min_length)?;
writeln!(f, "\tMaximum read length: {}", self.max_length)?;
match &self.blacklisted_locations {
Some(blacklisted_locations) => {
writeln!(
f,
"\tNumber of Blacklisted locations: {}",
blacklisted_locations
.values()
.map(|v| v.len())
.sum::<usize>()
)?;
}
None => {
writeln!(f, "\tNumber of Blacklisted locations: 0")?;
}
}
match &self.whitelisted_barcodes {
Some(whitelisted_barcodes) => {
writeln!(
f,
"\tNumber of Whitelisted barcodes: {}",
whitelisted_barcodes.len()
)?;
}
None => {
writeln!(f, "\tNumber of Whitelisted barcodes: 0")?;
}
}
match (&self.filter_tag, &self.filter_tag_value) {
(Some(tag), Some(value)) => {
writeln!(f, "\tFilter tag: {} = {}", tag, value)?;
}
_ => {
writeln!(f, "\tFilter tag: None")?;
}
}
Ok(())
}
}
impl Default for BamReadFilter {
fn default() -> Self {
Self::new(
bio_types::strand::Strand::Unknown,
true,
Some(0),
Some(0),
Some(1000),
None,
None,
None,
None,
None,
None,
None,
)
}
}
impl BamReadFilter {
#[allow(clippy::too_many_arguments)]
pub fn new(
strand: bio_types::strand::Strand,
proper_pair: bool,
min_mapq: Option<u8>,
min_length: Option<u32>,
max_length: Option<u32>,
read_group: Option<String>,
blacklisted_locations: Option<HashMap<usize, Lapper<usize, u32>>>,
whitelisted_barcodes: Option<HashSet<String>>,
filter_tag: Option<String>,
filter_tag_value: Option<String>,
min_fragment_length: Option<u32>,
max_fragment_length: Option<u32>,
) -> Self {
let min_mapq = min_mapq.unwrap_or(0);
let min_length = min_length.unwrap_or(0);
let max_length = max_length.unwrap_or(u32::MAX);
Self {
strand,
proper_pair,
min_mapq,
min_length,
max_length,
blacklisted_locations,
whitelisted_barcodes,
read_group,
filter_tag,
filter_tag_value,
min_fragment_length,
max_fragment_length,
stats: Arc::new(BamReadFilterStats::new()),
}
}
pub fn is_valid<R>(&self, alignment: &R, header: Option<&sam::Header>) -> Result<bool>
where
R: sam::alignment::Record,
{
self.stats.n_total.fetch_add(1, Ordering::Relaxed);
let flags = match alignment.flags() {
Ok(flags) => flags,
Err(_) => {
self.stats
.n_failed_proper_pair
.fetch_add(1, Ordering::Relaxed);
return Ok(false);
}
};
match flags.is_reverse_complemented() {
true => {
if self.strand == bio_types::strand::Strand::Forward {
self.stats
.n_incorrect_strand
.fetch_add(1, Ordering::Relaxed);
return Ok(false);
}
}
false => {
if self.strand == bio_types::strand::Strand::Reverse {
self.stats
.n_incorrect_strand
.fetch_add(1, Ordering::Relaxed);
return Ok(false);
}
}
}
if self.proper_pair && !flags.is_properly_segmented() {
self.stats
.n_failed_proper_pair
.fetch_add(1, Ordering::Relaxed);
return Ok(false);
}
if flags.is_unmapped() {
self.stats.n_failed_mapq.fetch_add(1, Ordering::Relaxed);
return Ok(false);
}
match alignment.mapping_quality() {
Some(Ok(mapping_quality)) => {
if mapping_quality.get() < self.min_mapq {
self.stats.n_failed_mapq.fetch_add(1, Ordering::Relaxed);
return Ok(false);
}
}
None => {
}
_ => {
self.stats.n_failed_mapq.fetch_add(1, Ordering::Relaxed);
return Ok(false);
}
}
let alignment_length = alignment.sequence().len();
if alignment_length < self.min_length as usize
|| alignment_length > self.max_length as usize
{
self.stats.n_failed_length.fetch_add(1, Ordering::Relaxed);
return Ok(false);
}
if self.min_fragment_length.is_some() || self.max_fragment_length.is_some() {
let tlen = alignment
.template_length()
.map_err(|e| anyhow::anyhow!(e))?
.unsigned_abs();
if let Some(min_flen) = self.min_fragment_length
&& tlen < min_flen
{
self.stats
.n_failed_fragment_length
.fetch_add(1, Ordering::Relaxed);
return Ok(false);
}
if let Some(max_flen) = self.max_fragment_length
&& tlen > max_flen
{
self.stats
.n_failed_fragment_length
.fetch_add(1, Ordering::Relaxed);
return Ok(false);
}
}
let header = header.expect("No header provided");
let chrom_id = alignment
.reference_sequence_id(header)
.context("Failed to get reference sequence ID")??;
let start = match alignment.alignment_start() {
Some(Ok(start)) => start.get(),
_ => {
self.stats.n_failed_mapq.fetch_add(1, Ordering::Relaxed);
return Ok(false);
}
};
let end = start + alignment_length;
if let Some(blacklisted_locations) = &self.blacklisted_locations
&& let Some(blacklist) = blacklisted_locations.get(&chrom_id)
&& blacklist.count(start, end) > 0
{
self.stats
.n_failed_blacklist
.fetch_add(1, Ordering::Relaxed);
return Ok(false);
}
if let Some(barcodes) = &self.whitelisted_barcodes {
let barcode = get_cell_barcode(alignment);
match barcode {
Some(barcode) => {
if !barcodes.contains(&barcode) {
self.stats.n_failed_barcode.fetch_add(1, Ordering::Relaxed);
return Ok(false);
}
}
None => {
self.stats.n_failed_barcode.fetch_add(1, Ordering::Relaxed);
return Ok(false);
}
}
}
if let Some(read_group) = &self.read_group {
let rg = noodles::sam::alignment::record::data::field::tag::Tag::READ_GROUP;
let data = alignment.data();
let read_group_value = data.get(&rg).context("Failed to get read group")??;
match read_group_value {
Value::String(value) => {
if value != read_group {
self.stats
.n_not_in_read_group
.fetch_add(1, Ordering::Relaxed);
return Ok(false);
}
}
_ => {
self.stats
.n_not_in_read_group
.fetch_add(1, Ordering::Relaxed);
return Ok(false);
}
}
}
if let (Some(filter_tag), Some(filter_tag_value)) =
(&self.filter_tag, &self.filter_tag_value)
{
if filter_tag.len() != 2 {
self.stats
.n_failed_tag_filter
.fetch_add(1, Ordering::Relaxed);
return Ok(false);
}
let tag_bytes = filter_tag.as_bytes();
let tag = Tag::new(tag_bytes[0], tag_bytes[1]);
let data = alignment.data();
let tag_value = data.get(&tag);
match tag_value {
Some(Ok(Value::String(value))) => {
if value != filter_tag_value {
self.stats
.n_failed_tag_filter
.fetch_add(1, Ordering::Relaxed);
return Ok(false);
}
}
_ => {
self.stats
.n_failed_tag_filter
.fetch_add(1, Ordering::Relaxed);
return Ok(false);
}
}
}
Ok(true)
}
pub fn stats(&self) -> BamReadFilterStatsSnapshot {
self.stats.snapshot()
}
}
fn get_cell_barcode<R>(alignment: &R) -> Option<String>
where
R: sam::alignment::Record,
{
let tags = alignment.data();
let cell_barcode_tag = Tag::from(CB);
let tag_value = tags.get(&cell_barcode_tag);
if let Some(Ok(Value::String(barcode))) = tag_value {
Some(barcode.to_string())
} else {
None
}
}