1#![deny(missing_docs)]
3
4use std::{error::Error, path::Path};
5
6use ggml::Tensor;
7use llm_base::{
8 ggml,
9 model::{common, HyperparametersWriteError},
10 util, FileType, InferenceParameters, InferenceSession, InferenceSessionConfig, KnownModel,
11 LoadError, LoadProgress, Mmap, ModelParameters, OutputRequest, TensorLoader, TokenId,
12 Vocabulary,
13};
14
15pub struct NeoX {
20 hyperparameters: Hyperparameters,
21 n_context_tokens: usize,
22
23 vocabulary: Vocabulary,
24
25 ln_f_g: Tensor,
27 ln_f_b: Tensor,
28
29 wte: Tensor,
31
32 lmh_g: Tensor,
34
35 layers: Vec<Layer>,
36
37 inference_parameters: InferenceParameters,
38
39 _mmap: Option<Mmap>,
41
42 _context: ggml::Context,
44}
45
46unsafe impl Send for NeoX {}
47unsafe impl Sync for NeoX {}
48
49impl NeoX {
50 pub fn load(
54 path: &Path,
55 params: ModelParameters,
56 load_progress_callback: impl FnMut(LoadProgress),
57 ) -> Result<NeoX, LoadError> {
58 llm_base::load(path, params, load_progress_callback)
59 }
60}
61
62impl KnownModel for NeoX {
63 type Hyperparameters = Hyperparameters;
64
65 fn new<E: Error>(
66 hyperparameters: Self::Hyperparameters,
67 params: ModelParameters,
68 vocabulary: Vocabulary,
69 tensor_loader: impl TensorLoader<E>,
70 ) -> Result<Self, E>
71 where
72 Self: Sized,
73 {
74 let mut tl = tensor_loader;
75
76 let wte = tl.load("gpt_neox.embed_in.weight")?;
78 let ln_f_g = tl.load("gpt_neox.final_layer_norm.weight")?;
79 let ln_f_b = tl.load("gpt_neox.final_layer_norm.bias")?;
80 let lmh_g = tl.load("embed_out.weight")?;
81
82 let mut layers = Vec::new();
83 for i in 0..hyperparameters.n_layer {
84 let layer = Layer {
85 ln_1_g: tl.load(&format!("gpt_neox.layers.{i}.input_layernorm.weight"))?,
86 ln_1_b: tl.load(&format!("gpt_neox.layers.{i}.input_layernorm.bias"))?,
87
88 c_attn_attn_w: tl.load(&format!(
89 "gpt_neox.layers.{i}.attention.query_key_value.weight"
90 ))?,
91 c_attn_attn_b: tl.load(&format!(
92 "gpt_neox.layers.{i}.attention.query_key_value.bias"
93 ))?,
94
95 c_attn_proj_w: tl.load(&format!("gpt_neox.layers.{i}.attention.dense.weight"))?,
96 c_attn_proj_b: tl.load(&format!("gpt_neox.layers.{i}.attention.dense.bias"))?,
97
98 ln_2_g: tl.load(&format!(
99 "gpt_neox.layers.{i}.post_attention_layernorm.weight"
100 ))?,
101 ln_2_b: tl.load(&format!(
102 "gpt_neox.layers.{i}.post_attention_layernorm.bias"
103 ))?,
104
105 c_mlp_fc_w: tl.load(&format!("gpt_neox.layers.{i}.mlp.dense_h_to_4h.weight"))?,
106 c_mlp_fc_b: tl.load(&format!("gpt_neox.layers.{i}.mlp.dense_h_to_4h.bias"))?,
107
108 c_mlp_proj_w: tl.load(&format!("gpt_neox.layers.{i}.mlp.dense_4h_to_h.weight"))?,
109 c_mlp_proj_b: tl.load(&format!("gpt_neox.layers.{i}.mlp.dense_4h_to_h.bias"))?,
110 };
111
112 layers.push(layer);
113 }
114
115 let (_context, _, _mmap) = tl.finish();
116
117 let ModelParameters {
118 n_context_tokens,
119 inference_parameters,
120 ..
121 } = params;
122
123 Ok(NeoX {
124 hyperparameters,
125 n_context_tokens,
126 vocabulary,
127 ln_f_g,
128 ln_f_b,
129 wte,
130 lmh_g,
131 layers,
132 inference_parameters,
133 _context,
134 _mmap,
135 })
136 }
137
138 fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession {
139 InferenceSession::new(
140 config,
141 self.hyperparameters.n_ctx,
142 self.hyperparameters.n_layer,
143 self.hyperparameters.n_embd,
144 self.hyperparameters.n_vocab,
145 )
146 }
147
148 fn evaluate(
149 &self,
150 session: &mut InferenceSession,
151 params: &InferenceParameters,
152 input_tokens: &[TokenId],
153 output_request: &mut OutputRequest,
154 ) {
155 let n = input_tokens.len();
156 let n_threads = params.n_threads;
157
158 let Hyperparameters {
159 n_embd,
160 n_head,
161 n_vocab,
162 n_layer,
163 n_rot,
164 ..
165 } = self.hyperparameters;
166 let n_ctx = self.n_context_tokens;
167
168 let (ctx0, embd) = common::prepare_for_evaluate(n_layer, session, input_tokens);
169
170 let n_past = session.n_past;
171
172 let mut input_layer = ctx0.op_get_rows(&self.wte, &embd);
174
175 let memory_k = &session.memory_k;
176 let memory_k_size = memory_k.element_size();
177
178 let memory_v = &session.memory_v;
179 let memory_v_size = memory_v.element_size();
180
181 let mut gf = ggml::ComputationGraph::new(n_threads);
182
183 for il in 0..n_layer {
184 let mut current = ctx0.op_norm(&input_layer);
186 current = ctx0.op_add(
187 &ctx0.op_mul(&ctx0.op_repeat(&self.layers[il].ln_1_g, ¤t), ¤t),
188 &ctx0.op_repeat(&self.layers[il].ln_1_b, ¤t),
189 );
190
191 current = ctx0.op_mul_mat(&self.layers[il].c_attn_attn_w, ¤t);
193 current = ctx0.op_add(
194 &ctx0.op_repeat(&self.layers[il].c_attn_attn_b, ¤t),
195 ¤t,
196 );
197
198 let nb = current.get_nb()[1];
199 let f32_size = std::mem::size_of::<f32>();
200 let mut qcur = ctx0.op_cont(&ctx0.op_view_3d(
201 ¤t,
202 (n_embd / n_head, n_head, n),
203 (nb / n_head, nb),
204 0,
205 ));
206 let mut kcur = ctx0.op_cont(&ctx0.op_view_3d(
207 ¤t,
208 (n_embd / n_head, n_head, n),
209 (nb / n_head, nb),
210 f32_size * n_embd / n_head,
211 ));
212 let mut vcur = ctx0.op_cont(&ctx0.op_view_3d(
213 ¤t,
214 (n_embd / n_head, n_head, n),
215 (nb / n_head, nb),
216 2 * f32_size * n_embd / n_head,
217 ));
218
219 qcur = ctx0.op_rope(&qcur, n_past, n_rot, 2);
221 kcur = ctx0.op_rope(&kcur, n_past, n_rot, 2);
222
223 vcur = ctx0.op_transpose(&ctx0.op_reshape_2d(&vcur, n_embd, n));
225
226 let little_k = ctx0.op_view_1d(
227 memory_k,
228 n * n_embd,
229 (memory_k_size * n_embd) * (il * n_ctx + n_past),
230 );
231 let little_v = ctx0.op_view_2d(
232 memory_v,
233 (n, n_embd),
234 n_ctx * memory_v_size,
235 (il * n_ctx) * memory_v_size * n_embd + n_past * memory_v_size,
236 );
237
238 gf.build_forward_expand(&ctx0.op_cpy(&kcur, &little_k));
239 gf.build_forward_expand(&ctx0.op_cpy(&vcur, &little_v));
240
241 let q = ctx0.op_permute(&qcur, 0, 2, 1, 3);
242 let big_k = ctx0.op_permute(
243 &ctx0.op_reshape_3d(
244 &ctx0.op_view_1d(
245 memory_k,
246 (n_past + n) * n_embd,
247 il * n_ctx * memory_k_size * n_embd,
248 ),
249 n_embd / n_head,
250 n_head,
251 n_past + n,
252 ),
253 0,
254 2,
255 1,
256 3,
257 );
258
259 let kq = ctx0.op_mul_mat(&big_k, &q);
260 let kq_scaled = ctx0.op_scale(
261 &kq,
262 &ctx0.new_f32(1f32 / f32::sqrt(n_embd as f32 / n_head as f32)),
263 );
264
265 let kq_masked = ctx0.op_diag_mask_inf(&kq_scaled, n_past);
266 let kq_softmax = ctx0.op_soft_max(&kq_masked);
267
268 let big_v = ctx0.op_view_3d(
269 memory_v,
270 (n_past + n, n_embd / n_head, n_head),
271 (
272 n_ctx * memory_v_size,
273 n_ctx * memory_v_size * n_embd / n_head,
274 ),
275 il * n_ctx * memory_v_size * n_embd,
276 );
277
278 let kqv = ctx0.op_mul_mat(&big_v, &kq_softmax);
279 let kqv_merged = ctx0.op_permute(&kqv, 0, 2, 1, 3);
280
281 current = ctx0.op_cpy(&kqv_merged, &ctx0.new_tensor_2d(ggml::Type::F32, n_embd, n));
282
283 current = ctx0.op_mul_mat(&self.layers[il].c_attn_proj_w, ¤t);
285 current = ctx0.op_add(
286 &ctx0.op_repeat(&self.layers[il].c_attn_proj_b, ¤t),
287 ¤t,
288 );
289
290 let ff_in = current.share();
292
293 current = ctx0.op_norm(&input_layer);
295 current = ctx0.op_add(
296 &ctx0.op_mul(&ctx0.op_repeat(&self.layers[il].ln_2_g, ¤t), ¤t),
297 &ctx0.op_repeat(&self.layers[il].ln_2_b, ¤t),
298 );
299
300 current = ctx0.op_mul_mat(&self.layers[il].c_mlp_fc_w, ¤t);
301 current = ctx0.op_add(
302 &ctx0.op_repeat(&self.layers[il].c_mlp_fc_b, ¤t),
303 ¤t,
304 );
305
306 current = ctx0.op_gelu(¤t);
307
308 current = ctx0.op_mul_mat(&self.layers[il].c_mlp_proj_w, ¤t);
310 current = ctx0.op_add(
311 &ctx0.op_repeat(&self.layers[il].c_mlp_proj_b, ¤t),
312 ¤t,
313 );
314
315 current = ctx0.op_add(¤t, &ff_in);
316
317 input_layer = ctx0.op_add(¤t, &input_layer);
319 }
320
321 input_layer = ctx0.op_norm(&input_layer);
322 input_layer = ctx0.op_add(
323 &ctx0.op_mul(&ctx0.op_repeat(&self.ln_f_g, &input_layer), &input_layer),
324 &ctx0.op_repeat(&self.ln_f_b, &input_layer),
325 );
326
327 input_layer = ctx0.op_mul_mat(&self.lmh_g, &input_layer);
328
329 gf.build_forward_expand(&input_layer);
331 ctx0.graph_compute(&mut gf);
332
333 common::read_last_token(session, &input_layer, n_vocab, n);
335 common::extract_logits(output_request, &input_layer, n_vocab, n);
336 common::extract_embeddings(output_request, &embd, n_embd, n);
337 common::update_session(session, &ctx0, input_tokens.len(), n);
338 }
339
340 fn vocabulary(&self) -> &Vocabulary {
341 &self.vocabulary
342 }
343
344 fn n_context_tokens(&self) -> usize {
345 self.hyperparameters.n_ctx
346 }
347
348 fn bot_token_id(&self) -> Option<TokenId> {
349 None
350 }
351
352 fn eot_token_id(&self) -> TokenId {
353 self.vocabulary
354 .token_to_id
355 .get("<|endoftext|>".as_bytes())
356 .copied()
357 .unwrap()
358 }
359
360 fn inference_parameters(&self) -> &InferenceParameters {
361 &self.inference_parameters
362 }
363}
364
365#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
367pub struct Hyperparameters {
368 pub n_vocab: usize,
370 pub n_ctx: usize,
372 pub n_embd: usize,
374 pub n_head: usize,
376 pub n_layer: usize,
378 pub n_rot: usize,
380 pub file_type: FileType,
382}
383impl llm_base::Hyperparameters for Hyperparameters {
384 fn read_ggml(reader: &mut dyn std::io::BufRead) -> Result<Self, LoadError> {
385 Ok(Hyperparameters {
386 n_vocab: util::read_i32(reader)?.try_into()?,
387 n_ctx: util::read_i32(reader)?.try_into()?,
388 n_embd: util::read_i32(reader)?.try_into()?,
389 n_head: util::read_i32(reader)?.try_into()?,
390 n_layer: util::read_i32(reader)?.try_into()?,
391 n_rot: util::read_i32(reader)?.try_into()?,
392 file_type: {
393 let ftype = util::read_i32(reader)?;
394 FileType::try_from(ftype).map_err(|_| LoadError::UnsupportedFileType(ftype))?
395 },
396 })
397 }
398
399 fn write_ggml(&self, writer: &mut dyn std::io::Write) -> Result<(), HyperparametersWriteError> {
400 util::write_i32(writer, self.n_vocab.try_into()?)?;
401 util::write_i32(writer, self.n_ctx.try_into()?)?;
402 util::write_i32(writer, self.n_embd.try_into()?)?;
403 util::write_i32(writer, self.n_head.try_into()?)?;
404 util::write_i32(writer, self.n_layer.try_into()?)?;
405 util::write_i32(writer, self.n_rot.try_into()?)?;
406 util::write_i32(writer, self.file_type.into())?;
407 Ok(())
408 }
409
410 fn n_vocabulary(&self) -> usize {
411 self.n_vocab
412 }
413}
414
415struct Layer {
416 ln_1_g: Tensor,
418 ln_1_b: Tensor,
419
420 c_attn_attn_w: Tensor,
422 c_attn_attn_b: Tensor,
423
424 c_attn_proj_w: Tensor,
425 c_attn_proj_b: Tensor,
426
427 ln_2_g: Tensor,
429 ln_2_b: Tensor,
430
431 c_mlp_fc_w: Tensor,
433 c_mlp_fc_b: Tensor,
434
435 c_mlp_proj_w: Tensor,
436 c_mlp_proj_b: Tensor,
437}
438
439#[cfg(test)]
440impl NeoX {
441 fn new_empty() -> Self {
444 let context = ggml::Context::init(1024 * 1024, true);
445
446 Self {
447 hyperparameters: Default::default(),
448 n_context_tokens: 0,
449 vocabulary: Default::default(),
450 ln_f_g: context.new_f32(0.0),
451 ln_f_b: context.new_f32(0.0),
452 wte: context.new_f32(0.0),
453 lmh_g: context.new_f32(0.0),
454 layers: Default::default(),
455 inference_parameters: Default::default(),
456 _mmap: Default::default(),
457 _context: context,
458 }
459 }
460}
461
462#[cfg(test)]
463mod tests {
464 use super::*;
465 use std::sync::Arc;
466
467 #[test]
468 fn can_share_model_between_threads() {
469 let model = Arc::new(NeoX::new_empty());
470
471 for _ in 0..4 {
472 let model = model.clone();
473 std::thread::spawn(move || {
474 let _session = model.start_session(Default::default());
475 });
476 }
477
478 let session = model.start_session(Default::default());
479 std::thread::spawn(move || {
480 let _session = session;
481 });
482 }
483}