1use ferrolearn_core::error::FerroError;
39use ferrolearn_core::introspection::HasCoefficients;
40use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
41use ferrolearn_core::traits::{Fit, Predict};
42use ndarray::{Array1, Array2, Axis, ScalarOperand};
43use num_traits::{Float, FromPrimitive};
44
45#[derive(Debug, Clone)]
54pub struct ARDRegression<F> {
55 pub max_iter: usize,
57 pub tol: F,
59 pub alpha_1: F,
61 pub alpha_2: F,
63 pub lambda_1: F,
65 pub lambda_2: F,
67 pub threshold_lambda: F,
69 pub fit_intercept: bool,
71}
72
73impl<F: Float + FromPrimitive> ARDRegression<F> {
74 #[must_use]
80 pub fn new() -> Self {
81 Self {
82 max_iter: 300,
83 tol: F::from(1e-3).unwrap(),
84 alpha_1: F::from(1e-6).unwrap(),
85 alpha_2: F::from(1e-6).unwrap(),
86 lambda_1: F::from(1e-6).unwrap(),
87 lambda_2: F::from(1e-6).unwrap(),
88 threshold_lambda: F::from(1e4).unwrap(),
89 fit_intercept: true,
90 }
91 }
92
93 #[must_use]
95 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
96 self.max_iter = max_iter;
97 self
98 }
99
100 #[must_use]
102 pub fn with_tol(mut self, tol: F) -> Self {
103 self.tol = tol;
104 self
105 }
106
107 #[must_use]
109 pub fn with_alpha_1(mut self, alpha_1: F) -> Self {
110 self.alpha_1 = alpha_1;
111 self
112 }
113
114 #[must_use]
116 pub fn with_alpha_2(mut self, alpha_2: F) -> Self {
117 self.alpha_2 = alpha_2;
118 self
119 }
120
121 #[must_use]
123 pub fn with_lambda_1(mut self, lambda_1: F) -> Self {
124 self.lambda_1 = lambda_1;
125 self
126 }
127
128 #[must_use]
130 pub fn with_lambda_2(mut self, lambda_2: F) -> Self {
131 self.lambda_2 = lambda_2;
132 self
133 }
134
135 #[must_use]
137 pub fn with_threshold_lambda(mut self, threshold_lambda: F) -> Self {
138 self.threshold_lambda = threshold_lambda;
139 self
140 }
141
142 #[must_use]
144 pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
145 self.fit_intercept = fit_intercept;
146 self
147 }
148}
149
150impl<F: Float + FromPrimitive> Default for ARDRegression<F> {
151 fn default() -> Self {
152 Self::new()
153 }
154}
155
156#[derive(Debug, Clone)]
162pub struct FittedARDRegression<F> {
163 coefficients: Array1<F>,
165 intercept: F,
167 alpha: F,
169 lambda: Array1<F>,
171 sigma: Array1<F>,
173}
174
175impl<F: Float> FittedARDRegression<F> {
176 #[must_use]
178 pub fn alpha(&self) -> F {
179 self.alpha
180 }
181
182 #[must_use]
184 pub fn lambda(&self) -> &Array1<F> {
185 &self.lambda
186 }
187
188 #[must_use]
190 pub fn sigma(&self) -> &Array1<F> {
191 &self.sigma
192 }
193}
194
195fn ard_solve<F: Float + FromPrimitive + 'static>(
199 x: &Array2<F>,
200 y: &Array1<F>,
201 alpha: F,
202 lambda: &Array1<F>,
203) -> Result<(Array1<F>, Array1<F>), FerroError> {
204 let n_features = x.ncols();
205 let xt = x.t();
206 let mut xtx = xt.dot(x);
207
208 for i in 0..n_features {
210 for j in 0..n_features {
211 xtx[[i, j]] = xtx[[i, j]] * alpha;
212 }
213 xtx[[i, i]] = xtx[[i, i]] + lambda[i];
214 }
215
216 let xty = xt.dot(y);
217 let xty_scaled: Array1<F> = xty.mapv(|v| v * alpha);
218
219 let n = n_features;
221 let mut l = Array2::<F>::zeros((n, n));
222
223 for i in 0..n {
224 for j in 0..=i {
225 let mut s = xtx[[i, j]];
226 for k in 0..j {
227 s = s - l[[i, k]] * l[[j, k]];
228 }
229 if i == j {
230 if s <= F::zero() {
231 return Err(FerroError::NumericalInstability {
232 message: "ARD: matrix not positive definite".into(),
233 });
234 }
235 l[[i, j]] = s.sqrt();
236 } else {
237 l[[i, j]] = s / l[[j, j]];
238 }
239 }
240 }
241
242 let mut z = Array1::<F>::zeros(n);
244 for i in 0..n {
245 let mut s = xty_scaled[i];
246 for j in 0..i {
247 s = s - l[[i, j]] * z[j];
248 }
249 z[i] = s / l[[i, i]];
250 }
251
252 let mut w = Array1::<F>::zeros(n);
254 for i in (0..n).rev() {
255 let mut s = z[i];
256 for j in (i + 1)..n {
257 s = s - l[[j, i]] * w[j];
258 }
259 w[i] = s / l[[i, i]];
260 }
261
262 let mut sigma_diag = Array1::<F>::zeros(n);
264 for col in 0..n {
265 let mut z_inv = Array1::<F>::zeros(n);
266 z_inv[col] = F::one() / l[[col, col]];
267 for i in (col + 1)..n {
268 let mut s = F::zero();
269 for k in col..i {
270 s = s + l[[i, k]] * z_inv[k];
271 }
272 z_inv[i] = -s / l[[i, i]];
273 }
274 for i in 0..n {
275 sigma_diag[i] = sigma_diag[i] + z_inv[i] * z_inv[i];
276 }
277 }
278
279 Ok((w, sigma_diag))
280}
281
282impl<F: Float + Send + Sync + ScalarOperand + FromPrimitive + 'static> Fit<Array2<F>, Array1<F>>
283 for ARDRegression<F>
284{
285 type Fitted = FittedARDRegression<F>;
286 type Error = FerroError;
287
288 fn fit(
296 &self,
297 x: &Array2<F>,
298 y: &Array1<F>,
299 ) -> Result<FittedARDRegression<F>, FerroError> {
300 let (n_samples, n_features) = x.dim();
301
302 if n_samples != y.len() {
303 return Err(FerroError::ShapeMismatch {
304 expected: vec![n_samples],
305 actual: vec![y.len()],
306 context: "y length must match number of samples in X".into(),
307 });
308 }
309
310 if n_samples < 2 {
311 return Err(FerroError::InsufficientSamples {
312 required: 2,
313 actual: n_samples,
314 context: "ARDRegression requires at least 2 samples".into(),
315 });
316 }
317
318 let n_f = F::from(n_samples).unwrap();
319
320 let (x_work, y_work, x_mean, y_mean) = if self.fit_intercept {
322 let x_mean = x
323 .mean_axis(Axis(0))
324 .ok_or_else(|| FerroError::NumericalInstability {
325 message: "failed to compute column means".into(),
326 })?;
327 let y_mean = y.mean().ok_or_else(|| FerroError::NumericalInstability {
328 message: "failed to compute target mean".into(),
329 })?;
330 let x_c = x - &x_mean;
331 let y_c = y - y_mean;
332 (x_c, y_c, Some(x_mean), Some(y_mean))
333 } else {
334 (x.clone(), y.clone(), None, None)
335 };
336
337 let mut alpha = F::one();
338 let mut lambda = Array1::<F>::from_elem(n_features, F::one());
339 let clamp_max = F::from(1e10).unwrap();
340 let clamp_min = F::from(1e-10).unwrap();
341
342 let mut w = Array1::<F>::zeros(n_features);
343 let mut sigma_diag = Array1::<F>::ones(n_features);
344
345 for _iter in 0..self.max_iter {
346 let alpha_old = alpha;
347 let lambda_old = lambda.clone();
348
349 let (w_new, sd_new) = ard_solve(&x_work, &y_work, alpha, &lambda)?;
351
352 let gamma: Array1<F> = Array1::from_shape_fn(n_features, |i| {
354 F::one() - lambda[i] * sd_new[i]
355 });
356
357 let gamma_sum: F = gamma.iter().fold(F::zero(), |a, &b| a + b);
358
359 let residual = &y_work - x_work.dot(&w_new);
361 let sse = residual.dot(&residual);
362 let two = F::from(2.0).unwrap();
363 let new_alpha = (n_f - gamma_sum + two * self.alpha_1)
364 / (sse + two * self.alpha_2).max(F::from(1e-300).unwrap());
365
366 let mut new_lambda = Array1::<F>::zeros(n_features);
368 for i in 0..n_features {
369 let wi_sq = w_new[i] * w_new[i];
370 new_lambda[i] = (gamma[i] + two * self.lambda_1)
371 / (wi_sq + two * self.lambda_2).max(F::from(1e-300).unwrap());
372 }
373
374 alpha = new_alpha.min(clamp_max).max(clamp_min);
376 for i in 0..n_features {
377 new_lambda[i] = new_lambda[i].min(clamp_max).max(clamp_min);
378 }
379 lambda = new_lambda;
380
381 w = w_new;
382 sigma_diag = sd_new;
383
384 let delta_alpha =
386 (alpha - alpha_old).abs() / (alpha_old.abs() + F::from(1e-10).unwrap());
387 let mut max_delta_lambda = F::zero();
388 for i in 0..n_features {
389 let delta = (lambda[i] - lambda_old[i]).abs()
390 / (lambda_old[i].abs() + F::from(1e-10).unwrap());
391 if delta > max_delta_lambda {
392 max_delta_lambda = delta;
393 }
394 }
395
396 if delta_alpha < self.tol && max_delta_lambda < self.tol {
397 break;
398 }
399 }
400
401 for i in 0..n_features {
403 if lambda[i] > self.threshold_lambda {
404 w[i] = F::zero();
405 }
406 }
407
408 let intercept = if let (Some(xm), Some(ym)) = (&x_mean, &y_mean) {
409 *ym - xm.dot(&w)
410 } else {
411 F::zero()
412 };
413
414 Ok(FittedARDRegression {
415 coefficients: w,
416 intercept,
417 alpha,
418 lambda,
419 sigma: sigma_diag,
420 })
421 }
422}
423
424impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>>
425 for FittedARDRegression<F>
426{
427 type Output = Array1<F>;
428 type Error = FerroError;
429
430 fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
439 let n_features = x.ncols();
440 if n_features != self.coefficients.len() {
441 return Err(FerroError::ShapeMismatch {
442 expected: vec![self.coefficients.len()],
443 actual: vec![n_features],
444 context: "number of features must match fitted model".into(),
445 });
446 }
447
448 let preds = x.dot(&self.coefficients) + self.intercept;
449 Ok(preds)
450 }
451}
452
453impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F>
454 for FittedARDRegression<F>
455{
456 fn coefficients(&self) -> &Array1<F> {
457 &self.coefficients
458 }
459
460 fn intercept(&self) -> F {
461 self.intercept
462 }
463}
464
465impl<F> PipelineEstimator<F> for ARDRegression<F>
467where
468 F: Float + FromPrimitive + ScalarOperand + Send + Sync + 'static,
469{
470 fn fit_pipeline(
471 &self,
472 x: &Array2<F>,
473 y: &Array1<F>,
474 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
475 let fitted = self.fit(x, y)?;
476 Ok(Box::new(fitted))
477 }
478}
479
480impl<F> FittedPipelineEstimator<F> for FittedARDRegression<F>
481where
482 F: Float + ScalarOperand + Send + Sync + 'static,
483{
484 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
485 self.predict(x)
486 }
487}
488
489#[cfg(test)]
490mod tests {
491 use super::*;
492 use approx::assert_relative_eq;
493 use ndarray::array;
494
495 #[test]
496 fn test_default_constructor() {
497 let m = ARDRegression::<f64>::new();
498 assert_eq!(m.max_iter, 300);
499 assert!(m.fit_intercept);
500 assert_relative_eq!(m.alpha_1, 1e-6);
501 }
502
503 #[test]
504 fn test_builder_setters() {
505 let m = ARDRegression::<f64>::new()
506 .with_max_iter(50)
507 .with_tol(1e-6)
508 .with_alpha_1(1e-3)
509 .with_alpha_2(1e-3)
510 .with_lambda_1(1e-3)
511 .with_lambda_2(1e-3)
512 .with_threshold_lambda(1e5)
513 .with_fit_intercept(false);
514 assert_eq!(m.max_iter, 50);
515 assert!(!m.fit_intercept);
516 assert_relative_eq!(m.threshold_lambda, 1e5);
517 }
518
519 #[test]
520 fn test_shape_mismatch() {
521 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
522 let y = array![1.0, 2.0];
523 let result = ARDRegression::<f64>::new().fit(&x, &y);
524 assert!(result.is_err());
525 }
526
527 #[test]
528 fn test_insufficient_samples() {
529 let x = Array2::from_shape_vec((1, 1), vec![1.0]).unwrap();
530 let y = array![1.0];
531 let result = ARDRegression::<f64>::new().fit(&x, &y);
532 assert!(result.is_err());
533 }
534
535 #[test]
536 fn test_fits_linear_data() {
537 let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
538 let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
539
540 let fitted = ARDRegression::<f64>::new().fit(&x, &y).unwrap();
541
542 assert_relative_eq!(fitted.coefficients()[0], 2.0, epsilon = 0.5);
544 assert_relative_eq!(fitted.intercept(), 1.0, epsilon = 1.5);
545 }
546
547 #[test]
548 fn test_alpha_positive() {
549 let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
550 let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
551
552 let fitted = ARDRegression::<f64>::new().fit(&x, &y).unwrap();
553 assert!(fitted.alpha() > 0.0);
554 }
555
556 #[test]
557 fn test_lambda_positive() {
558 let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
559 let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
560
561 let fitted = ARDRegression::<f64>::new().fit(&x, &y).unwrap();
562 for &v in fitted.lambda().iter() {
563 assert!(v > 0.0, "lambda must be positive, got {v}");
564 }
565 }
566
567 #[test]
568 fn test_sigma_positive() {
569 let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
570 let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
571
572 let fitted = ARDRegression::<f64>::new().fit(&x, &y).unwrap();
573 for &v in fitted.sigma().iter() {
574 assert!(v > 0.0, "sigma diagonal must be positive, got {v}");
575 }
576 }
577
578 #[test]
579 fn test_predict_length() {
580 let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
581 let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
582
583 let fitted = ARDRegression::<f64>::new().fit(&x, &y).unwrap();
584 let preds = fitted.predict(&x).unwrap();
585 assert_eq!(preds.len(), 5);
586 }
587
588 #[test]
589 fn test_predict_feature_mismatch() {
590 let x = Array2::from_shape_vec((3, 2), vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0]).unwrap();
591 let y = array![1.0, 2.0, 3.0];
592 let fitted = ARDRegression::<f64>::new().fit(&x, &y).unwrap();
593
594 let x_bad = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
595 assert!(fitted.predict(&x_bad).is_err());
596 }
597
598 #[test]
599 fn test_no_intercept() {
600 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
601 let y = array![2.0, 4.0, 6.0, 8.0];
602
603 let fitted = ARDRegression::<f64>::new()
604 .with_fit_intercept(false)
605 .fit(&x, &y)
606 .unwrap();
607 assert_relative_eq!(fitted.intercept(), 0.0, epsilon = 1e-10);
608 }
609
610 #[test]
611 fn test_sparsity_on_irrelevant_features() {
612 let x = Array2::from_shape_vec(
614 (6, 2),
615 vec![1.0, 100.0, 2.0, 200.0, 3.0, 300.0, 4.0, 400.0, 5.0, 500.0, 6.0, 600.0],
616 )
617 .unwrap();
618 let y = array![2.0, 4.0, 6.0, 8.0, 10.0, 12.0]; let fitted = ARDRegression::<f64>::new()
621 .with_max_iter(1000)
622 .fit(&x, &y)
623 .unwrap();
624
625 let preds = fitted.predict(&x).unwrap();
627 assert_eq!(preds.len(), 6);
628 }
629
630 #[test]
631 fn test_has_coefficients_length() {
632 let x = Array2::from_shape_vec(
633 (4, 3),
634 vec![1.0, 0.0, 0.5, 2.0, 1.0, 1.0, 3.0, 0.0, 1.5, 4.0, 1.0, 2.0],
635 )
636 .unwrap();
637 let y = array![1.0, 2.0, 3.0, 4.0];
638 let fitted = ARDRegression::<f64>::new().fit(&x, &y).unwrap();
639 assert_eq!(fitted.coefficients().len(), 3);
640 }
641
642 #[test]
643 fn test_pipeline_integration() {
644 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
645 let y = array![3.0, 5.0, 7.0, 9.0];
646
647 let model = ARDRegression::<f64>::new();
648 let fitted_pipe = model.fit_pipeline(&x, &y).unwrap();
649 let preds = fitted_pipe.predict_pipeline(&x).unwrap();
650 assert_eq!(preds.len(), 4);
651 }
652}