#![allow(dead_code)]
use burn::nn::{Linear, LinearConfig};
use burn::prelude::*;
#[derive(Config, Debug)]
pub struct FeedForwardConfig {
pub d_model: usize,
pub d_intermediate: usize,
}
impl FeedForwardConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> FeedForward<B> {
let no_bias = |d_in, d_out| {
LinearConfig::new(d_in, d_out)
.with_bias(false)
.init(device)
};
FeedForward {
gate_proj: no_bias(self.d_model, self.d_intermediate),
up_proj: no_bias(self.d_model, self.d_intermediate),
down_proj: no_bias(self.d_intermediate, self.d_model),
}
}
}
#[derive(Module, Debug)]
pub struct FeedForward<B: Backend> {
pub(crate) gate_proj: Linear<B>,
pub(crate) up_proj: Linear<B>,
pub(crate) down_proj: Linear<B>,
}
impl<B: Backend> FeedForward<B> {
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let gate = burn::tensor::activation::silu(self.gate_proj.forward(x.clone()));
let up = self.up_proj.forward(x);
self.down_proj.forward(gate * up)
}
}
#[derive(Config, Debug)]
pub struct NemotronMlpConfig {
pub d_model: usize,
pub d_intermediate: usize,
}
impl NemotronMlpConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> NemotronMlp<B> {
let no_bias = |d_in, d_out| {
LinearConfig::new(d_in, d_out)
.with_bias(false)
.init(device)
};
NemotronMlp {
up_proj: no_bias(self.d_model, self.d_intermediate),
down_proj: no_bias(self.d_intermediate, self.d_model),
}
}
}
#[derive(Module, Debug)]
pub struct NemotronMlp<B: Backend> {
pub(crate) up_proj: Linear<B>,
pub(crate) down_proj: Linear<B>,
}
impl<B: Backend> NemotronMlp<B> {
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let h = burn::tensor::activation::relu(self.up_proj.forward(x));
self.down_proj.forward(h.clone() * h)
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
type B = NdArray<f32>;
#[test]
fn forward_preserves_shape() {
let device = Default::default();
let ffn = FeedForwardConfig {
d_model: 64,
d_intermediate: 128,
}
.init::<B>(&device);
let x = Tensor::<B, 3>::zeros([2, 8, 64], &device);
let out = ffn.forward(x);
assert_eq!(out.dims(), [2, 8, 64]);
}
#[test]
fn nemotron_mlp_preserves_shape() {
let device = Default::default();
let mlp = NemotronMlpConfig {
d_model: 64,
d_intermediate: 128,
}
.init::<B>(&device);
let x = Tensor::<B, 3>::zeros([2, 8, 64], &device);
let out = mlp.forward(x);
assert_eq!(out.dims(), [2, 8, 64]);
}
#[test]
fn relu_squared_intermediate_is_non_negative() {
let device = Default::default();
let mlp = NemotronMlpConfig {
d_model: 16,
d_intermediate: 32,
}
.init::<B>(&device);
let x = Tensor::<B, 3>::random(
[1, 4, 16],
burn::tensor::Distribution::Normal(0.0, 1.0),
&device,
);
let h = burn::tensor::activation::relu(mlp.up_proj.forward(x));
let h_sq = h.clone() * h;
let min_val: f32 = h_sq.min().into_scalar().elem();
assert!(min_val >= 0.0, "relu² intermediate should be non-negative, got {min_val}");
}
#[test]
fn silu_gating_produces_nonzero_for_nonzero_input() {
let device = Default::default();
let ffn = FeedForwardConfig {
d_model: 16,
d_intermediate: 32,
}
.init::<B>(&device);
let x = Tensor::<B, 3>::random([1, 4, 16], burn::tensor::Distribution::Normal(0.0, 1.0), &device);
let out = ffn.forward(x);
let abs_sum: f32 = out.abs().sum().into_scalar().elem();
assert!(abs_sum > 0.0, "SiLU gating should produce non-zero output");
}
}