llm_llama/
lib.rs

1//! An implementation of [LLaMA](https://huggingface.co/docs/transformers/model_doc/llama) for the `llm` ecosystem.
2#![deny(missing_docs)]
3
4use std::{error::Error, path::Path};
5
6use llm_base::{
7    ggml,
8    model::{common, HyperparametersWriteError},
9    util, FileType, InferenceParameters, InferenceSession, InferenceSessionConfig, KnownModel,
10    LoadError, LoadProgress, Mmap, ModelParameters, OutputRequest, TensorLoader, TokenId,
11    Vocabulary,
12};
13
14#[cfg(feature = "convert")]
15pub mod convert;
16
17mod old_loader;
18
19/// The LLaMA model. Ref: [Introducing LLaMA](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/)
20///
21/// # Safety
22/// This implements [Send] and [Sync] as it is immutable after construction.
23pub struct Llama {
24    hyperparameters: Hyperparameters,
25    n_context_tokens: usize,
26
27    vocabulary: Vocabulary,
28
29    tok_embeddings: ggml::Tensor,
30
31    norm: ggml::Tensor,
32    output: ggml::Tensor,
33
34    layers: Vec<Layer>,
35
36    inference_parameters: InferenceParameters,
37
38    /// Needs to kept alive while the model is alive
39    _mmap: Option<Mmap>,
40
41    // Must be kept alive for the model
42    _context: ggml::Context,
43}
44
45unsafe impl Send for Llama {}
46unsafe impl Sync for Llama {}
47
48impl Llama {
49    /// Load a LLaMA model from the `path` and configure it per the `params`. The status
50    /// of the loading process will be reported through `load_progress_callback`. This
51    /// is a helper function on top of [llm_base::load].
52    pub fn load(
53        path: &Path,
54        params: ModelParameters,
55        load_progress_callback: impl FnMut(LoadProgress),
56    ) -> Result<Llama, LoadError> {
57        llm_base::load(path, params, load_progress_callback)
58    }
59}
60
61impl KnownModel for Llama {
62    type Hyperparameters = Hyperparameters;
63
64    fn new<E: Error>(
65        hyperparameters: Self::Hyperparameters,
66        params: ModelParameters,
67        vocabulary: Vocabulary,
68        tensor_loader: impl TensorLoader<E>,
69    ) -> Result<Self, E> {
70        let mut tl = tensor_loader;
71
72        let tok_embeddings = tl.load("tok_embeddings.weight")?;
73        let norm = tl.load("norm.weight")?;
74        let output = tl.load("output.weight")?;
75
76        let mut layers = Vec::new();
77        for i in 0..hyperparameters.n_layer {
78            let layer = Layer {
79                attention_norm: tl.load(&format!("layers.{i}.attention_norm.weight"))?,
80                wq: tl.load(&format!("layers.{i}.attention.wq.weight"))?,
81                wk: tl.load(&format!("layers.{i}.attention.wk.weight"))?,
82                wv: tl.load(&format!("layers.{i}.attention.wv.weight"))?,
83                wo: tl.load(&format!("layers.{i}.attention.wo.weight"))?,
84                ffn_norm: tl.load(&format!("layers.{i}.ffn_norm.weight"))?,
85                w1: tl.load(&format!("layers.{i}.feed_forward.w1.weight"))?,
86                w2: tl.load(&format!("layers.{i}.feed_forward.w2.weight"))?,
87                w3: tl.load(&format!("layers.{i}.feed_forward.w3.weight"))?,
88            };
89
90            layers.push(layer);
91        }
92
93        let (_context, _tensors, _mmap) = tl.finish();
94
95        let ModelParameters {
96            n_context_tokens,
97            inference_parameters,
98            ..
99        } = params;
100
101        Ok(Self {
102            hyperparameters,
103            n_context_tokens,
104            vocabulary,
105            tok_embeddings,
106            norm,
107            output,
108            layers,
109            inference_parameters,
110            _context,
111            _mmap,
112        })
113    }
114
115    /// Starts a new `InferenceSession` for this model.
116    fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession {
117        InferenceSession::new(
118            config,
119            self.n_context_tokens,
120            self.hyperparameters.n_layer,
121            self.hyperparameters.n_embd,
122            self.hyperparameters.n_vocab,
123        )
124    }
125
126    fn evaluate(
127        &self,
128        session: &mut InferenceSession,
129        params: &InferenceParameters,
130        input_tokens: &[TokenId],
131        output_request: &mut OutputRequest,
132    ) {
133        let n = input_tokens.len();
134        let n_past = session.n_past;
135        let n_threads = params.n_threads;
136
137        let memk_elsize = session.memory_k.element_size();
138        let memv_elsize = session.memory_v.element_size();
139
140        let Hyperparameters {
141            n_vocab,
142            n_embd,
143            n_mult: _,
144            n_head,
145            n_layer,
146            n_rot,
147            file_type: _,
148        } = self.hyperparameters;
149        let n_ctx = self.n_context_tokens;
150
151        let (ctx0, embd) = common::prepare_for_evaluate(n_layer, session, input_tokens);
152
153        let mut input_layer = ctx0.op_get_rows(&self.tok_embeddings, &embd);
154
155        let mut gf = ggml::ComputationGraph::new(n_threads);
156
157        for il in 0..n_layer {
158            let input_self_attention = input_layer.share();
159            let mut current: ggml::Tensor;
160
161            ctx0.use_scratch(Some(&mut session.scratch[0]));
162
163            // norm
164            {
165                current = ctx0.op_rms_norm(&input_layer);
166
167                // cur = attention_norm * cur
168                current = ctx0.op_mul(
169                    &ctx0.op_repeat(&self.layers[il].attention_norm, &current),
170                    &current,
171                );
172            }
173
174            // self-attention
175            {
176                // compute Q and K and RoPE them
177                let q_current = ctx0.op_rope(
178                    &ctx0.op_reshape_3d(
179                        &ctx0.op_mul_mat(&self.layers[il].wq, &current),
180                        n_embd / n_head,
181                        n_head,
182                        n,
183                    ),
184                    n_past,
185                    n_rot,
186                    0,
187                );
188                let k_current = ctx0.op_rope(
189                    &ctx0.op_reshape_3d(
190                        &ctx0.op_mul_mat(&self.layers[il].wk, &current),
191                        n_embd / n_head,
192                        n_head,
193                        n,
194                    ),
195                    n_past,
196                    n_rot,
197                    0,
198                );
199
200                // store key and value to memory
201                {
202                    // compute the transposed [N, n_embd] V matrix
203                    let v_current = ctx0.op_transpose(&ctx0.op_reshape_2d(
204                        &ctx0.op_mul_mat(&self.layers[il].wv, &current),
205                        n_embd,
206                        n,
207                    ));
208
209                    let k = ctx0.op_view_1d(
210                        &session.memory_k,
211                        n * n_embd,
212                        (memk_elsize * n_embd) * (il * n_ctx + n_past),
213                    );
214
215                    let v = ctx0.op_view_2d(
216                        &session.memory_v,
217                        (n, n_embd),
218                        n_ctx * memv_elsize,
219                        (il * n_ctx) * memv_elsize * n_embd + n_past * memv_elsize,
220                    );
221
222                    // important: storing RoPE-ed version of K in the KV cache!
223                    gf.build_forward_expand(&ctx0.op_cpy(&k_current, &k));
224                    gf.build_forward_expand(&ctx0.op_cpy(&v_current, &v));
225                }
226
227                let q = ctx0.op_permute(&q_current, 0, 2, 1, 3);
228
229                let k = ctx0.op_permute(
230                    &ctx0.op_reshape_3d(
231                        &ctx0.op_view_1d(
232                            &session.memory_k,
233                            (n_past + n) * n_embd,
234                            il * n_ctx * memk_elsize * n_embd,
235                        ),
236                        n_embd / n_head,
237                        n_head,
238                        n_past + n,
239                    ),
240                    0,
241                    2,
242                    1,
243                    3,
244                );
245
246                // K * Q
247                let k_q = ctx0.op_mul_mat(&k, &q);
248
249                // KQ_scaled = KQ / sqrt(n_embd/n_head)
250                let k_q_scaled = ctx0.op_scale(
251                    &k_q,
252                    &ctx0.new_f32(1.0 / f32::sqrt(n_embd as f32 / n_head as f32)),
253                );
254
255                // KQ_masked = mask_past(KQ_scaled)
256                let k_q_masked = ctx0.op_diag_mask_inf(&k_q_scaled, n_past);
257
258                // KQ = soft_max(KQ_masked)
259                let k_q_soft_max = ctx0.op_soft_max(&k_q_masked);
260
261                // split cached V into n_head heads
262                let v = ctx0.op_view_3d(
263                    &session.memory_v,
264                    (n_past + n, n_embd / n_head, n_head),
265                    (n_ctx * memv_elsize, n_ctx * memv_elsize * n_embd / n_head),
266                    il * n_ctx * memv_elsize * n_embd,
267                );
268
269                let k_q_v = ctx0.op_mul_mat(&v, &k_q_soft_max);
270
271                // KQV_merged = KQV.permute(0, 2, 1, 3)
272                let k_q_v_merged = ctx0.op_permute(&k_q_v, 0, 2, 1, 3);
273
274                // cur = KQV_merged.contiguous().view(n_embd, N)
275                current = ctx0.op_cpy(
276                    &k_q_v_merged,
277                    &ctx0.new_tensor_2d(ggml::Type::F32, n_embd, n),
278                );
279
280                // projection (no bias)
281                current = ctx0.op_mul_mat(&self.layers[il].wo, &current);
282            }
283
284            ctx0.use_scratch(Some(&mut session.scratch[1]));
285
286            let input_feed_forward = ctx0.op_add(&current, &input_self_attention);
287
288            // feed-forward network
289            {
290                // norm
291                {
292                    current = ctx0.op_rms_norm(&input_feed_forward);
293
294                    // cur = ffn_norm*cur
295                    current = ctx0.op_mul(
296                        &ctx0.op_repeat(&self.layers[il].ffn_norm, &current),
297                        &current,
298                    );
299                }
300
301                let tmp = ctx0.op_mul_mat(&self.layers[il].w3, &current);
302
303                current = ctx0.op_mul_mat(&self.layers[il].w1, &current);
304
305                // SILU activation
306                current = ctx0.op_silu(&current);
307
308                current = ctx0.op_mul(&current, &tmp);
309
310                current = ctx0.op_mul_mat(&self.layers[il].w2, &current);
311            }
312
313            current = ctx0.op_add(&current, &input_feed_forward);
314
315            // input for next layer
316            input_layer = current;
317        }
318
319        ctx0.use_scratch(Some(&mut session.scratch[0]));
320
321        // Used at the end to optionally extract the embeddings.
322
323        // norm
324        {
325            input_layer = ctx0.op_rms_norm(&input_layer);
326
327            // inpL = norm*inpL
328            input_layer = ctx0.op_mul(&ctx0.op_repeat(&self.norm, &input_layer), &input_layer);
329        }
330
331        // lm_head
332        {
333            input_layer = ctx0.op_mul_mat(&self.output, &input_layer);
334        }
335
336        ctx0.use_scratch(None);
337
338        // run the computation
339        gf.build_forward_expand(&input_layer);
340        ctx0.graph_compute(&mut gf);
341
342        // finish evaluation
343        common::read_last_token(session, &input_layer, n_vocab, n);
344        common::extract_logits(output_request, &input_layer, n_vocab, n);
345        common::extract_embeddings(output_request, &embd, n_embd, n);
346        common::update_session(session, &ctx0, input_tokens.len(), n);
347    }
348
349    /// Returns the vocabulary used by this model.
350    fn vocabulary(&self) -> &Vocabulary {
351        &self.vocabulary
352    }
353
354    fn n_context_tokens(&self) -> usize {
355        self.n_context_tokens
356    }
357
358    fn bot_token_id(&self) -> Option<TokenId> {
359        None
360    }
361
362    fn eot_token_id(&self) -> TokenId {
363        2
364    }
365
366    fn inference_parameters(&self) -> &InferenceParameters {
367        &self.inference_parameters
368    }
369}
370#[cfg(test)]
371impl Llama {
372    /// This does *not* construct a valid model. All of the tensors are entirely
373    /// empty. However, it can be used to determine if some code will compile.
374    fn new_empty() -> Self {
375        let context = ggml::Context::init(1024 * 1024, true);
376        let tok_embeddings = context.new_f32(0.0);
377        let norm = context.new_f32(0.0);
378        let output = context.new_f32(0.0);
379
380        Self {
381            hyperparameters: Default::default(),
382            n_context_tokens: 0,
383            vocabulary: Default::default(),
384            tok_embeddings,
385            norm,
386            output,
387            layers: Default::default(),
388            _mmap: Default::default(),
389            _context: context,
390            inference_parameters: Default::default(),
391        }
392    }
393}
394
395/// LLaMA [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning))
396#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
397pub struct Hyperparameters {
398    /// Size of the model's vocabulary
399    pub n_vocab: usize,
400    /// Size of the model's embedding layer
401    pub n_embd: usize,
402    /// n_mult
403    pub n_mult: usize,
404    /// n_head
405    pub n_head: usize,
406    /// Number of layers in the model
407    pub n_layer: usize,
408    /// n_rot
409    pub n_rot: usize,
410    /// file_type
411    pub file_type: FileType,
412}
413impl llm_base::Hyperparameters for Hyperparameters {
414    fn read_ggml(reader: &mut dyn std::io::BufRead) -> Result<Self, LoadError> {
415        Ok(Hyperparameters {
416            n_vocab: util::read_i32(reader)?.try_into()?,
417            n_embd: util::read_i32(reader)?.try_into()?,
418            n_mult: util::read_i32(reader)?.try_into()?,
419            n_head: util::read_i32(reader)?.try_into()?,
420            n_layer: util::read_i32(reader)?.try_into()?,
421            n_rot: util::read_i32(reader)?.try_into()?,
422            file_type: {
423                let ftype = util::read_i32(reader)?;
424                FileType::try_from(ftype).map_err(|_| LoadError::UnsupportedFileType(ftype))?
425            },
426        })
427    }
428
429    fn write_ggml(&self, writer: &mut dyn std::io::Write) -> Result<(), HyperparametersWriteError> {
430        util::write_i32(writer, self.n_vocab.try_into()?)?;
431        util::write_i32(writer, self.n_embd.try_into()?)?;
432        util::write_i32(writer, self.n_mult.try_into()?)?;
433        util::write_i32(writer, self.n_head.try_into()?)?;
434        util::write_i32(writer, self.n_layer.try_into()?)?;
435        util::write_i32(writer, self.n_rot.try_into()?)?;
436        util::write_i32(writer, self.file_type.into())?;
437        Ok(())
438    }
439
440    fn n_vocabulary(&self) -> usize {
441        self.n_vocab
442    }
443}
444
445struct Layer {
446    attention_norm: ggml::Tensor,
447
448    wq: ggml::Tensor,
449    wk: ggml::Tensor,
450    wv: ggml::Tensor,
451    wo: ggml::Tensor,
452
453    // normalization
454    ffn_norm: ggml::Tensor,
455
456    // ff
457    w1: ggml::Tensor,
458    w2: ggml::Tensor,
459    w3: ggml::Tensor,
460}
461
462#[cfg(test)]
463mod tests {
464    use super::*;
465    use std::sync::Arc;
466
467    #[test]
468    fn can_share_model_between_threads() {
469        let model = Arc::new(Llama::new_empty());
470
471        for _ in 0..4 {
472            let model = model.clone();
473            std::thread::spawn(move || {
474                let _session = model.start_session(Default::default());
475            });
476        }
477
478        let session = model.start_session(Default::default());
479        std::thread::spawn(move || {
480            let _session = session;
481        });
482    }
483}