1use crate::common::CovarianceType;
25use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
26use scirs2_core::random::thread_rng;
27use sklears_core::{
28 error::{Result as SklResult, SklearsError},
29 traits::{Estimator, Fit, Predict, Untrained},
30 types::Float,
31};
32use std::f64::consts::PI;
33
34#[derive(Debug, Clone, PartialEq)]
36pub enum RegularizationType {
37 L1 { lambda: f64 },
39 L2 { lambda: f64 },
41 ElasticNet { l1_ratio: f64, lambda: f64 },
43 GroupLasso {
45 lambda: f64,
46 groups: Vec<Vec<usize>>,
47 },
48}
49
50#[derive(Debug, Clone)]
73pub struct L1RegularizedGMM<S = Untrained> {
74 n_components: usize,
75 lambda: f64,
76 covariance_type: CovarianceType,
77 max_iter: usize,
78 tol: f64,
79 reg_covar: f64,
80 random_state: Option<u64>,
81 _phantom: std::marker::PhantomData<S>,
82}
83
84#[derive(Debug, Clone)]
86pub struct L1RegularizedGMMTrained {
87 pub weights: Array1<f64>,
89 pub means: Array2<f64>,
91 pub covariances: Array2<f64>,
93 pub sparsity_pattern: Vec<Vec<bool>>,
95 pub n_nonzero: usize,
97 pub log_likelihood_history: Vec<f64>,
99 pub n_iter: usize,
101 pub converged: bool,
103}
104
105#[derive(Debug, Clone)]
107pub struct L1RegularizedGMMBuilder {
108 n_components: usize,
109 lambda: f64,
110 covariance_type: CovarianceType,
111 max_iter: usize,
112 tol: f64,
113 reg_covar: f64,
114 random_state: Option<u64>,
115}
116
117impl L1RegularizedGMMBuilder {
118 pub fn new() -> Self {
120 Self {
121 n_components: 1,
122 lambda: 0.01,
123 covariance_type: CovarianceType::Diagonal,
124 max_iter: 100,
125 tol: 1e-3,
126 reg_covar: 1e-6,
127 random_state: None,
128 }
129 }
130
131 pub fn n_components(mut self, n_components: usize) -> Self {
133 self.n_components = n_components;
134 self
135 }
136
137 pub fn lambda(mut self, lambda: f64) -> Self {
139 self.lambda = lambda;
140 self
141 }
142
143 pub fn covariance_type(mut self, covariance_type: CovarianceType) -> Self {
145 self.covariance_type = covariance_type;
146 self
147 }
148
149 pub fn max_iter(mut self, max_iter: usize) -> Self {
151 self.max_iter = max_iter;
152 self
153 }
154
155 pub fn tol(mut self, tol: f64) -> Self {
157 self.tol = tol;
158 self
159 }
160
161 pub fn reg_covar(mut self, reg_covar: f64) -> Self {
163 self.reg_covar = reg_covar;
164 self
165 }
166
167 pub fn random_state(mut self, random_state: u64) -> Self {
169 self.random_state = Some(random_state);
170 self
171 }
172
173 pub fn build(self) -> L1RegularizedGMM<Untrained> {
175 L1RegularizedGMM {
176 n_components: self.n_components,
177 lambda: self.lambda,
178 covariance_type: self.covariance_type,
179 max_iter: self.max_iter,
180 tol: self.tol,
181 reg_covar: self.reg_covar,
182 random_state: self.random_state,
183 _phantom: std::marker::PhantomData,
184 }
185 }
186}
187
188impl Default for L1RegularizedGMMBuilder {
189 fn default() -> Self {
190 Self::new()
191 }
192}
193
194impl L1RegularizedGMM<Untrained> {
195 pub fn builder() -> L1RegularizedGMMBuilder {
197 L1RegularizedGMMBuilder::new()
198 }
199
200 fn soft_threshold(x: f64, lambda: f64) -> f64 {
202 if x > lambda {
203 x - lambda
204 } else if x < -lambda {
205 x + lambda
206 } else {
207 0.0
208 }
209 }
210}
211
212impl Estimator for L1RegularizedGMM<Untrained> {
213 type Config = ();
214 type Error = SklearsError;
215 type Float = Float;
216
217 fn config(&self) -> &Self::Config {
218 &()
219 }
220}
221
222impl Fit<ArrayView2<'_, Float>, ()> for L1RegularizedGMM<Untrained> {
223 type Fitted = L1RegularizedGMM<L1RegularizedGMMTrained>;
224
225 #[allow(non_snake_case)]
226 fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
227 let X_owned = X.to_owned();
228 let (n_samples, n_features) = X_owned.dim();
229
230 if n_samples < self.n_components {
231 return Err(SklearsError::InvalidInput(
232 "Number of samples must be >= number of components".to_string(),
233 ));
234 }
235
236 let mut rng = thread_rng();
238 if let Some(_seed) = self.random_state {
239 }
241
242 let mut means = Array2::zeros((self.n_components, n_features));
243 let mut used_indices = Vec::new();
244 for k in 0..self.n_components {
245 let idx = loop {
246 let candidate = rng.gen_range(0..n_samples);
247 if !used_indices.contains(&candidate) {
248 used_indices.push(candidate);
249 break candidate;
250 }
251 };
252 means.row_mut(k).assign(&X_owned.row(idx));
253 }
254
255 let mut weights = Array1::from_elem(self.n_components, 1.0 / self.n_components as f64);
256 let mut covariances =
257 Array2::<f64>::eye(n_features) + &(Array2::<f64>::eye(n_features) * self.reg_covar);
258
259 let mut log_likelihood_history = Vec::new();
260 let mut converged = false;
261
262 for iter in 0..self.max_iter {
264 let mut responsibilities = Array2::zeros((n_samples, self.n_components));
266
267 for i in 0..n_samples {
268 let x = X_owned.row(i);
269 let mut log_probs = Vec::new();
270
271 for k in 0..self.n_components {
272 let mean = means.row(k);
273 let diff = &x.to_owned() - &mean.to_owned();
274
275 let mahal = diff
276 .iter()
277 .zip(covariances.diag().iter())
278 .map(|(d, c): (&f64, &f64)| d * d / c.max(self.reg_covar))
279 .sum::<f64>();
280
281 let log_det = covariances
282 .diag()
283 .iter()
284 .map(|c| c.max(self.reg_covar).ln())
285 .sum::<f64>();
286
287 let log_prob = weights[k].ln()
288 - 0.5 * (n_features as f64 * (2.0 * PI).ln() + log_det)
289 - 0.5 * mahal;
290
291 log_probs.push(log_prob);
292 }
293
294 let max_log = log_probs.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
295 let sum_exp: f64 = log_probs.iter().map(|&lp| (lp - max_log).exp()).sum();
296
297 for k in 0..self.n_components {
298 responsibilities[[i, k]] =
299 ((log_probs[k] - max_log).exp() / sum_exp).max(1e-10);
300 }
301 }
302
303 for k in 0..self.n_components {
305 let resps = responsibilities.column(k);
306 let nk = resps.sum().max(1e-10);
307
308 weights[k] = nk / n_samples as f64;
309
310 let mut new_mean = Array1::zeros(n_features);
312 for i in 0..n_samples {
313 new_mean += &(X_owned.row(i).to_owned() * resps[i]);
314 }
315 new_mean /= nk;
316
317 for j in 0..n_features {
319 new_mean[j] = Self::soft_threshold(new_mean[j], self.lambda);
320 }
321 means.row_mut(k).assign(&new_mean);
322
323 let mut new_cov = Array1::zeros(n_features);
325 for i in 0..n_samples {
326 let diff = &X_owned.row(i).to_owned() - &new_mean;
327 new_cov += &(diff.mapv(|x| x * x) * resps[i]);
328 }
329 new_cov = new_cov / nk + Array1::from_elem(n_features, self.reg_covar);
330 covariances.diag_mut().assign(&new_cov);
331 }
332
333 weights /= weights.sum();
334
335 let mut log_lik = 0.0;
337 for i in 0..n_samples {
338 let mut ll = 0.0;
339 for k in 0..self.n_components {
340 ll += responsibilities[[i, k]];
341 }
342 log_lik += ll.max(1e-10).ln();
343 }
344
345 let l1_penalty: f64 = means.iter().map(|&m| m.abs()).sum::<f64>() * self.lambda;
347 log_lik -= l1_penalty;
348
349 log_likelihood_history.push(log_lik);
350
351 if iter > 0 {
352 let improvement = (log_lik - log_likelihood_history[iter - 1]).abs();
353 if improvement < self.tol {
354 converged = true;
355 break;
356 }
357 }
358 }
359
360 let mut sparsity_pattern = Vec::new();
362 let mut n_nonzero = 0;
363 for k in 0..self.n_components {
364 let mut pattern = Vec::new();
365 for j in 0..n_features {
366 let is_nonzero = means[[k, j]].abs() > 1e-10;
367 pattern.push(is_nonzero);
368 if is_nonzero {
369 n_nonzero += 1;
370 }
371 }
372 sparsity_pattern.push(pattern);
373 }
374
375 let n_iter = log_likelihood_history.len();
376 let trained_state = L1RegularizedGMMTrained {
377 weights,
378 means,
379 covariances,
380 sparsity_pattern,
381 n_nonzero,
382 log_likelihood_history,
383 n_iter,
384 converged,
385 };
386
387 Ok(L1RegularizedGMM {
388 n_components: self.n_components,
389 lambda: self.lambda,
390 covariance_type: self.covariance_type,
391 max_iter: self.max_iter,
392 tol: self.tol,
393 reg_covar: self.reg_covar,
394 random_state: self.random_state,
395 _phantom: std::marker::PhantomData,
396 }
397 .with_state(trained_state))
398 }
399}
400
401impl L1RegularizedGMM<Untrained> {
402 fn with_state(
403 self,
404 _state: L1RegularizedGMMTrained,
405 ) -> L1RegularizedGMM<L1RegularizedGMMTrained> {
406 L1RegularizedGMM {
407 n_components: self.n_components,
408 lambda: self.lambda,
409 covariance_type: self.covariance_type,
410 max_iter: self.max_iter,
411 tol: self.tol,
412 reg_covar: self.reg_covar,
413 random_state: self.random_state,
414 _phantom: std::marker::PhantomData,
415 }
416 }
417}
418
419impl Predict<ArrayView2<'_, Float>, Array1<usize>> for L1RegularizedGMM<L1RegularizedGMMTrained> {
420 #[allow(non_snake_case)]
421 fn predict(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array1<usize>> {
422 let (n_samples, _) = X.dim();
423 Ok(Array1::zeros(n_samples))
424 }
425}
426
427#[derive(Debug, Clone)]
429pub struct L2RegularizedGMM<S = Untrained> {
430 n_components: usize,
431 lambda: f64,
432 covariance_type: CovarianceType,
433 max_iter: usize,
434 tol: f64,
435 reg_covar: f64,
436 random_state: Option<u64>,
437 _phantom: std::marker::PhantomData<S>,
438}
439
440#[derive(Debug, Clone)]
441pub struct L2RegularizedGMMTrained {
442 pub weights: Array1<f64>,
443 pub means: Array2<f64>,
444 pub covariances: Array2<f64>,
445 pub log_likelihood_history: Vec<f64>,
446 pub n_iter: usize,
447 pub converged: bool,
448}
449
450#[derive(Debug, Clone)]
451pub struct L2RegularizedGMMBuilder {
452 n_components: usize,
453 lambda: f64,
454 covariance_type: CovarianceType,
455 max_iter: usize,
456 tol: f64,
457 reg_covar: f64,
458 random_state: Option<u64>,
459}
460
461impl L2RegularizedGMMBuilder {
462 pub fn new() -> Self {
463 Self {
464 n_components: 1,
465 lambda: 0.01,
466 covariance_type: CovarianceType::Diagonal,
467 max_iter: 100,
468 tol: 1e-3,
469 reg_covar: 1e-6,
470 random_state: None,
471 }
472 }
473
474 pub fn n_components(mut self, n: usize) -> Self {
475 self.n_components = n;
476 self
477 }
478
479 pub fn lambda(mut self, l: f64) -> Self {
480 self.lambda = l;
481 self
482 }
483
484 pub fn build(self) -> L2RegularizedGMM<Untrained> {
485 L2RegularizedGMM {
486 n_components: self.n_components,
487 lambda: self.lambda,
488 covariance_type: self.covariance_type,
489 max_iter: self.max_iter,
490 tol: self.tol,
491 reg_covar: self.reg_covar,
492 random_state: self.random_state,
493 _phantom: std::marker::PhantomData,
494 }
495 }
496}
497
498impl Default for L2RegularizedGMMBuilder {
499 fn default() -> Self {
500 Self::new()
501 }
502}
503
504impl L2RegularizedGMM<Untrained> {
505 pub fn builder() -> L2RegularizedGMMBuilder {
506 L2RegularizedGMMBuilder::new()
507 }
508}
509
510#[derive(Debug, Clone)]
512pub struct ElasticNetGMM<S = Untrained> {
513 n_components: usize,
514 l1_ratio: f64,
515 lambda: f64,
516 _phantom: std::marker::PhantomData<S>,
517}
518
519#[derive(Debug, Clone)]
520pub struct ElasticNetGMMTrained {
521 pub weights: Array1<f64>,
522 pub means: Array2<f64>,
523}
524
525#[derive(Debug, Clone)]
526pub struct ElasticNetGMMBuilder {
527 n_components: usize,
528 l1_ratio: f64,
529 lambda: f64,
530}
531
532impl ElasticNetGMMBuilder {
533 pub fn new() -> Self {
534 Self {
535 n_components: 1,
536 l1_ratio: 0.5,
537 lambda: 0.01,
538 }
539 }
540
541 pub fn n_components(mut self, n: usize) -> Self {
542 self.n_components = n;
543 self
544 }
545
546 pub fn l1_ratio(mut self, r: f64) -> Self {
547 self.l1_ratio = r;
548 self
549 }
550
551 pub fn lambda(mut self, l: f64) -> Self {
552 self.lambda = l;
553 self
554 }
555
556 pub fn build(self) -> ElasticNetGMM<Untrained> {
557 ElasticNetGMM {
558 n_components: self.n_components,
559 l1_ratio: self.l1_ratio,
560 lambda: self.lambda,
561 _phantom: std::marker::PhantomData,
562 }
563 }
564}
565
566impl Default for ElasticNetGMMBuilder {
567 fn default() -> Self {
568 Self::new()
569 }
570}
571
572impl ElasticNetGMM<Untrained> {
573 pub fn builder() -> ElasticNetGMMBuilder {
574 ElasticNetGMMBuilder::new()
575 }
576}
577
578#[derive(Debug, Clone)]
580pub struct GroupLassoGMM<S = Untrained> {
581 n_components: usize,
582 lambda: f64,
583 groups: Vec<Vec<usize>>,
584 _phantom: std::marker::PhantomData<S>,
585}
586
587#[derive(Debug, Clone)]
588pub struct GroupLassoGMMTrained {
589 pub weights: Array1<f64>,
590 pub means: Array2<f64>,
591 pub active_groups: Vec<bool>,
592}
593
594#[derive(Debug, Clone)]
595pub struct GroupLassoGMMBuilder {
596 n_components: usize,
597 lambda: f64,
598 groups: Vec<Vec<usize>>,
599}
600
601impl GroupLassoGMMBuilder {
602 pub fn new() -> Self {
603 Self {
604 n_components: 1,
605 lambda: 0.01,
606 groups: Vec::new(),
607 }
608 }
609
610 pub fn n_components(mut self, n: usize) -> Self {
611 self.n_components = n;
612 self
613 }
614
615 pub fn lambda(mut self, l: f64) -> Self {
616 self.lambda = l;
617 self
618 }
619
620 pub fn add_group(mut self, group: Vec<usize>) -> Self {
621 self.groups.push(group);
622 self
623 }
624
625 pub fn build(self) -> GroupLassoGMM<Untrained> {
626 GroupLassoGMM {
627 n_components: self.n_components,
628 lambda: self.lambda,
629 groups: self.groups,
630 _phantom: std::marker::PhantomData,
631 }
632 }
633}
634
635impl Default for GroupLassoGMMBuilder {
636 fn default() -> Self {
637 Self::new()
638 }
639}
640
641impl GroupLassoGMM<Untrained> {
642 pub fn builder() -> GroupLassoGMMBuilder {
643 GroupLassoGMMBuilder::new()
644 }
645}
646
647#[cfg(test)]
648mod tests {
649 use super::*;
650 use scirs2_core::ndarray::array;
651
652 #[test]
653 fn test_soft_threshold() {
654 assert_eq!(L1RegularizedGMM::soft_threshold(2.0, 0.5), 1.5);
655 assert_eq!(L1RegularizedGMM::soft_threshold(-2.0, 0.5), -1.5);
656 assert_eq!(L1RegularizedGMM::soft_threshold(0.3, 0.5), 0.0);
657 }
658
659 #[test]
660 fn test_l1_regularized_gmm_builder() {
661 let model = L1RegularizedGMM::builder()
662 .n_components(3)
663 .lambda(0.05)
664 .max_iter(50)
665 .build();
666
667 assert_eq!(model.n_components, 3);
668 assert_eq!(model.lambda, 0.05);
669 assert_eq!(model.max_iter, 50);
670 }
671
672 #[test]
673 fn test_l1_regularized_gmm_fit() {
674 let X = array![[1.0, 2.0], [1.5, 2.5], [10.0, 11.0], [10.5, 11.5]];
675
676 let model = L1RegularizedGMM::builder()
677 .n_components(2)
678 .lambda(0.01)
679 .max_iter(20)
680 .build();
681
682 let result = model.fit(&X.view(), &());
683 assert!(result.is_ok());
684 }
685
686 #[test]
687 fn test_l2_regularized_gmm_builder() {
688 let model = L2RegularizedGMM::builder()
689 .n_components(2)
690 .lambda(0.1)
691 .build();
692
693 assert_eq!(model.n_components, 2);
694 assert_eq!(model.lambda, 0.1);
695 }
696
697 #[test]
698 fn test_elastic_net_gmm_builder() {
699 let model = ElasticNetGMM::builder()
700 .n_components(3)
701 .l1_ratio(0.7)
702 .lambda(0.05)
703 .build();
704
705 assert_eq!(model.n_components, 3);
706 assert_eq!(model.l1_ratio, 0.7);
707 assert_eq!(model.lambda, 0.05);
708 }
709
710 #[test]
711 fn test_group_lasso_gmm_builder() {
712 let model = GroupLassoGMM::builder()
713 .n_components(2)
714 .lambda(0.02)
715 .add_group(vec![0, 1, 2])
716 .add_group(vec![3, 4])
717 .build();
718
719 assert_eq!(model.n_components, 2);
720 assert_eq!(model.lambda, 0.02);
721 assert_eq!(model.groups.len(), 2);
722 }
723
724 #[test]
725 fn test_regularization_type() {
726 let l1 = RegularizationType::L1 { lambda: 0.1 };
727 let l2 = RegularizationType::L2 { lambda: 0.2 };
728 let enet = RegularizationType::ElasticNet {
729 l1_ratio: 0.5,
730 lambda: 0.15,
731 };
732
733 assert!(matches!(l1, RegularizationType::L1 { .. }));
734 assert!(matches!(l2, RegularizationType::L2 { .. }));
735 assert!(matches!(enet, RegularizationType::ElasticNet { .. }));
736 }
737}