use std::io::{self, BufRead};
use std::path::Path;
use anyhow::Result;
use rand::Rng;
pub struct IndelErrorModel {
pub indel_rate: f64,
pub insertion_fraction: f64,
pub max_length: usize,
}
pub fn inject_indel_errors(
seq: &mut Vec<u8>,
qual: &mut Vec<u8>,
read_length: usize,
model: &IndelErrorModel,
rng: &mut impl Rng,
) {
let mut events: Vec<(usize, bool, usize)> = Vec::new();
for pos in (0..seq.len()).rev() {
if rng.random::<f64>() < model.indel_rate {
let is_insertion = rng.random::<f64>() < model.insertion_fraction;
let mut len = 1usize;
while len < model.max_length && rng.random::<f64>() < 0.3 {
len += 1;
}
events.push((pos, is_insertion, len));
}
}
const BASES: [u8; 4] = [b'A', b'C', b'G', b'T'];
for (pos, is_insertion, len) in events {
if pos >= seq.len() {
continue;
}
if is_insertion {
let insert_pos = (pos + 1).min(seq.len());
let q = qual[pos];
for k in 0..len {
let base = BASES[rng.random_range(0..4)];
seq.insert(insert_pos + k, base);
qual.insert(insert_pos + k, q);
}
} else {
let del_count = len.min(seq.len() - pos);
seq.drain(pos..pos + del_count);
qual.drain(pos..pos + del_count);
}
}
if seq.len() > read_length {
seq.truncate(read_length);
qual.truncate(read_length);
} else {
while seq.len() < read_length {
seq.push(b'N');
qual.push(0);
}
}
debug_assert_eq!(
seq.len(),
read_length,
"seq length mismatch after indel injection"
);
debug_assert_eq!(
qual.len(),
read_length,
"qual length mismatch after indel injection"
);
}
pub struct CycleErrorCurve {
curve: Vec<f64>,
}
impl CycleErrorCurve {
pub fn flat(read_length: usize, base_error_rate: f64) -> Self {
Self {
curve: vec![base_error_rate; read_length],
}
}
pub fn exponential(
read_length: usize,
base_error_rate: f64,
tail_start_fraction: f64,
tail_rate_multiplier: f64,
) -> Self {
let tail_start = (tail_start_fraction * read_length as f64) as usize;
let k = tail_rate_multiplier.ln();
let denom = if read_length > 1 && tail_start < read_length - 1 {
(read_length - 1 - tail_start) as f64
} else {
1.0
};
let curve = (0..read_length)
.map(|i| {
if i < tail_start {
base_error_rate
} else {
let t = (i - tail_start) as f64 / denom;
base_error_rate * (k * t).exp()
}
})
.collect();
Self { curve }
}
pub fn from_tsv(path: &Path, read_length: usize) -> Result<Self> {
let file = std::fs::File::open(path)?;
let reader = io::BufReader::new(file);
let mut points: Vec<(usize, f64)> = Vec::new();
for line in reader.lines() {
let line = line?;
let line = line.trim();
if line.is_empty() {
continue;
}
let mut parts = line.splitn(2, '\t');
let cycle: usize = parts
.next()
.ok_or_else(|| anyhow::anyhow!("missing cycle column in TSV"))?
.trim()
.parse()?;
let rate: f64 = parts
.next()
.ok_or_else(|| anyhow::anyhow!("missing rate column in TSV"))?
.trim()
.parse()?;
points.push((cycle, rate));
}
anyhow::ensure!(
!points.is_empty(),
"cycle error TSV is empty: {}",
path.display()
);
points.sort_by_key(|&(c, _)| c);
let curve = (0..read_length)
.map(|i| Self::interpolate(&points, i))
.collect();
Ok(Self { curve })
}
pub fn rates(&self) -> &[f64] {
&self.curve
}
pub fn from_rates(iter: impl Iterator<Item = f64>, read_length: usize) -> Self {
let mut curve = Vec::with_capacity(read_length);
for rate in iter {
curve.push(rate);
}
curve.resize(read_length, 0.0);
Self { curve }
}
fn interpolate(points: &[(usize, f64)], i: usize) -> f64 {
if i <= points[0].0 {
return points[0].1;
}
let last = points[points.len() - 1];
if i >= last.0 {
return last.1;
}
let pos = points.partition_point(|&(c, _)| c <= i);
let (c0, r0) = points[pos - 1];
let (c1, r1) = points[pos];
let t = (i - c0) as f64 / (c1 - c0) as f64;
r0 + t * (r1 - r0)
}
}
pub fn inject_cycle_errors(seq: &mut [u8], model: &CycleErrorCurve, rng: &mut impl Rng) {
for (i, base) in seq.iter_mut().enumerate() {
let rate = if i < model.curve.len() {
model.curve[i]
} else {
0.0
};
if rng.random::<f64>() < rate {
let alts: [u8; 3] = match *base {
b'A' => [b'C', b'G', b'T'],
b'C' => [b'A', b'G', b'T'],
b'G' => [b'A', b'C', b'T'],
b'T' => [b'A', b'C', b'G'],
_ => [b'A', b'C', b'G'], };
*base = alts[rng.random_range(0..3)];
}
}
}
pub struct StrandBiasModel {
#[allow(dead_code)]
pub r2_error_multiplier: f64,
pub r2_quality_offset: i8,
}
impl StrandBiasModel {
pub fn apply_to_r2_qual(&self, qual: &mut [u8]) {
for q in qual.iter_mut() {
if self.r2_quality_offset > 0 {
*q = q.saturating_sub(self.r2_quality_offset as u8);
} else if self.r2_quality_offset < 0 {
*q = q
.saturating_add(self.r2_quality_offset.unsigned_abs())
.min(93);
}
}
}
}
pub struct CorrelatedErrorModel {
pub burst_rate: f64,
pub burst_length_mean: f64,
}
pub fn inject_burst_errors(
seq: &mut [u8],
qual: &mut [u8],
model: &CorrelatedErrorModel,
rng: &mut impl Rng,
) {
const BASES: [u8; 4] = [b'A', b'C', b'G', b'T'];
let len = seq.len();
let mut burst_remaining: usize = 0;
let mut burst_base: u8 = b'A';
for i in 0..len {
if burst_remaining > 0 {
seq[i] = burst_base;
qual[i] = 12;
burst_remaining -= 1;
} else if rng.random::<f64>() < model.burst_rate {
let p_continue = 1.0 - 1.0 / model.burst_length_mean;
let mut drawn_len = 1usize;
let max_len = len - i;
while drawn_len < max_len && rng.random::<f64>() < p_continue {
drawn_len += 1;
}
let current = seq[i];
let wrong: u8 = loop {
let candidate = BASES[rng.random_range(0..4)];
if candidate != current {
break candidate;
}
};
burst_base = wrong;
seq[i] = burst_base;
qual[i] = 12;
burst_remaining = drawn_len - 1;
}
}
}
pub fn inject_context_errors(
seq: &mut [u8],
base_rate: f64,
model: &KmerErrorModel,
rng: &mut impl Rng,
) {
let len = seq.len();
for i in 0..len {
let multiplier = model.sub_multiplier_at(seq, i) as f64;
let effective_rate = base_rate * multiplier;
if effective_rate <= 0.0 {
continue;
}
if rng.random::<f64>() < effective_rate {
let original = seq[i];
let alts: [u8; 3] = match original {
b'A' => [b'C', b'G', b'T'],
b'C' => [b'A', b'G', b'T'],
b'G' => [b'A', b'C', b'T'],
b'T' => [b'A', b'C', b'G'],
_ => [b'A', b'C', b'G'],
};
seq[i] = alts[rng.random_range(0..3)];
}
}
}
fn base_to_bits(b: u8) -> usize {
match b {
b'A' | b'a' => 0,
b'C' | b'c' => 1,
b'G' | b'g' => 2,
b'T' | b't' => 3,
_ => 0,
}
}
#[derive(serde::Deserialize)]
pub struct KmerProfileJson {
pub kmer_length: usize,
pub rules: Vec<KmerRuleJson>,
}
#[derive(serde::Deserialize)]
pub struct KmerRuleJson {
pub context: String,
pub sub_multiplier: f32,
pub indel_multiplier: f32,
}
pub struct KmerErrorModel {
pub k: usize,
sub_multipliers: Vec<f32>,
indel_multipliers: Vec<f32>,
}
impl KmerErrorModel {
pub fn uniform(k: usize) -> Self {
let size = 1 << (k * 2);
Self {
k,
sub_multipliers: vec![1.0f32; size],
indel_multipliers: vec![1.0f32; size],
}
}
pub fn set_rule(&mut self, context: &str, sub_multiplier: f32, indel_multiplier: f32) {
assert_eq!(
context.len(),
self.k,
"context length {} != k={}",
context.len(),
self.k
);
let idx = self.kmer_index(context.as_bytes());
self.sub_multipliers[idx] = sub_multiplier;
self.indel_multipliers[idx] = indel_multiplier;
}
fn kmer_index(&self, bases: &[u8]) -> usize {
bases
.iter()
.fold(0usize, |acc, &b| (acc << 2) | base_to_bits(b))
}
pub fn sub_multiplier_at(&self, seq: &[u8], pos: usize) -> f32 {
if pos + 1 < self.k {
return 1.0;
}
let start = pos + 1 - self.k;
let idx = self.kmer_index(&seq[start..=pos]);
self.sub_multipliers[idx]
}
#[allow(dead_code)]
pub fn indel_multiplier_at(&self, seq: &[u8], pos: usize) -> f32 {
if pos + 1 < self.k {
return 1.0;
}
let start = pos + 1 - self.k;
let idx = self.kmer_index(&seq[start..=pos]);
self.indel_multipliers[idx]
}
pub fn from_profile_json(path: &std::path::Path) -> anyhow::Result<Self> {
let text = std::fs::read_to_string(path)?;
let profile: KmerProfileJson = serde_json::from_str(&text)?;
let mut model = Self::uniform(profile.kmer_length);
for rule in &profile.rules {
model.set_rule(&rule.context, rule.sub_multiplier, rule.indel_multiplier);
}
Ok(model)
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::rngs::StdRng;
use rand::SeedableRng;
fn make_read(len: usize) -> (Vec<u8>, Vec<u8>) {
let seq = vec![b'A'; len];
let qual = vec![30u8; len];
(seq, qual)
}
#[test]
fn test_indel_rate_zero_leaves_read_unchanged() {
let model = IndelErrorModel {
indel_rate: 0.0,
insertion_fraction: 0.5,
max_length: 3,
};
let read_length = 50;
let mut rng = StdRng::seed_from_u64(42);
let original_seq = vec![b'A'; read_length];
let original_qual = vec![30u8; read_length];
let mut seq = original_seq.clone();
let mut qual = original_qual.clone();
inject_indel_errors(&mut seq, &mut qual, read_length, &model, &mut rng);
assert_eq!(seq, original_seq, "seq should be unchanged at rate 0.0");
assert_eq!(qual, original_qual, "qual should be unchanged at rate 0.0");
}
#[test]
fn test_high_indel_rate_modifies_reads() {
let model = IndelErrorModel {
indel_rate: 0.5,
insertion_fraction: 0.5,
max_length: 2,
};
let read_length = 20;
let mut rng = StdRng::seed_from_u64(99);
let mut changed = 0usize;
let n_reads = 1_000;
for _ in 0..n_reads {
let original = vec![b'A'; read_length];
let mut seq = original.clone();
let mut qual = vec![30u8; read_length];
inject_indel_errors(&mut seq, &mut qual, read_length, &model, &mut rng);
if seq != original {
changed += 1;
}
}
assert!(
changed > 900,
"expected >90% reads modified at rate 0.5, got {}/{}",
changed,
n_reads
);
}
#[test]
fn test_fixed_length_contract() {
let model = IndelErrorModel {
indel_rate: 0.1,
insertion_fraction: 0.5,
max_length: 3,
};
let read_length = 100;
let mut rng = StdRng::seed_from_u64(7);
for _ in 0..10_000 {
let (mut seq, mut qual) = make_read(read_length);
inject_indel_errors(&mut seq, &mut qual, read_length, &model, &mut rng);
assert_eq!(
seq.len(),
read_length,
"seq length {} != {}",
seq.len(),
read_length
);
assert_eq!(
qual.len(),
read_length,
"qual length {} != {}",
qual.len(),
read_length
);
}
}
#[test]
fn test_only_insertions_when_fraction_one() {
let model = IndelErrorModel {
indel_rate: 0.3,
insertion_fraction: 1.0,
max_length: 1,
};
let read_length = 20;
let mut rng = StdRng::seed_from_u64(7);
let mut has_non_a = false;
for _ in 0..100 {
let mut seq = vec![b'A'; read_length];
let mut qual = vec![30u8; read_length];
inject_indel_errors(&mut seq, &mut qual, read_length, &model, &mut rng);
assert_eq!(seq.len(), read_length, "fixed-length contract violated");
if seq.iter().any(|&b| b != b'A') {
has_non_a = true;
}
}
assert!(
has_non_a,
"insertions should produce non-A bases in at least one read"
);
}
#[test]
fn test_length_distribution() {
let model = IndelErrorModel {
indel_rate: 1.0,
insertion_fraction: 1.0,
max_length: 5,
};
let read_length = 20;
let n_reads = 1_000;
let mut rng = StdRng::seed_from_u64(13);
let mut modified = 0usize;
for _ in 0..n_reads {
let original = vec![b'A'; read_length];
let mut seq = original.clone();
let mut qual = vec![30u8; read_length];
inject_indel_errors(&mut seq, &mut qual, read_length, &model, &mut rng);
assert_eq!(seq.len(), read_length, "fixed-length contract violated");
if seq != original {
modified += 1;
}
}
assert!(
modified > n_reads / 2,
"expected >50% reads modified at rate 1.0, got {}/{}",
modified,
n_reads
);
}
#[test]
fn test_indel_rate_statistical() {
let model = IndelErrorModel {
indel_rate: 0.1,
insertion_fraction: 1.0,
max_length: 1,
};
let mut rng = StdRng::seed_from_u64(12345);
let read_length = 20usize;
let n_reads = 10_000usize;
let mut non_a_count = 0usize;
let total_bases = n_reads * read_length;
for _ in 0..n_reads {
let mut seq = vec![b'A'; read_length];
let mut qual = vec![30u8; read_length];
inject_indel_errors(&mut seq, &mut qual, read_length, &model, &mut rng);
non_a_count += seq.iter().filter(|&&b| b != b'A').count();
}
let observed_rate = non_a_count as f64 / total_bases as f64;
assert!(
(0.04..=0.09).contains(&observed_rate),
"expected non-A rate in [0.04, 0.09], got {:.4}",
observed_rate
);
}
#[test]
fn test_strand_bias_lowers_quality() {
let model = StrandBiasModel {
r2_error_multiplier: 1.3,
r2_quality_offset: 3,
};
let mut qual = vec![40u8; 20];
model.apply_to_r2_qual(&mut qual);
assert!(
qual.iter().all(|&q| q == 37),
"all qualities should be 37 after subtracting offset 3 from 40"
);
}
#[test]
fn test_strand_bias_zero_offset_noop() {
let model = StrandBiasModel {
r2_error_multiplier: 1.3,
r2_quality_offset: 0,
};
let original = vec![30u8, 25u8, 40u8, 10u8, 5u8];
let mut qual = original.clone();
model.apply_to_r2_qual(&mut qual);
assert_eq!(qual, original, "zero offset must leave quality unchanged");
}
#[test]
fn test_strand_bias_negative_offset_raises_quality() {
let model = StrandBiasModel {
r2_error_multiplier: 1.0,
r2_quality_offset: -5,
};
let mut qual = vec![30u8; 10];
model.apply_to_r2_qual(&mut qual);
assert!(
qual.iter().all(|&q| q == 35),
"negative offset -5 should raise quality from 30 to 35"
);
}
#[test]
fn test_burst_errors_correlated() {
let model = CorrelatedErrorModel {
burst_rate: 0.1,
burst_length_mean: 5.0,
};
let mut rng = StdRng::seed_from_u64(42);
let n_reads = 10_000usize;
let read_len = 50usize;
let mut total_run_len = 0usize;
let mut run_count = 0usize;
for _ in 0..n_reads {
let mut seq = vec![b'A'; read_len];
let mut qual = vec![30u8; read_len];
inject_burst_errors(&mut seq, &mut qual, &model, &mut rng);
if let Some(start) = seq.iter().position(|&b| b != b'A') {
let run_base = seq[start];
let run_len = seq[start..].iter().take_while(|&&b| b == run_base).count();
total_run_len += run_len;
run_count += 1;
}
}
assert!(run_count > 0, "expected some reads to have bursts");
let avg_run = total_run_len as f64 / run_count as f64;
assert!(
avg_run > 1.5,
"expected average burst run length > 1.5, got {:.3}",
avg_run
);
}
#[test]
fn test_burst_base_consistent() {
let model = CorrelatedErrorModel {
burst_rate: 0.5,
burst_length_mean: 4.0,
};
let mut rng = StdRng::seed_from_u64(99);
let read_len = 50usize;
let mut found_multi_base_burst = false;
for _ in 0..1_000 {
let mut seq = vec![b'A'; read_len];
let mut qual = vec![30u8; read_len];
inject_burst_errors(&mut seq, &mut qual, &model, &mut rng);
if let Some(start) = seq.iter().position(|&b| b != b'A') {
let burst_base = seq[start];
let run_len = seq[start..]
.iter()
.take_while(|&&b| b == burst_base)
.count();
if run_len > 1 {
assert!(
seq[start..start + run_len].iter().all(|&b| b == burst_base),
"burst bases are not all identical"
);
found_multi_base_burst = true;
break;
}
}
}
assert!(
found_multi_base_burst,
"expected to find at least one burst longer than 1 base"
);
}
#[test]
fn test_burst_rate_zero_noop() {
let model = CorrelatedErrorModel {
burst_rate: 0.0,
burst_length_mean: 3.0,
};
let mut rng = StdRng::seed_from_u64(7);
let read_len = 50usize;
for _ in 0..1_000 {
let original_seq = vec![b'A'; read_len];
let original_qual = vec![30u8; read_len];
let mut seq = original_seq.clone();
let mut qual = original_qual.clone();
inject_burst_errors(&mut seq, &mut qual, &model, &mut rng);
assert_eq!(seq, original_seq, "burst_rate 0.0 must not change seq");
assert_eq!(qual, original_qual, "burst_rate 0.0 must not change qual");
}
}
#[test]
fn test_insertion_fraction_statistical() {
let model = IndelErrorModel {
indel_rate: 0.01,
insertion_fraction: 0.7,
max_length: 1,
};
let mut rng = StdRng::seed_from_u64(54321);
let read_length = 20usize;
let n_reads = 10_000usize;
let mut insertion_evidence = 0usize; let mut deletion_evidence = 0usize; for _ in 0..n_reads {
let mut seq = vec![b'A'; read_length];
let mut qual = vec![30u8; read_length];
inject_indel_errors(&mut seq, &mut qual, read_length, &model, &mut rng);
for &b in &seq {
if b != b'A' && b != b'N' {
insertion_evidence += 1;
}
if b == b'N' {
deletion_evidence += 1;
}
}
}
let total_evidence = insertion_evidence + deletion_evidence;
if total_evidence > 100 {
let observed_insertion_fraction = insertion_evidence as f64 / total_evidence as f64;
assert!(
(0.60..=0.80).contains(&observed_insertion_fraction),
"expected insertion fraction ~0.7, got {:.4}",
observed_insertion_fraction
);
}
}
#[test]
fn test_flat_curve_rate() {
let model = CycleErrorCurve::flat(50, 0.1);
let mut rng = StdRng::seed_from_u64(1001);
let n_reads = 10_000usize;
let read_length = 50usize;
let mut non_a = 0usize;
for _ in 0..n_reads {
let mut seq = vec![b'A'; read_length];
inject_cycle_errors(&mut seq, &model, &mut rng);
non_a += seq.iter().filter(|&&b| b != b'A').count();
}
let rate = non_a as f64 / (n_reads * read_length) as f64;
assert!(
(0.09..=0.11).contains(&rate),
"expected flat rate in [0.09, 0.11], got {:.4}",
rate
);
}
#[test]
fn test_exponential_tail_rises() {
let model = CycleErrorCurve::exponential(100, 0.01, 0.8, 10.0);
let expected_last = 0.1f64;
let actual_last = model.curve[99];
assert!(
(actual_last - expected_last).abs() / expected_last < 0.05,
"expected curve[99] ≈ {:.4}, got {:.6}",
expected_last,
actual_last
);
assert_eq!(
model.curve[0], 0.01,
"curve[0] should equal base_error_rate"
);
}
#[test]
fn test_exponential_curve_len() {
let flat = CycleErrorCurve::flat(75, 0.005);
assert_eq!(flat.curve.len(), 75, "flat curve length mismatch");
let exp = CycleErrorCurve::exponential(120, 0.005, 0.7, 8.0);
assert_eq!(exp.curve.len(), 120, "exponential curve length mismatch");
}
#[test]
fn test_substitution_not_identity() {
let model = CycleErrorCurve::flat(1, 1.0); let mut rng = StdRng::seed_from_u64(7777);
for _ in 0..10_000 {
let original = b'A';
let mut seq = vec![original];
inject_cycle_errors(&mut seq, &model, &mut rng);
assert_ne!(
seq[0], original,
"substituted base must differ from original"
);
}
}
#[test]
fn test_uniform_model_all_ones() {
let model = KmerErrorModel::uniform(3);
assert!(
model.sub_multipliers.iter().all(|&v| v == 1.0f32),
"all sub_multipliers should be 1.0"
);
let seq = b"ACGTACGT";
for pos in 0..seq.len() {
assert_eq!(
model.sub_multiplier_at(seq, pos),
1.0f32,
"expected 1.0 at pos {}",
pos
);
}
}
#[test]
fn test_set_rule_lookup() {
let mut model = KmerErrorModel::uniform(3);
model.set_rule("GGG", 5.0, 1.0);
let seq = b"AAAGGG";
assert_eq!(
model.sub_multiplier_at(seq, 5),
5.0f32,
"expected 5.0 at GGG context"
);
assert_eq!(
model.sub_multiplier_at(seq, 2),
1.0f32,
"expected 1.0 at AAA context"
);
}
#[test]
fn test_rolling_hash_matches_naive() {
let mut model = KmerErrorModel::uniform(3);
model.set_rule("GGC", 2.0, 3.0);
model.set_rule("TTT", 4.0, 1.5);
let alphabet = [b'A', b'C', b'G', b'T'];
let mut rng = StdRng::seed_from_u64(42);
use rand::Rng as _;
for _ in 0..100 {
let seq: Vec<u8> = (0..20).map(|_| alphabet[rng.random_range(0..4)]).collect();
for pos in 0..seq.len() {
let via_fn = model.sub_multiplier_at(&seq, pos);
let naive = if pos + 1 < model.k {
1.0f32
} else {
let start = pos + 1 - model.k;
let idx = model.kmer_index(&seq[start..=pos]);
model.sub_multipliers[idx]
};
assert_eq!(via_fn, naive, "mismatch at pos {} in seq {:?}", pos, seq);
}
}
}
#[test]
fn test_elevated_context_increases_errors() {
let mut model = KmerErrorModel::uniform(2);
model.set_rule("GG", 20.0, 1.0);
let seq = b"AAGGTAA";
assert_eq!(model.sub_multiplier_at(seq, 0), 1.0f32);
assert_eq!(model.sub_multiplier_at(seq, 1), 1.0f32);
assert_eq!(model.sub_multiplier_at(seq, 2), 1.0f32);
assert_eq!(model.sub_multiplier_at(seq, 3), 20.0f32);
assert_eq!(model.sub_multiplier_at(seq, 5), 1.0f32);
}
#[test]
fn test_kmer_size_1_to_4() {
for k in 1usize..=4 {
let model = KmerErrorModel::uniform(k);
let expected = 4usize.pow(k as u32);
assert_eq!(
model.sub_multipliers.len(),
expected,
"k={}: expected {} sub_multipliers, got {}",
k,
expected,
model.sub_multipliers.len()
);
assert_eq!(
model.indel_multipliers.len(),
expected,
"k={}: expected {} indel_multipliers, got {}",
k,
expected,
model.indel_multipliers.len()
);
}
}
}