use crate::algorithms::sampler::{ReservoirSampler, TwoPassSampler};
use crate::core::{SeqReader, SeqRecord, SeqWriter};
use crate::utils::random::RandomGenerator;
use anyhow::{bail, Context, Result};
use clap::Args;
use rustc_hash::FxHashSet;
#[derive(Args, Debug)]
pub struct SampleArgs {
#[arg(value_name = "in.fq")]
pub input: String,
#[arg(value_name = "FRAC|NUM")]
pub target: f64,
#[arg(short = 's', long, value_name = "INT", default_value = "11")]
pub seed: u64,
#[arg(short = '2', long)]
pub two_pass: bool,
}
pub fn run(args: &SampleArgs) -> Result<()> {
let (is_fraction, sample_count) = if args.target >= 1.0 {
(false, args.target.round() as usize)
} else {
(true, 0)
};
if args.two_pass && is_fraction {
eprintln!("[警告] 采样分数时,-2选项将被忽略");
}
if args.two_pass && args.input == "-" {
bail!("两遍模式下,输入不能是stdin");
}
if !args.two_pass || is_fraction {
sample_single_pass(args, is_fraction, sample_count)
} else {
sample_two_pass(args, sample_count)
}
}
fn sample_single_pass(args: &SampleArgs, is_fraction: bool, sample_count: usize) -> Result<()> {
let mut reader = if args.input == "-" {
SeqReader::from_stdin()
} else {
SeqReader::from_path(&args.input)
.with_context(|| format!("无法打开输入文件: {}", args.input))?
};
let mut writer = SeqWriter::to_stdout();
if is_fraction {
let mut rng = RandomGenerator::with_seed(args.seed);
let mut record = SeqRecord::new(Vec::new(), Vec::new());
while reader.read_next(&mut record)? {
if rng.random() < args.target {
writer.write_record(&record)?;
}
}
} else {
let mut sampler = ReservoirSampler::new(sample_count, Some(args.seed));
let mut record = SeqRecord::new(Vec::new(), Vec::new());
while reader.read_next(&mut record)? {
sampler.add(record.clone());
}
for sample in sampler.into_samples() {
writer.write_record(&sample)?;
}
}
writer.flush()?;
Ok(())
}
fn sample_two_pass(args: &SampleArgs, sample_count: usize) -> Result<()> {
let selected_indices = {
let mut reader = SeqReader::from_path(&args.input)
.with_context(|| format!("无法打开输入文件进行第一遍扫描: {}", args.input))?;
let mut sampler = TwoPassSampler::new(sample_count, Some(args.seed));
let mut record = SeqRecord::new(Vec::new(), Vec::new());
while reader.read_next(&mut record)? {
sampler.add_index();
}
sampler.get_selected_indices()
};
let mut reader = SeqReader::from_path(&args.input)
.with_context(|| format!("无法打开输入文件进行第二遍扫描: {}", args.input))?;
let mut writer = SeqWriter::to_stdout();
let mut record = SeqRecord::new(Vec::new(), Vec::new());
let mut current_index = 0;
let selected_set: FxHashSet<usize> = selected_indices.into_iter().collect();
while reader.read_next(&mut record)? {
if selected_set.contains(¤t_index) {
writer.write_record(&record)?;
}
current_index += 1;
}
writer.flush()?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_sample_count() -> Result<()> {
let mut temp_file = NamedTempFile::new()?;
for i in 0..100 {
writeln!(temp_file, ">seq{}", i)?;
writeln!(temp_file, "ACGT")?;
}
temp_file.flush()?;
let _args = SampleArgs {
input: temp_file.path().to_str().unwrap().to_string(),
target: 10.0,
seed: 42,
two_pass: false,
};
Ok(())
}
#[test]
fn test_sample_fraction() -> Result<()> {
let mut temp_file = NamedTempFile::new()?;
for i in 0..100 {
writeln!(temp_file, ">seq{}", i)?;
writeln!(temp_file, "ACGT")?;
}
temp_file.flush()?;
let _args = SampleArgs {
input: temp_file.path().to_str().unwrap().to_string(),
target: 0.1,
seed: 42,
two_pass: false,
};
Ok(())
}
}