oxidized_transformers/models/transformer/
causal_lm.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
use candle_core::{ModuleT, Tensor};
use candle_nn::{linear_no_bias, VarBuilder};
use snafu::{ResultExt, Snafu};

use crate::architectures::BuildArchitecture;
use crate::architectures::{BuildDecoder, CausalLM, CausalLMOutput, Decoder, LayerOutputs};
use crate::error::BoxedError;
use crate::kv_cache::KeyValueCache;
use crate::layers::attention::AttentionMask;
use crate::models::transformer::{TransformerDecoder, TransformerDecoderConfig};

/// Transformer causal language model configuration.
#[derive(Debug)]
pub struct TransformerCausalLMConfig {
    decoder: Box<dyn BuildDecoder<Decoder = TransformerDecoder>>,
    hidden_size: usize,
    n_pieces: usize,
}

impl TransformerCausalLMConfig {
    /// Decoder.
    ///
    /// Default: `TransformerDecoderConfig`.
    pub fn decoder(mut self, decoder: Box<dyn BuildDecoder<Decoder = TransformerDecoder>>) -> Self {
        self.decoder = decoder;
        self
    }

    /// Hidden size.
    ///
    /// Default: 4096.
    pub fn hidden_size(mut self, hidden_size: usize) -> Self {
        self.hidden_size = hidden_size;
        self
    }

    /// Number of pieces in the output vocabulary.
    ///
    /// Default: 32000.
    pub fn n_pieces(mut self, n_pieces: usize) -> Self {
        self.n_pieces = n_pieces;
        self
    }
}

impl Default for TransformerCausalLMConfig {
    fn default() -> Self {
        Self {
            decoder: Box::<TransformerDecoderConfig>::default(),
            hidden_size: 4096,
            n_pieces: 32000,
        }
    }
}

impl BuildArchitecture for TransformerCausalLMConfig {
    type Architecture = TransformerCausalLM;

    fn build(&self, vb: VarBuilder) -> Result<Self::Architecture, BoxedError> {
        let decoder = Box::new(
            self.decoder
                .build(vb.push_prefix("decoder"))
                .context(BuildDecoderSnafu)?,
        );

        Ok(TransformerCausalLM {
            decoder,
            output_embeddings: Box::new(
                linear_no_bias(
                    self.hidden_size,
                    self.n_pieces,
                    vb.push_prefix("output_embeddings"),
                )
                .context(BuildEmbeddingsSnafu)?,
            ),
        })
    }
}

/// `TransformerCausalLM` errors.
#[derive(Debug, Snafu)]
pub enum TransformerCausalLMError {
    #[snafu(display("Cannot build decoder"))]
    BuildDecoder { source: BoxedError },

    #[snafu(display("Cannot build output embeddings"))]
    BuildEmbeddings { source: candle_core::Error },

    #[snafu(display("Nothing to decode, the decoder does not have any layer outputs"))]
    NoLayers,

    #[snafu(display("Cannot compute logits for the output vocabulary"))]
    ComputeLogits { source: candle_core::Error },
}

/// Transformer-based causal language model.
pub struct TransformerCausalLM {
    decoder: Box<dyn Decoder<Cache = KeyValueCache>>,
    output_embeddings: Box<dyn ModuleT>,
}

impl CausalLM for TransformerCausalLM {
    type Cache = KeyValueCache;

    fn forward_t(
        &self,
        piece_ids: &Tensor,
        mask: &AttentionMask,
        cache: &mut Self::Cache,
        positions: Option<&Tensor>,
        train: bool,
    ) -> Result<CausalLMOutput, BoxedError> {
        let decoder_output = self
            .decoder
            .forward_t(piece_ids, mask, cache, positions, train)?;
        let last_layer = decoder_output
            .layer_outputs()
            .last()
            .ok_or(TransformerCausalLMError::NoLayers)?;
        let logits = self
            .output_embeddings
            .forward_t(last_layer, train)
            .context(ComputeLogitsSnafu)?;

        Ok(CausalLMOutput::new(decoder_output, logits))
    }
}