use burn::{
module::Module,
nn::{Gelu, Linear, LinearConfig},
tensor::{Tensor, backend::Backend},
};
#[derive(Module, Debug)]
pub struct FeedForward<B: Backend> {
fc1: Linear<B>,
activation: Gelu,
fc2: Linear<B>,
}
impl<B: Backend> FeedForward<B> {
pub fn new(d_model: usize, ffn_dim: usize, device: &B::Device) -> Self {
let fc1 = LinearConfig::new(d_model, ffn_dim).init(device);
let fc2 = LinearConfig::new(ffn_dim, d_model).init(device);
let activation = Gelu::new();
Self {
fc1,
activation,
fc2,
}
}
pub fn forward<const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
let x = self.fc1.forward(x);
let x = self.activation.forward(x);
let x = self.fc2.forward(x);
x
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
type TestBackend = NdArray;
#[test]
fn test_feed_forward_shape() {
let device = Default::default();
let ffn = FeedForward::<TestBackend>::new(384, 1536, &device);
let input = Tensor::<TestBackend, 3>::random(
[2, 10, 384],
burn::tensor::Distribution::Normal(0.0, 1.0),
&device,
);
let output = ffn.forward(input.clone());
assert_eq!(output.shape().dims, [2, 10, 384]);
}
}