Skip to main content

oxicuda_ssl/
lib.rs

1//! `oxicuda-ssl` — Self-supervised learning primitives for OxiCUDA.
2//!
3//! Pure-Rust implementation of the four canonical SSL families, suitable for
4//! CPU simulation and PTX kernel generation for GPU execution.
5//!
6//! # Architecture
7//!
8//! ```text
9//! oxicuda-ssl
10//! ├── contrastive/      — SimCLR (NT-Xent), MoCo (memory-bank InfoNCE)
11//! ├── non_contrastive/  — BYOL (cosine), Barlow Twins, VICReg
12//! ├── masked/           — MAE (random patch mask + reconstruction MSE)
13//! ├── clustering/       — SwAV (Sinkhorn-Knopp), DINO (centred + sharpened CE)
14//! ├── augment/          — Color jitter, multi-crop helpers
15//! ├── metrics/          — Uniformity, alignment, effective rank, collapse score
16//! ├── momentum/         — EmaUpdater for momentum-encoder schemes
17//! ├── head/             — MlpProjector, PredictorHead
18//! ├── error             — SslError / SslResult
19//! ├── handle            — SslHandle (SmVersion + LcgRng)
20//! └── ptx_kernels       — GPU PTX kernel strings
21//! ```
22
23// ─── Module declarations ─────────────────────────────────────────────────────
24
25pub mod augment;
26pub mod clustering;
27pub mod contrastive;
28pub mod error;
29pub mod handle;
30pub mod head;
31pub mod masked;
32pub mod metrics;
33pub mod momentum;
34pub mod non_contrastive;
35pub mod ptx_kernels;
36pub mod ssl;
37
38// ─── Prelude ─────────────────────────────────────────────────────────────────
39
40/// Convenience re-exports for common SSL types.
41pub mod prelude {
42    pub use crate::augment::color::{color_jitter, random_grayscale_chw};
43    pub use crate::augment::multi_crop::{MultiCropConfig, multi_crop};
44    pub use crate::augment::rand_augment::{
45        AugOp, AutoAugPolicy, AutoAugmentConfig, RandAugmentConfig, SubPolicy, all_aug_ops,
46        apply_aug_op, auto_augment, rand_augment,
47    };
48    pub use crate::augment::solarize_blur::{
49        SimClrBlurSolarConfig, add_gaussian_noise, gaussian_blur_chw, random_gaussian_blur_chw,
50        random_solarize, simclr_blur_solar, solarize,
51    };
52    pub use crate::clustering::deep_cluster::{
53        DeepClusterConfig, DeepClusterResult, DeeperClusterConfig, DeeperClusterResult,
54        deep_cluster, deep_cluster_loss, deeper_cluster, pca_whiten,
55    };
56    pub use crate::clustering::dino::{DinoConfig, dino_loss};
57    pub use crate::clustering::dino_v2::{DinoV2, DinoV2Config};
58    pub use crate::clustering::ibot::{
59        IBotCenters, IBotConfig, IBotResult, ibot_centers_init, ibot_cls_loss, ibot_loss,
60        ibot_mim_loss, ibot_random_patch_mask, ibot_update_centers,
61    };
62    pub use crate::clustering::swav::{SwavConfig, sinkhorn_knopp, swav_loss};
63    pub use crate::contrastive::info_nce::info_nce_loss;
64    pub use crate::contrastive::moco::{MocoQueue, moco_loss};
65    pub use crate::contrastive::moco_v3::{
66        MocoV3Config, MocoV3State, moco_v3_loss, moco_v3_symmetric_loss,
67    };
68    pub use crate::contrastive::simclr::{SimClrConfig, simclr_loss};
69    pub use crate::error::{SslError, SslResult};
70    pub use crate::handle::{LcgRng, SmVersion, SslHandle};
71    pub use crate::head::linear_probe::{
72        FittedLinearProbe, LinearProbeConfig, LinearProbeResult, linear_probe_eval,
73        linear_probe_fit, linear_probe_predict,
74    };
75    pub use crate::head::predictor::PredictorHead;
76    pub use crate::head::projector::MlpProjector;
77    pub use crate::masked::beit::{
78        BeitConfig, BeitResult, VqCodebook, beit_block_mask, beit_loss, vq_codebook_init,
79        vq_encode, vq_update_codebook,
80    };
81    pub use crate::masked::data2vec::{
82        Data2VecConfig, Data2VecResult, Data2VecState, data2vec_batch_loss, data2vec_loss,
83        data2vec_mask, huber_loss, normalize_teacher_targets,
84    };
85    pub use crate::masked::i_jepa::{IJepa, IJepaConfig};
86    pub use crate::masked::mae::{MaeConfig, mae_reconstruction_loss, random_patch_mask};
87    pub use crate::masked::simmim::{
88        SimMimConfig, simmim_block_mask, simmim_l1_loss, simmim_l2_loss, simmim_random_mask,
89        simmim_reconstruction_loss,
90    };
91    pub use crate::metrics::feature_metrics::{
92        alignment_loss, collapse_score, effective_rank, pairwise_cosine_stats, uniformity_loss,
93    };
94    pub use crate::metrics::knn_eval::{KnnEvalConfig, KnnEvalResult, knn_eval};
95    pub use crate::momentum::ema::{EmaUpdater, cosine_momentum};
96    pub use crate::non_contrastive::barlow::{BarlowTwinsConfig, barlow_twins_loss};
97    pub use crate::non_contrastive::byol::{ByolPredictor, byol_loss};
98    pub use crate::non_contrastive::dense_cl::{
99        DenseCLConfig, DenseCLResult, PixProConfig, dense_cl_loss, dense_correspondence,
100        dense_infonce, pixpro_loss,
101    };
102    pub use crate::non_contrastive::msn::{
103        MsnConfig, MsnPrototypes, MsnResult, msn_loss, msn_prototype_init, msn_random_mask,
104        msn_update_prototypes,
105    };
106    pub use crate::non_contrastive::simsiam::{
107        SimSiamConfig, SimSiamPredictor, is_collapsed, simsiam_loss, simsiam_loss_batch,
108    };
109    pub use crate::non_contrastive::vicreg::{VicRegConfig, vicreg_loss};
110    pub use crate::ptx_kernels::{
111        barlow_cross_corr_ptx, byol_cosine_loss_ptx, cosine_similarity_ptx, f32_hex,
112        gather_features_ptx, momentum_update_ptx, nt_xent_softmax_ptx, random_mask_ptx,
113    };
114    pub use crate::ssl::data2vec_v2::{Data2VecModel, Data2VecModelConfig};
115    pub use crate::ssl::jem::{Jem, JemConfig};
116    pub use crate::ssl::sim_siam::{SimSiam, SimSiamConfig as SimSiamStructConfig};
117}
118
119// ─── End-to-end integration tests ────────────────────────────────────────────
120
121#[cfg(test)]
122mod e2e_tests {
123    use crate::prelude::*;
124
125    /// Build a "perfectly aligned" projection batch where each row is a
126    /// distinct one-hot basis vector — diagonal cosine = 1, off-diagonal = 0.
127    fn aligned_projections(n: usize, d: usize) -> Vec<f32> {
128        let mut z = vec![0.0_f32; n * d];
129        for i in 0..n {
130            z[i * d + i % d] = 1.0;
131        }
132        z
133    }
134
135    #[test]
136    fn e2e_simclr_loss_drops_with_aligned_pairs() {
137        let n = 8;
138        let d = 16;
139        let z = aligned_projections(n, d);
140        let cfg = SimClrConfig::default();
141        let (loss, acc) = simclr_loss(&z, &z, n, d, &cfg).expect("simclr_loss should succeed");
142        assert!(loss.is_finite() && loss < 1.0, "loss = {loss}");
143        assert!((acc - 1.0).abs() < 1e-6);
144    }
145
146    #[test]
147    fn e2e_moco_queue_lifecycle_fifo() {
148        let mut q = MocoQueue::new(8, 4).expect("new should succeed");
149        for batch_id in 0..6 {
150            let mut batch = vec![0.0_f32; 4];
151            batch[batch_id % 4] = 1.0;
152            q.enqueue(&batch).expect("enqueue should succeed");
153        }
154        assert_eq!(q.len(), 6);
155        // Run MoCo loss with a meaningful query/key pair.
156        let q_vec = vec![1.0_f32, 0.0, 0.0, 0.0];
157        let k_vec = q_vec.clone();
158        let l = moco_loss(&q_vec, &k_vec, 1, 4, &q, 0.1).expect("moco_loss should succeed");
159        assert!(l.is_finite());
160    }
161
162    #[test]
163    fn e2e_byol_loss_zero_for_identical_inputs() {
164        let z = vec![1.0_f32, 0.0, 0.0, 0.0, 1.0, 0.0];
165        let l = byol_loss(&z, &z, 2, 3).expect("byol_loss should succeed");
166        assert!(l.abs() < 1e-4);
167    }
168
169    #[test]
170    fn e2e_barlow_twins_low_for_identical_inputs() {
171        let n = 16;
172        let d = 4;
173        // Each column has distinct mean → standardisation makes columns
174        // independent. Identical Z_A = Z_B → diag(C) ≈ 1.
175        let mut z = vec![0.0_f32; n * d];
176        for i in 0..n {
177            for j in 0..d {
178                z[i * d + j] = (i as f32) * 0.1 + (j as f32) * 0.7;
179            }
180        }
181        let cfg = BarlowTwinsConfig::default();
182        let l = barlow_twins_loss(&z, &z, n, d, &cfg).expect("barlow_twins_loss should succeed");
183        assert!(l.is_finite());
184    }
185
186    #[test]
187    fn e2e_vicreg_three_terms_combine() {
188        let n = 16;
189        let d = 4;
190        let z_a: Vec<f32> = (0..n * d).map(|i| (i as f32 * 0.013).sin()).collect();
191        let z_b: Vec<f32> = (0..n * d)
192            .map(|i| (i as f32 * 0.013).sin() + 0.01)
193            .collect();
194        let cfg = VicRegConfig::default();
195        let l = vicreg_loss(&z_a, &z_b, n, d, &cfg).expect("vicreg_loss should succeed");
196        assert!(l.is_finite() && l > 0.0);
197    }
198
199    #[test]
200    fn e2e_mae_mask_ratio_respected() {
201        let mut handle = SslHandle::default_handle();
202        let mask = random_patch_mask(196, 0.75, handle.rng_mut()).expect("value should be present");
203        let n_masked = mask.iter().filter(|&&v| v == 0.0).count();
204        assert_eq!(n_masked, 147); // floor(196 * 0.75)
205        // Reconstruction MSE on a perfect predictor is zero.
206        let target = vec![1.5_f32; 196 * 4];
207        let pred = target.clone();
208        let l = mae_reconstruction_loss(&target, &pred, &mask, 196, 4)
209            .expect("mae_reconstruction_loss should succeed");
210        assert!(l.abs() < 1e-7);
211    }
212
213    #[test]
214    fn e2e_swav_sinkhorn_normalises_uniform() {
215        let n = 8;
216        let k = 4;
217        let mut q = vec![1.0_f32; n * k];
218        sinkhorn_knopp(&mut q, n, k, 5).expect("sinkhorn_knopp should succeed");
219        // After Sinkhorn, each row sums to 1 and is uniform.
220        for i in 0..n {
221            let s: f32 = q[i * k..(i + 1) * k].iter().sum();
222            assert!((s - 1.0).abs() < 1e-4, "row sum = {s}");
223        }
224    }
225
226    #[test]
227    fn e2e_dino_centred_softmax_returns_finite() {
228        let n = 4;
229        let k = 8;
230        let mut handle = SslHandle::default_handle();
231        let mut s = vec![0.0_f32; n * k];
232        let mut t = vec![0.0_f32; n * k];
233        handle.rng_mut().fill_normal(&mut s);
234        handle.rng_mut().fill_normal(&mut t);
235        let centre = vec![0.0_f32; k];
236        let cfg = DinoConfig::default();
237        let l = dino_loss(&s, &t, &centre, n, k, &cfg).expect("dino_loss should succeed");
238        assert!(l.is_finite() && l > 0.0);
239    }
240
241    #[test]
242    fn e2e_ema_converges_to_online_when_momentum_zero() {
243        let mut updater = EmaUpdater::new();
244        let mut target = vec![5.0_f32; 8];
245        let online = vec![10.0_f32; 8];
246        updater
247            .update(&mut target, &online, 0.0)
248            .expect("update should succeed");
249        for &v in &target {
250            assert!((v - 10.0).abs() < 1e-6);
251        }
252        // cosine_momentum is monotone increasing.
253        let m1 = cosine_momentum(0, 100, 0.5, 1.0).expect("cosine_momentum should succeed");
254        let m2 = cosine_momentum(100, 100, 0.5, 1.0).expect("cosine_momentum should succeed");
255        assert!(m1 < m2);
256    }
257
258    #[test]
259    fn e2e_mlp_projector_forward_correct_shape() {
260        let mut handle = SslHandle::default_handle();
261        let p = MlpProjector::new(64, 32, 16, handle.rng_mut()).expect("value should be present");
262        let x = vec![0.1_f32; 64];
263        let y = p.forward(&x).expect("forward should succeed");
264        assert_eq!(y.len(), 16);
265        // Predictor head similar interface
266        let pred =
267            PredictorHead::new(16, 32, 16, handle.rng_mut()).expect("value should be present");
268        let y2 = pred.forward(&y).expect("forward should succeed");
269        assert_eq!(y2.len(), 16);
270    }
271
272    #[test]
273    fn e2e_multi_crop_returns_n_crops() {
274        let cfg = MultiCropConfig::default();
275        let crops = multi_crop(&cfg).expect("multi_crop should succeed");
276        assert_eq!(crops.len(), cfg.n_crops());
277        // First two are global.
278        assert!(crops[0].is_global);
279        assert!(crops[1].is_global);
280        // Color jitter on a sample image runs without error.
281        let mut handle = SslHandle::default_handle();
282        let h = 8;
283        let w = 8;
284        let mut img = vec![0.5_f32; 3 * h * w];
285        color_jitter(&mut img, h, w, 0.5, handle.rng_mut()).expect("value should be present");
286        let _converted = random_grayscale_chw(&mut img, h, w, 0.5, handle.rng_mut())
287            .expect("value should be present");
288        for v in &img {
289            assert!((0.0..=1.0).contains(v));
290        }
291    }
292
293    #[test]
294    fn e2e_ptx_kernels_all_sm_versions() {
295        for sm in [75_u32, 80, 86, 90, 100, 120] {
296            for prog in [
297                nt_xent_softmax_ptx(sm),
298                momentum_update_ptx(sm),
299                byol_cosine_loss_ptx(sm),
300                barlow_cross_corr_ptx(sm),
301                random_mask_ptx(sm),
302                cosine_similarity_ptx(sm),
303                gather_features_ptx(sm),
304            ] {
305                assert!(prog.contains(&format!("sm_{sm}")));
306                assert!(prog.contains(".visible .entry"));
307            }
308        }
309        // Smoke-test f32_hex to keep the prelude path live.
310        assert_eq!(f32_hex(1.0_f32), "0F3F800000");
311    }
312}