brainharmony 0.1.0

Brain-Harmony multimodal brain foundation model — inference in Rust with Burn ML
Documentation
use burn::backend::NdArray;
use burn::prelude::*;

type B = NdArray;

fn device() -> burn::backend::ndarray::NdArrayDevice {
    burn::backend::ndarray::NdArrayDevice::Cpu
}

// -- FlexiPatchEmbed --------------------------------------------------------------

#[test]
fn patch_embed_output_shape() {
    let patch_size = 48;
    let n_rois = 10;
    let signal_length = 48 * 4; // 4 patches
    let embed_dim = 768;

    let pe = brainharmony::model::patch_embed::FlexiPatchEmbed::<B>::new(
        (n_rois, signal_length),
        patch_size,
        1,
        embed_dim,
        &device(),
    );

    let input = Tensor::<B, 4>::zeros([1, 1, n_rois, signal_length], &device());
    let output = pe.forward(input, None);

    let expected_patches = n_rois * (signal_length / patch_size);
    assert_eq!(output.dims(), [1, expected_patches, embed_dim]);
}

#[test]
fn patch_embed_num_patches() {
    let pe = brainharmony::model::patch_embed::FlexiPatchEmbed::<B>::new(
        (400, 864),
        48,
        1,
        768,
        &device(),
    );
    assert_eq!(pe.num_patches, 400 * 18);
    assert_eq!(pe.num_patches_2d, (400, 18));
}

#[test]
fn patch_embed_batch_forward() {
    let pe = brainharmony::model::patch_embed::FlexiPatchEmbed::<B>::new(
        (8, 96),
        48,
        1,
        128,
        &device(),
    );
    let input = Tensor::<B, 4>::zeros([4, 1, 8, 96], &device());
    let output = pe.forward(input, None);
    assert_eq!(output.dims(), [4, 8 * 2, 128]);
}

// -- Block ------------------------------------------------------------------------

#[test]
fn block_forward_preserves_shape() {
    let dim = 64;
    let num_heads = 4;
    let block = brainharmony::model::block::Block::<B>::new(
        dim, num_heads, 4.0, true, 1e-6, &device(),
    );

    let input = Tensor::<B, 3>::zeros([1, 100, dim], &device());
    let output = block.forward(input, None);
    assert_eq!(output.dims(), [1, 100, dim]);
}

#[test]
fn block_forward_batch() {
    let dim = 64;
    let block = brainharmony::model::block::Block::<B>::new(
        dim, 4, 4.0, true, 1e-6, &device(),
    );

    let input = Tensor::<B, 3>::zeros([2, 50, dim], &device());
    let output = block.forward(input, None);
    assert_eq!(output.dims(), [2, 50, dim]);
}

// -- Attention --------------------------------------------------------------------

#[test]
fn attention_forward_preserves_shape() {
    let dim = 64;
    let num_heads = 4;
    let attn = brainharmony::model::attention::Attention::<B>::new(
        dim, num_heads, true, &device(),
    );

    let input = Tensor::<B, 3>::zeros([1, 100, dim], &device());
    let output = attn.forward(input, None);
    assert_eq!(output.dims(), [1, 100, dim]);
}

#[test]
fn attention_head_dim_and_scale() {
    let dim = 128;
    let num_heads = 8;
    let attn = brainharmony::model::attention::Attention::<B>::new(
        dim, num_heads, true, &device(),
    );
    assert_eq!(attn.num_heads, 8);
    assert_eq!(attn.head_dim, 16);
    let expected_scale = (16.0f64).powf(-0.5) as f32;
    assert!((attn.scale - expected_scale).abs() < 1e-6);
}

#[test]
fn attention_no_qkv_bias() {
    let attn = brainharmony::model::attention::Attention::<B>::new(
        64, 4, false, &device(),
    );
    assert!(attn.qkv.bias.is_none());
}

// -- MLP --------------------------------------------------------------------------

#[test]
fn mlp_forward_preserves_shape() {
    let dim = 64;
    let hidden = 256;
    let mlp = brainharmony::model::feedforward::MLP::<B>::new(dim, hidden, &device());

    let input = Tensor::<B, 3>::zeros([1, 100, dim], &device());
    let output = mlp.forward(input);
    assert_eq!(output.dims(), [1, 100, dim]);
}

#[test]
fn mlp_zeros_produce_zeros() {
    let dim = 16;
    let mlp = brainharmony::model::feedforward::MLP::<B>::new(dim, 64, &device());

    let input = Tensor::<B, 3>::zeros([1, 5, dim], &device());
    let output = mlp.forward(input);

    use burn::prelude::ElementConversion;
    let max_abs: f32 = output.abs().max().into_scalar().elem();
    assert!(max_abs < 1e-6, "expected all zeros, max abs = {max_abs}");
}

// -- MLPHead ----------------------------------------------------------------------

#[test]
fn mlp_head_forward_shape() {
    let head = brainharmony::MLPHead::<B>::new(768, 384, 2, &device());
    let input = Tensor::<B, 2>::zeros([2, 768], &device());
    let output = head.forward(input);
    assert_eq!(output.dims(), [2, 2]);
}