Skip to main content

oxicuda_seq/hmm/
hmm.rs

1//! HMM model definitions and shared helpers.
2
3use crate::error::{SeqError, SeqResult};
4
5/// Discrete-emission Hidden Markov Model.
6///
7/// * `pi[i]` — initial state probability for state `i`
8/// * `a[i*n_states + j]` — transition probability from state `i` to state `j`
9/// * `b[i*n_obs + k]` — emission probability of observation `k` from state `i`
10#[derive(Debug, Clone)]
11pub struct HmmDiscrete {
12    pub n_states: usize,
13    pub n_obs: usize,
14    pub pi: Vec<f64>,
15    pub a: Vec<f64>,
16    pub b: Vec<f64>,
17}
18
19impl HmmDiscrete {
20    /// Construct a discrete HMM, validating shapes and that each row of `pi`,
21    /// `a`, and `b` sums to approximately 1.
22    pub fn new(
23        n_states: usize,
24        n_obs: usize,
25        pi: Vec<f64>,
26        a: Vec<f64>,
27        b: Vec<f64>,
28    ) -> SeqResult<Self> {
29        if n_states == 0 || n_obs == 0 {
30            return Err(SeqError::InvalidConfiguration(
31                "n_states and n_obs must be > 0".to_string(),
32            ));
33        }
34        if pi.len() != n_states {
35            return Err(SeqError::ShapeMismatch {
36                expected: n_states,
37                got: pi.len(),
38            });
39        }
40        if a.len() != n_states * n_states {
41            return Err(SeqError::ShapeMismatch {
42                expected: n_states * n_states,
43                got: a.len(),
44            });
45        }
46        if b.len() != n_states * n_obs {
47            return Err(SeqError::ShapeMismatch {
48                expected: n_states * n_obs,
49                got: b.len(),
50            });
51        }
52        validate_distribution(&pi, "pi", 1)?;
53        for i in 0..n_states {
54            validate_distribution(&a[i * n_states..(i + 1) * n_states], "A row", i)?;
55            validate_distribution(&b[i * n_obs..(i + 1) * n_obs], "B row", i)?;
56        }
57        Ok(Self {
58            n_states,
59            n_obs,
60            pi,
61            a,
62            b,
63        })
64    }
65
66    /// Log emission `log B_j(o_t)`.
67    pub fn log_emission(&self, state: usize, obs: usize) -> SeqResult<f64> {
68        if state >= self.n_states {
69            return Err(SeqError::IndexOutOfBounds {
70                index: state,
71                len: self.n_states,
72            });
73        }
74        if obs >= self.n_obs {
75            return Err(SeqError::InvalidObservation(format!(
76                "obs index {obs} >= n_obs {}",
77                self.n_obs
78            )));
79        }
80        Ok(log_safe(self.b[state * self.n_obs + obs]))
81    }
82
83    /// Log transition `log A_ij`.
84    pub fn log_trans(&self, i: usize, j: usize) -> SeqResult<f64> {
85        if i >= self.n_states || j >= self.n_states {
86            return Err(SeqError::IndexOutOfBounds {
87                index: i.max(j),
88                len: self.n_states,
89            });
90        }
91        Ok(log_safe(self.a[i * self.n_states + j]))
92    }
93
94    /// Log initial probability `log π_i`.
95    pub fn log_init(&self, i: usize) -> SeqResult<f64> {
96        if i >= self.n_states {
97            return Err(SeqError::IndexOutOfBounds {
98                index: i,
99                len: self.n_states,
100            });
101        }
102        Ok(log_safe(self.pi[i]))
103    }
104}
105
106/// Gaussian-emission HMM with diagonal covariance per state.
107///
108/// * `means[i*dim + d]` — mean of dim `d` for state `i`
109/// * `vars[i*dim + d]` — variance (diagonal) of dim `d` for state `i`
110#[derive(Debug, Clone)]
111pub struct HmmGaussian {
112    pub n_states: usize,
113    pub dim: usize,
114    pub pi: Vec<f64>,
115    pub a: Vec<f64>,
116    pub means: Vec<f64>,
117    pub vars: Vec<f64>,
118}
119
120impl HmmGaussian {
121    /// Construct a Gaussian HMM, validating shapes and probabilistic invariants.
122    pub fn new(
123        n_states: usize,
124        dim: usize,
125        pi: Vec<f64>,
126        a: Vec<f64>,
127        means: Vec<f64>,
128        vars: Vec<f64>,
129    ) -> SeqResult<Self> {
130        if n_states == 0 || dim == 0 {
131            return Err(SeqError::InvalidConfiguration(
132                "n_states and dim must be > 0".to_string(),
133            ));
134        }
135        if pi.len() != n_states {
136            return Err(SeqError::ShapeMismatch {
137                expected: n_states,
138                got: pi.len(),
139            });
140        }
141        if a.len() != n_states * n_states {
142            return Err(SeqError::ShapeMismatch {
143                expected: n_states * n_states,
144                got: a.len(),
145            });
146        }
147        if means.len() != n_states * dim {
148            return Err(SeqError::ShapeMismatch {
149                expected: n_states * dim,
150                got: means.len(),
151            });
152        }
153        if vars.len() != n_states * dim {
154            return Err(SeqError::ShapeMismatch {
155                expected: n_states * dim,
156                got: vars.len(),
157            });
158        }
159        for &v in &vars {
160            if v <= 0.0 || !v.is_finite() {
161                return Err(SeqError::InvalidParameter {
162                    name: "variance".to_string(),
163                    value: v,
164                });
165            }
166        }
167        validate_distribution(&pi, "pi", 1)?;
168        for i in 0..n_states {
169            validate_distribution(&a[i * n_states..(i + 1) * n_states], "A row", i)?;
170        }
171        Ok(Self {
172            n_states,
173            dim,
174            pi,
175            a,
176            means,
177            vars,
178        })
179    }
180
181    /// Log Gaussian-emission density `log N(x_t | μ_j, diag(σ²_j))`.
182    pub fn log_emission(&self, state: usize, x: &[f64]) -> SeqResult<f64> {
183        if state >= self.n_states {
184            return Err(SeqError::IndexOutOfBounds {
185                index: state,
186                len: self.n_states,
187            });
188        }
189        if x.len() != self.dim {
190            return Err(SeqError::ShapeMismatch {
191                expected: self.dim,
192                got: x.len(),
193            });
194        }
195        let mut ll = 0.0;
196        let log_2pi = (2.0 * std::f64::consts::PI).ln();
197        for d in 0..self.dim {
198            let mu = self.means[state * self.dim + d];
199            let var = self.vars[state * self.dim + d];
200            let diff = x[d] - mu;
201            ll += -0.5 * (log_2pi + var.ln() + diff * diff / var);
202        }
203        Ok(ll)
204    }
205
206    /// Log transition `log A_ij`.
207    pub fn log_trans(&self, i: usize, j: usize) -> SeqResult<f64> {
208        if i >= self.n_states || j >= self.n_states {
209            return Err(SeqError::IndexOutOfBounds {
210                index: i.max(j),
211                len: self.n_states,
212            });
213        }
214        Ok(log_safe(self.a[i * self.n_states + j]))
215    }
216
217    /// Log initial probability `log π_i`.
218    pub fn log_init(&self, i: usize) -> SeqResult<f64> {
219        if i >= self.n_states {
220            return Err(SeqError::IndexOutOfBounds {
221                index: i,
222                len: self.n_states,
223            });
224        }
225        Ok(log_safe(self.pi[i]))
226    }
227}
228
229/// Compute `ln(x)` returning `-∞` for `x ≤ 0` instead of `NaN`.
230#[inline]
231pub fn log_safe(x: f64) -> f64 {
232    if x <= 0.0 || !x.is_finite() {
233        f64::NEG_INFINITY
234    } else {
235        x.ln()
236    }
237}
238
239/// Validate that `p` is a probability vector summing to ~1 (tolerance 1e-6).
240fn validate_distribution(p: &[f64], label: &str, idx: usize) -> SeqResult<()> {
241    let mut s = 0.0;
242    for &v in p {
243        if !(0.0..=1.0 + 1e-9).contains(&v) || !v.is_finite() {
244            return Err(SeqError::ProbabilityOutOfRange(v));
245        }
246        s += v;
247    }
248    if (s - 1.0).abs() > 1e-5 {
249        return Err(SeqError::InvalidConfiguration(format!(
250            "{label}[{idx}] sums to {s}, expected 1"
251        )));
252    }
253    Ok(())
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259
260    #[test]
261    fn discrete_hmm_basic() {
262        let hmm = HmmDiscrete::new(
263            2,
264            3,
265            vec![0.6, 0.4],
266            vec![0.7, 0.3, 0.4, 0.6],
267            vec![0.1, 0.4, 0.5, 0.6, 0.3, 0.1],
268        )
269        .expect("ok");
270        assert_eq!(hmm.n_states, 2);
271        let le = hmm.log_emission(0, 2).expect("ok");
272        assert!((le - 0.5_f64.ln()).abs() < 1e-12);
273    }
274
275    #[test]
276    fn log_safe_negative_is_neg_inf() {
277        assert!(log_safe(0.0).is_infinite());
278        assert!(log_safe(-1.0).is_infinite());
279        assert!(log_safe(f64::NAN).is_infinite());
280        assert!((log_safe(2.0) - 2.0_f64.ln()).abs() < 1e-12);
281    }
282
283    #[test]
284    fn gaussian_hmm_emission() {
285        let hmm = HmmGaussian::new(
286            2,
287            1,
288            vec![0.5, 0.5],
289            vec![0.9, 0.1, 0.1, 0.9],
290            vec![0.0, 5.0],
291            vec![1.0, 1.0],
292        )
293        .expect("ok");
294        let le0 = hmm.log_emission(0, &[0.0]).expect("ok");
295        let expected = -0.5 * (2.0 * std::f64::consts::PI).ln();
296        assert!((le0 - expected).abs() < 1e-12);
297    }
298}