use std::fmt::Debug;
use candle_core::Tensor;
use candle_nn::VarBuilder;
use crate::architectures::{BuildArchitecture, DecoderOutput, LayerOutputs};
use crate::error::BoxedError;
use crate::layers::attention::AttentionMask;
pub struct CausalLMOutput {
decoder_output: DecoderOutput,
logits: Tensor,
}
impl CausalLMOutput {
pub fn new(decoder_output: DecoderOutput, logits: Tensor) -> Self {
Self {
decoder_output,
logits,
}
}
pub fn decoder_output(&self) -> &DecoderOutput {
&self.decoder_output
}
pub fn logits(&self) -> &Tensor {
&self.logits
}
}
impl LayerOutputs for CausalLMOutput {
fn layer_outputs(&self) -> &[Tensor] {
self.decoder_output.layer_outputs()
}
fn embedding_layer_output(&self) -> Option<&Tensor> {
self.decoder_output.embedding_layer_output()
}
}
pub trait BuildCausalLM: Debug {
type CausalLM: CausalLM;
fn build(&self, vb: VarBuilder) -> Result<Self::CausalLM, BoxedError>;
}
impl<C> BuildCausalLM for C
where
C: BuildArchitecture + Debug,
C::Architecture: CausalLM,
{
type CausalLM = C::Architecture;
fn build(&self, vb: VarBuilder) -> Result<Self::CausalLM, BoxedError> {
self.build(vb)
}
}
pub trait CausalLM {
type Cache;
fn forward_t(
&self,
piece_ids: &Tensor,
attention_mask: &AttentionMask,
cache: &mut Self::Cache,
positions: Option<&Tensor>,
train: bool,
) -> Result<CausalLMOutput, BoxedError>;
}