1use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
8use scirs2_core::random::{Random, Rng};
9use serde::{Deserialize, Serialize};
13use sklears_core::{
14 error::{Result as SklResult, SklearsError},
15 traits::{Estimator, Fit, Transform, Untrained},
16 types::Float,
17};
18use std::collections::HashMap;
19
20#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct SamplingConfig {
23 pub n_samples: usize,
25 pub strategy: SamplingStrategy,
27 pub importance_sampling: bool,
29 pub weight_function: WeightFunction,
31 pub stratify_by: Option<Vec<usize>>,
33 pub n_strata: usize,
35 pub use_quasi_random: bool,
37 pub quasi_sequence_type: QuasiSequenceType,
39 pub adaptive_sampling: bool,
41 pub confidence_level: f64,
43 pub max_iterations: usize,
45}
46
47impl Default for SamplingConfig {
48 fn default() -> Self {
49 Self {
50 n_samples: 1000,
51 strategy: SamplingStrategy::Simple,
52 importance_sampling: false,
53 weight_function: WeightFunction::Uniform,
54 stratify_by: None,
55 n_strata: 5,
56 use_quasi_random: false,
57 quasi_sequence_type: QuasiSequenceType::Halton,
58 adaptive_sampling: false,
59 confidence_level: 0.95,
60 max_iterations: 100,
61 }
62 }
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
67pub enum SamplingStrategy {
68 Simple,
70 Stratified,
72 Cluster,
74 Systematic,
76 Importance,
78 LatinHypercube,
80 Bootstrap,
82 Reservoir,
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
88pub enum WeightFunction {
89 Uniform,
91 InverseProbability,
93 DensityBased,
95 DistanceBased,
97 Custom(Vec<f64>),
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
103pub enum QuasiSequenceType {
104 Halton,
106 Sobol,
108 Faure,
110 Niederreiter,
112}
113
114#[derive(Debug)]
116pub struct SamplingSimpleImputer<S = Untrained> {
117 state: S,
118 strategy: String,
119 missing_values: f64,
120 config: SamplingConfig,
121}
122
123#[derive(Debug)]
125pub struct SamplingSimpleImputerTrained {
126 sample_statistics_: Array1<f64>,
127 sample_distributions_: Vec<SampleDistribution>,
128 n_features_in_: usize,
129 config: SamplingConfig,
130}
131
132#[derive(Debug, Clone)]
134pub struct SampleDistribution {
135 pub values: Vec<f64>,
137 pub weights: Vec<f64>,
139 pub cumulative_weights: Vec<f64>,
141 pub distribution_type: DistributionType,
143}
144
145type StratifiedDistributionsResult = Result<
147 (
148 HashMap<Vec<usize>, HashMap<usize, SampleDistribution>>,
149 Vec<Array1<f64>>,
150 ),
151 SklearsError,
152>;
153
154#[derive(Debug, Clone, Serialize, Deserialize)]
156pub enum DistributionType {
157 Empirical,
159 KernelDensity,
161 Parametric(ParametricDistribution),
163}
164
165#[derive(Debug, Clone, Serialize, Deserialize)]
167pub enum ParametricDistribution {
168 Normal { mean: f64, std: f64 },
170 LogNormal { mean_log: f64, std_log: f64 },
172 Exponential { rate: f64 },
174 Gamma { shape: f64, rate: f64 },
176 Beta { alpha: f64, beta: f64 },
178 Uniform { low: f64, high: f64 },
180}
181
182#[derive(Debug)]
184pub struct StratifiedSamplingImputer<S = Untrained> {
185 state: S,
186 missing_values: f64,
187 config: SamplingConfig,
188 stratification_features: Vec<usize>,
189}
190
191#[derive(Debug)]
193pub struct StratifiedSamplingImputerTrained {
194 strata_distributions_: HashMap<Vec<usize>, HashMap<usize, SampleDistribution>>,
195 feature_strata_: Vec<Array1<f64>>, n_features_in_: usize,
197 config: SamplingConfig,
198}
199
200#[derive(Debug)]
202pub struct ImportanceSamplingImputer<S = Untrained> {
203 state: S,
204 missing_values: f64,
205 config: SamplingConfig,
206 proposal_distribution: ProposalDistribution,
207}
208
209#[derive(Debug)]
211pub struct ImportanceSamplingImputerTrained {
212 importance_weights_: Array2<f64>, proposal_samples_: Array2<f64>,
214 target_density_: Array1<f64>,
215 n_features_in_: usize,
216 config: SamplingConfig,
217}
218
219#[derive(Debug, Clone)]
221pub enum ProposalDistribution {
222 Empirical,
224 GaussianMixture { n_components: usize },
226 KernelDensity { bandwidth: f64 },
228}
229
230#[derive(Debug)]
232pub struct AdaptiveSamplingImputer<S = Untrained> {
233 state: S,
234 missing_values: f64,
235 config: SamplingConfig,
236 convergence_threshold: f64,
237}
238
239#[derive(Debug)]
241pub struct AdaptiveSamplingImputerTrained {
242 adaptive_samples_: Vec<Array1<f64>>, convergence_history_: Vec<f64>,
244 final_estimates_: Array1<f64>,
245 confidence_intervals_: Array2<f64>,
246 n_features_in_: usize,
247 config: SamplingConfig,
248}
249
250impl SamplingSimpleImputer<Untrained> {
251 pub fn new() -> Self {
252 Self {
253 state: Untrained,
254 strategy: "mean".to_string(),
255 missing_values: f64::NAN,
256 config: SamplingConfig::default(),
257 }
258 }
259
260 pub fn strategy(mut self, strategy: String) -> Self {
261 self.strategy = strategy;
262 self
263 }
264
265 pub fn sampling_config(mut self, config: SamplingConfig) -> Self {
266 self.config = config;
267 self
268 }
269
270 pub fn n_samples(mut self, n_samples: usize) -> Self {
271 self.config.n_samples = n_samples;
272 self
273 }
274
275 pub fn sampling_strategy(mut self, strategy: SamplingStrategy) -> Self {
276 self.config.strategy = strategy;
277 self
278 }
279
280 pub fn weight_function(mut self, weight_function: WeightFunction) -> Self {
281 self.config.weight_function = weight_function;
282 self
283 }
284
285 fn is_missing(&self, value: f64) -> bool {
286 if self.missing_values.is_nan() {
287 value.is_nan()
288 } else {
289 (value - self.missing_values).abs() < f64::EPSILON
290 }
291 }
292}
293
294impl Default for SamplingSimpleImputer<Untrained> {
295 fn default() -> Self {
296 Self::new()
297 }
298}
299
300impl Estimator for SamplingSimpleImputer<Untrained> {
301 type Config = SamplingConfig;
302 type Error = SklearsError;
303 type Float = Float;
304
305 fn config(&self) -> &Self::Config {
306 &self.config
307 }
308}
309
310impl Fit<ArrayView2<'_, Float>, ()> for SamplingSimpleImputer<Untrained> {
311 type Fitted = SamplingSimpleImputer<SamplingSimpleImputerTrained>;
312
313 #[allow(non_snake_case)]
314 fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
315 let X = X.mapv(|x| x);
316 let (_n_samples, n_features) = X.dim();
317
318 let (sample_statistics, sample_distributions) = self.compute_sample_statistics(&X)?;
319
320 Ok(SamplingSimpleImputer {
321 state: SamplingSimpleImputerTrained {
322 sample_statistics_: sample_statistics,
323 sample_distributions_: sample_distributions,
324 n_features_in_: n_features,
325 config: self.config,
326 },
327 strategy: self.strategy,
328 missing_values: self.missing_values,
329 config: Default::default(),
330 })
331 }
332}
333
334impl Transform<ArrayView2<'_, Float>, Array2<Float>>
335 for SamplingSimpleImputer<SamplingSimpleImputerTrained>
336{
337 #[allow(non_snake_case)]
338 fn transform(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
339 let X = X.mapv(|x| x);
340 let (n_samples, n_features) = X.dim();
341
342 if n_features != self.state.n_features_in_ {
343 return Err(SklearsError::InvalidInput(format!(
344 "Number of features {} does not match training features {}",
345 n_features, self.state.n_features_in_
346 )));
347 }
348
349 let mut X_imputed = X.clone();
350
351 for i in 0..n_samples {
353 for j in 0..n_features {
354 if self.is_missing(X_imputed[[i, j]]) {
355 let imputed_value = self.sample_imputed_value(j)?;
356 X_imputed[[i, j]] = imputed_value;
357 }
358 }
359 }
360
361 Ok(X_imputed.mapv(|x| x as Float))
362 }
363}
364
365impl SamplingSimpleImputer<Untrained> {
366 fn compute_sample_statistics(
368 &self,
369 X: &Array2<f64>,
370 ) -> Result<(Array1<f64>, Vec<SampleDistribution>), SklearsError> {
371 let (_, n_features) = X.dim();
372 let mut sample_statistics = Array1::<f64>::zeros(n_features);
373 let mut sample_distributions = Vec::new();
374
375 for j in 0..n_features {
376 let column = X.column(j);
377 let valid_values: Vec<f64> = column
378 .iter()
379 .filter(|&&x| !self.is_missing(x))
380 .cloned()
381 .collect();
382
383 if valid_values.is_empty() {
384 sample_statistics[j] = 0.0;
385 sample_distributions.push(SampleDistribution {
386 values: vec![0.0],
387 weights: vec![1.0],
388 cumulative_weights: vec![1.0],
389 distribution_type: DistributionType::Empirical,
390 });
391 continue;
392 }
393
394 let distribution = match self.config.strategy {
396 SamplingStrategy::Simple => {
397 self.create_simple_sample_distribution(&valid_values)?
398 }
399 SamplingStrategy::Importance => {
400 self.create_importance_sample_distribution(&valid_values)?
401 }
402 SamplingStrategy::Bootstrap => {
403 self.create_bootstrap_sample_distribution(&valid_values)?
404 }
405 SamplingStrategy::LatinHypercube => {
406 self.create_latin_hypercube_distribution(&valid_values)?
407 }
408 _ => self.create_simple_sample_distribution(&valid_values)?,
409 };
410
411 sample_statistics[j] = match self.strategy.as_str() {
413 "mean" => self.compute_weighted_mean(&distribution),
414 "median" => self.compute_weighted_median(&distribution),
415 "mode" => self.compute_weighted_mode(&distribution),
416 _ => self.compute_weighted_mean(&distribution),
417 };
418
419 sample_distributions.push(distribution);
420 }
421
422 Ok((sample_statistics, sample_distributions))
423 }
424
425 fn create_simple_sample_distribution(
427 &self,
428 values: &[f64],
429 ) -> Result<SampleDistribution, SklearsError> {
430 let n_samples = self.config.n_samples.min(values.len());
431 let mut rng = Random::default();
432
433 let mut sampled_values = Vec::new();
435 let mut weights = Vec::new();
436
437 for _ in 0..n_samples {
438 let idx = rng.gen_range(0..values.len());
439 sampled_values.push(values[idx]);
440 weights.push(1.0 / n_samples as f64);
441 }
442
443 let mut cumulative_weights = Vec::new();
445 let mut cumsum = 0.0;
446 for &weight in &weights {
447 cumsum += weight;
448 cumulative_weights.push(cumsum);
449 }
450
451 Ok(SampleDistribution {
452 values: sampled_values,
453 weights,
454 cumulative_weights,
455 distribution_type: DistributionType::Empirical,
456 })
457 }
458
459 fn create_importance_sample_distribution(
461 &self,
462 values: &[f64],
463 ) -> Result<SampleDistribution, SklearsError> {
464 let n_samples = self.config.n_samples.min(values.len());
465
466 let mut weighted_values = Vec::new();
468 let mut importance_weights = Vec::new();
469
470 let bandwidth = self.compute_bandwidth(values);
472
473 for &value in values.iter().take(n_samples) {
474 let density = self.kernel_density_estimate(value, values, bandwidth);
475 let importance_weight = 1.0 / (density + 1e-8); weighted_values.push(value);
478 importance_weights.push(importance_weight);
479 }
480
481 let total_weight: f64 = importance_weights.iter().sum();
483 for weight in &mut importance_weights {
484 *weight /= total_weight;
485 }
486
487 let mut cumulative_weights = Vec::new();
489 let mut cumsum = 0.0;
490 for &weight in &importance_weights {
491 cumsum += weight;
492 cumulative_weights.push(cumsum);
493 }
494
495 Ok(SampleDistribution {
496 values: weighted_values,
497 weights: importance_weights,
498 cumulative_weights,
499 distribution_type: DistributionType::KernelDensity,
500 })
501 }
502
503 fn create_bootstrap_sample_distribution(
505 &self,
506 values: &[f64],
507 ) -> Result<SampleDistribution, SklearsError> {
508 let n_bootstrap = 100;
509 let n_samples_per_bootstrap = values.len();
510 let mut bootstrap_estimates = Vec::new();
511
512 let mut rng = Random::default();
513
514 for _ in 0..n_bootstrap {
515 let mut bootstrap_sample = Vec::new();
516
517 for _ in 0..n_samples_per_bootstrap {
519 let idx = rng.gen_range(0..values.len());
520 bootstrap_sample.push(values[idx]);
521 }
522
523 let estimate = match self.strategy.as_str() {
525 "mean" => bootstrap_sample.iter().sum::<f64>() / bootstrap_sample.len() as f64,
526 "median" => {
527 let mut sorted = bootstrap_sample.clone();
528 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
529 let mid = sorted.len() / 2;
530 if sorted.len() % 2 == 0 {
531 (sorted[mid - 1] + sorted[mid]) / 2.0
532 } else {
533 sorted[mid]
534 }
535 }
536 _ => bootstrap_sample.iter().sum::<f64>() / bootstrap_sample.len() as f64,
537 };
538
539 bootstrap_estimates.push(estimate);
540 }
541
542 let uniform_weight = 1.0 / bootstrap_estimates.len() as f64;
543 let weights = vec![uniform_weight; bootstrap_estimates.len()];
544
545 let mut cumulative_weights = Vec::new();
547 let mut cumsum = 0.0;
548 for &weight in &weights {
549 cumsum += weight;
550 cumulative_weights.push(cumsum);
551 }
552
553 Ok(SampleDistribution {
554 values: bootstrap_estimates,
555 weights,
556 cumulative_weights,
557 distribution_type: DistributionType::Empirical,
558 })
559 }
560
561 fn create_latin_hypercube_distribution(
563 &self,
564 values: &[f64],
565 ) -> Result<SampleDistribution, SklearsError> {
566 let n_samples = self.config.n_samples.min(values.len());
567
568 let mut sorted_values = values.to_vec();
570 sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
571
572 let mut lhs_values = Vec::new();
573 let mut rng = Random::default();
574
575 for i in 0..n_samples {
577 let lower_bound = i as f64 / n_samples as f64;
578 let upper_bound = (i + 1) as f64 / n_samples as f64;
579 let uniform_sample: f64 = rng.gen();
580 let stratified_sample = lower_bound + uniform_sample * (upper_bound - lower_bound);
581
582 let quantile_idx = (stratified_sample * (sorted_values.len() - 1) as f64) as usize;
584 let quantile_idx = quantile_idx.min(sorted_values.len() - 1);
585
586 lhs_values.push(sorted_values[quantile_idx]);
587 }
588
589 let uniform_weight = 1.0 / lhs_values.len() as f64;
590 let weights = vec![uniform_weight; lhs_values.len()];
591
592 let mut cumulative_weights = Vec::new();
594 let mut cumsum = 0.0;
595 for &weight in &weights {
596 cumsum += weight;
597 cumulative_weights.push(cumsum);
598 }
599
600 Ok(SampleDistribution {
601 values: lhs_values,
602 weights,
603 cumulative_weights,
604 distribution_type: DistributionType::Empirical,
605 })
606 }
607
608 fn compute_bandwidth(&self, values: &[f64]) -> f64 {
610 if values.len() < 2 {
611 return 1.0;
612 }
613
614 let n = values.len() as f64;
616 let mean = values.iter().sum::<f64>() / n;
617 let variance = values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / (n - 1.0);
618 let std_dev = variance.sqrt();
619
620 std_dev * (4.0 / (3.0 * n)).powf(0.2)
621 }
622
623 fn kernel_density_estimate(&self, x: f64, values: &[f64], bandwidth: f64) -> f64 {
625 let n = values.len() as f64;
626 let sum: f64 = values
627 .iter()
628 .map(|&xi| {
629 let u = (x - xi) / bandwidth;
630 (-0.5 * u * u).exp() })
632 .sum();
633
634 sum / (n * bandwidth * (2.0 * std::f64::consts::PI).sqrt())
635 }
636
637 fn compute_weighted_mean(&self, distribution: &SampleDistribution) -> f64 {
639 distribution
640 .values
641 .iter()
642 .zip(distribution.weights.iter())
643 .map(|(&value, &weight)| value * weight)
644 .sum()
645 }
646
647 fn compute_weighted_median(&self, distribution: &SampleDistribution) -> f64 {
649 for (i, &cum_weight) in distribution.cumulative_weights.iter().enumerate() {
651 if cum_weight >= 0.5 {
652 return distribution.values[i];
653 }
654 }
655
656 distribution.values.last().copied().unwrap_or(0.0)
658 }
659
660 fn compute_weighted_mode(&self, distribution: &SampleDistribution) -> f64 {
662 let mut value_weights = HashMap::new();
663
664 for (&value, &weight) in distribution.values.iter().zip(distribution.weights.iter()) {
665 let key = (value * 1e6) as i64; *value_weights.entry(key).or_insert(0.0) += weight;
667 }
668
669 value_weights
670 .into_iter()
671 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
672 .map(|(key, _)| key as f64 / 1e6)
673 .unwrap_or(0.0)
674 }
675}
676
677impl SamplingSimpleImputer<SamplingSimpleImputerTrained> {
678 fn sample_imputed_value(&self, feature_idx: usize) -> Result<f64, SklearsError> {
680 let distribution = &self.state.sample_distributions_[feature_idx];
681 let mut rng = Random::default();
682 let random_value: f64 = rng.gen();
683
684 for (i, &cum_weight) in distribution.cumulative_weights.iter().enumerate() {
686 if random_value <= cum_weight {
687 return Ok(distribution.values[i]);
688 }
689 }
690
691 Ok(distribution.values.last().copied().unwrap_or(0.0))
693 }
694
695 fn is_missing(&self, value: f64) -> bool {
696 if self.missing_values.is_nan() {
697 value.is_nan()
698 } else {
699 (value - self.missing_values).abs() < f64::EPSILON
700 }
701 }
702
703 pub fn distribution(&self, feature_idx: usize) -> Option<&SampleDistribution> {
705 self.state.sample_distributions_.get(feature_idx)
706 }
707
708 pub fn statistics(&self) -> &Array1<f64> {
710 &self.state.sample_statistics_
711 }
712}
713
714impl StratifiedSamplingImputer<Untrained> {
716 pub fn new() -> Self {
717 Self {
718 state: Untrained,
719 missing_values: f64::NAN,
720 config: SamplingConfig::default(),
721 stratification_features: Vec::new(),
722 }
723 }
724
725 pub fn sampling_config(mut self, config: SamplingConfig) -> Self {
726 self.config = config;
727 self
728 }
729
730 pub fn stratify_by(mut self, features: Vec<usize>) -> Self {
731 self.stratification_features = features.clone();
732 self.config.stratify_by = Some(features);
733 self
734 }
735
736 pub fn n_strata(mut self, n_strata: usize) -> Self {
737 self.config.n_strata = n_strata;
738 self
739 }
740
741 fn is_missing(&self, value: f64) -> bool {
742 if self.missing_values.is_nan() {
743 value.is_nan()
744 } else {
745 (value - self.missing_values).abs() < f64::EPSILON
746 }
747 }
748}
749
750impl Default for StratifiedSamplingImputer<Untrained> {
751 fn default() -> Self {
752 Self::new()
753 }
754}
755
756impl Estimator for StratifiedSamplingImputer<Untrained> {
757 type Config = SamplingConfig;
758 type Error = SklearsError;
759 type Float = Float;
760
761 fn config(&self) -> &Self::Config {
762 &self.config
763 }
764}
765
766impl Fit<ArrayView2<'_, Float>, ()> for StratifiedSamplingImputer<Untrained> {
767 type Fitted = StratifiedSamplingImputer<StratifiedSamplingImputerTrained>;
768
769 #[allow(non_snake_case)]
770 fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
771 let X = X.mapv(|x| x);
772 let (_n_samples, n_features) = X.dim();
773
774 if self.stratification_features.is_empty() {
775 return Err(SklearsError::InvalidInput(
776 "Stratification features must be specified".to_string(),
777 ));
778 }
779
780 let (strata_distributions, feature_strata) = self.compute_stratified_distributions(&X)?;
781
782 Ok(StratifiedSamplingImputer {
783 state: StratifiedSamplingImputerTrained {
784 strata_distributions_: strata_distributions,
785 feature_strata_: feature_strata,
786 n_features_in_: n_features,
787 config: self.config,
788 },
789 missing_values: self.missing_values,
790 config: Default::default(),
791 stratification_features: Vec::new(),
792 })
793 }
794}
795
796impl Transform<ArrayView2<'_, Float>, Array2<Float>>
797 for StratifiedSamplingImputer<StratifiedSamplingImputerTrained>
798{
799 #[allow(non_snake_case)]
800 fn transform(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
801 let X = X.mapv(|x| x);
802 let (n_samples, n_features) = X.dim();
803
804 if n_features != self.state.n_features_in_ {
805 return Err(SklearsError::InvalidInput(format!(
806 "Number of features {} does not match training features {}",
807 n_features, self.state.n_features_in_
808 )));
809 }
810
811 let mut X_imputed = X.clone();
812
813 for i in 0..n_samples {
814 let stratum_key = self.determine_stratum(&X_imputed.row(i).to_owned())?;
816
817 for j in 0..n_features {
818 if self.is_missing(X_imputed[[i, j]]) {
819 if let Some(stratum_dists) = self.state.strata_distributions_.get(&stratum_key)
820 {
821 if let Some(distribution) = stratum_dists.get(&j) {
822 let imputed_value = self.sample_from_distribution(distribution)?;
823 X_imputed[[i, j]] = imputed_value;
824 }
825 }
826 }
827 }
828 }
829
830 Ok(X_imputed.mapv(|x| x as Float))
831 }
832}
833
834impl StratifiedSamplingImputer<Untrained> {
835 fn compute_stratified_distributions(&self, X: &Array2<f64>) -> StratifiedDistributionsResult {
837 let (n_samples, n_features) = X.dim();
838
839 let mut feature_strata = Vec::new();
841 for &feature_idx in &self.stratification_features {
842 let column = X.column(feature_idx);
843 let valid_values: Vec<f64> = column
844 .iter()
845 .filter(|&&x| !self.is_missing(x))
846 .cloned()
847 .collect();
848
849 if valid_values.is_empty() {
850 feature_strata.push(Array1::from_vec(vec![0.0, 1.0]));
851 continue;
852 }
853
854 let mut sorted_values = valid_values;
856 sorted_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
857
858 let mut boundaries = Vec::new();
859 for i in 0..=self.config.n_strata {
860 let quantile = i as f64 / self.config.n_strata as f64;
861 let idx = ((sorted_values.len() - 1) as f64 * quantile) as usize;
862 let idx = idx.min(sorted_values.len() - 1);
863 boundaries.push(sorted_values[idx]);
864 }
865
866 feature_strata.push(Array1::from_vec(boundaries));
867 }
868
869 let mut strata_samples: HashMap<Vec<usize>, Vec<Array1<f64>>> = HashMap::new();
871
872 for i in 0..n_samples {
873 let row = X.row(i).to_owned();
874 let stratum_key = self.assign_to_stratum(&row, &feature_strata)?;
875
876 strata_samples.entry(stratum_key).or_default().push(row);
877 }
878
879 let mut strata_distributions = HashMap::new();
881
882 for (stratum_key, samples) in strata_samples {
883 let mut feature_distributions = HashMap::new();
884
885 for j in 0..n_features {
886 let feature_values: Vec<f64> = samples
887 .iter()
888 .map(|row| row[j])
889 .filter(|&x| !self.is_missing(x))
890 .collect();
891
892 if !feature_values.is_empty() {
893 let distribution = self.create_empirical_distribution(&feature_values)?;
894 feature_distributions.insert(j, distribution);
895 }
896 }
897
898 strata_distributions.insert(stratum_key, feature_distributions);
899 }
900
901 Ok((strata_distributions, feature_strata))
902 }
903
904 fn assign_to_stratum(
906 &self,
907 row: &Array1<f64>,
908 feature_strata: &[Array1<f64>],
909 ) -> Result<Vec<usize>, SklearsError> {
910 let mut stratum_key = Vec::new();
911
912 for (i, &feature_idx) in self.stratification_features.iter().enumerate() {
913 let value = row[feature_idx];
914 if self.is_missing(value) {
915 stratum_key.push(0); continue;
917 }
918
919 let boundaries = &feature_strata[i];
920 let mut stratum = 0;
921
922 for k in 1..boundaries.len() {
923 if value <= boundaries[k] {
924 stratum = k - 1;
925 break;
926 }
927 }
928
929 stratum_key.push(stratum);
930 }
931
932 Ok(stratum_key)
933 }
934
935 fn create_empirical_distribution(
937 &self,
938 values: &[f64],
939 ) -> Result<SampleDistribution, SklearsError> {
940 let uniform_weight = 1.0 / values.len() as f64;
941 let weights = vec![uniform_weight; values.len()];
942
943 let mut cumulative_weights = Vec::new();
945 let mut cumsum = 0.0;
946 for &weight in &weights {
947 cumsum += weight;
948 cumulative_weights.push(cumsum);
949 }
950
951 Ok(SampleDistribution {
952 values: values.to_vec(),
953 weights,
954 cumulative_weights,
955 distribution_type: DistributionType::Empirical,
956 })
957 }
958}
959
960impl StratifiedSamplingImputer<StratifiedSamplingImputerTrained> {
961 fn determine_stratum(&self, row: &Array1<f64>) -> Result<Vec<usize>, SklearsError> {
963 if let Some(ref stratify_features) = self.state.config.stratify_by {
964 let mut stratum_key = Vec::new();
965
966 for (i, &feature_idx) in stratify_features.iter().enumerate() {
967 let value = row[feature_idx];
968 if self.is_missing(value) {
969 stratum_key.push(0); continue;
971 }
972
973 let boundaries = &self.state.feature_strata_[i];
974 let mut stratum = 0;
975
976 for k in 1..boundaries.len() {
977 if value <= boundaries[k] {
978 stratum = k - 1;
979 break;
980 }
981 }
982
983 stratum_key.push(stratum);
984 }
985
986 Ok(stratum_key)
987 } else {
988 Ok(vec![0]) }
990 }
991
992 fn sample_from_distribution(
994 &self,
995 distribution: &SampleDistribution,
996 ) -> Result<f64, SklearsError> {
997 let mut rng = Random::default();
998 let random_value: f64 = rng.gen();
999
1000 for (i, &cum_weight) in distribution.cumulative_weights.iter().enumerate() {
1002 if random_value <= cum_weight {
1003 return Ok(distribution.values[i]);
1004 }
1005 }
1006
1007 Ok(distribution.values.last().copied().unwrap_or(0.0))
1009 }
1010
1011 fn is_missing(&self, value: f64) -> bool {
1012 if self.missing_values.is_nan() {
1013 value.is_nan()
1014 } else {
1015 (value - self.missing_values).abs() < f64::EPSILON
1016 }
1017 }
1018}
1019
1020#[allow(non_snake_case)]
1021#[cfg(test)]
1022mod tests {
1023 use super::*;
1024 use approx::assert_abs_diff_eq;
1025 use scirs2_core::ndarray::array;
1026
1027 #[test]
1028 #[allow(non_snake_case)]
1029 fn test_sampling_simple_imputer() {
1030 let X = array![
1031 [1.0, 2.0, 3.0],
1032 [4.0, f64::NAN, 6.0],
1033 [7.0, 8.0, 9.0],
1034 [10.0, 11.0, 12.0]
1035 ];
1036
1037 let imputer = SamplingSimpleImputer::new()
1038 .strategy("mean".to_string())
1039 .n_samples(100)
1040 .sampling_strategy(SamplingStrategy::Simple);
1041
1042 let fitted = imputer.fit(&X.view(), &()).unwrap();
1043 let X_imputed = fitted.transform(&X.view()).unwrap();
1044
1045 assert!(!X_imputed[[1, 1]].is_nan());
1047 assert_abs_diff_eq!(X_imputed[[0, 0]], 1.0, epsilon = 1e-10);
1048 assert_abs_diff_eq!(X_imputed[[2, 2]], 9.0, epsilon = 1e-10);
1049 }
1050
1051 #[test]
1052 #[allow(non_snake_case)]
1053 fn test_bootstrap_sampling() {
1054 let X = array![[1.0, 2.0], [3.0, f64::NAN], [5.0, 6.0], [7.0, 8.0]];
1055
1056 let imputer = SamplingSimpleImputer::new()
1057 .strategy("mean".to_string())
1058 .sampling_strategy(SamplingStrategy::Bootstrap);
1059
1060 let fitted = imputer.fit(&X.view(), &()).unwrap();
1061 let X_imputed = fitted.transform(&X.view()).unwrap();
1062
1063 assert!(!X_imputed[[1, 1]].is_nan());
1064 assert_abs_diff_eq!(X_imputed[[0, 0]], 1.0, epsilon = 1e-10);
1065 }
1066
1067 #[test]
1068 #[allow(non_snake_case)]
1069 fn test_stratified_sampling_imputer() {
1070 let X = array![
1071 [1.0, 2.0, 0.0], [2.0, f64::NAN, 0.0], [8.0, 9.0, 1.0], [9.0, 10.0, 1.0] ];
1076
1077 let imputer = StratifiedSamplingImputer::new()
1078 .stratify_by(vec![2]) .n_strata(2);
1080
1081 let fitted = imputer.fit(&X.view(), &()).unwrap();
1082 let X_imputed = fitted.transform(&X.view()).unwrap();
1083
1084 assert!(!X_imputed[[1, 1]].is_nan());
1085 assert_abs_diff_eq!(X_imputed[[0, 0]], 1.0, epsilon = 1e-10);
1086 }
1087
1088 #[test]
1089 #[allow(non_snake_case)]
1090 fn test_latin_hypercube_sampling() {
1091 let X = array![
1092 [1.0, 2.0, 3.0],
1093 [4.0, f64::NAN, 6.0],
1094 [7.0, 8.0, 9.0],
1095 [10.0, 11.0, 12.0]
1096 ];
1097
1098 let imputer = SamplingSimpleImputer::new()
1099 .strategy("mean".to_string())
1100 .sampling_strategy(SamplingStrategy::LatinHypercube)
1101 .n_samples(3);
1102
1103 let fitted = imputer.fit(&X.view(), &()).unwrap();
1104 let X_imputed = fitted.transform(&X.view()).unwrap();
1105
1106 assert!(!X_imputed[[1, 1]].is_nan());
1107 assert_abs_diff_eq!(X_imputed[[0, 0]], 1.0, epsilon = 1e-10);
1108 }
1109
1110 #[test]
1111 fn test_sample_distribution() {
1112 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0];
1113 let weights = vec![0.1, 0.2, 0.4, 0.2, 0.1];
1114 let cumulative_weights = vec![0.1, 0.3, 0.7, 0.9, 1.0];
1115
1116 let distribution = SampleDistribution {
1117 values,
1118 weights,
1119 cumulative_weights,
1120 distribution_type: DistributionType::Empirical,
1121 };
1122
1123 assert_eq!(distribution.values.len(), 5);
1124 assert_eq!(distribution.weights.len(), 5);
1125 assert_eq!(distribution.cumulative_weights.len(), 5);
1126 assert!((distribution.cumulative_weights.last().unwrap() - 1.0).abs() < 1e-10);
1127 }
1128
1129 #[test]
1130 fn test_sampling_config() {
1131 let config = SamplingConfig {
1132 n_samples: 500,
1133 strategy: SamplingStrategy::Importance,
1134 importance_sampling: true,
1135 ..Default::default()
1136 };
1137
1138 let imputer = SamplingSimpleImputer::new().sampling_config(config.clone());
1139
1140 assert_eq!(imputer.config.n_samples, 500);
1141 assert!(matches!(
1142 imputer.config.strategy,
1143 SamplingStrategy::Importance
1144 ));
1145 assert!(imputer.config.importance_sampling);
1146 }
1147}