llm_gptj/
lib.rs

1//! An implementation of [GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj) for the `llm` ecosystem.
2#![deny(missing_docs)]
3
4use std::{error::Error, path::Path};
5
6use ggml::Tensor;
7use llm_base::{
8    ggml,
9    model::{common, HyperparametersWriteError},
10    util, FileType, InferenceParameters, InferenceSession, InferenceSessionConfig, KnownModel,
11    LoadError, LoadProgress, Mmap, ModelParameters, OutputRequest, TensorLoader, TokenId,
12    Vocabulary,
13};
14
15/// The GPT-J model. Ref: [GitHub](https://github.com/kingoflolz/mesh-transformer-jax/#gpt-j-6b)
16///
17/// # Safety
18/// This implements [Send] and [Sync] as it is immutable after construction.
19pub struct GptJ {
20    hyperparameters: Hyperparameters,
21    n_context_tokens: usize,
22
23    vocabulary: Vocabulary,
24
25    // normalization
26    ln_f_g: Tensor,
27    ln_f_b: Tensor,
28
29    // position embedding
30    wte: Tensor,
31
32    // language model head & bias
33    lmh_g: Tensor,
34    lmh_b: Tensor,
35
36    layers: Vec<Layer>,
37
38    inference_parameters: InferenceParameters,
39
40    /// Needs to kept alive while the model is alive
41    _mmap: Option<Mmap>,
42
43    // Must be kept alive for the model
44    _context: ggml::Context,
45}
46
47unsafe impl Send for GptJ {}
48unsafe impl Sync for GptJ {}
49
50impl GptJ {
51    /// Load a GPT-J model from the `path` and configure it per the `params`. The status
52    /// of the loading process will be reported through `load_progress_callback`. This
53    /// is a helper function on top of [llm_base::load].
54    pub fn load(
55        path: &Path,
56        params: ModelParameters,
57        load_progress_callback: impl FnMut(LoadProgress),
58    ) -> Result<GptJ, LoadError> {
59        llm_base::load(path, params, load_progress_callback)
60    }
61}
62
63impl KnownModel for GptJ {
64    type Hyperparameters = Hyperparameters;
65
66    fn new<E: Error>(
67        hyperparameters: Self::Hyperparameters,
68        params: ModelParameters,
69        vocabulary: Vocabulary,
70        tensor_loader: impl TensorLoader<E>,
71    ) -> Result<Self, E>
72    where
73        Self: Sized,
74    {
75        let mut tl = tensor_loader;
76
77        // prepare memory for weights
78        let wte = tl.load("transformer.wte.weight")?;
79        let ln_f_g = tl.load("transformer.ln_f.weight")?;
80        let ln_f_b = tl.load("transformer.ln_f.bias")?;
81        let lmh_g = tl.load("lm_head.weight")?;
82        let lmh_b = tl.load("lm_head.bias")?;
83
84        let mut layers = Vec::new();
85        for i in 0..hyperparameters.n_layer {
86            let layer = Layer {
87                ln_1_g: tl.load(&format!("transformer.h.{i}.ln_1.weight"))?,
88                ln_1_b: tl.load(&format!("transformer.h.{i}.ln_1.bias"))?,
89                c_attn_q_proj_w: tl.load(&format!("transformer.h.{i}.attn.q_proj.weight"))?,
90                c_attn_k_proj_w: tl.load(&format!("transformer.h.{i}.attn.k_proj.weight"))?,
91                c_attn_v_proj_w: tl.load(&format!("transformer.h.{i}.attn.v_proj.weight"))?,
92                c_attn_proj_w: tl.load(&format!("transformer.h.{i}.attn.out_proj.weight"))?,
93                c_mlp_fc_w: tl.load(&format!("transformer.h.{i}.mlp.fc_in.weight"))?,
94                c_mlp_fc_b: tl.load(&format!("transformer.h.{i}.mlp.fc_in.bias"))?,
95                c_mlp_proj_w: tl.load(&format!("transformer.h.{i}.mlp.fc_out.weight"))?,
96                c_mlp_proj_b: tl.load(&format!("transformer.h.{i}.mlp.fc_out.bias"))?,
97            };
98
99            layers.push(layer);
100        }
101
102        let (_context, _, _mmap) = tl.finish();
103
104        let ModelParameters {
105            n_context_tokens,
106            inference_parameters,
107            ..
108        } = params;
109
110        Ok(GptJ {
111            hyperparameters,
112            n_context_tokens,
113            vocabulary,
114            ln_f_g,
115            ln_f_b,
116            wte,
117            lmh_g,
118            lmh_b,
119            layers,
120            inference_parameters,
121            _mmap,
122            _context,
123        })
124    }
125
126    fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession {
127        InferenceSession::new(
128            config,
129            self.hyperparameters.n_ctx,
130            self.hyperparameters.n_layer,
131            self.hyperparameters.n_embd,
132            self.hyperparameters.n_vocab,
133        )
134    }
135
136    fn evaluate(
137        &self,
138        session: &mut InferenceSession,
139        params: &InferenceParameters,
140        input_tokens: &[TokenId],
141        output_request: &mut OutputRequest,
142    ) {
143        let n = input_tokens.len();
144        let n_threads = params.n_threads;
145
146        let Hyperparameters {
147            n_embd,
148            n_head,
149            n_vocab,
150            n_layer,
151            n_rot,
152            ..
153        } = self.hyperparameters;
154        let n_ctx = self.n_context_tokens;
155
156        let (ctx0, embd) = common::prepare_for_evaluate(n_layer, session, input_tokens);
157
158        let n_past = session.n_past;
159
160        // wte
161        let mut input_layer = ctx0.op_get_rows(&self.wte, &embd);
162
163        let memory_k = &session.memory_k;
164        let memory_k_size = memory_k.element_size();
165
166        let memory_v = &session.memory_v;
167        let memory_v_size = memory_v.element_size();
168
169        let mut gf = ggml::ComputationGraph::new(n_threads);
170
171        for il in 0..n_layer {
172            // norm
173            let mut current = ctx0.op_norm(&input_layer);
174            current = ctx0.op_add(
175                &ctx0.op_mul(&ctx0.op_repeat(&self.layers[il].ln_1_g, &current), &current),
176                &ctx0.op_repeat(&self.layers[il].ln_1_b, &current),
177            );
178
179            let input_sa = current.share();
180
181            // self-attention
182            let qcur = ctx0.op_rope(
183                &ctx0.op_reshape_3d(
184                    &ctx0.op_mul_mat(&self.layers[il].c_attn_q_proj_w, &current),
185                    n_embd / n_head,
186                    n_head,
187                    n,
188                ),
189                n_past,
190                n_rot,
191                0,
192            );
193            let kcur = ctx0.op_rope(
194                &ctx0.op_reshape_3d(
195                    &ctx0.op_mul_mat(&self.layers[il].c_attn_k_proj_w, &current),
196                    n_embd / n_head,
197                    n_head,
198                    n,
199                ),
200                n_past,
201                n_rot,
202                0,
203            );
204
205            // self-attention store key and value to memory
206            let vcur =
207                ctx0.op_transpose(&ctx0.op_mul_mat(&self.layers[il].c_attn_v_proj_w, &current));
208
209            let k = ctx0.op_view_1d(
210                memory_k,
211                n * n_embd,
212                (memory_k_size * n_embd) * (il * n_ctx + n_past),
213            );
214            let v = ctx0.op_view_2d(
215                memory_v,
216                (n, n_embd),
217                n_ctx * memory_v_size,
218                (il * n_ctx) * memory_v_size * n_embd + n_past * memory_v_size,
219            );
220
221            gf.build_forward_expand(&ctx0.op_cpy(&kcur, &k));
222            gf.build_forward_expand(&ctx0.op_cpy(&vcur, &v));
223
224            let q = ctx0.op_permute(&qcur, 0, 2, 1, 3);
225            let big_k = ctx0.op_permute(
226                &ctx0.op_reshape_3d(
227                    &ctx0.op_view_1d(
228                        memory_k,
229                        (n_past + n) * n_embd,
230                        il * n_ctx * memory_k_size * n_embd,
231                    ),
232                    n_embd / n_head,
233                    n_head,
234                    n_past + n,
235                ),
236                0,
237                2,
238                1,
239                3,
240            );
241
242            let kq = ctx0.op_mul_mat(&big_k, &q);
243            let kq_scaled = ctx0.op_scale(
244                &kq,
245                &ctx0.new_f32(1f32 / f32::sqrt(n_embd as f32 / n_head as f32)),
246            );
247
248            let kq_masked = ctx0.op_diag_mask_inf(&kq_scaled, n_past);
249            let kq_softmax = ctx0.op_soft_max(&kq_masked);
250
251            let big_v = ctx0.op_view_3d(
252                memory_v,
253                (n_past + n, n_embd / n_head, n_head),
254                (
255                    n_ctx * memory_v_size,
256                    n_ctx * memory_v_size * n_embd / n_head,
257                ),
258                il * n_ctx * memory_v_size * n_embd,
259            );
260
261            let kqv = ctx0.op_mul_mat(&big_v, &kq_softmax);
262            let kqv_merged = ctx0.op_permute(&kqv, 0, 2, 1, 3);
263
264            current = ctx0.op_cpy(&kqv_merged, &ctx0.new_tensor_2d(ggml::Type::F32, n_embd, n));
265
266            // self-attention projection
267            current = ctx0.op_mul_mat(&self.layers[il].c_attn_proj_w, &current);
268
269            // feed-forward
270            let ff_in = current.share();
271
272            current = ctx0.op_mul_mat(&self.layers[il].c_mlp_fc_w, &input_sa);
273            current = ctx0.op_add(
274                &ctx0.op_repeat(&self.layers[il].c_mlp_fc_b, &current),
275                &current,
276            );
277
278            current = ctx0.op_gelu(&current);
279
280            // feed-forward projection
281            current = ctx0.op_mul_mat(&self.layers[il].c_mlp_proj_w, &current);
282            current = ctx0.op_add(
283                &ctx0.op_repeat(&self.layers[il].c_mlp_proj_b, &current),
284                &current,
285            );
286
287            current = ctx0.op_add(&current, &ff_in);
288
289            // input for next layer
290            input_layer = ctx0.op_add(&current, &input_layer);
291        }
292
293        // norm
294        input_layer = ctx0.op_norm(&input_layer);
295        input_layer = ctx0.op_add(
296            &ctx0.op_mul(&ctx0.op_repeat(&self.ln_f_g, &input_layer), &input_layer),
297            &ctx0.op_repeat(&self.ln_f_b, &input_layer),
298        );
299
300        // lm_head
301        input_layer = ctx0.op_mul_mat(&self.lmh_g, &input_layer);
302        input_layer = ctx0.op_add(&ctx0.op_repeat(&self.lmh_b, &input_layer), &input_layer);
303
304        // run the computation
305        gf.build_forward_expand(&input_layer);
306        ctx0.graph_compute(&mut gf);
307
308        // finish evaluation
309        common::read_last_token(session, &input_layer, n_vocab, n);
310        common::extract_logits(output_request, &input_layer, n_vocab, n);
311        common::extract_embeddings(output_request, &embd, n_embd, n);
312        common::update_session(session, &ctx0, input_tokens.len(), n);
313    }
314
315    fn vocabulary(&self) -> &Vocabulary {
316        &self.vocabulary
317    }
318
319    fn n_context_tokens(&self) -> usize {
320        self.hyperparameters.n_ctx
321    }
322
323    fn bot_token_id(&self) -> Option<TokenId> {
324        None
325    }
326
327    fn eot_token_id(&self) -> TokenId {
328        self.vocabulary
329            .token_to_id
330            .get("<|endoftext|>".as_bytes())
331            .copied()
332            .unwrap()
333    }
334
335    fn inference_parameters(&self) -> &InferenceParameters {
336        &self.inference_parameters
337    }
338}
339
340/// GPT-J [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning))
341#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
342pub struct Hyperparameters {
343    /// Size of the model's vocabulary
344    pub n_vocab: usize,
345    /// Size of the model's context
346    pub n_ctx: usize,
347    /// Size of the model's embedding layer
348    pub n_embd: usize,
349    /// n_head
350    pub n_head: usize,
351    /// Number of layers in the model
352    pub n_layer: usize,
353    /// n_rot
354    pub n_rot: usize,
355    /// file_type
356    pub file_type: FileType,
357}
358impl llm_base::Hyperparameters for Hyperparameters {
359    fn read_ggml(reader: &mut dyn std::io::BufRead) -> Result<Self, LoadError> {
360        let hyperparameters = Hyperparameters {
361            n_vocab: util::read_i32(reader)?.try_into()?,
362            n_ctx: util::read_i32(reader)?.try_into()?,
363            n_embd: util::read_i32(reader)?.try_into()?,
364            n_head: util::read_i32(reader)?.try_into()?,
365            n_layer: util::read_i32(reader)?.try_into()?,
366            n_rot: util::read_i32(reader)?.try_into()?,
367            file_type: {
368                let ftype = util::read_i32(reader)?;
369                FileType::try_from(ftype).map_err(|_| LoadError::UnsupportedFileType(ftype))?
370            },
371        };
372
373        let n_vocab = util::read_i32(reader)? as usize;
374        if hyperparameters.n_vocab != n_vocab {
375            return Err(LoadError::InvariantBroken {
376                path: None,
377                invariant: format!(
378                    "GPT2 model expected n_vocab {} found {}",
379                    hyperparameters.n_vocab, n_vocab
380                ),
381            });
382        }
383
384        Ok(hyperparameters)
385    }
386
387    fn write_ggml(&self, writer: &mut dyn std::io::Write) -> Result<(), HyperparametersWriteError> {
388        util::write_i32(writer, self.n_vocab.try_into()?)?;
389        util::write_i32(writer, self.n_ctx.try_into()?)?;
390        util::write_i32(writer, self.n_embd.try_into()?)?;
391        util::write_i32(writer, self.n_head.try_into()?)?;
392        util::write_i32(writer, self.n_layer.try_into()?)?;
393        util::write_i32(writer, self.n_rot.try_into()?)?;
394        util::write_i32(writer, self.file_type.into())?;
395        Ok(())
396    }
397
398    fn n_vocabulary(&self) -> usize {
399        self.n_vocab
400    }
401}
402
403struct Layer {
404    // normalization
405    ln_1_g: Tensor,
406    ln_1_b: Tensor,
407
408    // attention
409    c_attn_q_proj_w: Tensor,
410    c_attn_k_proj_w: Tensor,
411    c_attn_v_proj_w: Tensor,
412
413    c_attn_proj_w: Tensor,
414
415    // ff
416    c_mlp_fc_w: Tensor,
417    c_mlp_fc_b: Tensor,
418
419    c_mlp_proj_w: Tensor,
420    c_mlp_proj_b: Tensor,
421}
422
423#[cfg(test)]
424impl GptJ {
425    /// This does *not* construct a valid model. All of the tensors are entirely
426    /// empty. However, it can be used to determine if some code will compile.
427    fn new_empty() -> Self {
428        let context = ggml::Context::init(1024 * 1024, true);
429
430        Self {
431            hyperparameters: Default::default(),
432            n_context_tokens: 0,
433            vocabulary: Default::default(),
434            ln_f_g: context.new_f32(0.0),
435            ln_f_b: context.new_f32(0.0),
436            wte: context.new_f32(0.0),
437            lmh_g: context.new_f32(0.0),
438            lmh_b: context.new_f32(0.0),
439            layers: Default::default(),
440            inference_parameters: Default::default(),
441            _mmap: Default::default(),
442            _context: context,
443        }
444    }
445}
446
447#[cfg(test)]
448mod tests {
449    use super::*;
450    use std::sync::Arc;
451
452    #[test]
453    fn can_share_model_between_threads() {
454        let model = Arc::new(GptJ::new_empty());
455
456        for _ in 0..4 {
457            let model = model.clone();
458            std::thread::spawn(move || {
459                let _session = model.start_session(Default::default());
460            });
461        }
462
463        let session = model.start_session(Default::default());
464        std::thread::spawn(move || {
465            let _session = session;
466        });
467    }
468}