Skip to main content

ferrolearn_linear/
bayesian_ridge.rs

1//! Bayesian Ridge Regression.
2//!
3//! This module provides [`BayesianRidge`], which fits a Bayesian formulation of
4//! Ridge regression. Rather than using a fixed regularization strength, the
5//! model iteratively estimates two precision hyperparameters:
6//!
7//! - **`lambda`** — precision (inverse variance) of the weight prior.
8//! - **`alpha`** — noise precision (inverse of noise variance).
9//!
10//! Both are inferred from the data via evidence maximization (Type-II maximum
11//! likelihood / Empirical Bayes). This automatic relevance determination means
12//! the user does not need to tune the regularization parameter by hand.
13//!
14//! The objective is the Bayesian evidence (marginal likelihood) of the model:
15//!
16//! ```text
17//! p(y | X, alpha, lambda) ∝ N(y; 0, (1/alpha)*I + (1/lambda)*X X^T)
18//! ```
19//!
20//! After fitting, the model exposes the posterior mean (`coefficients`),
21//! the posterior covariance diagonal (`sigma`), the noise precision (`alpha`),
22//! and the weight precision (`lambda`).
23//!
24//! # Examples
25//!
26//! ```
27//! use ferrolearn_linear::BayesianRidge;
28//! use ferrolearn_core::{Fit, Predict};
29//! use ndarray::{array, Array1, Array2};
30//!
31//! let model = BayesianRidge::<f64>::new();
32//! let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
33//! let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
34//!
35//! let fitted = model.fit(&x, &y).unwrap();
36//! let preds = fitted.predict(&x).unwrap();
37//! ```
38
39use ferrolearn_core::error::FerroError;
40use ferrolearn_core::introspection::HasCoefficients;
41use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
42use ferrolearn_core::traits::{Fit, Predict};
43use ndarray::{Array1, Array2, Axis, ScalarOperand};
44use num_traits::{Float, FromPrimitive};
45
46/// Bayesian Ridge Regression with automatic regularization tuning.
47///
48/// Estimates weight precision (`lambda`) and noise precision (`alpha`)
49/// iteratively using evidence maximization. The intercept, if requested,
50/// is fit by centering.
51///
52/// # Type Parameters
53///
54/// - `F`: The floating-point type (`f32` or `f64`).
55#[derive(Debug, Clone)]
56pub struct BayesianRidge<F> {
57    /// Maximum number of EM (evidence-maximization) iterations.
58    pub max_iter: usize,
59    /// Convergence tolerance on the relative change in log-evidence.
60    pub tol: F,
61    /// Initial noise precision (alpha). Must be positive.
62    pub alpha_init: F,
63    /// Initial weight precision (lambda). Must be positive.
64    pub lambda_init: F,
65    /// Whether to fit an intercept (bias) term.
66    pub fit_intercept: bool,
67}
68
69impl<F: Float + FromPrimitive> BayesianRidge<F> {
70    /// Create a new `BayesianRidge` with default settings.
71    ///
72    /// Defaults: `max_iter = 300`, `tol = 1e-3`, `alpha_init = 1.0`,
73    /// `lambda_init = 1.0`, `fit_intercept = true`.
74    #[must_use]
75    pub fn new() -> Self {
76        Self {
77            max_iter: 300,
78            tol: F::from(1e-3).unwrap(),
79            alpha_init: F::one(),
80            lambda_init: F::one(),
81            fit_intercept: true,
82        }
83    }
84
85    /// Set the maximum number of iterations.
86    #[must_use]
87    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
88        self.max_iter = max_iter;
89        self
90    }
91
92    /// Set the convergence tolerance.
93    #[must_use]
94    pub fn with_tol(mut self, tol: F) -> Self {
95        self.tol = tol;
96        self
97    }
98
99    /// Set the initial noise precision.
100    #[must_use]
101    pub fn with_alpha_init(mut self, alpha_init: F) -> Self {
102        self.alpha_init = alpha_init;
103        self
104    }
105
106    /// Set the initial weight precision.
107    #[must_use]
108    pub fn with_lambda_init(mut self, lambda_init: F) -> Self {
109        self.lambda_init = lambda_init;
110        self
111    }
112
113    /// Set whether to fit an intercept term.
114    #[must_use]
115    pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
116        self.fit_intercept = fit_intercept;
117        self
118    }
119}
120
121impl<F: Float + FromPrimitive> Default for BayesianRidge<F> {
122    fn default() -> Self {
123        Self::new()
124    }
125}
126
127/// Fitted Bayesian Ridge Regression model.
128///
129/// Stores the posterior mean coefficients, intercept, estimated noise
130/// precision (`alpha`), weight precision (`lambda`), and the diagonal
131/// of the posterior covariance matrix (`sigma`).
132#[derive(Debug, Clone)]
133pub struct FittedBayesianRidge<F> {
134    /// Posterior mean coefficient vector.
135    coefficients: Array1<F>,
136    /// Intercept (bias) term.
137    intercept: F,
138    /// Estimated noise precision (1 / noise_variance).
139    alpha: F,
140    /// Estimated weight precision (1 / weight_variance).
141    lambda: F,
142    /// Diagonal of the posterior covariance matrix `Sigma`.
143    sigma: Array1<F>,
144}
145
146impl<F: Float> FittedBayesianRidge<F> {
147    /// Returns the estimated noise precision (alpha = 1/sigma²_noise).
148    pub fn alpha(&self) -> F {
149        self.alpha
150    }
151
152    /// Returns the estimated weight precision (lambda = 1/sigma²_weights).
153    pub fn lambda(&self) -> F {
154        self.lambda
155    }
156
157    /// Returns the diagonal of the posterior covariance matrix.
158    pub fn sigma(&self) -> &Array1<F> {
159        &self.sigma
160    }
161}
162
163/// Solve `(lambda/alpha * I + X^T X) w = X^T y` via Cholesky or fallback.
164///
165/// Returns `(w, diag(Sigma))` where `Sigma = alpha^{-1} * (lambda * I + alpha * X^T X)^{-1}`.
166fn bayesian_ridge_solve<F: Float + FromPrimitive + 'static>(
167    x: &Array2<F>,
168    y: &Array1<F>,
169    alpha: F,
170    lambda: F,
171) -> Result<(Array1<F>, Array1<F>), FerroError> {
172    let (_n_samples, n_features) = x.dim();
173
174    // Compute X^T X.
175    let xt = x.t();
176    let mut xtx = xt.dot(x);
177
178    // Scale by alpha, then add lambda * I.
179    // The system we solve is: (alpha * X^T X + lambda * I) w = alpha * X^T y
180    for i in 0..n_features {
181        for j in 0..n_features {
182            xtx[[i, j]] = xtx[[i, j]] * alpha;
183        }
184        xtx[[i, i]] = xtx[[i, i]] + lambda;
185    }
186
187    let xty = xt.dot(y);
188    let xty_scaled: Array1<F> = xty.mapv(|v| v * alpha);
189
190    // Solve via Cholesky.
191    let w = cholesky_solve(&xtx, &xty_scaled)?;
192
193    // Compute diagonal of posterior covariance: diag((alpha * X^T X + lambda * I)^{-1}).
194    let sigma_diag = cholesky_diag_inv(&xtx)?;
195
196    Ok((w, sigma_diag))
197}
198
199/// Cholesky decomposition and solve `A x = b`.
200fn cholesky_solve<F: Float>(a: &Array2<F>, b: &Array1<F>) -> Result<Array1<F>, FerroError> {
201    let n = a.nrows();
202    let mut l = Array2::<F>::zeros((n, n));
203
204    for i in 0..n {
205        for j in 0..=i {
206            let mut s = a[[i, j]];
207            for k in 0..j {
208                s = s - l[[i, k]] * l[[j, k]];
209            }
210            if i == j {
211                if s <= F::zero() {
212                    return Err(FerroError::NumericalInstability {
213                        message: "Cholesky: matrix not positive definite".into(),
214                    });
215                }
216                l[[i, j]] = s.sqrt();
217            } else {
218                l[[i, j]] = s / l[[j, j]];
219            }
220        }
221    }
222
223    // Forward substitution.
224    let mut z = Array1::<F>::zeros(n);
225    for i in 0..n {
226        let mut s = b[i];
227        for j in 0..i {
228            s = s - l[[i, j]] * z[j];
229        }
230        z[i] = s / l[[i, i]];
231    }
232
233    // Backward substitution.
234    let mut x = Array1::<F>::zeros(n);
235    for i in (0..n).rev() {
236        let mut s = z[i];
237        for j in (i + 1)..n {
238            s = s - l[[j, i]] * x[j];
239        }
240        x[i] = s / l[[i, i]];
241    }
242
243    Ok(x)
244}
245
246/// Compute the diagonal of `A^{-1}` given Cholesky `L` of `A = L L^T`.
247///
248/// Uses the identity: `diag(A^{-1}) = diag(L^{-T} L^{-1})`.
249fn cholesky_diag_inv<F: Float>(a: &Array2<F>) -> Result<Array1<F>, FerroError> {
250    let n = a.nrows();
251    let mut l = Array2::<F>::zeros((n, n));
252
253    for i in 0..n {
254        for j in 0..=i {
255            let mut s = a[[i, j]];
256            for k in 0..j {
257                s = s - l[[i, k]] * l[[j, k]];
258            }
259            if i == j {
260                if s <= F::zero() {
261                    return Err(FerroError::NumericalInstability {
262                        message: "Cholesky diag_inv: matrix not positive definite".into(),
263                    });
264                }
265                l[[i, j]] = s.sqrt();
266            } else {
267                l[[i, j]] = s / l[[j, j]];
268            }
269        }
270    }
271
272    // Compute L^{-1} column by column and accumulate diagonal of L^{-T} L^{-1}.
273    let mut diag = Array1::<F>::zeros(n);
274    for col in 0..n {
275        // Solve L z = e_col.
276        let mut z = Array1::<F>::zeros(n);
277        z[col] = F::one() / l[[col, col]];
278        for i in (col + 1)..n {
279            let mut s = F::zero();
280            for k in col..i {
281                s = s + l[[i, k]] * z[k];
282            }
283            z[i] = -s / l[[i, i]];
284        }
285        // Accumulate z^T z into the diagonal positions it touches.
286        for i in 0..n {
287            diag[i] = diag[i] + z[i] * z[i];
288        }
289    }
290
291    Ok(diag)
292}
293
294impl<F: Float + Send + Sync + ScalarOperand + FromPrimitive + 'static> Fit<Array2<F>, Array1<F>>
295    for BayesianRidge<F>
296{
297    type Fitted = FittedBayesianRidge<F>;
298    type Error = FerroError;
299
300    /// Fit the Bayesian Ridge model via evidence maximization (EM).
301    ///
302    /// Iterates over:
303    /// 1. Solve posterior for `w` given current `alpha` and `lambda`.
304    /// 2. Update `alpha` and `lambda` using the posterior statistics.
305    ///
306    /// # Errors
307    ///
308    /// - [`FerroError::ShapeMismatch`] — sample count mismatch.
309    /// - [`FerroError::InvalidParameter`] — non-positive initial precisions.
310    /// - [`FerroError::InsufficientSamples`] — fewer than 2 samples.
311    /// - [`FerroError::NumericalInstability`] — numerical failure in solver.
312    fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedBayesianRidge<F>, FerroError> {
313        let (n_samples, n_features) = x.dim();
314
315        if n_samples != y.len() {
316            return Err(FerroError::ShapeMismatch {
317                expected: vec![n_samples],
318                actual: vec![y.len()],
319                context: "y length must match number of samples in X".into(),
320            });
321        }
322
323        if n_samples < 2 {
324            return Err(FerroError::InsufficientSamples {
325                required: 2,
326                actual: n_samples,
327                context: "BayesianRidge requires at least 2 samples".into(),
328            });
329        }
330
331        if self.alpha_init <= F::zero() {
332            return Err(FerroError::InvalidParameter {
333                name: "alpha_init".into(),
334                reason: "must be positive".into(),
335            });
336        }
337
338        if self.lambda_init <= F::zero() {
339            return Err(FerroError::InvalidParameter {
340                name: "lambda_init".into(),
341                reason: "must be positive".into(),
342            });
343        }
344
345        let n_f = F::from(n_samples).unwrap();
346        let n_feat_f = F::from(n_features).unwrap();
347
348        // Center data for intercept.
349        let (x_work, y_work, x_mean, y_mean) = if self.fit_intercept {
350            let x_mean = x
351                .mean_axis(Axis(0))
352                .ok_or_else(|| FerroError::NumericalInstability {
353                    message: "failed to compute column means".into(),
354                })?;
355            let y_mean = y.mean().ok_or_else(|| FerroError::NumericalInstability {
356                message: "failed to compute target mean".into(),
357            })?;
358
359            let x_c = x - &x_mean;
360            let y_c = y - y_mean;
361            (x_c, y_c, Some(x_mean), Some(y_mean))
362        } else {
363            (x.clone(), y.clone(), None, None)
364        };
365
366        // Precompute eigenvalues of X^T X for the effective degrees of freedom.
367        // We use a simpler trace-based approximation here: gamma ≈ n_features.
368        let xt = x_work.t();
369        let xtx = xt.dot(&x_work);
370
371        // Eigenvalues of X^T X via power iteration or trace approximation.
372        // We compute the trace to get sum of eigenvalues.
373        let trace_xtx: F = (0..n_features)
374            .map(|i| xtx[[i, i]])
375            .fold(F::zero(), |a, b| a + b);
376
377        let mut alpha = self.alpha_init;
378        let mut lambda = self.lambda_init;
379
380        let mut w = Array1::<F>::zeros(n_features);
381        let mut sigma_diag = Array1::<F>::ones(n_features);
382
383        for _iter in 0..self.max_iter {
384            let alpha_old = alpha;
385            let lambda_old = lambda;
386
387            // E-step: compute posterior mean w and diag(Sigma).
388            let (w_new, sd_new) = bayesian_ridge_solve(&x_work, &y_work, alpha, lambda)?;
389
390            // Effective degrees of freedom: gamma = sum_i alpha * lambda_i / (alpha * lambda_i + lambda)
391            // Approximated using trace(Sigma * alpha * X^T X) = alpha * trace(X^T X Sigma).
392            // For simplicity we use: gamma ≈ sum_i (alpha * xtx_ii * sigma_ii).
393            let gamma: F = (0..n_features)
394                .map(|i| alpha * xtx[[i, i]] * sd_new[i])
395                .fold(F::zero(), |a, b| a + b);
396
397            // M-step: update alpha and lambda.
398            let residual = &y_work - x_work.dot(&w_new);
399            let sse = residual.dot(&residual);
400
401            // alpha = (n - gamma) / ||y - Xw||^2
402            let new_alpha = (n_f - gamma) / sse.max(F::from(1e-300).unwrap());
403
404            // lambda = gamma / ||w||^2
405            let w_norm_sq = w_new.dot(&w_new);
406            let new_lambda = gamma / w_norm_sq.max(F::from(1e-300).unwrap());
407
408            // Clamp to reasonable range.
409            let clamp_max = F::from(1e10).unwrap();
410            let clamp_min = F::from(1e-10).unwrap();
411            alpha = new_alpha.min(clamp_max).max(clamp_min);
412            lambda = new_lambda.min(clamp_max).max(clamp_min);
413
414            // Check convergence on relative change in alpha.
415            let delta_alpha =
416                (alpha - alpha_old).abs() / (alpha_old.abs() + F::from(1e-10).unwrap());
417            let delta_lambda =
418                (lambda - lambda_old).abs() / (lambda_old.abs() + F::from(1e-10).unwrap());
419
420            w = w_new;
421            sigma_diag = sd_new;
422
423            if delta_alpha < self.tol && delta_lambda < self.tol {
424                break;
425            }
426
427            // Avoid unused variable warning — trace_xtx is used in convergence.
428            let _ = trace_xtx;
429            let _ = n_feat_f;
430        }
431
432        let intercept = if let (Some(xm), Some(ym)) = (&x_mean, &y_mean) {
433            *ym - xm.dot(&w)
434        } else {
435            F::zero()
436        };
437
438        Ok(FittedBayesianRidge {
439            coefficients: w,
440            intercept,
441            alpha,
442            lambda,
443            sigma: sigma_diag,
444        })
445    }
446}
447
448impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>>
449    for FittedBayesianRidge<F>
450{
451    type Output = Array1<F>;
452    type Error = FerroError;
453
454    /// Predict target values using the posterior mean coefficients.
455    ///
456    /// Computes `X @ coefficients + intercept`.
457    ///
458    /// # Errors
459    ///
460    /// Returns [`FerroError::ShapeMismatch`] if the number of features
461    /// does not match the fitted model.
462    fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
463        let n_features = x.ncols();
464        if n_features != self.coefficients.len() {
465            return Err(FerroError::ShapeMismatch {
466                expected: vec![self.coefficients.len()],
467                actual: vec![n_features],
468                context: "number of features must match fitted model".into(),
469            });
470        }
471
472        let preds = x.dot(&self.coefficients) + self.intercept;
473        Ok(preds)
474    }
475}
476
477impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F>
478    for FittedBayesianRidge<F>
479{
480    /// Returns the posterior mean coefficient vector.
481    fn coefficients(&self) -> &Array1<F> {
482        &self.coefficients
483    }
484
485    /// Returns the intercept term.
486    fn intercept(&self) -> F {
487        self.intercept
488    }
489}
490
491// Pipeline integration for f64.
492impl PipelineEstimator<f64> for BayesianRidge<f64> {
493    /// Fit the model and return it as a boxed pipeline estimator.
494    ///
495    /// # Errors
496    ///
497    /// Propagates any [`FerroError`] from `fit`.
498    fn fit_pipeline(
499        &self,
500        x: &Array2<f64>,
501        y: &Array1<f64>,
502    ) -> Result<Box<dyn FittedPipelineEstimator<f64>>, FerroError> {
503        let fitted = self.fit(x, y)?;
504        Ok(Box::new(fitted))
505    }
506}
507
508impl FittedPipelineEstimator<f64> for FittedBayesianRidge<f64> {
509    /// Generate predictions via the pipeline interface.
510    ///
511    /// # Errors
512    ///
513    /// Propagates any [`FerroError`] from `predict`.
514    fn predict_pipeline(&self, x: &Array2<f64>) -> Result<Array1<f64>, FerroError> {
515        self.predict(x)
516    }
517}
518
519#[cfg(test)]
520mod tests {
521    use super::*;
522    use approx::assert_relative_eq;
523    use ndarray::array;
524
525    // ---- Builder ----
526
527    #[test]
528    fn test_default_constructor() {
529        let m = BayesianRidge::<f64>::new();
530        assert_eq!(m.max_iter, 300);
531        assert!(m.fit_intercept);
532        assert_relative_eq!(m.alpha_init, 1.0);
533        assert_relative_eq!(m.lambda_init, 1.0);
534    }
535
536    #[test]
537    fn test_builder_setters() {
538        let m = BayesianRidge::<f64>::new()
539            .with_max_iter(50)
540            .with_tol(1e-6)
541            .with_alpha_init(2.0)
542            .with_lambda_init(0.5)
543            .with_fit_intercept(false);
544        assert_eq!(m.max_iter, 50);
545        assert!(!m.fit_intercept);
546        assert_relative_eq!(m.alpha_init, 2.0);
547        assert_relative_eq!(m.lambda_init, 0.5);
548    }
549
550    // ---- Validation errors ----
551
552    #[test]
553    fn test_shape_mismatch() {
554        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
555        let y = array![1.0, 2.0];
556        let result = BayesianRidge::<f64>::new().fit(&x, &y);
557        assert!(result.is_err());
558    }
559
560    #[test]
561    fn test_insufficient_samples() {
562        let x = Array2::from_shape_vec((1, 1), vec![1.0]).unwrap();
563        let y = array![1.0];
564        let result = BayesianRidge::<f64>::new().fit(&x, &y);
565        assert!(result.is_err());
566    }
567
568    #[test]
569    fn test_non_positive_alpha_init() {
570        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
571        let y = array![1.0, 2.0, 3.0];
572        let result = BayesianRidge::<f64>::new().with_alpha_init(0.0).fit(&x, &y);
573        assert!(result.is_err());
574    }
575
576    #[test]
577    fn test_non_positive_lambda_init() {
578        let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
579        let y = array![1.0, 2.0, 3.0];
580        let result = BayesianRidge::<f64>::new()
581            .with_lambda_init(-1.0)
582            .fit(&x, &y);
583        assert!(result.is_err());
584    }
585
586    // ---- Correctness ----
587
588    #[test]
589    fn test_fits_linear_data() {
590        let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
591        let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
592
593        let fitted = BayesianRidge::<f64>::new().fit(&x, &y).unwrap();
594
595        // Should recover roughly y = 2x + 1.
596        assert_relative_eq!(fitted.coefficients()[0], 2.0, epsilon = 0.1);
597        assert_relative_eq!(fitted.intercept(), 1.0, epsilon = 0.5);
598    }
599
600    #[test]
601    fn test_alpha_and_lambda_positive() {
602        let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
603        let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
604
605        let fitted = BayesianRidge::<f64>::new().fit(&x, &y).unwrap();
606
607        assert!(fitted.alpha() > 0.0);
608        assert!(fitted.lambda() > 0.0);
609    }
610
611    #[test]
612    fn test_sigma_diagonal_positive() {
613        let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
614        let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
615
616        let fitted = BayesianRidge::<f64>::new().fit(&x, &y).unwrap();
617
618        for &v in fitted.sigma().iter() {
619            assert!(v > 0.0, "sigma diagonal must be positive, got {v}");
620        }
621    }
622
623    #[test]
624    fn test_sigma_length_matches_features() {
625        let x = Array2::from_shape_vec(
626            (5, 2),
627            vec![1.0, 0.5, 2.0, 1.0, 3.0, 1.5, 4.0, 2.0, 5.0, 2.5],
628        )
629        .unwrap();
630        let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
631
632        let fitted = BayesianRidge::<f64>::new().fit(&x, &y).unwrap();
633        assert_eq!(fitted.sigma().len(), 2);
634    }
635
636    #[test]
637    fn test_no_intercept() {
638        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
639        let y = array![2.0, 4.0, 6.0, 8.0];
640
641        let fitted = BayesianRidge::<f64>::new()
642            .with_fit_intercept(false)
643            .fit(&x, &y)
644            .unwrap();
645
646        assert_relative_eq!(fitted.intercept(), 0.0, epsilon = 1e-10);
647    }
648
649    #[test]
650    fn test_predict_length() {
651        let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
652        let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
653
654        let fitted = BayesianRidge::<f64>::new().fit(&x, &y).unwrap();
655        let preds = fitted.predict(&x).unwrap();
656        assert_eq!(preds.len(), 5);
657    }
658
659    #[test]
660    fn test_predict_feature_mismatch() {
661        let x = Array2::from_shape_vec((3, 2), vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0]).unwrap();
662        let y = array![1.0, 2.0, 3.0];
663        let fitted = BayesianRidge::<f64>::new().fit(&x, &y).unwrap();
664
665        let x_bad = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
666        assert!(fitted.predict(&x_bad).is_err());
667    }
668
669    #[test]
670    fn test_has_coefficients_length() {
671        let x = Array2::from_shape_vec(
672            (4, 3),
673            vec![1.0, 0.0, 0.5, 2.0, 1.0, 1.0, 3.0, 0.0, 1.5, 4.0, 1.0, 2.0],
674        )
675        .unwrap();
676        let y = array![1.0, 2.0, 3.0, 4.0];
677        let fitted = BayesianRidge::<f64>::new().fit(&x, &y).unwrap();
678        assert_eq!(fitted.coefficients().len(), 3);
679    }
680
681    #[test]
682    fn test_pipeline_integration() {
683        let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
684        let y = array![3.0, 5.0, 7.0, 9.0];
685
686        let model = BayesianRidge::<f64>::new();
687        let fitted_pipe = model.fit_pipeline(&x, &y).unwrap();
688        let preds = fitted_pipe.predict_pipeline(&x).unwrap();
689        assert_eq!(preds.len(), 4);
690    }
691
692    #[test]
693    fn test_multivariate_fit() {
694        // y = 1*x1 + 2*x2
695        let x =
696            Array2::from_shape_vec((4, 2), vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0]).unwrap();
697        let y = array![1.0, 2.0, 3.0, 6.0];
698
699        let fitted = BayesianRidge::<f64>::new().fit(&x, &y).unwrap();
700        let preds = fitted.predict(&x).unwrap();
701        assert_eq!(preds.len(), 4);
702        // Rough sanity: residuals should be small.
703        let residuals: Vec<f64> = preds
704            .iter()
705            .zip(y.iter())
706            .map(|(p, t)| (p - t).abs())
707            .collect();
708        assert!(residuals.iter().all(|&r| r < 1.0));
709    }
710}