burn_dragon_vision 0.4.0

Foveation and vision sampling utilities for burn dragon
Documentation
#[cfg(not(target_arch = "wasm32"))]
use crate::train::init_wgpu_test_runtime;
use crate::train::prelude::*;
use burn_ndarray::NdArray;

#[cfg(all(feature = "cuda", not(target_arch = "wasm32")))]
use burn_cuda::Cuda;
#[cfg(not(target_arch = "wasm32"))]
use burn_wgpu::Wgpu;

fn token_count_for_patch(patch_size: usize) -> usize {
    let image_size = 128usize;
    let patch_size = patch_size.max(1);
    let grid = (image_size / patch_size).max(1);
    grid * grid
}

fn assert_finite(values: &[f32]) {
    assert!(values.iter().all(|value| value.is_finite()));
}

fn run_projection_backend<B: BackendTrait>(
    device: &B::Device,
    config: VisionSaccadeInputProjectionConfig,
    patch_size: usize,
) {
    let embed_dim = 32usize;
    let batch = 2usize;
    let tokens = token_count_for_patch(patch_size);
    let input = Tensor::<B, 3>::random(
        [batch, tokens, embed_dim],
        TensorDistribution::Default,
        device,
    );
    let projection = VisionSaccadeInputProjection::new(embed_dim, patch_size, &config, device);
    let output = projection.forward(input.clone());
    assert_eq!(output.shape().dims::<3>(), input.shape().dims::<3>());
    let data = output
        .to_data()
        .convert::<f32>()
        .into_vec::<f32>()
        .expect("output vec");
    assert_finite(&data);
}

fn run_projection_variants<B: BackendTrait>(device: &B::Device, patch_size: usize) {
    run_projection_backend::<B>(
        device,
        VisionSaccadeInputProjectionConfig::Linear,
        patch_size,
    );
    run_projection_backend::<B>(
        device,
        VisionSaccadeInputProjectionConfig::Cnn(VisionSaccadeInputProjectionCnnConfig::default()),
        patch_size,
    );
    run_projection_backend::<B>(
        device,
        VisionSaccadeInputProjectionConfig::RadialMicroVit(
            VisionSaccadeInputProjectionMicroVitConfig::default(),
        ),
        patch_size,
    );
}

#[test]
fn input_projection_forward_shapes_across_patch_sizes() {
    type Backend = NdArray<f32>;
    let device = <Backend as BackendTrait>::Device::default();
    let patch_sizes = [16usize, 32, 64, 128];
    for patch_size in patch_sizes {
        run_projection_variants::<Backend>(&device, patch_size);
    }
}

#[cfg(not(target_arch = "wasm32"))]
#[test]
fn input_projection_forward_shapes_wgpu() {
    type Backend = Wgpu<f32>;
    let device = <Backend as BackendTrait>::Device::default();
    init_wgpu_test_runtime(&device);
    let patch_sizes = [16usize, 32, 64, 128];
    for patch_size in patch_sizes {
        run_projection_variants::<Backend>(&device, patch_size);
    }
}

#[cfg(all(feature = "cuda", not(target_arch = "wasm32")))]
#[test]
fn input_projection_forward_shapes_cuda() {
    type Backend = Cuda<f32>;
    let device = burn_cuda::CudaDevice::default();
    let patch_sizes = [16usize, 32, 64, 128];
    for patch_size in patch_sizes {
        run_projection_variants::<Backend>(&device, patch_size);
    }
}

#[test]
fn input_projection_linear_param_count_matches_formula() {
    type Backend = NdArray<f32>;
    let device = <Backend as BackendTrait>::Device::default();
    let embed_dim = 32usize;
    let patch_size = 32usize;
    let config = VisionSaccadeInputProjectionConfig::Linear;
    let projection: VisionSaccadeInputProjection<Backend> =
        VisionSaccadeInputProjection::new(embed_dim, patch_size, &config, &device);
    let expected = embed_dim * embed_dim + 3 * embed_dim;
    assert_eq!(projection.param_count(), expected);
}