1use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
7use scirs2_core::rand_prelude::SliceRandom;
8use scirs2_core::random::{Random, Rng};
9use sklears_core::{
10 error::{Result as SklResult, SklearsError},
11 traits::{Estimator, Fit, Transform, Untrained},
12 types::Float,
13};
14use std::collections::HashMap;
15
16#[derive(Debug, Clone)]
48pub struct RandomForestImputer<S = Untrained> {
49 state: S,
50 n_estimators: usize,
51 max_depth: Option<usize>,
52 min_samples_split: usize,
53 min_samples_leaf: usize,
54 max_features: String,
55 bootstrap: bool,
56 random_state: Option<u64>,
57 missing_values: f64,
58}
59
60#[derive(Debug, Clone)]
62pub struct RandomForestImputerTrained {
63 forests: HashMap<usize, RandomForest>,
64 feature_means_: Array1<f64>,
65 n_features_in_: usize,
66}
67
68#[derive(Debug, Clone)]
70pub struct RandomForest {
71 trees: Vec<DecisionTree>,
72 feature_indices: Vec<usize>,
73 target_feature: usize,
74}
75
76#[derive(Debug, Clone)]
108pub struct GradientBoostingImputer<S = Untrained> {
109 state: S,
110 n_estimators: usize,
111 learning_rate: f64,
112 max_depth: usize,
113 min_samples_split: usize,
114 min_samples_leaf: usize,
115 subsample: f64,
116 random_state: Option<u64>,
117 missing_values: f64,
118}
119
120#[derive(Debug, Clone)]
122pub struct GradientBoostingImputerTrained {
123 boosting_models: HashMap<usize, GradientBoostingModel>,
124 feature_means_: Array1<f64>,
125 n_features_in_: usize,
126}
127
128#[derive(Debug, Clone)]
130pub struct GradientBoostingModel {
131 trees: Vec<DecisionTree>,
132 learning_rate: f64,
133 initial_prediction: f64,
134 target_feature: usize,
135}
136
137#[derive(Debug, Clone)]
153pub struct ExtraTreesImputer<S = Untrained> {
154 state: S,
155 n_estimators: usize,
156 max_depth: Option<usize>,
157 min_samples_split: usize,
158 min_samples_leaf: usize,
159 max_features: String,
160 bootstrap: bool,
161 random_state: Option<u64>,
162 missing_values: f64,
163}
164
165#[derive(Debug, Clone)]
167pub struct ExtraTreesImputerTrained {
168 forests: HashMap<usize, ExtraTreesForest>,
169 feature_means_: Array1<f64>,
170 n_features_in_: usize,
171}
172
173#[derive(Debug, Clone)]
175pub struct ExtraTreesForest {
176 trees: Vec<DecisionTree>,
177 feature_indices: Vec<usize>,
178 target_feature: usize,
179}
180
181#[derive(Debug, Clone)]
183pub struct DecisionTree {
184 nodes: Vec<TreeNode>,
185 max_depth: Option<usize>,
186 min_samples_split: usize,
187 min_samples_leaf: usize,
188}
189
190#[derive(Debug, Clone)]
192pub struct TreeNode {
193 feature_index: Option<usize>,
194 threshold: Option<f64>,
195 left_child: Option<usize>,
196 right_child: Option<usize>,
197 value: f64,
198 n_samples: usize,
199 is_leaf: bool,
200}
201
202#[derive(Debug, Clone)]
204struct TreeTrainingData {
205 features: Array2<f64>,
206 targets: Array1<f64>,
207 sample_indices: Vec<usize>,
208}
209
210impl RandomForestImputer<Untrained> {
213 pub fn new() -> Self {
215 Self {
216 state: Untrained,
217 n_estimators: 100,
218 max_depth: None,
219 min_samples_split: 2,
220 min_samples_leaf: 1,
221 max_features: "sqrt".to_string(),
222 bootstrap: true,
223 random_state: None,
224 missing_values: f64::NAN,
225 }
226 }
227
228 pub fn n_estimators(mut self, n_estimators: usize) -> Self {
230 self.n_estimators = n_estimators;
231 self
232 }
233
234 pub fn max_depth(mut self, max_depth: usize) -> Self {
236 self.max_depth = Some(max_depth);
237 self
238 }
239
240 pub fn min_samples_split(mut self, min_samples_split: usize) -> Self {
242 self.min_samples_split = min_samples_split;
243 self
244 }
245
246 pub fn min_samples_leaf(mut self, min_samples_leaf: usize) -> Self {
248 self.min_samples_leaf = min_samples_leaf;
249 self
250 }
251
252 pub fn max_features(mut self, max_features: String) -> Self {
254 self.max_features = max_features;
255 self
256 }
257
258 pub fn bootstrap(mut self, bootstrap: bool) -> Self {
260 self.bootstrap = bootstrap;
261 self
262 }
263
264 pub fn random_state(mut self, random_state: u64) -> Self {
266 self.random_state = Some(random_state);
267 self
268 }
269
270 pub fn missing_values(mut self, missing_values: f64) -> Self {
272 self.missing_values = missing_values;
273 self
274 }
275
276 fn is_missing(&self, value: f64) -> bool {
277 if self.missing_values.is_nan() {
278 value.is_nan()
279 } else {
280 (value - self.missing_values).abs() < f64::EPSILON
281 }
282 }
283}
284
285impl Default for RandomForestImputer<Untrained> {
286 fn default() -> Self {
287 Self::new()
288 }
289}
290
291impl Estimator for RandomForestImputer<Untrained> {
292 type Config = ();
293 type Error = SklearsError;
294 type Float = Float;
295
296 fn config(&self) -> &Self::Config {
297 &()
298 }
299}
300
301impl Fit<ArrayView2<'_, Float>, ()> for RandomForestImputer<Untrained> {
302 type Fitted = RandomForestImputer<RandomForestImputerTrained>;
303
304 #[allow(non_snake_case)]
305 fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
306 let X = X.mapv(|x| x);
307 let (n_samples, n_features) = X.dim();
308
309 if n_samples == 0 || n_features == 0 {
310 return Err(SklearsError::InvalidInput("Empty dataset".to_string()));
311 }
312
313 let mut rng = Random::default();
314
315 let feature_means = compute_feature_means(&X, self.missing_values);
317
318 let mut forests = HashMap::new();
319
320 for target_feature in 0..n_features {
322 let has_missing = (0..n_samples).any(|i| self.is_missing(X[[i, target_feature]]));
323
324 if has_missing {
325 let forest = self.train_random_forest(&X, target_feature, &mut rng)?;
326 forests.insert(target_feature, forest);
327 }
328 }
329
330 Ok(RandomForestImputer {
331 state: RandomForestImputerTrained {
332 forests,
333 feature_means_: feature_means,
334 n_features_in_: n_features,
335 },
336 n_estimators: self.n_estimators,
337 max_depth: self.max_depth,
338 min_samples_split: self.min_samples_split,
339 min_samples_leaf: self.min_samples_leaf,
340 max_features: self.max_features,
341 bootstrap: self.bootstrap,
342 random_state: self.random_state,
343 missing_values: self.missing_values,
344 })
345 }
346}
347
348impl Transform<ArrayView2<'_, Float>, Array2<Float>>
349 for RandomForestImputer<RandomForestImputerTrained>
350{
351 #[allow(non_snake_case)]
352 fn transform(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
353 let X = X.mapv(|x| x);
354 let (n_samples, n_features) = X.dim();
355
356 if n_features != self.state.n_features_in_ {
357 return Err(SklearsError::InvalidInput(format!(
358 "Number of features {} does not match training features {}",
359 n_features, self.state.n_features_in_
360 )));
361 }
362
363 let mut X_imputed = X.clone();
364
365 for (&target_feature, forest) in &self.state.forests {
367 for i in 0..n_samples {
368 if self.is_missing(X_imputed[[i, target_feature]]) {
369 let mut input_features = Vec::new();
371 for j in 0..n_features {
372 if j != target_feature {
373 if self.is_missing(X_imputed[[i, j]]) {
374 input_features.push(self.state.feature_means_[j]);
375 } else {
376 input_features.push(X_imputed[[i, j]]);
377 }
378 }
379 }
380
381 let input_array = Array1::from_vec(input_features);
382 let predicted_value = self.predict_forest(forest, &input_array)?;
383 X_imputed[[i, target_feature]] = predicted_value;
384 }
385 }
386 }
387
388 Ok(X_imputed.mapv(|x| x as Float))
389 }
390}
391
392impl RandomForestImputer<Untrained> {
393 fn train_random_forest(
394 &self,
395 X: &Array2<f64>,
396 target_feature: usize,
397 rng: &mut impl Rng,
398 ) -> SklResult<RandomForest> {
399 let (n_samples, n_features) = X.dim();
400
401 let mut training_data = Vec::new();
403 let mut training_targets = Vec::new();
404
405 for i in 0..n_samples {
406 if !self.is_missing(X[[i, target_feature]]) {
407 let mut features = Vec::new();
408 let mut has_missing = false;
409
410 for j in 0..n_features {
411 if j != target_feature {
412 if self.is_missing(X[[i, j]]) {
413 has_missing = true;
414 break;
415 }
416 features.push(X[[i, j]]);
417 }
418 }
419
420 if !has_missing {
421 training_data.push(features);
422 training_targets.push(X[[i, target_feature]]);
423 }
424 }
425 }
426
427 if training_data.is_empty() {
428 return Err(SklearsError::InvalidInput(
429 "No valid training samples for feature".to_string(),
430 ));
431 }
432
433 let n_training_features = training_data[0].len();
434 let training_X =
435 Array2::from_shape_fn((training_data.len(), n_training_features), |(i, j)| {
436 training_data[i][j]
437 });
438 let training_y = Array1::from_vec(training_targets);
439
440 let mut feature_indices = Vec::new();
442 for j in 0..n_features {
443 if j != target_feature {
444 feature_indices.push(j);
445 }
446 }
447
448 let mut trees = Vec::new();
450 for _ in 0..self.n_estimators {
451 let tree = self.train_tree(&training_X, &training_y, rng)?;
452 trees.push(tree);
453 }
454
455 Ok(RandomForest {
456 trees,
457 feature_indices,
458 target_feature,
459 })
460 }
461
462 fn train_tree(
463 &self,
464 X: &Array2<f64>,
465 y: &Array1<f64>,
466 rng: &mut impl Rng,
467 ) -> SklResult<DecisionTree> {
468 let (n_samples, _n_features) = X.dim();
469 let mut sample_indices: Vec<usize> = (0..n_samples).collect();
470
471 if self.bootstrap {
472 sample_indices = (0..n_samples)
474 .map(|_| rng.gen_range(0..n_samples))
475 .collect();
476 }
477
478 let training_data = TreeTrainingData {
479 features: X.clone(),
480 targets: y.clone(),
481 sample_indices,
482 };
483
484 let mut tree = DecisionTree {
485 nodes: Vec::new(),
486 max_depth: self.max_depth,
487 min_samples_split: self.min_samples_split,
488 min_samples_leaf: self.min_samples_leaf,
489 };
490
491 self.build_tree(&mut tree, &training_data, 0, rng)?;
492 Ok(tree)
493 }
494
495 fn build_tree(
496 &self,
497 tree: &mut DecisionTree,
498 data: &TreeTrainingData,
499 depth: usize,
500 rng: &mut impl Rng,
501 ) -> SklResult<usize> {
502 let sample_indices = &data.sample_indices;
503 let n_samples = sample_indices.len();
504
505 let node_value = if n_samples > 0 {
507 sample_indices.iter().map(|&i| data.targets[i]).sum::<f64>() / n_samples as f64
508 } else {
509 0.0
510 };
511
512 let should_stop = n_samples < self.min_samples_split
514 || n_samples < self.min_samples_leaf * 2
515 || self.max_depth.is_some_and(|max_d| depth >= max_d)
516 || self.all_targets_equal(data, sample_indices);
517
518 if should_stop {
519 let node_index = tree.nodes.len();
521 tree.nodes.push(TreeNode {
522 feature_index: None,
523 threshold: None,
524 left_child: None,
525 right_child: None,
526 value: node_value,
527 n_samples,
528 is_leaf: true,
529 });
530 return Ok(node_index);
531 }
532
533 let (best_feature, best_threshold) = self.find_best_split(data, sample_indices, rng)?;
535
536 if best_feature.is_none() {
537 let node_index = tree.nodes.len();
539 tree.nodes.push(TreeNode {
540 feature_index: None,
541 threshold: None,
542 left_child: None,
543 right_child: None,
544 value: node_value,
545 n_samples,
546 is_leaf: true,
547 });
548 return Ok(node_index);
549 }
550
551 let feature_idx = best_feature.unwrap();
552 let threshold = best_threshold.unwrap();
553
554 let (left_indices, right_indices) =
556 self.split_samples(data, sample_indices, feature_idx, threshold);
557
558 if left_indices.is_empty() || right_indices.is_empty() {
559 let node_index = tree.nodes.len();
561 tree.nodes.push(TreeNode {
562 feature_index: None,
563 threshold: None,
564 left_child: None,
565 right_child: None,
566 value: node_value,
567 n_samples,
568 is_leaf: true,
569 });
570 return Ok(node_index);
571 }
572
573 let node_index = tree.nodes.len();
575 tree.nodes.push(TreeNode {
576 feature_index: Some(feature_idx),
577 threshold: Some(threshold),
578 left_child: None,
579 right_child: None,
580 value: node_value,
581 n_samples,
582 is_leaf: false,
583 });
584
585 let left_data = TreeTrainingData {
587 features: data.features.clone(),
588 targets: data.targets.clone(),
589 sample_indices: left_indices,
590 };
591 let left_child_idx = self.build_tree(tree, &left_data, depth + 1, rng)?;
592
593 let right_data = TreeTrainingData {
594 features: data.features.clone(),
595 targets: data.targets.clone(),
596 sample_indices: right_indices,
597 };
598 let right_child_idx = self.build_tree(tree, &right_data, depth + 1, rng)?;
599
600 tree.nodes[node_index].left_child = Some(left_child_idx);
602 tree.nodes[node_index].right_child = Some(right_child_idx);
603
604 Ok(node_index)
605 }
606
607 fn all_targets_equal(&self, data: &TreeTrainingData, sample_indices: &[usize]) -> bool {
608 if sample_indices.is_empty() {
609 return true;
610 }
611
612 let first_target = data.targets[sample_indices[0]];
613 sample_indices
614 .iter()
615 .all(|&i| (data.targets[i] - first_target).abs() < 1e-8)
616 }
617
618 fn find_best_split(
619 &self,
620 data: &TreeTrainingData,
621 sample_indices: &[usize],
622 rng: &mut impl Rng,
623 ) -> SklResult<(Option<usize>, Option<f64>)> {
624 let n_features = data.features.ncols();
625
626 let max_features = match self.max_features.as_str() {
628 "sqrt" => (n_features as f64).sqrt() as usize,
629 "log2" => (n_features as f64).log2() as usize,
630 "all" => n_features,
631 _ => n_features,
632 };
633
634 let mut feature_candidates: Vec<usize> = (0..n_features).collect();
636 feature_candidates.shuffle(rng);
637 feature_candidates.truncate(max_features.max(1));
638
639 let mut best_score = f64::NEG_INFINITY;
640 let mut best_feature = None;
641 let mut best_threshold = None;
642
643 for &feature_idx in &feature_candidates {
644 let mut feature_values: Vec<f64> = sample_indices
646 .iter()
647 .map(|&i| data.features[[i, feature_idx]])
648 .collect();
649 feature_values.sort_by(|a, b| a.partial_cmp(b).unwrap());
650 feature_values.dedup_by(|a, b| (*a - *b).abs() < 1e-8);
651
652 if feature_values.len() < 2 {
653 continue;
654 }
655
656 for i in 0..(feature_values.len() - 1) {
658 let threshold = (feature_values[i] + feature_values[i + 1]) / 2.0;
659 let score =
660 self.calculate_split_score(data, sample_indices, feature_idx, threshold);
661
662 if score > best_score {
663 best_score = score;
664 best_feature = Some(feature_idx);
665 best_threshold = Some(threshold);
666 }
667 }
668 }
669
670 Ok((best_feature, best_threshold))
671 }
672
673 fn calculate_split_score(
674 &self,
675 data: &TreeTrainingData,
676 sample_indices: &[usize],
677 feature_idx: usize,
678 threshold: f64,
679 ) -> f64 {
680 let (left_indices, right_indices) =
681 self.split_samples(data, sample_indices, feature_idx, threshold);
682
683 if left_indices.is_empty() || right_indices.is_empty() {
684 return f64::NEG_INFINITY;
685 }
686
687 let total_variance = self.calculate_variance(data, sample_indices);
689 let left_variance = self.calculate_variance(data, &left_indices);
690 let right_variance = self.calculate_variance(data, &right_indices);
691
692 let left_weight = left_indices.len() as f64 / sample_indices.len() as f64;
693 let right_weight = right_indices.len() as f64 / sample_indices.len() as f64;
694
695 let weighted_variance = left_weight * left_variance + right_weight * right_variance;
696 total_variance - weighted_variance
697 }
698
699 fn calculate_variance(&self, data: &TreeTrainingData, sample_indices: &[usize]) -> f64 {
700 if sample_indices.len() <= 1 {
701 return 0.0;
702 }
703
704 let mean = sample_indices.iter().map(|&i| data.targets[i]).sum::<f64>()
705 / sample_indices.len() as f64;
706 let variance = sample_indices
707 .iter()
708 .map(|&i| (data.targets[i] - mean).powi(2))
709 .sum::<f64>()
710 / sample_indices.len() as f64;
711
712 variance
713 }
714
715 fn split_samples(
716 &self,
717 data: &TreeTrainingData,
718 sample_indices: &[usize],
719 feature_idx: usize,
720 threshold: f64,
721 ) -> (Vec<usize>, Vec<usize>) {
722 let mut left_indices = Vec::new();
723 let mut right_indices = Vec::new();
724
725 for &sample_idx in sample_indices {
726 if data.features[[sample_idx, feature_idx]] <= threshold {
727 left_indices.push(sample_idx);
728 } else {
729 right_indices.push(sample_idx);
730 }
731 }
732
733 (left_indices, right_indices)
734 }
735}
736
737impl RandomForestImputer<RandomForestImputerTrained> {
738 fn is_missing(&self, value: f64) -> bool {
739 if self.missing_values.is_nan() {
740 value.is_nan()
741 } else {
742 (value - self.missing_values).abs() < f64::EPSILON
743 }
744 }
745
746 fn predict_forest(&self, forest: &RandomForest, input: &Array1<f64>) -> SklResult<f64> {
747 let mut predictions = Vec::new();
748
749 for tree in &forest.trees {
750 let prediction = self.predict_tree(tree, input)?;
751 predictions.push(prediction);
752 }
753
754 Ok(predictions.iter().sum::<f64>() / predictions.len() as f64)
756 }
757
758 fn predict_tree(&self, tree: &DecisionTree, input: &Array1<f64>) -> SklResult<f64> {
759 let mut current_node_idx = 0;
760
761 loop {
762 if current_node_idx >= tree.nodes.len() {
763 return Err(SklearsError::InvalidInput(
764 "Invalid tree structure".to_string(),
765 ));
766 }
767
768 let node = &tree.nodes[current_node_idx];
769
770 if node.is_leaf {
771 return Ok(node.value);
772 }
773
774 let feature_idx = node.feature_index.ok_or_else(|| {
775 SklearsError::InvalidInput("Non-leaf node missing feature index".to_string())
776 })?;
777 let threshold = node.threshold.ok_or_else(|| {
778 SklearsError::InvalidInput("Non-leaf node missing threshold".to_string())
779 })?;
780
781 if feature_idx >= input.len() {
782 return Err(SklearsError::InvalidInput(
783 "Feature index out of bounds".to_string(),
784 ));
785 }
786
787 if input[feature_idx] <= threshold {
788 current_node_idx = node
789 .left_child
790 .ok_or_else(|| SklearsError::InvalidInput("Missing left child".to_string()))?;
791 } else {
792 current_node_idx = node
793 .right_child
794 .ok_or_else(|| SklearsError::InvalidInput("Missing right child".to_string()))?;
795 }
796 }
797 }
798}
799
800fn compute_feature_means(X: &Array2<f64>, missing_values: f64) -> Array1<f64> {
803 let (_, n_features) = X.dim();
804 let mut means = Array1::zeros(n_features);
805
806 let is_missing_nan = missing_values.is_nan();
807
808 for j in 0..n_features {
809 let column = X.column(j);
810 let valid_values: Vec<f64> = column
811 .iter()
812 .filter(|&&x| {
813 if is_missing_nan {
814 !x.is_nan()
815 } else {
816 (x - missing_values).abs() >= f64::EPSILON
817 }
818 })
819 .cloned()
820 .collect();
821
822 means[j] = if valid_values.is_empty() {
823 0.0
824 } else {
825 valid_values.iter().sum::<f64>() / valid_values.len() as f64
826 };
827 }
828
829 means
830}
831
832impl GradientBoostingImputer<Untrained> {
836 pub fn new() -> Self {
838 Self {
839 state: Untrained,
840 n_estimators: 100,
841 learning_rate: 0.1,
842 max_depth: 3,
843 min_samples_split: 2,
844 min_samples_leaf: 1,
845 subsample: 1.0,
846 random_state: None,
847 missing_values: f64::NAN,
848 }
849 }
850
851 pub fn n_estimators(mut self, n_estimators: usize) -> Self {
853 self.n_estimators = n_estimators;
854 self
855 }
856
857 pub fn learning_rate(mut self, learning_rate: f64) -> Self {
859 self.learning_rate = learning_rate;
860 self
861 }
862
863 pub fn max_depth(mut self, max_depth: usize) -> Self {
865 self.max_depth = max_depth;
866 self
867 }
868
869 pub fn subsample(mut self, subsample: f64) -> Self {
871 self.subsample = subsample;
872 self
873 }
874
875 pub fn random_state(mut self, random_state: u64) -> Self {
877 self.random_state = Some(random_state);
878 self
879 }
880
881 pub fn missing_values(mut self, missing_values: f64) -> Self {
883 self.missing_values = missing_values;
884 self
885 }
886}
887
888impl Default for GradientBoostingImputer<Untrained> {
889 fn default() -> Self {
890 Self::new()
891 }
892}
893
894impl Estimator for GradientBoostingImputer<Untrained> {
895 type Config = ();
896 type Error = SklearsError;
897 type Float = Float;
898
899 fn config(&self) -> &Self::Config {
900 &()
901 }
902}
903
904impl Fit<ArrayView2<'_, Float>, ()> for GradientBoostingImputer<Untrained> {
905 type Fitted = GradientBoostingImputer<GradientBoostingImputerTrained>;
906
907 #[allow(non_snake_case)]
908 fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
909 let X = X.mapv(|x| x);
910 let (_n_samples, n_features) = X.dim();
911
912 let feature_means = compute_feature_means(&X, self.missing_values);
913 let boosting_models = HashMap::new(); Ok(GradientBoostingImputer {
916 state: GradientBoostingImputerTrained {
917 boosting_models,
918 feature_means_: feature_means,
919 n_features_in_: n_features,
920 },
921 n_estimators: self.n_estimators,
922 learning_rate: self.learning_rate,
923 max_depth: self.max_depth,
924 min_samples_split: self.min_samples_split,
925 min_samples_leaf: self.min_samples_leaf,
926 subsample: self.subsample,
927 random_state: self.random_state,
928 missing_values: self.missing_values,
929 })
930 }
931}
932
933impl Transform<ArrayView2<'_, Float>, Array2<Float>>
934 for GradientBoostingImputer<GradientBoostingImputerTrained>
935{
936 #[allow(non_snake_case)]
937 fn transform(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
938 let X = X.mapv(|x| x);
939 let (n_samples, n_features) = X.dim();
940
941 if n_features != self.state.n_features_in_ {
942 return Err(SklearsError::InvalidInput(format!(
943 "Number of features {} does not match training features {}",
944 n_features, self.state.n_features_in_
945 )));
946 }
947
948 let mut X_imputed = X.clone();
949
950 for i in 0..n_samples {
952 for j in 0..n_features {
953 if self.is_missing(X_imputed[[i, j]]) {
954 X_imputed[[i, j]] = self.state.feature_means_[j];
955 }
956 }
957 }
958
959 Ok(X_imputed.mapv(|x| x as Float))
960 }
961}
962
963impl GradientBoostingImputer<GradientBoostingImputerTrained> {
964 fn is_missing(&self, value: f64) -> bool {
965 if self.missing_values.is_nan() {
966 value.is_nan()
967 } else {
968 (value - self.missing_values).abs() < f64::EPSILON
969 }
970 }
971}
972
973impl ExtraTreesImputer<Untrained> {
974 pub fn new() -> Self {
976 Self {
977 state: Untrained,
978 n_estimators: 100,
979 max_depth: None,
980 min_samples_split: 2,
981 min_samples_leaf: 1,
982 max_features: "sqrt".to_string(),
983 bootstrap: false,
984 random_state: None,
985 missing_values: f64::NAN,
986 }
987 }
988
989 pub fn n_estimators(mut self, n_estimators: usize) -> Self {
991 self.n_estimators = n_estimators;
992 self
993 }
994
995 pub fn max_depth(mut self, max_depth: usize) -> Self {
997 self.max_depth = Some(max_depth);
998 self
999 }
1000
1001 pub fn random_state(mut self, random_state: u64) -> Self {
1003 self.random_state = Some(random_state);
1004 self
1005 }
1006
1007 pub fn missing_values(mut self, missing_values: f64) -> Self {
1009 self.missing_values = missing_values;
1010 self
1011 }
1012}
1013
1014impl Default for ExtraTreesImputer<Untrained> {
1015 fn default() -> Self {
1016 Self::new()
1017 }
1018}
1019
1020impl Estimator for ExtraTreesImputer<Untrained> {
1021 type Config = ();
1022 type Error = SklearsError;
1023 type Float = Float;
1024
1025 fn config(&self) -> &Self::Config {
1026 &()
1027 }
1028}
1029
1030impl Fit<ArrayView2<'_, Float>, ()> for ExtraTreesImputer<Untrained> {
1031 type Fitted = ExtraTreesImputer<ExtraTreesImputerTrained>;
1032
1033 #[allow(non_snake_case)]
1034 fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
1035 let X = X.mapv(|x| x);
1036 let (_n_samples, n_features) = X.dim();
1037
1038 let feature_means = compute_feature_means(&X, self.missing_values);
1039 let forests = HashMap::new(); Ok(ExtraTreesImputer {
1042 state: ExtraTreesImputerTrained {
1043 forests,
1044 feature_means_: feature_means,
1045 n_features_in_: n_features,
1046 },
1047 n_estimators: self.n_estimators,
1048 max_depth: self.max_depth,
1049 min_samples_split: self.min_samples_split,
1050 min_samples_leaf: self.min_samples_leaf,
1051 max_features: self.max_features,
1052 bootstrap: self.bootstrap,
1053 random_state: self.random_state,
1054 missing_values: self.missing_values,
1055 })
1056 }
1057}
1058
1059impl Transform<ArrayView2<'_, Float>, Array2<Float>>
1060 for ExtraTreesImputer<ExtraTreesImputerTrained>
1061{
1062 #[allow(non_snake_case)]
1063 fn transform(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
1064 let X = X.mapv(|x| x);
1065 let (n_samples, n_features) = X.dim();
1066
1067 if n_features != self.state.n_features_in_ {
1068 return Err(SklearsError::InvalidInput(format!(
1069 "Number of features {} does not match training features {}",
1070 n_features, self.state.n_features_in_
1071 )));
1072 }
1073
1074 let mut X_imputed = X.clone();
1075
1076 for i in 0..n_samples {
1078 for j in 0..n_features {
1079 if self.is_missing(X_imputed[[i, j]]) {
1080 X_imputed[[i, j]] = self.state.feature_means_[j];
1081 }
1082 }
1083 }
1084
1085 Ok(X_imputed.mapv(|x| x as Float))
1086 }
1087}
1088
1089impl ExtraTreesImputer<ExtraTreesImputerTrained> {
1090 fn is_missing(&self, value: f64) -> bool {
1091 if self.missing_values.is_nan() {
1092 value.is_nan()
1093 } else {
1094 (value - self.missing_values).abs() < f64::EPSILON
1095 }
1096 }
1097}