llm_bloom/
lib.rs

1//! An implementation of [BLOOM](https://huggingface.co/docs/transformers/model_doc/bloom)
2//! for the `llm` ecosystem.
3#![deny(missing_docs)]
4
5use std::path::Path;
6
7use llm_base::{
8    ggml,
9    model::{common, HyperparametersWriteError},
10    util, FileType, InferenceParameters, InferenceSession, InferenceSessionConfig, KnownModel,
11    LoadError, LoadProgress, Mmap, ModelParameters, OutputRequest, TokenId, Vocabulary,
12};
13
14/// The BLOOM model. Ref: [Introducing BLOOM](https://bigscience.huggingface.co/blog/bloom)
15///
16/// # Safety
17/// This implements [Send] and [Sync] as it is immutable after construction.
18pub struct Bloom {
19    hyperparameters: Hyperparameters,
20    n_context_tokens: usize,
21
22    vocabulary: Vocabulary,
23    tok_embeddings: ggml::Tensor,
24    norm: ggml::Tensor,
25    norm_b: ggml::Tensor,
26    output_norm: ggml::Tensor,
27    output_norm_b: ggml::Tensor,
28    output: ggml::Tensor,
29    layers: Vec<Layer>,
30
31    inference_parameters: InferenceParameters,
32
33    // Must be kept alive for the model
34    _context: ggml::Context,
35    _mmap: Option<Mmap>,
36}
37
38unsafe impl Send for Bloom {}
39unsafe impl Sync for Bloom {}
40
41impl Bloom {
42    /// Load a BLOOM model from the `path` and configure it per the `params`. The status
43    /// of the loading process will be reported through `load_progress_callback`. This
44    /// is a helper function on top of [llm_base::load].
45    pub fn load(
46        path: &Path,
47        params: ModelParameters,
48        load_progress_callback: impl FnMut(LoadProgress),
49    ) -> Result<Bloom, LoadError> {
50        llm_base::load(path, params, load_progress_callback)
51    }
52}
53
54impl KnownModel for Bloom {
55    type Hyperparameters = Hyperparameters;
56
57    fn new<E: std::error::Error>(
58        hyperparameters: Self::Hyperparameters,
59        params: ModelParameters,
60        vocabulary: Vocabulary,
61        tensor_loader: impl llm_base::TensorLoader<E>,
62    ) -> Result<Self, E> {
63        let mut tl = tensor_loader;
64
65        let tok_embeddings = tl.load("tok_embeddings.weight")?;
66
67        let norm = tl.load("norm.weight")?;
68        let norm_b = tl.load("norm.bias")?;
69
70        let output_norm = tl.load("output_norm.weight")?;
71        let output_norm_b = tl.load("output_norm.bias")?;
72
73        let output = tl.load("output.weight")?;
74
75        let mut layers = Vec::new();
76        for i in 0..hyperparameters.n_layer {
77            let layer = Layer {
78                attention_norm: tl.load(&format!("layers.{i}.attention_norm.weight"))?,
79                attention_norm_b: tl.load(&format!("layers.{i}.attention_norm.bias"))?,
80
81                query_key_value: tl
82                    .load(&format!("layers.{i}.attention.query_key_value.weight"))?,
83                query_key_value_b: tl
84                    .load(&format!("layers.{i}.attention.query_key_value.bias"))?,
85
86                wo: tl.load(&format!("layers.{i}.attention.wo.weight"))?,
87                wo_b: tl.load(&format!("layers.{i}.attention.wo.bias"))?,
88
89                ffn_norm: tl.load(&format!("layers.{i}.ffn_norm.weight"))?,
90                ffn_norm_b: tl.load(&format!("layers.{i}.ffn_norm.bias"))?,
91
92                w1: tl.load(&format!("layers.{i}.feed_forward.w1.weight"))?,
93                w1_b: tl.load(&format!("layers.{i}.feed_forward.w1.bias"))?,
94                w2: tl.load(&format!("layers.{i}.feed_forward.w2.weight"))?,
95                w2_b: tl.load(&format!("layers.{i}.feed_forward.w2.bias"))?,
96            };
97
98            layers.push(layer);
99        }
100
101        let (_context, _, _mmap) = tl.finish();
102
103        let ModelParameters {
104            n_context_tokens,
105            inference_parameters,
106            ..
107        } = params;
108
109        Ok(Bloom {
110            hyperparameters,
111            n_context_tokens,
112            vocabulary,
113            tok_embeddings,
114            norm,
115            norm_b,
116            output_norm,
117            output_norm_b,
118            output,
119            layers,
120            inference_parameters,
121            _context,
122            _mmap,
123        })
124    }
125
126    fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession {
127        InferenceSession::new(
128            config,
129            self.n_context_tokens,
130            self.hyperparameters.n_layer,
131            self.hyperparameters.n_embd,
132            self.hyperparameters.n_vocab,
133        )
134    }
135
136    fn evaluate(
137        &self,
138        session: &mut InferenceSession,
139        params: &InferenceParameters,
140        input_tokens: &[TokenId],
141        output_request: &mut OutputRequest,
142    ) {
143        let n = input_tokens.len();
144        let n_past = session.n_past;
145        let n_threads = params.n_threads;
146
147        let Hyperparameters {
148            n_vocab,
149            n_embd,
150            n_mult: _,
151            n_head,
152            n_layer,
153            file_type: _,
154        } = self.hyperparameters;
155        let n_ctx = self.n_context_tokens;
156
157        let (ctx0, embd) = common::prepare_for_evaluate(n_layer, session, input_tokens);
158
159        let mut input_layer = ctx0.op_get_rows(&self.tok_embeddings, &embd);
160
161        // word embeddings norm,
162        {
163            input_layer = ctx0.op_norm(&input_layer);
164            input_layer = ctx0.op_mul(&ctx0.op_repeat(&self.norm, &input_layer), &input_layer);
165            input_layer = ctx0.op_add(&ctx0.op_repeat(&self.norm_b, &input_layer), &input_layer);
166        }
167
168        let mut gf = ggml::ComputationGraph::new(n_threads);
169
170        for il in 0..n_layer {
171            let input_self_attention = input_layer.share();
172            let mut current: ggml::Tensor;
173
174            // norm
175            {
176                current = ctx0.op_norm(&input_layer);
177
178                // cur = attention_norm * cur
179                current = ctx0.op_mul(
180                    &ctx0.op_repeat(&self.layers[il].attention_norm, &current),
181                    &current,
182                );
183                current = ctx0.op_add(
184                    &ctx0.op_repeat(&self.layers[il].attention_norm_b, &current),
185                    &current,
186                );
187            }
188
189            //attention
190            {
191                current = ctx0.op_mul_mat(&self.layers[il].query_key_value, &current);
192                current = ctx0.op_add(
193                    &ctx0.op_repeat(&self.layers[il].query_key_value_b, &current),
194                    &current,
195                );
196            }
197
198            // self-attention
199            {
200                let nb = current.get_nb()[1];
201                let q_current = ctx0.op_view_2d(
202                    &current,
203                    (n_embd, n),
204                    nb,
205                    //0 * std::mem::size_of::<f32>() * n_embd as usize,
206                    0,
207                );
208                let k_current = ctx0.op_view_2d(
209                    &current,
210                    (n_embd, n),
211                    nb,
212                    std::mem::size_of::<f32>() * n_embd,
213                );
214                let v_current = ctx0.op_view_2d(
215                    &current,
216                    (n_embd, n),
217                    nb,
218                    2 * std::mem::size_of::<f32>() * n_embd,
219                );
220
221                // store key and value to memory
222                if n >= 1 {
223                    let k = ctx0.op_view_1d(
224                        &session.memory_k,
225                        n * n_embd,
226                        (session.memory_k.element_size() * n_embd) * (il * n_ctx + n_past),
227                    );
228
229                    let v = ctx0.op_view_1d(
230                        &session.memory_v,
231                        n * n_embd,
232                        (session.memory_v.element_size() * n_embd) * (il * n_ctx + n_past),
233                    );
234
235                    gf.build_forward_expand(&ctx0.op_cpy(&k_current, &k));
236                    gf.build_forward_expand(&ctx0.op_cpy(&v_current, &v));
237                }
238
239                // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
240                let big_q = ctx0.op_permute(
241                    &ctx0.op_cpy(
242                        &q_current,
243                        &ctx0.new_tensor_3d(ggml::Type::F32, n_embd / n_head, n_head, n),
244                    ),
245                    0,
246                    2,
247                    1,
248                    3,
249                );
250
251                // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3)
252                let big_k = ctx0.op_permute(
253                    &ctx0.op_reshape_3d(
254                        &ctx0.op_view_1d(
255                            &session.memory_k,
256                            (n_past + n) * n_embd,
257                            il * n_ctx * session.memory_k.element_size() * n_embd,
258                        ),
259                        n_embd / n_head,
260                        n_head,
261                        n_past + n,
262                    ),
263                    0,
264                    2,
265                    1,
266                    3,
267                );
268
269                // K * Q
270                let k_q = ctx0.op_mul_mat(&big_k, &big_q);
271
272                // KQ_scaled = KQ / sqrt(n_embd/n_head)
273                let k_q_scaled = ctx0.op_scale(
274                    &k_q,
275                    &ctx0.new_f32(1.0 / f32::sqrt(n_embd as f32 / n_head as f32)),
276                );
277
278                //alibi
279                // KQ_scaled_alibi = KQ_scaled + alibi_bias
280                let k_q_scaled_alibi = ctx0.op_alibi(&k_q_scaled, n_past, n_head);
281
282                // KQ_masked = mask_past(KQ_scaled)
283                let k_q_masked = ctx0.op_diag_mask_inf(&k_q_scaled_alibi, n_past);
284
285                // KQ = soft_max(KQ_masked)
286                let k_q_soft_max = ctx0.op_soft_max(&k_q_masked);
287
288                let memv_elsize = session.memory_v.element_size();
289
290                let v_trans = ctx0.op_cpy(
291                    &ctx0.op_permute(
292                        &ctx0.op_reshape_3d(
293                            &ctx0.op_view_1d(
294                                &session.memory_v,
295                                (n_past + n) * n_embd,
296                                il * n_ctx * memv_elsize * n_embd,
297                            ),
298                            n_embd / n_head,
299                            n_head,
300                            n_past + n,
301                        ),
302                        1,
303                        2,
304                        0,
305                        3,
306                    ),
307                    &ctx0.new_tensor_3d(
308                        session.memory_v.get_type(),
309                        n_past + n,
310                        n_embd / n_head,
311                        n_head,
312                    ),
313                );
314
315                let k_q_v = ctx0.op_mul_mat(&v_trans, &k_q_soft_max);
316
317                // KQV_merged = KQV.permute(0, 2, 1, 3)
318                let k_q_v_merged = ctx0.op_permute(&k_q_v, 0, 2, 1, 3);
319
320                // cur = KQV_merged.contiguous().view(n_embd, N)
321                current = ctx0.op_cpy(
322                    &k_q_v_merged,
323                    &ctx0.new_tensor_2d(ggml::Type::F32, n_embd, n),
324                );
325
326                // projection
327                current = ctx0.op_mul_mat(&self.layers[il].wo, &current);
328                current = ctx0.op_add(&ctx0.op_repeat(&self.layers[il].wo_b, &current), &current);
329            }
330
331            let input_feed_forward = ctx0.op_add(&current, &input_self_attention);
332
333            // feed-forward network
334            {
335                // norm
336                {
337                    current = ctx0.op_norm(&input_feed_forward);
338
339                    // cur = ffn_norm*cur + ffn_norm_b
340                    current = ctx0.op_mul(
341                        &ctx0.op_repeat(&self.layers[il].ffn_norm, &current),
342                        &current,
343                    );
344
345                    current = ctx0.op_add(
346                        &ctx0.op_repeat(&self.layers[il].ffn_norm_b, &current),
347                        &current,
348                    );
349                }
350
351                current = ctx0.op_mul_mat(&self.layers[il].w1, &current);
352
353                current = ctx0.op_add(&ctx0.op_repeat(&self.layers[il].w1_b, &current), &current);
354
355                // SILU activation
356
357                current = ctx0.op_gelu(&current);
358
359                current = ctx0.op_mul_mat(&self.layers[il].w2, &current);
360
361                current = ctx0.op_add(&ctx0.op_repeat(&self.layers[il].w2_b, &current), &current);
362            }
363
364            current = ctx0.op_add(&current, &input_feed_forward);
365
366            // input for next layer
367            input_layer = current;
368        }
369
370        // norm
371        {
372            input_layer = ctx0.op_norm(&input_layer);
373
374            // inpL = norm*inpL
375            input_layer = ctx0.op_mul(
376                &ctx0.op_repeat(&self.output_norm, &input_layer),
377                &input_layer,
378            );
379
380            input_layer = ctx0.op_add(
381                &ctx0.op_repeat(&self.output_norm_b, &input_layer),
382                &input_layer,
383            );
384        }
385
386        // lm_head
387        {
388            input_layer = ctx0.op_mul_mat(&self.output, &input_layer);
389        }
390
391        // run the computation
392        gf.build_forward_expand(&input_layer);
393        ctx0.graph_compute(&mut gf);
394
395        // finish evaluation
396        common::read_last_token(session, &input_layer, n_vocab, n);
397        common::extract_logits(output_request, &input_layer, n_vocab, n);
398        common::extract_embeddings(output_request, &embd, n_embd, n);
399        common::update_session(session, &ctx0, input_tokens.len(), n);
400    }
401
402    /// Returns the vocabulary used by this model.
403    fn vocabulary(&self) -> &Vocabulary {
404        &self.vocabulary
405    }
406
407    fn n_context_tokens(&self) -> usize {
408        self.n_context_tokens
409    }
410
411    fn bot_token_id(&self) -> Option<TokenId> {
412        self.vocabulary.token_to_id.get("<s>".as_bytes()).copied()
413    }
414
415    fn eot_token_id(&self) -> TokenId {
416        self.vocabulary
417            .token_to_id
418            .get("</s>".as_bytes())
419            .copied()
420            .unwrap()
421    }
422
423    fn inference_parameters(&self) -> &InferenceParameters {
424        &self.inference_parameters
425    }
426}
427
428/// BLOOM [hyperparameters](https://en.wikipedia.org/wiki/Hyperparameter_(machine_learning))
429#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
430pub struct Hyperparameters {
431    /// Size of the model's vocabulary
432    pub n_vocab: usize,
433    /// Size of the model's embedding layer
434    pub n_embd: usize,
435    /// n_mult
436    pub n_mult: usize,
437    /// n_head
438    pub n_head: usize,
439    /// Number of layers in the model
440    pub n_layer: usize,
441    /// file_type
442    pub file_type: FileType,
443}
444impl llm_base::Hyperparameters for Hyperparameters {
445    fn read_ggml(reader: &mut dyn std::io::BufRead) -> Result<Self, llm_base::LoadError> {
446        // NOTE: Field order matters! Data is laid out in the file exactly
447        // in this order.
448        Ok(Hyperparameters {
449            n_vocab: util::read_i32(reader)?.try_into()?,
450            n_embd: util::read_i32(reader)?.try_into()?,
451            n_mult: util::read_i32(reader)?.try_into()?,
452            n_head: util::read_i32(reader)?.try_into()?,
453            n_layer: util::read_i32(reader)?.try_into()?,
454            file_type: {
455                let ftype = util::read_i32(reader)?;
456                FileType::try_from(ftype).map_err(|_| LoadError::UnsupportedFileType(ftype))?
457            },
458        })
459    }
460
461    fn write_ggml(&self, writer: &mut dyn std::io::Write) -> Result<(), HyperparametersWriteError> {
462        util::write_i32(writer, self.n_vocab.try_into()?)?;
463        util::write_i32(writer, self.n_embd.try_into()?)?;
464        util::write_i32(writer, self.n_mult.try_into()?)?;
465        util::write_i32(writer, self.n_head.try_into()?)?;
466        util::write_i32(writer, self.n_layer.try_into()?)?;
467        util::write_i32(writer, self.file_type.into())?;
468        Ok(())
469    }
470
471    fn n_vocabulary(&self) -> usize {
472        self.n_vocab
473    }
474}
475
476struct Layer {
477    pub attention_norm: ggml::Tensor,
478    pub attention_norm_b: ggml::Tensor,
479    pub wo: ggml::Tensor,
480    pub wo_b: ggml::Tensor,
481    pub query_key_value: ggml::Tensor,
482    pub query_key_value_b: ggml::Tensor,
483    // normalization
484    pub ffn_norm: ggml::Tensor,
485    pub ffn_norm_b: ggml::Tensor,
486    // ff
487    pub w1: ggml::Tensor,
488    pub w1_b: ggml::Tensor,
489    pub w2: ggml::Tensor,
490    pub w2_b: ggml::Tensor,
491}