use crate::bam_io::{
BamWriter, RawBamWriter, create_bam_reader_for_pipeline_with_opts, create_bam_writer,
create_optional_bam_writer, create_raw_bam_reader_with_opts, create_raw_bam_writer,
};
use crate::bitenc::BitEnc;
use crate::dna::reverse_complement_str;
use crate::grouper::TemplateGrouper;
use crate::logging::OperationTimer;
use crate::metrics::correct::UmiCorrectionMetrics;
use crate::per_thread_accumulator::PerThreadAccumulator;
use crate::progress::ProgressTracker;
use crate::sam::{SamTag, header_as_unsorted};
use crate::sort::bam_fields;
use crate::template::TemplateBatch;
use crate::unified_pipeline::{Grouper, MemoryEstimate, run_bam_pipeline_from_reader};
use crate::validation::validate_file_exists;
use ahash::AHashMap;
use anyhow::{Context, Result, bail};
use clap::Parser;
use fgumi_raw_bam::RawRecord;
use log::{info, warn};
use lru::LruCache;
use noodles::sam::Header;
use noodles::sam::alignment::record::data::field::Tag;
use parking_lot::Mutex;
use std::io;
use std::num::NonZero;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use crate::commands::command::Command;
use crate::commands::common::{
BamIoOptions, CompressionOptions, QueueMemoryOptions, RejectsOptions, SchedulerOptions,
ThreadingOptions, build_pipeline_config, parse_bool,
};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct UmiMatch {
pub matched: bool,
pub umi: String,
pub mismatches: usize,
}
#[derive(Parser, Debug, Clone)]
#[command(
name = "correct",
author,
version,
about = "\x1b[38;5;30m[UMI EXTRACTION]\x1b[0m \x1b[36mCorrect UMIs in a BAM file to a fixed set of UMIs\x1b[0m",
long_about = r#"
Corrects UMIs stored in BAM files when a set of fixed UMIs is in use.
If the set of UMIs used in an experiment is known and is a _subset_ of the possible randomers
of the same length, it is possible to error-correct UMIs prior to grouping reads by UMI. This
tool takes an input BAM with UMIs in the `RX` tag and set of known UMIs (either on
the command line or in a file) and produces:
1. A new BAM with corrected UMIs written to the `RX` tag
2. Optionally a set of metrics about the representation of each UMI in the set
3. Optionally a second BAM file of reads whose UMIs could not be corrected within the specific parameters
All of the fixed UMIs must be of the same length, and all UMIs in the BAM file must also have
the same length. Multiple UMIs that are concatenated with hyphens (e.g. `AACCAGT-AGGTAGA`) are
split apart, corrected individually and then re-assembled. A read is accepted only if all the
UMIs can be corrected.
## Correction Parameters
Correction is controlled by two parameters that are applied per-UMI:
1. **--max-mismatches** controls how many mismatches (no-calls are counted as mismatches) are
tolerated between a UMI as read and a fixed UMI
2. **--min-distance** controls how many *more* mismatches the next best hit must have
For example, with two fixed UMIs `AAAAA` and `CCCCC` and `--max-mismatches=3` and `--min-distance=2`:
- AAAAA would match to AAAAA
- AAGTG would match to AAAAA with three mismatches because CCCCC has six mismatches and 6 >= 3 + 2
- AACCA would be rejected because it is 2 mismatches to AAAAA and 3 to CCCCC and 3 <= 2 + 2
## Specifying UMIs
The set of fixed UMIs may be specified on the command line using `--umis umi1 umi2 ...` or via
one or more files of UMIs with a single sequence per line using `--umi-files umis.txt more_umis.txt`.
If there are multiple UMIs per template, leading to hyphenated UMI tags, the values for the fixed
UMIs should be single, non-hyphenated UMIs (e.g. if a record has `RX:Z:ACGT-GGCA`, you would use
`--umis ACGT GGCA`).
## Original UMI Storage
Records which have their UMIs corrected (i.e. the UMI is not identical to one of the expected
UMIs but is close enough to be corrected) will by default have their original UMI stored in the
`OX` tag. This can be disabled with the `--dont-store-original-umis` option.
"#
)]
pub struct CorrectUmis {
#[command(flatten)]
pub io: BamIoOptions,
#[command(flatten)]
pub rejects_opts: RejectsOptions,
#[arg(short = 'M', long)]
pub metrics: Option<PathBuf>,
#[arg(long, default_value = "2")]
pub max_mismatches: usize,
#[arg(short = 'd', long = "min-distance")]
pub min_distance_diff: usize,
#[arg(short = 'u', long)]
pub umis: Vec<String>,
#[arg(short = 'U', long)]
pub umi_files: Vec<PathBuf>,
#[arg(long, default_value = "false", num_args = 0..=1, default_missing_value = "true", action = clap::ArgAction::Set, value_parser = parse_bool)]
pub dont_store_original_umis: bool,
#[arg(long, default_value = "100000")]
pub cache_size: usize,
#[arg(long)]
pub min_corrected: Option<f64>,
#[arg(long, default_value = "false", num_args = 0..=1, default_missing_value = "true", action = clap::ArgAction::Set, value_parser = parse_bool)]
pub revcomp: bool,
#[command(flatten)]
pub threading: ThreadingOptions,
#[command(flatten)]
pub compression: CompressionOptions,
#[command(flatten)]
pub scheduler_opts: SchedulerOptions,
#[command(flatten)]
pub queue_memory: QueueMemoryOptions,
}
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
enum RejectionReason {
WrongLength,
Mismatched,
#[default]
None,
}
#[derive(Debug)]
struct TemplateCorrection {
matched: bool,
corrected_umi: Option<String>,
original_umi: String,
needs_correction: bool,
has_mismatches: bool,
matches: Vec<UmiMatch>,
rejection_reason: RejectionReason,
}
struct CorrectProcessedBatch {
kept_raw_records: Vec<RawRecord>,
templates_count: u64,
missing_umis: u64,
wrong_length: u64,
mismatched: u64,
umi_matches: AHashMap<String, UmiCorrectionMetrics>,
}
impl MemoryEstimate for CorrectProcessedBatch {
fn estimate_heap_size(&self) -> usize {
let raw_size: usize = self.kept_raw_records.iter().map(RawRecord::capacity).sum();
let raw_vec_overhead = self.kept_raw_records.capacity() * std::mem::size_of::<RawRecord>();
raw_size + raw_vec_overhead
}
}
#[derive(Default)]
struct CollectedCorrectMetrics {
templates_processed: u64,
missing_umis: u64,
wrong_length: u64,
mismatched: u64,
umi_matches: AHashMap<String, UmiCorrectionMetrics>,
}
impl Command for CorrectUmis {
fn execute(&self, command_line: &str) -> Result<()> {
self.validate()?;
let timer = OperationTimer::new("Correcting UMIs");
let (umi_sequences, umi_length) = self.load_umi_sequences()?;
let encoded_umi_set = EncodedUmiSet::new(&umi_sequences);
self.check_umi_distances(&umi_sequences);
let (reader, header) = create_bam_reader_for_pipeline_with_opts(
&self.io.input,
self.io.pipeline_reader_opts(),
)?;
let header = crate::commands::common::add_pg_record(header, command_line)?;
let total_records = if let Some(threads) = self.threading.threads {
self.execute_threads_mode(
threads,
reader,
header,
Arc::new(encoded_umi_set),
umi_length,
self.rejects_opts.rejects.is_some(),
)?
} else {
drop(reader);
self.execute_single_thread_mode(
header,
encoded_umi_set,
umi_length,
self.rejects_opts.rejects.is_some(),
)?
};
timer.log_completion(total_records);
Ok(())
}
}
impl CorrectUmis {
fn validate(&self) -> Result<()> {
if self.umis.is_empty() && self.umi_files.is_empty() {
bail!("At least one UMI or UMI file must be provided.");
}
validate_file_exists(&self.io.input, "input BAM file")?;
if let Some(min) = self.min_corrected {
if !(0.0..=1.0).contains(&min) {
bail!("--min-corrected must be between 0 and 1.");
}
}
Ok(())
}
fn load_umi_sequences(&self) -> Result<(Vec<String>, usize)> {
let mut umi_set: std::collections::HashSet<String> =
self.umis.iter().map(|s| s.to_uppercase()).collect();
for file in &self.umi_files {
let content = std::fs::read_to_string(file)?;
for line in content.lines() {
let umi = line.trim().to_uppercase();
if !umi.is_empty() {
umi_set.insert(umi);
}
}
}
if umi_set.is_empty() {
bail!("No UMIs provided.");
}
let mut umi_sequences: Vec<String> = umi_set.into_iter().collect();
umi_sequences.sort_unstable();
let first_len = umi_sequences[0].len();
if !umi_sequences.iter().all(|u| u.len() == first_len) {
bail!("All UMIs must have the same length.");
}
info!("Loaded {} UMI sequences of length {}", umi_sequences.len(), first_len);
Ok((umi_sequences, first_len))
}
fn check_umi_distances(&self, umi_sequences: &[String]) {
let pairs = find_umi_pairs_within_distance(umi_sequences, self.min_distance_diff - 1);
if !pairs.is_empty() {
warn!("###################################################################");
warn!("# WARNING: Found pairs of UMIs within min-distance-diff threshold!");
warn!("# These pairs may be ambiguous and fail to match:");
for (umi1, umi2, dist) in &pairs {
warn!("# {umi1} <-> {umi2} (distance {dist})");
}
warn!("###################################################################");
}
}
#[allow(clippy::too_many_arguments)]
fn compute_template_correction(
umi: &str,
umi_length: usize,
revcomp: bool,
max_mismatches: usize,
min_distance_diff: usize,
encoded_umi_set: &EncodedUmiSet,
cache: &mut Option<LruCache<Vec<u8>, UmiMatch>>,
) -> TemplateCorrection {
let original_umi = umi.to_string();
let sequences: Vec<String> = if revcomp {
umi.split('-').map(reverse_complement_str).rev().collect()
} else {
umi.split('-').map(std::string::ToString::to_string).collect()
};
if sequences.iter().any(|s| s.len() != umi_length) {
return TemplateCorrection {
matched: false,
corrected_umi: None,
original_umi,
needs_correction: false,
has_mismatches: false,
matches: Vec::new(),
rejection_reason: RejectionReason::WrongLength,
};
}
let mut matches = Vec::with_capacity(sequences.len());
for seq in &sequences {
let seq_bytes: Vec<u8> = seq.bytes().map(|b| b.to_ascii_uppercase()).collect();
let umi_match = if let Some(c) = cache {
if let Some(cached) = c.get(&seq_bytes[..]) {
cached.clone()
} else {
let result = find_best_match_encoded(
&seq_bytes,
encoded_umi_set,
max_mismatches,
min_distance_diff,
);
c.put(seq_bytes, result.clone());
result
}
} else {
find_best_match_encoded(
&seq_bytes,
encoded_umi_set,
max_mismatches,
min_distance_diff,
)
};
matches.push(umi_match);
}
let all_matched = matches.iter().all(|m| m.matched);
let has_mismatches = matches.iter().any(|m| m.mismatches > 0);
let needs_correction = has_mismatches || revcomp;
if all_matched {
let corrected_umi: String =
matches.iter().map(|m| m.umi.clone()).collect::<Vec<_>>().join("-");
TemplateCorrection {
matched: true,
corrected_umi: Some(corrected_umi),
original_umi,
needs_correction,
has_mismatches,
matches,
rejection_reason: RejectionReason::None,
}
} else {
TemplateCorrection {
matched: false,
corrected_umi: None,
original_umi,
needs_correction: false,
has_mismatches: false,
matches,
rejection_reason: RejectionReason::Mismatched,
}
}
}
fn extract_and_validate_template_umi_raw(
raw_records: &[RawRecord],
umi_tag: [u8; 2],
) -> anyhow::Result<Option<String>> {
use crate::sort::bam_fields;
if raw_records.is_empty() {
return Ok(None);
}
if raw_records.iter().any(|r| r.len() < 32) {
return Ok(None);
}
let first_aux = bam_fields::aux_data_slice(&raw_records[0]);
let first_umi_bytes = bam_fields::find_string_tag(first_aux, &umi_tag);
if first_umi_bytes.is_none() {
if let Some(tag_type) = bam_fields::find_tag_type(first_aux, &umi_tag) {
anyhow::bail!(
"UMI tag {:?} exists but has non-string type '{}', expected 'Z'",
std::str::from_utf8(&umi_tag).unwrap_or("??"),
tag_type as char,
);
}
}
for raw in &raw_records[1..] {
let aux = bam_fields::aux_data_slice(raw);
let current_umi_bytes = bam_fields::find_string_tag(aux, &umi_tag);
match (first_umi_bytes, current_umi_bytes) {
(Some(first), Some(current)) if first != current => {
anyhow::bail!(
"Template has mismatched UMIs: first={:?}, current={:?}",
String::from_utf8_lossy(first),
String::from_utf8_lossy(current)
);
}
(Some(_), None) | (None, Some(_)) => {
anyhow::bail!("Template has inconsistent UMI presence across records");
}
_ => {}
}
}
Ok(first_umi_bytes.map(|b| String::from_utf8_lossy(b).into_owned()))
}
fn apply_correction_to_raw(
record: &mut RawRecord,
correction: &TemplateCorrection,
umi_tag: [u8; 2],
dont_store_original_umis: bool,
) {
use crate::sort::bam_fields;
if correction.needs_correction {
if let Some(ref corrected) = correction.corrected_umi {
bam_fields::update_string_tag(record.as_mut_vec(), &umi_tag, corrected.as_bytes());
}
if !dont_store_original_umis && correction.has_mismatches {
bam_fields::update_string_tag(
record.as_mut_vec(),
&SamTag::OX,
correction.original_umi.as_bytes(),
);
}
}
}
fn finalize_metrics(
&self,
umi_metrics: &mut AHashMap<String, UmiCorrectionMetrics>,
unmatched_umi: &str,
) -> Result<()> {
let total: u64 = umi_metrics.values().map(|m| m.total_matches).sum();
let matched_total: u64 = umi_metrics
.iter()
.filter(|(umi, _)| *umi != unmatched_umi)
.map(|(_, m)| m.total_matches)
.sum();
#[allow(clippy::cast_precision_loss)]
for metric in umi_metrics.values_mut() {
metric.fraction_of_matches = metric.total_matches as f64 / total as f64;
}
let umi_count = umi_metrics.keys().filter(|umi| *umi != unmatched_umi).count();
#[allow(clippy::cast_precision_loss)]
let mean = matched_total as f64 / umi_count as f64;
for metric in umi_metrics.values_mut() {
metric.representation = metric.total_matches as f64 / mean;
}
if let Some(path) = &self.metrics {
let mut metrics: Vec<UmiCorrectionMetrics> = umi_metrics.values().cloned().collect();
metrics.sort_by(|a, b| a.umi.cmp(&b.umi));
UmiCorrectionMetrics::write_metrics(&metrics, path)?;
}
Ok(())
}
#[allow(clippy::too_many_lines)]
fn execute_threads_mode(
&self,
num_threads: usize,
reader: Box<dyn std::io::Read + Send>,
header: Header,
encoded_umi_set: Arc<EncodedUmiSet>,
umi_length: usize,
track_rejects: bool,
) -> Result<u64> {
let mut pipeline_config = build_pipeline_config(
&self.scheduler_opts,
&self.compression,
&self.queue_memory,
num_threads,
)?;
{
use crate::read_info::LibraryIndex;
use crate::unified_pipeline::GroupKeyConfig;
let library_index = LibraryIndex::from_header(&header);
pipeline_config.group_key_config = Some(GroupKeyConfig::new_raw_no_cell(library_index));
}
let collected_metrics = PerThreadAccumulator::<CollectedCorrectMetrics>::new(num_threads);
let collected_for_serialize = Arc::clone(&collected_metrics);
let rejects_writer: Option<Arc<Mutex<Option<RawBamWriter>>>> = if track_rejects {
if let Some(path) = self.rejects_opts.rejects.as_ref() {
let writer_threads = self.threading.num_threads();
let rejects_header = header_as_unsorted(&header);
let w = create_raw_bam_writer(
path,
&rejects_header,
writer_threads,
self.compression.compression_level,
)?;
Some(Arc::new(Mutex::new(Some(w))))
} else {
None
}
} else {
None
};
let rejects_writer_for_process = rejects_writer.as_ref().map(Arc::clone);
const BATCH_SIZE: usize = 1000; let max_mismatches = self.max_mismatches;
let min_distance_diff = self.min_distance_diff;
let umi_tag = Tag::from(SamTag::RX);
let revcomp = self.revcomp;
let cache_size = self.cache_size;
let dont_store_original_umis = self.dont_store_original_umis;
let progress_counter = Arc::new(AtomicU64::new(0));
let progress_for_process = Arc::clone(&progress_counter);
let unmatched_umi = "N".repeat(umi_length);
let unmatched_umi_for_process = unmatched_umi.clone();
let encoded_umi_set_for_metrics = Arc::clone(&encoded_umi_set);
let grouper_fn = move |_header: &Header| {
Box::new(TemplateGrouper::new(BATCH_SIZE))
as Box<dyn Grouper<Group = TemplateBatch> + Send>
};
let process_fn = move |batch: TemplateBatch| -> io::Result<CorrectProcessedBatch> {
thread_local! {
static CACHE: std::cell::RefCell<Option<LruCache<Vec<u8>, UmiMatch>>> = const { std::cell::RefCell::new(None) };
}
CACHE.with(|cache_cell| {
let mut cache_ref = cache_cell.borrow_mut();
if cache_ref.is_none() && cache_size > 0 {
*cache_ref = Some(LruCache::new(
NonZero::new(cache_size).expect("cache_size > 0 checked above"),
));
}
let mut kept_raw_records: Vec<RawRecord> = Vec::new();
let mut missing_umis = 0u64;
let mut wrong_length = 0u64;
let mut mismatched = 0u64;
let mut umi_matches_map: AHashMap<String, UmiCorrectionMetrics> = AHashMap::new();
let templates_count = batch.len() as u64;
let mut total_input_records = 0u64;
let umi_tag_bytes: [u8; 2] = [umi_tag.as_ref()[0], umi_tag.as_ref()[1]];
let flush_raw_records = |recs: Vec<RawRecord>| -> io::Result<()> {
if let Some(ref rw_arc) = rejects_writer_for_process {
if !recs.is_empty() {
let mut guard = rw_arc.lock();
if let Some(w) = guard.as_mut() {
for raw in &recs {
w.write_raw_record(raw.as_ref())?;
}
}
}
}
Ok(())
};
for template in batch {
total_input_records += template.read_count() as u64;
{
let raw_records: Vec<RawRecord> = template.into_records();
let umi_opt = Self::extract_and_validate_template_umi_raw(
&raw_records,
umi_tag_bytes,
)
.map_err(io::Error::other)?;
match umi_opt {
None => {
let num_records = raw_records.len() as u64;
missing_umis += num_records;
let entry = umi_matches_map
.entry(unmatched_umi_for_process.clone())
.or_insert_with(|| {
UmiCorrectionMetrics::new(unmatched_umi_for_process.clone())
});
entry.total_matches += num_records;
if track_rejects {
flush_raw_records(raw_records)?;
}
}
Some(umi) => {
let correction = Self::compute_template_correction(
&umi,
umi_length,
revcomp,
max_mismatches,
min_distance_diff,
&encoded_umi_set,
&mut cache_ref,
);
let num_records = raw_records.len() as u64;
if correction.matched {
for m in &correction.matches {
if m.matched {
let entry = umi_matches_map
.entry(m.umi.clone())
.or_insert_with(|| {
UmiCorrectionMetrics::new(m.umi.clone())
});
entry.total_matches += num_records;
match m.mismatches {
0 => entry.perfect_matches += num_records,
1 => entry.one_mismatch_matches += num_records,
2 => entry.two_mismatch_matches += num_records,
_ => entry.other_matches += num_records,
}
}
}
for mut raw in raw_records {
Self::apply_correction_to_raw(
&mut raw,
&correction,
umi_tag_bytes,
dont_store_original_umis,
);
kept_raw_records.push(raw);
}
} else {
match correction.rejection_reason {
RejectionReason::WrongLength => {
wrong_length += num_records;
}
RejectionReason::Mismatched => {
mismatched += num_records;
}
RejectionReason::None => {}
}
let entry = umi_matches_map
.entry(unmatched_umi_for_process.clone())
.or_insert_with(|| {
UmiCorrectionMetrics::new(
unmatched_umi_for_process.clone(),
)
});
entry.total_matches += num_records;
if track_rejects {
flush_raw_records(raw_records)?;
}
}
}
}
}
}
let count = progress_for_process.fetch_add(total_input_records, Ordering::Relaxed);
if (count + total_input_records) / 1_000_000 > count / 1_000_000 {
info!("Processed {} records", count + total_input_records);
}
Ok(CorrectProcessedBatch {
kept_raw_records,
templates_count,
missing_umis,
wrong_length,
mismatched,
umi_matches: umi_matches_map,
})
})
};
let serialize_fn = move |mut processed: CorrectProcessedBatch,
_header: &Header,
output: &mut Vec<u8>|
-> io::Result<u64> {
let umi_matches = std::mem::take(&mut processed.umi_matches);
collected_for_serialize.with_slot(|m| {
m.templates_processed += processed.templates_count;
m.missing_umis += processed.missing_umis;
m.wrong_length += processed.wrong_length;
m.mismatched += processed.mismatched;
for (umi, counts) in umi_matches {
merge_umi_counts(&mut m.umi_matches, umi, &counts);
}
});
let mut kept_count = 0u64;
for raw in &processed.kept_raw_records {
let block_size = raw.len() as u32;
output.extend_from_slice(&block_size.to_le_bytes());
output.extend_from_slice(raw);
kept_count += 1;
}
Ok(kept_count)
};
let pipeline_result = run_bam_pipeline_from_reader(
pipeline_config,
reader,
header,
&self.io.output,
None, grouper_fn,
process_fn,
serialize_fn,
);
let rejects_finish_result = rejects_writer
.and_then(|rw_arc| rw_arc.lock().take())
.map(|writer| writer.finish().context("Failed to finish rejects file"));
let records_written = match (pipeline_result, rejects_finish_result) {
(Ok(records_written), Some(Ok(()))) => {
info!("Rejected records streamed to rejects file during processing");
records_written
}
(Ok(records_written), None) => records_written,
(Ok(_), Some(Err(finish_err))) => return Err(finish_err),
(Err(pipeline_err), Some(Err(finish_err))) => {
return Err(anyhow::anyhow!(
"Pipeline error: {pipeline_err}; additionally failed to finish rejects file: {finish_err}"
));
}
(Err(pipeline_err), _) => return Err(pipeline_err.into()),
};
let mut total_templates = 0u64;
let mut total_missing = 0u64;
let mut total_wrong_length = 0u64;
let mut total_mismatched = 0u64;
let mut merged_umi_matches: AHashMap<String, UmiCorrectionMetrics> = AHashMap::new();
for slot in collected_metrics.slots() {
let mut m = slot.lock();
total_templates += m.templates_processed;
total_missing += m.missing_umis;
total_wrong_length += m.wrong_length;
total_mismatched += m.mismatched;
for (umi, counts) in m.umi_matches.drain() {
merge_umi_counts(&mut merged_umi_matches, umi, &counts);
}
}
for umi in encoded_umi_set_for_metrics.strings.iter().chain(std::iter::once(&unmatched_umi))
{
merged_umi_matches
.entry(umi.clone())
.or_insert_with(|| UmiCorrectionMetrics::new(umi.clone()));
}
self.finalize_metrics(&mut merged_umi_matches, &unmatched_umi)?;
let rejected = total_missing + total_wrong_length + total_mismatched;
let total_records = records_written + rejected;
info!("Read {total_records}; kept {records_written} and rejected {rejected}");
info!("Total templates processed: {total_templates}");
if total_missing > 0 || total_wrong_length > 0 {
warn!("###################################################################");
if total_missing > 0 {
warn!("# {total_missing} were missing UMI attributes in the BAM file!");
}
if total_wrong_length > 0 {
warn!(
"# {total_wrong_length} had unexpected UMIs of differing lengths in the BAM file!"
);
}
warn!("###################################################################");
}
if let Some(min) = self.min_corrected {
#[allow(clippy::cast_precision_loss)]
let ratio_kept = records_written as f64 / total_records as f64;
if ratio_kept < min {
bail!(
"Final ratio of reads kept / total was {ratio_kept:.2} (user specified minimum was {min:.2}). \
This could indicate a mismatch between library preparation and the provided UMI file."
);
}
}
Ok(total_records)
}
#[allow(clippy::too_many_lines)]
fn execute_single_thread_mode(
&self,
header: Header,
encoded_umi_set: EncodedUmiSet,
umi_length: usize,
track_rejects: bool,
) -> Result<u64> {
info!("Using single-threaded mode with template-level UMI correction");
let (mut bam_reader, _) =
create_raw_bam_reader_with_opts(&self.io.input, 1, self.io.pipeline_reader_opts())?;
let mut writer =
create_bam_writer(&self.io.output, &header, 1, self.compression.compression_level)?;
let mut reject_writer = create_optional_bam_writer(
self.rejects_opts.rejects.as_ref(),
&header,
1,
self.compression.compression_level,
)?;
let mut cache: Option<LruCache<Vec<u8>, UmiMatch>> = if self.cache_size > 0 {
Some(LruCache::new(
NonZero::new(self.cache_size).expect("cache_size > 0 checked above"),
))
} else {
None
};
let unmatched_umi = "N".repeat(umi_length);
let umi_sequences = encoded_umi_set.strings.clone();
let mut umi_metrics: AHashMap<String, UmiCorrectionMetrics> = umi_sequences
.iter()
.chain(std::iter::once(&unmatched_umi))
.map(|umi| (umi.clone(), UmiCorrectionMetrics::new(umi.clone())))
.collect();
let mut total_records = 0u64;
let mut total_templates = 0u64;
let mut missing_umis = 0u64;
let mut wrong_length = 0u64;
let mut mismatched = 0u64;
let progress = ProgressTracker::new("Processed records").with_interval(1_000_000);
let umi_tag_bytes: [u8; 2] = *b"RX";
let max_mismatches = self.max_mismatches;
let min_distance_diff = self.min_distance_diff;
let revcomp = self.revcomp;
let dont_store_original_umis = self.dont_store_original_umis;
#[allow(clippy::cast_possible_truncation)]
fn write_raw(writer: &mut BamWriter, raw: &[u8]) -> Result<()> {
use std::io::Write;
let block_size = raw.len() as u32;
writer.get_mut().write_all(&block_size.to_le_bytes())?;
writer.get_mut().write_all(raw)?;
Ok(())
}
let mut record = RawRecord::new();
let mut current_template: Vec<RawRecord> = Vec::new();
let mut current_name: Option<Vec<u8>> = None;
loop {
let bytes_read = bam_reader.read_record(&mut record)?;
let eof = bytes_read == 0;
let flush = if eof {
!current_template.is_empty()
} else {
let name = bam_fields::read_name(record.as_ref());
current_name.as_deref().is_some_and(|cn| cn != name)
};
if flush {
let mut raw_records = std::mem::take(&mut current_template);
#[allow(clippy::cast_possible_truncation)]
let num_records = raw_records.len() as u64;
total_records += num_records;
total_templates += 1;
let umi_opt = CorrectUmis::extract_and_validate_template_umi_raw(
&raw_records,
umi_tag_bytes,
)?;
match umi_opt {
None => {
missing_umis += num_records;
umi_metrics
.get_mut(&unmatched_umi)
.expect("unmatched_umi key initialized in metrics map")
.total_matches += num_records;
if track_rejects {
if let Some(rw) = reject_writer.as_mut() {
for raw in raw_records.drain(..) {
write_raw(rw, &raw)?;
}
}
}
}
Some(umi) => {
let correction = CorrectUmis::compute_template_correction(
&umi,
umi_length,
revcomp,
max_mismatches,
min_distance_diff,
&encoded_umi_set,
&mut cache,
);
if correction.matched {
for m in &correction.matches {
if m.matched {
let entry = umi_metrics.get_mut(&m.umi).expect(
"UMI key initialized in metrics map from allowed UMIs",
);
entry.total_matches += num_records;
match m.mismatches {
0 => entry.perfect_matches += num_records,
1 => entry.one_mismatch_matches += num_records,
2 => entry.two_mismatch_matches += num_records,
_ => entry.other_matches += num_records,
}
}
}
for mut raw in raw_records.drain(..) {
CorrectUmis::apply_correction_to_raw(
&mut raw,
&correction,
umi_tag_bytes,
dont_store_original_umis,
);
write_raw(&mut writer, &raw)?;
}
} else {
match correction.rejection_reason {
RejectionReason::WrongLength => wrong_length += num_records,
RejectionReason::Mismatched => mismatched += num_records,
RejectionReason::None => {}
}
umi_metrics
.get_mut(&unmatched_umi)
.expect("unmatched_umi key initialized in metrics map")
.total_matches += num_records;
if track_rejects {
if let Some(rw) = reject_writer.as_mut() {
for raw in raw_records.drain(..) {
write_raw(rw, &raw)?;
}
}
}
}
}
}
progress.log_if_needed(num_records);
}
if eof {
break;
}
current_name = Some(bam_fields::read_name(record.as_ref()).to_vec());
let taken = std::mem::take(&mut record);
current_template.push(taken);
}
progress.log_final();
writer.into_inner().finish()?;
if let Some(rw) = reject_writer {
rw.into_inner().finish()?;
}
self.finalize_metrics(&mut umi_metrics, &unmatched_umi)?;
let rejected = missing_umis + wrong_length + mismatched;
let kept = total_records - rejected;
info!("Read {total_records}; kept {kept} and rejected {rejected}");
info!("Total templates processed: {total_templates}");
if missing_umis > 0 || wrong_length > 0 {
warn!("###################################################################");
if missing_umis > 0 {
warn!("# {missing_umis} were missing UMI attributes in the BAM file!");
}
if wrong_length > 0 {
warn!("# {wrong_length} had unexpected UMIs of differing lengths in the BAM file!");
}
warn!("###################################################################");
}
if let Some(min) = self.min_corrected {
#[allow(clippy::cast_precision_loss)]
let ratio_kept = kept as f64 / total_records as f64;
if ratio_kept < min {
bail!(
"Final ratio of reads kept / total was {ratio_kept:.2} (user specified minimum was {min:.2}). \
This could indicate a mismatch between library preparation and the provided UMI file."
);
}
}
Ok(total_records)
}
}
fn merge_umi_counts(
dst: &mut AHashMap<String, UmiCorrectionMetrics>,
umi: String,
counts: &UmiCorrectionMetrics,
) {
let entry = dst.entry(umi).or_insert_with_key(|k| UmiCorrectionMetrics::new(k.clone()));
entry.total_matches += counts.total_matches;
entry.perfect_matches += counts.perfect_matches;
entry.one_mismatch_matches += counts.one_mismatch_matches;
entry.two_mismatch_matches += counts.two_mismatch_matches;
entry.other_matches += counts.other_matches;
}
#[must_use]
pub fn count_mismatches_with_max(a: &[u8], b: &[u8], max_mismatches: usize) -> usize {
let mut mismatches = 0;
let min_len = a.len().min(b.len());
for i in 0..min_len {
if a[i] != b[i] {
mismatches += 1;
if mismatches > max_mismatches {
return mismatches;
}
}
}
mismatches += a.len().abs_diff(b.len());
mismatches
}
#[derive(Clone)]
pub struct EncodedUmiSet {
bytes: Vec<Vec<u8>>,
encoded: Vec<Option<BitEnc>>,
strings: Vec<String>,
}
impl EncodedUmiSet {
pub fn new(sequences: &[String]) -> Self {
let bytes: Vec<Vec<u8>> =
sequences.iter().map(|s| s.bytes().map(|b| b.to_ascii_uppercase()).collect()).collect();
let encoded: Vec<Option<BitEnc>> = bytes.iter().map(|b| BitEnc::from_bytes(b)).collect();
let strings: Vec<String> = sequences.iter().map(|s| s.to_uppercase()).collect();
Self { bytes, encoded, strings }
}
#[inline]
pub fn len(&self) -> usize {
self.bytes.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.bytes.is_empty()
}
}
fn find_best_match_encoded(
observed: &[u8],
umi_set: &EncodedUmiSet,
max_mismatches: usize,
min_distance_diff: usize,
) -> UmiMatch {
let mut best_index: Option<usize> = None;
let mut best_mismatches = usize::MAX;
let mut second_best_mismatches = usize::MAX;
let observed_encoded = BitEnc::from_bytes(observed);
match observed_encoded {
Some(obs_enc) => {
for (i, fixed_enc) in umi_set.encoded.iter().enumerate() {
let mismatches = if let Some(enc) = fixed_enc {
enc.hamming_distance(&obs_enc) as usize
} else {
count_mismatches_with_max(observed, &umi_set.bytes[i], second_best_mismatches)
};
if mismatches < best_mismatches {
second_best_mismatches = best_mismatches;
best_mismatches = mismatches;
best_index = Some(i);
} else if mismatches < second_best_mismatches {
second_best_mismatches = mismatches;
}
}
}
None => {
for (i, fixed_umi) in umi_set.bytes.iter().enumerate() {
let mismatches =
count_mismatches_with_max(observed, fixed_umi, second_best_mismatches);
if mismatches < best_mismatches {
second_best_mismatches = best_mismatches;
best_mismatches = mismatches;
best_index = Some(i);
} else if mismatches < second_best_mismatches {
second_best_mismatches = mismatches;
}
}
}
}
let Some(idx) = best_index else {
return UmiMatch { matched: false, umi: String::new(), mismatches: usize::MAX };
};
let matched = if best_mismatches <= max_mismatches {
let distance_to_second = second_best_mismatches.saturating_sub(best_mismatches);
distance_to_second >= min_distance_diff
} else {
false
};
UmiMatch { matched, umi: umi_set.strings[idx].clone(), mismatches: best_mismatches }
}
#[must_use]
pub fn find_umi_pairs_within_distance(
umis: &[String],
distance: usize,
) -> Vec<(String, String, usize)> {
let mut pairs = Vec::new();
for i in 0..umis.len() {
for j in (i + 1)..umis.len() {
let d = count_mismatches_with_max(umis[i].as_bytes(), umis[j].as_bytes(), distance + 1);
if d <= distance {
pairs.push((umis[i].clone(), umis[j].clone(), d));
}
}
}
pairs
}
#[cfg(test)]
mod tests {
use super::*;
use noodles::sam;
use noodles::sam::alignment::io::Write as SamWrite;
use noodles::sam::alignment::record_buf::RecordBuf;
use rstest::rstest;
use std::{io::Write as IoWrite, path::Path};
use tempfile::{NamedTempFile, TempDir};
struct TestPaths {
dir: TempDir,
pub output: PathBuf,
pub rejects: PathBuf,
pub metrics: PathBuf,
}
impl TestPaths {
fn new() -> Result<Self> {
let dir = TempDir::new()?;
Ok(Self {
output: dir.path().join("output.bam"),
rejects: dir.path().join("rejects.bam"),
metrics: dir.path().join("metrics.txt"),
dir,
})
}
fn output_n(&self, n: usize) -> PathBuf {
self.dir.path().join(format!("output{n}.bam"))
}
}
const FIXED_UMIS: &[&str] = &["AAAAAA", "CCCCCC", "GGGGGG", "TTTTTT"];
fn get_encoded_umi_set() -> EncodedUmiSet {
let strings: Vec<String> = FIXED_UMIS.iter().map(|s| (*s).to_string()).collect();
EncodedUmiSet::new(&strings)
}
#[test]
fn test_find_best_match_perfect() {
let umi_set = get_encoded_umi_set();
let hit1 = find_best_match_encoded(b"AAAAAA", &umi_set, 2, 2);
assert!(hit1.matched);
assert_eq!(hit1.mismatches, 0);
assert_eq!(hit1.umi, "AAAAAA");
let hit2 = find_best_match_encoded(b"CCCCCC", &umi_set, 2, 2);
assert!(hit2.matched);
assert_eq!(hit2.mismatches, 0);
assert_eq!(hit2.umi, "CCCCCC");
}
#[test]
fn test_find_best_match_with_mismatches() {
let umi_set = get_encoded_umi_set();
let m1 = find_best_match_encoded(b"AAAAAA", &umi_set, 2, 2);
assert!(m1.matched);
assert_eq!(m1.umi, "AAAAAA");
assert_eq!(m1.mismatches, 0);
let m2 = find_best_match_encoded(b"AAAAAT", &umi_set, 2, 2);
assert!(m2.matched);
assert_eq!(m2.umi, "AAAAAA");
assert_eq!(m2.mismatches, 1);
let m3 = find_best_match_encoded(b"AAAACT", &umi_set, 2, 2);
assert!(m3.matched);
assert_eq!(m3.umi, "AAAAAA");
assert_eq!(m3.mismatches, 2);
let m4 = find_best_match_encoded(b"AAAGCT", &umi_set, 2, 2);
assert!(!m4.matched);
assert_eq!(m4.mismatches, 3);
}
#[test]
fn test_no_match_when_umis_too_similar() {
let umi_set = EncodedUmiSet::new(&["AAAG".to_string(), "AAAT".to_string()]);
let m = find_best_match_encoded(b"AAAG", &umi_set, 2, 2);
assert!(!m.matched);
}
#[test]
fn test_match_with_many_mismatches() {
let umi_set = get_encoded_umi_set();
let m1 = find_best_match_encoded(b"AAAAAA", &umi_set, 3, 2);
assert!(m1.matched);
assert_eq!(m1.umi, "AAAAAA");
assert_eq!(m1.mismatches, 0);
let m2 = find_best_match_encoded(b"AAACGT", &umi_set, 3, 2);
assert!(m2.matched);
assert_eq!(m2.umi, "AAAAAA");
assert_eq!(m2.mismatches, 3);
let m3 = find_best_match_encoded(b"AAACCC", &umi_set, 3, 2);
assert!(!m3.matched);
assert_eq!(m3.mismatches, 3);
let m4 = find_best_match_encoded(b"AAACCT", &umi_set, 3, 2);
assert!(!m4.matched);
assert_eq!(m4.mismatches, 3);
}
#[test]
fn test_find_umi_pairs_within_distance_none() {
let pairs = find_umi_pairs_within_distance(
&["AAAA".to_string(), "TTTT".to_string(), "CCCC".to_string(), "GGGG".to_string()],
2,
);
assert!(pairs.is_empty());
}
#[test]
fn test_find_umi_pairs_within_distance_some() {
let umis = vec![
"ACACAC".to_string(),
"CTCTCT".to_string(),
"GAGAGA".to_string(),
"TGTGTG".to_string(),
"ACAGAC".to_string(),
"AGAGAG".to_string(),
];
let pairs = find_umi_pairs_within_distance(&umis, 2);
assert_eq!(pairs.len(), 2);
assert!(pairs.contains(&("ACACAC".to_string(), "ACAGAC".to_string(), 1)));
assert!(pairs.contains(&("ACAGAC".to_string(), "AGAGAG".to_string(), 2)));
}
#[test]
fn test_reverse_complement() {
assert_eq!(reverse_complement_str("AAAAAA"), "TTTTTT");
assert_eq!(reverse_complement_str("AAAAGA"), "TCTTTT");
assert_eq!(reverse_complement_str("ACGT"), "ACGT");
assert_eq!(reverse_complement_str("GGGGGG"), "CCCCCC");
}
#[test]
fn test_count_mismatches_with_max() {
assert_eq!(count_mismatches_with_max(b"AAAAAA", b"AAAAAA", 10), 0);
assert_eq!(count_mismatches_with_max(b"AAAAAA", b"AAAAAT", 10), 1);
assert_eq!(count_mismatches_with_max(b"AAAAAA", b"AAAATT", 10), 2);
assert_eq!(count_mismatches_with_max(b"AAAAAA", b"CCCCCC", 10), 6);
assert_eq!(count_mismatches_with_max(b"AAAAAA", b"CCCCCC", 2), 3);
}
fn create_test_bam(records: Vec<(&str, Option<&str>)>) -> Result<NamedTempFile> {
use fgumi_raw_bam::{
SamBuilder as RawSamBuilder, raw_record_to_record_buf, testutil::encode_op,
};
use noodles::sam::header::record::value::map::Map;
let temp_file = NamedTempFile::new()?;
let path = temp_file.path().to_path_buf();
let header = sam::Header::builder()
.add_reference_sequence(
"chr1",
Map::<sam::header::record::value::map::ReferenceSequence>::new(
std::num::NonZero::new(1000).expect("non-zero reference length"),
),
)
.build();
let mut writer = noodles::bam::io::writer::Builder.build_from_path(&path)?;
writer.write_header(&header)?;
for (name, umi) in records {
let mut b = RawSamBuilder::new();
b.read_name(name.as_bytes())
.ref_id(0)
.pos(0)
.mapq(60)
.cigar_ops(&[encode_op(0, 10)])
.sequence(b"AAAAAAAAAA")
.qualities(&[40u8; 10]);
if let Some(umi_seq) = umi {
b.add_string_tag(b"RX", umi_seq.as_bytes());
}
let raw = b.build();
let record = raw_record_to_record_buf(&raw, &sam::Header::default())
.expect("raw_record_to_record_buf failed");
writer.write_alignment_record(&header, &record)?;
}
writer.try_finish()?;
Ok(temp_file)
}
fn read_bam_record_names(path: &Path) -> Result<Vec<String>> {
let mut reader = noodles::bam::io::reader::Builder.build_from_path(path)?;
let _header = reader.read_header()?;
let mut names = Vec::new();
for result in reader.records() {
let record = result?;
let name = record.name().map(std::string::ToString::to_string).unwrap_or_default();
names.push(name);
}
Ok(names)
}
#[test]
fn test_rejects_reads_without_umis() -> Result<()> {
let input_file = create_test_bam(vec![("q1", None)])?;
let paths = TestPaths::new()?;
let corrector = CorrectUmis {
io: BamIoOptions {
input: input_file.path().to_path_buf(),
output: paths.output.clone(),
async_reader: false,
},
rejects_opts: RejectsOptions { rejects: Some(paths.rejects.clone()) },
metrics: None,
max_mismatches: 2,
min_distance_diff: 2,
umis: vec!["AAAAAA".to_string(), "CCCCCC".to_string()],
umi_files: vec![],
dont_store_original_umis: false,
cache_size: 100_000,
min_corrected: None,
revcomp: false,
threading: ThreadingOptions::none(),
compression: CompressionOptions { compression_level: 1 },
scheduler_opts: SchedulerOptions::default(),
queue_memory: QueueMemoryOptions::default(),
};
corrector.execute("test")?;
assert_eq!(read_bam_record_names(&paths.output)?.len(), 0);
assert_eq!(read_bam_record_names(&paths.rejects)?.len(), 1);
Ok(())
}
#[test]
fn test_validation_different_length_umis() {
let temp_input = NamedTempFile::new().expect("failed to create temp file");
let temp_output = NamedTempFile::new().expect("failed to create temp file");
let corrector = CorrectUmis {
io: BamIoOptions {
input: temp_input.path().to_path_buf(),
output: temp_output.path().to_path_buf(),
async_reader: false,
},
rejects_opts: RejectsOptions::default(),
metrics: None,
max_mismatches: 2,
min_distance_diff: 2,
umis: vec!["AAAAAA".to_string(), "CCC".to_string()],
umi_files: vec![],
dont_store_original_umis: false,
cache_size: 100_000,
min_corrected: None,
revcomp: false,
threading: ThreadingOptions::none(),
compression: CompressionOptions { compression_level: 1 },
scheduler_opts: SchedulerOptions::default(),
queue_memory: QueueMemoryOptions::default(),
};
assert!(corrector.execute("test").is_err());
}
#[test]
fn test_rejects_incorrect_length_umis() -> Result<()> {
let input = create_test_bam(vec![("q1", Some("ACGT"))])?;
let dir = TempDir::new()?;
let output = dir.path().join("output.bam");
let rejects = dir.path().join("rejects.bam");
let corrector = CorrectUmis {
io: BamIoOptions {
input: input.path().to_path_buf(),
output: output.clone(),
async_reader: false,
},
rejects_opts: RejectsOptions { rejects: Some(rejects.clone()) },
metrics: None,
max_mismatches: 2,
min_distance_diff: 2,
umis: vec!["AAAAAA".to_string(), "CCCCCC".to_string()],
umi_files: vec![],
dont_store_original_umis: false,
cache_size: 100_000,
min_corrected: None,
revcomp: false,
threading: ThreadingOptions::none(),
compression: CompressionOptions { compression_level: 1 },
scheduler_opts: SchedulerOptions::default(),
queue_memory: QueueMemoryOptions::default(),
};
corrector.execute("test")?;
assert_eq!(read_bam_record_names(&output)?.len(), 0);
assert_eq!(read_bam_record_names(&rejects)?.len(), 1);
Ok(())
}
#[test]
fn test_end_to_end_single_umis() -> Result<()> {
let input = create_test_bam(vec![
("q1", Some("AAAAAA")),
("q2", Some("AAAAAA")),
("q3", Some("AATAAA")),
("q4", Some("AATTAA")),
("q5", Some("AAATGC")),
("q6", Some("CCCCCC")),
("q7", Some("GGGGGG")),
("q8", Some("TTTTTT")),
("q9", Some("GGGTTT")),
("q10", Some("AAACCC")),
])?;
let dir = TempDir::new()?;
let output = dir.path().join("output.bam");
let rejects = dir.path().join("rejects.bam");
let metrics = dir.path().join("metrics.txt");
let corrector = CorrectUmis {
io: BamIoOptions {
input: input.path().to_path_buf(),
output: output.clone(),
async_reader: false,
},
rejects_opts: RejectsOptions { rejects: Some(rejects.clone()) },
metrics: Some(metrics.clone()),
max_mismatches: 3,
min_distance_diff: 2,
umis: vec![
"AAAAAA".to_string(),
"CCCCCC".to_string(),
"GGGGGG".to_string(),
"TTTTTT".to_string(),
],
umi_files: vec![],
dont_store_original_umis: false,
cache_size: 100_000,
min_corrected: None,
revcomp: false,
threading: ThreadingOptions::none(),
compression: CompressionOptions { compression_level: 1 },
scheduler_opts: SchedulerOptions::default(),
queue_memory: QueueMemoryOptions::default(),
};
corrector.execute("test")?;
let corrected_names = read_bam_record_names(&output)?;
assert_eq!(corrected_names.len(), 8);
assert!(corrected_names.contains(&"q1".to_string()));
assert!(corrected_names.contains(&"q8".to_string()));
let rejected_names = read_bam_record_names(&rejects)?;
assert_eq!(rejected_names.len(), 2);
assert!(rejected_names.contains(&"q9".to_string()));
assert!(rejected_names.contains(&"q10".to_string()));
let metrics_data = UmiCorrectionMetrics::read_metrics(&metrics)?;
let aaaaaa =
metrics_data.iter().find(|m| m.umi == "AAAAAA").expect("AAAAAA metric not found");
assert_eq!(aaaaaa.total_matches, 5);
assert_eq!(aaaaaa.perfect_matches, 2);
assert_eq!(aaaaaa.one_mismatch_matches, 1);
assert_eq!(aaaaaa.two_mismatch_matches, 1);
assert_eq!(aaaaaa.other_matches, 1);
Ok(())
}
#[test]
fn test_end_to_end_duplex_umis() -> Result<()> {
let input = create_test_bam(vec![
("q1", Some("AAAAAA-CCCCCC")),
("q2", Some("AAACAA-CCCCAC")),
("q3", Some("AAAAAA-ACTACT")),
("q4", Some("GGGGGG-TTTTTT")),
("q5", Some("GCGCGC-TTTTTT")),
])?;
let dir = TempDir::new()?;
let output = dir.path().join("output.bam");
let rejects = dir.path().join("rejects.bam");
let corrector = CorrectUmis {
io: BamIoOptions {
input: input.path().to_path_buf(),
output: output.clone(),
async_reader: false,
},
rejects_opts: RejectsOptions { rejects: Some(rejects.clone()) },
metrics: None,
max_mismatches: 3,
min_distance_diff: 2,
umis: vec![
"AAAAAA".to_string(),
"CCCCCC".to_string(),
"GGGGGG".to_string(),
"TTTTTT".to_string(),
],
umi_files: vec![],
dont_store_original_umis: false,
cache_size: 100_000,
min_corrected: None,
revcomp: false,
threading: ThreadingOptions::none(),
compression: CompressionOptions { compression_level: 1 },
scheduler_opts: SchedulerOptions::default(),
queue_memory: QueueMemoryOptions::default(),
};
corrector.execute("test")?;
let corrected_names = read_bam_record_names(&output)?;
assert_eq!(corrected_names.len(), 3);
assert!(corrected_names.contains(&"q1".to_string()));
assert!(corrected_names.contains(&"q4".to_string()));
let rejected_names = read_bam_record_names(&rejects)?;
assert_eq!(rejected_names.len(), 2);
assert!(rejected_names.contains(&"q3".to_string()));
Ok(())
}
#[test]
fn test_umi_loading_from_file() -> Result<()> {
let input = create_test_bam(vec![("q1", Some("AAAAAA")), ("q2", Some("AATAAA"))])?;
let mut umi_file = NamedTempFile::new()?;
writeln!(umi_file, "AAAAAA")?;
writeln!(umi_file, "CCCCCC")?;
writeln!(umi_file, "GGGGGG")?;
writeln!(umi_file, "TTTTTT")?;
umi_file.flush()?;
let dir = TempDir::new()?;
let output1 = dir.path().join("output1.bam");
let metrics1 = dir.path().join("metrics1.txt");
let output2 = dir.path().join("output2.bam");
let metrics2 = dir.path().join("metrics2.txt");
let corrector1 = CorrectUmis {
io: BamIoOptions {
input: input.path().to_path_buf(),
output: output1.clone(),
async_reader: false,
},
rejects_opts: RejectsOptions::default(),
metrics: Some(metrics1.clone()),
max_mismatches: 3,
min_distance_diff: 2,
umis: vec![
"AAAAAA".to_string(),
"CCCCCC".to_string(),
"GGGGGG".to_string(),
"TTTTTT".to_string(),
],
umi_files: vec![],
dont_store_original_umis: false,
cache_size: 100_000,
min_corrected: None,
revcomp: false,
threading: ThreadingOptions::none(),
compression: CompressionOptions { compression_level: 1 },
scheduler_opts: SchedulerOptions::default(),
queue_memory: QueueMemoryOptions::default(),
};
corrector1.execute("test")?;
let corrector2 = CorrectUmis {
io: BamIoOptions {
input: input.path().to_path_buf(),
output: output2.clone(),
async_reader: false,
},
rejects_opts: RejectsOptions::default(),
metrics: Some(metrics2.clone()),
max_mismatches: 3,
min_distance_diff: 2,
umis: vec![],
umi_files: vec![umi_file.path().to_path_buf()],
dont_store_original_umis: false,
cache_size: 100_000,
min_corrected: None,
revcomp: false,
threading: ThreadingOptions::none(),
compression: CompressionOptions { compression_level: 1 },
scheduler_opts: SchedulerOptions::default(),
queue_memory: QueueMemoryOptions::default(),
};
corrector2.execute("test")?;
let m1 = UmiCorrectionMetrics::read_metrics(&metrics1)?;
let m2 = UmiCorrectionMetrics::read_metrics(&metrics2)?;
assert_eq!(m1.len(), m2.len());
for (metric1, metric2) in m1.iter().zip(m2.iter()) {
assert_eq!(metric1.umi, metric2.umi);
assert_eq!(metric1.total_matches, metric2.total_matches);
}
Ok(())
}
#[test]
fn test_original_umi_storage() -> Result<()> {
let input =
create_test_bam(vec![("exact", Some("AAAAAA")), ("correctable", Some("AAAAGA"))])?;
let dir = TempDir::new()?;
let output = dir.path().join("output.bam");
let corrector = CorrectUmis {
io: BamIoOptions {
input: input.path().to_path_buf(),
output: output.clone(),
async_reader: false,
},
rejects_opts: RejectsOptions::default(),
metrics: None,
max_mismatches: 2,
min_distance_diff: 2,
umis: vec!["AAAAAA".to_string(), "TTTTTT".to_string(), "CCCCCC".to_string()],
umi_files: vec![],
dont_store_original_umis: false,
cache_size: 100_000,
min_corrected: None,
revcomp: false,
threading: ThreadingOptions::none(),
compression: CompressionOptions { compression_level: 1 },
scheduler_opts: SchedulerOptions::default(),
queue_memory: QueueMemoryOptions::default(),
};
corrector.execute("test")?;
let mut reader = noodles::bam::io::reader::Builder.build_from_path(&output)?;
let header = reader.read_header()?;
for result in reader.records() {
let record = result?;
let record_buf = RecordBuf::try_from_alignment_record(&header, &record)?;
let name = record_buf.name().map(std::string::ToString::to_string).unwrap_or_default();
let ox_tag = Tag::ORIGINAL_UMI_BARCODE_SEQUENCE;
if name == "exact" {
assert!(record_buf.data().get(&ox_tag).is_none());
} else if name == "correctable" {
let ox_value = record_buf.data().get(&ox_tag);
assert!(ox_value.is_some());
}
}
let output2 = dir.path().join("output2.bam");
let corrector2 = CorrectUmis {
io: BamIoOptions {
input: input.path().to_path_buf(),
output: output2.clone(),
async_reader: false,
},
rejects_opts: RejectsOptions::default(),
metrics: None,
max_mismatches: 2,
min_distance_diff: 2,
umis: vec!["AAAAAA".to_string(), "TTTTTT".to_string(), "CCCCCC".to_string()],
umi_files: vec![],
dont_store_original_umis: true,
cache_size: 100_000,
min_corrected: None,
revcomp: false,
threading: ThreadingOptions::none(),
compression: CompressionOptions { compression_level: 1 },
scheduler_opts: SchedulerOptions::default(),
queue_memory: QueueMemoryOptions::default(),
};
corrector2.execute("test")?;
let mut reader2 = noodles::bam::io::reader::Builder.build_from_path(&output2)?;
let header2 = reader2.read_header()?;
for result in reader2.records() {
let record = result?;
let record_buf = RecordBuf::try_from_alignment_record(&header2, &record)?;
assert!(record_buf.data().get(&Tag::ORIGINAL_UMI_BARCODE_SEQUENCE).is_none());
}
Ok(())
}
#[test]
fn test_revcomp_option() -> Result<()> {
let input = create_test_bam(vec![
("exact", Some(&reverse_complement_str("AAAAAA"))),
("correctable", Some(&reverse_complement_str("AAAAGA"))),
])?;
let paths = TestPaths::new()?;
let corrector = CorrectUmis {
io: BamIoOptions {
input: input.path().to_path_buf(),
output: paths.output.clone(),
async_reader: false,
},
rejects_opts: RejectsOptions::default(),
metrics: None,
max_mismatches: 2,
min_distance_diff: 2,
umis: vec!["AAAAAA".to_string(), "TTTTTT".to_string(), "CCCCCC".to_string()],
umi_files: vec![],
dont_store_original_umis: false,
cache_size: 100_000,
min_corrected: None,
revcomp: true,
threading: ThreadingOptions::none(),
compression: CompressionOptions { compression_level: 1 },
scheduler_opts: SchedulerOptions::default(),
queue_memory: QueueMemoryOptions::default(),
};
corrector.execute("test")?;
let mut reader = noodles::bam::io::reader::Builder.build_from_path(&paths.output)?;
let header = reader.read_header()?;
for result in reader.records() {
let record = result?;
let record_buf = RecordBuf::try_from_alignment_record(&header, &record)?;
let name = record_buf.name().map(std::string::ToString::to_string).unwrap_or_default();
if name == "correctable" {
let ox_tag = Tag::ORIGINAL_UMI_BARCODE_SEQUENCE;
let ox_value = record_buf.data().get(&ox_tag);
assert!(ox_value.is_some());
if let Some(sam::alignment::record_buf::data::field::Value::String(s)) = ox_value {
assert_eq!(s.to_string(), "TCTTTT");
}
}
}
Ok(())
}
#[test]
fn test_exact_match_only_mode() -> Result<()> {
let input = create_test_bam(vec![
("q1", Some("AAAAAA")), ("q2", Some("AAAAAB")), ("q3", Some("CCCCCC")), ])?;
let paths = TestPaths::new()?;
let corrector = CorrectUmis {
io: BamIoOptions {
input: input.path().to_path_buf(),
output: paths.output.clone(),
async_reader: false,
},
rejects_opts: RejectsOptions { rejects: Some(paths.rejects.clone()) },
metrics: None,
max_mismatches: 0, min_distance_diff: 1,
umis: vec!["AAAAAA".to_string(), "CCCCCC".to_string()],
umi_files: vec![],
dont_store_original_umis: false,
cache_size: 100_000,
min_corrected: None,
revcomp: false,
threading: ThreadingOptions::none(),
compression: CompressionOptions { compression_level: 1 },
scheduler_opts: SchedulerOptions::default(),
queue_memory: QueueMemoryOptions::default(),
};
corrector.execute("test")?;
let corrected_names = read_bam_record_names(&paths.output)?;
assert_eq!(corrected_names.len(), 2);
assert!(corrected_names.contains(&"q1".to_string()));
assert!(corrected_names.contains(&"q3".to_string()));
let rejected_names = read_bam_record_names(&paths.rejects)?;
assert_eq!(rejected_names.len(), 1);
assert!(rejected_names.contains(&"q2".to_string()));
Ok(())
}
#[test]
fn test_all_reads_already_correct() -> Result<()> {
let input = create_test_bam(vec![
("q1", Some("AAAAAA")),
("q2", Some("CCCCCC")),
("q3", Some("GGGGGG")),
("q4", Some("TTTTTT")),
])?;
let paths = TestPaths::new()?;
let corrector = CorrectUmis {
io: BamIoOptions {
input: input.path().to_path_buf(),
output: paths.output.clone(),
async_reader: false,
},
rejects_opts: RejectsOptions { rejects: Some(paths.rejects.clone()) },
metrics: Some(paths.metrics.clone()),
max_mismatches: 2,
min_distance_diff: 2,
umis: vec![
"AAAAAA".to_string(),
"CCCCCC".to_string(),
"GGGGGG".to_string(),
"TTTTTT".to_string(),
],
umi_files: vec![],
dont_store_original_umis: false,
cache_size: 100_000,
min_corrected: None,
revcomp: false,
threading: ThreadingOptions::none(),
compression: CompressionOptions { compression_level: 1 },
scheduler_opts: SchedulerOptions::default(),
queue_memory: QueueMemoryOptions::default(),
};
corrector.execute("test")?;
let corrected_names = read_bam_record_names(&paths.output)?;
assert_eq!(corrected_names.len(), 4);
let rejected_names = read_bam_record_names(&paths.rejects)?;
assert_eq!(rejected_names.len(), 0);
let metrics_data = UmiCorrectionMetrics::read_metrics(&paths.metrics)?;
for metric in &metrics_data {
if !metric.umi.starts_with('N') {
assert_eq!(metric.perfect_matches, 1);
assert_eq!(metric.one_mismatch_matches, 0);
}
}
Ok(())
}
#[test]
fn test_single_known_umi() -> Result<()> {
let input = create_test_bam(vec![
("q1", Some("AAAAAA")),
("q2", Some("AAAAAB")),
("q3", Some("AAAACC")),
("q4", Some("CCCCCC")), ])?;
let paths = TestPaths::new()?;
let corrector = CorrectUmis {
io: BamIoOptions {
input: input.path().to_path_buf(),
output: paths.output.clone(),
async_reader: false,
},
rejects_opts: RejectsOptions { rejects: Some(paths.rejects.clone()) },
metrics: None,
max_mismatches: 2,
min_distance_diff: 2,
umis: vec!["AAAAAA".to_string()], umi_files: vec![],
dont_store_original_umis: false,
cache_size: 100_000,
min_corrected: None,
revcomp: false,
threading: ThreadingOptions::none(),
compression: CompressionOptions { compression_level: 1 },
scheduler_opts: SchedulerOptions::default(),
queue_memory: QueueMemoryOptions::default(),
};
corrector.execute("test")?;
let corrected_names = read_bam_record_names(&paths.output)?;
assert_eq!(corrected_names.len(), 3);
let rejected_names = read_bam_record_names(&paths.rejects)?;
assert_eq!(rejected_names.len(), 1);
assert!(rejected_names.contains(&"q4".to_string()));
Ok(())
}
#[test]
fn test_umis_with_n_bases() -> Result<()> {
let input = create_test_bam(vec![
("q1", Some("ANAAAA")), ("q2", Some("AANAAA")), ("q3", Some("AAAAAA")), ])?;
let paths = TestPaths::new()?;
let corrector = CorrectUmis {
io: BamIoOptions {
input: input.path().to_path_buf(),
output: paths.output.clone(),
async_reader: false,
},
rejects_opts: RejectsOptions { rejects: Some(paths.rejects.clone()) },
metrics: None,
max_mismatches: 2,
min_distance_diff: 2,
umis: vec!["AAAAAA".to_string()],
umi_files: vec![],
dont_store_original_umis: false,
cache_size: 100_000,
min_corrected: None,
revcomp: false,
threading: ThreadingOptions::none(),
compression: CompressionOptions { compression_level: 1 },
scheduler_opts: SchedulerOptions::default(),
queue_memory: QueueMemoryOptions::default(),
};
corrector.execute("test")?;
let corrected_names = read_bam_record_names(&paths.output)?;
assert_eq!(corrected_names.len(), 3);
Ok(())
}
#[test]
fn test_very_long_umis() -> Result<()> {
let input = create_test_bam(vec![
("q1", Some("AAAAAAAAAAAAAAAAAAAA")), ("q2", Some("AAAAAAAAAAAAAAAAAAAB")), ("q3", Some("CCCCCCCCCCCCCCCCCCCC")), ("q4", Some("CCCCCCCCCCCCCCCCCCCT")), ])?;
let paths = TestPaths::new()?;
let corrector = CorrectUmis {
io: BamIoOptions {
input: input.path().to_path_buf(),
output: paths.output.clone(),
async_reader: false,
},
rejects_opts: RejectsOptions { rejects: Some(paths.rejects.clone()) },
metrics: None,
max_mismatches: 1,
min_distance_diff: 5, umis: vec!["AAAAAAAAAAAAAAAAAAAA".to_string(), "CCCCCCCCCCCCCCCCCCCC".to_string()],
umi_files: vec![],
dont_store_original_umis: false,
cache_size: 100_000,
min_corrected: None,
revcomp: false,
threading: ThreadingOptions::none(),
compression: CompressionOptions { compression_level: 1 },
scheduler_opts: SchedulerOptions::default(),
queue_memory: QueueMemoryOptions::default(),
};
corrector.execute("test")?;
let corrected_names = read_bam_record_names(&paths.output)?;
assert_eq!(corrected_names.len(), 4);
Ok(())
}
#[test]
fn test_min_corrected_threshold() -> Result<()> {
let input = create_test_bam(vec![
("q1", Some("AAAAAA")), ("q2", Some("AAAAAB")), ("q3", Some("AAAAAC")), ("q4", Some("TTTTTT")), ])?;
let paths = TestPaths::new()?;
let corrector = CorrectUmis {
io: BamIoOptions {
input: input.path().to_path_buf(),
output: paths.output.clone(),
async_reader: false,
},
rejects_opts: RejectsOptions::default(),
metrics: None,
max_mismatches: 1,
min_distance_diff: 2,
umis: vec!["AAAAAA".to_string()],
umi_files: vec![],
dont_store_original_umis: false,
cache_size: 100_000,
min_corrected: Some(0.75), revcomp: false,
threading: ThreadingOptions::none(),
compression: CompressionOptions { compression_level: 1 },
scheduler_opts: SchedulerOptions::default(),
queue_memory: QueueMemoryOptions::default(),
};
corrector.execute("test")?;
let corrected_names = read_bam_record_names(&paths.output)?;
assert_eq!(corrected_names.len(), 3);
let input2 = create_test_bam(vec![
("q1", Some("AAAAAA")), ("q2", Some("TTTTTT")), ("q3", Some("GGGGGG")), ("q4", Some("CCCCCC")), ])?;
let output2 = paths.output_n(2);
let corrector2 = CorrectUmis {
io: BamIoOptions {
input: input2.path().to_path_buf(),
output: output2.clone(),
async_reader: false,
},
rejects_opts: RejectsOptions::default(),
metrics: None,
max_mismatches: 1,
min_distance_diff: 2,
umis: vec!["AAAAAA".to_string()],
umi_files: vec![],
dont_store_original_umis: false,
cache_size: 100_000,
min_corrected: Some(0.75), revcomp: false,
threading: ThreadingOptions::none(),
compression: CompressionOptions { compression_level: 1 },
scheduler_opts: SchedulerOptions::default(),
queue_memory: QueueMemoryOptions::default(),
};
let result = corrector2.execute("test");
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Final ratio of reads kept"));
Ok(())
}
#[test]
fn test_case_sensitivity() -> Result<()> {
let input = create_test_bam(vec![
("q1", Some("aaaaaa")), ("q2", Some("AaAaAa")), ("q3", Some("AAAAAA")), ])?;
let paths = TestPaths::new()?;
let corrector = CorrectUmis {
io: BamIoOptions {
input: input.path().to_path_buf(),
output: paths.output.clone(),
async_reader: false,
},
rejects_opts: RejectsOptions::default(),
metrics: None,
max_mismatches: 0, min_distance_diff: 1,
umis: vec!["AAAAAA".to_string()],
umi_files: vec![],
dont_store_original_umis: false,
cache_size: 100_000,
min_corrected: None,
revcomp: false,
threading: ThreadingOptions::none(),
compression: CompressionOptions { compression_level: 1 },
scheduler_opts: SchedulerOptions::default(),
queue_memory: QueueMemoryOptions::default(),
};
corrector.execute("test")?;
let corrected_names = read_bam_record_names(&paths.output)?;
assert_eq!(corrected_names.len(), 3);
Ok(())
}
#[test]
fn test_output_order_preserved_with_multiple_threads() -> Result<()> {
let mut records: Vec<(&str, Option<&str>)> = Vec::new();
let umis = ["AAAAAA", "CCCCCC", "GGGGGG", "TTTTTT"];
for i in 0..100 {
let name = format!("read_{i:04}");
let name_static: &'static str = Box::leak(name.into_boxed_str());
let umi = umis[i % 4];
records.push((name_static, Some(umi)));
}
let input = create_test_bam(records)?;
let dir = TempDir::new()?;
let output = dir.path().join("output.bam");
let corrector = CorrectUmis {
io: BamIoOptions {
input: input.path().to_path_buf(),
output: output.clone(),
async_reader: false,
},
rejects_opts: RejectsOptions::default(),
metrics: None,
max_mismatches: 2,
min_distance_diff: 2,
umis: vec![
"AAAAAA".to_string(),
"CCCCCC".to_string(),
"GGGGGG".to_string(),
"TTTTTT".to_string(),
],
umi_files: vec![],
dont_store_original_umis: false,
cache_size: 100_000,
min_corrected: None,
revcomp: false,
threading: ThreadingOptions::new(4), compression: CompressionOptions { compression_level: 1 },
scheduler_opts: SchedulerOptions::default(),
queue_memory: QueueMemoryOptions::default(),
};
corrector.execute("test")?;
let output_names = read_bam_record_names(&output)?;
assert_eq!(output_names.len(), 100);
for (i, name) in output_names.iter().enumerate() {
let expected = format!("read_{i:04}");
assert_eq!(
name, &expected,
"Output order mismatch at index {i}: expected {expected}, got {name}",
);
}
Ok(())
}
#[test]
fn test_output_order_preserved_with_8_threads() -> Result<()> {
let mut records: Vec<(&str, Option<&str>)> = Vec::new();
let umis = ["AAAAAA", "CCCCCC", "GGGGGG", "TTTTTT"];
for i in 0..500 {
let name = format!("rec_{i:05}");
let name_static: &'static str = Box::leak(name.into_boxed_str());
let umi = umis[i % 4];
records.push((name_static, Some(umi)));
}
let input = create_test_bam(records)?;
let dir = TempDir::new()?;
let output = dir.path().join("output.bam");
let corrector = CorrectUmis {
io: BamIoOptions {
input: input.path().to_path_buf(),
output: output.clone(),
async_reader: false,
},
rejects_opts: RejectsOptions::default(),
metrics: None,
max_mismatches: 2,
min_distance_diff: 2,
umis: vec![
"AAAAAA".to_string(),
"CCCCCC".to_string(),
"GGGGGG".to_string(),
"TTTTTT".to_string(),
],
umi_files: vec![],
dont_store_original_umis: false,
cache_size: 100_000,
min_corrected: None,
revcomp: false,
threading: ThreadingOptions::new(8), compression: CompressionOptions { compression_level: 1 },
scheduler_opts: SchedulerOptions::default(),
queue_memory: QueueMemoryOptions::default(),
};
corrector.execute("test")?;
let output_names = read_bam_record_names(&output)?;
assert_eq!(output_names.len(), 500);
for (i, name) in output_names.iter().enumerate() {
let expected = format!("rec_{i:05}");
assert_eq!(
name, &expected,
"Output order mismatch at index {i}: expected {expected}, got {name}",
);
}
Ok(())
}
#[test]
fn test_compute_template_correction_perfect_match() {
let umi_set = get_encoded_umi_set();
let mut cache = None;
let correction = CorrectUmis::compute_template_correction(
"AAAAAA", 6, false, 2, 2, &umi_set, &mut cache,
);
assert!(correction.matched);
assert_eq!(correction.corrected_umi, Some("AAAAAA".to_string()));
assert_eq!(correction.original_umi, "AAAAAA");
assert!(!correction.needs_correction);
assert!(!correction.has_mismatches);
assert_eq!(correction.rejection_reason, RejectionReason::None);
}
#[test]
fn test_compute_template_correction_with_mismatch() {
let umi_set = get_encoded_umi_set();
let mut cache = None;
let correction = CorrectUmis::compute_template_correction(
"AAAAAT", 6, false, 2, 2, &umi_set, &mut cache,
);
assert!(correction.matched);
assert_eq!(correction.corrected_umi, Some("AAAAAA".to_string()));
assert_eq!(correction.original_umi, "AAAAAT");
assert!(correction.needs_correction);
assert!(correction.has_mismatches);
assert_eq!(correction.rejection_reason, RejectionReason::None);
}
#[test]
fn test_compute_template_correction_wrong_length() {
let umi_set = get_encoded_umi_set();
let mut cache = None;
let correction =
CorrectUmis::compute_template_correction("AAAA", 6, false, 2, 2, &umi_set, &mut cache);
assert!(!correction.matched);
assert!(correction.corrected_umi.is_none());
assert_eq!(correction.rejection_reason, RejectionReason::WrongLength);
}
#[test]
fn test_compute_template_correction_too_many_mismatches() {
let umi_set = get_encoded_umi_set();
let mut cache = None;
let correction = CorrectUmis::compute_template_correction(
"AAAGGG", 6, false, 2, 2, &umi_set, &mut cache,
);
assert!(!correction.matched);
assert!(correction.corrected_umi.is_none());
assert_eq!(correction.rejection_reason, RejectionReason::Mismatched);
}
#[test]
fn test_compute_template_correction_with_revcomp() {
let umi_set = get_encoded_umi_set();
let mut cache = None;
let correction =
CorrectUmis::compute_template_correction("TTTTTT", 6, true, 2, 2, &umi_set, &mut cache);
assert!(correction.matched);
assert_eq!(correction.corrected_umi, Some("AAAAAA".to_string()));
assert!(correction.needs_correction); assert!(!correction.has_mismatches);
}
#[test]
fn test_compute_template_correction_dual_umi() {
let umi_set = get_encoded_umi_set();
let mut cache = None;
let correction = CorrectUmis::compute_template_correction(
"AAAAAA-CCCCCC",
6,
false,
2,
2,
&umi_set,
&mut cache,
);
assert!(correction.matched);
assert_eq!(correction.corrected_umi, Some("AAAAAA-CCCCCC".to_string()));
assert!(!correction.needs_correction);
}
#[test]
fn test_metrics_includes_unmatched_umi_row() -> Result<()> {
let input = create_test_bam(vec![
("q1", Some("AAAAAA")), ("q2", Some("GGGTTT")), ("q3", Some("CCCCCC")), ("q4", Some("ACTGAC")), ("q5", None), ])?;
let paths = TestPaths::new()?;
let corrector = CorrectUmis {
io: BamIoOptions {
input: input.path().to_path_buf(),
output: paths.output.clone(),
async_reader: false,
},
rejects_opts: RejectsOptions::default(),
metrics: Some(paths.metrics.clone()),
max_mismatches: 1, min_distance_diff: 2,
umis: vec!["AAAAAA".to_string(), "CCCCCC".to_string()],
umi_files: vec![],
dont_store_original_umis: false,
cache_size: 100_000,
min_corrected: None,
revcomp: false,
threading: ThreadingOptions::none(),
compression: CompressionOptions { compression_level: 1 },
scheduler_opts: SchedulerOptions::default(),
queue_memory: QueueMemoryOptions::default(),
};
corrector.execute("test")?;
let metrics_data = UmiCorrectionMetrics::read_metrics(&paths.metrics)?;
assert_eq!(metrics_data.len(), 3, "Expected 3 UMI rows in metrics");
let unmatched = metrics_data.iter().find(|m| m.umi == "NNNNNN");
assert!(unmatched.is_some(), "Unmatched UMI row (NNNNNN) not found in metrics");
let unmatched = unmatched.expect("unmatched UMI row (NNNNNN) should be present");
assert_eq!(
unmatched.total_matches, 3,
"Unmatched row should have 3 total_matches for uncorrectable reads"
);
assert_eq!(unmatched.perfect_matches, 0);
assert_eq!(unmatched.one_mismatch_matches, 0);
assert_eq!(unmatched.two_mismatch_matches, 0);
assert_eq!(unmatched.other_matches, 0);
Ok(())
}
#[test]
fn test_metrics_includes_all_umi_rows_even_with_zero_counts() -> Result<()> {
let input = create_test_bam(vec![
("q1", Some("AAAAAA")), ("q2", Some("AAAAAA")),
])?;
let paths = TestPaths::new()?;
let corrector = CorrectUmis {
io: BamIoOptions {
input: input.path().to_path_buf(),
output: paths.output.clone(),
async_reader: false,
},
rejects_opts: RejectsOptions::default(),
metrics: Some(paths.metrics.clone()),
max_mismatches: 1,
min_distance_diff: 2,
umis: vec![
"AAAAAA".to_string(),
"CCCCCC".to_string(), "GGGGGG".to_string(), ],
umi_files: vec![],
dont_store_original_umis: false,
cache_size: 100_000,
min_corrected: None,
revcomp: false,
threading: ThreadingOptions::none(),
compression: CompressionOptions { compression_level: 1 },
scheduler_opts: SchedulerOptions::default(),
queue_memory: QueueMemoryOptions::default(),
};
corrector.execute("test")?;
let metrics_data = UmiCorrectionMetrics::read_metrics(&paths.metrics)?;
assert_eq!(metrics_data.len(), 4, "Expected 4 UMI rows in metrics (including zeros)");
assert!(metrics_data.iter().any(|m| m.umi == "AAAAAA"), "AAAAAA should be in metrics");
assert!(
metrics_data.iter().any(|m| m.umi == "CCCCCC"),
"CCCCCC should be in metrics (even with 0 matches)"
);
assert!(
metrics_data.iter().any(|m| m.umi == "GGGGGG"),
"GGGGGG should be in metrics (even with 0 matches)"
);
assert!(
metrics_data.iter().any(|m| m.umi == "NNNNNN"),
"NNNNNN (unmatched) should be in metrics"
);
let cccccc =
metrics_data.iter().find(|m| m.umi == "CCCCCC").expect("CCCCCC metric not found");
assert_eq!(cccccc.total_matches, 0, "CCCCCC should have 0 total_matches");
let gggggg =
metrics_data.iter().find(|m| m.umi == "GGGGGG").expect("GGGGGG metric not found");
assert_eq!(gggggg.total_matches, 0, "GGGGGG should have 0 total_matches");
let aaaaaa =
metrics_data.iter().find(|m| m.umi == "AAAAAA").expect("AAAAAA metric not found");
assert_eq!(aaaaaa.total_matches, 2, "AAAAAA should have 2 total_matches");
Ok(())
}
#[test]
fn test_metrics_unmatched_row_with_multithreaded() -> Result<()> {
let mut records: Vec<(&str, Option<&str>)> = Vec::new();
for i in 0..50 {
let name = format!("read_{i:03}");
let name_static: &'static str = Box::leak(name.into_boxed_str());
let umi = match i % 5 {
0 => Some("AAAAAA"), 1 => Some("CCCCCC"), 2 => Some("AAAAAB"), 3 => Some("GGGTTT"), 4 => Some("ACTGAC"), _ => unreachable!(),
};
records.push((name_static, umi));
}
let input = create_test_bam(records)?;
let paths = TestPaths::new()?;
let corrector = CorrectUmis {
io: BamIoOptions {
input: input.path().to_path_buf(),
output: paths.output.clone(),
async_reader: false,
},
rejects_opts: RejectsOptions::default(),
metrics: Some(paths.metrics.clone()),
max_mismatches: 1,
min_distance_diff: 2,
umis: vec!["AAAAAA".to_string(), "CCCCCC".to_string()],
umi_files: vec![],
dont_store_original_umis: false,
cache_size: 100_000,
min_corrected: None,
revcomp: false,
threading: ThreadingOptions::new(4), compression: CompressionOptions { compression_level: 1 },
scheduler_opts: SchedulerOptions::default(),
queue_memory: QueueMemoryOptions::default(),
};
corrector.execute("test")?;
let metrics_data = UmiCorrectionMetrics::read_metrics(&paths.metrics)?;
assert_eq!(metrics_data.len(), 3, "Expected 3 UMI rows in metrics");
let unmatched = metrics_data
.iter()
.find(|m| m.umi == "NNNNNN")
.expect("unmatched UMI row (NNNNNN) not found");
assert_eq!(unmatched.total_matches, 20, "Unmatched row should have 20 total_matches");
let aaaaaa =
metrics_data.iter().find(|m| m.umi == "AAAAAA").expect("AAAAAA metric not found");
assert_eq!(aaaaaa.total_matches, 20, "AAAAAA should have 20 total_matches");
assert_eq!(aaaaaa.perfect_matches, 10);
assert_eq!(aaaaaa.one_mismatch_matches, 10);
let cccccc =
metrics_data.iter().find(|m| m.umi == "CCCCCC").expect("CCCCCC metric not found");
assert_eq!(cccccc.total_matches, 10, "CCCCCC should have 10 total_matches");
assert_eq!(cccccc.perfect_matches, 10);
Ok(())
}
#[rstest]
#[case::fast_path(ThreadingOptions::none())]
#[case::pipeline_1(ThreadingOptions::new(1))]
#[case::pipeline_2(ThreadingOptions::new(2))]
fn test_threading_modes(#[case] threading: ThreadingOptions) -> Result<()> {
let input = create_test_bam(vec![
("read1", Some("AAAAAA")),
("read2", Some("AAAAAG")), ("read3", Some("CCCCCC")),
])?;
let output = NamedTempFile::new()?;
let corrector = CorrectUmis {
io: BamIoOptions {
input: input.path().to_path_buf(),
output: output.path().to_path_buf(),
async_reader: false,
},
rejects_opts: RejectsOptions::default(),
metrics: None,
max_mismatches: 2,
min_distance_diff: 2,
umis: FIXED_UMIS.iter().map(|s| (*s).to_string()).collect(),
umi_files: vec![],
dont_store_original_umis: false,
cache_size: 100_000,
min_corrected: None,
revcomp: false,
threading,
compression: CompressionOptions { compression_level: 1 },
scheduler_opts: SchedulerOptions::default(),
queue_memory: QueueMemoryOptions::default(),
};
corrector.execute("test")?;
let names = read_bam_record_names(output.path())?;
assert_eq!(names.len(), 3, "Should have 3 records");
Ok(())
}
fn make_raw_bam_for_correct(name: &[u8], flag: u16, umi: &[u8]) -> RawRecord {
let l_read_name = (name.len() + 1) as u8;
let seq_len = 4usize;
let cigar_ops: &[u32] = if (flag & crate::sort::bam_fields::flags::UNMAPPED) == 0 {
&[(seq_len as u32) << 4]
} else {
&[]
};
let n_cigar_op = cigar_ops.len() as u16;
let seq_bytes = seq_len.div_ceil(2);
let total = 32 + l_read_name as usize + cigar_ops.len() * 4 + seq_bytes + seq_len;
let mut buf = vec![0u8; total];
buf[0..4].copy_from_slice(&0i32.to_le_bytes());
buf[4..8].copy_from_slice(&100i32.to_le_bytes());
buf[8] = l_read_name;
buf[9] = 30; buf[12..14].copy_from_slice(&n_cigar_op.to_le_bytes());
buf[14..16].copy_from_slice(&flag.to_le_bytes());
buf[16..20].copy_from_slice(&(seq_len as u32).to_le_bytes());
buf[20..24].copy_from_slice(&(-1i32).to_le_bytes());
buf[24..28].copy_from_slice(&(-1i32).to_le_bytes());
let name_start = 32;
buf[name_start..name_start + name.len()].copy_from_slice(name);
buf[name_start + name.len()] = 0;
let cigar_start = name_start + l_read_name as usize;
for (i, &op) in cigar_ops.iter().enumerate() {
let off = cigar_start + i * 4;
buf[off..off + 4].copy_from_slice(&op.to_le_bytes());
}
buf.extend_from_slice(b"RXZ");
buf.extend_from_slice(umi);
buf.push(0);
RawRecord::from(buf)
}
#[test]
fn test_extract_and_validate_template_umi_raw_single_record() {
let raw = make_raw_bam_for_correct(
b"rea",
crate::sort::bam_fields::flags::PAIRED | crate::sort::bam_fields::flags::FIRST_SEGMENT,
b"AAAAAA",
);
let result =
CorrectUmis::extract_and_validate_template_umi_raw(&[raw], *SamTag::RX).unwrap();
assert_eq!(result, Some("AAAAAA".to_string()));
}
#[test]
fn test_extract_and_validate_template_umi_raw_matching_pair() {
let r1 = make_raw_bam_for_correct(
b"rea",
crate::sort::bam_fields::flags::PAIRED | crate::sort::bam_fields::flags::FIRST_SEGMENT,
b"ACGTAC",
);
let r2 = make_raw_bam_for_correct(
b"rea",
crate::sort::bam_fields::flags::PAIRED | crate::sort::bam_fields::flags::LAST_SEGMENT,
b"ACGTAC",
);
let result =
CorrectUmis::extract_and_validate_template_umi_raw(&[r1, r2], *SamTag::RX).unwrap();
assert_eq!(result, Some("ACGTAC".to_string()));
}
#[test]
fn test_extract_and_validate_template_umi_raw_mismatch_errors() {
let r1 = make_raw_bam_for_correct(
b"rea",
crate::sort::bam_fields::flags::PAIRED | crate::sort::bam_fields::flags::FIRST_SEGMENT,
b"AAAAAA",
);
let r2 = make_raw_bam_for_correct(
b"rea",
crate::sort::bam_fields::flags::PAIRED | crate::sort::bam_fields::flags::LAST_SEGMENT,
b"CCCCCC",
);
let err =
CorrectUmis::extract_and_validate_template_umi_raw(&[r1, r2], *SamTag::RX).unwrap_err();
assert!(err.to_string().contains("mismatched UMIs"));
}
#[test]
fn test_extract_and_validate_template_umi_raw_empty() {
let result = CorrectUmis::extract_and_validate_template_umi_raw(&[], *SamTag::RX).unwrap();
assert!(result.is_none());
}
#[test]
fn test_extract_and_validate_template_umi_raw_no_umi_tag() {
let mut raw_bytes = vec![0u8; 42]; raw_bytes[8] = 4; raw_bytes[14..16].copy_from_slice(
&(crate::sort::bam_fields::flags::PAIRED
| crate::sort::bam_fields::flags::FIRST_SEGMENT)
.to_le_bytes(),
);
raw_bytes[16..20].copy_from_slice(&4u32.to_le_bytes()); raw_bytes[32..36].copy_from_slice(b"rea\0");
let raw = RawRecord::from(raw_bytes);
let result =
CorrectUmis::extract_and_validate_template_umi_raw(&[raw], *SamTag::RX).unwrap();
assert!(result.is_none());
}
#[test]
fn test_apply_correction_to_raw_corrects_umi() {
use crate::sort::bam_fields;
let mut raw = make_raw_bam_for_correct(
b"rea",
bam_fields::flags::PAIRED | bam_fields::flags::FIRST_SEGMENT,
b"AAAAAG", );
let correction = TemplateCorrection {
matched: true,
corrected_umi: Some("AAAAAA".to_string()),
original_umi: "AAAAAG".to_string(),
needs_correction: true,
has_mismatches: true,
matches: vec![],
rejection_reason: RejectionReason::None,
};
CorrectUmis::apply_correction_to_raw(&mut raw, &correction, *SamTag::RX, false);
let umi = bam_fields::find_string_tag_in_record(&raw, &SamTag::RX);
assert_eq!(umi, Some(b"AAAAAA".as_ref()));
let ox = bam_fields::find_string_tag_in_record(&raw, &SamTag::OX);
assert_eq!(ox, Some(b"AAAAAG".as_ref()));
}
#[test]
fn test_apply_correction_to_raw_no_correction_needed() {
use crate::sort::bam_fields;
let mut raw = make_raw_bam_for_correct(
b"rea",
bam_fields::flags::PAIRED | bam_fields::flags::FIRST_SEGMENT,
b"AAAAAA",
);
let orig_len = raw.len();
let correction = TemplateCorrection {
matched: true,
corrected_umi: Some("AAAAAA".to_string()),
original_umi: "AAAAAA".to_string(),
needs_correction: false,
has_mismatches: false,
matches: vec![],
rejection_reason: RejectionReason::None,
};
CorrectUmis::apply_correction_to_raw(&mut raw, &correction, *SamTag::RX, false);
assert_eq!(raw.len(), orig_len);
let umi = bam_fields::find_string_tag_in_record(&raw, &SamTag::RX);
assert_eq!(umi, Some(b"AAAAAA".as_ref()));
}
#[test]
fn test_single_thread_mode_produces_correct_output() -> Result<()> {
use fgumi_raw_bam::{
SamBuilder as RawSamBuilder, flags, raw_record_to_record_buf, testutil::encode_op,
};
use noodles::sam::header::record::value::{Map, map::ReferenceSequence};
use std::num::NonZeroUsize;
let header = sam::Header::builder()
.add_reference_sequence(
"chr1",
Map::<ReferenceSequence>::new(
NonZeroUsize::new(248_956_422).expect("non-zero chr1 length"),
),
)
.build();
let input = {
let temp = tempfile::NamedTempFile::new()?;
let mut writer = noodles::bam::io::writer::Builder.build_from_path(temp.path())?;
writer.write_header(&header)?;
let cigar = encode_op(0, 100); let seq = vec![b'A'; 100];
let quals = vec![30u8; 100];
for (pos1, pos2, name, umi) in
[(99i32, 199i32, "t1", "AAAAAG"), (299i32, 399i32, "t2", "CCCCCC")]
{
let mut b1 = RawSamBuilder::new();
b1.read_name(name.as_bytes())
.flags(flags::PAIRED | flags::FIRST_SEGMENT | flags::MATE_REVERSE)
.ref_id(0)
.pos(pos1)
.mapq(60)
.cigar_ops(&[cigar])
.sequence(&seq)
.qualities(&quals)
.mate_ref_id(0)
.mate_pos(pos2);
b1.add_string_tag(b"RX", umi.as_bytes());
let r1 = raw_record_to_record_buf(&b1.build(), &header)?;
writer.write_alignment_record(&header, &r1)?;
let mut b2 = RawSamBuilder::new();
b2.read_name(name.as_bytes())
.flags(flags::PAIRED | flags::LAST_SEGMENT | flags::REVERSE)
.ref_id(0)
.pos(pos2)
.mapq(60)
.cigar_ops(&[cigar])
.sequence(&seq)
.qualities(&quals)
.mate_ref_id(0)
.mate_pos(pos1);
b2.add_string_tag(b"RX", umi.as_bytes());
let r2 = raw_record_to_record_buf(&b2.build(), &header)?;
writer.write_alignment_record(&header, &r2)?;
}
writer.try_finish()?;
temp
};
let input = input;
let paths = TestPaths::new()?;
let corrector = CorrectUmis {
io: BamIoOptions {
input: input.path().to_path_buf(),
output: paths.output.clone(),
async_reader: false,
},
rejects_opts: RejectsOptions { rejects: Some(paths.rejects.clone()) },
metrics: Some(paths.metrics.clone()),
max_mismatches: 2,
min_distance_diff: 2,
umis: FIXED_UMIS.iter().map(|s| (*s).to_string()).collect(),
umi_files: vec![],
dont_store_original_umis: false,
cache_size: 100_000,
min_corrected: None,
revcomp: false,
threading: ThreadingOptions::none(),
compression: CompressionOptions { compression_level: 1 },
scheduler_opts: SchedulerOptions::default(),
queue_memory: QueueMemoryOptions::default(),
};
corrector.execute("test")?;
let output_names = read_bam_record_names(&paths.output)?;
assert_eq!(output_names.len(), 4, "Expected 4 records in output");
assert_eq!(
output_names.iter().filter(|n| *n == "t1").count(),
2,
"Expected 2 records for template t1"
);
assert_eq!(
output_names.iter().filter(|n| *n == "t2").count(),
2,
"Expected 2 records for template t2"
);
let reject_names = read_bam_record_names(&paths.rejects)?;
assert_eq!(reject_names.len(), 0, "Expected no rejected records");
let mut reader = noodles::bam::io::reader::Builder.build_from_path(&paths.output)?;
let header = reader.read_header()?;
let rx_tag = Tag::from(SamTag::RX);
let ox_tag = Tag::ORIGINAL_UMI_BARCODE_SEQUENCE;
for result in reader.records() {
let record = result?;
let record_buf = RecordBuf::try_from_alignment_record(&header, &record)?;
let name = record_buf.name().map(std::string::ToString::to_string).unwrap_or_default();
match name.as_str() {
"t1" => {
if let Some(sam::alignment::record_buf::data::field::Value::String(s)) =
record_buf.data().get(&rx_tag)
{
assert_eq!(
String::from_utf8_lossy(s),
"AAAAAA",
"t1 RX tag should be corrected to AAAAAA"
);
} else {
panic!("t1: RX tag not found or wrong type");
}
if let Some(sam::alignment::record_buf::data::field::Value::String(s)) =
record_buf.data().get(&ox_tag)
{
assert_eq!(
String::from_utf8_lossy(s),
"AAAAAG",
"t1 OX tag should store original UMI AAAAAG"
);
} else {
panic!("t1: OX tag not found or wrong type");
}
}
"t2" => {
if let Some(sam::alignment::record_buf::data::field::Value::String(s)) =
record_buf.data().get(&rx_tag)
{
assert_eq!(
String::from_utf8_lossy(s),
"CCCCCC",
"t2 RX tag should remain CCCCCC"
);
} else {
panic!("t2: RX tag not found or wrong type");
}
assert!(
record_buf.data().get(&ox_tag).is_none(),
"t2 should not have an OX tag (exact match)"
);
}
other => panic!("Unexpected record name: {other}"),
}
}
Ok(())
}
#[test]
fn test_apply_correction_to_raw_dont_store_original() {
use crate::sort::bam_fields;
let mut raw = make_raw_bam_for_correct(
b"rea",
bam_fields::flags::PAIRED | bam_fields::flags::FIRST_SEGMENT,
b"AAAAAG",
);
let correction = TemplateCorrection {
matched: true,
corrected_umi: Some("AAAAAA".to_string()),
original_umi: "AAAAAG".to_string(),
needs_correction: true,
has_mismatches: true,
matches: vec![],
rejection_reason: RejectionReason::None,
};
CorrectUmis::apply_correction_to_raw(&mut raw, &correction, *SamTag::RX, true);
let umi = bam_fields::find_string_tag_in_record(&raw, &SamTag::RX);
assert_eq!(umi, Some(b"AAAAAA".as_ref()));
let ox = bam_fields::find_string_tag_in_record(&raw, &SamTag::OX);
assert!(ox.is_none());
}
}