use anyhow::{Result, ensure};
use burn::module::{Module, Param};
use burn::nn::{Linear, LinearConfig};
use burn::tensor::activation;
use burn::tensor::backend::Backend;
use burn::tensor::{Distribution as TensorDistribution, Tensor};
use serde::{Deserialize, Serialize};
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
use std::any::TypeId;
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
mod chunk_wy;
#[cfg(feature = "cuda")]
pub type Gdn2CudaKernelBackend =
burn_cubecl::CubeBackend<burn_cubecl::cubecl::cuda::CudaRuntime, f32, i32, u8>;
#[cfg(feature = "cuda")]
pub type Gdn2CudaBf16KernelBackend =
burn_cubecl::CubeBackend<burn_cubecl::cubecl::cuda::CudaRuntime, burn::tensor::bf16, i32, u8>;
#[cfg(feature = "cuda")]
type Gdn2CudaKernelTensor = burn::tensor::ops::FloatTensor<Gdn2CudaKernelBackend>;
#[cfg(feature = "cuda")]
type Gdn2CudaAutodiffBackend = burn_autodiff::Autodiff<Gdn2CudaKernelBackend>;
#[cfg(feature = "cuda")]
type Gdn2CudaAutodiffTensor = burn::tensor::ops::FloatTensor<Gdn2CudaAutodiffBackend>;
#[cfg(feature = "wgpu-kernel")]
pub type Gdn2WgpuKernelBackend =
burn_cubecl::CubeBackend<burn_cubecl::cubecl::wgpu::WgpuRuntime, f32, i32, u32>;
#[cfg(feature = "wgpu-kernel")]
type Gdn2WgpuKernelTensor = burn::tensor::ops::FloatTensor<Gdn2WgpuKernelBackend>;
#[cfg(feature = "wgpu-kernel")]
type Gdn2WgpuAutodiffBackend = burn_autodiff::Autodiff<Gdn2WgpuKernelBackend>;
#[cfg(feature = "wgpu-kernel")]
type Gdn2WgpuAutodiffTensor = burn::tensor::ops::FloatTensor<Gdn2WgpuAutodiffBackend>;
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum GatedDeltaNet2Executor {
#[default]
Reference,
ChunkWy,
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum GatedDeltaNet2GateMode {
#[default]
Channel,
Scalar,
Disabled,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum GatedDeltaNet2KernelPath {
Reference,
ReferenceFallback,
ForwardKernel,
CustomBackward,
CudaTensorCoreBackward,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(default)]
pub struct GatedDeltaNet2Config {
pub heads: usize,
pub latent_per_head: usize,
pub chunk_size: usize,
pub qk_l2_norm: bool,
pub allow_neg_eigval: bool,
pub erase_gate: GatedDeltaNet2GateMode,
pub write_gate: GatedDeltaNet2GateMode,
pub decay_gate: GatedDeltaNet2GateMode,
pub state_epsilon: f32,
pub output_scale: f32,
pub executor: GatedDeltaNet2Executor,
}
impl Default for GatedDeltaNet2Config {
fn default() -> Self {
Self {
heads: 12,
latent_per_head: 64,
chunk_size: 64,
qk_l2_norm: true,
allow_neg_eigval: false,
erase_gate: GatedDeltaNet2GateMode::Channel,
write_gate: GatedDeltaNet2GateMode::Channel,
decay_gate: GatedDeltaNet2GateMode::Channel,
state_epsilon: 1.0e-6,
output_scale: 1.0,
executor: GatedDeltaNet2Executor::Reference,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct ResolvedGatedDeltaNet2Config {
pub heads: usize,
pub head_dim: usize,
pub latent_per_head: usize,
pub chunk_size: usize,
pub qk_l2_norm: bool,
pub allow_neg_eigval: bool,
pub erase_gate: GatedDeltaNet2GateMode,
pub write_gate: GatedDeltaNet2GateMode,
pub decay_gate: GatedDeltaNet2GateMode,
pub state_epsilon: f64,
pub output_scale: f64,
pub executor: GatedDeltaNet2Executor,
}
impl GatedDeltaNet2Config {
pub fn validate(&self, hidden_dim: usize) -> Result<()> {
ensure!(hidden_dim > 0, "GDN2 hidden_dim must be nonzero");
ensure!(self.heads > 0, "gdn2.heads must be nonzero");
ensure!(
hidden_dim.is_multiple_of(self.heads),
"GDN2 hidden_dim must be divisible by gdn2.heads"
);
ensure!(
self.latent_per_head > 0,
"gdn2.latent_per_head must be nonzero"
);
ensure!(self.chunk_size > 0, "gdn2.chunk_size must be nonzero");
ensure!(
self.state_epsilon.is_finite() && self.state_epsilon > 0.0,
"gdn2.state_epsilon must be finite and positive"
);
ensure!(
self.output_scale.is_finite(),
"gdn2.output_scale must be finite"
);
Ok(())
}
pub fn resolve(&self, hidden_dim: usize) -> Result<ResolvedGatedDeltaNet2Config> {
self.validate(hidden_dim)?;
Ok(ResolvedGatedDeltaNet2Config {
heads: self.heads,
head_dim: hidden_dim / self.heads,
latent_per_head: self.latent_per_head,
chunk_size: self.chunk_size.max(1),
qk_l2_norm: self.qk_l2_norm,
allow_neg_eigval: self.allow_neg_eigval,
erase_gate: self.erase_gate,
write_gate: self.write_gate,
decay_gate: self.decay_gate,
state_epsilon: self.state_epsilon.max(1.0e-12) as f64,
output_scale: self.output_scale as f64,
executor: self.executor,
})
}
}
#[derive(Module, Debug)]
pub struct GatedDeltaNet2Memory<B: Backend> {
pub query: Linear<B>,
pub key: Linear<B>,
pub value: Linear<B>,
pub erase: Linear<B>,
pub write: Linear<B>,
pub decay: Linear<B>,
pub decay_log: Param<Tensor<B, 2>>,
pub out: Linear<B>,
#[module(skip)]
pub config: ResolvedGatedDeltaNet2Config,
}
impl<B: Backend> GatedDeltaNet2Memory<B> {
pub fn new(
hidden_dim: usize,
config: GatedDeltaNet2Config,
device: &B::Device,
) -> Result<Self> {
let resolved = config.resolve(hidden_dim)?;
let latent_width = resolved.heads * resolved.latent_per_head;
let mut out = LinearConfig::new(hidden_dim, hidden_dim)
.with_bias(false)
.init(device);
out.bias = None;
Ok(Self {
query: LinearConfig::new(hidden_dim, latent_width).init(device),
key: LinearConfig::new(hidden_dim, latent_width).init(device),
value: LinearConfig::new(hidden_dim, hidden_dim).init(device),
erase: LinearConfig::new(hidden_dim, latent_width).init(device),
write: LinearConfig::new(hidden_dim, hidden_dim).init(device),
decay: LinearConfig::new(hidden_dim, latent_width).init(device),
decay_log: Param::from_tensor(Tensor::<B, 2>::random(
[resolved.heads, resolved.latent_per_head],
TensorDistribution::Normal(0.0, 1.0e-5),
device,
)),
out,
config: resolved,
})
}
pub fn forward(
&self,
x: Tensor<B, 3>,
state: &mut Option<Tensor<B, 4>>,
update_state: bool,
) -> Tensor<B, 3> {
let [batch, tokens, hidden_dim] = x.shape().dims::<3>();
debug_assert_eq!(hidden_dim, self.config.heads * self.config.head_dim);
let device = x.device();
if !update_state && state.is_none() {
return Tensor::<B, 3>::zeros([batch, tokens, hidden_dim], &device);
}
let mut query = self.project_latent(self.query.forward(x.clone()));
if self.config.qk_l2_norm {
query = l2_normalize_last(query, self.config.state_epsilon);
}
let memory = state
.take()
.filter(|memory| {
memory.shape().dims::<4>()
== [
batch,
self.config.heads,
self.config.latent_per_head,
self.config.head_dim,
]
})
.unwrap_or_else(|| {
Tensor::<B, 4>::zeros(
[
batch,
self.config.heads,
self.config.latent_per_head,
self.config.head_dim,
],
&device,
)
});
if !update_state {
let output = query
.matmul(memory.clone())
.mul_scalar((self.config.latent_per_head as f64).powf(-0.5))
.permute([0, 2, 1, 3])
.reshape([batch, tokens, hidden_dim]);
*state = Some(memory);
return apply_output_scale(self.out.forward(output), self.config.output_scale);
}
let mut key = self.project_latent(self.key.forward(x.clone()));
if self.config.qk_l2_norm {
key = l2_normalize_last(key, self.config.state_epsilon);
}
let value = self.project_hidden(self.value.forward(x.clone()));
let erase = apply_gate_mode(
self.project_latent(self.erase.forward(x.clone())),
self.config.erase_gate,
self.config.allow_neg_eigval,
);
let write = apply_gate_mode(
self.project_hidden(self.write.forward(x.clone())),
self.config.write_gate,
false,
);
let log_decay = self.log_decay(self.decay.forward(x));
let (output, next_state) = match self.config.executor {
GatedDeltaNet2Executor::Reference => gated_deltanet2_reference(
query,
key,
value,
erase,
write,
log_decay,
memory,
self.config.chunk_size,
),
GatedDeltaNet2Executor::ChunkWy => chunk_wy_or_reference(
query,
key,
value,
erase,
write,
log_decay,
memory,
self.config.chunk_size,
),
};
*state = Some(next_state);
let output = output
.permute([0, 2, 1, 3])
.reshape([batch, tokens, hidden_dim]);
apply_output_scale(self.out.forward(output), self.config.output_scale)
}
fn project_latent(&self, tensor: Tensor<B, 3>) -> Tensor<B, 4> {
let [batch, tokens, _] = tensor.shape().dims::<3>();
tensor
.reshape([
batch,
tokens,
self.config.heads,
self.config.latent_per_head,
])
.permute([0, 2, 1, 3])
}
fn project_hidden(&self, tensor: Tensor<B, 3>) -> Tensor<B, 4> {
let [batch, tokens, _] = tensor.shape().dims::<3>();
tensor
.reshape([batch, tokens, self.config.heads, self.config.head_dim])
.permute([0, 2, 1, 3])
}
fn log_decay(&self, tensor: Tensor<B, 3>) -> Tensor<B, 4> {
let [batch, tokens, _] = tensor.shape().dims::<3>();
if matches!(self.config.decay_gate, GatedDeltaNet2GateMode::Disabled) {
return Tensor::<B, 4>::zeros(
[
batch,
self.config.heads,
tokens,
self.config.latent_per_head,
],
&tensor.device(),
);
}
let logits = match self.config.decay_gate {
GatedDeltaNet2GateMode::Channel => self.project_latent(tensor),
GatedDeltaNet2GateMode::Scalar => {
let latent = self.project_latent(tensor);
latent
.mean_dim(3)
.repeat_dim(3, self.config.latent_per_head)
}
GatedDeltaNet2GateMode::Disabled => unreachable!(),
};
let decay_rate = self.decay_log.val().exp().reshape([
1,
self.config.heads,
1,
self.config.latent_per_head,
]);
activation::softplus(logits, 1.0)
.mul(decay_rate)
.mul_scalar(-1.0)
}
}
fn apply_output_scale<B: Backend>(tensor: Tensor<B, 3>, scale: f64) -> Tensor<B, 3> {
if (scale - 1.0).abs() <= f64::EPSILON {
tensor
} else {
tensor.mul_scalar(scale)
}
}
#[allow(clippy::too_many_arguments)]
pub fn gated_deltanet2_reference<B: Backend>(
query: Tensor<B, 4>,
key: Tensor<B, 4>,
value: Tensor<B, 4>,
erase: Tensor<B, 4>,
write: Tensor<B, 4>,
log_decay: Tensor<B, 4>,
mut state: Tensor<B, 4>,
chunk_size: usize,
) -> (Tensor<B, 4>, Tensor<B, 4>) {
let [batch, heads, time, latent] = query.shape().dims::<4>();
let dense = value.shape().dims::<4>()[3];
let scale = (latent as f64).powf(-0.5);
let mut outputs = Vec::with_capacity(time);
for chunk_start in (0..time).step_by(chunk_size.max(1)) {
let chunk_end = (chunk_start + chunk_size.max(1)).min(time);
for t in chunk_start..chunk_end {
let q_t = query.clone().slice_dim(2, t..t + 1);
let k_t = key.clone().slice_dim(2, t..t + 1);
let v_t = value.clone().slice_dim(2, t..t + 1);
let erase_t = erase.clone().slice_dim(2, t..t + 1);
let write_t = write.clone().slice_dim(2, t..t + 1);
let decay_t = log_decay
.clone()
.slice_dim(2, t..t + 1)
.exp()
.swap_dims(2, 3);
state = state * decay_t;
let erased_key = erase_t * k_t.clone();
let erased_value = (state.clone() * erased_key.swap_dims(2, 3))
.sum_dim(2)
.reshape([batch, heads, 1, dense]);
let write_value = write_t * v_t - erased_value;
state = state + k_t.swap_dims(2, 3) * write_value;
let output = (state.clone() * q_t.swap_dims(2, 3))
.sum_dim(2)
.reshape([batch, heads, 1, dense])
.mul_scalar(scale);
outputs.push(output);
}
}
(Tensor::cat(outputs, 2), state)
}
#[allow(clippy::too_many_arguments)]
fn chunk_wy_or_reference<B: Backend>(
query: Tensor<B, 4>,
key: Tensor<B, 4>,
value: Tensor<B, 4>,
erase: Tensor<B, 4>,
write: Tensor<B, 4>,
log_decay: Tensor<B, 4>,
state: Tensor<B, 4>,
chunk_size: usize,
) -> (Tensor<B, 4>, Tensor<B, 4>) {
if let Some(output) = try_gdn2_chunk_wy(
query.clone(),
key.clone(),
value.clone(),
erase.clone(),
write.clone(),
log_decay.clone(),
state.clone(),
chunk_size,
) {
return (output.context, output.state);
}
gated_deltanet2_reference(
query, key, value, erase, write, log_decay, state, chunk_size,
)
}
#[derive(Debug)]
pub struct GatedDeltaNet2KernelOutput<B: Backend> {
pub context: Tensor<B, 4>,
pub state: Tensor<B, 4>,
}
#[allow(clippy::too_many_arguments)]
pub fn try_gdn2_chunk_wy<B: Backend>(
query: Tensor<B, 4>,
key: Tensor<B, 4>,
value: Tensor<B, 4>,
erase: Tensor<B, 4>,
write: Tensor<B, 4>,
log_decay: Tensor<B, 4>,
initial_state: Tensor<B, 4>,
chunk_size: usize,
) -> Option<GatedDeltaNet2KernelOutput<B>>
where
B::FloatTensorPrimitive: 'static,
{
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
{
if let Some(output) = chunk_wy::try_gdn2_chunk_wy_custom_backward(
query.clone(),
key.clone(),
value.clone(),
erase.clone(),
write.clone(),
log_decay.clone(),
initial_state.clone(),
chunk_size,
) {
return Some(GatedDeltaNet2KernelOutput {
context: output.context,
state: output.state,
});
}
chunk_wy::try_gdn2_chunk_wy_forward(
query,
key,
value,
erase,
write,
log_decay,
initial_state,
chunk_size,
)
.map(|output| GatedDeltaNet2KernelOutput {
context: output.context,
state: output.state,
})
}
#[cfg(not(any(feature = "cuda", feature = "wgpu-kernel")))]
{
let _ = (
query,
key,
value,
erase,
write,
log_decay,
initial_state,
chunk_size,
);
None
}
}
pub fn supports_gdn2_chunk_wy_backend<B: Backend>() -> bool
where
B::FloatTensorPrimitive: 'static,
B::FloatElem: 'static,
{
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
{
if TypeId::of::<B::FloatElem>() != TypeId::of::<f32>() {
return false;
}
let tensor_type = TypeId::of::<B::FloatTensorPrimitive>();
#[cfg(feature = "cuda")]
if tensor_type == TypeId::of::<Gdn2CudaKernelTensor>()
|| tensor_type == TypeId::of::<Gdn2CudaAutodiffTensor>()
{
return true;
}
#[cfg(feature = "wgpu-kernel")]
if tensor_type == TypeId::of::<Gdn2WgpuKernelTensor>()
|| tensor_type == TypeId::of::<Gdn2WgpuAutodiffTensor>()
{
return true;
}
false
}
#[cfg(not(any(feature = "cuda", feature = "wgpu-kernel")))]
{
false
}
}
pub fn gdn2_kernel_path<B: Backend>(executor: GatedDeltaNet2Executor) -> GatedDeltaNet2KernelPath
where
B::FloatTensorPrimitive: 'static,
B::FloatElem: 'static,
{
match executor {
GatedDeltaNet2Executor::Reference => GatedDeltaNet2KernelPath::Reference,
GatedDeltaNet2Executor::ChunkWy => {
if !supports_gdn2_chunk_wy_backend::<B>() {
return GatedDeltaNet2KernelPath::ReferenceFallback;
}
#[cfg(any(feature = "cuda", feature = "wgpu-kernel"))]
if !chunk_wy::gdn2_chunk_wy_custom_backward_enabled() {
return GatedDeltaNet2KernelPath::ForwardKernel;
}
#[cfg(feature = "cuda")]
{
let tensor_type = TypeId::of::<B::FloatTensorPrimitive>();
if chunk_wy::gdn2_cuda_tensor_core_backward_enabled()
&& (tensor_type == TypeId::of::<Gdn2CudaKernelTensor>()
|| tensor_type == TypeId::of::<Gdn2CudaAutodiffTensor>())
{
return GatedDeltaNet2KernelPath::CudaTensorCoreBackward;
}
}
GatedDeltaNet2KernelPath::CustomBackward
}
}
}
pub fn gdn2_kernel_status<B: Backend>(executor: GatedDeltaNet2Executor) -> &'static str
where
B::FloatTensorPrimitive: 'static,
B::FloatElem: 'static,
{
match gdn2_kernel_path::<B>(executor) {
GatedDeltaNet2KernelPath::Reference => "available",
GatedDeltaNet2KernelPath::ReferenceFallback => "reference_fallback",
GatedDeltaNet2KernelPath::ForwardKernel => "forward_kernel",
GatedDeltaNet2KernelPath::CustomBackward => "custom_backward",
GatedDeltaNet2KernelPath::CudaTensorCoreBackward => "cuda_tensor_core_backward",
}
}
pub fn gdn2_parameter_count(hidden_dim: usize, config: &GatedDeltaNet2Config) -> usize {
let heads = config.heads.max(1);
let latent = config.latent_per_head.max(1);
let latent_width = heads.saturating_mul(latent);
linear_params(hidden_dim, latent_width).saturating_mul(4)
+ linear_params(hidden_dim, hidden_dim).saturating_mul(2)
+ heads.saturating_mul(latent)
+ hidden_dim.saturating_mul(hidden_dim)
}
pub fn gdn2_macs_per_token(hidden_dim: usize, config: &GatedDeltaNet2Config) -> u128 {
let heads = config.heads.max(1);
let latent = config.latent_per_head.max(1);
let head_dim = hidden_dim / heads;
let latent_width = heads.saturating_mul(latent);
linear_macs(hidden_dim, latent_width).saturating_mul(4)
+ linear_macs(hidden_dim, hidden_dim).saturating_mul(3)
+ (heads as u128)
.saturating_mul(latent as u128)
.saturating_mul(head_dim.max(1) as u128)
.saturating_mul(5)
}
fn apply_gate_mode<B: Backend>(
logits: Tensor<B, 4>,
mode: GatedDeltaNet2GateMode,
allow_neg_eigval: bool,
) -> Tensor<B, 4> {
let [batch, heads, time, channels] = logits.shape().dims::<4>();
let device = logits.device();
let gate = match mode {
GatedDeltaNet2GateMode::Channel => activation::sigmoid(logits),
GatedDeltaNet2GateMode::Scalar => {
activation::sigmoid(logits.mean_dim(3)).repeat_dim(3, channels)
}
GatedDeltaNet2GateMode::Disabled => {
Tensor::<B, 4>::ones([batch, heads, time, channels], &device)
}
};
if allow_neg_eigval && !matches!(mode, GatedDeltaNet2GateMode::Disabled) {
gate.mul_scalar(2.0)
} else {
gate
}
}
pub fn l2_normalize_last<B: Backend>(values: Tensor<B, 4>, epsilon: f64) -> Tensor<B, 4> {
let norm = values
.clone()
.powf_scalar(2.0)
.sum_dim(3)
.add_scalar(epsilon)
.sqrt();
values.div(norm)
}
fn linear_params(input: usize, output: usize) -> usize {
input.saturating_mul(output).saturating_add(output)
}
fn linear_macs(input: usize, output: usize) -> u128 {
(input as u128).saturating_mul(output as u128)
}
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
use burn::tensor::TensorData;
type TestBackend = NdArray<f32>;
fn tensor4(values: Vec<f32>, shape: [usize; 4]) -> Tensor<TestBackend, 4> {
Tensor::<TestBackend, 4>::from_data(TensorData::new(values, shape), &Default::default())
}
fn max_abs_diff(lhs: Tensor<TestBackend, 4>, rhs: Tensor<TestBackend, 4>) -> f32 {
let lhs = lhs
.into_data()
.convert::<f32>()
.into_vec::<f32>()
.expect("lhs vec");
let rhs = rhs
.into_data()
.convert::<f32>()
.into_vec::<f32>()
.expect("rhs vec");
lhs.into_iter()
.zip(rhs)
.map(|(left, right)| (left - right).abs())
.fold(0.0f32, f32::max)
}
#[test]
fn reference_matches_one_step_update() {
let query = tensor4(vec![0.7, -0.3], [1, 1, 1, 2]);
let key = tensor4(vec![0.4, 0.6], [1, 1, 1, 2]);
let value = tensor4(vec![0.8, -0.5, 0.25], [1, 1, 1, 3]);
let erase = tensor4(vec![0.25, 0.75], [1, 1, 1, 2]);
let write = tensor4(vec![0.5, 0.2, 0.9], [1, 1, 1, 3]);
let log_decay = tensor4(vec![0.5f32.ln(), 0.25f32.ln()], [1, 1, 1, 2]);
let state = tensor4(vec![0.2, -0.1, 0.05, 0.3, 0.4, -0.2], [1, 1, 2, 3]);
let (output, next_state) =
gated_deltanet2_reference(query, key, value, erase, write, log_decay, state, 64);
let decayed_state = [[0.1f32, -0.05, 0.025], [0.075, 0.1, -0.05]];
let erased_key = [0.25f32 * 0.4, 0.75 * 0.6];
let erased_value = [
decayed_state[0][0] * erased_key[0] + decayed_state[1][0] * erased_key[1],
decayed_state[0][1] * erased_key[0] + decayed_state[1][1] * erased_key[1],
decayed_state[0][2] * erased_key[0] + decayed_state[1][2] * erased_key[1],
];
let write_value = [
0.5f32 * 0.8 - erased_value[0],
0.2 * -0.5 - erased_value[1],
0.9 * 0.25 - erased_value[2],
];
let expected_state = [
decayed_state[0][0] + 0.4 * write_value[0],
decayed_state[0][1] + 0.4 * write_value[1],
decayed_state[0][2] + 0.4 * write_value[2],
decayed_state[1][0] + 0.6 * write_value[0],
decayed_state[1][1] + 0.6 * write_value[1],
decayed_state[1][2] + 0.6 * write_value[2],
];
let output_scale = 2.0f32.sqrt().recip();
let expected_output = [
(0.7 * expected_state[0] - 0.3 * expected_state[3]) * output_scale,
(0.7 * expected_state[1] - 0.3 * expected_state[4]) * output_scale,
(0.7 * expected_state[2] - 0.3 * expected_state[5]) * output_scale,
];
assert!(max_abs_diff(next_state, tensor4(expected_state.to_vec(), [1, 1, 2, 3])) < 1.0e-6);
assert!(max_abs_diff(output, tensor4(expected_output.to_vec(), [1, 1, 1, 3])) < 1.0e-6);
}
#[test]
fn memory_block_preserves_expected_shape_and_state() {
let device = Default::default();
let config = GatedDeltaNet2Config {
heads: 3,
latent_per_head: 4,
output_scale: 0.25,
..GatedDeltaNet2Config::default()
};
let block = GatedDeltaNet2Memory::<TestBackend>::new(12, config, &device)
.expect("GDN2 memory block");
let x = Tensor::<TestBackend, 3>::zeros([2, 5, 12], &device);
let mut state = None;
let y = block.forward(x, &mut state, true);
assert_eq!(y.shape().dims::<3>(), [2, 5, 12]);
assert_eq!(state.expect("state").shape().dims::<4>(), [2, 3, 4, 4]);
}
#[cfg(feature = "cuda")]
#[test]
fn cuda_bf16_backend_reports_reference_fallback_for_f32_chunk_wy_kernel() {
assert_eq!(
gdn2_kernel_status::<Gdn2CudaBf16KernelBackend>(GatedDeltaNet2Executor::ChunkWy),
"reference_fallback"
);
}
}