use super::sort::TemplateCoordKey;
use clap::Args;
use rand::{Rng, RngExt};
use std::path::PathBuf;
#[derive(Args, Debug, Clone)]
pub struct SimulationCommon {
#[arg(long = "seed")]
pub seed: Option<u64>,
#[arg(short = 'n', long = "num-molecules", default_value = "1000")]
pub num_molecules: usize,
#[arg(short = 'l', long = "read-length", default_value = "150")]
pub read_length: usize,
#[arg(short = 'u', long = "umi-length", default_value = "8")]
pub umi_length: usize,
}
#[derive(Args, Debug, Clone)]
pub struct QualityArgs {
#[arg(long = "warmup-bases", default_value = "10")]
pub warmup_bases: usize,
#[arg(long = "warmup-quality", default_value = "25")]
pub warmup_quality: u8,
#[arg(long = "peak-quality", default_value = "37")]
pub peak_quality: u8,
#[arg(long = "decay-start", default_value = "100")]
pub decay_start: usize,
#[arg(long = "decay-rate", default_value = "0.08")]
pub decay_rate: f64,
#[arg(long = "quality-noise", default_value = "2.0", value_parser = parse_noise_stddev)]
pub quality_noise: f64,
#[arg(long = "r2-quality-offset", default_value = "-2", allow_hyphen_values = true)]
pub r2_quality_offset: i8,
}
fn parse_noise_stddev(s: &str) -> Result<f64, String> {
let val: f64 = s.parse().map_err(|e| format!("invalid float: {e}"))?;
if !val.is_finite() || val < 0.0 {
return Err(format!("quality-noise must be finite and >= 0.0, got {val}"));
}
Ok(val)
}
impl QualityArgs {
pub fn to_quality_model(&self) -> crate::simulate::PositionQualityModel {
crate::simulate::PositionQualityModel::new(
self.warmup_bases,
self.warmup_quality,
self.peak_quality,
self.decay_start,
self.decay_rate,
2, self.quality_noise,
)
}
pub fn to_quality_bias(&self) -> crate::simulate::ReadPairQualityBias {
crate::simulate::ReadPairQualityBias::new(self.r2_quality_offset)
}
}
#[derive(Args, Debug, Clone)]
pub struct InsertSizeArgs {
#[arg(long = "insert-size-mean", default_value = "300.0")]
pub insert_size_mean: f64,
#[arg(long = "insert-size-stddev", default_value = "50.0")]
pub insert_size_stddev: f64,
#[arg(long = "insert-size-min", default_value = "50")]
pub insert_size_min: usize,
#[arg(long = "insert-size-max", default_value = "800")]
pub insert_size_max: usize,
}
impl InsertSizeArgs {
pub fn to_insert_size_model(&self) -> crate::simulate::InsertSizeModel {
crate::simulate::InsertSizeModel::new(
self.insert_size_mean,
self.insert_size_stddev,
self.insert_size_min,
self.insert_size_max,
)
}
}
#[derive(Args, Debug, Clone)]
pub struct FamilySizeArgs {
#[arg(long = "family-size-dist", default_value = "lognormal")]
pub family_size_dist: String,
#[arg(long = "family-size-mean", default_value = "3.0")]
pub family_size_mean: f64,
#[arg(long = "family-size-stddev", default_value = "2.0")]
pub family_size_stddev: f64,
#[arg(long = "family-size-r", default_value = "2.0")]
pub family_size_r: f64,
#[arg(long = "family-size-p", default_value = "0.5")]
pub family_size_p: f64,
#[arg(long = "min-family-size", default_value = "1")]
pub min_family_size: usize,
}
impl FamilySizeArgs {
pub fn to_family_size_distribution(
&self,
) -> anyhow::Result<crate::simulate::FamilySizeDistribution> {
match self.family_size_dist.as_str() {
"lognormal" => Ok(crate::simulate::FamilySizeDistribution::log_normal(
self.family_size_mean,
self.family_size_stddev,
)),
"negbin" => Ok(crate::simulate::FamilySizeDistribution::negative_binomial(
self.family_size_r,
self.family_size_p,
)),
path => {
crate::simulate::FamilySizeDistribution::from_histogram(path)
}
}
}
}
#[derive(Args, Debug, Clone)]
pub struct StrandBiasArgs {
#[arg(long = "strand-alpha", default_value = "5.0")]
pub strand_alpha: f64,
#[arg(long = "strand-beta", default_value = "5.0")]
pub strand_beta: f64,
}
impl StrandBiasArgs {
pub fn to_strand_bias_model(&self) -> crate::simulate::StrandBiasModel {
crate::simulate::StrandBiasModel::new(self.strand_alpha, self.strand_beta)
}
}
#[derive(Args, Debug, Clone)]
pub struct ReferenceArgs {
#[arg(short = 'r', long = "reference")]
pub reference: Option<PathBuf>,
#[arg(long = "ref-name", default_value = "chr1")]
pub ref_name: String,
#[arg(long = "ref-length", default_value = "250000000")]
pub ref_length: usize,
}
#[derive(Args, Debug, Clone)]
pub struct PositionDistArgs {
#[arg(long = "num-positions")]
pub num_positions: Option<usize>,
#[arg(long = "umis-per-position", default_value = "1")]
pub umis_per_position: usize,
}
pub(super) fn generate_random_sequence(len: usize, rng: &mut impl Rng) -> Vec<u8> {
const BASES: &[u8] = b"ACGT";
let mut seq = Vec::with_capacity(len);
for _ in 0..len {
seq.push(BASES[rng.random_range(0..4)]);
}
seq
}
pub(super) fn pad_sequence(mut seq: Vec<u8>, target_len: usize, rng: &mut impl Rng) -> Vec<u8> {
while seq.len() < target_len {
seq.push(b"ACGT"[rng.random_range(0..4)]);
}
seq.truncate(target_len);
seq
}
#[inline]
pub(super) fn compute_position(mol_id: usize, num_positions: usize, ref_length: usize) -> usize {
let fallback = ref_length.saturating_sub(1).min(100);
if num_positions == 0 {
return fallback;
}
let usable_span = ref_length.saturating_sub(1000);
if usable_span == 0 {
return fallback;
}
let position_idx = mol_id % num_positions;
((position_idx as f64 / num_positions as f64) * usable_span as f64) as usize + fallback
}
#[derive(Debug)]
pub(super) struct MoleculeInfo {
pub mol_id: usize,
pub seed: u64,
pub sort_key: TemplateCoordKey,
}
impl Ord for MoleculeInfo {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.sort_key.cmp(&other.sort_key)
}
}
impl PartialOrd for MoleculeInfo {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl PartialEq for MoleculeInfo {
fn eq(&self, other: &Self) -> bool {
self.sort_key == other.sort_key
}
}
impl Eq for MoleculeInfo {}
#[cfg(test)]
mod tests {
use super::*;
use crate::simulate::create_rng;
use rstest::rstest;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_quality_args_default_model() {
let args = QualityArgs {
warmup_bases: 10,
warmup_quality: 25,
peak_quality: 37,
decay_start: 100,
decay_rate: 0.08,
quality_noise: 2.0,
r2_quality_offset: -2,
};
let model = args.to_quality_model();
assert_eq!(model.warmup_bases, 10);
assert_eq!(model.warmup_start, 25);
assert_eq!(model.peak_quality, 37);
assert_eq!(model.decay_start, 100);
assert!((model.decay_rate - 0.08).abs() < f64::EPSILON);
}
#[test]
fn test_quality_args_custom_values() {
let args = QualityArgs {
warmup_bases: 5,
warmup_quality: 20,
peak_quality: 40,
decay_start: 80,
decay_rate: 0.1,
quality_noise: 1.5,
r2_quality_offset: -3,
};
let model = args.to_quality_model();
assert_eq!(model.warmup_bases, 5);
assert_eq!(model.peak_quality, 40);
let bias = args.to_quality_bias();
assert_eq!(bias.r2_offset, -3);
}
#[test]
fn test_quality_bias_positive_offset() {
let args = QualityArgs {
warmup_bases: 10,
warmup_quality: 25,
peak_quality: 37,
decay_start: 100,
decay_rate: 0.08,
quality_noise: 2.0,
r2_quality_offset: 3, };
let bias = args.to_quality_bias();
assert_eq!(bias.apply(30, true), 33);
}
#[test]
fn test_insert_size_args_default() {
let args = InsertSizeArgs {
insert_size_mean: 300.0,
insert_size_stddev: 50.0,
insert_size_min: 50,
insert_size_max: 800,
};
let model = args.to_insert_size_model();
assert!((model.mean - 300.0).abs() < f64::EPSILON);
assert_eq!(model.min, 50);
assert_eq!(model.max, 800);
}
#[test]
fn test_insert_size_args_narrow_range() {
let args = InsertSizeArgs {
insert_size_mean: 200.0,
insert_size_stddev: 10.0,
insert_size_min: 180,
insert_size_max: 220,
};
let model = args.to_insert_size_model();
let mut rng = create_rng(Some(42));
for _ in 0..100 {
let size = model.sample(&mut rng);
assert!((180..=220).contains(&size));
}
}
#[test]
fn test_family_size_args_lognormal() {
let args = FamilySizeArgs {
family_size_dist: "lognormal".to_string(),
family_size_mean: 3.0,
family_size_stddev: 2.0,
family_size_r: 2.0,
family_size_p: 0.5,
min_family_size: 1,
};
let dist =
args.to_family_size_distribution().expect("lognormal distribution should be created");
let mut rng = create_rng(Some(42));
let size = dist.sample(&mut rng, 1);
assert!(size >= 1);
}
#[test]
fn test_family_size_args_negbin() {
let args = FamilySizeArgs {
family_size_dist: "negbin".to_string(),
family_size_mean: 3.0,
family_size_stddev: 2.0,
family_size_r: 2.0,
family_size_p: 0.5,
min_family_size: 1,
};
let dist =
args.to_family_size_distribution().expect("negbin distribution should be created");
let mut rng = create_rng(Some(42));
let size = dist.sample(&mut rng, 1);
assert!(size >= 1);
}
#[test]
fn test_family_size_args_from_histogram() -> anyhow::Result<()> {
let mut temp = NamedTempFile::new()?;
writeln!(temp, "family_size\tcount")?;
writeln!(temp, "1\t50")?;
writeln!(temp, "2\t30")?;
writeln!(temp, "3\t20")?;
temp.flush()?;
let args = FamilySizeArgs {
family_size_dist: temp.path().to_string_lossy().to_string(),
family_size_mean: 3.0,
family_size_stddev: 2.0,
family_size_r: 2.0,
family_size_p: 0.5,
min_family_size: 1,
};
let dist = args.to_family_size_distribution()?;
let mut rng = create_rng(Some(42));
for _ in 0..100 {
let size = dist.sample(&mut rng, 1);
assert!((1..=3).contains(&size));
}
Ok(())
}
#[test]
fn test_family_size_args_invalid_histogram() {
let args = FamilySizeArgs {
family_size_dist: "/nonexistent/path/histogram.tsv".to_string(),
family_size_mean: 3.0,
family_size_stddev: 2.0,
family_size_r: 2.0,
family_size_p: 0.5,
min_family_size: 1,
};
let result = args.to_family_size_distribution();
assert!(result.is_err());
}
#[test]
fn test_strand_bias_args_symmetric() {
let args = StrandBiasArgs { strand_alpha: 5.0, strand_beta: 5.0 };
let model = args.to_strand_bias_model();
assert!((model.alpha - 5.0).abs() < f64::EPSILON);
assert!((model.beta - 5.0).abs() < f64::EPSILON);
let mut rng = create_rng(Some(42));
let fractions: Vec<f64> = (0..1000).map(|_| model.sample_a_fraction(&mut rng)).collect();
let mean: f64 = fractions.iter().sum::<f64>() / fractions.len() as f64;
assert!((mean - 0.5).abs() < 0.05);
}
#[test]
fn test_strand_bias_args_a_biased() {
let args = StrandBiasArgs { strand_alpha: 8.0, strand_beta: 2.0 };
let model = args.to_strand_bias_model();
let mut rng = create_rng(Some(42));
let fractions: Vec<f64> = (0..1000).map(|_| model.sample_a_fraction(&mut rng)).collect();
let mean: f64 = fractions.iter().sum::<f64>() / fractions.len() as f64;
assert!(mean > 0.7);
}
#[test]
fn test_strand_bias_args_b_biased() {
let args = StrandBiasArgs { strand_alpha: 2.0, strand_beta: 8.0 };
let model = args.to_strand_bias_model();
let mut rng = create_rng(Some(42));
let fractions: Vec<f64> = (0..1000).map(|_| model.sample_a_fraction(&mut rng)).collect();
let mean: f64 = fractions.iter().sum::<f64>() / fractions.len() as f64;
assert!(mean < 0.3);
}
#[test]
fn test_quality_args_zero_warmup() {
let args = QualityArgs {
warmup_bases: 0,
warmup_quality: 25,
peak_quality: 37,
decay_start: 100,
decay_rate: 0.08,
quality_noise: 2.0,
r2_quality_offset: -2,
};
let model = args.to_quality_model();
assert_eq!(model.warmup_bases, 0);
}
#[test]
fn test_quality_args_high_peak_quality() {
let args = QualityArgs {
warmup_bases: 10,
warmup_quality: 30,
peak_quality: 41, decay_start: 100,
decay_rate: 0.08,
quality_noise: 0.0, r2_quality_offset: 0,
};
let model = args.to_quality_model();
let mut rng = create_rng(Some(42));
let quals = model.generate_qualities(50, &mut rng);
assert!(!quals.is_empty());
}
#[test]
fn test_insert_size_args_min_equals_max() {
let args = InsertSizeArgs {
insert_size_mean: 200.0,
insert_size_stddev: 50.0,
insert_size_min: 200,
insert_size_max: 200,
};
let model = args.to_insert_size_model();
let mut rng = create_rng(Some(42));
for _ in 0..10 {
let size = model.sample(&mut rng);
assert_eq!(size, 200);
}
}
#[test]
fn test_family_size_args_high_min() {
let args = FamilySizeArgs {
family_size_dist: "lognormal".to_string(),
family_size_mean: 3.0,
family_size_stddev: 2.0,
family_size_r: 2.0,
family_size_p: 0.5,
min_family_size: 5, };
let dist = args
.to_family_size_distribution()
.expect("lognormal distribution with high min should be created");
let mut rng = create_rng(Some(42));
for _ in 0..100 {
let size = dist.sample(&mut rng, 5);
assert!(size >= 5, "Size {size} should be >= min 5");
}
}
#[test]
fn test_strand_bias_split_reads() {
let args = StrandBiasArgs { strand_alpha: 5.0, strand_beta: 5.0 };
let model = args.to_strand_bias_model();
let mut rng = create_rng(Some(42));
for total in [0, 1, 2, 5, 10, 100] {
let (a, b) = model.split_reads(total, &mut rng);
assert_eq!(a + b, total, "A ({a}) + B ({b}) should equal total ({total})");
}
}
#[test]
fn test_strand_bias_split_zero_total() {
let args = StrandBiasArgs { strand_alpha: 5.0, strand_beta: 5.0 };
let model = args.to_strand_bias_model();
let mut rng = create_rng(Some(42));
let (a, b) = model.split_reads(0, &mut rng);
assert_eq!(a, 0);
assert_eq!(b, 0);
}
#[test]
fn test_strand_bias_split_one_read() {
let args = StrandBiasArgs { strand_alpha: 5.0, strand_beta: 5.0 };
let model = args.to_strand_bias_model();
let mut rng = create_rng(Some(42));
let (a, b) = model.split_reads(1, &mut rng);
assert_eq!(a + b, 1);
assert!(a <= 1 && b <= 1);
}
#[rstest]
#[case("0.0", 0.0)]
#[case("2.5", 2.5)]
#[case("100.0", 100.0)]
fn test_parse_noise_stddev_accepts_valid(#[case] input: &str, #[case] expected: f64) {
let parsed = parse_noise_stddev(input).expect("valid noise stddev should parse");
assert!((parsed - expected).abs() < f64::EPSILON);
}
#[rstest]
#[case("-0.1")]
#[case("NaN")]
#[case("inf")]
#[case("-inf")]
#[case("abc")]
fn test_parse_noise_stddev_rejects_invalid(#[case] input: &str) {
assert!(parse_noise_stddev(input).is_err(), "input should be rejected: {input}");
}
#[test]
fn test_quality_args_zero_noise() {
let args = QualityArgs {
warmup_bases: 10,
warmup_quality: 25,
peak_quality: 37,
decay_start: 100,
decay_rate: 0.08,
quality_noise: 0.0, r2_quality_offset: 0,
};
let model = args.to_quality_model();
assert!((model.noise_stddev - 0.0).abs() < f64::EPSILON);
}
#[test]
fn test_insert_size_sample_distribution() {
let args = InsertSizeArgs {
insert_size_mean: 300.0,
insert_size_stddev: 50.0,
insert_size_min: 50,
insert_size_max: 800,
};
let model = args.to_insert_size_model();
let mut rng = create_rng(Some(42));
let samples: Vec<usize> = (0..1000).map(|_| model.sample(&mut rng)).collect();
let mean: f64 = samples.iter().map(|&s| s as f64).sum::<f64>() / samples.len() as f64;
assert!(mean > 280.0 && mean < 320.0, "Mean {mean} not close to expected 300");
}
#[test]
fn test_family_size_distribution_type_case_insensitive() {
let args_lower = FamilySizeArgs {
family_size_dist: "lognormal".to_string(),
family_size_mean: 3.0,
family_size_stddev: 2.0,
family_size_r: 2.0,
family_size_p: 0.5,
min_family_size: 1,
};
let _ = args_lower
.to_family_size_distribution()
.expect("lowercase distribution name should be accepted");
}
#[rstest]
#[case(5, 0, 250_000_000, 0, 250_000_000)]
#[case(0, 10, 500, 0, 500)]
#[case(0, 10, 50, 0, 50)]
#[case(0, 10, 10_000, 0, 10_000)]
#[case(9, 10, 10_000, 101, 10_000)]
fn test_compute_position(
#[case] mol_id: usize,
#[case] num_positions: usize,
#[case] ref_length: usize,
#[case] min_expected: usize,
#[case] max_expected: usize,
) {
let pos = compute_position(mol_id, num_positions, ref_length);
assert!(
pos >= min_expected && pos < max_expected,
"compute_position({mol_id}, {num_positions}, {ref_length}) = {pos}, expected [{min_expected}, {max_expected})"
);
}
fn make_molecule_info(mol_id: usize, tid1: i32, pos1: i64) -> MoleculeInfo {
MoleculeInfo {
mol_id,
seed: 0,
sort_key: TemplateCoordKey {
tid1,
tid2: 0,
pos1,
pos2: 0,
neg1: false,
neg2: false,
mid: String::new(),
name: String::new(),
is_upper_of_pair: false,
},
}
}
#[test]
fn test_molecule_info_ordering() {
let a = make_molecule_info(0, 1, 100);
let b = make_molecule_info(1, 1, 200);
let c = make_molecule_info(2, 2, 50);
assert!(a < b);
assert!(b > a);
assert!(b < c);
assert_eq!(a.partial_cmp(&b), Some(std::cmp::Ordering::Less));
let a2 = make_molecule_info(99, 1, 100);
assert_eq!(a, a2);
assert_eq!(a.cmp(&a2), std::cmp::Ordering::Equal);
}
}