Skip to main content

oxicuda_seq/hmm/
posterior_decoding.rs

1//! Posterior (max-marginal) decoding for discrete HMMs.
2//!
3//! Viterbi decoding returns the single jointly-most-probable state path
4//! `argmax_y P(y | o)`.  **Posterior decoding** (also "maximum-marginal" or
5//! "MPM — maximum posterior marginal" decoding) instead chooses, independently
6//! at each time step, the state with the largest *marginal* posterior:
7//!
8//! ```text
9//! ŷ_t = argmax_j  γ_t(j) ,   γ_t(j) = P(s_t = j | o_1…o_T)
10//! ```
11//!
12//! This minimises the expected number of individually-mislabelled positions
13//! (per-symbol error / Hamming risk), whereas Viterbi minimises the
14//! whole-sequence 0/1 error.  The two can disagree, and the posterior path is
15//! not guaranteed to be a *legal* path (it may traverse a zero-probability
16//! transition), so we also expose [`posterior_path_is_feasible`] to check it.
17//!
18//! The marginals `γ` come from log-space forward-backward; this module also
19//! returns the **expected number of correct labels** under the model, which is
20//! a useful confidence proxy: `Σ_t max_j γ_t(j)`.
21
22use super::forward_backward::forward_backward;
23use super::hmm::HmmDiscrete;
24use crate::error::{SeqError, SeqResult};
25
26/// Result of posterior (max-marginal) decoding.
27#[derive(Debug, Clone)]
28pub struct PosteriorDecode {
29    /// Per-position max-marginal labels, length `T`.
30    pub path: Vec<usize>,
31    /// The marginal posterior of the chosen label at each position, length `T`.
32    pub marginal: Vec<f64>,
33    /// State posteriors `γ` (`T × n_states`, row-major) for downstream use.
34    pub gamma: Vec<f64>,
35    /// Expected number of correctly-labelled positions, `Σ_t γ_t(ŷ_t)`.
36    pub expected_correct: f64,
37}
38
39/// Decode an observation sequence by maximum posterior marginal.
40///
41/// # Errors
42///
43/// * [`SeqError::EmptyInput`] — if `obs` is empty.
44/// * Propagates emission/transition errors from forward-backward (e.g. an
45///   observation symbol out of range).
46pub fn posterior_decode(hmm: &HmmDiscrete, obs: &[usize]) -> SeqResult<PosteriorDecode> {
47    if obs.is_empty() {
48        return Err(SeqError::EmptyInput);
49    }
50    let n = hmm.n_states;
51    let fb = forward_backward(hmm, obs)?;
52    let t_max = obs.len();
53
54    let mut path = vec![0usize; t_max];
55    let mut marginal = vec![0.0_f64; t_max];
56    let mut expected_correct = 0.0_f64;
57
58    for t in 0..t_max {
59        let row = &fb.gamma[t * n..t * n + n];
60        let mut best = f64::NEG_INFINITY;
61        let mut argmax = 0usize;
62        for (j, &g) in row.iter().enumerate() {
63            if g > best {
64                best = g;
65                argmax = j;
66            }
67        }
68        path[t] = argmax;
69        marginal[t] = best;
70        expected_correct += best;
71    }
72
73    Ok(PosteriorDecode {
74        path,
75        marginal,
76        gamma: fb.gamma,
77        expected_correct,
78    })
79}
80
81/// Whether a decoded path is *feasible* under the model — i.e. every visited
82/// transition has positive probability and the initial state has positive `π`.
83///
84/// Posterior decoding can yield infeasible paths (Viterbi cannot); this lets a
85/// caller detect that and fall back to Viterbi if a legal path is required.
86///
87/// # Errors
88///
89/// * [`SeqError::EmptyInput`]       — if `path` is empty.
90/// * [`SeqError::IndexOutOfBounds`] — if any state index `≥ n_states`.
91pub fn posterior_path_is_feasible(hmm: &HmmDiscrete, path: &[usize]) -> SeqResult<bool> {
92    if path.is_empty() {
93        return Err(SeqError::EmptyInput);
94    }
95    let n = hmm.n_states;
96    for &s in path {
97        if s >= n {
98            return Err(SeqError::IndexOutOfBounds { index: s, len: n });
99        }
100    }
101    if hmm.pi[path[0]] <= 0.0 {
102        return Ok(false);
103    }
104    for w in path.windows(2) {
105        if hmm.a[w[0] * n + w[1]] <= 0.0 {
106            return Ok(false);
107        }
108    }
109    Ok(true)
110}
111
112// ─── Tests ───────────────────────────────────────────────────────────────────
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117
118    /// A simple 2-state, 2-symbol HMM.
119    fn small_hmm() -> HmmDiscrete {
120        HmmDiscrete::new(
121            2,
122            2,
123            vec![0.6, 0.4],
124            vec![0.7, 0.3, 0.4, 0.6],
125            vec![0.1, 0.9, 0.8, 0.2],
126        )
127        .expect("hmm")
128    }
129
130    /// A near-deterministic HMM whose states are strongly tied to symbols.
131    fn deterministic_hmm() -> HmmDiscrete {
132        HmmDiscrete::new(
133            2,
134            2,
135            vec![0.5, 0.5],
136            vec![0.9, 0.1, 0.1, 0.9],
137            // state 0 almost always emits symbol 0; state 1 → symbol 1.
138            vec![0.99, 0.01, 0.01, 0.99],
139        )
140        .expect("hmm")
141    }
142
143    #[test]
144    fn decode_rejects_empty() {
145        let h = small_hmm();
146        assert!(matches!(
147            posterior_decode(&h, &[]),
148            Err(SeqError::EmptyInput)
149        ));
150    }
151
152    #[test]
153    fn decode_shapes_correct() {
154        let h = small_hmm();
155        let d = posterior_decode(&h, &[0, 1, 0, 1]).expect("ok");
156        assert_eq!(d.path.len(), 4);
157        assert_eq!(d.marginal.len(), 4);
158        assert_eq!(d.gamma.len(), 4 * 2);
159    }
160
161    #[test]
162    fn marginals_are_probabilities() {
163        let h = small_hmm();
164        let d = posterior_decode(&h, &[0, 1, 1, 0]).expect("ok");
165        for &m in &d.marginal {
166            assert!((0.0..=1.0).contains(&m), "marginal {m} out of [0,1]");
167        }
168    }
169
170    #[test]
171    fn chosen_label_is_argmax_of_gamma() {
172        let h = small_hmm();
173        let d = posterior_decode(&h, &[0, 1, 0]).expect("ok");
174        let n = h.n_states;
175        for t in 0..3 {
176            let row = &d.gamma[t * n..t * n + n];
177            let true_arg = row
178                .iter()
179                .enumerate()
180                .max_by(|a, b| a.1.partial_cmp(b.1).expect("finite"))
181                .map(|(i, _)| i)
182                .expect("nonempty");
183            assert_eq!(d.path[t], true_arg);
184            assert!((d.marginal[t] - row[true_arg]).abs() < 1e-12);
185        }
186    }
187
188    #[test]
189    fn gamma_rows_sum_to_one() {
190        let h = small_hmm();
191        let d = posterior_decode(&h, &[0, 0, 1, 1]).expect("ok");
192        let n = h.n_states;
193        for t in 0..4 {
194            let s: f64 = d.gamma[t * n..t * n + n].iter().sum();
195            assert!((s - 1.0).abs() < 1e-9);
196        }
197    }
198
199    #[test]
200    fn expected_correct_bounds() {
201        let h = small_hmm();
202        let obs = [0, 1, 0, 1, 0];
203        let d = posterior_decode(&h, &obs).expect("ok");
204        // Σ_t max_j γ ≥ T/n_states (uniform lower bound) and ≤ T.
205        assert!(d.expected_correct <= obs.len() as f64 + 1e-9);
206        assert!(d.expected_correct >= obs.len() as f64 / 2.0 - 1e-9);
207        // Also equals the sum of the per-position marginals.
208        let s: f64 = d.marginal.iter().sum();
209        assert!((d.expected_correct - s).abs() < 1e-12);
210    }
211
212    #[test]
213    fn deterministic_recovers_symbol_states() {
214        // With near-deterministic emissions, the max-marginal state at each
215        // position matches the emitted symbol.
216        let h = deterministic_hmm();
217        let obs = [0usize, 1, 0, 1, 1, 0];
218        let d = posterior_decode(&h, &obs).expect("ok");
219        for (t, &o) in obs.iter().enumerate() {
220            assert_eq!(d.path[t], o, "pos {t}: expected state {o}");
221            // Each marginal is the winner over two states, hence ≥ 0.5.
222            assert!(d.marginal[t] >= 0.5, "winner marginal must be ≥ 0.5");
223        }
224        // On average the model is highly confident on this clean sequence.
225        let mean_conf: f64 = d.marginal.iter().sum::<f64>() / obs.len() as f64;
226        assert!(
227            mean_conf > 0.8,
228            "mean confidence {mean_conf} should be high"
229        );
230    }
231
232    #[test]
233    fn feasible_path_check_accepts_valid() {
234        let h = small_hmm();
235        // All transitions are positive in this HMM, so any path is feasible.
236        assert!(posterior_path_is_feasible(&h, &[0, 1, 0, 1]).expect("ok"));
237    }
238
239    #[test]
240    fn feasible_path_check_rejects_zero_transition() {
241        // Build an HMM with a forbidden 0→0 self-transition.
242        let h = HmmDiscrete::new(
243            2,
244            2,
245            vec![0.5, 0.5],
246            vec![0.0, 1.0, 1.0, 0.0], // A[0,0] = 0
247            vec![0.6, 0.4, 0.4, 0.6],
248        )
249        .expect("hmm");
250        assert!(!posterior_path_is_feasible(&h, &[0, 0]).expect("ok"));
251        assert!(posterior_path_is_feasible(&h, &[0, 1]).expect("ok"));
252    }
253
254    #[test]
255    fn feasible_rejects_zero_initial() {
256        let h = HmmDiscrete::new(
257            2,
258            2,
259            vec![1.0, 0.0], // state 1 has zero prior
260            vec![0.5, 0.5, 0.5, 0.5],
261            vec![0.6, 0.4, 0.4, 0.6],
262        )
263        .expect("hmm");
264        assert!(!posterior_path_is_feasible(&h, &[1, 0]).expect("ok"));
265    }
266
267    #[test]
268    fn feasible_rejects_empty_and_oob() {
269        let h = small_hmm();
270        assert!(matches!(
271            posterior_path_is_feasible(&h, &[]),
272            Err(SeqError::EmptyInput)
273        ));
274        assert!(matches!(
275            posterior_path_is_feasible(&h, &[0, 5]),
276            Err(SeqError::IndexOutOfBounds { .. })
277        ));
278    }
279
280    #[test]
281    fn single_observation_decodes() {
282        let h = small_hmm();
283        let d = posterior_decode(&h, &[1]).expect("ok");
284        assert_eq!(d.path.len(), 1);
285        // γ_0 ∝ π_j B_j(o_0); the larger product wins.
286        assert!(d.path[0] < 2);
287    }
288
289    #[test]
290    fn out_of_range_symbol_errors() {
291        let h = small_hmm();
292        // symbol 5 ≥ n_obs=2 → emission lookup fails inside forward-backward.
293        assert!(posterior_decode(&h, &[0, 5]).is_err());
294    }
295}