use nalgebra::{DMatrix, DVector};
use crate::error::AsrError;
use crate::tree::asr::alphabet::Alphabet;
pub struct GtrModel<A: Alphabet> {
pi: DVector<f64>,
eigenvalues: DVector<f64>,
eigenvectors: DMatrix<f64>,
sqrt_pi: DVector<f64>,
inv_sqrt_pi: DVector<f64>,
_phantom: std::marker::PhantomData<A>,
}
impl<A: Alphabet> GtrModel<A> {
pub fn new(pi: Vec<f64>, w: DMatrix<f64>, normalize: bool) -> Result<Self, AsrError> {
let n = A::N_STATES;
if pi.len() != n {
return Err(AsrError::AlphabetMismatch("pi length does not match alphabet states".to_string()));
}
if w.nrows() != n || w.ncols() != n {
return Err(AsrError::AlphabetMismatch("W matrix dimensions do not match alphabet states".to_string()));
}
let pi_vec = DVector::from_vec(pi);
let pi_sum: f64 = pi_vec.sum();
let pi_norm = pi_vec / pi_sum;
let mut q = DMatrix::zeros(n, n);
for i in 0..n {
for j in 0..n {
if i != j {
q[(i, j)] = pi_norm[j] * w[(i, j)];
}
}
}
for i in 0..n {
let row_sum: f64 = q.row(i).iter().sum();
q[(i, i)] = -row_sum;
}
if normalize {
let mut mu = 0.0;
for i in 0..n {
mu -= pi_norm[i] * q[(i, i)];
}
if mu <= 0.0 {
return Err(AsrError::NumericalInstability);
}
q /= mu;
}
let sqrt_pi = pi_norm.map(|x| x.sqrt());
let inv_sqrt_pi = sqrt_pi.map(|x| 1.0 / x);
let mut s = DMatrix::zeros(n, n);
for i in 0..n {
for j in 0..n {
s[(i, j)] = sqrt_pi[i] * q[(i, j)] * inv_sqrt_pi[j];
}
}
let eigen = s.symmetric_eigen();
Ok(Self {
pi: pi_norm,
eigenvalues: eigen.eigenvalues,
eigenvectors: eigen.eigenvectors,
sqrt_pi,
inv_sqrt_pi,
_phantom: std::marker::PhantomData,
})
}
pub fn transition(&self, t: f64) -> DMatrix<f64> {
let n = A::N_STATES;
if t == 0.0 {
return DMatrix::identity(n, n);
}
let mut exp_lambda_t = DMatrix::zeros(n, n);
for i in 0..n {
exp_lambda_t[(i, i)] = (self.eigenvalues[i] * t).exp();
}
let s_t = &self.eigenvectors * exp_lambda_t * self.eigenvectors.transpose();
let mut p_t = DMatrix::zeros(n, n);
for i in 0..n {
for j in 0..n {
p_t[(i, j)] = self.inv_sqrt_pi[i] * s_t[(i, j)] * self.sqrt_pi[j];
}
}
p_t
}
pub fn equilibrium(&self) -> &DVector<f64> {
&self.pi
}
pub fn jukes_cantor() -> Result<Self, AsrError> {
let n = A::N_STATES;
let pi = vec![1.0 / (n as f64); n];
let w = DMatrix::from_element(n, n, 1.0);
Self::new(pi, w, true)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tree::asr::alphabet::Nucleotide;
#[test]
fn test_jc_transition() {
let model = GtrModel::<Nucleotide>::jukes_cantor().unwrap();
let p_t = model.transition(1.0);
let diag = p_t[(0, 0)];
let off_diag = p_t[(0, 1)];
assert!((diag - off_diag).abs() > 0.0);
for i in 0..4 {
let sum: f64 = p_t.row(i).iter().sum();
assert!((sum - 1.0).abs() < 1e-10);
}
}
}