use std::collections::HashMap;
use burn::module::{Module, Param};
use burn::nn::{LayerNorm, LayerNormConfig, LayerNormRecord, Linear, LinearConfig, LinearRecord};
use burn::prelude::*;
use burn::tensor::backend::Backend;
use burn::tensor::TensorData;
use jepa_core::ema::Ema;
use jepa_core::types::Representation;
use jepa_core::Encoder;
use crate::patch::{PatchEmbedding, PatchEmbeddingConfig};
use crate::rope::{RotaryPositionEncoding2D, RotaryPositionEncoding2DConfig};
use crate::token_ops::gather_token_sequence;
#[derive(Debug, Clone, thiserror::Error, PartialEq, Eq)]
pub enum VitLoadError {
#[error("missing checkpoint tensor `{0}`")]
MissingKey(String),
#[error(
"shape mismatch for `{key}`: checkpoint {checkpoint_shape:?} vs model {model_shape:?}"
)]
ShapeMismatch {
key: String,
checkpoint_shape: Vec<usize>,
model_shape: Vec<usize>,
},
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct VitConfig {
pub in_channels: usize,
pub image_height: usize,
pub image_width: usize,
pub patch_size: (usize, usize),
pub embed_dim: usize,
pub num_layers: usize,
pub num_heads: usize,
pub mlp_dim: usize,
pub dropout: f64,
}
impl VitConfig {
pub fn vit_base_patch16() -> Self {
Self {
in_channels: 3,
image_height: 224,
image_width: 224,
patch_size: (16, 16),
embed_dim: 768,
num_layers: 12,
num_heads: 12,
mlp_dim: 3072,
dropout: 0.0,
}
}
pub fn vit_small_patch16() -> Self {
Self {
in_channels: 3,
image_height: 224,
image_width: 224,
patch_size: (16, 16),
embed_dim: 384,
num_layers: 12,
num_heads: 6,
mlp_dim: 1536,
dropout: 0.0,
}
}
pub fn vit_large_patch16() -> Self {
Self {
in_channels: 3,
image_height: 224,
image_width: 224,
patch_size: (16, 16),
embed_dim: 1024,
num_layers: 24,
num_heads: 16,
mlp_dim: 4096,
dropout: 0.0,
}
}
pub fn vit_huge_patch14() -> Self {
Self {
in_channels: 3,
image_height: 224,
image_width: 224,
patch_size: (14, 14),
embed_dim: 1280,
num_layers: 32,
num_heads: 16,
mlp_dim: 5120,
dropout: 0.0,
}
}
pub fn vit_huge_patch16_448() -> Self {
Self {
in_channels: 3,
image_height: 448,
image_width: 448,
patch_size: (16, 16),
embed_dim: 1280,
num_layers: 32,
num_heads: 16,
mlp_dim: 5120,
dropout: 0.0,
}
}
pub fn vit_giant_patch16() -> Self {
Self {
in_channels: 3,
image_height: 224,
image_width: 224,
patch_size: (16, 16),
embed_dim: 1408,
num_layers: 40,
num_heads: 16,
mlp_dim: 6144,
dropout: 0.0,
}
}
pub fn tiny_test() -> Self {
Self {
in_channels: 1,
image_height: 8,
image_width: 8,
patch_size: (2, 2),
embed_dim: 32,
num_layers: 2,
num_heads: 4,
mlp_dim: 64,
dropout: 0.0,
}
}
fn grid_height(&self) -> usize {
self.image_height / self.patch_size.0
}
fn grid_width(&self) -> usize {
self.image_width / self.patch_size.1
}
pub fn init<B: Backend>(&self, device: &B::Device) -> VitEncoder<B> {
let patch_embed_config = PatchEmbeddingConfig::new(
self.in_channels,
self.patch_size.0,
self.patch_size.1,
self.embed_dim,
);
let patch_embed = patch_embed_config.init(device);
let rope_config = RotaryPositionEncoding2DConfig::new(
self.embed_dim,
self.grid_height(),
self.grid_width(),
);
let positional_encoding = rope_config.init(device);
let blocks: Vec<TransformerBlock<B>> = (0..self.num_layers)
.map(|_| {
TransformerBlockConfig {
embed_dim: self.embed_dim,
num_heads: self.num_heads,
mlp_dim: self.mlp_dim,
}
.init(device)
})
.collect();
let norm = LayerNormConfig::new(self.embed_dim).init(device);
VitEncoder {
patch_embed,
positional_encoding,
blocks,
norm,
embed_dim: self.embed_dim,
}
}
}
#[derive(Module, Debug)]
pub struct VitEncoder<B: Backend> {
patch_embed: PatchEmbedding<B>,
positional_encoding: RotaryPositionEncoding2D<B>,
blocks: Vec<TransformerBlock<B>>,
norm: LayerNorm<B>,
embed_dim: usize,
}
impl<B: Backend> VitEncoder<B> {
fn positioned_patch_tokens(&self, images: &Tensor<B, 4>) -> Tensor<B, 3> {
let x = self.patch_embed.forward(images.clone());
self.positional_encoding.forward(x)
}
fn encode_positioned_tokens(&self, mut x: Tensor<B, 3>) -> Representation<B> {
for block in &self.blocks {
x = block.forward(x);
}
x = self.norm.forward(x);
Representation::new(x)
}
pub fn forward(&self, images: &Tensor<B, 4>) -> Representation<B> {
let x = self.positioned_patch_tokens(images);
self.encode_positioned_tokens(x)
}
pub fn forward_visible_tokens(
&self,
images: &Tensor<B, 4>,
visible_indices: &[usize],
) -> Representation<B> {
let x = self.positioned_patch_tokens(images);
let x = gather_token_sequence(x, visible_indices);
self.encode_positioned_tokens(x)
}
pub fn load_named_tensors(
self,
tensors: &HashMap<String, TensorData>,
) -> Result<Self, VitLoadError> {
let mut record = self.clone().into_record();
load_linear_record(
&mut record.patch_embed.projection,
"patch_embed.projection",
tensors,
)?;
for (index, block) in record.blocks.iter_mut().enumerate() {
load_layer_norm_record(&mut block.norm1, &format!("blocks.{index}.norm1"), tensors)?;
load_linear_record(
&mut block.attn.qkv,
&format!("blocks.{index}.attn.qkv"),
tensors,
)?;
load_linear_record(
&mut block.attn.out_proj,
&format!("blocks.{index}.attn.out_proj"),
tensors,
)?;
load_layer_norm_record(&mut block.norm2, &format!("blocks.{index}.norm2"), tensors)?;
load_linear_record(
&mut block.mlp.fc1,
&format!("blocks.{index}.mlp.fc1"),
tensors,
)?;
load_linear_record(
&mut block.mlp.fc2,
&format!("blocks.{index}.mlp.fc2"),
tensors,
)?;
}
load_layer_norm_record(&mut record.norm, "norm", tensors)?;
Ok(self.load_record(record))
}
pub fn ema_update_from(self, online: &Self, ema: &Ema, step: usize) -> Self {
let mut target_record = self.clone().into_record();
let online_record = online.clone().into_record();
ema_update_linear_record(
&mut target_record.patch_embed.projection,
&online_record.patch_embed.projection,
ema,
step,
);
for (target_block, online_block) in target_record
.blocks
.iter_mut()
.zip(online_record.blocks.iter())
{
ema_update_layer_norm_record(&mut target_block.norm1, &online_block.norm1, ema, step);
ema_update_linear_record(
&mut target_block.attn.qkv,
&online_block.attn.qkv,
ema,
step,
);
ema_update_linear_record(
&mut target_block.attn.out_proj,
&online_block.attn.out_proj,
ema,
step,
);
ema_update_layer_norm_record(&mut target_block.norm2, &online_block.norm2, ema, step);
ema_update_linear_record(&mut target_block.mlp.fc1, &online_block.mlp.fc1, ema, step);
ema_update_linear_record(&mut target_block.mlp.fc2, &online_block.mlp.fc2, ema, step);
}
ema_update_layer_norm_record(&mut target_record.norm, &online_record.norm, ema, step);
self.load_record(target_record)
}
}
impl<B: Backend> Encoder<B> for VitEncoder<B> {
type Input = Tensor<B, 4>;
fn encode(&self, input: &Self::Input) -> Representation<B> {
self.forward(input)
}
fn embed_dim(&self) -> usize {
self.embed_dim
}
}
fn load_linear_record<B: Backend>(
record: &mut LinearRecord<B>,
prefix: &str,
tensors: &HashMap<String, TensorData>,
) -> Result<(), VitLoadError> {
load_param_from_tensors(&mut record.weight, &format!("{prefix}.weight"), tensors)?;
load_optional_param_from_tensors(&mut record.bias, &format!("{prefix}.bias"), tensors)?;
Ok(())
}
fn load_layer_norm_record<B: Backend>(
record: &mut LayerNormRecord<B>,
prefix: &str,
tensors: &HashMap<String, TensorData>,
) -> Result<(), VitLoadError> {
load_param_from_tensors(&mut record.gamma, &format!("{prefix}.weight"), tensors)?;
load_optional_param_from_tensors(&mut record.beta, &format!("{prefix}.bias"), tensors)?;
Ok(())
}
fn load_param_from_tensors<B: Backend, const D: usize>(
param: &mut Param<Tensor<B, D>>,
key: &str,
tensors: &HashMap<String, TensorData>,
) -> Result<(), VitLoadError> {
let tensor = tensors
.get(key)
.ok_or_else(|| VitLoadError::MissingKey(key.to_string()))?;
let expected_shape = param.lazy_shape().dims;
if tensor.shape != expected_shape {
return Err(VitLoadError::ShapeMismatch {
key: key.to_string(),
checkpoint_shape: tensor.shape.clone(),
model_shape: expected_shape,
});
}
*param = param
.clone()
.load_record(Param::from_data(tensor.clone(), ¶m.lazy_device()));
Ok(())
}
fn load_optional_param_from_tensors<B: Backend, const D: usize>(
param: &mut Option<Param<Tensor<B, D>>>,
key: &str,
tensors: &HashMap<String, TensorData>,
) -> Result<(), VitLoadError> {
let Some(inner) = param else {
return Ok(());
};
load_param_from_tensors(inner, key, tensors)
}
fn ema_update_linear_record<B: Backend>(
target: &mut LinearRecord<B>,
online: &LinearRecord<B>,
ema: &Ema,
step: usize,
) {
ema_update_param(&mut target.weight, &online.weight, ema, step);
ema_update_optional_param(&mut target.bias, &online.bias, ema, step);
}
fn ema_update_layer_norm_record<B: Backend>(
target: &mut LayerNormRecord<B>,
online: &LayerNormRecord<B>,
ema: &Ema,
step: usize,
) {
ema_update_param(&mut target.gamma, &online.gamma, ema, step);
ema_update_optional_param(&mut target.beta, &online.beta, ema, step);
}
fn ema_update_param<B: Backend, const D: usize>(
target: &mut Param<Tensor<B, D>>,
online: &Param<Tensor<B, D>>,
ema: &Ema,
step: usize,
) {
let param_id = target.clone().consume().0;
let updated = ema.update_tensor(target.val().detach(), &online.val().detach(), step);
let record = Param::initialized(param_id, updated.detach());
*target = target.clone().load_record(record);
}
fn ema_update_optional_param<B: Backend, const D: usize>(
target: &mut Option<Param<Tensor<B, D>>>,
online: &Option<Param<Tensor<B, D>>>,
ema: &Ema,
step: usize,
) {
let (Some(target), Some(online)) = (target, online) else {
return;
};
ema_update_param(target, online, ema, step);
}
#[derive(Debug, Clone)]
struct TransformerBlockConfig {
embed_dim: usize,
num_heads: usize,
mlp_dim: usize,
}
impl TransformerBlockConfig {
fn init<B: Backend>(&self, device: &B::Device) -> TransformerBlock<B> {
TransformerBlock {
norm1: LayerNormConfig::new(self.embed_dim).init(device),
attn: MultiHeadSelfAttentionConfig {
embed_dim: self.embed_dim,
num_heads: self.num_heads,
}
.init(device),
norm2: LayerNormConfig::new(self.embed_dim).init(device),
mlp: MlpConfig {
in_dim: self.embed_dim,
hidden_dim: self.mlp_dim,
}
.init(device),
}
}
}
#[derive(Module, Debug)]
struct TransformerBlock<B: Backend> {
norm1: LayerNorm<B>,
attn: MultiHeadSelfAttention<B>,
norm2: LayerNorm<B>,
mlp: Mlp<B>,
}
impl<B: Backend> TransformerBlock<B> {
fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let residual = x.clone();
let x_norm = self.norm1.forward(x);
let attn_out = self.attn.forward(x_norm);
let x = residual + attn_out;
let residual = x.clone();
let x_norm = self.norm2.forward(x);
let mlp_out = self.mlp.forward(x_norm);
residual + mlp_out
}
}
#[derive(Debug, Clone)]
struct MultiHeadSelfAttentionConfig {
embed_dim: usize,
num_heads: usize,
}
impl MultiHeadSelfAttentionConfig {
fn init<B: Backend>(&self, device: &B::Device) -> MultiHeadSelfAttention<B> {
let head_dim = self.embed_dim / self.num_heads;
MultiHeadSelfAttention {
qkv: LinearConfig::new(self.embed_dim, 3 * self.embed_dim).init(device),
out_proj: LinearConfig::new(self.embed_dim, self.embed_dim).init(device),
num_heads: self.num_heads,
head_dim,
}
}
}
#[derive(Module, Debug)]
struct MultiHeadSelfAttention<B: Backend> {
qkv: Linear<B>,
out_proj: Linear<B>,
num_heads: usize,
head_dim: usize,
}
impl<B: Backend> MultiHeadSelfAttention<B> {
fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let [batch, seq_len, _embed_dim] = x.dims();
let embed_dim = self.num_heads * self.head_dim;
let qkv = self.qkv.forward(x);
let q = qkv.clone().slice([0..batch, 0..seq_len, 0..embed_dim]);
let k = qkv
.clone()
.slice([0..batch, 0..seq_len, embed_dim..2 * embed_dim]);
let v = qkv.slice([0..batch, 0..seq_len, 2 * embed_dim..3 * embed_dim]);
let q = q
.reshape([batch, seq_len, self.num_heads, self.head_dim])
.swap_dims(1, 2);
let k = k
.reshape([batch, seq_len, self.num_heads, self.head_dim])
.swap_dims(1, 2);
let v = v
.reshape([batch, seq_len, self.num_heads, self.head_dim])
.swap_dims(1, 2);
let scale = (self.head_dim as f64).sqrt();
let attn_weights = q.matmul(k.transpose()) / scale; let attn_weights = burn::tensor::activation::softmax(attn_weights, 3);
let out = attn_weights.matmul(v);
let out = out.swap_dims(1, 2).reshape([batch, seq_len, embed_dim]);
self.out_proj.forward(out)
}
}
#[derive(Debug, Clone)]
struct MlpConfig {
in_dim: usize,
hidden_dim: usize,
}
impl MlpConfig {
fn init<B: Backend>(&self, device: &B::Device) -> Mlp<B> {
Mlp {
fc1: LinearConfig::new(self.in_dim, self.hidden_dim).init(device),
fc2: LinearConfig::new(self.hidden_dim, self.in_dim).init(device),
}
}
}
#[derive(Module, Debug)]
struct Mlp<B: Backend> {
fc1: Linear<B>,
fc2: Linear<B>,
}
impl<B: Backend> Mlp<B> {
fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let x = self.fc1.forward(x);
let x = burn::tensor::activation::gelu(x);
self.fc2.forward(x)
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn_ndarray::NdArray;
use std::collections::HashMap;
type TestBackend = NdArray<f32>;
fn device() -> burn_ndarray::NdArrayDevice {
burn_ndarray::NdArrayDevice::Cpu
}
#[test]
fn test_vit_encoder_output_shape() {
let config = VitConfig::tiny_test();
let encoder = config.init::<TestBackend>(&device());
let images: Tensor<TestBackend, 4> = Tensor::zeros([2, 1, 8, 8], &device());
let repr = encoder.forward(&images);
assert_eq!(repr.batch_size(), 2);
assert_eq!(repr.seq_len(), 16);
assert_eq!(repr.embed_dim(), 32);
}
#[test]
fn test_vit_encoder_trait_impl() {
let config = VitConfig::tiny_test();
let encoder = config.init::<TestBackend>(&device());
let images: Tensor<TestBackend, 4> = Tensor::zeros([1, 1, 8, 8], &device());
let repr = Encoder::encode(&encoder, &images);
assert_eq!(repr.batch_size(), 1);
assert_eq!(repr.seq_len(), 16);
assert_eq!(encoder.embed_dim(), 32);
}
#[test]
fn test_vit_encoder_different_inputs_different_outputs() {
let config = VitConfig::tiny_test();
let encoder = config.init::<TestBackend>(&device());
let a: Tensor<TestBackend, 4> = Tensor::zeros([1, 1, 8, 8], &device());
let b: Tensor<TestBackend, 4> = Tensor::ones([1, 1, 8, 8], &device());
let repr_a = encoder.forward(&a);
let repr_b = encoder.forward(&b);
let diff: f32 = (repr_a.embeddings - repr_b.embeddings)
.abs()
.sum()
.into_scalar()
.elem();
assert!(
diff > 1e-6,
"different inputs should produce different representations"
);
}
#[test]
fn test_transformer_block_residual() {
let block = TransformerBlockConfig {
embed_dim: 16,
num_heads: 2,
mlp_dim: 32,
}
.init::<TestBackend>(&device());
let x: Tensor<TestBackend, 3> = Tensor::zeros([1, 4, 16], &device());
let out = block.forward(x);
assert_eq!(out.dims(), [1, 4, 16]);
}
#[test]
fn test_mhsa_output_shape() {
let attn = MultiHeadSelfAttentionConfig {
embed_dim: 16,
num_heads: 4,
}
.init::<TestBackend>(&device());
let x: Tensor<TestBackend, 3> = Tensor::zeros([2, 8, 16], &device());
let out = attn.forward(x);
assert_eq!(out.dims(), [2, 8, 16]);
}
#[test]
fn test_mlp_output_shape() {
let mlp = MlpConfig {
in_dim: 16,
hidden_dim: 64,
}
.init::<TestBackend>(&device());
let x: Tensor<TestBackend, 3> = Tensor::zeros([2, 8, 16], &device());
let out = mlp.forward(x);
assert_eq!(out.dims(), [2, 8, 16]);
}
fn checkpoint_tensors_from_encoder(
encoder: &VitEncoder<TestBackend>,
) -> HashMap<String, TensorData> {
let record = encoder.clone().into_record();
let mut tensors = HashMap::new();
insert_linear_tensors(
&mut tensors,
"patch_embed.projection",
&record.patch_embed.projection,
);
for (index, block) in record.blocks.iter().enumerate() {
insert_layer_norm_tensors(&mut tensors, &format!("blocks.{index}.norm1"), &block.norm1);
insert_linear_tensors(
&mut tensors,
&format!("blocks.{index}.attn.qkv"),
&block.attn.qkv,
);
insert_linear_tensors(
&mut tensors,
&format!("blocks.{index}.attn.out_proj"),
&block.attn.out_proj,
);
insert_layer_norm_tensors(&mut tensors, &format!("blocks.{index}.norm2"), &block.norm2);
insert_linear_tensors(
&mut tensors,
&format!("blocks.{index}.mlp.fc1"),
&block.mlp.fc1,
);
insert_linear_tensors(
&mut tensors,
&format!("blocks.{index}.mlp.fc2"),
&block.mlp.fc2,
);
}
insert_layer_norm_tensors(&mut tensors, "norm", &record.norm);
tensors
}
fn insert_linear_tensors(
tensors: &mut HashMap<String, TensorData>,
prefix: &str,
record: &LinearRecord<TestBackend>,
) {
tensors.insert(format!("{prefix}.weight"), record.weight.val().to_data());
if let Some(bias) = &record.bias {
tensors.insert(format!("{prefix}.bias"), bias.val().to_data());
}
}
fn insert_layer_norm_tensors(
tensors: &mut HashMap<String, TensorData>,
prefix: &str,
record: &LayerNormRecord<TestBackend>,
) {
tensors.insert(format!("{prefix}.weight"), record.gamma.val().to_data());
if let Some(beta) = &record.beta {
tensors.insert(format!("{prefix}.bias"), beta.val().to_data());
}
}
#[test]
fn test_vit_encoder_load_named_tensors_restores_encoder_state() {
let config = VitConfig::tiny_test();
let source = config.init::<TestBackend>(&device());
let target = config.init::<TestBackend>(&device());
let tensors = checkpoint_tensors_from_encoder(&source);
let loaded = target
.load_named_tensors(&tensors)
.expect("loading tensors exported from a matching encoder should succeed");
let images: Tensor<TestBackend, 4> = Tensor::random(
[2, 1, 8, 8],
burn::tensor::Distribution::Normal(0.0, 1.0),
&device(),
);
let source_repr = source.forward(&images);
let loaded_repr = loaded.forward(&images);
let diff: f32 = (source_repr.embeddings - loaded_repr.embeddings)
.abs()
.sum()
.into_scalar()
.elem();
assert!(
diff < 1e-6,
"loading the exported tensors should restore the encoder exactly, diff={diff}"
);
}
#[test]
fn test_vit_encoder_load_named_tensors_rejects_shape_mismatch() {
let config = VitConfig::tiny_test();
let encoder = config.init::<TestBackend>(&device());
let mut tensors = checkpoint_tensors_from_encoder(&encoder);
tensors.insert(
"norm.weight".to_string(),
TensorData::new(vec![1.0f32; 31], [31]),
);
let err = config
.init::<TestBackend>(&device())
.load_named_tensors(&tensors)
.expect_err("shape mismatch should be reported");
assert!(matches!(
err,
VitLoadError::ShapeMismatch { key, .. } if key == "norm.weight"
));
}
#[test]
fn test_vit_encoder_ema_update_moves_target_toward_online() {
let config = VitConfig::tiny_test();
let target = config.init::<TestBackend>(&device());
let online = config.init::<TestBackend>(&device());
let ema = Ema::new(0.5);
let images: Tensor<TestBackend, 4> = Tensor::random(
[1, 1, 8, 8],
burn::tensor::Distribution::Normal(0.0, 1.0),
&device(),
);
let target_before = target.forward(&images);
let online_before = online.forward(&images);
let updated = target.clone().ema_update_from(&online, &ema, 0);
let updated_repr = updated.forward(&images);
let before_distance: f32 = (target_before.embeddings.clone()
- online_before.embeddings.clone())
.abs()
.sum()
.into_scalar()
.elem();
let after_distance: f32 = (updated_repr.embeddings - online_before.embeddings)
.abs()
.sum()
.into_scalar()
.elem();
assert!(
after_distance < before_distance,
"EMA update should move target toward online encoder"
);
}
use burn::tensor::ElementConversion;
use proptest::prelude::*;
proptest! {
#[test]
fn prop_vit_output_is_finite(batch in 1usize..3) {
let config = VitConfig::tiny_test();
let encoder = config.init::<TestBackend>(&device());
let images: Tensor<TestBackend, 4> = Tensor::random(
[batch, 1, 8, 8],
burn::tensor::Distribution::Normal(0.0, 1.0),
&device(),
);
let repr = encoder.forward(&images);
prop_assert_eq!(repr.batch_size(), batch);
prop_assert_eq!(repr.seq_len(), 16);
prop_assert_eq!(repr.embed_dim(), 32);
let total: f32 = repr.embeddings.abs().sum().into_scalar().elem();
prop_assert!(total.is_finite(), "ViT output should be finite, got {}", total);
}
#[test]
fn prop_vit_is_deterministic(batch in 1usize..3) {
let config = VitConfig::tiny_test();
let encoder = config.init::<TestBackend>(&device());
let images: Tensor<TestBackend, 4> = Tensor::ones([batch, 1, 8, 8], &device());
let repr1 = encoder.forward(&images);
let repr2 = encoder.forward(&images);
let diff: f32 = (repr1.embeddings - repr2.embeddings)
.abs()
.sum()
.into_scalar()
.elem();
prop_assert!(diff < 1e-6, "ViT should be deterministic, diff={}", diff);
}
#[test]
fn prop_transformer_block_preserves_shape(
seq_len in 2usize..8,
num_heads in proptest::sample::select(vec![2usize, 4]),
) {
let embed_dim = 16; let block = TransformerBlockConfig {
embed_dim,
num_heads,
mlp_dim: embed_dim * 4,
}
.init::<TestBackend>(&device());
let x: Tensor<TestBackend, 3> = Tensor::random(
[1, seq_len, embed_dim],
burn::tensor::Distribution::Normal(0.0, 1.0),
&device(),
);
let out = block.forward(x);
prop_assert_eq!(out.dims(), [1, seq_len, embed_dim]);
let total: f32 = out.abs().sum().into_scalar().elem();
prop_assert!(total.is_finite(), "block output should be finite");
}
}
}