sklears_cross_decomposition/
lib.rs

1//! Cross decomposition algorithms (PLS, CCA)
2//!
3//! This module is part of sklears, providing scikit-learn compatible
4//! machine learning algorithms in Rust.
5
6#![allow(missing_docs)]
7#![allow(unused_imports)]
8#![allow(deprecated)]
9#![allow(clippy::all)]
10#![allow(clippy::pedantic)]
11#![allow(clippy::nursery)]
12#![allow(non_snake_case)]
13#![allow(unused_variables)]
14#![allow(unused_assignments)]
15#![allow(unused_mut)]
16#![allow(dead_code)]
17#![allow(clippy::clone_on_copy)]
18#![allow(clippy::assign_op_pattern)]
19#![allow(clippy::derivable_impls)]
20#![allow(clippy::needless_range_loop)]
21#![allow(clippy::too_many_arguments)]
22#![allow(clippy::field_reassign_with_default)]
23#![allow(clippy::type_complexity)]
24#![allow(clippy::collapsible_else_if)]
25#![allow(clippy::op_ref)]
26#![allow(clippy::redundant_closure)]
27#![allow(clippy::manual_clamp)]
28#![allow(clippy::useless_format)]
29#![allow(clippy::unnecessary_cast)]
30#![allow(clippy::ptr_arg)]
31#![allow(unpredictable_function_pointer_comparisons)]
32#![recursion_limit = "512"]
33
34pub mod bayesian;
35pub mod benchmarks;
36pub mod cca;
37pub mod consensus_pca;
38// TODO: Disabled due to ndarray 0.17 HRTB trait bound issues
39// pub mod cross_validation;
40pub mod deep_cca;
41pub mod deep_learning;
42pub mod differential_geometry;
43pub mod federated_learning;
44pub mod finance;
45pub mod generalized_cca;
46pub mod genomics;
47pub mod gpu_acceleration;
48pub mod graph_regularization;
49pub mod information_theory;
50pub mod interactive_visualization;
51pub mod jive;
52pub mod kernel_cca;
53pub mod manifold_learning;
54pub mod multi_omics;
55pub mod multiblock_pls;
56pub mod multitask;
57pub mod multiview_cca;
58pub mod multiview_clustering;
59pub mod neuroimaging;
60pub mod opls;
61pub mod out_of_core;
62pub mod parallel;
63// TODO: Disabled due to ndarray 0.17 HRTB trait bound issues
64// pub mod permutation_tests;
65pub mod pls;
66pub mod pls_canonical;
67pub mod pls_da;
68pub mod pls_svd;
69pub mod quantum_methods;
70pub mod regularization;
71pub mod riemannian_optimization;
72pub mod robust_methods;
73pub mod scalability;
74pub mod simd_acceleration;
75pub mod sparse_pls;
76pub mod tensor_methods;
77pub mod time_series;
78pub mod type_safe_linalg;
79pub mod validation_framework;
80
81pub use bayesian::{
82    BayesianCCA, BayesianCCAResults, HierarchicalBayesianCCA, HierarchicalBayesianCCAResults,
83    VariationalPLS, VariationalPLSResults,
84};
85pub use benchmarks::{
86    AccuracyResults, BenchmarkResults, BenchmarkSuite, DecompositionResult, MethodBenchmarkResults,
87    ScalabilityResults, SpeedResult, SummaryStats,
88};
89pub use cca::{RidgeCCA, SparseCCA, CCA};
90pub use consensus_pca::ConsensusPCA;
91// TODO: Disabled due to ndarray 0.17 HRTB trait bound issues
92// pub use cross_validation::{
93//     CVResults, CVStrategy, CrossValidator, NestedCrossValidator, ScoringFunction,
94// };
95pub use deep_cca::{ActivationFunction, DeepCCA};
96pub use deep_learning::{
97    ActivationFunction as DeepActivationFunction, AttentionActivation, AttentionConfig,
98    AttentionLayer, AttentionOutput, AttentionTensorDecomposition, AttentionType,
99    CrossModalAttention, CrossModalAttentionOutput, CrossModalSimilarity, CrossModalVAE,
100    MultiHeadAttention, NeuralActivation, NeuralParafacDecomposition, NeuralTensorConfig,
101    NeuralTensorResults, NeuralTuckerDecomposition, TransformerDecoderBlock,
102    TransformerEncoderBlock, VAEConfig, VAETrainingResults, VariationalTensorNetwork,
103};
104pub use federated_learning::{
105    AggregationStrategy as FederatedAggregationStrategy, ClientId, CommunicationConfig,
106    FederatedCCA, FederatedCCAResults, FederatedClient, FederatedError, FederatedPCA,
107    FederatedPCAResults, FederatedServer, PrivacyBudget,
108};
109pub use finance::{
110    FactorConstrainedOptimization, FactorRotation, FactorStatistics, FinanceError,
111    FinancialFactorAnalysis, FittedFinancialFactorAnalysis, FittedMacroeconomicFactorAnalysis,
112    ForecastingModel, MacroFactorStatistics, MacroeconomicFactorAnalysis, OptimizedPortfolio,
113    RiskDecomposition,
114};
115pub use generalized_cca::GeneralizedCCA;
116pub use genomics::{
117    ConsensusMethod, EnhancedPathwayAnalysis, EnhancedPathwayResults, EnrichmentMethod,
118    FittedGeneEnvironmentInteraction, FittedMultiOmicsIntegration, FittedSingleCellMultiModal,
119    FittedTemporalGeneExpression, GeneEnvironmentInteraction, GenomicsError, MLScoringConfig,
120    MissingDataStrategy, MultiModalConfig, MultiOmicsIntegration, MultipleTestingCorrection,
121    NetworkAnalysisConfig, PathwayAnalysis, PathwayAnalysisConfig, PathwayDatabase,
122    SingleCellMultiModal, TemporalAnalysisConfig, TemporalGeneExpression,
123};
124pub use gpu_acceleration::{
125    GpuAcceleratedContext, GpuCCA, GpuCCAFitted, GpuMatrixOps, GpuMemoryInfo,
126};
127pub use graph_regularization::{
128    CommunityAlgorithm, CommunityDetectionConfig, CommunityDetector, CommunityStructure,
129    GraphBuilder, GraphRegularizationConfig, GraphRegularizationError, GraphRegularizedCCA,
130    GraphStructure, GraphType, Hypergraph, HypergraphCCA, HypergraphCCAResults,
131    HypergraphCentrality, HypergraphConfig, HypergraphLaplacianType, MotifType, MultiGraphCCA,
132    MultiWayInteractionAnalyzer, NetworkConstrainedPLS, RegularizationType,
133    TemporalAnalysisResults, TemporalMotif, TemporalNetwork, TemporalNetworkAnalyzer,
134    TemporalNetworkConfig,
135};
136pub use information_theory::{
137    ComponentInterpretation, ComponentInterpreter, ComponentSelection, ComponentSimilarityAnalysis,
138    DistanceBasedConfig, DistanceBasedMetric, DistanceBasedResults, DistanceCCA,
139    DistanceCovariance, EntropyComponentSelection, EntropyEstimator, FeatureContribution,
140    FeatureImportanceAnalyzer, FeatureImportanceResults, FittedMutualInformationCCA,
141    HigherOrderAnalyzer, HigherOrderConfig, HigherOrderResults, ImportanceMethod,
142    InformationGeometry, InformationMeasure, InformationTheoreticRegularization,
143    InformationTheoryError, KLDivergenceMethods, ManifoldStructure, MutualInformationCCA,
144    NonGaussianComponentAnalysis, NonGaussianResults, PolyspectralCCA, PolyspectralResults,
145    RegularizationMethod, RiemannianOptimizer, SelectionCriteria, SingleComponentInterpretation,
146    VariableInterpretation, HSIC,
147};
148pub use interactive_visualization::{
149    ColorScheme, InteractivePlot, InteractiveVisualizationConfig, InteractiveVisualizer, PlotData,
150    PlotType, VisualizationError,
151};
152pub use jive::JIVE;
153pub use kernel_cca::{KernelCCA, KernelType};
154pub use manifold_learning::{
155    AdvancedManifoldLearning, ConvergenceInfo, CrossModalAlignment, DistanceMetric, EigenSolver,
156    FittedManifoldAwareCCA, FittedManifoldCCA as FittedAdvancedManifoldCCA, GeodesicMethod,
157    ManifoldAwareCCA, ManifoldCCA as AdvancedManifoldCCA, ManifoldError, ManifoldLearning,
158    ManifoldLearningResult, ManifoldMethod, ManifoldProperties, ManifoldRegularization,
159    ManifoldResults, OptimizationParams, PathMethod,
160};
161pub use multi_omics::GenomicsError as MultiOmicsGenomicsError;
162pub use multiblock_pls::{BlockScaling, MultiBlockPLS};
163pub use multitask::{
164    DomainAdaptationCCA, FewShotCCA, MultiTaskCCA, SharedComponentAnalysis, TransferLearningCCA,
165};
166pub use multiview_cca::MultiViewCCA;
167pub use multiview_clustering::{
168    DistanceMetric as MultiViewDistanceMetric, InitMethod, MultiViewClustering,
169};
170pub use neuroimaging::{
171    BrainBehaviorCorrelation, BrainBehaviorResults, ConnectivityType, CorrelationMethod,
172    FunctionalConnectivity, FunctionalConnectivityResults, NetworkMeasures,
173};
174pub use opls::OPLS;
175pub use out_of_core::{
176    OOCAlgorithm, OutOfCoreCCA, OutOfCoreCCAResults, OutOfCorePLS, OutOfCorePLSResults,
177};
178pub use parallel::{
179    EigenMethod, OptimizedMatrixOps, ParallelEigenSolver, ParallelMatrixOps, ParallelSVD,
180    SVDAlgorithm, WorkStealingThreadPool,
181};
182// TODO: Disabled due to ndarray 0.17 HRTB trait bound issues
183// pub use permutation_tests::{
184//     ComputeStatistic, PermutationTest, PermutationTestResults, StabilityResults,
185//     StabilitySelection, TestStatistic,
186// };
187pub use pls::PLSRegression;
188pub use pls_canonical::PLSCanonical;
189pub use pls_da::PLSDA;
190pub use pls_svd::PLSSVD;
191pub use quantum_methods::{
192    QuantumCCA, QuantumCCAResults, QuantumCircuit, QuantumError, QuantumFeatureSelection,
193    QuantumGate, QuantumMethod, QuantumPCA, QuantumPCAResults, QuantumState,
194};
195pub use regularization::{AdaptiveLasso, ElasticNet, FusedLasso, GroupLasso, MCP, SCAD};
196pub use riemannian_optimization::{
197    CCAObjective, GrassmannManifold, LineSearchParams, ManifoldType, RiemannianAlgorithm,
198    RiemannianConfig, RiemannianError, RiemannianManifold, RiemannianObjective,
199    RiemannianOptimizer as RiemannianOptimizerAdvanced, RiemannianResults, SPDManifold,
200    StiefelManifold, TrustRegionParams,
201};
202pub use robust_methods::{MEstimatorType, RobustCCA, RobustPLS};
203pub use scalability::{
204    AggregationStrategy, DistributedCCA, DistributedCCAResults, MemoryEfficientCCA,
205};
206pub use simd_acceleration::{
207    AdvancedSimdConfig, AdvancedSimdOps, SimdBenchmarkResults, SimdCCA, SimdCCAFitted,
208    SimdMatrixOps,
209};
210pub use sparse_pls::SparsePLS;
211pub use tensor_methods::{
212    BayesianParafac, ParafacDecomposition, ProbabilisticConfig, ProbabilisticTensorResults,
213    ProbabilisticTucker, RobustProbabilisticTensor, SparseTensorDecomposition, TensorCCA,
214    TensorCompletion, TensorInitMethod, TuckerDecomposition,
215};
216pub use time_series::{
217    DynamicCCA, DynamicCCAResults, DynamicCCASummary, FittedRegimeSwitchingModel,
218    FittedStateSpaceModel, FittedVAR, GrangerCausalityTest, GrangerTestResult,
219    InformationCriterion, RegimeSwitchingModel, StateSpaceForecast, StateSpaceModel,
220    StateSpaceModelDiagnostics, StreamingCCA, TrendType, VARMethod, VectorAutoregression,
221};
222pub use type_safe_linalg::{
223    decomp, ops, Dim, MatrixDimension, SquareMatrix, TypeSafeMatrix, TypeSafeVector,
224};
225pub use validation_framework::{
226    BenchmarkDataset, CaseStudy, ComputationalBenchmarks, CorrelationStructure, CriterionType,
227    CrossValidationResult, CrossValidationSettings, DatasetCharacteristics,
228    DatasetValidationResult, DistributionType, PerformanceMetric, PerformanceRange,
229    PerformanceSummary, RobustnessAnalysis, ScalabilityAnalysis, SignificanceTest,
230    StatisticalTestResult, ValidationError, ValidationFramework, ValidationResults,
231};
232
233#[allow(non_snake_case)]
234#[cfg(test)]
235mod tests {
236    use super::*;
237    use scirs2_core::ndarray::array;
238    use sklears_core::traits::{Fit, Predict};
239
240    #[test]
241    fn test_pls_integration() {
242        let x = array![[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0],];
243
244        let y = array![[1.5], [2.5], [3.5], [4.5],];
245
246        let pls = PLSRegression::new(1);
247        let fitted = pls.fit(&x, &y).unwrap();
248        let predictions = fitted.predict(&x).unwrap();
249
250        assert_eq!(predictions.shape(), &[4, 1]);
251    }
252}