use crate as burn;
use crate::{
config::Config,
module::Module,
nn::{Dropout, DropoutConfig, Linear, LinearConfig, GELU},
tensor::{backend::Backend, Tensor},
};
#[derive(Config)]
pub struct PositionWiseFeedForwardConfig {
pub d_model: usize,
pub d_ff: usize,
#[config(default = 0.1)]
pub dropout: f64,
}
#[derive(Module, Debug)]
pub struct PositionWiseFeedForward<B: Backend> {
linear_inner: Linear<B>,
linear_outer: Linear<B>,
dropout: Dropout,
gelu: GELU,
}
impl PositionWiseFeedForwardConfig {
pub fn init<B: Backend>(&self) -> PositionWiseFeedForward<B> {
PositionWiseFeedForward {
linear_inner: LinearConfig::new(self.d_model, self.d_ff).init(),
linear_outer: LinearConfig::new(self.d_ff, self.d_model).init(),
dropout: DropoutConfig::new(self.dropout).init(),
gelu: GELU::new(),
}
}
pub fn init_with<B: Backend>(
&self,
record: PositionWiseFeedForwardRecord<B>,
) -> PositionWiseFeedForward<B> {
PositionWiseFeedForward {
linear_inner: LinearConfig::new(self.d_model, self.d_ff).init_with(record.linear_inner),
linear_outer: LinearConfig::new(self.d_ff, self.d_model).init_with(record.linear_outer),
dropout: DropoutConfig::new(self.dropout).init(),
gelu: GELU::new(),
}
}
}
impl<B: Backend> PositionWiseFeedForward<B> {
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
let x = self.linear_inner.forward(input);
let x = self.gelu.forward(x);
let x = self.dropout.forward(x);
self.linear_outer.forward(x)
}
}