burn_nn/modules/transformer/
pwff.rs1use burn_core as burn;
2
3use crate::activation::{Activation, ActivationConfig};
4use crate::{Dropout, DropoutConfig, Linear, LinearConfig};
5use burn::config::Config;
6use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};
7use burn::tensor::{Tensor, backend::Backend};
8
9#[derive(Config, Debug)]
11pub struct PositionWiseFeedForwardConfig {
12 pub d_model: usize,
14 pub d_ff: usize,
16 #[config(default = 0.1)]
18 pub dropout: f64,
19 #[config(
21 default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
22 )]
23 pub initializer: Initializer,
24 #[config(default = "ActivationConfig::Gelu")]
26 pub activation: ActivationConfig,
27}
28
29#[derive(Module, Debug)]
51#[module(custom_display)]
52pub struct PositionWiseFeedForward<B: Backend> {
53 pub linear_inner: Linear<B>,
55 pub linear_outer: Linear<B>,
57 pub dropout: Dropout,
59 #[module(skip)] pub activation: Activation<B>,
62}
63
64impl<B: Backend> ModuleDisplay for PositionWiseFeedForward<B> {
65 fn custom_settings(&self) -> Option<DisplaySettings> {
66 DisplaySettings::new()
67 .with_new_line_after_attribute(false)
68 .optional()
69 }
70
71 fn custom_content(&self, content: Content) -> Option<Content> {
72 let [d_model, dff] = self.linear_inner.weight.shape().dims();
73
74 content
75 .add("d_model", &d_model)
76 .add("d_ff", &dff)
77 .add("prob", &self.dropout.prob)
78 .optional()
79 }
80}
81
82impl PositionWiseFeedForwardConfig {
83 pub fn init<B: Backend>(&self, device: &B::Device) -> PositionWiseFeedForward<B> {
85 PositionWiseFeedForward {
86 linear_inner: LinearConfig::new(self.d_model, self.d_ff)
87 .with_initializer(self.initializer.clone())
88 .init(device),
89 linear_outer: LinearConfig::new(self.d_ff, self.d_model)
90 .with_initializer(self.initializer.clone())
91 .init(device),
92 dropout: DropoutConfig::new(self.dropout).init(),
93 activation: self.activation.init(device),
94 }
95 }
96}
97
98impl<B: Backend> PositionWiseFeedForward<B> {
99 pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
106 let x = self.linear_inner.forward(input);
107 let x = self.activation.forward(x);
108 let x = self.dropout.forward(x);
109
110 self.linear_outer.forward(x)
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use super::*;
117 use crate::TestBackend;
118
119 #[test]
120 fn display() {
121 let config = PositionWiseFeedForwardConfig::new(2, 4);
122 let pwff = config.init::<TestBackend>(&Default::default());
123
124 assert_eq!(
125 alloc::format!("{pwff}"),
126 "PositionWiseFeedForward {d_model: 2, d_ff: 4, prob: 0.1, params: 22}"
127 );
128 }
129}