1use sklears_core::error::SklearsError;
8use std::cmp::Ordering;
9
10#[derive(Debug, Clone, PartialEq)]
12pub enum WeightFunction {
13 Huber { c: f64 },
15 Bisquare { c: f64 },
17 Andrews { c: f64 },
19 Cauchy { c: f64 },
21 Fair { c: f64 },
23 Logistic { c: f64 },
25}
26
27#[derive(Debug, Clone, PartialEq)]
29pub enum ScaleEstimator {
30 MAD,
32 StandardDeviation,
34 IQR,
36 Fixed(f64),
38}
39
40#[derive(Debug, Clone)]
42pub struct IRLSConfig {
43 pub weight_function: WeightFunction,
45 pub scale_estimator: ScaleEstimator,
47 pub max_iter: usize,
49 pub tol: f64,
51 pub fit_intercept: bool,
53 pub initial_scale: Option<f64>,
55 pub min_weight: f64,
57 pub update_scale: bool,
59 pub alpha: f64,
61}
62
63impl Default for IRLSConfig {
64 fn default() -> Self {
65 Self {
66 weight_function: WeightFunction::Huber { c: 1.345 },
67 scale_estimator: ScaleEstimator::MAD,
68 max_iter: 100,
69 tol: 1e-6,
70 fit_intercept: true,
71 initial_scale: None,
72 min_weight: 1e-8,
73 update_scale: true,
74 alpha: 0.0,
75 }
76 }
77}
78
79#[derive(Debug, Clone)]
81pub struct IRLSResult {
82 pub coefficients: Vec<f64>,
84 pub intercept: Option<f64>,
86 pub weights: Vec<f64>,
88 pub scale: f64,
90 pub n_iter: usize,
92 pub converged: bool,
94 pub convergence_history: Vec<f64>,
96 pub config: IRLSConfig,
98}
99
100pub struct IRLSEstimator {
102 config: IRLSConfig,
103 is_fitted: bool,
104 result: Option<IRLSResult>,
105}
106
107impl IRLSEstimator {
108 pub fn new() -> Self {
110 Self {
111 config: IRLSConfig::default(),
112 is_fitted: false,
113 result: None,
114 }
115 }
116
117 pub fn with_config(config: IRLSConfig) -> Self {
119 Self {
120 config,
121 is_fitted: false,
122 result: None,
123 }
124 }
125
126 pub fn with_weight_function(mut self, weight_function: WeightFunction) -> Self {
128 self.config.weight_function = weight_function;
129 self
130 }
131
132 pub fn with_scale_estimator(mut self, scale_estimator: ScaleEstimator) -> Self {
134 self.config.scale_estimator = scale_estimator;
135 self
136 }
137
138 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
140 self.config.max_iter = max_iter;
141 self
142 }
143
144 pub fn with_tolerance(mut self, tol: f64) -> Self {
146 self.config.tol = tol;
147 self
148 }
149
150 pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
152 self.config.fit_intercept = fit_intercept;
153 self
154 }
155
156 pub fn with_alpha(mut self, alpha: f64) -> Self {
158 self.config.alpha = alpha;
159 self
160 }
161
162 pub fn fit(&mut self, x: &[Vec<f64>], y: &[f64]) -> Result<(), SklearsError> {
164 if x.is_empty() || y.is_empty() {
165 return Err(SklearsError::InvalidInput(
166 "Cannot fit IRLS on empty dataset".to_string(),
167 ));
168 }
169
170 let n_samples = x.len();
171 let n_features = x[0].len();
172
173 if y.len() != n_samples {
174 return Err(SklearsError::ShapeMismatch {
175 expected: format!("target.len() == {}", n_samples),
176 actual: format!("target.len() == {}", y.len()),
177 });
178 }
179
180 for (i, row) in x.iter().enumerate() {
182 if row.len() != n_features {
183 return Err(SklearsError::ShapeMismatch {
184 expected: format!("row[{}].len() == {}", i, n_features),
185 actual: format!("row[{}].len() == {}", i, row.len()),
186 });
187 }
188 }
189
190 let x_matrix = if self.config.fit_intercept {
192 self.add_intercept_column(x)
193 } else {
194 x.to_vec()
195 };
196
197 let _effective_n_features = x_matrix[0].len();
198
199 let mut coefficients = self.ordinary_least_squares(&x_matrix, y)?;
201 let mut weights = vec![1.0; n_samples];
202 let mut convergence_history = Vec::new();
203
204 let mut residuals = self.compute_residuals(&x_matrix, y, &coefficients);
206 let mut scale = self.estimate_scale(&residuals)?;
207
208 let mut converged = false;
209 let mut n_iter = 0;
210
211 for iteration in 0..self.config.max_iter {
213 n_iter = iteration + 1;
214
215 self.update_weights(&residuals, scale, &mut weights);
217
218 let new_coefficients = self.weighted_least_squares(&x_matrix, y, &weights)?;
220
221 let coefficient_change =
223 self.compute_coefficient_change(&coefficients, &new_coefficients);
224 convergence_history.push(coefficient_change);
225
226 if coefficient_change < self.config.tol {
227 converged = true;
228 coefficients = new_coefficients;
229 break;
230 }
231
232 coefficients = new_coefficients;
233
234 residuals = self.compute_residuals(&x_matrix, y, &coefficients);
236
237 if self.config.update_scale {
239 scale = self.estimate_scale(&residuals)?;
240 }
241 }
242
243 let (final_coefficients, intercept) = if self.config.fit_intercept {
245 let intercept = coefficients[0];
246 let coefs = coefficients[1..].to_vec();
247 (coefs, Some(intercept))
248 } else {
249 (coefficients, None)
250 };
251
252 self.result = Some(IRLSResult {
253 coefficients: final_coefficients,
254 intercept,
255 weights,
256 scale,
257 n_iter,
258 converged,
259 convergence_history,
260 config: self.config.clone(),
261 });
262
263 self.is_fitted = true;
264 Ok(())
265 }
266
267 pub fn predict(&self, x: &[Vec<f64>]) -> Result<Vec<f64>, SklearsError> {
269 if !self.is_fitted {
270 return Err(SklearsError::NotFitted {
271 operation: "predict".to_string(),
272 });
273 }
274
275 let result = self.result.as_ref().unwrap();
276
277 if x.is_empty() {
278 return Ok(Vec::new());
279 }
280
281 let n_features = x[0].len();
282 if n_features != result.coefficients.len() {
283 return Err(SklearsError::FeatureMismatch {
284 expected: result.coefficients.len(),
285 actual: n_features,
286 });
287 }
288
289 let mut predictions = Vec::new();
290
291 for row in x {
292 let mut pred = 0.0;
293 for (i, &coef) in result.coefficients.iter().enumerate() {
294 pred += coef * row[i];
295 }
296
297 if let Some(intercept) = result.intercept {
298 pred += intercept;
299 }
300
301 predictions.push(pred);
302 }
303
304 Ok(predictions)
305 }
306
307 pub fn get_result(&self) -> Option<&IRLSResult> {
309 self.result.as_ref()
310 }
311
312 pub fn get_coefficients(&self) -> Option<&Vec<f64>> {
314 self.result.as_ref().map(|r| &r.coefficients)
315 }
316
317 pub fn get_intercept(&self) -> Option<f64> {
319 self.result.as_ref().and_then(|r| r.intercept)
320 }
321
322 pub fn get_weights(&self) -> Option<&Vec<f64>> {
324 self.result.as_ref().map(|r| &r.weights)
325 }
326
327 fn add_intercept_column(&self, x: &[Vec<f64>]) -> Vec<Vec<f64>> {
329 x.iter()
330 .map(|row| {
331 let mut new_row = vec![1.0];
332 new_row.extend(row);
333 new_row
334 })
335 .collect()
336 }
337
338 fn ordinary_least_squares(&self, x: &[Vec<f64>], y: &[f64]) -> Result<Vec<f64>, SklearsError> {
340 let n_samples = x.len();
341 let n_features = x[0].len();
342
343 let mut xtx = vec![vec![0.0; n_features]; n_features];
345 #[allow(clippy::needless_range_loop)]
346 for i in 0..n_features {
347 for j in 0..n_features {
348 #[allow(clippy::needless_range_loop)]
349 for k in 0..n_samples {
350 xtx[i][j] += x[k][i] * x[k][j];
351 }
352
353 if i == j {
355 xtx[i][j] += self.config.alpha;
356 }
357 }
358 }
359
360 let mut xty = vec![0.0; n_features];
362 #[allow(clippy::needless_range_loop)]
363 for i in 0..n_features {
364 for j in 0..n_samples {
365 xty[i] += x[j][i] * y[j];
366 }
367 }
368
369 self.solve_linear_system(&xtx, &xty)
371 }
372
373 fn weighted_least_squares(
375 &self,
376 x: &[Vec<f64>],
377 y: &[f64],
378 weights: &[f64],
379 ) -> Result<Vec<f64>, SklearsError> {
380 let n_samples = x.len();
381 let n_features = x[0].len();
382
383 let mut xtwx = vec![vec![0.0; n_features]; n_features];
385 #[allow(clippy::needless_range_loop)]
386 for i in 0..n_features {
387 for j in 0..n_features {
388 for k in 0..n_samples {
389 xtwx[i][j] += weights[k] * x[k][i] * x[k][j];
390 }
391
392 if i == j {
394 xtwx[i][j] += self.config.alpha;
395 }
396 }
397 }
398
399 let mut xtwy = vec![0.0; n_features];
401 #[allow(clippy::needless_range_loop)]
402 for i in 0..n_features {
403 for j in 0..n_samples {
404 xtwy[i] += weights[j] * x[j][i] * y[j];
405 }
406 }
407
408 self.solve_linear_system(&xtwx, &xtwy)
410 }
411
412 fn solve_linear_system(&self, a: &[Vec<f64>], b: &[f64]) -> Result<Vec<f64>, SklearsError> {
414 match Self::gaussian_elimination(a, b) {
415 Ok(solution) => Ok(solution),
416 Err(original_error) => {
417 let mut regularized = a.to_vec();
419 if regularized.is_empty() || regularized[0].is_empty() {
420 return Err(original_error);
421 }
422
423 let ridge = if self.config.alpha > 0.0 {
424 self.config.alpha
425 } else {
426 1e-6
427 };
428
429 #[allow(clippy::needless_range_loop)]
430 for i in 0..regularized.len() {
431 regularized[i][i] += ridge;
432 }
433
434 match Self::gaussian_elimination(®ularized, b) {
435 Ok(solution) => Ok(solution),
436 Err(_) => Err(original_error),
437 }
438 }
439 }
440 }
441
442 fn gaussian_elimination(a: &[Vec<f64>], b: &[f64]) -> Result<Vec<f64>, SklearsError> {
443 let n = a.len();
444 if n == 0 || b.len() != n {
445 return Err(SklearsError::InvalidInput(
446 "Matrix dimensions do not align for Gaussian elimination".to_string(),
447 ));
448 }
449
450 let mut aug_matrix = vec![vec![0.0; n + 1]; n];
451
452 for i in 0..n {
454 if a[i].len() != n {
455 return Err(SklearsError::InvalidInput(
456 "Matrix must be square for Gaussian elimination".to_string(),
457 ));
458 }
459
460 for j in 0..n {
461 aug_matrix[i][j] = a[i][j];
462 }
463 aug_matrix[i][n] = b[i];
464 }
465
466 for i in 0..n {
468 let mut max_row = i;
470 for k in i + 1..n {
471 if aug_matrix[k][i].abs() > aug_matrix[max_row][i].abs() {
472 max_row = k;
473 }
474 }
475
476 if max_row != i {
478 aug_matrix.swap(i, max_row);
479 }
480
481 if aug_matrix[i][i].abs() < 1e-12 {
483 return Err(SklearsError::InvalidInput(
484 "Matrix is singular or nearly singular. Add regularization or check for multicollinearity".to_string(),
485 ));
486 }
487
488 for k in i + 1..n {
490 let factor = aug_matrix[k][i] / aug_matrix[i][i];
491 for j in i..n + 1 {
492 aug_matrix[k][j] -= factor * aug_matrix[i][j];
493 }
494 }
495 }
496
497 let mut solution = vec![0.0; n];
499 for i in (0..n).rev() {
500 solution[i] = aug_matrix[i][n];
501 for j in i + 1..n {
502 solution[i] -= aug_matrix[i][j] * solution[j];
503 }
504 solution[i] /= aug_matrix[i][i];
505 }
506
507 Ok(solution)
508 }
509
510 fn compute_residuals(&self, x: &[Vec<f64>], y: &[f64], coefficients: &[f64]) -> Vec<f64> {
512 let mut residuals = Vec::new();
513
514 for (i, row) in x.iter().enumerate() {
515 let mut pred = 0.0;
516 for (j, &coef) in coefficients.iter().enumerate() {
517 pred += coef * row[j];
518 }
519 residuals.push(y[i] - pred);
520 }
521
522 residuals
523 }
524
525 fn estimate_scale(&self, residuals: &[f64]) -> Result<f64, SklearsError> {
527 if residuals.is_empty() {
528 return Err(SklearsError::InvalidInput(
529 "Cannot estimate scale from empty residuals".to_string(),
530 ));
531 }
532
533 let scale = match &self.config.scale_estimator {
534 ScaleEstimator::MAD => {
535 let mut abs_residuals: Vec<f64> = residuals.iter().map(|&r| r.abs()).collect();
536 abs_residuals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
537 let median = abs_residuals[abs_residuals.len() / 2];
538 median * 1.4826 }
540
541 ScaleEstimator::StandardDeviation => {
542 let mean = residuals.iter().sum::<f64>() / residuals.len() as f64;
543 let variance = residuals.iter().map(|&r| (r - mean).powi(2)).sum::<f64>()
544 / (residuals.len() - 1) as f64;
545 variance.sqrt()
546 }
547
548 ScaleEstimator::IQR => {
549 let mut sorted_residuals: Vec<f64> = residuals.to_vec();
550 sorted_residuals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
551 let n = sorted_residuals.len();
552 let q1 = sorted_residuals[n / 4];
553 let q3 = sorted_residuals[3 * n / 4];
554 (q3 - q1) / 1.349 }
556
557 ScaleEstimator::Fixed(scale) => *scale,
558 };
559
560 if scale <= 0.0 {
561 Ok(1e-6) } else {
563 Ok(scale)
564 }
565 }
566
567 fn update_weights(&self, residuals: &[f64], scale: f64, weights: &mut [f64]) {
569 for (i, &residual) in residuals.iter().enumerate() {
570 let standardized_residual = residual / scale;
571 weights[i] = self
572 .compute_weight(standardized_residual)
573 .max(self.config.min_weight);
574 }
575 }
576
577 fn compute_weight(&self, r: f64) -> f64 {
579 match &self.config.weight_function {
580 WeightFunction::Huber { c } => {
581 if r.abs() <= *c {
582 1.0
583 } else {
584 c / r.abs()
585 }
586 }
587
588 WeightFunction::Bisquare { c } => {
589 if r.abs() <= *c {
590 let ratio = r / c;
591 (1.0 - ratio.powi(2)).powi(2)
592 } else {
593 0.0
594 }
595 }
596
597 WeightFunction::Andrews { c } => {
598 if r.abs() <= *c {
599 let ratio = std::f64::consts::PI * r / c;
600 if ratio.abs() < 1e-10 {
601 1.0
602 } else {
603 ratio.sin() / ratio
604 }
605 } else {
606 0.0
607 }
608 }
609
610 WeightFunction::Cauchy { c } => 1.0 / (1.0 + (r / c).powi(2)),
611
612 WeightFunction::Fair { c } => 1.0 / (1.0 + r.abs() / c),
613
614 WeightFunction::Logistic { c } => {
615 let cr = c * r;
616 if cr.abs() < 1e-10 {
617 1.0
618 } else {
619 cr.tanh() / cr
620 }
621 }
622 }
623 }
624
625 fn compute_coefficient_change(&self, old_coefs: &[f64], new_coefs: &[f64]) -> f64 {
627 old_coefs
628 .iter()
629 .zip(new_coefs.iter())
630 .map(|(&old, &new)| (old - new).powi(2))
631 .sum::<f64>()
632 .sqrt()
633 }
634}
635
636impl Default for IRLSEstimator {
637 fn default() -> Self {
638 Self::new()
639 }
640}
641
642#[allow(non_snake_case)]
643#[cfg(test)]
644mod tests {
645 use super::*;
646
647 fn create_sample_data() -> (Vec<Vec<f64>>, Vec<f64>) {
648 let x = vec![
650 vec![1.0],
651 vec![2.0],
652 vec![3.0],
653 vec![4.0],
654 vec![5.0],
655 vec![6.0],
656 vec![7.0],
657 vec![8.0],
658 vec![9.0],
659 vec![10.0],
660 ];
661
662 let mut y = vec![3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0, 17.0, 19.0, 21.0];
664 y[8] = 50.0; y[9] = 5.0; (x, y)
668 }
669
670 #[test]
671 fn test_irls_basic() {
672 let mut irls = IRLSEstimator::new();
673 let (x, y) = create_sample_data();
674
675 let result = irls.fit(&x, &y);
676 assert!(result.is_ok());
677
678 let coefficients = irls.get_coefficients().unwrap();
679 assert_eq!(coefficients.len(), 1);
680
681 assert!((coefficients[0] - 2.0).abs() < 0.5);
683
684 let intercept = irls.get_intercept().unwrap();
685 assert!((intercept - 1.0).abs() < 1.0);
687 }
688
689 #[test]
690 fn test_irls_huber() {
691 let mut irls =
692 IRLSEstimator::new().with_weight_function(WeightFunction::Huber { c: 1.345 });
693
694 let (x, y) = create_sample_data();
695 let result = irls.fit(&x, &y);
696
697 assert!(result.is_ok());
698
699 let weights = irls.get_weights().unwrap();
700 assert_eq!(weights.len(), y.len());
701
702 assert!(weights[8] < weights[0]); assert!(weights[9] < weights[0]); }
706
707 #[test]
708 fn test_irls_bisquare() {
709 let mut irls =
710 IRLSEstimator::new().with_weight_function(WeightFunction::Bisquare { c: 4.685 });
711
712 let (x, y) = create_sample_data();
713 let result = irls.fit(&x, &y);
714
715 assert!(result.is_ok());
716 assert!(irls.is_fitted);
717 }
718
719 #[test]
720 fn test_irls_prediction() {
721 let mut irls = IRLSEstimator::new();
722 let (x, y) = create_sample_data();
723
724 irls.fit(&x, &y).unwrap();
725
726 let x_test = vec![vec![5.5], vec![7.5]];
727 let predictions = irls.predict(&x_test).unwrap();
728
729 assert_eq!(predictions.len(), 2);
730
731 assert!(predictions[0] > 10.0 && predictions[0] < 15.0);
733 assert!(predictions[1] > 14.0 && predictions[1] < 18.0);
734 }
735
736 #[test]
737 fn test_irls_no_intercept() {
738 let mut irls = IRLSEstimator::new().with_fit_intercept(false);
739
740 let x = vec![vec![1.0], vec![2.0], vec![3.0]];
741 let y = vec![2.0, 4.0, 6.0]; let result = irls.fit(&x, &y);
744 assert!(result.is_ok());
745
746 assert_eq!(irls.get_intercept(), None);
747
748 let coefficients = irls.get_coefficients().unwrap();
749 assert!((coefficients[0] - 2.0).abs() < 0.1);
750 }
751
752 #[test]
753 fn test_irls_multivariate() {
754 let mut irls = IRLSEstimator::new();
755
756 let x = vec![
757 vec![1.0, 2.0],
758 vec![2.0, 3.0],
759 vec![3.0, 4.0],
760 vec![4.0, 5.0],
761 vec![5.0, 6.0],
762 ];
763 let y = vec![8.0, 13.0, 18.0, 23.0, 28.0]; let result = irls.fit(&x, &y);
766 assert!(result.is_ok());
767
768 let coefficients = irls.get_coefficients().unwrap();
769 assert_eq!(coefficients.len(), 2);
770 }
771
772 #[test]
773 fn test_irls_convergence() {
774 let mut irls = IRLSEstimator::new().with_max_iter(5).with_tolerance(1e-3);
775
776 let (x, y) = create_sample_data();
777 irls.fit(&x, &y).unwrap();
778
779 let result = irls.get_result().unwrap();
780 assert!(result.n_iter <= 5);
781 assert!(!result.convergence_history.is_empty());
782 }
783
784 #[test]
785 fn test_irls_different_scale_estimators() {
786 let (x, y) = create_sample_data();
787
788 let scale_estimators = vec![
789 ScaleEstimator::MAD,
790 ScaleEstimator::StandardDeviation,
791 ScaleEstimator::IQR,
792 ScaleEstimator::Fixed(1.0),
793 ];
794
795 for scale_estimator in scale_estimators {
796 let mut irls = IRLSEstimator::new().with_scale_estimator(scale_estimator);
797
798 let result = irls.fit(&x, &y);
799 assert!(
800 result.is_ok(),
801 "Failed with scale estimator: {:?}",
802 irls.config.scale_estimator
803 );
804 }
805 }
806
807 #[test]
808 fn test_irls_weight_functions() {
809 let (x, y) = create_sample_data();
810
811 let weight_functions = vec![
812 WeightFunction::Huber { c: 1.345 },
813 WeightFunction::Bisquare { c: 4.685 },
814 WeightFunction::Andrews { c: 1.339 },
815 WeightFunction::Cauchy { c: 2.385 },
816 WeightFunction::Fair { c: 1.4 },
817 WeightFunction::Logistic { c: 1.2 },
818 ];
819
820 for weight_function in weight_functions {
821 let mut irls = IRLSEstimator::new().with_weight_function(weight_function.clone());
822
823 let result = irls.fit(&x, &y);
824 assert!(
825 result.is_ok(),
826 "Failed with weight function: {:?}",
827 weight_function
828 );
829 }
830 }
831
832 #[test]
833 fn test_irls_empty_data_error() {
834 let mut irls = IRLSEstimator::new();
835 let x: Vec<Vec<f64>> = vec![];
836 let y: Vec<f64> = vec![];
837
838 let result = irls.fit(&x, &y);
839 assert!(result.is_err());
840 }
841
842 #[test]
843 fn test_irls_dimension_mismatch_error() {
844 let mut irls = IRLSEstimator::new();
845 let (x, _) = create_sample_data();
846 let wrong_y = vec![1.0, 2.0]; let result = irls.fit(&x, &wrong_y);
849 assert!(result.is_err());
850 }
851
852 #[test]
853 fn test_irls_predict_before_fit_error() {
854 let irls = IRLSEstimator::new();
855 let x = vec![vec![1.0]];
856
857 let result = irls.predict(&x);
858 assert!(result.is_err());
859 }
860
861 #[test]
862 fn test_irls_regularization() {
863 let mut irls = IRLSEstimator::new().with_alpha(0.1); let (x, y) = create_sample_data();
866 let result = irls.fit(&x, &y);
867
868 assert!(result.is_ok());
869
870 let coefficients = irls.get_coefficients().unwrap();
872 assert!(coefficients[0] < 2.1); }
874}