use blr_core::{fit, ArdConfig};
const REF_MU: [f64; 6] = [
3.318614945429841e-06,
1.499444429251276e-01,
-1.356245421649093e-05,
7.494213126776456e-06,
-9.342072600966224e-06,
4.298196216852146e+00,
];
const REF_DIAG_SIGMA: [f64; 6] = [
3.681735557898095e-06,
4.810623773542053e-02,
5.030677519628666e-07,
1.595545582086612e-07,
3.381798564590973e-05,
6.918883836481319e-01,
];
const REF_ALPHA: [f64; 6] = [
2.724_259_807_119_9e5,
1.416704911931874e+01,
1.986686948228129e+06,
6.261289835045856e+06,
2.956906379257813e+04,
5.217467162490756e-02,
];
const REF_BETA: f64 = 3.373685356006518e-01;
const REF_LOG_EV_LAST: f64 = -4.611079075240905e+01;
fn build_phi(x: &[f64]) -> (Vec<f64>, usize) {
let n = x.len();
let d = 6usize;
let mut phi = vec![0.0_f64; n * d];
for i in 0..n {
phi[i * d] = 1.0;
phi[i * d + 1] = x[i];
phi[i * d + 2] = x[i].powi(2);
phi[i * d + 3] = x[i].powi(3);
phi[i * d + 4] = (x[i] / 0.8).tanh();
phi[i * d + 5] = (x[i] / 1.5).tanh();
}
(phi, d)
}
fn load_lin_data() -> (Vec<f64>, Vec<f64>) {
let manifest_dir = env!("CARGO_MANIFEST_DIR");
let csv_path = std::path::Path::new(manifest_dir)
.join("data")
.join("lin_data.csv");
let content = std::fs::read_to_string(&csv_path)
.unwrap_or_else(|e| panic!("Cannot read {}: {e}", csv_path.display()));
let mut x_vec = Vec::new();
let mut y_vec = Vec::new();
for line in content.lines().skip(1) {
let parts: Vec<&str> = line.split(',').collect();
if parts.len() < 2 {
continue;
}
let xi: f64 = parts[0].trim().parse().expect("parse x");
let yi: f64 = parts[1].trim().parse().expect("parse y");
x_vec.push(xi);
y_vec.push(yi);
}
assert_eq!(x_vec.len(), 25, "Expected 25 rows in lin_data.csv");
(x_vec, y_vec)
}
#[test]
fn test_numerical_parity_hall_sensor() {
let (x, y) = load_lin_data();
let (phi, d) = build_phi(&x);
let n = y.len();
let cfg = ArdConfig {
alpha_init: 1.0,
beta_init: 1.0,
max_iter: 2000,
tol: 1e-4,
update_beta: true,
};
let model = fit(&phi, &y, n, d, &cfg).expect("fit failed");
let mu = &model.posterior.mean;
let diag_sigma: Vec<f64> = (0..d).map(|i| model.posterior.cov[i * d + i]).collect();
for j in 0..d {
let diff = (mu[j] - REF_MU[j]).abs();
assert!(
diff < 1e-6,
"mu[{j}]: Rust={:.6e} Python={:.6e} diff={diff:.2e}",
mu[j],
REF_MU[j]
);
}
for j in 0..d {
let diff = (diag_sigma[j] - REF_DIAG_SIGMA[j]).abs();
assert!(
diff < 1e-5,
"diag_sigma[{j}]: Rust={:.6e} Python={:.6e} diff={diff:.2e}",
diag_sigma[j],
REF_DIAG_SIGMA[j]
);
}
for (j, &expected_alpha) in REF_ALPHA.iter().take(d).enumerate() {
let rel = (model.alpha[j] - expected_alpha).abs() / expected_alpha.max(1e-30);
assert!(
rel < 1e-4,
"alpha[{j}]: Rust={:.6e} Python={:.6e} rel={rel:.2e}",
model.alpha[j],
expected_alpha
);
}
let beta_rel = (model.beta - REF_BETA).abs() / REF_BETA;
assert!(
beta_rel < 1e-3,
"beta: Rust={:.6e} Python={:.6e} rel={beta_rel:.2e}",
model.beta,
REF_BETA
);
let lml = model.log_marginal_likelihood();
let lml_diff = (lml - REF_LOG_EV_LAST).abs();
assert!(
lml_diff < 1e-4,
"log_ev_last: Rust={lml:.6e} Python={REF_LOG_EV_LAST:.6e} diff={lml_diff:.2e}"
);
}