burn_dragon_vision 0.4.0

Foveation and vision sampling utilities for burn dragon
Documentation
use crate::train::prelude::*;
use burn::optim::Optimizer;
use burn::tensor::Distribution;
use burn_autodiff::Autodiff;
use burn_dragon_core::{
    FusedKernelConfig, ManifoldHyperConnectionsConfig, SpatialPositionalEncodingKind,
    VisionAttentionMode, VisionLatentActivation, VisionPatchEmbedMode,
};
use burn_ndarray::NdArray;

#[test]
fn lejepa_invariance_loss_is_finite() {
    type Backend = NdArray<f32>;
    let device = <Backend as BackendTrait>::Device::default();

    let proj = Tensor::<Backend, 3>::random([2, 4, 8], Distribution::Default, &device);
    let loss = lejepa_invariance_loss(proj);
    let value = loss
        .to_data()
        .convert::<f32>()
        .into_vec::<f32>()
        .expect("loss vec")[0];
    assert!(value.is_finite());
}

#[test]
fn lejepa_sigreg_loss_is_finite() {
    type Backend = NdArray<f32>;
    let device = <Backend as BackendTrait>::Device::default();
    let config = VisionLejepaConfig::default();

    let proj = Tensor::<Backend, 3>::random([2, 4, 8], Distribution::Default, &device);
    let loss = lejepa_sigreg_loss(proj, &config.loss.lejepa);
    let value = loss
        .to_data()
        .convert::<f32>()
        .into_vec::<f32>()
        .expect("loss vec")[0];
    assert!(value.is_finite());
}

#[test]
fn patchify_roundtrip() {
    type Backend = NdArray<f32>;
    let device = <Backend as BackendTrait>::Device::default();

    let batch = 1;
    let channels = 3;
    let height = 4;
    let width = 4;
    let patch_size = 2;
    let total = batch * channels * height * width;
    let data: Vec<f32> = (0..total).map(|v| v as f32).collect();

    let images = Tensor::<Backend, 4>::from_data(
        TensorData::new(data.clone(), [batch, channels, height, width]),
        &device,
    );
    let patches = patchify(images.clone(), patch_size);
    let [patch_batch, tokens, patch_dim] = patches.shape().dims::<3>();
    assert_eq!(patch_batch, batch);
    assert_eq!(tokens, (height / patch_size) * (width / patch_size));
    assert_eq!(patch_dim, channels * patch_size * patch_size);

    let recon = unpatchify(patches, patch_size, height, width, channels);
    let out = recon
        .to_data()
        .convert::<f32>()
        .into_vec::<f32>()
        .expect("recon vec");
    assert_eq!(data, out);
}

#[test]
fn patchify_roundtrip_with_padding() {
    type Backend = NdArray<f32>;
    let device = <Backend as BackendTrait>::Device::default();

    let batch = 1;
    let channels = 3;
    let height = 5;
    let width = 6;
    let patch_size = 4;
    let total = batch * channels * height * width;
    let data: Vec<f32> = (0..total).map(|v| v as f32).collect();

    let images = Tensor::<Backend, 4>::from_data(
        TensorData::new(data.clone(), [batch, channels, height, width]),
        &device,
    );
    let patches = patchify(images.clone(), patch_size);
    let [patch_batch, tokens, patch_dim] = patches.shape().dims::<3>();
    let grid_h = height.div_ceil(patch_size);
    let grid_w = width.div_ceil(patch_size);
    assert_eq!(patch_batch, batch);
    assert_eq!(tokens, grid_h * grid_w);
    assert_eq!(patch_dim, channels * patch_size * patch_size);

    let recon = unpatchify(patches, patch_size, height, width, channels);
    let out = recon
        .to_data()
        .convert::<f32>()
        .into_vec::<f32>()
        .expect("recon vec");
    assert_eq!(data, out);
}

