oxidized_transformers/architectures/causal_lm.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
use std::fmt::Debug;
use candle_core::Tensor;
use candle_nn::VarBuilder;
use crate::architectures::{BuildArchitecture, DecoderOutput, LayerOutputs};
use crate::error::BoxedError;
use crate::layers::attention::AttentionMask;
/// Causal language model output.
pub struct CausalLMOutput {
decoder_output: DecoderOutput,
logits: Tensor,
}
impl CausalLMOutput {
/// Create a causal language model output.
pub fn new(decoder_output: DecoderOutput, logits: Tensor) -> Self {
Self {
decoder_output,
logits,
}
}
/// Get the output of the decoder used by the causal language model.
pub fn decoder_output(&self) -> &DecoderOutput {
&self.decoder_output
}
/// Get the logits of the next predicted token.
///
/// The logits are the unnormalized probabilities. Applying softmax to the
/// logits will give the probability distribution of the next token over the
/// vocabulary.
pub fn logits(&self) -> &Tensor {
&self.logits
}
}
impl LayerOutputs for CausalLMOutput {
fn layer_outputs(&self) -> &[Tensor] {
self.decoder_output.layer_outputs()
}
fn embedding_layer_output(&self) -> Option<&Tensor> {
self.decoder_output.embedding_layer_output()
}
}
/// Trait for building causal language models.
pub trait BuildCausalLM: Debug {
type CausalLM: CausalLM;
/// Build a causal language model.
fn build(&self, vb: VarBuilder) -> Result<Self::CausalLM, BoxedError>;
}
impl<C> BuildCausalLM for C
where
C: BuildArchitecture + Debug,
C::Architecture: CausalLM,
{
type CausalLM = C::Architecture;
fn build(&self, vb: VarBuilder) -> Result<Self::CausalLM, BoxedError> {
self.build(vb)
}
}
/// Trait for causal language models.
pub trait CausalLM {
/// Cache type for the causal language model.
type Cache;
/// Predict the next token using a causal language model.
///
/// Returns the piece representations, cache, and logits of the next
/// predicted token.
///
/// * `piece_ids` - Input sequence.
/// *Shape:* `(batch_size, seq_len)`
/// * `attention_mask` - Attention mask. Sequence elements for which the
/// corresponding mask element is set to `false` are ignored during
/// attention calculation.
/// *Shape:* `(batch_size, seq_len)`
/// * `cache` - Cache to avoid recomputing intermediate values.
/// * `positions` - Input positions.
/// *Shape:* `(batch_size, seq_len)`
/// * `train` - Whether to train the layer.
fn forward_t(
&self,
piece_ids: &Tensor,
attention_mask: &AttentionMask,
cache: &mut Self::Cache,
positions: Option<&Tensor>,
train: bool,
) -> Result<CausalLMOutput, BoxedError>;
}