use std::f32;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HmmStateType {
Begin,
Match,
Insert,
Delete,
End,
}
const NUM_AMINO_ACIDS: usize = 24;
pub struct ViterbiKernel;
impl ViterbiKernel {
pub fn viterbi_forward(
sequence: &[u8],
match_emissions: &[Vec<f32>],
insert_emissions: &[Vec<f32>],
transitions: &[Vec<f32>],
) -> Result<(Vec<Vec<f32>>, Vec<Vec<usize>>), String> {
let n = sequence.len();
let m = match_emissions.len();
if n == 0 || m == 0 {
return Err("Empty input".to_string());
}
let mut dp = vec![vec![f32::NEG_INFINITY; 3]; n + 1]; let mut path = vec![vec![0usize; 3]; n + 1];
dp[0][0] = 0.0;
for i in 1..=n {
let aa_idx = (sequence[i - 1] as usize) % NUM_AMINO_ACIDS;
let match_emission = match_emissions.get(i - 1)
.and_then(|e| e.get(aa_idx))
.copied()
.unwrap_or(f32::NEG_INFINITY);
let prev_match = dp[i - 1][0] + transitions.get(0).and_then(|t| t.get(0)).copied().unwrap_or(-1.0);
let prev_insert = dp[i - 1][1] + transitions.get(1).and_then(|t| t.get(0)).copied().unwrap_or(-1.0);
let prev_delete = dp[i - 1][2] + transitions.get(2).and_then(|t| t.get(0)).copied().unwrap_or(-1.0);
let max_prev = prev_match.max(prev_insert).max(prev_delete);
if max_prev.is_finite() {
dp[i][0] = max_prev + match_emission;
path[i][0] = if max_prev == prev_match { 0 } else if max_prev == prev_insert { 1 } else { 2 };
}
let insert_emission = insert_emissions.get(i - 1)
.and_then(|e| e.get(aa_idx))
.copied()
.unwrap_or(f32::NEG_INFINITY);
let prev_match_ins = dp[i - 1][0] + transitions.get(0).and_then(|t| t.get(1)).copied().unwrap_or(-1.0);
let prev_insert_ins = dp[i - 1][1] + transitions.get(1).and_then(|t| t.get(1)).copied().unwrap_or(-1.0);
let max_prev_ins = prev_match_ins.max(prev_insert_ins);
if max_prev_ins.is_finite() {
dp[i][1] = max_prev_ins + insert_emission;
path[i][1] = if max_prev_ins == prev_match_ins { 0 } else { 1 };
}
let prev_match_del = dp[i - 1][0] + transitions.get(0).and_then(|t| t.get(2)).copied().unwrap_or(-1.0);
let prev_delete_del = dp[i - 1][2] + transitions.get(2).and_then(|t| t.get(2)).copied().unwrap_or(-1.0);
let max_prev_del = prev_match_del.max(prev_delete_del);
if max_prev_del.is_finite() {
dp[i][2] = max_prev_del;
path[i][2] = if max_prev_del == prev_match_del { 0 } else { 2 };
}
}
Ok((dp, path))
}
pub fn backtrack(
sequence: &[u8],
_dp: &[Vec<f32>],
path_indices: &[Vec<usize>],
) -> Vec<HmmStateType> {
let mut alignment = Vec::new();
let mut current_state = 0;
let n = sequence.len();
for i in (1..=n).rev() {
alignment.push(match current_state {
0 => HmmStateType::Match,
1 => HmmStateType::Insert,
2 => HmmStateType::Delete,
_ => HmmStateType::Match,
});
current_state = path_indices.get(i).and_then(|p| p.get(current_state)).copied().unwrap_or(0);
}
alignment.reverse();
alignment
}
}
pub struct ForwardKernel;
impl ForwardKernel {
pub fn forward(
sequence: &[u8],
match_emissions: &[Vec<f32>],
insert_emissions: &[Vec<f32>],
transitions: &[Vec<f32>],
) -> Result<f32, String> {
let n = sequence.len();
let m = match_emissions.len();
if n == 0 || m == 0 {
return Err("Empty input".to_string());
}
let mut dp = vec![vec![f32::NEG_INFINITY; 3]; n + 1];
dp[0][0] = 0.0;
for i in 1..=n {
let aa_idx = (sequence[i - 1] as usize) % NUM_AMINO_ACIDS;
let match_emission = match_emissions.get(i - 1)
.and_then(|e| e.get(aa_idx))
.copied()
.unwrap_or(f32::NEG_INFINITY);
let insert_emission = insert_emissions.get(i - 1)
.and_then(|e| e.get(aa_idx))
.copied()
.unwrap_or(f32::NEG_INFINITY);
let match_prev = dp[i - 1][0] + transitions.get(0).and_then(|t| t.get(0)).copied().unwrap_or(-1.0);
let insert_prev = dp[i - 1][1] + transitions.get(1).and_then(|t| t.get(0)).copied().unwrap_or(-1.0);
let delete_prev = dp[i - 1][2] + transitions.get(2).and_then(|t| t.get(0)).copied().unwrap_or(-1.0);
dp[i][0] = Self::logsumexp(&[match_prev, insert_prev, delete_prev]) + match_emission;
let match_prev_ins = dp[i - 1][0] + transitions.get(0).and_then(|t| t.get(1)).copied().unwrap_or(-1.0);
let insert_prev_ins = dp[i - 1][1] + transitions.get(1).and_then(|t| t.get(1)).copied().unwrap_or(-1.0);
dp[i][1] = Self::logsumexp(&[match_prev_ins, insert_prev_ins]) + insert_emission;
let match_prev_del = dp[i - 1][0] + transitions.get(0).and_then(|t| t.get(2)).copied().unwrap_or(-1.0);
let delete_prev_del = dp[i - 1][2] + transitions.get(2).and_then(|t| t.get(2)).copied().unwrap_or(-1.0);
dp[i][2] = Self::logsumexp(&[match_prev_del, delete_prev_del]);
}
let final_val = Self::logsumexp(&[dp[n][0], dp[n][1], dp[n][2]]);
Ok(if final_val.is_finite() { final_val } else { f32::NEG_INFINITY })
}
fn logsumexp(values: &[f32]) -> f32 {
if values.is_empty() {
return f32::NEG_INFINITY;
}
let max_val = values.iter().copied().fold(f32::NEG_INFINITY, f32::max);
if !max_val.is_finite() {
return f32::NEG_INFINITY;
}
let sum_exp: f32 = values
.iter()
.map(|&v| (v - max_val).exp())
.sum();
max_val + sum_exp.ln()
}
}
pub struct BackwardKernel;
impl BackwardKernel {
pub fn backward(
sequence: &[u8],
match_emissions: &[Vec<f32>],
insert_emissions: &[Vec<f32>],
transitions: &[Vec<f32>],
) -> Result<Vec<Vec<f32>>, String> {
let n = sequence.len();
let m = match_emissions.len();
if n == 0 || m == 0 {
return Err("Empty input".to_string());
}
let mut dp = vec![vec![f32::NEG_INFINITY; 3]; n + 1];
dp[n][0] = 0.0;
for i in (1..n).rev() {
let aa_idx = (sequence[i] as usize) % NUM_AMINO_ACIDS;
let match_emission = match_emissions.get(i)
.and_then(|e| e.get(aa_idx))
.copied()
.unwrap_or(f32::NEG_INFINITY);
let insert_emission = insert_emissions.get(i)
.and_then(|e| e.get(aa_idx))
.copied()
.unwrap_or(f32::NEG_INFINITY);
let next_match = dp[i + 1][0] + transitions.get(0).and_then(|t| t.get(0)).copied().unwrap_or(-1.0) + match_emission;
let next_insert = dp[i + 1][1] + transitions.get(1).and_then(|t| t.get(0)).copied().unwrap_or(-1.0) + insert_emission;
let next_delete = dp[i + 1][2] + transitions.get(2).and_then(|t| t.get(0)).copied().unwrap_or(-1.0);
dp[i][0] = ForwardKernel::logsumexp(&[next_match, next_insert, next_delete]);
let next_match_ins = dp[i + 1][0] + transitions.get(0).and_then(|t| t.get(1)).copied().unwrap_or(-1.0) + match_emission;
let next_insert_ins = dp[i + 1][1] + transitions.get(1).and_then(|t| t.get(1)).copied().unwrap_or(-1.0) + insert_emission;
dp[i][1] = ForwardKernel::logsumexp(&[next_match_ins, next_insert_ins]);
let next_match_del = dp[i + 1][0] + transitions.get(0).and_then(|t| t.get(2)).copied().unwrap_or(-1.0);
let next_delete_del = dp[i + 1][2] + transitions.get(2).and_then(|t| t.get(2)).copied().unwrap_or(-1.0);
dp[i][2] = ForwardKernel::logsumexp(&[next_match_del, next_delete_del]);
}
Ok(dp)
}
}
pub struct BaumWelchKernel;
impl BaumWelchKernel {
pub fn baum_welch_iteration(
sequences: &[Vec<u8>],
match_emissions: &mut [Vec<f32>],
insert_emissions: &mut [Vec<f32>],
transitions: &mut [Vec<f32>],
iteration: usize,
) -> Result<f32, String> {
let mut likelihood = 0.0f32;
let mut emission_counts = vec![vec![0.0f32; NUM_AMINO_ACIDS]; match_emissions.len()];
let mut transition_counts = vec![vec![0.0f32; 3]; 3];
for sequence in sequences {
let forward = ForwardKernel::forward(sequence, match_emissions, insert_emissions, transitions)?;
let _backward = BackwardKernel::backward(sequence, match_emissions, insert_emissions, transitions)?;
likelihood += forward;
for i in 0..sequence.len() {
let aa_idx = (sequence[i] as usize) % NUM_AMINO_ACIDS;
if i < match_emissions.len() {
emission_counts[i][aa_idx] += 1.0;
}
if i > 0 {
let _prev_aa = (sequence[i - 1] as usize) % NUM_AMINO_ACIDS;
transition_counts[0][0] += 1.0; }
}
}
for i in 0..match_emissions.len() {
let total: f32 = emission_counts[i].iter().sum();
if total > 0.0 {
for j in 0..NUM_AMINO_ACIDS {
match_emissions[i][j] = (emission_counts[i][j] / total).max(0.001).ln();
}
}
}
let damping = 0.9f32.powi(iteration as i32);
let _ = damping;
Ok(likelihood / sequences.len() as f32)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_viterbi_simple() {
let sequence = vec![0u8, 1u8, 2u8]; let match_emissions = vec![
vec![-1.0; NUM_AMINO_ACIDS],
vec![-1.0; NUM_AMINO_ACIDS],
vec![-1.0; NUM_AMINO_ACIDS],
];
let insert_emissions = vec![
vec![-2.0; NUM_AMINO_ACIDS],
vec![-2.0; NUM_AMINO_ACIDS],
vec![-2.0; NUM_AMINO_ACIDS],
];
let transitions = vec![
vec![-0.5, -1.0, -1.0],
vec![-0.5, -1.0, -1.0],
vec![-0.5, -1.0, -1.0],
];
let result = ViterbiKernel::viterbi_forward(&sequence, &match_emissions, &insert_emissions, &transitions);
assert!(result.is_ok());
let (dp, _path) = result.unwrap();
assert_eq!(dp.len(), 4); }
#[test]
fn test_forward_algorithm() {
let sequence = vec![0u8, 1u8];
let match_emissions = vec![
vec![-1.0; NUM_AMINO_ACIDS],
vec![-1.0; NUM_AMINO_ACIDS],
];
let insert_emissions = vec![
vec![-2.0; NUM_AMINO_ACIDS],
vec![-2.0; NUM_AMINO_ACIDS],
];
let transitions = vec![
vec![-0.5, -1.0, -1.0],
vec![-0.5, -1.0, -1.0],
];
let result = ForwardKernel::forward(&sequence, &match_emissions, &insert_emissions, &transitions);
assert!(result.is_ok());
let score = result.unwrap();
assert!(score.is_finite());
}
#[test]
fn test_backward_algorithm() {
let sequence = vec![0u8, 1u8];
let match_emissions = vec![
vec![-1.0; NUM_AMINO_ACIDS],
vec![-1.0; NUM_AMINO_ACIDS],
];
let insert_emissions = vec![
vec![-2.0; NUM_AMINO_ACIDS],
vec![-2.0; NUM_AMINO_ACIDS],
];
let transitions = vec![
vec![-0.5, -1.0, -1.0],
vec![-0.5, -1.0, -1.0],
];
let result = BackwardKernel::backward(&sequence, &match_emissions, &insert_emissions, &transitions);
assert!(result.is_ok());
let dp = result.unwrap();
assert_eq!(dp.len(), 3); }
#[test]
fn test_logsumexp_stability() {
let values = vec![-100.0, -101.0, -102.0];
let result = ForwardKernel::logsumexp(&values);
assert!(result.is_finite());
assert!(result < -99.0); }
#[test]
fn test_viterbi_backtrack() {
let sequence = vec![0u8];
let dp = vec![
vec![0.0, f32::NEG_INFINITY, f32::NEG_INFINITY],
vec![-1.0, -2.0, -3.0],
];
let path = vec![
vec![0, 0, 0],
vec![0, 1, 2],
];
let alignment = ViterbiKernel::backtrack(&sequence, &dp, &path);
assert_eq!(alignment.len(), 1);
}
#[test]
fn test_baum_welch_iteration() {
let sequences = vec![vec![0u8, 1u8], vec![2u8, 3u8]];
let mut match_emissions = vec![
vec![-1.0; NUM_AMINO_ACIDS],
vec![-1.0; NUM_AMINO_ACIDS],
];
let mut insert_emissions = vec![
vec![-2.0; NUM_AMINO_ACIDS],
vec![-2.0; NUM_AMINO_ACIDS],
];
let mut transitions = vec![
vec![-0.5, -1.0, -1.0],
vec![-0.5, -1.0, -1.0],
];
let result = BaumWelchKernel::baum_welch_iteration(
&sequences,
&mut match_emissions,
&mut insert_emissions,
&mut transitions,
0,
);
assert!(result.is_ok());
let likelihood = result.unwrap();
assert!(likelihood.is_finite());
}
#[test]
fn test_empty_sequence_error() {
let sequence = vec![];
let match_emissions = vec![vec![-1.0; NUM_AMINO_ACIDS]];
let insert_emissions = vec![vec![-2.0; NUM_AMINO_ACIDS]];
let transitions = vec![vec![-0.5, -1.0, -1.0]];
let result = ViterbiKernel::viterbi_forward(&sequence, &match_emissions, &insert_emissions, &transitions);
assert!(result.is_err());
}
#[test]
fn test_single_amino_acid() {
let sequence = vec![0u8];
let match_emissions = vec![vec![-1.0; NUM_AMINO_ACIDS]];
let insert_emissions = vec![vec![-2.0; NUM_AMINO_ACIDS]];
let transitions = vec![vec![-0.5, -1.0, -1.0]];
let result = ViterbiKernel::viterbi_forward(&sequence, &match_emissions, &insert_emissions, &transitions);
assert!(result.is_ok());
}
#[test]
fn test_long_sequence() {
let sequence: Vec<u8> = (0..100).map(|i| (i % NUM_AMINO_ACIDS as u8) as u8).collect();
let match_emissions = vec![vec![-1.0; NUM_AMINO_ACIDS]; 100];
let insert_emissions = vec![vec![-2.0; NUM_AMINO_ACIDS]; 100];
let transitions = vec![vec![-0.5, -1.0, -1.0]; 100];
let result = ViterbiKernel::viterbi_forward(&sequence, &match_emissions, &insert_emissions, &transitions);
assert!(result.is_ok());
let (dp, _path) = result.unwrap();
assert_eq!(dp.len(), 101);
}
#[test]
fn test_forward_backward_consistency() {
let sequence = vec![0u8, 1u8, 2u8];
let match_emissions = vec![vec![-1.0; NUM_AMINO_ACIDS]; 3];
let insert_emissions = vec![vec![-2.0; NUM_AMINO_ACIDS]; 3];
let transitions = vec![vec![-0.5, -1.0, -1.0]; 3];
let forward_result = ForwardKernel::forward(&sequence, &match_emissions, &insert_emissions, &transitions);
let backward_result = BackwardKernel::backward(&sequence, &match_emissions, &insert_emissions, &transitions);
assert!(forward_result.is_ok());
assert!(backward_result.is_ok());
let forward_score = forward_result.unwrap();
assert!(forward_score.is_finite());
}
#[test]
fn test_viterbi_different_paths() {
let sequence = vec![0u8, 1u8, 2u8, 3u8, 4u8];
let match_emissions = vec![vec![-1.0; NUM_AMINO_ACIDS]; 5];
let mut insert_emissions = vec![vec![-2.0; NUM_AMINO_ACIDS]; 5];
insert_emissions[1][0] = -0.5;
let transitions = vec![vec![-0.5, -1.0, -1.0]; 5];
let result = ViterbiKernel::viterbi_forward(&sequence, &match_emissions, &insert_emissions, &transitions);
assert!(result.is_ok());
}
#[test]
fn test_logsumexp_zero_values() {
let values = vec![0.0, 0.0, 0.0];
let result = ForwardKernel::logsumexp(&values);
assert!(result.is_finite());
assert!(result > 0.0);
}
#[test]
fn test_logsumexp_extreme_values() {
let values = vec![-1e10, -1e10, -1e10];
let result = ForwardKernel::logsumexp(&values);
assert!(result.is_finite());
assert!(result <= -9.99999999e9); }
#[test]
fn test_multiple_sequence_alignment() {
let sequences = vec![
vec![0u8, 1u8, 2u8],
vec![0u8, 1u8, 3u8],
vec![1u8, 2u8, 3u8],
];
let mut match_emissions = vec![vec![-1.0; NUM_AMINO_ACIDS]; 3];
let mut insert_emissions = vec![vec![-2.0; NUM_AMINO_ACIDS]; 3];
let mut transitions = vec![vec![-0.5, -1.0, -1.0]; 3];
let result = BaumWelchKernel::baum_welch_iteration(
&sequences,
&mut match_emissions,
&mut insert_emissions,
&mut transitions,
0,
);
assert!(result.is_ok());
}
#[test]
fn test_baum_welch_convergence() {
let sequences = vec![vec![0u8, 1u8], vec![0u8, 1u8], vec![0u8, 1u8]];
let mut match_emissions = vec![vec![-1.0; NUM_AMINO_ACIDS]; 2];
let mut insert_emissions = vec![vec![-2.0; NUM_AMINO_ACIDS]; 2];
let mut transitions = vec![vec![-0.5, -1.0, -1.0]; 2];
let result1 = BaumWelchKernel::baum_welch_iteration(
&sequences,
&mut match_emissions,
&mut insert_emissions,
&mut transitions,
0,
).unwrap();
let result2 = BaumWelchKernel::baum_welch_iteration(
&sequences,
&mut match_emissions,
&mut insert_emissions,
&mut transitions,
1,
).unwrap();
assert!(result1.is_finite());
assert!(result2.is_finite());
}
#[test]
fn test_viterbi_all_amino_acids() {
let sequence: Vec<u8> = (0..24).collect();
let match_emissions = vec![vec![-1.0; NUM_AMINO_ACIDS]; 24];
let insert_emissions = vec![vec![-2.0; NUM_AMINO_ACIDS]; 24];
let transitions = vec![vec![-0.5, -1.0, -1.0]; 24];
let result = ViterbiKernel::viterbi_forward(&sequence, &match_emissions, &insert_emissions, &transitions);
assert!(result.is_ok());
}
#[test]
fn test_forward_score_range() {
let sequence = vec![0u8, 1u8, 2u8];
let match_emissions = vec![vec![-1.0; NUM_AMINO_ACIDS]; 3];
let insert_emissions = vec![vec![-2.0; NUM_AMINO_ACIDS]; 3];
let transitions = vec![vec![-0.5, -1.0, -1.0]; 3];
let result1 = ForwardKernel::forward(&sequence, &match_emissions, &insert_emissions, &transitions).unwrap();
let match_emissions2 = vec![vec![-10.0; NUM_AMINO_ACIDS]; 3];
let result2 = ForwardKernel::forward(&sequence, &match_emissions2, &insert_emissions, &transitions).unwrap();
assert!(result1 > result2);
}
}