1use ferrolearn_core::error::FerroError;
40use ferrolearn_core::introspection::HasCoefficients;
41use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
42use ferrolearn_core::traits::{Fit, Predict};
43use ndarray::{Array1, Array2, Axis, ScalarOperand};
44use num_traits::{Float, FromPrimitive};
45
46#[derive(Debug, Clone)]
56pub struct BayesianRidge<F> {
57 pub max_iter: usize,
59 pub tol: F,
61 pub alpha_init: F,
63 pub lambda_init: F,
65 pub fit_intercept: bool,
67}
68
69impl<F: Float + FromPrimitive> BayesianRidge<F> {
70 #[must_use]
75 pub fn new() -> Self {
76 Self {
77 max_iter: 300,
78 tol: F::from(1e-3).unwrap(),
79 alpha_init: F::one(),
80 lambda_init: F::one(),
81 fit_intercept: true,
82 }
83 }
84
85 #[must_use]
87 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
88 self.max_iter = max_iter;
89 self
90 }
91
92 #[must_use]
94 pub fn with_tol(mut self, tol: F) -> Self {
95 self.tol = tol;
96 self
97 }
98
99 #[must_use]
101 pub fn with_alpha_init(mut self, alpha_init: F) -> Self {
102 self.alpha_init = alpha_init;
103 self
104 }
105
106 #[must_use]
108 pub fn with_lambda_init(mut self, lambda_init: F) -> Self {
109 self.lambda_init = lambda_init;
110 self
111 }
112
113 #[must_use]
115 pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
116 self.fit_intercept = fit_intercept;
117 self
118 }
119}
120
121impl<F: Float + FromPrimitive> Default for BayesianRidge<F> {
122 fn default() -> Self {
123 Self::new()
124 }
125}
126
127#[derive(Debug, Clone)]
133pub struct FittedBayesianRidge<F> {
134 coefficients: Array1<F>,
136 intercept: F,
138 alpha: F,
140 lambda: F,
142 sigma: Array1<F>,
144}
145
146impl<F: Float> FittedBayesianRidge<F> {
147 pub fn alpha(&self) -> F {
149 self.alpha
150 }
151
152 pub fn lambda(&self) -> F {
154 self.lambda
155 }
156
157 pub fn sigma(&self) -> &Array1<F> {
159 &self.sigma
160 }
161}
162
163fn bayesian_ridge_solve<F: Float + FromPrimitive + 'static>(
167 x: &Array2<F>,
168 y: &Array1<F>,
169 alpha: F,
170 lambda: F,
171) -> Result<(Array1<F>, Array1<F>), FerroError> {
172 let (_n_samples, n_features) = x.dim();
173
174 let xt = x.t();
176 let mut xtx = xt.dot(x);
177
178 for i in 0..n_features {
181 for j in 0..n_features {
182 xtx[[i, j]] = xtx[[i, j]] * alpha;
183 }
184 xtx[[i, i]] = xtx[[i, i]] + lambda;
185 }
186
187 let xty = xt.dot(y);
188 let xty_scaled: Array1<F> = xty.mapv(|v| v * alpha);
189
190 let w = cholesky_solve(&xtx, &xty_scaled)?;
192
193 let sigma_diag = cholesky_diag_inv(&xtx)?;
195
196 Ok((w, sigma_diag))
197}
198
199fn cholesky_solve<F: Float>(a: &Array2<F>, b: &Array1<F>) -> Result<Array1<F>, FerroError> {
201 let n = a.nrows();
202 let mut l = Array2::<F>::zeros((n, n));
203
204 for i in 0..n {
205 for j in 0..=i {
206 let mut s = a[[i, j]];
207 for k in 0..j {
208 s = s - l[[i, k]] * l[[j, k]];
209 }
210 if i == j {
211 if s <= F::zero() {
212 return Err(FerroError::NumericalInstability {
213 message: "Cholesky: matrix not positive definite".into(),
214 });
215 }
216 l[[i, j]] = s.sqrt();
217 } else {
218 l[[i, j]] = s / l[[j, j]];
219 }
220 }
221 }
222
223 let mut z = Array1::<F>::zeros(n);
225 for i in 0..n {
226 let mut s = b[i];
227 for j in 0..i {
228 s = s - l[[i, j]] * z[j];
229 }
230 z[i] = s / l[[i, i]];
231 }
232
233 let mut x = Array1::<F>::zeros(n);
235 for i in (0..n).rev() {
236 let mut s = z[i];
237 for j in (i + 1)..n {
238 s = s - l[[j, i]] * x[j];
239 }
240 x[i] = s / l[[i, i]];
241 }
242
243 Ok(x)
244}
245
246fn cholesky_diag_inv<F: Float>(a: &Array2<F>) -> Result<Array1<F>, FerroError> {
250 let n = a.nrows();
251 let mut l = Array2::<F>::zeros((n, n));
252
253 for i in 0..n {
254 for j in 0..=i {
255 let mut s = a[[i, j]];
256 for k in 0..j {
257 s = s - l[[i, k]] * l[[j, k]];
258 }
259 if i == j {
260 if s <= F::zero() {
261 return Err(FerroError::NumericalInstability {
262 message: "Cholesky diag_inv: matrix not positive definite".into(),
263 });
264 }
265 l[[i, j]] = s.sqrt();
266 } else {
267 l[[i, j]] = s / l[[j, j]];
268 }
269 }
270 }
271
272 let mut diag = Array1::<F>::zeros(n);
274 for col in 0..n {
275 let mut z = Array1::<F>::zeros(n);
277 z[col] = F::one() / l[[col, col]];
278 for i in (col + 1)..n {
279 let mut s = F::zero();
280 for k in col..i {
281 s = s + l[[i, k]] * z[k];
282 }
283 z[i] = -s / l[[i, i]];
284 }
285 for i in 0..n {
287 diag[i] = diag[i] + z[i] * z[i];
288 }
289 }
290
291 Ok(diag)
292}
293
294impl<F: Float + Send + Sync + ScalarOperand + FromPrimitive + 'static> Fit<Array2<F>, Array1<F>>
295 for BayesianRidge<F>
296{
297 type Fitted = FittedBayesianRidge<F>;
298 type Error = FerroError;
299
300 fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedBayesianRidge<F>, FerroError> {
313 let (n_samples, n_features) = x.dim();
314
315 if n_samples != y.len() {
316 return Err(FerroError::ShapeMismatch {
317 expected: vec![n_samples],
318 actual: vec![y.len()],
319 context: "y length must match number of samples in X".into(),
320 });
321 }
322
323 if n_samples < 2 {
324 return Err(FerroError::InsufficientSamples {
325 required: 2,
326 actual: n_samples,
327 context: "BayesianRidge requires at least 2 samples".into(),
328 });
329 }
330
331 if self.alpha_init <= F::zero() {
332 return Err(FerroError::InvalidParameter {
333 name: "alpha_init".into(),
334 reason: "must be positive".into(),
335 });
336 }
337
338 if self.lambda_init <= F::zero() {
339 return Err(FerroError::InvalidParameter {
340 name: "lambda_init".into(),
341 reason: "must be positive".into(),
342 });
343 }
344
345 let n_f = F::from(n_samples).unwrap();
346 let n_feat_f = F::from(n_features).unwrap();
347
348 let (x_work, y_work, x_mean, y_mean) = if self.fit_intercept {
350 let x_mean = x
351 .mean_axis(Axis(0))
352 .ok_or_else(|| FerroError::NumericalInstability {
353 message: "failed to compute column means".into(),
354 })?;
355 let y_mean = y.mean().ok_or_else(|| FerroError::NumericalInstability {
356 message: "failed to compute target mean".into(),
357 })?;
358
359 let x_c = x - &x_mean;
360 let y_c = y - y_mean;
361 (x_c, y_c, Some(x_mean), Some(y_mean))
362 } else {
363 (x.clone(), y.clone(), None, None)
364 };
365
366 let xt = x_work.t();
369 let xtx = xt.dot(&x_work);
370
371 let trace_xtx: F = (0..n_features)
374 .map(|i| xtx[[i, i]])
375 .fold(F::zero(), |a, b| a + b);
376
377 let mut alpha = self.alpha_init;
378 let mut lambda = self.lambda_init;
379
380 let mut w = Array1::<F>::zeros(n_features);
381 let mut sigma_diag = Array1::<F>::ones(n_features);
382
383 for _iter in 0..self.max_iter {
384 let alpha_old = alpha;
385 let lambda_old = lambda;
386
387 let (w_new, sd_new) = bayesian_ridge_solve(&x_work, &y_work, alpha, lambda)?;
389
390 let gamma: F = (0..n_features)
394 .map(|i| alpha * xtx[[i, i]] * sd_new[i])
395 .fold(F::zero(), |a, b| a + b);
396
397 let residual = &y_work - x_work.dot(&w_new);
399 let sse = residual.dot(&residual);
400
401 let new_alpha = (n_f - gamma) / sse.max(F::from(1e-300).unwrap());
403
404 let w_norm_sq = w_new.dot(&w_new);
406 let new_lambda = gamma / w_norm_sq.max(F::from(1e-300).unwrap());
407
408 let clamp_max = F::from(1e10).unwrap();
410 let clamp_min = F::from(1e-10).unwrap();
411 alpha = new_alpha.min(clamp_max).max(clamp_min);
412 lambda = new_lambda.min(clamp_max).max(clamp_min);
413
414 let delta_alpha =
416 (alpha - alpha_old).abs() / (alpha_old.abs() + F::from(1e-10).unwrap());
417 let delta_lambda =
418 (lambda - lambda_old).abs() / (lambda_old.abs() + F::from(1e-10).unwrap());
419
420 w = w_new;
421 sigma_diag = sd_new;
422
423 if delta_alpha < self.tol && delta_lambda < self.tol {
424 break;
425 }
426
427 let _ = trace_xtx;
429 let _ = n_feat_f;
430 }
431
432 let intercept = if let (Some(xm), Some(ym)) = (&x_mean, &y_mean) {
433 *ym - xm.dot(&w)
434 } else {
435 F::zero()
436 };
437
438 Ok(FittedBayesianRidge {
439 coefficients: w,
440 intercept,
441 alpha,
442 lambda,
443 sigma: sigma_diag,
444 })
445 }
446}
447
448impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>>
449 for FittedBayesianRidge<F>
450{
451 type Output = Array1<F>;
452 type Error = FerroError;
453
454 fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
463 let n_features = x.ncols();
464 if n_features != self.coefficients.len() {
465 return Err(FerroError::ShapeMismatch {
466 expected: vec![self.coefficients.len()],
467 actual: vec![n_features],
468 context: "number of features must match fitted model".into(),
469 });
470 }
471
472 let preds = x.dot(&self.coefficients) + self.intercept;
473 Ok(preds)
474 }
475}
476
477impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F>
478 for FittedBayesianRidge<F>
479{
480 fn coefficients(&self) -> &Array1<F> {
482 &self.coefficients
483 }
484
485 fn intercept(&self) -> F {
487 self.intercept
488 }
489}
490
491impl PipelineEstimator<f64> for BayesianRidge<f64> {
493 fn fit_pipeline(
499 &self,
500 x: &Array2<f64>,
501 y: &Array1<f64>,
502 ) -> Result<Box<dyn FittedPipelineEstimator<f64>>, FerroError> {
503 let fitted = self.fit(x, y)?;
504 Ok(Box::new(fitted))
505 }
506}
507
508impl FittedPipelineEstimator<f64> for FittedBayesianRidge<f64> {
509 fn predict_pipeline(&self, x: &Array2<f64>) -> Result<Array1<f64>, FerroError> {
515 self.predict(x)
516 }
517}
518
519#[cfg(test)]
520mod tests {
521 use super::*;
522 use approx::assert_relative_eq;
523 use ndarray::array;
524
525 #[test]
528 fn test_default_constructor() {
529 let m = BayesianRidge::<f64>::new();
530 assert_eq!(m.max_iter, 300);
531 assert!(m.fit_intercept);
532 assert_relative_eq!(m.alpha_init, 1.0);
533 assert_relative_eq!(m.lambda_init, 1.0);
534 }
535
536 #[test]
537 fn test_builder_setters() {
538 let m = BayesianRidge::<f64>::new()
539 .with_max_iter(50)
540 .with_tol(1e-6)
541 .with_alpha_init(2.0)
542 .with_lambda_init(0.5)
543 .with_fit_intercept(false);
544 assert_eq!(m.max_iter, 50);
545 assert!(!m.fit_intercept);
546 assert_relative_eq!(m.alpha_init, 2.0);
547 assert_relative_eq!(m.lambda_init, 0.5);
548 }
549
550 #[test]
553 fn test_shape_mismatch() {
554 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
555 let y = array![1.0, 2.0];
556 let result = BayesianRidge::<f64>::new().fit(&x, &y);
557 assert!(result.is_err());
558 }
559
560 #[test]
561 fn test_insufficient_samples() {
562 let x = Array2::from_shape_vec((1, 1), vec![1.0]).unwrap();
563 let y = array![1.0];
564 let result = BayesianRidge::<f64>::new().fit(&x, &y);
565 assert!(result.is_err());
566 }
567
568 #[test]
569 fn test_non_positive_alpha_init() {
570 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
571 let y = array![1.0, 2.0, 3.0];
572 let result = BayesianRidge::<f64>::new().with_alpha_init(0.0).fit(&x, &y);
573 assert!(result.is_err());
574 }
575
576 #[test]
577 fn test_non_positive_lambda_init() {
578 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
579 let y = array![1.0, 2.0, 3.0];
580 let result = BayesianRidge::<f64>::new()
581 .with_lambda_init(-1.0)
582 .fit(&x, &y);
583 assert!(result.is_err());
584 }
585
586 #[test]
589 fn test_fits_linear_data() {
590 let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
591 let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
592
593 let fitted = BayesianRidge::<f64>::new().fit(&x, &y).unwrap();
594
595 assert_relative_eq!(fitted.coefficients()[0], 2.0, epsilon = 0.1);
597 assert_relative_eq!(fitted.intercept(), 1.0, epsilon = 0.5);
598 }
599
600 #[test]
601 fn test_alpha_and_lambda_positive() {
602 let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
603 let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
604
605 let fitted = BayesianRidge::<f64>::new().fit(&x, &y).unwrap();
606
607 assert!(fitted.alpha() > 0.0);
608 assert!(fitted.lambda() > 0.0);
609 }
610
611 #[test]
612 fn test_sigma_diagonal_positive() {
613 let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
614 let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
615
616 let fitted = BayesianRidge::<f64>::new().fit(&x, &y).unwrap();
617
618 for &v in fitted.sigma().iter() {
619 assert!(v > 0.0, "sigma diagonal must be positive, got {v}");
620 }
621 }
622
623 #[test]
624 fn test_sigma_length_matches_features() {
625 let x = Array2::from_shape_vec(
626 (5, 2),
627 vec![1.0, 0.5, 2.0, 1.0, 3.0, 1.5, 4.0, 2.0, 5.0, 2.5],
628 )
629 .unwrap();
630 let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
631
632 let fitted = BayesianRidge::<f64>::new().fit(&x, &y).unwrap();
633 assert_eq!(fitted.sigma().len(), 2);
634 }
635
636 #[test]
637 fn test_no_intercept() {
638 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
639 let y = array![2.0, 4.0, 6.0, 8.0];
640
641 let fitted = BayesianRidge::<f64>::new()
642 .with_fit_intercept(false)
643 .fit(&x, &y)
644 .unwrap();
645
646 assert_relative_eq!(fitted.intercept(), 0.0, epsilon = 1e-10);
647 }
648
649 #[test]
650 fn test_predict_length() {
651 let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
652 let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
653
654 let fitted = BayesianRidge::<f64>::new().fit(&x, &y).unwrap();
655 let preds = fitted.predict(&x).unwrap();
656 assert_eq!(preds.len(), 5);
657 }
658
659 #[test]
660 fn test_predict_feature_mismatch() {
661 let x = Array2::from_shape_vec((3, 2), vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0]).unwrap();
662 let y = array![1.0, 2.0, 3.0];
663 let fitted = BayesianRidge::<f64>::new().fit(&x, &y).unwrap();
664
665 let x_bad = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
666 assert!(fitted.predict(&x_bad).is_err());
667 }
668
669 #[test]
670 fn test_has_coefficients_length() {
671 let x = Array2::from_shape_vec(
672 (4, 3),
673 vec![1.0, 0.0, 0.5, 2.0, 1.0, 1.0, 3.0, 0.0, 1.5, 4.0, 1.0, 2.0],
674 )
675 .unwrap();
676 let y = array![1.0, 2.0, 3.0, 4.0];
677 let fitted = BayesianRidge::<f64>::new().fit(&x, &y).unwrap();
678 assert_eq!(fitted.coefficients().len(), 3);
679 }
680
681 #[test]
682 fn test_pipeline_integration() {
683 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
684 let y = array![3.0, 5.0, 7.0, 9.0];
685
686 let model = BayesianRidge::<f64>::new();
687 let fitted_pipe = model.fit_pipeline(&x, &y).unwrap();
688 let preds = fitted_pipe.predict_pipeline(&x).unwrap();
689 assert_eq!(preds.len(), 4);
690 }
691
692 #[test]
693 fn test_multivariate_fit() {
694 let x =
696 Array2::from_shape_vec((4, 2), vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0]).unwrap();
697 let y = array![1.0, 2.0, 3.0, 6.0];
698
699 let fitted = BayesianRidge::<f64>::new().fit(&x, &y).unwrap();
700 let preds = fitted.predict(&x).unwrap();
701 assert_eq!(preds.len(), 4);
702 let residuals: Vec<f64> = preds
704 .iter()
705 .zip(y.iter())
706 .map(|(p, t)| (p - t).abs())
707 .collect();
708 assert!(residuals.iter().all(|&r| r < 1.0));
709 }
710}