concision_transformer/model/
sublayer.rs

1/*
2    Appellation: sublayer <module>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5#![cfg(feature = "rand")]
6use concision::nn::DropoutLayer;
7use concision::Forward;
8use linear::{Biased, LayerNorm, ParamMode, Unbiased};
9use nd::prelude::*;
10use nd::{DataOwned, RemoveAxis, ScalarOperand};
11use num::traits::{Float, FromPrimitive};
12
13/// A residual connection followed by a [layer norm](LayerNorm)
14/// [Transformer](crate::Transformer)
15pub struct Sublayer<A = f64, K = Biased, D = Ix2>
16where
17    D: Dimension,
18{
19    pub(crate) dropout: DropoutLayer,
20    pub(crate) norm: LayerNorm<A, K, D>,
21}
22
23impl<A, K, D> Sublayer<A, K, D>
24where
25    D: RemoveAxis,
26{
27    pub fn new<Sh>(shape: Sh, dropout: f64) -> Self
28    where
29        A: Default,
30        K: ParamMode,
31        Sh: ShapeBuilder<Dim = D>,
32    {
33        Self {
34            dropout: DropoutLayer::new(dropout),
35            norm: LayerNorm::new(shape),
36        }
37    }
38
39    pub fn dropout(&self) -> &DropoutLayer {
40        &self.dropout
41    }
42
43    pub fn norm(&self) -> &LayerNorm<A, K, D> {
44        &self.norm
45    }
46}
47
48impl<A, S, D> Forward<ArrayBase<S, D>> for Sublayer<A, Biased, D>
49where
50    A: Float + FromPrimitive + ScalarOperand,
51    D: RemoveAxis,
52    S: DataOwned<Elem = A>,
53{
54    type Output = Array<A, D>;
55
56    fn forward(&self, input: &ArrayBase<S, D>) -> Self::Output {
57        let normal = self.norm().forward(input);
58        input + self.dropout().forward(&normal)
59    }
60}
61
62impl<A, S, D> Forward<ArrayBase<S, D>> for Sublayer<A, Unbiased, D>
63where
64    A: Float + FromPrimitive + ScalarOperand,
65    D: RemoveAxis,
66    S: DataOwned<Elem = A>,
67{
68    type Output = Array<A, D>;
69
70    fn forward(&self, input: &ArrayBase<S, D>) -> Self::Output {
71        let normal = self.norm().forward(input);
72        input + self.dropout().forward(&normal)
73    }
74}