1use crate::common::CovarianceType;
23use scirs2_core::ndarray::{s, Array1, Array2, ArrayView2};
24use scirs2_core::random::thread_rng;
25use sklears_core::{
26 error::{Result as SklResult, SklearsError},
27 traits::{Estimator, Fit, Predict, Untrained},
28 types::Float,
29};
30use std::f64::consts::PI;
31
32#[derive(Debug, Clone, Copy, PartialEq)]
34pub enum BatchStrategy {
35 Fixed { size: usize },
37 Adaptive {
39 initial_size: usize,
40 max_size: usize,
41 },
42 Dynamic { target_memory_mb: usize },
44}
45
46#[derive(Debug, Clone, Copy, PartialEq)]
48pub enum ParallelStrategy {
49 DataParallel { n_threads: usize },
51 ModelParallel { n_threads: usize },
53 Hybrid {
55 data_threads: usize,
56 model_threads: usize,
57 },
58}
59
60#[derive(Debug, Clone)]
81pub struct MiniBatchGMM<S = Untrained> {
82 n_components: usize,
83 batch_strategy: BatchStrategy,
84 covariance_type: CovarianceType,
85 max_iter: usize,
86 tol: f64,
87 reg_covar: f64,
88 learning_rate: f64,
89 momentum: f64,
90 random_state: Option<u64>,
91 _phantom: std::marker::PhantomData<S>,
92}
93
94#[derive(Debug, Clone)]
96pub struct MiniBatchGMMTrained {
97 pub weights: Array1<f64>,
99 pub means: Array2<f64>,
101 pub covariances: Array2<f64>,
103 pub log_likelihood_history: Vec<f64>,
105 pub batch_sizes: Vec<usize>,
107 pub n_iter: usize,
109 pub converged: bool,
111}
112
113#[derive(Debug, Clone)]
115pub struct MiniBatchGMMBuilder {
116 n_components: usize,
117 batch_strategy: BatchStrategy,
118 covariance_type: CovarianceType,
119 max_iter: usize,
120 tol: f64,
121 reg_covar: f64,
122 learning_rate: f64,
123 momentum: f64,
124 random_state: Option<u64>,
125}
126
127impl MiniBatchGMMBuilder {
128 pub fn new() -> Self {
130 Self {
131 n_components: 1,
132 batch_strategy: BatchStrategy::Fixed { size: 256 },
133 covariance_type: CovarianceType::Diagonal,
134 max_iter: 100,
135 tol: 1e-3,
136 reg_covar: 1e-6,
137 learning_rate: 0.1,
138 momentum: 0.9,
139 random_state: None,
140 }
141 }
142
143 pub fn n_components(mut self, n: usize) -> Self {
145 self.n_components = n;
146 self
147 }
148
149 pub fn batch_strategy(mut self, strategy: BatchStrategy) -> Self {
151 self.batch_strategy = strategy;
152 self
153 }
154
155 pub fn covariance_type(mut self, cov_type: CovarianceType) -> Self {
157 self.covariance_type = cov_type;
158 self
159 }
160
161 pub fn max_iter(mut self, max_iter: usize) -> Self {
163 self.max_iter = max_iter;
164 self
165 }
166
167 pub fn tol(mut self, tol: f64) -> Self {
169 self.tol = tol;
170 self
171 }
172
173 pub fn learning_rate(mut self, lr: f64) -> Self {
175 self.learning_rate = lr;
176 self
177 }
178
179 pub fn momentum(mut self, m: f64) -> Self {
181 self.momentum = m;
182 self
183 }
184
185 pub fn build(self) -> MiniBatchGMM<Untrained> {
187 MiniBatchGMM {
188 n_components: self.n_components,
189 batch_strategy: self.batch_strategy,
190 covariance_type: self.covariance_type,
191 max_iter: self.max_iter,
192 tol: self.tol,
193 reg_covar: self.reg_covar,
194 learning_rate: self.learning_rate,
195 momentum: self.momentum,
196 random_state: self.random_state,
197 _phantom: std::marker::PhantomData,
198 }
199 }
200}
201
202impl Default for MiniBatchGMMBuilder {
203 fn default() -> Self {
204 Self::new()
205 }
206}
207
208impl MiniBatchGMM<Untrained> {
209 pub fn builder() -> MiniBatchGMMBuilder {
211 MiniBatchGMMBuilder::new()
212 }
213}
214
215impl Estimator for MiniBatchGMM<Untrained> {
216 type Config = ();
217 type Error = SklearsError;
218 type Float = Float;
219
220 fn config(&self) -> &Self::Config {
221 &()
222 }
223}
224
225impl Fit<ArrayView2<'_, Float>, ()> for MiniBatchGMM<Untrained> {
226 type Fitted = MiniBatchGMM<MiniBatchGMMTrained>;
227
228 #[allow(non_snake_case)]
229 fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
230 let X_owned = X.to_owned();
231 let (n_samples, n_features) = X_owned.dim();
232
233 if n_samples < self.n_components {
234 return Err(SklearsError::InvalidInput(
235 "Number of samples must be >= number of components".to_string(),
236 ));
237 }
238
239 let batch_size = match self.batch_strategy {
241 BatchStrategy::Fixed { size } => size.min(n_samples),
242 BatchStrategy::Adaptive { initial_size, .. } => initial_size.min(n_samples),
243 BatchStrategy::Dynamic { target_memory_mb } => {
244 let bytes_per_sample = n_features * 8; let target_bytes = target_memory_mb * 1024 * 1024;
247 (target_bytes / bytes_per_sample).min(n_samples)
248 }
249 };
250
251 let mut rng = thread_rng();
253 let mut means = Array2::zeros((self.n_components, n_features));
254 let mut used_indices = Vec::new();
255 for k in 0..self.n_components {
256 let idx = loop {
257 let candidate = rng.gen_range(0..n_samples);
258 if !used_indices.contains(&candidate) {
259 used_indices.push(candidate);
260 break candidate;
261 }
262 };
263 means.row_mut(k).assign(&X_owned.row(idx));
264 }
265
266 let mut weights = Array1::from_elem(self.n_components, 1.0 / self.n_components as f64);
267 let covariances =
268 Array2::<f64>::eye(n_features) + &(Array2::<f64>::eye(n_features) * self.reg_covar);
269
270 let mut log_likelihood_history = Vec::new();
271 let mut batch_sizes = Vec::new();
272 let mut converged = false;
273
274 for _iter in 0..self.max_iter {
276 for batch_start in (0..n_samples).step_by(batch_size) {
278 let batch_end = (batch_start + batch_size).min(n_samples);
279 let batch = X_owned.slice(s![batch_start..batch_end, ..]);
280
281 let batch_size_actual = batch_end - batch_start;
283 let mut responsibilities = Array2::zeros((batch_size_actual, self.n_components));
284
285 for i in 0..batch_size_actual {
286 let x = batch.row(i);
287 let mut log_probs = Vec::new();
288
289 for k in 0..self.n_components {
290 let mean = means.row(k);
291 let diff = &x.to_owned() - &mean.to_owned();
292
293 let mahal = diff
294 .iter()
295 .zip(covariances.diag().iter())
296 .map(|(d, c): (&f64, &f64)| d * d / c.max(self.reg_covar))
297 .sum::<f64>();
298
299 let log_det = covariances
300 .diag()
301 .iter()
302 .map(|c| c.max(self.reg_covar).ln())
303 .sum::<f64>();
304
305 let log_prob = weights[k].ln()
306 - 0.5 * (n_features as f64 * (2.0 * PI).ln() + log_det)
307 - 0.5 * mahal;
308
309 log_probs.push(log_prob);
310 }
311
312 let max_log = log_probs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
313 let sum_exp: f64 = log_probs.iter().map(|&lp| (lp - max_log).exp()).sum();
314
315 for k in 0..self.n_components {
316 responsibilities[[i, k]] =
317 ((log_probs[k] - max_log).exp() / sum_exp).max(1e-10);
318 }
319 }
320
321 for k in 0..self.n_components {
323 let resps = responsibilities.column(k);
324 let nk = resps.sum().max(1e-10);
325
326 let new_weight = nk / batch_size_actual as f64;
328 weights[k] =
329 (1.0 - self.learning_rate) * weights[k] + self.learning_rate * new_weight;
330
331 let mut batch_mean = Array1::zeros(n_features);
333 for i in 0..batch_size_actual {
334 batch_mean += &(batch.row(i).to_owned() * resps[i]);
335 }
336 batch_mean /= nk;
337
338 for j in 0..n_features {
339 means[[k, j]] = (1.0 - self.learning_rate) * means[[k, j]]
340 + self.learning_rate * batch_mean[j];
341 }
342 }
343
344 batch_sizes.push(batch_size_actual);
345 }
346
347 let weight_sum = weights.sum();
349 weights /= weight_sum;
350
351 let sample_size = 1000.min(n_samples);
353 let mut log_lik = 0.0;
354 for _i in 0..sample_size {
355 let mut sample_ll = 0.0;
356 for k in 0..self.n_components {
357 sample_ll += weights[k];
358 }
359 log_lik += sample_ll.max(1e-10).ln();
360 }
361 log_lik /= sample_size as f64;
362 log_likelihood_history.push(log_lik);
363
364 if log_likelihood_history.len() > 1 {
366 let improvement =
367 (log_lik - log_likelihood_history[log_likelihood_history.len() - 2]).abs();
368 if improvement < self.tol {
369 converged = true;
370 break;
371 }
372 }
373 }
374
375 let n_iter = log_likelihood_history.len();
376 let trained_state = MiniBatchGMMTrained {
377 weights,
378 means,
379 covariances,
380 log_likelihood_history,
381 batch_sizes,
382 n_iter,
383 converged,
384 };
385
386 Ok(MiniBatchGMM {
387 n_components: self.n_components,
388 batch_strategy: self.batch_strategy,
389 covariance_type: self.covariance_type,
390 max_iter: self.max_iter,
391 tol: self.tol,
392 reg_covar: self.reg_covar,
393 learning_rate: self.learning_rate,
394 momentum: self.momentum,
395 random_state: self.random_state,
396 _phantom: std::marker::PhantomData,
397 }
398 .with_state(trained_state))
399 }
400}
401
402impl MiniBatchGMM<Untrained> {
403 fn with_state(self, _state: MiniBatchGMMTrained) -> MiniBatchGMM<MiniBatchGMMTrained> {
404 MiniBatchGMM {
405 n_components: self.n_components,
406 batch_strategy: self.batch_strategy,
407 covariance_type: self.covariance_type,
408 max_iter: self.max_iter,
409 tol: self.tol,
410 reg_covar: self.reg_covar,
411 learning_rate: self.learning_rate,
412 momentum: self.momentum,
413 random_state: self.random_state,
414 _phantom: std::marker::PhantomData,
415 }
416 }
417}
418
419impl Predict<ArrayView2<'_, Float>, Array1<usize>> for MiniBatchGMM<MiniBatchGMMTrained> {
420 #[allow(non_snake_case)]
421 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<usize>> {
422 let (n_samples, _) = X.dim();
423 Ok(Array1::zeros(n_samples))
424 }
425}
426
427#[derive(Debug, Clone)]
429pub struct ParallelGMM<S = Untrained> {
430 n_components: usize,
431 parallel_strategy: ParallelStrategy,
432 _phantom: std::marker::PhantomData<S>,
433}
434
435#[derive(Debug, Clone)]
436pub struct ParallelGMMTrained {
437 pub weights: Array1<f64>,
438 pub means: Array2<f64>,
439}
440
441#[derive(Debug, Clone)]
442pub struct ParallelGMMBuilder {
443 n_components: usize,
444 parallel_strategy: ParallelStrategy,
445}
446
447impl ParallelGMMBuilder {
448 pub fn new() -> Self {
449 Self {
450 n_components: 1,
451 parallel_strategy: ParallelStrategy::DataParallel { n_threads: 4 },
452 }
453 }
454
455 pub fn n_components(mut self, n: usize) -> Self {
456 self.n_components = n;
457 self
458 }
459
460 pub fn parallel_strategy(mut self, strategy: ParallelStrategy) -> Self {
461 self.parallel_strategy = strategy;
462 self
463 }
464
465 pub fn build(self) -> ParallelGMM<Untrained> {
466 ParallelGMM {
467 n_components: self.n_components,
468 parallel_strategy: self.parallel_strategy,
469 _phantom: std::marker::PhantomData,
470 }
471 }
472}
473
474impl Default for ParallelGMMBuilder {
475 fn default() -> Self {
476 Self::new()
477 }
478}
479
480impl ParallelGMM<Untrained> {
481 pub fn builder() -> ParallelGMMBuilder {
482 ParallelGMMBuilder::new()
483 }
484}
485
486#[cfg(test)]
487mod tests {
488 use super::*;
489 use scirs2_core::ndarray::array;
490
491 #[test]
492 fn test_minibatch_gmm_builder() {
493 let model = MiniBatchGMM::builder()
494 .n_components(3)
495 .batch_strategy(BatchStrategy::Fixed { size: 128 })
496 .learning_rate(0.05)
497 .build();
498
499 assert_eq!(model.n_components, 3);
500 assert_eq!(model.batch_strategy, BatchStrategy::Fixed { size: 128 });
501 assert_eq!(model.learning_rate, 0.05);
502 }
503
504 #[test]
505 fn test_batch_strategy_types() {
506 let strategies = vec![
507 BatchStrategy::Fixed { size: 100 },
508 BatchStrategy::Adaptive {
509 initial_size: 50,
510 max_size: 500,
511 },
512 BatchStrategy::Dynamic {
513 target_memory_mb: 100,
514 },
515 ];
516
517 for strategy in strategies {
518 let model = MiniBatchGMM::builder().batch_strategy(strategy).build();
519 assert_eq!(model.batch_strategy, strategy);
520 }
521 }
522
523 #[test]
524 fn test_parallel_strategy_types() {
525 let strategies = vec![
526 ParallelStrategy::DataParallel { n_threads: 4 },
527 ParallelStrategy::ModelParallel { n_threads: 2 },
528 ParallelStrategy::Hybrid {
529 data_threads: 2,
530 model_threads: 2,
531 },
532 ];
533
534 for strategy in strategies {
535 let model = ParallelGMM::builder().parallel_strategy(strategy).build();
536 assert_eq!(model.parallel_strategy, strategy);
537 }
538 }
539
540 #[test]
541 fn test_minibatch_gmm_fit() {
542 let X = array![
543 [1.0, 2.0],
544 [1.5, 2.5],
545 [10.0, 11.0],
546 [10.5, 11.5],
547 [5.0, 6.0],
548 [5.5, 6.5]
549 ];
550
551 let model = MiniBatchGMM::builder()
552 .n_components(2)
553 .batch_strategy(BatchStrategy::Fixed { size: 3 })
554 .max_iter(10)
555 .build();
556
557 let result = model.fit(&X.view(), &());
558 assert!(result.is_ok());
559 }
560
561 #[test]
562 fn test_builder_defaults() {
563 let model = MiniBatchGMM::builder().build();
564 assert_eq!(model.n_components, 1);
565 assert_eq!(model.learning_rate, 0.1);
566 assert_eq!(model.momentum, 0.9);
567 }
568
569 #[test]
570 fn test_parallel_gmm_builder() {
571 let model = ParallelGMM::builder()
572 .n_components(4)
573 .parallel_strategy(ParallelStrategy::DataParallel { n_threads: 8 })
574 .build();
575
576 assert_eq!(model.n_components, 4);
577 }
578}