use std::io::Write;
use std::path::PathBuf;
use anyhow::{Result, bail};
use clap::Parser;
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use rand_distr::{Distribution, Geometric};
use super::command::Command;
use super::common::{BedOptions, ReferenceOptions, SeedOptions};
use crate::bed::TargetRegions;
use crate::fasta::Fasta;
use crate::ploidy::PloidyMap;
use crate::seed::resolve_seed;
const BASES: [u8; 4] = [b'A', b'C', b'G', b'T'];
#[derive(Parser, Debug)]
#[command(after_long_help = "EXAMPLES:\n \
holodeck mutate -r ref.fa -o muts.vcf --snp-rate 0.001\n \
holodeck mutate -r ref.fa -o muts.vcf --snp-rate 0.001 --ploidy 2 \
--ploidy-override chrX=1 --ploidy-override chrX:10001-2781479=2")]
pub struct Mutate {
#[command(flatten)]
pub reference: ReferenceOptions,
#[command(flatten)]
pub bed: BedOptions,
#[command(flatten)]
pub seed: SeedOptions,
#[arg(short = 'o', long, value_name = "VCF")]
pub output: PathBuf,
#[arg(long, default_value_t = 0.001, value_name = "FLOAT")]
pub snp_rate: f64,
#[arg(long, default_value_t = 0.0001, value_name = "FLOAT")]
pub indel_rate: f64,
#[arg(long, default_value_t = 0.00005, value_name = "FLOAT")]
pub mnp_rate: f64,
#[arg(long, default_value_t = 0.7, value_name = "FLOAT")]
pub indel_length_param: f64,
#[arg(long, default_value_t = 2.0, value_name = "FLOAT")]
pub het_hom_ratio: f64,
#[arg(long, default_value_t = 2, value_name = "INT")]
pub ploidy: u8,
#[arg(long, value_name = "SPEC")]
pub ploidy_override: Vec<String>,
}
impl Command for Mutate {
fn execute(&self) -> Result<()> {
self.validate()?;
self.run_mutation()
}
}
impl Mutate {
fn validate(&self) -> Result<()> {
if self.snp_rate < 0.0 || self.indel_rate < 0.0 || self.mnp_rate < 0.0 {
bail!("Individual mutation rates must be >= 0");
}
let total_rate = self.snp_rate + self.indel_rate + self.mnp_rate;
if !(0.0..=1.0).contains(&total_rate) {
bail!("Combined mutation rate must be in [0, 1], got {total_rate}");
}
if self.indel_length_param <= 0.0 || self.indel_length_param >= 1.0 {
bail!("--indel-length-param must be in (0, 1)");
}
if self.het_hom_ratio < 0.0 {
bail!("--het-hom-ratio must be >= 0");
}
if self.ploidy == 0 {
bail!("--ploidy must be >= 1");
}
Ok(())
}
fn run_mutation(&self) -> Result<()> {
let seed_desc = format!(
"mutate:{}:{}:{}:{}:{}",
self.reference.reference.display(),
self.snp_rate,
self.indel_rate,
self.mnp_rate,
self.ploidy,
);
let seed = resolve_seed(self.seed.seed, &seed_desc);
let mut rng = SmallRng::seed_from_u64(seed);
log::info!("Using random seed: {seed}");
let mut fasta = Fasta::from_path(&self.reference.reference)?;
let dict = fasta.dict().clone();
log::info!(
"Loaded reference with {} contigs, total {} bp",
dict.len(),
dict.total_length()
);
let targets = match &self.bed.targets {
Some(bed_path) => {
let t = TargetRegions::from_path(bed_path, &dict)?;
log::info!(
"Restricting mutations to {} bp of target territory",
t.total_territory()
);
Some(t)
}
None => None,
};
let ploidy_map = PloidyMap::new(self.ploidy, &self.ploidy_override)?;
let indel_dist = Geometric::new(self.indel_length_param)
.map_err(|e| anyhow::anyhow!("Invalid indel length distribution: {e}"))?;
let mut vcf_out = std::fs::File::create(&self.output)?;
Self::write_vcf_header(&mut vcf_out, &dict)?;
let total_rate = self.snp_rate + self.indel_rate + self.mnp_rate;
let mut total_variants = 0u64;
let contig_names: Vec<String> = dict.names().into_iter().map(String::from).collect();
for contig_name in &contig_names {
let reference = fasta.load_contig(contig_name)?;
let contig_idx = dict.get_by_name(contig_name).unwrap().index();
let mut pos = 0u32;
#[expect(clippy::cast_possible_truncation, reason = "contig length fits u32")]
let contig_len = reference.len() as u32;
while pos < contig_len {
if !BASES.contains(&reference[pos as usize]) {
pos += 1;
continue;
}
if let Some(tgt) = &targets
&& !tgt.overlaps(contig_idx, pos, pos + 1)
{
pos += 1;
continue;
}
if rng.random::<f64>() >= total_rate {
pos += 1;
continue;
}
let ploidy = ploidy_map.ploidy_at(contig_name, pos);
let type_roll: f64 = rng.random::<f64>() * total_rate;
let (ref_allele, alt_allele, advance) = if type_roll < self.snp_rate {
generate_snp(reference[pos as usize], &mut rng)
} else if type_roll < self.snp_rate + self.indel_rate {
generate_indel(&reference, pos, contig_len, &indel_dist, &mut rng)
} else {
if pos + 2 > contig_len {
pos += 1;
continue;
}
generate_mnp(&reference, pos, contig_len, &mut rng)
};
let gt = generate_genotype(ploidy, self.het_hom_ratio, &mut rng);
writeln!(
vcf_out,
"{contig_name}\t{}\t.\t{}\t{}\t100\tPASS\t.\tGT\t{gt}",
pos + 1,
String::from_utf8_lossy(&ref_allele),
String::from_utf8_lossy(&alt_allele),
)?;
total_variants += 1;
pos += advance;
}
}
log::info!("Generated {total_variants} variants");
Ok(())
}
fn write_vcf_header(
out: &mut std::fs::File,
dict: &crate::sequence_dict::SequenceDictionary,
) -> Result<()> {
writeln!(out, "##fileformat=VCFv4.3")?;
writeln!(out, "##source=holodeck-mutate")?;
for meta in dict.iter() {
writeln!(out, "##contig=<ID={},length={}>", meta.name(), meta.length())?;
}
writeln!(out, "##FORMAT=<ID=GT,Number=1,Type=String,Description=\"Genotype\">")?;
writeln!(out, "#CHROM\tPOS\tID\tREF\tALT\tQUAL\tFILTER\tINFO\tFORMAT\tSAMPLE")?;
Ok(())
}
}
fn generate_snp(ref_base: u8, rng: &mut impl Rng) -> (Vec<u8>, Vec<u8>, u32) {
let alt = loop {
let candidate = BASES[rng.random_range(0..4)];
if candidate != ref_base {
break candidate;
}
};
(vec![ref_base], vec![alt], 1)
}
fn generate_indel(
reference: &[u8],
pos: u32,
contig_len: u32,
indel_dist: &Geometric,
rng: &mut impl Rng,
) -> (Vec<u8>, Vec<u8>, u32) {
#[expect(
clippy::cast_possible_truncation,
reason = "geometric distribution produces small values"
)]
let length = (indel_dist.sample(rng) as u32).max(1);
let is_insertion: bool = rng.random();
let anchor = reference[pos as usize];
if is_insertion {
let mut alt = vec![anchor];
for _ in 0..length {
alt.push(BASES[rng.random_range(0..4)]);
}
(vec![anchor], alt, 1)
} else {
let available = contig_len.saturating_sub(pos + 1);
if available == 0 {
let mut alt = vec![anchor];
for _ in 0..length {
alt.push(BASES[rng.random_range(0..4)]);
}
return (vec![anchor], alt, 1);
}
let del_len = length.min(available);
let del_end = pos + 1 + del_len;
let ref_allele: Vec<u8> = reference[pos as usize..del_end as usize].to_vec();
#[expect(clippy::cast_possible_truncation, reason = "allele length fits u32")]
let advance = ref_allele.len() as u32;
(ref_allele, vec![anchor], advance)
}
}
fn generate_mnp(
reference: &[u8],
pos: u32,
contig_len: u32,
rng: &mut impl Rng,
) -> (Vec<u8>, Vec<u8>, u32) {
let length = rng.random_range(2..=3u32).min(contig_len - pos);
let ref_allele: Vec<u8> = reference[pos as usize..(pos + length) as usize].to_vec();
let mut alt_allele = ref_allele.clone();
for base in &mut alt_allele {
if rng.random::<f64>() < 0.8 {
let original = *base;
*base = loop {
let candidate = BASES[rng.random_range(0..4)];
if candidate != original {
break candidate;
}
};
}
}
if alt_allele == ref_allele {
let idx = rng.random_range(0..alt_allele.len());
let original = alt_allele[idx];
alt_allele[idx] = loop {
let candidate = BASES[rng.random_range(0..4)];
if candidate != original {
break candidate;
}
};
}
(ref_allele, alt_allele, length)
}
fn generate_genotype(ploidy: u8, het_hom_ratio: f64, rng: &mut impl Rng) -> String {
let p_het = het_hom_ratio / (1.0 + het_hom_ratio);
let is_het = ploidy > 1 && rng.random::<f64>() < p_het;
if ploidy == 1 {
"1".to_string()
} else if is_het {
let mut alleles: Vec<&str> = vec!["0"; ploidy as usize];
let alt_hap = rng.random_range(0..ploidy as usize);
alleles[alt_hap] = "1";
alleles.join("/")
} else {
vec!["1"; ploidy as usize].join("/")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate_snp() {
let mut rng = rand::rng();
for _ in 0..100 {
let (ref_a, alt_a, advance) = generate_snp(b'A', &mut rng);
assert_eq!(ref_a, vec![b'A']);
assert_eq!(alt_a.len(), 1);
assert_ne!(alt_a[0], b'A');
assert_eq!(advance, 1);
}
}
#[test]
fn test_generate_indel_insertion() {
let reference = b"ACGTACGT";
let indel_dist = Geometric::new(0.7).unwrap();
let mut rng = SmallRng::seed_from_u64(42);
let mut found_insertion = false;
for _ in 0..100 {
let (ref_a, alt_a, advance) = generate_indel(reference, 2, 8, &indel_dist, &mut rng);
if alt_a.len() > ref_a.len() {
found_insertion = true;
assert_eq!(ref_a.len(), 1); assert!(alt_a.len() >= 2); assert_eq!(ref_a[0], alt_a[0]); assert_eq!(advance, 1);
}
}
assert!(found_insertion, "Should have generated at least one insertion");
}
#[test]
fn test_generate_indel_deletion() {
let reference = b"ACGTACGT";
let indel_dist = Geometric::new(0.7).unwrap();
let mut rng = SmallRng::seed_from_u64(99);
let mut found_deletion = false;
for _ in 0..100 {
let (ref_a, alt_a, _advance) = generate_indel(reference, 2, 8, &indel_dist, &mut rng);
if ref_a.len() > alt_a.len() {
found_deletion = true;
assert_eq!(alt_a.len(), 1); assert!(ref_a.len() >= 2); assert_eq!(ref_a[0], alt_a[0]); }
}
assert!(found_deletion, "Should have generated at least one deletion");
}
#[test]
fn test_generate_mnp() {
let reference = b"ACGTACGT";
let mut rng = rand::rng();
for _ in 0..100 {
let (ref_a, alt_a, advance) = generate_mnp(reference, 1, 8, &mut rng);
assert!(ref_a.len() >= 2 && ref_a.len() <= 3);
assert_eq!(ref_a.len(), alt_a.len());
assert_ne!(ref_a, alt_a);
assert_eq!(advance as usize, ref_a.len());
}
}
#[test]
fn test_generate_genotype_haploid() {
let mut rng = rand::rng();
let gt = generate_genotype(1, 2.0, &mut rng);
assert_eq!(gt, "1");
}
#[test]
fn test_generate_genotype_diploid_het() {
let mut rng = rand::rng();
let mut het_count = 0;
let mut hom_count = 0;
for _ in 0..1000 {
let gt = generate_genotype(2, 2.0, &mut rng);
if gt == "0/1" || gt == "1/0" {
het_count += 1;
} else if gt == "1/1" {
hom_count += 1;
}
}
assert!(het_count > 500, "expected majority het, got {het_count}");
assert!(hom_count > 200, "expected some hom, got {hom_count}");
}
#[test]
fn test_generate_genotype_triploid() {
let mut rng = rand::rng();
let gt = generate_genotype(3, 0.0, &mut rng);
assert_eq!(gt, "1/1/1");
}
}