use crate as burn;
use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
use crate::nn::Initializer;
use crate::{
config::Config,
nn::{Dropout, DropoutConfig, Gelu, Linear, LinearConfig},
tensor::{backend::Backend, Tensor},
};
#[derive(Config)]
pub struct PositionWiseFeedForwardConfig {
pub d_model: usize,
pub d_ff: usize,
#[config(default = 0.1)]
pub dropout: f64,
#[config(
default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
)]
pub initializer: Initializer,
}
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct PositionWiseFeedForward<B: Backend> {
pub linear_inner: Linear<B>,
pub linear_outer: Linear<B>,
pub dropout: Dropout,
pub gelu: Gelu,
}
impl<B: Backend> ModuleDisplay for PositionWiseFeedForward<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
let [d_model, dff] = self.linear_inner.weight.shape().dims;
content
.add("d_model", &d_model)
.add("d_ff", &dff)
.add("prob", &self.dropout.prob)
.optional()
}
}
impl PositionWiseFeedForwardConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> PositionWiseFeedForward<B> {
PositionWiseFeedForward {
linear_inner: LinearConfig::new(self.d_model, self.d_ff)
.with_initializer(self.initializer.clone())
.init(device),
linear_outer: LinearConfig::new(self.d_ff, self.d_model)
.with_initializer(self.initializer.clone())
.init(device),
dropout: DropoutConfig::new(self.dropout).init(),
gelu: Gelu::new(),
}
}
}
impl<B: Backend> PositionWiseFeedForward<B> {
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
let x = self.linear_inner.forward(input);
let x = self.gelu.forward(x);
let x = self.dropout.forward(x);
self.linear_outer.forward(x)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
#[test]
fn display() {
let config = PositionWiseFeedForwardConfig::new(2, 4);
let pwff = config.init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{}", pwff),
"PositionWiseFeedForward {d_model: 2, d_ff: 4, prob: 0.1, params: 22}"
);
}
}