Skip to main content

sklears_linear/
lib.rs

1#![allow(dead_code)]
2#![allow(non_snake_case)]
3#![allow(missing_docs)]
4#![allow(deprecated)]
5#![allow(clippy::needless_range_loop)]
6#![allow(clippy::derivable_impls)]
7#![allow(clippy::needless_borrow)]
8//! Linear models for sklears
9//!
10//! This crate provides implementations of linear models including:
11//! - Linear Regression (OLS, Ridge, Lasso)
12//! - Logistic Regression
13//! - Generalized Linear Models
14//!
15//! These implementations leverage scirs2's linear algebra and optimization capabilities.
16
17#[cfg(feature = "admm")]
18pub mod admm;
19#[cfg(feature = "gpu")]
20pub mod advanced_gpu_acceleration;
21#[cfg(feature = "bayesian")]
22pub mod bayesian;
23pub mod builder_enhancements;
24#[cfg(feature = "feature-selection")]
25pub mod categorical_encoding;
26#[cfg(feature = "chunked-processing")]
27pub mod chunked_processing;
28#[cfg(feature = "diagnostics")]
29pub mod classification_diagnostics;
30#[cfg(feature = "constrained-optimization")]
31// TODO: Migrate to scirs2-linalg (uses nalgebra types)
32//pub mod constrained_optimization;
33#[cfg(feature = "convergence-analysis")]
34pub mod convergence_visualization;
35#[cfg(feature = "coordinate-descent")]
36pub mod coordinate_descent;
37#[cfg(feature = "cross-validation")]
38pub mod cross_validation;
39#[cfg(feature = "early-stopping")]
40pub mod early_stopping;
41#[cfg(feature = "elastic-net")]
42pub mod elastic_net_cv;
43pub mod errors;
44#[cfg(feature = "feature-selection")]
45pub mod feature_scaling;
46#[cfg(feature = "feature-selection")]
47pub mod feature_selection;
48#[cfg(feature = "glm")]
49// TODO: Migrate to scirs2-linalg (uses ndarray_linalg::Solve)
50//pub mod glm;
51#[cfg(feature = "gpu")]
52pub mod gpu_acceleration;
53#[cfg(feature = "huber")]
54pub mod huber;
55pub mod irls;
56#[cfg(feature = "lasso")]
57pub mod lars;
58#[cfg(feature = "lasso")]
59pub mod lasso_cv;
60#[cfg(feature = "lasso")]
61pub mod lasso_lars;
62#[cfg(feature = "linear-regression")]
63pub mod linear_regression;
64#[cfg(feature = "logistic-regression")]
65pub mod logistic_regression;
66// TODO: Temporarily disabled until cross_val_score is generalized for LogisticRegression
67// #[cfg(feature = "logistic-regression")]
68// pub mod logistic_regression_cv;
69#[cfg(feature = "memory-mapping")]
70pub mod memory_efficient_ops;
71#[cfg(feature = "memory-mapping")]
72pub mod mmap_arrays;
73#[cfg(any(feature = "multi-task", feature = "all-algorithms"))]
74// TODO: Migrate to scirs2-linalg (uses nalgebra types)
75//pub mod multi_output_regression;
76#[cfg(feature = "multi-task-elastic-net")]
77pub mod multi_task_elastic_net;
78#[cfg(feature = "multi-task-elastic-net")]
79pub mod multi_task_elastic_net_cv;
80#[cfg(feature = "multi-task")]
81pub mod multi_task_feature_selection;
82#[cfg(feature = "multi-task-lasso")]
83pub mod multi_task_lasso;
84#[cfg(feature = "multi-task-lasso")]
85pub mod multi_task_lasso_cv;
86#[cfg(feature = "multi-task")]
87pub mod multi_task_shared_representation;
88#[cfg(feature = "lasso")]
89pub mod omp;
90#[cfg(feature = "online-learning")]
91pub mod online_learning;
92pub mod optimizer;
93#[cfg(feature = "sgd")]
94pub mod passive_aggressive;
95#[cfg(feature = "regularization-path")]
96pub mod paths;
97#[cfg(feature = "sgd")]
98pub mod perceptron;
99#[cfg(feature = "feature-selection")]
100pub mod polynomial_features;
101#[cfg(feature = "quantile-regression")]
102// TODO: Migrate to scirs2-linalg (uses ndarray_linalg::Solve)
103//pub mod quantile;
104#[cfg(feature = "ransac")]
105pub mod ransac;
106#[cfg(feature = "feature-selection")]
107pub mod recursive_feature_elimination;
108#[cfg(feature = "residual-analysis")]
109pub mod residual_analysis;
110#[cfg(feature = "ridge")]
111pub mod ridge_classifier;
112#[cfg(feature = "ridge")]
113pub mod ridge_cv;
114#[cfg(feature = "serde")]
115// TODO: Migrate to scirs2-linalg (uses nalgebra types)
116//pub mod serialization;
117#[cfg(feature = "sgd")]
118pub mod sgd;
119#[cfg(feature = "simd")]
120pub mod simd_optimizations;
121pub mod solver;
122#[cfg(feature = "sparse")]
123pub mod sparse;
124#[cfg(feature = "sparse")]
125pub mod sparse_linear_regression;
126#[cfg(feature = "sparse")]
127pub mod sparse_regularized;
128#[cfg(feature = "feature-selection")]
129pub mod stability_selection;
130#[cfg(feature = "streaming")]
131pub mod streaming_algorithms;
132pub mod utils;
133
134// New modular framework modules
135pub mod large_scale_variational_inference;
136pub mod loss_functions;
137pub mod modular_framework;
138pub mod regularization_schemes;
139pub mod solver_implementations;
140pub mod type_safety;
141pub mod uncertainty_quantification;
142
143#[cfg(feature = "theil-sen")]
144pub mod theil_sen;
145
146//#[allow(non_snake_case)]
147#[cfg(test)]
148//pub mod advanced_property_tests;
149#[cfg(feature = "admm")]
150pub use admm::{AdmmConfig, AdmmSolution, AdmmSolver};
151
152#[cfg(feature = "bayesian")]
153pub use bayesian::{
154    ARDRegression, ARDRegressionConfig, BayesianRidge, BayesianRidgeConfig,
155    VariationalBayesianConfig, VariationalBayesianRegression,
156};
157pub use builder_enhancements::{
158    EnhancedLinearRegressionBuilder, ModelPreset, ModelValidation, ValidationConfig,
159};
160
161#[cfg(feature = "logistic-regression")]
162pub use builder_enhancements::EnhancedLogisticRegressionBuilder;
163
164#[cfg(feature = "feature-selection")]
165pub use categorical_encoding::{
166    CategoricalEncoder, CategoricalEncodingConfig, CategoricalEncodingResult,
167    CategoricalEncodingStrategy, CategoricalFeatureInfo, UnknownHandling,
168};
169
170#[cfg(feature = "chunked-processing")]
171pub use chunked_processing::{
172    ChunkProcessingConfig, ChunkProcessingResult, ChunkedDataIterator, ChunkedLinearRegression,
173    ChunkedMatrixProcessor, ChunkedProcessingUtils, ChunkedProcessor, MemoryStats,
174    ParallelChunkedProcessor,
175};
176
177#[cfg(feature = "diagnostics")]
178pub use classification_diagnostics::{
179    CalibrationResult, ClassImbalanceResult, ClassificationDiagnostics,
180    ClassificationDiagnosticsConfig, DecisionBoundaryResult, FeatureImportanceMethod,
181    FeatureImportanceResult,
182};
183
184#[cfg(feature = "constrained-optimization")]
185// TODO: Migrate to scirs2-linalg (uses nalgebra types)
186// pub use constrained_optimization::{
187//     ConstrainedLinearRegression, ConstrainedOptimizationBuilder, ConstrainedOptimizationConfig,
188//     ConstrainedOptimizationProblem, ConstrainedOptimizationResult, ConstraintType,
189//     InteriorPointSolver,
190// };
191#[cfg(feature = "convergence-analysis")]
192pub use convergence_visualization::{
193    ComparisonResult, ConvergenceAnalysis, ConvergenceConfig, ConvergenceCriteria,
194    ConvergenceCriterion, ConvergenceMetric, ConvergenceReport, ConvergenceStatus,
195    ConvergenceTracker, MetricHistory, MetricSummary, PlotData,
196};
197
198#[cfg(feature = "coordinate-descent")]
199pub use coordinate_descent::{CoordinateDescentSolver, ValidationInfo};
200
201#[cfg(feature = "cross-validation")]
202pub use cross_validation::{
203    cross_validate_with_early_stopping, CVStrategy, CrossValidationResult,
204    CrossValidatorWithEarlyStopping, StratifiedKFold,
205};
206
207#[cfg(feature = "early-stopping")]
208pub use early_stopping::{
209    train_validation_split, EarlyStopping, EarlyStoppingCallback, EarlyStoppingConfig,
210    StoppingCriterion,
211};
212
213#[cfg(feature = "elastic-net")]
214pub use elastic_net_cv::{ElasticNetCV, ElasticNetCVConfig};
215pub use errors::{
216    ConfigurationError, ConfigurationErrorKind, ConvergenceInfo, CrossValidationError,
217    CrossValidationErrorKind, DataError, DataErrorKind, ErrorBuilder, ErrorSeverity, FeatureError,
218    FeatureErrorKind, FoldInfo, LinearModelError, MatrixError, MatrixErrorKind, MatrixInfo,
219    NumericalError, NumericalErrorKind, OptimizationError, OptimizationErrorKind, ResourceError,
220    ResourceErrorKind, ResourceInfo, StateError, StateErrorKind,
221};
222#[cfg(feature = "feature-selection")]
223pub use feature_scaling::{
224    FeatureScaler, FeatureScalerBuilder, FeatureScalingConfig, FeatureStats, PowerTransformMethod,
225    ScalingMethod,
226};
227#[cfg(feature = "feature-selection")]
228pub use feature_selection::{
229    FeatureScore, FeatureSelectionConfig, FeatureSelector, ModelBasedEstimator, UnivariateScoreFunc,
230};
231#[cfg(feature = "glm")]
232// TODO: Migrate to scirs2-linalg (uses ndarray_linalg::Solve)
233// pub use glm::{Family, GLMConfig, GeneralizedLinearModel, Link};
234#[cfg(feature = "huber")]
235pub use huber::{HuberRegressor, HuberRegressorConfig};
236// TODO: Migrate to scirs2-linalg (uses ndarray_linalg::Solve)
237// pub use irls::{IRLSConfig, IRLSEstimator, IRLSResult, ScaleEstimator, WeightFunction};
238#[cfg(feature = "lasso")]
239pub use lars::{Lars, LarsConfig};
240#[cfg(feature = "lasso")]
241pub use lasso_cv::{LassoCV, LassoCVConfig};
242#[cfg(feature = "lasso")]
243pub use lasso_lars::{LassoLars, LassoLarsConfig};
244#[cfg(feature = "linear-regression")]
245pub use linear_regression::{LinearRegression, LinearRegressionConfig};
246#[cfg(feature = "logistic-regression")]
247pub use logistic_regression::{LogisticRegression, LogisticRegressionConfig};
248// TODO: Temporarily disabled until cross_val_score is generalized for LogisticRegression
249// #[cfg(feature = "logistic-regression")]
250// pub use logistic_regression_cv::{LogisticRegressionCV, LogisticRegressionCVConfig};
251#[cfg(feature = "memory-mapping")]
252pub use memory_efficient_ops::{
253    MemoryEfficiencyConfig, MemoryEfficientCoordinateDescent, MemoryEfficientOps, MemoryOperation,
254    NormType,
255};
256#[cfg(feature = "memory-mapping")]
257pub use mmap_arrays::{
258    MmapAdvice, MmapConfig, MmapMatrix, MmapMatrixMut, MmapUtils, MmapVector, MmapVectorMut,
259};
260#[cfg(any(feature = "multi-task", feature = "all-algorithms"))]
261// TODO: Migrate to scirs2-linalg (uses nalgebra types)
262// pub use multi_output_regression::{
263//     MultiOutputConfig, MultiOutputRegression, MultiOutputRegressionBuilder, MultiOutputResult,
264//     MultiOutputStrategy,
265// };
266#[cfg(feature = "multi-task-elastic-net")]
267pub use multi_task_elastic_net::{MultiTaskElasticNet, MultiTaskElasticNetConfig};
268#[cfg(feature = "multi-task-elastic-net")]
269pub use multi_task_elastic_net_cv::{MultiTaskElasticNetCV, MultiTaskElasticNetCVConfig};
270#[cfg(feature = "multi-task")]
271pub use multi_task_feature_selection::{
272    FeatureSelectionResult, FeatureSelectionStrategy, MultiTaskFeatureSelectionConfig,
273    MultiTaskFeatureSelector, SelectionSummary,
274};
275#[cfg(feature = "multi-task-lasso")]
276pub use multi_task_lasso::{MultiTaskLasso, MultiTaskLassoConfig};
277#[cfg(feature = "multi-task-lasso")]
278pub use multi_task_lasso_cv::{MultiTaskLassoCV, MultiTaskLassoCVConfig};
279#[cfg(feature = "multi-task")]
280pub use multi_task_shared_representation::{
281    MultiTaskSharedRepresentation, SharedRepresentationBuilder, SharedRepresentationConfig,
282    SharedRepresentationStrategy,
283};
284#[cfg(feature = "lasso")]
285pub use omp::{OrthogonalMatchingPursuit, OrthogonalMatchingPursuitConfig};
286#[cfg(feature = "online-learning")]
287pub use online_learning::{
288    LearningRateSchedule, MiniBatchConfig, MiniBatchIterator, OnlineCoordinateDescent,
289    OnlineLearningConfig, OnlineLinearRegression, OnlineLogisticRegression, SGDVariant,
290};
291pub use optimizer::{
292    FistaOptimizer, LbfgsOptimizer, NesterovAcceleratedGradient, ProximalGradientOptimizer,
293    SagOptimizer, SagaOptimizer,
294};
295#[cfg(feature = "sgd")]
296pub use passive_aggressive::{
297    PassiveAggressiveClassifier, PassiveAggressiveClassifierConfig, PassiveAggressiveLoss,
298    PassiveAggressiveRegressor, PassiveAggressiveRegressorConfig,
299};
300#[cfg(feature = "sgd")]
301pub use perceptron::{Perceptron, PerceptronConfig, PerceptronPenalty};
302#[cfg(feature = "feature-selection")]
303pub use polynomial_features::{
304    FeatureInfo, PolynomialConfig, PolynomialFeatures, PolynomialFeaturesBuilder, PolynomialUtils,
305};
306#[cfg(feature = "quantile-regression")]
307// TODO: Migrate to scirs2-linalg (uses ndarray_linalg::Solve)
308// pub use quantile::{QuantileRegressor, QuantileRegressorConfig, QuantileSolver, SolverOptions};
309#[cfg(feature = "ransac")]
310pub use ransac::{RANSACLoss, RANSACRegressor, RANSACRegressorConfig};
311#[cfg(feature = "feature-selection")]
312pub use recursive_feature_elimination::{
313    RFEConfig, RFEEstimator, RFEFeatureInfo, RFEResult, RecursiveFeatureElimination, ScoringMetric,
314};
315#[cfg(feature = "residual-analysis")]
316pub use residual_analysis::{
317    AssumptionResult, AssumptionTests, InfluenceMeasures, OutlierAnalysis, ResidualAnalysisConfig,
318    ResidualAnalysisResult, ResidualAnalyzer, ResidualStats, StatisticalTests, TestResult,
319};
320#[cfg(feature = "ridge")]
321pub use ridge_classifier::{RidgeClassifier, RidgeClassifierConfig};
322#[cfg(feature = "ridge")]
323pub use ridge_cv::{RidgeCV, RidgeCVConfig};
324#[cfg(feature = "serde")]
325// TODO: Migrate to scirs2-linalg (uses nalgebra types)
326// pub use serialization::{
327//     ModelMetadata, ModelRegistry, ModelSerializer, ModelVersioning, PerformanceMetrics,
328//     SerializableConstrainedOptimization, SerializableLassoRegression, SerializableLinearRegression,
329//     SerializableMatrix, SerializableModel, SerializableMultiOutputRegression,
330//     SerializableRidgeRegression, SerializableVector, SerializationFormat, TrainingInfo,
331// };
332#[cfg(feature = "sgd")]
333pub use sgd::{
334    SGDClassifier, SGDClassifierConfig, SGDLoss, SGDPenalty, SGDRegressor, SGDRegressorConfig,
335};
336// #[cfg(feature = "simd")]
337// TODO: Migrate to scirs2-linalg (uses nalgebra types)
338// pub use simd_optimizations::{
339//     SimdConfig, SimdCoordinateDescent, SimdFeatures, SimdLinearRegression, SimdOps,
340// };
341#[cfg(feature = "theil-sen")]
342pub use theil_sen::{TheilSenRegressor, TheilSenRegressorConfig};
343// Exports for new modular framework
344pub use modular_framework::{
345    create_modular_linear_regression, BayesianPredictionProvider, CompositeObjective,
346    LinearPredictionProvider, ModularConfig, ModularFramework, ModularLinearModel, Objective,
347    ObjectiveData, ObjectiveMetadata, OptimizationResult, OptimizationSolver, PredictionProvider,
348    PredictionWithConfidence, PredictionWithUncertainty, ProbabilisticPredictionProvider,
349    SolverInfo, SolverRecommendations,
350};
351pub use solver::Solver;
352#[cfg(feature = "feature-selection")]
353pub use stability_selection::{
354    BaseSelector, BootstrapResult, StabilityPath, StabilitySelection, StabilitySelectionConfig,
355    StabilitySelectionResult,
356};
357#[cfg(feature = "streaming")]
358pub use streaming_algorithms::{
359    DataStreamIterator, StreamingConfig, StreamingLasso, StreamingLinearRegression,
360    StreamingLinearRegressionBuilder, StreamingUtils,
361};
362
363pub use loss_functions::{
364    AbsoluteLoss, EpsilonInsensitiveLoss, HingeLoss, HuberLoss, LogisticLoss, LossFactory,
365    QuantileLoss, SquaredHingeLoss, SquaredLoss,
366};
367
368pub use regularization_schemes::{
369    CompositeRegularization, ElasticNetRegularization, GroupLassoRegularization, L1Regularization,
370    L2Regularization, RegularizationFactory,
371};
372
373pub use modular_framework::Regularization;
374
375pub use solver_implementations::{
376    BacktrackingConfig, CoordinateDescentConfig, CoordinateDescentResult, GradientDescentConfig,
377    GradientDescentResult, GradientDescentSolver, LineSearchConfig, ProximalGradientConfig,
378    ProximalGradientResult, ProximalGradientSolver, SolverFactory,
379};
380
381pub use type_safety::{
382    problem_type, solver_capability, ComputationalComplexity, ConfigurationHints,
383    ConfigurationValidator, FeatureValidator, FixedSizeOps, L1Scheme, L2Scheme,
384    LargeLinearRegression, MediumLinearRegression, MemoryRequirements, RegularizationConstraint,
385    RegularizationScheme, SmallLinearRegression, SolverConstraint, Trained, TypeSafeConfig,
386    TypeSafeFit, TypeSafeLinearModel, TypeSafeModelBuilder, TypeSafePredict,
387    TypeSafeSolverSelector, Untrained,
388};
389
390pub use large_scale_variational_inference::{
391    ARDConfiguration, LargeScaleVariationalConfig, LargeScaleVariationalRegression,
392    LearningRateDecay, PriorConfiguration, VariationalPosterior,
393};
394
395pub use uncertainty_quantification::{
396    CalibrationMetrics, UncertaintyCapable, UncertaintyConfig, UncertaintyMethod,
397    UncertaintyQuantifier, UncertaintyResult,
398};
399
400// Re-export path functions
401#[cfg(feature = "regularization-path")]
402pub use paths::{
403    enet_path, enet_path_enhanced, lars_path, lars_path_gram, lasso_path, ElasticNetPathConfig,
404    ElasticNetPathResult,
405};
406
407// Re-export utility functions
408pub use crate::utils::{
409    accurate_condition_number, adaptive_least_squares, condition_number,
410    diagnose_numerical_stability, enhanced_ridge_regression, orthogonal_mp, orthogonal_mp_gram,
411    qr_ridge_regression, rank_revealing_qr, ridge_regression, solve_with_iterative_refinement,
412    stable_normal_equations, stable_ridge_regression, svd_ridge_regression, NumericalDiagnostics,
413};
414
415// Re-export sparse matrix functionality
416#[cfg(feature = "sparse")]
417pub use sparse::{
418    Either, SparseConfig, SparseCoordinateDescentSolver, SparseMatrix, SparseMatrixCSR,
419    SparsityAnalysis,
420};
421
422#[cfg(feature = "sparse")]
423pub use sparse_linear_regression::{SparseLinearRegression, SparseLinearRegressionConfig};
424
425#[cfg(feature = "sparse")]
426pub use sparse_regularized::{
427    SparseElasticNet, SparseElasticNetConfig, SparseLasso, SparseLassoConfig,
428};
429
430// Disabled error functions are defined inline when sparse feature is not enabled
431#[cfg(not(feature = "sparse"))]
432pub fn sparse_feature_disabled_error() -> sklears_core::error::SklearsError {
433    sklears_core::error::SklearsError::InvalidParameter {
434        name: "sparse".to_string(),
435        reason: "Sparse matrix support requires the 'sparse' feature".to_string(),
436    }
437}
438
439#[cfg(not(feature = "sparse"))]
440pub fn sparse_linear_regression_disabled_error() -> sklears_core::error::SklearsError {
441    sklears_core::error::SklearsError::InvalidParameter {
442        name: "sparse-linear-regression".to_string(),
443        reason: "Sparse linear regression requires the 'sparse' feature".to_string(),
444    }
445}
446
447#[cfg(not(feature = "sparse"))]
448pub fn sparse_regularized_disabled_error() -> sklears_core::error::SklearsError {
449    sklears_core::error::SklearsError::InvalidParameter {
450        name: "sparse-regularized".to_string(),
451        reason: "Sparse regularized models require the 'sparse' feature".to_string(),
452    }
453}
454
455/// Penalty types for regularized models
456#[derive(Debug, Clone, Copy, PartialEq, Default)]
457pub enum Penalty {
458    /// No regularization
459    #[default]
460    None,
461    /// L1 regularization (Lasso)
462    L1(f64),
463    /// L2 regularization (Ridge)
464    L2(f64),
465    /// Elastic Net (L1 + L2)
466    ElasticNet { l1_ratio: f64, alpha: f64 },
467}