1use ferrolearn_core::error::FerroError;
40use ferrolearn_core::introspection::HasCoefficients;
41use ferrolearn_core::pipeline::{FittedPipelineEstimator, PipelineEstimator};
42use ferrolearn_core::traits::{Fit, Predict};
43use ndarray::{Array1, Array2, Axis, ScalarOperand};
44use num_traits::{Float, FromPrimitive};
45
46#[derive(Debug, Clone)]
56pub struct BayesianRidge<F> {
57 pub max_iter: usize,
59 pub tol: F,
61 pub alpha_init: F,
63 pub lambda_init: F,
65 pub fit_intercept: bool,
67}
68
69impl<F: Float + FromPrimitive> BayesianRidge<F> {
70 #[must_use]
75 pub fn new() -> Self {
76 Self {
77 max_iter: 300,
78 tol: F::from(1e-3).unwrap(),
79 alpha_init: F::one(),
80 lambda_init: F::one(),
81 fit_intercept: true,
82 }
83 }
84
85 #[must_use]
87 pub fn with_max_iter(mut self, max_iter: usize) -> Self {
88 self.max_iter = max_iter;
89 self
90 }
91
92 #[must_use]
94 pub fn with_tol(mut self, tol: F) -> Self {
95 self.tol = tol;
96 self
97 }
98
99 #[must_use]
101 pub fn with_alpha_init(mut self, alpha_init: F) -> Self {
102 self.alpha_init = alpha_init;
103 self
104 }
105
106 #[must_use]
108 pub fn with_lambda_init(mut self, lambda_init: F) -> Self {
109 self.lambda_init = lambda_init;
110 self
111 }
112
113 #[must_use]
115 pub fn with_fit_intercept(mut self, fit_intercept: bool) -> Self {
116 self.fit_intercept = fit_intercept;
117 self
118 }
119}
120
121impl<F: Float + FromPrimitive> Default for BayesianRidge<F> {
122 fn default() -> Self {
123 Self::new()
124 }
125}
126
127#[derive(Debug, Clone)]
133pub struct FittedBayesianRidge<F> {
134 coefficients: Array1<F>,
136 intercept: F,
138 alpha: F,
140 lambda: F,
142 sigma: Array1<F>,
144}
145
146impl<F: Float> FittedBayesianRidge<F> {
147 pub fn alpha(&self) -> F {
149 self.alpha
150 }
151
152 pub fn lambda(&self) -> F {
154 self.lambda
155 }
156
157 pub fn sigma(&self) -> &Array1<F> {
159 &self.sigma
160 }
161}
162
163fn bayesian_ridge_solve<F: Float + FromPrimitive + 'static>(
167 x: &Array2<F>,
168 y: &Array1<F>,
169 alpha: F,
170 lambda: F,
171) -> Result<(Array1<F>, Array1<F>), FerroError> {
172 let (_n_samples, n_features) = x.dim();
173
174 let xt = x.t();
176 let mut xtx = xt.dot(x);
177
178 for i in 0..n_features {
181 for j in 0..n_features {
182 xtx[[i, j]] = xtx[[i, j]] * alpha;
183 }
184 xtx[[i, i]] = xtx[[i, i]] + lambda;
185 }
186
187 let xty = xt.dot(y);
188 let xty_scaled: Array1<F> = xty.mapv(|v| v * alpha);
189
190 let w = cholesky_solve(&xtx, &xty_scaled)?;
192
193 let sigma_diag = cholesky_diag_inv(&xtx)?;
195
196 Ok((w, sigma_diag))
197}
198
199fn cholesky_solve<F: Float>(a: &Array2<F>, b: &Array1<F>) -> Result<Array1<F>, FerroError> {
201 let n = a.nrows();
202 let mut l = Array2::<F>::zeros((n, n));
203
204 for i in 0..n {
205 for j in 0..=i {
206 let mut s = a[[i, j]];
207 for k in 0..j {
208 s = s - l[[i, k]] * l[[j, k]];
209 }
210 if i == j {
211 if s <= F::zero() {
212 return Err(FerroError::NumericalInstability {
213 message: "Cholesky: matrix not positive definite".into(),
214 });
215 }
216 l[[i, j]] = s.sqrt();
217 } else {
218 l[[i, j]] = s / l[[j, j]];
219 }
220 }
221 }
222
223 let mut z = Array1::<F>::zeros(n);
225 for i in 0..n {
226 let mut s = b[i];
227 for j in 0..i {
228 s = s - l[[i, j]] * z[j];
229 }
230 z[i] = s / l[[i, i]];
231 }
232
233 let mut x = Array1::<F>::zeros(n);
235 for i in (0..n).rev() {
236 let mut s = z[i];
237 for j in (i + 1)..n {
238 s = s - l[[j, i]] * x[j];
239 }
240 x[i] = s / l[[i, i]];
241 }
242
243 Ok(x)
244}
245
246fn cholesky_diag_inv<F: Float>(a: &Array2<F>) -> Result<Array1<F>, FerroError> {
250 let n = a.nrows();
251 let mut l = Array2::<F>::zeros((n, n));
252
253 for i in 0..n {
254 for j in 0..=i {
255 let mut s = a[[i, j]];
256 for k in 0..j {
257 s = s - l[[i, k]] * l[[j, k]];
258 }
259 if i == j {
260 if s <= F::zero() {
261 return Err(FerroError::NumericalInstability {
262 message: "Cholesky diag_inv: matrix not positive definite".into(),
263 });
264 }
265 l[[i, j]] = s.sqrt();
266 } else {
267 l[[i, j]] = s / l[[j, j]];
268 }
269 }
270 }
271
272 let mut diag = Array1::<F>::zeros(n);
274 for col in 0..n {
275 let mut z = Array1::<F>::zeros(n);
277 z[col] = F::one() / l[[col, col]];
278 for i in (col + 1)..n {
279 let mut s = F::zero();
280 for k in col..i {
281 s = s + l[[i, k]] * z[k];
282 }
283 z[i] = -s / l[[i, i]];
284 }
285 for i in 0..n {
287 diag[i] = diag[i] + z[i] * z[i];
288 }
289 }
290
291 Ok(diag)
292}
293
294impl<F: Float + Send + Sync + ScalarOperand + FromPrimitive + 'static> Fit<Array2<F>, Array1<F>>
295 for BayesianRidge<F>
296{
297 type Fitted = FittedBayesianRidge<F>;
298 type Error = FerroError;
299
300 fn fit(&self, x: &Array2<F>, y: &Array1<F>) -> Result<FittedBayesianRidge<F>, FerroError> {
313 let (n_samples, n_features) = x.dim();
314
315 if n_samples != y.len() {
316 return Err(FerroError::ShapeMismatch {
317 expected: vec![n_samples],
318 actual: vec![y.len()],
319 context: "y length must match number of samples in X".into(),
320 });
321 }
322
323 if n_samples < 2 {
324 return Err(FerroError::InsufficientSamples {
325 required: 2,
326 actual: n_samples,
327 context: "BayesianRidge requires at least 2 samples".into(),
328 });
329 }
330
331 if self.alpha_init <= F::zero() {
332 return Err(FerroError::InvalidParameter {
333 name: "alpha_init".into(),
334 reason: "must be positive".into(),
335 });
336 }
337
338 if self.lambda_init <= F::zero() {
339 return Err(FerroError::InvalidParameter {
340 name: "lambda_init".into(),
341 reason: "must be positive".into(),
342 });
343 }
344
345 let n_f = F::from(n_samples).unwrap();
346 let n_feat_f = F::from(n_features).unwrap();
347
348 let (x_work, y_work, x_mean, y_mean) = if self.fit_intercept {
350 let x_mean = x
351 .mean_axis(Axis(0))
352 .ok_or_else(|| FerroError::NumericalInstability {
353 message: "failed to compute column means".into(),
354 })?;
355 let y_mean = y.mean().ok_or_else(|| FerroError::NumericalInstability {
356 message: "failed to compute target mean".into(),
357 })?;
358
359 let x_c = x - &x_mean;
360 let y_c = y - y_mean;
361 (x_c, y_c, Some(x_mean), Some(y_mean))
362 } else {
363 (x.clone(), y.clone(), None, None)
364 };
365
366 let xt = x_work.t();
369 let xtx = xt.dot(&x_work);
370
371 let trace_xtx: F = (0..n_features)
374 .map(|i| xtx[[i, i]])
375 .fold(F::zero(), |a, b| a + b);
376
377 let mut alpha = self.alpha_init;
378 let mut lambda = self.lambda_init;
379
380 let mut w = Array1::<F>::zeros(n_features);
381 let mut sigma_diag = Array1::<F>::ones(n_features);
382
383 for _iter in 0..self.max_iter {
384 let alpha_old = alpha;
385 let lambda_old = lambda;
386
387 let (w_new, sd_new) = bayesian_ridge_solve(&x_work, &y_work, alpha, lambda)?;
389
390 let gamma: F = (0..n_features)
394 .map(|i| alpha * xtx[[i, i]] * sd_new[i])
395 .fold(F::zero(), |a, b| a + b);
396
397 let residual = &y_work - x_work.dot(&w_new);
399 let sse = residual.dot(&residual);
400
401 let new_alpha = (n_f - gamma) / sse.max(F::from(1e-300).unwrap());
403
404 let w_norm_sq = w_new.dot(&w_new);
406 let new_lambda = gamma / w_norm_sq.max(F::from(1e-300).unwrap());
407
408 let clamp_max = F::from(1e10).unwrap();
410 let clamp_min = F::from(1e-10).unwrap();
411 alpha = new_alpha.min(clamp_max).max(clamp_min);
412 lambda = new_lambda.min(clamp_max).max(clamp_min);
413
414 let delta_alpha =
416 (alpha - alpha_old).abs() / (alpha_old.abs() + F::from(1e-10).unwrap());
417 let delta_lambda =
418 (lambda - lambda_old).abs() / (lambda_old.abs() + F::from(1e-10).unwrap());
419
420 w = w_new;
421 sigma_diag = sd_new;
422
423 if delta_alpha < self.tol && delta_lambda < self.tol {
424 break;
425 }
426
427 let _ = trace_xtx;
429 let _ = n_feat_f;
430 }
431
432 let intercept = if let (Some(xm), Some(ym)) = (&x_mean, &y_mean) {
433 *ym - xm.dot(&w)
434 } else {
435 F::zero()
436 };
437
438 Ok(FittedBayesianRidge {
439 coefficients: w,
440 intercept,
441 alpha,
442 lambda,
443 sigma: sigma_diag,
444 })
445 }
446}
447
448impl<F: Float + Send + Sync + ScalarOperand + 'static> Predict<Array2<F>>
449 for FittedBayesianRidge<F>
450{
451 type Output = Array1<F>;
452 type Error = FerroError;
453
454 fn predict(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
463 let n_features = x.ncols();
464 if n_features != self.coefficients.len() {
465 return Err(FerroError::ShapeMismatch {
466 expected: vec![self.coefficients.len()],
467 actual: vec![n_features],
468 context: "number of features must match fitted model".into(),
469 });
470 }
471
472 let preds = x.dot(&self.coefficients) + self.intercept;
473 Ok(preds)
474 }
475}
476
477impl<F: Float + Send + Sync + ScalarOperand + 'static> HasCoefficients<F>
478 for FittedBayesianRidge<F>
479{
480 fn coefficients(&self) -> &Array1<F> {
482 &self.coefficients
483 }
484
485 fn intercept(&self) -> F {
487 self.intercept
488 }
489}
490
491impl<F> PipelineEstimator<F> for BayesianRidge<F>
493where
494 F: Float + FromPrimitive + ScalarOperand + Send + Sync + 'static,
495{
496 fn fit_pipeline(
502 &self,
503 x: &Array2<F>,
504 y: &Array1<F>,
505 ) -> Result<Box<dyn FittedPipelineEstimator<F>>, FerroError> {
506 let fitted = self.fit(x, y)?;
507 Ok(Box::new(fitted))
508 }
509}
510
511impl<F> FittedPipelineEstimator<F> for FittedBayesianRidge<F>
512where
513 F: Float + ScalarOperand + Send + Sync + 'static,
514{
515 fn predict_pipeline(&self, x: &Array2<F>) -> Result<Array1<F>, FerroError> {
521 self.predict(x)
522 }
523}
524
525#[cfg(test)]
526mod tests {
527 use super::*;
528 use approx::assert_relative_eq;
529 use ndarray::array;
530
531 #[test]
534 fn test_default_constructor() {
535 let m = BayesianRidge::<f64>::new();
536 assert_eq!(m.max_iter, 300);
537 assert!(m.fit_intercept);
538 assert_relative_eq!(m.alpha_init, 1.0);
539 assert_relative_eq!(m.lambda_init, 1.0);
540 }
541
542 #[test]
543 fn test_builder_setters() {
544 let m = BayesianRidge::<f64>::new()
545 .with_max_iter(50)
546 .with_tol(1e-6)
547 .with_alpha_init(2.0)
548 .with_lambda_init(0.5)
549 .with_fit_intercept(false);
550 assert_eq!(m.max_iter, 50);
551 assert!(!m.fit_intercept);
552 assert_relative_eq!(m.alpha_init, 2.0);
553 assert_relative_eq!(m.lambda_init, 0.5);
554 }
555
556 #[test]
559 fn test_shape_mismatch() {
560 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
561 let y = array![1.0, 2.0];
562 let result = BayesianRidge::<f64>::new().fit(&x, &y);
563 assert!(result.is_err());
564 }
565
566 #[test]
567 fn test_insufficient_samples() {
568 let x = Array2::from_shape_vec((1, 1), vec![1.0]).unwrap();
569 let y = array![1.0];
570 let result = BayesianRidge::<f64>::new().fit(&x, &y);
571 assert!(result.is_err());
572 }
573
574 #[test]
575 fn test_non_positive_alpha_init() {
576 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
577 let y = array![1.0, 2.0, 3.0];
578 let result = BayesianRidge::<f64>::new().with_alpha_init(0.0).fit(&x, &y);
579 assert!(result.is_err());
580 }
581
582 #[test]
583 fn test_non_positive_lambda_init() {
584 let x = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
585 let y = array![1.0, 2.0, 3.0];
586 let result = BayesianRidge::<f64>::new()
587 .with_lambda_init(-1.0)
588 .fit(&x, &y);
589 assert!(result.is_err());
590 }
591
592 #[test]
595 fn test_fits_linear_data() {
596 let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
597 let y = array![3.0, 5.0, 7.0, 9.0, 11.0];
598
599 let fitted = BayesianRidge::<f64>::new().fit(&x, &y).unwrap();
600
601 assert_relative_eq!(fitted.coefficients()[0], 2.0, epsilon = 0.1);
603 assert_relative_eq!(fitted.intercept(), 1.0, epsilon = 0.5);
604 }
605
606 #[test]
607 fn test_alpha_and_lambda_positive() {
608 let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
609 let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
610
611 let fitted = BayesianRidge::<f64>::new().fit(&x, &y).unwrap();
612
613 assert!(fitted.alpha() > 0.0);
614 assert!(fitted.lambda() > 0.0);
615 }
616
617 #[test]
618 fn test_sigma_diagonal_positive() {
619 let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
620 let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
621
622 let fitted = BayesianRidge::<f64>::new().fit(&x, &y).unwrap();
623
624 for &v in fitted.sigma().iter() {
625 assert!(v > 0.0, "sigma diagonal must be positive, got {v}");
626 }
627 }
628
629 #[test]
630 fn test_sigma_length_matches_features() {
631 let x = Array2::from_shape_vec(
632 (5, 2),
633 vec![1.0, 0.5, 2.0, 1.0, 3.0, 1.5, 4.0, 2.0, 5.0, 2.5],
634 )
635 .unwrap();
636 let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
637
638 let fitted = BayesianRidge::<f64>::new().fit(&x, &y).unwrap();
639 assert_eq!(fitted.sigma().len(), 2);
640 }
641
642 #[test]
643 fn test_no_intercept() {
644 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
645 let y = array![2.0, 4.0, 6.0, 8.0];
646
647 let fitted = BayesianRidge::<f64>::new()
648 .with_fit_intercept(false)
649 .fit(&x, &y)
650 .unwrap();
651
652 assert_relative_eq!(fitted.intercept(), 0.0, epsilon = 1e-10);
653 }
654
655 #[test]
656 fn test_predict_length() {
657 let x = Array2::from_shape_vec((5, 1), vec![1.0, 2.0, 3.0, 4.0, 5.0]).unwrap();
658 let y = array![2.0, 4.0, 6.0, 8.0, 10.0];
659
660 let fitted = BayesianRidge::<f64>::new().fit(&x, &y).unwrap();
661 let preds = fitted.predict(&x).unwrap();
662 assert_eq!(preds.len(), 5);
663 }
664
665 #[test]
666 fn test_predict_feature_mismatch() {
667 let x = Array2::from_shape_vec((3, 2), vec![1.0, 0.0, 2.0, 0.0, 3.0, 0.0]).unwrap();
668 let y = array![1.0, 2.0, 3.0];
669 let fitted = BayesianRidge::<f64>::new().fit(&x, &y).unwrap();
670
671 let x_bad = Array2::from_shape_vec((3, 1), vec![1.0, 2.0, 3.0]).unwrap();
672 assert!(fitted.predict(&x_bad).is_err());
673 }
674
675 #[test]
676 fn test_has_coefficients_length() {
677 let x = Array2::from_shape_vec(
678 (4, 3),
679 vec![1.0, 0.0, 0.5, 2.0, 1.0, 1.0, 3.0, 0.0, 1.5, 4.0, 1.0, 2.0],
680 )
681 .unwrap();
682 let y = array![1.0, 2.0, 3.0, 4.0];
683 let fitted = BayesianRidge::<f64>::new().fit(&x, &y).unwrap();
684 assert_eq!(fitted.coefficients().len(), 3);
685 }
686
687 #[test]
688 fn test_pipeline_integration() {
689 let x = Array2::from_shape_vec((4, 1), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
690 let y = array![3.0, 5.0, 7.0, 9.0];
691
692 let model = BayesianRidge::<f64>::new();
693 let fitted_pipe = model.fit_pipeline(&x, &y).unwrap();
694 let preds = fitted_pipe.predict_pipeline(&x).unwrap();
695 assert_eq!(preds.len(), 4);
696 }
697
698 #[test]
699 fn test_multivariate_fit() {
700 let x =
702 Array2::from_shape_vec((4, 2), vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 2.0]).unwrap();
703 let y = array![1.0, 2.0, 3.0, 6.0];
704
705 let fitted = BayesianRidge::<f64>::new().fit(&x, &y).unwrap();
706 let preds = fitted.predict(&x).unwrap();
707 assert_eq!(preds.len(), 4);
708 let residuals: Vec<f64> = preds
710 .iter()
711 .zip(y.iter())
712 .map(|(p, t)| (p - t).abs())
713 .collect();
714 assert!(residuals.iter().all(|&r| r < 1.0));
715 }
716}