use crate::train::prelude::*;
use burn::tensor::Distribution;
use burn::tensor::TensorData;
use burn::tensor::backend::Backend as BackendTrait;
use burn_dragon_core::{
FusedKernelConfig, ManifoldHyperConnectionsConfig, SpatialPositionalEncodingKind,
VisionAttentionMode, VisionLatentActivation, VisionPatchEmbedMode,
};
use burn_dragon_train::VisionLocationEmbeddingMode;
use burn_ndarray::NdArray;
fn make_saccade_model<B: BackendTrait>(
device: &B::Device,
mode: VisionLocationEmbeddingMode,
) -> VisionSaccadeModel<B> {
let vision_config = VisionDragonHatchlingConfig {
image_size: 32,
patch_size: 16,
patch_embed_mode: VisionPatchEmbedMode::default(),
in_channels: 3,
embed_dim: 16,
steps: 1,
n_head: 2,
mlp_internal_dim_multiplier: 2,
dropout: 0.0,
projection_dim: 16,
projection_hidden_dim: 16,
use_cls_token: true,
cls_sync_alpha: 0.0,
num_eyes: 1,
cross_eye_steps: 0,
token_state_norm: true,
latent_activation: VisionLatentActivation::default(),
pos_encoding: SpatialPositionalEncodingKind::Learned2d,
pos_max_height: 2,
pos_max_width: 2,
attention_mode: VisionAttentionMode::RowL1,
use_alibi: true,
fused_kernels: FusedKernelConfig::default(),
mhc: ManifoldHyperConnectionsConfig::default(),
};
let model = VisionDragonHatchling::<B>::new(vision_config.clone(), device);
let mut saccade_config = VisionSaccadeConfig {
num_eyes: vision_config.num_eyes,
mip_levels: 2,
..Default::default()
};
saccade_config.loss.recon.weight = 0.0;
saccade_config.policy.location_embedding.mode = mode;
saccade_config.policy.location_embedding.embed_dim = 8;
let rollout = VisionRollout {
min_steps: 1,
max_steps: 1,
backprop_steps: 1,
};
let recon_patch_dim =
vision_config.patch_size * vision_config.patch_size * vision_config.in_channels;
VisionSaccadeModel::new(
model,
saccade_config,
vision_config.embed_dim,
vision_config.patch_size,
rollout,
recon_patch_dim,
1,
0,
device,
)
}
fn tensor_scalar<B: BackendTrait>(tensor: Tensor<B, 1>) -> f32 {
tensor
.to_data()
.convert::<f32>()
.into_vec::<f32>()
.expect("tensor vec")[0]
}
#[test]
fn location_embedding_none_is_invariant() {
type Backend = NdArray<f32>;
let device = <Backend as BackendTrait>::Device::default();
let saccade = make_saccade_model::<Backend>(&device, VisionLocationEmbeddingMode::None);
let input_context = Tensor::<Backend, 3>::random([2, 1, 16], Distribution::Default, &device);
let state_context = Tensor::<Backend, 3>::random([2, 1, 16], Distribution::Default, &device);
let mean_a = Tensor::<Backend, 3>::from_data(
TensorData::new(vec![0.2, 0.3, 0.7, 0.8], [2, 1, 2]),
&device,
);
let sigma_a =
Tensor::<Backend, 3>::from_data(TensorData::new(vec![0.15, 0.35], [2, 1, 1]), &device);
let mean_b = Tensor::<Backend, 3>::from_data(
TensorData::new(vec![0.6, 0.4, 0.1, 0.9], [2, 1, 2]),
&device,
);
let sigma_b =
Tensor::<Backend, 3>::from_data(TensorData::new(vec![0.25, 0.45], [2, 1, 1]), &device);
let tokens_a = saccade.build_input_tokens(
input_context.clone(),
state_context.clone(),
mean_a,
sigma_a,
);
let tokens_b = saccade.build_input_tokens(input_context, state_context, mean_b, sigma_b);
let mse = (tokens_a - tokens_b).powf_scalar(2.0).mean();
assert!(tensor_scalar(mse) < 1e-6);
}
#[test]
fn location_embedding_learned_changes_tokens() {
type Backend = NdArray<f32>;
let device = <Backend as BackendTrait>::Device::default();
let saccade = make_saccade_model::<Backend>(&device, VisionLocationEmbeddingMode::Learned);
let input_context = Tensor::<Backend, 3>::random([2, 1, 16], Distribution::Default, &device);
let state_context = Tensor::<Backend, 3>::random([2, 1, 16], Distribution::Default, &device);
let mean_a = Tensor::<Backend, 3>::from_data(
TensorData::new(vec![0.2, 0.3, 0.7, 0.8], [2, 1, 2]),
&device,
);
let sigma_a =
Tensor::<Backend, 3>::from_data(TensorData::new(vec![0.15, 0.35], [2, 1, 1]), &device);
let mean_b = Tensor::<Backend, 3>::from_data(
TensorData::new(vec![0.6, 0.4, 0.1, 0.9], [2, 1, 2]),
&device,
);
let sigma_b =
Tensor::<Backend, 3>::from_data(TensorData::new(vec![0.25, 0.45], [2, 1, 1]), &device);
let tokens_a = saccade.build_input_tokens(
input_context.clone(),
state_context.clone(),
mean_a,
sigma_a,
);
let tokens_b = saccade.build_input_tokens(input_context, state_context, mean_b, sigma_b);
let mse = (tokens_a - tokens_b).powf_scalar(2.0).mean();
assert!(tensor_scalar(mse) > 1e-6);
}
#[test]
fn location_embedding_rope_changes_tokens() {
type Backend = NdArray<f32>;
let device = <Backend as BackendTrait>::Device::default();
let saccade = make_saccade_model::<Backend>(&device, VisionLocationEmbeddingMode::Rope);
let input_context = Tensor::<Backend, 3>::random([2, 1, 16], Distribution::Default, &device);
let state_context = Tensor::<Backend, 3>::random([2, 1, 16], Distribution::Default, &device);
let mean_a = Tensor::<Backend, 3>::from_data(
TensorData::new(vec![0.2, 0.3, 0.7, 0.8], [2, 1, 2]),
&device,
);
let sigma_a =
Tensor::<Backend, 3>::from_data(TensorData::new(vec![0.15, 0.35], [2, 1, 1]), &device);
let mean_b = Tensor::<Backend, 3>::from_data(
TensorData::new(vec![0.6, 0.4, 0.1, 0.9], [2, 1, 2]),
&device,
);
let sigma_b =
Tensor::<Backend, 3>::from_data(TensorData::new(vec![0.25, 0.45], [2, 1, 1]), &device);
let tokens_a = saccade.build_input_tokens(
input_context.clone(),
state_context.clone(),
mean_a,
sigma_a,
);
let tokens_b = saccade.build_input_tokens(input_context, state_context, mean_b, sigma_b);
let mse = (tokens_a - tokens_b).powf_scalar(2.0).mean();
assert!(tensor_scalar(mse) > 1e-6);
}
#[test]
fn location_embedding_pope_changes_tokens() {
type Backend = NdArray<f32>;
let device = <Backend as BackendTrait>::Device::default();
let saccade = make_saccade_model::<Backend>(&device, VisionLocationEmbeddingMode::Pope);
let input_context = Tensor::<Backend, 3>::random([2, 1, 16], Distribution::Default, &device);
let state_context = Tensor::<Backend, 3>::random([2, 1, 16], Distribution::Default, &device);
let mean_a = Tensor::<Backend, 3>::from_data(
TensorData::new(vec![0.2, 0.3, 0.7, 0.8], [2, 1, 2]),
&device,
);
let sigma_a =
Tensor::<Backend, 3>::from_data(TensorData::new(vec![0.15, 0.35], [2, 1, 1]), &device);
let mean_b = Tensor::<Backend, 3>::from_data(
TensorData::new(vec![0.6, 0.4, 0.1, 0.9], [2, 1, 2]),
&device,
);
let sigma_b =
Tensor::<Backend, 3>::from_data(TensorData::new(vec![0.25, 0.45], [2, 1, 1]), &device);
let tokens_a = saccade.build_input_tokens(
input_context.clone(),
state_context.clone(),
mean_a,
sigma_a,
);
let tokens_b = saccade.build_input_tokens(input_context, state_context, mean_b, sigma_b);
let mse = (tokens_a - tokens_b).powf_scalar(2.0).mean();
assert!(tensor_scalar(mse) > 1e-6);
}
#[test]
fn location_embedding_pope_fills_double_budget() {
type Backend = NdArray<f32>;
let device = <Backend as BackendTrait>::Device::default();
let mut saccade = make_saccade_model::<Backend>(&device, VisionLocationEmbeddingMode::Pope);
saccade.config.policy.location_embedding.embed_dim = 4;
let input_context = Tensor::<Backend, 3>::zeros([2, 1, 16], &device);
let state_context = Tensor::<Backend, 3>::zeros([2, 1, 16], &device);
let mean_a = Tensor::<Backend, 3>::from_data(
TensorData::new(vec![0.2, 0.3, 0.7, 0.8], [2, 1, 2]),
&device,
);
let sigma_a =
Tensor::<Backend, 3>::from_data(TensorData::new(vec![0.15, 0.35], [2, 1, 1]), &device);
let mean_b = Tensor::<Backend, 3>::from_data(
TensorData::new(vec![0.6, 0.4, 0.1, 0.9], [2, 1, 2]),
&device,
);
let sigma_b =
Tensor::<Backend, 3>::from_data(TensorData::new(vec![0.25, 0.45], [2, 1, 1]), &device);
let tokens_a = saccade.build_input_tokens(
input_context.clone(),
state_context.clone(),
mean_a,
sigma_a,
);
let tokens_b = saccade.build_input_tokens(input_context, state_context, mean_b, sigma_b);
let delta = tokens_a - tokens_b;
let embed_dim = delta.shape().dims::<3>()[2];
let base_dim = saccade
.config
.policy
.location_embedding
.embed_dim
.min(embed_dim / 2);
let pope_dim = base_dim * 2;
let mid = delta.clone().slice_dim(2, base_dim..pope_dim);
let mid_norm = mid.abs().sum();
assert!(tensor_scalar(mid_norm) > 1e-6);
if pope_dim < embed_dim {
let tail = delta.slice_dim(2, pope_dim..embed_dim);
let tail_norm = tail.abs().sum();
assert!(tensor_scalar(tail_norm) < 1e-6);
}
}