ghostflow_ml/
hmm.rs

1//! Hidden Markov Models (HMM)
2//!
3//! Statistical models for sequential data where the system is assumed to be
4//! a Markov process with hidden states.
5
6use ghostflow_core::Tensor;
7use rand::prelude::*;
8
9/// Hidden Markov Model with Gaussian emissions
10pub struct GaussianHMM {
11    pub n_components: usize,  // Number of hidden states
12    pub n_features: usize,    // Dimensionality of observations
13    pub covariance_type: HMMCovarianceType,
14    pub max_iter: usize,
15    pub tol: f32,
16    pub n_init: usize,
17    
18    // Model parameters
19    start_prob: Vec<f32>,           // Initial state probabilities (n_components,)
20    trans_prob: Vec<Vec<f32>>,      // Transition probabilities (n_components, n_components)
21    means: Vec<Vec<f32>>,           // Emission means (n_components, n_features)
22    covariances: Vec<Vec<f32>>,     // Emission covariances
23    converged: bool,
24}
25
26#[derive(Clone, Copy)]
27pub enum HMMCovarianceType {
28    Diag,      // Diagonal covariance
29    Full,      // Full covariance
30    Spherical, // Single variance
31}
32
33impl GaussianHMM {
34    pub fn new(n_components: usize, n_features: usize) -> Self {
35        Self {
36            n_components,
37            n_features,
38            covariance_type: HMMCovarianceType::Diag,
39            max_iter: 100,
40            tol: 1e-2,
41            n_init: 1,
42            start_prob: vec![1.0 / n_components as f32; n_components],
43            trans_prob: vec![vec![1.0 / n_components as f32; n_components]; n_components],
44            means: Vec::new(),
45            covariances: Vec::new(),
46            converged: false,
47        }
48    }
49
50    pub fn covariance_type(mut self, cov_type: HMMCovarianceType) -> Self {
51        self.covariance_type = cov_type;
52        self
53    }
54
55    pub fn max_iter(mut self, iter: usize) -> Self {
56        self.max_iter = iter;
57        self
58    }
59
60    /// Fit the HMM using Baum-Welch algorithm (EM for HMMs)
61    pub fn fit(&mut self, sequences: &[Tensor]) {
62        if sequences.is_empty() {
63            return;
64        }
65
66        let mut best_log_likelihood = f32::NEG_INFINITY;
67        let mut best_start_prob = Vec::new();
68        let mut best_trans_prob = Vec::new();
69        let mut best_means = Vec::new();
70        let mut best_covariances = Vec::new();
71
72        for _ in 0..self.n_init {
73            // Initialize parameters
74            self.initialize_parameters(sequences);
75
76            let mut prev_log_likelihood = f32::NEG_INFINITY;
77
78            // Baum-Welch algorithm
79            for _ in 0..self.max_iter {
80                // E-step: Forward-backward algorithm
81                let (log_likelihood, gamma, xi) = self.e_step(sequences);
82
83                // M-step: Update parameters
84                self.m_step(sequences, &gamma, &xi);
85
86                // Check convergence
87                if (log_likelihood - prev_log_likelihood).abs() < self.tol {
88                    self.converged = true;
89                    break;
90                }
91
92                prev_log_likelihood = log_likelihood;
93            }
94
95            // Keep best result
96            let final_log_likelihood = self.compute_log_likelihood(sequences);
97            if final_log_likelihood > best_log_likelihood {
98                best_log_likelihood = final_log_likelihood;
99                best_start_prob = self.start_prob.clone();
100                best_trans_prob = self.trans_prob.clone();
101                best_means = self.means.clone();
102                best_covariances = self.covariances.clone();
103            }
104        }
105
106        self.start_prob = best_start_prob;
107        self.trans_prob = best_trans_prob;
108        self.means = best_means;
109        self.covariances = best_covariances;
110    }
111
112    fn initialize_parameters(&mut self, sequences: &[Tensor]) {
113        let mut rng = thread_rng();
114
115        // Initialize start probabilities uniformly
116        self.start_prob = vec![1.0 / self.n_components as f32; self.n_components];
117
118        // Initialize transition probabilities uniformly
119        self.trans_prob = vec![vec![1.0 / self.n_components as f32; self.n_components]; self.n_components];
120
121        // Initialize means using k-means++ on all observations
122        let mut all_obs = Vec::new();
123        for seq in sequences {
124            let seq_data = seq.data_f32();
125            let seq_len = seq.dims()[0];
126            for t in 0..seq_len {
127                all_obs.push(seq_data[t * self.n_features..(t + 1) * self.n_features].to_vec());
128            }
129        }
130
131        self.means = Vec::with_capacity(self.n_components);
132        
133        // First mean: random observation
134        let first_idx = rng.gen_range(0..all_obs.len());
135        self.means.push(all_obs[first_idx].clone());
136
137        // Remaining means: k-means++ strategy
138        for _ in 1..self.n_components {
139            let mut distances = vec![f32::MAX; all_obs.len()];
140            
141            for (i, obs) in all_obs.iter().enumerate() {
142                let min_dist = self.means.iter()
143                    .map(|mean| {
144                        obs.iter().zip(mean.iter())
145                            .map(|(x, m)| (x - m).powi(2))
146                            .sum::<f32>()
147                    })
148                    .min_by(|a, b| a.partial_cmp(b).unwrap())
149                    .unwrap();
150                distances[i] = min_dist;
151            }
152
153            let total_dist: f32 = distances.iter().sum();
154            let mut cumsum = 0.0;
155            let rand_val = rng.gen::<f32>() * total_dist;
156            
157            let mut selected_idx = 0;
158            for (i, &dist) in distances.iter().enumerate() {
159                cumsum += dist;
160                if cumsum >= rand_val {
161                    selected_idx = i;
162                    break;
163                }
164            }
165
166            self.means.push(all_obs[selected_idx].clone());
167        }
168
169        // Initialize covariances
170        self.covariances = match self.covariance_type {
171            HMMCovarianceType::Diag | HMMCovarianceType::Full => {
172                (0..self.n_components)
173                    .map(|_| vec![1.0; self.n_features])
174                    .collect()
175            }
176            HMMCovarianceType::Spherical => {
177                (0..self.n_components)
178                    .map(|_| vec![1.0])
179                    .collect()
180            }
181        };
182    }
183
184    /// E-step: Forward-backward algorithm
185    fn e_step(&self, sequences: &[Tensor]) -> (f32, Vec<Vec<Vec<f32>>>, Vec<Vec<Vec<Vec<f32>>>>) {
186        let mut total_log_likelihood = 0.0;
187        let mut all_gamma = Vec::new();
188        let mut all_xi = Vec::new();
189
190        for seq in sequences {
191            let seq_data = seq.data_f32();
192            let seq_len = seq.dims()[0];
193
194            // Forward algorithm
195            let (alpha, log_likelihood) = self.forward(&seq_data, seq_len);
196            total_log_likelihood += log_likelihood;
197
198            // Backward algorithm
199            let beta = self.backward(&seq_data, seq_len);
200
201            // Calculate gamma (state probabilities)
202            let gamma = self.calculate_gamma(&alpha, &beta, seq_len);
203
204            // Calculate xi (transition probabilities)
205            let xi = self.calculate_xi(&alpha, &beta, &seq_data, seq_len);
206
207            all_gamma.push(gamma);
208            all_xi.push(xi);
209        }
210
211        (total_log_likelihood, all_gamma, all_xi)
212    }
213
214    /// Forward algorithm
215    fn forward(&self, seq_data: &[f32], seq_len: usize) -> (Vec<Vec<f32>>, f32) {
216        let mut alpha = vec![vec![0.0; self.n_components]; seq_len];
217        let mut scaling = vec![0.0; seq_len];
218
219        // Initialize
220        for i in 0..self.n_components {
221            let obs = &seq_data[0..self.n_features];
222            alpha[0][i] = self.start_prob[i] * self.emission_prob(obs, i);
223            scaling[0] += alpha[0][i];
224        }
225
226        // Scale
227        if scaling[0] > 0.0 {
228            for i in 0..self.n_components {
229                alpha[0][i] /= scaling[0];
230            }
231        }
232
233        // Recursion
234        for t in 1..seq_len {
235            for j in 0..self.n_components {
236                let mut sum = 0.0;
237                for i in 0..self.n_components {
238                    sum += alpha[t - 1][i] * self.trans_prob[i][j];
239                }
240                let obs = &seq_data[t * self.n_features..(t + 1) * self.n_features];
241                alpha[t][j] = sum * self.emission_prob(obs, j);
242                scaling[t] += alpha[t][j];
243            }
244
245            // Scale
246            if scaling[t] > 0.0 {
247                for j in 0..self.n_components {
248                    alpha[t][j] /= scaling[t];
249                }
250            }
251        }
252
253        // Calculate log likelihood
254        let log_likelihood: f32 = scaling.iter().map(|&s| s.max(1e-10).ln()).sum();
255
256        (alpha, log_likelihood)
257    }
258
259    /// Backward algorithm
260    fn backward(&self, seq_data: &[f32], seq_len: usize) -> Vec<Vec<f32>> {
261        let mut beta = vec![vec![0.0; self.n_components]; seq_len];
262
263        // Initialize
264        for i in 0..self.n_components {
265            beta[seq_len - 1][i] = 1.0;
266        }
267
268        // Recursion
269        for t in (0..seq_len - 1).rev() {
270            for i in 0..self.n_components {
271                let mut sum = 0.0;
272                for j in 0..self.n_components {
273                    let obs = &seq_data[(t + 1) * self.n_features..(t + 2) * self.n_features];
274                    sum += self.trans_prob[i][j] * self.emission_prob(obs, j) * beta[t + 1][j];
275                }
276                beta[t][i] = sum;
277            }
278
279            // Normalize
280            let total: f32 = beta[t].iter().sum();
281            if total > 0.0 {
282                for i in 0..self.n_components {
283                    beta[t][i] /= total;
284                }
285            }
286        }
287
288        beta
289    }
290
291    /// Calculate gamma (state probabilities)
292    fn calculate_gamma(&self, alpha: &[Vec<f32>], beta: &[Vec<f32>], seq_len: usize) -> Vec<Vec<f32>> {
293        let mut gamma = vec![vec![0.0; self.n_components]; seq_len];
294
295        for t in 0..seq_len {
296            let mut total = 0.0;
297            for i in 0..self.n_components {
298                gamma[t][i] = alpha[t][i] * beta[t][i];
299                total += gamma[t][i];
300            }
301
302            // Normalize
303            if total > 0.0 {
304                for i in 0..self.n_components {
305                    gamma[t][i] /= total;
306                }
307            }
308        }
309
310        gamma
311    }
312
313    /// Calculate xi (transition probabilities)
314    fn calculate_xi(&self, alpha: &[Vec<f32>], beta: &[Vec<f32>], seq_data: &[f32], seq_len: usize) -> Vec<Vec<Vec<f32>>> {
315        let mut xi = vec![vec![vec![0.0; self.n_components]; self.n_components]; seq_len - 1];
316
317        for t in 0..seq_len - 1 {
318            let mut total = 0.0;
319            for i in 0..self.n_components {
320                for j in 0..self.n_components {
321                    let obs = &seq_data[(t + 1) * self.n_features..(t + 2) * self.n_features];
322                    xi[t][i][j] = alpha[t][i] * self.trans_prob[i][j] * 
323                                  self.emission_prob(obs, j) * beta[t + 1][j];
324                    total += xi[t][i][j];
325                }
326            }
327
328            // Normalize
329            if total > 0.0 {
330                for i in 0..self.n_components {
331                    for j in 0..self.n_components {
332                        xi[t][i][j] /= total;
333                    }
334                }
335            }
336        }
337
338        xi
339    }
340
341    /// M-step: Update parameters
342    fn m_step(&mut self, sequences: &[Tensor], all_gamma: &[Vec<Vec<f32>>], all_xi: &[Vec<Vec<Vec<f32>>>]) {
343        // Update start probabilities
344        for i in 0..self.n_components {
345            self.start_prob[i] = all_gamma.iter().map(|gamma| gamma[0][i]).sum::<f32>() / sequences.len() as f32;
346        }
347
348        // Update transition probabilities
349        for i in 0..self.n_components {
350            let mut denom = 0.0;
351            for j in 0..self.n_components {
352                let mut numer = 0.0;
353                for xi in all_xi {
354                    for t in 0..xi.len() {
355                        numer += xi[t][i][j];
356                    }
357                }
358                
359                for gamma in all_gamma {
360                    for t in 0..gamma.len() - 1 {
361                        denom += gamma[t][i];
362                    }
363                }
364                
365                self.trans_prob[i][j] = if denom > 0.0 { numer / denom } else { 1.0 / self.n_components as f32 };
366            }
367        }
368
369        // Update emission parameters
370        for i in 0..self.n_components {
371            let mut weighted_sum = vec![0.0; self.n_features];
372            let mut weight_total = 0.0;
373
374            for (seq_idx, seq) in sequences.iter().enumerate() {
375                let seq_data = seq.data_f32();
376                let seq_len = seq.dims()[0];
377                let gamma = &all_gamma[seq_idx];
378
379                for t in 0..seq_len {
380                    let obs = &seq_data[t * self.n_features..(t + 1) * self.n_features];
381                    for j in 0..self.n_features {
382                        weighted_sum[j] += gamma[t][i] * obs[j];
383                    }
384                    weight_total += gamma[t][i];
385                }
386            }
387
388            // Update mean
389            for j in 0..self.n_features {
390                self.means[i][j] = if weight_total > 0.0 { weighted_sum[j] / weight_total } else { 0.0 };
391            }
392
393            // Update covariance
394            let mut weighted_var = vec![0.0; self.n_features];
395            for (seq_idx, seq) in sequences.iter().enumerate() {
396                let seq_data = seq.data_f32();
397                let seq_len = seq.dims()[0];
398                let gamma = &all_gamma[seq_idx];
399
400                for t in 0..seq_len {
401                    let obs = &seq_data[t * self.n_features..(t + 1) * self.n_features];
402                    for j in 0..self.n_features {
403                        let diff = obs[j] - self.means[i][j];
404                        weighted_var[j] += gamma[t][i] * diff * diff;
405                    }
406                }
407            }
408
409            match self.covariance_type {
410                HMMCovarianceType::Diag | HMMCovarianceType::Full => {
411                    for j in 0..self.n_features {
412                        self.covariances[i][j] = if weight_total > 0.0 { 
413                            (weighted_var[j] / weight_total).max(1e-6)
414                        } else { 
415                            1.0 
416                        };
417                    }
418                }
419                HMMCovarianceType::Spherical => {
420                    let avg_var = weighted_var.iter().sum::<f32>() / self.n_features as f32;
421                    self.covariances[i][0] = if weight_total > 0.0 { 
422                        (avg_var / weight_total).max(1e-6)
423                    } else { 
424                        1.0 
425                    };
426                }
427            }
428        }
429    }
430
431    /// Calculate emission probability
432    fn emission_prob(&self, obs: &[f32], state: usize) -> f32 {
433        let mean = &self.means[state];
434        let cov = &self.covariances[state];
435
436        match self.covariance_type {
437            HMMCovarianceType::Diag | HMMCovarianceType::Full => {
438                let mut exponent = 0.0;
439                let mut det = 1.0;
440                
441                for i in 0..self.n_features {
442                    let diff = obs[i] - mean[i];
443                    exponent += diff * diff / cov[i];
444                    det *= cov[i];
445                }
446
447                let norm = 1.0 / ((2.0 * std::f32::consts::PI).powf(self.n_features as f32 / 2.0) * det.sqrt());
448                (norm * (-0.5 * exponent).exp()).max(1e-10)
449            }
450            HMMCovarianceType::Spherical => {
451                let variance = cov[0];
452                let mut exponent = 0.0;
453                
454                for i in 0..self.n_features {
455                    let diff = obs[i] - mean[i];
456                    exponent += diff * diff;
457                }
458
459                let norm = 1.0 / ((2.0 * std::f32::consts::PI * variance).powf(self.n_features as f32 / 2.0));
460                (norm * (-exponent / (2.0 * variance)).exp()).max(1e-10)
461            }
462        }
463    }
464
465    /// Compute log likelihood
466    fn compute_log_likelihood(&self, sequences: &[Tensor]) -> f32 {
467        let mut total_log_likelihood = 0.0;
468
469        for seq in sequences {
470            let seq_data = seq.data_f32();
471            let seq_len = seq.dims()[0];
472            let (_, log_likelihood) = self.forward(&seq_data, seq_len);
473            total_log_likelihood += log_likelihood;
474        }
475
476        total_log_likelihood
477    }
478
479    /// Predict hidden state sequence using Viterbi algorithm
480    pub fn predict(&self, sequence: &Tensor) -> Tensor {
481        let seq_data = sequence.data_f32();
482        let seq_len = sequence.dims()[0];
483
484        let mut delta = vec![vec![0.0; self.n_components]; seq_len];
485        let mut psi = vec![vec![0; self.n_components]; seq_len];
486
487        // Initialize
488        for i in 0..self.n_components {
489            let obs = &seq_data[0..self.n_features];
490            delta[0][i] = self.start_prob[i].ln() + self.emission_prob(obs, i).ln();
491        }
492
493        // Recursion
494        for t in 1..seq_len {
495            for j in 0..self.n_components {
496                let mut max_val = f32::NEG_INFINITY;
497                let mut max_idx = 0;
498
499                for i in 0..self.n_components {
500                    let val = delta[t - 1][i] + self.trans_prob[i][j].ln();
501                    if val > max_val {
502                        max_val = val;
503                        max_idx = i;
504                    }
505                }
506
507                let obs = &seq_data[t * self.n_features..(t + 1) * self.n_features];
508                delta[t][j] = max_val + self.emission_prob(obs, j).ln();
509                psi[t][j] = max_idx;
510            }
511        }
512
513        // Backtrack
514        let mut path = vec![0; seq_len];
515        path[seq_len - 1] = delta[seq_len - 1]
516            .iter()
517            .enumerate()
518            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
519            .map(|(idx, _)| idx)
520            .unwrap();
521
522        for t in (0..seq_len - 1).rev() {
523            path[t] = psi[t + 1][path[t + 1]];
524        }
525
526        let path_f32: Vec<f32> = path.iter().map(|&x| x as f32).collect();
527        Tensor::from_slice(&path_f32, &[seq_len]).unwrap()
528    }
529}
530
531#[cfg(test)]
532mod tests {
533    use super::*;
534
535    #[test]
536    fn test_gaussian_hmm() {
537        // Create simple sequence
538        let seq1 = Tensor::from_slice(
539            &[0.0f32, 0.0, 0.1, 0.1, 5.0, 5.0, 5.1, 5.1],
540            &[4, 2],
541        ).unwrap();
542
543        let sequences = vec![seq1];
544
545        let mut hmm = GaussianHMM::new(2, 2)
546            .covariance_type(HMMCovarianceType::Diag)
547            .max_iter(20);
548
549        hmm.fit(&sequences);
550
551        let test_seq = Tensor::from_slice(&[0.0f32, 0.0, 5.0, 5.0], &[2, 2]).unwrap();
552        let states = hmm.predict(&test_seq);
553
554        assert_eq!(states.dims()[0], 2); // Number of observations
555    }
556}
557
558