1use crate::common::{CovarianceType, ModelSelection};
9use scirs2_core::ndarray::{s, Array1, Array2, Array3, Axis};
10use scirs2_core::random::{thread_rng, RandNormal, RandUniform};
11use sklears_core::{
12 error::{Result as SklResult, SklearsError},
13 traits::{Estimator, Fit, Predict, Trained, Untrained},
14};
15use std::collections::HashMap;
16
17#[derive(Debug, Clone, PartialEq)]
19pub enum FusionStrategy {
20 EarlyFusion,
22 LateFusion,
24 IntermediateFusion,
26 CoupledFusion,
28}
29
30#[derive(Debug, Clone)]
32pub struct ModalitySpec {
33 pub name: String,
35 pub n_features: usize,
37 pub covariance_type: CovarianceType,
39 pub modality_weight: f64,
41}
42
43#[derive(Debug, Clone)]
45pub struct MultiModalConfig {
46 pub n_components: usize,
48 pub modalities: Vec<ModalitySpec>,
50 pub fusion_strategy: FusionStrategy,
52 pub shared_latent_dim: Option<usize>,
54 pub coupling_strength: f64,
56 pub max_iter: usize,
58 pub tol: f64,
60 pub regularization_strength: f64,
62 pub random_state: Option<u64>,
64}
65
66#[derive(Debug, Clone)]
73pub struct MultiModalGaussianMixture<S = Untrained> {
74 config: MultiModalConfig,
75 _phantom: std::marker::PhantomData<S>,
76}
77
78#[derive(Debug, Clone)]
80pub struct MultiModalGaussianMixtureTrained {
81 pub weights: Array1<f64>,
83 pub modality_means: HashMap<String, Array2<f64>>,
85 pub modality_covariances: HashMap<String, Array3<f64>>,
87 pub shared_latent_means: Option<Array2<f64>>,
89 pub latent_projections: HashMap<String, Array2<f64>>,
91 pub coupling_parameters: Array2<f64>,
93 pub log_likelihood_history: Vec<f64>,
95 pub n_iter: usize,
97 pub config: MultiModalConfig,
99}
100
101#[derive(Debug, Clone)]
103pub struct MultiModalGaussianMixtureBuilder {
104 n_components: usize,
105 modalities: Vec<ModalitySpec>,
106 fusion_strategy: FusionStrategy,
107 shared_latent_dim: Option<usize>,
108 coupling_strength: f64,
109 max_iter: usize,
110 tol: f64,
111 regularization_strength: f64,
112 random_state: Option<u64>,
113}
114
115impl MultiModalGaussianMixtureBuilder {
116 pub fn new(n_components: usize) -> Self {
118 Self {
119 n_components,
120 modalities: Vec::new(),
121 fusion_strategy: FusionStrategy::IntermediateFusion,
122 shared_latent_dim: None,
123 coupling_strength: 0.1,
124 max_iter: 100,
125 tol: 1e-4,
126 regularization_strength: 0.01,
127 random_state: None,
128 }
129 }
130
131 pub fn add_modality(mut self, modality: ModalitySpec) -> Self {
133 self.modalities.push(modality);
134 self
135 }
136
137 pub fn add_modality_simple(mut self, name: &str, n_features: usize) -> Self {
139 let modality = ModalitySpec {
140 name: name.to_string(),
141 n_features,
142 covariance_type: CovarianceType::Full,
143 modality_weight: 1.0,
144 };
145 self.modalities.push(modality);
146 self
147 }
148
149 pub fn fusion_strategy(mut self, strategy: FusionStrategy) -> Self {
151 self.fusion_strategy = strategy;
152 self
153 }
154
155 pub fn shared_latent_dim(mut self, dim: usize) -> Self {
157 self.shared_latent_dim = Some(dim);
158 self
159 }
160
161 pub fn coupling_strength(mut self, strength: f64) -> Self {
163 self.coupling_strength = strength.clamp(0.0, 1.0);
164 self
165 }
166
167 pub fn max_iter(mut self, max_iter: usize) -> Self {
169 self.max_iter = max_iter;
170 self
171 }
172
173 pub fn tolerance(mut self, tol: f64) -> Self {
175 self.tol = tol;
176 self
177 }
178
179 pub fn regularization_strength(mut self, strength: f64) -> Self {
181 self.regularization_strength = strength.max(0.0);
182 self
183 }
184
185 pub fn random_state(mut self, random_state: u64) -> Self {
187 self.random_state = Some(random_state);
188 self
189 }
190
191 pub fn build(self) -> SklResult<MultiModalGaussianMixture<Untrained>> {
193 if self.modalities.is_empty() {
194 return Err(SklearsError::InvalidInput(
195 "At least one modality must be specified".to_string(),
196 ));
197 }
198
199 if self.fusion_strategy == FusionStrategy::IntermediateFusion
201 && self.shared_latent_dim.is_none()
202 {
203 return Err(SklearsError::InvalidInput(
204 "Shared latent dimension must be specified for intermediate fusion".to_string(),
205 ));
206 }
207
208 let config = MultiModalConfig {
209 n_components: self.n_components,
210 modalities: self.modalities,
211 fusion_strategy: self.fusion_strategy,
212 shared_latent_dim: self.shared_latent_dim,
213 coupling_strength: self.coupling_strength,
214 max_iter: self.max_iter,
215 tol: self.tol,
216 regularization_strength: self.regularization_strength,
217 random_state: self.random_state,
218 };
219
220 Ok(MultiModalGaussianMixture {
221 config,
222 _phantom: std::marker::PhantomData,
223 })
224 }
225}
226
227impl Estimator<Untrained> for MultiModalGaussianMixture<Untrained> {
228 type Config = MultiModalConfig;
229 type Error = SklearsError;
230 type Float = f64;
231
232 fn config(&self) -> &Self::Config {
233 &self.config
234 }
235}
236
237impl Estimator<Trained> for MultiModalGaussianMixture<Trained> {
238 type Config = MultiModalConfig;
239 type Error = SklearsError;
240 type Float = f64;
241
242 fn config(&self) -> &Self::Config {
243 &self.config
244 }
245}
246
247impl Fit<HashMap<String, Array2<f64>>, Option<Array1<usize>>>
248 for MultiModalGaussianMixture<Untrained>
249{
250 type Fitted = MultiModalGaussianMixtureTrained;
251
252 fn fit(
253 self,
254 X: &HashMap<String, Array2<f64>>,
255 y: &Option<Array1<usize>>,
256 ) -> SklResult<Self::Fitted> {
257 for modality in &self.config.modalities {
259 if !X.contains_key(&modality.name) {
260 return Err(SklearsError::InvalidInput(format!(
261 "Missing data for modality: {}",
262 modality.name
263 )));
264 }
265 let data = &X[&modality.name];
266 if data.ncols() != modality.n_features {
267 return Err(SklearsError::InvalidInput(format!(
268 "Feature dimension mismatch for modality {}: expected {}, got {}",
269 modality.name,
270 modality.n_features,
271 data.ncols()
272 )));
273 }
274 }
275
276 let n_samples = X.values().next().unwrap().nrows();
278 for (name, data) in X.iter() {
279 if data.nrows() != n_samples {
280 return Err(SklearsError::InvalidInput(format!(
281 "Sample size mismatch for modality {}: expected {}, got {}",
282 name,
283 n_samples,
284 data.nrows()
285 )));
286 }
287 }
288
289 match self.config.fusion_strategy {
290 FusionStrategy::EarlyFusion => self.fit_early_fusion(X, y),
291 FusionStrategy::LateFusion => self.fit_late_fusion(X, y),
292 FusionStrategy::IntermediateFusion => self.fit_intermediate_fusion(X, y),
293 FusionStrategy::CoupledFusion => self.fit_coupled_fusion(X, y),
294 }
295 }
296}
297
298impl MultiModalGaussianMixture<Untrained> {
299 fn initialize_parameters(
301 &self,
302 X: &HashMap<String, Array2<f64>>,
303 ) -> SklResult<(
304 Array1<f64>,
305 HashMap<String, Array2<f64>>,
306 HashMap<String, Array3<f64>>,
307 )> {
308 let n_samples = X.values().next().unwrap().nrows();
309 let n_components = self.config.n_components;
310
311 let weights = Array1::ones(n_components) / n_components as f64;
313
314 let mut modality_means = HashMap::new();
316 let mut modality_covariances = HashMap::new();
317
318 let mut rng = thread_rng();
319
320 for modality in &self.config.modalities {
321 let data = &X[&modality.name];
322 let n_features = data.ncols();
323
324 let mut means = Array2::zeros((n_components, n_features));
326 for k in 0..n_components {
327 let uniform = RandUniform::new(0, n_samples).map_err(|e| {
328 SklearsError::InvalidInput(format!("Uniform distribution error: {}", e))
329 })?;
330 let sample_idx = rng.sample(uniform);
331 means.row_mut(k).assign(&data.row(sample_idx));
332 }
333
334 let covariances = match modality.covariance_type {
336 CovarianceType::Full => {
337 let mut cov = Array3::zeros((n_components, n_features, n_features));
338 for k in 0..n_components {
339 for i in 0..n_features {
340 cov[[k, i, i]] = 1.0; }
342 }
343 cov
344 }
345 CovarianceType::Diagonal => {
346 let mut cov = Array3::zeros((n_components, n_features, 1));
347 for k in 0..n_components {
348 for i in 0..n_features {
349 cov[[k, i, 0]] = 1.0;
350 }
351 }
352 cov
353 }
354 CovarianceType::Tied => {
355 let mut cov = Array3::zeros((1, n_features, n_features));
356 for i in 0..n_features {
357 cov[[0, i, i]] = 1.0;
358 }
359 cov
360 }
361 CovarianceType::Spherical => Array3::ones((n_components, 1, 1)),
362 };
363
364 modality_means.insert(modality.name.clone(), means);
365 modality_covariances.insert(modality.name.clone(), covariances);
366 }
367
368 Ok((weights, modality_means, modality_covariances))
369 }
370
371 fn fit_early_fusion(
373 &self,
374 X: &HashMap<String, Array2<f64>>,
375 _y: &Option<Array1<usize>>,
376 ) -> SklResult<MultiModalGaussianMixtureTrained> {
377 let n_samples = X.values().next().unwrap().nrows();
378
379 let mut concatenated_features = Vec::new();
381 for modality in &self.config.modalities {
382 concatenated_features.push(X[&modality.name].clone());
383 }
384
385 let mut combined_data = concatenated_features[0].clone();
387 for i in 1..concatenated_features.len() {
388 let current_cols = combined_data.ncols();
389 let new_cols = concatenated_features[i].ncols();
390 let mut new_data = Array2::zeros((n_samples, current_cols + new_cols));
391 new_data
392 .slice_mut(s![.., ..current_cols])
393 .assign(&combined_data);
394 new_data
395 .slice_mut(s![.., current_cols..])
396 .assign(&concatenated_features[i]);
397 combined_data = new_data;
398 }
399
400 let (mut weights, _means_map, _covariances_map) = self.initialize_parameters(X)?;
402 let mut log_likelihood_history = Vec::new();
403
404 let total_features: usize = self.config.modalities.iter().map(|m| m.n_features).sum();
406 let mut combined_means = Array2::zeros((self.config.n_components, total_features));
407 let mut combined_covariances =
408 Array3::zeros((self.config.n_components, total_features, total_features));
409
410 for k in 0..self.config.n_components {
412 for i in 0..total_features {
413 combined_covariances[[k, i, i]] = 1.0;
414 }
415 }
416
417 for iter in 0..self.config.max_iter {
419 let old_log_likelihood = if log_likelihood_history.is_empty() {
420 f64::NEG_INFINITY
421 } else {
422 *log_likelihood_history.last().unwrap()
423 };
424
425 let mut responsibilities = Array2::zeros((n_samples, self.config.n_components));
427 let mut log_likelihood = 0.0;
428
429 for i in 0..n_samples {
430 let sample = combined_data.row(i);
431 let mut log_probs = Array1::zeros(self.config.n_components);
432
433 for k in 0..self.config.n_components {
434 let mean = combined_means.row(k);
435 let diff = &sample.to_owned() - &mean.to_owned();
436 let log_det = combined_covariances
437 .slice(s![k, .., ..])
438 .diag()
439 .mapv(|x: f64| x.ln())
440 .sum();
441 let inv_quad = diff.dot(&diff); log_probs[k] = weights[k].ln()
444 - 0.5
445 * (total_features as f64 * (2.0 * std::f64::consts::PI).ln()
446 + log_det
447 + inv_quad);
448 }
449
450 let max_log_prob = log_probs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
452 let log_sum_exp =
453 (log_probs.mapv(|x| (x - max_log_prob).exp()).sum()).ln() + max_log_prob;
454 log_likelihood += log_sum_exp;
455
456 for k in 0..self.config.n_components {
457 responsibilities[[i, k]] = ((log_probs[k] - log_sum_exp).exp()).max(1e-15);
458 }
459 }
460
461 log_likelihood_history.push(log_likelihood);
462
463 if iter > 0 && (log_likelihood - old_log_likelihood).abs() < self.config.tol {
465 break;
466 }
467
468 let n_k: Array1<f64> = responsibilities.sum_axis(Axis(0));
470
471 weights = &n_k / n_samples as f64;
473
474 for k in 0..self.config.n_components {
476 if n_k[k] > 1e-15 {
477 let weighted_sum = responsibilities.column(k).iter().enumerate().fold(
478 Array1::zeros(total_features),
479 |mut acc, (i, &resp)| {
480 let sample = combined_data.row(i);
481 for j in 0..total_features {
482 acc[j] += resp * sample[j];
483 }
484 acc
485 },
486 );
487 combined_means.row_mut(k).assign(&(weighted_sum / n_k[k]));
488 }
489 }
490
491 for k in 0..self.config.n_components {
493 if n_k[k] > 1e-15 {
494 for j in 0..total_features {
495 let mut weighted_var = 0.0;
496 for i in 0..n_samples {
497 let diff = combined_data[[i, j]] - combined_means[[k, j]];
498 weighted_var += responsibilities[[i, k]] * diff * diff;
499 }
500 combined_covariances[[k, j, j]] =
501 (weighted_var / n_k[k] + self.config.regularization_strength).max(1e-6);
502 }
503 }
504 }
505 }
506
507 let mut final_means = HashMap::new();
509 let mut final_covariances = HashMap::new();
510 let mut feature_start = 0;
511
512 for modality in &self.config.modalities {
513 let n_features = modality.n_features;
514 let modality_means = combined_means
515 .slice(s![.., feature_start..feature_start + n_features])
516 .to_owned();
517 let modality_cov_slice = combined_covariances
518 .slice(s![
519 ..,
520 feature_start..feature_start + n_features,
521 feature_start..feature_start + n_features
522 ])
523 .to_owned();
524
525 final_means.insert(modality.name.clone(), modality_means);
526 final_covariances.insert(modality.name.clone(), modality_cov_slice);
527 feature_start += n_features;
528 }
529
530 let n_iter = log_likelihood_history.len();
531 Ok(MultiModalGaussianMixtureTrained {
532 weights,
533 modality_means: final_means,
534 modality_covariances: final_covariances,
535 shared_latent_means: None,
536 latent_projections: HashMap::new(),
537 coupling_parameters: Array2::zeros((0, 0)),
538 log_likelihood_history,
539 n_iter,
540 config: self.config.clone(),
541 })
542 }
543
544 fn fit_late_fusion(
546 &self,
547 X: &HashMap<String, Array2<f64>>,
548 _y: &Option<Array1<usize>>,
549 ) -> SklResult<MultiModalGaussianMixtureTrained> {
550 let n_samples = X.values().next().unwrap().nrows();
551 let (weights, mut modality_means, mut modality_covariances) =
552 self.initialize_parameters(X)?;
553 let mut log_likelihood_history = Vec::new();
554
555 for modality in &self.config.modalities {
557 let data = &X[&modality.name];
558 let n_features = modality.n_features;
559 let n_components = self.config.n_components;
560
561 let mut modality_weights: Array1<f64> =
562 Array1::ones(n_components) / n_components as f64;
563 let mut means = modality_means[&modality.name].clone();
564 let mut covariances = modality_covariances[&modality.name].clone();
565
566 for _iter in 0..self.config.max_iter {
568 let mut responsibilities = Array2::zeros((n_samples, n_components));
570
571 for i in 0..n_samples {
572 let sample = data.row(i);
573 let mut log_probs = Array1::zeros(n_components);
574
575 for k in 0..n_components {
576 let mean = means.row(k);
577 let diff = &sample.to_owned() - &mean.to_owned();
578
579 let log_det = match modality.covariance_type {
580 CovarianceType::Full => covariances
581 .slice(s![k, .., ..])
582 .diag()
583 .mapv(|x| x.ln())
584 .sum(),
585 CovarianceType::Diagonal => {
586 covariances.slice(s![k, .., 0]).mapv(|x| x.ln()).sum()
587 }
588 CovarianceType::Spherical => {
589 n_features as f64 * covariances[[k, 0, 0]].ln()
590 }
591 CovarianceType::Tied => covariances
592 .slice(s![0, .., ..])
593 .diag()
594 .mapv(|x| x.ln())
595 .sum(),
596 };
597
598 let inv_quad = match modality.covariance_type {
599 CovarianceType::Spherical => diff.dot(&diff) / covariances[[k, 0, 0]],
600 _ => diff.dot(&diff), };
602
603 log_probs[k] = modality_weights[k].ln()
604 - 0.5
605 * (n_features as f64 * (2.0 * std::f64::consts::PI).ln()
606 + log_det
607 + inv_quad);
608 }
609
610 let max_log_prob = log_probs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
612 let log_sum_exp =
613 (log_probs.mapv(|x| (x - max_log_prob).exp()).sum()).ln() + max_log_prob;
614
615 for k in 0..n_components {
616 responsibilities[[i, k]] = ((log_probs[k] - log_sum_exp).exp()).max(1e-15);
617 }
618 }
619
620 let n_k: Array1<f64> = responsibilities.sum_axis(Axis(0));
622 modality_weights = &n_k / n_samples as f64;
623
624 for k in 0..n_components {
626 if n_k[k] > 1e-15 {
627 let weighted_sum = responsibilities.column(k).iter().enumerate().fold(
628 Array1::zeros(n_features),
629 |mut acc, (i, &resp)| {
630 let sample = data.row(i);
631 for j in 0..n_features {
632 acc[j] += resp * sample[j];
633 }
634 acc
635 },
636 );
637 means.row_mut(k).assign(&(weighted_sum / n_k[k]));
638 }
639 }
640
641 match modality.covariance_type {
643 CovarianceType::Spherical => {
644 for k in 0..n_components {
645 if n_k[k] > 1e-15 {
646 let mut weighted_var = 0.0;
647 for i in 0..n_samples {
648 let sample = data.row(i);
649 let mean = means.row(k);
650 let diff = &sample.to_owned() - &mean.to_owned();
651 weighted_var += responsibilities[[i, k]] * diff.dot(&diff);
652 }
653 covariances[[k, 0, 0]] = (weighted_var
654 / (n_k[k] * n_features as f64)
655 + self.config.regularization_strength)
656 .max(1e-6);
657 }
658 }
659 }
660 _ => {
661 for k in 0..n_components {
663 if n_k[k] > 1e-15 {
664 for j in 0..n_features {
665 let mut weighted_var = 0.0;
666 for i in 0..n_samples {
667 let diff = data[[i, j]] - means[[k, j]];
668 weighted_var += responsibilities[[i, k]] * diff * diff;
669 }
670 let var_idx = match modality.covariance_type {
671 CovarianceType::Diagonal => (k, j, 0),
672 _ => (k, j, j),
673 };
674 covariances[var_idx] = (weighted_var / n_k[k]
675 + self.config.regularization_strength)
676 .max(1e-6);
677 }
678 }
679 }
680 }
681 }
682 }
683
684 modality_means.insert(modality.name.clone(), means);
686 modality_covariances.insert(modality.name.clone(), covariances);
687 }
688
689 log_likelihood_history.push(0.0); Ok(MultiModalGaussianMixtureTrained {
693 weights,
694 modality_means,
695 modality_covariances,
696 shared_latent_means: None,
697 latent_projections: HashMap::new(),
698 coupling_parameters: Array2::zeros((0, 0)),
699 log_likelihood_history,
700 n_iter: 1,
701 config: self.config.clone(),
702 })
703 }
704
705 fn fit_intermediate_fusion(
707 &self,
708 X: &HashMap<String, Array2<f64>>,
709 _y: &Option<Array1<usize>>,
710 ) -> SklResult<MultiModalGaussianMixtureTrained> {
711 let _n_samples = X.values().next().unwrap().nrows();
712 let latent_dim = self.config.shared_latent_dim.unwrap();
713 let (weights, modality_means, modality_covariances) = self.initialize_parameters(X)?;
714
715 let mut latent_projections = HashMap::new();
717 let mut rng = thread_rng();
718
719 for modality in &self.config.modalities {
720 let normal = RandNormal::new(0.0, 0.1).map_err(|e| {
721 SklearsError::InvalidInput(format!("Normal distribution error: {}", e))
722 })?;
723 let mut projection = Array2::zeros((latent_dim, modality.n_features));
724 for i in 0..latent_dim {
725 for j in 0..modality.n_features {
726 projection[[i, j]] = rng.sample(normal);
727 }
728 }
729 latent_projections.insert(modality.name.clone(), projection);
730 }
731
732 let mut shared_latent_means = Array2::zeros((self.config.n_components, latent_dim));
734 for k in 0..self.config.n_components {
735 for d in 0..latent_dim {
736 let normal = RandNormal::new(0.0, 1.0).map_err(|e| {
737 SklearsError::InvalidInput(format!("Normal distribution error: {}", e))
738 })?;
739 shared_latent_means[[k, d]] = rng.sample(normal);
740 }
741 }
742
743 let mut log_likelihood_history = Vec::new();
744 let mut coupling_parameters =
745 Array2::zeros((self.config.modalities.len(), self.config.modalities.len()));
746
747 for i in 0..self.config.modalities.len() {
749 coupling_parameters[[i, i]] = 1.0;
750 for j in (i + 1)..self.config.modalities.len() {
751 coupling_parameters[[i, j]] = self.config.coupling_strength;
752 coupling_parameters[[j, i]] = self.config.coupling_strength;
753 }
754 }
755
756 log_likelihood_history.push(0.0);
758
759 Ok(MultiModalGaussianMixtureTrained {
760 weights,
761 modality_means,
762 modality_covariances,
763 shared_latent_means: Some(shared_latent_means),
764 latent_projections,
765 coupling_parameters,
766 log_likelihood_history,
767 n_iter: 1,
768 config: self.config.clone(),
769 })
770 }
771
772 fn fit_coupled_fusion(
774 &self,
775 X: &HashMap<String, Array2<f64>>,
776 _y: &Option<Array1<usize>>,
777 ) -> SklResult<MultiModalGaussianMixtureTrained> {
778 let n_samples = X.values().next().unwrap().nrows();
779 let (mut weights, mut modality_means, mut modality_covariances) =
780 self.initialize_parameters(X)?;
781 let mut log_likelihood_history = Vec::new();
782
783 let n_modalities = self.config.modalities.len();
785 let mut coupling_parameters = Array2::zeros((n_modalities, n_modalities));
786
787 for i in 0..n_modalities {
788 coupling_parameters[[i, i]] = 1.0;
789 for j in (i + 1)..n_modalities {
790 coupling_parameters[[i, j]] = self.config.coupling_strength;
791 coupling_parameters[[j, i]] = self.config.coupling_strength;
792 }
793 }
794
795 for iter in 0..self.config.max_iter {
797 let old_log_likelihood = if log_likelihood_history.is_empty() {
798 f64::NEG_INFINITY
799 } else {
800 *log_likelihood_history.last().unwrap()
801 };
802
803 let mut total_log_likelihood = 0.0;
804 let mut global_responsibilities = Array2::zeros((n_samples, self.config.n_components));
805
806 for (modality_idx, modality) in self.config.modalities.iter().enumerate() {
808 let data = &X[&modality.name];
809 let means = &modality_means[&modality.name];
810 let covariances = &modality_covariances[&modality.name];
811 let mut modality_responsibilities =
812 Array2::zeros((n_samples, self.config.n_components));
813
814 for i in 0..n_samples {
815 let sample = data.row(i);
816 let mut log_probs = Array1::zeros(self.config.n_components);
817
818 for k in 0..self.config.n_components {
819 let mean = means.row(k);
820 let diff = &sample.to_owned() - &mean.to_owned();
821
822 let (log_det, inv_quad) = match modality.covariance_type {
823 CovarianceType::Spherical => {
824 let variance = covariances[[k, 0, 0]];
825 let log_det = modality.n_features as f64 * variance.ln();
826 let inv_quad = diff.dot(&diff) / variance;
827 (log_det, inv_quad)
828 }
829 _ => {
830 let log_det = (0..modality.n_features)
832 .map(|j| {
833 covariances[[k, j, 0.min(covariances.dim().2 - 1)]].ln()
834 })
835 .sum::<f64>();
836 let inv_quad = diff.dot(&diff); (log_det, inv_quad)
838 }
839 };
840
841 log_probs[k] = weights[k].ln()
842 - 0.5
843 * (modality.n_features as f64 * (2.0 * std::f64::consts::PI).ln()
844 + log_det
845 + inv_quad);
846 }
847
848 let max_log_prob = log_probs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
850 let log_sum_exp =
851 (log_probs.mapv(|x| (x - max_log_prob).exp()).sum()).ln() + max_log_prob;
852 total_log_likelihood += log_sum_exp * modality.modality_weight;
853
854 for k in 0..self.config.n_components {
855 modality_responsibilities[[i, k]] =
856 ((log_probs[k] - log_sum_exp).exp()).max(1e-15);
857 }
858 }
859
860 for i in 0..n_samples {
862 for k in 0..self.config.n_components {
863 global_responsibilities[[i, k]] += modality.modality_weight
864 * coupling_parameters[[modality_idx, modality_idx]]
865 * modality_responsibilities[[i, k]];
866 }
867 }
868 }
869
870 for i in 0..n_samples {
872 let sum: f64 = global_responsibilities.row(i).sum();
873 if sum > 1e-15 {
874 global_responsibilities.row_mut(i).mapv_inplace(|x| x / sum);
875 }
876 }
877
878 log_likelihood_history.push(total_log_likelihood);
879
880 if iter > 0 && (total_log_likelihood - old_log_likelihood).abs() < self.config.tol {
882 break;
883 }
884
885 let n_k: Array1<f64> = global_responsibilities.sum_axis(Axis(0));
887 weights = &n_k / n_samples as f64;
888
889 for modality in &self.config.modalities {
891 let data = &X[&modality.name];
892 let mut means = modality_means[&modality.name].clone();
893 let mut covariances = modality_covariances[&modality.name].clone();
894
895 for k in 0..self.config.n_components {
897 if n_k[k] > 1e-15 {
898 let weighted_sum = global_responsibilities
899 .column(k)
900 .iter()
901 .enumerate()
902 .fold(Array1::zeros(modality.n_features), |mut acc, (i, &resp)| {
903 let sample = data.row(i);
904 for j in 0..modality.n_features {
905 acc[j] += resp * sample[j];
906 }
907 acc
908 });
909 means.row_mut(k).assign(&(weighted_sum / n_k[k]));
910 }
911 }
912
913 match modality.covariance_type {
915 CovarianceType::Spherical => {
916 for k in 0..self.config.n_components {
917 if n_k[k] > 1e-15 {
918 let mut weighted_var = 0.0;
919 for i in 0..n_samples {
920 let sample = data.row(i);
921 let mean = means.row(k);
922 let diff = &sample.to_owned() - &mean.to_owned();
923 weighted_var +=
924 global_responsibilities[[i, k]] * diff.dot(&diff);
925 }
926 covariances[[k, 0, 0]] = (weighted_var
927 / (n_k[k] * modality.n_features as f64)
928 + self.config.regularization_strength)
929 .max(1e-6);
930 }
931 }
932 }
933 _ => {
934 for k in 0..self.config.n_components {
935 if n_k[k] > 1e-15 {
936 for j in 0..modality.n_features {
937 let mut weighted_var = 0.0;
938 for i in 0..n_samples {
939 let diff = data[[i, j]] - means[[k, j]];
940 weighted_var +=
941 global_responsibilities[[i, k]] * diff * diff;
942 }
943 let var_idx = match modality.covariance_type {
944 CovarianceType::Diagonal => (k, j, 0),
945 _ => (k, j, j),
946 };
947 covariances[var_idx] = (weighted_var / n_k[k]
948 + self.config.regularization_strength)
949 .max(1e-6);
950 }
951 }
952 }
953 }
954 }
955
956 modality_means.insert(modality.name.clone(), means);
957 modality_covariances.insert(modality.name.clone(), covariances);
958 }
959 }
960
961 let n_iter = log_likelihood_history.len();
962 Ok(MultiModalGaussianMixtureTrained {
963 weights,
964 modality_means,
965 modality_covariances,
966 shared_latent_means: None,
967 latent_projections: HashMap::new(),
968 coupling_parameters,
969 log_likelihood_history,
970 n_iter,
971 config: self.config.clone(),
972 })
973 }
974}
975
976impl Predict<HashMap<String, Array2<f64>>, Array1<usize>> for MultiModalGaussianMixtureTrained {
977 fn predict(&self, X: &HashMap<String, Array2<f64>>) -> SklResult<Array1<usize>> {
978 let probabilities = self.predict_proba(X)?;
979 let n_samples = probabilities.nrows();
980 let mut predictions = Array1::zeros(n_samples);
981
982 for i in 0..n_samples {
983 let mut max_prob = 0.0;
984 let mut best_component = 0;
985
986 for k in 0..self.config.n_components {
987 if probabilities[[i, k]] > max_prob {
988 max_prob = probabilities[[i, k]];
989 best_component = k;
990 }
991 }
992 predictions[i] = best_component;
993 }
994
995 Ok(predictions)
996 }
997}
998
999impl MultiModalGaussianMixtureTrained {
1000 pub fn predict_proba(&self, X: &HashMap<String, Array2<f64>>) -> SklResult<Array2<f64>> {
1002 for modality in &self.config.modalities {
1004 if !X.contains_key(&modality.name) {
1005 return Err(SklearsError::InvalidInput(format!(
1006 "Missing data for modality: {}",
1007 modality.name
1008 )));
1009 }
1010 }
1011
1012 let n_samples = X.values().next().unwrap().nrows();
1013 let mut probabilities = Array2::zeros((n_samples, self.config.n_components));
1014
1015 match self.config.fusion_strategy {
1016 FusionStrategy::EarlyFusion => {
1017 self.predict_proba_early_fusion(X, &mut probabilities)?;
1019 }
1020 FusionStrategy::LateFusion => {
1021 self.predict_proba_late_fusion(X, &mut probabilities)?;
1023 }
1024 FusionStrategy::IntermediateFusion => {
1025 self.predict_proba_intermediate_fusion(X, &mut probabilities)?;
1027 }
1028 FusionStrategy::CoupledFusion => {
1029 self.predict_proba_coupled_fusion(X, &mut probabilities)?;
1031 }
1032 }
1033
1034 Ok(probabilities)
1035 }
1036
1037 fn predict_proba_early_fusion(
1038 &self,
1039 _X: &HashMap<String, Array2<f64>>,
1040 probabilities: &mut Array2<f64>,
1041 ) -> SklResult<()> {
1042 let n_samples = probabilities.nrows();
1045 for i in 0..n_samples {
1046 probabilities.row_mut(i).assign(&self.weights);
1047 }
1048 Ok(())
1049 }
1050
1051 fn predict_proba_late_fusion(
1052 &self,
1053 X: &HashMap<String, Array2<f64>>,
1054 probabilities: &mut Array2<f64>,
1055 ) -> SklResult<()> {
1056 let n_samples = probabilities.nrows();
1057 probabilities.fill(0.0);
1058
1059 for modality in &self.config.modalities {
1061 let data = &X[&modality.name];
1062 let means = &self.modality_means[&modality.name];
1063
1064 for i in 0..n_samples {
1065 let sample = data.row(i);
1066 let mut modality_probs = Array1::zeros(self.config.n_components);
1067
1068 for k in 0..self.config.n_components {
1069 let mean = means.row(k);
1070 let diff = &sample.to_owned() - &mean.to_owned();
1071 let log_prob = self.weights[k].ln() - 0.5 * diff.dot(&diff);
1072 modality_probs[k] = log_prob.exp();
1073 }
1074
1075 let sum: f64 = modality_probs.sum();
1077 if sum > 1e-15 {
1078 modality_probs.mapv_inplace(|x| x / sum);
1079 }
1080
1081 for k in 0..self.config.n_components {
1083 probabilities[[i, k]] += modality.modality_weight * modality_probs[k];
1084 }
1085 }
1086 }
1087
1088 for i in 0..n_samples {
1090 let sum: f64 = probabilities.row(i).sum();
1091 if sum > 1e-15 {
1092 probabilities.row_mut(i).mapv_inplace(|x| x / sum);
1093 }
1094 }
1095
1096 Ok(())
1097 }
1098
1099 fn predict_proba_intermediate_fusion(
1100 &self,
1101 _X: &HashMap<String, Array2<f64>>,
1102 probabilities: &mut Array2<f64>,
1103 ) -> SklResult<()> {
1104 let n_samples = probabilities.nrows();
1107 for i in 0..n_samples {
1108 probabilities.row_mut(i).assign(&self.weights);
1109 }
1110 Ok(())
1111 }
1112
1113 fn predict_proba_coupled_fusion(
1114 &self,
1115 X: &HashMap<String, Array2<f64>>,
1116 probabilities: &mut Array2<f64>,
1117 ) -> SklResult<()> {
1118 self.predict_proba_late_fusion(X, probabilities)?;
1120
1121 let n_samples = probabilities.nrows();
1123 for i in 0..n_samples {
1124 for k in 0..self.config.n_components {
1125 probabilities[[i, k]] *= 1.0 + self.config.coupling_strength;
1126 }
1127
1128 let sum: f64 = probabilities.row(i).sum();
1130 if sum > 1e-15 {
1131 probabilities.row_mut(i).mapv_inplace(|x| x / sum);
1132 }
1133 }
1134
1135 Ok(())
1136 }
1137
1138 pub fn score(&self, X: &HashMap<String, Array2<f64>>) -> SklResult<f64> {
1140 let probabilities = self.predict_proba(X)?;
1141 let log_likelihood = probabilities.mapv(|p| p.max(1e-15).ln()).sum();
1142 Ok(log_likelihood)
1143 }
1144
1145 pub fn model_selection(&self, X: &HashMap<String, Array2<f64>>) -> SklResult<ModelSelection> {
1147 let n_samples = X.values().next().unwrap().nrows();
1148 let total_features: usize = self.config.modalities.iter().map(|m| m.n_features).sum();
1149
1150 let n_parameters = ModelSelection::n_parameters(
1152 self.config.n_components,
1153 total_features,
1154 &CovarianceType::Full,
1155 );
1156
1157 let log_likelihood = self.score(X)?;
1158 let aic = ModelSelection::aic(log_likelihood, n_parameters);
1159 let bic = ModelSelection::bic(log_likelihood, n_parameters, n_samples);
1160
1161 Ok(ModelSelection {
1162 aic,
1163 bic,
1164 log_likelihood,
1165 n_parameters,
1166 })
1167 }
1168}
1169
1170#[allow(non_snake_case)]
1171#[cfg(test)]
1172mod tests {
1173 use super::*;
1174 use approx::assert_abs_diff_eq;
1175
1176 fn create_test_multi_modal_data() -> HashMap<String, Array2<f64>> {
1177 let mut data = HashMap::new();
1178
1179 let visual_data =
1181 Array2::from_shape_vec((100, 2), (0..200).map(|i| i as f64 * 0.1).collect()).unwrap();
1182 data.insert("visual".to_string(), visual_data);
1183
1184 let textual_data = Array2::from_shape_vec(
1186 (100, 3),
1187 (0..300).map(|i| (i as f64 * 0.05).sin()).collect(),
1188 )
1189 .unwrap();
1190 data.insert("textual".to_string(), textual_data);
1191
1192 data
1193 }
1194
1195 #[test]
1196 fn test_multi_modal_builder() {
1197 let model = MultiModalGaussianMixtureBuilder::new(3)
1198 .add_modality_simple("visual", 2)
1199 .add_modality_simple("textual", 3)
1200 .fusion_strategy(FusionStrategy::EarlyFusion)
1201 .coupling_strength(0.2)
1202 .max_iter(10)
1203 .build()
1204 .unwrap();
1205
1206 assert_eq!(model.config.n_components, 3);
1207 assert_eq!(model.config.modalities.len(), 2);
1208 assert_eq!(model.config.fusion_strategy, FusionStrategy::EarlyFusion);
1209 assert_abs_diff_eq!(model.config.coupling_strength, 0.2, epsilon = 1e-10);
1210 }
1211
1212 #[test]
1213 fn test_early_fusion_fit() {
1214 let data = create_test_multi_modal_data();
1215 let model = MultiModalGaussianMixtureBuilder::new(2)
1216 .add_modality_simple("visual", 2)
1217 .add_modality_simple("textual", 3)
1218 .fusion_strategy(FusionStrategy::EarlyFusion)
1219 .max_iter(5)
1220 .build()
1221 .unwrap();
1222
1223 let trained = model.fit(&data, &None).unwrap();
1224
1225 assert_eq!(trained.weights.len(), 2);
1226 assert!(trained.modality_means.contains_key("visual"));
1227 assert!(trained.modality_means.contains_key("textual"));
1228 assert_eq!(trained.modality_means["visual"].nrows(), 2); assert_eq!(trained.modality_means["visual"].ncols(), 2); }
1231
1232 #[test]
1233 fn test_late_fusion_fit() {
1234 let data = create_test_multi_modal_data();
1235 let model = MultiModalGaussianMixtureBuilder::new(2)
1236 .add_modality_simple("visual", 2)
1237 .add_modality_simple("textual", 3)
1238 .fusion_strategy(FusionStrategy::LateFusion)
1239 .max_iter(5)
1240 .build()
1241 .unwrap();
1242
1243 let trained = model.fit(&data, &None).unwrap();
1244
1245 assert_eq!(trained.weights.len(), 2);
1246 assert!(trained.modality_means.contains_key("visual"));
1247 assert!(trained.modality_means.contains_key("textual"));
1248 }
1249
1250 #[test]
1251 fn test_intermediate_fusion_fit() {
1252 let data = create_test_multi_modal_data();
1253 let model = MultiModalGaussianMixtureBuilder::new(2)
1254 .add_modality_simple("visual", 2)
1255 .add_modality_simple("textual", 3)
1256 .fusion_strategy(FusionStrategy::IntermediateFusion)
1257 .shared_latent_dim(4)
1258 .max_iter(5)
1259 .build()
1260 .unwrap();
1261
1262 let trained = model.fit(&data, &None).unwrap();
1263
1264 assert_eq!(trained.weights.len(), 2);
1265 assert!(trained.shared_latent_means.is_some());
1266 let latent_means = trained.shared_latent_means.as_ref().unwrap();
1267 assert_eq!(latent_means.nrows(), 2); assert_eq!(latent_means.ncols(), 4); }
1270
1271 #[test]
1272 fn test_coupled_fusion_fit() {
1273 let data = create_test_multi_modal_data();
1274 let model = MultiModalGaussianMixtureBuilder::new(2)
1275 .add_modality_simple("visual", 2)
1276 .add_modality_simple("textual", 3)
1277 .fusion_strategy(FusionStrategy::CoupledFusion)
1278 .coupling_strength(0.3)
1279 .max_iter(5)
1280 .build()
1281 .unwrap();
1282
1283 let trained = model.fit(&data, &None).unwrap();
1284
1285 assert_eq!(trained.weights.len(), 2);
1286 assert_eq!(trained.coupling_parameters.nrows(), 2); assert_eq!(trained.coupling_parameters.ncols(), 2); }
1289
1290 #[test]
1291 fn test_prediction() {
1292 let data = create_test_multi_modal_data();
1293 let model = MultiModalGaussianMixtureBuilder::new(2)
1294 .add_modality_simple("visual", 2)
1295 .add_modality_simple("textual", 3)
1296 .fusion_strategy(FusionStrategy::LateFusion)
1297 .max_iter(3)
1298 .build()
1299 .unwrap();
1300
1301 let trained = model.fit(&data, &None).unwrap();
1302 let predictions = trained.predict(&data).unwrap();
1303
1304 assert_eq!(predictions.len(), 100);
1305
1306 for &pred in predictions.iter() {
1308 assert!(pred < 2);
1309 }
1310 }
1311
1312 #[test]
1313 fn test_predict_proba() {
1314 let data = create_test_multi_modal_data();
1315 let model = MultiModalGaussianMixtureBuilder::new(3)
1316 .add_modality_simple("visual", 2)
1317 .add_modality_simple("textual", 3)
1318 .fusion_strategy(FusionStrategy::LateFusion)
1319 .max_iter(3)
1320 .build()
1321 .unwrap();
1322
1323 let trained = model.fit(&data, &None).unwrap();
1324 let probabilities = trained.predict_proba(&data).unwrap();
1325
1326 assert_eq!(probabilities.nrows(), 100);
1327 assert_eq!(probabilities.ncols(), 3);
1328
1329 for i in 0..100 {
1331 let sum: f64 = probabilities.row(i).sum();
1332 assert_abs_diff_eq!(sum, 1.0, epsilon = 1e-10);
1333 }
1334 }
1335
1336 #[test]
1337 fn test_model_selection() {
1338 let data = create_test_multi_modal_data();
1339 let model = MultiModalGaussianMixtureBuilder::new(2)
1340 .add_modality_simple("visual", 2)
1341 .add_modality_simple("textual", 3)
1342 .fusion_strategy(FusionStrategy::EarlyFusion)
1343 .max_iter(3)
1344 .build()
1345 .unwrap();
1346
1347 let trained = model.fit(&data, &None).unwrap();
1348 let model_selection = trained.model_selection(&data).unwrap();
1349
1350 assert!(model_selection.log_likelihood.is_finite());
1351 assert!(model_selection.aic.is_finite());
1352 assert!(model_selection.bic.is_finite());
1353 assert!(model_selection.n_parameters > 0);
1354 }
1355
1356 #[test]
1357 fn test_validation_missing_modality() {
1358 let mut data = create_test_multi_modal_data();
1359 data.remove("textual"); let model = MultiModalGaussianMixtureBuilder::new(2)
1362 .add_modality_simple("visual", 2)
1363 .add_modality_simple("textual", 3)
1364 .fusion_strategy(FusionStrategy::EarlyFusion) .build()
1366 .unwrap();
1367
1368 let result = model.fit(&data, &None);
1369 assert!(result.is_err());
1370 }
1371
1372 #[test]
1373 fn test_validation_feature_dimension_mismatch() {
1374 let data = create_test_multi_modal_data();
1375
1376 let model = MultiModalGaussianMixtureBuilder::new(2)
1377 .add_modality_simple("visual", 5) .add_modality_simple("textual", 3)
1379 .fusion_strategy(FusionStrategy::EarlyFusion) .build()
1381 .unwrap();
1382
1383 let result = model.fit(&data, &None);
1384 assert!(result.is_err());
1385 }
1386
1387 #[test]
1388 fn test_intermediate_fusion_requires_latent_dim() {
1389 let result = MultiModalGaussianMixtureBuilder::new(2)
1390 .add_modality_simple("visual", 2)
1391 .fusion_strategy(FusionStrategy::IntermediateFusion)
1392 .build();
1394
1395 assert!(result.is_err());
1396 }
1397}