use std::collections::HashMap;
use std::io::BufReader;
use std::path::Path;
use anyhow::{Context, Result};
use noodles_bam as bam;
use noodles_sam::{
self as sam,
alignment::{record::Sequence as _, record_buf::RecordBuf},
};
use crate::core::error_profile::{ContextEffectJson, ProfileJson, QualityDistributionJson};
#[derive(Debug, Clone)]
pub struct LearnerConfig {
pub sample_size: usize,
pub min_mapq: u8,
}
impl Default for LearnerConfig {
fn default() -> Self {
Self {
sample_size: 1_000_000,
min_mapq: 20,
}
}
}
#[derive(Debug)]
pub struct LearnerStats {
pub reads_examined: u64,
pub reads_skipped: u64,
pub quality_counts_r1: Vec<HashMap<u8, u64>>,
pub quality_counts_r2: Vec<HashMap<u8, u64>>,
pub insert_size_counts: HashMap<i32, u64>,
pub substitution_counts: HashMap<String, u64>,
pub gc_bias: Vec<(u64, u64)>,
pub context_error_counts: HashMap<String, u64>,
pub context_total_counts: HashMap<String, u64>,
pub max_read_length: usize,
}
impl LearnerStats {
fn new(max_read_length: usize) -> Self {
Self {
reads_examined: 0,
reads_skipped: 0,
quality_counts_r1: vec![HashMap::new(); max_read_length],
quality_counts_r2: vec![HashMap::new(); max_read_length],
insert_size_counts: HashMap::new(),
substitution_counts: HashMap::new(),
gc_bias: vec![(0, 0); 101],
context_error_counts: HashMap::new(),
context_total_counts: HashMap::new(),
max_read_length,
}
}
}
pub struct ProfileLearner {
config: LearnerConfig,
}
impl ProfileLearner {
pub fn new(config: LearnerConfig) -> Self {
Self { config }
}
pub fn learn_from_bam(&self, bam_path: &Path) -> Result<ProfileJson> {
let file = std::fs::File::open(bam_path)
.with_context(|| format!("cannot open BAM: {}", bam_path.display()))?;
let mut reader = bam::io::Reader::new(BufReader::new(file));
let header = reader.read_header().context("failed to read BAM header")?;
let initial_read_length = read_length_from_header(&header).unwrap_or(150);
let mut stats = LearnerStats::new(initial_read_length);
let mut records_seen: u64 = 0;
for result in reader.record_bufs(&header) {
if records_seen >= self.config.sample_size as u64 {
break;
}
let record = result.context("failed to read BAM record")?;
if should_skip(&record, self.config.min_mapq) {
stats.reads_skipped += 1;
continue;
}
let rlen = record.sequence().len();
if rlen > stats.max_read_length {
let new_len = rlen;
stats.quality_counts_r1.resize(new_len, HashMap::new());
stats.quality_counts_r2.resize(new_len, HashMap::new());
stats.max_read_length = new_len;
}
let is_r2 = record.flags().is_last_segment();
let quals: Vec<u8> = record.quality_scores().as_ref().to_vec();
let target = if is_r2 {
&mut stats.quality_counts_r2
} else {
&mut stats.quality_counts_r1
};
for (pos, &q) in quals.iter().enumerate() {
if pos < target.len() {
*target[pos].entry(q).or_insert(0) += 1;
}
}
if !is_r2 {
let tlen = record.template_length();
if tlen > 0 {
*stats.insert_size_counts.entry(tlen).or_insert(0) += 1;
}
}
let seq_bytes: Vec<u8> = record.sequence().iter().collect();
if !seq_bytes.is_empty() {
let gc = gc_percent(&seq_bytes);
let depth = 1u64; stats.gc_bias[gc].0 += depth;
stats.gc_bias[gc].1 += 1;
}
stats.reads_examined += 1;
records_seen += 1;
}
self.build_profile(stats)
}
pub fn build_profile(&self, stats: LearnerStats) -> Result<ProfileJson> {
anyhow::ensure!(
stats.reads_examined > 0,
"no usable reads found in BAM (all filtered or file empty)"
);
let read_length = stats.max_read_length;
let read1 = build_quality_distribution(&stats.quality_counts_r1, read_length);
let has_r2 = stats.quality_counts_r2.iter().any(|h| !h.is_empty());
let read2 = if has_r2 {
Some(build_quality_distribution(
&stats.quality_counts_r2,
read_length,
))
} else {
None
};
let substitution_matrix = build_substitution_matrix(&stats.substitution_counts);
let context_effects =
build_context_effects(&stats.context_error_counts, &stats.context_total_counts);
Ok(ProfileJson {
platform: None,
read_length,
quality_distribution: QualityDistributionJson { read1, read2 },
substitution_matrix,
context_effects,
})
}
#[cfg(test)]
pub fn learn_from_records(
&self,
records: &[RecordBuf],
read_length: usize,
) -> Result<ProfileJson> {
let mut stats = LearnerStats::new(read_length);
for record in records {
if should_skip(record, self.config.min_mapq) {
stats.reads_skipped += 1;
continue;
}
if stats.reads_examined >= self.config.sample_size as u64 {
break;
}
let rlen = record.sequence().len();
if rlen > stats.max_read_length {
stats.quality_counts_r1.resize(rlen, HashMap::new());
stats.quality_counts_r2.resize(rlen, HashMap::new());
stats.max_read_length = rlen;
}
let is_r2 = record.flags().is_last_segment();
let quals: Vec<u8> = record.quality_scores().as_ref().to_vec();
let target = if is_r2 {
&mut stats.quality_counts_r2
} else {
&mut stats.quality_counts_r1
};
for (pos, &q) in quals.iter().enumerate() {
if pos < target.len() {
*target[pos].entry(q).or_insert(0) += 1;
}
}
if !is_r2 {
let tlen = record.template_length();
if tlen > 0 {
*stats.insert_size_counts.entry(tlen).or_insert(0) += 1;
}
}
let seq_bytes: Vec<u8> = record.sequence().iter().collect();
if !seq_bytes.is_empty() {
let gc = gc_percent(&seq_bytes);
stats.gc_bias[gc].0 += 1;
stats.gc_bias[gc].1 += 1;
}
stats.reads_examined += 1;
}
self.build_profile(stats)
}
}
fn should_skip(record: &RecordBuf, min_mapq: u8) -> bool {
let flags = record.flags();
if flags.is_unmapped()
|| flags.is_supplementary()
|| flags.is_secondary()
|| flags.is_duplicate()
{
return true;
}
if let Some(mapq) = record.mapping_quality() {
if u8::from(mapq) < min_mapq {
return true;
}
}
false
}
fn gc_percent(seq: &[u8]) -> usize {
if seq.is_empty() {
return 0;
}
let gc_count = seq
.iter()
.filter(|&&b| matches!(b, b'G' | b'C' | b'g' | b'c'))
.count();
(gc_count * 100 / seq.len()).min(100)
}
fn build_quality_distribution(
counts: &[HashMap<u8, u64>],
read_length: usize,
) -> Vec<Vec<[f64; 2]>> {
let mut dist = Vec::with_capacity(read_length);
for pos_counts in counts.iter().take(read_length) {
let total: u64 = pos_counts.values().sum();
if total == 0 {
dist.push(vec![[30.0, 1.0]]);
continue;
}
let total_f = total as f64;
let mut entries: Vec<[f64; 2]> = pos_counts
.iter()
.map(|(&q, &cnt)| [q as f64, cnt as f64 / total_f])
.collect();
entries.sort_by(|a, b| b[0].partial_cmp(&a[0]).unwrap_or(std::cmp::Ordering::Equal));
dist.push(entries);
}
while dist.len() < read_length {
dist.push(vec![[30.0, 1.0]]);
}
dist
}
fn build_substitution_matrix(counts: &HashMap<String, u64>) -> HashMap<String, f64> {
let mut from_totals: HashMap<u8, u64> = HashMap::new();
for (key, &cnt) in counts {
if let Some(from) = key.as_bytes().first() {
*from_totals.entry(*from).or_insert(0) += cnt;
}
}
let mut matrix = HashMap::new();
for (key, &cnt) in counts {
if cnt == 0 {
continue;
}
let from = key.as_bytes()[0];
let total = *from_totals.get(&from).unwrap_or(&1);
if total > 0 {
matrix.insert(key.clone(), cnt as f64 / total as f64);
}
}
matrix
}
fn build_context_effects(
error_counts: &HashMap<String, u64>,
total_counts: &HashMap<String, u64>,
) -> HashMap<String, ContextEffectJson> {
if error_counts.is_empty() || total_counts.is_empty() {
return HashMap::new();
}
let total_errors: u64 = error_counts.values().sum();
let total_obs: u64 = total_counts.values().sum();
if total_obs == 0 {
return HashMap::new();
}
let global_rate = total_errors as f64 / total_obs as f64;
let mut effects = HashMap::new();
for (ctx, &obs) in total_counts {
if obs < 100 {
continue; }
let errors = error_counts.get(ctx).copied().unwrap_or(0);
let rate = errors as f64 / obs as f64;
if rate > global_rate * 5.0 {
let penalty = (-10.0 * (global_rate / rate).log10()).round() as u8;
let penalty = penalty.clamp(1, 20);
effects.insert(
ctx.clone(),
ContextEffectJson {
quality_penalty: penalty,
},
);
}
}
effects
}
fn read_length_from_header(header: &sam::Header) -> Option<usize> {
let _ = header;
None
}
pub fn write_quality_tsv(profile: &ProfileJson, path: &std::path::Path) -> anyhow::Result<()> {
use std::io::Write as _;
let f = std::fs::File::create(path)
.with_context(|| format!("failed to create profile TSV: {}", path.display()))?;
let mut w = std::io::BufWriter::new(f);
let mut header = "cycle".to_string();
for q in 0u8..=93 {
header.push_str(&format!("\tq{}", q));
}
writeln!(w, "{}", header)?;
for (cycle, entries) in profile.quality_distribution.read1.iter().enumerate() {
let mut weights = vec![0.0f64; 94];
for entry in entries {
let q = entry[0] as usize;
if q < 94 {
weights[q] = entry[1];
}
}
let mut row = cycle.to_string();
for &wt in &weights {
row.push_str(&format!("\t{:.6}", wt));
}
writeln!(w, "{}", row)?;
}
w.flush()?;
Ok(())
}
#[allow(dead_code)]
pub fn read_quality_tsv(path: &std::path::Path) -> anyhow::Result<Vec<Vec<[f64; 2]>>> {
use std::io::BufRead as _;
let f = std::fs::File::open(path)
.with_context(|| format!("failed to open profile TSV: {}", path.display()))?;
let reader = std::io::BufReader::new(f);
let mut lines = reader.lines();
lines
.next()
.ok_or_else(|| anyhow::anyhow!("profile TSV is empty"))??;
let mut distribution = Vec::new();
for line in lines {
let line = line?;
if line.trim().is_empty() {
continue;
}
let fields: Vec<&str> = line.split('\t').collect();
let mut entries: Vec<[f64; 2]> = fields[1..]
.iter()
.enumerate()
.filter_map(|(i, s)| {
let w: f64 = s.trim().parse().unwrap_or(0.0);
if w > 0.0 {
Some([i as f64, w])
} else {
None
}
})
.collect();
if entries.is_empty() {
entries.push([30.0, 1.0]); }
distribution.push(entries);
}
Ok(distribution)
}
#[cfg(test)]
pub fn make_test_record(
seq: &[u8],
quals: &[u8],
is_r2: bool,
template_len: i32,
mapq: u8,
pos: usize,
ref_id: usize,
) -> Result<RecordBuf> {
use crate::io::bam::parse_cigar;
use noodles_core::Position;
use noodles_sam::alignment::{
record::MappingQuality,
record_buf::{Cigar, QualityScores, Sequence},
};
let cigar_str = format!("{}M", seq.len());
let cigar_ops = parse_cigar(&cigar_str)?;
let mut flags = noodles_sam::alignment::record::Flags::SEGMENTED
| noodles_sam::alignment::record::Flags::PROPERLY_SEGMENTED;
if is_r2 {
flags |= noodles_sam::alignment::record::Flags::REVERSE_COMPLEMENTED
| noodles_sam::alignment::record::Flags::LAST_SEGMENT;
} else {
flags |= noodles_sam::alignment::record::Flags::MATE_REVERSE_COMPLEMENTED
| noodles_sam::alignment::record::Flags::FIRST_SEGMENT;
}
let alignment_pos =
Position::new(pos + 1).ok_or_else(|| anyhow::anyhow!("invalid position"))?;
let mate_pos = Position::new(pos + seq.len() + 1)
.ok_or_else(|| anyhow::anyhow!("invalid mate position"))?;
let record = RecordBuf::builder()
.set_flags(flags)
.set_reference_sequence_id(ref_id)
.set_alignment_start(alignment_pos)
.set_mapping_quality(
MappingQuality::new(mapq).ok_or_else(|| anyhow::anyhow!("invalid mapq"))?,
)
.set_cigar(cigar_ops.into_iter().collect::<Cigar>())
.set_mate_reference_sequence_id(ref_id)
.set_mate_alignment_start(mate_pos)
.set_template_length(template_len)
.set_sequence(Sequence::from(seq))
.set_quality_scores(QualityScores::from(quals.to_vec()))
.build();
Ok(record)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::error_profile::EmpiricalQualityModel;
use crate::core::quality::QualityModel;
use rand::rngs::StdRng;
use rand::SeedableRng;
fn learner() -> ProfileLearner {
ProfileLearner::new(LearnerConfig {
sample_size: 100_000,
min_mapq: 0, })
}
fn uniform_r1_records(n: usize, read_len: usize, q: u8, tlen: i32) -> Vec<RecordBuf> {
let seq = vec![b'A'; read_len];
let quals = vec![q; read_len];
(0..n)
.map(|i| {
make_test_record(&seq, &quals, false, tlen, 60, i * (read_len + 50), 0)
.expect("make_test_record should not fail")
})
.collect()
}
#[test]
fn test_quality_extraction() {
let records = uniform_r1_records(200, 10, 37, 200);
let profile = learner()
.learn_from_records(&records, 10)
.expect("learn_from_records should succeed");
assert_eq!(
profile.quality_distribution.read1.len(),
10,
"wrong number of positions"
);
for (pos, entries) in profile.quality_distribution.read1.iter().enumerate() {
assert_eq!(
entries.len(),
1,
"pos {pos}: expected exactly 1 quality bucket"
);
assert!(
(entries[0][0] - 37.0).abs() < f64::EPSILON,
"pos {pos}: quality should be 37"
);
assert!(
(entries[0][1] - 1.0).abs() < 1e-9,
"pos {pos}: weight should be 1.0"
);
}
}
#[test]
fn test_fragment_size_extraction() {
let read_len = 10usize;
let seq = vec![b'A'; read_len];
let quals = vec![30u8; read_len];
let tlens = [150i32, 150, 150, 200, 200, 250];
let records: Vec<RecordBuf> = tlens
.iter()
.enumerate()
.map(|(i, &tlen)| {
make_test_record(&seq, &quals, false, tlen, 60, i * 300, 0)
.expect("make_test_record failed")
})
.collect();
let learner = ProfileLearner::new(LearnerConfig {
sample_size: 100_000,
min_mapq: 0,
});
let mut stats = LearnerStats::new(read_len);
for record in &records {
let tlen = record.template_length();
if tlen > 0 {
*stats.insert_size_counts.entry(tlen).or_insert(0) += 1;
}
stats.reads_examined += 1;
}
assert_eq!(*stats.insert_size_counts.get(&150).unwrap_or(&0), 3);
assert_eq!(*stats.insert_size_counts.get(&200).unwrap_or(&0), 2);
assert_eq!(*stats.insert_size_counts.get(&250).unwrap_or(&0), 1);
let profile = learner
.learn_from_records(&records, read_len)
.expect("learn_from_records should succeed");
assert!(!profile.quality_distribution.read1.is_empty());
}
#[test]
fn test_gc_bias_extraction() {
let read_len = 10usize;
let at_seq = vec![b'A'; read_len];
let gc_seq = vec![b'G'; read_len];
let quals = vec![30u8; read_len];
let at_records: Vec<RecordBuf> = (0..5)
.map(|i| {
make_test_record(&at_seq, &quals, false, 200, 60, i * 300, 0)
.expect("make_test_record failed")
})
.collect();
let gc_records: Vec<RecordBuf> = (5..10)
.map(|i| {
make_test_record(&gc_seq, &quals, false, 200, 60, i * 300, 0)
.expect("make_test_record failed")
})
.collect();
let mut all: Vec<RecordBuf> = at_records;
all.extend(gc_records);
let mut stats = LearnerStats::new(read_len);
for record in &all {
let seq_bytes: Vec<u8> = record.sequence().iter().collect();
let gc = gc_percent(&seq_bytes);
stats.gc_bias[gc].0 += 1;
stats.gc_bias[gc].1 += 1;
stats.reads_examined += 1;
}
assert_eq!(stats.gc_bias[0].1, 5, "expected 5 reads in GC=0 bin");
assert_eq!(stats.gc_bias[100].1, 5, "expected 5 reads in GC=100 bin");
for pct in 1..100 {
assert_eq!(stats.gc_bias[pct].1, 0, "unexpected reads in GC={pct} bin");
}
}
#[test]
fn test_output_format() {
let records = uniform_r1_records(50, 10, 35, 180);
let profile = learner()
.learn_from_records(&records, 10)
.expect("learn_from_records should succeed");
assert!(
!profile.quality_distribution.read1.is_empty(),
"read1 quality distribution must not be empty"
);
assert!(profile.read_length > 0, "read_length must be > 0");
let json = serde_json::to_string(&profile).expect("profile must serialize to JSON");
assert!(json.contains("read1"), "JSON must contain read1 key");
assert!(
json.contains("quality_distribution"),
"JSON must contain quality_distribution"
);
let model = EmpiricalQualityModel::from_json_str(&json)
.expect("serialized profile must be parseable by EmpiricalQualityModel");
assert_eq!(model.platform, None);
}
#[test]
fn test_sampling() {
let records = uniform_r1_records(200, 10, 30, 150);
let small_learner = ProfileLearner::new(LearnerConfig {
sample_size: 50,
min_mapq: 0,
});
let profile = small_learner
.learn_from_records(&records, 10)
.expect("learn_from_records should succeed with sample_size limit");
assert_eq!(
profile.quality_distribution.read1.len(),
10,
"quality distribution length should equal read_length"
);
let mut stats = LearnerStats::new(10);
for (i, record) in records.iter().enumerate() {
if i >= 50 {
break;
}
let quals: Vec<u8> = record.quality_scores().as_ref().to_vec();
for (pos, &q) in quals.iter().enumerate() {
*stats.quality_counts_r1[pos].entry(q).or_insert(0) += 1;
}
stats.reads_examined += 1;
}
assert_eq!(
stats.reads_examined, 50,
"should have examined exactly 50 reads"
);
}
#[test]
fn test_round_trip() {
let records = uniform_r1_records(500, 20, 30, 200);
let profile = learner()
.learn_from_records(&records, 20)
.expect("learn_from_records should succeed");
let json = serde_json::to_string(&profile).expect("profile must serialize");
let model =
EmpiricalQualityModel::from_json_str(&json).expect("serialized profile must reload");
let mut rng = StdRng::seed_from_u64(42);
let mut total: u64 = 0;
let n_reads = 1_000usize;
for _ in 0..n_reads {
let quals = model.generate_qualities(20, &mut rng);
total += quals.iter().map(|&q| q as u64).sum::<u64>();
}
let mean_q = total as f64 / (n_reads * 20) as f64;
assert!(
(mean_q - 30.0).abs() < 1.0,
"mean quality {mean_q} should be close to 30"
);
}
}