1#![deny(missing_docs)]
3
4use std::path::Path;
5
6use ggml::Tensor;
7use llm_base::{
8 ggml,
9 model::{common, HyperparametersWriteError},
10 util, FileType, InferenceParameters, InferenceSession, InferenceSessionConfig, KnownModel,
11 LoadError, LoadProgress, ModelParameters, OutputRequest, TokenId, Vocabulary,
12};
13
14pub struct Gpt2 {
19 hyperparameters: Hyperparameters,
20 n_context_tokens: usize,
21 vocabulary: Vocabulary,
22 ln_f_g: Tensor,
23 ln_f_b: Tensor,
24 wte: Tensor,
25 wpe: Tensor,
26 lm_head: Tensor,
27 layers: Vec<Layer>,
28 inference_params: InferenceParameters,
29 _context: ggml::Context,
30}
31
32unsafe impl Send for Gpt2 {}
33unsafe impl Sync for Gpt2 {}
34
35impl Gpt2 {
36 pub fn load(
40 path: &Path,
41 params: ModelParameters,
42 load_progress_callback: impl FnMut(LoadProgress),
43 ) -> Result<Gpt2, LoadError> {
44 llm_base::load(path, params, load_progress_callback)
45 }
46}
47
48impl KnownModel for Gpt2 {
49 type Hyperparameters = Hyperparameters;
50
51 fn new<E: std::error::Error>(
52 hyperparameters: Self::Hyperparameters,
53 params: ModelParameters,
54 vocabulary: Vocabulary,
55 tensor_loader: impl llm_base::TensorLoader<E>,
56 ) -> Result<Self, E> {
57 let mut tl = tensor_loader;
58 let ln_f_g = tl.load("model/ln_f/g")?;
60 let ln_f_b = tl.load("model/ln_f/b")?;
61 let wte = tl.load("model/wte")?;
62 let wpe = tl.load("model/wpe")?;
63 let lm_head = tl.load("model/lm_head")?;
64
65 let mut layers = Vec::new();
66 for i in 0..hyperparameters.n_layer {
67 let layer = Layer {
68 ln_1_g: tl.load(&format!("model/h{i}/ln_1/g"))?,
69 ln_1_b: tl.load(&format!("model/h{i}/ln_1/b"))?,
70 ln_2_g: tl.load(&format!("model/h{i}/ln_2/g"))?,
71 ln_2_b: tl.load(&format!("model/h{i}/ln_2/b"))?,
72 c_attn_attn_w: tl.load(&format!("model/h{i}/attn/c_attn/w"))?,
73 c_attn_attn_b: tl.load(&format!("model/h{i}/attn/c_attn/b"))?,
74 c_attn_proj_w: tl.load(&format!("model/h{i}/attn/c_proj/w"))?,
75 c_attn_proj_b: tl.load(&format!("model/h{i}/attn/c_proj/b"))?,
76 c_mlp_fc_w: tl.load(&format!("model/h{i}/mlp/c_fc/w"))?,
77 c_mlp_fc_b: tl.load(&format!("model/h{i}/mlp/c_fc/b"))?,
78 c_mlp_proj_w: tl.load(&format!("model/h{i}/mlp/c_proj/w"))?,
79 c_mlp_proj_b: tl.load(&format!("model/h{i}/mlp/c_proj/b"))?,
80 };
81
82 layers.push(layer);
83 }
84
85 let (_context, _, _mmap) = tl.finish();
86
87 let ModelParameters {
88 n_context_tokens,
89 inference_parameters: inference_params,
90 ..
91 } = params;
92
93 Ok(Gpt2 {
94 hyperparameters,
95 n_context_tokens,
96 vocabulary,
97 layers,
98 ln_f_g,
99 ln_f_b,
100 wte,
101 wpe,
102 lm_head,
103 inference_params,
104 _context,
105 })
106 }
107
108 fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession {
109 InferenceSession::new(
110 config,
111 self.hyperparameters.n_ctx,
112 self.hyperparameters.n_layer,
113 self.hyperparameters.n_embd,
114 self.hyperparameters.n_vocab,
115 )
116 }
117
118 fn evaluate(
119 &self,
120 session: &mut InferenceSession,
121 params: &InferenceParameters,
122 input_tokens: &[TokenId],
123 output_request: &mut OutputRequest,
124 ) {
125 let n = input_tokens.len();
126 let n_threads = params.n_threads;
127
128 let Hyperparameters {
129 n_embd,
130 n_head,
131 n_vocab,
132 n_layer,
133 ..
134 } = self.hyperparameters;
135 let n_ctx = self.n_context_tokens;
136
137 let (ctx0, embd) = common::prepare_for_evaluate(n_layer, session, input_tokens);
138
139 let n_past = session.n_past;
140
141 let mut position_buf = vec![];
142 for position_idx in 0..n {
143 position_buf.push(n_past + position_idx);
144 }
145
146 let mut position = ctx0.new_tensor_1d(ggml::Type::I32, n);
147 unsafe { position.write_data(bytemuck::cast_slice(&position_buf)) };
148
149 let mut input_layer = ctx0.op_add(
150 &ctx0.op_get_rows(&self.wte, &embd),
151 &ctx0.op_get_rows(&self.wpe, &position),
152 );
153
154 let memory_k = &session.memory_k;
155 let memory_k_size = memory_k.element_size();
156
157 let memory_v = &session.memory_v;
158 let memory_v_size = memory_v.element_size();
159
160 let mut gf = ggml::ComputationGraph::new(n_threads);
161
162 for il in 0..n_layer {
163 let mut current = ctx0.op_norm(&input_layer);
165 current = ctx0.op_add(
166 &ctx0.op_mul(&ctx0.op_repeat(&self.layers[il].ln_1_g, ¤t), ¤t),
167 &ctx0.op_repeat(&self.layers[il].ln_1_b, ¤t),
168 );
169
170 current = ctx0.op_mul_mat(&self.layers[il].c_attn_attn_w, ¤t);
172 current = ctx0.op_add(
173 &ctx0.op_repeat(&self.layers[il].c_attn_attn_b, ¤t),
174 ¤t,
175 );
176
177 let nb = current.get_nb()[1];
179 let f32_size = std::mem::size_of::<f32>();
180 let qcur = ctx0.op_view_2d(¤t, (n_embd, n), nb, 0);
181 let kcur = ctx0.op_view_2d(¤t, (n_embd, n), nb, f32_size * n_embd);
182 let vcur = ctx0.op_view_2d(¤t, (n_embd, n), nb, f32_size * n_embd * 2);
183
184 if n >= 1 {
185 let k = ctx0.op_view_1d(
186 memory_k,
187 n * n_embd,
188 (memory_k_size * n_embd) * (il * n_ctx + n_past),
189 );
190 let v = ctx0.op_view_1d(
191 memory_v,
192 n * n_embd,
193 (memory_v_size * n_embd) * (il * n_ctx + n_past),
194 );
195
196 gf.build_forward_expand(&ctx0.op_cpy(&kcur, &k));
197 gf.build_forward_expand(&ctx0.op_cpy(&vcur, &v));
198 }
199
200 let q = ctx0.op_permute(
201 &ctx0.op_cpy(
202 &qcur,
203 &ctx0.new_tensor_3d(ggml::Type::F32, n_embd / n_head, n_head, n),
204 ),
205 0,
206 2,
207 1,
208 3,
209 );
210
211 let k = ctx0.op_permute(
212 &ctx0.op_reshape_3d(
213 &ctx0.op_view_1d(
214 &session.memory_k,
215 (n_past + n) * n_embd,
216 il * n_ctx * memory_k_size * n_embd,
217 ),
218 n_embd / n_head,
219 n_head,
220 n_past + n,
221 ),
222 0,
223 2,
224 1,
225 3,
226 );
227
228 let kq = ctx0.op_mul_mat(&k, &q);
229 let kq_scaled = ctx0.op_scale(
230 &kq,
231 &ctx0.new_f32(1f32 / f32::sqrt(n_embd as f32 / n_head as f32)),
232 );
233
234 let kq_masked = ctx0.op_diag_mask_inf(&kq_scaled, n_past);
235 let kq_softmax = ctx0.op_soft_max(&kq_masked);
236
237 let v_trans = ctx0.op_cpy(
238 &ctx0.op_permute(
239 &ctx0.op_reshape_3d(
240 &ctx0.op_view_1d(
241 memory_v,
242 (n_past + n) * n_embd,
243 il * n_ctx * memory_v_size * n_embd,
244 ),
245 n_embd / n_head,
246 n_head,
247 n_past + n,
248 ),
249 1,
250 2,
251 0,
252 3,
253 ),
254 &ctx0.new_tensor_3d(memory_v.get_type(), n_past + n, n_embd / n_head, n_head),
255 );
256
257 let kqv = ctx0.op_mul_mat(&v_trans, &kq_softmax);
258 let kqv_merged = ctx0.op_permute(&kqv, 0, 2, 1, 3);
259
260 current = ctx0.op_cpy(&kqv_merged, &ctx0.new_tensor_2d(ggml::Type::F32, n_embd, n));
261
262 current = ctx0.op_mul_mat(&self.layers[il].c_attn_proj_w, ¤t);
264 current = ctx0.op_add(
265 &ctx0.op_repeat(&self.layers[il].c_attn_proj_b, ¤t),
266 ¤t,
267 );
268
269 current = ctx0.op_add(¤t, &input_layer);
271
272 let ff_in = current.share();
274
275 current = ctx0.op_norm(&ff_in);
277 current = ctx0.op_add(
278 &ctx0.op_mul(&ctx0.op_repeat(&self.layers[il].ln_2_g, ¤t), ¤t),
279 &ctx0.op_repeat(&self.layers[il].ln_2_b, ¤t),
280 );
281
282 current = ctx0.op_mul_mat(&self.layers[il].c_mlp_fc_w, ¤t);
284 current = ctx0.op_add(
285 &ctx0.op_repeat(&self.layers[il].c_mlp_fc_b, ¤t),
286 ¤t,
287 );
288
289 current = ctx0.op_gelu(¤t);
291
292 current = ctx0.op_mul_mat(&self.layers[il].c_mlp_proj_w, ¤t);
294 current = ctx0.op_add(
295 &ctx0.op_repeat(&self.layers[il].c_mlp_proj_b, ¤t),
296 ¤t,
297 );
298
299 input_layer = ctx0.op_add(¤t, &ff_in);
301 }
302
303 input_layer = ctx0.op_norm(&input_layer);
305 input_layer = ctx0.op_add(
306 &ctx0.op_mul(&ctx0.op_repeat(&self.ln_f_g, &input_layer), &input_layer),
307 &ctx0.op_repeat(&self.ln_f_b, &input_layer),
308 );
309
310 input_layer = ctx0.op_mul_mat(&self.lm_head, &input_layer);
311
312 gf.build_forward_expand(&input_layer);
314 ctx0.graph_compute(&mut gf);
315
316 common::read_last_token(session, &input_layer, n_vocab, n);
318 common::extract_logits(output_request, &input_layer, n_vocab, n);
319 common::extract_embeddings(output_request, &embd, n_embd, n);
320 common::update_session(session, &ctx0, input_tokens.len(), n);
321 }
322
323 fn vocabulary(&self) -> &Vocabulary {
324 &self.vocabulary
325 }
326
327 fn n_context_tokens(&self) -> usize {
328 self.hyperparameters.n_ctx
329 }
330
331 fn bot_token_id(&self) -> Option<TokenId> {
332 None
333 }
334
335 fn eot_token_id(&self) -> TokenId {
336 self.vocabulary
337 .token_to_id
338 .get("<|endoftext|>".as_bytes())
339 .copied()
340 .unwrap()
341 }
342
343 fn inference_parameters(&self) -> &InferenceParameters {
344 &self.inference_params
345 }
346}
347
348#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
350pub struct Hyperparameters {
351 n_vocab: usize,
353 n_ctx: usize,
355 n_embd: usize,
357 n_head: usize,
359 n_layer: usize,
361 file_type: FileType,
363}
364impl llm_base::Hyperparameters for Hyperparameters {
365 fn read_ggml(reader: &mut dyn std::io::BufRead) -> Result<Self, LoadError> {
366 let hyperparameters = Hyperparameters {
367 n_vocab: util::read_i32(reader)?.try_into()?,
368 n_ctx: util::read_i32(reader)?.try_into()?,
369 n_embd: util::read_i32(reader)?.try_into()?,
370 n_head: util::read_i32(reader)?.try_into()?,
371 n_layer: util::read_i32(reader)?.try_into()?,
372 file_type: {
373 let ftype = util::read_i32(reader)?;
374 FileType::try_from(ftype).map_err(|_| LoadError::UnsupportedFileType(ftype))?
375 },
376 };
377
378 let n_vocab = util::read_i32(reader)? as usize;
379 if hyperparameters.n_vocab != n_vocab {
380 return Err(LoadError::InvariantBroken {
381 path: None,
382 invariant: format!(
383 "GPT2 model expected n_vocab {} found {}",
384 hyperparameters.n_vocab, n_vocab
385 ),
386 });
387 }
388
389 Ok(hyperparameters)
390 }
391
392 fn write_ggml(&self, writer: &mut dyn std::io::Write) -> Result<(), HyperparametersWriteError> {
393 util::write_i32(writer, self.n_vocab.try_into()?)?;
394 util::write_i32(writer, self.n_ctx.try_into()?)?;
395 util::write_i32(writer, self.n_embd.try_into()?)?;
396 util::write_i32(writer, self.n_head.try_into()?)?;
397 util::write_i32(writer, self.n_layer.try_into()?)?;
398 util::write_i32(writer, self.file_type.into())?;
399 util::write_i32(writer, self.n_vocab.try_into()?)?;
400
401 Ok(())
402 }
403
404 fn n_vocabulary(&self) -> usize {
405 self.n_vocab
406 }
407}
408
409struct Layer {
410 ln_1_g: Tensor,
412 ln_1_b: Tensor,
413
414 ln_2_g: Tensor,
415 ln_2_b: Tensor,
416
417 c_attn_attn_w: Tensor,
419 c_attn_attn_b: Tensor,
420
421 c_attn_proj_w: Tensor,
422 c_attn_proj_b: Tensor,
423
424 c_mlp_fc_w: Tensor,
426 c_mlp_fc_b: Tensor,
427
428 c_mlp_proj_w: Tensor,
429 c_mlp_proj_b: Tensor,
430}