1pub mod calibration;
23pub mod error;
24pub mod handle;
25pub mod layers;
26pub mod ptx_kernels;
27pub mod uncertainty;
28pub mod variational;
29
30pub mod prelude {
34 pub use crate::calibration::beta::{BetaCalibConfig, BetaCalibrator};
35 pub use crate::calibration::conformal::{
36 ConformalClassifier, ConformalRegressor, RapsClassifier, conformal_quantile,
37 };
38 pub use crate::calibration::histogram::{
39 BinStrategy, HistogramBinCalibrator, HistogramBinConfig,
40 };
41 pub use crate::calibration::isotonic::IsotonicRegressor;
42 pub use crate::calibration::metrics::{
43 ReliabilityBin, ReliabilityDiagram, adaptive_calibration_error, brier_score,
44 expected_calibration_error, maximum_calibration_error, negative_log_likelihood,
45 reliability_diagram, top1_confidences,
46 };
47 pub use crate::calibration::platt::{PlattFitConfig, PlattScaler};
48 pub use crate::calibration::temperature::{TemperatureFitConfig, TemperatureScaler};
49 pub use crate::calibration::vector_scaling::{ScalingMode, VectorScaler, VectorScalingConfig};
50 pub use crate::error::{BayesError, BayesResult};
51 pub use crate::handle::{BayesHandle, LcgRng, SmVersion};
52 pub use crate::layers::bayes_conv::BayesConv2d;
53 pub use crate::layers::bayes_gru::{BayesGru, BayesGruConfig, BayesGruState, BayesGruWeights};
54 pub use crate::layers::bayes_linear::{BayesLinear, softplus};
55 pub use crate::layers::flipout::{FlipoutConv2d, FlipoutLinear};
56 pub use crate::ptx_kernels::{
57 ece_bucket_ptx, ensemble_aggregate_ptx, f32_hex, flipout_perturb_ptx, kl_gaussian_ptx,
58 local_reparam_ptx, mc_dropout_mask_ptx, temp_scale_logits_ptx,
59 };
60 pub use crate::uncertainty::deep_ensemble::{DeepEnsemble, EnsembleStats};
61 pub use crate::uncertainty::entropy::{
62 aleatoric_entropy, epistemic_entropy, mutual_information, predictive_entropy,
63 };
64 pub use crate::uncertainty::functional_laplace::{FunctionalLaplace, FunctionalLaplaceConfig};
65 pub use crate::uncertainty::laplace::LastLayerLaplace;
66 pub use crate::uncertainty::mc_dropout::{McDropoutPredictor, mc_dropout_predict};
67 pub use crate::uncertainty::swag::SwagPosterior;
68 pub use crate::variational::elbo::{ElboConfig, elbo, iwae, kl_gaussian, kl_gaussian_vec};
69 pub use crate::variational::flows::{PlanarFlow, RadialFlow};
70 pub use crate::variational::hmc::{Hmc, HmcConfig, HmcResult, Nuts, NutsConfig, NutsResult};
71 pub use crate::variational::mean_field::MeanFieldDist;
72 pub use crate::variational::reparam::{
73 gaussian_log_prob, gaussian_sample, laplacian_log_prob, laplacian_sample,
74 log_prob_gaussian_vec, sample_gaussian_vec, straight_through,
75 };
76 pub use crate::variational::vcl::{VclConfig, VclState};
77}
78
79#[cfg(test)]
82mod e2e_tests {
83 use crate::prelude::*;
84
85 fn synthetic_overconfident(
88 n: usize,
89 n_classes: usize,
90 acc_ratio: f32,
91 ) -> (Vec<f32>, Vec<usize>) {
92 let mut logits = Vec::with_capacity(n * n_classes);
93 let mut labels = Vec::with_capacity(n);
94 for i in 0..n {
95 for k in 0..n_classes {
96 logits.push(if k == 0 { 6.0 } else { 0.0 });
97 }
98 let frac = i as f32 / n as f32;
99 labels.push(if frac < acc_ratio {
100 0
101 } else {
102 1 + (i % n_classes.saturating_sub(1)).min(n_classes - 2)
103 });
104 }
105 (logits, labels)
106 }
107
108 fn softmax_rows(logits: &[f32], k: usize) -> Vec<f32> {
110 let n = logits.len() / k;
111 let mut out = Vec::with_capacity(logits.len());
112 for i in 0..n {
113 let row = &logits[i * k..(i + 1) * k];
114 let mut m = f32::NEG_INFINITY;
115 for &v in row {
116 if v > m {
117 m = v;
118 }
119 }
120 let mut s = 0.0_f32;
121 let mut tmp = Vec::with_capacity(k);
122 for &v in row {
123 let e = (v - m).exp();
124 tmp.push(e);
125 s += e;
126 }
127 let inv = 1.0_f32 / s;
128 for v in tmp.iter_mut() {
129 *v *= inv;
130 }
131 out.extend_from_slice(&tmp);
132 }
133 out
134 }
135
136 #[test]
137 fn e2e_temperature_scaling_recalibrates_overconfident_classifier() {
138 let n_classes = 3;
139 let (logits, labels) = synthetic_overconfident(300, n_classes, 0.7);
140 let scaler = TemperatureScaler::fit_default(&logits, &labels, n_classes).unwrap();
141 let probs_before = softmax_rows(&logits, n_classes);
142 let probs_after = scaler.apply(&logits, n_classes).unwrap();
143 let (c_before, ok_before) = top1_confidences(&probs_before, &labels, n_classes).unwrap();
144 let (c_after, ok_after) = top1_confidences(&probs_after, &labels, n_classes).unwrap();
145 let ece_before = expected_calibration_error(&c_before, &ok_before, 10).unwrap();
146 let ece_after = expected_calibration_error(&c_after, &ok_after, 10).unwrap();
147 assert!(
148 ece_after <= ece_before + 1e-4,
149 "Temperature scaling should not worsen ECE (before={ece_before}, after={ece_after})"
150 );
151 let acc_before = ok_before.iter().filter(|&&x| x).count();
153 let acc_after = ok_after.iter().filter(|&&x| x).count();
154 assert_eq!(acc_before, acc_after);
155 }
156
157 #[test]
158 fn e2e_isotonic_recalibrates_binary_scores() {
159 let n = 100;
160 let mut scores = Vec::new();
161 let mut targets = Vec::new();
162 for i in 0..n {
163 let s = i as f32 / n as f32;
164 scores.push(s);
165 targets.push((s * s).clamp(0.0, 1.0));
167 }
168 let labels: Vec<f32> = targets
169 .iter()
170 .map(|&p| if p > 0.5 { 1.0 } else { 0.0 })
171 .collect();
172 let r = IsotonicRegressor::fit(&scores, &labels).unwrap();
173 let preds = r.predict(&scores);
174 for w in preds.windows(2) {
175 assert!(w[0] <= w[1] + 1e-6, "isotonic must be non-decreasing");
176 }
177 }
178
179 #[test]
180 fn e2e_platt_scales_binary_logits() {
181 let mut scores = Vec::new();
182 let mut labels = Vec::new();
183 for i in 0..100 {
184 let s = (i as f32 - 50.0) * 0.1;
185 scores.push(s);
186 labels.push(if s > 0.0 { 1_u8 } else { 0_u8 });
187 }
188 let p = PlattScaler::fit_default(&scores, &labels).unwrap();
189 assert!(p.predict_one(5.0) > p.predict_one(-5.0));
190 }
191
192 #[test]
193 fn e2e_mc_dropout_quantifies_uncertainty() {
194 let mut handle = BayesHandle::default_handle();
195 let base = [0.3_f32, 0.7];
199 let stats = mc_dropout_predict(2048, handle.rng_mut(), |r| {
200 let (e0, _) = r.next_normal_pair();
201 let (e1, _) = r.next_normal_pair();
202 Ok(vec![base[0] + 0.1 * e0, base[1] + 0.1 * e1])
203 })
204 .unwrap();
205 eprintln!(
206 "MC Dropout mean=({:.3}, {:.3}), var=({:.4}, {:.4})",
207 stats.mean[0], stats.mean[1], stats.variance[0], stats.variance[1]
208 );
209 for v in &stats.variance {
211 assert!(
212 *v > 0.001 && *v < 0.05,
213 "variance out of expected range: {v}"
214 );
215 }
216 assert!((0.0..=1.0).contains(&stats.mean[0]));
218 assert!((0.0..=1.0).contains(&stats.mean[1]));
219 }
220
221 #[test]
222 fn e2e_deep_ensemble_aggregates_disagreement() {
223 let preds = vec![
224 vec![0.9_f32, 0.05, 0.05],
225 vec![0.05_f32, 0.9, 0.05],
226 vec![0.05_f32, 0.05, 0.9],
227 ];
228 let ensemble = DeepEnsemble::new(preds).unwrap();
229 let stats = ensemble.aggregate_probabilities().unwrap();
230 for v in &stats.mean {
232 assert!((v - 1.0 / 3.0).abs() < 0.01);
233 }
234 for v in &stats.variance {
236 assert!(*v > 0.1);
237 }
238 }
239
240 #[test]
241 fn e2e_swag_posterior_sampling_round_trip() {
242 let mut handle = BayesHandle::default_handle();
243 let mut posterior = SwagPosterior::new(4, 3).unwrap();
244 for offset in [-0.1_f32, -0.05, 0.0, 0.05, 0.1] {
246 let iterate: Vec<f32> = (0..4).map(|i| (i + 1) as f32 + offset).collect();
247 posterior.update(&iterate).unwrap();
248 }
249 for (i, &m) in posterior.mean.iter().enumerate() {
251 assert!((m - (i + 1) as f32).abs() < 1e-4);
252 }
253 let theta = posterior.sample(handle.rng_mut()).unwrap();
254 assert_eq!(theta.len(), 4);
255 for v in theta {
256 assert!(v.is_finite());
257 }
258 }
259
260 #[test]
261 fn e2e_laplace_widens_predictions_with_low_precision() {
262 let map = vec![1.0_f32, -1.0];
263 let phi: Vec<f32> = (0..30)
264 .flat_map(|i| {
265 let x = (i as f32 - 15.0) * 0.2;
266 vec![x, x * 0.5]
267 })
268 .collect();
269 let labels: Vec<u8> = (0..30)
270 .map(|i| if (i as f32 - 15.0) * 0.2 > 0.0 { 1 } else { 0 })
271 .collect();
272 let laplace = LastLayerLaplace::fit_binary_logistic(&map, &phi, &labels, 1.0).unwrap();
273 let p = laplace.predictive_probability(&[2.0_f32, 1.0]).unwrap();
274 assert!((0.0..=1.0).contains(&p));
275 }
276
277 #[test]
278 fn e2e_bald_finds_disagreement() {
279 let samples = vec![
281 0.95_f32, 0.05, 0.05_f32, 0.95, 0.5_f32, 0.5, ];
285 let mi = mutual_information(&samples, 2, 3).unwrap();
286 let ent = predictive_entropy(&samples, 2, 3).unwrap();
287 let aleatoric = aleatoric_entropy(&samples, 2, 3).unwrap();
288 assert!((ent - aleatoric - mi).abs() < 1e-5);
289 assert!(mi > 0.0);
290 }
291
292 #[test]
293 fn e2e_brier_and_nll_agree_on_perfect_predictor() {
294 let probs = vec![1.0_f32, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
295 let labels = vec![0_usize, 1, 2];
296 let bs = brier_score(&probs, &labels, 3).unwrap();
297 let nll = negative_log_likelihood(&probs, &labels, 3).unwrap();
298 assert!(bs < 1e-5);
299 assert!(nll < 1e-3);
300 }
301
302 #[test]
303 fn e2e_ptx_kernels_all_sm_versions() {
304 for sm in [75_u32, 80, 86, 90, 100, 120] {
305 for prog in [
306 kl_gaussian_ptx(sm),
307 mc_dropout_mask_ptx(sm),
308 local_reparam_ptx(sm),
309 ece_bucket_ptx(sm),
310 ensemble_aggregate_ptx(sm),
311 flipout_perturb_ptx(sm),
312 temp_scale_logits_ptx(sm),
313 ] {
314 assert!(prog.contains(&format!("sm_{sm}")));
315 assert!(prog.contains(".visible .entry"));
316 }
317 }
318 }
319
320 #[test]
321 fn e2e_reliability_diagram_serialises_to_json() {
322 let c = vec![0.1_f32, 0.5, 0.9];
323 let ok = vec![false, true, true];
324 let rd = reliability_diagram(&c, &ok, 5).unwrap();
325 assert_eq!(rd.bins.len(), 5);
327 assert_eq!(rd.n_samples, 3);
328 let total_count: usize = rd.bins.iter().map(|b| b.count).sum();
329 assert_eq!(total_count, 3);
330 let ece1 = rd.ece();
332 let ece2 = expected_calibration_error(&c, &ok, 5).unwrap();
333 assert!((ece1 - ece2).abs() < 1e-6);
334 }
335}