Skip to main content

holodeck_lib/commands/
simulate.rs

1use std::fs::File;
2use std::io::BufWriter;
3
4use anyhow::{Result, bail};
5use clap::Parser;
6use pooled_writer::PoolBuilder;
7use pooled_writer::bgzf::BgzfCompressor;
8use rand::rngs::SmallRng;
9use rand::{Rng, SeedableRng};
10use rand_distr::{Distribution, Normal};
11
12use super::command::{Command, output_path};
13use super::common::{BedOptions, OutputPrefixOptions, ReferenceOptions, SeedOptions, VcfOptions};
14use crate::bed::{PaddedIntervalSampler, TargetRegions};
15use crate::error_model::illumina::IlluminaErrorModel;
16use crate::fasta::Fasta;
17use crate::fragment::extract_fragment;
18use crate::haplotype::build_haplotypes;
19use crate::output::fastq::FastqWriter;
20use crate::output::golden_bam::{GoldenBamMetadata, GoldenBamWriter};
21use crate::read::generate_read_pair;
22use crate::seed::resolve_seed;
23use crate::sequence_dict::SequenceDictionary;
24use crate::version::VERSION;
25
26/// Default sample name used in the golden BAM `@RG SM` field when the
27/// simulation is not driven by a VCF sample.
28const DEFAULT_SAMPLE_NAME: &str = "holodeck-simulation";
29
30/// Default Illumina TruSeq adapter sequence for read 1.
31const DEFAULT_ADAPTER_R1: &str = "AGATCGGAAGAGCACACGTCTGAACTCCAGTCA";
32
33/// Default Illumina TruSeq adapter sequence for read 2.
34const DEFAULT_ADAPTER_R2: &str = "AGATCGGAAGAGCGTCGTGTAGGGAAAGAGTGT";
35
36/// Simulate sequencing reads from a reference genome.
37///
38/// Generates paired-end or single-end FASTQ files with optional ground-truth
39/// BAM and VCF outputs for benchmarking alignment and variant calling pipelines.
40/// Variants are applied from an input VCF to construct haplotype sequences, and
41/// reads are sampled with a position-dependent Illumina error model.
42#[derive(Parser, Debug)]
43#[command(after_long_help = "EXAMPLES:\n  \
44    holodeck simulate -r ref.fa -o out --coverage 30\n  \
45    holodeck simulate -r ref.fa -v vars.vcf -o out --coverage 30 --golden-bam\n  \
46    holodeck simulate -r ref.fa -v vars.vcf -b targets.bed -o out --coverage 100")]
47#[allow(clippy::struct_excessive_bools)] // CLI flags are naturally boolean
48pub struct Simulate {
49    #[command(flatten)]
50    pub reference: ReferenceOptions,
51
52    #[command(flatten)]
53    pub vcf: VcfOptions,
54
55    #[command(flatten)]
56    pub bed: BedOptions,
57
58    #[command(flatten)]
59    pub output: OutputPrefixOptions,
60
61    #[command(flatten)]
62    pub seed: SeedOptions,
63
64    /// Mean sequencing coverage depth.
65    #[arg(short = 'c', long, default_value_t = 30.0, value_name = "FLOAT")]
66    pub coverage: f64,
67
68    /// Length of each read in bases.
69    #[arg(short = 'l', long, default_value_t = 150, value_name = "INT")]
70    pub read_length: usize,
71
72    /// Mean insert size (outer distance between read pair ends).
73    #[arg(short = 'd', long, default_value_t = 300, value_name = "INT")]
74    pub fragment_mean: usize,
75
76    /// Standard deviation of the insert size distribution.
77    #[arg(short = 's', long, default_value_t = 50, value_name = "INT")]
78    pub fragment_stddev: usize,
79
80    /// Minimum fragment length in bases; sampled lengths below this are clamped
81    /// up.  Fragments shorter than the read length are padded with adapter
82    /// sequence.  Must be at least 1.
83    #[arg(long, default_value_t = 20, value_name = "INT")]
84    pub min_fragment_length: usize,
85
86    /// Generate single-end reads instead of paired-end.
87    #[arg(long)]
88    pub single_end: bool,
89
90    /// Adapter sequence appended to read 1 when the fragment is shorter than
91    /// the read length.
92    #[arg(long, default_value = DEFAULT_ADAPTER_R1, value_name = "SEQ")]
93    pub adapter_r1: String,
94
95    /// Adapter sequence appended to read 2 when the fragment is shorter than
96    /// the read length.
97    #[arg(long, default_value = DEFAULT_ADAPTER_R2, value_name = "SEQ")]
98    pub adapter_r2: String,
99
100    /// Minimum per-base error rate, applied at the start of reads.
101    #[arg(long, default_value_t = 0.001, value_name = "FLOAT")]
102    pub min_error_rate: f64,
103
104    /// Maximum per-base error rate, applied at the end of reads.
105    #[arg(long, default_value_t = 0.01, value_name = "FLOAT")]
106    pub max_error_rate: f64,
107
108    /// Write a ground-truth BAM file with correct alignments.
109    #[arg(long)]
110    pub golden_bam: bool,
111
112    /// Write a ground-truth VCF annotated with simulated coverage.
113    #[arg(long)]
114    pub golden_vcf: bool,
115
116    /// Use simple read names (`holodeck:N`) instead of encoding truth
117    /// coordinates in the read name.
118    #[arg(long)]
119    pub simple_names: bool,
120
121    /// BGZF compression level (0-12). Lower values are faster with larger
122    /// output files; higher values produce smaller files at the cost of speed.
123    /// Level 0 is no compression; 1 is fastest; 12 is maximum compression.
124    #[arg(long, default_value_t = 1, value_name = "INT")]
125    pub compression: u8,
126
127    /// Number of threads for parallel BGZF output compression.
128    #[arg(short = 't', long, default_value_t = 4, value_name = "INT")]
129    pub threads: usize,
130}
131
132impl Command for Simulate {
133    fn execute(&self) -> Result<()> {
134        let resolved_sample = self.validate()?;
135        self.run_simulation(resolved_sample.as_deref())
136    }
137}
138
139impl Simulate {
140    /// Validate command-line arguments before running and return the
141    /// resolved VCF sample name (if a VCF was provided).  Resolving the
142    /// sample here means the VCF header is read only once per run.
143    fn validate(&self) -> Result<Option<String>> {
144        if !self.coverage.is_finite() || self.coverage <= 0.0 {
145            bail!("--coverage must be a finite positive number");
146        }
147        if self.read_length == 0 {
148            bail!("--read-length must be > 0");
149        }
150        if self.min_fragment_length == 0 {
151            bail!("--min-fragment-length must be at least 1");
152        }
153        if self.min_error_rate < 0.0 || self.max_error_rate < 0.0 {
154            bail!("Error rates must be >= 0");
155        }
156        if self.min_error_rate > self.max_error_rate {
157            bail!("--min-error-rate must be <= --max-error-rate");
158        }
159        if self.compression > 12 {
160            bail!("--compression must be between 0 and 12");
161        }
162
163        // --sample without --vcf is nonsensical.
164        if self.vcf.sample.is_some() && self.vcf.vcf.is_none() {
165            bail!("--sample requires --vcf");
166        }
167
168        // Validate VCF sample configuration upfront so the user gets a clear
169        // error before the simulation loop starts, and capture the resolved
170        // sample name for use in downstream metadata (e.g. the golden BAM
171        // `@RG` line).
172        let resolved_sample = if let Some(vcf_path) = &self.vcf.vcf {
173            Some(crate::vcf::validate_vcf_sample(vcf_path, self.vcf.sample.as_deref())?)
174        } else {
175            None
176        };
177
178        // Validate output parent directory exists.
179        if let Some(parent) = self.output.output.parent()
180            && !parent.as_os_str().is_empty()
181            && !parent.exists()
182        {
183            bail!("Output directory does not exist: {}", parent.display());
184        }
185
186        Ok(resolved_sample)
187    }
188
189    /// Run the main simulation pipeline.
190    ///
191    /// `resolved_vcf_sample` is the sample name resolved from the VCF during
192    /// validation (if any), used for the golden BAM `@RG SM` field.
193    fn run_simulation(&self, resolved_vcf_sample: Option<&str>) -> Result<()> {
194        let seed = self.compute_seed();
195        let mut rng = SmallRng::seed_from_u64(seed);
196        log::info!("Using random seed: {seed}");
197
198        let mut fasta = Fasta::from_path(&self.reference.reference)?;
199        let dict = fasta.dict().clone();
200        log::info!(
201            "Loaded reference with {} contigs, total {} bp",
202            dict.len(),
203            dict.total_length()
204        );
205
206        let targets = self.load_targets(&dict)?;
207        let effective_size = targets
208            .as_ref()
209            .map_or(dict.total_length(), |t| t.effective_territory(self.fragment_mean));
210        if effective_size == 0 {
211            bail!("Effective genome size is 0; nothing to simulate");
212        }
213
214        let total_reads = self.compute_total_reads(effective_size);
215        log::info!("Will generate {total_reads} read pairs for {:.1}x coverage", self.coverage);
216
217        let error_model =
218            IlluminaErrorModel::new(self.read_length, self.min_error_rate, self.max_error_rate);
219        let frag_dist = Normal::new(self.fragment_mean as f64, self.fragment_stddev as f64)
220            .map_err(|e| anyhow::anyhow!("Invalid fragment distribution parameters: {e}"))?;
221
222        let compression = self.compression;
223        let use_pool = self.threads > 1;
224
225        // When using multiple threads, create a shared compression pool that
226        // handles BGZF block compression and writing across all output files.
227        let mut pool_builder: Option<PoolBuilder<BufWriter<File>, BgzfCompressor>> = if use_pool {
228            let pb = PoolBuilder::new()
229                .threads(self.threads)
230                .compression_level(compression)
231                .map_err(|e| anyhow::anyhow!("failed to set compression level: {e}"))?;
232            log::info!("Using {} threads for BGZF compression", self.threads);
233            Some(pb)
234        } else {
235            None
236        };
237
238        let mut r1_writer = self.create_fastq_writer(".r1.fastq.gz", &mut pool_builder)?;
239        let mut r2_writer = if self.single_end {
240            None
241        } else {
242            Some(self.create_fastq_writer(".r2.fastq.gz", &mut pool_builder)?)
243        };
244
245        let mut golden_bam_writer = if self.golden_bam {
246            let bam_path = output_path(&self.output.output, ".golden.bam");
247            log::info!("Writing golden BAM to: {}", bam_path.display());
248            let meta = Self::golden_bam_metadata(resolved_vcf_sample);
249            if let Some(pb) = &mut pool_builder {
250                let file = File::create(&bam_path)?;
251                let pooled = pb.exchange(BufWriter::new(file));
252                Some(GoldenBamWriter::from_writer(Box::new(pooled), &dict, &meta)?)
253            } else {
254                Some(GoldenBamWriter::new(&bam_path, &dict, compression, &meta)?)
255            }
256        } else {
257            None
258        };
259
260        // Build the pool after all writers have been exchanged.
261        let mut pool = pool_builder
262            .map(PoolBuilder::build)
263            .transpose()
264            .map_err(|e| anyhow::anyhow!("failed to build compression pool: {e}"))?;
265
266        if self.golden_vcf {
267            log::warn!("--golden-vcf is not yet implemented; skipping");
268        }
269
270        let mut read_num: u64 = 0;
271        let contig_names: Vec<String> = dict.names().into_iter().map(String::from).collect();
272
273        for contig_name in &contig_names {
274            read_num += self.simulate_contig(
275                contig_name,
276                &dict,
277                &mut fasta,
278                targets.as_ref(),
279                total_reads,
280                &error_model,
281                &frag_dist,
282                &mut r1_writer,
283                &mut r2_writer,
284                &mut golden_bam_writer,
285                read_num,
286                &mut rng,
287            )?;
288        }
289
290        // Close writers first so pooled writers flush their buffers to the
291        // pool, then stop the pool to wait for all compression/writing to
292        // complete.
293        r1_writer.close();
294        if let Some(w) = r2_writer {
295            w.close();
296        }
297        if let Some(w) = golden_bam_writer {
298            w.close();
299        }
300        if let Some(ref mut p) = pool {
301            p.stop_pool().map_err(|e| anyhow::anyhow!("failed to stop compression pool: {e}"))?;
302        }
303
304        log::info!("Generated {read_num} total read pairs");
305        Ok(())
306    }
307
308    /// Create a FASTQ writer, using the compression pool if available or
309    /// single-threaded BGZF otherwise.
310    fn create_fastq_writer(
311        &self,
312        suffix: &str,
313        pool_builder: &mut Option<PoolBuilder<BufWriter<File>, BgzfCompressor>>,
314    ) -> Result<FastqWriter> {
315        let path = output_path(&self.output.output, suffix);
316        if let Some(pb) = pool_builder {
317            let file = File::create(&path)?;
318            let pooled = pb.exchange(BufWriter::new(file));
319            Ok(FastqWriter::from_writer(pooled))
320        } else {
321            FastqWriter::new(&path, self.compression)
322        }
323    }
324
325    /// Build the `@PG`/`@RG` metadata for the golden BAM header.  The
326    /// command line is captured verbatim from `std::env::args_os`, using
327    /// lossy UTF-8 conversion so that non-Unicode arguments do not panic.
328    /// `resolved_vcf_sample` should be the sample name returned by
329    /// [`crate::vcf::validate_vcf_sample`] during validation; when absent,
330    /// the sample defaults to [`DEFAULT_SAMPLE_NAME`].
331    fn golden_bam_metadata(resolved_vcf_sample: Option<&str>) -> GoldenBamMetadata {
332        let command_line = std::env::args_os()
333            .map(|arg| arg.to_string_lossy().into_owned())
334            .collect::<Vec<_>>()
335            .join(" ");
336        let sample =
337            resolved_vcf_sample.map_or_else(|| DEFAULT_SAMPLE_NAME.to_string(), str::to_string);
338        GoldenBamMetadata { command_line, version: VERSION.clone(), sample }
339    }
340
341    /// Compute the deterministic seed from simulation parameters.
342    fn compute_seed(&self) -> u64 {
343        let seed_desc = format!(
344            "{}:{}:{}:{}:{}:{}:{}",
345            self.reference.reference.display(),
346            self.coverage,
347            self.read_length,
348            self.fragment_mean,
349            self.fragment_stddev,
350            self.min_error_rate,
351            self.max_error_rate,
352        );
353        resolve_seed(self.seed.seed, &seed_desc)
354    }
355
356    /// Load BED target regions if specified.
357    fn load_targets(&self, dict: &SequenceDictionary) -> Result<Option<TargetRegions>> {
358        match &self.bed.targets {
359            Some(bed_path) => {
360                let t = TargetRegions::from_path(bed_path, dict)?;
361                log::info!("Loaded {} bp of target territory", t.total_territory());
362                Ok(Some(t))
363            }
364            None => Ok(None),
365        }
366    }
367
368    /// Compute total number of read pairs needed for the requested coverage.
369    fn compute_total_reads(&self, effective_size: u64) -> u64 {
370        let bases_per_read =
371            if self.single_end { self.read_length as u64 } else { self.read_length as u64 * 2 };
372        #[expect(clippy::cast_possible_truncation, reason = "read count fits u64")]
373        #[expect(clippy::cast_sign_loss, reason = "coverage is positive")]
374        let n = ((self.coverage * effective_size as f64) / bases_per_read as f64).round() as u64;
375        n
376    }
377
378    /// Simulate reads for a single contig.
379    #[allow(clippy::too_many_arguments, clippy::too_many_lines)]
380    fn simulate_contig(
381        &self,
382        contig_name: &str,
383        dict: &SequenceDictionary,
384        fasta: &mut Fasta,
385        targets: Option<&TargetRegions>,
386        total_reads: u64,
387        error_model: &IlluminaErrorModel,
388        frag_dist: &Normal<f64>,
389        r1_writer: &mut FastqWriter,
390        r2_writer: &mut Option<FastqWriter>,
391        golden_bam: &mut Option<GoldenBamWriter>,
392        start_read_num: u64,
393        rng: &mut SmallRng,
394    ) -> Result<u64> {
395        let contig_meta = dict.get_by_name(contig_name).unwrap();
396        let contig_len = contig_meta.length() as u64;
397        let contig_idx = contig_meta.index();
398
399        // Compute reads proportional to effective territory (if BED) or contig
400        // size (whole genome).  For targeted mode, effective territory accounts
401        // for the fact that fragments extend beyond targets — see
402        // TargetRegions::effective_territory for the derivation.
403        let contig_effective_size = targets
404            .map_or(contig_len, |t| t.contig_effective_territory(contig_idx, self.fragment_mean));
405        let effective_total =
406            targets.map_or(dict.total_length(), |t| t.effective_territory(self.fragment_mean));
407
408        if contig_effective_size == 0 || effective_total == 0 {
409            return Ok(0);
410        }
411
412        #[expect(clippy::cast_possible_truncation, reason = "read count fits u64")]
413        #[expect(clippy::cast_sign_loss, reason = "fraction is positive")]
414        let contig_reads = (total_reads as f64 * contig_effective_size as f64
415            / effective_total as f64)
416            .round() as u64;
417
418        if contig_reads == 0 {
419            return Ok(0);
420        }
421
422        log::info!("Simulating {contig_reads} reads for contig {contig_name} ({contig_len} bp)");
423
424        // Build a padded interval sampler when targets exist.  The pad covers
425        // the catchment zone — fragment start positions outside a target whose
426        // fragment still extends into the target.  Fragments whose drawn length
427        // is too short to actually reach a target are caught by the overlap
428        // check below (rare with this padding).
429        #[expect(clippy::cast_possible_truncation, reason = "pad fits u32")]
430        let sampler = targets.map(|tgt| {
431            let pad = (self.fragment_mean + 4 * self.fragment_stddev) as u32;
432            PaddedIntervalSampler::new(tgt.contig_intervals(contig_idx), pad, contig_len as u32)
433        });
434
435        let reference = fasta.load_contig(contig_name)?;
436
437        // Load variants if VCF provided.
438        let variants = if let Some(vcf_path) = &self.vcf.vcf {
439            crate::vcf::load_variants_for_contig(
440                vcf_path,
441                contig_name,
442                self.vcf.sample.as_deref(),
443                dict,
444            )?
445        } else {
446            Vec::new()
447        };
448
449        if !variants.is_empty() {
450            log::info!("  Loaded {} variants for {contig_name}", variants.len());
451        }
452
453        let max_ploidy = variants.iter().map(|v| v.genotype.ploidy()).max().unwrap_or(2);
454        let haplotypes = build_haplotypes(&variants, max_ploidy, rng);
455
456        let mut generated: u64 = 0;
457        let mut attempts: u64 = 0;
458        let max_attempts = contig_reads * 100;
459
460        while generated < contig_reads && attempts < max_attempts {
461            attempts += 1;
462
463            // Draw fragment length, clamped to [min_fragment_length, contig_length].
464            // Fragments shorter than read_length are padded with adapter sequence
465            // by the read extraction layer.
466            #[expect(clippy::cast_possible_truncation, reason = "fragment length fits usize")]
467            #[expect(clippy::cast_sign_loss, reason = "clamped to positive")]
468            let frag_len = frag_dist
469                .sample(rng)
470                .round()
471                .clamp(self.min_fragment_length as f64, contig_len as f64)
472                as usize;
473
474            if frag_len == 0 {
475                continue;
476            }
477
478            // Pick a random start position.  When targets exist, sample from
479            // the padded target regions so that nearly every draw overlaps a
480            // target — vastly more efficient than rejection-sampling across
481            // the whole contig.
482            #[expect(clippy::cast_possible_truncation, reason = "position fits u32")]
483            let ref_start = if let Some(samp) = &sampler {
484                let s = samp.sample_start(rng).unwrap();
485                // Ensure the fragment fits within the contig.
486                s.min((contig_len - frag_len as u64) as u32)
487            } else {
488                let max_start = contig_len - frag_len as u64;
489                if max_start > 0 { rng.random_range(0..=max_start) as u32 } else { 0 }
490            };
491
492            // Check BED target overlap — with padded sampling this rarely
493            // rejects, but catches the occasional short fragment drawn from
494            // the pad zone that doesn't reach the target.
495            #[expect(clippy::cast_possible_truncation, reason = "frag end fits u32")]
496            let frag_end = ref_start + frag_len as u32;
497            if let Some(tgt) = targets
498                && !tgt.overlaps(contig_idx, ref_start, frag_end)
499            {
500                continue;
501            }
502
503            let hap_idx = rng.random_range(0..haplotypes.len());
504            let is_forward: bool = rng.random();
505            let fragment =
506                extract_fragment(&haplotypes[hap_idx], &reference, ref_start, frag_len, is_forward);
507
508            let read_num = start_read_num + generated + 1;
509            let pair = generate_read_pair(
510                &fragment,
511                contig_name,
512                read_num,
513                self.read_length,
514                !self.single_end,
515                self.adapter_r1.as_bytes(),
516                self.adapter_r2.as_bytes(),
517                error_model,
518                self.simple_names,
519                rng,
520            );
521
522            r1_writer.write_read(&pair.read1)?;
523            if let Some(w) = r2_writer
524                && let Some(r2) = &pair.read2
525            {
526                w.write_read(r2)?;
527            }
528            if let Some(bam_w) = golden_bam {
529                bam_w.write_pair(&pair)?;
530            }
531
532            generated += 1;
533        }
534
535        if generated < contig_reads {
536            log::warn!(
537                "Only generated {generated}/{contig_reads} reads for {contig_name} \
538                 after {max_attempts} attempts"
539            );
540        }
541
542        Ok(generated)
543    }
544}