use std::fmt::Debug;
use crate::architectures::BuildArchitecture;
use candle_core::Tensor;
use candle_nn::VarBuilder;
use crate::architectures::output::LayerOutputs;
use crate::error::BoxedError;
use crate::layers::attention::AttentionMask;
pub struct DecoderOutput {
all_outputs: Vec<Tensor>,
}
impl DecoderOutput {
pub fn new(all_outputs: Vec<Tensor>) -> Self {
Self { all_outputs }
}
}
impl LayerOutputs for DecoderOutput {
fn layer_outputs(&self) -> &[Tensor] {
&self.all_outputs
}
fn embedding_layer_output(&self) -> Option<&Tensor> {
self.all_outputs.first()
}
}
pub trait BuildDecoder: Debug {
type Decoder: Decoder;
fn build(&self, vb: VarBuilder) -> Result<Self::Decoder, BoxedError>;
}
impl<D> BuildDecoder for D
where
D: BuildArchitecture + Debug,
D::Architecture: Decoder,
{
type Decoder = D::Architecture;
fn build(&self, vb: VarBuilder) -> Result<Self::Decoder, BoxedError> {
self.build(vb)
}
}
pub trait Decoder {
type Cache;
fn forward_t(
&self,
piece_ids: &Tensor,
attention_mask: &AttentionMask,
cache: &mut Self::Cache,
positions: Option<&Tensor>,
train: bool,
) -> Result<DecoderOutput, BoxedError>;
}
pub trait DecoderLayer {
type Cache;
fn forward_t(
&self,
input: &Tensor,
attention_mask: &AttentionMask,
cache: &mut Self::Cache,
positions: Option<&Tensor>,
train: bool,
) -> Result<Tensor, BoxedError>;
}
pub trait BuildDecoderLayer: Debug {
type Cache;
fn build_decoder_layer(
&self,
vb: VarBuilder,
) -> Result<Box<dyn DecoderLayer<Cache = Self::Cache>>, BoxedError>;
}