oxidized_transformers/architectures/encoder.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
use std::fmt::Debug;
use candle_core::Tensor;
use candle_nn::VarBuilder;
use crate::architectures::output::LayerOutputs;
use crate::architectures::BuildArchitecture;
use crate::error::BoxedError;
use crate::layers::attention::AttentionMask;
/// Encoder output.
pub struct EncoderOutput {
all_outputs: Vec<Tensor>,
}
impl EncoderOutput {
/// Create an encoder output.
pub fn new(all_outputs: Vec<Tensor>) -> Self {
Self { all_outputs }
}
}
impl LayerOutputs for EncoderOutput {
/// All layer outputs.
fn layer_outputs(&self) -> &[Tensor] {
&self.all_outputs
}
/// Embedding layer output.
fn embedding_layer_output(&self) -> Option<&Tensor> {
self.all_outputs.first()
}
}
/// Trait for encoders.
pub trait Encoder {
/// Encode an input sequence.
///
/// Returns the encoder output.
///
/// * `piece_ids` - Input sequence.
/// *Shape:* `(batch_size, seq_len)`
/// * `attention_mask` - Attention mask. Sequence elements for which the
/// corresponding mask element is set to `false` are ignored during
/// attention calculation.
/// *Shape:* `(batch_size, seq_len)`
/// * `positions` - Input positions.
/// *Shape:* `(batch_size, seq_len)`
/// * `type_ids` - Input type ids.
/// *Shape:* `(batch_size, seq_len)`
/// * `train` - Whether to train the layer.
fn forward_t(
&self,
piece_ids: &Tensor,
attention_mask: &AttentionMask,
positions: Option<&Tensor>,
type_ids: Option<&Tensor>,
train: bool,
) -> Result<EncoderOutput, BoxedError>;
}
/// Trait for building encoders.
pub trait BuildEncoder: Debug {
/// Encoder type.
type Encoder: Encoder;
/// Build an encoder.
fn build(&self, vb: VarBuilder) -> Result<Self::Encoder, BoxedError>;
}
impl<D> BuildEncoder for D
where
D: BuildArchitecture + Debug,
D::Architecture: Encoder,
{
type Encoder = D::Architecture;
fn build(&self, vb: VarBuilder) -> Result<Self::Encoder, BoxedError> {
self.build(vb)
}
}
/// Trait for encoder layers.
pub trait EncoderLayer {
/// Apply the encoder layer to the given hidden representations.
///
/// * `input` - Hidden representations to apply the layer to.
/// *Shape:* `(batch_size, seq_len, width)`
/// * `attention_mask` - Attention mask. Sequence elements for which the
/// corresponding mask element is set to `false` are ignored
/// during attention calculation.
/// *Shape:* `(batch_size, seq_len)`
/// * `positions` - Input positions.
/// *Shape:* `(batch_size, seq_len)`
/// * `train` - Whether to train the layer.
///
/// Returns layer output and the cache.
/// *Shape:* ``(batch_size, seq_len, width)``
fn forward_t(
&self,
input: &Tensor,
mask: &AttentionMask,
positions: Option<&Tensor>,
train: bool,
) -> Result<Tensor, BoxedError>;
}
/// Trait for building encoder layers.
pub trait BuildEncoderLayer: Debug {
/// Build a encoder layer.
fn build_encoder_layer(&self, vb: VarBuilder) -> Result<Box<dyn EncoderLayer>, BoxedError>;
}