Skip to main content

oxicuda_ssl/
lib.rs

1//! `oxicuda-ssl` — Self-supervised learning primitives for OxiCUDA.
2//!
3//! Pure-Rust implementation of the four canonical SSL families, suitable for
4//! CPU simulation and PTX kernel generation for GPU execution.
5//!
6//! # Architecture
7//!
8//! ```text
9//! oxicuda-ssl
10//! ├── contrastive/      — SimCLR (NT-Xent), MoCo (memory-bank InfoNCE)
11//! ├── non_contrastive/  — BYOL (cosine), Barlow Twins, VICReg
12//! ├── masked/           — MAE (random patch mask + reconstruction MSE)
13//! ├── clustering/       — SwAV (Sinkhorn-Knopp), DINO (centred + sharpened CE)
14//! ├── augment/          — Color jitter, multi-crop helpers
15//! ├── momentum/         — EmaUpdater for momentum-encoder schemes
16//! ├── head/             — MlpProjector, PredictorHead
17//! ├── error             — SslError / SslResult
18//! ├── handle            — SslHandle (SmVersion + LcgRng)
19//! └── ptx_kernels       — GPU PTX kernel strings
20//! ```
21
22// ─── Module declarations ─────────────────────────────────────────────────────
23
24pub mod augment;
25pub mod clustering;
26pub mod contrastive;
27pub mod error;
28pub mod handle;
29pub mod head;
30pub mod masked;
31pub mod momentum;
32pub mod non_contrastive;
33pub mod ptx_kernels;
34
35// ─── Prelude ─────────────────────────────────────────────────────────────────
36
37/// Convenience re-exports for common SSL types.
38pub mod prelude {
39    pub use crate::augment::color::{color_jitter, random_grayscale_chw};
40    pub use crate::augment::multi_crop::{MultiCropConfig, multi_crop};
41    pub use crate::clustering::dino::{DinoConfig, dino_loss};
42    pub use crate::clustering::swav::{SwavConfig, sinkhorn_knopp, swav_loss};
43    pub use crate::contrastive::info_nce::info_nce_loss;
44    pub use crate::contrastive::moco::{MocoQueue, moco_loss};
45    pub use crate::contrastive::simclr::{SimClrConfig, simclr_loss};
46    pub use crate::error::{SslError, SslResult};
47    pub use crate::handle::{LcgRng, SmVersion, SslHandle};
48    pub use crate::head::predictor::PredictorHead;
49    pub use crate::head::projector::MlpProjector;
50    pub use crate::masked::mae::{MaeConfig, mae_reconstruction_loss, random_patch_mask};
51    pub use crate::momentum::ema::{EmaUpdater, cosine_momentum};
52    pub use crate::non_contrastive::barlow::{BarlowTwinsConfig, barlow_twins_loss};
53    pub use crate::non_contrastive::byol::{ByolPredictor, byol_loss};
54    pub use crate::non_contrastive::vicreg::{VicRegConfig, vicreg_loss};
55    pub use crate::ptx_kernels::{
56        barlow_cross_corr_ptx, byol_cosine_loss_ptx, cosine_similarity_ptx, f32_hex,
57        gather_features_ptx, momentum_update_ptx, nt_xent_softmax_ptx, random_mask_ptx,
58    };
59}
60
61// ─── End-to-end integration tests ────────────────────────────────────────────
62
63#[cfg(test)]
64mod e2e_tests {
65    use crate::prelude::*;
66
67    /// Build a "perfectly aligned" projection batch where each row is a
68    /// distinct one-hot basis vector — diagonal cosine = 1, off-diagonal = 0.
69    fn aligned_projections(n: usize, d: usize) -> Vec<f32> {
70        let mut z = vec![0.0_f32; n * d];
71        for i in 0..n {
72            z[i * d + i % d] = 1.0;
73        }
74        z
75    }
76
77    #[test]
78    fn e2e_simclr_loss_drops_with_aligned_pairs() {
79        let n = 8;
80        let d = 16;
81        let z = aligned_projections(n, d);
82        let cfg = SimClrConfig::default();
83        let (loss, acc) = simclr_loss(&z, &z, n, d, &cfg).unwrap();
84        assert!(loss.is_finite() && loss < 1.0, "loss = {loss}");
85        assert!((acc - 1.0).abs() < 1e-6);
86    }
87
88    #[test]
89    fn e2e_moco_queue_lifecycle_fifo() {
90        let mut q = MocoQueue::new(8, 4).unwrap();
91        for batch_id in 0..6 {
92            let mut batch = vec![0.0_f32; 4];
93            batch[batch_id % 4] = 1.0;
94            q.enqueue(&batch).unwrap();
95        }
96        assert_eq!(q.len(), 6);
97        // Run MoCo loss with a meaningful query/key pair.
98        let q_vec = vec![1.0_f32, 0.0, 0.0, 0.0];
99        let k_vec = q_vec.clone();
100        let l = moco_loss(&q_vec, &k_vec, 1, 4, &q, 0.1).unwrap();
101        assert!(l.is_finite());
102    }
103
104    #[test]
105    fn e2e_byol_loss_zero_for_identical_inputs() {
106        let z = vec![1.0_f32, 0.0, 0.0, 0.0, 1.0, 0.0];
107        let l = byol_loss(&z, &z, 2, 3).unwrap();
108        assert!(l.abs() < 1e-4);
109    }
110
111    #[test]
112    fn e2e_barlow_twins_low_for_identical_inputs() {
113        let n = 16;
114        let d = 4;
115        // Each column has distinct mean → standardisation makes columns
116        // independent. Identical Z_A = Z_B → diag(C) ≈ 1.
117        let mut z = vec![0.0_f32; n * d];
118        for i in 0..n {
119            for j in 0..d {
120                z[i * d + j] = (i as f32) * 0.1 + (j as f32) * 0.7;
121            }
122        }
123        let cfg = BarlowTwinsConfig::default();
124        let l = barlow_twins_loss(&z, &z, n, d, &cfg).unwrap();
125        assert!(l.is_finite());
126    }
127
128    #[test]
129    fn e2e_vicreg_three_terms_combine() {
130        let n = 16;
131        let d = 4;
132        let z_a: Vec<f32> = (0..n * d).map(|i| (i as f32 * 0.013).sin()).collect();
133        let z_b: Vec<f32> = (0..n * d)
134            .map(|i| (i as f32 * 0.013).sin() + 0.01)
135            .collect();
136        let cfg = VicRegConfig::default();
137        let l = vicreg_loss(&z_a, &z_b, n, d, &cfg).unwrap();
138        assert!(l.is_finite() && l > 0.0);
139    }
140
141    #[test]
142    fn e2e_mae_mask_ratio_respected() {
143        let mut handle = SslHandle::default_handle();
144        let mask = random_patch_mask(196, 0.75, handle.rng_mut()).unwrap();
145        let n_masked = mask.iter().filter(|&&v| v == 0.0).count();
146        assert_eq!(n_masked, 147); // floor(196 * 0.75)
147        // Reconstruction MSE on a perfect predictor is zero.
148        let target = vec![1.5_f32; 196 * 4];
149        let pred = target.clone();
150        let l = mae_reconstruction_loss(&target, &pred, &mask, 196, 4).unwrap();
151        assert!(l.abs() < 1e-7);
152    }
153
154    #[test]
155    fn e2e_swav_sinkhorn_normalises_uniform() {
156        let n = 8;
157        let k = 4;
158        let mut q = vec![1.0_f32; n * k];
159        sinkhorn_knopp(&mut q, n, k, 5).unwrap();
160        // After Sinkhorn, each row sums to 1 and is uniform.
161        for i in 0..n {
162            let s: f32 = q[i * k..(i + 1) * k].iter().sum();
163            assert!((s - 1.0).abs() < 1e-4, "row sum = {s}");
164        }
165    }
166
167    #[test]
168    fn e2e_dino_centred_softmax_returns_finite() {
169        let n = 4;
170        let k = 8;
171        let mut handle = SslHandle::default_handle();
172        let mut s = vec![0.0_f32; n * k];
173        let mut t = vec![0.0_f32; n * k];
174        handle.rng_mut().fill_normal(&mut s);
175        handle.rng_mut().fill_normal(&mut t);
176        let centre = vec![0.0_f32; k];
177        let cfg = DinoConfig::default();
178        let l = dino_loss(&s, &t, &centre, n, k, &cfg).unwrap();
179        assert!(l.is_finite() && l > 0.0);
180    }
181
182    #[test]
183    fn e2e_ema_converges_to_online_when_momentum_zero() {
184        let mut updater = EmaUpdater::new();
185        let mut target = vec![5.0_f32; 8];
186        let online = vec![10.0_f32; 8];
187        updater.update(&mut target, &online, 0.0).unwrap();
188        for &v in &target {
189            assert!((v - 10.0).abs() < 1e-6);
190        }
191        // cosine_momentum is monotone increasing.
192        let m1 = cosine_momentum(0, 100, 0.5, 1.0).unwrap();
193        let m2 = cosine_momentum(100, 100, 0.5, 1.0).unwrap();
194        assert!(m1 < m2);
195    }
196
197    #[test]
198    fn e2e_mlp_projector_forward_correct_shape() {
199        let mut handle = SslHandle::default_handle();
200        let p = MlpProjector::new(64, 32, 16, handle.rng_mut()).unwrap();
201        let x = vec![0.1_f32; 64];
202        let y = p.forward(&x).unwrap();
203        assert_eq!(y.len(), 16);
204        // Predictor head similar interface
205        let pred = PredictorHead::new(16, 32, 16, handle.rng_mut()).unwrap();
206        let y2 = pred.forward(&y).unwrap();
207        assert_eq!(y2.len(), 16);
208    }
209
210    #[test]
211    fn e2e_multi_crop_returns_n_crops() {
212        let cfg = MultiCropConfig::default();
213        let crops = multi_crop(&cfg).unwrap();
214        assert_eq!(crops.len(), cfg.n_crops());
215        // First two are global.
216        assert!(crops[0].is_global);
217        assert!(crops[1].is_global);
218        // Color jitter on a sample image runs without error.
219        let mut handle = SslHandle::default_handle();
220        let h = 8;
221        let w = 8;
222        let mut img = vec![0.5_f32; 3 * h * w];
223        color_jitter(&mut img, h, w, 0.5, handle.rng_mut()).unwrap();
224        let _converted = random_grayscale_chw(&mut img, h, w, 0.5, handle.rng_mut()).unwrap();
225        for v in &img {
226            assert!((0.0..=1.0).contains(v));
227        }
228    }
229
230    #[test]
231    fn e2e_ptx_kernels_all_sm_versions() {
232        for sm in [75_u32, 80, 86, 90, 100, 120] {
233            for prog in [
234                nt_xent_softmax_ptx(sm),
235                momentum_update_ptx(sm),
236                byol_cosine_loss_ptx(sm),
237                barlow_cross_corr_ptx(sm),
238                random_mask_ptx(sm),
239                cosine_similarity_ptx(sm),
240                gather_features_ptx(sm),
241            ] {
242                assert!(prog.contains(&format!("sm_{sm}")));
243                assert!(prog.contains(".visible .entry"));
244            }
245        }
246        // Smoke-test f32_hex to keep the prelude path live.
247        assert_eq!(f32_hex(1.0_f32), "0F3F800000");
248    }
249}