use burn::backend::NdArray;
use burn::prelude::*;
type B = NdArray;
fn device() -> burn::backend::ndarray::NdArrayDevice {
burn::backend::ndarray::NdArrayDevice::Cpu
}
#[test]
fn patch_embed_output_shape() {
let patch_size = 48;
let n_rois = 10;
let signal_length = 48 * 4; let embed_dim = 768;
let pe = brainharmony::model::patch_embed::FlexiPatchEmbed::<B>::new(
(n_rois, signal_length),
patch_size,
1,
embed_dim,
&device(),
);
let input = Tensor::<B, 4>::zeros([1, 1, n_rois, signal_length], &device());
let output = pe.forward(input, None);
let expected_patches = n_rois * (signal_length / patch_size);
assert_eq!(output.dims(), [1, expected_patches, embed_dim]);
}
#[test]
fn patch_embed_num_patches() {
let pe = brainharmony::model::patch_embed::FlexiPatchEmbed::<B>::new(
(400, 864),
48,
1,
768,
&device(),
);
assert_eq!(pe.num_patches, 400 * 18);
assert_eq!(pe.num_patches_2d, (400, 18));
}
#[test]
fn patch_embed_batch_forward() {
let pe = brainharmony::model::patch_embed::FlexiPatchEmbed::<B>::new(
(8, 96),
48,
1,
128,
&device(),
);
let input = Tensor::<B, 4>::zeros([4, 1, 8, 96], &device());
let output = pe.forward(input, None);
assert_eq!(output.dims(), [4, 8 * 2, 128]);
}
#[test]
fn block_forward_preserves_shape() {
let dim = 64;
let num_heads = 4;
let block = brainharmony::model::block::Block::<B>::new(
dim, num_heads, 4.0, true, 1e-6, &device(),
);
let input = Tensor::<B, 3>::zeros([1, 100, dim], &device());
let output = block.forward(input, None);
assert_eq!(output.dims(), [1, 100, dim]);
}
#[test]
fn block_forward_batch() {
let dim = 64;
let block = brainharmony::model::block::Block::<B>::new(
dim, 4, 4.0, true, 1e-6, &device(),
);
let input = Tensor::<B, 3>::zeros([2, 50, dim], &device());
let output = block.forward(input, None);
assert_eq!(output.dims(), [2, 50, dim]);
}
#[test]
fn attention_forward_preserves_shape() {
let dim = 64;
let num_heads = 4;
let attn = brainharmony::model::attention::Attention::<B>::new(
dim, num_heads, true, &device(),
);
let input = Tensor::<B, 3>::zeros([1, 100, dim], &device());
let output = attn.forward(input, None);
assert_eq!(output.dims(), [1, 100, dim]);
}
#[test]
fn attention_head_dim_and_scale() {
let dim = 128;
let num_heads = 8;
let attn = brainharmony::model::attention::Attention::<B>::new(
dim, num_heads, true, &device(),
);
assert_eq!(attn.num_heads, 8);
assert_eq!(attn.head_dim, 16);
let expected_scale = (16.0f64).powf(-0.5) as f32;
assert!((attn.scale - expected_scale).abs() < 1e-6);
}
#[test]
fn attention_no_qkv_bias() {
let attn = brainharmony::model::attention::Attention::<B>::new(
64, 4, false, &device(),
);
assert!(attn.qkv.bias.is_none());
}
#[test]
fn mlp_forward_preserves_shape() {
let dim = 64;
let hidden = 256;
let mlp = brainharmony::model::feedforward::MLP::<B>::new(dim, hidden, &device());
let input = Tensor::<B, 3>::zeros([1, 100, dim], &device());
let output = mlp.forward(input);
assert_eq!(output.dims(), [1, 100, dim]);
}
#[test]
fn mlp_zeros_produce_zeros() {
let dim = 16;
let mlp = brainharmony::model::feedforward::MLP::<B>::new(dim, 64, &device());
let input = Tensor::<B, 3>::zeros([1, 5, dim], &device());
let output = mlp.forward(input);
use burn::prelude::ElementConversion;
let max_abs: f32 = output.abs().max().into_scalar().elem();
assert!(max_abs < 1e-6, "expected all zeros, max abs = {max_abs}");
}
#[test]
fn mlp_head_forward_shape() {
let head = brainharmony::MLPHead::<B>::new(768, 384, 2, &device());
let input = Tensor::<B, 2>::zeros([2, 768], &device());
let output = head.forward(input);
assert_eq!(output.dims(), [2, 2]);
}