burn_core/nn/transformer/
pwff.rs1use crate as burn;
2
3use crate::module::{Content, DisplaySettings, Module, ModuleDisplay};
4use crate::nn::Initializer;
5use crate::{
6 config::Config,
7 nn::{Dropout, DropoutConfig, Gelu, Linear, LinearConfig},
8 tensor::{Tensor, backend::Backend},
9};
10
11#[derive(Config)]
13pub struct PositionWiseFeedForwardConfig {
14 pub d_model: usize,
16 pub d_ff: usize,
18 #[config(default = 0.1)]
20 pub dropout: f64,
21 #[config(
23 default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
24 )]
25 pub initializer: Initializer,
26}
27
28#[derive(Module, Debug)]
39#[module(custom_display)]
40pub struct PositionWiseFeedForward<B: Backend> {
41 pub linear_inner: Linear<B>,
43 pub linear_outer: Linear<B>,
45 pub dropout: Dropout,
47 pub gelu: Gelu,
49}
50
51impl<B: Backend> ModuleDisplay for PositionWiseFeedForward<B> {
52 fn custom_settings(&self) -> Option<DisplaySettings> {
53 DisplaySettings::new()
54 .with_new_line_after_attribute(false)
55 .optional()
56 }
57
58 fn custom_content(&self, content: Content) -> Option<Content> {
59 let [d_model, dff] = self.linear_inner.weight.shape().dims();
60
61 content
62 .add("d_model", &d_model)
63 .add("d_ff", &dff)
64 .add("prob", &self.dropout.prob)
65 .optional()
66 }
67}
68
69impl PositionWiseFeedForwardConfig {
70 pub fn init<B: Backend>(&self, device: &B::Device) -> PositionWiseFeedForward<B> {
72 PositionWiseFeedForward {
73 linear_inner: LinearConfig::new(self.d_model, self.d_ff)
74 .with_initializer(self.initializer.clone())
75 .init(device),
76 linear_outer: LinearConfig::new(self.d_ff, self.d_model)
77 .with_initializer(self.initializer.clone())
78 .init(device),
79 dropout: DropoutConfig::new(self.dropout).init(),
80 gelu: Gelu::new(),
81 }
82 }
83}
84
85impl<B: Backend> PositionWiseFeedForward<B> {
86 pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
93 let x = self.linear_inner.forward(input);
94 let x = self.gelu.forward(x);
95 let x = self.dropout.forward(x);
96
97 self.linear_outer.forward(x)
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104 use crate::TestBackend;
105
106 #[test]
107 fn display() {
108 let config = PositionWiseFeedForwardConfig::new(2, 4);
109 let pwff = config.init::<TestBackend>(&Default::default());
110
111 assert_eq!(
112 alloc::format!("{}", pwff),
113 "PositionWiseFeedForward {d_model: 2, d_ff: 4, prob: 0.1, params: 22}"
114 );
115 }
116}