oxidized_transformers/architectures/decoder.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
use std::fmt::Debug;
use crate::architectures::BuildArchitecture;
use candle_core::Tensor;
use candle_nn::VarBuilder;
use crate::architectures::output::LayerOutputs;
use crate::error::BoxedError;
use crate::layers::attention::AttentionMask;
/// Decoder output.
pub struct DecoderOutput {
all_outputs: Vec<Tensor>,
}
impl DecoderOutput {
pub fn new(all_outputs: Vec<Tensor>) -> Self {
Self { all_outputs }
}
}
impl LayerOutputs for DecoderOutput {
fn layer_outputs(&self) -> &[Tensor] {
&self.all_outputs
}
fn embedding_layer_output(&self) -> Option<&Tensor> {
self.all_outputs.first()
}
}
/// Trait for building decoders.
pub trait BuildDecoder: Debug {
/// Decoder type.
type Decoder: Decoder;
/// Build a decoder.
fn build(&self, vb: VarBuilder) -> Result<Self::Decoder, BoxedError>;
}
impl<D> BuildDecoder for D
where
D: BuildArchitecture + Debug,
D::Architecture: Decoder,
{
type Decoder = D::Architecture;
fn build(&self, vb: VarBuilder) -> Result<Self::Decoder, BoxedError> {
self.build(vb)
}
}
/// Trait for decoders.
pub trait Decoder {
/// Cache type for the decoder.
type Cache;
/// Decode an input sequence.
///
/// Returns the decoder output and cache.
///
/// * `piece_ids` - Input sequence.
/// *Shape:* `(batch_size, seq_len)`
/// * `attention_mask` - Attention mask. Sequence elements for which the
/// corresponding mask element is set to `false` are ignored during
/// attention calculation.
/// *Shape:* `(batch_size, seq_len)`
/// * `cache` - Cache to avoid recomputing intermediate values.
/// * `positions` - Input positions.
/// *Shape:* `(batch_size, seq_len)`
/// * `train` - Whether to train the layer.
fn forward_t(
&self,
piece_ids: &Tensor,
attention_mask: &AttentionMask,
cache: &mut Self::Cache,
positions: Option<&Tensor>,
train: bool,
) -> Result<DecoderOutput, BoxedError>;
}
/// Trait for decoder layers.
pub trait DecoderLayer {
/// Cache type for the decoder.
///
/// The cache can store the intermediate values of the decoder layer,
/// avoiding recomputation when calling the decoder again for generating
/// another output.
type Cache;
/// Apply the decoder layer to the given hidden representations.
///
/// * `input` - Hidden representations to apply the layer to.
/// *Shape:* `(batch_size, seq_len, width)`
/// * `attention_mask` - Attention mask. Sequence elements for which the
/// corresponding mask element is set to `false` are ignored
/// during attention calculation.
/// *Shape:* `(batch_size, seq_len)`
/// * `cache` - Cache to avoid recomputing intermediate values.
/// * `positions` - Input positions.
/// *Shape:* `(batch_size, seq_len)`
/// * `train` - Whether to train the layer.
///
/// Returns layer output and the cache.
/// *Shape:* ``(batch_size, seq_len, width)``
fn forward_t(
&self,
input: &Tensor,
attention_mask: &AttentionMask,
cache: &mut Self::Cache,
positions: Option<&Tensor>,
train: bool,
) -> Result<Tensor, BoxedError>;
}
/// Trait for building decoder layers.
pub trait BuildDecoderLayer: Debug {
/// Cache type for the decoder.
///
/// The cache can store the intermediate values of the decoder layer,
/// avoiding recomputation when calling the decoder again for generating
/// another output.
type Cache;
/// Build a decoder layer.
fn build_decoder_layer(
&self,
vb: VarBuilder,
) -> Result<Box<dyn DecoderLayer<Cache = Self::Cache>>, BoxedError>;
}