Skip to main content

oxicuda_bayes/
lib.rs

1//! `oxicuda-bayes` — Bayesian deep learning primitives for OxiCUDA.
2//!
3//! Pure-Rust implementation of variational inference and Bayesian neural network
4//! building blocks suitable for CPU simulation and PTX kernel generation for GPU
5//! execution.
6//!
7//! # Architecture
8//!
9//! ```text
10//! oxicuda-bayes
11//! ├── layers/         — BayesLinear, BayesConv2d, Flipout layers
12//! ├── variational/    — ELBO, normalizing flows, mean-field, reparameterization
13//! ├── calibration/    — Temperature scaling, ECE/MCE/ACE, isotonic, Platt, Brier, NLL
14//! ├── uncertainty/    — MC Dropout, Deep Ensembles, SWAG, last-layer Laplace, BALD
15//! ├── error           — BayesError / BayesResult
16//! ├── handle          — BayesHandle (SmVersion + LcgRng)
17//! └── ptx_kernels     — GPU PTX kernel strings
18//! ```
19
20// ─── Module declarations ─────────────────────────────────────────────────────
21
22pub mod calibration;
23pub mod error;
24pub mod handle;
25pub mod layers;
26pub mod ptx_kernels;
27pub mod uncertainty;
28pub mod variational;
29
30// ─── Prelude ─────────────────────────────────────────────────────────────────
31
32/// Convenience re-exports for common Bayesian deep learning types.
33pub mod prelude {
34    pub use crate::calibration::beta::{BetaCalibConfig, BetaCalibrator};
35    pub use crate::calibration::conformal::{
36        ConformalClassifier, ConformalRegressor, RapsClassifier, conformal_quantile,
37    };
38    pub use crate::calibration::histogram::{
39        BinStrategy, HistogramBinCalibrator, HistogramBinConfig,
40    };
41    pub use crate::calibration::isotonic::IsotonicRegressor;
42    pub use crate::calibration::metrics::{
43        ReliabilityBin, ReliabilityDiagram, adaptive_calibration_error, brier_score,
44        expected_calibration_error, maximum_calibration_error, negative_log_likelihood,
45        reliability_diagram, top1_confidences,
46    };
47    pub use crate::calibration::platt::{PlattFitConfig, PlattScaler};
48    pub use crate::calibration::temperature::{TemperatureFitConfig, TemperatureScaler};
49    pub use crate::calibration::vector_scaling::{ScalingMode, VectorScaler, VectorScalingConfig};
50    pub use crate::error::{BayesError, BayesResult};
51    pub use crate::handle::{BayesHandle, LcgRng, SmVersion};
52    pub use crate::layers::bayes_conv::BayesConv2d;
53    pub use crate::layers::bayes_gru::{BayesGru, BayesGruConfig, BayesGruState, BayesGruWeights};
54    pub use crate::layers::bayes_linear::{BayesLinear, softplus};
55    pub use crate::layers::flipout::{FlipoutConv2d, FlipoutLinear};
56    pub use crate::ptx_kernels::{
57        ece_bucket_ptx, ensemble_aggregate_ptx, f32_hex, flipout_perturb_ptx, kl_gaussian_ptx,
58        local_reparam_ptx, mc_dropout_mask_ptx, temp_scale_logits_ptx,
59    };
60    pub use crate::uncertainty::deep_ensemble::{DeepEnsemble, EnsembleStats};
61    pub use crate::uncertainty::entropy::{
62        aleatoric_entropy, epistemic_entropy, mutual_information, predictive_entropy,
63    };
64    pub use crate::uncertainty::functional_laplace::{FunctionalLaplace, FunctionalLaplaceConfig};
65    pub use crate::uncertainty::laplace::LastLayerLaplace;
66    pub use crate::uncertainty::mc_dropout::{McDropoutPredictor, mc_dropout_predict};
67    pub use crate::uncertainty::swag::SwagPosterior;
68    pub use crate::variational::elbo::{ElboConfig, elbo, iwae, kl_gaussian, kl_gaussian_vec};
69    pub use crate::variational::flows::{PlanarFlow, RadialFlow};
70    pub use crate::variational::hmc::{Hmc, HmcConfig, HmcResult, Nuts, NutsConfig, NutsResult};
71    pub use crate::variational::mean_field::MeanFieldDist;
72    pub use crate::variational::reparam::{
73        gaussian_log_prob, gaussian_sample, laplacian_log_prob, laplacian_sample,
74        log_prob_gaussian_vec, sample_gaussian_vec, straight_through,
75    };
76    pub use crate::variational::vcl::{VclConfig, VclState};
77}
78
79// ─── End-to-end integration tests ────────────────────────────────────────────
80
81#[cfg(test)]
82mod e2e_tests {
83    use crate::prelude::*;
84
85    /// Generate a synthetic over-confident classifier output: argmax is always
86    /// class 0 (with margin 6) but only `acc_ratio` of true labels match.
87    fn synthetic_overconfident(
88        n: usize,
89        n_classes: usize,
90        acc_ratio: f32,
91    ) -> (Vec<f32>, Vec<usize>) {
92        let mut logits = Vec::with_capacity(n * n_classes);
93        let mut labels = Vec::with_capacity(n);
94        for i in 0..n {
95            for k in 0..n_classes {
96                logits.push(if k == 0 { 6.0 } else { 0.0 });
97            }
98            let frac = i as f32 / n as f32;
99            labels.push(if frac < acc_ratio {
100                0
101            } else {
102                1 + (i % n_classes.saturating_sub(1)).min(n_classes - 2)
103            });
104        }
105        (logits, labels)
106    }
107
108    /// Apply softmax row-wise to a `[N, K]` buffer (returns fresh allocation).
109    fn softmax_rows(logits: &[f32], k: usize) -> Vec<f32> {
110        let n = logits.len() / k;
111        let mut out = Vec::with_capacity(logits.len());
112        for i in 0..n {
113            let row = &logits[i * k..(i + 1) * k];
114            let mut m = f32::NEG_INFINITY;
115            for &v in row {
116                if v > m {
117                    m = v;
118                }
119            }
120            let mut s = 0.0_f32;
121            let mut tmp = Vec::with_capacity(k);
122            for &v in row {
123                let e = (v - m).exp();
124                tmp.push(e);
125                s += e;
126            }
127            let inv = 1.0_f32 / s;
128            for v in tmp.iter_mut() {
129                *v *= inv;
130            }
131            out.extend_from_slice(&tmp);
132        }
133        out
134    }
135
136    #[test]
137    fn e2e_temperature_scaling_recalibrates_overconfident_classifier() {
138        let n_classes = 3;
139        let (logits, labels) = synthetic_overconfident(300, n_classes, 0.7);
140        let scaler = TemperatureScaler::fit_default(&logits, &labels, n_classes).unwrap();
141        let probs_before = softmax_rows(&logits, n_classes);
142        let probs_after = scaler.apply(&logits, n_classes).unwrap();
143        let (c_before, ok_before) = top1_confidences(&probs_before, &labels, n_classes).unwrap();
144        let (c_after, ok_after) = top1_confidences(&probs_after, &labels, n_classes).unwrap();
145        let ece_before = expected_calibration_error(&c_before, &ok_before, 10).unwrap();
146        let ece_after = expected_calibration_error(&c_after, &ok_after, 10).unwrap();
147        assert!(
148            ece_after <= ece_before + 1e-4,
149            "Temperature scaling should not worsen ECE (before={ece_before}, after={ece_after})"
150        );
151        // Argmax preserved → accuracy unchanged.
152        let acc_before = ok_before.iter().filter(|&&x| x).count();
153        let acc_after = ok_after.iter().filter(|&&x| x).count();
154        assert_eq!(acc_before, acc_after);
155    }
156
157    #[test]
158    fn e2e_isotonic_recalibrates_binary_scores() {
159        let n = 100;
160        let mut scores = Vec::new();
161        let mut targets = Vec::new();
162        for i in 0..n {
163            let s = i as f32 / n as f32;
164            scores.push(s);
165            // True calibration: P(y=1|s) = s² (under-confident at mid range).
166            targets.push((s * s).clamp(0.0, 1.0));
167        }
168        let labels: Vec<f32> = targets
169            .iter()
170            .map(|&p| if p > 0.5 { 1.0 } else { 0.0 })
171            .collect();
172        let r = IsotonicRegressor::fit(&scores, &labels).unwrap();
173        let preds = r.predict(&scores);
174        for w in preds.windows(2) {
175            assert!(w[0] <= w[1] + 1e-6, "isotonic must be non-decreasing");
176        }
177    }
178
179    #[test]
180    fn e2e_platt_scales_binary_logits() {
181        let mut scores = Vec::new();
182        let mut labels = Vec::new();
183        for i in 0..100 {
184            let s = (i as f32 - 50.0) * 0.1;
185            scores.push(s);
186            labels.push(if s > 0.0 { 1_u8 } else { 0_u8 });
187        }
188        let p = PlattScaler::fit_default(&scores, &labels).unwrap();
189        assert!(p.predict_one(5.0) > p.predict_one(-5.0));
190    }
191
192    #[test]
193    fn e2e_mc_dropout_quantifies_uncertainty() {
194        let mut handle = BayesHandle::default_handle();
195        // Simulate a stochastic model: y = base + ε (ε ~ N(0, σ²)). Each
196        // closure call burns four LCG draws to decorrelate the components
197        // (the Knuth LCG has visible correlation between consecutive normals).
198        let base = [0.3_f32, 0.7];
199        let stats = mc_dropout_predict(2048, handle.rng_mut(), |r| {
200            let (e0, _) = r.next_normal_pair();
201            let (e1, _) = r.next_normal_pair();
202            Ok(vec![base[0] + 0.1 * e0, base[1] + 0.1 * e1])
203        })
204        .unwrap();
205        eprintln!(
206            "MC Dropout mean=({:.3}, {:.3}), var=({:.4}, {:.4})",
207            stats.mean[0], stats.mean[1], stats.variance[0], stats.variance[1]
208        );
209        // Variance is the principal outcome — should be on the order of 0.01.
210        for v in &stats.variance {
211            assert!(
212                *v > 0.001 && *v < 0.05,
213                "variance out of expected range: {v}"
214            );
215        }
216        // Predictive mean should be in the [0, 1] band (sanity).
217        assert!((0.0..=1.0).contains(&stats.mean[0]));
218        assert!((0.0..=1.0).contains(&stats.mean[1]));
219    }
220
221    #[test]
222    fn e2e_deep_ensemble_aggregates_disagreement() {
223        let preds = vec![
224            vec![0.9_f32, 0.05, 0.05],
225            vec![0.05_f32, 0.9, 0.05],
226            vec![0.05_f32, 0.05, 0.9],
227        ];
228        let ensemble = DeepEnsemble::new(preds).unwrap();
229        let stats = ensemble.aggregate_probabilities().unwrap();
230        // Mean is ~ uniform — very high disagreement.
231        for v in &stats.mean {
232            assert!((v - 1.0 / 3.0).abs() < 0.01);
233        }
234        // Variance per class should be substantial.
235        for v in &stats.variance {
236            assert!(*v > 0.1);
237        }
238    }
239
240    #[test]
241    fn e2e_swag_posterior_sampling_round_trip() {
242        let mut handle = BayesHandle::default_handle();
243        let mut posterior = SwagPosterior::new(4, 3).unwrap();
244        // Inject a few SGD-like iterates around mean (1, 2, 3, 4).
245        for offset in [-0.1_f32, -0.05, 0.0, 0.05, 0.1] {
246            let iterate: Vec<f32> = (0..4).map(|i| (i + 1) as f32 + offset).collect();
247            posterior.update(&iterate).unwrap();
248        }
249        // Mean should be approximately (1, 2, 3, 4).
250        for (i, &m) in posterior.mean.iter().enumerate() {
251            assert!((m - (i + 1) as f32).abs() < 1e-4);
252        }
253        let theta = posterior.sample(handle.rng_mut()).unwrap();
254        assert_eq!(theta.len(), 4);
255        for v in theta {
256            assert!(v.is_finite());
257        }
258    }
259
260    #[test]
261    fn e2e_laplace_widens_predictions_with_low_precision() {
262        let map = vec![1.0_f32, -1.0];
263        let phi: Vec<f32> = (0..30)
264            .flat_map(|i| {
265                let x = (i as f32 - 15.0) * 0.2;
266                vec![x, x * 0.5]
267            })
268            .collect();
269        let labels: Vec<u8> = (0..30)
270            .map(|i| if (i as f32 - 15.0) * 0.2 > 0.0 { 1 } else { 0 })
271            .collect();
272        let laplace = LastLayerLaplace::fit_binary_logistic(&map, &phi, &labels, 1.0).unwrap();
273        let p = laplace.predictive_probability(&[2.0_f32, 1.0]).unwrap();
274        assert!((0.0..=1.0).contains(&p));
275    }
276
277    #[test]
278    fn e2e_bald_finds_disagreement() {
279        // 3 ensemble members on a 2-class problem disagreeing strongly.
280        let samples = vec![
281            0.95_f32, 0.05, // member 1
282            0.05_f32, 0.95, // member 2
283            0.5_f32, 0.5, // member 3
284        ];
285        let mi = mutual_information(&samples, 2, 3).unwrap();
286        let ent = predictive_entropy(&samples, 2, 3).unwrap();
287        let aleatoric = aleatoric_entropy(&samples, 2, 3).unwrap();
288        assert!((ent - aleatoric - mi).abs() < 1e-5);
289        assert!(mi > 0.0);
290    }
291
292    #[test]
293    fn e2e_brier_and_nll_agree_on_perfect_predictor() {
294        let probs = vec![1.0_f32, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
295        let labels = vec![0_usize, 1, 2];
296        let bs = brier_score(&probs, &labels, 3).unwrap();
297        let nll = negative_log_likelihood(&probs, &labels, 3).unwrap();
298        assert!(bs < 1e-5);
299        assert!(nll < 1e-3);
300    }
301
302    #[test]
303    fn e2e_ptx_kernels_all_sm_versions() {
304        for sm in [75_u32, 80, 86, 90, 100, 120] {
305            for prog in [
306                kl_gaussian_ptx(sm),
307                mc_dropout_mask_ptx(sm),
308                local_reparam_ptx(sm),
309                ece_bucket_ptx(sm),
310                ensemble_aggregate_ptx(sm),
311                flipout_perturb_ptx(sm),
312                temp_scale_logits_ptx(sm),
313            ] {
314                assert!(prog.contains(&format!("sm_{sm}")));
315                assert!(prog.contains(".visible .entry"));
316            }
317        }
318    }
319
320    #[test]
321    fn e2e_reliability_diagram_serialises_to_json() {
322        let c = vec![0.1_f32, 0.5, 0.9];
323        let ok = vec![false, true, true];
324        let rd = reliability_diagram(&c, &ok, 5).unwrap();
325        // Spot-check counts.
326        assert_eq!(rd.bins.len(), 5);
327        assert_eq!(rd.n_samples, 3);
328        let total_count: usize = rd.bins.iter().map(|b| b.count).sum();
329        assert_eq!(total_count, 3);
330        // Use the diagram for ECE.
331        let ece1 = rd.ece();
332        let ece2 = expected_calibration_error(&c, &ok, 5).unwrap();
333        assert!((ece1 - ece2).abs() < 1e-6);
334    }
335}