sklears_impute/
lib.rs

1#![allow(dead_code)]
2#![allow(non_snake_case)]
3#![allow(missing_docs)]
4#![allow(deprecated)]
5#![allow(clippy::all)]
6#![allow(clippy::pedantic)]
7#![allow(clippy::nursery)]
8//! Missing value imputation strategies
9//!
10//! This module provides various strategies for handling missing values in datasets.
11//! It includes simple imputation methods as well as more sophisticated approaches
12//! like iterative imputation, KNN-based imputation, matrix factorization, and Bayesian methods.
13
14// #![warn(missing_docs)]
15
16// Re-export the main modules
17pub mod advanced;
18pub mod approximate;
19pub mod bayesian;
20pub mod benchmarks;
21pub mod categorical;
22pub mod core;
23pub mod dimensionality;
24pub mod distributed;
25pub mod domain_specific;
26pub mod ensemble;
27pub mod fluent_api;
28pub mod independence;
29pub mod information_theoretic;
30pub mod kernel;
31pub mod memory_profiler;
32pub mod mixed_type;
33pub mod multivariate;
34pub mod neural;
35pub mod out_of_core;
36pub mod parallel;
37pub mod sampling;
38pub mod simd_ops;
39pub mod simple;
40// TODO: Temporarily disabled until ndarray 0.17 HRTB trait bound issues are resolved
41// pub mod testing_pipeline;
42pub mod timeseries;
43pub mod type_safe;
44// TODO: Temporarily disabled until ndarray 0.17 HRTB trait bound issues are resolved
45// pub mod validation;
46pub mod visualization;
47
48// Re-export commonly used types and functions for convenience
49pub use advanced::{
50    analyze_breakdown_point, BreakdownPointAnalysis, CopulaImputer, CopulaParameters, EmpiricalCDF,
51    EmpiricalQuantile, FactorAnalysisImputer, KDEImputer, LocalLinearImputer, LowessImputer,
52    MultivariateNormalImputer, RobustRegressionImputer, TrimmedMeanImputer,
53};
54pub use bayesian::{
55    BayesianLinearImputer, BayesianLogisticImputer, BayesianModel, BayesianModelAveraging,
56    BayesianModelAveragingResults, BayesianMultipleImputer, ConvergenceDiagnostics,
57    HierarchicalBayesianImputer, HierarchicalBayesianSample, PooledResults,
58    VariationalBayesImputer,
59};
60pub use benchmarks::{
61    AccuracyMetrics, BenchmarkDatasetGenerator, BenchmarkSuite, ImputationBenchmark,
62    ImputationComparison, MissingPattern, MissingPatternGenerator,
63};
64pub use categorical::{
65    AssociationRule, AssociationRuleImputer, CategoricalClusteringImputer,
66    CategoricalRandomForestImputer, HotDeckImputer, Item, Itemset,
67};
68pub use core::{
69    utils, ConvergenceInfo, ImputationError, ImputationMetadata, ImputationOutputWithMetadata,
70    ImputationResult, Imputer, ImputerConfig, MissingPatternHandler, QualityAssessment,
71    StatisticalValidator, TrainableImputer, TransformableImputer,
72};
73pub use dimensionality::{
74    CompressedSensingImputer, ICAImputer, ManifoldLearningImputer, PCAImputer, SparseImputer,
75};
76pub use domain_specific::{
77    CreditScoringImputer, DemographicDataImputer, EconomicIndicatorImputer,
78    FinancialTimeSeriesImputer, GenomicImputer, LongitudinalStudyImputer, MetabolomicsImputer,
79    MissingResponseHandler, PhylogeneticImputer, PortfolioDataImputer, ProteinExpressionImputer,
80    RiskFactorImputer, SingleCellRNASeqImputer, SocialNetworkImputer, SurveyDataImputer,
81};
82pub use ensemble::{ExtraTreesImputer, GradientBoostingImputer, RandomForestImputer};
83pub use fluent_api::{
84    pluggable::{
85        ComposedPipeline, DataCharacteristics, DataType, ImputationInstance, ImputationMiddleware,
86        ImputationModule, LogLevel, LoggingMiddleware, MissingPatternType, ModuleConfig,
87        ModuleConfigSchema, ModuleRegistry, ParameterGroup, ParameterRange, ParameterSchema,
88        ParameterType, PipelineComposer, PipelineStage, StageCondition, ValidationMiddleware,
89    },
90    quick, DeepLearningBuilder, EnsembleImputationBuilder, GaussianProcessBuilder,
91    ImputationBuilder, ImputationMethod, ImputationPipeline, ImputationPreset,
92    IterativeImputationBuilder, KNNImputationBuilder, PostprocessingConfig, PreprocessingConfig,
93    SimpleImputationBuilder, ValidationConfig,
94};
95pub use independence::{
96    chi_square_independence_test, cramers_v_association_test, fisher_exact_independence_test,
97    kolmogorov_smirnov_independence_test, pattern_sensitivity_analysis,
98    run_independence_test_suite, sensitivity_analysis, ChiSquareTestResult, CramersVTestResult,
99    FisherExactTestResult, IndependenceTestSuite, KolmogorovSmirnovTestResult, MARSensitivityCase,
100    MNARSensitivityCase, MissingDataAssessment, PatternSensitivityResult, RobustnessSummary,
101    SensitivityAnalysisResult,
102};
103pub use information_theoretic::{
104    EntropyImputer, InformationGainImputer, MDLImputer, MaxEntropyImputer, MutualInformationImputer,
105};
106pub use kernel::{
107    GPPredictionResult, GaussianProcessImputer, KernelRidgeImputer, ReproducingKernelImputer,
108    SVRImputer,
109};
110pub use memory_profiler::{
111    ImputationMemoryBenchmark, MemoryProfiler, MemoryProfilingResult, MemoryStats,
112};
113pub use mixed_type::{
114    HeterogeneousImputer, MixedTypeMICEImputer, MixedTypeMultipleImputationResults, OrdinalImputer,
115    VariableMetadata, VariableParameters, VariableType,
116};
117pub use multivariate::CanonicalCorrelationImputer;
118pub use neural::{
119    AutoencoderImputer, DiffusionImputer, GANImputer, MLPImputer, NeuralODEImputer,
120    NormalizingFlowImputer, VAEImputer,
121};
122pub use parallel::{
123    AdaptiveStreamingImputer, MemoryEfficientImputer, MemoryMappedData, MemoryOptimizedImputer,
124    MemoryStrategy, OnlineStatistics, ParallelConfig, ParallelIterativeImputer, ParallelKNNImputer,
125    SharedDataRef, SparseMatrix, StreamingImputer,
126};
127pub use simd_ops::{
128    SimdDistanceCalculator, SimdImputationOps, SimdKMeans, SimdMatrixOps, SimdStatistics,
129};
130pub use simple::{MissingIndicator, SimpleImputer};
131pub use timeseries::{
132    ARIMAImputer, KalmanFilterImputer, SeasonalDecompositionImputer, StateSpaceImputer,
133};
134pub use type_safe::{
135    ClassifiedArray, Complete, CompleteArray, FixedSizeArray, FixedSizeValidation,
136    ImputationQualityMetrics, MARArray, MCARArray, MNARArray, MissingMechanism,
137    MissingPatternValidator, MissingValueDetector, NaNDetector, SentinelDetector,
138    TypeSafeImputation, TypeSafeMeanImputer, TypeSafeMissingOps, TypedArray, UnknownMechanism,
139    WithMissing, MAR, MCAR, MNAR,
140};
141// TODO: Temporarily disabled until ndarray 0.17 HRTB trait bound issues are resolved
142// pub use validation::{
143//     validate_with_holdout, CrossValidationResults, CrossValidationStrategy, HoldOutValidator,
144//     ImputationCrossValidator, ImputationMetrics, MissingDataPattern, SyntheticMissingValidator,
145// };
146pub use visualization::{
147    create_completeness_matrix, create_missing_correlation_heatmap,
148    create_missing_distribution_plot, create_missing_pattern_plot, export_correlation_csv,
149    export_missing_pattern_csv, generate_missing_summary_stats, CompletenessMatrix,
150    MissingCorrelationHeatmap, MissingDistributionPlot, MissingPatternPlot,
151};
152
153// New modules - Advanced Algorithms
154pub use approximate::{
155    ApproximateConfig, ApproximateKNNImputer, ApproximateSimpleImputer, ApproximationStrategy,
156    LocalityHashTable, SketchingImputer,
157};
158pub use distributed::{
159    CommunicationStrategy, DistributedConfig, DistributedKNNImputer, DistributedSimpleImputer,
160    DistributedWorker, ImputationCoordinator,
161};
162pub use out_of_core::{
163    IndexType, MemoryManager, NeighborIndex, OutOfCoreConfig, OutOfCoreKNNImputer,
164    OutOfCoreSimpleImputer, PrefetchStrategy,
165};
166pub use sampling::{
167    AdaptiveSamplingImputer, ImportanceSamplingImputer, ParametricDistribution,
168    ProposalDistribution, QuasiSequenceType, SampleDistribution, SamplingConfig,
169    SamplingSimpleImputer, SamplingStrategy, StratifiedSamplingImputer, WeightFunction,
170};
171// TODO: Temporarily disabled until ndarray 0.17 HRTB trait bound issues are resolved
172// pub use testing_pipeline::{
173//     AutomatedTestPipeline, CompletedTestCase, PerformanceBenchmarks, QualityThresholds, TestCase,
174//     TestDataset, TestPipelineConfig, TestResults, TestRunner, TestStatus,
175// };
176
177// ✅ SciRS2 Policy compliant imports
178use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
179use sklears_core::{
180    error::{Result as SklResult, SklearsError},
181    traits::{Estimator, Fit, Transform, Untrained},
182    types::Float,
183};
184use std::collections::HashMap;
185
186// Legacy implementations (to be moved to separate modules when fully refactored)
187
188/// K-Nearest Neighbors Imputer
189///
190/// Imputation for completing missing values using k-Nearest Neighbors.
191/// Each missing value is imputed using values from k nearest neighbors
192/// found in the training set.
193///
194/// # Parameters
195///
196/// * `n_neighbors` - Number of neighboring samples to use for imputation
197/// * `weights` - Weight function used in prediction ('uniform' or 'distance')
198/// * `metric` - Distance metric for searching neighbors ('nan_euclidean')
199/// * `missing_values` - The placeholder for missing values (NaN by default)
200/// * `add_indicator` - Whether to add a missing value indicator
201///
202/// # Examples
203///
204/// ```
205/// use sklears_impute::KNNImputer;
206/// use sklears_core::traits::{Transform, Fit};
207/// use scirs2_core::ndarray::array;
208///
209/// let X = array![[1.0, 2.0, 3.0], [4.0, f64::NAN, 6.0], [7.0, 8.0, 9.0]];
210///
211/// let imputer = KNNImputer::new()
212///     .n_neighbors(2);
213/// let fitted = imputer.fit(&X.view(), &()).unwrap();
214/// let X_imputed = fitted.transform(&X.view()).unwrap();
215/// ```
216#[derive(Debug, Clone)]
217pub struct KNNImputer<S = Untrained> {
218    state: S,
219    n_neighbors: usize,
220    weights: String,
221    metric: String,
222    missing_values: f64,
223    add_indicator: bool,
224}
225
226/// Trained state for KNNImputer
227#[derive(Debug, Clone)]
228pub struct KNNImputerTrained {
229    X_train_: Array2<f64>,
230    n_features_in_: usize,
231}
232
233impl KNNImputer<Untrained> {
234    /// Create a new KNNImputer instance
235    pub fn new() -> Self {
236        Self {
237            state: Untrained,
238            n_neighbors: 5,
239            weights: "uniform".to_string(),
240            metric: "nan_euclidean".to_string(),
241            missing_values: f64::NAN,
242            add_indicator: false,
243        }
244    }
245
246    /// Set the number of neighbors
247    pub fn n_neighbors(mut self, n_neighbors: usize) -> Self {
248        self.n_neighbors = n_neighbors;
249        self
250    }
251
252    /// Set the weight function
253    pub fn weights(mut self, weights: String) -> Self {
254        self.weights = weights;
255        self
256    }
257
258    /// Set the distance metric
259    pub fn metric(mut self, metric: String) -> Self {
260        self.metric = metric;
261        self
262    }
263
264    /// Set the missing values placeholder
265    pub fn missing_values(mut self, missing_values: f64) -> Self {
266        self.missing_values = missing_values;
267        self
268    }
269
270    /// Set whether to add missing indicator
271    pub fn add_indicator(mut self, add_indicator: bool) -> Self {
272        self.add_indicator = add_indicator;
273        self
274    }
275
276    fn is_missing(&self, value: f64) -> bool {
277        if self.missing_values.is_nan() {
278            value.is_nan()
279        } else {
280            (value - self.missing_values).abs() < f64::EPSILON
281        }
282    }
283}
284
285impl Default for KNNImputer<Untrained> {
286    fn default() -> Self {
287        Self::new()
288    }
289}
290
291impl Estimator for KNNImputer<Untrained> {
292    type Config = ();
293    type Error = SklearsError;
294    type Float = Float;
295
296    fn config(&self) -> &Self::Config {
297        &()
298    }
299}
300
301impl Fit<ArrayView2<'_, Float>, ()> for KNNImputer<Untrained> {
302    type Fitted = KNNImputer<KNNImputerTrained>;
303
304    #[allow(non_snake_case)]
305    fn fit(self, X: &ArrayView2<'_, Float>, _y: &()) -> SklResult<Self::Fitted> {
306        let X = X.mapv(|x| x);
307        let (_, n_features) = X.dim();
308
309        Ok(KNNImputer {
310            state: KNNImputerTrained {
311                X_train_: X.clone(),
312                n_features_in_: n_features,
313            },
314            n_neighbors: self.n_neighbors,
315            weights: self.weights,
316            metric: self.metric,
317            missing_values: self.missing_values,
318            add_indicator: self.add_indicator,
319        })
320    }
321}
322
323impl Transform<ArrayView2<'_, Float>, Array2<Float>> for KNNImputer<KNNImputerTrained> {
324    #[allow(non_snake_case)]
325    fn transform(&self, X: &ArrayView2<'_, Float>) -> SklResult<Array2<Float>> {
326        let X = X.mapv(|x| x);
327        let (n_samples, n_features) = X.dim();
328
329        if n_features != self.state.n_features_in_ {
330            return Err(SklearsError::InvalidInput(format!(
331                "Number of features {} does not match training features {}",
332                n_features, self.state.n_features_in_
333            )));
334        }
335
336        let mut X_imputed = X.clone();
337        let X_train = &self.state.X_train_;
338
339        for i in 0..n_samples {
340            for j in 0..n_features {
341                if self.is_missing(X_imputed[[i, j]]) {
342                    // Find k nearest neighbors
343                    let mut distances: Vec<(f64, usize)> = Vec::new();
344
345                    for train_idx in 0..X_train.nrows() {
346                        let distance =
347                            self.nan_euclidean_distance(X_imputed.row(i), X_train.row(train_idx));
348                        distances.push((distance, train_idx));
349                    }
350
351                    distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
352
353                    // Take k nearest neighbors that have this feature observed
354                    let mut neighbor_values = Vec::new();
355                    let mut weights = Vec::new();
356
357                    for &(distance, train_idx) in distances.iter().take(self.n_neighbors * 3) {
358                        if !self.is_missing(X_train[[train_idx, j]]) {
359                            neighbor_values.push(X_train[[train_idx, j]]);
360                            let weight = match self.weights.as_str() {
361                                "distance" => {
362                                    if distance > 0.0 {
363                                        1.0 / distance
364                                    } else {
365                                        1e6
366                                    }
367                                }
368                                _ => 1.0,
369                            };
370                            weights.push(weight);
371
372                            if neighbor_values.len() >= self.n_neighbors {
373                                break;
374                            }
375                        }
376                    }
377
378                    if neighbor_values.is_empty() {
379                        // Fallback to mean of training data for this feature
380                        let column = X_train.column(j);
381                        let valid_values: Vec<f64> = column
382                            .iter()
383                            .filter(|&&x| !self.is_missing(x))
384                            .cloned()
385                            .collect();
386
387                        if !valid_values.is_empty() {
388                            X_imputed[[i, j]] =
389                                valid_values.iter().sum::<f64>() / valid_values.len() as f64;
390                        }
391                    } else {
392                        // Weighted average of neighbor values
393                        let total_weight: f64 = weights.iter().sum();
394                        let weighted_sum: f64 = neighbor_values
395                            .iter()
396                            .zip(weights.iter())
397                            .map(|(&value, &weight)| value * weight)
398                            .sum();
399
400                        X_imputed[[i, j]] = if total_weight > 0.0 {
401                            weighted_sum / total_weight
402                        } else {
403                            neighbor_values.iter().sum::<f64>() / neighbor_values.len() as f64
404                        };
405                    }
406                }
407            }
408        }
409
410        Ok(X_imputed.mapv(|x| x as Float))
411    }
412}
413
414impl KNNImputer<KNNImputerTrained> {
415    fn is_missing(&self, value: f64) -> bool {
416        if self.missing_values.is_nan() {
417            value.is_nan()
418        } else {
419            (value - self.missing_values).abs() < f64::EPSILON
420        }
421    }
422
423    fn nan_euclidean_distance(&self, row1: ArrayView1<f64>, row2: ArrayView1<f64>) -> f64 {
424        let mut sum_sq = 0.0;
425        let mut valid_count = 0;
426
427        for (x1, x2) in row1.iter().zip(row2.iter()) {
428            if !self.is_missing(*x1) && !self.is_missing(*x2) {
429                sum_sq += (x1 - x2).powi(2);
430                valid_count += 1;
431            }
432        }
433
434        if valid_count > 0 {
435            (sum_sq / valid_count as f64).sqrt()
436        } else {
437            f64::INFINITY
438        }
439    }
440}
441
442/// Analysis functions for missing data patterns
443#[allow(non_snake_case)]
444pub fn analyze_missing_patterns(
445    X: &ArrayView2<'_, Float>,
446    missing_values: f64,
447) -> SklResult<HashMap<String, Vec<usize>>> {
448    let X = X.mapv(|x| x);
449    let (n_samples, n_features) = X.dim();
450    let mut patterns = HashMap::new();
451
452    for i in 0..n_samples {
453        let mut pattern = Vec::new();
454        for j in 0..n_features {
455            let is_missing = if missing_values.is_nan() {
456                X[[i, j]].is_nan()
457            } else {
458                (X[[i, j]] - missing_values).abs() < f64::EPSILON
459            };
460            if is_missing {
461                pattern.push(j);
462            }
463        }
464
465        let pattern_key = format!("{:?}", pattern);
466        patterns.entry(pattern_key).or_insert_with(Vec::new).push(i);
467    }
468
469    Ok(patterns)
470}
471
472/// Compute missing correlation matrix
473#[allow(non_snake_case)]
474pub fn missing_correlation_matrix(
475    X: &ArrayView2<'_, Float>,
476    missing_values: f64,
477) -> SklResult<Array2<f64>> {
478    let X = X.mapv(|x| x);
479    let (n_samples, n_features) = X.dim();
480
481    // Create missing indicators
482    let mut missing_indicators = Array2::zeros((n_samples, n_features));
483    for i in 0..n_samples {
484        for j in 0..n_features {
485            let is_missing = if missing_values.is_nan() {
486                X[[i, j]].is_nan()
487            } else {
488                (X[[i, j]] - missing_values).abs() < f64::EPSILON
489            };
490            missing_indicators[[i, j]] = if is_missing { 1.0 } else { 0.0 };
491        }
492    }
493
494    // Compute correlation matrix
495    let mut correlation_matrix = Array2::zeros((n_features, n_features));
496    for i in 0..n_features {
497        for j in 0..n_features {
498            if i == j {
499                correlation_matrix[[i, j]] = 1.0;
500            } else {
501                let col_i = missing_indicators.column(i);
502                let col_j = missing_indicators.column(j);
503                correlation_matrix[[i, j]] =
504                    compute_correlation(&col_i.to_owned(), &col_j.to_owned());
505            }
506        }
507    }
508
509    Ok(correlation_matrix)
510}
511
512/// Compute missing completeness matrix
513#[allow(non_snake_case)]
514pub fn missing_completeness_matrix(
515    X: &ArrayView2<'_, Float>,
516    missing_values: f64,
517) -> SklResult<Array2<f64>> {
518    let X = X.mapv(|x| x);
519    let (n_samples, n_features) = X.dim();
520
521    let mut completeness_matrix = Array2::zeros((n_features, n_features));
522
523    for i in 0..n_features {
524        for j in 0..n_features {
525            let mut joint_observed = 0;
526
527            for sample_idx in 0..n_samples {
528                let i_observed = if missing_values.is_nan() {
529                    !X[[sample_idx, i]].is_nan()
530                } else {
531                    (X[[sample_idx, i]] - missing_values).abs() >= f64::EPSILON
532                };
533
534                let j_observed = if missing_values.is_nan() {
535                    !X[[sample_idx, j]].is_nan()
536                } else {
537                    (X[[sample_idx, j]] - missing_values).abs() >= f64::EPSILON
538                };
539
540                if i_observed && j_observed {
541                    joint_observed += 1;
542                }
543            }
544
545            completeness_matrix[[i, j]] = joint_observed as f64 / n_samples as f64;
546        }
547    }
548
549    Ok(completeness_matrix)
550}
551
552/// Generate comprehensive missing data summary
553pub fn missing_data_summary(X: &ArrayView2<'_, Float>, missing_values: f64) -> SklResult<String> {
554    generate_missing_summary_stats(X, missing_values)
555}
556
557// Helper function for correlation computation
558fn compute_correlation(x: &Array1<f64>, y: &Array1<f64>) -> f64 {
559    let n = x.len() as f64;
560    if n == 0.0 {
561        return 0.0;
562    }
563
564    let mean_x = x.sum() / n;
565    let mean_y = y.sum() / n;
566
567    let mut numerator = 0.0;
568    let mut var_x = 0.0;
569    let mut var_y = 0.0;
570
571    for i in 0..x.len() {
572        let dx = x[i] - mean_x;
573        let dy = y[i] - mean_y;
574
575        numerator += dx * dy;
576        var_x += dx * dx;
577        var_y += dy * dy;
578    }
579
580    let denominator = (var_x * var_y).sqrt();
581    if denominator == 0.0 {
582        0.0
583    } else {
584        numerator / denominator
585    }
586}
587
588// Test module declaration - tests are in separate file for better organization
589#[allow(non_snake_case)]
590#[cfg(test)]
591mod tests;