oxidized_transformers/architectures/
encoder.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
use std::fmt::Debug;

use candle_core::Tensor;
use candle_nn::VarBuilder;

use crate::architectures::output::LayerOutputs;
use crate::architectures::BuildArchitecture;
use crate::error::BoxedError;
use crate::layers::attention::AttentionMask;

/// Encoder output.
pub struct EncoderOutput {
    all_outputs: Vec<Tensor>,
}

impl EncoderOutput {
    /// Create an encoder output.
    pub fn new(all_outputs: Vec<Tensor>) -> Self {
        Self { all_outputs }
    }
}

impl LayerOutputs for EncoderOutput {
    /// All layer outputs.
    fn layer_outputs(&self) -> &[Tensor] {
        &self.all_outputs
    }

    /// Embedding layer output.
    fn embedding_layer_output(&self) -> Option<&Tensor> {
        self.all_outputs.first()
    }
}

/// Trait for encoders.
pub trait Encoder {
    /// Encode an input sequence.
    ///
    /// Returns the encoder output.
    ///
    /// * `piece_ids` - Input sequence.
    ///   *Shape:* `(batch_size, seq_len)`
    /// * `attention_mask` - Attention mask. Sequence elements for which the
    ///   corresponding mask element is set to `false` are ignored during
    ///   attention calculation.
    ///   *Shape:* `(batch_size, seq_len)`
    /// * `positions` - Input positions.
    ///   *Shape:* `(batch_size, seq_len)`
    /// * `type_ids` - Input type ids.
    ///   *Shape:* `(batch_size, seq_len)`
    /// * `train` - Whether to train the layer.
    fn forward_t(
        &self,
        piece_ids: &Tensor,
        attention_mask: &AttentionMask,
        positions: Option<&Tensor>,
        type_ids: Option<&Tensor>,
        train: bool,
    ) -> Result<EncoderOutput, BoxedError>;
}

/// Trait for building encoders.
pub trait BuildEncoder: Debug {
    /// Encoder type.
    type Encoder: Encoder;

    /// Build an encoder.
    fn build(&self, vb: VarBuilder) -> Result<Self::Encoder, BoxedError>;
}

impl<D> BuildEncoder for D
where
    D: BuildArchitecture + Debug,
    D::Architecture: Encoder,
{
    type Encoder = D::Architecture;

    fn build(&self, vb: VarBuilder) -> Result<Self::Encoder, BoxedError> {
        self.build(vb)
    }
}

/// Trait for encoder layers.
pub trait EncoderLayer {
    /// Apply the encoder layer to the given hidden representations.
    ///
    /// * `input` - Hidden representations to apply the layer to.
    ///   *Shape:* `(batch_size, seq_len, width)`
    /// * `attention_mask` - Attention mask. Sequence elements for which the
    ///    corresponding mask element is set to `false` are ignored
    ///    during attention calculation.
    ///   *Shape:* `(batch_size, seq_len)`
    /// * `positions` - Input positions.
    ///    *Shape:* `(batch_size, seq_len)`
    /// * `train` - Whether to train the layer.
    ///
    /// Returns layer output and the cache.
    /// *Shape:* ``(batch_size, seq_len, width)``
    fn forward_t(
        &self,
        input: &Tensor,
        mask: &AttentionMask,
        positions: Option<&Tensor>,
        train: bool,
    ) -> Result<Tensor, BoxedError>;
}

/// Trait for building encoder layers.
pub trait BuildEncoderLayer: Debug {
    /// Build a encoder layer.
    fn build_encoder_layer(&self, vb: VarBuilder) -> Result<Box<dyn EncoderLayer>, BoxedError>;
}