use super::sort::TemplateCoordKey;
use crate::commands::common::MethylationModeArg;
use anyhow::{Context, Result, bail};
use clap::Args;
use fgumi_consensus::MethylationMode;
use fgumi_consensus::methylation::is_cpg_context;
use log::info;
use noodles::fasta;
use rand::{Rng, RngExt};
use std::fs::File;
use std::io::BufReader;
use std::path::PathBuf;
#[derive(Args, Debug, Clone)]
pub struct SimulationCommon {
#[arg(long = "seed")]
pub seed: Option<u64>,
#[arg(short = 'n', long = "num-molecules", default_value = "1000")]
pub num_molecules: usize,
#[arg(short = 'l', long = "read-length", default_value = "150")]
pub read_length: usize,
#[arg(short = 'u', long = "umi-length", default_value = "8")]
pub umi_length: usize,
}
#[derive(Args, Debug, Clone)]
pub struct QualityArgs {
#[arg(long = "warmup-bases", default_value = "10")]
pub warmup_bases: usize,
#[arg(long = "warmup-quality", default_value = "25")]
pub warmup_quality: u8,
#[arg(long = "peak-quality", default_value = "37")]
pub peak_quality: u8,
#[arg(long = "decay-start", default_value = "100")]
pub decay_start: usize,
#[arg(long = "decay-rate", default_value = "0.08")]
pub decay_rate: f64,
#[arg(long = "quality-noise", default_value = "2.0", value_parser = parse_noise_stddev)]
pub quality_noise: f64,
#[arg(long = "r2-quality-offset", default_value = "-2", allow_hyphen_values = true)]
pub r2_quality_offset: i8,
}
fn parse_noise_stddev(s: &str) -> Result<f64, String> {
let val: f64 = s.parse().map_err(|e| format!("invalid float: {e}"))?;
if !val.is_finite() || val < 0.0 {
return Err(format!("quality-noise must be finite and >= 0.0, got {val}"));
}
Ok(val)
}
impl QualityArgs {
pub fn to_quality_model(&self) -> crate::simulate::PositionQualityModel {
crate::simulate::PositionQualityModel::new(
self.warmup_bases,
self.warmup_quality,
self.peak_quality,
self.decay_start,
self.decay_rate,
2, self.quality_noise,
)
}
pub fn to_quality_bias(&self) -> crate::simulate::ReadPairQualityBias {
crate::simulate::ReadPairQualityBias::new(self.r2_quality_offset)
}
}
#[derive(Args, Debug, Clone)]
pub struct InsertSizeArgs {
#[arg(long = "insert-size-mean", default_value = "300.0")]
pub insert_size_mean: f64,
#[arg(long = "insert-size-stddev", default_value = "50.0")]
pub insert_size_stddev: f64,
#[arg(long = "insert-size-min", default_value = "50")]
pub insert_size_min: usize,
#[arg(long = "insert-size-max", default_value = "800")]
pub insert_size_max: usize,
}
impl InsertSizeArgs {
pub fn to_insert_size_model(&self) -> crate::simulate::InsertSizeModel {
crate::simulate::InsertSizeModel::new(
self.insert_size_mean,
self.insert_size_stddev,
self.insert_size_min,
self.insert_size_max,
)
}
}
#[derive(Args, Debug, Clone)]
pub struct FamilySizeArgs {
#[arg(long = "family-size-dist", default_value = "lognormal")]
pub family_size_dist: String,
#[arg(long = "family-size-mean", default_value = "3.0")]
pub family_size_mean: f64,
#[arg(long = "family-size-stddev", default_value = "2.0")]
pub family_size_stddev: f64,
#[arg(long = "family-size-r", default_value = "2.0")]
pub family_size_r: f64,
#[arg(long = "family-size-p", default_value = "0.5")]
pub family_size_p: f64,
#[arg(long = "min-family-size", default_value = "1")]
pub min_family_size: usize,
}
impl FamilySizeArgs {
pub fn to_family_size_distribution(
&self,
) -> anyhow::Result<crate::simulate::FamilySizeDistribution> {
match self.family_size_dist.as_str() {
"lognormal" => Ok(crate::simulate::FamilySizeDistribution::log_normal(
self.family_size_mean,
self.family_size_stddev,
)),
"negbin" => Ok(crate::simulate::FamilySizeDistribution::negative_binomial(
self.family_size_r,
self.family_size_p,
)),
path => {
crate::simulate::FamilySizeDistribution::from_histogram(path)
}
}
}
}
#[derive(Args, Debug, Clone)]
pub struct StrandBiasArgs {
#[arg(long = "strand-alpha", default_value = "5.0")]
pub strand_alpha: f64,
#[arg(long = "strand-beta", default_value = "5.0")]
pub strand_beta: f64,
}
impl StrandBiasArgs {
pub fn to_strand_bias_model(&self) -> crate::simulate::StrandBiasModel {
crate::simulate::StrandBiasModel::new(self.strand_alpha, self.strand_beta)
}
}
#[derive(Args, Debug, Clone)]
pub struct MethylationArgs {
#[arg(long = "methylation-mode", value_enum)]
pub methylation_mode: Option<MethylationModeArg>,
#[arg(long = "cpg-methylation-rate", default_value = "0.75")]
pub cpg_methylation_rate: f64,
#[arg(long = "conversion-rate", default_value = "0.98")]
pub conversion_rate: f64,
}
impl MethylationArgs {
pub fn resolve(&self) -> MethylationConfig {
MethylationConfig {
mode: crate::commands::common::resolve_methylation_mode(self.methylation_mode),
cpg_methylation_rate: self.cpg_methylation_rate,
conversion_rate: self.conversion_rate,
}
}
pub fn validate(&self) -> anyhow::Result<()> {
validate_rate(self.cpg_methylation_rate, "cpg-methylation-rate")?;
validate_rate(self.conversion_rate, "conversion-rate")?;
Ok(())
}
}
#[derive(Debug, Clone, Copy)]
pub struct MethylationConfig {
pub mode: MethylationMode,
pub cpg_methylation_rate: f64,
pub conversion_rate: f64,
}
pub(super) fn validate_rate(value: f64, name: &str) -> anyhow::Result<()> {
if !value.is_finite() || !(0.0..=1.0).contains(&value) {
anyhow::bail!("--{name} must be a finite value between 0.0 and 1.0, got {value}");
}
Ok(())
}
#[derive(Args, Debug, Clone)]
pub struct ReferenceArgs {
#[arg(short = 'r', long = "reference", required = true)]
pub reference: PathBuf,
}
const MIN_CONTIG_LENGTH: usize = 1000;
pub(super) struct ReferenceGenome {
names: Vec<String>,
sequences: Vec<Vec<u8>>,
cumulative_lengths: Vec<usize>,
total_length: usize,
}
impl ReferenceGenome {
pub fn load<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
let path = path.as_ref();
info!("Loading reference from {}", path.display());
let file = File::open(path)
.with_context(|| format!("Failed to open reference: {}", path.display()))?;
let reader = BufReader::new(file);
let mut fasta_reader = fasta::io::Reader::new(reader);
let mut names = Vec::new();
let mut sequences = Vec::new();
let mut cumulative_lengths = Vec::new();
let mut total_length = 0usize;
for result in fasta_reader.records() {
let record = result.with_context(|| "Failed to read FASTA record")?;
let name = std::str::from_utf8(record.name())
.with_context(|| "Invalid chromosome name")?
.to_string();
let seq: Vec<u8> =
record.sequence().as_ref().iter().map(|&b| b.to_ascii_uppercase()).collect();
if seq.len() >= MIN_CONTIG_LENGTH {
total_length += seq.len();
cumulative_lengths.push(total_length);
names.push(name);
sequences.push(seq);
}
}
if sequences.is_empty() {
bail!("No valid sequences found in reference FASTA");
}
info!("Loaded {} chromosomes, total {} bp", sequences.len(), total_length);
Ok(Self { names, sequences, cumulative_lengths, total_length })
}
pub fn name(&self, chrom_idx: usize) -> &str {
&self.names[chrom_idx]
}
pub fn sample_sequence(
&self,
length: usize,
rng: &mut impl Rng,
) -> Option<(usize, usize, Vec<u8>)> {
if length == 0 || length > self.total_length {
return None;
}
let start_bound = self.total_length - length + 1;
for _ in 0..10 {
let genome_pos = rng.random_range(0..start_bound);
let chrom_idx = self.cumulative_lengths.partition_point(|&cum| cum <= genome_pos);
let chrom_start =
if chrom_idx == 0 { 0 } else { self.cumulative_lengths[chrom_idx - 1] };
let local_pos = genome_pos - chrom_start;
let seq = &self.sequences[chrom_idx];
if local_pos + length > seq.len() {
continue;
}
let template = &seq[local_pos..local_pos + length];
if template.iter().any(|&b| b == b'N' || b == b'n') {
continue;
}
return Some((chrom_idx, local_pos, template.to_vec()));
}
None
}
pub fn total_length(&self) -> usize {
self.total_length
}
#[allow(dead_code)] pub fn sequence_at_genome_pos(&self, genome_pos: usize, length: usize) -> Option<Vec<u8>> {
if self.total_length == 0 {
return None;
}
let genome_pos = genome_pos % self.total_length;
let chrom_idx = self.cumulative_lengths.partition_point(|&cum| cum <= genome_pos);
let chrom_start = if chrom_idx == 0 { 0 } else { self.cumulative_lengths[chrom_idx - 1] };
let local_pos = genome_pos - chrom_start;
self.sequence_at(chrom_idx, local_pos, length)
}
pub fn sequence_at(&self, chrom_idx: usize, pos: usize, length: usize) -> Option<Vec<u8>> {
if chrom_idx >= self.sequences.len() {
return None;
}
let seq = &self.sequences[chrom_idx];
if pos + length > seq.len() {
return None;
}
let subseq = &seq[pos..pos + length];
if subseq.iter().any(|&b| b == b'N' || b == b'n') {
return None;
}
Some(subseq.to_vec())
}
pub(super) fn build_bam_header(&self) -> noodles::sam::header::Header {
use bstr::BString;
use noodles::sam::header::Header;
use noodles::sam::header::record::value::Map;
use noodles::sam::header::record::value::map::ReferenceSequence;
use std::num::NonZeroUsize;
let mut builder = Header::builder();
for (name, seq) in self.names.iter().zip(self.sequences.iter()) {
let length = NonZeroUsize::try_from(seq.len()).expect("chromosome length must be > 0");
let ref_seq = Map::<ReferenceSequence>::new(length);
builder = builder.add_reference_sequence(BString::from(name.as_str()), ref_seq);
}
builder.build()
}
pub(super) fn max_contig_length(&self) -> usize {
self.sequences.iter().map(|s| s.len()).max().unwrap_or(0)
}
#[allow(dead_code)] pub(super) fn num_chromosomes(&self) -> usize {
self.sequences.len()
}
#[allow(dead_code)] pub(super) fn chromosome_length(&self, chrom_idx: usize) -> usize {
self.sequences[chrom_idx].len()
}
pub(super) fn sample_positions(
&self,
num_positions: usize,
rng: &mut impl Rng,
) -> Vec<(usize, usize)> {
const WINDOW: usize = MIN_CONTIG_LENGTH;
let max_attempts = num_positions.saturating_mul(100).max(1);
let mut positions = Vec::with_capacity(num_positions);
let mut attempts = 0usize;
while positions.len() < num_positions {
assert!(
attempts < max_attempts,
"sample_positions: exhausted {max_attempts} attempts to find \
{num_positions} N-free positions in the reference"
);
attempts += 1;
let genome_pos = rng.random_range(0..self.total_length);
let chrom_idx = self.cumulative_lengths.partition_point(|&cum| cum <= genome_pos);
let chrom_start =
if chrom_idx == 0 { 0 } else { self.cumulative_lengths[chrom_idx - 1] };
let local_pos = genome_pos - chrom_start;
let seq = &self.sequences[chrom_idx];
let window_end = (local_pos + WINDOW).min(seq.len());
let window_start = local_pos.min(window_end);
let window = &seq[window_start..window_end];
if window.iter().any(|&b| b == b'N' || b == b'n') {
continue;
}
positions.push((chrom_idx, local_pos));
}
positions
}
}
#[derive(Args, Debug, Clone)]
pub struct PositionDistArgs {
#[arg(long = "num-positions")]
pub num_positions: Option<usize>,
#[arg(long = "umis-per-position", default_value = "1")]
pub umis_per_position: usize,
}
pub(super) fn generate_random_sequence(len: usize, rng: &mut impl Rng) -> Vec<u8> {
const BASES: &[u8] = b"ACGT";
let mut seq = Vec::with_capacity(len);
for _ in 0..len {
seq.push(BASES[rng.random_range(0..4)]);
}
seq
}
pub(super) fn pad_sequence(mut seq: Vec<u8>, target_len: usize, rng: &mut impl Rng) -> Vec<u8> {
while seq.len() < target_len {
seq.push(b"ACGT"[rng.random_range(0..4)]);
}
seq.truncate(target_len);
seq
}
pub(super) fn apply_methylation_conversion(
read_seq: &mut [u8],
ref_seq: &[u8],
ref_offset: usize,
is_top_strand: bool,
config: &MethylationConfig,
rng: &mut impl Rng,
) {
if !config.mode.is_enabled() {
return;
}
for (i, base) in read_seq.iter_mut().enumerate() {
let ref_pos = ref_offset + i;
if ref_pos >= ref_seq.len() {
break;
}
let ref_base = ref_seq[ref_pos].to_ascii_uppercase();
if is_top_strand && ref_base == b'C' {
let cpg = is_cpg_context(ref_seq, ref_pos, true);
if is_conversion_target(config.mode, cpg, config.cpg_methylation_rate, rng)
&& rng.random::<f64>() < config.conversion_rate
{
*base = b'T';
}
} else if !is_top_strand && ref_base == b'G' {
let cpg = is_cpg_context(ref_seq, ref_pos, false);
if is_conversion_target(config.mode, cpg, config.cpg_methylation_rate, rng)
&& rng.random::<f64>() < config.conversion_rate
{
*base = b'A';
}
}
}
}
fn is_conversion_target(
mode: MethylationMode,
is_cpg: bool,
cpg_methylation_rate: f64,
rng: &mut impl Rng,
) -> bool {
if is_cpg {
let methylated = rng.random::<f64>() < cpg_methylation_rate;
match mode {
MethylationMode::EmSeq => !methylated, MethylationMode::Taps => methylated, MethylationMode::Disabled => false,
}
} else {
match mode {
MethylationMode::EmSeq => true, MethylationMode::Taps => false, MethylationMode::Disabled => false,
}
}
}
#[derive(Debug)]
pub(super) struct MoleculeInfo {
pub mol_id: usize,
pub seed: u64,
pub sort_key: TemplateCoordKey,
pub is_unmapped: bool,
}
impl Ord for MoleculeInfo {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.sort_key.cmp(&other.sort_key)
}
}
impl PartialOrd for MoleculeInfo {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl PartialEq for MoleculeInfo {
fn eq(&self, other: &Self) -> bool {
self.sort_key == other.sort_key
}
}
impl Eq for MoleculeInfo {}
#[cfg(test)]
mod tests {
use super::*;
use crate::simulate::create_rng;
use rstest::rstest;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_quality_args_default_model() {
let args = QualityArgs {
warmup_bases: 10,
warmup_quality: 25,
peak_quality: 37,
decay_start: 100,
decay_rate: 0.08,
quality_noise: 2.0,
r2_quality_offset: -2,
};
let model = args.to_quality_model();
assert_eq!(model.warmup_bases, 10);
assert_eq!(model.warmup_start, 25);
assert_eq!(model.peak_quality, 37);
assert_eq!(model.decay_start, 100);
assert!((model.decay_rate - 0.08).abs() < f64::EPSILON);
}
#[test]
fn test_quality_args_custom_values() {
let args = QualityArgs {
warmup_bases: 5,
warmup_quality: 20,
peak_quality: 40,
decay_start: 80,
decay_rate: 0.1,
quality_noise: 1.5,
r2_quality_offset: -3,
};
let model = args.to_quality_model();
assert_eq!(model.warmup_bases, 5);
assert_eq!(model.peak_quality, 40);
let bias = args.to_quality_bias();
assert_eq!(bias.r2_offset, -3);
}
#[test]
fn test_quality_bias_positive_offset() {
let args = QualityArgs {
warmup_bases: 10,
warmup_quality: 25,
peak_quality: 37,
decay_start: 100,
decay_rate: 0.08,
quality_noise: 2.0,
r2_quality_offset: 3, };
let bias = args.to_quality_bias();
assert_eq!(bias.apply(30, true), 33);
}
#[test]
fn test_insert_size_args_default() {
let args = InsertSizeArgs {
insert_size_mean: 300.0,
insert_size_stddev: 50.0,
insert_size_min: 50,
insert_size_max: 800,
};
let model = args.to_insert_size_model();
assert!((model.mean - 300.0).abs() < f64::EPSILON);
assert_eq!(model.min, 50);
assert_eq!(model.max, 800);
}
#[test]
fn test_insert_size_args_narrow_range() {
let args = InsertSizeArgs {
insert_size_mean: 200.0,
insert_size_stddev: 10.0,
insert_size_min: 180,
insert_size_max: 220,
};
let model = args.to_insert_size_model();
let mut rng = create_rng(Some(42));
for _ in 0..100 {
let size = model.sample(&mut rng);
assert!((180..=220).contains(&size));
}
}
#[test]
fn test_family_size_args_lognormal() {
let args = FamilySizeArgs {
family_size_dist: "lognormal".to_string(),
family_size_mean: 3.0,
family_size_stddev: 2.0,
family_size_r: 2.0,
family_size_p: 0.5,
min_family_size: 1,
};
let dist =
args.to_family_size_distribution().expect("lognormal distribution should be created");
let mut rng = create_rng(Some(42));
let size = dist.sample(&mut rng, 1);
assert!(size >= 1);
}
#[test]
fn test_family_size_args_negbin() {
let args = FamilySizeArgs {
family_size_dist: "negbin".to_string(),
family_size_mean: 3.0,
family_size_stddev: 2.0,
family_size_r: 2.0,
family_size_p: 0.5,
min_family_size: 1,
};
let dist =
args.to_family_size_distribution().expect("negbin distribution should be created");
let mut rng = create_rng(Some(42));
let size = dist.sample(&mut rng, 1);
assert!(size >= 1);
}
#[test]
fn test_family_size_args_from_histogram() -> anyhow::Result<()> {
let mut temp = NamedTempFile::new()?;
writeln!(temp, "family_size\tcount")?;
writeln!(temp, "1\t50")?;
writeln!(temp, "2\t30")?;
writeln!(temp, "3\t20")?;
temp.flush()?;
let args = FamilySizeArgs {
family_size_dist: temp.path().to_string_lossy().to_string(),
family_size_mean: 3.0,
family_size_stddev: 2.0,
family_size_r: 2.0,
family_size_p: 0.5,
min_family_size: 1,
};
let dist = args.to_family_size_distribution()?;
let mut rng = create_rng(Some(42));
for _ in 0..100 {
let size = dist.sample(&mut rng, 1);
assert!((1..=3).contains(&size));
}
Ok(())
}
#[test]
fn test_family_size_args_invalid_histogram() {
let args = FamilySizeArgs {
family_size_dist: "/nonexistent/path/histogram.tsv".to_string(),
family_size_mean: 3.0,
family_size_stddev: 2.0,
family_size_r: 2.0,
family_size_p: 0.5,
min_family_size: 1,
};
let result = args.to_family_size_distribution();
assert!(result.is_err());
}
#[test]
fn test_strand_bias_args_symmetric() {
let args = StrandBiasArgs { strand_alpha: 5.0, strand_beta: 5.0 };
let model = args.to_strand_bias_model();
assert!((model.alpha - 5.0).abs() < f64::EPSILON);
assert!((model.beta - 5.0).abs() < f64::EPSILON);
let mut rng = create_rng(Some(42));
let fractions: Vec<f64> = (0..1000).map(|_| model.sample_a_fraction(&mut rng)).collect();
let mean: f64 = fractions.iter().sum::<f64>() / fractions.len() as f64;
assert!((mean - 0.5).abs() < 0.05);
}
#[test]
fn test_strand_bias_args_a_biased() {
let args = StrandBiasArgs { strand_alpha: 8.0, strand_beta: 2.0 };
let model = args.to_strand_bias_model();
let mut rng = create_rng(Some(42));
let fractions: Vec<f64> = (0..1000).map(|_| model.sample_a_fraction(&mut rng)).collect();
let mean: f64 = fractions.iter().sum::<f64>() / fractions.len() as f64;
assert!(mean > 0.7);
}
#[test]
fn test_strand_bias_args_b_biased() {
let args = StrandBiasArgs { strand_alpha: 2.0, strand_beta: 8.0 };
let model = args.to_strand_bias_model();
let mut rng = create_rng(Some(42));
let fractions: Vec<f64> = (0..1000).map(|_| model.sample_a_fraction(&mut rng)).collect();
let mean: f64 = fractions.iter().sum::<f64>() / fractions.len() as f64;
assert!(mean < 0.3);
}
#[test]
fn test_quality_args_zero_warmup() {
let args = QualityArgs {
warmup_bases: 0,
warmup_quality: 25,
peak_quality: 37,
decay_start: 100,
decay_rate: 0.08,
quality_noise: 2.0,
r2_quality_offset: -2,
};
let model = args.to_quality_model();
assert_eq!(model.warmup_bases, 0);
}
#[test]
fn test_quality_args_high_peak_quality() {
let args = QualityArgs {
warmup_bases: 10,
warmup_quality: 30,
peak_quality: 41, decay_start: 100,
decay_rate: 0.08,
quality_noise: 0.0, r2_quality_offset: 0,
};
let model = args.to_quality_model();
let mut rng = create_rng(Some(42));
let quals = model.generate_qualities(50, &mut rng);
assert!(!quals.is_empty());
}
#[test]
fn test_insert_size_args_min_equals_max() {
let args = InsertSizeArgs {
insert_size_mean: 200.0,
insert_size_stddev: 50.0,
insert_size_min: 200,
insert_size_max: 200,
};
let model = args.to_insert_size_model();
let mut rng = create_rng(Some(42));
for _ in 0..10 {
let size = model.sample(&mut rng);
assert_eq!(size, 200);
}
}
#[test]
fn test_family_size_args_high_min() {
let args = FamilySizeArgs {
family_size_dist: "lognormal".to_string(),
family_size_mean: 3.0,
family_size_stddev: 2.0,
family_size_r: 2.0,
family_size_p: 0.5,
min_family_size: 5, };
let dist = args
.to_family_size_distribution()
.expect("lognormal distribution with high min should be created");
let mut rng = create_rng(Some(42));
for _ in 0..100 {
let size = dist.sample(&mut rng, 5);
assert!(size >= 5, "Size {size} should be >= min 5");
}
}
#[test]
fn test_strand_bias_split_reads() {
let args = StrandBiasArgs { strand_alpha: 5.0, strand_beta: 5.0 };
let model = args.to_strand_bias_model();
let mut rng = create_rng(Some(42));
for total in [0, 1, 2, 5, 10, 100] {
let (a, b) = model.split_reads(total, &mut rng);
assert_eq!(a + b, total, "A ({a}) + B ({b}) should equal total ({total})");
}
}
#[test]
fn test_strand_bias_split_zero_total() {
let args = StrandBiasArgs { strand_alpha: 5.0, strand_beta: 5.0 };
let model = args.to_strand_bias_model();
let mut rng = create_rng(Some(42));
let (a, b) = model.split_reads(0, &mut rng);
assert_eq!(a, 0);
assert_eq!(b, 0);
}
#[test]
fn test_strand_bias_split_one_read() {
let args = StrandBiasArgs { strand_alpha: 5.0, strand_beta: 5.0 };
let model = args.to_strand_bias_model();
let mut rng = create_rng(Some(42));
let (a, b) = model.split_reads(1, &mut rng);
assert_eq!(a + b, 1);
assert!(a <= 1 && b <= 1);
}
#[rstest]
#[case("0.0", 0.0)]
#[case("2.5", 2.5)]
#[case("100.0", 100.0)]
fn test_parse_noise_stddev_accepts_valid(#[case] input: &str, #[case] expected: f64) {
let parsed = parse_noise_stddev(input).expect("valid noise stddev should parse");
assert!((parsed - expected).abs() < f64::EPSILON);
}
#[rstest]
#[case("-0.1")]
#[case("NaN")]
#[case("inf")]
#[case("-inf")]
#[case("abc")]
fn test_parse_noise_stddev_rejects_invalid(#[case] input: &str) {
assert!(parse_noise_stddev(input).is_err(), "input should be rejected: {input}");
}
#[test]
fn test_quality_args_zero_noise() {
let args = QualityArgs {
warmup_bases: 10,
warmup_quality: 25,
peak_quality: 37,
decay_start: 100,
decay_rate: 0.08,
quality_noise: 0.0, r2_quality_offset: 0,
};
let model = args.to_quality_model();
assert!((model.noise_stddev - 0.0).abs() < f64::EPSILON);
}
#[test]
fn test_insert_size_sample_distribution() {
let args = InsertSizeArgs {
insert_size_mean: 300.0,
insert_size_stddev: 50.0,
insert_size_min: 50,
insert_size_max: 800,
};
let model = args.to_insert_size_model();
let mut rng = create_rng(Some(42));
let samples: Vec<usize> = (0..1000).map(|_| model.sample(&mut rng)).collect();
let mean: f64 = samples.iter().map(|&s| s as f64).sum::<f64>() / samples.len() as f64;
assert!(mean > 280.0 && mean < 320.0, "Mean {mean} not close to expected 300");
}
#[test]
fn test_family_size_distribution_type_case_insensitive() {
let args_lower = FamilySizeArgs {
family_size_dist: "lognormal".to_string(),
family_size_mean: 3.0,
family_size_stddev: 2.0,
family_size_r: 2.0,
family_size_p: 0.5,
min_family_size: 1,
};
let _ = args_lower
.to_family_size_distribution()
.expect("lowercase distribution name should be accepted");
}
fn make_molecule_info(mol_id: usize, tid1: i32, pos1: i64) -> MoleculeInfo {
MoleculeInfo {
mol_id,
seed: 0,
sort_key: TemplateCoordKey {
tid1,
tid2: 0,
pos1,
pos2: 0,
neg1: false,
neg2: false,
mid: String::new(),
name: String::new(),
is_upper_of_pair: false,
},
is_unmapped: false,
}
}
#[test]
fn test_molecule_info_ordering() {
let a = make_molecule_info(0, 1, 100);
let b = make_molecule_info(1, 1, 200);
let c = make_molecule_info(2, 2, 50);
assert!(a < b);
assert!(b > a);
assert!(b < c);
assert_eq!(a.partial_cmp(&b), Some(std::cmp::Ordering::Less));
let a2 = make_molecule_info(99, 1, 100);
assert_eq!(a, a2);
assert_eq!(a.cmp(&a2), std::cmp::Ordering::Equal);
}
fn write_test_fasta(seq: &[u8]) -> NamedTempFile {
let mut f = NamedTempFile::new().unwrap();
writeln!(f, ">chr1").unwrap();
f.write_all(seq).unwrap();
writeln!(f).unwrap();
f.flush().unwrap();
f
}
#[test]
fn test_reference_genome_load_and_sample() {
let seq = b"ACGT".repeat(375);
let fasta = write_test_fasta(&seq);
let genome = ReferenceGenome::load(fasta.path()).unwrap();
assert_eq!(genome.name(0), "chr1");
assert!(genome.sequence_at(0, 0, 1500).is_some());
}
#[test]
fn test_reference_genome_sequence_at_valid() {
let seq = b"ACGT".repeat(375);
let fasta = write_test_fasta(&seq);
let genome = ReferenceGenome::load(fasta.path()).unwrap();
let subseq = genome.sequence_at(0, 4, 8).unwrap();
assert_eq!(subseq, b"ACGTACGT");
}
#[test]
fn test_reference_genome_sequence_at_out_of_bounds() {
let seq = b"ACGT".repeat(375);
let fasta = write_test_fasta(&seq);
let genome = ReferenceGenome::load(fasta.path()).unwrap();
assert!(genome.sequence_at(0, 1495, 10).is_none());
}
#[test]
fn test_reference_genome_sequence_at_invalid_chrom() {
let seq = b"ACGT".repeat(375);
let fasta = write_test_fasta(&seq);
let genome = ReferenceGenome::load(fasta.path()).unwrap();
assert!(genome.sequence_at(99, 0, 10).is_none());
}
#[test]
fn test_reference_genome_sequence_at_n_bases() {
let mut seq = b"ACGT".repeat(375);
seq[10] = b'N';
let fasta = write_test_fasta(&seq);
let genome = ReferenceGenome::load(fasta.path()).unwrap();
assert!(genome.sequence_at(0, 8, 4).is_none());
assert!(genome.sequence_at(0, 0, 4).is_some());
}
#[test]
fn test_reference_genome_skips_short_sequences() {
let mut f = NamedTempFile::new().unwrap();
writeln!(f, ">short").unwrap();
let short = b"ACGT".repeat(125);
f.write_all(&short).unwrap();
writeln!(f).unwrap();
writeln!(f, ">long").unwrap();
let long = b"ACGT".repeat(375);
f.write_all(&long).unwrap();
writeln!(f).unwrap();
f.flush().unwrap();
let genome = ReferenceGenome::load(f.path()).unwrap();
assert_eq!(genome.name(0), "long");
assert!(genome.sequence_at(1, 0, 1).is_none()); }
#[test]
fn test_reference_genome_total_length() {
let seq = b"ACGT".repeat(375);
let fasta = write_test_fasta(&seq);
let genome = ReferenceGenome::load(fasta.path()).unwrap();
assert_eq!(genome.total_length(), 1500);
}
#[test]
fn test_reference_genome_sequence_at_genome_pos_single_chrom() {
let seq = b"ACGT".repeat(375);
let fasta = write_test_fasta(&seq);
let genome = ReferenceGenome::load(fasta.path()).unwrap();
let subseq = genome.sequence_at_genome_pos(4, 8).unwrap();
assert_eq!(subseq, b"ACGTACGT");
}
#[test]
fn test_reference_genome_sequence_at_genome_pos_wraps_around() {
let seq = b"ACGT".repeat(375);
let fasta = write_test_fasta(&seq);
let genome = ReferenceGenome::load(fasta.path()).unwrap();
let direct = genome.sequence_at_genome_pos(4, 8).unwrap();
let wrapped = genome.sequence_at_genome_pos(4 + genome.total_length(), 8).unwrap();
assert_eq!(direct, wrapped);
}
#[test]
fn test_reference_genome_sequence_at_genome_pos_multi_chrom() {
let mut f = NamedTempFile::new().unwrap();
writeln!(f, ">chr1").unwrap();
let chr1 = b"AAAA".repeat(375); f.write_all(&chr1).unwrap();
writeln!(f).unwrap();
writeln!(f, ">chr2").unwrap();
let chr2 = b"CCCC".repeat(375); f.write_all(&chr2).unwrap();
writeln!(f).unwrap();
f.flush().unwrap();
let genome = ReferenceGenome::load(f.path()).unwrap();
assert_eq!(genome.total_length(), 3000);
let from_chr1 = genome.sequence_at_genome_pos(0, 4).unwrap();
assert_eq!(from_chr1, b"AAAA");
let from_chr2 = genome.sequence_at_genome_pos(1500, 4).unwrap();
assert_eq!(from_chr2, b"CCCC");
}
#[test]
fn test_reference_genome_sequence_at_genome_pos_boundary() {
let seq = b"ACGT".repeat(375);
let fasta = write_test_fasta(&seq);
let genome = ReferenceGenome::load(fasta.path()).unwrap();
assert!(genome.sequence_at_genome_pos(1490, 20).is_none());
}
#[test]
fn test_reference_genome_sample_sequence_returns_valid() {
let seq = b"ACGT".repeat(375);
let fasta = write_test_fasta(&seq);
let genome = ReferenceGenome::load(fasta.path()).unwrap();
let mut rng = create_rng(Some(42));
let result = genome.sample_sequence(100, &mut rng);
assert!(result.is_some());
let (chrom_idx, _pos, subseq) = result.unwrap();
assert_eq!(chrom_idx, 0);
assert_eq!(subseq.len(), 100);
}
#[test]
fn test_reference_genome_sample_sequence_exact_fit() {
let seq = b"ACGT".repeat(375); let fasta = write_test_fasta(&seq);
let genome = ReferenceGenome::load(fasta.path()).unwrap();
let mut rng = create_rng(Some(42));
let result = genome.sample_sequence(genome.total_length(), &mut rng);
assert!(result.is_some());
let (_chrom_idx, pos, subseq) = result.unwrap();
assert_eq!(pos, 0);
assert_eq!(subseq.len(), genome.total_length());
}
#[test]
fn test_reference_genome_sample_sequence_too_large() {
let seq = b"ACGT".repeat(375); let fasta = write_test_fasta(&seq);
let genome = ReferenceGenome::load(fasta.path()).unwrap();
let mut rng = create_rng(Some(42));
let result = genome.sample_sequence(genome.total_length() + 1, &mut rng);
assert!(result.is_none());
}
#[test]
fn test_reference_genome_sample_sequence_zero_length() {
let seq = b"ACGT".repeat(375);
let fasta = write_test_fasta(&seq);
let genome = ReferenceGenome::load(fasta.path()).unwrap();
let mut rng = create_rng(Some(42));
let result = genome.sample_sequence(0, &mut rng);
assert!(result.is_none());
}
#[test]
fn test_methylation_args_defaults() {
let args = MethylationArgs {
methylation_mode: None,
cpg_methylation_rate: 0.75,
conversion_rate: 0.98,
};
assert!((args.cpg_methylation_rate - 0.75).abs() < f64::EPSILON);
assert!((args.conversion_rate - 0.98).abs() < f64::EPSILON);
}
#[test]
fn test_methylation_args_resolve_disabled() {
let args = MethylationArgs {
methylation_mode: None,
cpg_methylation_rate: 0.75,
conversion_rate: 0.98,
};
assert_eq!(args.resolve().mode, MethylationMode::Disabled);
}
#[test]
fn test_methylation_args_resolve_emseq() {
let args = MethylationArgs {
methylation_mode: Some(MethylationModeArg::EmSeq),
cpg_methylation_rate: 0.75,
conversion_rate: 0.98,
};
let config = args.resolve();
assert_eq!(config.mode, MethylationMode::EmSeq);
assert!((config.cpg_methylation_rate - 0.75).abs() < f64::EPSILON);
assert!((config.conversion_rate - 0.98).abs() < f64::EPSILON);
}
#[test]
fn test_methylation_args_resolve_taps() {
let args = MethylationArgs {
methylation_mode: Some(MethylationModeArg::Taps),
cpg_methylation_rate: 0.75,
conversion_rate: 0.98,
};
assert_eq!(args.resolve().mode, MethylationMode::Taps);
}
#[test]
fn test_methylation_args_validate_valid_rates() {
for rate in [0.0, 0.5, 1.0] {
let args = MethylationArgs {
methylation_mode: None,
cpg_methylation_rate: rate,
conversion_rate: rate,
};
assert!(args.validate().is_ok(), "rate {rate} should be valid");
}
}
#[rstest]
#[case(-0.1)]
#[case(1.1)]
#[case(f64::NAN)]
#[case(f64::INFINITY)]
#[case(f64::NEG_INFINITY)]
fn test_methylation_args_validate_invalid_cpg_rate(#[case] rate: f64) {
let args = MethylationArgs {
methylation_mode: None,
cpg_methylation_rate: rate,
conversion_rate: 0.98,
};
assert!(args.validate().is_err(), "cpg rate {rate} should be invalid");
}
#[rstest]
#[case(-0.1)]
#[case(1.1)]
#[case(f64::NAN)]
#[case(f64::INFINITY)]
fn test_methylation_args_validate_invalid_conversion_rate(#[case] rate: f64) {
let args = MethylationArgs {
methylation_mode: None,
cpg_methylation_rate: 0.75,
conversion_rate: rate,
};
assert!(args.validate().is_err(), "conversion rate {rate} should be invalid");
}
#[allow(clippy::too_many_arguments)]
fn convert(
seq: &mut [u8],
ref_seq: &[u8],
ref_offset: usize,
is_top: bool,
mode: MethylationMode,
cpg_rate: f64,
conv_rate: f64,
seed: u64,
) {
let config =
MethylationConfig { mode, cpg_methylation_rate: cpg_rate, conversion_rate: conv_rate };
let mut rng = create_rng(Some(seed));
apply_methylation_conversion(seq, ref_seq, ref_offset, is_top, &config, &mut rng);
}
#[test]
fn test_emseq_cpg_all_methylated_no_conversion() {
let ref_seq = b"ACGTACGT";
let mut read = ref_seq.to_vec();
convert(&mut read, ref_seq, 0, true, MethylationMode::EmSeq, 1.0, 1.0, 42);
assert_eq!(read[1], b'C', "CpG C should be protected when methylated");
assert_eq!(read[5], b'C', "CpG C should be protected when methylated");
}
#[test]
fn test_emseq_cpg_all_unmethylated_full_conversion() {
let ref_seq = b"ACGTACGT";
let mut read = ref_seq.to_vec();
convert(&mut read, ref_seq, 0, true, MethylationMode::EmSeq, 0.0, 1.0, 42);
assert_eq!(read[1], b'T', "unmethylated CpG C should convert to T");
assert_eq!(read[5], b'T', "unmethylated CpG C should convert to T");
}
#[test]
fn test_emseq_non_cpg_c_always_converts() {
let ref_seq = b"ACCTA";
let mut read = ref_seq.to_vec();
convert(&mut read, ref_seq, 0, true, MethylationMode::EmSeq, 0.75, 1.0, 42);
assert_eq!(read[1], b'T', "non-CpG C should convert to T in EM-Seq");
assert_eq!(read[2], b'T', "non-CpG C should convert to T in EM-Seq");
}
#[test]
fn test_taps_cpg_all_methylated_full_conversion() {
let ref_seq = b"ACGTACGT";
let mut read = ref_seq.to_vec();
convert(&mut read, ref_seq, 0, true, MethylationMode::Taps, 1.0, 1.0, 42);
assert_eq!(read[1], b'T', "methylated CpG C should convert in TAPs");
assert_eq!(read[5], b'T', "methylated CpG C should convert in TAPs");
}
#[test]
fn test_taps_cpg_all_unmethylated_no_conversion() {
let ref_seq = b"ACGTACGT";
let mut read = ref_seq.to_vec();
convert(&mut read, ref_seq, 0, true, MethylationMode::Taps, 0.0, 1.0, 42);
assert_eq!(read[1], b'C', "unmethylated CpG C should not convert in TAPs");
assert_eq!(read[5], b'C', "unmethylated CpG C should not convert in TAPs");
}
#[test]
fn test_taps_non_cpg_c_never_converts() {
let ref_seq = b"ACCTA";
let mut read = ref_seq.to_vec();
convert(&mut read, ref_seq, 0, true, MethylationMode::Taps, 0.75, 1.0, 42);
assert_eq!(read[1], b'C', "non-CpG C should not convert in TAPs");
assert_eq!(read[2], b'C', "non-CpG C should not convert in TAPs");
}
#[test]
fn test_bottom_strand_emseq_converts_g_to_a() {
let ref_seq = b"ACGT";
let mut read = ref_seq.to_vec();
convert(&mut read, ref_seq, 0, false, MethylationMode::EmSeq, 0.0, 1.0, 42);
assert_eq!(read[2], b'A', "bottom strand unmethylated CpG G should convert to A");
}
#[test]
fn test_bottom_strand_taps_converts_g_to_a() {
let ref_seq = b"ACGT";
let mut read = ref_seq.to_vec();
convert(&mut read, ref_seq, 0, false, MethylationMode::Taps, 1.0, 1.0, 42);
assert_eq!(read[2], b'A', "bottom strand methylated CpG G should convert to A in TAPs");
}
#[test]
fn test_non_target_bases_unchanged_top_strand() {
let ref_seq = b"AGTAGT";
let mut read = ref_seq.to_vec();
convert(&mut read, ref_seq, 0, true, MethylationMode::EmSeq, 0.0, 1.0, 42);
assert_eq!(read, b"AGTAGT");
}
#[test]
fn test_non_target_bases_unchanged_bottom_strand() {
let ref_seq = b"ACTACT";
let mut read = ref_seq.to_vec();
convert(&mut read, ref_seq, 0, false, MethylationMode::EmSeq, 0.0, 1.0, 42);
assert_eq!(read, b"ACTACT");
}
#[test]
fn test_disabled_mode_no_conversion() {
let ref_seq = b"ACGTACGT";
let mut read = ref_seq.to_vec();
convert(&mut read, ref_seq, 0, true, MethylationMode::Disabled, 0.0, 1.0, 42);
assert_eq!(read, ref_seq, "Disabled mode should not modify any bases");
}
#[test]
fn test_empty_sequence() {
let ref_seq = b"";
let mut read: Vec<u8> = vec![];
convert(&mut read, ref_seq, 0, true, MethylationMode::EmSeq, 0.75, 0.98, 42);
assert!(read.is_empty());
}
#[test]
fn test_ref_offset_nonzero() {
let ref_seq = b"AACGTAA";
let mut read = b"CGT".to_vec();
convert(&mut read, ref_seq, 2, true, MethylationMode::EmSeq, 0.0, 1.0, 42);
assert_eq!(read[0], b'T', "C at ref_offset=2 (CpG) should convert");
}
#[test]
fn test_conversion_rate_zero_no_conversion() {
let ref_seq = b"ACCTA";
let mut read = ref_seq.to_vec();
convert(&mut read, ref_seq, 0, true, MethylationMode::EmSeq, 0.0, 0.0, 42);
assert_eq!(read[1], b'C', "conversion_rate=0 should prevent conversion");
assert_eq!(read[2], b'C', "conversion_rate=0 should prevent conversion");
}
#[test]
fn test_probabilistic_emseq_cpg_partial_methylation() {
let ref_seq = b"CG"; let mut converted_count = 0;
let trials = 10_000;
for seed in 0..trials {
let mut read = ref_seq.to_vec();
convert(&mut read, ref_seq, 0, true, MethylationMode::EmSeq, 0.5, 1.0, seed);
if read[0] == b'T' {
converted_count += 1;
}
}
let fraction = converted_count as f64 / trials as f64;
assert!(
(fraction - 0.5).abs() < 0.05,
"Expected ~50% conversion at CpG with methylation_rate=0.5, got {fraction:.3}"
);
}
#[test]
fn test_probabilistic_emseq_non_cpg_partial_conversion_rate() {
let ref_seq = b"ACT"; let mut converted_count = 0;
let trials = 10_000;
for seed in 0..trials {
let mut read = ref_seq.to_vec();
convert(&mut read, ref_seq, 0, true, MethylationMode::EmSeq, 0.75, 0.5, seed);
if read[1] == b'T' {
converted_count += 1;
}
}
let fraction = converted_count as f64 / trials as f64;
assert!(
(fraction - 0.5).abs() < 0.05,
"Expected ~50% conversion with conversion_rate=0.5, got {fraction:.3}"
);
}
#[test]
fn test_conversion_rate_zero_leaves_bases_unchanged() {
let ref_seq = b"CACACACACACACACAC"; let mut read = ref_seq.to_vec();
convert(&mut read, ref_seq, 0, true, MethylationMode::EmSeq, 0.0, 0.0, 42);
assert_eq!(read, ref_seq, "conversion_rate=0 should leave all bases unchanged");
}
#[test]
fn test_conversion_rate_one_converts_all_targets() {
let ref_seq = b"CACACACACACACACAC"; let mut read = ref_seq.to_vec();
convert(&mut read, ref_seq, 0, true, MethylationMode::EmSeq, 0.0, 1.0, 42);
for (i, &b) in read.iter().enumerate() {
if ref_seq[i] == b'C' {
assert_eq!(b, b'T', "position {i}: C should be converted with rate=1.0");
} else {
assert_eq!(b, ref_seq[i], "position {i}: non-C should be unchanged");
}
}
}
#[test]
fn test_disabled_mode_never_converts() {
let ref_seq = b"CACGTCACGTCACGT";
let mut read = ref_seq.to_vec();
convert(&mut read, ref_seq, 0, true, MethylationMode::Disabled, 0.75, 1.0, 42);
assert_eq!(read, ref_seq, "Disabled mode should never modify bases");
}
fn write_multi_contig_fasta() -> NamedTempFile {
let mut f = NamedTempFile::new().unwrap();
writeln!(f, ">chr1").unwrap();
f.write_all(&b"ACGT".repeat(500)).unwrap(); writeln!(f).unwrap();
writeln!(f, ">chr2").unwrap();
f.write_all(&b"CCGG".repeat(375)).unwrap(); writeln!(f).unwrap();
writeln!(f, ">chr3").unwrap();
f.write_all(&b"AATT".repeat(450)).unwrap(); writeln!(f).unwrap();
writeln!(f, ">short_contig").unwrap();
f.write_all(&b"ACGT".repeat(125)).unwrap(); writeln!(f).unwrap();
f.flush().unwrap();
f
}
#[test]
fn test_build_bam_header_has_all_contigs() {
let fasta = write_multi_contig_fasta();
let genome = ReferenceGenome::load(fasta.path()).unwrap();
let header = genome.build_bam_header();
let ref_seqs: Vec<_> = header.reference_sequences().keys().collect();
assert_eq!(ref_seqs.len(), 3, "short_contig should be excluded");
let names: Vec<&str> =
ref_seqs.iter().map(|k| std::str::from_utf8(k.as_ref()).unwrap()).collect();
assert!(names.contains(&"chr1"));
assert!(names.contains(&"chr2"));
assert!(names.contains(&"chr3"));
assert!(!names.contains(&"short_contig"));
}
#[test]
fn test_build_bam_header_contig_lengths() {
let fasta = write_multi_contig_fasta();
let genome = ReferenceGenome::load(fasta.path()).unwrap();
let header = genome.build_bam_header();
let ref_seqs = header.reference_sequences();
let chr1_len: usize = ref_seqs.get(&bstr::BString::from("chr1")).unwrap().length().get();
let chr2_len: usize = ref_seqs.get(&bstr::BString::from("chr2")).unwrap().length().get();
let chr3_len: usize = ref_seqs.get(&bstr::BString::from("chr3")).unwrap().length().get();
assert_eq!(chr1_len, 2000);
assert_eq!(chr2_len, 1500);
assert_eq!(chr3_len, 1800);
}
#[test]
fn test_num_chromosomes() {
let fasta = write_multi_contig_fasta();
let genome = ReferenceGenome::load(fasta.path()).unwrap();
assert_eq!(genome.num_chromosomes(), 3);
}
#[test]
fn test_chromosome_length() {
let fasta = write_multi_contig_fasta();
let genome = ReferenceGenome::load(fasta.path()).unwrap();
assert_eq!(genome.chromosome_length(0), 2000);
assert_eq!(genome.chromosome_length(1), 1500);
assert_eq!(genome.chromosome_length(2), 1800);
}
#[test]
fn test_sample_positions_count_and_bounds() {
let fasta = write_multi_contig_fasta();
let genome = ReferenceGenome::load(fasta.path()).unwrap();
let mut rng = create_rng(Some(42));
let positions = genome.sample_positions(20, &mut rng);
assert_eq!(positions.len(), 20);
for (chrom_idx, local_pos) in &positions {
assert!(*chrom_idx < genome.num_chromosomes());
assert!(*local_pos < genome.chromosome_length(*chrom_idx));
}
}
#[test]
fn test_sample_positions_deterministic() {
let fasta = write_multi_contig_fasta();
let genome = ReferenceGenome::load(fasta.path()).unwrap();
let mut rng1 = create_rng(Some(99));
let mut rng2 = create_rng(Some(99));
let pos1 = genome.sample_positions(10, &mut rng1);
let pos2 = genome.sample_positions(10, &mut rng2);
assert_eq!(pos1, pos2);
}
#[test]
fn test_sample_positions_spans_chromosomes() {
let fasta = write_multi_contig_fasta();
let genome = ReferenceGenome::load(fasta.path()).unwrap();
let mut rng = create_rng(Some(7));
let positions = genome.sample_positions(100, &mut rng);
let unique_chroms: std::collections::HashSet<usize> =
positions.iter().map(|(c, _)| *c).collect();
assert!(
unique_chroms.len() >= 2,
"100 positions should span at least 2 chromosomes, got {}",
unique_chroms.len()
);
}
}