sklears_linear/
builder_enhancements.rs

1//! Enhanced builder patterns for linear models
2//!
3//! This module provides advanced builder patterns with:
4//! - Fluent API for complex model configurations
5//! - Compile-time parameter validation
6//! - Configuration presets for common use cases
7//! - Method chaining for model configuration
8//!
9//! The builders use phantom types to ensure type safety and prevent invalid configurations.
10
11use std::marker::PhantomData;
12
13use sklears_core::{
14    error::{Result, SklearsError},
15    traits::Estimator,
16    types::Float,
17};
18
19use crate::{LinearRegression, LinearRegressionConfig, Penalty, Solver};
20
21#[cfg(feature = "logistic-regression")]
22use crate::{LogisticRegression, LogisticRegressionConfig};
23
24/// Configuration presets for common use cases
25#[derive(Debug, Clone, Copy, PartialEq)]
26pub enum ModelPreset {
27    Quick,
28    Balanced,
29    HighAccuracy,
30    Robust,
31    MemoryEfficient,
32    Production,
33}
34
35/// Marker traits for compile-time validation
36pub mod validation {
37    /// Marker for models that have been properly configured
38    pub trait Configured {}
39
40    /// Marker for models that have regularization configured
41    pub trait WithRegularization {}
42
43    /// Marker for models that have solver configured
44    pub trait WithSolver {}
45}
46
47/// Enhanced builder for Linear Regression with preset configurations
48#[derive(Debug, Clone)]
49pub struct EnhancedLinearRegressionBuilder<State = Unconfigured> {
50    config: LinearRegressionConfig,
51    validation_config: ValidationConfig,
52    _state: PhantomData<State>,
53}
54
55/// Enhanced builder for Logistic Regression with preset configurations
56#[cfg(feature = "logistic-regression")]
57#[derive(Debug, Clone)]
58pub struct EnhancedLogisticRegressionBuilder<State = Unconfigured> {
59    config: LogisticRegressionConfig,
60    validation_config: ValidationConfig,
61    _state: PhantomData<State>,
62}
63
64/// Marker type for unconfigured builders
65#[derive(Debug, Clone, Copy)]
66pub struct Unconfigured;
67
68/// Marker type for configured builders
69#[derive(Debug, Clone, Copy)]
70pub struct Configured;
71
72/// Marker type for builders with regularization
73#[derive(Debug, Clone, Copy)]
74pub struct WithRegularization;
75
76/// Marker type for builders with solver
77#[derive(Debug, Clone, Copy)]
78pub struct WithSolver;
79
80/// Validation configuration for enhanced models
81#[derive(Debug, Clone, Default)]
82pub struct ValidationConfig {
83    /// Number of cross-validation folds
84    pub cross_validation_folds: Option<usize>,
85    /// Validation split ratio
86    pub validation_split: Option<Float>,
87    /// Whether to use early stopping
88    pub early_stopping: bool,
89    /// Random state for reproducibility
90    pub random_state: Option<u64>,
91}
92
93// Enhanced Linear Regression Builder Implementation
94impl Default for EnhancedLinearRegressionBuilder<Unconfigured> {
95    fn default() -> Self {
96        Self {
97            config: LinearRegressionConfig::default(),
98            validation_config: ValidationConfig::default(),
99            _state: PhantomData,
100        }
101    }
102}
103
104impl EnhancedLinearRegressionBuilder<Unconfigured> {
105    /// Create a new enhanced linear regression builder
106    pub fn new() -> Self {
107        Self::default()
108    }
109
110    /// Start with a preset configuration
111    pub fn with_preset(preset: ModelPreset) -> EnhancedLinearRegressionBuilder<Configured> {
112        let builder = Self::new();
113        builder.apply_preset(preset)
114    }
115
116    /// Apply a configuration preset
117    pub fn apply_preset(
118        mut self,
119        preset: ModelPreset,
120    ) -> EnhancedLinearRegressionBuilder<Configured> {
121        match preset {
122            ModelPreset::Quick => {
123                self.config.solver = Solver::Normal;
124                self.config.fit_intercept = true;
125                self.config.max_iter = 100;
126            }
127            ModelPreset::Balanced => {
128                self.config.solver = Solver::Auto;
129                self.config.fit_intercept = true;
130                self.config.max_iter = 1000;
131                self.config.penalty = Penalty::L2(0.1);
132            }
133            ModelPreset::HighAccuracy => {
134                self.config.solver = Solver::Normal;
135                self.config.fit_intercept = true;
136                self.config.max_iter = 5000;
137                self.config.penalty = Penalty::L2(0.01);
138                self.validation_config.cross_validation_folds = Some(5);
139            }
140            ModelPreset::Robust => {
141                self.config.solver = Solver::Auto;
142                self.config.fit_intercept = true;
143                self.config.penalty = Penalty::L1(0.1);
144                self.config.max_iter = 2000;
145            }
146            ModelPreset::MemoryEfficient => {
147                self.config.solver = Solver::Normal;
148                self.config.fit_intercept = true;
149                self.config.max_iter = 500;
150            }
151            ModelPreset::Production => {
152                self.config.solver = Solver::Auto;
153                self.config.fit_intercept = true;
154                self.config.penalty = Penalty::ElasticNet {
155                    l1_ratio: 0.5,
156                    alpha: 0.1,
157                };
158                self.config.max_iter = 3000;
159                self.validation_config.cross_validation_folds = Some(10);
160                self.validation_config.early_stopping = true;
161            }
162        }
163
164        EnhancedLinearRegressionBuilder {
165            config: self.config,
166            validation_config: self.validation_config,
167            _state: PhantomData,
168        }
169    }
170}
171
172impl<State> EnhancedLinearRegressionBuilder<State> {
173    /// Set the solver
174    pub fn solver(mut self, solver: Solver) -> EnhancedLinearRegressionBuilder<WithSolver> {
175        self.config.solver = solver;
176        EnhancedLinearRegressionBuilder {
177            config: self.config,
178            validation_config: self.validation_config,
179            _state: PhantomData,
180        }
181    }
182
183    /// Set whether to fit intercept
184    pub fn fit_intercept(mut self, fit_intercept: bool) -> Self {
185        self.config.fit_intercept = fit_intercept;
186        self
187    }
188
189    /// Set regularization penalty
190    pub fn penalty(
191        mut self,
192        penalty: Penalty,
193    ) -> EnhancedLinearRegressionBuilder<WithRegularization> {
194        self.config.penalty = penalty;
195        EnhancedLinearRegressionBuilder {
196            config: self.config,
197            validation_config: self.validation_config,
198            _state: PhantomData,
199        }
200    }
201
202    /// Set maximum iterations
203    pub fn max_iter(mut self, max_iter: usize) -> Self {
204        self.config.max_iter = max_iter;
205        self
206    }
207
208    /// Set tolerance for convergence
209    pub fn tolerance(mut self, tol: f64) -> Self {
210        self.config.tol = tol;
211        self
212    }
213
214    /// Enable warm start
215    pub fn warm_start(mut self, warm_start: bool) -> Self {
216        self.config.warm_start = warm_start;
217        self
218    }
219
220    /// Configure cross-validation
221    pub fn with_cross_validation(mut self, folds: usize) -> Self {
222        self.validation_config.cross_validation_folds = Some(folds);
223        self
224    }
225
226    /// Configure validation split
227    pub fn with_validation_split(mut self, split: Float) -> Self {
228        self.validation_config.validation_split = Some(split);
229        self
230    }
231
232    /// Enable early stopping
233    pub fn with_early_stopping(mut self) -> Self {
234        self.validation_config.early_stopping = true;
235        self
236    }
237
238    /// Set random state for reproducibility
239    pub fn random_state(mut self, seed: u64) -> Self {
240        self.validation_config.random_state = Some(seed);
241        self
242    }
243
244    /// Build the linear regression model
245    pub fn build(self) -> Result<LinearRegression> {
246        LinearRegression::new()
247            .penalty(self.config.penalty)
248            .solver(self.config.solver)
249            .fit_intercept(self.config.fit_intercept)
250            .max_iter(self.config.max_iter)
251            .warm_start(self.config.warm_start)
252            .validate_config()
253    }
254
255    /// Get the configuration
256    pub fn config(&self) -> &LinearRegressionConfig {
257        &self.config
258    }
259
260    /// Get the validation configuration
261    pub fn validation_config(&self) -> &ValidationConfig {
262        &self.validation_config
263    }
264}
265
266// Enhanced Logistic Regression Builder Implementation
267#[cfg(feature = "logistic-regression")]
268impl Default for EnhancedLogisticRegressionBuilder<Unconfigured> {
269    fn default() -> Self {
270        Self {
271            config: LogisticRegressionConfig::default(),
272            validation_config: ValidationConfig::default(),
273            _state: PhantomData,
274        }
275    }
276}
277
278#[cfg(feature = "logistic-regression")]
279impl EnhancedLogisticRegressionBuilder<Unconfigured> {
280    /// Create a new enhanced logistic regression builder
281    pub fn new() -> Self {
282        Self::default()
283    }
284
285    /// Start with a preset configuration
286    pub fn with_preset(preset: ModelPreset) -> EnhancedLogisticRegressionBuilder<Configured> {
287        let builder = Self::new();
288        builder.apply_preset(preset)
289    }
290
291    /// Apply a configuration preset
292    pub fn apply_preset(
293        mut self,
294        preset: ModelPreset,
295    ) -> EnhancedLogisticRegressionBuilder<Configured> {
296        match preset {
297            ModelPreset::Quick => {
298                self.config.solver = Solver::Lbfgs;
299                self.config.max_iter = 100;
300                self.config.penalty = Penalty::L2(1.0);
301                self.config.tol = 1e-3;
302            }
303            ModelPreset::Balanced => {
304                self.config.solver = Solver::Auto;
305                self.config.max_iter = 1000;
306                self.config.penalty = Penalty::L2(1.0);
307                self.config.tol = 1e-4;
308            }
309            ModelPreset::HighAccuracy => {
310                self.config.solver = Solver::Lbfgs;
311                self.config.max_iter = 10000;
312                self.config.penalty = Penalty::ElasticNet {
313                    l1_ratio: 0.5,
314                    alpha: 1.0,
315                };
316                self.config.tol = 1e-6;
317                self.validation_config.cross_validation_folds = Some(5);
318            }
319            ModelPreset::Robust => {
320                self.config.solver = Solver::Saga;
321                self.config.penalty = Penalty::L1(1.0);
322                self.config.max_iter = 2000;
323                self.config.tol = 1e-4;
324            }
325            ModelPreset::MemoryEfficient => {
326                self.config.solver = Solver::Sag;
327                self.config.max_iter = 1000;
328                self.config.penalty = Penalty::L2(1.0);
329                self.config.tol = 1e-3;
330            }
331            ModelPreset::Production => {
332                self.config.solver = Solver::Lbfgs;
333                self.config.max_iter = 5000;
334                self.config.penalty = Penalty::ElasticNet {
335                    l1_ratio: 0.1,
336                    alpha: 1.0,
337                };
338                self.config.tol = 1e-5;
339                self.validation_config.cross_validation_folds = Some(5);
340                self.validation_config.early_stopping = true;
341            }
342        }
343
344        EnhancedLogisticRegressionBuilder {
345            config: self.config,
346            validation_config: self.validation_config,
347            _state: PhantomData,
348        }
349    }
350}
351
352#[cfg(feature = "logistic-regression")]
353impl<State> EnhancedLogisticRegressionBuilder<State> {
354    /// Set the penalty
355    pub fn penalty(
356        mut self,
357        penalty: Penalty,
358    ) -> EnhancedLogisticRegressionBuilder<WithRegularization> {
359        self.config.penalty = penalty;
360        EnhancedLogisticRegressionBuilder {
361            config: self.config,
362            validation_config: self.validation_config,
363            _state: PhantomData,
364        }
365    }
366
367    /// Set the solver
368    pub fn solver(mut self, solver: Solver) -> EnhancedLogisticRegressionBuilder<WithSolver> {
369        self.config.solver = solver;
370        EnhancedLogisticRegressionBuilder {
371            config: self.config,
372            validation_config: self.validation_config,
373            _state: PhantomData,
374        }
375    }
376
377    /// Set maximum iterations
378    pub fn max_iter(mut self, max_iter: usize) -> Self {
379        self.config.max_iter = max_iter;
380        self
381    }
382
383    /// Set tolerance
384    pub fn tolerance(mut self, tol: f64) -> Self {
385        self.config.tol = tol;
386        self
387    }
388
389    /// Set whether to fit intercept
390    pub fn fit_intercept(mut self, fit_intercept: bool) -> Self {
391        self.config.fit_intercept = fit_intercept;
392        self
393    }
394
395    /// Configure cross-validation
396    pub fn with_cross_validation(mut self, folds: usize) -> Self {
397        self.validation_config.cross_validation_folds = Some(folds);
398        self
399    }
400
401    /// Configure validation split
402    pub fn with_validation_split(mut self, split: Float) -> Self {
403        self.validation_config.validation_split = Some(split);
404        self
405    }
406
407    /// Enable early stopping
408    pub fn with_early_stopping(mut self) -> Self {
409        self.validation_config.early_stopping = true;
410        self
411    }
412
413    /// Set random state for reproducibility
414    pub fn random_state(mut self, seed: u64) -> Self {
415        self.config.random_state = Some(seed);
416        self.validation_config.random_state = Some(seed);
417        self
418    }
419
420    /// Build the logistic regression model
421    pub fn build(self) -> Result<LogisticRegression> {
422        Ok(LogisticRegression::new()
423            .penalty(self.config.penalty)
424            .solver(self.config.solver)
425            .max_iter(self.config.max_iter)
426            .fit_intercept(self.config.fit_intercept))
427    }
428
429    /// Get the configuration
430    pub fn config(&self) -> &LogisticRegressionConfig {
431        &self.config
432    }
433
434    /// Get the validation configuration
435    pub fn validation_config(&self) -> &ValidationConfig {
436        &self.validation_config
437    }
438}
439
440/// Compile-time validation trait implementations
441impl validation::Configured for EnhancedLinearRegressionBuilder<Configured> {}
442impl validation::WithRegularization for EnhancedLinearRegressionBuilder<WithRegularization> {}
443impl validation::WithSolver for EnhancedLinearRegressionBuilder<WithSolver> {}
444
445#[cfg(feature = "logistic-regression")]
446impl validation::Configured for EnhancedLogisticRegressionBuilder<Configured> {}
447#[cfg(feature = "logistic-regression")]
448impl validation::WithRegularization for EnhancedLogisticRegressionBuilder<WithRegularization> {}
449#[cfg(feature = "logistic-regression")]
450impl validation::WithSolver for EnhancedLogisticRegressionBuilder<WithSolver> {}
451
452/// Extension trait for model validation
453pub trait ModelValidation {
454    type Error;
455
456    /// Validate the model configuration
457    fn validate_config(self) -> std::result::Result<Self, Self::Error>
458    where
459        Self: Sized;
460}
461
462impl ModelValidation for LinearRegression {
463    type Error = SklearsError;
464
465    fn validate_config(self) -> std::result::Result<Self, Self::Error> {
466        // Add validation logic here
467        match self.config().penalty {
468            Penalty::L1(_) | Penalty::ElasticNet { .. } => {
469                if matches!(self.config().solver, Solver::Normal) {
470                    return Err(SklearsError::InvalidInput(
471                        "Normal equations solver does not support L1 regularization. Use CoordinateDescent or other iterative solver.".to_string()
472                    ));
473                }
474            }
475            _ => {}
476        }
477
478        if self.config().max_iter == 0 {
479            return Err(SklearsError::InvalidInput(
480                "max_iter must be greater than 0".to_string(),
481            ));
482        }
483
484        Ok(self)
485    }
486}
487
488#[allow(non_snake_case)]
489#[cfg(test)]
490mod tests {
491    use super::*;
492
493    #[test]
494    fn test_enhanced_linear_regression_builder_presets() {
495        let quick_model = EnhancedLinearRegressionBuilder::with_preset(ModelPreset::Quick)
496            .build()
497            .unwrap();
498        assert_eq!(quick_model.config().solver, Solver::Normal);
499
500        let balanced_model = EnhancedLinearRegressionBuilder::with_preset(ModelPreset::Balanced)
501            .build()
502            .unwrap();
503        assert_eq!(balanced_model.config().solver, Solver::Auto);
504
505        let production_model =
506            EnhancedLinearRegressionBuilder::with_preset(ModelPreset::Production)
507                .build()
508                .unwrap();
509        assert!(matches!(
510            production_model.config().penalty,
511            Penalty::ElasticNet { .. }
512        ));
513    }
514
515    #[test]
516    #[cfg(feature = "logistic-regression")]
517    fn test_enhanced_logistic_regression_builder_presets() {
518        let quick_model = EnhancedLogisticRegressionBuilder::with_preset(ModelPreset::Quick)
519            .build()
520            .unwrap();
521        assert_eq!(quick_model.config().solver, Solver::Lbfgs);
522
523        let robust_model = EnhancedLogisticRegressionBuilder::with_preset(ModelPreset::Robust)
524            .build()
525            .unwrap();
526        assert_eq!(robust_model.config().solver, Solver::Saga);
527    }
528
529    #[test]
530    fn test_builder_method_chaining() {
531        let model = EnhancedLinearRegressionBuilder::new()
532            .solver(Solver::CoordinateDescent)
533            .penalty(Penalty::L1(0.5))
534            .max_iter(2000)
535            .fit_intercept(false)
536            .with_cross_validation(5)
537            .with_early_stopping()
538            .build()
539            .unwrap();
540
541        assert_eq!(model.config().solver, Solver::CoordinateDescent);
542        assert!(matches!(model.config().penalty, Penalty::L1(_)));
543        assert_eq!(model.config().max_iter, 2000);
544        assert!(!model.config().fit_intercept);
545    }
546
547    #[test]
548    #[cfg(feature = "logistic-regression")]
549    fn test_fluent_api() {
550        let builder = EnhancedLogisticRegressionBuilder::new()
551            .penalty(Penalty::L2(2.0))
552            .solver(Solver::Saga)
553            .max_iter(1500)
554            .tolerance(1e-5)
555            .random_state(42);
556
557        assert!(matches!(builder.config().penalty, Penalty::L2(_)));
558        assert_eq!(builder.config().solver, Solver::Saga);
559        assert_eq!(builder.config().max_iter, 1500);
560        assert_eq!(builder.config().random_state, Some(42));
561    }
562
563    #[test]
564    fn test_configuration_validation() {
565        // Test that L1 penalty with Normal solver fails validation
566        let result = EnhancedLinearRegressionBuilder::new()
567            .solver(Solver::Normal)
568            .penalty(Penalty::L1(1.0))
569            .build();
570
571        assert!(result.is_err());
572    }
573
574    #[test]
575    fn test_preset_configurations_differ() {
576        let quick = EnhancedLinearRegressionBuilder::with_preset(ModelPreset::Quick);
577        let production = EnhancedLinearRegressionBuilder::with_preset(ModelPreset::Production);
578
579        assert_ne!(quick.config().max_iter, production.config().max_iter);
580        assert_ne!(
581            quick.validation_config().cross_validation_folds,
582            production.validation_config().cross_validation_folds
583        );
584    }
585}