1use std::fs::File;
2use std::io::BufWriter;
3
4use anyhow::{Result, bail};
5use clap::Parser;
6use pooled_writer::PoolBuilder;
7use pooled_writer::bgzf::BgzfCompressor;
8use rand::rngs::SmallRng;
9use rand::{Rng, SeedableRng};
10use rand_distr::{Distribution, Normal};
11
12use super::command::{Command, output_path};
13use super::common::{BedOptions, OutputPrefixOptions, ReferenceOptions, SeedOptions, VcfOptions};
14use crate::bed::{PaddedIntervalSampler, TargetRegions};
15use crate::error_model::illumina::IlluminaErrorModel;
16use crate::fasta::Fasta;
17use crate::fragment::extract_fragment;
18use crate::haplotype::build_haplotypes;
19use crate::output::fastq::FastqWriter;
20use crate::output::golden_bam::{GoldenBamMetadata, GoldenBamWriter};
21use crate::read::generate_read_pair;
22use crate::seed::{derive_seed, resolve_seed};
23use crate::sequence_dict::SequenceDictionary;
24use crate::version::VERSION;
25
26const DEFAULT_SAMPLE_NAME: &str = "holodeck-simulation";
29
30const DEFAULT_ADAPTER_R1: &str = "AGATCGGAAGAGCACACGTCTGAACTCCAGTCA";
32
33const DEFAULT_ADAPTER_R2: &str = "AGATCGGAAGAGCGTCGTGTAGGGAAAGAGTGT";
35
36#[derive(Parser, Debug)]
43#[command(after_long_help = "EXAMPLES:\n \
44 holodeck simulate -r ref.fa -o out --coverage 30\n \
45 holodeck simulate -r ref.fa -v vars.vcf -o out --coverage 30 --golden-bam\n \
46 holodeck simulate -r ref.fa -v vars.vcf -b targets.bed -o out --coverage 100")]
47#[allow(clippy::struct_excessive_bools)] pub struct Simulate {
49 #[command(flatten)]
50 pub reference: ReferenceOptions,
51
52 #[command(flatten)]
53 pub vcf: VcfOptions,
54
55 #[command(flatten)]
56 pub bed: BedOptions,
57
58 #[command(flatten)]
59 pub output: OutputPrefixOptions,
60
61 #[command(flatten)]
62 pub seed: SeedOptions,
63
64 #[arg(short = 'c', long, default_value_t = 30.0, value_name = "FLOAT")]
66 pub coverage: f64,
67
68 #[arg(short = 'l', long, default_value_t = 150, value_name = "INT")]
70 pub read_length: usize,
71
72 #[arg(short = 'd', long, default_value_t = 300, value_name = "INT")]
74 pub fragment_mean: usize,
75
76 #[arg(short = 's', long, default_value_t = 50, value_name = "INT")]
78 pub fragment_stddev: usize,
79
80 #[arg(long, default_value_t = 20, value_name = "INT")]
84 pub min_fragment_length: usize,
85
86 #[arg(long)]
88 pub single_end: bool,
89
90 #[arg(long, default_value = DEFAULT_ADAPTER_R1, value_name = "SEQ")]
93 pub adapter_r1: String,
94
95 #[arg(long, default_value = DEFAULT_ADAPTER_R2, value_name = "SEQ")]
98 pub adapter_r2: String,
99
100 #[arg(long, default_value_t = 0.001, value_name = "FLOAT")]
102 pub min_error_rate: f64,
103
104 #[arg(long, default_value_t = 0.01, value_name = "FLOAT")]
106 pub max_error_rate: f64,
107
108 #[arg(long, default_value_t = 0.02, value_name = "FLOAT")]
116 pub max_n_frac: f64,
117
118 #[arg(long)]
120 pub golden_bam: bool,
121
122 #[arg(long)]
124 pub golden_vcf: bool,
125
126 #[arg(long)]
129 pub simple_names: bool,
130
131 #[arg(long, default_value_t = 1, value_name = "INT")]
135 pub compression: u8,
136
137 #[arg(short = 't', long, default_value_t = 4, value_name = "INT")]
139 pub threads: usize,
140}
141
142impl Command for Simulate {
143 fn execute(&self) -> Result<()> {
144 let resolved_sample = self.validate()?;
145 self.run_simulation(resolved_sample.as_deref())
146 }
147}
148
149impl Simulate {
150 fn validate(&self) -> Result<Option<String>> {
154 if !self.coverage.is_finite() || self.coverage <= 0.0 {
155 bail!("--coverage must be a finite positive number");
156 }
157 if self.read_length == 0 {
158 bail!("--read-length must be > 0");
159 }
160 if self.min_fragment_length == 0 {
161 bail!("--min-fragment-length must be at least 1");
162 }
163 if self.min_error_rate < 0.0 || self.max_error_rate < 0.0 {
164 bail!("Error rates must be >= 0");
165 }
166 if self.min_error_rate > self.max_error_rate {
167 bail!("--min-error-rate must be <= --max-error-rate");
168 }
169 if !self.max_n_frac.is_finite() || !(0.0..=1.0).contains(&self.max_n_frac) {
170 bail!("--max-n-frac must be in [0.0, 1.0]");
171 }
172 if self.compression > 12 {
173 bail!("--compression must be between 0 and 12");
174 }
175
176 if self.vcf.sample.is_some() && self.vcf.vcf.is_none() {
178 bail!("--sample requires --vcf");
179 }
180
181 let resolved_sample = if let Some(vcf_path) = &self.vcf.vcf {
186 Some(crate::vcf::validate_vcf_sample(vcf_path, self.vcf.sample.as_deref())?)
187 } else {
188 None
189 };
190
191 if let Some(parent) = self.output.output.parent()
193 && !parent.as_os_str().is_empty()
194 && !parent.exists()
195 {
196 bail!("Output directory does not exist: {}", parent.display());
197 }
198
199 Ok(resolved_sample)
200 }
201
202 fn run_simulation(&self, resolved_vcf_sample: Option<&str>) -> Result<()> {
207 let seed = self.compute_seed();
208 let mut rng = SmallRng::seed_from_u64(seed);
209 log::info!("Using random seed: {seed}");
210
211 let mut fasta = Fasta::from_path(&self.reference.reference)?;
212 let dict = fasta.dict().clone();
213 log::info!(
214 "Loaded reference with {} contigs, total {} bp",
215 dict.len(),
216 dict.total_length()
217 );
218
219 let targets = self.load_targets(&dict)?;
220 let effective_size = targets
221 .as_ref()
222 .map_or(dict.total_length(), |t| t.effective_territory(self.fragment_mean));
223 if effective_size == 0 {
224 bail!("Effective genome size is 0; nothing to simulate");
225 }
226
227 let total_reads = self.compute_total_reads(effective_size);
228 log::info!("Will generate {total_reads} read pairs for {:.1}x coverage", self.coverage);
229
230 let error_model =
231 IlluminaErrorModel::new(self.read_length, self.min_error_rate, self.max_error_rate);
232 let frag_dist = Normal::new(self.fragment_mean as f64, self.fragment_stddev as f64)
233 .map_err(|e| anyhow::anyhow!("Invalid fragment distribution parameters: {e}"))?;
234
235 let adapter_r1 = self.adapter_r1.to_ascii_uppercase();
239 let adapter_r2 = self.adapter_r2.to_ascii_uppercase();
240
241 let compression = self.compression;
242 let use_pool = self.threads > 1;
243
244 let mut pool_builder: Option<PoolBuilder<BufWriter<File>, BgzfCompressor>> = if use_pool {
247 let pb = PoolBuilder::new()
248 .threads(self.threads)
249 .compression_level(compression)
250 .map_err(|e| anyhow::anyhow!("failed to set compression level: {e}"))?;
251 log::info!("Using {} threads for BGZF compression", self.threads);
252 Some(pb)
253 } else {
254 None
255 };
256
257 let mut r1_writer = self.create_fastq_writer(".r1.fastq.gz", &mut pool_builder)?;
258 let mut r2_writer = if self.single_end {
259 None
260 } else {
261 Some(self.create_fastq_writer(".r2.fastq.gz", &mut pool_builder)?)
262 };
263
264 let mut golden_bam_writer = if self.golden_bam {
265 let bam_path = output_path(&self.output.output, ".golden.bam");
266 log::info!("Writing golden BAM to: {}", bam_path.display());
267 let meta = Self::golden_bam_metadata(resolved_vcf_sample);
268 if let Some(pb) = &mut pool_builder {
269 let file = File::create(&bam_path)?;
270 let pooled = pb.exchange(BufWriter::new(file));
271 Some(GoldenBamWriter::from_writer(Box::new(pooled), &dict, &meta)?)
272 } else {
273 Some(GoldenBamWriter::new(&bam_path, &dict, compression, &meta)?)
274 }
275 } else {
276 None
277 };
278
279 let mut pool = pool_builder
281 .map(PoolBuilder::build)
282 .transpose()
283 .map_err(|e| anyhow::anyhow!("failed to build compression pool: {e}"))?;
284
285 if self.golden_vcf {
286 log::warn!("--golden-vcf is not yet implemented; skipping");
287 }
288
289 let mut read_num: u64 = 0;
290 let contig_names: Vec<String> = dict.names().into_iter().map(String::from).collect();
291
292 for contig_name in &contig_names {
293 read_num += self.simulate_contig(
294 contig_name,
295 &dict,
296 &mut fasta,
297 targets.as_ref(),
298 total_reads,
299 &error_model,
300 &frag_dist,
301 adapter_r1.as_bytes(),
302 adapter_r2.as_bytes(),
303 &mut r1_writer,
304 &mut r2_writer,
305 &mut golden_bam_writer,
306 read_num,
307 seed,
308 &mut rng,
309 )?;
310 }
311
312 r1_writer.close();
316 if let Some(w) = r2_writer {
317 w.close();
318 }
319 if let Some(w) = golden_bam_writer {
320 w.close();
321 }
322 if let Some(ref mut p) = pool {
323 p.stop_pool().map_err(|e| anyhow::anyhow!("failed to stop compression pool: {e}"))?;
324 }
325
326 log::info!("Generated {read_num} total read pairs");
327 Ok(())
328 }
329
330 fn create_fastq_writer(
333 &self,
334 suffix: &str,
335 pool_builder: &mut Option<PoolBuilder<BufWriter<File>, BgzfCompressor>>,
336 ) -> Result<FastqWriter> {
337 let path = output_path(&self.output.output, suffix);
338 if let Some(pb) = pool_builder {
339 let file = File::create(&path)?;
340 let pooled = pb.exchange(BufWriter::new(file));
341 Ok(FastqWriter::from_writer(pooled))
342 } else {
343 FastqWriter::new(&path, self.compression)
344 }
345 }
346
347 fn golden_bam_metadata(resolved_vcf_sample: Option<&str>) -> GoldenBamMetadata {
354 let command_line = std::env::args_os()
355 .map(|arg| arg.to_string_lossy().into_owned())
356 .collect::<Vec<_>>()
357 .join(" ");
358 let sample =
359 resolved_vcf_sample.map_or_else(|| DEFAULT_SAMPLE_NAME.to_string(), str::to_string);
360 GoldenBamMetadata { command_line, version: VERSION.clone(), sample }
361 }
362
363 fn compute_seed(&self) -> u64 {
365 let seed_desc = format!(
366 "{}:{}:{}:{}:{}:{}:{}",
367 self.reference.reference.display(),
368 self.coverage,
369 self.read_length,
370 self.fragment_mean,
371 self.fragment_stddev,
372 self.min_error_rate,
373 self.max_error_rate,
374 );
375 resolve_seed(self.seed.seed, &seed_desc)
376 }
377
378 fn load_targets(&self, dict: &SequenceDictionary) -> Result<Option<TargetRegions>> {
380 match &self.bed.targets {
381 Some(bed_path) => {
382 let t = TargetRegions::from_path(bed_path, dict)?;
383 log::info!("Loaded {} bp of target territory", t.total_territory());
384 Ok(Some(t))
385 }
386 None => Ok(None),
387 }
388 }
389
390 fn compute_total_reads(&self, effective_size: u64) -> u64 {
392 let bases_per_read =
393 if self.single_end { self.read_length as u64 } else { self.read_length as u64 * 2 };
394 #[expect(clippy::cast_possible_truncation, reason = "read count fits u64")]
395 #[expect(clippy::cast_sign_loss, reason = "coverage is positive")]
396 let n = ((self.coverage * effective_size as f64) / bases_per_read as f64).round() as u64;
397 n
398 }
399
400 #[allow(clippy::too_many_arguments, clippy::too_many_lines)]
402 fn simulate_contig(
403 &self,
404 contig_name: &str,
405 dict: &SequenceDictionary,
406 fasta: &mut Fasta,
407 targets: Option<&TargetRegions>,
408 total_reads: u64,
409 error_model: &IlluminaErrorModel,
410 frag_dist: &Normal<f64>,
411 adapter_r1: &[u8],
412 adapter_r2: &[u8],
413 r1_writer: &mut FastqWriter,
414 r2_writer: &mut Option<FastqWriter>,
415 golden_bam: &mut Option<GoldenBamWriter>,
416 start_read_num: u64,
417 main_seed: u64,
418 rng: &mut SmallRng,
419 ) -> Result<u64> {
420 let contig_meta = dict.get_by_name(contig_name).unwrap();
421 let contig_len = contig_meta.length() as u64;
422 let contig_idx = contig_meta.index();
423
424 let contig_effective_size = targets
429 .map_or(contig_len, |t| t.contig_effective_territory(contig_idx, self.fragment_mean));
430 let effective_total =
431 targets.map_or(dict.total_length(), |t| t.effective_territory(self.fragment_mean));
432
433 if contig_effective_size == 0 || effective_total == 0 {
434 return Ok(0);
435 }
436
437 #[expect(clippy::cast_possible_truncation, reason = "read count fits u64")]
438 #[expect(clippy::cast_sign_loss, reason = "fraction is positive")]
439 let contig_reads = (total_reads as f64 * contig_effective_size as f64
440 / effective_total as f64)
441 .round() as u64;
442
443 if contig_reads == 0 {
444 return Ok(0);
445 }
446
447 log::info!("Simulating {contig_reads} reads for contig {contig_name} ({contig_len} bp)");
448
449 #[expect(clippy::cast_possible_truncation, reason = "pad fits u32")]
455 let sampler = targets.map(|tgt| {
456 let pad = (self.fragment_mean + 4 * self.fragment_stddev) as u32;
457 PaddedIntervalSampler::new(tgt.contig_intervals(contig_idx), pad, contig_len as u32)
458 });
459
460 let contig_seed = derive_seed(main_seed, contig_name);
463 let mut ref_rng = SmallRng::seed_from_u64(contig_seed);
464 let reference = fasta.load_contig(contig_name, &mut ref_rng)?;
465
466 let variants = if let Some(vcf_path) = &self.vcf.vcf {
468 crate::vcf::load_variants_for_contig(
469 vcf_path,
470 contig_name,
471 self.vcf.sample.as_deref(),
472 dict,
473 )?
474 } else {
475 Vec::new()
476 };
477
478 if !variants.is_empty() {
479 log::info!(" Loaded {} variants for {contig_name}", variants.len());
480 }
481
482 let max_ploidy = variants.iter().map(|v| v.genotype.ploidy()).max().unwrap_or(2);
483 let haplotypes = build_haplotypes(&variants, max_ploidy, rng);
484
485 let mut generated: u64 = 0;
486 let mut attempts: u64 = 0;
487 let max_attempts = contig_reads * 100;
488
489 while generated < contig_reads && attempts < max_attempts {
490 attempts += 1;
491
492 #[expect(clippy::cast_possible_truncation, reason = "fragment length fits usize")]
496 #[expect(clippy::cast_sign_loss, reason = "clamped to positive")]
497 let frag_len = frag_dist
498 .sample(rng)
499 .round()
500 .clamp(self.min_fragment_length as f64, contig_len as f64)
501 as usize;
502
503 if frag_len == 0 {
504 continue;
505 }
506
507 #[expect(clippy::cast_possible_truncation, reason = "position fits u32")]
512 let ref_start = if let Some(samp) = &sampler {
513 let s = samp.sample_start(rng).unwrap();
514 s.min((contig_len - frag_len as u64) as u32)
516 } else {
517 let max_start = contig_len - frag_len as u64;
518 if max_start > 0 { rng.random_range(0..=max_start) as u32 } else { 0 }
519 };
520
521 #[expect(clippy::cast_possible_truncation, reason = "frag end fits u32")]
525 let frag_end = ref_start + frag_len as u32;
526 if let Some(tgt) = targets
527 && !tgt.overlaps(contig_idx, ref_start, frag_end)
528 {
529 continue;
530 }
531
532 let hap_idx = rng.random_range(0..haplotypes.len());
533 let is_forward: bool = rng.random();
534 let fragment =
535 extract_fragment(&haplotypes[hap_idx], &reference, ref_start, frag_len, is_forward);
536
537 let read_num = start_read_num + generated + 1;
538 let Some(pair) = generate_read_pair(
539 &fragment,
540 contig_name,
541 read_num,
542 self.read_length,
543 !self.single_end,
544 adapter_r1,
545 adapter_r2,
546 self.max_n_frac,
547 error_model,
548 self.simple_names,
549 rng,
550 ) else {
551 continue;
553 };
554
555 r1_writer.write_read(&pair.read1)?;
556 if let Some(w) = r2_writer
557 && let Some(r2) = &pair.read2
558 {
559 w.write_read(r2)?;
560 }
561 if let Some(bam_w) = golden_bam {
562 bam_w.write_pair(&pair)?;
563 }
564
565 generated += 1;
566 }
567
568 if generated < contig_reads {
569 log::warn!(
570 "Only generated {generated}/{contig_reads} reads for {contig_name} \
571 after {max_attempts} attempts"
572 );
573 }
574
575 Ok(generated)
576 }
577}