use crate::bam_io::{RawBamWriter, create_raw_bam_reader, create_raw_bam_writer};
use crate::logging::OperationTimer;
use crate::progress::ProgressTracker;
use crate::sam::is_template_coordinate_sorted;
use crate::validation::validate_file_exists;
use anyhow::{Result, bail};
use clap::Parser;
use fgumi_raw_bam::{
RawBamReader, RawRecord, aux_data_slice, find_int_tag, find_string_tag, read_name,
};
use log::info;
use rand::SeedableRng;
use rand::rngs::StdRng;
use std::collections::{BTreeMap, HashSet};
use std::fs::File;
use std::io::Write;
use std::path::PathBuf;
use crate::commands::command::Command;
use crate::commands::common::{BamIoOptions, CompressionOptions, parse_bool};
#[derive(Debug, Parser)]
#[command(
name = "downsample",
about = "\x1b[38;5;166m[UTILITIES]\x1b[0m \x1b[36mDownsample BAM by UMI family using streaming\x1b[0m",
long_about = r#"
Downsample a BAM file by UMI family using a single-pass streaming algorithm.
This tool reads a BAM file that has been processed by fgumi group (or fgbio GroupReadsByUmi)
containing MI tags, uniformly samples UMI families, and outputs kept reads directly to a BAM file.
Requires input BAM to be in template-coordinate order:
- SO:unsorted (or not set)
- GO:query
- SS:unsorted:template-coordinate or SS:template-coordinate
The tool processes families in streaming fashion by grouping consecutive reads with the same
MI tag value. For each family, a random decision is made based on the fraction parameter to
either keep or reject all reads in that family.
Example usage:
fgumi downsample -i grouped.bam -o downsampled.bam -f 0.1 --seed 42
fgumi downsample -i grouped.bam -o kept.bam -f 0.5 --rejects rejected.bam
fgumi downsample -i grouped.bam -o kept.bam -f 0.1 --histogram-kept kept_hist.txt
"#
)]
pub struct Downsample {
#[command(flatten)]
pub io: BamIoOptions,
#[arg(short = 'f', long = "fraction")]
pub fraction: f64,
#[arg(long = "rejects")]
pub rejects: Option<PathBuf>,
#[arg(long = "seed")]
pub seed: Option<u64>,
#[arg(long = "validate-mi-order", default_value = "false", num_args = 0..=1, default_missing_value = "true", action = clap::ArgAction::Set, value_parser = parse_bool)]
pub validate_mi_order: bool,
#[arg(long = "histogram-kept")]
pub histogram_kept: Option<PathBuf>,
#[arg(long = "histogram-rejected")]
pub histogram_rejected: Option<PathBuf>,
#[command(flatten)]
pub compression: CompressionOptions,
}
impl Command for Downsample {
fn execute(&self, command_line: &str) -> Result<()> {
validate_file_exists(&self.io.input, "Input BAM")?;
if self.fraction <= 0.0 || self.fraction > 1.0 {
bail!(
"--fraction must be between 0.0 (exclusive) and 1.0 (inclusive), got {}",
self.fraction
);
}
let timer = OperationTimer::new("Downsampling reads");
info!("Starting Downsample");
info!("Input: {}", self.io.input.display());
info!("Output: {}", self.io.output.display());
info!("Target fraction: {}", self.fraction);
if let Some(seed) = self.seed {
info!("Random seed: {seed}");
}
if self.validate_mi_order {
info!("MI order validation: enabled");
}
let mut rng = match self.seed {
Some(seed) => StdRng::seed_from_u64(seed),
None => rand::make_rng(),
};
let (reader, header) = create_raw_bam_reader(&self.io.input, 1)?;
if !is_template_coordinate_sorted(&header) {
bail!(
"Input BAM must be template-coordinate sorted (output from group).\n\n\
Expected header fields: SO:unsorted, GO:query, SS:template-coordinate\n\n\
The input to this tool should be the output of fgumi group or fgbio GroupReadsByUmi."
);
}
info!("Header validation passed (template-coordinate order confirmed)");
let header = crate::commands::common::add_pg_record(header, command_line)?;
let mut writer =
create_raw_bam_writer(&self.io.output, &header, 1, self.compression.compression_level)?;
let mut rejects_writer: Option<RawBamWriter> = self
.rejects
.as_ref()
.map(|path| create_raw_bam_writer(path, &header, 1, self.compression.compression_level))
.transpose()?;
let mut total_families: u64 = 0;
let mut kept_families: u64 = 0;
let mut kept_reads: u64 = 0;
let mut rejected_reads: u64 = 0;
let mut record_count: usize = 0;
let progress = ProgressTracker::new("Processed records").with_interval(1_000_000);
let mut hist_kept: BTreeMap<usize, u64> = BTreeMap::new();
let mut hist_rejected: BTreeMap<usize, u64> = BTreeMap::new();
let mut seen_mis: HashSet<String> = HashSet::new();
info!("Processing reads...");
let mut family_iter = FamilyIterator::new(raw_record_iter(reader));
while let Some(family_result) = family_iter.next_family()? {
let (mi, family) = family_result;
total_families += 1;
let family_size = family.len();
if self.validate_mi_order {
if seen_mis.contains(&mi) {
bail!(
"MI tag '{mi}' seen non-consecutively. Input BAM may not be properly grouped by MI."
);
}
seen_mis.insert(mi);
}
let keep = rand::RngExt::random::<f64>(&mut rng) < self.fraction;
record_count += family_size;
if keep {
kept_families += 1;
kept_reads += family_size as u64;
*hist_kept.entry(family_size).or_insert(0) += 1;
for record in &family {
writer.write_raw_record(record.as_ref())?;
}
} else {
rejected_reads += family_size as u64;
*hist_rejected.entry(family_size).or_insert(0) += 1;
if let Some(ref mut rw) = rejects_writer {
for record in &family {
rw.write_raw_record(record.as_ref())?;
}
}
}
progress.log_if_needed(family_size as u64);
}
progress.log_final();
if let Some(ref path) = self.histogram_kept {
write_histogram(&hist_kept, path)?;
info!("Wrote kept histogram to: {}", path.display());
}
if let Some(ref path) = self.histogram_rejected {
write_histogram(&hist_rejected, path)?;
info!("Wrote rejected histogram to: {}", path.display());
}
writer.finish()?;
if let Some(rw) = rejects_writer {
rw.finish()?;
}
info!("=== Summary ===");
info!("Total reads processed: {}", kept_reads + rejected_reads);
info!("Input families: {total_families}");
if total_families > 0 {
let kept_pct = 100.0 * kept_families as f64 / total_families as f64;
info!("Kept families: {kept_families} ({kept_pct:.2}%)");
} else {
info!("Kept families: 0");
}
info!("Kept reads: {kept_reads}");
info!("Rejected reads: {rejected_reads}");
info!("Output BAM: {}", self.io.output.display());
if let Some(ref rejects) = self.rejects {
info!("Rejects BAM: {}", rejects.display());
}
timer.log_completion(record_count as u64);
Ok(())
}
}
fn write_histogram(histogram: &BTreeMap<usize, u64>, path: &PathBuf) -> Result<()> {
let mut file = File::create(path)?;
writeln!(file, "family_size\tcount")?;
for (size, count) in histogram {
writeln!(file, "{size}\t{count}")?;
}
Ok(())
}
fn raw_record_iter<R: std::io::Read>(
mut reader: RawBamReader<R>,
) -> impl Iterator<Item = Result<RawRecord>> {
let mut exhausted = false;
std::iter::from_fn(move || {
if exhausted {
return None;
}
let mut rec = RawRecord::new();
match reader.read_record(&mut rec) {
Ok(0) => {
exhausted = true;
None
}
Ok(_) => Some(Ok(rec)),
Err(e) => {
exhausted = true;
Some(Err(anyhow::Error::from(e)))
}
}
})
}
struct FamilyIterator<I>
where
I: Iterator<Item = Result<RawRecord>>,
{
records: std::iter::Peekable<I>,
}
impl<I> FamilyIterator<I>
where
I: Iterator<Item = Result<RawRecord>>,
{
fn new(records: I) -> Self {
Self { records: records.peekable() }
}
fn next_family(&mut self) -> Result<Option<(String, Vec<RawRecord>)>> {
let mi = match self.records.peek() {
Some(Ok(record)) => get_mi_tag(record)?,
Some(Err(_)) => {
return Err(self.records.next().expect("peek() returned Some").unwrap_err());
}
None => return Ok(None),
};
let mut family = Vec::new();
while let Some(peek_result) = self.records.peek() {
match peek_result {
Ok(record) => {
let record_mi = get_mi_tag(record)?;
if record_mi != mi {
break;
}
family.push(self.records.next().expect("peek() returned Some")?);
}
Err(_) => {
return Err(self.records.next().expect("peek() returned Some").unwrap_err());
}
}
}
Ok(Some((mi, family)))
}
}
fn get_mi_tag(record: &RawRecord) -> Result<String> {
let aux = aux_data_slice(record.as_ref());
if let Some(bytes) = find_string_tag(aux, b"MI") {
return std::str::from_utf8(bytes)
.map(str::to_string)
.map_err(|e| anyhow::anyhow!("MI tag is not valid UTF-8: {e}"));
}
if let Some(v) = find_int_tag(aux, b"MI") {
return Ok(v.to_string());
}
let name = String::from_utf8_lossy(read_name(record.as_ref())).into_owned();
let display_name = if name.is_empty() { "<unknown>".to_string() } else { name };
bail!("Read '{display_name}' is missing required MI tag")
}
#[cfg(test)]
mod tests {
use super::*;
use fgumi_raw_bam::SamBuilder as RawSamBuilder;
fn create_test_record(name: &str, mi: &str) -> RawRecord {
let mut b = RawSamBuilder::new();
b.read_name(name.as_bytes());
b.add_string_tag(b"MI", mi.as_bytes());
b.build()
}
fn create_test_record_int_mi(name: &str, mi: i32) -> RawRecord {
let mut b = RawSamBuilder::new();
b.read_name(name.as_bytes());
b.add_int_tag(b"MI", mi);
b.build()
}
fn create_test_record_no_mi(name: &str) -> RawRecord {
let mut b = RawSamBuilder::new();
b.read_name(name.as_bytes());
b.build()
}
fn test_bam_io_options() -> BamIoOptions {
BamIoOptions {
input: PathBuf::from("input.bam"),
output: PathBuf::from("output.bam"),
async_reader: false,
}
}
#[test]
fn test_get_mi_tag_string() {
let record = create_test_record("read1", "12345");
let mi = get_mi_tag(&record).expect("get_mi_tag should succeed for string MI");
assert_eq!(mi, "12345");
}
#[test]
fn test_get_mi_tag_integer() {
let record = create_test_record_int_mi("read1", 42);
let mi = get_mi_tag(&record).expect("get_mi_tag should succeed for integer MI");
assert_eq!(mi, "42");
}
#[test]
fn test_get_mi_tag_missing() {
let record = create_test_record_no_mi("read1");
let result = get_mi_tag(&record);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("missing required MI tag"));
}
#[test]
fn test_family_iterator_single_family() {
let records = vec![
Ok(create_test_record("r1", "100")),
Ok(create_test_record("r2", "100")),
Ok(create_test_record("r3", "100")),
];
let mut iter = FamilyIterator::new(records.into_iter());
let family1 =
iter.next_family().expect("next_family should succeed").expect("expected a family");
assert_eq!(family1.0, "100");
assert_eq!(family1.1.len(), 3);
let family2 = iter.next_family().expect("next_family should succeed");
assert!(family2.is_none());
}
#[test]
fn test_family_iterator_multiple_families() {
let records = vec![
Ok(create_test_record("r1", "100")),
Ok(create_test_record("r2", "100")),
Ok(create_test_record("r3", "200")),
Ok(create_test_record("r4", "200")),
Ok(create_test_record("r5", "200")),
Ok(create_test_record("r6", "300")),
];
let mut iter = FamilyIterator::new(records.into_iter());
let family1 =
iter.next_family().expect("next_family should succeed").expect("expected family 1");
assert_eq!(family1.0, "100");
assert_eq!(family1.1.len(), 2);
let family2 =
iter.next_family().expect("next_family should succeed").expect("expected family 2");
assert_eq!(family2.0, "200");
assert_eq!(family2.1.len(), 3);
let family3 =
iter.next_family().expect("next_family should succeed").expect("expected family 3");
assert_eq!(family3.0, "300");
assert_eq!(family3.1.len(), 1);
let family4 = iter.next_family().expect("next_family should succeed");
assert!(family4.is_none());
}
#[test]
fn test_family_iterator_empty() {
let records: Vec<Result<RawRecord>> = vec![];
let mut iter = FamilyIterator::new(records.into_iter());
let family = iter.next_family().expect("next_family should succeed");
assert!(family.is_none());
}
#[test]
fn test_validate_fraction_too_low() {
let cmd = Downsample {
io: test_bam_io_options(),
fraction: 0.0,
rejects: None,
seed: None,
validate_mi_order: false,
histogram_kept: None,
histogram_rejected: None,
compression: CompressionOptions { compression_level: 1 },
};
assert!(cmd.fraction <= 0.0);
}
#[test]
fn test_validate_fraction_too_high() {
let cmd = Downsample {
io: test_bam_io_options(),
fraction: 1.5,
rejects: None,
seed: None,
validate_mi_order: false,
histogram_kept: None,
histogram_rejected: None,
compression: CompressionOptions { compression_level: 1 },
};
assert!(cmd.fraction > 1.0);
}
#[test]
fn test_validate_fraction_valid() {
let cmd = Downsample {
io: test_bam_io_options(),
fraction: 0.5,
rejects: None,
seed: None,
validate_mi_order: false,
histogram_kept: None,
histogram_rejected: None,
compression: CompressionOptions { compression_level: 1 },
};
assert!(cmd.fraction > 0.0 && cmd.fraction <= 1.0);
}
#[test]
fn test_write_histogram() {
use tempfile::NamedTempFile;
let mut hist = BTreeMap::new();
hist.insert(1, 10);
hist.insert(2, 20);
hist.insert(5, 5);
let temp_file = NamedTempFile::new().expect("failed to create temp file");
write_histogram(&hist, &temp_file.path().to_path_buf())
.expect("write_histogram should succeed");
let contents =
std::fs::read_to_string(temp_file.path()).expect("failed to read histogram file");
assert!(contents.contains("family_size\tcount"));
assert!(contents.contains("1\t10"));
assert!(contents.contains("2\t20"));
assert!(contents.contains("5\t5"));
}
#[test]
#[allow(clippy::float_cmp)] fn test_downsample_parameters() {
let cmd = Downsample {
io: test_bam_io_options(),
fraction: 0.1,
rejects: Some(PathBuf::from("rejects.bam")),
seed: Some(42),
validate_mi_order: true,
histogram_kept: Some(PathBuf::from("kept.txt")),
histogram_rejected: Some(PathBuf::from("rejected.txt")),
compression: CompressionOptions { compression_level: 1 },
};
assert_eq!(cmd.fraction, 0.1);
assert_eq!(cmd.seed, Some(42));
assert!(cmd.validate_mi_order);
assert!(cmd.rejects.is_some());
assert!(cmd.histogram_kept.is_some());
assert!(cmd.histogram_rejected.is_some());
}
#[test]
fn test_deterministic_sampling_with_seed() {
use rand::RngExt;
let seed = 12345u64;
let mut rng1 = StdRng::seed_from_u64(seed);
let results1: Vec<bool> = (0..100).map(|_| rng1.random::<f64>() < 0.5).collect();
let mut rng2 = StdRng::seed_from_u64(seed);
let results2: Vec<bool> = (0..100).map(|_| rng2.random::<f64>() < 0.5).collect();
assert_eq!(results1, results2);
}
#[test]
fn test_histogram_sorted_by_family_size() {
let mut hist = BTreeMap::new();
hist.insert(5, 10);
hist.insert(1, 20);
hist.insert(3, 15);
let sizes: Vec<usize> = hist.keys().copied().collect();
assert_eq!(sizes, vec![1, 3, 5]);
}
}