1use ghostflow_core::Tensor;
7use rand::prelude::*;
8
9pub struct GaussianHMM {
11 pub n_components: usize, pub n_features: usize, pub covariance_type: HMMCovarianceType,
14 pub max_iter: usize,
15 pub tol: f32,
16 pub n_init: usize,
17
18 start_prob: Vec<f32>, trans_prob: Vec<Vec<f32>>, means: Vec<Vec<f32>>, covariances: Vec<Vec<f32>>, converged: bool,
24}
25
26#[derive(Clone, Copy)]
27pub enum HMMCovarianceType {
28 Diag, Full, Spherical, }
32
33impl GaussianHMM {
34 pub fn new(n_components: usize, n_features: usize) -> Self {
35 Self {
36 n_components,
37 n_features,
38 covariance_type: HMMCovarianceType::Diag,
39 max_iter: 100,
40 tol: 1e-2,
41 n_init: 1,
42 start_prob: vec![1.0 / n_components as f32; n_components],
43 trans_prob: vec![vec![1.0 / n_components as f32; n_components]; n_components],
44 means: Vec::new(),
45 covariances: Vec::new(),
46 converged: false,
47 }
48 }
49
50 pub fn covariance_type(mut self, cov_type: HMMCovarianceType) -> Self {
51 self.covariance_type = cov_type;
52 self
53 }
54
55 pub fn max_iter(mut self, iter: usize) -> Self {
56 self.max_iter = iter;
57 self
58 }
59
60 pub fn fit(&mut self, sequences: &[Tensor]) {
62 if sequences.is_empty() {
63 return;
64 }
65
66 let mut best_log_likelihood = f32::NEG_INFINITY;
67 let mut best_start_prob = Vec::new();
68 let mut best_trans_prob = Vec::new();
69 let mut best_means = Vec::new();
70 let mut best_covariances = Vec::new();
71
72 for _ in 0..self.n_init {
73 self.initialize_parameters(sequences);
75
76 let mut prev_log_likelihood = f32::NEG_INFINITY;
77
78 for _ in 0..self.max_iter {
80 let (log_likelihood, gamma, xi) = self.e_step(sequences);
82
83 self.m_step(sequences, &gamma, &xi);
85
86 if (log_likelihood - prev_log_likelihood).abs() < self.tol {
88 self.converged = true;
89 break;
90 }
91
92 prev_log_likelihood = log_likelihood;
93 }
94
95 let final_log_likelihood = self.compute_log_likelihood(sequences);
97 if final_log_likelihood > best_log_likelihood {
98 best_log_likelihood = final_log_likelihood;
99 best_start_prob = self.start_prob.clone();
100 best_trans_prob = self.trans_prob.clone();
101 best_means = self.means.clone();
102 best_covariances = self.covariances.clone();
103 }
104 }
105
106 self.start_prob = best_start_prob;
107 self.trans_prob = best_trans_prob;
108 self.means = best_means;
109 self.covariances = best_covariances;
110 }
111
112 fn initialize_parameters(&mut self, sequences: &[Tensor]) {
113 let mut rng = thread_rng();
114
115 self.start_prob = vec![1.0 / self.n_components as f32; self.n_components];
117
118 self.trans_prob = vec![vec![1.0 / self.n_components as f32; self.n_components]; self.n_components];
120
121 let mut all_obs = Vec::new();
123 for seq in sequences {
124 let seq_data = seq.data_f32();
125 let seq_len = seq.dims()[0];
126 for t in 0..seq_len {
127 all_obs.push(seq_data[t * self.n_features..(t + 1) * self.n_features].to_vec());
128 }
129 }
130
131 self.means = Vec::with_capacity(self.n_components);
132
133 let first_idx = rng.gen_range(0..all_obs.len());
135 self.means.push(all_obs[first_idx].clone());
136
137 for _ in 1..self.n_components {
139 let mut distances = vec![f32::MAX; all_obs.len()];
140
141 for (i, obs) in all_obs.iter().enumerate() {
142 let min_dist = self.means.iter()
143 .map(|mean| {
144 obs.iter().zip(mean.iter())
145 .map(|(x, m)| (x - m).powi(2))
146 .sum::<f32>()
147 })
148 .min_by(|a, b| a.partial_cmp(b).unwrap())
149 .unwrap();
150 distances[i] = min_dist;
151 }
152
153 let total_dist: f32 = distances.iter().sum();
154 let mut cumsum = 0.0;
155 let rand_val = rng.gen::<f32>() * total_dist;
156
157 let mut selected_idx = 0;
158 for (i, &dist) in distances.iter().enumerate() {
159 cumsum += dist;
160 if cumsum >= rand_val {
161 selected_idx = i;
162 break;
163 }
164 }
165
166 self.means.push(all_obs[selected_idx].clone());
167 }
168
169 self.covariances = match self.covariance_type {
171 HMMCovarianceType::Diag | HMMCovarianceType::Full => {
172 (0..self.n_components)
173 .map(|_| vec![1.0; self.n_features])
174 .collect()
175 }
176 HMMCovarianceType::Spherical => {
177 (0..self.n_components)
178 .map(|_| vec![1.0])
179 .collect()
180 }
181 };
182 }
183
184 fn e_step(&self, sequences: &[Tensor]) -> (f32, Vec<Vec<Vec<f32>>>, Vec<Vec<Vec<Vec<f32>>>>) {
186 let mut total_log_likelihood = 0.0;
187 let mut all_gamma = Vec::new();
188 let mut all_xi = Vec::new();
189
190 for seq in sequences {
191 let seq_data = seq.data_f32();
192 let seq_len = seq.dims()[0];
193
194 let (alpha, log_likelihood) = self.forward(&seq_data, seq_len);
196 total_log_likelihood += log_likelihood;
197
198 let beta = self.backward(&seq_data, seq_len);
200
201 let gamma = self.calculate_gamma(&alpha, &beta, seq_len);
203
204 let xi = self.calculate_xi(&alpha, &beta, &seq_data, seq_len);
206
207 all_gamma.push(gamma);
208 all_xi.push(xi);
209 }
210
211 (total_log_likelihood, all_gamma, all_xi)
212 }
213
214 fn forward(&self, seq_data: &[f32], seq_len: usize) -> (Vec<Vec<f32>>, f32) {
216 let mut alpha = vec![vec![0.0; self.n_components]; seq_len];
217 let mut scaling = vec![0.0; seq_len];
218
219 for i in 0..self.n_components {
221 let obs = &seq_data[0..self.n_features];
222 alpha[0][i] = self.start_prob[i] * self.emission_prob(obs, i);
223 scaling[0] += alpha[0][i];
224 }
225
226 if scaling[0] > 0.0 {
228 for i in 0..self.n_components {
229 alpha[0][i] /= scaling[0];
230 }
231 }
232
233 for t in 1..seq_len {
235 for j in 0..self.n_components {
236 let mut sum = 0.0;
237 for i in 0..self.n_components {
238 sum += alpha[t - 1][i] * self.trans_prob[i][j];
239 }
240 let obs = &seq_data[t * self.n_features..(t + 1) * self.n_features];
241 alpha[t][j] = sum * self.emission_prob(obs, j);
242 scaling[t] += alpha[t][j];
243 }
244
245 if scaling[t] > 0.0 {
247 for j in 0..self.n_components {
248 alpha[t][j] /= scaling[t];
249 }
250 }
251 }
252
253 let log_likelihood: f32 = scaling.iter().map(|&s| s.max(1e-10).ln()).sum();
255
256 (alpha, log_likelihood)
257 }
258
259 fn backward(&self, seq_data: &[f32], seq_len: usize) -> Vec<Vec<f32>> {
261 let mut beta = vec![vec![0.0; self.n_components]; seq_len];
262
263 for i in 0..self.n_components {
265 beta[seq_len - 1][i] = 1.0;
266 }
267
268 for t in (0..seq_len - 1).rev() {
270 for i in 0..self.n_components {
271 let mut sum = 0.0;
272 for j in 0..self.n_components {
273 let obs = &seq_data[(t + 1) * self.n_features..(t + 2) * self.n_features];
274 sum += self.trans_prob[i][j] * self.emission_prob(obs, j) * beta[t + 1][j];
275 }
276 beta[t][i] = sum;
277 }
278
279 let total: f32 = beta[t].iter().sum();
281 if total > 0.0 {
282 for i in 0..self.n_components {
283 beta[t][i] /= total;
284 }
285 }
286 }
287
288 beta
289 }
290
291 fn calculate_gamma(&self, alpha: &[Vec<f32>], beta: &[Vec<f32>], seq_len: usize) -> Vec<Vec<f32>> {
293 let mut gamma = vec![vec![0.0; self.n_components]; seq_len];
294
295 for t in 0..seq_len {
296 let mut total = 0.0;
297 for i in 0..self.n_components {
298 gamma[t][i] = alpha[t][i] * beta[t][i];
299 total += gamma[t][i];
300 }
301
302 if total > 0.0 {
304 for i in 0..self.n_components {
305 gamma[t][i] /= total;
306 }
307 }
308 }
309
310 gamma
311 }
312
313 fn calculate_xi(&self, alpha: &[Vec<f32>], beta: &[Vec<f32>], seq_data: &[f32], seq_len: usize) -> Vec<Vec<Vec<f32>>> {
315 let mut xi = vec![vec![vec![0.0; self.n_components]; self.n_components]; seq_len - 1];
316
317 for t in 0..seq_len - 1 {
318 let mut total = 0.0;
319 for i in 0..self.n_components {
320 for j in 0..self.n_components {
321 let obs = &seq_data[(t + 1) * self.n_features..(t + 2) * self.n_features];
322 xi[t][i][j] = alpha[t][i] * self.trans_prob[i][j] *
323 self.emission_prob(obs, j) * beta[t + 1][j];
324 total += xi[t][i][j];
325 }
326 }
327
328 if total > 0.0 {
330 for i in 0..self.n_components {
331 for j in 0..self.n_components {
332 xi[t][i][j] /= total;
333 }
334 }
335 }
336 }
337
338 xi
339 }
340
341 fn m_step(&mut self, sequences: &[Tensor], all_gamma: &[Vec<Vec<f32>>], all_xi: &[Vec<Vec<Vec<f32>>>]) {
343 for i in 0..self.n_components {
345 self.start_prob[i] = all_gamma.iter().map(|gamma| gamma[0][i]).sum::<f32>() / sequences.len() as f32;
346 }
347
348 for i in 0..self.n_components {
350 let mut denom = 0.0;
351 for j in 0..self.n_components {
352 let mut numer = 0.0;
353 for xi in all_xi {
354 for t in 0..xi.len() {
355 numer += xi[t][i][j];
356 }
357 }
358
359 for gamma in all_gamma {
360 for t in 0..gamma.len() - 1 {
361 denom += gamma[t][i];
362 }
363 }
364
365 self.trans_prob[i][j] = if denom > 0.0 { numer / denom } else { 1.0 / self.n_components as f32 };
366 }
367 }
368
369 for i in 0..self.n_components {
371 let mut weighted_sum = vec![0.0; self.n_features];
372 let mut weight_total = 0.0;
373
374 for (seq_idx, seq) in sequences.iter().enumerate() {
375 let seq_data = seq.data_f32();
376 let seq_len = seq.dims()[0];
377 let gamma = &all_gamma[seq_idx];
378
379 for t in 0..seq_len {
380 let obs = &seq_data[t * self.n_features..(t + 1) * self.n_features];
381 for j in 0..self.n_features {
382 weighted_sum[j] += gamma[t][i] * obs[j];
383 }
384 weight_total += gamma[t][i];
385 }
386 }
387
388 for j in 0..self.n_features {
390 self.means[i][j] = if weight_total > 0.0 { weighted_sum[j] / weight_total } else { 0.0 };
391 }
392
393 let mut weighted_var = vec![0.0; self.n_features];
395 for (seq_idx, seq) in sequences.iter().enumerate() {
396 let seq_data = seq.data_f32();
397 let seq_len = seq.dims()[0];
398 let gamma = &all_gamma[seq_idx];
399
400 for t in 0..seq_len {
401 let obs = &seq_data[t * self.n_features..(t + 1) * self.n_features];
402 for j in 0..self.n_features {
403 let diff = obs[j] - self.means[i][j];
404 weighted_var[j] += gamma[t][i] * diff * diff;
405 }
406 }
407 }
408
409 match self.covariance_type {
410 HMMCovarianceType::Diag | HMMCovarianceType::Full => {
411 for j in 0..self.n_features {
412 self.covariances[i][j] = if weight_total > 0.0 {
413 (weighted_var[j] / weight_total).max(1e-6)
414 } else {
415 1.0
416 };
417 }
418 }
419 HMMCovarianceType::Spherical => {
420 let avg_var = weighted_var.iter().sum::<f32>() / self.n_features as f32;
421 self.covariances[i][0] = if weight_total > 0.0 {
422 (avg_var / weight_total).max(1e-6)
423 } else {
424 1.0
425 };
426 }
427 }
428 }
429 }
430
431 fn emission_prob(&self, obs: &[f32], state: usize) -> f32 {
433 let mean = &self.means[state];
434 let cov = &self.covariances[state];
435
436 match self.covariance_type {
437 HMMCovarianceType::Diag | HMMCovarianceType::Full => {
438 let mut exponent = 0.0;
439 let mut det = 1.0;
440
441 for i in 0..self.n_features {
442 let diff = obs[i] - mean[i];
443 exponent += diff * diff / cov[i];
444 det *= cov[i];
445 }
446
447 let norm = 1.0 / ((2.0 * std::f32::consts::PI).powf(self.n_features as f32 / 2.0) * det.sqrt());
448 (norm * (-0.5 * exponent).exp()).max(1e-10)
449 }
450 HMMCovarianceType::Spherical => {
451 let variance = cov[0];
452 let mut exponent = 0.0;
453
454 for i in 0..self.n_features {
455 let diff = obs[i] - mean[i];
456 exponent += diff * diff;
457 }
458
459 let norm = 1.0 / ((2.0 * std::f32::consts::PI * variance).powf(self.n_features as f32 / 2.0));
460 (norm * (-exponent / (2.0 * variance)).exp()).max(1e-10)
461 }
462 }
463 }
464
465 fn compute_log_likelihood(&self, sequences: &[Tensor]) -> f32 {
467 let mut total_log_likelihood = 0.0;
468
469 for seq in sequences {
470 let seq_data = seq.data_f32();
471 let seq_len = seq.dims()[0];
472 let (_, log_likelihood) = self.forward(&seq_data, seq_len);
473 total_log_likelihood += log_likelihood;
474 }
475
476 total_log_likelihood
477 }
478
479 pub fn predict(&self, sequence: &Tensor) -> Tensor {
481 let seq_data = sequence.data_f32();
482 let seq_len = sequence.dims()[0];
483
484 let mut delta = vec![vec![0.0; self.n_components]; seq_len];
485 let mut psi = vec![vec![0; self.n_components]; seq_len];
486
487 for i in 0..self.n_components {
489 let obs = &seq_data[0..self.n_features];
490 delta[0][i] = self.start_prob[i].ln() + self.emission_prob(obs, i).ln();
491 }
492
493 for t in 1..seq_len {
495 for j in 0..self.n_components {
496 let mut max_val = f32::NEG_INFINITY;
497 let mut max_idx = 0;
498
499 for i in 0..self.n_components {
500 let val = delta[t - 1][i] + self.trans_prob[i][j].ln();
501 if val > max_val {
502 max_val = val;
503 max_idx = i;
504 }
505 }
506
507 let obs = &seq_data[t * self.n_features..(t + 1) * self.n_features];
508 delta[t][j] = max_val + self.emission_prob(obs, j).ln();
509 psi[t][j] = max_idx;
510 }
511 }
512
513 let mut path = vec![0; seq_len];
515 path[seq_len - 1] = delta[seq_len - 1]
516 .iter()
517 .enumerate()
518 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
519 .map(|(idx, _)| idx)
520 .unwrap();
521
522 for t in (0..seq_len - 1).rev() {
523 path[t] = psi[t + 1][path[t + 1]];
524 }
525
526 let path_f32: Vec<f32> = path.iter().map(|&x| x as f32).collect();
527 Tensor::from_slice(&path_f32, &[seq_len]).unwrap()
528 }
529}
530
531#[cfg(test)]
532mod tests {
533 use super::*;
534
535 #[test]
536 fn test_gaussian_hmm() {
537 let seq1 = Tensor::from_slice(
539 &[0.0f32, 0.0, 0.1, 0.1, 5.0, 5.0, 5.1, 5.1],
540 &[4, 2],
541 ).unwrap();
542
543 let sequences = vec![seq1];
544
545 let mut hmm = GaussianHMM::new(2, 2)
546 .covariance_type(HMMCovarianceType::Diag)
547 .max_iter(20);
548
549 hmm.fit(&sequences);
550
551 let test_seq = Tensor::from_slice(&[0.0f32, 0.0, 5.0, 5.0], &[2, 2]).unwrap();
552 let states = hmm.predict(&test_seq);
553
554 assert_eq!(states.dims()[0], 2); }
556}
557
558