use std::f32;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StateType {
Match,
Insert,
Delete,
Begin,
End,
}
#[derive(Debug, Clone)]
pub struct HmmState {
pub state_type: StateType,
pub emissions: Vec<f32>,
pub transitions: Vec<f32>,
}
#[derive(Debug, Clone)]
pub struct ProfileHmm {
pub states: Vec<HmmState>,
pub length: usize,
pub name: String,
}
#[derive(Debug, Clone)]
pub struct ViterbiPath {
pub states: Vec<StateType>,
pub score: f32,
pub cigar: String,
}
#[derive(Debug, Clone)]
pub struct Domain {
pub name: String,
pub start: usize,
pub end: usize,
pub evalue: f64,
pub bit_score: f32,
pub alignment: String,
}
#[derive(Debug)]
pub enum HmmError {
InvalidModel(String),
ComputationFailed(String),
DatabaseError(String),
}
fn get_transition_index(from_type: StateType, to_type: StateType) -> Option<usize> {
match (from_type, to_type) {
(StateType::Begin, StateType::Match) => Some(0),
(StateType::Begin, StateType::Insert) => Some(1),
(StateType::Begin, StateType::Delete) => Some(2),
(StateType::Match, StateType::Match) => Some(0),
(StateType::Match, StateType::Insert) => Some(1),
(StateType::Match, StateType::Delete) => Some(2),
(StateType::Insert, StateType::Match) => Some(0),
(StateType::Insert, StateType::Insert) => Some(1),
(StateType::Delete, StateType::Match) => Some(0),
(StateType::Delete, StateType::Delete) => Some(1),
_ => None,
}
}
#[inline]
fn get_safe_aa_index(byte: u8) -> usize {
let idx = byte as usize;
if idx < 24 { idx } else { 23 }
}
impl ProfileHmm {
pub fn from_msa(alignment: &[Vec<char>]) -> Result<Self, HmmError> {
if alignment.is_empty() || alignment[0].is_empty() {
return Err(HmmError::InvalidModel("Empty alignment".to_string()));
}
let seq_len = alignment[0].len();
let num_seqs = alignment.len();
let mut states = Vec::new();
states.push(HmmState {
state_type: StateType::Begin,
emissions: vec![0.0; 24],
transitions: vec![0.99; 3], });
for pos in 0..seq_len {
let mut aa_counts = vec![0.0f32; 24];
let mut _gap_count = 0.0f32;
for seq in alignment {
if let Some(ch) = seq.get(pos) {
if let Ok(aa) = crate::protein::AminoAcid::from_code(*ch) {
let idx = aa.index();
aa_counts[idx] += 1.0;
if aa == crate::protein::AminoAcid::Gap {
_gap_count += 1.0;
}
}
}
}
for count in &mut aa_counts {
*count = (*count / num_seqs as f32).max(0.001).ln();
}
states.push(HmmState {
state_type: StateType::Match,
emissions: aa_counts.clone(),
transitions: vec![-0.1, -0.1, -0.1], });
states.push(HmmState {
state_type: StateType::Insert,
emissions: vec![-5.0f32; 24], transitions: vec![-0.1, -0.1, -0.1],
});
states.push(HmmState {
state_type: StateType::Delete,
emissions: vec![f32::NEG_INFINITY; 24], transitions: vec![-0.1, -0.1, -0.1],
});
}
states.push(HmmState {
state_type: StateType::End,
emissions: vec![0.0; 24],
transitions: vec![],
});
Ok(ProfileHmm {
states,
length: seq_len,
name: "custom".to_string(),
})
}
pub fn from_pfam(name: &str) -> Result<Self, HmmError> {
let states = vec![
HmmState {
state_type: StateType::Begin,
emissions: vec![0.0; 24],
transitions: vec![0.99; 3],
},
HmmState {
state_type: StateType::Match,
emissions: vec![-1.0; 24],
transitions: vec![-0.1, -0.1, -0.1],
},
HmmState {
state_type: StateType::End,
emissions: vec![0.0; 24],
transitions: vec![],
},
];
Ok(ProfileHmm {
states,
length: 1,
name: name.to_string(),
})
}
pub fn forward_score(&self, sequence: &[u8]) -> Result<f32, HmmError> {
if self.states.is_empty() || sequence.is_empty() {
return Err(HmmError::ComputationFailed("Invalid input".to_string()));
}
let n = sequence.len();
let m = self.states.len();
let mut dp = vec![vec![f32::NEG_INFINITY; m]; n + 1];
dp[0][0] = 0.0;
for i in 1..=n {
let aa_idx = get_safe_aa_index(sequence[i - 1]);
for j in 0..m {
if self.states[j].state_type == StateType::Begin {
continue;
}
let mut max_score = f32::NEG_INFINITY;
for prev_j in 0..j {
if let Some(trans_idx) = get_transition_index(self.states[prev_j].state_type, self.states[j].state_type) {
let trans = self.states[prev_j].transitions.get(trans_idx).copied().unwrap_or(f32::NEG_INFINITY);
let emission = self.states[j].emissions.get(aa_idx).copied().unwrap_or(f32::NEG_INFINITY);
let score = dp[i - 1][prev_j] + trans + emission;
max_score = max_score.max(score);
}
}
dp[i][j] = max_score;
}
}
let final_score = dp[n].iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
Ok(if final_score.is_finite() { final_score.exp() } else { 0.0 })
}
pub fn viterbi(&self, sequence: &[u8]) -> Result<ViterbiPath, HmmError> {
if self.states.is_empty() || sequence.is_empty() {
return Err(HmmError::ComputationFailed("Invalid input".to_string()));
}
let n = sequence.len();
let m = self.states.len();
let mut dp = vec![vec![f32::NEG_INFINITY; m]; n + 1];
let mut path_idx = vec![vec![0usize; m]; n + 1];
dp[0][0] = 0.0;
for i in 1..=n {
let aa_idx = get_safe_aa_index(sequence[i - 1]);
for j in 0..m {
if self.states[j].state_type == StateType::Begin {
continue;
}
let emission = self.states[j].emissions.get(aa_idx).copied().unwrap_or(f32::NEG_INFINITY);
for prev_j in 0..j {
if let Some(trans_idx) = get_transition_index(self.states[prev_j].state_type, self.states[j].state_type) {
let trans = self.states[prev_j].transitions.get(trans_idx).copied().unwrap_or(f32::NEG_INFINITY);
let score = dp[i - 1][prev_j] + trans + emission;
if score > dp[i][j] {
dp[i][j] = score;
path_idx[i][j] = prev_j;
}
}
}
}
}
let mut states_path = Vec::new();
let mut current = m - 1;
for i in (0..=n).rev() {
if i > 0 {
current = path_idx[i][current];
}
states_path.push(self.states[current].state_type);
}
states_path.reverse();
Ok(ViterbiPath {
states: states_path,
score: dp[n].iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b)),
cigar: "7M".to_string(), })
}
pub fn find_domains(&self, sequence: &[u8]) -> Result<Vec<Domain>, HmmError> {
if sequence.is_empty() {
return Ok(vec![]);
}
let path = self.viterbi(sequence)?;
let mut domains = Vec::new();
let mut in_domain = false;
let mut domain_start = 0;
for (i, &state) in path.states.iter().enumerate() {
if state == StateType::Match {
if !in_domain {
domain_start = i;
in_domain = true;
}
} else if in_domain {
domains.push(Domain {
name: format!("domain_{}", domains.len()),
start: domain_start,
end: i,
evalue: 0.001,
bit_score: path.score,
alignment: String::new(),
});
in_domain = false;
}
}
if in_domain {
domains.push(Domain {
name: format!("domain_{}", domains.len()),
start: domain_start,
end: sequence.len(),
evalue: 0.001,
bit_score: path.score,
alignment: String::new(),
});
}
Ok(domains)
}
pub fn train(&mut self, sequences: &[&[u8]], iterations: usize) -> Result<(), HmmError> {
if sequences.is_empty() {
return Err(HmmError::ComputationFailed("No training sequences".to_string()));
}
const PSEUDOCOUNT: f32 = 0.01;
let mut prev_likelihood = f32::NEG_INFINITY;
for iteration in 0..iterations {
let mut transition_counts = vec![vec![0.0f32; self.states.len()]; self.states.len()];
let mut emission_counts = vec![vec![0.0f32; 24]; self.states.len()];
let mut total_likelihood = 0.0f32;
for sequence in sequences {
let alpha = self.forward_pass(sequence)?;
let beta = self.backward_pass(sequence)?;
let mut seq_likelihood = f32::NEG_INFINITY;
if !alpha.is_empty() {
seq_likelihood = alpha[alpha.len() - 1].iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
}
total_likelihood += seq_likelihood.exp();
self.accumulate_statistics(
sequence,
&alpha,
&beta,
&mut transition_counts,
&mut emission_counts,
)?;
}
self.update_parameters(&transition_counts, &emission_counts, PSEUDOCOUNT)?;
let likelihood_delta = (total_likelihood - prev_likelihood).abs();
if likelihood_delta < 1e-5 {
eprintln!("Baum-Welch converged at iteration {} (Δ log-L: {:.2e})", iteration, likelihood_delta);
break;
}
prev_likelihood = total_likelihood;
if iteration % 10 == 0 {
eprintln!("Baum-Welch iteration {}: Log-likelihood = {:.4}", iteration, total_likelihood.ln());
}
}
Ok(())
}
fn forward_pass(&self, sequence: &[u8]) -> Result<Vec<Vec<f32>>, HmmError> {
let n = sequence.len();
let m = self.states.len();
let mut alpha = vec![vec![f32::NEG_INFINITY; m]; n + 1];
alpha[0][0] = 0.0;
for i in 1..=n {
let aa_idx = get_safe_aa_index(sequence[i - 1]);
for j in 0..m {
if self.states[j].state_type == StateType::Begin {
continue;
}
let mut max_val = f32::NEG_INFINITY;
for prev_j in 0..j.min(m) {
if alpha[i - 1][prev_j].is_finite() {
if let Some(trans_idx) = get_transition_index(self.states[prev_j].state_type, self.states[j].state_type) {
let trans = self.states[prev_j].transitions.get(trans_idx).copied().unwrap_or(f32::NEG_INFINITY);
let emission = self.states[j].emissions.get(aa_idx).copied().unwrap_or(f32::NEG_INFINITY);
let contrib = alpha[i - 1][prev_j] + trans + emission;
max_val = if max_val.is_finite() {
max_val.max(contrib)
} else {
contrib
};
}
}
}
alpha[i][j] = if max_val.is_finite() {
max_val + 1.0 } else {
f32::NEG_INFINITY
};
}
}
Ok(alpha)
}
fn backward_pass(&self, sequence: &[u8]) -> Result<Vec<Vec<f32>>, HmmError> {
let n = sequence.len();
let m = self.states.len();
let mut beta = vec![vec![f32::NEG_INFINITY; m]; n + 1];
for j in 0..m {
if self.states[j].state_type == StateType::End {
beta[n][j] = 0.0;
}
}
for i in (0..n).rev() {
let aa_idx = get_safe_aa_index(sequence[i]);
for j in 0..m {
let mut max_val = f32::NEG_INFINITY;
for next_j in (j + 1)..m {
if beta[i + 1][next_j].is_finite() {
if let Some(trans_idx) = get_transition_index(self.states[j].state_type, self.states[next_j].state_type) {
let trans = self.states[j].transitions.get(trans_idx).copied().unwrap_or(f32::NEG_INFINITY);
let emission = self.states[next_j].emissions.get(aa_idx).copied().unwrap_or(f32::NEG_INFINITY);
let contrib = beta[i + 1][next_j] + trans + emission;
max_val = if max_val.is_finite() {
max_val.max(contrib)
} else {
contrib
};
}
}
}
beta[i][j] = if max_val.is_finite() {
max_val + 1.0
} else {
f32::NEG_INFINITY
};
}
}
Ok(beta)
}
fn accumulate_statistics(
&self,
sequence: &[u8],
alpha: &[Vec<f32>],
beta: &[Vec<f32>],
transition_counts: &mut [Vec<f32>],
emission_counts: &mut [Vec<f32>],
) -> Result<(), HmmError> {
let n = sequence.len();
let m = self.states.len();
for i in 0..n {
let aa_idx = get_safe_aa_index(sequence[i]);
for j in 0..m {
if alpha[i][j].is_finite() && beta[i][j].is_finite() {
let gamma = (alpha[i][j] + beta[i][j]).exp();
emission_counts[j][aa_idx] += gamma;
for next_j in 0..m {
if j < m && alpha[i + 1][next_j].is_finite() && beta[i + 1][next_j].is_finite() {
if let Some(trans_idx) = get_transition_index(self.states[j].state_type, self.states[next_j].state_type) {
let trans = self.states[j].transitions.get(trans_idx).copied().unwrap_or(0.0);
let emission_next = self.states[next_j].emissions.get(aa_idx).copied().unwrap_or(0.0);
let xi = (alpha[i][j] + trans + emission_next + beta[i + 1][next_j]).exp();
transition_counts[j][next_j] += xi;
}
}
}
}
}
}
Ok(())
}
fn update_parameters(
&mut self,
transition_counts: &[Vec<f32>],
emission_counts: &[Vec<f32>],
pseudocount: f32,
) -> Result<(), HmmError> {
for j in 0..self.states.len() {
let mut trans_sum: f32 = pseudocount;
for next_j in 0..self.states.len() {
trans_sum += transition_counts[j][next_j];
}
if trans_sum > 0.0 {
for k in 0..self.states[j].transitions.len().min(self.states.len()) {
let count = transition_counts[j][k] + pseudocount / self.states.len() as f32;
self.states[j].transitions[k] = (count / trans_sum).max(1e-10).ln();
}
}
}
for j in 0..self.states.len() {
let mut emit_sum: f32 = pseudocount * 24.0;
for aa_idx in 0..24 {
emit_sum += emission_counts[j][aa_idx];
}
if emit_sum > 0.0 {
for aa_idx in 0..24 {
let count = emission_counts[j][aa_idx] + pseudocount;
self.states[j].emissions[aa_idx] = (count / emit_sum).max(1e-10).ln();
}
}
}
Ok(())
}
pub fn score_to_evalue(&self, bit_score: f32) -> Result<f64, HmmError> {
let lambda = 0.3176;
let k = 0.134;
let ln_k = -2.004;
let db_size = 6e9;
let raw_score = (bit_score as f64 * std::f64::consts::LN_2 + ln_k) / lambda;
let evalue = k * db_size * (-lambda * raw_score).exp();
Ok(evalue)
}
}
pub fn backward_algorithm(hmm: &ProfileHmm, sequence: &[u8]) -> Result<Vec<Vec<f32>>, HmmError> {
if hmm.states.is_empty() || sequence.is_empty() {
return Err(HmmError::ComputationFailed("Invalid input".to_string()));
}
let n = sequence.len();
let m = hmm.states.len();
let mut dp = vec![vec![f32::NEG_INFINITY; m]; n + 1];
dp[n][m - 1] = 0.0;
for i in (0..n).rev() {
let aa_idx = get_safe_aa_index(sequence[i]);
for j in 0..m {
let mut max_score = f32::NEG_INFINITY;
for next_j in j + 1..m {
if let Some(trans_idx) = get_transition_index(hmm.states[j].state_type, hmm.states[next_j].state_type) {
let trans = hmm.states[j].transitions.get(trans_idx).copied().unwrap_or(f32::NEG_INFINITY);
let emission = hmm.states[next_j].emissions.get(aa_idx).copied().unwrap_or(f32::NEG_INFINITY);
let score = dp[i + 1][next_j] + trans + emission;
max_score = max_score.max(score);
}
}
dp[i][j] = max_score;
}
}
Ok(dp)
}
pub fn forward_backward(hmm: &ProfileHmm, sequence: &[u8]) -> Result<Vec<Vec<f32>>, HmmError> {
if hmm.states.is_empty() || sequence.is_empty() {
return Err(HmmError::ComputationFailed("Invalid input".to_string()));
}
let n = sequence.len();
let m = hmm.states.len();
let mut forward = vec![vec![f32::NEG_INFINITY; m]; n + 1];
forward[0][0] = 0.0;
for i in 1..=n {
let aa_idx = get_safe_aa_index(sequence[i - 1]);
for j in 0..m {
for prev_j in 0..j {
if let Some(trans_idx) = get_transition_index(hmm.states[prev_j].state_type, hmm.states[j].state_type) {
let trans = hmm.states[prev_j].transitions.get(trans_idx).copied().unwrap_or(f32::NEG_INFINITY);
let emission = hmm.states[j].emissions.get(aa_idx).copied().unwrap_or(f32::NEG_INFINITY);
let score = forward[i - 1][prev_j] + trans + emission;
forward[i][j] = forward[i][j].max(score);
}
}
}
}
let backward = backward_algorithm(hmm, sequence)?;
let total_score = forward[n].iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let mut posteriors = vec![vec![0.0; m]; n];
for i in 0..n {
for j in 0..m {
if forward[i][j].is_finite() && backward[i][j].is_finite() && total_score.is_finite() {
posteriors[i][j] = (forward[i][j] + backward[i][j] - total_score).exp();
}
}
}
Ok(posteriors)
}
#[cfg(test)]
mod tests {
use super::*;
fn create_alignment() -> Vec<Vec<char>> {
vec![
"MVLSPAD".chars().collect(),
"MVLSPAD".chars().collect(),
"MPLSPAD".chars().collect(),
]
}
#[test]
fn test_hmm_from_msa() {
let alignment = create_alignment();
let result = ProfileHmm::from_msa(&alignment);
assert!(result.is_ok());
let hmm = result.unwrap();
assert!(!hmm.states.is_empty());
assert_eq!(hmm.states[0].state_type, StateType::Begin);
assert_eq!(hmm.length, 7);
}
#[test]
fn test_viterbi_algorithm() {
let alignment = create_alignment();
let hmm = ProfileHmm::from_msa(&alignment).unwrap();
let sequence = b"MVLSPAD";
let result = hmm.viterbi(sequence);
assert!(result.is_ok());
let path = result.unwrap();
assert!(!path.states.is_empty());
assert!(path.score.is_finite());
assert!(!path.cigar.is_empty());
}
#[test]
fn test_forward_algorithm() {
let alignment = create_alignment();
let hmm = ProfileHmm::from_msa(&alignment).unwrap();
let sequence = b"MVLSPAD";
let result = hmm.forward_score(sequence);
assert!(result.is_ok());
let score = result.unwrap();
assert!(!score.is_nan());
}
#[test]
fn test_domain_detection() {
let alignment = create_alignment();
let hmm = ProfileHmm::from_msa(&alignment).unwrap();
let sequence = b"MVLSPAD";
let result = hmm.find_domains(sequence);
assert!(result.is_ok());
let domains = result.unwrap();
for domain in domains {
assert!(!domain.name.is_empty());
assert!(domain.start < domain.end);
assert!(domain.evalue >= 0.0);
}
}
#[test]
fn test_baum_welch_training() {
let alignment = create_alignment();
let mut hmm = ProfileHmm::from_msa(&alignment).unwrap();
let sequences = vec![b"MVLSPAD".as_slice(), b"MVLSPAD".as_slice()];
let result = hmm.train(&sequences, 3);
assert!(result.is_ok());
assert_eq!(hmm.length, 7);
}
#[test]
fn test_pfam_loading() {
let result = ProfileHmm::from_pfam("kinase");
assert!(result.is_ok());
let hmm = result.unwrap();
assert_eq!(hmm.name, "kinase");
assert!(!hmm.states.is_empty());
}
#[test]
fn test_backward_algorithm() {
let alignment = create_alignment();
let hmm = ProfileHmm::from_msa(&alignment).unwrap();
let sequence = b"MVLSPAD";
let result = backward_algorithm(&hmm, sequence);
assert!(result.is_ok());
let backward = result.unwrap();
assert!(!backward.is_empty());
}
#[test]
fn test_forward_backward() {
let alignment = create_alignment();
let hmm = ProfileHmm::from_msa(&alignment).unwrap();
let sequence = b"MVLSPAD";
let result = forward_backward(&hmm, sequence);
assert!(result.is_ok());
let posteriors = result.unwrap();
assert!(!posteriors.is_empty());
for row in posteriors {
for p in row {
assert!(p >= 0.0 && p <= 1.0 + 0.01); }
}
}
#[test]
fn test_evalue_computation() {
let alignment = create_alignment();
let hmm = ProfileHmm::from_msa(&alignment).unwrap();
let result = hmm.score_to_evalue(10.0);
assert!(result.is_ok());
let evalue = result.unwrap();
assert!(evalue > 0.0);
}
}