use std::fmt::Debug;
use candle_core::Tensor;
use candle_nn::VarBuilder;
use crate::architectures::output::LayerOutputs;
use crate::architectures::BuildArchitecture;
use crate::error::BoxedError;
use crate::layers::attention::AttentionMask;
pub struct EncoderOutput {
all_outputs: Vec<Tensor>,
}
impl EncoderOutput {
pub fn new(all_outputs: Vec<Tensor>) -> Self {
Self { all_outputs }
}
}
impl LayerOutputs for EncoderOutput {
fn layer_outputs(&self) -> &[Tensor] {
&self.all_outputs
}
fn embedding_layer_output(&self) -> Option<&Tensor> {
self.all_outputs.first()
}
}
pub trait Encoder {
fn forward_t(
&self,
piece_ids: &Tensor,
attention_mask: &AttentionMask,
positions: Option<&Tensor>,
type_ids: Option<&Tensor>,
train: bool,
) -> Result<EncoderOutput, BoxedError>;
}
pub trait BuildEncoder: Debug {
type Encoder: Encoder;
fn build(&self, vb: VarBuilder) -> Result<Self::Encoder, BoxedError>;
}
impl<D> BuildEncoder for D
where
D: BuildArchitecture + Debug,
D::Architecture: Encoder,
{
type Encoder = D::Architecture;
fn build(&self, vb: VarBuilder) -> Result<Self::Encoder, BoxedError> {
self.build(vb)
}
}
pub trait EncoderLayer {
fn forward_t(
&self,
input: &Tensor,
mask: &AttentionMask,
positions: Option<&Tensor>,
train: bool,
) -> Result<Tensor, BoxedError>;
}
pub trait BuildEncoderLayer: Debug {
fn build_encoder_layer(&self, vb: VarBuilder) -> Result<Box<dyn EncoderLayer>, BoxedError>;
}