oxicuda_seq/hmm/
forward_backward.rs1use super::hmm::{HmmDiscrete, HmmGaussian, log_safe};
4use crate::error::{SeqError, SeqResult};
5
6#[derive(Debug, Clone)]
8pub struct ForwardBackward {
9 pub log_alpha: Vec<f64>,
11 pub log_beta: Vec<f64>,
13 pub gamma: Vec<f64>,
15 pub xi: Vec<f64>,
17 pub log_likelihood: f64,
19}
20
21#[inline]
23pub fn logsumexp(xs: &[f64]) -> f64 {
24 let m = xs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
25 if m == f64::NEG_INFINITY {
26 return f64::NEG_INFINITY;
27 }
28 let s: f64 = xs.iter().map(|&x| (x - m).exp()).sum();
29 m + s.ln()
30}
31
32pub fn forward_backward(hmm: &HmmDiscrete, obs: &[usize]) -> SeqResult<ForwardBackward> {
34 if obs.is_empty() {
35 return Err(SeqError::EmptyInput);
36 }
37 let t_max = obs.len();
38 let n = hmm.n_states;
39
40 let mut log_em = vec![f64::NEG_INFINITY; t_max * n];
42 for t in 0..t_max {
43 for j in 0..n {
44 log_em[t * n + j] = hmm.log_emission(j, obs[t])?;
45 }
46 }
47
48 let mut log_a = vec![f64::NEG_INFINITY; n * n];
50 for i in 0..n {
51 for j in 0..n {
52 log_a[i * n + j] = log_safe(hmm.a[i * n + j]);
53 }
54 }
55 let log_pi: Vec<f64> = hmm.pi.iter().map(|&p| log_safe(p)).collect();
56
57 forward_backward_log(&log_pi, &log_a, &log_em, n, t_max)
58}
59
60pub fn forward_backward_gaussian(hmm: &HmmGaussian, x: &[f64]) -> SeqResult<ForwardBackward> {
63 if x.is_empty() {
64 return Err(SeqError::EmptyInput);
65 }
66 if x.len() % hmm.dim != 0 {
67 return Err(SeqError::DimensionMismatch {
68 a: x.len(),
69 b: hmm.dim,
70 });
71 }
72 let t_max = x.len() / hmm.dim;
73 let n = hmm.n_states;
74 let mut log_em = vec![f64::NEG_INFINITY; t_max * n];
75 for t in 0..t_max {
76 let row = &x[t * hmm.dim..(t + 1) * hmm.dim];
77 for j in 0..n {
78 log_em[t * n + j] = hmm.log_emission(j, row)?;
79 }
80 }
81 let mut log_a = vec![f64::NEG_INFINITY; n * n];
82 for i in 0..n {
83 for j in 0..n {
84 log_a[i * n + j] = log_safe(hmm.a[i * n + j]);
85 }
86 }
87 let log_pi: Vec<f64> = hmm.pi.iter().map(|&p| log_safe(p)).collect();
88 forward_backward_log(&log_pi, &log_a, &log_em, n, t_max)
89}
90
91fn forward_backward_log(
93 log_pi: &[f64],
94 log_a: &[f64],
95 log_em: &[f64],
96 n: usize,
97 t_max: usize,
98) -> SeqResult<ForwardBackward> {
99 let mut log_alpha = vec![f64::NEG_INFINITY; t_max * n];
100 let mut log_beta = vec![f64::NEG_INFINITY; t_max * n];
101
102 for j in 0..n {
104 log_alpha[j] = log_pi[j] + log_em[j];
105 }
106
107 let mut tmp = vec![0.0; n];
109 for t in 1..t_max {
110 for j in 0..n {
111 for i in 0..n {
112 tmp[i] = log_alpha[(t - 1) * n + i] + log_a[i * n + j];
113 }
114 log_alpha[t * n + j] = logsumexp(&tmp) + log_em[t * n + j];
115 }
116 }
117
118 for i in 0..n {
120 log_beta[(t_max - 1) * n + i] = 0.0;
121 }
122 for t in (0..t_max - 1).rev() {
124 for i in 0..n {
125 for j in 0..n {
126 tmp[j] = log_a[i * n + j] + log_em[(t + 1) * n + j] + log_beta[(t + 1) * n + j];
127 }
128 log_beta[t * n + i] = logsumexp(&tmp);
129 }
130 }
131
132 let last_alpha = &log_alpha[(t_max - 1) * n..t_max * n];
134 let ll = logsumexp(last_alpha);
135
136 let mut gamma = vec![0.0; t_max * n];
138 for t in 0..t_max {
139 for i in 0..n {
140 gamma[t * n + i] = (log_alpha[t * n + i] + log_beta[t * n + i] - ll).exp();
141 }
142 let s: f64 = gamma[t * n..t * n + n].iter().sum();
144 if s > 0.0 {
145 for i in 0..n {
146 gamma[t * n + i] /= s;
147 }
148 }
149 }
150
151 let mut xi = vec![0.0; (t_max.saturating_sub(1)) * n * n];
153 for t in 0..t_max.saturating_sub(1) {
154 let mut s = 0.0;
155 for i in 0..n {
156 for j in 0..n {
157 let v = (log_alpha[t * n + i]
158 + log_a[i * n + j]
159 + log_em[(t + 1) * n + j]
160 + log_beta[(t + 1) * n + j]
161 - ll)
162 .exp();
163 xi[t * n * n + i * n + j] = v;
164 s += v;
165 }
166 }
167 if s > 0.0 {
168 for v in xi[t * n * n..(t + 1) * n * n].iter_mut() {
169 *v /= s;
170 }
171 }
172 }
173
174 Ok(ForwardBackward {
175 log_alpha,
176 log_beta,
177 gamma,
178 xi,
179 log_likelihood: ll,
180 })
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186
187 fn small_hmm() -> HmmDiscrete {
188 HmmDiscrete::new(
189 2,
190 2,
191 vec![0.6, 0.4],
192 vec![0.7, 0.3, 0.4, 0.6],
193 vec![0.1, 0.9, 0.8, 0.2],
194 )
195 .expect("ok")
196 }
197
198 #[test]
199 fn forward_alpha_dimensions() {
200 let h = small_hmm();
201 let fb = forward_backward(&h, &[0, 1, 0]).expect("ok");
202 assert_eq!(fb.log_alpha.len(), 6);
203 assert_eq!(fb.gamma.len(), 6);
204 assert_eq!(fb.xi.len(), 8);
205 }
206
207 #[test]
208 fn gamma_rows_sum_to_one() {
209 let h = small_hmm();
210 let fb = forward_backward(&h, &[0, 1, 0, 1]).expect("ok");
211 for t in 0..4 {
212 let s: f64 = fb.gamma[t * 2..(t + 1) * 2].iter().sum();
213 assert!((s - 1.0).abs() < 1e-9, "γ_{t} sums to {s}");
214 }
215 }
216
217 #[test]
218 fn logsumexp_neg_inf() {
219 let xs = vec![f64::NEG_INFINITY, f64::NEG_INFINITY];
220 assert!(logsumexp(&xs).is_infinite());
221 }
222
223 #[test]
224 fn logsumexp_single() {
225 let xs = vec![5.0];
226 assert!((logsumexp(&xs) - 5.0).abs() < 1e-12);
227 }
228}