use crate::bam_io::create_raw_bam_writer;
use crate::commands::command::Command;
use crate::commands::common::CompressionOptions;
use crate::commands::simulate::common::generate_random_sequence;
use crate::progress::ProgressTracker;
use crate::simulate::create_rng;
use anyhow::{Context, Result};
use clap::Parser;
use crossbeam_channel::bounded;
use fgumi_raw_bam::{RawRecord, SamBuilder, flags as raw_flags};
use log::info;
use noodles::sam::header::Header;
use rand::{Rng, RngExt};
use rayon::prelude::*;
use std::fs::File;
use std::io::{BufWriter, Write};
use std::path::PathBuf;
use std::sync::Arc;
use std::thread;
#[derive(Parser, Debug)]
#[command(
name = "correct-reads",
about = "Generate BAM and includelist for correct",
long_about = r#"
Generate synthetic unmapped BAM and UMI includelist for testing `fgumi correct`.
Reads are generated with varying edit distances from the includelist UMIs
to test the correction algorithm's ability to find the correct UMI.
"#
)]
pub struct CorrectReads {
#[arg(short = 'o', long = "output", required = true)]
pub output: PathBuf,
#[arg(short = 'i', long = "includelist", required = true)]
pub includelist: PathBuf,
#[arg(long = "truth", required = true)]
pub truth_output: PathBuf,
#[arg(short = 'n', long = "num-reads", default_value = "10000")]
pub num_reads: usize,
#[arg(long = "num-umis", default_value = "1000")]
pub num_umis: usize,
#[arg(short = 'u', long = "umi-length", default_value = "8")]
pub umi_length: usize,
#[arg(short = 'l', long = "read-length", default_value = "100")]
pub read_length: usize,
#[arg(long = "seed")]
pub seed: Option<u64>,
#[arg(short = 't', long = "threads", default_value = "1")]
pub threads: usize,
#[command(flatten)]
pub compression: CompressionOptions,
#[arg(long = "exact-fraction", default_value = "0.4")]
pub exact_fraction: f64,
#[arg(long = "edit1-fraction", default_value = "0.3")]
pub edit1_fraction: f64,
#[arg(long = "edit2-fraction", default_value = "0.2")]
pub edit2_fraction: f64,
#[arg(long = "multi-fraction", default_value = "0.1")]
pub multi_fraction: f64,
#[arg(long = "quality", default_value = "30")]
pub quality: u8,
}
struct CorrectReadPair {
read_name: String,
r1_record: RawRecord,
r2_record: RawRecord,
truth: (String, String, String, usize, &'static str),
}
#[derive(Clone)]
struct GenerationParams {
read_length: usize,
quality: u8,
exact_fraction: f64,
edit1_fraction: f64,
edit2_fraction: f64,
umi_length: usize,
}
const CHANNEL_CAPACITY: usize = 1_000;
impl Command for CorrectReads {
fn execute(&self, command_line: &str) -> Result<()> {
let max_possible_umis = 4_usize.saturating_pow(self.umi_length as u32);
if self.num_umis > max_possible_umis {
anyhow::bail!(
"Cannot generate {} unique UMIs with {}-base UMIs (maximum possible: 4^{} = {})",
self.num_umis,
self.umi_length,
self.umi_length,
max_possible_umis
);
}
info!("Generating correct-reads data");
info!(" Output: {}", self.output.display());
info!(" Includelist: {}", self.includelist.display());
info!(" Truth: {}", self.truth_output.display());
info!(" Num reads: {}", self.num_reads);
info!(" Num UMIs: {}", self.num_umis);
info!(" UMI length: {}", self.umi_length);
info!(" Threads: {}", self.threads);
let total =
self.exact_fraction + self.edit1_fraction + self.edit2_fraction + self.multi_fraction;
if (total - 1.0).abs() > 0.001 {
anyhow::bail!("Error fractions must sum to 1.0 (got {total})");
}
let mut rng = create_rng(self.seed);
info!("Generating {} unique UMIs", self.num_umis);
let mut umis: Vec<String> = Vec::with_capacity(self.num_umis);
let mut seen = std::collections::HashSet::new();
while umis.len() < self.num_umis {
let umi = generate_random_sequence(self.umi_length, &mut rng);
let umi_str = String::from_utf8_lossy(&umi).to_string();
if !seen.contains(&umi_str) {
seen.insert(umi_str.clone());
umis.push(umi_str);
}
}
umis.sort();
let includelist_file = File::create(&self.includelist)
.with_context(|| format!("Failed to create {}", self.includelist.display()))?;
let mut includelist_writer = BufWriter::new(includelist_file);
for umi in &umis {
writeln!(includelist_writer, "{umi}")?;
}
includelist_writer.flush()?;
info!("Wrote includelist with {} UMIs", umis.len());
let umis = Arc::new(umis);
let mut header_builder = Header::builder();
header_builder = crate::commands::common::add_pg_to_builder(header_builder, command_line)?;
let header = header_builder.build();
let params = Arc::new(GenerationParams {
read_length: self.read_length,
quality: self.quality,
exact_fraction: self.exact_fraction,
edit1_fraction: self.edit1_fraction,
edit2_fraction: self.edit2_fraction,
umi_length: self.umi_length,
});
let read_seeds: Vec<u64> = (0..self.num_reads).map(|_| rng.random()).collect();
let (sender, receiver) = bounded::<CorrectReadPair>(CHANNEL_CAPACITY);
let output_path = self.output.clone();
let truth_path = self.truth_output.clone();
let compression_level = self.compression.compression_level;
let writer_threads = self.threads;
let header_clone = header.clone();
let num_reads = self.num_reads;
let writer_handle = thread::spawn(move || -> Result<(u64, u64, u64, u64, u64)> {
let mut writer = create_raw_bam_writer(
&output_path,
&header_clone,
writer_threads,
compression_level,
)?;
let truth_file = File::create(&truth_path)
.with_context(|| format!("Failed to create {}", truth_path.display()))?;
let mut truth_writer = BufWriter::new(truth_file);
writeln!(
truth_writer,
"read_name\ttrue_umi\tobserved_umi\texpected_correction\tedit_distance\terror_type"
)?;
let mut read_count = 0u64;
let mut exact_count = 0u64;
let mut edit1_count = 0u64;
let mut edit2_count = 0u64;
let mut multi_count = 0u64;
let progress = ProgressTracker::new("Generated reads").with_interval(100_000);
for pair in receiver {
read_count += 1;
progress.log_if_needed(1);
writer.write_raw_record(pair.r1_record.as_ref())?;
writer.write_raw_record(pair.r2_record.as_ref())?;
let (true_umi, observed_umi, expected_correction, edit_distance, error_type) =
&pair.truth;
match *error_type {
"exact" => exact_count += 1,
"edit1" => edit1_count += 1,
"edit2" => edit2_count += 1,
"multi" => multi_count += 1,
_ => {}
}
writeln!(
truth_writer,
"{}\t{true_umi}\t{observed_umi}\t{expected_correction}\t{edit_distance}\t{error_type}",
pair.read_name
)?;
}
progress.log_final();
truth_writer.flush()?;
writer.finish()?;
Ok((read_count, exact_count, edit1_count, edit2_count, multi_count))
});
let gen_threads = if self.threads <= 1 { 1 } else { self.threads.max(2) };
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(gen_threads)
.build()
.with_context(|| "Failed to create thread pool")?;
let generation_result: Result<(), crossbeam_channel::SendError<CorrectReadPair>> = pool
.install(|| {
read_seeds.into_par_iter().enumerate().try_for_each(|(read_idx, seed)| {
let pair = generate_correct_read_pair(read_idx, seed, ¶ms, &umis);
sender.send(pair)
})
});
drop(sender);
if let Err(e) = generation_result {
return Err(anyhow::anyhow!("Failed to send record to writer: {e}"));
}
let (read_count, exact_count, edit1_count, edit2_count, multi_count) =
writer_handle.join().map_err(|_| anyhow::anyhow!("Writer thread panicked"))??;
info!("Generated {read_count} reads:");
info!(
" Exact (0 edits): {} ({:.1}%)",
exact_count,
100.0 * exact_count as f64 / num_reads as f64
);
info!(
" Edit1 (1 edit): {} ({:.1}%)",
edit1_count,
100.0 * edit1_count as f64 / num_reads as f64
);
info!(
" Edit2 (2 edits): {} ({:.1}%)",
edit2_count,
100.0 * edit2_count as f64 / num_reads as f64
);
info!(
" Multi (3+ edits): {} ({:.1}%)",
multi_count,
100.0 * multi_count as f64 / num_reads as f64
);
info!("Done");
Ok(())
}
}
fn generate_correct_read_pair(
read_idx: usize,
seed: u64,
params: &GenerationParams,
umis: &[String],
) -> CorrectReadPair {
let mut rng = create_rng(Some(seed));
let read_name = format!("read_{read_idx:08}");
let true_umi_idx = rng.random_range(0..umis.len());
let true_umi = &umis[true_umi_idx];
let exact_threshold = params.exact_fraction;
let edit1_threshold = exact_threshold + params.edit1_fraction;
let edit2_threshold = edit1_threshold + params.edit2_fraction;
let r: f64 = rng.random();
let (observed_umi, edit_distance, error_type) = if r < exact_threshold {
(true_umi.clone(), 0, "exact")
} else if r < edit1_threshold {
(introduce_n_errors(true_umi, 1, &mut rng), 1, "edit1")
} else if r < edit2_threshold {
(introduce_n_errors(true_umi, 2, &mut rng), 2, "edit2")
} else {
let num_errors = rng.random_range(3..=params.umi_length.min(5));
(introduce_n_errors(true_umi, num_errors, &mut rng), num_errors, "multi")
};
let expected_correction =
if error_type == "multi" { observed_umi.clone() } else { true_umi.clone() };
let template_r1 = generate_random_sequence(params.read_length, &mut rng);
let template_r2 = generate_random_sequence(params.read_length, &mut rng);
let mut r1_builder = SamBuilder::new();
r1_builder
.read_name(read_name.as_bytes())
.flags(
raw_flags::PAIRED
| raw_flags::FIRST_SEGMENT
| raw_flags::UNMAPPED
| raw_flags::MATE_UNMAPPED,
)
.sequence(&template_r1)
.qualities(&vec![params.quality; template_r1.len()])
.add_string_tag(b"RX", observed_umi.as_bytes());
let r1_record = r1_builder.build();
let mut r2_builder = SamBuilder::new();
r2_builder
.read_name(read_name.as_bytes())
.flags(
raw_flags::PAIRED
| raw_flags::LAST_SEGMENT
| raw_flags::UNMAPPED
| raw_flags::MATE_UNMAPPED,
)
.sequence(&template_r2)
.qualities(&vec![params.quality; template_r2.len()])
.add_string_tag(b"RX", observed_umi.as_bytes());
let r2_record = r2_builder.build();
CorrectReadPair {
read_name,
r1_record,
r2_record,
truth: (true_umi.clone(), observed_umi, expected_correction, edit_distance, error_type),
}
}
fn introduce_n_errors(umi: &str, n: usize, rng: &mut impl Rng) -> String {
const BASES: &[u8] = b"ACGT";
let mut result: Vec<u8> = umi.as_bytes().to_vec();
let mut positions: Vec<usize> = (0..result.len()).collect();
for i in 0..n.min(positions.len()) {
let j = rng.random_range(i..positions.len());
positions.swap(i, j);
}
for &pos in positions.iter().take(n) {
let current = result[pos];
loop {
let new_base = BASES[rng.random_range(0..4)];
if new_base != current {
result[pos] = new_base;
break;
}
}
}
String::from_utf8_lossy(&result).to_string()
}
#[cfg(test)]
#[allow(clippy::naive_bytecount)]
mod tests {
use super::*;
use crate::sam::SamTag;
use crate::simulate::create_rng;
use noodles::sam::alignment::record::data::field::Tag;
#[test]
fn test_generate_random_sequence_length() {
let mut rng = create_rng(Some(42));
for len in [0, 1, 8, 16, 100] {
let seq = generate_random_sequence(len, &mut rng);
assert_eq!(seq.len(), len);
}
}
#[test]
fn test_generate_random_sequence_valid_bases() {
let mut rng = create_rng(Some(42));
let seq = generate_random_sequence(1000, &mut rng);
for &base in &seq {
assert!(
base == b'A' || base == b'C' || base == b'G' || base == b'T',
"Invalid base: {}",
base as char
);
}
}
#[test]
fn test_generate_random_sequence_reproducibility() {
let mut rng1 = create_rng(Some(42));
let mut rng2 = create_rng(Some(42));
let seq1 = generate_random_sequence(100, &mut rng1);
let seq2 = generate_random_sequence(100, &mut rng2);
assert_eq!(seq1, seq2);
}
#[test]
fn test_generate_random_sequence_distribution() {
let mut rng = create_rng(Some(42));
let seq = generate_random_sequence(10000, &mut rng);
let a_count = seq.iter().filter(|&&b| b == b'A').count();
let c_count = seq.iter().filter(|&&b| b == b'C').count();
let g_count = seq.iter().filter(|&&b| b == b'G').count();
let t_count = seq.iter().filter(|&&b| b == b'T').count();
for (base, count) in [('A', a_count), ('C', c_count), ('G', g_count), ('T', t_count)] {
let fraction = count as f64 / 10000.0;
assert!(
(0.20..0.30).contains(&fraction),
"Base {} has unexpected frequency: {:.2}%",
base,
fraction * 100.0
);
}
}
#[test]
fn test_introduce_n_errors_zero() {
let mut rng = create_rng(Some(42));
let umi = "ACGTACGT";
let result = introduce_n_errors(umi, 0, &mut rng);
assert_eq!(result, umi);
}
#[test]
fn test_introduce_n_errors_preserves_length() {
let mut rng = create_rng(Some(42));
let umi = "ACGTACGT";
for n in 1..=8 {
let result = introduce_n_errors(umi, n, &mut rng);
assert_eq!(result.len(), umi.len());
}
}
#[test]
fn test_introduce_n_errors_exact_count() {
let mut rng = create_rng(Some(42));
let umi = "AAAAAAAA";
for n in 1..=8 {
let result = introduce_n_errors(umi, n, &mut rng);
let diff_count = umi.chars().zip(result.chars()).filter(|(a, b)| a != b).count();
assert_eq!(
diff_count, n,
"Expected {n} errors but got {diff_count} for UMI {umi} -> {result}"
);
}
}
#[test]
fn test_introduce_n_errors_more_than_length() {
let mut rng = create_rng(Some(42));
let umi = "ACGT";
let result = introduce_n_errors(umi, 10, &mut rng);
assert_eq!(result.len(), 4);
let diff_count = umi.chars().zip(result.chars()).filter(|(a, b)| a != b).count();
assert_eq!(diff_count, 4);
}
#[test]
fn test_introduce_n_errors_different_base_guarantee() {
let mut rng = create_rng(Some(42));
let umi = "AAAAAAAA";
for _ in 0..100 {
let result = introduce_n_errors(umi, 4, &mut rng);
for (orig, new) in umi.chars().zip(result.chars()) {
if orig != new {
assert_ne!(new, 'A', "Mutated base should not equal original");
}
}
}
}
#[test]
fn test_introduce_n_errors_valid_bases() {
let mut rng = create_rng(Some(42));
let umi = "ACGTACGT";
for _ in 0..100 {
let result = introduce_n_errors(umi, 4, &mut rng);
for c in result.chars() {
assert!(
c == 'A' || c == 'C' || c == 'G' || c == 'T',
"Invalid base in result: {c}"
);
}
}
}
#[test]
fn test_introduce_n_errors_reproducibility() {
let umi = "ACGTACGT";
let mut rng1 = create_rng(Some(42));
let mut rng2 = create_rng(Some(42));
let result1 = introduce_n_errors(umi, 3, &mut rng1);
let result2 = introduce_n_errors(umi, 3, &mut rng2);
assert_eq!(result1, result2);
}
#[test]
fn test_introduce_n_errors_single_base_umi() {
let mut rng = create_rng(Some(42));
let umi = "A";
let result = introduce_n_errors(umi, 1, &mut rng);
assert_eq!(result.len(), 1);
assert_ne!(result, "A");
assert!(result == "C" || result == "G" || result == "T");
}
#[test]
fn test_introduce_n_errors_empty_umi() {
let mut rng = create_rng(Some(42));
let umi = "";
let result = introduce_n_errors(umi, 5, &mut rng);
assert_eq!(result, "");
}
#[test]
fn test_introduce_n_errors_unique_positions() {
let mut rng = create_rng(Some(42));
let umi = "AAAAAAAAAAAA";
for _ in 0..100 {
let result = introduce_n_errors(umi, 6, &mut rng);
let diff_count = umi.chars().zip(result.chars()).filter(|(a, b)| a != b).count();
assert_eq!(diff_count, 6, "Should have exactly 6 different positions");
}
}
fn to_record_buf(raw: fgumi_raw_bam::RawRecord) -> noodles::sam::alignment::RecordBuf {
fgumi_raw_bam::raw_record_to_record_buf(&raw, &noodles::sam::Header::default())
.expect("raw_record_to_record_buf failed in test")
}
#[test]
fn test_paired_read_r1_flags() {
use fgumi_raw_bam::{SamBuilder as RawSamBuilder, flags as raw_flags};
let mut b = RawSamBuilder::new();
b.read_name(b"test_read")
.flags(
raw_flags::PAIRED
| raw_flags::FIRST_SEGMENT
| raw_flags::UNMAPPED
| raw_flags::MATE_UNMAPPED,
)
.sequence(b"ACGT");
b.add_string_tag(b"RX", b"AAAAAAAA");
let r1 = to_record_buf(b.build());
assert!(r1.flags().is_unmapped());
assert!(r1.flags().is_segmented());
assert!(r1.flags().is_first_segment());
assert!(r1.flags().is_mate_unmapped());
assert!(!r1.flags().is_last_segment());
let flag_bits: u16 = r1.flags().bits();
assert_eq!(flag_bits, 77, "R1 flag should be 77");
}
#[test]
fn test_paired_read_r2_flags() {
use fgumi_raw_bam::{SamBuilder as RawSamBuilder, flags as raw_flags};
let mut b = RawSamBuilder::new();
b.read_name(b"test_read")
.flags(
raw_flags::PAIRED
| raw_flags::LAST_SEGMENT
| raw_flags::UNMAPPED
| raw_flags::MATE_UNMAPPED,
)
.sequence(b"ACGT");
b.add_string_tag(b"RX", b"AAAAAAAA");
let r2 = to_record_buf(b.build());
assert!(r2.flags().is_unmapped());
assert!(r2.flags().is_segmented());
assert!(r2.flags().is_last_segment());
assert!(r2.flags().is_mate_unmapped());
assert!(!r2.flags().is_first_segment());
let flag_bits: u16 = r2.flags().bits();
assert_eq!(flag_bits, 141, "R2 flag should be 141");
}
#[test]
fn test_paired_reads_same_name_and_umi() {
use fgumi_raw_bam::{SamBuilder as RawSamBuilder, flags as raw_flags};
let read_name = "test_pair";
let umi = "ACGTACGT";
let mut b1 = RawSamBuilder::new();
b1.read_name(read_name.as_bytes())
.flags(
raw_flags::PAIRED
| raw_flags::FIRST_SEGMENT
| raw_flags::UNMAPPED
| raw_flags::MATE_UNMAPPED,
)
.sequence(b"AAAA");
b1.add_string_tag(b"RX", umi.as_bytes());
let r1 = to_record_buf(b1.build());
let mut b2 = RawSamBuilder::new();
b2.read_name(read_name.as_bytes())
.flags(
raw_flags::PAIRED
| raw_flags::LAST_SEGMENT
| raw_flags::UNMAPPED
| raw_flags::MATE_UNMAPPED,
)
.sequence(b"CCCC");
b2.add_string_tag(b"RX", umi.as_bytes());
let r2 = to_record_buf(b2.build());
assert_eq!(r1.name(), r2.name());
let rx_tag: Tag = Tag::from(SamTag::RX);
assert!(r1.data().get(&rx_tag).is_some());
assert!(r2.data().get(&rx_tag).is_some());
assert_eq!(r1.data().get(&rx_tag), r2.data().get(&rx_tag));
}
#[test]
fn test_paired_read_qualities() {
use fgumi_raw_bam::{SamBuilder as RawSamBuilder, flags as raw_flags};
let quality: u8 = 35;
let seq = b"ACGTACGT";
let mut b = RawSamBuilder::new();
b.read_name(b"read")
.flags(
raw_flags::PAIRED
| raw_flags::FIRST_SEGMENT
| raw_flags::UNMAPPED
| raw_flags::MATE_UNMAPPED,
)
.sequence(seq)
.qualities(&vec![quality; seq.len()]);
let r1 = to_record_buf(b.build());
let quals: Vec<u8> = r1.quality_scores().iter().collect();
assert_eq!(quals.len(), 8);
assert!(quals.iter().all(|&q| q == quality));
}
#[test]
fn test_num_umis_exceeds_possible_umis_fails() {
use crate::commands::command::Command;
use tempfile::tempdir;
let dir = tempdir().expect("failed to create temp dir");
let output = dir.path().join("output.bam");
let includelist = dir.path().join("includelist.txt");
let truth = dir.path().join("truth.tsv");
let cmd = CorrectReads {
output,
includelist,
truth_output: truth,
num_reads: 100,
num_umis: 1000,
umi_length: 4,
read_length: 100,
seed: Some(42),
threads: 1,
compression: CompressionOptions::default(),
exact_fraction: 0.4,
edit1_fraction: 0.3,
edit2_fraction: 0.2,
multi_fraction: 0.1,
quality: 30,
};
let result = cmd.execute("test");
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("Cannot generate 1000 unique UMIs"),
"Error message should mention UMI count: {err_msg}"
);
assert!(
err_msg.contains("4-base UMIs"),
"Error message should mention UMI length: {err_msg}"
);
assert!(err_msg.contains("256"), "Error message should mention max possible: {err_msg}");
}
#[test]
fn test_num_umis_at_max_succeeds() {
let max_possible = 4_usize.pow(2);
assert_eq!(max_possible, 16);
let num_umis = 16;
let umi_length = 2;
let max_possible_umis = 4_usize.saturating_pow(umi_length as u32);
assert!(num_umis <= max_possible_umis, "16 UMIs should be allowed for 2-base UMIs");
}
}