use crate::error::{NeuralError, Result};
use scirs2_core::ndarray::{Array1, Array2, Array4, Axis};
use scirs2_core::numeric::{Float, FromPrimitive, NumAssign, ToPrimitive};
use std::fmt::Debug;
#[derive(Debug, Clone)]
pub struct BarlowTwinsConfig {
pub lambda: f64,
pub scale_loss: bool,
}
impl Default for BarlowTwinsConfig {
fn default() -> Self {
Self {
lambda: 5e-3,
scale_loss: true,
}
}
}
#[derive(Debug, Clone)]
pub struct BarlowTwinsLoss {
config: BarlowTwinsConfig,
}
impl BarlowTwinsLoss {
pub fn new(config: BarlowTwinsConfig) -> Self {
Self { config }
}
pub fn forward<F>(&self, z_a: &Array2<F>, z_b: &Array2<F>) -> Result<F>
where
F: Float + Debug + NumAssign + FromPrimitive + ToPrimitive,
{
let n = z_a.nrows();
let d = z_a.ncols();
if z_b.shape() != z_a.shape() {
return Err(NeuralError::ShapeMismatch(format!(
"BarlowTwins: z_a shape {:?} != z_b shape {:?}",
z_a.shape(),
z_b.shape()
)));
}
if n == 0 || d == 0 {
return Err(NeuralError::InvalidArgument(
"BarlowTwins: batch and feature dim must be > 0".to_string(),
));
}
let za_norm = batch_norm_features(z_a)?;
let zb_norm = batch_norm_features(z_b)?;
let n_f = F::from_usize(n).ok_or_else(|| {
NeuralError::ComputationError("BarlowTwins: cannot convert N".to_string())
})?;
let mut c = Array2::zeros((d, d));
for k in 0..n {
for i in 0..d {
for j in 0..d {
c[[i, j]] += za_norm[[k, i]] * zb_norm[[k, j]];
}
}
}
for v in c.iter_mut() {
*v /= n_f;
}
let lambda = F::from_f64(self.config.lambda).ok_or_else(|| {
NeuralError::ComputationError("BarlowTwins: cannot convert lambda".to_string())
})?;
let mut loss = F::zero();
for i in 0..d {
for j in 0..d {
let cij = c[[i, j]];
if i == j {
let diff = F::one() - cij;
loss += diff * diff;
} else {
loss += lambda * cij * cij;
}
}
}
if self.config.scale_loss {
let d_f = F::from_usize(d).ok_or_else(|| {
NeuralError::ComputationError("BarlowTwins: cannot convert D".to_string())
})?;
loss /= d_f;
}
Ok(loss)
}
}
#[derive(Debug, Clone)]
pub struct VICRegConfig {
pub invariance_weight: f64,
pub variance_weight: f64,
pub covariance_weight: f64,
pub variance_target: f64,
pub eps: f64,
}
impl Default for VICRegConfig {
fn default() -> Self {
Self {
invariance_weight: 25.0,
variance_weight: 25.0,
covariance_weight: 1.0,
variance_target: 1.0,
eps: 1e-4,
}
}
}
#[derive(Debug, Clone)]
pub struct VICRegLoss {
config: VICRegConfig,
}
impl VICRegLoss {
pub fn new(config: VICRegConfig) -> Self {
Self { config }
}
pub fn forward<F>(
&self,
z_a: &Array2<F>,
z_b: &Array2<F>,
) -> Result<(F, F, F, F)>
where
F: Float + Debug + NumAssign + FromPrimitive + ToPrimitive,
{
let n = z_a.nrows();
let d = z_a.ncols();
if z_b.shape() != z_a.shape() {
return Err(NeuralError::ShapeMismatch(format!(
"VICReg: z_a shape {:?} != z_b shape {:?}",
z_a.shape(),
z_b.shape()
)));
}
if n < 2 {
return Err(NeuralError::InvalidArgument(
"VICReg: batch size must be ≥ 2".to_string(),
));
}
let n_f = F::from_usize(n).ok_or_else(|| {
NeuralError::ComputationError("VICReg: cannot convert N".to_string())
})?;
let d_f = F::from_usize(d).ok_or_else(|| {
NeuralError::ComputationError("VICReg: cannot convert D".to_string())
})?;
let eps = F::from_f64(self.config.eps).ok_or_else(|| {
NeuralError::ComputationError("VICReg: cannot convert eps".to_string())
})?;
let var_target = F::from_f64(self.config.variance_target).ok_or_else(|| {
NeuralError::ComputationError("VICReg: cannot convert variance_target".to_string())
})?;
let mut inv_loss = F::zero();
for i in 0..n {
for j in 0..d {
let diff = z_a[[i, j]] - z_b[[i, j]];
inv_loss += diff * diff;
}
}
inv_loss /= n_f * d_f;
let var_loss = variance_loss(z_a, var_target, eps, n_f, d_f)?
+ variance_loss(z_b, var_target, eps, n_f, d_f)?;
let var_loss = var_loss / F::from_f64(2.0).ok_or_else(|| {
NeuralError::ComputationError("VICReg: cannot convert 2.0".to_string())
})?;
let cov_loss = covariance_loss(z_a, n_f, d_f)?
+ covariance_loss(z_b, n_f, d_f)?;
let cov_loss = cov_loss / F::from_f64(2.0).ok_or_else(|| {
NeuralError::ComputationError("VICReg: cannot convert 2.0 (cov)".to_string())
})?;
let lam = F::from_f64(self.config.invariance_weight).ok_or_else(|| {
NeuralError::ComputationError("VICReg: cannot convert invariance_weight".to_string())
})?;
let mu = F::from_f64(self.config.variance_weight).ok_or_else(|| {
NeuralError::ComputationError("VICReg: cannot convert variance_weight".to_string())
})?;
let nu = F::from_f64(self.config.covariance_weight).ok_or_else(|| {
NeuralError::ComputationError("VICReg: cannot convert covariance_weight".to_string())
})?;
let total = lam * inv_loss + mu * var_loss + nu * cov_loss;
Ok((total, inv_loss, var_loss, cov_loss))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MaskedAEReconLoss {
MSE,
MAE,
SmoothL1,
}
#[derive(Debug, Clone)]
pub struct MaskedAEConfig {
pub masking_ratio: f64,
pub patch_size: usize,
pub input_size: usize,
pub recon_loss: MaskedAEReconLoss,
pub normalize_target: bool,
pub encoder_embed_dim: usize,
pub decoder_embed_dim: usize,
pub encoder_depth: usize,
pub decoder_depth: usize,
pub num_heads: usize,
}
impl Default for MaskedAEConfig {
fn default() -> Self {
Self {
masking_ratio: 0.75,
patch_size: 16,
input_size: 224,
recon_loss: MaskedAEReconLoss::MSE,
normalize_target: true,
encoder_embed_dim: 768,
decoder_embed_dim: 512,
encoder_depth: 12,
decoder_depth: 8,
num_heads: 12,
}
}
}
impl MaskedAEConfig {
pub fn num_patches(&self) -> usize {
let grid = self.input_size / self.patch_size;
grid * grid
}
pub fn num_masked(&self) -> usize {
let n = self.num_patches() as f64;
(n * self.masking_ratio).round() as usize
}
pub fn validate(&self) -> Result<()> {
if self.patch_size == 0 {
return Err(NeuralError::ConfigError(
"MaskedAEConfig: patch_size must be > 0".to_string(),
));
}
if self.input_size == 0 || self.input_size % self.patch_size != 0 {
return Err(NeuralError::ConfigError(format!(
"MaskedAEConfig: input_size ({}) must be divisible by patch_size ({})",
self.input_size, self.patch_size
)));
}
if !(0.0..1.0).contains(&self.masking_ratio) {
return Err(NeuralError::ConfigError(
"MaskedAEConfig: masking_ratio must be in [0, 1)".to_string(),
));
}
if self.encoder_embed_dim == 0 || self.decoder_embed_dim == 0 {
return Err(NeuralError::ConfigError(
"MaskedAEConfig: embed dims must be > 0".to_string(),
));
}
Ok(())
}
pub fn reconstruction_loss<F>(&self, pred: &Array2<F>, target: &Array2<F>) -> Result<F>
where
F: Float + Debug + NumAssign + FromPrimitive,
{
if pred.shape() != target.shape() {
return Err(NeuralError::ShapeMismatch(format!(
"MaskedAEConfig::reconstruction_loss: pred {:?} != target {:?}",
pred.shape(),
target.shape()
)));
}
let n_elems = pred.len();
if n_elems == 0 {
return Ok(F::zero());
}
let n_f = F::from_usize(n_elems).ok_or_else(|| {
NeuralError::ComputationError("MaskedAEConfig: cannot convert n_elems".to_string())
})?;
let loss = match self.recon_loss {
MaskedAEReconLoss::MSE => {
let s: F = pred
.iter()
.zip(target.iter())
.map(|(p, t)| {
let d = *p - *t;
d * d
})
.fold(F::zero(), |a, b| a + b);
s / n_f
}
MaskedAEReconLoss::MAE => {
let s: F = pred
.iter()
.zip(target.iter())
.map(|(p, t)| (*p - *t).abs())
.fold(F::zero(), |a, b| a + b);
s / n_f
}
MaskedAEReconLoss::SmoothL1 => {
let one = F::one();
let half = F::from_f64(0.5).ok_or_else(|| {
NeuralError::ComputationError(
"MaskedAEConfig: cannot convert 0.5".to_string(),
)
})?;
let s: F = pred
.iter()
.zip(target.iter())
.map(|(p, t)| {
let d = (*p - *t).abs();
if d < one {
half * d * d
} else {
d - half
}
})
.fold(F::zero(), |a, b| a + b);
s / n_f
}
};
Ok(loss)
}
}
#[derive(Debug, Clone)]
pub enum SSLAugmentation {
RandomCrop {
height: usize,
width: usize,
padding: usize,
},
HorizontalFlip {
probability: f64,
},
VerticalFlip {
probability: f64,
},
ColorJitter {
brightness: f64,
contrast: f64,
saturation: f64,
},
GaussianBlur {
sigma: f64,
},
Grayscale {
probability: f64,
},
Normalize {
mean: Vec<f64>,
std: Vec<f64>,
},
}
#[derive(Debug, Clone)]
pub struct AugmentationPipelineSSL {
augmentations: Vec<SSLAugmentation>,
seed: Option<u64>,
}
impl AugmentationPipelineSSL {
pub fn new() -> Self {
Self {
augmentations: Vec::new(),
seed: None,
}
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
pub fn add(mut self, aug: SSLAugmentation) -> Self {
self.augmentations.push(aug);
self
}
pub fn len(&self) -> usize {
self.augmentations.len()
}
pub fn is_empty(&self) -> bool {
self.augmentations.is_empty()
}
pub fn simclr_cifar10() -> Self {
Self::new()
.add(SSLAugmentation::RandomCrop {
height: 32,
width: 32,
padding: 4,
})
.add(SSLAugmentation::HorizontalFlip { probability: 0.5 })
.add(SSLAugmentation::ColorJitter {
brightness: 0.4,
contrast: 0.4,
saturation: 0.4,
})
.add(SSLAugmentation::GaussianBlur { sigma: 1.0 })
.add(SSLAugmentation::Normalize {
mean: vec![0.4914, 0.4822, 0.4465],
std: vec![0.2023, 0.1994, 0.2010],
})
}
pub fn simclr_imagenet() -> Self {
Self::new()
.add(SSLAugmentation::RandomCrop {
height: 224,
width: 224,
padding: 28,
})
.add(SSLAugmentation::HorizontalFlip { probability: 0.5 })
.add(SSLAugmentation::ColorJitter {
brightness: 0.8,
contrast: 0.8,
saturation: 0.8,
})
.add(SSLAugmentation::GaussianBlur { sigma: 2.0 })
.add(SSLAugmentation::Grayscale { probability: 0.2 })
.add(SSLAugmentation::Normalize {
mean: vec![0.485, 0.456, 0.406],
std: vec![0.229, 0.224, 0.225],
})
}
pub fn apply<F>(&self, batch: &Array4<F>) -> Result<Array4<F>>
where
F: Float + Debug + NumAssign + FromPrimitive + ToPrimitive,
{
let mut out = batch.to_owned();
for aug in &self.augmentations {
out = apply_augmentation(&out, aug)?;
}
Ok(out)
}
}
impl Default for AugmentationPipelineSSL {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct SSLStepResult<F: Float + Debug> {
pub total_loss: F,
pub step: usize,
pub components: std::collections::HashMap<String, F>,
}
pub trait SSLStepCallback<F: Float + Debug>: Debug {
fn on_step(&mut self, result: &SSLStepResult<F>);
}
#[derive(Debug, Clone)]
pub enum SSLLossType {
NTXent {
temperature: f64,
},
BarlowTwins(BarlowTwinsConfig),
VICReg(VICRegConfig),
}
#[derive(Debug, Clone)]
pub struct SSLTrainerConfig {
pub epochs: usize,
pub warmup_epochs: usize,
pub base_lr: f64,
pub min_lr: f64,
pub weight_decay: f64,
pub log_every: usize,
pub loss_type: SSLLossType,
}
impl Default for SSLTrainerConfig {
fn default() -> Self {
Self {
epochs: 100,
warmup_epochs: 10,
base_lr: 3e-4,
min_lr: 1e-6,
weight_decay: 1e-6,
log_every: 50,
loss_type: SSLLossType::BarlowTwins(BarlowTwinsConfig::default()),
}
}
}
impl SSLTrainerConfig {
pub fn lr_at_epoch(&self, epoch: usize) -> f64 {
if epoch < self.warmup_epochs {
let t = (epoch as f64 + 1.0) / (self.warmup_epochs as f64).max(1.0);
self.base_lr * t
} else {
let progress = (epoch - self.warmup_epochs) as f64
/ ((self.epochs - self.warmup_epochs) as f64).max(1.0);
let cos = (std::f64::consts::PI * progress).cos();
self.min_lr + 0.5 * (self.base_lr - self.min_lr) * (1.0 + cos)
}
}
}
#[derive(Debug, Clone)]
pub struct SelfSupervisedTrainer {
pub config: SSLTrainerConfig,
pub global_step: usize,
pub current_epoch: usize,
pub loss_history: Vec<f64>,
}
impl SelfSupervisedTrainer {
pub fn new(config: SSLTrainerConfig) -> Self {
Self {
config,
global_step: 0,
current_epoch: 0,
loss_history: Vec::new(),
}
}
pub fn compute_loss<F>(&mut self, z_a: &Array2<F>, z_b: &Array2<F>) -> Result<SSLStepResult<F>>
where
F: Float + Debug + NumAssign + FromPrimitive + ToPrimitive,
{
let mut components = std::collections::HashMap::new();
let total_loss = match &self.config.loss_type {
SSLLossType::NTXent { temperature } => {
use crate::training::contrastive::NTXentLoss;
let loss_fn = NTXentLoss::new(*temperature);
let l = loss_fn.forward(z_a, z_b)?;
l
}
SSLLossType::BarlowTwins(cfg) => {
let loss_fn = BarlowTwinsLoss::new(cfg.clone());
loss_fn.forward(z_a, z_b)?
}
SSLLossType::VICReg(cfg) => {
let loss_fn = VICRegLoss::new(cfg.clone());
let (total, inv, var, cov) = loss_fn.forward(z_a, z_b)?;
components.insert("invariance".to_string(), inv);
components.insert("variance".to_string(), var);
components.insert("covariance".to_string(), cov);
total
}
};
let step = self.global_step;
self.global_step += 1;
Ok(SSLStepResult {
total_loss,
step,
components,
})
}
pub fn on_epoch_end(&mut self, epoch_mean_loss: f64) {
self.loss_history.push(epoch_mean_loss);
self.current_epoch += 1;
}
pub fn current_lr(&self) -> f64 {
self.config.lr_at_epoch(self.current_epoch)
}
}
fn batch_norm_features<F>(x: &Array2<F>) -> Result<Array2<F>>
where
F: Float + Debug + NumAssign + FromPrimitive,
{
let n = x.nrows();
let d = x.ncols();
if n < 2 {
return Err(NeuralError::InvalidArgument(
"batch_norm_features: need at least 2 samples".to_string(),
));
}
let n_f = F::from_usize(n).ok_or_else(|| {
NeuralError::ComputationError("batch_norm_features: cannot convert N".to_string())
})?;
let eps = F::from_f64(1e-12).ok_or_else(|| {
NeuralError::ComputationError("batch_norm_features: cannot convert eps".to_string())
})?;
let mut out = x.to_owned();
for j in 0..d {
let col = x.column(j);
let mean = col.iter().fold(F::zero(), |a, &b| a + b) / n_f;
let var = col
.iter()
.map(|&v| {
let d = v - mean;
d * d
})
.fold(F::zero(), |a, b| a + b)
/ n_f;
let std = (var + eps).sqrt();
for i in 0..n {
out[[i, j]] = (x[[i, j]] - mean) / std;
}
}
Ok(out)
}
fn variance_loss<F>(
z: &Array2<F>,
gamma: F,
eps: F,
n_f: F,
d_f: F,
) -> Result<F>
where
F: Float + Debug + NumAssign + FromPrimitive,
{
let n = z.nrows();
let d = z.ncols();
let mut total = F::zero();
for j in 0..d {
let col = z.column(j);
let mean = col.iter().fold(F::zero(), |a, &b| a + b) / n_f;
let var = col
.iter()
.map(|&v| {
let diff = v - mean;
diff * diff
})
.fold(F::zero(), |a, b| a + b)
/ (n_f - F::one());
let std = (var + eps).sqrt();
let hinge = (gamma - std).max(F::zero());
total += hinge;
}
let _ = n; Ok(total / d_f)
}
fn covariance_loss<F>(z: &Array2<F>, n_f: F, d_f: F) -> Result<F>
where
F: Float + Debug + NumAssign + FromPrimitive,
{
let n = z.nrows();
let d = z.ncols();
let mut z_centred = z.to_owned();
for j in 0..d {
let col = z.column(j);
let mean = col.iter().fold(F::zero(), |a, &b| a + b) / n_f;
for i in 0..n {
z_centred[[i, j]] -= mean;
}
}
let mut cov = Array2::<F>::zeros((d, d));
for i in 0..n {
for a in 0..d {
for b in 0..d {
cov[[a, b]] += z_centred[[i, a]] * z_centred[[i, b]];
}
}
}
let denom = n_f - F::one();
for v in cov.iter_mut() {
*v /= denom;
}
let mut total = F::zero();
for a in 0..d {
for b in 0..d {
if a != b {
total += cov[[a, b]] * cov[[a, b]];
}
}
}
Ok(total / d_f)
}
fn apply_augmentation<F>(
batch: &Array4<F>,
aug: &SSLAugmentation,
) -> Result<Array4<F>>
where
F: Float + Debug + NumAssign + FromPrimitive + ToPrimitive,
{
match aug {
SSLAugmentation::Normalize { mean, std } => {
let n = batch.shape()[0];
let c = batch.shape()[1];
let h = batch.shape()[2];
let w = batch.shape()[3];
if mean.len() != c || std.len() != c {
return Err(NeuralError::ShapeMismatch(format!(
"Normalize augmentation: mean/std length {} but C={}",
mean.len(),
c
)));
}
let mut out = batch.to_owned();
for ni in 0..n {
for ci in 0..c {
let m = F::from_f64(mean[ci]).ok_or_else(|| {
NeuralError::ComputationError("Normalize: cannot convert mean".to_string())
})?;
let s = F::from_f64(std[ci]).ok_or_else(|| {
NeuralError::ComputationError("Normalize: cannot convert std".to_string())
})?;
for hi in 0..h {
for wi in 0..w {
out[[ni, ci, hi, wi]] = (batch[[ni, ci, hi, wi]] - m) / s;
}
}
}
}
Ok(out)
}
SSLAugmentation::HorizontalFlip { probability } => {
if *probability >= 0.5 {
let w = batch.shape()[3];
let mut out = batch.to_owned();
let n = batch.shape()[0];
let c = batch.shape()[1];
let h = batch.shape()[2];
for ni in 0..n {
for ci in 0..c {
for hi in 0..h {
for wi in 0..w / 2 {
let mirror = w - 1 - wi;
let tmp = out[[ni, ci, hi, wi]];
out[[ni, ci, hi, wi]] = out[[ni, ci, hi, mirror]];
out[[ni, ci, hi, mirror]] = tmp;
}
}
}
}
Ok(out)
} else {
Ok(batch.to_owned())
}
}
SSLAugmentation::VerticalFlip { probability } => {
if *probability >= 0.5 {
let h = batch.shape()[2];
let mut out = batch.to_owned();
let n = batch.shape()[0];
let c = batch.shape()[1];
let w = batch.shape()[3];
for ni in 0..n {
for ci in 0..c {
for hi in 0..h / 2 {
let mirror = h - 1 - hi;
for wi in 0..w {
let tmp = out[[ni, ci, hi, wi]];
out[[ni, ci, hi, wi]] = out[[ni, ci, mirror, wi]];
out[[ni, ci, mirror, wi]] = tmp;
}
}
}
}
Ok(out)
} else {
Ok(batch.to_owned())
}
}
SSLAugmentation::ColorJitter {
brightness,
contrast,
saturation,
} => {
let delta = F::from_f64(*brightness).ok_or_else(|| {
NeuralError::ComputationError("ColorJitter: cannot convert brightness".to_string())
})?;
let _ = contrast;
let _ = saturation;
let out = batch.mapv(|v| v + delta);
Ok(out)
}
SSLAugmentation::GaussianBlur { sigma } => {
apply_gaussian_blur(batch, *sigma)
}
SSLAugmentation::Grayscale { probability } => {
if *probability >= 0.5 {
apply_grayscale(batch)
} else {
Ok(batch.to_owned())
}
}
SSLAugmentation::RandomCrop {
height,
width,
padding,
} => apply_center_crop(batch, *height, *width, *padding),
}
}
fn apply_center_crop<F: Float + Debug + FromPrimitive>(
batch: &Array4<F>,
target_h: usize,
target_w: usize,
padding: usize,
) -> Result<Array4<F>> {
let n = batch.shape()[0];
let c = batch.shape()[1];
let h = batch.shape()[2];
let w = batch.shape()[3];
let ph = h + 2 * padding;
let pw = w + 2 * padding;
if target_h > ph || target_w > pw {
return Err(NeuralError::InvalidArgument(format!(
"RandomCrop: target ({},{}) larger than padded image ({},{})",
target_h, target_w, ph, pw
)));
}
let start_h = (ph - target_h) / 2;
let start_w = (pw - target_w) / 2;
let mut out = Array4::zeros((n, c, target_h, target_w));
for ni in 0..n {
for ci in 0..c {
for hi in 0..target_h {
for wi in 0..target_w {
let src_h = (start_h + hi).saturating_sub(padding);
let src_w = (start_w + wi).saturating_sub(padding);
let src_h = src_h.min(h - 1);
let src_w = src_w.min(w - 1);
out[[ni, ci, hi, wi]] = batch[[ni, ci, src_h, src_w]];
}
}
}
}
Ok(out)
}
fn apply_gaussian_blur<F: Float + Debug + NumAssign + FromPrimitive>(
batch: &Array4<F>,
_sigma: f64,
) -> Result<Array4<F>> {
let n = batch.shape()[0];
let c = batch.shape()[1];
let h = batch.shape()[2];
let w = batch.shape()[3];
if h < 3 || w < 3 {
return Ok(batch.to_owned());
}
let w_val = F::from_f64(1.0 / 9.0).ok_or_else(|| {
NeuralError::ComputationError("GaussianBlur: cannot convert weight".to_string())
})?;
let mut out = Array4::zeros((n, c, h, w));
for ni in 0..n {
for ci in 0..c {
for hi in 0..h {
for wi in 0..w {
let mut sum = F::zero();
for dh in 0..3usize {
for dw in 0..3usize {
let src_h = (hi + dh).saturating_sub(1).min(h - 1);
let src_w = (wi + dw).saturating_sub(1).min(w - 1);
sum += batch[[ni, ci, src_h, src_w]] * w_val;
}
}
out[[ni, ci, hi, wi]] = sum;
}
}
}
}
Ok(out)
}
fn apply_grayscale<F: Float + Debug + NumAssign + FromPrimitive>(
batch: &Array4<F>,
) -> Result<Array4<F>> {
let n = batch.shape()[0];
let c = batch.shape()[1];
let h = batch.shape()[2];
let w = batch.shape()[3];
if c != 3 {
return Ok(batch.to_owned());
}
let one_third = F::from_f64(1.0 / 3.0).ok_or_else(|| {
NeuralError::ComputationError("apply_grayscale: cannot convert 1/3".to_string())
})?;
let mut out = batch.to_owned();
for ni in 0..n {
for hi in 0..h {
for wi in 0..w {
let gray = (batch[[ni, 0, hi, wi]]
+ batch[[ni, 1, hi, wi]]
+ batch[[ni, 2, hi, wi]])
* one_third;
for ci in 0..c {
out[[ni, ci, hi, wi]] = gray;
}
}
}
}
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
#[test]
fn test_barlow_twins() {
let config = BarlowTwinsConfig::default();
let loss_fn = BarlowTwinsLoss::new(config);
let z_a = Array2::<f64>::from_shape_fn((8, 16), |(i, j)| (i + j) as f64 * 0.05);
let z_b = Array2::<f64>::from_shape_fn((8, 16), |(i, j)| (i + j) as f64 * 0.06);
let loss = loss_fn.forward(&z_a, &z_b).expect("BarlowTwins forward");
assert!(loss.is_finite());
assert!(loss >= 0.0);
}
#[test]
fn test_vicreg() {
let config = VICRegConfig::default();
let loss_fn = VICRegLoss::new(config);
let z_a = Array2::<f64>::from_shape_fn((8, 16), |(i, j)| (i * 2 + j) as f64 * 0.05);
let z_b = Array2::<f64>::from_shape_fn((8, 16), |(i, j)| (i * 2 + j) as f64 * 0.06);
let (total, inv, var, cov) = loss_fn.forward(&z_a, &z_b).expect("VICReg forward");
assert!(total.is_finite());
assert!(inv.is_finite());
assert!(var.is_finite());
assert!(cov.is_finite());
}
#[test]
fn test_masked_ae_config() {
let config = MaskedAEConfig::default();
config.validate().expect("config valid");
assert_eq!(config.num_patches(), 196); assert_eq!(config.num_masked(), 147); }
#[test]
fn test_reconstruction_loss() {
let config = MaskedAEConfig::default();
let pred = Array2::<f64>::from_shape_fn((4, 16), |(i, j)| (i + j) as f64 * 0.1);
let target = Array2::<f64>::from_shape_fn((4, 16), |(i, j)| (i + j) as f64 * 0.12);
let loss = config
.reconstruction_loss(&pred, &target)
.expect("reconstruction_loss");
assert!(loss.is_finite());
assert!(loss >= 0.0);
}
#[test]
fn test_augmentation_pipeline_ssl() {
use scirs2_core::ndarray::Array4;
let pipeline = AugmentationPipelineSSL::simclr_cifar10();
assert!(!pipeline.is_empty());
let batch = Array4::<f64>::from_shape_fn((2, 3, 32, 32), |(n, c, h, w)| {
(n + c + h + w) as f64 * 0.01
});
let augmented = pipeline.apply(&batch).expect("pipeline apply");
assert_eq!(augmented.shape(), &[2, 3, 32, 32]);
}
#[test]
fn test_ssl_trainer_compute_loss() {
let config = SSLTrainerConfig {
loss_type: SSLLossType::BarlowTwins(BarlowTwinsConfig::default()),
epochs: 10,
..Default::default()
};
let mut trainer = SelfSupervisedTrainer::new(config);
let z_a = Array2::<f64>::from_shape_fn((8, 16), |(i, j)| (i + j) as f64 * 0.1);
let z_b = Array2::<f64>::from_shape_fn((8, 16), |(i, j)| (i + j) as f64 * 0.11);
let result = trainer.compute_loss(&z_a, &z_b).expect("compute_loss");
assert!(result.total_loss.is_finite());
assert_eq!(result.step, 0);
}
#[test]
fn test_ssl_lr_schedule() {
let config = SSLTrainerConfig {
epochs: 100,
warmup_epochs: 10,
base_lr: 1e-3,
min_lr: 1e-6,
..Default::default()
};
let lr_warmup = config.lr_at_epoch(9);
assert!((lr_warmup - 1e-3).abs() < 1e-10);
let lr_post = config.lr_at_epoch(50);
assert!(lr_post < 1e-3);
assert!(lr_post >= 1e-6);
}
}