Skip to main content

ferrolearn_linear/
ard.rs

1//! Automatic Relevance Determination (ARD) Regression.
2//!
3//! This module provides [`ARDRegression`], a Bayesian linear regression model
4//! with per-feature weight precision priors. Features whose precision
5//! (`lambda_i`) exceeds a threshold are pruned — their weights are driven to
6//! zero, achieving automatic feature selection.
7//!
8//! # Algorithm
9//!
10//! Starting from initial alpha (noise precision) and per-feature lambda_i
11//! (weight precision) values, the model iterates:
12//!
13//! 1. Solve the regularised posterior: `w = (alpha * X^T X + diag(lambda))^{-1} alpha X^T y`.
14//! 2. Update gamma_i (effective degrees of freedom): `gamma_i = 1 - lambda_i * Sigma_{ii}`.
15//! 3. Update alpha: `alpha = (n - sum(gamma)) / ||y - Xw||^2`.
16//! 4. Update lambda_i: `lambda_i = gamma_i / w_i^2`.
17//!
18//! Features where `lambda_i > threshold_lambda` are pruned.
19//!
20//! # Examples
21//!
22//! ```
23//! use ferrolearn_linear::ard::ARDRegression;
24//! use ferrolearn_core::{Fit, Predict};
25//! use ndarray::{array, Array1, Array2};
26//!
27//! let x = Array2::from_shape_vec((5, 2), vec![
28//!     1.0, 0.0, 2.0, 0.0, 3.0, 0.0, 4.0, 0.0, 5.0, 0.0,
29//! ]).unwrap();
30//! let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
31//!
32//! let model = ARDRegression::<f64>::new();
33//! let fitted = model.fit(&x, &y).unwrap();
34//! let preds = fitted.predict(&x).unwrap();
35//! assert_eq!(preds.len(), 5);
36//! ```
37
38use ferrolearn_core::error::FerroError;
39use ferrolearn_core::introspection::HasCoefficients;
40use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
41use ferrolearn_core::traits::{Fit, Predict};
42use ndarray::{Array1, Array2, Axis, ScalarOperand};
43use num_traits::{Float, FromPrimitive};
44
45/// Automatic Relevance Determination Regression.
46///
47/// Bayesian linear regression with per-feature precision priors. Features
48/// with high precision (small variance) are pruned, achieving sparsity.
49///
50/// # Type Parameters
51///
52/// - `F`: The floating-point type (`f32` or `f64`).
53#[derive(Debug, Clone)]
54pub struct ARDRegression<F> {
55    /// Maximum number of EM iterations.
56    pub max_iter: usize,
57    /// Convergence tolerance on the relative change in alpha/lambda.
58    pub tol: F,
59    /// Shape hyperparameter for the alpha (noise precision) Gamma prior.
60    pub alpha_1: F,
61    /// Rate hyperparameter for the alpha (noise precision) Gamma prior.
62    pub alpha_2: F,
63    /// Shape hyperparameter for the lambda (weight precision) Gamma prior.
64    pub lambda_1: F,
65    /// Rate hyperparameter for the lambda (weight precision) Gamma prior.
66    pub lambda_2: F,
67    /// Features with `lambda_i > threshold_lambda` are pruned.
68    pub threshold_lambda: F,
69    /// Whether to fit an intercept (bias) term.
70    pub fit_intercept: bool,
71}
72
73impl<F: Float + FromPrimitive> ARDRegression<F> {
74    /// Create a new `ARDRegression` with default settings.
75    ///
76    /// Defaults: `max_iter = 300`, `tol = 1e-3`, `alpha_1 = alpha_2 = 1e-6`,
77    /// `lambda_1 = lambda_2 = 1e-6`, `threshold_lambda = 1e4`,
78    /// `fit_intercept = true`.
79    #[must_use]
80    pub fn new() -> Self {
81        Self {
82            max_iter: 300,
83            tol: F::from(1e-3).unwrap(),
84            alpha_1: F::from(1e-6).unwrap(),
85            alpha_2: F::from(1e-6).unwrap(),
86            lambda_1: F::from(1e-6).unwrap(),
87            lambda_2: F::from(1e-6).unwrap(),
88            threshold_lambda: F::from(1e4).unwrap(),
89            fit_intercept: true,
90        }
91    }
92
93    /// Set the maximum number of iterations.
94    #[must_use]
95    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
96        self.max_iter = max_iter;
97        self
98    }
99
100    /// Set the convergence tolerance.
101    #[must_use]
102    pub fn with_tol(mut self, tol: F) -> Self {
103        self.tol = tol;
104        self
105    }
106
107    /// Set the alpha shape hyperparameter.
108    #[must_use]
109    pub fn with_alpha_1(mut self, alpha_1: F) -> Self {
110        self.alpha_1 = alpha_1;
111        self
112    }
113
114    /// Set the alpha rate hyperparameter.
115    #[must_use]
116    pub fn with_alpha_2(mut self, alpha_2: F) -> Self {
117        self.alpha_2 = alpha_2;
118        self
119    }
120
121    /// Set the lambda shape hyperparameter.
122    #[must_use]
123    pub fn with_lambda_1(mut self, lambda_1: F) -> Self {
124        self.lambda_1 = lambda_1;
125        self
126    }
127
128    /// Set the lambda rate hyperparameter.
129    #[must_use]
130    pub fn with_lambda_2(mut self, lambda_2: F) -> Self {
131        self.lambda_2 = lambda_2;
132        self
133    }
134
135    /// Set the pruning threshold for feature lambda values.
136    #[must_use]
137    pub fn with_threshold_lambda(mut self, threshold_lambda: F) -> Self {
138        self.threshold_lambda = threshold_lambda;
139        self
140    }
141
142    /// Set whether to fit an intercept term.
143    #[must_use]
144    pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
145        self.fit_intercept = fit_intercept;
146        self
147    }
148}
149
150impl<F: Float + FromPrimitive> Default for ARDRegression<F> {
151    fn default() -> Self {
152        Self::new()
153    }
154}
155
156/// Fitted ARD Regression model.
157///
158/// Stores the posterior mean coefficients, intercept, estimated noise
159/// precision (`alpha`), per-feature weight precisions (`lambda`), and
160/// the diagonal of the posterior covariance.
161#[derive(Debug, Clone)]
162pub struct FittedARDRegression<F> {
163    /// Posterior mean coefficient vector.
164    coefficients: Array1<F>,
165    /// Intercept (bias) term.
166    intercept: F,
167    /// Estimated noise precision (1 / noise_variance).
168    alpha: F,
169    /// Per-feature weight precisions.
170    lambda: Array1<F>,
171    /// Diagonal of the posterior covariance matrix.
172    sigma: Array1<F>,
173}
174
175impl<F: Float> FittedARDRegression<F> {
176    /// Returns the estimated noise precision (alpha = 1/sigma^2_noise).
177    #[must_use]
178    pub fn alpha(&self) -> F {
179        self.alpha
180    }
181
182    /// Returns the per-feature weight precisions.
183    #[must_use]
184    pub fn lambda(&self) -> &Array1<F> {
185        &self.lambda
186    }
187
188    /// Returns the diagonal of the posterior covariance matrix.
189    #[must_use]
190    pub fn sigma(&self) -> &Array1<F> {
191        &self.sigma
192    }
193}
194
195/// Solve the ARD system: `(alpha * X^T X + diag(lambda)) w = alpha * X^T y`.
196///
197/// Returns `(w, diag(Sigma))`.
198fn ard_solve<F: Float + FromPrimitive + 'static>(
199    x: &Array2<F>,
200    y: &Array1<F>,
201    alpha: F,
202    lambda: &Array1<F>,
203) -> Result<(Array1<F>, Array1<F>), FerroError> {
204    let n_features = x.ncols();
205    let xt = x.t();
206    let mut xtx = xt.dot(x);
207
208    // Scale by alpha, then add diag(lambda).
209    for i in 0..n_features {
210        for j in 0..n_features {
211            xtx[[i, j]] = xtx[[i, j]] * alpha;
212        }
213        xtx[[i, i]] = xtx[[i, i]] + lambda[i];
214    }
215
216    let xty = xt.dot(y);
217    let xty_scaled: Array1<F> = xty.mapv(|v| v * alpha);
218
219    // Cholesky solve.
220    let n = n_features;
221    let mut l = Array2::<F>::zeros((n, n));
222
223    for i in 0..n {
224        for j in 0..=i {
225            let mut s = xtx[[i, j]];
226            for k in 0..j {
227                s = s - l[[i, k]] * l[[j, k]];
228            }
229            if i == j {
230                if s <= F::zero() {
231                    return Err(FerroError::NumericalInstability {
232                        message: "ARD: matrix not positive definite".into(),
233                    });
234                }
235                l[[i, j]] = s.sqrt();
236            } else {
237                l[[i, j]] = s / l[[j, j]];
238            }
239        }
240    }
241
242    // Forward substitution.
243    let mut z = Array1::<F>::zeros(n);
244    for i in 0..n {
245        let mut s = xty_scaled[i];
246        for j in 0..i {
247            s = s - l[[i, j]] * z[j];
248        }
249        z[i] = s / l[[i, i]];
250    }
251
252    // Back substitution.
253    let mut w = Array1::<F>::zeros(n);
254    for i in (0..n).rev() {
255        let mut s = z[i];
256        for j in (i + 1)..n {
257            s = s - l[[j, i]] * w[j];
258        }
259        w[i] = s / l[[i, i]];
260    }
261
262    // Compute diagonal of posterior covariance: diag((alpha * X^T X + diag(lambda))^{-1}).
263    let mut sigma_diag = Array1::<F>::zeros(n);
264    for col in 0..n {
265        let mut z_inv = Array1::<F>::zeros(n);
266        z_inv[col] = F::one() / l[[col, col]];
267        for i in (col + 1)..n {
268            let mut s = F::zero();
269            for k in col..i {
270                s = s + l[[i, k]] * z_inv[k];
271            }
272            z_inv[i] = -s / l[[i, i]];
273        }
274        for i in 0..n {
275            sigma_diag[i] = sigma_diag[i] + z_inv[i] * z_inv[i];
276        }
277    }
278
279    Ok((w, sigma_diag))
280}
281
282impl<F: Float + Send + Sync + ScalarOperand + FromPrimitive + 'static> Fit<Array2<F>, Array1<F>>
283    for ARDRegression<F>
284{
285    type Fitted = FittedARDRegression<F>;
286    type Error = FerroError;
287
288    /// Fit the ARD model via iterative evidence maximization.
289    ///
290    /// # Errors
291    ///
292    /// - [`FerroError::ShapeMismatch`] — sample count mismatch.
293    /// - [`FerroError::InsufficientSamples`] — fewer than 2 samples.
294    /// - [`FerroError::NumericalInstability`] — numerical failure in solver.
295    fn fit(
296        &self,
297        x: &Array2<F>,
298        y: &Array1<F>,
299    ) -> Result<FittedARDRegression<F>, FerroError> {
300        let (n_samples, n_features) = x.dim();
301
302        if n_samples != y.len() {
303            return Err(FerroError::ShapeMismatch {
304                expected: vec![n_samples],
305                actual: vec![y.len()],
306                context: "y length must match number of samples in X".into(),
307            });
308        }
309
310        if n_samples < 2 {
311            return Err(FerroError::InsufficientSamples {
312                required: 2,
313                actual: n_samples,
314                context: "ARDRegression requires at least 2 samples".into(),
315            });
316        }
317
318        let n_f = F::from(n_samples).unwrap();
319
320        // Center data for intercept.
321        let (x_work, y_work, x_mean, y_mean) = if self.fit_intercept {
322            let x_mean = x
323                .mean_axis(Axis(0))
324                .ok_or_else(|| FerroError::NumericalInstability {
325                    message: "failed to compute column means".into(),
326                })?;
327            let y_mean = y.mean().ok_or_else(|| FerroError::NumericalInstability {
328                message: "failed to compute target mean".into(),
329            })?;
330            let x_c = x - &x_mean;
331            let y_c = y - y_mean;
332            (x_c, y_c, Some(x_mean), Some(y_mean))
333        } else {
334            (x.clone(), y.clone(), None, None)
335        };
336
337        let mut alpha = F::one();
338        let mut lambda = Array1::<F>::from_elem(n_features, F::one());
339        let clamp_max = F::from(1e10).unwrap();
340        let clamp_min = F::from(1e-10).unwrap();
341
342        let mut w = Array1::<F>::zeros(n_features);
343        let mut sigma_diag = Array1::<F>::ones(n_features);
344
345        for _iter in 0..self.max_iter {
346            let alpha_old = alpha;
347            let lambda_old = lambda.clone();
348
349            // E-step: compute posterior.
350            let (w_new, sd_new) = ard_solve(&x_work, &y_work, alpha, &lambda)?;
351
352            // Compute gamma_i = 1 - lambda_i * Sigma_ii.
353            let gamma: Array1<F> = Array1::from_shape_fn(n_features, |i| {
354                F::one() - lambda[i] * sd_new[i]
355            });
356
357            let gamma_sum: F = gamma.iter().fold(F::zero(), |a, &b| a + b);
358
359            // Update alpha: (n - sum(gamma) + 2*alpha_1) / (||y - Xw||^2 + 2*alpha_2).
360            let residual = &y_work - x_work.dot(&w_new);
361            let sse = residual.dot(&residual);
362            let two = F::from(2.0).unwrap();
363            let new_alpha = (n_f - gamma_sum + two * self.alpha_1)
364                / (sse + two * self.alpha_2).max(F::from(1e-300).unwrap());
365
366            // Update lambda_i: (gamma_i + 2*lambda_1) / (w_i^2 + 2*lambda_2).
367            let mut new_lambda = Array1::<F>::zeros(n_features);
368            for i in 0..n_features {
369                let wi_sq = w_new[i] * w_new[i];
370                new_lambda[i] = (gamma[i] + two * self.lambda_1)
371                    / (wi_sq + two * self.lambda_2).max(F::from(1e-300).unwrap());
372            }
373
374            // Clamp.
375            alpha = new_alpha.min(clamp_max).max(clamp_min);
376            for i in 0..n_features {
377                new_lambda[i] = new_lambda[i].min(clamp_max).max(clamp_min);
378            }
379            lambda = new_lambda;
380
381            w = w_new;
382            sigma_diag = sd_new;
383
384            // Check convergence.
385            let delta_alpha =
386                (alpha - alpha_old).abs() / (alpha_old.abs() + F::from(1e-10).unwrap());
387            let mut max_delta_lambda = F::zero();
388            for i in 0..n_features {
389                let delta = (lambda[i] - lambda_old[i]).abs()
390                    / (lambda_old[i].abs() + F::from(1e-10).unwrap());
391                if delta > max_delta_lambda {
392                    max_delta_lambda = delta;
393                }
394            }
395
396            if delta_alpha < self.tol && max_delta_lambda < self.tol {
397                break;
398            }
399        }
400
401        // Prune features with lambda > threshold.
402        for i in 0..n_features {
403            if lambda[i] > self.threshold_lambda {
404                w[i] = F::zero();
405            }
406        }
407
408        let intercept = if let (Some(xm), Some(ym)) = (&x_mean, &y_mean) {
409            *ym - xm.dot(&w)
410        } else {
411            F::zero()
412        };
413
414        Ok(FittedARDRegression {
415            coefficients: w,
416            intercept,
417            alpha,
418            lambda,
419            sigma: sigma_diag,
420        })
421    }
422}
423
424impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>>
425    for FittedARDRegression<F>
426{
427    type Output = Array1<F>;
428    type Error = FerroError;
429
430    /// Predict target values using the posterior mean coefficients.
431    ///
432    /// Computes `X @ coefficients + intercept`.
433    ///
434    /// # Errors
435    ///
436    /// Returns [`FerroError::ShapeMismatch`] if the number of features
437    /// does not match the fitted model.
438    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
439        let n_features = x.ncols();
440        if n_features != self.coefficients.len() {
441            return Err(FerroError::ShapeMismatch {
442                expected: vec![self.coefficients.len()],
443                actual: vec![n_features],
444                context: "number of features must match fitted model".into(),
445            });
446        }
447
448        let preds = x.dot(&self.coefficients) + self.intercept;
449        Ok(preds)
450    }
451}
452
453impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F>
454    for FittedARDRegression<F>
455{
456    fn coefficients(&self) -> &Array1<F> {
457        &self.coefficients
458    }
459
460    fn intercept(&self) -> F {
461        self.intercept
462    }
463}
464
465// Pipeline integration.
466impl<F> PipelineEstimator<F> for ARDRegression<F>
467where
468    F: Float + FromPrimitive + ScalarOperand + Send + Sync + 'static,
469{
470    fn fit_pipeline(
471        &self,
472        x: &Array2<F>,
473        y: &Array1<F>,
474    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
475        let fitted = self.fit(x, y)?;
476        Ok(Box::new(fitted))
477    }
478}
479
480impl<F> FittedPipelineEstimator<F> for FittedARDRegression<F>
481where
482    F: Float + ScalarOperand + Send + Sync + 'static,
483{
484    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
485        self.predict(x)
486    }
487}
488
489#[cfg(test)]
490mod tests {
491    use super::*;
492    use approx::assert_relative_eq;
493    use ndarray::array;
494
495    #[test]
496    fn test_default_constructor() {
497        let m = ARDRegression::<f64>::new();
498        assert_eq!(m.max_iter, 300);
499        assert!(m.fit_intercept);
500        assert_relative_eq!(m.alpha_1, 1e-6);
501    }
502
503    #[test]
504    fn test_builder_setters() {
505        let m = ARDRegression::<f64>::new()
506            .with_max_iter(50)
507            .with_tol(1e-6)
508            .with_alpha_1(1e-3)
509            .with_alpha_2(1e-3)
510            .with_lambda_1(1e-3)
511            .with_lambda_2(1e-3)
512            .with_threshold_lambda(1e5)
513            .with_fit_intercept(false);
514        assert_eq!(m.max_iter, 50);
515        assert!(!m.fit_intercept);
516        assert_relative_eq!(m.threshold_lambda, 1e5);
517    }
518
519    #[test]
520    fn test_shape_mismatch() {
521        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
522        let y = array![1.0, 2.0];
523        let result = ARDRegression::<f64>::new().fit(&x, &y);
524        assert!(result.is_err());
525    }
526
527    #[test]
528    fn test_insufficient_samples() {
529        let x = Array2::from_shape_vec((1, 1), vec![1.0]).unwrap();
530        let y = array![1.0];
531        let result = ARDRegression::<f64>::new().fit(&x, &y);
532        assert!(result.is_err());
533    }
534
535    #[test]
536    fn test_fits_linear_data() {
537        let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
538        let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
539
540        let fitted = ARDRegression::<f64>::new().fit(&x, &y).unwrap();
541
542        // Should recover roughly y = 2x + 1.
543        assert_relative_eq!(fitted.coefficients()[0], 2.0, epsilon = 0.5);
544        assert_relative_eq!(fitted.intercept(), 1.0, epsilon = 1.5);
545    }
546
547    #[test]
548    fn test_alpha_positive() {
549        let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
550        let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
551
552        let fitted = ARDRegression::<f64>::new().fit(&x, &y).unwrap();
553        assert!(fitted.alpha() > 0.0);
554    }
555
556    #[test]
557    fn test_lambda_positive() {
558        let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
559        let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
560
561        let fitted = ARDRegression::<f64>::new().fit(&x, &y).unwrap();
562        for &v in fitted.lambda().iter() {
563            assert!(v > 0.0, "lambda must be positive, got {v}");
564        }
565    }
566
567    #[test]
568    fn test_sigma_positive() {
569        let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
570        let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
571
572        let fitted = ARDRegression::<f64>::new().fit(&x, &y).unwrap();
573        for &v in fitted.sigma().iter() {
574            assert!(v > 0.0, "sigma diagonal must be positive, got {v}");
575        }
576    }
577
578    #[test]
579    fn test_predict_length() {
580        let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
581        let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
582
583        let fitted = ARDRegression::<f64>::new().fit(&x, &y).unwrap();
584        let preds = fitted.predict(&x).unwrap();
585        assert_eq!(preds.len(), 5);
586    }
587
588    #[test]
589    fn test_predict_feature_mismatch() {
590        let x = Array2::from_shape_vec((3, 2), vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0]).unwrap();
591        let y = array![1.0, 2.0, 3.0];
592        let fitted = ARDRegression::<f64>::new().fit(&x, &y).unwrap();
593
594        let x_bad = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
595        assert!(fitted.predict(&x_bad).is_err());
596    }
597
598    #[test]
599    fn test_no_intercept() {
600        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
601        let y = array![2.0, 4.0, 6.0, 8.0];
602
603        let fitted = ARDRegression::<f64>::new()
604            .with_fit_intercept(false)
605            .fit(&x, &y)
606            .unwrap();
607        assert_relative_eq!(fitted.intercept(), 0.0, epsilon = 1e-10);
608    }
609
610    #[test]
611    fn test_sparsity_on_irrelevant_features() {
612        // y depends only on x1, x2 is noise-free irrelevant.
613        let x = Array2::from_shape_vec(
614            (6, 2),
615            vec![1.0, 100.0, 2.0, 200.0, 3.0, 300.0, 4.0, 400.0, 5.0, 500.0, 6.0, 600.0],
616        )
617        .unwrap();
618        let y = array![2.0, 4.0, 6.0, 8.0, 10.0, 12.0]; // y = 2 * x1
619
620        let fitted = ARDRegression::<f64>::new()
621            .with_max_iter(1000)
622            .fit(&x, &y)
623            .unwrap();
624
625        // The model should learn that x1 is relevant.
626        let preds = fitted.predict(&x).unwrap();
627        assert_eq!(preds.len(), 6);
628    }
629
630    #[test]
631    fn test_has_coefficients_length() {
632        let x = Array2::from_shape_vec(
633            (4, 3),
634            vec![1.0, 0.0, 0.5, 2.0, 1.0, 1.0, 3.0, 0.0, 1.5, 4.0, 1.0, 2.0],
635        )
636        .unwrap();
637        let y = array![1.0, 2.0, 3.0, 4.0];
638        let fitted = ARDRegression::<f64>::new().fit(&x, &y).unwrap();
639        assert_eq!(fitted.coefficients().len(), 3);
640    }
641
642    #[test]
643    fn test_pipeline_integration() {
644        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
645        let y = array![3.0, 5.0, 7.0, 9.0];
646
647        let model = ARDRegression::<f64>::new();
648        let fitted_pipe = model.fit_pipeline(&x, &y).unwrap();
649        let preds = fitted_pipe.predict_pipeline(&x).unwrap();
650        assert_eq!(preds.len(), 4);
651    }
652}