use super::sort::TemplateCoordKey;
use crate::bam_io::create_bam_writer;
use crate::commands::command::Command;
use crate::commands::common::{CompressionOptions, parse_bool};
use crate::commands::simulate::common::{
FamilySizeArgs, InsertSizeArgs, MoleculeInfo, PositionDistArgs, QualityArgs, ReferenceArgs,
SimulationCommon, StrandBiasArgs, compute_position, generate_random_sequence, pad_sequence,
};
use crate::dna::reverse_complement;
use crate::progress::ProgressTracker;
use crate::sam::builder::RecordBuilder;
use crate::simulate::{
FamilySizeDistribution, InsertSizeModel, PositionQualityModel, ReadPairQualityBias,
StrandBiasModel, create_rng,
};
use anyhow::{Context, Result};
use bstr::BString;
use clap::Parser;
use log::info;
use noodles::sam::alignment::io::Write as AlignmentWrite;
use noodles::sam::alignment::record_buf::RecordBuf;
use noodles::sam::header::Header;
use noodles::sam::header::record::value::Map;
use noodles::sam::header::record::value::map::ReferenceSequence;
use noodles::sam::header::record::value::map::header::{self as HeaderRecord, Tag as HeaderTag};
use rand::{Rng, RngExt};
use std::fs::File;
use std::io::{BufWriter, Write};
use std::num::NonZeroUsize;
use std::path::PathBuf;
#[derive(Parser, Debug)]
#[command(
name = "grouped-reads",
about = "Generate grouped BAM with MI tags for consensus calling",
long_about = r#"
Generate synthetic BAM files with MI (molecule ID) tags already assigned.
The output is template-coordinate sorted and suitable for input to
`fgumi simplex`, `fgumi duplex`, or `fgumi codec`.
"#
)]
pub struct GroupedReads {
#[arg(short = 'o', long = "output", required = true)]
pub output: PathBuf,
#[arg(long = "truth", required = true)]
pub truth_output: PathBuf,
#[arg(long = "duplex", default_value = "false", num_args = 0..=1, default_missing_value = "true", action = clap::ArgAction::Set, value_parser = parse_bool)]
pub duplex: bool,
#[arg(long = "mapq", default_value = "60")]
pub mapq: u8,
#[arg(short = 't', long = "threads", default_value = "1")]
pub threads: usize,
#[command(flatten)]
pub compression: CompressionOptions,
#[command(flatten)]
pub common: SimulationCommon,
#[command(flatten)]
pub quality: QualityArgs,
#[command(flatten)]
pub family_size: FamilySizeArgs,
#[command(flatten)]
pub insert_size: InsertSizeArgs,
#[command(flatten)]
pub reference: ReferenceArgs,
#[command(flatten)]
pub position_dist: PositionDistArgs,
#[command(flatten)]
pub strand_bias: StrandBiasArgs,
}
struct GenerationParams {
read_length: usize,
umi_length: usize,
mapq: u8,
duplex: bool,
min_family_size: usize,
num_positions: usize,
ref_length: usize,
quality_model: PositionQualityModel,
quality_bias: ReadPairQualityBias,
family_dist: FamilySizeDistribution,
insert_model: InsertSizeModel,
strand_bias_model: StrandBiasModel,
}
impl Command for GroupedReads {
fn execute(&self, command_line: &str) -> Result<()> {
info!("Generating grouped reads");
info!(" Output: {}", self.output.display());
info!(" Truth: {}", self.truth_output.display());
info!(" Duplex: {}", self.duplex);
info!(" Num molecules: {}", self.common.num_molecules);
info!(" Read length: {}", self.common.read_length);
info!(" UMI length: {}", self.common.umi_length);
info!(" Threads: {}", self.threads);
let num_positions = self.position_dist.num_positions.unwrap_or(self.common.num_molecules);
let ref_length = self.reference.ref_length;
let usable_bases = ref_length.saturating_sub(1000);
let bases_per_position = usable_bases as f64 / num_positions as f64;
if bases_per_position < 1.0 {
let suggested_ref_length = num_positions * 2;
anyhow::bail!(
"Position collision: {num_positions} positions cannot fit in {ref_length} bp reference ({bases_per_position:.2} bp/position). \
Increase --ref-length to at least {suggested_ref_length} or reduce --num-molecules."
);
} else if bases_per_position < 10.0 {
log::warn!(
"Low position spacing ({bases_per_position:.1} bp/position) may cause UMI collisions. \
Consider increasing --ref-length."
);
}
let ref_name = self.reference.ref_name.clone();
let mut header = Header::builder();
let HeaderTag::Other(so_tag) = HeaderTag::from([b'S', b'O']) else { unreachable!() };
let HeaderTag::Other(go_tag) = HeaderTag::from([b'G', b'O']) else { unreachable!() };
let HeaderTag::Other(ss_tag) = HeaderTag::from([b'S', b'S']) else { unreachable!() };
let header_map = Map::<HeaderRecord::Header>::builder()
.insert(so_tag, "unsorted")
.insert(go_tag, "query")
.insert(ss_tag, "template-coordinate")
.build()
.expect("header map with valid SO/GO/SS tags");
header = header.set_header(header_map);
let length = NonZeroUsize::try_from(ref_length).expect("Reference length must be > 0");
let ref_seq: Map<ReferenceSequence> = Map::<ReferenceSequence>::new(length);
header = header.add_reference_sequence(BString::from(&*ref_name), ref_seq);
header = crate::commands::common::add_pg_to_builder(header, command_line)?;
let header = header.build();
let params = GenerationParams {
read_length: self.common.read_length,
umi_length: self.common.umi_length,
mapq: self.mapq,
duplex: self.duplex,
min_family_size: self.family_size.min_family_size,
num_positions,
ref_length,
quality_model: self.quality.to_quality_model(),
quality_bias: self.quality.to_quality_bias(),
family_dist: self.family_size.to_family_size_distribution()?,
insert_model: self.insert_size.to_insert_size_model(),
strand_bias_model: self.strand_bias.to_strand_bias_model(),
};
let mut seed_rng = create_rng(self.common.seed);
info!("Computing molecule positions and sort keys...");
let mut molecules: Vec<MoleculeInfo> = (0..self.common.num_molecules)
.map(|mol_id| {
let seed: u64 = seed_rng.random();
let pos1 = compute_position(mol_id, num_positions, ref_length);
let mut mol_rng = create_rng(Some(seed));
for _ in 0..params.umi_length {
let _: usize = mol_rng.random_range(0..4);
}
let _ = params.family_dist.sample(&mut mol_rng, params.min_family_size);
let insert_size = params.insert_model.sample(&mut mol_rng);
let sort_key = TemplateCoordKey::for_f1r2_pair(
0, pos1,
insert_size,
mol_id.to_string(), format!("mol{mol_id:08}"),
);
MoleculeInfo { mol_id, seed, sort_key }
})
.collect();
info!("Sorting {} molecules by template-coordinate...", molecules.len());
molecules.sort_unstable();
let mut writer = create_bam_writer(
&self.output,
&header,
self.threads,
self.compression.compression_level,
)?;
let truth_file = File::create(&self.truth_output)
.with_context(|| format!("Failed to create {}", self.truth_output.display()))?;
let mut truth_writer = BufWriter::new(truth_file);
writeln!(
truth_writer,
"read_name\ttrue_umi\tmolecule_id\tmi_tag\tchrom\tposition\tstrand"
)?;
info!("Generating and writing records in sorted order...");
let progress = ProgressTracker::new("Processed molecules").with_interval(100_000);
let mut total_pairs = 0usize;
for mol_info in molecules {
progress.log_if_needed(1);
let pairs = generate_molecule_reads(mol_info.mol_id, mol_info.seed, ¶ms);
for (r1, r2, read_name, umi_str, mi_tag, strand, insert_size) in pairs {
let r2_first = strand == 'B' && insert_size < 2 * params.read_length;
if r2_first {
writer.write_alignment_record(&header, &r2)?;
writer.write_alignment_record(&header, &r1)?;
} else {
writer.write_alignment_record(&header, &r1)?;
writer.write_alignment_record(&header, &r2)?;
}
writeln!(
truth_writer,
"{}\t{}\t{}\t{}\t{}\t{}\t{}",
read_name,
umi_str,
mol_info.mol_id,
mi_tag,
ref_name,
mol_info.sort_key.pos1,
strand
)?;
total_pairs += 1;
}
}
progress.log_final();
truth_writer.flush()?;
info!("Generated {total_pairs} read pairs");
info!("Done");
Ok(())
}
}
fn generate_molecule_reads(
mol_id: usize,
seed: u64,
params: &GenerationParams,
) -> Vec<(RecordBuf, RecordBuf, String, String, String, char, usize)> {
let mut rng = create_rng(Some(seed));
let position = compute_position(mol_id, params.num_positions, params.ref_length);
let umi = generate_random_sequence(params.umi_length, &mut rng);
let umi_str = String::from_utf8_lossy(&umi).to_string();
let family_size = params.family_dist.sample(&mut rng, params.min_family_size);
let insert_size = params.insert_model.sample(&mut rng);
let mut pairs = Vec::new();
if params.duplex {
let (a_count, b_count) = params.strand_bias_model.split_reads(family_size, &mut rng);
for read_idx in 0..a_count {
let read_name = format!("mol{mol_id:08}_readA{read_idx:04}");
let mi_tag = format!("{mol_id}/A");
let (r1_record, r2_record) = generate_read_pair_records(
&read_name,
&umi_str,
&mi_tag,
position,
insert_size,
params.read_length,
params.mapq,
'A',
¶ms.quality_model,
¶ms.quality_bias,
&mut rng,
);
pairs.push((
r1_record,
r2_record,
read_name,
umi_str.clone(),
mi_tag,
'A',
insert_size,
));
}
for read_idx in 0..b_count {
let read_name = format!("mol{mol_id:08}_readB{read_idx:04}");
let mi_tag = format!("{mol_id}/B");
let (r1_record, r2_record) = generate_read_pair_records(
&read_name,
&umi_str,
&mi_tag,
position,
insert_size,
params.read_length,
params.mapq,
'B',
¶ms.quality_model,
¶ms.quality_bias,
&mut rng,
);
pairs.push((
r1_record,
r2_record,
read_name,
umi_str.clone(),
mi_tag,
'B',
insert_size,
));
}
} else {
let mi_tag = mol_id.to_string();
for read_idx in 0..family_size {
let read_name = format!("mol{mol_id:08}_read{read_idx:04}");
let (r1_record, r2_record) = generate_read_pair_records(
&read_name,
&umi_str,
&mi_tag,
position,
insert_size,
params.read_length,
params.mapq,
'+',
¶ms.quality_model,
¶ms.quality_bias,
&mut rng,
);
pairs.push((
r1_record,
r2_record,
read_name,
umi_str.clone(),
mi_tag.clone(),
'+',
insert_size,
));
}
}
pairs
}
#[allow(clippy::too_many_arguments)]
fn generate_read_pair_records(
read_name: &str,
umi_str: &str,
mi_tag: &str,
position: usize,
insert_size: usize,
read_length: usize,
mapq: u8,
strand: char,
quality_model: &PositionQualityModel,
quality_bias: &ReadPairQualityBias,
rng: &mut impl Rng,
) -> (RecordBuf, RecordBuf) {
let template = generate_random_sequence(insert_size, rng);
let r1_seq: Vec<u8> = template.iter().take(read_length).copied().collect();
let r1_seq = pad_sequence(r1_seq, read_length, rng);
let r2_start = insert_size.saturating_sub(read_length);
let r2_template: Vec<u8> = template.iter().skip(r2_start).take(read_length).copied().collect();
let r2_seq = reverse_complement(&r2_template);
let r2_seq = pad_sequence(r2_seq, read_length, rng);
let r1_quals = quality_model.generate_qualities(read_length, rng);
let r2_quals_raw = quality_model.generate_qualities(read_length, rng);
let r2_quals = quality_bias.apply_to_vec(&r2_quals_raw, true);
let mate_cigar = format!("{read_length}M");
let (r1_is_reverse, r2_is_reverse) = match strand {
'B' => (true, false), _ => (false, true), };
let r1_record = build_record(
read_name,
&r1_seq,
&r1_quals,
0, position,
mapq,
true, r1_is_reverse,
position + insert_size - read_length,
insert_size as i32,
umi_str,
mi_tag,
&mate_cigar,
mapq,
);
let r2_record = build_record(
read_name,
&r2_seq,
&r2_quals,
0,
position + insert_size - read_length,
mapq,
false, r2_is_reverse,
position,
-(insert_size as i32),
umi_str,
mi_tag,
&mate_cigar,
mapq,
);
(r1_record, r2_record)
}
#[allow(clippy::too_many_arguments)]
fn build_record(
name: &str,
seq: &[u8],
quals: &[u8],
ref_id: usize,
pos: usize,
mapq: u8,
is_first: bool,
is_reverse: bool,
mate_pos: usize,
tlen: i32,
umi: &str,
mi_tag: &str,
mate_cigar: &str,
mate_mapq: u8,
) -> RecordBuf {
let seq_str = String::from_utf8_lossy(seq);
RecordBuilder::new()
.name(name)
.sequence(&seq_str)
.qualities(quals)
.reference_sequence_id(ref_id)
.alignment_start(pos + 1) .mapping_quality(mapq)
.paired(true)
.first_segment(is_first)
.reverse_complement(is_reverse)
.mate_reverse_complement(!is_reverse)
.mate_reference_sequence_id(ref_id)
.mate_alignment_start(mate_pos + 1) .template_length(tlen)
.tag("RX", umi)
.tag("MI", mi_tag)
.tag("MC", mate_cigar)
.tag("MQ", i32::from(mate_mapq))
.build()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::simulate::create_rng;
#[test]
fn test_generate_random_sequence_length() {
let mut rng = create_rng(Some(42));
for len in [0, 1, 8, 100, 300] {
let seq = generate_random_sequence(len, &mut rng);
assert_eq!(seq.len(), len);
}
}
#[test]
fn test_generate_random_sequence_valid_bases() {
let mut rng = create_rng(Some(42));
let seq = generate_random_sequence(1000, &mut rng);
for &base in &seq {
assert!(
base == b'A' || base == b'C' || base == b'G' || base == b'T',
"Invalid base: {}",
base as char
);
}
}
#[test]
fn test_reverse_complement_basic() {
assert_eq!(reverse_complement(b"A"), b"T");
assert_eq!(reverse_complement(b"T"), b"A");
assert_eq!(reverse_complement(b"C"), b"G");
assert_eq!(reverse_complement(b"G"), b"C");
}
#[test]
fn test_reverse_complement_sequence() {
assert_eq!(reverse_complement(b"ACGT"), b"ACGT");
assert_eq!(reverse_complement(b"AAAA"), b"TTTT");
assert_eq!(reverse_complement(b"AACG"), b"CGTT");
}
#[test]
fn test_pad_sequence_already_correct_length() {
let mut rng = create_rng(Some(42));
let seq = vec![b'A', b'C', b'G', b'T'];
let padded = pad_sequence(seq.clone(), 4, &mut rng);
assert_eq!(padded, seq);
}
#[test]
fn test_pad_sequence_needs_padding() {
let mut rng = create_rng(Some(42));
let seq = vec![b'A', b'C'];
let padded = pad_sequence(seq, 6, &mut rng);
assert_eq!(padded.len(), 6);
assert_eq!(&padded[0..2], b"AC");
}
#[test]
fn test_build_record_with_mi_tag() {
let seq = b"ACGT";
let quals = vec![30, 30, 30, 30];
let record = build_record(
"test_read",
seq,
&quals,
0,
100,
60,
true,
false,
200,
150,
"AAAAAAAA",
"42/A",
"4M",
60,
);
assert!(record.name().is_some());
let flags = record.flags();
assert!(flags.is_segmented());
assert!(flags.is_first_segment());
}
#[test]
fn test_build_record_simplex_mi_tag() {
let seq = b"ACGT";
let quals = vec![30, 30, 30, 30];
let record = build_record(
"test_read",
seq,
&quals,
0,
100,
60,
true,
false,
200,
150,
"AAAAAAAA",
"42",
"4M",
60,
);
assert!(record.name().is_some());
}
#[test]
fn test_generate_read_pair_records() {
let mut rng = create_rng(Some(42));
let quality_model = PositionQualityModel::default();
let quality_bias = ReadPairQualityBias::default();
let (r1, r2) = generate_read_pair_records(
"read_001",
"ACGTACGT",
"0/A",
1000,
300,
150,
60,
'A',
&quality_model,
&quality_bias,
&mut rng,
);
assert!(r1.flags().is_first_segment());
assert!(!r1.flags().is_reverse_complemented());
assert!(r2.flags().is_last_segment());
assert!(r2.flags().is_reverse_complemented());
}
#[test]
fn test_generate_read_pair_records_b_strand() {
let mut rng = create_rng(Some(42));
let quality_model = PositionQualityModel::default();
let quality_bias = ReadPairQualityBias::default();
let (r1, r2) = generate_read_pair_records(
"read_001",
"ACGTACGT",
"5/B",
1000,
300,
150,
60,
'B',
&quality_model,
&quality_bias,
&mut rng,
);
assert!(r1.flags().is_first_segment());
assert!(r1.flags().is_reverse_complemented());
assert!(r2.flags().is_last_segment());
assert!(!r2.flags().is_reverse_complemented());
}
#[test]
fn test_generate_read_pair_records_simplex() {
let mut rng = create_rng(Some(42));
let quality_model = PositionQualityModel::default();
let quality_bias = ReadPairQualityBias::default();
let (r1, r2) = generate_read_pair_records(
"read_001",
"ACGTACGT",
"10",
1000,
300,
150,
60,
'+',
&quality_model,
&quality_bias,
&mut rng,
);
assert!(r1.flags().is_first_segment());
assert!(!r1.flags().is_reverse_complemented());
assert!(r2.flags().is_last_segment());
assert!(r2.flags().is_reverse_complemented());
}
#[test]
fn test_generate_read_pair_records_small_insert() {
let mut rng = create_rng(Some(42));
let quality_model = PositionQualityModel::default();
let quality_bias = ReadPairQualityBias::default();
let (r1, r2) = generate_read_pair_records(
"read_001",
"ACGTACGT",
"0/A",
1000,
50, 150,
60,
'A',
&quality_model,
&quality_bias,
&mut rng,
);
assert!(r1.flags().is_first_segment());
assert!(r2.flags().is_last_segment());
}
#[test]
fn test_generate_read_pair_records_reproducibility() {
let quality_model = PositionQualityModel::default();
let quality_bias = ReadPairQualityBias::default();
let mut rng1 = create_rng(Some(42));
let mut rng2 = create_rng(Some(42));
let (r1_a, r2_a) = generate_read_pair_records(
"read_001",
"ACGTACGT",
"0/A",
1000,
300,
150,
60,
'A',
&quality_model,
&quality_bias,
&mut rng1,
);
let (r1_b, r2_b) = generate_read_pair_records(
"read_001",
"ACGTACGT",
"0/A",
1000,
300,
150,
60,
'A',
&quality_model,
&quality_bias,
&mut rng2,
);
assert_eq!(r1_a.flags(), r1_b.flags());
assert_eq!(r2_a.flags(), r2_b.flags());
}
#[test]
fn test_build_record_r2_flags() {
let seq = b"ACGT";
let quals = vec![30, 30, 30, 30];
let record = build_record(
"test_read",
seq,
&quals,
0,
200,
60,
false, true, 100,
-150,
"AAAAAAAA",
"42/B",
"4M",
60,
);
let flags = record.flags();
assert!(flags.is_last_segment());
assert!(!flags.is_first_segment());
assert!(flags.is_reverse_complemented());
assert!(!flags.is_mate_reverse_complemented());
}
#[test]
fn test_build_record_positions() {
let seq = b"ACGTACGT";
let quals = vec![30; 8];
let record = build_record(
"read", seq, &quals, 0, 1000, 60, true, false, 1100, 200, "AAA", "0/A", "8M", 60,
);
assert_eq!(
record.alignment_start().expect("record should have alignment start").get(),
1001
);
assert_eq!(
record.mate_alignment_start().expect("record should have mate alignment start").get(),
1101
);
assert_eq!(record.template_length(), 200);
}
#[test]
fn test_build_record_mapping_quality() {
let seq = b"ACGT";
let quals = vec![30; 4];
for mapq in [0, 30, 60] {
let record = build_record(
"read", seq, &quals, 0, 100, mapq, true, false, 200, 100, "AAA", "0", "4M", mapq,
);
assert_eq!(
record.mapping_quality().expect("record should have mapping quality").get(),
mapq
);
}
}
#[test]
fn test_build_record_quality_scores() {
let seq = b"ACGTACGT";
let quals = vec![10, 20, 30, 40, 30, 20, 10, 5];
let record = build_record(
"read", seq, &quals, 0, 100, 60, true, false, 200, 100, "AAA", "0", "8M", 60,
);
let record_quals: Vec<u8> = record.quality_scores().iter().collect();
assert_eq!(record_quals, quals);
}
#[test]
fn test_pad_sequence_truncates() {
let mut rng = create_rng(Some(42));
let seq = vec![b'A', b'C', b'G', b'T', b'A', b'C'];
let padded = pad_sequence(seq, 3, &mut rng);
assert_eq!(padded, vec![b'A', b'C', b'G']);
}
#[test]
fn test_pad_sequence_empty_to_length() {
let mut rng = create_rng(Some(42));
let seq: Vec<u8> = vec![];
let padded = pad_sequence(seq, 5, &mut rng);
assert_eq!(padded.len(), 5);
for &base in &padded {
assert!(base == b'A' || base == b'C' || base == b'G' || base == b'T');
}
}
#[test]
fn test_reverse_complement_empty() {
let empty: &[u8] = b"";
assert_eq!(reverse_complement(empty), Vec::<u8>::new());
}
#[test]
fn test_reverse_complement_unknown_base() {
assert_eq!(reverse_complement(b"N"), b"N");
assert_eq!(reverse_complement(b"ANCG"), b"CGNT");
}
fn check_collision(num_positions: usize, ref_length: usize) -> Result<(), String> {
let usable_bases = ref_length.saturating_sub(1000);
let bases_per_position = usable_bases as f64 / num_positions as f64;
if bases_per_position < 1.0 {
Err(format!(
"Position collision: {num_positions} positions cannot fit in {ref_length} bp reference ({bases_per_position:.2} bp/position)"
))
} else {
Ok(())
}
}
#[test]
fn test_collision_detection_error() {
let result = check_collision(5_000_000, 10_000_000);
assert!(result.is_ok(), "Should pass with ~1.8 bp/position");
let result = check_collision(10_000_000, 10_000_000);
assert!(result.is_err(), "Should fail with <1 bp/position");
let result = check_collision(20_000_000, 10_000_000);
assert!(result.is_err());
}
#[test]
fn test_collision_detection_success() {
let result = check_collision(5_000_000, 250_000_000);
assert!(result.is_ok());
let result = check_collision(1_000_000, 250_000_000);
assert!(result.is_ok());
}
#[test]
fn test_a_strand_orientation() {
let mut rng = create_rng(Some(42));
let quality_model = PositionQualityModel::default();
let quality_bias = ReadPairQualityBias::default();
let (r1, r2) = generate_read_pair_records(
"read_001",
"ACGTACGT",
"1/A",
1000,
300,
150,
60,
'A',
&quality_model,
&quality_bias,
&mut rng,
);
assert!(r1.flags().is_first_segment());
assert!(!r1.flags().is_reverse_complemented());
assert!(r2.flags().is_last_segment());
assert!(r2.flags().is_reverse_complemented());
}
}