burn_nn/modules/transformer/
pwff.rs1use burn_core as burn;
2
3use crate::{Dropout, DropoutConfig, Gelu, Linear, LinearConfig};
4use burn::config::Config;
5use burn::module::{Content, DisplaySettings, Initializer, Module, ModuleDisplay};
6use burn::tensor::{Tensor, backend::Backend};
7
8#[derive(Config, Debug)]
10pub struct PositionWiseFeedForwardConfig {
11 pub d_model: usize,
13 pub d_ff: usize,
15 #[config(default = 0.1)]
17 pub dropout: f64,
18 #[config(
20 default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
21 )]
22 pub initializer: Initializer,
23}
24
25#[derive(Module, Debug)]
36#[module(custom_display)]
37pub struct PositionWiseFeedForward<B: Backend> {
38 pub linear_inner: Linear<B>,
40 pub linear_outer: Linear<B>,
42 pub dropout: Dropout,
44 pub gelu: Gelu,
46}
47
48impl<B: Backend> ModuleDisplay for PositionWiseFeedForward<B> {
49 fn custom_settings(&self) -> Option<DisplaySettings> {
50 DisplaySettings::new()
51 .with_new_line_after_attribute(false)
52 .optional()
53 }
54
55 fn custom_content(&self, content: Content) -> Option<Content> {
56 let [d_model, dff] = self.linear_inner.weight.shape().dims();
57
58 content
59 .add("d_model", &d_model)
60 .add("d_ff", &dff)
61 .add("prob", &self.dropout.prob)
62 .optional()
63 }
64}
65
66impl PositionWiseFeedForwardConfig {
67 pub fn init<B: Backend>(&self, device: &B::Device) -> PositionWiseFeedForward<B> {
69 PositionWiseFeedForward {
70 linear_inner: LinearConfig::new(self.d_model, self.d_ff)
71 .with_initializer(self.initializer.clone())
72 .init(device),
73 linear_outer: LinearConfig::new(self.d_ff, self.d_model)
74 .with_initializer(self.initializer.clone())
75 .init(device),
76 dropout: DropoutConfig::new(self.dropout).init(),
77 gelu: Gelu::new(),
78 }
79 }
80}
81
82impl<B: Backend> PositionWiseFeedForward<B> {
83 pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
90 let x = self.linear_inner.forward(input);
91 let x = self.gelu.forward(x);
92 let x = self.dropout.forward(x);
93
94 self.linear_outer.forward(x)
95 }
96}
97
98#[cfg(test)]
99mod tests {
100 use super::*;
101 use crate::TestBackend;
102
103 #[test]
104 fn display() {
105 let config = PositionWiseFeedForwardConfig::new(2, 4);
106 let pwff = config.init::<TestBackend>(&Default::default());
107
108 assert_eq!(
109 alloc::format!("{pwff}"),
110 "PositionWiseFeedForward {d_model: 2, d_ff: 4, prob: 0.1, params: 22}"
111 );
112 }
113}