fn toy_images<B: BackendTrait>(
    batch: usize,
    channels: usize,
    height: usize,
    width: usize,
    device: &B::Device,
) -> Tensor<B, 4> {
    let mut data = Vec::with_capacity(batch * channels * height * width);
    let denom_w = (width - 1).max(1) as f32;
    let denom_h = (height - 1).max(1) as f32;
    for _ in 0..batch {
        for c in 0..channels {
            for y in 0..height {
                let gy = y as f32 / denom_h;
                for x in 0..width {
                    let gx = x as f32 / denom_w;
                    let checker = ((x / 2 + y / 3 + c) % 2) as f32;
                    let value = match c {
                        0 => gx,
                        1 => gy,
                        _ => 0.55 * gx + 0.35 * gy + 0.1 * checker,
                    };
                    data.push(value);
                }
            }
        }
    }
    Tensor::<B, 4>::from_data(
        TensorData::new(data, [batch, channels, height, width]),
        device,
    )
}

#[test]
fn lejepa_recon_psnr_improves_on_toy_batch() {
    type Backend = Autodiff<NdArray<f32>>;
    let device = <Backend as BackendTrait>::Device::default();

    let image_size: usize = 8;
    let patch_size: usize = 4;
    let grid = image_size.div_ceil(patch_size);
    let vision_config = VisionDragonHatchlingConfig {
        image_size,
        patch_size,
        patch_embed_mode: VisionPatchEmbedMode::default(),
        in_channels: 3,
        embed_dim: 32,
        steps: 1,
        n_head: 4,
        mlp_internal_dim_multiplier: 2,
        dropout: 0.0,
        projection_dim: 32,
        projection_hidden_dim: 64,
        use_cls_token: true,
        cls_sync_alpha: 0.0,
        num_eyes: 1,
        cross_eye_steps: 0,
        token_state_norm: true,
        latent_activation: VisionLatentActivation::default(),
        pos_encoding: SpatialPositionalEncodingKind::Learned2d,
        pos_max_height: grid,
        pos_max_width: grid,
        attention_mode: VisionAttentionMode::RowL1,
        use_alibi: true,
        fused_kernels: FusedKernelConfig::default(),
        mhc: ManifoldHyperConnectionsConfig::default(),
    };
    let mut lejepa_config = VisionLejepaConfig {
        views: 1,
        global_views: 0,
        local_views: 0,
        artifact_every: 0,
        artifact_max_images: 0,
        artifact_max_views: 0,
        ..Default::default()
    };
    lejepa_config.loss.lejepa.enabled = false;
    lejepa_config.loss.recon.weight = 1.0;
    lejepa_config.loss.recon.mask_ratio = 0.5;
    lejepa_config.loss.recon.hidden_dim = 64;

    let rollout = VisionRollout {
        min_steps: 1,
        max_steps: 1,
        backprop_steps: 1,
    };
    let recon_patch_dim =
        vision_config.patch_size * vision_config.patch_size * vision_config.in_channels;
    let model = VisionDragonHatchling::<Backend>::new(vision_config.clone(), &device);
    let mut lejepa = VisionLejepaModel::new(
        model,
        lejepa_config,
        vision_config.embed_dim,
        1,
        rollout,
        recon_patch_dim,
        &device,
    );

    let batch_size = 2;
    let images = toy_images::<Backend>(batch_size, 3, image_size, image_size, &device);
    let labels = Tensor::<Backend, 1, Int>::zeros([batch_size], &device);
    let batch = ImageNetBatch::new(images, None, None, None, None, None, labels, None, None);
    let steps = 1;
    let backprop_steps = 1;

    let initial_psnr = lejepa
        .forward_losses(batch.clone(), steps, backprop_steps, false)
        .recon_psnr
        .to_data()
        .convert::<f32>()
        .into_vec::<f32>()
        .expect("psnr vec")[0];

    let mut optimizer = AdamWConfig::new()
        .with_weight_decay(0.0)
        .init::<Backend, VisionLejepaModel<Backend>>();
    let lr = 0.02;
    for _ in 0..40 {
        let losses = lejepa.forward_losses(batch.clone(), steps, backprop_steps, false);
        let total = losses.total.clone() + losses.probe_loss.clone();
        let grads = GradientsParams::from_grads(total.backward(), &lejepa);
        lejepa = optimizer.step(lr, lejepa, grads);
    }

    let final_psnr = lejepa
        .forward_losses(batch, steps, backprop_steps, false)
        .recon_psnr
        .to_data()
        .convert::<f32>()
        .into_vec::<f32>()
        .expect("psnr vec")[0];

    assert!(final_psnr.is_finite());
    assert!(final_psnr > initial_psnr);
    assert!(final_psnr > 24.0);
}