sklears_linear/
irls.rs

1//! Iteratively Reweighted Least Squares (IRLS) for robust regression
2//!
3//! IRLS is a method for fitting robust regression models by iteratively reweighting
4//! observations based on their residuals. It's particularly useful for M-estimators
5//! and can handle various loss functions and weight functions to downweight outliers.
6
7use sklears_core::error::SklearsError;
8use std::cmp::Ordering;
9
10/// Weight functions for IRLS
11#[derive(Debug, Clone, PartialEq)]
12pub enum WeightFunction {
13    /// Huber weight function: w(r) = min(1, c/|r|)
14    Huber { c: f64 },
15    /// Bisquare (Tukey) weight function: w(r) = (1 - (r/c)^2)^2 if |r| <= c, 0 otherwise
16    Bisquare { c: f64 },
17    /// Andrews wave weight function: w(r) = sin(πr/c) / (πr/c) if |r| <= c, 0 otherwise
18    Andrews { c: f64 },
19    /// Cauchy weight function: w(r) = 1 / (1 + (r/c)^2)
20    Cauchy { c: f64 },
21    /// Fair weight function: w(r) = 1 / (1 + |r|/c)
22    Fair { c: f64 },
23    /// Logistic weight function: w(r) = tanh(c*r) / (c*r)
24    Logistic { c: f64 },
25}
26
27/// Scale estimation methods
28#[derive(Debug, Clone, PartialEq)]
29pub enum ScaleEstimator {
30    /// Median Absolute Deviation (MAD)
31    MAD,
32    /// Standard deviation
33    StandardDeviation,
34    /// Interquartile range scaled to approximate standard deviation
35    IQR,
36    /// Custom fixed scale
37    Fixed(f64),
38}
39
40/// IRLS configuration
41#[derive(Debug, Clone)]
42pub struct IRLSConfig {
43    /// Weight function for robust estimation
44    pub weight_function: WeightFunction,
45    /// Scale estimation method
46    pub scale_estimator: ScaleEstimator,
47    /// Maximum number of iterations
48    pub max_iter: usize,
49    /// Convergence tolerance
50    pub tol: f64,
51    /// Whether to fit intercept
52    pub fit_intercept: bool,
53    /// Initial scale estimate (if None, estimated from data)
54    pub initial_scale: Option<f64>,
55    /// Minimum weight threshold (weights below this are set to 0)
56    pub min_weight: f64,
57    /// Whether to update scale at each iteration
58    pub update_scale: bool,
59    /// Regularization parameter (L2)
60    pub alpha: f64,
61}
62
63impl Default for IRLSConfig {
64    fn default() -> Self {
65        Self {
66            weight_function: WeightFunction::Huber { c: 1.345 },
67            scale_estimator: ScaleEstimator::MAD,
68            max_iter: 100,
69            tol: 1e-6,
70            fit_intercept: true,
71            initial_scale: None,
72            min_weight: 1e-8,
73            update_scale: true,
74            alpha: 0.0,
75        }
76    }
77}
78
79/// IRLS result
80#[derive(Debug, Clone)]
81pub struct IRLSResult {
82    /// Fitted coefficients
83    pub coefficients: Vec<f64>,
84    /// Fitted intercept (if fit_intercept=true)
85    pub intercept: Option<f64>,
86    /// Final weights for each observation
87    pub weights: Vec<f64>,
88    /// Final scale estimate
89    pub scale: f64,
90    /// Number of iterations performed
91    pub n_iter: usize,
92    /// Whether the algorithm converged
93    pub converged: bool,
94    /// Convergence history (coefficient changes)
95    pub convergence_history: Vec<f64>,
96    /// Configuration used
97    pub config: IRLSConfig,
98}
99
100/// Iteratively Reweighted Least Squares estimator
101pub struct IRLSEstimator {
102    config: IRLSConfig,
103    is_fitted: bool,
104    result: Option<IRLSResult>,
105}
106
107impl IRLSEstimator {
108    /// Create a new IRLS estimator with default configuration
109    pub fn new() -> Self {
110        Self {
111            config: IRLSConfig::default(),
112            is_fitted: false,
113            result: None,
114        }
115    }
116
117    /// Create an IRLS estimator with custom configuration
118    pub fn with_config(config: IRLSConfig) -> Self {
119        Self {
120            config,
121            is_fitted: false,
122            result: None,
123        }
124    }
125
126    /// Set the weight function
127    pub fn with_weight_function(mut self, weight_function: WeightFunction) -> Self {
128        self.config.weight_function = weight_function;
129        self
130    }
131
132    /// Set the scale estimator
133    pub fn with_scale_estimator(mut self, scale_estimator: ScaleEstimator) -> Self {
134        self.config.scale_estimator = scale_estimator;
135        self
136    }
137
138    /// Set maximum iterations
139    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
140        self.config.max_iter = max_iter;
141        self
142    }
143
144    /// Set convergence tolerance
145    pub fn with_tolerance(mut self, tol: f64) -> Self {
146        self.config.tol = tol;
147        self
148    }
149
150    /// Set whether to fit intercept
151    pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
152        self.config.fit_intercept = fit_intercept;
153        self
154    }
155
156    /// Set regularization parameter
157    pub fn with_alpha(mut self, alpha: f64) -> Self {
158        self.config.alpha = alpha;
159        self
160    }
161
162    /// Fit the IRLS model
163    pub fn fit(&mut self, x: &[Vec<f64>], y: &[f64]) -> Result<(), SklearsError> {
164        if x.is_empty() || y.is_empty() {
165            return Err(SklearsError::InvalidInput(
166                "Cannot fit IRLS on empty dataset".to_string(),
167            ));
168        }
169
170        let n_samples = x.len();
171        let n_features = x[0].len();
172
173        if y.len() != n_samples {
174            return Err(SklearsError::ShapeMismatch {
175                expected: format!("target.len() == {}", n_samples),
176                actual: format!("target.len() == {}", y.len()),
177            });
178        }
179
180        // Validate input data
181        for (i, row) in x.iter().enumerate() {
182            if row.len() != n_features {
183                return Err(SklearsError::ShapeMismatch {
184                    expected: format!("row[{}].len() == {}", i, n_features),
185                    actual: format!("row[{}].len() == {}", i, row.len()),
186                });
187            }
188        }
189
190        // Prepare design matrix
191        let x_matrix = if self.config.fit_intercept {
192            self.add_intercept_column(x)
193        } else {
194            x.to_vec()
195        };
196
197        let _effective_n_features = x_matrix[0].len();
198
199        // Initialize coefficients with ordinary least squares
200        let mut coefficients = self.ordinary_least_squares(&x_matrix, y)?;
201        let mut weights = vec![1.0; n_samples];
202        let mut convergence_history = Vec::new();
203
204        // Initial residuals and scale estimate
205        let mut residuals = self.compute_residuals(&x_matrix, y, &coefficients);
206        let mut scale = self.estimate_scale(&residuals)?;
207
208        let mut converged = false;
209        let mut n_iter = 0;
210
211        // IRLS iterations
212        for iteration in 0..self.config.max_iter {
213            n_iter = iteration + 1;
214
215            // Update weights based on residuals
216            self.update_weights(&residuals, scale, &mut weights);
217
218            // Weighted least squares
219            let new_coefficients = self.weighted_least_squares(&x_matrix, y, &weights)?;
220
221            // Check convergence
222            let coefficient_change =
223                self.compute_coefficient_change(&coefficients, &new_coefficients);
224            convergence_history.push(coefficient_change);
225
226            if coefficient_change < self.config.tol {
227                converged = true;
228                coefficients = new_coefficients;
229                break;
230            }
231
232            coefficients = new_coefficients;
233
234            // Update residuals
235            residuals = self.compute_residuals(&x_matrix, y, &coefficients);
236
237            // Update scale if requested
238            if self.config.update_scale {
239                scale = self.estimate_scale(&residuals)?;
240            }
241        }
242
243        // Split coefficients and intercept
244        let (final_coefficients, intercept) = if self.config.fit_intercept {
245            let intercept = coefficients[0];
246            let coefs = coefficients[1..].to_vec();
247            (coefs, Some(intercept))
248        } else {
249            (coefficients, None)
250        };
251
252        self.result = Some(IRLSResult {
253            coefficients: final_coefficients,
254            intercept,
255            weights,
256            scale,
257            n_iter,
258            converged,
259            convergence_history,
260            config: self.config.clone(),
261        });
262
263        self.is_fitted = true;
264        Ok(())
265    }
266
267    /// Predict using the fitted model
268    pub fn predict(&self, x: &[Vec<f64>]) -> Result<Vec<f64>, SklearsError> {
269        if !self.is_fitted {
270            return Err(SklearsError::NotFitted {
271                operation: "predict".to_string(),
272            });
273        }
274
275        let result = self.result.as_ref().unwrap();
276
277        if x.is_empty() {
278            return Ok(Vec::new());
279        }
280
281        let n_features = x[0].len();
282        if n_features != result.coefficients.len() {
283            return Err(SklearsError::FeatureMismatch {
284                expected: result.coefficients.len(),
285                actual: n_features,
286            });
287        }
288
289        let mut predictions = Vec::new();
290
291        for row in x {
292            let mut pred = 0.0;
293            for (i, &coef) in result.coefficients.iter().enumerate() {
294                pred += coef * row[i];
295            }
296
297            if let Some(intercept) = result.intercept {
298                pred += intercept;
299            }
300
301            predictions.push(pred);
302        }
303
304        Ok(predictions)
305    }
306
307    /// Get the fitting result
308    pub fn get_result(&self) -> Option<&IRLSResult> {
309        self.result.as_ref()
310    }
311
312    /// Get fitted coefficients
313    pub fn get_coefficients(&self) -> Option<&Vec<f64>> {
314        self.result.as_ref().map(|r| &r.coefficients)
315    }
316
317    /// Get fitted intercept
318    pub fn get_intercept(&self) -> Option<f64> {
319        self.result.as_ref().and_then(|r| r.intercept)
320    }
321
322    /// Get final weights
323    pub fn get_weights(&self) -> Option<&Vec<f64>> {
324        self.result.as_ref().map(|r| &r.weights)
325    }
326
327    /// Add intercept column to design matrix
328    fn add_intercept_column(&self, x: &[Vec<f64>]) -> Vec<Vec<f64>> {
329        x.iter()
330            .map(|row| {
331                let mut new_row = vec![1.0];
332                new_row.extend(row);
333                new_row
334            })
335            .collect()
336    }
337
338    /// Compute ordinary least squares solution
339    fn ordinary_least_squares(&self, x: &[Vec<f64>], y: &[f64]) -> Result<Vec<f64>, SklearsError> {
340        let n_samples = x.len();
341        let n_features = x[0].len();
342
343        // Compute X^T X
344        let mut xtx = vec![vec![0.0; n_features]; n_features];
345        #[allow(clippy::needless_range_loop)]
346        for i in 0..n_features {
347            for j in 0..n_features {
348                #[allow(clippy::needless_range_loop)]
349                for k in 0..n_samples {
350                    xtx[i][j] += x[k][i] * x[k][j];
351                }
352
353                // Add regularization to diagonal
354                if i == j {
355                    xtx[i][j] += self.config.alpha;
356                }
357            }
358        }
359
360        // Compute X^T y
361        let mut xty = vec![0.0; n_features];
362        #[allow(clippy::needless_range_loop)]
363        for i in 0..n_features {
364            for j in 0..n_samples {
365                xty[i] += x[j][i] * y[j];
366            }
367        }
368
369        // Solve system
370        self.solve_linear_system(&xtx, &xty)
371    }
372
373    /// Weighted least squares solution
374    fn weighted_least_squares(
375        &self,
376        x: &[Vec<f64>],
377        y: &[f64],
378        weights: &[f64],
379    ) -> Result<Vec<f64>, SklearsError> {
380        let n_samples = x.len();
381        let n_features = x[0].len();
382
383        // Compute X^T W X
384        let mut xtwx = vec![vec![0.0; n_features]; n_features];
385        #[allow(clippy::needless_range_loop)]
386        for i in 0..n_features {
387            for j in 0..n_features {
388                for k in 0..n_samples {
389                    xtwx[i][j] += weights[k] * x[k][i] * x[k][j];
390                }
391
392                // Add regularization to diagonal
393                if i == j {
394                    xtwx[i][j] += self.config.alpha;
395                }
396            }
397        }
398
399        // Compute X^T W y
400        let mut xtwy = vec![0.0; n_features];
401        #[allow(clippy::needless_range_loop)]
402        for i in 0..n_features {
403            for j in 0..n_samples {
404                xtwy[i] += weights[j] * x[j][i] * y[j];
405            }
406        }
407
408        // Solve system
409        self.solve_linear_system(&xtwx, &xtwy)
410    }
411
412    /// Solve linear system using Gaussian elimination with partial pivoting
413    fn solve_linear_system(&self, a: &[Vec<f64>], b: &[f64]) -> Result<Vec<f64>, SklearsError> {
414        match Self::gaussian_elimination(a, b) {
415            Ok(solution) => Ok(solution),
416            Err(original_error) => {
417                // Apply a small ridge regularization and retry to handle rank-deficient systems
418                let mut regularized = a.to_vec();
419                if regularized.is_empty() || regularized[0].is_empty() {
420                    return Err(original_error);
421                }
422
423                let ridge = if self.config.alpha > 0.0 {
424                    self.config.alpha
425                } else {
426                    1e-6
427                };
428
429                #[allow(clippy::needless_range_loop)]
430                for i in 0..regularized.len() {
431                    regularized[i][i] += ridge;
432                }
433
434                match Self::gaussian_elimination(&regularized, b) {
435                    Ok(solution) => Ok(solution),
436                    Err(_) => Err(original_error),
437                }
438            }
439        }
440    }
441
442    fn gaussian_elimination(a: &[Vec<f64>], b: &[f64]) -> Result<Vec<f64>, SklearsError> {
443        let n = a.len();
444        if n == 0 || b.len() != n {
445            return Err(SklearsError::InvalidInput(
446                "Matrix dimensions do not align for Gaussian elimination".to_string(),
447            ));
448        }
449
450        let mut aug_matrix = vec![vec![0.0; n + 1]; n];
451
452        // Create augmented matrix
453        for i in 0..n {
454            if a[i].len() != n {
455                return Err(SklearsError::InvalidInput(
456                    "Matrix must be square for Gaussian elimination".to_string(),
457                ));
458            }
459
460            for j in 0..n {
461                aug_matrix[i][j] = a[i][j];
462            }
463            aug_matrix[i][n] = b[i];
464        }
465
466        // Gaussian elimination with partial pivoting
467        for i in 0..n {
468            // Find pivot
469            let mut max_row = i;
470            for k in i + 1..n {
471                if aug_matrix[k][i].abs() > aug_matrix[max_row][i].abs() {
472                    max_row = k;
473                }
474            }
475
476            // Swap rows
477            if max_row != i {
478                aug_matrix.swap(i, max_row);
479            }
480
481            // Check for singularity
482            if aug_matrix[i][i].abs() < 1e-12 {
483                return Err(SklearsError::InvalidInput(
484                    "Matrix is singular or nearly singular. Add regularization or check for multicollinearity".to_string(),
485                ));
486            }
487
488            // Eliminate
489            for k in i + 1..n {
490                let factor = aug_matrix[k][i] / aug_matrix[i][i];
491                for j in i..n + 1 {
492                    aug_matrix[k][j] -= factor * aug_matrix[i][j];
493                }
494            }
495        }
496
497        // Back substitution
498        let mut solution = vec![0.0; n];
499        for i in (0..n).rev() {
500            solution[i] = aug_matrix[i][n];
501            for j in i + 1..n {
502                solution[i] -= aug_matrix[i][j] * solution[j];
503            }
504            solution[i] /= aug_matrix[i][i];
505        }
506
507        Ok(solution)
508    }
509
510    /// Compute residuals
511    fn compute_residuals(&self, x: &[Vec<f64>], y: &[f64], coefficients: &[f64]) -> Vec<f64> {
512        let mut residuals = Vec::new();
513
514        for (i, row) in x.iter().enumerate() {
515            let mut pred = 0.0;
516            for (j, &coef) in coefficients.iter().enumerate() {
517                pred += coef * row[j];
518            }
519            residuals.push(y[i] - pred);
520        }
521
522        residuals
523    }
524
525    /// Estimate scale parameter
526    fn estimate_scale(&self, residuals: &[f64]) -> Result<f64, SklearsError> {
527        if residuals.is_empty() {
528            return Err(SklearsError::InvalidInput(
529                "Cannot estimate scale from empty residuals".to_string(),
530            ));
531        }
532
533        let scale = match &self.config.scale_estimator {
534            ScaleEstimator::MAD => {
535                let mut abs_residuals: Vec<f64> = residuals.iter().map(|&r| r.abs()).collect();
536                abs_residuals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
537                let median = abs_residuals[abs_residuals.len() / 2];
538                median * 1.4826 // MAD to standard deviation conversion
539            }
540
541            ScaleEstimator::StandardDeviation => {
542                let mean = residuals.iter().sum::<f64>() / residuals.len() as f64;
543                let variance = residuals.iter().map(|&r| (r - mean).powi(2)).sum::<f64>()
544                    / (residuals.len() - 1) as f64;
545                variance.sqrt()
546            }
547
548            ScaleEstimator::IQR => {
549                let mut sorted_residuals: Vec<f64> = residuals.to_vec();
550                sorted_residuals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
551                let n = sorted_residuals.len();
552                let q1 = sorted_residuals[n / 4];
553                let q3 = sorted_residuals[3 * n / 4];
554                (q3 - q1) / 1.349 // IQR to standard deviation conversion
555            }
556
557            ScaleEstimator::Fixed(scale) => *scale,
558        };
559
560        if scale <= 0.0 {
561            Ok(1e-6) // Minimum scale to avoid division by zero
562        } else {
563            Ok(scale)
564        }
565    }
566
567    /// Update weights based on residuals
568    fn update_weights(&self, residuals: &[f64], scale: f64, weights: &mut [f64]) {
569        for (i, &residual) in residuals.iter().enumerate() {
570            let standardized_residual = residual / scale;
571            weights[i] = self
572                .compute_weight(standardized_residual)
573                .max(self.config.min_weight);
574        }
575    }
576
577    /// Compute weight for a given standardized residual
578    fn compute_weight(&self, r: f64) -> f64 {
579        match &self.config.weight_function {
580            WeightFunction::Huber { c } => {
581                if r.abs() <= *c {
582                    1.0
583                } else {
584                    c / r.abs()
585                }
586            }
587
588            WeightFunction::Bisquare { c } => {
589                if r.abs() <= *c {
590                    let ratio = r / c;
591                    (1.0 - ratio.powi(2)).powi(2)
592                } else {
593                    0.0
594                }
595            }
596
597            WeightFunction::Andrews { c } => {
598                if r.abs() <= *c {
599                    let ratio = std::f64::consts::PI * r / c;
600                    if ratio.abs() < 1e-10 {
601                        1.0
602                    } else {
603                        ratio.sin() / ratio
604                    }
605                } else {
606                    0.0
607                }
608            }
609
610            WeightFunction::Cauchy { c } => 1.0 / (1.0 + (r / c).powi(2)),
611
612            WeightFunction::Fair { c } => 1.0 / (1.0 + r.abs() / c),
613
614            WeightFunction::Logistic { c } => {
615                let cr = c * r;
616                if cr.abs() < 1e-10 {
617                    1.0
618                } else {
619                    cr.tanh() / cr
620                }
621            }
622        }
623    }
624
625    /// Compute coefficient change between iterations
626    fn compute_coefficient_change(&self, old_coefs: &[f64], new_coefs: &[f64]) -> f64 {
627        old_coefs
628            .iter()
629            .zip(new_coefs.iter())
630            .map(|(&old, &new)| (old - new).powi(2))
631            .sum::<f64>()
632            .sqrt()
633    }
634}
635
636impl Default for IRLSEstimator {
637    fn default() -> Self {
638        Self::new()
639    }
640}
641
642#[allow(non_snake_case)]
643#[cfg(test)]
644mod tests {
645    use super::*;
646
647    fn create_sample_data() -> (Vec<Vec<f64>>, Vec<f64>) {
648        // Linear relationship with some outliers
649        let x = vec![
650            vec![1.0],
651            vec![2.0],
652            vec![3.0],
653            vec![4.0],
654            vec![5.0],
655            vec![6.0],
656            vec![7.0],
657            vec![8.0],
658            vec![9.0],
659            vec![10.0],
660        ];
661
662        // y = 2*x + 1 with outliers
663        let mut y = vec![3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0];
664        y[8] = 50.0; // outlier
665        y[9] = 5.0; // outlier
666
667        (x, y)
668    }
669
670    #[test]
671    fn test_irls_basic() {
672        let mut irls = IRLSEstimator::new();
673        let (x, y) = create_sample_data();
674
675        let result = irls.fit(&x, &y);
676        assert!(result.is_ok());
677
678        let coefficients = irls.get_coefficients().unwrap();
679        assert_eq!(coefficients.len(), 1);
680
681        // Should be robust to outliers and close to true slope of 2
682        assert!((coefficients[0] - 2.0).abs() < 0.5);
683
684        let intercept = irls.get_intercept().unwrap();
685        // Should be close to true intercept of 1
686        assert!((intercept - 1.0).abs() < 1.0);
687    }
688
689    #[test]
690    fn test_irls_huber() {
691        let mut irls =
692            IRLSEstimator::new().with_weight_function(WeightFunction::Huber { c: 1.345 });
693
694        let (x, y) = create_sample_data();
695        let result = irls.fit(&x, &y);
696
697        assert!(result.is_ok());
698
699        let weights = irls.get_weights().unwrap();
700        assert_eq!(weights.len(), y.len());
701
702        // Outliers should have lower weights
703        assert!(weights[8] < weights[0]); // outlier has lower weight
704        assert!(weights[9] < weights[0]); // outlier has lower weight
705    }
706
707    #[test]
708    fn test_irls_bisquare() {
709        let mut irls =
710            IRLSEstimator::new().with_weight_function(WeightFunction::Bisquare { c: 4.685 });
711
712        let (x, y) = create_sample_data();
713        let result = irls.fit(&x, &y);
714
715        assert!(result.is_ok());
716        assert!(irls.is_fitted);
717    }
718
719    #[test]
720    fn test_irls_prediction() {
721        let mut irls = IRLSEstimator::new();
722        let (x, y) = create_sample_data();
723
724        irls.fit(&x, &y).unwrap();
725
726        let x_test = vec![vec![5.5], vec![7.5]];
727        let predictions = irls.predict(&x_test).unwrap();
728
729        assert_eq!(predictions.len(), 2);
730
731        // Predictions should be reasonable
732        assert!(predictions[0] > 10.0 && predictions[0] < 15.0);
733        assert!(predictions[1] > 14.0 && predictions[1] < 18.0);
734    }
735
736    #[test]
737    fn test_irls_no_intercept() {
738        let mut irls = IRLSEstimator::new().with_fit_intercept(false);
739
740        let x = vec![vec![1.0], vec![2.0], vec![3.0]];
741        let y = vec![2.0, 4.0, 6.0]; // y = 2*x
742
743        let result = irls.fit(&x, &y);
744        assert!(result.is_ok());
745
746        assert_eq!(irls.get_intercept(), None);
747
748        let coefficients = irls.get_coefficients().unwrap();
749        assert!((coefficients[0] - 2.0).abs() < 0.1);
750    }
751
752    #[test]
753    fn test_irls_multivariate() {
754        let mut irls = IRLSEstimator::new();
755
756        let x = vec![
757            vec![1.0, 2.0],
758            vec![2.0, 3.0],
759            vec![3.0, 4.0],
760            vec![4.0, 5.0],
761            vec![5.0, 6.0],
762        ];
763        let y = vec![8.0, 13.0, 18.0, 23.0, 28.0]; // y = 1*x1 + 3*x2 + 1
764
765        let result = irls.fit(&x, &y);
766        assert!(result.is_ok());
767
768        let coefficients = irls.get_coefficients().unwrap();
769        assert_eq!(coefficients.len(), 2);
770    }
771
772    #[test]
773    fn test_irls_convergence() {
774        let mut irls = IRLSEstimator::new().with_max_iter(5).with_tolerance(1e-3);
775
776        let (x, y) = create_sample_data();
777        irls.fit(&x, &y).unwrap();
778
779        let result = irls.get_result().unwrap();
780        assert!(result.n_iter <= 5);
781        assert!(!result.convergence_history.is_empty());
782    }
783
784    #[test]
785    fn test_irls_different_scale_estimators() {
786        let (x, y) = create_sample_data();
787
788        let scale_estimators = vec![
789            ScaleEstimator::MAD,
790            ScaleEstimator::StandardDeviation,
791            ScaleEstimator::IQR,
792            ScaleEstimator::Fixed(1.0),
793        ];
794
795        for scale_estimator in scale_estimators {
796            let mut irls = IRLSEstimator::new().with_scale_estimator(scale_estimator);
797
798            let result = irls.fit(&x, &y);
799            assert!(
800                result.is_ok(),
801                "Failed with scale estimator: {:?}",
802                irls.config.scale_estimator
803            );
804        }
805    }
806
807    #[test]
808    fn test_irls_weight_functions() {
809        let (x, y) = create_sample_data();
810
811        let weight_functions = vec![
812            WeightFunction::Huber { c: 1.345 },
813            WeightFunction::Bisquare { c: 4.685 },
814            WeightFunction::Andrews { c: 1.339 },
815            WeightFunction::Cauchy { c: 2.385 },
816            WeightFunction::Fair { c: 1.4 },
817            WeightFunction::Logistic { c: 1.2 },
818        ];
819
820        for weight_function in weight_functions {
821            let mut irls = IRLSEstimator::new().with_weight_function(weight_function.clone());
822
823            let result = irls.fit(&x, &y);
824            assert!(
825                result.is_ok(),
826                "Failed with weight function: {:?}",
827                weight_function
828            );
829        }
830    }
831
832    #[test]
833    fn test_irls_empty_data_error() {
834        let mut irls = IRLSEstimator::new();
835        let x: Vec<Vec<f64>> = vec![];
836        let y: Vec<f64> = vec![];
837
838        let result = irls.fit(&x, &y);
839        assert!(result.is_err());
840    }
841
842    #[test]
843    fn test_irls_dimension_mismatch_error() {
844        let mut irls = IRLSEstimator::new();
845        let (x, _) = create_sample_data();
846        let wrong_y = vec![1.0, 2.0]; // Wrong length
847
848        let result = irls.fit(&x, &wrong_y);
849        assert!(result.is_err());
850    }
851
852    #[test]
853    fn test_irls_predict_before_fit_error() {
854        let irls = IRLSEstimator::new();
855        let x = vec![vec![1.0]];
856
857        let result = irls.predict(&x);
858        assert!(result.is_err());
859    }
860
861    #[test]
862    fn test_irls_regularization() {
863        let mut irls = IRLSEstimator::new().with_alpha(0.1); // Add L2 regularization
864
865        let (x, y) = create_sample_data();
866        let result = irls.fit(&x, &y);
867
868        assert!(result.is_ok());
869
870        // With regularization, coefficients should be slightly smaller
871        let coefficients = irls.get_coefficients().unwrap();
872        assert!(coefficients[0] < 2.1); // Regularized coefficient
873    }
874}