1#![deny(missing_docs)]
4
5use std::path::Path;
6
7use llm_base::{
8 ggml,
9 model::{common, HyperparametersWriteError},
10 util, FileType, InferenceParameters, InferenceSession, InferenceSessionConfig, KnownModel,
11 LoadError, LoadProgress, Mmap, ModelParameters, OutputRequest, TokenId, Vocabulary,
12};
13
14pub struct Bloom {
19 hyperparameters: Hyperparameters,
20 n_context_tokens: usize,
21
22 vocabulary: Vocabulary,
23 tok_embeddings: ggml::Tensor,
24 norm: ggml::Tensor,
25 norm_b: ggml::Tensor,
26 output_norm: ggml::Tensor,
27 output_norm_b: ggml::Tensor,
28 output: ggml::Tensor,
29 layers: Vec<Layer>,
30
31 inference_parameters: InferenceParameters,
32
33 _context: ggml::Context,
35 _mmap: Option<Mmap>,
36}
37
38unsafe impl Send for Bloom {}
39unsafe impl Sync for Bloom {}
40
41impl Bloom {
42 pub fn load(
46 path: &Path,
47 params: ModelParameters,
48 load_progress_callback: impl FnMut(LoadProgress),
49 ) -> Result<Bloom, LoadError> {
50 llm_base::load(path, params, load_progress_callback)
51 }
52}
53
54impl KnownModel for Bloom {
55 type Hyperparameters = Hyperparameters;
56
57 fn new<E: std::error::Error>(
58 hyperparameters: Self::Hyperparameters,
59 params: ModelParameters,
60 vocabulary: Vocabulary,
61 tensor_loader: impl llm_base::TensorLoader<E>,
62 ) -> Result<Self, E> {
63 let mut tl = tensor_loader;
64
65 let tok_embeddings = tl.load("tok_embeddings.weight")?;
66
67 let norm = tl.load("norm.weight")?;
68 let norm_b = tl.load("norm.bias")?;
69
70 let output_norm = tl.load("output_norm.weight")?;
71 let output_norm_b = tl.load("output_norm.bias")?;
72
73 let output = tl.load("output.weight")?;
74
75 let mut layers = Vec::new();
76 for i in 0..hyperparameters.n_layer {
77 let layer = Layer {
78 attention_norm: tl.load(&format!("layers.{i}.attention_norm.weight"))?,
79 attention_norm_b: tl.load(&format!("layers.{i}.attention_norm.bias"))?,
80
81 query_key_value: tl
82 .load(&format!("layers.{i}.attention.query_key_value.weight"))?,
83 query_key_value_b: tl
84 .load(&format!("layers.{i}.attention.query_key_value.bias"))?,
85
86 wo: tl.load(&format!("layers.{i}.attention.wo.weight"))?,
87 wo_b: tl.load(&format!("layers.{i}.attention.wo.bias"))?,
88
89 ffn_norm: tl.load(&format!("layers.{i}.ffn_norm.weight"))?,
90 ffn_norm_b: tl.load(&format!("layers.{i}.ffn_norm.bias"))?,
91
92 w1: tl.load(&format!("layers.{i}.feed_forward.w1.weight"))?,
93 w1_b: tl.load(&format!("layers.{i}.feed_forward.w1.bias"))?,
94 w2: tl.load(&format!("layers.{i}.feed_forward.w2.weight"))?,
95 w2_b: tl.load(&format!("layers.{i}.feed_forward.w2.bias"))?,
96 };
97
98 layers.push(layer);
99 }
100
101 let (_context, _, _mmap) = tl.finish();
102
103 let ModelParameters {
104 n_context_tokens,
105 inference_parameters,
106 ..
107 } = params;
108
109 Ok(Bloom {
110 hyperparameters,
111 n_context_tokens,
112 vocabulary,
113 tok_embeddings,
114 norm,
115 norm_b,
116 output_norm,
117 output_norm_b,
118 output,
119 layers,
120 inference_parameters,
121 _context,
122 _mmap,
123 })
124 }
125
126 fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession {
127 InferenceSession::new(
128 config,
129 self.n_context_tokens,
130 self.hyperparameters.n_layer,
131 self.hyperparameters.n_embd,
132 self.hyperparameters.n_vocab,
133 )
134 }
135
136 fn evaluate(
137 &self,
138 session: &mut InferenceSession,
139 params: &InferenceParameters,
140 input_tokens: &[TokenId],
141 output_request: &mut OutputRequest,
142 ) {
143 let n = input_tokens.len();
144 let n_past = session.n_past;
145 let n_threads = params.n_threads;
146
147 let Hyperparameters {
148 n_vocab,
149 n_embd,
150 n_mult: _,
151 n_head,
152 n_layer,
153 file_type: _,
154 } = self.hyperparameters;
155 let n_ctx = self.n_context_tokens;
156
157 let (ctx0, embd) = common::prepare_for_evaluate(n_layer, session, input_tokens);
158
159 let mut input_layer = ctx0.op_get_rows(&self.tok_embeddings, &embd);
160
161 {
163 input_layer = ctx0.op_norm(&input_layer);
164 input_layer = ctx0.op_mul(&ctx0.op_repeat(&self.norm, &input_layer), &input_layer);
165 input_layer = ctx0.op_add(&ctx0.op_repeat(&self.norm_b, &input_layer), &input_layer);
166 }
167
168 let mut gf = ggml::ComputationGraph::new(n_threads);
169
170 for il in 0..n_layer {
171 let input_self_attention = input_layer.share();
172 let mut current: ggml::Tensor;
173
174 {
176 current = ctx0.op_norm(&input_layer);
177
178 current = ctx0.op_mul(
180 &ctx0.op_repeat(&self.layers[il].attention_norm, ¤t),
181 ¤t,
182 );
183 current = ctx0.op_add(
184 &ctx0.op_repeat(&self.layers[il].attention_norm_b, ¤t),
185 ¤t,
186 );
187 }
188
189 {
191 current = ctx0.op_mul_mat(&self.layers[il].query_key_value, ¤t);
192 current = ctx0.op_add(
193 &ctx0.op_repeat(&self.layers[il].query_key_value_b, ¤t),
194 ¤t,
195 );
196 }
197
198 {
200 let nb = current.get_nb()[1];
201 let q_current = ctx0.op_view_2d(
202 ¤t,
203 (n_embd, n),
204 nb,
205 0,
207 );
208 let k_current = ctx0.op_view_2d(
209 ¤t,
210 (n_embd, n),
211 nb,
212 std::mem::size_of::<f32>() * n_embd,
213 );
214 let v_current = ctx0.op_view_2d(
215 ¤t,
216 (n_embd, n),
217 nb,
218 2 * std::mem::size_of::<f32>() * n_embd,
219 );
220
221 if n >= 1 {
223 let k = ctx0.op_view_1d(
224 &session.memory_k,
225 n * n_embd,
226 (session.memory_k.element_size() * n_embd) * (il * n_ctx + n_past),
227 );
228
229 let v = ctx0.op_view_1d(
230 &session.memory_v,
231 n * n_embd,
232 (session.memory_v.element_size() * n_embd) * (il * n_ctx + n_past),
233 );
234
235 gf.build_forward_expand(&ctx0.op_cpy(&k_current, &k));
236 gf.build_forward_expand(&ctx0.op_cpy(&v_current, &v));
237 }
238
239 let big_q = ctx0.op_permute(
241 &ctx0.op_cpy(
242 &q_current,
243 &ctx0.new_tensor_3d(ggml::Type::F32, n_embd / n_head, n_head, n),
244 ),
245 0,
246 2,
247 1,
248 3,
249 );
250
251 let big_k = ctx0.op_permute(
253 &ctx0.op_reshape_3d(
254 &ctx0.op_view_1d(
255 &session.memory_k,
256 (n_past + n) * n_embd,
257 il * n_ctx * session.memory_k.element_size() * n_embd,
258 ),
259 n_embd / n_head,
260 n_head,
261 n_past + n,
262 ),
263 0,
264 2,
265 1,
266 3,
267 );
268
269 let k_q = ctx0.op_mul_mat(&big_k, &big_q);
271
272 let k_q_scaled = ctx0.op_scale(
274 &k_q,
275 &ctx0.new_f32(1.0 / f32::sqrt(n_embd as f32 / n_head as f32)),
276 );
277
278 let k_q_scaled_alibi = ctx0.op_alibi(&k_q_scaled, n_past, n_head);
281
282 let k_q_masked = ctx0.op_diag_mask_inf(&k_q_scaled_alibi, n_past);
284
285 let k_q_soft_max = ctx0.op_soft_max(&k_q_masked);
287
288 let memv_elsize = session.memory_v.element_size();
289
290 let v_trans = ctx0.op_cpy(
291 &ctx0.op_permute(
292 &ctx0.op_reshape_3d(
293 &ctx0.op_view_1d(
294 &session.memory_v,
295 (n_past + n) * n_embd,
296 il * n_ctx * memv_elsize * n_embd,
297 ),
298 n_embd / n_head,
299 n_head,
300 n_past + n,
301 ),
302 1,
303 2,
304 0,
305 3,
306 ),
307 &ctx0.new_tensor_3d(
308 session.memory_v.get_type(),
309 n_past + n,
310 n_embd / n_head,
311 n_head,
312 ),
313 );
314
315 let k_q_v = ctx0.op_mul_mat(&v_trans, &k_q_soft_max);
316
317 let k_q_v_merged = ctx0.op_permute(&k_q_v, 0, 2, 1, 3);
319
320 current = ctx0.op_cpy(
322 &k_q_v_merged,
323 &ctx0.new_tensor_2d(ggml::Type::F32, n_embd, n),
324 );
325
326 current = ctx0.op_mul_mat(&self.layers[il].wo, ¤t);
328 current = ctx0.op_add(&ctx0.op_repeat(&self.layers[il].wo_b, ¤t), ¤t);
329 }
330
331 let input_feed_forward = ctx0.op_add(¤t, &input_self_attention);
332
333 {
335 {
337 current = ctx0.op_norm(&input_feed_forward);
338
339 current = ctx0.op_mul(
341 &ctx0.op_repeat(&self.layers[il].ffn_norm, ¤t),
342 ¤t,
343 );
344
345 current = ctx0.op_add(
346 &ctx0.op_repeat(&self.layers[il].ffn_norm_b, ¤t),
347 ¤t,
348 );
349 }
350
351 current = ctx0.op_mul_mat(&self.layers[il].w1, ¤t);
352
353 current = ctx0.op_add(&ctx0.op_repeat(&self.layers[il].w1_b, ¤t), ¤t);
354
355 current = ctx0.op_gelu(¤t);
358
359 current = ctx0.op_mul_mat(&self.layers[il].w2, ¤t);
360
361 current = ctx0.op_add(&ctx0.op_repeat(&self.layers[il].w2_b, ¤t), ¤t);
362 }
363
364 current = ctx0.op_add(¤t, &input_feed_forward);
365
366 input_layer = current;
368 }
369
370 {
372 input_layer = ctx0.op_norm(&input_layer);
373
374 input_layer = ctx0.op_mul(
376 &ctx0.op_repeat(&self.output_norm, &input_layer),
377 &input_layer,
378 );
379
380 input_layer = ctx0.op_add(
381 &ctx0.op_repeat(&self.output_norm_b, &input_layer),
382 &input_layer,
383 );
384 }
385
386 {
388 input_layer = ctx0.op_mul_mat(&self.output, &input_layer);
389 }
390
391 gf.build_forward_expand(&input_layer);
393 ctx0.graph_compute(&mut gf);
394
395 common::read_last_token(session, &input_layer, n_vocab, n);
397 common::extract_logits(output_request, &input_layer, n_vocab, n);
398 common::extract_embeddings(output_request, &embd, n_embd, n);
399 common::update_session(session, &ctx0, input_tokens.len(), n);
400 }
401
402 fn vocabulary(&self) -> &Vocabulary {
404 &self.vocabulary
405 }
406
407 fn n_context_tokens(&self) -> usize {
408 self.n_context_tokens
409 }
410
411 fn bot_token_id(&self) -> Option<TokenId> {
412 self.vocabulary.token_to_id.get("<s>".as_bytes()).copied()
413 }
414
415 fn eot_token_id(&self) -> TokenId {
416 self.vocabulary
417 .token_to_id
418 .get("</s>".as_bytes())
419 .copied()
420 .unwrap()
421 }
422
423 fn inference_parameters(&self) -> &InferenceParameters {
424 &self.inference_parameters
425 }
426}
427
428#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
430pub struct Hyperparameters {
431 pub n_vocab: usize,
433 pub n_embd: usize,
435 pub n_mult: usize,
437 pub n_head: usize,
439 pub n_layer: usize,
441 pub file_type: FileType,
443}
444impl llm_base::Hyperparameters for Hyperparameters {
445 fn read_ggml(reader: &mut dyn std::io::BufRead) -> Result<Self, llm_base::LoadError> {
446 Ok(Hyperparameters {
449 n_vocab: util::read_i32(reader)?.try_into()?,
450 n_embd: util::read_i32(reader)?.try_into()?,
451 n_mult: util::read_i32(reader)?.try_into()?,
452 n_head: util::read_i32(reader)?.try_into()?,
453 n_layer: util::read_i32(reader)?.try_into()?,
454 file_type: {
455 let ftype = util::read_i32(reader)?;
456 FileType::try_from(ftype).map_err(|_| LoadError::UnsupportedFileType(ftype))?
457 },
458 })
459 }
460
461 fn write_ggml(&self, writer: &mut dyn std::io::Write) -> Result<(), HyperparametersWriteError> {
462 util::write_i32(writer, self.n_vocab.try_into()?)?;
463 util::write_i32(writer, self.n_embd.try_into()?)?;
464 util::write_i32(writer, self.n_mult.try_into()?)?;
465 util::write_i32(writer, self.n_head.try_into()?)?;
466 util::write_i32(writer, self.n_layer.try_into()?)?;
467 util::write_i32(writer, self.file_type.into())?;
468 Ok(())
469 }
470
471 fn n_vocabulary(&self) -> usize {
472 self.n_vocab
473 }
474}
475
476struct Layer {
477 pub attention_norm: ggml::Tensor,
478 pub attention_norm_b: ggml::Tensor,
479 pub wo: ggml::Tensor,
480 pub wo_b: ggml::Tensor,
481 pub query_key_value: ggml::Tensor,
482 pub query_key_value_b: ggml::Tensor,
483 pub ffn_norm: ggml::Tensor,
485 pub ffn_norm_b: ggml::Tensor,
486 pub w1: ggml::Tensor,
488 pub w1_b: ggml::Tensor,
489 pub w2: ggml::Tensor,
490 pub w2_b: ggml::Tensor,
491}