pub mod augment;
pub mod clustering;
pub mod contrastive;
pub mod error;
pub mod handle;
pub mod head;
pub mod masked;
pub mod momentum;
pub mod non_contrastive;
pub mod ptx_kernels;
pub mod prelude {
pub use crate::augment::color::{color_jitter, random_grayscale_chw};
pub use crate::augment::multi_crop::{MultiCropConfig, multi_crop};
pub use crate::clustering::dino::{DinoConfig, dino_loss};
pub use crate::clustering::swav::{SwavConfig, sinkhorn_knopp, swav_loss};
pub use crate::contrastive::info_nce::info_nce_loss;
pub use crate::contrastive::moco::{MocoQueue, moco_loss};
pub use crate::contrastive::simclr::{SimClrConfig, simclr_loss};
pub use crate::error::{SslError, SslResult};
pub use crate::handle::{LcgRng, SmVersion, SslHandle};
pub use crate::head::predictor::PredictorHead;
pub use crate::head::projector::MlpProjector;
pub use crate::masked::mae::{MaeConfig, mae_reconstruction_loss, random_patch_mask};
pub use crate::momentum::ema::{EmaUpdater, cosine_momentum};
pub use crate::non_contrastive::barlow::{BarlowTwinsConfig, barlow_twins_loss};
pub use crate::non_contrastive::byol::{ByolPredictor, byol_loss};
pub use crate::non_contrastive::vicreg::{VicRegConfig, vicreg_loss};
pub use crate::ptx_kernels::{
barlow_cross_corr_ptx, byol_cosine_loss_ptx, cosine_similarity_ptx, f32_hex,
gather_features_ptx, momentum_update_ptx, nt_xent_softmax_ptx, random_mask_ptx,
};
}
#[cfg(test)]
mod e2e_tests {
use crate::prelude::*;
fn aligned_projections(n: usize, d: usize) -> Vec<f32> {
let mut z = vec![0.0_f32; n * d];
for i in 0..n {
z[i * d + i % d] = 1.0;
}
z
}
#[test]
fn e2e_simclr_loss_drops_with_aligned_pairs() {
let n = 8;
let d = 16;
let z = aligned_projections(n, d);
let cfg = SimClrConfig::default();
let (loss, acc) = simclr_loss(&z, &z, n, d, &cfg).unwrap();
assert!(loss.is_finite() && loss < 1.0, "loss = {loss}");
assert!((acc - 1.0).abs() < 1e-6);
}
#[test]
fn e2e_moco_queue_lifecycle_fifo() {
let mut q = MocoQueue::new(8, 4).unwrap();
for batch_id in 0..6 {
let mut batch = vec![0.0_f32; 4];
batch[batch_id % 4] = 1.0;
q.enqueue(&batch).unwrap();
}
assert_eq!(q.len(), 6);
let q_vec = vec![1.0_f32, 0.0, 0.0, 0.0];
let k_vec = q_vec.clone();
let l = moco_loss(&q_vec, &k_vec, 1, 4, &q, 0.1).unwrap();
assert!(l.is_finite());
}
#[test]
fn e2e_byol_loss_zero_for_identical_inputs() {
let z = vec![1.0_f32, 0.0, 0.0, 0.0, 1.0, 0.0];
let l = byol_loss(&z, &z, 2, 3).unwrap();
assert!(l.abs() < 1e-4);
}
#[test]
fn e2e_barlow_twins_low_for_identical_inputs() {
let n = 16;
let d = 4;
let mut z = vec![0.0_f32; n * d];
for i in 0..n {
for j in 0..d {
z[i * d + j] = (i as f32) * 0.1 + (j as f32) * 0.7;
}
}
let cfg = BarlowTwinsConfig::default();
let l = barlow_twins_loss(&z, &z, n, d, &cfg).unwrap();
assert!(l.is_finite());
}
#[test]
fn e2e_vicreg_three_terms_combine() {
let n = 16;
let d = 4;
let z_a: Vec<f32> = (0..n * d).map(|i| (i as f32 * 0.013).sin()).collect();
let z_b: Vec<f32> = (0..n * d)
.map(|i| (i as f32 * 0.013).sin() + 0.01)
.collect();
let cfg = VicRegConfig::default();
let l = vicreg_loss(&z_a, &z_b, n, d, &cfg).unwrap();
assert!(l.is_finite() && l > 0.0);
}
#[test]
fn e2e_mae_mask_ratio_respected() {
let mut handle = SslHandle::default_handle();
let mask = random_patch_mask(196, 0.75, handle.rng_mut()).unwrap();
let n_masked = mask.iter().filter(|&&v| v == 0.0).count();
assert_eq!(n_masked, 147); let target = vec![1.5_f32; 196 * 4];
let pred = target.clone();
let l = mae_reconstruction_loss(&target, &pred, &mask, 196, 4).unwrap();
assert!(l.abs() < 1e-7);
}
#[test]
fn e2e_swav_sinkhorn_normalises_uniform() {
let n = 8;
let k = 4;
let mut q = vec![1.0_f32; n * k];
sinkhorn_knopp(&mut q, n, k, 5).unwrap();
for i in 0..n {
let s: f32 = q[i * k..(i + 1) * k].iter().sum();
assert!((s - 1.0).abs() < 1e-4, "row sum = {s}");
}
}
#[test]
fn e2e_dino_centred_softmax_returns_finite() {
let n = 4;
let k = 8;
let mut handle = SslHandle::default_handle();
let mut s = vec![0.0_f32; n * k];
let mut t = vec![0.0_f32; n * k];
handle.rng_mut().fill_normal(&mut s);
handle.rng_mut().fill_normal(&mut t);
let centre = vec![0.0_f32; k];
let cfg = DinoConfig::default();
let l = dino_loss(&s, &t, ¢re, n, k, &cfg).unwrap();
assert!(l.is_finite() && l > 0.0);
}
#[test]
fn e2e_ema_converges_to_online_when_momentum_zero() {
let mut updater = EmaUpdater::new();
let mut target = vec![5.0_f32; 8];
let online = vec![10.0_f32; 8];
updater.update(&mut target, &online, 0.0).unwrap();
for &v in &target {
assert!((v - 10.0).abs() < 1e-6);
}
let m1 = cosine_momentum(0, 100, 0.5, 1.0).unwrap();
let m2 = cosine_momentum(100, 100, 0.5, 1.0).unwrap();
assert!(m1 < m2);
}
#[test]
fn e2e_mlp_projector_forward_correct_shape() {
let mut handle = SslHandle::default_handle();
let p = MlpProjector::new(64, 32, 16, handle.rng_mut()).unwrap();
let x = vec![0.1_f32; 64];
let y = p.forward(&x).unwrap();
assert_eq!(y.len(), 16);
let pred = PredictorHead::new(16, 32, 16, handle.rng_mut()).unwrap();
let y2 = pred.forward(&y).unwrap();
assert_eq!(y2.len(), 16);
}
#[test]
fn e2e_multi_crop_returns_n_crops() {
let cfg = MultiCropConfig::default();
let crops = multi_crop(&cfg).unwrap();
assert_eq!(crops.len(), cfg.n_crops());
assert!(crops[0].is_global);
assert!(crops[1].is_global);
let mut handle = SslHandle::default_handle();
let h = 8;
let w = 8;
let mut img = vec![0.5_f32; 3 * h * w];
color_jitter(&mut img, h, w, 0.5, handle.rng_mut()).unwrap();
let _converted = random_grayscale_chw(&mut img, h, w, 0.5, handle.rng_mut()).unwrap();
for v in &img {
assert!((0.0..=1.0).contains(v));
}
}
#[test]
fn e2e_ptx_kernels_all_sm_versions() {
for sm in [75_u32, 80, 86, 90, 100, 120] {
for prog in [
nt_xent_softmax_ptx(sm),
momentum_update_ptx(sm),
byol_cosine_loss_ptx(sm),
barlow_cross_corr_ptx(sm),
random_mask_ptx(sm),
cosine_similarity_ptx(sm),
gather_features_ptx(sm),
] {
assert!(prog.contains(&format!("sm_{sm}")));
assert!(prog.contains(".visible .entry"));
}
}
assert_eq!(f32_hex(1.0_f32), "0F3F800000");
}
}