use crate::error::{SeqError, SeqResult};
#[derive(Debug, Clone)]
pub struct HmmDiscrete {
pub n_states: usize,
pub n_obs: usize,
pub pi: Vec<f64>,
pub a: Vec<f64>,
pub b: Vec<f64>,
}
impl HmmDiscrete {
pub fn new(
n_states: usize,
n_obs: usize,
pi: Vec<f64>,
a: Vec<f64>,
b: Vec<f64>,
) -> SeqResult<Self> {
if n_states == 0 || n_obs == 0 {
return Err(SeqError::InvalidConfiguration(
"n_states and n_obs must be > 0".to_string(),
));
}
if pi.len() != n_states {
return Err(SeqError::ShapeMismatch {
expected: n_states,
got: pi.len(),
});
}
if a.len() != n_states * n_states {
return Err(SeqError::ShapeMismatch {
expected: n_states * n_states,
got: a.len(),
});
}
if b.len() != n_states * n_obs {
return Err(SeqError::ShapeMismatch {
expected: n_states * n_obs,
got: b.len(),
});
}
validate_distribution(&pi, "pi", 1)?;
for i in 0..n_states {
validate_distribution(&a[i * n_states..(i + 1) * n_states], "A row", i)?;
validate_distribution(&b[i * n_obs..(i + 1) * n_obs], "B row", i)?;
}
Ok(Self {
n_states,
n_obs,
pi,
a,
b,
})
}
pub fn log_emission(&self, state: usize, obs: usize) -> SeqResult<f64> {
if state >= self.n_states {
return Err(SeqError::IndexOutOfBounds {
index: state,
len: self.n_states,
});
}
if obs >= self.n_obs {
return Err(SeqError::InvalidObservation(format!(
"obs index {obs} >= n_obs {}",
self.n_obs
)));
}
Ok(log_safe(self.b[state * self.n_obs + obs]))
}
pub fn log_trans(&self, i: usize, j: usize) -> SeqResult<f64> {
if i >= self.n_states || j >= self.n_states {
return Err(SeqError::IndexOutOfBounds {
index: i.max(j),
len: self.n_states,
});
}
Ok(log_safe(self.a[i * self.n_states + j]))
}
pub fn log_init(&self, i: usize) -> SeqResult<f64> {
if i >= self.n_states {
return Err(SeqError::IndexOutOfBounds {
index: i,
len: self.n_states,
});
}
Ok(log_safe(self.pi[i]))
}
}
#[derive(Debug, Clone)]
pub struct HmmGaussian {
pub n_states: usize,
pub dim: usize,
pub pi: Vec<f64>,
pub a: Vec<f64>,
pub means: Vec<f64>,
pub vars: Vec<f64>,
}
impl HmmGaussian {
pub fn new(
n_states: usize,
dim: usize,
pi: Vec<f64>,
a: Vec<f64>,
means: Vec<f64>,
vars: Vec<f64>,
) -> SeqResult<Self> {
if n_states == 0 || dim == 0 {
return Err(SeqError::InvalidConfiguration(
"n_states and dim must be > 0".to_string(),
));
}
if pi.len() != n_states {
return Err(SeqError::ShapeMismatch {
expected: n_states,
got: pi.len(),
});
}
if a.len() != n_states * n_states {
return Err(SeqError::ShapeMismatch {
expected: n_states * n_states,
got: a.len(),
});
}
if means.len() != n_states * dim {
return Err(SeqError::ShapeMismatch {
expected: n_states * dim,
got: means.len(),
});
}
if vars.len() != n_states * dim {
return Err(SeqError::ShapeMismatch {
expected: n_states * dim,
got: vars.len(),
});
}
for &v in &vars {
if v <= 0.0 || !v.is_finite() {
return Err(SeqError::InvalidParameter {
name: "variance".to_string(),
value: v,
});
}
}
validate_distribution(&pi, "pi", 1)?;
for i in 0..n_states {
validate_distribution(&a[i * n_states..(i + 1) * n_states], "A row", i)?;
}
Ok(Self {
n_states,
dim,
pi,
a,
means,
vars,
})
}
pub fn log_emission(&self, state: usize, x: &[f64]) -> SeqResult<f64> {
if state >= self.n_states {
return Err(SeqError::IndexOutOfBounds {
index: state,
len: self.n_states,
});
}
if x.len() != self.dim {
return Err(SeqError::ShapeMismatch {
expected: self.dim,
got: x.len(),
});
}
let mut ll = 0.0;
let log_2pi = (2.0 * std::f64::consts::PI).ln();
for d in 0..self.dim {
let mu = self.means[state * self.dim + d];
let var = self.vars[state * self.dim + d];
let diff = x[d] - mu;
ll += -0.5 * (log_2pi + var.ln() + diff * diff / var);
}
Ok(ll)
}
pub fn log_trans(&self, i: usize, j: usize) -> SeqResult<f64> {
if i >= self.n_states || j >= self.n_states {
return Err(SeqError::IndexOutOfBounds {
index: i.max(j),
len: self.n_states,
});
}
Ok(log_safe(self.a[i * self.n_states + j]))
}
pub fn log_init(&self, i: usize) -> SeqResult<f64> {
if i >= self.n_states {
return Err(SeqError::IndexOutOfBounds {
index: i,
len: self.n_states,
});
}
Ok(log_safe(self.pi[i]))
}
}
#[inline]
pub fn log_safe(x: f64) -> f64 {
if x <= 0.0 || !x.is_finite() {
f64::NEG_INFINITY
} else {
x.ln()
}
}
fn validate_distribution(p: &[f64], label: &str, idx: usize) -> SeqResult<()> {
let mut s = 0.0;
for &v in p {
if !(0.0..=1.0 + 1e-9).contains(&v) || !v.is_finite() {
return Err(SeqError::ProbabilityOutOfRange(v));
}
s += v;
}
if (s - 1.0).abs() > 1e-5 {
return Err(SeqError::InvalidConfiguration(format!(
"{label}[{idx}] sums to {s}, expected 1"
)));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn discrete_hmm_basic() {
let hmm = HmmDiscrete::new(
2,
3,
vec![0.6, 0.4],
vec![0.7, 0.3, 0.4, 0.6],
vec![0.1, 0.4, 0.5, 0.6, 0.3, 0.1],
)
.expect("ok");
assert_eq!(hmm.n_states, 2);
let le = hmm.log_emission(0, 2).expect("ok");
assert!((le - 0.5_f64.ln()).abs() < 1e-12);
}
#[test]
fn log_safe_negative_is_neg_inf() {
assert!(log_safe(0.0).is_infinite());
assert!(log_safe(-1.0).is_infinite());
assert!(log_safe(f64::NAN).is_infinite());
assert!((log_safe(2.0) - 2.0_f64.ln()).abs() < 1e-12);
}
#[test]
fn gaussian_hmm_emission() {
let hmm = HmmGaussian::new(
2,
1,
vec![0.5, 0.5],
vec![0.9, 0.1, 0.1, 0.9],
vec![0.0, 5.0],
vec![1.0, 1.0],
)
.expect("ok");
let le0 = hmm.log_emission(0, &[0.0]).expect("ok");
let expected = -0.5 * (2.0 * std::f64::consts::PI).ln();
assert!((le0 - expected).abs() < 1e-12);
}
}