Skip to main content

ferrolearn_linear/
bayesian_ridge.rs

1//! Bayesian Ridge Regression.
2//!
3//! This module provides [`BayesianRidge`], which fits a Bayesian formulation of
4//! Ridge regression. Rather than using a fixed regularization strength, the
5//! model iteratively estimates two precision hyperparameters:
6//!
7//! - **`lambda`** — precision (inverse variance) of the weight prior.
8//! - **`alpha`** — noise precision (inverse of noise variance).
9//!
10//! Both are inferred from the data via evidence maximization (Type-II maximum
11//! likelihood / Empirical Bayes). This automatic relevance determination means
12//! the user does not need to tune the regularization parameter by hand.
13//!
14//! The objective is the Bayesian evidence (marginal likelihood) of the model:
15//!
16//! ```text
17//! p(y | X, alpha, lambda) ∝ N(y; 0, (1/alpha)*I + (1/lambda)*X X^T)
18//! ```
19//!
20//! After fitting, the model exposes the posterior mean (`coefficients`),
21//! the posterior covariance diagonal (`sigma`), the noise precision (`alpha`),
22//! and the weight precision (`lambda`).
23//!
24//! # Examples
25//!
26//! ```
27//! use ferrolearn_linear::BayesianRidge;
28//! use ferrolearn_core::{Fit, Predict};
29//! use ndarray::{array, Array1, Array2};
30//!
31//! let model = BayesianRidge::<f64>::new();
32//! let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
33//! let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
34//!
35//! let fitted = model.fit(&x, &y).unwrap();
36//! let preds = fitted.predict(&x).unwrap();
37//! ```
38
39use ferrolearn_core::error::FerroError;
40use ferrolearn_core::introspection::HasCoefficients;
41use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
42use ferrolearn_core::traits::{Fit, Predict};
43use ndarray::{Array1, Array2, Axis, ScalarOperand};
44use num_traits::{Float, FromPrimitive};
45
46/// Bayesian Ridge Regression with automatic regularization tuning.
47///
48/// Estimates weight precision (`lambda`) and noise precision (`alpha`)
49/// iteratively using evidence maximization. The intercept, if requested,
50/// is fit by centering.
51///
52/// # Type Parameters
53///
54/// - `F`: The floating-point type (`f32` or `f64`).
55#[derive(Debug, Clone)]
56pub struct BayesianRidge<F> {
57    /// Maximum number of EM (evidence-maximization) iterations.
58    pub max_iter: usize,
59    /// Convergence tolerance on the relative change in log-evidence.
60    pub tol: F,
61    /// Initial noise precision (alpha). Must be positive.
62    pub alpha_init: F,
63    /// Initial weight precision (lambda). Must be positive.
64    pub lambda_init: F,
65    /// Whether to fit an intercept (bias) term.
66    pub fit_intercept: bool,
67}
68
69impl<F: Float + FromPrimitive> BayesianRidge<F> {
70    /// Create a new `BayesianRidge` with default settings.
71    ///
72    /// Defaults: `max_iter = 300`, `tol = 1e-3`, `alpha_init = 1.0`,
73    /// `lambda_init = 1.0`, `fit_intercept = true`.
74    #[must_use]
75    pub fn new() -> Self {
76        Self {
77            max_iter: 300,
78            tol: F::from(1e-3).unwrap(),
79            alpha_init: F::one(),
80            lambda_init: F::one(),
81            fit_intercept: true,
82        }
83    }
84
85    /// Set the maximum number of iterations.
86    #[must_use]
87    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
88        self.max_iter = max_iter;
89        self
90    }
91
92    /// Set the convergence tolerance.
93    #[must_use]
94    pub fn with_tol(mut self, tol: F) -> Self {
95        self.tol = tol;
96        self
97    }
98
99    /// Set the initial noise precision.
100    #[must_use]
101    pub fn with_alpha_init(mut self, alpha_init: F) -> Self {
102        self.alpha_init = alpha_init;
103        self
104    }
105
106    /// Set the initial weight precision.
107    #[must_use]
108    pub fn with_lambda_init(mut self, lambda_init: F) -> Self {
109        self.lambda_init = lambda_init;
110        self
111    }
112
113    /// Set whether to fit an intercept term.
114    #[must_use]
115    pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
116        self.fit_intercept = fit_intercept;
117        self
118    }
119}
120
121impl<F: Float + FromPrimitive> Default for BayesianRidge<F> {
122    fn default() -> Self {
123        Self::new()
124    }
125}
126
127/// Fitted Bayesian Ridge Regression model.
128///
129/// Stores the posterior mean coefficients, intercept, estimated noise
130/// precision (`alpha`), weight precision (`lambda`), and the diagonal
131/// of the posterior covariance matrix (`sigma`).
132#[derive(Debug, Clone)]
133pub struct FittedBayesianRidge<F> {
134    /// Posterior mean coefficient vector.
135    coefficients: Array1<F>,
136    /// Intercept (bias) term.
137    intercept: F,
138    /// Estimated noise precision (1 / noise_variance).
139    alpha: F,
140    /// Estimated weight precision (1 / weight_variance).
141    lambda: F,
142    /// Diagonal of the posterior covariance matrix `Sigma`.
143    sigma: Array1<F>,
144}
145
146impl<F: Float> FittedBayesianRidge<F> {
147    /// Returns the estimated noise precision (alpha = 1/sigma²_noise).
148    pub fn alpha(&self) -> F {
149        self.alpha
150    }
151
152    /// Returns the estimated weight precision (lambda = 1/sigma²_weights).
153    pub fn lambda(&self) -> F {
154        self.lambda
155    }
156
157    /// Returns the diagonal of the posterior covariance matrix.
158    pub fn sigma(&self) -> &Array1<F> {
159        &self.sigma
160    }
161}
162
163/// Solve `(lambda/alpha * I + X^T X) w = X^T y` via Cholesky or fallback.
164///
165/// Returns `(w, diag(Sigma))` where `Sigma = alpha^{-1} * (lambda * I + alpha * X^T X)^{-1}`.
166fn bayesian_ridge_solve<F: Float + FromPrimitive + 'static>(
167    x: &Array2<F>,
168    y: &Array1<F>,
169    alpha: F,
170    lambda: F,
171) -> Result<(Array1<F>, Array1<F>), FerroError> {
172    let (_n_samples, n_features) = x.dim();
173
174    // Compute X^T X.
175    let xt = x.t();
176    let mut xtx = xt.dot(x);
177
178    // Scale by alpha, then add lambda * I.
179    // The system we solve is: (alpha * X^T X + lambda * I) w = alpha * X^T y
180    for i in 0..n_features {
181        for j in 0..n_features {
182            xtx[[i, j]] = xtx[[i, j]] * alpha;
183        }
184        xtx[[i, i]] = xtx[[i, i]] + lambda;
185    }
186
187    let xty = xt.dot(y);
188    let xty_scaled: Array1<F> = xty.mapv(|v| v * alpha);
189
190    // Solve via Cholesky.
191    let w = cholesky_solve(&xtx, &xty_scaled)?;
192
193    // Compute diagonal of posterior covariance: diag((alpha * X^T X + lambda * I)^{-1}).
194    let sigma_diag = cholesky_diag_inv(&xtx)?;
195
196    Ok((w, sigma_diag))
197}
198
199/// Cholesky decomposition and solve `A x = b`.
200fn cholesky_solve<F: Float>(a: &Array2<F>, b: &Array1<F>) -> Result<Array1<F>, FerroError> {
201    let n = a.nrows();
202    let mut l = Array2::<F>::zeros((n, n));
203
204    for i in 0..n {
205        for j in 0..=i {
206            let mut s = a[[i, j]];
207            for k in 0..j {
208                s = s - l[[i, k]] * l[[j, k]];
209            }
210            if i == j {
211                if s <= F::zero() {
212                    return Err(FerroError::NumericalInstability {
213                        message: "Cholesky: matrix not positive definite".into(),
214                    });
215                }
216                l[[i, j]] = s.sqrt();
217            } else {
218                l[[i, j]] = s / l[[j, j]];
219            }
220        }
221    }
222
223    // Forward substitution.
224    let mut z = Array1::<F>::zeros(n);
225    for i in 0..n {
226        let mut s = b[i];
227        for j in 0..i {
228            s = s - l[[i, j]] * z[j];
229        }
230        z[i] = s / l[[i, i]];
231    }
232
233    // Backward substitution.
234    let mut x = Array1::<F>::zeros(n);
235    for i in (0..n).rev() {
236        let mut s = z[i];
237        for j in (i + 1)..n {
238            s = s - l[[j, i]] * x[j];
239        }
240        x[i] = s / l[[i, i]];
241    }
242
243    Ok(x)
244}
245
246/// Compute the diagonal of `A^{-1}` given Cholesky `L` of `A = L L^T`.
247///
248/// Uses the identity: `diag(A^{-1}) = diag(L^{-T} L^{-1})`.
249fn cholesky_diag_inv<F: Float>(a: &Array2<F>) -> Result<Array1<F>, FerroError> {
250    let n = a.nrows();
251    let mut l = Array2::<F>::zeros((n, n));
252
253    for i in 0..n {
254        for j in 0..=i {
255            let mut s = a[[i, j]];
256            for k in 0..j {
257                s = s - l[[i, k]] * l[[j, k]];
258            }
259            if i == j {
260                if s <= F::zero() {
261                    return Err(FerroError::NumericalInstability {
262                        message: "Cholesky diag_inv: matrix not positive definite".into(),
263                    });
264                }
265                l[[i, j]] = s.sqrt();
266            } else {
267                l[[i, j]] = s / l[[j, j]];
268            }
269        }
270    }
271
272    // Compute L^{-1} column by column and accumulate diagonal of L^{-T} L^{-1}.
273    let mut diag = Array1::<F>::zeros(n);
274    for col in 0..n {
275        // Solve L z = e_col.
276        let mut z = Array1::<F>::zeros(n);
277        z[col] = F::one() / l[[col, col]];
278        for i in (col + 1)..n {
279            let mut s = F::zero();
280            for k in col..i {
281                s = s + l[[i, k]] * z[k];
282            }
283            z[i] = -s / l[[i, i]];
284        }
285        // Accumulate z^T z into the diagonal positions it touches.
286        for i in 0..n {
287            diag[i] = diag[i] + z[i] * z[i];
288        }
289    }
290
291    Ok(diag)
292}
293
294impl<F: Float + Send + Sync + ScalarOperand + FromPrimitive + 'static> Fit<Array2<F>, Array1<F>>
295    for BayesianRidge<F>
296{
297    type Fitted = FittedBayesianRidge<F>;
298    type Error = FerroError;
299
300    /// Fit the Bayesian Ridge model via evidence maximization (EM).
301    ///
302    /// Iterates over:
303    /// 1. Solve posterior for `w` given current `alpha` and `lambda`.
304    /// 2. Update `alpha` and `lambda` using the posterior statistics.
305    ///
306    /// # Errors
307    ///
308    /// - [`FerroError::ShapeMismatch`] — sample count mismatch.
309    /// - [`FerroError::InvalidParameter`] — non-positive initial precisions.
310    /// - [`FerroError::InsufficientSamples`] — fewer than 2 samples.
311    /// - [`FerroError::NumericalInstability`] — numerical failure in solver.
312    fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedBayesianRidge<F>, FerroError> {
313        let (n_samples, n_features) = x.dim();
314
315        if n_samples != y.len() {
316            return Err(FerroError::ShapeMismatch {
317                expected: vec![n_samples],
318                actual: vec![y.len()],
319                context: "y length must match number of samples in X".into(),
320            });
321        }
322
323        if n_samples < 2 {
324            return Err(FerroError::InsufficientSamples {
325                required: 2,
326                actual: n_samples,
327                context: "BayesianRidge requires at least 2 samples".into(),
328            });
329        }
330
331        if self.alpha_init <= F::zero() {
332            return Err(FerroError::InvalidParameter {
333                name: "alpha_init".into(),
334                reason: "must be positive".into(),
335            });
336        }
337
338        if self.lambda_init <= F::zero() {
339            return Err(FerroError::InvalidParameter {
340                name: "lambda_init".into(),
341                reason: "must be positive".into(),
342            });
343        }
344
345        let n_f = F::from(n_samples).unwrap();
346        let n_feat_f = F::from(n_features).unwrap();
347
348        // Center data for intercept.
349        let (x_work, y_work, x_mean, y_mean) = if self.fit_intercept {
350            let x_mean = x
351                .mean_axis(Axis(0))
352                .ok_or_else(|| FerroError::NumericalInstability {
353                    message: "failed to compute column means".into(),
354                })?;
355            let y_mean = y.mean().ok_or_else(|| FerroError::NumericalInstability {
356                message: "failed to compute target mean".into(),
357            })?;
358
359            let x_c = x - &x_mean;
360            let y_c = y - y_mean;
361            (x_c, y_c, Some(x_mean), Some(y_mean))
362        } else {
363            (x.clone(), y.clone(), None, None)
364        };
365
366        // Precompute eigenvalues of X^T X for the effective degrees of freedom.
367        // We use a simpler trace-based approximation here: gamma ≈ n_features.
368        let xt = x_work.t();
369        let xtx = xt.dot(&x_work);
370
371        // Eigenvalues of X^T X via power iteration or trace approximation.
372        // We compute the trace to get sum of eigenvalues.
373        let trace_xtx: F = (0..n_features)
374            .map(|i| xtx[[i, i]])
375            .fold(F::zero(), |a, b| a + b);
376
377        let mut alpha = self.alpha_init;
378        let mut lambda = self.lambda_init;
379
380        let mut w = Array1::<F>::zeros(n_features);
381        let mut sigma_diag = Array1::<F>::ones(n_features);
382
383        for _iter in 0..self.max_iter {
384            let alpha_old = alpha;
385            let lambda_old = lambda;
386
387            // E-step: compute posterior mean w and diag(Sigma).
388            let (w_new, sd_new) = bayesian_ridge_solve(&x_work, &y_work, alpha, lambda)?;
389
390            // Effective degrees of freedom: gamma = sum_i alpha * lambda_i / (alpha * lambda_i + lambda)
391            // Approximated using trace(Sigma * alpha * X^T X) = alpha * trace(X^T X Sigma).
392            // For simplicity we use: gamma ≈ sum_i (alpha * xtx_ii * sigma_ii).
393            let gamma: F = (0..n_features)
394                .map(|i| alpha * xtx[[i, i]] * sd_new[i])
395                .fold(F::zero(), |a, b| a + b);
396
397            // M-step: update alpha and lambda.
398            let residual = &y_work - x_work.dot(&w_new);
399            let sse = residual.dot(&residual);
400
401            // alpha = (n - gamma) / ||y - Xw||^2
402            let new_alpha = (n_f - gamma) / sse.max(F::from(1e-300).unwrap());
403
404            // lambda = gamma / ||w||^2
405            let w_norm_sq = w_new.dot(&w_new);
406            let new_lambda = gamma / w_norm_sq.max(F::from(1e-300).unwrap());
407
408            // Clamp to reasonable range.
409            let clamp_max = F::from(1e10).unwrap();
410            let clamp_min = F::from(1e-10).unwrap();
411            alpha = new_alpha.min(clamp_max).max(clamp_min);
412            lambda = new_lambda.min(clamp_max).max(clamp_min);
413
414            // Check convergence on relative change in alpha.
415            let delta_alpha =
416                (alpha - alpha_old).abs() / (alpha_old.abs() + F::from(1e-10).unwrap());
417            let delta_lambda =
418                (lambda - lambda_old).abs() / (lambda_old.abs() + F::from(1e-10).unwrap());
419
420            w = w_new;
421            sigma_diag = sd_new;
422
423            if delta_alpha < self.tol && delta_lambda < self.tol {
424                break;
425            }
426
427            // Avoid unused variable warning — trace_xtx is used in convergence.
428            let _ = trace_xtx;
429            let _ = n_feat_f;
430        }
431
432        let intercept = if let (Some(xm), Some(ym)) = (&x_mean, &y_mean) {
433            *ym - xm.dot(&w)
434        } else {
435            F::zero()
436        };
437
438        Ok(FittedBayesianRidge {
439            coefficients: w,
440            intercept,
441            alpha,
442            lambda,
443            sigma: sigma_diag,
444        })
445    }
446}
447
448impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>>
449    for FittedBayesianRidge<F>
450{
451    type Output = Array1<F>;
452    type Error = FerroError;
453
454    /// Predict target values using the posterior mean coefficients.
455    ///
456    /// Computes `X @ coefficients + intercept`.
457    ///
458    /// # Errors
459    ///
460    /// Returns [`FerroError::ShapeMismatch`] if the number of features
461    /// does not match the fitted model.
462    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
463        let n_features = x.ncols();
464        if n_features != self.coefficients.len() {
465            return Err(FerroError::ShapeMismatch {
466                expected: vec![self.coefficients.len()],
467                actual: vec![n_features],
468                context: "number of features must match fitted model".into(),
469            });
470        }
471
472        let preds = x.dot(&self.coefficients) + self.intercept;
473        Ok(preds)
474    }
475}
476
477impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F>
478    for FittedBayesianRidge<F>
479{
480    /// Returns the posterior mean coefficient vector.
481    fn coefficients(&self) -> &Array1<F> {
482        &self.coefficients
483    }
484
485    /// Returns the intercept term.
486    fn intercept(&self) -> F {
487        self.intercept
488    }
489}
490
491// Pipeline integration.
492impl<F> PipelineEstimator<F> for BayesianRidge<F>
493where
494    F: Float + FromPrimitive + ScalarOperand + Send + Sync + 'static,
495{
496    /// Fit the model and return it as a boxed pipeline estimator.
497    ///
498    /// # Errors
499    ///
500    /// Propagates any [`FerroError`] from `fit`.
501    fn fit_pipeline(
502        &self,
503        x: &Array2<F>,
504        y: &Array1<F>,
505    ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
506        let fitted = self.fit(x, y)?;
507        Ok(Box::new(fitted))
508    }
509}
510
511impl<F> FittedPipelineEstimator<F> for FittedBayesianRidge<F>
512where
513    F: Float + ScalarOperand + Send + Sync + 'static,
514{
515    /// Generate predictions via the pipeline interface.
516    ///
517    /// # Errors
518    ///
519    /// Propagates any [`FerroError`] from `predict`.
520    fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
521        self.predict(x)
522    }
523}
524
525#[cfg(test)]
526mod tests {
527    use super::*;
528    use approx::assert_relative_eq;
529    use ndarray::array;
530
531    // ---- Builder ----
532
533    #[test]
534    fn test_default_constructor() {
535        let m = BayesianRidge::<f64>::new();
536        assert_eq!(m.max_iter, 300);
537        assert!(m.fit_intercept);
538        assert_relative_eq!(m.alpha_init, 1.0);
539        assert_relative_eq!(m.lambda_init, 1.0);
540    }
541
542    #[test]
543    fn test_builder_setters() {
544        let m = BayesianRidge::<f64>::new()
545            .with_max_iter(50)
546            .with_tol(1e-6)
547            .with_alpha_init(2.0)
548            .with_lambda_init(0.5)
549            .with_fit_intercept(false);
550        assert_eq!(m.max_iter, 50);
551        assert!(!m.fit_intercept);
552        assert_relative_eq!(m.alpha_init, 2.0);
553        assert_relative_eq!(m.lambda_init, 0.5);
554    }
555
556    // ---- Validation errors ----
557
558    #[test]
559    fn test_shape_mismatch() {
560        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
561        let y = array![1.0, 2.0];
562        let result = BayesianRidge::<f64>::new().fit(&x, &y);
563        assert!(result.is_err());
564    }
565
566    #[test]
567    fn test_insufficient_samples() {
568        let x = Array2::from_shape_vec((1, 1), vec![1.0]).unwrap();
569        let y = array![1.0];
570        let result = BayesianRidge::<f64>::new().fit(&x, &y);
571        assert!(result.is_err());
572    }
573
574    #[test]
575    fn test_non_positive_alpha_init() {
576        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
577        let y = array![1.0, 2.0, 3.0];
578        let result = BayesianRidge::<f64>::new().with_alpha_init(0.0).fit(&x, &y);
579        assert!(result.is_err());
580    }
581
582    #[test]
583    fn test_non_positive_lambda_init() {
584        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
585        let y = array![1.0, 2.0, 3.0];
586        let result = BayesianRidge::<f64>::new()
587            .with_lambda_init(-1.0)
588            .fit(&x, &y);
589        assert!(result.is_err());
590    }
591
592    // ---- Correctness ----
593
594    #[test]
595    fn test_fits_linear_data() {
596        let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
597        let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
598
599        let fitted = BayesianRidge::<f64>::new().fit(&x, &y).unwrap();
600
601        // Should recover roughly y = 2x + 1.
602        assert_relative_eq!(fitted.coefficients()[0], 2.0, epsilon = 0.1);
603        assert_relative_eq!(fitted.intercept(), 1.0, epsilon = 0.5);
604    }
605
606    #[test]
607    fn test_alpha_and_lambda_positive() {
608        let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
609        let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
610
611        let fitted = BayesianRidge::<f64>::new().fit(&x, &y).unwrap();
612
613        assert!(fitted.alpha() > 0.0);
614        assert!(fitted.lambda() > 0.0);
615    }
616
617    #[test]
618    fn test_sigma_diagonal_positive() {
619        let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
620        let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
621
622        let fitted = BayesianRidge::<f64>::new().fit(&x, &y).unwrap();
623
624        for &v in fitted.sigma().iter() {
625            assert!(v > 0.0, "sigma diagonal must be positive, got {v}");
626        }
627    }
628
629    #[test]
630    fn test_sigma_length_matches_features() {
631        let x = Array2::from_shape_vec(
632            (5, 2),
633            vec![1.0, 0.5, 2.0, 1.0, 3.0, 1.5, 4.0, 2.0, 5.0, 2.5],
634        )
635        .unwrap();
636        let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
637
638        let fitted = BayesianRidge::<f64>::new().fit(&x, &y).unwrap();
639        assert_eq!(fitted.sigma().len(), 2);
640    }
641
642    #[test]
643    fn test_no_intercept() {
644        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
645        let y = array![2.0, 4.0, 6.0, 8.0];
646
647        let fitted = BayesianRidge::<f64>::new()
648            .with_fit_intercept(false)
649            .fit(&x, &y)
650            .unwrap();
651
652        assert_relative_eq!(fitted.intercept(), 0.0, epsilon = 1e-10);
653    }
654
655    #[test]
656    fn test_predict_length() {
657        let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
658        let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
659
660        let fitted = BayesianRidge::<f64>::new().fit(&x, &y).unwrap();
661        let preds = fitted.predict(&x).unwrap();
662        assert_eq!(preds.len(), 5);
663    }
664
665    #[test]
666    fn test_predict_feature_mismatch() {
667        let x = Array2::from_shape_vec((3, 2), vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0]).unwrap();
668        let y = array![1.0, 2.0, 3.0];
669        let fitted = BayesianRidge::<f64>::new().fit(&x, &y).unwrap();
670
671        let x_bad = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
672        assert!(fitted.predict(&x_bad).is_err());
673    }
674
675    #[test]
676    fn test_has_coefficients_length() {
677        let x = Array2::from_shape_vec(
678            (4, 3),
679            vec![1.0, 0.0, 0.5, 2.0, 1.0, 1.0, 3.0, 0.0, 1.5, 4.0, 1.0, 2.0],
680        )
681        .unwrap();
682        let y = array![1.0, 2.0, 3.0, 4.0];
683        let fitted = BayesianRidge::<f64>::new().fit(&x, &y).unwrap();
684        assert_eq!(fitted.coefficients().len(), 3);
685    }
686
687    #[test]
688    fn test_pipeline_integration() {
689        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
690        let y = array![3.0, 5.0, 7.0, 9.0];
691
692        let model = BayesianRidge::<f64>::new();
693        let fitted_pipe = model.fit_pipeline(&x, &y).unwrap();
694        let preds = fitted_pipe.predict_pipeline(&x).unwrap();
695        assert_eq!(preds.len(), 4);
696    }
697
698    #[test]
699    fn test_multivariate_fit() {
700        // y = 1*x1 + 2*x2
701        let x =
702            Array2::from_shape_vec((4, 2), vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0]).unwrap();
703        let y = array![1.0, 2.0, 3.0, 6.0];
704
705        let fitted = BayesianRidge::<f64>::new().fit(&x, &y).unwrap();
706        let preds = fitted.predict(&x).unwrap();
707        assert_eq!(preds.len(), 4);
708        // Rough sanity: residuals should be small.
709        let residuals: Vec<f64> = preds
710            .iter()
711            .zip(y.iter())
712            .map(|(p, t)| (p - t).abs())
713            .collect();
714        assert!(residuals.iter().all(|&r| r < 1.0));
715    }
716}