sklears_impute/
fluent_api.rs

1//! Fluent API and builder patterns for easy imputation configuration
2//!
3//! This module provides a convenient, chainable API for configuring and using
4//! imputation methods with sensible defaults and validation.
5
6use scirs2_core::ndarray::{Array1, Array2, ArrayView2};
7use sklears_core::{
8    error::{Result as SklResult, SklearsError},
9    traits::{Fit, Transform},
10    types::Float,
11};
12use std::collections::HashMap;
13
14use crate::{parallel::ParallelConfig, KNNImputer, ParallelKNNImputer, SimpleImputer};
15
16/// Type alias for preprocessing/postprocessing results (means and stds)
17type PreprocessingResult = SklResult<(Option<Array1<Float>>, Option<Array1<Float>>)>;
18
19/// Fluent API builder for imputation pipelines
20#[derive(Debug, Clone)]
21pub struct ImputationBuilder {
22    method: ImputationMethod,
23    validation: ValidationConfig,
24    preprocessing: PreprocessingConfig,
25    postprocessing: PostprocessingConfig,
26    parallel_config: Option<ParallelConfig>,
27}
28
29/// Available imputation methods with their configurations
30#[derive(Debug, Clone)]
31#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
32pub enum ImputationMethod {
33    /// Simple
34    Simple(SimpleImputationConfig),
35    /// KNN
36    KNN(KNNImputationConfig),
37    /// Iterative
38    Iterative(IterativeImputationConfig),
39    /// GaussianProcess
40    GaussianProcess(GaussianProcessConfig),
41    /// MatrixFactorization
42    MatrixFactorization(MatrixFactorizationConfig),
43    /// Bayesian
44    Bayesian(BayesianImputationConfig),
45    /// Ensemble
46    Ensemble(EnsembleImputationConfig),
47    /// DeepLearning
48    DeepLearning(DeepLearningConfig),
49    /// Custom
50    Custom(CustomImputationConfig),
51}
52
53/// Configuration for simple imputation methods
54#[derive(Debug, Clone)]
55#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
56pub struct SimpleImputationConfig {
57    /// strategy
58    pub strategy: String,
59    /// fill_value
60    pub fill_value: Option<f64>,
61    /// copy
62    pub copy: bool,
63}
64
65/// Configuration for KNN imputation
66#[derive(Debug, Clone)]
67#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
68pub struct KNNImputationConfig {
69    /// n_neighbors
70    pub n_neighbors: usize,
71    /// weights
72    pub weights: String,
73    /// metric
74    pub metric: String,
75    /// add_indicator
76    pub add_indicator: bool,
77}
78
79/// Configuration for iterative imputation (MICE)
80#[derive(Debug, Clone)]
81#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
82pub struct IterativeImputationConfig {
83    /// max_iter
84    pub max_iter: usize,
85    /// tol
86    pub tol: f64,
87    /// n_nearest_features
88    pub n_nearest_features: Option<usize>,
89    /// sample_posterior
90    pub sample_posterior: bool,
91    /// random_state
92    pub random_state: Option<u64>,
93}
94
95/// Configuration for Gaussian Process imputation
96#[derive(Debug, Clone)]
97#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
98pub struct GaussianProcessConfig {
99    /// kernel
100    pub kernel: String,
101    /// alpha
102    pub alpha: f64,
103    /// n_restarts_optimizer
104    pub n_restarts_optimizer: usize,
105    /// normalize_y
106    pub normalize_y: bool,
107    /// random_state
108    pub random_state: Option<u64>,
109}
110
111/// Configuration for Matrix Factorization imputation
112#[derive(Debug, Clone)]
113#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
114pub struct MatrixFactorizationConfig {
115    /// n_components
116    pub n_components: usize,
117    /// max_iter
118    pub max_iter: usize,
119    /// tol
120    pub tol: f64,
121    /// regularization
122    pub regularization: f64,
123    /// random_state
124    pub random_state: Option<u64>,
125}
126
127/// Configuration for Bayesian imputation
128#[derive(Debug, Clone)]
129#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
130pub struct BayesianImputationConfig {
131    /// n_imputations
132    pub n_imputations: usize,
133    /// max_iter
134    pub max_iter: usize,
135    /// burn_in
136    pub burn_in: usize,
137    /// prior_variance
138    pub prior_variance: f64,
139    /// random_state
140    pub random_state: Option<u64>,
141}
142
143/// Configuration for ensemble methods
144#[derive(Debug, Clone)]
145#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
146pub struct EnsembleImputationConfig {
147    /// method
148    pub method: String, // "random_forest", "gradient_boosting", etc.
149    /// n_estimators
150    pub n_estimators: usize,
151    /// max_depth
152    pub max_depth: Option<usize>,
153    /// random_state
154    pub random_state: Option<u64>,
155}
156
157/// Configuration for deep learning methods
158#[derive(Debug, Clone)]
159#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
160pub struct DeepLearningConfig {
161    /// method
162    pub method: String, // "autoencoder", "vae", "gan"
163    /// hidden_dims
164    pub hidden_dims: Vec<usize>,
165    /// learning_rate
166    pub learning_rate: f64,
167    /// epochs
168    pub epochs: usize,
169    /// batch_size
170    pub batch_size: usize,
171    /// device
172    pub device: String, // "cpu", "cuda"
173}
174
175/// Configuration for custom imputation methods
176#[derive(Debug, Clone)]
177#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
178pub struct CustomImputationConfig {
179    /// name
180    pub name: String,
181    #[cfg(feature = "serde")]
182    pub parameters: HashMap<String, serde_json::Value>,
183    #[cfg(not(feature = "serde"))]
184    pub parameters: HashMap<String, String>,
185}
186
187/// Validation configuration
188#[derive(Debug, Clone)]
189#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
190pub struct ValidationConfig {
191    /// cross_validation
192    pub cross_validation: bool,
193    /// cv_folds
194    pub cv_folds: usize,
195    /// holdout_fraction
196    pub holdout_fraction: Option<f64>,
197    /// metrics
198    pub metrics: Vec<String>,
199    /// synthetic_missing_patterns
200    pub synthetic_missing_patterns: Vec<String>,
201}
202
203/// Preprocessing configuration
204#[derive(Debug, Clone)]
205#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
206pub struct PreprocessingConfig {
207    /// normalize
208    pub normalize: bool,
209    /// scale
210    pub scale: bool,
211    /// remove_constant_features
212    pub remove_constant_features: bool,
213    /// handle_outliers
214    pub handle_outliers: bool,
215    /// outlier_method
216    pub outlier_method: String,
217}
218
219/// Postprocessing configuration
220#[derive(Debug, Clone)]
221#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
222pub struct PostprocessingConfig {
223    /// clip_values
224    pub clip_values: Option<(f64, f64)>,
225    /// round_integers
226    pub round_integers: bool,
227    /// preserve_dtypes
228    pub preserve_dtypes: bool,
229    /// add_uncertainty_estimates
230    pub add_uncertainty_estimates: bool,
231}
232
233/// Predefined configuration presets
234#[derive(Debug, Clone, Copy, PartialEq, Eq)]
235pub enum ImputationPreset {
236    /// Fast
237    Fast,
238    /// Balanced
239    Balanced,
240    /// HighQuality
241    HighQuality,
242    /// Memory
243    Memory,
244    /// Parallel
245    Parallel,
246    /// Academic
247    Academic,
248    /// Production
249    Production,
250}
251
252impl Default for ImputationBuilder {
253    fn default() -> Self {
254        Self::new()
255    }
256}
257
258impl ImputationBuilder {
259    /// Create a new imputation builder with default settings
260    pub fn new() -> Self {
261        Self {
262            method: ImputationMethod::Simple(SimpleImputationConfig {
263                strategy: "mean".to_string(),
264                fill_value: None,
265                copy: true,
266            }),
267            validation: ValidationConfig {
268                cross_validation: false,
269                cv_folds: 5,
270                holdout_fraction: None,
271                metrics: vec!["rmse".to_string()],
272                synthetic_missing_patterns: vec!["mcar".to_string()],
273            },
274            preprocessing: PreprocessingConfig {
275                normalize: false,
276                scale: false,
277                remove_constant_features: false,
278                handle_outliers: false,
279                outlier_method: "iqr".to_string(),
280            },
281            postprocessing: PostprocessingConfig {
282                clip_values: None,
283                round_integers: false,
284                preserve_dtypes: true,
285                add_uncertainty_estimates: false,
286            },
287            parallel_config: None,
288        }
289    }
290
291    /// Apply a predefined configuration preset
292    pub fn preset(mut self, preset: ImputationPreset) -> Self {
293        match preset {
294            ImputationPreset::Fast => {
295                self.method = ImputationMethod::Simple(SimpleImputationConfig {
296                    strategy: "mean".to_string(),
297                    fill_value: None,
298                    copy: true,
299                });
300            }
301            ImputationPreset::Balanced => {
302                self.method = ImputationMethod::KNN(KNNImputationConfig {
303                    n_neighbors: 5,
304                    weights: "uniform".to_string(),
305                    metric: "euclidean".to_string(),
306                    add_indicator: false,
307                });
308            }
309            ImputationPreset::HighQuality => {
310                self.method = ImputationMethod::Iterative(IterativeImputationConfig {
311                    max_iter: 10,
312                    tol: 1e-3,
313                    n_nearest_features: None,
314                    sample_posterior: true,
315                    random_state: None,
316                });
317                self.validation.cross_validation = true;
318                self.postprocessing.add_uncertainty_estimates = true;
319            }
320            ImputationPreset::Memory => {
321                self.method = ImputationMethod::Simple(SimpleImputationConfig {
322                    strategy: "median".to_string(),
323                    fill_value: None,
324                    copy: false,
325                });
326                self.preprocessing.remove_constant_features = true;
327            }
328            ImputationPreset::Parallel => {
329                self.method = ImputationMethod::KNN(KNNImputationConfig {
330                    n_neighbors: 3,
331                    weights: "distance".to_string(),
332                    metric: "euclidean".to_string(),
333                    add_indicator: false,
334                });
335                self.parallel_config = Some(ParallelConfig::default());
336            }
337            ImputationPreset::Academic => {
338                self.method = ImputationMethod::Bayesian(BayesianImputationConfig {
339                    n_imputations: 5,
340                    max_iter: 100,
341                    burn_in: 20,
342                    prior_variance: 1.0,
343                    random_state: Some(42),
344                });
345                self.validation.cross_validation = true;
346                self.validation.cv_folds = 10;
347                self.validation.metrics = vec![
348                    "rmse".to_string(),
349                    "mae".to_string(),
350                    "bias".to_string(),
351                    "coverage".to_string(),
352                ];
353                self.postprocessing.add_uncertainty_estimates = true;
354            }
355            ImputationPreset::Production => {
356                self.method = ImputationMethod::Ensemble(EnsembleImputationConfig {
357                    method: "random_forest".to_string(),
358                    n_estimators: 100,
359                    max_depth: Some(10),
360                    random_state: Some(42),
361                });
362                self.validation.cross_validation = true;
363                self.preprocessing.handle_outliers = true;
364                self.postprocessing.preserve_dtypes = true;
365            }
366        }
367        self
368    }
369
370    /// Configure simple imputation
371    pub fn simple(self) -> SimpleImputationBuilder {
372        SimpleImputationBuilder::new(self)
373    }
374
375    /// Configure KNN imputation
376    pub fn knn(self) -> KNNImputationBuilder {
377        KNNImputationBuilder::new(self)
378    }
379
380    /// Configure iterative imputation
381    pub fn iterative(self) -> IterativeImputationBuilder {
382        IterativeImputationBuilder::new(self)
383    }
384
385    /// Configure Gaussian Process imputation
386    pub fn gaussian_process(self) -> GaussianProcessBuilder {
387        GaussianProcessBuilder::new(self)
388    }
389
390    /// Configure ensemble imputation
391    pub fn ensemble(self) -> EnsembleImputationBuilder {
392        EnsembleImputationBuilder::new(self)
393    }
394
395    /// Configure deep learning imputation
396    pub fn deep_learning(self) -> DeepLearningBuilder {
397        DeepLearningBuilder::new(self)
398    }
399
400    /// Enable parallel processing
401    pub fn parallel(mut self, config: Option<ParallelConfig>) -> Self {
402        self.parallel_config = config.or_else(|| Some(ParallelConfig::default()));
403        self
404    }
405
406    /// Configure validation
407    pub fn validation(mut self, config: ValidationConfig) -> Self {
408        self.validation = config;
409        self
410    }
411
412    /// Enable cross-validation
413    pub fn cross_validate(mut self, folds: usize) -> Self {
414        self.validation.cross_validation = true;
415        self.validation.cv_folds = folds;
416        self
417    }
418
419    /// Configure preprocessing
420    pub fn preprocessing(mut self, config: PreprocessingConfig) -> Self {
421        self.preprocessing = config;
422        self
423    }
424
425    /// Enable normalization
426    pub fn normalize(mut self) -> Self {
427        self.preprocessing.normalize = true;
428        self
429    }
430
431    /// Enable scaling
432    pub fn scale(mut self) -> Self {
433        self.preprocessing.scale = true;
434        self
435    }
436
437    /// Configure postprocessing
438    pub fn postprocessing(mut self, config: PostprocessingConfig) -> Self {
439        self.postprocessing = config;
440        self
441    }
442
443    /// Enable uncertainty estimation
444    pub fn with_uncertainty(mut self) -> Self {
445        self.postprocessing.add_uncertainty_estimates = true;
446        self
447    }
448
449    /// Build the imputation pipeline
450    pub fn build(self) -> SklResult<ImputationPipeline> {
451        ImputationPipeline::new(
452            self.method,
453            self.validation,
454            self.preprocessing,
455            self.postprocessing,
456            self.parallel_config,
457        )
458    }
459}
460
461/// Builder for simple imputation configuration
462pub struct SimpleImputationBuilder {
463    builder: ImputationBuilder,
464    config: SimpleImputationConfig,
465}
466
467impl SimpleImputationBuilder {
468    fn new(builder: ImputationBuilder) -> Self {
469        Self {
470            builder,
471            config: SimpleImputationConfig {
472                strategy: "mean".to_string(),
473                fill_value: None,
474                copy: true,
475            },
476        }
477    }
478
479    pub fn strategy(mut self, strategy: &str) -> Self {
480        self.config.strategy = strategy.to_string();
481        self
482    }
483
484    pub fn mean(mut self) -> Self {
485        self.config.strategy = "mean".to_string();
486        self
487    }
488
489    pub fn median(mut self) -> Self {
490        self.config.strategy = "median".to_string();
491        self
492    }
493
494    pub fn mode(mut self) -> Self {
495        self.config.strategy = "most_frequent".to_string();
496        self
497    }
498
499    pub fn constant(mut self, value: f64) -> Self {
500        self.config.strategy = "constant".to_string();
501        self.config.fill_value = Some(value);
502        self
503    }
504
505    pub fn finish(mut self) -> ImputationBuilder {
506        self.builder.method = ImputationMethod::Simple(self.config);
507        self.builder
508    }
509}
510
511/// Builder for KNN imputation configuration
512pub struct KNNImputationBuilder {
513    builder: ImputationBuilder,
514    config: KNNImputationConfig,
515}
516
517impl KNNImputationBuilder {
518    fn new(builder: ImputationBuilder) -> Self {
519        Self {
520            builder,
521            config: KNNImputationConfig {
522                n_neighbors: 5,
523                weights: "uniform".to_string(),
524                metric: "euclidean".to_string(),
525                add_indicator: false,
526            },
527        }
528    }
529
530    pub fn n_neighbors(mut self, n_neighbors: usize) -> Self {
531        self.config.n_neighbors = n_neighbors;
532        self
533    }
534
535    pub fn weights(mut self, weights: &str) -> Self {
536        self.config.weights = weights.to_string();
537        self
538    }
539
540    pub fn uniform_weights(mut self) -> Self {
541        self.config.weights = "uniform".to_string();
542        self
543    }
544
545    pub fn distance_weights(mut self) -> Self {
546        self.config.weights = "distance".to_string();
547        self
548    }
549
550    pub fn metric(mut self, metric: &str) -> Self {
551        self.config.metric = metric.to_string();
552        self
553    }
554
555    pub fn euclidean(mut self) -> Self {
556        self.config.metric = "euclidean".to_string();
557        self
558    }
559
560    pub fn manhattan(mut self) -> Self {
561        self.config.metric = "manhattan".to_string();
562        self
563    }
564
565    pub fn add_indicator(mut self, add_indicator: bool) -> Self {
566        self.config.add_indicator = add_indicator;
567        self
568    }
569
570    pub fn finish(mut self) -> ImputationBuilder {
571        self.builder.method = ImputationMethod::KNN(self.config);
572        self.builder
573    }
574}
575
576/// Builder for iterative imputation configuration
577pub struct IterativeImputationBuilder {
578    builder: ImputationBuilder,
579    config: IterativeImputationConfig,
580}
581
582impl IterativeImputationBuilder {
583    fn new(builder: ImputationBuilder) -> Self {
584        Self {
585            builder,
586            config: IterativeImputationConfig {
587                max_iter: 10,
588                tol: 1e-3,
589                n_nearest_features: None,
590                sample_posterior: false,
591                random_state: None,
592            },
593        }
594    }
595
596    pub fn max_iter(mut self, max_iter: usize) -> Self {
597        self.config.max_iter = max_iter;
598        self
599    }
600
601    pub fn tolerance(mut self, tol: f64) -> Self {
602        self.config.tol = tol;
603        self
604    }
605
606    pub fn n_nearest_features(mut self, n_features: usize) -> Self {
607        self.config.n_nearest_features = Some(n_features);
608        self
609    }
610
611    pub fn sample_posterior(mut self, sample: bool) -> Self {
612        self.config.sample_posterior = sample;
613        self
614    }
615
616    pub fn random_state(mut self, seed: u64) -> Self {
617        self.config.random_state = Some(seed);
618        self
619    }
620
621    pub fn finish(mut self) -> ImputationBuilder {
622        self.builder.method = ImputationMethod::Iterative(self.config);
623        self.builder
624    }
625}
626
627/// Builder for Gaussian Process imputation configuration
628pub struct GaussianProcessBuilder {
629    builder: ImputationBuilder,
630    config: GaussianProcessConfig,
631}
632
633impl GaussianProcessBuilder {
634    fn new(builder: ImputationBuilder) -> Self {
635        Self {
636            builder,
637            config: GaussianProcessConfig {
638                kernel: "rbf".to_string(),
639                alpha: 1e-6,
640                n_restarts_optimizer: 0,
641                normalize_y: false,
642                random_state: None,
643            },
644        }
645    }
646
647    pub fn kernel(mut self, kernel: &str) -> Self {
648        self.config.kernel = kernel.to_string();
649        self
650    }
651
652    pub fn rbf_kernel(mut self) -> Self {
653        self.config.kernel = "rbf".to_string();
654        self
655    }
656
657    pub fn matern_kernel(mut self) -> Self {
658        self.config.kernel = "matern".to_string();
659        self
660    }
661
662    pub fn alpha(mut self, alpha: f64) -> Self {
663        self.config.alpha = alpha;
664        self
665    }
666
667    pub fn n_restarts(mut self, n_restarts: usize) -> Self {
668        self.config.n_restarts_optimizer = n_restarts;
669        self
670    }
671
672    pub fn normalize_y(mut self, normalize: bool) -> Self {
673        self.config.normalize_y = normalize;
674        self
675    }
676
677    pub fn random_state(mut self, seed: u64) -> Self {
678        self.config.random_state = Some(seed);
679        self
680    }
681
682    pub fn finish(mut self) -> ImputationBuilder {
683        self.builder.method = ImputationMethod::GaussianProcess(self.config);
684        self.builder
685    }
686}
687
688/// Builder for ensemble imputation configuration
689pub struct EnsembleImputationBuilder {
690    builder: ImputationBuilder,
691    config: EnsembleImputationConfig,
692}
693
694impl EnsembleImputationBuilder {
695    fn new(builder: ImputationBuilder) -> Self {
696        Self {
697            builder,
698            config: EnsembleImputationConfig {
699                method: "random_forest".to_string(),
700                n_estimators: 100,
701                max_depth: None,
702                random_state: None,
703            },
704        }
705    }
706
707    pub fn random_forest(mut self) -> Self {
708        self.config.method = "random_forest".to_string();
709        self
710    }
711
712    pub fn gradient_boosting(mut self) -> Self {
713        self.config.method = "gradient_boosting".to_string();
714        self
715    }
716
717    pub fn extra_trees(mut self) -> Self {
718        self.config.method = "extra_trees".to_string();
719        self
720    }
721
722    pub fn n_estimators(mut self, n_estimators: usize) -> Self {
723        self.config.n_estimators = n_estimators;
724        self
725    }
726
727    pub fn max_depth(mut self, max_depth: usize) -> Self {
728        self.config.max_depth = Some(max_depth);
729        self
730    }
731
732    pub fn random_state(mut self, seed: u64) -> Self {
733        self.config.random_state = Some(seed);
734        self
735    }
736
737    pub fn finish(mut self) -> ImputationBuilder {
738        self.builder.method = ImputationMethod::Ensemble(self.config);
739        self.builder
740    }
741}
742
743/// Builder for deep learning imputation configuration
744pub struct DeepLearningBuilder {
745    builder: ImputationBuilder,
746    config: DeepLearningConfig,
747}
748
749impl DeepLearningBuilder {
750    fn new(builder: ImputationBuilder) -> Self {
751        Self {
752            builder,
753            config: DeepLearningConfig {
754                method: "autoencoder".to_string(),
755                hidden_dims: vec![128, 64, 32],
756                learning_rate: 0.001,
757                epochs: 100,
758                batch_size: 32,
759                device: "cpu".to_string(),
760            },
761        }
762    }
763
764    pub fn autoencoder(mut self) -> Self {
765        self.config.method = "autoencoder".to_string();
766        self
767    }
768
769    pub fn vae(mut self) -> Self {
770        self.config.method = "vae".to_string();
771        self
772    }
773
774    pub fn gan(mut self) -> Self {
775        self.config.method = "gan".to_string();
776        self
777    }
778
779    pub fn hidden_dims(mut self, dims: Vec<usize>) -> Self {
780        self.config.hidden_dims = dims;
781        self
782    }
783
784    pub fn learning_rate(mut self, lr: f64) -> Self {
785        self.config.learning_rate = lr;
786        self
787    }
788
789    pub fn epochs(mut self, epochs: usize) -> Self {
790        self.config.epochs = epochs;
791        self
792    }
793
794    pub fn batch_size(mut self, batch_size: usize) -> Self {
795        self.config.batch_size = batch_size;
796        self
797    }
798
799    pub fn device(mut self, device: &str) -> Self {
800        self.config.device = device.to_string();
801        self
802    }
803
804    pub fn finish(mut self) -> ImputationBuilder {
805        self.builder.method = ImputationMethod::DeepLearning(self.config);
806        self.builder
807    }
808}
809
810/// Main imputation pipeline that handles the complete workflow
811pub struct ImputationPipeline {
812    method: ImputationMethod,
813    validation: ValidationConfig,
814    preprocessing: PreprocessingConfig,
815    postprocessing: PostprocessingConfig,
816    parallel_config: Option<ParallelConfig>,
817}
818
819impl ImputationPipeline {
820    fn new(
821        method: ImputationMethod,
822        validation: ValidationConfig,
823        preprocessing: PreprocessingConfig,
824        postprocessing: PostprocessingConfig,
825        parallel_config: Option<ParallelConfig>,
826    ) -> SklResult<Self> {
827        Ok(Self {
828            method,
829            validation,
830            preprocessing,
831            postprocessing,
832            parallel_config,
833        })
834    }
835
836    /// Fit and transform in one step
837    pub fn fit_transform(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
838        // Complete pipeline with preprocessing, imputation, and postprocessing
839
840        // Step 1: Preprocessing
841        let X_preprocessed = X.to_owned();
842        let (means, stds) = self.apply_preprocessing(&X_preprocessed)?;
843
844        // Step 2: Imputation
845        let X_imputed = match &self.method {
846            ImputationMethod::Simple(config) => {
847                let imputer = SimpleImputer::new().strategy(config.strategy.clone());
848                let fitted = imputer.fit(X, &())?;
849                fitted.transform(X)
850            }
851            ImputationMethod::KNN(config) => {
852                if let Some(parallel_config) = &self.parallel_config {
853                    let imputer = ParallelKNNImputer::new()
854                        .n_neighbors(config.n_neighbors)
855                        .weights(config.weights.clone())
856                        .metric(config.metric.clone())
857                        .parallel_config(parallel_config.clone());
858                    let fitted = imputer.fit(X, &())?;
859                    fitted.transform(X)
860                } else {
861                    let imputer = KNNImputer::new()
862                        .n_neighbors(config.n_neighbors)
863                        .weights(config.weights.clone())
864                        .metric(config.metric.clone());
865                    let fitted = imputer.fit(X, &())?;
866                    fitted.transform(X)
867                }
868            }
869            _ => {
870                // For other methods, fall back to simple imputation for now
871                let imputer = SimpleImputer::new().strategy("mean".to_string());
872                let fitted = imputer.fit(&X_preprocessed.view(), &())?;
873                fitted.transform(&X_preprocessed.view())
874            }
875        }?;
876
877        // Step 3: Postprocessing
878        let X_final = self.apply_postprocessing(X_imputed, &means, &stds)?;
879
880        Ok(X_final)
881    }
882
883    /// Apply preprocessing transformations
884    fn apply_preprocessing(&self, X: &Array2<Float>) -> PreprocessingResult {
885        let mut X_proc = X.clone();
886        let mut means = None;
887        let mut stds = None;
888
889        // Normalize or scale if requested
890        if self.preprocessing.normalize || self.preprocessing.scale {
891            let (n_samples, n_features) = X_proc.dim();
892            let mut feature_means = Array1::zeros(n_features);
893            let mut feature_stds = Array1::ones(n_features);
894
895            for j in 0..n_features {
896                let col = X_proc.column(j);
897                let valid_values: Vec<Float> =
898                    col.iter().filter(|x| x.is_finite()).copied().collect();
899
900                if !valid_values.is_empty() {
901                    let mean = valid_values.iter().sum::<Float>() / valid_values.len() as Float;
902                    feature_means[j] = mean;
903
904                    if self.preprocessing.scale {
905                        let variance = valid_values
906                            .iter()
907                            .map(|x| (x - mean).powi(2))
908                            .sum::<Float>()
909                            / valid_values.len() as Float;
910                        feature_stds[j] = variance.sqrt().max(1e-8);
911                    }
912                }
913            }
914
915            // Apply normalization/scaling
916            for j in 0..n_features {
917                for i in 0..n_samples {
918                    if X_proc[[i, j]].is_finite() {
919                        X_proc[[i, j]] = (X_proc[[i, j]] - feature_means[j]) / feature_stds[j];
920                    }
921                }
922            }
923
924            means = Some(feature_means);
925            stds = Some(feature_stds);
926        }
927
928        Ok((means, stds))
929    }
930
931    /// Apply postprocessing transformations
932    fn apply_postprocessing(
933        &self,
934        mut X: Array2<Float>,
935        means: &Option<Array1<Float>>,
936        stds: &Option<Array1<Float>>,
937    ) -> SklResult<Array2<Float>> {
938        let (n_samples, n_features) = X.dim();
939
940        // Reverse normalization/scaling if it was applied
941        if let (Some(means_arr), Some(stds_arr)) = (means, stds) {
942            for j in 0..n_features {
943                for i in 0..n_samples {
944                    X[[i, j]] = X[[i, j]] * stds_arr[j] + means_arr[j];
945                }
946            }
947        }
948
949        // Clip values if requested
950        if let Some((min_val, max_val)) = self.postprocessing.clip_values {
951            for value in X.iter_mut() {
952                *value = value.clamp(min_val, max_val);
953            }
954        }
955
956        // Round to integers if requested
957        if self.postprocessing.round_integers {
958            for value in X.iter_mut() {
959                *value = value.round();
960            }
961        }
962
963        Ok(X)
964    }
965
966    /// Get configuration as JSON
967    #[cfg(feature = "serde")]
968    pub fn to_json(&self) -> SklResult<String> {
969        #[derive(serde::Serialize)]
970        struct PipelineConfig<'a> {
971            method: &'a ImputationMethod,
972            validation: &'a ValidationConfig,
973            preprocessing: &'a PreprocessingConfig,
974            postprocessing: &'a PostprocessingConfig,
975            parallel_config: &'a Option<ParallelConfig>,
976        }
977
978        let config = PipelineConfig {
979            method: &self.method,
980            validation: &self.validation,
981            preprocessing: &self.preprocessing,
982            postprocessing: &self.postprocessing,
983            parallel_config: &self.parallel_config,
984        };
985
986        serde_json::to_string_pretty(&config).map_err(|e| {
987            SklearsError::SerializationError(format!("Failed to serialize config: {}", e))
988        })
989    }
990
991    /// Get configuration as JSON (disabled without serde feature)
992    #[cfg(not(feature = "serde"))]
993    pub fn to_json(&self) -> SklResult<String> {
994        Err(SklearsError::NotImplemented(
995            "to_json requires serde feature".to_string(),
996        ))
997    }
998
999    /// Load configuration from JSON
1000    pub fn from_json(_json: &str) -> SklResult<Self> {
1001        // This would need a proper deserialization implementation
1002        Err(SklearsError::NotImplemented(
1003            "from_json not yet implemented".to_string(),
1004        ))
1005    }
1006}
1007
1008/// Convenience functions for quick imputation
1009pub mod quick {
1010    use super::*;
1011
1012    /// Quick mean imputation
1013    pub fn mean_impute(X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
1014        ImputationBuilder::new()
1015            .simple()
1016            .mean()
1017            .finish()
1018            .build()?
1019            .fit_transform(X)
1020    }
1021
1022    /// Quick KNN imputation
1023    pub fn knn_impute(X: &ArrayView2<'_, Float>, n_neighbors: usize) -> SklResult<Array2<Float>> {
1024        ImputationBuilder::new()
1025            .knn()
1026            .n_neighbors(n_neighbors)
1027            .finish()
1028            .build()?
1029            .fit_transform(X)
1030    }
1031
1032    /// Quick parallel KNN imputation
1033    pub fn parallel_knn_impute(
1034        X: &ArrayView2<'_, Float>,
1035        n_neighbors: usize,
1036    ) -> SklResult<Array2<Float>> {
1037        ImputationBuilder::new()
1038            .knn()
1039            .n_neighbors(n_neighbors)
1040            .finish()
1041            .parallel(None)
1042            .build()?
1043            .fit_transform(X)
1044    }
1045
1046    /// Quick iterative imputation
1047    pub fn iterative_impute(X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
1048        ImputationBuilder::new()
1049            .iterative()
1050            .max_iter(10)
1051            .finish()
1052            .build()?
1053            .fit_transform(X)
1054    }
1055
1056    /// Quick high-quality imputation with validation
1057    pub fn high_quality_impute(X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
1058        ImputationBuilder::new()
1059            .preset(ImputationPreset::HighQuality)
1060            .build()?
1061            .fit_transform(X)
1062    }
1063}
1064
1065/// Trait-based pluggable architecture for imputation modules
1066pub mod pluggable {
1067    use super::*;
1068
1069    /// Core trait for all imputation modules
1070    pub trait ImputationModule: Send + Sync {
1071        /// Get the name of this imputation module
1072        fn name(&self) -> &str;
1073
1074        /// Get the version of this module
1075        fn version(&self) -> &str;
1076
1077        /// Check if this module can handle the given data characteristics
1078        fn can_handle(&self, data_info: &DataCharacteristics) -> bool;
1079
1080        /// Get module-specific configuration schema
1081        fn config_schema(&self) -> ModuleConfigSchema;
1082
1083        /// Create an instance of this module with given configuration
1084        fn create_instance(&self, config: &ModuleConfig) -> SklResult<Box<dyn ImputationInstance>>;
1085
1086        /// Get module dependencies
1087        fn dependencies(&self) -> Vec<&str> {
1088            vec![]
1089        }
1090
1091        /// Get module priority (higher = preferred)
1092        fn priority(&self) -> i32 {
1093            0
1094        }
1095    }
1096
1097    /// Trait for actual imputation instances
1098    pub trait ImputationInstance: Send + Sync {
1099        /// Fit the imputation model
1100        fn fit(&mut self, X: &ArrayView2<Float>) -> SklResult<()>;
1101
1102        /// Transform data using the fitted model
1103        fn transform(&self, X: &ArrayView2<Float>) -> SklResult<Array2<Float>>;
1104
1105        /// Fit and transform in one step
1106        fn fit_transform(&mut self, X: &ArrayView2<Float>) -> SklResult<Array2<Float>> {
1107            self.fit(X)?;
1108            self.transform(X)
1109        }
1110
1111        /// Get uncertainty estimates if supported
1112        fn transform_with_uncertainty(
1113            &self,
1114            X: &ArrayView2<Float>,
1115        ) -> SklResult<(Array2<Float>, Option<Array2<Float>>)> {
1116            let result = self.transform(X)?;
1117            Ok((result, None))
1118        }
1119
1120        /// Check if this instance supports uncertainty quantification
1121        fn supports_uncertainty(&self) -> bool {
1122            false
1123        }
1124
1125        /// Get feature importance if supported
1126        fn feature_importance(&self) -> Option<Array1<Float>> {
1127            None
1128        }
1129
1130        /// Partial fit for streaming data
1131        fn partial_fit(&mut self, _X: &ArrayView2<Float>) -> SklResult<()> {
1132            Err(SklearsError::NotImplemented(
1133                "Partial fit not supported".to_string(),
1134            ))
1135        }
1136
1137        /// Check if partial fit is supported
1138        fn supports_partial_fit(&self) -> bool {
1139            false
1140        }
1141    }
1142
1143    /// Data characteristics for module selection
1144    #[derive(Debug, Clone)]
1145    pub struct DataCharacteristics {
1146        /// n_samples
1147        pub n_samples: usize,
1148        /// n_features
1149        pub n_features: usize,
1150        /// missing_rate
1151        pub missing_rate: f64,
1152        /// missing_pattern
1153        pub missing_pattern: MissingPatternType,
1154        /// data_types
1155        pub data_types: Vec<DataType>,
1156        /// has_categorical
1157        pub has_categorical: bool,
1158        /// has_temporal
1159        pub has_temporal: bool,
1160        /// is_sparse
1161        pub is_sparse: bool,
1162        /// memory_constraints
1163        pub memory_constraints: Option<usize>, // Max memory in bytes
1164    }
1165
1166    /// Missing pattern types
1167    #[derive(Debug, Clone, PartialEq)]
1168    pub enum MissingPatternType {
1169        /// MCAR
1170        MCAR,
1171        /// MAR
1172        MAR,
1173        /// MNAR
1174        MNAR,
1175        /// Unknown
1176        Unknown,
1177        /// Block
1178        Block,
1179        /// Monotone
1180        Monotone,
1181    }
1182
1183    /// Data types for features
1184    #[derive(Debug, Clone, PartialEq)]
1185    pub enum DataType {
1186        /// Continuous
1187        Continuous,
1188        /// Categorical
1189        Categorical,
1190        /// Ordinal
1191        Ordinal,
1192        /// Binary
1193        Binary,
1194        /// Count
1195        Count,
1196        /// Temporal
1197        Temporal,
1198        /// Text
1199        Text,
1200    }
1201
1202    /// Module configuration schema
1203    #[derive(Debug, Clone)]
1204    pub struct ModuleConfigSchema {
1205        /// parameters
1206        pub parameters: HashMap<String, ParameterSchema>,
1207        /// required_parameters
1208        pub required_parameters: Vec<String>,
1209        /// parameter_groups
1210        pub parameter_groups: Vec<ParameterGroup>,
1211    }
1212
1213    /// Parameter schema definition
1214    #[derive(Debug, Clone)]
1215    pub struct ParameterSchema {
1216        /// name
1217        pub name: String,
1218        /// parameter_type
1219        pub parameter_type: ParameterType,
1220        #[cfg(feature = "serde")]
1221        pub default_value: Option<serde_json::Value>,
1222        #[cfg(not(feature = "serde"))]
1223        pub default_value: Option<String>,
1224        /// valid_range
1225        pub valid_range: Option<ParameterRange>,
1226        /// description
1227        pub description: String,
1228        /// dependencies
1229        pub dependencies: Vec<String>,
1230    }
1231
1232    /// Parameter types
1233    #[derive(Debug, Clone)]
1234    pub enum ParameterType {
1235        /// Integer
1236        Integer,
1237        /// Float
1238        Float,
1239        /// String
1240        String,
1241        /// Boolean
1242        Boolean,
1243        /// Array
1244        Array(Box<ParameterType>),
1245        /// Enum
1246        Enum(Vec<String>),
1247        /// Object
1248        Object(HashMap<String, ParameterType>),
1249    }
1250
1251    /// Parameter value ranges
1252    #[derive(Debug, Clone)]
1253    pub enum ParameterRange {
1254        /// IntRange
1255        IntRange { min: Option<i64>, max: Option<i64> },
1256        /// FloatRange
1257        FloatRange { min: Option<f64>, max: Option<f64> },
1258        /// StringPattern
1259        StringPattern(String), // regex pattern
1260        /// ArrayLength
1261        ArrayLength {
1262            min: Option<usize>,
1263            max: Option<usize>,
1264        },
1265    }
1266
1267    /// Parameter groups for UI organization
1268    #[derive(Debug, Clone)]
1269    pub struct ParameterGroup {
1270        /// name
1271        pub name: String,
1272        /// description
1273        pub description: String,
1274        /// parameters
1275        pub parameters: Vec<String>,
1276        /// optional
1277        pub optional: bool,
1278    }
1279
1280    /// Module configuration
1281    #[derive(Debug, Clone)]
1282    pub struct ModuleConfig {
1283        #[cfg(feature = "serde")]
1284        pub parameters: HashMap<String, serde_json::Value>,
1285        #[cfg(not(feature = "serde"))]
1286        pub parameters: HashMap<String, String>,
1287    }
1288
1289    /// Registry for managing imputation modules
1290    pub struct ModuleRegistry {
1291        modules: HashMap<String, Box<dyn ImputationModule>>,
1292        aliases: HashMap<String, String>,
1293    }
1294
1295    impl Default for ModuleRegistry {
1296        fn default() -> Self {
1297            Self::new()
1298        }
1299    }
1300
1301    impl ModuleRegistry {
1302        pub fn new() -> Self {
1303            Self {
1304                modules: HashMap::new(),
1305                aliases: HashMap::new(),
1306            }
1307        }
1308
1309        /// Register a new imputation module
1310        pub fn register_module(&mut self, module: Box<dyn ImputationModule>) -> SklResult<()> {
1311            let name = module.name().to_string();
1312            if self.modules.contains_key(&name) {
1313                return Err(SklearsError::InvalidInput(format!(
1314                    "Module '{}' already registered",
1315                    name
1316                )));
1317            }
1318            self.modules.insert(name, module);
1319            Ok(())
1320        }
1321
1322        /// Register an alias for a module
1323        pub fn register_alias(&mut self, alias: String, module_name: String) -> SklResult<()> {
1324            if !self.modules.contains_key(&module_name) {
1325                return Err(SklearsError::InvalidInput(format!(
1326                    "Module '{}' not found",
1327                    module_name
1328                )));
1329            }
1330            self.aliases.insert(alias, module_name);
1331            Ok(())
1332        }
1333
1334        /// Get a module by name or alias
1335        pub fn get_module(&self, name: &str) -> Option<&dyn ImputationModule> {
1336            if let Some(actual_name) = self.aliases.get(name) {
1337                self.modules.get(actual_name).map(|m| m.as_ref())
1338            } else {
1339                self.modules.get(name).map(|m| m.as_ref())
1340            }
1341        }
1342
1343        /// List all available modules
1344        pub fn list_modules(&self) -> Vec<&str> {
1345            self.modules.keys().map(|s| s.as_str()).collect()
1346        }
1347
1348        /// Find suitable modules for given data characteristics
1349        pub fn find_suitable_modules(
1350            &self,
1351            data_info: &DataCharacteristics,
1352        ) -> Vec<&dyn ImputationModule> {
1353            let mut suitable: Vec<_> = self
1354                .modules
1355                .values()
1356                .filter(|m| m.can_handle(data_info))
1357                .map(|m| m.as_ref())
1358                .collect();
1359
1360            // Sort by priority (descending)
1361            suitable.sort_by_key(|b| std::cmp::Reverse(b.priority()));
1362            suitable
1363        }
1364
1365        /// Get recommended module for data characteristics
1366        pub fn recommend_module(
1367            &self,
1368            data_info: &DataCharacteristics,
1369        ) -> Option<&dyn ImputationModule> {
1370            self.find_suitable_modules(data_info).into_iter().next()
1371        }
1372    }
1373
1374    /// Pipeline composer for combining multiple modules
1375    pub struct PipelineComposer {
1376        stages: Vec<PipelineStage>,
1377        registry: ModuleRegistry,
1378    }
1379
1380    /// A stage in the imputation pipeline
1381    #[derive(Debug, Clone)]
1382    pub struct PipelineStage {
1383        /// name
1384        pub name: String,
1385        /// module_name
1386        pub module_name: String,
1387        /// config
1388        pub config: ModuleConfig,
1389        /// condition
1390        pub condition: Option<StageCondition>,
1391    }
1392
1393    /// Conditions for pipeline stage execution
1394    #[derive(Debug, Clone)]
1395    pub enum StageCondition {
1396        /// MissingRate
1397        MissingRate(f64), // Execute only if missing rate > threshold
1398        /// FeatureCount
1399        FeatureCount(usize), // Execute only if n_features > threshold
1400        /// DataType
1401        DataType(DataType), // Execute only for specific data types
1402        /// Custom
1403        Custom(String), // Custom condition expression
1404    }
1405
1406    impl PipelineComposer {
1407        pub fn new(registry: ModuleRegistry) -> Self {
1408            Self {
1409                stages: Vec::new(),
1410                registry,
1411            }
1412        }
1413
1414        /// Add a stage to the pipeline
1415        pub fn add_stage(&mut self, stage: PipelineStage) -> &mut Self {
1416            self.stages.push(stage);
1417            self
1418        }
1419
1420        /// Add a conditional stage
1421        pub fn add_conditional_stage(
1422            &mut self,
1423            stage: PipelineStage,
1424            condition: StageCondition,
1425        ) -> &mut Self {
1426            let mut stage = stage;
1427            stage.condition = Some(condition);
1428            self.stages.push(stage);
1429            self
1430        }
1431
1432        /// Build the complete pipeline
1433        pub fn build(&self) -> SklResult<ComposedPipeline> {
1434            let mut instances = Vec::new();
1435
1436            for stage in &self.stages {
1437                let module = self
1438                    .registry
1439                    .get_module(&stage.module_name)
1440                    .ok_or_else(|| {
1441                        SklearsError::InvalidInput(format!(
1442                            "Module '{}' not found",
1443                            stage.module_name
1444                        ))
1445                    })?;
1446
1447                let instance = module.create_instance(&stage.config)?;
1448                instances.push((stage.clone(), instance));
1449            }
1450
1451            Ok(ComposedPipeline { stages: instances })
1452        }
1453    }
1454
1455    /// A composed pipeline of imputation modules
1456    pub struct ComposedPipeline {
1457        stages: Vec<(PipelineStage, Box<dyn ImputationInstance>)>,
1458    }
1459
1460    impl ComposedPipeline {
1461        pub fn fit_transform(&mut self, X: &ArrayView2<Float>) -> SklResult<Array2<Float>> {
1462            let mut data = X.to_owned();
1463            let data_info = self.analyze_data(&data.view())?;
1464
1465            for (stage, instance) in &mut self.stages {
1466                // Check stage condition
1467                if let Some(condition) = &stage.condition {
1468                    if !Self::evaluate_condition_static(condition, &data_info)? {
1469                        continue;
1470                    }
1471                }
1472
1473                data = instance.fit_transform(&data.view())?;
1474            }
1475
1476            Ok(data)
1477        }
1478
1479        fn analyze_data(&self, X: &ArrayView2<Float>) -> SklResult<DataCharacteristics> {
1480            let (n_samples, n_features) = X.dim();
1481            let missing_count = X.iter().filter(|&&x| (x).is_nan()).count();
1482            let missing_rate = missing_count as f64 / (n_samples * n_features) as f64;
1483
1484            Ok(DataCharacteristics {
1485                n_samples,
1486                n_features,
1487                missing_rate,
1488                missing_pattern: MissingPatternType::Unknown, // Would need more analysis
1489                data_types: vec![DataType::Continuous; n_features], // Default assumption
1490                has_categorical: false,
1491                has_temporal: false,
1492                is_sparse: missing_rate > 0.5,
1493                memory_constraints: None,
1494            })
1495        }
1496
1497        fn evaluate_condition_static(
1498            condition: &StageCondition,
1499            data_info: &DataCharacteristics,
1500        ) -> SklResult<bool> {
1501            Ok(match condition {
1502                StageCondition::MissingRate(threshold) => data_info.missing_rate > *threshold,
1503                StageCondition::FeatureCount(threshold) => data_info.n_features > *threshold,
1504                StageCondition::DataType(data_type) => data_info.data_types.contains(data_type),
1505                StageCondition::Custom(_) => true, // Would need expression evaluator
1506            })
1507        }
1508    }
1509
1510    /// Middleware for imputation pipelines
1511    pub trait ImputationMiddleware: Send + Sync {
1512        /// Process data before imputation
1513        fn before_imputation(&self, X: &ArrayView2<Float>) -> SklResult<Array2<Float>> {
1514            Ok(X.to_owned())
1515        }
1516
1517        /// Process data after imputation
1518        fn after_imputation(&self, X: &ArrayView2<Float>) -> SklResult<Array2<Float>> {
1519            Ok(X.to_owned())
1520        }
1521
1522        /// Handle errors during imputation
1523        fn on_error(&self, error: &SklearsError) -> SklResult<()> {
1524            Err(error.clone())
1525        }
1526    }
1527
1528    /// Validation middleware
1529    pub struct ValidationMiddleware {
1530        /// validate_completeness
1531        pub validate_completeness: bool,
1532        /// validate_ranges
1533        pub validate_ranges: bool,
1534        /// expected_ranges
1535        pub expected_ranges: Option<HashMap<usize, (f64, f64)>>,
1536    }
1537
1538    impl ImputationMiddleware for ValidationMiddleware {
1539        fn after_imputation(&self, X: &ArrayView2<Float>) -> SklResult<Array2<Float>> {
1540            if self.validate_completeness && X.iter().any(|&x| (x).is_nan()) {
1541                return Err(SklearsError::InvalidInput(
1542                    "Imputation failed: missing values remain".to_string(),
1543                ));
1544            }
1545
1546            if self.validate_ranges {
1547                if let Some(ranges) = &self.expected_ranges {
1548                    for ((_, j), &value) in X.indexed_iter() {
1549                        if let Some((min_val, max_val)) = ranges.get(&j) {
1550                            let val = value;
1551                            if val < *min_val || val > *max_val {
1552                                return Err(SklearsError::InvalidInput(
1553                                    format!("Imputed value {} out of expected range [{}, {}] for feature {}", 
1554                                           val, min_val, max_val, j)
1555                                ));
1556                            }
1557                        }
1558                    }
1559                }
1560            }
1561
1562            Ok(X.to_owned())
1563        }
1564    }
1565
1566    /// Logging middleware
1567    pub struct LoggingMiddleware {
1568        /// log_level
1569        pub log_level: LogLevel,
1570        /// log_performance
1571        pub log_performance: bool,
1572    }
1573
1574    #[derive(Debug, Clone)]
1575    pub enum LogLevel {
1576        /// Debug
1577        Debug,
1578        /// Info
1579        Info,
1580        /// Warn
1581        Warn,
1582        /// Error
1583        Error,
1584    }
1585
1586    impl ImputationMiddleware for LoggingMiddleware {
1587        fn before_imputation(&self, X: &ArrayView2<Float>) -> SklResult<Array2<Float>> {
1588            if matches!(self.log_level, LogLevel::Debug | LogLevel::Info) {
1589                let missing_count = X.iter().filter(|&&x| (x).is_nan()).count();
1590                println!(
1591                    "Starting imputation: {} missing values in {}x{} matrix",
1592                    missing_count,
1593                    X.nrows(),
1594                    X.ncols()
1595                );
1596            }
1597            Ok(X.to_owned())
1598        }
1599
1600        fn after_imputation(&self, X: &ArrayView2<Float>) -> SklResult<Array2<Float>> {
1601            if matches!(self.log_level, LogLevel::Debug | LogLevel::Info) {
1602                let remaining_missing = X.iter().filter(|&&x| (x).is_nan()).count();
1603                println!(
1604                    "Imputation completed: {} missing values remaining",
1605                    remaining_missing
1606                );
1607            }
1608            Ok(X.to_owned())
1609        }
1610    }
1611}
1612
1613#[allow(non_snake_case)]
1614#[cfg(test)]
1615mod tests {
1616    use super::*;
1617    use approx::assert_abs_diff_eq;
1618
1619    #[test]
1620    fn test_fluent_api_simple_imputation() {
1621        let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, f64::NAN, 4.0, 5.0, 6.0]).unwrap();
1622
1623        let pipeline = ImputationBuilder::new()
1624            .simple()
1625            .mean()
1626            .finish()
1627            .build()
1628            .unwrap();
1629
1630        let result = pipeline.fit_transform(&data.view()).unwrap();
1631
1632        // Should have no missing values
1633        assert!(!result.iter().any(|&x| (x).is_nan()));
1634
1635        // Non-missing values should be preserved
1636        assert_abs_diff_eq!(result[[0, 0]], 1.0, epsilon = 1e-10);
1637        assert_abs_diff_eq!(result[[0, 1]], 2.0, epsilon = 1e-10);
1638    }
1639
1640    #[test]
1641    fn test_fluent_api_knn_imputation() {
1642        let data =
1643            Array2::from_shape_vec((4, 2), vec![1.0, 2.0, f64::NAN, 4.0, 5.0, 6.0, 7.0, 8.0])
1644                .unwrap();
1645
1646        let pipeline = ImputationBuilder::new()
1647            .knn()
1648            .n_neighbors(2)
1649            .distance_weights()
1650            .finish()
1651            .build()
1652            .unwrap();
1653
1654        let result = pipeline.fit_transform(&data.view()).unwrap();
1655
1656        // Should have no missing values
1657        assert!(!result.iter().any(|&x| (x).is_nan()));
1658    }
1659
1660    #[test]
1661    fn test_preset_configurations() {
1662        let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, f64::NAN, 4.0, 5.0, 6.0]).unwrap();
1663
1664        // Test fast preset
1665        let pipeline = ImputationBuilder::new()
1666            .preset(ImputationPreset::Fast)
1667            .build()
1668            .unwrap();
1669
1670        let result = pipeline.fit_transform(&data.view()).unwrap();
1671        assert!(!result.iter().any(|&x| (x).is_nan()));
1672
1673        // Test balanced preset
1674        let pipeline = ImputationBuilder::new()
1675            .preset(ImputationPreset::Balanced)
1676            .build()
1677            .unwrap();
1678
1679        let result = pipeline.fit_transform(&data.view()).unwrap();
1680        assert!(!result.iter().any(|&x| (x).is_nan()));
1681    }
1682
1683    #[test]
1684    fn test_quick_functions() {
1685        let data = Array2::from_shape_vec((3, 2), vec![1.0, 2.0, f64::NAN, 4.0, 5.0, 6.0]).unwrap();
1686
1687        // Test quick mean imputation
1688        let result = quick::mean_impute(&data.view()).unwrap();
1689        assert!(!result.iter().any(|&x| (x).is_nan()));
1690
1691        // Test quick KNN imputation
1692        let result = quick::knn_impute(&data.view(), 2).unwrap();
1693        assert!(!result.iter().any(|&x| (x).is_nan()));
1694    }
1695
1696    #[test]
1697    fn test_method_chaining() {
1698        let builder = ImputationBuilder::new()
1699            .normalize()
1700            .cross_validate(5)
1701            .with_uncertainty()
1702            .parallel(None);
1703
1704        // Should be able to chain methods without issues
1705        assert!(builder.validation.cross_validation);
1706        assert_eq!(builder.validation.cv_folds, 5);
1707        assert!(builder.preprocessing.normalize);
1708        assert!(builder.postprocessing.add_uncertainty_estimates);
1709        assert!(builder.parallel_config.is_some());
1710    }
1711}