use rust_htslib::bam;
use std::collections::{HashMap, HashSet};
use std::hash::{BuildHasher, Hasher};
use std::io::Write;
use std::path::Path;
use anyhow::{Context, Result};
use indexmap::IndexMap;
use log::debug;
use rand::Rng;
use crate::gtf::Gene;
#[derive(Clone, Debug)]
pub(crate) struct TinHashState(u64);
impl BuildHasher for TinHashState {
type Hasher = std::collections::hash_map::DefaultHasher;
fn build_hasher(&self) -> Self::Hasher {
let mut h = std::collections::hash_map::DefaultHasher::new();
h.write_u64(self.0);
h
}
}
impl TinHashState {
fn new(seed: Option<u64>) -> Self {
TinHashState(seed.unwrap_or_else(|| rand::rng().next_u64()))
}
}
#[derive(Debug, Clone)]
pub struct TranscriptSampling {
pub gene_id: String,
pub chrom: String,
pub chrom_upper: String,
pub tx_start: u64,
pub tx_end: u64,
#[allow(dead_code)] pub exon_regions: Vec<(u64, u64)>,
#[allow(dead_code)] pub exon_length: u64,
pub sampled_positions: Vec<u64>,
pub total_position_count: usize,
}
#[derive(Debug, Default)]
pub struct TinIndex {
pub transcripts: Vec<TranscriptSampling>,
pub chrom_positions: HashMap<String, Vec<(u64, u32, u32)>>,
pub chrom_spans: HashMap<String, Vec<(u64, u64, u32)>>,
}
#[derive(Debug)]
pub struct TinResult {
pub gene_id: String,
pub chrom: String,
pub tx_start: u64,
pub tx_end: u64,
pub tin: f64,
pub passed_threshold: bool,
}
#[derive(Debug)]
pub struct TinResults {
pub transcripts: Vec<TinResult>,
}
impl TinResults {
pub fn len(&self) -> usize {
self.transcripts.len()
}
}
impl TinIndex {
pub fn from_genes(genes: &IndexMap<String, Gene>, sample_size: usize) -> Self {
let mut index = TinIndex::default();
for gene in genes.values() {
for tx in &gene.transcripts {
let exon_regions: Vec<(u64, u64)> = tx
.exons
.iter()
.map(|&(s, e)| (s.saturating_sub(1), e)) .collect();
let exon_length: u64 = exon_regions.iter().map(|(s, e)| e - s).sum();
if exon_length == 0 {
continue;
}
let chrom = gene.chrom.clone();
let chrom_upper = chrom.to_uppercase();
let tx_start = exon_regions.first().map(|r| r.0).unwrap_or(0);
let tx_end = exon_regions.last().map(|r| r.1).unwrap_or(0);
let (sampled, total_count) =
sample_exonic_positions(&exon_regions, sample_size, tx_start, tx_end);
if sampled.is_empty() {
continue;
}
index.add_transcript(TranscriptSampling {
gene_id: tx.transcript_id.clone(),
chrom,
chrom_upper,
tx_start,
tx_end,
exon_regions,
exon_length,
sampled_positions: sampled,
total_position_count: total_count,
});
}
}
index.build();
index
}
fn add_transcript(&mut self, tx: TranscriptSampling) {
let tx_idx = self.transcripts.len() as u32;
let chrom_upper = tx.chrom_upper.clone();
let tx_start = tx.tx_start;
let tx_end = tx.tx_end;
for (slot_idx, &pos) in tx.sampled_positions.iter().enumerate() {
self.chrom_positions
.entry(chrom_upper.clone())
.or_default()
.push((pos, tx_idx, slot_idx as u32));
}
self.chrom_spans
.entry(chrom_upper)
.or_default()
.push((tx_start, tx_end, tx_idx));
self.transcripts.push(tx);
}
fn build(&mut self) {
for positions in self.chrom_positions.values_mut() {
positions.sort_unstable();
}
for spans in self.chrom_spans.values_mut() {
spans.sort_unstable();
}
debug!(
"TIN index: {} transcripts, {} chromosomes",
self.transcripts.len(),
self.chrom_positions.len()
);
}
}
fn sample_exonic_positions(
exon_regions: &[(u64, u64)],
n: usize,
tx_start: u64,
tx_end: u64,
) -> (Vec<u64>, usize) {
let mrna_size: u64 = exon_regions.iter().map(|(s, e)| e - s).sum();
if mrna_size == 0 {
return (Vec::new(), 0);
}
if mrna_size <= n as u64 {
let mut positions: Vec<u64> = Vec::with_capacity(mrna_size as usize + 2);
positions.push(tx_start + 1); positions.push(tx_end); for &(start, end) in exon_regions {
for pos in (start + 1)..=end {
positions.push(pos);
}
}
positions.sort_unstable();
let total_count = positions.len();
positions.dedup();
(positions, total_count)
} else {
let mut gene_all_base: Vec<u64> = Vec::with_capacity(mrna_size as usize);
let mut exon_bounds: Vec<u64> = Vec::new();
for &(start, end) in exon_regions {
for pos in (start + 1)..=end {
gene_all_base.push(pos);
}
exon_bounds.push(start + 1); exon_bounds.push(end); }
let step_size = (mrna_size as usize) / n;
let step_size = step_size.max(1);
let mut chose_bases: Vec<u64> = (0..gene_all_base.len())
.step_by(step_size)
.map(|i| gene_all_base[i])
.collect();
let mut all_positions = exon_bounds;
all_positions.append(&mut chose_bases);
let mut seen = HashSet::new();
all_positions.retain(|x| seen.insert(*x));
all_positions.sort_unstable();
let total_count = all_positions.len(); (all_positions, total_count)
}
}
#[derive(Debug)]
pub struct TinAccum {
pub coverage: Vec<Vec<u32>>,
pub unique_starts: Vec<HashSet<u64, TinHashState>>,
pub exceeded_threshold: Vec<bool>,
#[allow(dead_code)]
pub n_samples: Vec<u32>,
#[allow(dead_code)]
pub mapq_cut: u8,
pub min_cov: u32,
hash_state: TinHashState,
blocks_buf: Vec<(u64, u64)>,
cursor_chrom: String,
cursor_span: usize,
cursor_pos: usize,
}
impl TinAccum {
pub fn new(index: &TinIndex, mapq_cut: u8, min_cov: u32, seed: Option<u64>) -> Self {
let n_transcripts = index.transcripts.len();
let mut coverage = Vec::with_capacity(n_transcripts);
let mut n_samples = Vec::with_capacity(n_transcripts);
let hash_state = TinHashState::new(seed);
for tx in &index.transcripts {
let n = tx.sampled_positions.len();
coverage.push(vec![0u32; n]);
n_samples.push(n as u32);
}
let unique_starts = (0..n_transcripts)
.map(|_| HashSet::with_hasher(hash_state.clone()))
.collect();
TinAccum {
coverage,
unique_starts,
exceeded_threshold: vec![false; n_transcripts],
n_samples,
mapq_cut,
min_cov,
hash_state,
blocks_buf: Vec::with_capacity(8),
cursor_chrom: String::new(),
cursor_span: 0,
cursor_pos: 0,
}
}
pub fn process_read(&mut self, record: &bam::Record, chrom_upper: &str, index: &TinIndex) {
let flags = record.flags();
if flags & 0x4 != 0 || flags & 0x200 != 0 || flags & 0x100 != 0 {
return;
}
fill_aligned_blocks(record, &mut self.blocks_buf);
let blocks = &self.blocks_buf;
if blocks.is_empty() {
return;
}
let read_start = blocks[0].0;
let chrom_positions = match index.chrom_positions.get(chrom_upper) {
Some(p) => p,
None => return,
};
let chrom_spans = match index.chrom_spans.get(chrom_upper) {
Some(s) => s,
None => return,
};
if chrom_upper != self.cursor_chrom {
self.cursor_chrom.clear();
self.cursor_chrom.push_str(chrom_upper);
self.cursor_span = 0;
self.cursor_pos = 0;
}
let span_end = self.cursor_span
+ chrom_spans[self.cursor_span..].partition_point(|s| s.0 <= read_start);
self.cursor_span = span_end;
for &(_tx_start, tx_end, tx_idx) in &chrom_spans[..span_end] {
if read_start < tx_end {
let idx = tx_idx as usize;
if !self.exceeded_threshold[idx] {
self.unique_starts[idx].insert(read_start);
if self.unique_starts[idx].len() > self.min_cov as usize {
self.exceeded_threshold[idx] = true;
self.unique_starts[idx] = HashSet::with_hasher(self.hash_state.clone());
}
}
}
}
if flags & 0x400 != 0 {
return;
}
let mut first_block = true;
for &(block_start, block_end) in blocks.iter() {
let base = self.cursor_pos;
let pos_start = base + chrom_positions[base..].partition_point(|p| p.0 <= block_start);
if first_block {
self.cursor_pos = pos_start;
first_block = false;
}
for &(pos, tx_idx, slot_idx) in &chrom_positions[pos_start..] {
if pos > block_end {
break;
}
self.coverage[tx_idx as usize][slot_idx as usize] += 1;
}
}
}
pub fn merge(&mut self, other: TinAccum) {
for (i, other_cov) in other.coverage.into_iter().enumerate() {
for (j, count) in other_cov.into_iter().enumerate() {
self.coverage[i][j] += count;
}
if self.exceeded_threshold[i] || other.exceeded_threshold[i] {
self.exceeded_threshold[i] = true;
self.unique_starts[i] = HashSet::with_hasher(self.hash_state.clone());
} else {
self.unique_starts[i].extend(other.unique_starts[i].iter());
if self.unique_starts[i].len() > self.min_cov as usize {
self.exceeded_threshold[i] = true;
self.unique_starts[i] = HashSet::with_hasher(self.hash_state.clone());
}
}
}
}
pub fn into_result(self, index: &TinIndex) -> TinResults {
let mut transcripts = Vec::with_capacity(index.transcripts.len());
for (i, tx) in index.transcripts.iter().enumerate() {
let passed =
self.exceeded_threshold[i] || self.unique_starts[i].len() as u32 > self.min_cov;
let coverage = &self.coverage[i];
if !passed {
transcripts.push(TinResult {
gene_id: tx.gene_id.clone(),
chrom: tx.chrom.clone(),
tx_start: tx.tx_start,
tx_end: tx.tx_end,
tin: 0.0,
passed_threshold: false,
});
continue;
}
let n_total = tx.total_position_count;
let tin = compute_tin(coverage, n_total);
transcripts.push(TinResult {
gene_id: tx.gene_id.clone(),
chrom: tx.chrom.clone(),
tx_start: tx.tx_start,
tx_end: tx.tx_end,
tin,
passed_threshold: true,
});
}
TinResults { transcripts }
}
}
fn compute_tin(coverage: &[u32], n_total_positions: usize) -> f64 {
if n_total_positions == 0 {
return 0.0;
}
let nonzero: Vec<f64> = coverage
.iter()
.filter(|&&c| c > 0)
.map(|&c| c as f64)
.collect();
if nonzero.is_empty() {
return 0.0;
}
let total: f64 = nonzero.iter().sum();
if total == 0.0 {
return 0.0;
}
let entropy: f64 = nonzero
.iter()
.map(|&c| {
let p = c / total;
-p * p.ln()
})
.sum();
100.0 * entropy.exp() / n_total_positions as f64
}
fn fill_aligned_blocks(record: &bam::Record, buf: &mut Vec<(u64, u64)>) {
use rust_htslib::bam::record::Cigar;
buf.clear();
let mut pos = record.pos() as u64;
for op in record.cigar().iter() {
match op {
Cigar::Match(len) | Cigar::Equal(len) | Cigar::Diff(len) => {
buf.push((pos, pos + *len as u64));
pos += *len as u64;
}
Cigar::Del(len) | Cigar::RefSkip(len) => {
pos += *len as u64;
}
Cigar::Ins(_) | Cigar::SoftClip(_) | Cigar::HardClip(_) | Cigar::Pad(_) => {}
}
}
}
pub fn write_tin(results: &TinResults, output_path: &Path) -> Result<()> {
let mut f = std::fs::File::create(output_path)
.with_context(|| format!("Failed to create TIN output: {}", output_path.display()))?;
writeln!(f, "geneID\tchrom\ttx_start\ttx_end\tTIN")?;
for r in &results.transcripts {
if !r.passed_threshold {
writeln!(
f,
"{}\t{}\t{}\t{}\t0.0",
r.gene_id, r.chrom, r.tx_start, r.tx_end
)?;
} else if r.tin == 0.0 {
writeln!(
f,
"{}\t{}\t{}\t{}\t0",
r.gene_id, r.chrom, r.tx_start, r.tx_end
)?;
} else {
writeln!(
f,
"{}\t{}\t{}\t{}\t{}",
r.gene_id, r.chrom, r.tx_start, r.tx_end, r.tin
)?;
}
}
Ok(())
}
pub fn write_tin_summary(results: &TinResults, bam_name: &str, output_path: &Path) -> Result<()> {
let mut f = std::fs::File::create(output_path)
.with_context(|| format!("Failed to create TIN summary: {}", output_path.display()))?;
let scores: Vec<f64> = results
.transcripts
.iter()
.filter(|r| r.passed_threshold)
.map(|r| r.tin)
.collect();
let (mean, median, stdev) = if scores.is_empty() {
(0.0, 0.0, 0.0)
} else {
let n = scores.len() as f64;
let mean = scores.iter().sum::<f64>() / n;
let median = crate::io::median(&scores);
let variance = scores.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / n;
let stdev = variance.sqrt();
(mean, median, stdev)
};
writeln!(f, "Bam_file\tTIN(mean)\tTIN(median)\tTIN(stdev)")?;
writeln!(f, "{}\t{}\t{}\t{}", bam_name, mean, median, stdev)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sample_exonic_positions() {
let exons = vec![(100, 110)];
let (positions, total_count) = sample_exonic_positions(&exons, 5, 100, 110);
assert_eq!(positions.len(), 6); assert_eq!(total_count, 6); assert_eq!(positions[0], 101);
assert_eq!(positions[5], 110);
}
#[test]
fn test_sample_multi_exon() {
let exons = vec![(100, 105), (200, 205)];
let (positions, total_count) = sample_exonic_positions(&exons, 10, 100, 205);
assert_eq!(positions.len(), 10); assert_eq!(total_count, 12); assert_eq!(positions[0], 101);
assert_eq!(positions[9], 205);
}
#[test]
fn test_sample_tiny_exon() {
let exons = vec![(100, 103)];
let (positions, total_count) = sample_exonic_positions(&exons, 100, 100, 103);
assert_eq!(positions.len(), 3); assert_eq!(total_count, 5); assert_eq!(positions[0], 101);
assert_eq!(positions[2], 103);
}
#[test]
fn test_compute_tin_uniform() {
let cov = vec![10, 10, 10, 10, 10, 10, 10, 10, 10, 10];
let tin = compute_tin(&cov, cov.len());
assert!(
(tin - 100.0).abs() < 0.01,
"Uniform coverage TIN should be ~100: {tin}"
);
}
#[test]
fn test_compute_tin_uniform_with_zeros() {
let cov = vec![10, 10, 10, 10, 10, 0, 0, 0, 0, 0];
let tin = compute_tin(&cov, cov.len());
assert!(
(tin - 50.0).abs() < 0.01,
"Half-covered uniform TIN should be ~50: {tin}"
);
}
#[test]
fn test_compute_tin_degraded() {
let cov = vec![100, 80, 60, 40, 20, 10, 5, 2, 1, 1];
let tin = compute_tin(&cov, cov.len());
assert!(
tin > 0.0 && tin < 100.0,
"Degraded TIN should be between 0 and 100: {tin}"
);
}
#[test]
fn test_compute_tin_all_zero() {
let cov = vec![0, 0, 0, 0, 0];
let tin = compute_tin(&cov, cov.len());
assert_eq!(tin, 0.0);
}
#[test]
fn test_compute_tin_single_nonzero() {
let cov = vec![0, 0, 100, 0, 0];
let tin = compute_tin(&cov, cov.len());
assert!(
(tin - 20.0).abs() < 0.01,
"Single position coverage should give TIN=20: {tin}"
);
}
}