1use super::forward_backward::forward_backward;
23use super::hmm::HmmDiscrete;
24use crate::error::{SeqError, SeqResult};
25
26#[derive(Debug, Clone)]
28pub struct PosteriorDecode {
29 pub path: Vec<usize>,
31 pub marginal: Vec<f64>,
33 pub gamma: Vec<f64>,
35 pub expected_correct: f64,
37}
38
39pub fn posterior_decode(hmm: &HmmDiscrete, obs: &[usize]) -> SeqResult<PosteriorDecode> {
47 if obs.is_empty() {
48 return Err(SeqError::EmptyInput);
49 }
50 let n = hmm.n_states;
51 let fb = forward_backward(hmm, obs)?;
52 let t_max = obs.len();
53
54 let mut path = vec![0usize; t_max];
55 let mut marginal = vec![0.0_f64; t_max];
56 let mut expected_correct = 0.0_f64;
57
58 for t in 0..t_max {
59 let row = &fb.gamma[t * n..t * n + n];
60 let mut best = f64::NEG_INFINITY;
61 let mut argmax = 0usize;
62 for (j, &g) in row.iter().enumerate() {
63 if g > best {
64 best = g;
65 argmax = j;
66 }
67 }
68 path[t] = argmax;
69 marginal[t] = best;
70 expected_correct += best;
71 }
72
73 Ok(PosteriorDecode {
74 path,
75 marginal,
76 gamma: fb.gamma,
77 expected_correct,
78 })
79}
80
81pub fn posterior_path_is_feasible(hmm: &HmmDiscrete, path: &[usize]) -> SeqResult<bool> {
92 if path.is_empty() {
93 return Err(SeqError::EmptyInput);
94 }
95 let n = hmm.n_states;
96 for &s in path {
97 if s >= n {
98 return Err(SeqError::IndexOutOfBounds { index: s, len: n });
99 }
100 }
101 if hmm.pi[path[0]] <= 0.0 {
102 return Ok(false);
103 }
104 for w in path.windows(2) {
105 if hmm.a[w[0] * n + w[1]] <= 0.0 {
106 return Ok(false);
107 }
108 }
109 Ok(true)
110}
111
112#[cfg(test)]
115mod tests {
116 use super::*;
117
118 fn small_hmm() -> HmmDiscrete {
120 HmmDiscrete::new(
121 2,
122 2,
123 vec![0.6, 0.4],
124 vec![0.7, 0.3, 0.4, 0.6],
125 vec![0.1, 0.9, 0.8, 0.2],
126 )
127 .expect("hmm")
128 }
129
130 fn deterministic_hmm() -> HmmDiscrete {
132 HmmDiscrete::new(
133 2,
134 2,
135 vec![0.5, 0.5],
136 vec![0.9, 0.1, 0.1, 0.9],
137 vec![0.99, 0.01, 0.01, 0.99],
139 )
140 .expect("hmm")
141 }
142
143 #[test]
144 fn decode_rejects_empty() {
145 let h = small_hmm();
146 assert!(matches!(
147 posterior_decode(&h, &[]),
148 Err(SeqError::EmptyInput)
149 ));
150 }
151
152 #[test]
153 fn decode_shapes_correct() {
154 let h = small_hmm();
155 let d = posterior_decode(&h, &[0, 1, 0, 1]).expect("ok");
156 assert_eq!(d.path.len(), 4);
157 assert_eq!(d.marginal.len(), 4);
158 assert_eq!(d.gamma.len(), 4 * 2);
159 }
160
161 #[test]
162 fn marginals_are_probabilities() {
163 let h = small_hmm();
164 let d = posterior_decode(&h, &[0, 1, 1, 0]).expect("ok");
165 for &m in &d.marginal {
166 assert!((0.0..=1.0).contains(&m), "marginal {m} out of [0,1]");
167 }
168 }
169
170 #[test]
171 fn chosen_label_is_argmax_of_gamma() {
172 let h = small_hmm();
173 let d = posterior_decode(&h, &[0, 1, 0]).expect("ok");
174 let n = h.n_states;
175 for t in 0..3 {
176 let row = &d.gamma[t * n..t * n + n];
177 let true_arg = row
178 .iter()
179 .enumerate()
180 .max_by(|a, b| a.1.partial_cmp(b.1).expect("finite"))
181 .map(|(i, _)| i)
182 .expect("nonempty");
183 assert_eq!(d.path[t], true_arg);
184 assert!((d.marginal[t] - row[true_arg]).abs() < 1e-12);
185 }
186 }
187
188 #[test]
189 fn gamma_rows_sum_to_one() {
190 let h = small_hmm();
191 let d = posterior_decode(&h, &[0, 0, 1, 1]).expect("ok");
192 let n = h.n_states;
193 for t in 0..4 {
194 let s: f64 = d.gamma[t * n..t * n + n].iter().sum();
195 assert!((s - 1.0).abs() < 1e-9);
196 }
197 }
198
199 #[test]
200 fn expected_correct_bounds() {
201 let h = small_hmm();
202 let obs = [0, 1, 0, 1, 0];
203 let d = posterior_decode(&h, &obs).expect("ok");
204 assert!(d.expected_correct <= obs.len() as f64 + 1e-9);
206 assert!(d.expected_correct >= obs.len() as f64 / 2.0 - 1e-9);
207 let s: f64 = d.marginal.iter().sum();
209 assert!((d.expected_correct - s).abs() < 1e-12);
210 }
211
212 #[test]
213 fn deterministic_recovers_symbol_states() {
214 let h = deterministic_hmm();
217 let obs = [0usize, 1, 0, 1, 1, 0];
218 let d = posterior_decode(&h, &obs).expect("ok");
219 for (t, &o) in obs.iter().enumerate() {
220 assert_eq!(d.path[t], o, "pos {t}: expected state {o}");
221 assert!(d.marginal[t] >= 0.5, "winner marginal must be ≥ 0.5");
223 }
224 let mean_conf: f64 = d.marginal.iter().sum::<f64>() / obs.len() as f64;
226 assert!(
227 mean_conf > 0.8,
228 "mean confidence {mean_conf} should be high"
229 );
230 }
231
232 #[test]
233 fn feasible_path_check_accepts_valid() {
234 let h = small_hmm();
235 assert!(posterior_path_is_feasible(&h, &[0, 1, 0, 1]).expect("ok"));
237 }
238
239 #[test]
240 fn feasible_path_check_rejects_zero_transition() {
241 let h = HmmDiscrete::new(
243 2,
244 2,
245 vec![0.5, 0.5],
246 vec![0.0, 1.0, 1.0, 0.0], vec![0.6, 0.4, 0.4, 0.6],
248 )
249 .expect("hmm");
250 assert!(!posterior_path_is_feasible(&h, &[0, 0]).expect("ok"));
251 assert!(posterior_path_is_feasible(&h, &[0, 1]).expect("ok"));
252 }
253
254 #[test]
255 fn feasible_rejects_zero_initial() {
256 let h = HmmDiscrete::new(
257 2,
258 2,
259 vec![1.0, 0.0], vec![0.5, 0.5, 0.5, 0.5],
261 vec![0.6, 0.4, 0.4, 0.6],
262 )
263 .expect("hmm");
264 assert!(!posterior_path_is_feasible(&h, &[1, 0]).expect("ok"));
265 }
266
267 #[test]
268 fn feasible_rejects_empty_and_oob() {
269 let h = small_hmm();
270 assert!(matches!(
271 posterior_path_is_feasible(&h, &[]),
272 Err(SeqError::EmptyInput)
273 ));
274 assert!(matches!(
275 posterior_path_is_feasible(&h, &[0, 5]),
276 Err(SeqError::IndexOutOfBounds { .. })
277 ));
278 }
279
280 #[test]
281 fn single_observation_decodes() {
282 let h = small_hmm();
283 let d = posterior_decode(&h, &[1]).expect("ok");
284 assert_eq!(d.path.len(), 1);
285 assert!(d.path[0] < 2);
287 }
288
289 #[test]
290 fn out_of_range_symbol_errors() {
291 let h = small_hmm();
292 assert!(posterior_decode(&h, &[0, 5]).is_err());
294 }
295}