llm_neox/
lib.rs

1//! An implementation of [GPT-NeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox) for the `llm` ecosystem.
2#![deny(missing_docs)]
3
4use std::{error::Error, path::Path};
5
6use ggml::Tensor;
7use llm_base::{
8    ggml,
9    model::{common, HyperparametersWriteError},
10    util, FileType, InferenceParameters, InferenceSession, InferenceSessionConfig, KnownModel,
11    LoadError, LoadProgress, Mmap, ModelParameters, OutputRequest, TensorLoader, TokenId,
12    Vocabulary,
13};
14
15/// The GPT-NeoX model. Ref: [GitHub](https://github.com/EleutherAI/gpt-neox)
16///
17/// # Safety
18/// This implements [Send] and [Sync] as it is immutable after construction.
19pub struct NeoX {
20    hyperparameters: Hyperparameters,
21    n_context_tokens: usize,
22
23    vocabulary: Vocabulary,
24
25    // normalization
26    ln_f_g: Tensor,
27    ln_f_b: Tensor,
28
29    // position embedding
30    wte: Tensor,
31
32    // language model head
33    lmh_g: Tensor,
34
35    layers: Vec<Layer>,
36
37    inference_parameters: InferenceParameters,
38
39    /// Needs to kept alive while the model is alive
40    _mmap: Option<Mmap>,
41
42    // Must be kept alive for the model
43    _context: ggml::Context,
44}
45
46unsafe impl Send for NeoX {}
47unsafe impl Sync for NeoX {}
48
49impl NeoX {
50    /// Load a GPT-NeoX model from the `path` and configure it per the `params`. The
51    /// status of the loading process will be reported through `load_progress_callback`.
52    /// This is a helper function on top of [llm_base::load].
53    pub fn load(
54        path: &Path,
55        params: ModelParameters,
56        load_progress_callback: impl FnMut(LoadProgress),
57    ) -> Result<NeoX, LoadError> {
58        llm_base::load(path, params, load_progress_callback)
59    }
60}
61
62impl KnownModel for NeoX {
63    type Hyperparameters = Hyperparameters;
64
65    fn new<E: Error>(
66        hyperparameters: Self::Hyperparameters,
67        params: ModelParameters,
68        vocabulary: Vocabulary,
69        tensor_loader: impl TensorLoader<E>,
70    ) -> Result<Self, E>
71    where
72        Self: Sized,
73    {
74        let mut tl = tensor_loader;
75
76        // prepare memory for weights
77        let wte = tl.load("gpt_neox.embed_in.weight")?;
78        let ln_f_g = tl.load("gpt_neox.final_layer_norm.weight")?;
79        let ln_f_b = tl.load("gpt_neox.final_layer_norm.bias")?;
80        let lmh_g = tl.load("embed_out.weight")?;
81
82        let mut layers = Vec::new();
83        for i in 0..hyperparameters.n_layer {
84            let layer = Layer {
85                ln_1_g: tl.load(&format!("gpt_neox.layers.{i}.input_layernorm.weight"))?,
86                ln_1_b: tl.load(&format!("gpt_neox.layers.{i}.input_layernorm.bias"))?,
87
88                c_attn_attn_w: tl.load(&format!(
89                    "gpt_neox.layers.{i}.attention.query_key_value.weight"
90                ))?,
91                c_attn_attn_b: tl.load(&format!(
92                    "gpt_neox.layers.{i}.attention.query_key_value.bias"
93                ))?,
94
95                c_attn_proj_w: tl.load(&format!("gpt_neox.layers.{i}.attention.dense.weight"))?,
96                c_attn_proj_b: tl.load(&format!("gpt_neox.layers.{i}.attention.dense.bias"))?,
97
98                ln_2_g: tl.load(&format!(
99                    "gpt_neox.layers.{i}.post_attention_layernorm.weight"
100                ))?,
101                ln_2_b: tl.load(&format!(
102                    "gpt_neox.layers.{i}.post_attention_layernorm.bias"
103                ))?,
104
105                c_mlp_fc_w: tl.load(&format!("gpt_neox.layers.{i}.mlp.dense_h_to_4h.weight"))?,
106                c_mlp_fc_b: tl.load(&format!("gpt_neox.layers.{i}.mlp.dense_h_to_4h.bias"))?,
107
108                c_mlp_proj_w: tl.load(&format!("gpt_neox.layers.{i}.mlp.dense_4h_to_h.weight"))?,
109                c_mlp_proj_b: tl.load(&format!("gpt_neox.layers.{i}.mlp.dense_4h_to_h.bias"))?,
110            };
111
112            layers.push(layer);
113        }
114
115        let (_context, _, _mmap) = tl.finish();
116
117        let ModelParameters {
118            n_context_tokens,
119            inference_parameters,
120            ..
121        } = params;
122
123        Ok(NeoX {
124            hyperparameters,
125            n_context_tokens,
126            vocabulary,
127            ln_f_g,
128            ln_f_b,
129            wte,
130            lmh_g,
131            layers,
132            inference_parameters,
133            _context,
134            _mmap,
135        })
136    }
137
138    fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession {
139        InferenceSession::new(
140            config,
141            self.hyperparameters.n_ctx,
142            self.hyperparameters.n_layer,
143            self.hyperparameters.n_embd,
144            self.hyperparameters.n_vocab,
145        )
146    }
147
148    fn evaluate(
149        &self,
150        session: &mut InferenceSession,
151        params: &InferenceParameters,
152        input_tokens: &[TokenId],
153        output_request: &mut OutputRequest,
154    ) {
155        let n = input_tokens.len();
156        let n_threads = params.n_threads;
157
158        let Hyperparameters {
159            n_embd,
160            n_head,
161            n_vocab,
162            n_layer,
163            n_rot,
164            ..
165        } = self.hyperparameters;
166        let n_ctx = self.n_context_tokens;
167
168        let (ctx0, embd) = common::prepare_for_evaluate(n_layer, session, input_tokens);
169
170        let n_past = session.n_past;
171
172        // wte
173        let mut input_layer = ctx0.op_get_rows(&self.wte, &embd);
174
175        let memory_k = &session.memory_k;
176        let memory_k_size = memory_k.element_size();
177
178        let memory_v = &session.memory_v;
179        let memory_v_size = memory_v.element_size();
180
181        let mut gf = ggml::ComputationGraph::new(n_threads);
182
183        for il in 0..n_layer {
184            // self-attention
185            let mut current = ctx0.op_norm(&input_layer);
186            current = ctx0.op_add(
187                &ctx0.op_mul(&ctx0.op_repeat(&self.layers[il].ln_1_g, &current), &current),
188                &ctx0.op_repeat(&self.layers[il].ln_1_b, &current),
189            );
190
191            // self-attention compute QKV
192            current = ctx0.op_mul_mat(&self.layers[il].c_attn_attn_w, &current);
193            current = ctx0.op_add(
194                &ctx0.op_repeat(&self.layers[il].c_attn_attn_b, &current),
195                &current,
196            );
197
198            let nb = current.get_nb()[1];
199            let f32_size = std::mem::size_of::<f32>();
200            let mut qcur = ctx0.op_cont(&ctx0.op_view_3d(
201                &current,
202                (n_embd / n_head, n_head, n),
203                (nb / n_head, nb),
204                0,
205            ));
206            let mut kcur = ctx0.op_cont(&ctx0.op_view_3d(
207                &current,
208                (n_embd / n_head, n_head, n),
209                (nb / n_head, nb),
210                f32_size * n_embd / n_head,
211            ));
212            let mut vcur = ctx0.op_cont(&ctx0.op_view_3d(
213                &current,
214                (n_embd / n_head, n_head, n),
215                (nb / n_head, nb),
216                2 * f32_size * n_embd / n_head,
217            ));
218
219            // self-attention using mode = 2 for GPT-NeoX mode
220            qcur = ctx0.op_rope(&qcur, n_past, n_rot, 2);
221            kcur = ctx0.op_rope(&kcur, n_past, n_rot, 2);
222
223            // self-attention store key and value to memory
224            vcur = ctx0.op_transpose(&ctx0.op_reshape_2d(&vcur, n_embd, n));
225
226            let little_k = ctx0.op_view_1d(
227                memory_k,
228                n * n_embd,
229                (memory_k_size * n_embd) * (il * n_ctx + n_past),
230            );
231            let little_v = ctx0.op_view_2d(
232                memory_v,
233                (n, n_embd),
234                n_ctx * memory_v_size,
235                (il * n_ctx) * memory_v_size * n_embd + n_past * memory_v_size,
236            );
237
238            gf.build_forward_expand(&ctx0.op_cpy(&kcur, &little_k));
239            gf.build_forward_expand(&ctx0.op_cpy(&vcur, &little_v));
240
241            let q = ctx0.op_permute(&qcur, 0, 2, 1, 3);
242            let big_k = ctx0.op_permute(
243                &ctx0.op_reshape_3d(
244                    &ctx0.op_view_1d(
245                        memory_k,
246                        (n_past + n) * n_embd,
247                        il * n_ctx * memory_k_size * n_embd,
248                    ),
249                    n_embd / n_head,
250                    n_head,
251                    n_past + n,
252                ),
253                0,
254                2,
255                1,
256                3,
257            );
258
259            let kq = ctx0.op_mul_mat(&big_k, &q);
260            let kq_scaled = ctx0.op_scale(
261                &kq,
262                &ctx0.new_f32(1f32 / f32::sqrt(n_embd as f32 / n_head as f32)),
263            );
264
265            let kq_masked = ctx0.op_diag_mask_inf(&kq_scaled, n_past);
266            let kq_softmax = ctx0.op_soft_max(&kq_masked);
267
268            let big_v = ctx0.op_view_3d(
269                memory_v,
270                (n_past + n, n_embd / n_head, n_head),
271                (
272                    n_ctx * memory_v_size,
273                    n_ctx * memory_v_size * n_embd / n_head,
274                ),
275                il * n_ctx * memory_v_size * n_embd,
276            );
277
278            let kqv = ctx0.op_mul_mat(&big_v, &kq_softmax);
279            let kqv_merged = ctx0.op_permute(&kqv, 0, 2, 1, 3);
280
281            current = ctx0.op_cpy(&kqv_merged, &ctx0.new_tensor_2d(ggml::Type::F32, n_embd, n));
282
283            // self-attention projection
284            current = ctx0.op_mul_mat(&self.layers[il].c_attn_proj_w, &current);
285            current = ctx0.op_add(
286                &ctx0.op_repeat(&self.layers[il].c_attn_proj_b, &current),
287                &current,
288            );
289
290            // feed-forward
291            let ff_in = current.share();
292
293            // feed-forward post attention layer norm
294            current = ctx0.op_norm(&input_layer);
295            current = ctx0.op_add(
296                &ctx0.op_mul(&ctx0.op_repeat(&self.layers[il].ln_2_g, &current), &current),
297                &ctx0.op_repeat(&self.layers[il].ln_2_b, &current),
298            );
299
300            current = ctx0.op_mul_mat(&self.layers[il].c_mlp_fc_w, &current);
301            current = ctx0.op_add(
302                &ctx0.op_repeat(&self.layers[il].c_mlp_fc_b, &current),
303                &current,
304            );
305
306            current = ctx0.op_gelu(&current);
307
308            // feed-forward projection
309            current = ctx0.op_mul_mat(&self.layers[il].c_mlp_proj_w, &current);
310            current = ctx0.op_add(
311                &ctx0.op_repeat(&self.layers[il].c_mlp_proj_b, &current),
312                &current,
313            );
314
315            current = ctx0.op_add(&current, &ff_in);
316
317            // input for next layer
318            input_layer = ctx0.op_add(&current, &input_layer);
319        }
320
321        input_layer = ctx0.op_norm(&input_layer);
322        input_layer = ctx0.op_add(
323            &ctx0.op_mul(&ctx0.op_repeat(&self.ln_f_g, &input_layer), &input_layer),
324            &ctx0.op_repeat(&self.ln_f_b, &input_layer),
325        );
326
327        input_layer = ctx0.op_mul_mat(&self.lmh_g, &input_layer);
328
329        // run the computation
330        gf.build_forward_expand(&input_layer);
331        ctx0.graph_compute(&mut gf);
332
333        // finish evaluation
334        common::read_last_token(session, &input_layer, n_vocab, n);
335        common::extract_logits(output_request, &input_layer, n_vocab, n);
336        common::extract_embeddings(output_request, &embd, n_embd, n);
337        common::update_session(session, &ctx0, input_tokens.len(), n);
338    }
339
340    fn vocabulary(&self) -> &Vocabulary {
341        &self.vocabulary
342    }
343
344    fn n_context_tokens(&self) -> usize {
345        self.hyperparameters.n_ctx
346    }
347
348    fn bot_token_id(&self) -> Option<TokenId> {
349        None
350    }
351
352    fn eot_token_id(&self) -> TokenId {
353        self.vocabulary
354            .token_to_id
355            .get("<|endoftext|>".as_bytes())
356            .copied()
357            .unwrap()
358    }
359
360    fn inference_parameters(&self) -> &InferenceParameters {
361        &self.inference_parameters
362    }
363}
364
365/// GPT-NeoX [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning))
366#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
367pub struct Hyperparameters {
368    /// Size of the model's vocabulary
369    pub n_vocab: usize,
370    /// Size of the model's context
371    pub n_ctx: usize,
372    /// Size of the model's embedding layer
373    pub n_embd: usize,
374    /// n_head
375    pub n_head: usize,
376    /// Number of layers in the model
377    pub n_layer: usize,
378    /// n_rot
379    pub n_rot: usize,
380    /// file_type
381    pub file_type: FileType,
382}
383impl llm_base::Hyperparameters for Hyperparameters {
384    fn read_ggml(reader: &mut dyn std::io::BufRead) -> Result<Self, LoadError> {
385        Ok(Hyperparameters {
386            n_vocab: util::read_i32(reader)?.try_into()?,
387            n_ctx: util::read_i32(reader)?.try_into()?,
388            n_embd: util::read_i32(reader)?.try_into()?,
389            n_head: util::read_i32(reader)?.try_into()?,
390            n_layer: util::read_i32(reader)?.try_into()?,
391            n_rot: util::read_i32(reader)?.try_into()?,
392            file_type: {
393                let ftype = util::read_i32(reader)?;
394                FileType::try_from(ftype).map_err(|_| LoadError::UnsupportedFileType(ftype))?
395            },
396        })
397    }
398
399    fn write_ggml(&self, writer: &mut dyn std::io::Write) -> Result<(), HyperparametersWriteError> {
400        util::write_i32(writer, self.n_vocab.try_into()?)?;
401        util::write_i32(writer, self.n_ctx.try_into()?)?;
402        util::write_i32(writer, self.n_embd.try_into()?)?;
403        util::write_i32(writer, self.n_head.try_into()?)?;
404        util::write_i32(writer, self.n_layer.try_into()?)?;
405        util::write_i32(writer, self.n_rot.try_into()?)?;
406        util::write_i32(writer, self.file_type.into())?;
407        Ok(())
408    }
409
410    fn n_vocabulary(&self) -> usize {
411        self.n_vocab
412    }
413}
414
415struct Layer {
416    // pre-normalization
417    ln_1_g: Tensor,
418    ln_1_b: Tensor,
419
420    // attention
421    c_attn_attn_w: Tensor,
422    c_attn_attn_b: Tensor,
423
424    c_attn_proj_w: Tensor,
425    c_attn_proj_b: Tensor,
426
427    // post normalization
428    ln_2_g: Tensor,
429    ln_2_b: Tensor,
430
431    // feed-forward
432    c_mlp_fc_w: Tensor,
433    c_mlp_fc_b: Tensor,
434
435    c_mlp_proj_w: Tensor,
436    c_mlp_proj_b: Tensor,
437}
438
439#[cfg(test)]
440impl NeoX {
441    /// This does *not* construct a valid model. All of the tensors are entirely
442    /// empty. However, it can be used to determine if some code will compile.
443    fn new_empty() -> Self {
444        let context = ggml::Context::init(1024 * 1024, true);
445
446        Self {
447            hyperparameters: Default::default(),
448            n_context_tokens: 0,
449            vocabulary: Default::default(),
450            ln_f_g: context.new_f32(0.0),
451            ln_f_b: context.new_f32(0.0),
452            wte: context.new_f32(0.0),
453            lmh_g: context.new_f32(0.0),
454            layers: Default::default(),
455            inference_parameters: Default::default(),
456            _mmap: Default::default(),
457            _context: context,
458        }
459    }
460}
461
462#[cfg(test)]
463mod tests {
464    use super::*;
465    use std::sync::Arc;
466
467    #[test]
468    fn can_share_model_between_threads() {
469        let model = Arc::new(NeoX::new_empty());
470
471        for _ in 0..4 {
472            let model = model.clone();
473            std::thread::spawn(move || {
474                let _session = model.start_session(Default::default());
475            });
476        }
477
478        let session = model.start_session(Default::default());
479        std::thread::spawn(move || {
480            let _session = session;
481        });
482    }
483}