use burn::config::Config;
use burn::module::Module;
use burn::nn::{Dropout, DropoutConfig, Gelu, Linear, LinearConfig};
use burn::prelude::*;
use crate::config::AttnResConfig;
#[derive(Config, Debug)]
pub struct FeedForwardConfig {
pub d_model: usize,
pub d_ff: usize,
#[config(default = 0.0)]
pub dropout: f64,
}
impl FeedForwardConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> FeedForward<B> {
FeedForward {
linear1: LinearConfig::new(self.d_model, self.d_ff).init(device),
linear2: LinearConfig::new(self.d_ff, self.d_model).init(device),
gelu: Gelu::new(),
dropout: DropoutConfig::new(self.dropout).init(),
}
}
}
impl AttnResConfig {
pub fn feed_forward_config(&self) -> FeedForwardConfig {
FeedForwardConfig::new(self.d_model, self.effective_d_ff()).with_dropout(self.dropout)
}
}
#[derive(Module, Debug)]
pub struct FeedForward<B: Backend> {
linear1: Linear<B>,
linear2: Linear<B>,
gelu: Gelu,
dropout: Dropout,
}
impl<B: Backend> FeedForward<B> {
pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
let x = self.linear1.forward(x);
let x = self.gelu.forward(x);
let x = self.dropout.forward(x);
let x = self.linear2.forward(x);
self.dropout.forward(x)
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
use burn::tensor::Distribution;
type TestBackend = NdArray;
#[test]
fn test_ff_output_shape() {
let device = Default::default();
let ff = FeedForwardConfig::new(64, 256).init::<TestBackend>(&device);
let x = Tensor::random([2, 16, 64], Distribution::Normal(0.0, 1.0), &device);
let out = ff.forward(x);
assert_eq!(out.dims(), [2, 16, 64]);
}
}