pub trait Decoder {
type Cache;
// Required method
fn forward_t(
&self,
piece_ids: &Tensor,
attention_mask: &AttentionMask,
cache: &mut Self::Cache,
positions: Option<&Tensor>,
train: bool,
) -> Result<DecoderOutput, BoxedError>;
}Expand description
Trait for decoders.
Required Associated Types§
Required Methods§
Sourcefn forward_t(
&self,
piece_ids: &Tensor,
attention_mask: &AttentionMask,
cache: &mut Self::Cache,
positions: Option<&Tensor>,
train: bool,
) -> Result<DecoderOutput, BoxedError>
fn forward_t( &self, piece_ids: &Tensor, attention_mask: &AttentionMask, cache: &mut Self::Cache, positions: Option<&Tensor>, train: bool, ) -> Result<DecoderOutput, BoxedError>
Decode an input sequence.
Returns the decoder output and cache.
piece_ids- Input sequence. Shape:(batch_size, seq_len)attention_mask- Attention mask. Sequence elements for which the corresponding mask element is set tofalseare ignored during attention calculation. Shape:(batch_size, seq_len)cache- Cache to avoid recomputing intermediate values.positions- Input positions. Shape:(batch_size, seq_len)train- Whether to train the layer.