1use crate::error::{SeqError, SeqResult};
4
5#[derive(Debug, Clone)]
11pub struct HmmDiscrete {
12 pub n_states: usize,
13 pub n_obs: usize,
14 pub pi: Vec<f64>,
15 pub a: Vec<f64>,
16 pub b: Vec<f64>,
17}
18
19impl HmmDiscrete {
20 pub fn new(
23 n_states: usize,
24 n_obs: usize,
25 pi: Vec<f64>,
26 a: Vec<f64>,
27 b: Vec<f64>,
28 ) -> SeqResult<Self> {
29 if n_states == 0 || n_obs == 0 {
30 return Err(SeqError::InvalidConfiguration(
31 "n_states and n_obs must be > 0".to_string(),
32 ));
33 }
34 if pi.len() != n_states {
35 return Err(SeqError::ShapeMismatch {
36 expected: n_states,
37 got: pi.len(),
38 });
39 }
40 if a.len() != n_states * n_states {
41 return Err(SeqError::ShapeMismatch {
42 expected: n_states * n_states,
43 got: a.len(),
44 });
45 }
46 if b.len() != n_states * n_obs {
47 return Err(SeqError::ShapeMismatch {
48 expected: n_states * n_obs,
49 got: b.len(),
50 });
51 }
52 validate_distribution(&pi, "pi", 1)?;
53 for i in 0..n_states {
54 validate_distribution(&a[i * n_states..(i + 1) * n_states], "A row", i)?;
55 validate_distribution(&b[i * n_obs..(i + 1) * n_obs], "B row", i)?;
56 }
57 Ok(Self {
58 n_states,
59 n_obs,
60 pi,
61 a,
62 b,
63 })
64 }
65
66 pub fn log_emission(&self, state: usize, obs: usize) -> SeqResult<f64> {
68 if state >= self.n_states {
69 return Err(SeqError::IndexOutOfBounds {
70 index: state,
71 len: self.n_states,
72 });
73 }
74 if obs >= self.n_obs {
75 return Err(SeqError::InvalidObservation(format!(
76 "obs index {obs} >= n_obs {}",
77 self.n_obs
78 )));
79 }
80 Ok(log_safe(self.b[state * self.n_obs + obs]))
81 }
82
83 pub fn log_trans(&self, i: usize, j: usize) -> SeqResult<f64> {
85 if i >= self.n_states || j >= self.n_states {
86 return Err(SeqError::IndexOutOfBounds {
87 index: i.max(j),
88 len: self.n_states,
89 });
90 }
91 Ok(log_safe(self.a[i * self.n_states + j]))
92 }
93
94 pub fn log_init(&self, i: usize) -> SeqResult<f64> {
96 if i >= self.n_states {
97 return Err(SeqError::IndexOutOfBounds {
98 index: i,
99 len: self.n_states,
100 });
101 }
102 Ok(log_safe(self.pi[i]))
103 }
104}
105
106#[derive(Debug, Clone)]
111pub struct HmmGaussian {
112 pub n_states: usize,
113 pub dim: usize,
114 pub pi: Vec<f64>,
115 pub a: Vec<f64>,
116 pub means: Vec<f64>,
117 pub vars: Vec<f64>,
118}
119
120impl HmmGaussian {
121 pub fn new(
123 n_states: usize,
124 dim: usize,
125 pi: Vec<f64>,
126 a: Vec<f64>,
127 means: Vec<f64>,
128 vars: Vec<f64>,
129 ) -> SeqResult<Self> {
130 if n_states == 0 || dim == 0 {
131 return Err(SeqError::InvalidConfiguration(
132 "n_states and dim must be > 0".to_string(),
133 ));
134 }
135 if pi.len() != n_states {
136 return Err(SeqError::ShapeMismatch {
137 expected: n_states,
138 got: pi.len(),
139 });
140 }
141 if a.len() != n_states * n_states {
142 return Err(SeqError::ShapeMismatch {
143 expected: n_states * n_states,
144 got: a.len(),
145 });
146 }
147 if means.len() != n_states * dim {
148 return Err(SeqError::ShapeMismatch {
149 expected: n_states * dim,
150 got: means.len(),
151 });
152 }
153 if vars.len() != n_states * dim {
154 return Err(SeqError::ShapeMismatch {
155 expected: n_states * dim,
156 got: vars.len(),
157 });
158 }
159 for &v in &vars {
160 if v <= 0.0 || !v.is_finite() {
161 return Err(SeqError::InvalidParameter {
162 name: "variance".to_string(),
163 value: v,
164 });
165 }
166 }
167 validate_distribution(&pi, "pi", 1)?;
168 for i in 0..n_states {
169 validate_distribution(&a[i * n_states..(i + 1) * n_states], "A row", i)?;
170 }
171 Ok(Self {
172 n_states,
173 dim,
174 pi,
175 a,
176 means,
177 vars,
178 })
179 }
180
181 pub fn log_emission(&self, state: usize, x: &[f64]) -> SeqResult<f64> {
183 if state >= self.n_states {
184 return Err(SeqError::IndexOutOfBounds {
185 index: state,
186 len: self.n_states,
187 });
188 }
189 if x.len() != self.dim {
190 return Err(SeqError::ShapeMismatch {
191 expected: self.dim,
192 got: x.len(),
193 });
194 }
195 let mut ll = 0.0;
196 let log_2pi = (2.0 * std::f64::consts::PI).ln();
197 for d in 0..self.dim {
198 let mu = self.means[state * self.dim + d];
199 let var = self.vars[state * self.dim + d];
200 let diff = x[d] - mu;
201 ll += -0.5 * (log_2pi + var.ln() + diff * diff / var);
202 }
203 Ok(ll)
204 }
205
206 pub fn log_trans(&self, i: usize, j: usize) -> SeqResult<f64> {
208 if i >= self.n_states || j >= self.n_states {
209 return Err(SeqError::IndexOutOfBounds {
210 index: i.max(j),
211 len: self.n_states,
212 });
213 }
214 Ok(log_safe(self.a[i * self.n_states + j]))
215 }
216
217 pub fn log_init(&self, i: usize) -> SeqResult<f64> {
219 if i >= self.n_states {
220 return Err(SeqError::IndexOutOfBounds {
221 index: i,
222 len: self.n_states,
223 });
224 }
225 Ok(log_safe(self.pi[i]))
226 }
227}
228
229#[inline]
231pub fn log_safe(x: f64) -> f64 {
232 if x <= 0.0 || !x.is_finite() {
233 f64::NEG_INFINITY
234 } else {
235 x.ln()
236 }
237}
238
239fn validate_distribution(p: &[f64], label: &str, idx: usize) -> SeqResult<()> {
241 let mut s = 0.0;
242 for &v in p {
243 if !(0.0..=1.0 + 1e-9).contains(&v) || !v.is_finite() {
244 return Err(SeqError::ProbabilityOutOfRange(v));
245 }
246 s += v;
247 }
248 if (s - 1.0).abs() > 1e-5 {
249 return Err(SeqError::InvalidConfiguration(format!(
250 "{label}[{idx}] sums to {s}, expected 1"
251 )));
252 }
253 Ok(())
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259
260 #[test]
261 fn discrete_hmm_basic() {
262 let hmm = HmmDiscrete::new(
263 2,
264 3,
265 vec![0.6, 0.4],
266 vec![0.7, 0.3, 0.4, 0.6],
267 vec![0.1, 0.4, 0.5, 0.6, 0.3, 0.1],
268 )
269 .expect("ok");
270 assert_eq!(hmm.n_states, 2);
271 let le = hmm.log_emission(0, 2).expect("ok");
272 assert!((le - 0.5_f64.ln()).abs() < 1e-12);
273 }
274
275 #[test]
276 fn log_safe_negative_is_neg_inf() {
277 assert!(log_safe(0.0).is_infinite());
278 assert!(log_safe(-1.0).is_infinite());
279 assert!(log_safe(f64::NAN).is_infinite());
280 assert!((log_safe(2.0) - 2.0_f64.ln()).abs() < 1e-12);
281 }
282
283 #[test]
284 fn gaussian_hmm_emission() {
285 let hmm = HmmGaussian::new(
286 2,
287 1,
288 vec![0.5, 0.5],
289 vec![0.9, 0.1, 0.1, 0.9],
290 vec![0.0, 5.0],
291 vec![1.0, 1.0],
292 )
293 .expect("ok");
294 let le0 = hmm.log_emission(0, &[0.0]).expect("ok");
295 let expected = -0.5 * (2.0 * std::f64::consts::PI).ln();
296 assert!((le0 - expected).abs() < 1e-12);
297 }
298}