oxidized_transformers/architectures/
decoder.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
use std::fmt::Debug;

use crate::architectures::BuildArchitecture;
use candle_core::Tensor;
use candle_nn::VarBuilder;

use crate::architectures::output::LayerOutputs;
use crate::error::BoxedError;
use crate::layers::attention::AttentionMask;

/// Decoder output.
pub struct DecoderOutput {
    all_outputs: Vec<Tensor>,
}

impl DecoderOutput {
    pub fn new(all_outputs: Vec<Tensor>) -> Self {
        Self { all_outputs }
    }
}

impl LayerOutputs for DecoderOutput {
    fn layer_outputs(&self) -> &[Tensor] {
        &self.all_outputs
    }

    fn embedding_layer_output(&self) -> Option<&Tensor> {
        self.all_outputs.first()
    }
}

/// Trait for building decoders.
pub trait BuildDecoder: Debug {
    /// Decoder type.
    type Decoder: Decoder;

    /// Build a decoder.
    fn build(&self, vb: VarBuilder) -> Result<Self::Decoder, BoxedError>;
}

impl<D> BuildDecoder for D
where
    D: BuildArchitecture + Debug,
    D::Architecture: Decoder,
{
    type Decoder = D::Architecture;

    fn build(&self, vb: VarBuilder) -> Result<Self::Decoder, BoxedError> {
        self.build(vb)
    }
}

/// Trait for decoders.
pub trait Decoder {
    /// Cache type for the decoder.
    type Cache;

    /// Decode an input sequence.
    ///
    /// Returns the decoder output and cache.
    ///
    /// * `piece_ids` - Input sequence.
    ///   *Shape:* `(batch_size, seq_len)`
    /// * `attention_mask` - Attention mask. Sequence elements for which the
    ///   corresponding mask element is set to `false` are ignored during
    ///   attention calculation.
    ///   *Shape:* `(batch_size, seq_len)`
    /// * `cache` - Cache to avoid recomputing intermediate values.
    /// * `positions` - Input positions.
    ///   *Shape:* `(batch_size, seq_len)`
    /// * `train` - Whether to train the layer.
    fn forward_t(
        &self,
        piece_ids: &Tensor,
        attention_mask: &AttentionMask,
        cache: &mut Self::Cache,
        positions: Option<&Tensor>,
        train: bool,
    ) -> Result<DecoderOutput, BoxedError>;
}

/// Trait for decoder layers.
pub trait DecoderLayer {
    /// Cache type for the decoder.
    ///
    /// The cache can store the intermediate values of the decoder layer,
    /// avoiding recomputation when calling the decoder again for generating
    /// another output.
    type Cache;

    /// Apply the decoder layer to the given hidden representations.
    ///
    /// * `input` - Hidden representations to apply the layer to.
    ///   *Shape:* `(batch_size, seq_len, width)`
    /// * `attention_mask` - Attention mask. Sequence elements for which the
    ///    corresponding mask element is set to `false` are ignored
    ///    during attention calculation.
    ///   *Shape:* `(batch_size, seq_len)`
    /// * `cache` - Cache to avoid recomputing intermediate values.
    /// * `positions` - Input positions.
    ///    *Shape:* `(batch_size, seq_len)`
    /// * `train` - Whether to train the layer.
    ///
    /// Returns layer output and the cache.
    /// *Shape:* ``(batch_size, seq_len, width)``
    fn forward_t(
        &self,
        input: &Tensor,
        attention_mask: &AttentionMask,
        cache: &mut Self::Cache,
        positions: Option<&Tensor>,
        train: bool,
    ) -> Result<Tensor, BoxedError>;
}

/// Trait for building decoder layers.
pub trait BuildDecoderLayer: Debug {
    /// Cache type for the decoder.
    ///
    /// The cache can store the intermediate values of the decoder layer,
    /// avoiding recomputation when calling the decoder again for generating
    /// another output.
    type Cache;

    /// Build a decoder layer.
    fn build_decoder_layer(
        &self,
        vb: VarBuilder,
    ) -> Result<Box<dyn DecoderLayer<Cache = Self::Cache>>, BoxedError>;
}