1#![deny(missing_docs)]
3
4use std::{error::Error, path::Path};
5
6use ggml::Tensor;
7use llm_base::{
8 ggml,
9 model::{common, HyperparametersWriteError},
10 util, FileType, InferenceParameters, InferenceSession, InferenceSessionConfig, KnownModel,
11 LoadError, LoadProgress, Mmap, ModelParameters, OutputRequest, TensorLoader, TokenId,
12 Vocabulary,
13};
14
15pub struct GptJ {
20 hyperparameters: Hyperparameters,
21 n_context_tokens: usize,
22
23 vocabulary: Vocabulary,
24
25 ln_f_g: Tensor,
27 ln_f_b: Tensor,
28
29 wte: Tensor,
31
32 lmh_g: Tensor,
34 lmh_b: Tensor,
35
36 layers: Vec<Layer>,
37
38 inference_parameters: InferenceParameters,
39
40 _mmap: Option<Mmap>,
42
43 _context: ggml::Context,
45}
46
47unsafe impl Send for GptJ {}
48unsafe impl Sync for GptJ {}
49
50impl GptJ {
51 pub fn load(
55 path: &Path,
56 params: ModelParameters,
57 load_progress_callback: impl FnMut(LoadProgress),
58 ) -> Result<GptJ, LoadError> {
59 llm_base::load(path, params, load_progress_callback)
60 }
61}
62
63impl KnownModel for GptJ {
64 type Hyperparameters = Hyperparameters;
65
66 fn new<E: Error>(
67 hyperparameters: Self::Hyperparameters,
68 params: ModelParameters,
69 vocabulary: Vocabulary,
70 tensor_loader: impl TensorLoader<E>,
71 ) -> Result<Self, E>
72 where
73 Self: Sized,
74 {
75 let mut tl = tensor_loader;
76
77 let wte = tl.load("transformer.wte.weight")?;
79 let ln_f_g = tl.load("transformer.ln_f.weight")?;
80 let ln_f_b = tl.load("transformer.ln_f.bias")?;
81 let lmh_g = tl.load("lm_head.weight")?;
82 let lmh_b = tl.load("lm_head.bias")?;
83
84 let mut layers = Vec::new();
85 for i in 0..hyperparameters.n_layer {
86 let layer = Layer {
87 ln_1_g: tl.load(&format!("transformer.h.{i}.ln_1.weight"))?,
88 ln_1_b: tl.load(&format!("transformer.h.{i}.ln_1.bias"))?,
89 c_attn_q_proj_w: tl.load(&format!("transformer.h.{i}.attn.q_proj.weight"))?,
90 c_attn_k_proj_w: tl.load(&format!("transformer.h.{i}.attn.k_proj.weight"))?,
91 c_attn_v_proj_w: tl.load(&format!("transformer.h.{i}.attn.v_proj.weight"))?,
92 c_attn_proj_w: tl.load(&format!("transformer.h.{i}.attn.out_proj.weight"))?,
93 c_mlp_fc_w: tl.load(&format!("transformer.h.{i}.mlp.fc_in.weight"))?,
94 c_mlp_fc_b: tl.load(&format!("transformer.h.{i}.mlp.fc_in.bias"))?,
95 c_mlp_proj_w: tl.load(&format!("transformer.h.{i}.mlp.fc_out.weight"))?,
96 c_mlp_proj_b: tl.load(&format!("transformer.h.{i}.mlp.fc_out.bias"))?,
97 };
98
99 layers.push(layer);
100 }
101
102 let (_context, _, _mmap) = tl.finish();
103
104 let ModelParameters {
105 n_context_tokens,
106 inference_parameters,
107 ..
108 } = params;
109
110 Ok(GptJ {
111 hyperparameters,
112 n_context_tokens,
113 vocabulary,
114 ln_f_g,
115 ln_f_b,
116 wte,
117 lmh_g,
118 lmh_b,
119 layers,
120 inference_parameters,
121 _mmap,
122 _context,
123 })
124 }
125
126 fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession {
127 InferenceSession::new(
128 config,
129 self.hyperparameters.n_ctx,
130 self.hyperparameters.n_layer,
131 self.hyperparameters.n_embd,
132 self.hyperparameters.n_vocab,
133 )
134 }
135
136 fn evaluate(
137 &self,
138 session: &mut InferenceSession,
139 params: &InferenceParameters,
140 input_tokens: &[TokenId],
141 output_request: &mut OutputRequest,
142 ) {
143 let n = input_tokens.len();
144 let n_threads = params.n_threads;
145
146 let Hyperparameters {
147 n_embd,
148 n_head,
149 n_vocab,
150 n_layer,
151 n_rot,
152 ..
153 } = self.hyperparameters;
154 let n_ctx = self.n_context_tokens;
155
156 let (ctx0, embd) = common::prepare_for_evaluate(n_layer, session, input_tokens);
157
158 let n_past = session.n_past;
159
160 let mut input_layer = ctx0.op_get_rows(&self.wte, &embd);
162
163 let memory_k = &session.memory_k;
164 let memory_k_size = memory_k.element_size();
165
166 let memory_v = &session.memory_v;
167 let memory_v_size = memory_v.element_size();
168
169 let mut gf = ggml::ComputationGraph::new(n_threads);
170
171 for il in 0..n_layer {
172 let mut current = ctx0.op_norm(&input_layer);
174 current = ctx0.op_add(
175 &ctx0.op_mul(&ctx0.op_repeat(&self.layers[il].ln_1_g, ¤t), ¤t),
176 &ctx0.op_repeat(&self.layers[il].ln_1_b, ¤t),
177 );
178
179 let input_sa = current.share();
180
181 let qcur = ctx0.op_rope(
183 &ctx0.op_reshape_3d(
184 &ctx0.op_mul_mat(&self.layers[il].c_attn_q_proj_w, ¤t),
185 n_embd / n_head,
186 n_head,
187 n,
188 ),
189 n_past,
190 n_rot,
191 0,
192 );
193 let kcur = ctx0.op_rope(
194 &ctx0.op_reshape_3d(
195 &ctx0.op_mul_mat(&self.layers[il].c_attn_k_proj_w, ¤t),
196 n_embd / n_head,
197 n_head,
198 n,
199 ),
200 n_past,
201 n_rot,
202 0,
203 );
204
205 let vcur =
207 ctx0.op_transpose(&ctx0.op_mul_mat(&self.layers[il].c_attn_v_proj_w, ¤t));
208
209 let k = ctx0.op_view_1d(
210 memory_k,
211 n * n_embd,
212 (memory_k_size * n_embd) * (il * n_ctx + n_past),
213 );
214 let v = ctx0.op_view_2d(
215 memory_v,
216 (n, n_embd),
217 n_ctx * memory_v_size,
218 (il * n_ctx) * memory_v_size * n_embd + n_past * memory_v_size,
219 );
220
221 gf.build_forward_expand(&ctx0.op_cpy(&kcur, &k));
222 gf.build_forward_expand(&ctx0.op_cpy(&vcur, &v));
223
224 let q = ctx0.op_permute(&qcur, 0, 2, 1, 3);
225 let big_k = ctx0.op_permute(
226 &ctx0.op_reshape_3d(
227 &ctx0.op_view_1d(
228 memory_k,
229 (n_past + n) * n_embd,
230 il * n_ctx * memory_k_size * n_embd,
231 ),
232 n_embd / n_head,
233 n_head,
234 n_past + n,
235 ),
236 0,
237 2,
238 1,
239 3,
240 );
241
242 let kq = ctx0.op_mul_mat(&big_k, &q);
243 let kq_scaled = ctx0.op_scale(
244 &kq,
245 &ctx0.new_f32(1f32 / f32::sqrt(n_embd as f32 / n_head as f32)),
246 );
247
248 let kq_masked = ctx0.op_diag_mask_inf(&kq_scaled, n_past);
249 let kq_softmax = ctx0.op_soft_max(&kq_masked);
250
251 let big_v = ctx0.op_view_3d(
252 memory_v,
253 (n_past + n, n_embd / n_head, n_head),
254 (
255 n_ctx * memory_v_size,
256 n_ctx * memory_v_size * n_embd / n_head,
257 ),
258 il * n_ctx * memory_v_size * n_embd,
259 );
260
261 let kqv = ctx0.op_mul_mat(&big_v, &kq_softmax);
262 let kqv_merged = ctx0.op_permute(&kqv, 0, 2, 1, 3);
263
264 current = ctx0.op_cpy(&kqv_merged, &ctx0.new_tensor_2d(ggml::Type::F32, n_embd, n));
265
266 current = ctx0.op_mul_mat(&self.layers[il].c_attn_proj_w, ¤t);
268
269 let ff_in = current.share();
271
272 current = ctx0.op_mul_mat(&self.layers[il].c_mlp_fc_w, &input_sa);
273 current = ctx0.op_add(
274 &ctx0.op_repeat(&self.layers[il].c_mlp_fc_b, ¤t),
275 ¤t,
276 );
277
278 current = ctx0.op_gelu(¤t);
279
280 current = ctx0.op_mul_mat(&self.layers[il].c_mlp_proj_w, ¤t);
282 current = ctx0.op_add(
283 &ctx0.op_repeat(&self.layers[il].c_mlp_proj_b, ¤t),
284 ¤t,
285 );
286
287 current = ctx0.op_add(¤t, &ff_in);
288
289 input_layer = ctx0.op_add(¤t, &input_layer);
291 }
292
293 input_layer = ctx0.op_norm(&input_layer);
295 input_layer = ctx0.op_add(
296 &ctx0.op_mul(&ctx0.op_repeat(&self.ln_f_g, &input_layer), &input_layer),
297 &ctx0.op_repeat(&self.ln_f_b, &input_layer),
298 );
299
300 input_layer = ctx0.op_mul_mat(&self.lmh_g, &input_layer);
302 input_layer = ctx0.op_add(&ctx0.op_repeat(&self.lmh_b, &input_layer), &input_layer);
303
304 gf.build_forward_expand(&input_layer);
306 ctx0.graph_compute(&mut gf);
307
308 common::read_last_token(session, &input_layer, n_vocab, n);
310 common::extract_logits(output_request, &input_layer, n_vocab, n);
311 common::extract_embeddings(output_request, &embd, n_embd, n);
312 common::update_session(session, &ctx0, input_tokens.len(), n);
313 }
314
315 fn vocabulary(&self) -> &Vocabulary {
316 &self.vocabulary
317 }
318
319 fn n_context_tokens(&self) -> usize {
320 self.hyperparameters.n_ctx
321 }
322
323 fn bot_token_id(&self) -> Option<TokenId> {
324 None
325 }
326
327 fn eot_token_id(&self) -> TokenId {
328 self.vocabulary
329 .token_to_id
330 .get("<|endoftext|>".as_bytes())
331 .copied()
332 .unwrap()
333 }
334
335 fn inference_parameters(&self) -> &InferenceParameters {
336 &self.inference_parameters
337 }
338}
339
340#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
342pub struct Hyperparameters {
343 pub n_vocab: usize,
345 pub n_ctx: usize,
347 pub n_embd: usize,
349 pub n_head: usize,
351 pub n_layer: usize,
353 pub n_rot: usize,
355 pub file_type: FileType,
357}
358impl llm_base::Hyperparameters for Hyperparameters {
359 fn read_ggml(reader: &mut dyn std::io::BufRead) -> Result<Self, LoadError> {
360 let hyperparameters = Hyperparameters {
361 n_vocab: util::read_i32(reader)?.try_into()?,
362 n_ctx: util::read_i32(reader)?.try_into()?,
363 n_embd: util::read_i32(reader)?.try_into()?,
364 n_head: util::read_i32(reader)?.try_into()?,
365 n_layer: util::read_i32(reader)?.try_into()?,
366 n_rot: util::read_i32(reader)?.try_into()?,
367 file_type: {
368 let ftype = util::read_i32(reader)?;
369 FileType::try_from(ftype).map_err(|_| LoadError::UnsupportedFileType(ftype))?
370 },
371 };
372
373 let n_vocab = util::read_i32(reader)? as usize;
374 if hyperparameters.n_vocab != n_vocab {
375 return Err(LoadError::InvariantBroken {
376 path: None,
377 invariant: format!(
378 "GPT2 model expected n_vocab {} found {}",
379 hyperparameters.n_vocab, n_vocab
380 ),
381 });
382 }
383
384 Ok(hyperparameters)
385 }
386
387 fn write_ggml(&self, writer: &mut dyn std::io::Write) -> Result<(), HyperparametersWriteError> {
388 util::write_i32(writer, self.n_vocab.try_into()?)?;
389 util::write_i32(writer, self.n_ctx.try_into()?)?;
390 util::write_i32(writer, self.n_embd.try_into()?)?;
391 util::write_i32(writer, self.n_head.try_into()?)?;
392 util::write_i32(writer, self.n_layer.try_into()?)?;
393 util::write_i32(writer, self.n_rot.try_into()?)?;
394 util::write_i32(writer, self.file_type.into())?;
395 Ok(())
396 }
397
398 fn n_vocabulary(&self) -> usize {
399 self.n_vocab
400 }
401}
402
403struct Layer {
404 ln_1_g: Tensor,
406 ln_1_b: Tensor,
407
408 c_attn_q_proj_w: Tensor,
410 c_attn_k_proj_w: Tensor,
411 c_attn_v_proj_w: Tensor,
412
413 c_attn_proj_w: Tensor,
414
415 c_mlp_fc_w: Tensor,
417 c_mlp_fc_b: Tensor,
418
419 c_mlp_proj_w: Tensor,
420 c_mlp_proj_b: Tensor,
421}
422
423#[cfg(test)]
424impl GptJ {
425 fn new_empty() -> Self {
428 let context = ggml::Context::init(1024 * 1024, true);
429
430 Self {
431 hyperparameters: Default::default(),
432 n_context_tokens: 0,
433 vocabulary: Default::default(),
434 ln_f_g: context.new_f32(0.0),
435 ln_f_b: context.new_f32(0.0),
436 wte: context.new_f32(0.0),
437 lmh_g: context.new_f32(0.0),
438 lmh_b: context.new_f32(0.0),
439 layers: Default::default(),
440 inference_parameters: Default::default(),
441 _mmap: Default::default(),
442 _context: context,
443 }
444 }
445}
446
447#[cfg(test)]
448mod tests {
449 use super::*;
450 use std::sync::Arc;
451
452 #[test]
453 fn can_share_model_between_threads() {
454 let model = Arc::new(GptJ::new_empty());
455
456 for _ in 0..4 {
457 let model = model.clone();
458 std::thread::spawn(move || {
459 let _session = model.start_session(Default::default());
460 });
461 }
462
463 let session = model.start_session(Default::default());
464 std::thread::spawn(move || {
465 let _session = session;
466 });
467 }
468}