use crate::cli::Strandedness;
use crate::gtf::Gene;
use anyhow::{Context, Result};
use indexmap::IndexMap;
use log::debug;
use std::collections::HashMap;
use std::fs;
use std::io::Write;
use std::path::Path;
#[derive(Debug, Clone)]
struct TranscriptInterval {
start: u64,
end: u64,
strand: u8,
}
#[derive(Debug, Default)]
pub struct GeneModel {
intervals: HashMap<String, Vec<TranscriptInterval>>,
}
impl GeneModel {
pub fn from_genes(genes: &IndexMap<String, Gene>) -> Self {
let mut model = GeneModel::default();
let mut count: u64 = 0;
for gene in genes.values() {
if gene.transcripts.is_empty() {
let strand = match gene.strand {
'+' => b'+',
'-' => b'-',
_ => continue,
};
let start = gene.start.saturating_sub(1);
let end = gene.end;
model
.intervals
.entry(gene.chrom.clone())
.or_default()
.push(TranscriptInterval { start, end, strand });
count += 1;
} else {
for tx in &gene.transcripts {
let strand = match tx.strand {
'+' => b'+',
'-' => b'-',
_ => continue,
};
let start = tx.start.saturating_sub(1);
let end = tx.end;
model
.intervals
.entry(tx.chrom.clone())
.or_default()
.push(TranscriptInterval { start, end, strand });
count += 1;
}
}
}
for intervals in model.intervals.values_mut() {
intervals.sort_by_key(|iv| iv.start);
}
debug!("Loaded {} transcript intervals from GTF annotation", count);
model
}
pub fn find_strands(&self, chrom: &str, qstart: u64, qend: u64) -> Vec<u8> {
let mut strands = Vec::new();
if let Some(intervals) = self.intervals.get(chrom) {
let idx = intervals.partition_point(|iv| iv.start < qend);
for iv in intervals.iter().take(idx) {
if iv.end > qstart && !strands.contains(&iv.strand) {
strands.push(iv.strand);
}
}
}
strands
}
}
#[derive(Debug)]
pub struct InferExperimentResult {
pub total_sampled: u64,
pub library_type: String,
pub frac_failed: f64,
pub frac_protocol1: f64,
pub frac_protocol2: f64,
}
pub fn write_infer_experiment<P: AsRef<Path>>(
result: &InferExperimentResult,
output_path: P,
) -> Result<()> {
let output_path = output_path.as_ref();
let mut writer = fs::File::create(output_path)
.with_context(|| format!("Failed to create output file: {}", output_path.display()))?;
match result.library_type.as_str() {
"PairEnd" => {
writeln!(writer)?;
writeln!(writer)?;
writeln!(writer, "This is PairEnd Data")?;
writeln!(
writer,
"Fraction of reads failed to determine: {:.4}",
result.frac_failed
)?;
writeln!(
writer,
"Fraction of reads explained by \"1++,1--,2+-,2-+\": {:.4}",
result.frac_protocol1
)?;
writeln!(
writer,
"Fraction of reads explained by \"1+-,1-+,2++,2--\": {:.4}",
result.frac_protocol2
)?;
}
"SingleEnd" => {
writeln!(writer)?;
writeln!(writer)?;
writeln!(writer, "This is SingleEnd Data")?;
writeln!(
writer,
"Fraction of reads failed to determine: {:.4}",
result.frac_failed
)?;
writeln!(
writer,
"Fraction of reads explained by \"++,--\": {:.4}",
result.frac_protocol1
)?;
writeln!(
writer,
"Fraction of reads explained by \"+-,-+\": {:.4}",
result.frac_protocol2
)?;
}
_ => {
writeln!(writer)?;
writeln!(writer)?;
writeln!(writer, "Unknown Data type")?;
}
}
Ok(())
}
const STRAND_DOMINANT_THRESHOLD: f64 = 0.75;
const STRAND_UNSTRANDED_UPPER: f64 = 0.75;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum InferredStrandedness {
Forward,
Reverse,
Unstranded,
Undetermined,
}
impl std::fmt::Display for InferredStrandedness {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
InferredStrandedness::Forward => write!(f, "forward"),
InferredStrandedness::Reverse => write!(f, "reverse"),
InferredStrandedness::Unstranded => write!(f, "unstranded"),
InferredStrandedness::Undetermined => write!(f, "undetermined"),
}
}
}
pub fn infer_strandedness(result: &InferExperimentResult) -> InferredStrandedness {
if result.total_sampled == 0 {
return InferredStrandedness::Undetermined;
}
if result.frac_protocol1 > STRAND_DOMINANT_THRESHOLD {
InferredStrandedness::Forward
} else if result.frac_protocol2 > STRAND_DOMINANT_THRESHOLD {
InferredStrandedness::Reverse
} else if result.frac_protocol1 < STRAND_UNSTRANDED_UPPER
&& result.frac_protocol2 < STRAND_UNSTRANDED_UPPER
{
InferredStrandedness::Unstranded
} else {
InferredStrandedness::Undetermined
}
}
pub fn check_strandedness_mismatch(
result: &InferExperimentResult,
specified: Strandedness,
) -> Option<(InferredStrandedness, Strandedness)> {
let inferred = infer_strandedness(result);
if inferred == InferredStrandedness::Undetermined {
return None;
}
if matches!(
(specified, inferred),
(Strandedness::Forward, InferredStrandedness::Forward)
| (Strandedness::Reverse, InferredStrandedness::Reverse)
| (Strandedness::Unstranded, InferredStrandedness::Unstranded)
) {
return None;
}
let suggestion = match inferred {
InferredStrandedness::Forward => Strandedness::Forward,
InferredStrandedness::Reverse => Strandedness::Reverse,
InferredStrandedness::Unstranded => Strandedness::Unstranded,
InferredStrandedness::Undetermined => unreachable!(),
};
Some((inferred, suggestion))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::gtf::{Exon, Gene};
use indexmap::IndexMap;
fn make_gene(gene_id: &str, chrom: &str, start: u64, end: u64, strand: char) -> Gene {
Gene {
gene_id: gene_id.to_string(),
chrom: chrom.to_string(),
start,
end,
strand,
exons: vec![Exon {
chrom: chrom.to_string(),
start,
end,
strand,
}],
effective_length: end - start + 1,
attributes: HashMap::new(),
transcripts: Vec::new(),
}
}
#[test]
fn test_from_genes_basic() {
let mut genes = IndexMap::new();
genes.insert(
"GENE1".to_string(),
make_gene("GENE1", "chr1", 100, 500, '+'),
);
genes.insert(
"GENE2".to_string(),
make_gene("GENE2", "chr1", 1000, 2000, '-'),
);
let model = GeneModel::from_genes(&genes);
assert_eq!(model.intervals.len(), 1);
let chr1 = model.intervals.get("chr1").unwrap();
assert_eq!(chr1.len(), 2);
assert_eq!(chr1[0].start, 99);
assert_eq!(chr1[0].end, 500);
assert_eq!(chr1[0].strand, b'+');
assert_eq!(chr1[1].start, 999);
assert_eq!(chr1[1].end, 2000);
assert_eq!(chr1[1].strand, b'-');
}
#[test]
fn test_from_genes_skips_unknown_strand() {
let mut genes = IndexMap::new();
genes.insert(
"GENE1".to_string(),
make_gene("GENE1", "chr1", 100, 500, '.'),
);
genes.insert(
"GENE2".to_string(),
make_gene("GENE2", "chr1", 600, 800, '+'),
);
let model = GeneModel::from_genes(&genes);
let chr1 = model.intervals.get("chr1").unwrap();
assert_eq!(chr1.len(), 1); assert_eq!(chr1[0].strand, b'+');
}
#[test]
fn test_from_genes_find_strands() {
let mut genes = IndexMap::new();
genes.insert(
"GENE1".to_string(),
make_gene("GENE1", "chr1", 100, 500, '+'),
);
genes.insert(
"GENE2".to_string(),
make_gene("GENE2", "chr1", 300, 800, '-'),
);
let model = GeneModel::from_genes(&genes);
let strands = model.find_strands("chr1", 350, 400);
assert_eq!(strands.len(), 2);
assert!(strands.contains(&b'+'));
assert!(strands.contains(&b'-'));
let strands = model.find_strands("chr1", 550, 600);
assert_eq!(strands, vec![b'-']);
}
#[test]
fn test_find_strands_no_overlap() {
let model = GeneModel::default();
let strands = model.find_strands("chr1", 100, 200);
assert!(strands.is_empty());
}
#[test]
fn test_find_strands_single_overlap() {
let mut model = GeneModel::default();
model.intervals.insert(
"chr1".to_string(),
vec![TranscriptInterval {
start: 100,
end: 500,
strand: b'+',
}],
);
let strands = model.find_strands("chr1", 200, 300);
assert_eq!(strands, vec![b'+']);
}
#[test]
fn test_find_strands_both_strands() {
let mut model = GeneModel::default();
model.intervals.insert(
"chr1".to_string(),
vec![
TranscriptInterval {
start: 100,
end: 500,
strand: b'+',
},
TranscriptInterval {
start: 200,
end: 600,
strand: b'-',
},
],
);
let strands = model.find_strands("chr1", 250, 350);
assert_eq!(strands.len(), 2);
assert!(strands.contains(&b'+'));
assert!(strands.contains(&b'-'));
}
fn make_result(total: u64, frac_p1: f64, frac_p2: f64) -> InferExperimentResult {
InferExperimentResult {
total_sampled: total,
library_type: "PairEnd".to_string(),
frac_failed: 1.0 - frac_p1 - frac_p2,
frac_protocol1: frac_p1,
frac_protocol2: frac_p2,
}
}
#[test]
fn test_infer_strandedness_forward() {
let result = make_result(10000, 0.95, 0.03);
assert_eq!(infer_strandedness(&result), InferredStrandedness::Forward);
}
#[test]
fn test_infer_strandedness_reverse() {
let result = make_result(10000, 0.03, 0.95);
assert_eq!(infer_strandedness(&result), InferredStrandedness::Reverse);
}
#[test]
fn test_infer_strandedness_unstranded() {
let result = make_result(10000, 0.48, 0.48);
assert_eq!(
infer_strandedness(&result),
InferredStrandedness::Unstranded
);
}
#[test]
fn test_infer_strandedness_undetermined_no_reads() {
let result = make_result(0, 0.0, 0.0);
assert_eq!(
infer_strandedness(&result),
InferredStrandedness::Undetermined
);
}
#[test]
fn test_check_mismatch_forward_vs_reverse() {
let result = make_result(10000, 0.95, 0.03);
let mismatch = check_strandedness_mismatch(&result, Strandedness::Reverse);
assert!(mismatch.is_some());
let (inferred, suggestion) = mismatch.unwrap();
assert_eq!(inferred, InferredStrandedness::Forward);
assert_eq!(suggestion, Strandedness::Forward);
}
#[test]
fn test_check_mismatch_reverse_vs_forward() {
let result = make_result(10000, 0.03, 0.95);
let mismatch = check_strandedness_mismatch(&result, Strandedness::Forward);
assert!(mismatch.is_some());
let (inferred, suggestion) = mismatch.unwrap();
assert_eq!(inferred, InferredStrandedness::Reverse);
assert_eq!(suggestion, Strandedness::Reverse);
}
#[test]
fn test_check_mismatch_unstranded_vs_stranded() {
let result = make_result(10000, 0.48, 0.48);
let mismatch = check_strandedness_mismatch(&result, Strandedness::Reverse);
assert!(mismatch.is_some());
let (inferred, suggestion) = mismatch.unwrap();
assert_eq!(inferred, InferredStrandedness::Unstranded);
assert_eq!(suggestion, Strandedness::Unstranded);
}
#[test]
fn test_check_no_mismatch_when_matching() {
let result = make_result(10000, 0.95, 0.03);
assert!(check_strandedness_mismatch(&result, Strandedness::Forward).is_none());
let result = make_result(10000, 0.03, 0.95);
assert!(check_strandedness_mismatch(&result, Strandedness::Reverse).is_none());
let result = make_result(10000, 0.48, 0.48);
assert!(check_strandedness_mismatch(&result, Strandedness::Unstranded).is_none());
}
#[test]
fn test_check_no_mismatch_when_undetermined() {
let result = make_result(0, 0.0, 0.0);
assert!(check_strandedness_mismatch(&result, Strandedness::Forward).is_none());
}
}