Skip to main content

blr_core/
ard.rs

1//! BLR+ARD EM fitting, prediction API, and configuration.
2//!
3//! Implements MacKay (1992) / Tipping (2001) empirical Bayes for Bayesian
4//! Linear Regression with Automatic Relevance Determination.
5//!
6//! ## Overview
7//!
8//! The main entry point is [`fit`], which runs an EM loop to find the
9//! posterior distribution over weights and the noise precision hyperparameter β.
10//! ARD places an independent precision hyperparameter `α_d` on each weight;
11//! features with low signal converge to `α_d → ∞`, effectively removing them
12//! from the model.
13//!
14//! After fitting, use [`FittedArd`] to:
15//! - Inspect posterior weight mean and covariance
16//! - Identify active features via [`FittedArd::relevant_features`]
17//! - Predict on new data via [`FittedArd::predict`]
18//!
19//! ## Example: Basic Fit
20//!
21//! ```rust
22//! use blr_core::{fit, ArdConfig};
23//!
24//! // 20 observations, 3 features (row-major feature matrix)
25//! let phi: Vec<f64> = vec![1.0; 60];
26//! let y:   Vec<f64> = vec![0.5; 20];
27//! let config = ArdConfig::default();
28//!
29//! let fitted = fit(&phi, &y, 20, 3, &config)
30//!     .expect("fit should succeed with valid input");
31//!
32//! assert!(fitted.noise_std() > 0.0);
33//! assert_eq!(fitted.relevant_features(None).len(), 3);
34//! ```
35//!
36//! ## Example: Inspect ARD Relevance
37//!
38//! ```rust
39//! use blr_core::{fit, ArdConfig};
40//!
41//! let phi: Vec<f64> = vec![1.0; 60];
42//! let y:   Vec<f64> = vec![0.5; 20];
43//! let fitted = fit(&phi, &y, 20, 3, &ArdConfig::default()).unwrap();
44//!
45//! // relevance() returns 1/αd — larger means more relevant
46//! let rel = fitted.relevance();
47//! println!("Feature relevances: {:?}", rel);
48//!
49//! // relevant_features() returns a boolean mask
50//! let active = fitted.relevant_features(None);
51//! let n_active = active.iter().filter(|&&x| x).count();
52//! println!("{} of {} features are active", n_active, active.len());
53//! ```
54//!
55//! ## EM Algorithm Summary
56//!
57//! Each iteration:
58//!
59//! 1. **E-step**: Compute posterior mean `μ` and covariance `Σ`
60//!    using the current `{α_d, β}`.
61//! 2. **M-step**: Update each `α_d` and optionally β using the posterior
62//!    statistics (gamma updates from MacKay 1992 Eq. 32–33).
63//! 3. **Convergence**: Stop when the change in log-evidence between
64//!    consecutive iterations is below `ArdConfig::tol`, or `max_iter` is reached.
65//!
66//! ## References
67//!
68//! - MacKay, D. J. C. (1992). "Bayesian Interpolation."
69//!   *Neural Computation*, 4(3), 415–447.
70//! - Tipping, M. E. (2001). "Sparse Bayesian Learning and the Relevance Vector Machine."
71//!   *Journal of Machine Learning Research*, 1, 211–244.
72
73use std::f64::consts::PI;
74
75use faer::linalg::{matmul, solvers::Solve};
76use faer::{Accum, Mat, Par, Side};
77
78use crate::gaussian::cholesky_logdet;
79use crate::{BLRError, Gaussian};
80
81// ─── BLRPrior ─────────────────────────────────────────────────────────────────
82
83/// Custom prior for BLR+ARD fitting (batch transfer learning).
84///
85/// Encodes prior knowledge about the parameter distribution of a sensor
86/// ensemble, aggregated from N reference sensors calibrated in Phase 1.
87/// Pass this to [`fit_with_prior`] to accelerate calibration of new
88/// production sensors (Phase 2) using transferred knowledge.
89///
90/// ## Fields
91///
92/// All three vectors must have the same length D (feature dimension).
93///
94/// ## Reference
95///
96/// Berger, Schott, Paul. "Bayesian Sensor Calibration."
97/// IEEE Sensors Journal, Vol. 22, No. 20, October 2022.
98/// Equations (20)–(21): prior mean and covariance from ensemble.
99#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
100pub struct BLRPrior {
101    /// Prior weight mean μ_0 of length D.
102    pub mean: Vec<f64>,
103    /// Prior weight covariance Σ_0 (D×D, symmetric positive-definite, row-major).
104    pub cov: Vec<f64>,
105    /// Prior ARD precision hyperparameters α_0 of length D.
106    pub alphas: Vec<f64>,
107}
108
109impl BLRPrior {
110    /// Validate that dimensions are consistent and covariance is positive-definite.
111    ///
112    /// Checks:
113    /// 1. `mean.len() == alphas.len()` (both equal D).
114    /// 2. `cov.len() == D * D` (square matrix).
115    /// 3. Cholesky factorization of `cov` succeeds (confirms PSD).
116    pub fn validate(&self) -> Result<(), BLRError> {
117        let d = self.mean.len();
118        if self.alphas.len() != d {
119            return Err(BLRError::DimMismatch {
120                expected: d,
121                got: self.alphas.len(),
122            });
123        }
124        if self.cov.len() != d * d {
125            return Err(BLRError::DimMismatch {
126                expected: d * d,
127                got: self.cov.len(),
128            });
129        }
130        if d == 0 {
131            return Err(BLRError::DimMismatch {
132                expected: 1,
133                got: 0,
134            });
135        }
136        // Verify positive-definiteness by attempting Cholesky factorization.
137        let cov_mat = Mat::<f64>::from_fn(d, d, |i, j| self.cov[i * d + j]);
138        cov_mat
139            .llt(Side::Lower)
140            .map_err(|_| BLRError::SingularMatrix)?;
141        Ok(())
142    }
143}
144
145// ─── Configuration ────────────────────────────────────────────────────────────
146
147/// Configuration for the BLR+ARD EM fitting loop.
148///
149/// Defaults match the Python reference:
150/// `alpha_init=1.0, beta_init=1.0, max_iter=100, tol=1e-5, update_beta=true`.
151#[derive(Debug, Clone)]
152pub struct ArdConfig {
153    /// Initial value for all ARD precision hyperparameters α_j.
154    pub alpha_init: f64,
155    /// Initial noise precision β = 1/σ².
156    pub beta_init: f64,
157    /// Maximum number of EM iterations.
158    pub max_iter: usize,
159    /// Convergence tolerance on the period-2 log-evidence delta.
160    pub tol: f64,
161    /// Whether to update β during the M-step.
162    pub update_beta: bool,
163}
164
165impl Default for ArdConfig {
166    fn default() -> Self {
167        Self {
168            alpha_init: 1.0,
169            beta_init: 1.0,
170            max_iter: 100,
171            tol: 1e-5,
172            update_beta: true,
173        }
174    }
175}
176
177// ─── Predictive distribution (per test point) ────────────────────────────────
178
179/// Marginal predictive distributions for a set of test points.
180///
181/// Uncertainty is decomposed into aleatoric (noise) and epistemic (model)
182/// components, matching the Python `predict()` output.
183pub struct PredictiveMarginals {
184    /// Predictive mean E\[y_*\] for each test point.
185    pub mean: Vec<f64>,
186    /// Aleatoric std = 1/√β (noise; same for all points).
187    pub aleatoric_std: f64,
188    /// Epistemic std √(φ_* Σ φ_*ᵀ) for each test point.
189    pub epistemic_std: Vec<f64>,
190    /// Total std √(aleatoric² + epistemic²) for each test point.
191    pub total_std: Vec<f64>,
192}
193
194// ─── Fitted model ─────────────────────────────────────────────────────────────
195
196/// Result of a successful `fit()` call.
197pub struct FittedArd {
198    /// Weight-space posterior N(μ, Σ).
199    pub posterior: Gaussian,
200    /// ARD precision hyperparameters α (D,).
201    pub alpha: Vec<f64>,
202    /// Noise precision β.
203    pub beta: f64,
204    /// Log marginal likelihood per EM iteration.
205    pub log_evidences: Vec<f64>,
206    /// Number of training samples N used to fit this model.
207    pub n_samples: usize,
208}
209
210impl FittedArd {
211    // ── Prediction ──────────────────────────────────────────────────────────
212
213    /// Marginal predictions with decomposed uncertainty.
214    ///
215    /// `phi_test` is the N_test × D feature matrix in row-major order.
216    pub fn predict(
217        &self,
218        phi_test: &[f64],
219        n_test: usize,
220        n_features: usize,
221    ) -> PredictiveMarginals {
222        let d = n_features;
223        let sigma_mat = Mat::<f64>::from_fn(d, d, |i, j| self.posterior.cov[i * d + j]);
224        let mu_col = Mat::<f64>::from_fn(d, 1, |i, _| self.posterior.mean[i]);
225
226        let aleatoric_var = 1.0 / self.beta;
227        let aleatoric_std = aleatoric_var.sqrt();
228
229        let mut mean = Vec::with_capacity(n_test);
230        let mut epistemic_std = Vec::with_capacity(n_test);
231        let mut total_std = Vec::with_capacity(n_test);
232
233        for i in 0..n_test {
234            let phi_row = Mat::<f64>::from_fn(1, d, |_, j| phi_test[i * d + j]);
235
236            // mean[i] = phi_row * mu
237            let mut m_mat = Mat::<f64>::zeros(1, 1);
238            matmul::matmul(
239                m_mat.as_mut(),
240                Accum::Replace,
241                phi_row.as_ref(),
242                mu_col.as_ref(),
243                1.0_f64,
244                Par::Seq,
245            );
246            mean.push(m_mat[(0, 0)]);
247
248            // epistemic_var[i] = phi_row * Sigma * phi_row^T
249            let mut sigma_phi_t = Mat::<f64>::zeros(d, 1);
250            matmul::matmul(
251                sigma_phi_t.as_mut(),
252                Accum::Replace,
253                sigma_mat.as_ref(),
254                phi_row.as_ref().transpose(),
255                1.0_f64,
256                Par::Seq,
257            );
258            let mut ep_var_mat = Mat::<f64>::zeros(1, 1);
259            matmul::matmul(
260                ep_var_mat.as_mut(),
261                Accum::Replace,
262                phi_row.as_ref(),
263                sigma_phi_t.as_ref(),
264                1.0_f64,
265                Par::Seq,
266            );
267            let ep_var = ep_var_mat[(0, 0)].max(0.0);
268            epistemic_std.push(ep_var.sqrt());
269            total_std.push((aleatoric_var + ep_var).sqrt());
270        }
271
272        PredictiveMarginals {
273            mean,
274            aleatoric_std,
275            epistemic_std,
276            total_std,
277        }
278    }
279
280    /// Full joint predictive Gaussian over all M test points.
281    ///
282    /// Returns N(Φ_test μ, Φ_test Σ Φ_test^T + (1/β) I_M).
283    pub fn predict_gaussian(
284        &self,
285        phi_test: &[f64],
286        n_test: usize,
287        n_features: usize,
288    ) -> Result<Gaussian, BLRError> {
289        let d = n_features;
290        let m = n_test;
291
292        let phi_mat = Mat::<f64>::from_fn(m, d, |i, j| phi_test[i * d + j]);
293        let sigma_mat = Mat::<f64>::from_fn(d, d, |i, j| self.posterior.cov[i * d + j]);
294        let mu_col = Mat::<f64>::from_fn(d, 1, |i, _| self.posterior.mean[i]);
295
296        // pred_mean = Φ_test * μ  (M×1)
297        let mut pred_mean_mat = Mat::<f64>::zeros(m, 1);
298        matmul::matmul(
299            pred_mean_mat.as_mut(),
300            Accum::Replace,
301            phi_mat.as_ref(),
302            mu_col.as_ref(),
303            1.0_f64,
304            Par::Seq,
305        );
306
307        // pred_cov = Φ_test * Σ * Φ_test^T + (1/β) I_M  (M×M)
308        // Step 1: tmp = Φ_test * Σ  (M×D)
309        let mut tmp = Mat::<f64>::zeros(m, d);
310        matmul::matmul(
311            tmp.as_mut(),
312            Accum::Replace,
313            phi_mat.as_ref(),
314            sigma_mat.as_ref(),
315            1.0_f64,
316            Par::Seq,
317        );
318        // Step 2: pred_cov = tmp * Φ_test^T  (M×M)
319        let mut pred_cov = Mat::<f64>::zeros(m, m);
320        matmul::matmul(
321            pred_cov.as_mut(),
322            Accum::Replace,
323            tmp.as_ref(),
324            phi_mat.as_ref().transpose(),
325            1.0_f64,
326            Par::Seq,
327        );
328        // Step 3: add noise + jitter to diagonal
329        let noise_var = 1.0 / self.beta;
330        for i in 0..m {
331            pred_cov[(i, i)] += noise_var + 1e-9; // jitter for PSD guarantee
332        }
333
334        let pred_cov_ref = pred_cov.as_ref();
335        let pred_mean_vec: Vec<f64> = (0..m).map(|i| pred_mean_mat[(i, 0)]).collect();
336        let pred_cov_vec: Vec<f64> = (0..m)
337            .flat_map(|i| (0..m).map(move |j| pred_cov_ref[(i, j)]))
338            .collect();
339
340        Gaussian::new(pred_mean_vec, pred_cov_vec)
341    }
342
343    // ── Interpretability ────────────────────────────────────────────────────
344
345    /// Feature relevance scores: 1/α_j (higher = more relevant).
346    pub fn relevance(&self) -> Vec<f64> {
347        self.alpha.iter().map(|a| 1.0 / a).collect()
348    }
349
350    /// Boolean mask: `true` where feature j is relevant (α_j < threshold).
351    ///
352    /// Default threshold = geometric mean of α (exp(mean(ln(α_j)))).
353    ///
354    /// TODO: replace geometric mean with median for a more robust heuristic
355    /// in a future iteration.
356    pub fn relevant_features(&self, threshold: Option<f64>) -> Vec<bool> {
357        let t = threshold.unwrap_or_else(|| {
358            let ln_mean = self.alpha.iter().map(|a| a.ln()).sum::<f64>() / self.alpha.len() as f64;
359            ln_mean.exp()
360        });
361        self.alpha.iter().map(|a| *a < t).collect()
362    }
363
364    // ── Summary scalars ─────────────────────────────────────────────────────
365
366    /// Noise standard deviation 1/√β.
367    pub fn noise_std(&self) -> f64 {
368        1.0 / self.beta.sqrt()
369    }
370
371    /// Log marginal likelihood at the last EM iteration.
372    pub fn log_marginal_likelihood(&self) -> f64 {
373        *self.log_evidences.last().unwrap_or(&f64::NEG_INFINITY)
374    }
375
376    // ── Active Learning API ─────────────────────────────────────────────────
377
378    /// Noise precision β accessor (= 1/σ²_noise).
379    pub fn noise_precision(&self) -> f64 {
380        self.beta
381    }
382
383    /// Posterior covariance Σ as a flat row-major D×D slice.
384    pub fn posterior_covariance(&self) -> &[f64] {
385        &self.posterior.cov
386    }
387
388    /// Number of training samples N used during fitting.
389    pub fn sample_count(&self) -> usize {
390        self.n_samples
391    }
392
393    /// Posterior standard deviations for arbitrary test points.
394    ///
395    /// # Arguments
396    /// - `phi_test`: N_test × D feature matrix, row-major flat slice
397    /// - `n_test`: number of test points
398    /// - `n_features`: feature dimension D (must match the training feature dim)
399    ///
400    /// # Returns
401    /// Vec of length `n_test` with posterior std for each point.
402    pub fn posterior_std(&self, phi_test: &[f64], n_test: usize, n_features: usize) -> Vec<f64> {
403        let d = n_features;
404        let sigma_cov = &self.posterior.cov;
405        let noise_var = 1.0 / self.beta.max(1e-10);
406        (0..n_test)
407            .map(|i| {
408                let phi_i = &phi_test[i * d..(i + 1) * d];
409                let mut sigma_phi = vec![0.0_f64; d];
410                for row in 0..d {
411                    for col in 0..d {
412                        sigma_phi[row] += sigma_cov[row * d + col] * phi_i[col];
413                    }
414                }
415                let epistemic: f64 = phi_i.iter().zip(sigma_phi.iter()).map(|(a, b)| a * b).sum();
416                (noise_var + epistemic.max(0.0)).sqrt()
417            })
418            .collect()
419    }
420
421    /// Posterior std evaluated on a uniform 1-D input grid.
422    ///
423    /// # Arguments
424    /// - `input_range`: (min, max) of the input domain
425    /// - `resolution`: number of grid points (≥ 2)
426    /// - `feature_fn`: maps a scalar input to a feature vector of length D
427    ///
428    /// # Returns
429    /// `(grid_points, std_devs)` — both of length `resolution`.
430    pub fn posterior_std_grid(
431        &self,
432        input_range: (f64, f64),
433        resolution: usize,
434        feature_fn: &dyn Fn(f64) -> Vec<f64>,
435    ) -> (Vec<f64>, Vec<f64>) {
436        let d_sq = self.posterior.cov.len();
437        let d = (d_sq as f64).sqrt() as usize;
438        let resolution = resolution.max(2);
439        let step = (input_range.1 - input_range.0) / (resolution - 1) as f64;
440        let grid: Vec<f64> = (0..resolution)
441            .map(|k| input_range.0 + k as f64 * step)
442            .collect();
443        let mut phi_grid = Vec::with_capacity(resolution * d);
444        for &x in &grid {
445            let feats = feature_fn(x);
446            let actual = feats.len().min(d);
447            phi_grid.extend_from_slice(&feats[..actual]);
448            if actual < d {
449                phi_grid.extend(std::iter::repeat(0.0).take(d - actual));
450            }
451        }
452        let stds = self.posterior_std(&phi_grid, resolution, d);
453        (grid, stds)
454    }
455}
456
457// ─── Log evidence helper ──────────────────────────────────────────────────────
458
459/// Compute log marginal likelihood (evidence) matching the Python
460/// `_log_evidence` implementation, including the `+D·log(2π)/2` term.
461///
462/// L = 0.5 * (Σ log(α_j) + N log(β) - logdet(Σ_inv) - β ||r||² - μᵀΛμ
463///           + D log(2π)) - 0.5 N log(2π)
464fn log_evidence(
465    n: usize,
466    d: usize,
467    alpha: &[f64],
468    beta: f64,
469    mu: &[f64],
470    logdet_sigma_inv: f64,
471    residual_sq: f64,
472) -> f64 {
473    let log_alpha_sum: f64 = alpha.iter().map(|a| a.ln()).sum();
474    let mu_lambda_mu: f64 = alpha.iter().zip(mu.iter()).map(|(a, m)| a * m * m).sum();
475
476    0.5 * (log_alpha_sum + (n as f64) * beta.ln()
477        - logdet_sigma_inv
478        - beta * residual_sq
479        - mu_lambda_mu
480        + (d as f64) * (2.0 * PI).ln())
481        - 0.5 * (n as f64) * (2.0 * PI).ln()
482}
483
484// ─── Fit entry point ──────────────────────────────────────────────────────────
485
486/// Fit BLR+ARD via EM (Type-II maximum likelihood / empirical Bayes).
487///
488/// # Arguments
489/// - `phi`: N×D feature matrix, row-major.
490/// - `y`: N target values.
491/// - `n`: number of training points (rows of phi).
492/// - `d`: number of features (columns of phi).
493/// - `config`: fitting hyperparameters.
494///
495/// # Returns
496/// `Ok(FittedArd)` on success, `Err(BLRError)` if a matrix inversion fails.
497pub fn fit(
498    phi: &[f64],
499    y: &[f64],
500    n: usize,
501    d: usize,
502    config: &ArdConfig,
503) -> Result<FittedArd, BLRError> {
504    if phi.len() != n * d {
505        return Err(BLRError::DimMismatch {
506            expected: n * d,
507            got: phi.len(),
508        });
509    }
510    if y.len() != n {
511        return Err(BLRError::DimMismatch {
512            expected: n,
513            got: y.len(),
514        });
515    }
516
517    let phi_mat = Mat::<f64>::from_fn(n, d, |i, j| phi[i * d + j]);
518    let y_mat = Mat::<f64>::from_fn(n, 1, |i, _| y[i]);
519
520    // Pre-compute Φᵀ Φ (D×D) and Φᵀ y (D×1) — reused every iteration.
521    let mut phi_t_phi = Mat::<f64>::zeros(d, d);
522    matmul::matmul(
523        phi_t_phi.as_mut(),
524        Accum::Replace,
525        phi_mat.as_ref().transpose(),
526        phi_mat.as_ref(),
527        1.0_f64,
528        Par::Seq,
529    );
530
531    let mut phi_t_y = Mat::<f64>::zeros(d, 1);
532    matmul::matmul(
533        phi_t_y.as_mut(),
534        Accum::Replace,
535        phi_mat.as_ref().transpose(),
536        y_mat.as_ref(),
537        1.0_f64,
538        Par::Seq,
539    );
540
541    // Initialise hyperparameters.
542    let mut alpha = vec![config.alpha_init; d];
543    let mut beta = config.beta_init;
544    let mut log_evidences: Vec<f64> = Vec::new();
545
546    // Working storage reused across iterations.
547    let mut sigma_mat = Mat::<f64>::zeros(d, d);
548    let mut mu_vec = vec![0.0_f64; d];
549
550    for _iter in 0..config.max_iter {
551        // ── E-step ────────────────────────────────────────────────────────
552        // σ_inv = diag(α) + β Φᵀ Φ
553        let mut sigma_inv = Mat::<f64>::from_fn(d, d, |i, j| beta * phi_t_phi[(i, j)]);
554        for j in 0..d {
555            sigma_inv[(j, j)] += alpha[j];
556        }
557
558        // Cholesky: L Lᵀ = σ_inv
559        let llt = sigma_inv
560            .llt(Side::Lower)
561            .map_err(|_| BLRError::SingularMatrix)?;
562
563        // Σ = σ_inv⁻¹  (solve with identity)
564        let eye = Mat::<f64>::identity(d, d);
565        sigma_mat = llt.solve(eye.as_ref());
566
567        // μ = β Σ Φᵀ y  (solve σ_inv · μ = β Φᵀ y)
568        let mut rhs = phi_t_y.clone();
569        for i in 0..d {
570            rhs[(i, 0)] *= beta;
571        }
572        let mu_mat = llt.solve(rhs.as_ref());
573        for i in 0..d {
574            mu_vec[i] = mu_mat[(i, 0)];
575        }
576
577        // Log-determinant of σ_inv via manual Cholesky diagonal
578        let logdet_sigma_inv = cholesky_logdet(&sigma_inv, d)?;
579
580        // ── Residuals (needed for log-evidence and β update) ──────────────
581        let mut phi_mu = Mat::<f64>::zeros(n, 1);
582        let mu_mat_ref = Mat::<f64>::from_fn(d, 1, |i, _| mu_vec[i]);
583        matmul::matmul(
584            phi_mu.as_mut(),
585            Accum::Replace,
586            phi_mat.as_ref(),
587            mu_mat_ref.as_ref(),
588            1.0_f64,
589            Par::Seq,
590        );
591        let residual_sq: f64 = (0..n)
592            .map(|i| {
593                let r = y[i] - phi_mu[(i, 0)];
594                r * r
595            })
596            .sum();
597
598        // ── M-step ────────────────────────────────────────────────────────
599        // γ_j = 1 − α_j Σ_jj     (effective parameters per feature)
600        let gamma: Vec<f64> = (0..d).map(|j| 1.0 - alpha[j] * sigma_mat[(j, j)]).collect();
601
602        // α_j = γ_j / (μ_j² + ε),  clamp to ≥ 1e-8
603        for j in 0..d {
604            alpha[j] = (gamma[j] / (mu_vec[j] * mu_vec[j] + 1e-10)).max(1e-8);
605        }
606
607        // β = (N − Σγ_j) / (||r||² + ε),  clamp to ≥ 1e-8
608        if config.update_beta {
609            let gamma_sum: f64 = gamma.iter().sum();
610            beta = ((n as f64 - gamma_sum) / (residual_sq + 1e-10)).max(1e-8);
611        }
612
613        // ── Log evidence — computed after M-step (matches Python ordering) ─
614        // Uses updated alpha/beta but E-step logdet_sigma_inv and mu from this iter.
615        let lml = log_evidence(n, d, &alpha, beta, &mu_vec, logdet_sigma_inv, residual_sq);
616        log_evidences.push(lml);
617
618        // ── Convergence: period-2 paired log-evidence delta ───────────────
619        let n_ev = log_evidences.len();
620        let delta = if n_ev >= 4 {
621            let mean_curr = 0.5 * (log_evidences[n_ev - 1] + log_evidences[n_ev - 2]);
622            let mean_prev = 0.5 * (log_evidences[n_ev - 3] + log_evidences[n_ev - 4]);
623            (mean_curr - mean_prev).abs()
624        } else if n_ev >= 2 {
625            (log_evidences[n_ev - 1] - log_evidences[n_ev - 2]).abs()
626        } else {
627            f64::INFINITY
628        };
629
630        if delta < config.tol {
631            break;
632        }
633    }
634
635    // Build final posterior Gaussian.
636    let mu_final: Vec<f64> = mu_vec.clone();
637    let cov_final: Vec<f64> = {
638        let sigma_ref = sigma_mat.as_ref();
639        (0..d)
640            .flat_map(|i| (0..d).map(move |j| sigma_ref[(i, j)]))
641            .collect()
642    };
643    let posterior = Gaussian::new(mu_final, cov_final)?;
644
645    Ok(FittedArd {
646        posterior,
647        alpha,
648        beta,
649        log_evidences,
650        n_samples: n,
651    })
652}
653
654// ─── fit_with_prior entry point ───────────────────────────────────────────────
655
656/// Fit BLR+ARD with an optional informed prior (batch transfer learning).
657///
658/// When `prior` is `Some`, the EM loop initialises weight mean and ARD alphas
659/// from the prior values rather than the `ArdConfig` defaults. This allows
660/// knowledge from a reference batch of sensors to accelerate convergence on a
661/// new production sensor.
662///
663/// When `prior` is `None`, this function is numerically equivalent to `fit()`.
664///
665/// # Arguments
666/// - `phi`: N×D feature matrix, row-major.
667/// - `y`: N target values.
668/// - `n`: number of training points (rows of phi).
669/// - `d`: number of features (columns of phi).
670/// - `config`: fitting hyperparameters.
671/// - `prior`: optional `BLRPrior` from a reference batch.
672///
673/// # Returns
674/// `Ok(FittedArd)` on success, `Err(BLRError)` if input validation or
675/// matrix operations fail.
676pub fn fit_with_prior(
677    phi: &[f64],
678    y: &[f64],
679    n: usize,
680    d: usize,
681    config: &ArdConfig,
682    prior: Option<&BLRPrior>,
683) -> Result<FittedArd, BLRError> {
684    if phi.len() != n * d {
685        return Err(BLRError::DimMismatch {
686            expected: n * d,
687            got: phi.len(),
688        });
689    }
690    if y.len() != n {
691        return Err(BLRError::DimMismatch {
692            expected: n,
693            got: y.len(),
694        });
695    }
696
697    // Validate prior dimensions if provided.
698    if let Some(p) = prior {
699        p.validate()?;
700        if p.mean.len() != d {
701            return Err(BLRError::DimMismatch {
702                expected: d,
703                got: p.mean.len(),
704            });
705        }
706    }
707
708    let phi_mat = Mat::<f64>::from_fn(n, d, |i, j| phi[i * d + j]);
709    let y_mat = Mat::<f64>::from_fn(n, 1, |i, _| y[i]);
710
711    // Pre-compute Φᵀ Φ (D×D) and Φᵀ y (D×1) — reused every iteration.
712    let mut phi_t_phi = Mat::<f64>::zeros(d, d);
713    matmul::matmul(
714        phi_t_phi.as_mut(),
715        Accum::Replace,
716        phi_mat.as_ref().transpose(),
717        phi_mat.as_ref(),
718        1.0_f64,
719        Par::Seq,
720    );
721
722    let mut phi_t_y = Mat::<f64>::zeros(d, 1);
723    matmul::matmul(
724        phi_t_y.as_mut(),
725        Accum::Replace,
726        phi_mat.as_ref().transpose(),
727        y_mat.as_ref(),
728        1.0_f64,
729        Par::Seq,
730    );
731
732    // Initialise hyperparameters from prior or config defaults.
733    let mut alpha: Vec<f64> = prior
734        .map(|p| p.alphas.clone())
735        .unwrap_or_else(|| vec![config.alpha_init; d]);
736    let mut beta = config.beta_init;
737    let mut log_evidences: Vec<f64> = Vec::new();
738
739    // Working storage reused across iterations.
740    let mut sigma_mat = Mat::<f64>::zeros(d, d);
741    // Initialise mu from prior mean if available, else zeros.
742    let mut mu_vec: Vec<f64> = prior
743        .map(|p| p.mean.clone())
744        .unwrap_or_else(|| vec![0.0f64; d]);
745
746    for _iter in 0..config.max_iter {
747        // ── E-step ────────────────────────────────────────────────────────
748        // σ_inv = diag(α) + β Φᵀ Φ
749        let mut sigma_inv = Mat::<f64>::from_fn(d, d, |i, j| beta * phi_t_phi[(i, j)]);
750        for j in 0..d {
751            sigma_inv[(j, j)] += alpha[j];
752        }
753
754        // Cholesky: L Lᵀ = σ_inv
755        let llt = sigma_inv
756            .llt(Side::Lower)
757            .map_err(|_| BLRError::SingularMatrix)?;
758
759        // Σ = σ_inv⁻¹  (solve with identity)
760        let eye = Mat::<f64>::identity(d, d);
761        sigma_mat = llt.solve(eye.as_ref());
762
763        // μ = β Σ Φᵀ y  (solve σ_inv · μ = β Φᵀ y)
764        let mut rhs = phi_t_y.clone();
765        for i in 0..d {
766            rhs[(i, 0)] *= beta;
767        }
768        let mu_mat = llt.solve(rhs.as_ref());
769        for i in 0..d {
770            mu_vec[i] = mu_mat[(i, 0)];
771        }
772
773        // Log-determinant of σ_inv via manual Cholesky diagonal
774        let logdet_sigma_inv = cholesky_logdet(&sigma_inv, d)?;
775
776        // ── Residuals (needed for log-evidence and β update) ──────────────
777        let mut phi_mu = Mat::<f64>::zeros(n, 1);
778        let mu_mat_ref = Mat::<f64>::from_fn(d, 1, |i, _| mu_vec[i]);
779        matmul::matmul(
780            phi_mu.as_mut(),
781            Accum::Replace,
782            phi_mat.as_ref(),
783            mu_mat_ref.as_ref(),
784            1.0_f64,
785            Par::Seq,
786        );
787        let residual_sq: f64 = (0..n)
788            .map(|i| {
789                let r = y[i] - phi_mu[(i, 0)];
790                r * r
791            })
792            .sum();
793
794        // ── M-step ────────────────────────────────────────────────────────
795        let gamma: Vec<f64> = (0..d).map(|j| 1.0 - alpha[j] * sigma_mat[(j, j)]).collect();
796
797        for j in 0..d {
798            alpha[j] = (gamma[j] / (mu_vec[j] * mu_vec[j] + 1e-10)).max(1e-8);
799        }
800
801        if config.update_beta {
802            let gamma_sum: f64 = gamma.iter().sum();
803            beta = ((n as f64 - gamma_sum) / (residual_sq + 1e-10)).max(1e-8);
804        }
805
806        let lml = log_evidence(n, d, &alpha, beta, &mu_vec, logdet_sigma_inv, residual_sq);
807        log_evidences.push(lml);
808
809        let n_ev = log_evidences.len();
810        let delta = if n_ev >= 4 {
811            let mean_curr = 0.5 * (log_evidences[n_ev - 1] + log_evidences[n_ev - 2]);
812            let mean_prev = 0.5 * (log_evidences[n_ev - 3] + log_evidences[n_ev - 4]);
813            (mean_curr - mean_prev).abs()
814        } else if n_ev >= 2 {
815            (log_evidences[n_ev - 1] - log_evidences[n_ev - 2]).abs()
816        } else {
817            f64::INFINITY
818        };
819
820        if delta < config.tol {
821            break;
822        }
823    }
824
825    let mu_final: Vec<f64> = mu_vec.clone();
826    let cov_final: Vec<f64> = {
827        let sigma_ref = sigma_mat.as_ref();
828        (0..d)
829            .flat_map(|i| (0..d).map(move |j| sigma_ref[(i, j)]))
830            .collect()
831    };
832    let posterior = Gaussian::new(mu_final, cov_final)?;
833
834    Ok(FittedArd {
835        posterior,
836        alpha,
837        beta,
838        log_evidences,
839        n_samples: n,
840    })
841}
842#[cfg(test)]
843mod tests {
844    use super::*;
845
846    #[test]
847    fn test_ard_config_defaults() {
848        let cfg = ArdConfig::default();
849        assert_eq!(cfg.alpha_init, 1.0);
850        assert_eq!(cfg.beta_init, 1.0);
851        assert_eq!(cfg.max_iter, 100);
852        assert_eq!(cfg.tol, 1e-5);
853        assert!(cfg.update_beta);
854    }
855
856    #[test]
857    fn test_log_evidence_helper() {
858        // Smoke test: result must be finite.
859        let lml = log_evidence(10, 3, &[1.0; 3], 1.0, &[0.0; 3], 5.0, 2.0);
860        assert!(lml.is_finite(), "log_evidence = {lml}");
861    }
862
863    #[test]
864    fn test_blr_prior_valid() {
865        let d = 3;
866        let prior = BLRPrior {
867            mean: vec![0.0; d],
868            cov: vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0], // identity
869            alphas: vec![1.0; d],
870        };
871        assert!(prior.validate().is_ok());
872    }
873
874    #[test]
875    fn test_blr_prior_invalid_dimensions() {
876        let prior = BLRPrior {
877            mean: vec![0.0; 3],
878            cov: vec![1.0, 0.0, 0.0, 1.0], // 2×2, should be 3×3
879            alphas: vec![1.0; 3],
880        };
881        assert!(prior.validate().is_err());
882    }
883
884    #[test]
885    fn test_blr_prior_not_psd() {
886        let d = 2;
887        let prior = BLRPrior {
888            mean: vec![0.0; d],
889            cov: vec![-1.0, 0.0, 0.0, -1.0], // negative diagonal → not PSD
890            alphas: vec![1.0; d],
891        };
892        assert!(matches!(prior.validate(), Err(BLRError::SingularMatrix)));
893    }
894
895    #[test]
896    fn test_fit_with_prior_none_equals_fit() {
897        // fit_with_prior(None) must be numerically equivalent to fit().
898        let phi: Vec<f64> = vec![1.0, 0.5, 0.25, 2.0, 1.0, 0.5, 0.5, 0.25, 0.125];
899        let y: Vec<f64> = vec![1.0, 2.0, 0.5];
900        let config = ArdConfig::default();
901
902        let r1 = fit(&phi, &y, 3, 3, &config).unwrap();
903        let r2 = fit_with_prior(&phi, &y, 3, 3, &config, None).unwrap();
904
905        // Same number of features.
906        assert_eq!(r1.alpha.len(), r2.alpha.len());
907        // Alpha values should be essentially identical.
908        for (a1, a2) in r1.alpha.iter().zip(r2.alpha.iter()) {
909            assert!((a1 - a2).abs() < 1e-10, "alpha mismatch: {a1} vs {a2}");
910        }
911        assert!((r1.beta - r2.beta).abs() < 1e-10);
912    }
913
914    #[test]
915    fn test_fit_with_prior_some_compiles_and_runs() {
916        let d = 3;
917        let phi: Vec<f64> = vec![
918            1.0, 0.5, 0.25, 2.0, 1.0, 0.5, 0.5, 0.25, 0.125, 1.5, 0.75, 0.3, 0.8, 0.4, 0.2,
919        ];
920        let y: Vec<f64> = vec![1.0, 2.0, 0.5, 1.5, 0.8];
921        let config = ArdConfig::default();
922
923        let prior = BLRPrior {
924            mean: vec![0.5; d],
925            cov: vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0],
926            alphas: vec![0.5; d],
927        };
928        let result = fit_with_prior(&phi, &y, 5, d, &config, Some(&prior));
929        assert!(
930            result.is_ok(),
931            "fit_with_prior should succeed: {:?}",
932            result.err()
933        );
934        let fitted = result.unwrap();
935        assert!(fitted.noise_std() > 0.0);
936        assert_eq!(fitted.alpha.len(), d);
937    }
938
939    #[test]
940    fn test_fit_with_prior_convergence_faster() {
941        // With an informed prior the EM loop should need fewer iterations.
942        let d = 3;
943        let n = 5;
944        let phi: Vec<f64> = vec![
945            1.0, 0.5, 0.25, 2.0, 1.0, 0.5, 0.5, 0.25, 0.125, 1.5, 0.75, 0.3, 0.8, 0.4, 0.2,
946        ];
947        let y: Vec<f64> = vec![1.0, 2.0, 0.5, 1.5, 0.8];
948        // Use tight tolerance so convergence differences are visible.
949        let config = ArdConfig {
950            max_iter: 200,
951            tol: 1e-9,
952            ..ArdConfig::default()
953        };
954
955        let baseline = fit_with_prior(&phi, &y, n, d, &config, None).unwrap();
956
957        // Prior centred near the baseline posterior → should converge faster.
958        let prior = BLRPrior {
959            mean: baseline.posterior.mean.clone(),
960            cov: baseline.posterior.cov.clone(),
961            alphas: baseline.alpha.clone(),
962        };
963        let informed = fit_with_prior(&phi, &y, n, d, &config, Some(&prior)).unwrap();
964
965        // Both must produce valid results.
966        assert!(informed.noise_std() > 0.0);
967        // The informed fit should need ≤ baseline iterations.
968        assert!(
969            informed.log_evidences.len() <= baseline.log_evidences.len(),
970            "informed iterations {} should be <= baseline iterations {}",
971            informed.log_evidences.len(),
972            baseline.log_evidences.len()
973        );
974    }
975}