Skip to main content

oxicuda_ssl/
lib.rs

1//! `oxicuda-ssl` — Self-supervised learning primitives for OxiCUDA.
2//!
3//! Pure-Rust implementation of the four canonical SSL families, suitable for
4//! CPU simulation and PTX kernel generation for GPU execution.
5//!
6//! # Architecture
7//!
8//! ```text
9//! oxicuda-ssl
10//! ├── contrastive/      — SimCLR (NT-Xent), MoCo (memory-bank InfoNCE)
11//! ├── non_contrastive/  — BYOL (cosine), Barlow Twins, VICReg
12//! ├── masked/           — MAE (random patch mask + reconstruction MSE)
13//! ├── clustering/       — SwAV (Sinkhorn-Knopp), DINO (centred + sharpened CE)
14//! ├── augment/          — Color jitter, multi-crop helpers
15//! ├── metrics/          — Uniformity, alignment, effective rank, collapse score
16//! ├── momentum/         — EmaUpdater for momentum-encoder schemes
17//! ├── head/             — MlpProjector, PredictorHead
18//! ├── error             — SslError / SslResult
19//! ├── handle            — SslHandle (SmVersion + LcgRng)
20//! └── ptx_kernels       — GPU PTX kernel strings
21//! ```
22
23// ─── Module declarations ─────────────────────────────────────────────────────
24
25pub mod augment;
26pub mod clustering;
27pub mod contrastive;
28pub mod error;
29pub mod handle;
30pub mod head;
31pub mod masked;
32pub mod metrics;
33pub mod momentum;
34pub mod non_contrastive;
35pub mod ptx_kernels;
36
37// ─── Prelude ─────────────────────────────────────────────────────────────────
38
39/// Convenience re-exports for common SSL types.
40pub mod prelude {
41    pub use crate::augment::color::{color_jitter, random_grayscale_chw};
42    pub use crate::augment::multi_crop::{MultiCropConfig, multi_crop};
43    pub use crate::augment::rand_augment::{
44        AugOp, AutoAugPolicy, AutoAugmentConfig, RandAugmentConfig, SubPolicy, all_aug_ops,
45        apply_aug_op, auto_augment, rand_augment,
46    };
47    pub use crate::augment::solarize_blur::{
48        SimClrBlurSolarConfig, add_gaussian_noise, gaussian_blur_chw, random_gaussian_blur_chw,
49        random_solarize, simclr_blur_solar, solarize,
50    };
51    pub use crate::clustering::deep_cluster::{
52        DeepClusterConfig, DeepClusterResult, DeeperClusterConfig, DeeperClusterResult,
53        deep_cluster, deep_cluster_loss, deeper_cluster, pca_whiten,
54    };
55    pub use crate::clustering::dino::{DinoConfig, dino_loss};
56    pub use crate::clustering::ibot::{
57        IBotCenters, IBotConfig, IBotResult, ibot_centers_init, ibot_cls_loss, ibot_loss,
58        ibot_mim_loss, ibot_random_patch_mask, ibot_update_centers,
59    };
60    pub use crate::clustering::swav::{SwavConfig, sinkhorn_knopp, swav_loss};
61    pub use crate::contrastive::info_nce::info_nce_loss;
62    pub use crate::contrastive::moco::{MocoQueue, moco_loss};
63    pub use crate::contrastive::moco_v3::{
64        MocoV3Config, MocoV3State, moco_v3_loss, moco_v3_symmetric_loss,
65    };
66    pub use crate::contrastive::simclr::{SimClrConfig, simclr_loss};
67    pub use crate::error::{SslError, SslResult};
68    pub use crate::handle::{LcgRng, SmVersion, SslHandle};
69    pub use crate::head::linear_probe::{
70        FittedLinearProbe, LinearProbeConfig, LinearProbeResult, linear_probe_eval,
71        linear_probe_fit, linear_probe_predict,
72    };
73    pub use crate::head::predictor::PredictorHead;
74    pub use crate::head::projector::MlpProjector;
75    pub use crate::masked::beit::{
76        BeitConfig, BeitResult, VqCodebook, beit_block_mask, beit_loss, vq_codebook_init,
77        vq_encode, vq_update_codebook,
78    };
79    pub use crate::masked::data2vec::{
80        Data2VecConfig, Data2VecResult, Data2VecState, data2vec_batch_loss, data2vec_loss,
81        data2vec_mask, huber_loss, normalize_teacher_targets,
82    };
83    pub use crate::masked::mae::{MaeConfig, mae_reconstruction_loss, random_patch_mask};
84    pub use crate::masked::simmim::{
85        SimMimConfig, simmim_block_mask, simmim_l1_loss, simmim_l2_loss, simmim_random_mask,
86        simmim_reconstruction_loss,
87    };
88    pub use crate::metrics::feature_metrics::{
89        alignment_loss, collapse_score, effective_rank, pairwise_cosine_stats, uniformity_loss,
90    };
91    pub use crate::metrics::knn_eval::{KnnEvalConfig, KnnEvalResult, knn_eval};
92    pub use crate::momentum::ema::{EmaUpdater, cosine_momentum};
93    pub use crate::non_contrastive::barlow::{BarlowTwinsConfig, barlow_twins_loss};
94    pub use crate::non_contrastive::byol::{ByolPredictor, byol_loss};
95    pub use crate::non_contrastive::dense_cl::{
96        DenseCLConfig, DenseCLResult, PixProConfig, dense_cl_loss, dense_correspondence,
97        dense_infonce, pixpro_loss,
98    };
99    pub use crate::non_contrastive::msn::{
100        MsnConfig, MsnPrototypes, MsnResult, msn_loss, msn_prototype_init, msn_random_mask,
101        msn_update_prototypes,
102    };
103    pub use crate::non_contrastive::simsiam::{
104        SimSiamConfig, SimSiamPredictor, is_collapsed, simsiam_loss, simsiam_loss_batch,
105    };
106    pub use crate::non_contrastive::vicreg::{VicRegConfig, vicreg_loss};
107    pub use crate::ptx_kernels::{
108        barlow_cross_corr_ptx, byol_cosine_loss_ptx, cosine_similarity_ptx, f32_hex,
109        gather_features_ptx, momentum_update_ptx, nt_xent_softmax_ptx, random_mask_ptx,
110    };
111}
112
113// ─── End-to-end integration tests ────────────────────────────────────────────
114
115#[cfg(test)]
116mod e2e_tests {
117    use crate::prelude::*;
118
119    /// Build a "perfectly aligned" projection batch where each row is a
120    /// distinct one-hot basis vector — diagonal cosine = 1, off-diagonal = 0.
121    fn aligned_projections(n: usize, d: usize) -> Vec<f32> {
122        let mut z = vec![0.0_f32; n * d];
123        for i in 0..n {
124            z[i * d + i % d] = 1.0;
125        }
126        z
127    }
128
129    #[test]
130    fn e2e_simclr_loss_drops_with_aligned_pairs() {
131        let n = 8;
132        let d = 16;
133        let z = aligned_projections(n, d);
134        let cfg = SimClrConfig::default();
135        let (loss, acc) = simclr_loss(&z, &z, n, d, &cfg).unwrap();
136        assert!(loss.is_finite() && loss < 1.0, "loss = {loss}");
137        assert!((acc - 1.0).abs() < 1e-6);
138    }
139
140    #[test]
141    fn e2e_moco_queue_lifecycle_fifo() {
142        let mut q = MocoQueue::new(8, 4).unwrap();
143        for batch_id in 0..6 {
144            let mut batch = vec![0.0_f32; 4];
145            batch[batch_id % 4] = 1.0;
146            q.enqueue(&batch).unwrap();
147        }
148        assert_eq!(q.len(), 6);
149        // Run MoCo loss with a meaningful query/key pair.
150        let q_vec = vec![1.0_f32, 0.0, 0.0, 0.0];
151        let k_vec = q_vec.clone();
152        let l = moco_loss(&q_vec, &k_vec, 1, 4, &q, 0.1).unwrap();
153        assert!(l.is_finite());
154    }
155
156    #[test]
157    fn e2e_byol_loss_zero_for_identical_inputs() {
158        let z = vec![1.0_f32, 0.0, 0.0, 0.0, 1.0, 0.0];
159        let l = byol_loss(&z, &z, 2, 3).unwrap();
160        assert!(l.abs() < 1e-4);
161    }
162
163    #[test]
164    fn e2e_barlow_twins_low_for_identical_inputs() {
165        let n = 16;
166        let d = 4;
167        // Each column has distinct mean → standardisation makes columns
168        // independent. Identical Z_A = Z_B → diag(C) ≈ 1.
169        let mut z = vec![0.0_f32; n * d];
170        for i in 0..n {
171            for j in 0..d {
172                z[i * d + j] = (i as f32) * 0.1 + (j as f32) * 0.7;
173            }
174        }
175        let cfg = BarlowTwinsConfig::default();
176        let l = barlow_twins_loss(&z, &z, n, d, &cfg).unwrap();
177        assert!(l.is_finite());
178    }
179
180    #[test]
181    fn e2e_vicreg_three_terms_combine() {
182        let n = 16;
183        let d = 4;
184        let z_a: Vec<f32> = (0..n * d).map(|i| (i as f32 * 0.013).sin()).collect();
185        let z_b: Vec<f32> = (0..n * d)
186            .map(|i| (i as f32 * 0.013).sin() + 0.01)
187            .collect();
188        let cfg = VicRegConfig::default();
189        let l = vicreg_loss(&z_a, &z_b, n, d, &cfg).unwrap();
190        assert!(l.is_finite() && l > 0.0);
191    }
192
193    #[test]
194    fn e2e_mae_mask_ratio_respected() {
195        let mut handle = SslHandle::default_handle();
196        let mask = random_patch_mask(196, 0.75, handle.rng_mut()).unwrap();
197        let n_masked = mask.iter().filter(|&&v| v == 0.0).count();
198        assert_eq!(n_masked, 147); // floor(196 * 0.75)
199        // Reconstruction MSE on a perfect predictor is zero.
200        let target = vec![1.5_f32; 196 * 4];
201        let pred = target.clone();
202        let l = mae_reconstruction_loss(&target, &pred, &mask, 196, 4).unwrap();
203        assert!(l.abs() < 1e-7);
204    }
205
206    #[test]
207    fn e2e_swav_sinkhorn_normalises_uniform() {
208        let n = 8;
209        let k = 4;
210        let mut q = vec![1.0_f32; n * k];
211        sinkhorn_knopp(&mut q, n, k, 5).unwrap();
212        // After Sinkhorn, each row sums to 1 and is uniform.
213        for i in 0..n {
214            let s: f32 = q[i * k..(i + 1) * k].iter().sum();
215            assert!((s - 1.0).abs() < 1e-4, "row sum = {s}");
216        }
217    }
218
219    #[test]
220    fn e2e_dino_centred_softmax_returns_finite() {
221        let n = 4;
222        let k = 8;
223        let mut handle = SslHandle::default_handle();
224        let mut s = vec![0.0_f32; n * k];
225        let mut t = vec![0.0_f32; n * k];
226        handle.rng_mut().fill_normal(&mut s);
227        handle.rng_mut().fill_normal(&mut t);
228        let centre = vec![0.0_f32; k];
229        let cfg = DinoConfig::default();
230        let l = dino_loss(&s, &t, &centre, n, k, &cfg).unwrap();
231        assert!(l.is_finite() && l > 0.0);
232    }
233
234    #[test]
235    fn e2e_ema_converges_to_online_when_momentum_zero() {
236        let mut updater = EmaUpdater::new();
237        let mut target = vec![5.0_f32; 8];
238        let online = vec![10.0_f32; 8];
239        updater.update(&mut target, &online, 0.0).unwrap();
240        for &v in &target {
241            assert!((v - 10.0).abs() < 1e-6);
242        }
243        // cosine_momentum is monotone increasing.
244        let m1 = cosine_momentum(0, 100, 0.5, 1.0).unwrap();
245        let m2 = cosine_momentum(100, 100, 0.5, 1.0).unwrap();
246        assert!(m1 < m2);
247    }
248
249    #[test]
250    fn e2e_mlp_projector_forward_correct_shape() {
251        let mut handle = SslHandle::default_handle();
252        let p = MlpProjector::new(64, 32, 16, handle.rng_mut()).unwrap();
253        let x = vec![0.1_f32; 64];
254        let y = p.forward(&x).unwrap();
255        assert_eq!(y.len(), 16);
256        // Predictor head similar interface
257        let pred = PredictorHead::new(16, 32, 16, handle.rng_mut()).unwrap();
258        let y2 = pred.forward(&y).unwrap();
259        assert_eq!(y2.len(), 16);
260    }
261
262    #[test]
263    fn e2e_multi_crop_returns_n_crops() {
264        let cfg = MultiCropConfig::default();
265        let crops = multi_crop(&cfg).unwrap();
266        assert_eq!(crops.len(), cfg.n_crops());
267        // First two are global.
268        assert!(crops[0].is_global);
269        assert!(crops[1].is_global);
270        // Color jitter on a sample image runs without error.
271        let mut handle = SslHandle::default_handle();
272        let h = 8;
273        let w = 8;
274        let mut img = vec![0.5_f32; 3 * h * w];
275        color_jitter(&mut img, h, w, 0.5, handle.rng_mut()).unwrap();
276        let _converted = random_grayscale_chw(&mut img, h, w, 0.5, handle.rng_mut()).unwrap();
277        for v in &img {
278            assert!((0.0..=1.0).contains(v));
279        }
280    }
281
282    #[test]
283    fn e2e_ptx_kernels_all_sm_versions() {
284        for sm in [75_u32, 80, 86, 90, 100, 120] {
285            for prog in [
286                nt_xent_softmax_ptx(sm),
287                momentum_update_ptx(sm),
288                byol_cosine_loss_ptx(sm),
289                barlow_cross_corr_ptx(sm),
290                random_mask_ptx(sm),
291                cosine_similarity_ptx(sm),
292                gather_features_ptx(sm),
293            ] {
294                assert!(prog.contains(&format!("sm_{sm}")));
295                assert!(prog.contains(".visible .entry"));
296            }
297        }
298        // Smoke-test f32_hex to keep the prelude path live.
299        assert_eq!(f32_hex(1.0_f32), "0F3F800000");
300    }
301}