1pub mod augment;
26pub mod clustering;
27pub mod contrastive;
28pub mod error;
29pub mod handle;
30pub mod head;
31pub mod masked;
32pub mod metrics;
33pub mod momentum;
34pub mod non_contrastive;
35pub mod ptx_kernels;
36
37pub mod prelude {
41 pub use crate::augment::color::{color_jitter, random_grayscale_chw};
42 pub use crate::augment::multi_crop::{MultiCropConfig, multi_crop};
43 pub use crate::augment::rand_augment::{
44 AugOp, AutoAugPolicy, AutoAugmentConfig, RandAugmentConfig, SubPolicy, all_aug_ops,
45 apply_aug_op, auto_augment, rand_augment,
46 };
47 pub use crate::augment::solarize_blur::{
48 SimClrBlurSolarConfig, add_gaussian_noise, gaussian_blur_chw, random_gaussian_blur_chw,
49 random_solarize, simclr_blur_solar, solarize,
50 };
51 pub use crate::clustering::deep_cluster::{
52 DeepClusterConfig, DeepClusterResult, DeeperClusterConfig, DeeperClusterResult,
53 deep_cluster, deep_cluster_loss, deeper_cluster, pca_whiten,
54 };
55 pub use crate::clustering::dino::{DinoConfig, dino_loss};
56 pub use crate::clustering::ibot::{
57 IBotCenters, IBotConfig, IBotResult, ibot_centers_init, ibot_cls_loss, ibot_loss,
58 ibot_mim_loss, ibot_random_patch_mask, ibot_update_centers,
59 };
60 pub use crate::clustering::swav::{SwavConfig, sinkhorn_knopp, swav_loss};
61 pub use crate::contrastive::info_nce::info_nce_loss;
62 pub use crate::contrastive::moco::{MocoQueue, moco_loss};
63 pub use crate::contrastive::moco_v3::{
64 MocoV3Config, MocoV3State, moco_v3_loss, moco_v3_symmetric_loss,
65 };
66 pub use crate::contrastive::simclr::{SimClrConfig, simclr_loss};
67 pub use crate::error::{SslError, SslResult};
68 pub use crate::handle::{LcgRng, SmVersion, SslHandle};
69 pub use crate::head::linear_probe::{
70 FittedLinearProbe, LinearProbeConfig, LinearProbeResult, linear_probe_eval,
71 linear_probe_fit, linear_probe_predict,
72 };
73 pub use crate::head::predictor::PredictorHead;
74 pub use crate::head::projector::MlpProjector;
75 pub use crate::masked::beit::{
76 BeitConfig, BeitResult, VqCodebook, beit_block_mask, beit_loss, vq_codebook_init,
77 vq_encode, vq_update_codebook,
78 };
79 pub use crate::masked::data2vec::{
80 Data2VecConfig, Data2VecResult, Data2VecState, data2vec_batch_loss, data2vec_loss,
81 data2vec_mask, huber_loss, normalize_teacher_targets,
82 };
83 pub use crate::masked::mae::{MaeConfig, mae_reconstruction_loss, random_patch_mask};
84 pub use crate::masked::simmim::{
85 SimMimConfig, simmim_block_mask, simmim_l1_loss, simmim_l2_loss, simmim_random_mask,
86 simmim_reconstruction_loss,
87 };
88 pub use crate::metrics::feature_metrics::{
89 alignment_loss, collapse_score, effective_rank, pairwise_cosine_stats, uniformity_loss,
90 };
91 pub use crate::metrics::knn_eval::{KnnEvalConfig, KnnEvalResult, knn_eval};
92 pub use crate::momentum::ema::{EmaUpdater, cosine_momentum};
93 pub use crate::non_contrastive::barlow::{BarlowTwinsConfig, barlow_twins_loss};
94 pub use crate::non_contrastive::byol::{ByolPredictor, byol_loss};
95 pub use crate::non_contrastive::dense_cl::{
96 DenseCLConfig, DenseCLResult, PixProConfig, dense_cl_loss, dense_correspondence,
97 dense_infonce, pixpro_loss,
98 };
99 pub use crate::non_contrastive::msn::{
100 MsnConfig, MsnPrototypes, MsnResult, msn_loss, msn_prototype_init, msn_random_mask,
101 msn_update_prototypes,
102 };
103 pub use crate::non_contrastive::simsiam::{
104 SimSiamConfig, SimSiamPredictor, is_collapsed, simsiam_loss, simsiam_loss_batch,
105 };
106 pub use crate::non_contrastive::vicreg::{VicRegConfig, vicreg_loss};
107 pub use crate::ptx_kernels::{
108 barlow_cross_corr_ptx, byol_cosine_loss_ptx, cosine_similarity_ptx, f32_hex,
109 gather_features_ptx, momentum_update_ptx, nt_xent_softmax_ptx, random_mask_ptx,
110 };
111}
112
113#[cfg(test)]
116mod e2e_tests {
117 use crate::prelude::*;
118
119 fn aligned_projections(n: usize, d: usize) -> Vec<f32> {
122 let mut z = vec![0.0_f32; n * d];
123 for i in 0..n {
124 z[i * d + i % d] = 1.0;
125 }
126 z
127 }
128
129 #[test]
130 fn e2e_simclr_loss_drops_with_aligned_pairs() {
131 let n = 8;
132 let d = 16;
133 let z = aligned_projections(n, d);
134 let cfg = SimClrConfig::default();
135 let (loss, acc) = simclr_loss(&z, &z, n, d, &cfg).unwrap();
136 assert!(loss.is_finite() && loss < 1.0, "loss = {loss}");
137 assert!((acc - 1.0).abs() < 1e-6);
138 }
139
140 #[test]
141 fn e2e_moco_queue_lifecycle_fifo() {
142 let mut q = MocoQueue::new(8, 4).unwrap();
143 for batch_id in 0..6 {
144 let mut batch = vec![0.0_f32; 4];
145 batch[batch_id % 4] = 1.0;
146 q.enqueue(&batch).unwrap();
147 }
148 assert_eq!(q.len(), 6);
149 let q_vec = vec![1.0_f32, 0.0, 0.0, 0.0];
151 let k_vec = q_vec.clone();
152 let l = moco_loss(&q_vec, &k_vec, 1, 4, &q, 0.1).unwrap();
153 assert!(l.is_finite());
154 }
155
156 #[test]
157 fn e2e_byol_loss_zero_for_identical_inputs() {
158 let z = vec![1.0_f32, 0.0, 0.0, 0.0, 1.0, 0.0];
159 let l = byol_loss(&z, &z, 2, 3).unwrap();
160 assert!(l.abs() < 1e-4);
161 }
162
163 #[test]
164 fn e2e_barlow_twins_low_for_identical_inputs() {
165 let n = 16;
166 let d = 4;
167 let mut z = vec![0.0_f32; n * d];
170 for i in 0..n {
171 for j in 0..d {
172 z[i * d + j] = (i as f32) * 0.1 + (j as f32) * 0.7;
173 }
174 }
175 let cfg = BarlowTwinsConfig::default();
176 let l = barlow_twins_loss(&z, &z, n, d, &cfg).unwrap();
177 assert!(l.is_finite());
178 }
179
180 #[test]
181 fn e2e_vicreg_three_terms_combine() {
182 let n = 16;
183 let d = 4;
184 let z_a: Vec<f32> = (0..n * d).map(|i| (i as f32 * 0.013).sin()).collect();
185 let z_b: Vec<f32> = (0..n * d)
186 .map(|i| (i as f32 * 0.013).sin() + 0.01)
187 .collect();
188 let cfg = VicRegConfig::default();
189 let l = vicreg_loss(&z_a, &z_b, n, d, &cfg).unwrap();
190 assert!(l.is_finite() && l > 0.0);
191 }
192
193 #[test]
194 fn e2e_mae_mask_ratio_respected() {
195 let mut handle = SslHandle::default_handle();
196 let mask = random_patch_mask(196, 0.75, handle.rng_mut()).unwrap();
197 let n_masked = mask.iter().filter(|&&v| v == 0.0).count();
198 assert_eq!(n_masked, 147); let target = vec![1.5_f32; 196 * 4];
201 let pred = target.clone();
202 let l = mae_reconstruction_loss(&target, &pred, &mask, 196, 4).unwrap();
203 assert!(l.abs() < 1e-7);
204 }
205
206 #[test]
207 fn e2e_swav_sinkhorn_normalises_uniform() {
208 let n = 8;
209 let k = 4;
210 let mut q = vec![1.0_f32; n * k];
211 sinkhorn_knopp(&mut q, n, k, 5).unwrap();
212 for i in 0..n {
214 let s: f32 = q[i * k..(i + 1) * k].iter().sum();
215 assert!((s - 1.0).abs() < 1e-4, "row sum = {s}");
216 }
217 }
218
219 #[test]
220 fn e2e_dino_centred_softmax_returns_finite() {
221 let n = 4;
222 let k = 8;
223 let mut handle = SslHandle::default_handle();
224 let mut s = vec![0.0_f32; n * k];
225 let mut t = vec![0.0_f32; n * k];
226 handle.rng_mut().fill_normal(&mut s);
227 handle.rng_mut().fill_normal(&mut t);
228 let centre = vec![0.0_f32; k];
229 let cfg = DinoConfig::default();
230 let l = dino_loss(&s, &t, ¢re, n, k, &cfg).unwrap();
231 assert!(l.is_finite() && l > 0.0);
232 }
233
234 #[test]
235 fn e2e_ema_converges_to_online_when_momentum_zero() {
236 let mut updater = EmaUpdater::new();
237 let mut target = vec![5.0_f32; 8];
238 let online = vec![10.0_f32; 8];
239 updater.update(&mut target, &online, 0.0).unwrap();
240 for &v in &target {
241 assert!((v - 10.0).abs() < 1e-6);
242 }
243 let m1 = cosine_momentum(0, 100, 0.5, 1.0).unwrap();
245 let m2 = cosine_momentum(100, 100, 0.5, 1.0).unwrap();
246 assert!(m1 < m2);
247 }
248
249 #[test]
250 fn e2e_mlp_projector_forward_correct_shape() {
251 let mut handle = SslHandle::default_handle();
252 let p = MlpProjector::new(64, 32, 16, handle.rng_mut()).unwrap();
253 let x = vec![0.1_f32; 64];
254 let y = p.forward(&x).unwrap();
255 assert_eq!(y.len(), 16);
256 let pred = PredictorHead::new(16, 32, 16, handle.rng_mut()).unwrap();
258 let y2 = pred.forward(&y).unwrap();
259 assert_eq!(y2.len(), 16);
260 }
261
262 #[test]
263 fn e2e_multi_crop_returns_n_crops() {
264 let cfg = MultiCropConfig::default();
265 let crops = multi_crop(&cfg).unwrap();
266 assert_eq!(crops.len(), cfg.n_crops());
267 assert!(crops[0].is_global);
269 assert!(crops[1].is_global);
270 let mut handle = SslHandle::default_handle();
272 let h = 8;
273 let w = 8;
274 let mut img = vec![0.5_f32; 3 * h * w];
275 color_jitter(&mut img, h, w, 0.5, handle.rng_mut()).unwrap();
276 let _converted = random_grayscale_chw(&mut img, h, w, 0.5, handle.rng_mut()).unwrap();
277 for v in &img {
278 assert!((0.0..=1.0).contains(v));
279 }
280 }
281
282 #[test]
283 fn e2e_ptx_kernels_all_sm_versions() {
284 for sm in [75_u32, 80, 86, 90, 100, 120] {
285 for prog in [
286 nt_xent_softmax_ptx(sm),
287 momentum_update_ptx(sm),
288 byol_cosine_loss_ptx(sm),
289 barlow_cross_corr_ptx(sm),
290 random_mask_ptx(sm),
291 cosine_similarity_ptx(sm),
292 gather_features_ptx(sm),
293 ] {
294 assert!(prog.contains(&format!("sm_{sm}")));
295 assert!(prog.contains(".visible .entry"));
296 }
297 }
298 assert_eq!(f32_hex(1.0_f32), "0F3F800000");
300 }
301}