Trait Decoder

Source
pub trait Decoder {
    type Cache;

    // Required method
    fn forward_t(
        &self,
        piece_ids: &Tensor,
        attention_mask: &AttentionMask,
        cache: &mut Self::Cache,
        positions: Option<&Tensor>,
        train: bool,
    ) -> Result<DecoderOutput, BoxedError>;
}
Expand description

Trait for decoders.

Required Associated Types§

Source

type Cache

Cache type for the decoder.

Required Methods§

Source

fn forward_t( &self, piece_ids: &Tensor, attention_mask: &AttentionMask, cache: &mut Self::Cache, positions: Option<&Tensor>, train: bool, ) -> Result<DecoderOutput, BoxedError>

Decode an input sequence.

Returns the decoder output and cache.

  • piece_ids - Input sequence. Shape: (batch_size, seq_len)
  • attention_mask - Attention mask. Sequence elements for which the corresponding mask element is set to false are ignored during attention calculation. Shape: (batch_size, seq_len)
  • cache - Cache to avoid recomputing intermediate values.
  • positions - Input positions. Shape: (batch_size, seq_len)
  • train - Whether to train the layer.

Implementors§