llm_gpt2/
lib.rs

1//! An implementation of [GPT-2](https://huggingface.co/docs/transformers/model_doc/gpt2) for the `llm` ecosystem.
2#![deny(missing_docs)]
3
4use std::path::Path;
5
6use ggml::Tensor;
7use llm_base::{
8    ggml,
9    model::{common, HyperparametersWriteError},
10    util, FileType, InferenceParameters, InferenceSession, InferenceSessionConfig, KnownModel,
11    LoadError, LoadProgress, ModelParameters, OutputRequest, TokenId, Vocabulary,
12};
13
14/// The GPT-2 model. Ref: [The Illustrated GPT-2](https://jalammar.github.io/illustrated-gpt2/)
15///
16/// # Safety
17/// This implements [Send] and [Sync] as it is immutable after construction.
18pub struct Gpt2 {
19    hyperparameters: Hyperparameters,
20    n_context_tokens: usize,
21    vocabulary: Vocabulary,
22    ln_f_g: Tensor,
23    ln_f_b: Tensor,
24    wte: Tensor,
25    wpe: Tensor,
26    lm_head: Tensor,
27    layers: Vec<Layer>,
28    inference_params: InferenceParameters,
29    _context: ggml::Context,
30}
31
32unsafe impl Send for Gpt2 {}
33unsafe impl Sync for Gpt2 {}
34
35impl Gpt2 {
36    /// Load a GPT-2 model from the `path` and configure it per the `params`. The status
37    /// of the loading process will be reported through `load_progress_callback`. This
38    /// is a helper function on top of [llm_base::load].
39    pub fn load(
40        path: &Path,
41        params: ModelParameters,
42        load_progress_callback: impl FnMut(LoadProgress),
43    ) -> Result<Gpt2, LoadError> {
44        llm_base::load(path, params, load_progress_callback)
45    }
46}
47
48impl KnownModel for Gpt2 {
49    type Hyperparameters = Hyperparameters;
50
51    fn new<E: std::error::Error>(
52        hyperparameters: Self::Hyperparameters,
53        params: ModelParameters,
54        vocabulary: Vocabulary,
55        tensor_loader: impl llm_base::TensorLoader<E>,
56    ) -> Result<Self, E> {
57        let mut tl = tensor_loader;
58        // prepare memory for weights
59        let ln_f_g = tl.load("model/ln_f/g")?;
60        let ln_f_b = tl.load("model/ln_f/b")?;
61        let wte = tl.load("model/wte")?;
62        let wpe = tl.load("model/wpe")?;
63        let lm_head = tl.load("model/lm_head")?;
64
65        let mut layers = Vec::new();
66        for i in 0..hyperparameters.n_layer {
67            let layer = Layer {
68                ln_1_g: tl.load(&format!("model/h{i}/ln_1/g"))?,
69                ln_1_b: tl.load(&format!("model/h{i}/ln_1/b"))?,
70                ln_2_g: tl.load(&format!("model/h{i}/ln_2/g"))?,
71                ln_2_b: tl.load(&format!("model/h{i}/ln_2/b"))?,
72                c_attn_attn_w: tl.load(&format!("model/h{i}/attn/c_attn/w"))?,
73                c_attn_attn_b: tl.load(&format!("model/h{i}/attn/c_attn/b"))?,
74                c_attn_proj_w: tl.load(&format!("model/h{i}/attn/c_proj/w"))?,
75                c_attn_proj_b: tl.load(&format!("model/h{i}/attn/c_proj/b"))?,
76                c_mlp_fc_w: tl.load(&format!("model/h{i}/mlp/c_fc/w"))?,
77                c_mlp_fc_b: tl.load(&format!("model/h{i}/mlp/c_fc/b"))?,
78                c_mlp_proj_w: tl.load(&format!("model/h{i}/mlp/c_proj/w"))?,
79                c_mlp_proj_b: tl.load(&format!("model/h{i}/mlp/c_proj/b"))?,
80            };
81
82            layers.push(layer);
83        }
84
85        let (_context, _, _mmap) = tl.finish();
86
87        let ModelParameters {
88            n_context_tokens,
89            inference_parameters: inference_params,
90            ..
91        } = params;
92
93        Ok(Gpt2 {
94            hyperparameters,
95            n_context_tokens,
96            vocabulary,
97            layers,
98            ln_f_g,
99            ln_f_b,
100            wte,
101            wpe,
102            lm_head,
103            inference_params,
104            _context,
105        })
106    }
107
108    fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession {
109        InferenceSession::new(
110            config,
111            self.hyperparameters.n_ctx,
112            self.hyperparameters.n_layer,
113            self.hyperparameters.n_embd,
114            self.hyperparameters.n_vocab,
115        )
116    }
117
118    fn evaluate(
119        &self,
120        session: &mut InferenceSession,
121        params: &InferenceParameters,
122        input_tokens: &[TokenId],
123        output_request: &mut OutputRequest,
124    ) {
125        let n = input_tokens.len();
126        let n_threads = params.n_threads;
127
128        let Hyperparameters {
129            n_embd,
130            n_head,
131            n_vocab,
132            n_layer,
133            ..
134        } = self.hyperparameters;
135        let n_ctx = self.n_context_tokens;
136
137        let (ctx0, embd) = common::prepare_for_evaluate(n_layer, session, input_tokens);
138
139        let n_past = session.n_past;
140
141        let mut position_buf = vec![];
142        for position_idx in 0..n {
143            position_buf.push(n_past + position_idx);
144        }
145
146        let mut position = ctx0.new_tensor_1d(ggml::Type::I32, n);
147        unsafe { position.write_data(bytemuck::cast_slice(&position_buf)) };
148
149        let mut input_layer = ctx0.op_add(
150            &ctx0.op_get_rows(&self.wte, &embd),
151            &ctx0.op_get_rows(&self.wpe, &position),
152        );
153
154        let memory_k = &session.memory_k;
155        let memory_k_size = memory_k.element_size();
156
157        let memory_v = &session.memory_v;
158        let memory_v_size = memory_v.element_size();
159
160        let mut gf = ggml::ComputationGraph::new(n_threads);
161
162        for il in 0..n_layer {
163            // norm
164            let mut current = ctx0.op_norm(&input_layer);
165            current = ctx0.op_add(
166                &ctx0.op_mul(&ctx0.op_repeat(&self.layers[il].ln_1_g, &current), &current),
167                &ctx0.op_repeat(&self.layers[il].ln_1_b, &current),
168            );
169
170            // attn
171            current = ctx0.op_mul_mat(&self.layers[il].c_attn_attn_w, &current);
172            current = ctx0.op_add(
173                &ctx0.op_repeat(&self.layers[il].c_attn_attn_b, &current),
174                &current,
175            );
176
177            // self-attn
178            let nb = current.get_nb()[1];
179            let f32_size = std::mem::size_of::<f32>();
180            let qcur = ctx0.op_view_2d(&current, (n_embd, n), nb, 0);
181            let kcur = ctx0.op_view_2d(&current, (n_embd, n), nb, f32_size * n_embd);
182            let vcur = ctx0.op_view_2d(&current, (n_embd, n), nb, f32_size * n_embd * 2);
183
184            if n >= 1 {
185                let k = ctx0.op_view_1d(
186                    memory_k,
187                    n * n_embd,
188                    (memory_k_size * n_embd) * (il * n_ctx + n_past),
189                );
190                let v = ctx0.op_view_1d(
191                    memory_v,
192                    n * n_embd,
193                    (memory_v_size * n_embd) * (il * n_ctx + n_past),
194                );
195
196                gf.build_forward_expand(&ctx0.op_cpy(&kcur, &k));
197                gf.build_forward_expand(&ctx0.op_cpy(&vcur, &v));
198            }
199
200            let q = ctx0.op_permute(
201                &ctx0.op_cpy(
202                    &qcur,
203                    &ctx0.new_tensor_3d(ggml::Type::F32, n_embd / n_head, n_head, n),
204                ),
205                0,
206                2,
207                1,
208                3,
209            );
210
211            let k = ctx0.op_permute(
212                &ctx0.op_reshape_3d(
213                    &ctx0.op_view_1d(
214                        &session.memory_k,
215                        (n_past + n) * n_embd,
216                        il * n_ctx * memory_k_size * n_embd,
217                    ),
218                    n_embd / n_head,
219                    n_head,
220                    n_past + n,
221                ),
222                0,
223                2,
224                1,
225                3,
226            );
227
228            let kq = ctx0.op_mul_mat(&k, &q);
229            let kq_scaled = ctx0.op_scale(
230                &kq,
231                &ctx0.new_f32(1f32 / f32::sqrt(n_embd as f32 / n_head as f32)),
232            );
233
234            let kq_masked = ctx0.op_diag_mask_inf(&kq_scaled, n_past);
235            let kq_softmax = ctx0.op_soft_max(&kq_masked);
236
237            let v_trans = ctx0.op_cpy(
238                &ctx0.op_permute(
239                    &ctx0.op_reshape_3d(
240                        &ctx0.op_view_1d(
241                            memory_v,
242                            (n_past + n) * n_embd,
243                            il * n_ctx * memory_v_size * n_embd,
244                        ),
245                        n_embd / n_head,
246                        n_head,
247                        n_past + n,
248                    ),
249                    1,
250                    2,
251                    0,
252                    3,
253                ),
254                &ctx0.new_tensor_3d(memory_v.get_type(), n_past + n, n_embd / n_head, n_head),
255            );
256
257            let kqv = ctx0.op_mul_mat(&v_trans, &kq_softmax);
258            let kqv_merged = ctx0.op_permute(&kqv, 0, 2, 1, 3);
259
260            current = ctx0.op_cpy(&kqv_merged, &ctx0.new_tensor_2d(ggml::Type::F32, n_embd, n));
261
262            // projection
263            current = ctx0.op_mul_mat(&self.layers[il].c_attn_proj_w, &current);
264            current = ctx0.op_add(
265                &ctx0.op_repeat(&self.layers[il].c_attn_proj_b, &current),
266                &current,
267            );
268
269            // add input
270            current = ctx0.op_add(&current, &input_layer);
271
272            // feed-forward
273            let ff_in = current.share();
274
275            // feed-forward normalization
276            current = ctx0.op_norm(&ff_in);
277            current = ctx0.op_add(
278                &ctx0.op_mul(&ctx0.op_repeat(&self.layers[il].ln_2_g, &current), &current),
279                &ctx0.op_repeat(&self.layers[il].ln_2_b, &current),
280            );
281
282            // feed-forward fully connected
283            current = ctx0.op_mul_mat(&self.layers[il].c_mlp_fc_w, &current);
284            current = ctx0.op_add(
285                &ctx0.op_repeat(&self.layers[il].c_mlp_fc_b, &current),
286                &current,
287            );
288
289            // feed-forward activation
290            current = ctx0.op_gelu(&current);
291
292            // feed-forward projection
293            current = ctx0.op_mul_mat(&self.layers[il].c_mlp_proj_w, &current);
294            current = ctx0.op_add(
295                &ctx0.op_repeat(&self.layers[il].c_mlp_proj_b, &current),
296                &current,
297            );
298
299            // input for next layer
300            input_layer = ctx0.op_add(&current, &ff_in);
301        }
302
303        // normalization
304        input_layer = ctx0.op_norm(&input_layer);
305        input_layer = ctx0.op_add(
306            &ctx0.op_mul(&ctx0.op_repeat(&self.ln_f_g, &input_layer), &input_layer),
307            &ctx0.op_repeat(&self.ln_f_b, &input_layer),
308        );
309
310        input_layer = ctx0.op_mul_mat(&self.lm_head, &input_layer);
311
312        // run the computation
313        gf.build_forward_expand(&input_layer);
314        ctx0.graph_compute(&mut gf);
315
316        // finish evaluation
317        common::read_last_token(session, &input_layer, n_vocab, n);
318        common::extract_logits(output_request, &input_layer, n_vocab, n);
319        common::extract_embeddings(output_request, &embd, n_embd, n);
320        common::update_session(session, &ctx0, input_tokens.len(), n);
321    }
322
323    fn vocabulary(&self) -> &Vocabulary {
324        &self.vocabulary
325    }
326
327    fn n_context_tokens(&self) -> usize {
328        self.hyperparameters.n_ctx
329    }
330
331    fn bot_token_id(&self) -> Option<TokenId> {
332        None
333    }
334
335    fn eot_token_id(&self) -> TokenId {
336        self.vocabulary
337            .token_to_id
338            .get("<|endoftext|>".as_bytes())
339            .copied()
340            .unwrap()
341    }
342
343    fn inference_parameters(&self) -> &InferenceParameters {
344        &self.inference_params
345    }
346}
347
348/// GPT-2 [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning))
349#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
350pub struct Hyperparameters {
351    /// Size of the model's vocabulary
352    n_vocab: usize,
353    /// Size of the model's context
354    n_ctx: usize,
355    /// Size of the model's embedding layer
356    n_embd: usize,
357    /// n_head
358    n_head: usize,
359    /// Number of layers in the model
360    n_layer: usize,
361    /// file type
362    file_type: FileType,
363}
364impl llm_base::Hyperparameters for Hyperparameters {
365    fn read_ggml(reader: &mut dyn std::io::BufRead) -> Result<Self, LoadError> {
366        let hyperparameters = Hyperparameters {
367            n_vocab: util::read_i32(reader)?.try_into()?,
368            n_ctx: util::read_i32(reader)?.try_into()?,
369            n_embd: util::read_i32(reader)?.try_into()?,
370            n_head: util::read_i32(reader)?.try_into()?,
371            n_layer: util::read_i32(reader)?.try_into()?,
372            file_type: {
373                let ftype = util::read_i32(reader)?;
374                FileType::try_from(ftype).map_err(|_| LoadError::UnsupportedFileType(ftype))?
375            },
376        };
377
378        let n_vocab = util::read_i32(reader)? as usize;
379        if hyperparameters.n_vocab != n_vocab {
380            return Err(LoadError::InvariantBroken {
381                path: None,
382                invariant: format!(
383                    "GPT2 model expected n_vocab {} found {}",
384                    hyperparameters.n_vocab, n_vocab
385                ),
386            });
387        }
388
389        Ok(hyperparameters)
390    }
391
392    fn write_ggml(&self, writer: &mut dyn std::io::Write) -> Result<(), HyperparametersWriteError> {
393        util::write_i32(writer, self.n_vocab.try_into()?)?;
394        util::write_i32(writer, self.n_ctx.try_into()?)?;
395        util::write_i32(writer, self.n_embd.try_into()?)?;
396        util::write_i32(writer, self.n_head.try_into()?)?;
397        util::write_i32(writer, self.n_layer.try_into()?)?;
398        util::write_i32(writer, self.file_type.into())?;
399        util::write_i32(writer, self.n_vocab.try_into()?)?;
400
401        Ok(())
402    }
403
404    fn n_vocabulary(&self) -> usize {
405        self.n_vocab
406    }
407}
408
409struct Layer {
410    // normalization
411    ln_1_g: Tensor,
412    ln_1_b: Tensor,
413
414    ln_2_g: Tensor,
415    ln_2_b: Tensor,
416
417    // attention
418    c_attn_attn_w: Tensor,
419    c_attn_attn_b: Tensor,
420
421    c_attn_proj_w: Tensor,
422    c_attn_proj_b: Tensor,
423
424    // mlp
425    c_mlp_fc_w: Tensor,
426    c_mlp_fc_b: Tensor,
427
428    c_mlp_proj_w: Tensor,
429    c_mlp_proj_b: Tensor,
430}