1#![deny(missing_docs)]
3
4use std::{error::Error, path::Path};
5
6use llm_base::{
7 ggml,
8 model::{common, HyperparametersWriteError},
9 util, FileType, InferenceParameters, InferenceSession, InferenceSessionConfig, KnownModel,
10 LoadError, LoadProgress, Mmap, ModelParameters, OutputRequest, TensorLoader, TokenId,
11 Vocabulary,
12};
13
14#[cfg(feature = "convert")]
15pub mod convert;
16
17mod old_loader;
18
19pub struct Llama {
24 hyperparameters: Hyperparameters,
25 n_context_tokens: usize,
26
27 vocabulary: Vocabulary,
28
29 tok_embeddings: ggml::Tensor,
30
31 norm: ggml::Tensor,
32 output: ggml::Tensor,
33
34 layers: Vec<Layer>,
35
36 inference_parameters: InferenceParameters,
37
38 _mmap: Option<Mmap>,
40
41 _context: ggml::Context,
43}
44
45unsafe impl Send for Llama {}
46unsafe impl Sync for Llama {}
47
48impl Llama {
49 pub fn load(
53 path: &Path,
54 params: ModelParameters,
55 load_progress_callback: impl FnMut(LoadProgress),
56 ) -> Result<Llama, LoadError> {
57 llm_base::load(path, params, load_progress_callback)
58 }
59}
60
61impl KnownModel for Llama {
62 type Hyperparameters = Hyperparameters;
63
64 fn new<E: Error>(
65 hyperparameters: Self::Hyperparameters,
66 params: ModelParameters,
67 vocabulary: Vocabulary,
68 tensor_loader: impl TensorLoader<E>,
69 ) -> Result<Self, E> {
70 let mut tl = tensor_loader;
71
72 let tok_embeddings = tl.load("tok_embeddings.weight")?;
73 let norm = tl.load("norm.weight")?;
74 let output = tl.load("output.weight")?;
75
76 let mut layers = Vec::new();
77 for i in 0..hyperparameters.n_layer {
78 let layer = Layer {
79 attention_norm: tl.load(&format!("layers.{i}.attention_norm.weight"))?,
80 wq: tl.load(&format!("layers.{i}.attention.wq.weight"))?,
81 wk: tl.load(&format!("layers.{i}.attention.wk.weight"))?,
82 wv: tl.load(&format!("layers.{i}.attention.wv.weight"))?,
83 wo: tl.load(&format!("layers.{i}.attention.wo.weight"))?,
84 ffn_norm: tl.load(&format!("layers.{i}.ffn_norm.weight"))?,
85 w1: tl.load(&format!("layers.{i}.feed_forward.w1.weight"))?,
86 w2: tl.load(&format!("layers.{i}.feed_forward.w2.weight"))?,
87 w3: tl.load(&format!("layers.{i}.feed_forward.w3.weight"))?,
88 };
89
90 layers.push(layer);
91 }
92
93 let (_context, _tensors, _mmap) = tl.finish();
94
95 let ModelParameters {
96 n_context_tokens,
97 inference_parameters,
98 ..
99 } = params;
100
101 Ok(Self {
102 hyperparameters,
103 n_context_tokens,
104 vocabulary,
105 tok_embeddings,
106 norm,
107 output,
108 layers,
109 inference_parameters,
110 _context,
111 _mmap,
112 })
113 }
114
115 fn start_session(&self, config: InferenceSessionConfig) -> InferenceSession {
117 InferenceSession::new(
118 config,
119 self.n_context_tokens,
120 self.hyperparameters.n_layer,
121 self.hyperparameters.n_embd,
122 self.hyperparameters.n_vocab,
123 )
124 }
125
126 fn evaluate(
127 &self,
128 session: &mut InferenceSession,
129 params: &InferenceParameters,
130 input_tokens: &[TokenId],
131 output_request: &mut OutputRequest,
132 ) {
133 let n = input_tokens.len();
134 let n_past = session.n_past;
135 let n_threads = params.n_threads;
136
137 let memk_elsize = session.memory_k.element_size();
138 let memv_elsize = session.memory_v.element_size();
139
140 let Hyperparameters {
141 n_vocab,
142 n_embd,
143 n_mult: _,
144 n_head,
145 n_layer,
146 n_rot,
147 file_type: _,
148 } = self.hyperparameters;
149 let n_ctx = self.n_context_tokens;
150
151 let (ctx0, embd) = common::prepare_for_evaluate(n_layer, session, input_tokens);
152
153 let mut input_layer = ctx0.op_get_rows(&self.tok_embeddings, &embd);
154
155 let mut gf = ggml::ComputationGraph::new(n_threads);
156
157 for il in 0..n_layer {
158 let input_self_attention = input_layer.share();
159 let mut current: ggml::Tensor;
160
161 ctx0.use_scratch(Some(&mut session.scratch[0]));
162
163 {
165 current = ctx0.op_rms_norm(&input_layer);
166
167 current = ctx0.op_mul(
169 &ctx0.op_repeat(&self.layers[il].attention_norm, ¤t),
170 ¤t,
171 );
172 }
173
174 {
176 let q_current = ctx0.op_rope(
178 &ctx0.op_reshape_3d(
179 &ctx0.op_mul_mat(&self.layers[il].wq, ¤t),
180 n_embd / n_head,
181 n_head,
182 n,
183 ),
184 n_past,
185 n_rot,
186 0,
187 );
188 let k_current = ctx0.op_rope(
189 &ctx0.op_reshape_3d(
190 &ctx0.op_mul_mat(&self.layers[il].wk, ¤t),
191 n_embd / n_head,
192 n_head,
193 n,
194 ),
195 n_past,
196 n_rot,
197 0,
198 );
199
200 {
202 let v_current = ctx0.op_transpose(&ctx0.op_reshape_2d(
204 &ctx0.op_mul_mat(&self.layers[il].wv, ¤t),
205 n_embd,
206 n,
207 ));
208
209 let k = ctx0.op_view_1d(
210 &session.memory_k,
211 n * n_embd,
212 (memk_elsize * n_embd) * (il * n_ctx + n_past),
213 );
214
215 let v = ctx0.op_view_2d(
216 &session.memory_v,
217 (n, n_embd),
218 n_ctx * memv_elsize,
219 (il * n_ctx) * memv_elsize * n_embd + n_past * memv_elsize,
220 );
221
222 gf.build_forward_expand(&ctx0.op_cpy(&k_current, &k));
224 gf.build_forward_expand(&ctx0.op_cpy(&v_current, &v));
225 }
226
227 let q = ctx0.op_permute(&q_current, 0, 2, 1, 3);
228
229 let k = ctx0.op_permute(
230 &ctx0.op_reshape_3d(
231 &ctx0.op_view_1d(
232 &session.memory_k,
233 (n_past + n) * n_embd,
234 il * n_ctx * memk_elsize * n_embd,
235 ),
236 n_embd / n_head,
237 n_head,
238 n_past + n,
239 ),
240 0,
241 2,
242 1,
243 3,
244 );
245
246 let k_q = ctx0.op_mul_mat(&k, &q);
248
249 let k_q_scaled = ctx0.op_scale(
251 &k_q,
252 &ctx0.new_f32(1.0 / f32::sqrt(n_embd as f32 / n_head as f32)),
253 );
254
255 let k_q_masked = ctx0.op_diag_mask_inf(&k_q_scaled, n_past);
257
258 let k_q_soft_max = ctx0.op_soft_max(&k_q_masked);
260
261 let v = ctx0.op_view_3d(
263 &session.memory_v,
264 (n_past + n, n_embd / n_head, n_head),
265 (n_ctx * memv_elsize, n_ctx * memv_elsize * n_embd / n_head),
266 il * n_ctx * memv_elsize * n_embd,
267 );
268
269 let k_q_v = ctx0.op_mul_mat(&v, &k_q_soft_max);
270
271 let k_q_v_merged = ctx0.op_permute(&k_q_v, 0, 2, 1, 3);
273
274 current = ctx0.op_cpy(
276 &k_q_v_merged,
277 &ctx0.new_tensor_2d(ggml::Type::F32, n_embd, n),
278 );
279
280 current = ctx0.op_mul_mat(&self.layers[il].wo, ¤t);
282 }
283
284 ctx0.use_scratch(Some(&mut session.scratch[1]));
285
286 let input_feed_forward = ctx0.op_add(¤t, &input_self_attention);
287
288 {
290 {
292 current = ctx0.op_rms_norm(&input_feed_forward);
293
294 current = ctx0.op_mul(
296 &ctx0.op_repeat(&self.layers[il].ffn_norm, ¤t),
297 ¤t,
298 );
299 }
300
301 let tmp = ctx0.op_mul_mat(&self.layers[il].w3, ¤t);
302
303 current = ctx0.op_mul_mat(&self.layers[il].w1, ¤t);
304
305 current = ctx0.op_silu(¤t);
307
308 current = ctx0.op_mul(¤t, &tmp);
309
310 current = ctx0.op_mul_mat(&self.layers[il].w2, ¤t);
311 }
312
313 current = ctx0.op_add(¤t, &input_feed_forward);
314
315 input_layer = current;
317 }
318
319 ctx0.use_scratch(Some(&mut session.scratch[0]));
320
321 {
325 input_layer = ctx0.op_rms_norm(&input_layer);
326
327 input_layer = ctx0.op_mul(&ctx0.op_repeat(&self.norm, &input_layer), &input_layer);
329 }
330
331 {
333 input_layer = ctx0.op_mul_mat(&self.output, &input_layer);
334 }
335
336 ctx0.use_scratch(None);
337
338 gf.build_forward_expand(&input_layer);
340 ctx0.graph_compute(&mut gf);
341
342 common::read_last_token(session, &input_layer, n_vocab, n);
344 common::extract_logits(output_request, &input_layer, n_vocab, n);
345 common::extract_embeddings(output_request, &embd, n_embd, n);
346 common::update_session(session, &ctx0, input_tokens.len(), n);
347 }
348
349 fn vocabulary(&self) -> &Vocabulary {
351 &self.vocabulary
352 }
353
354 fn n_context_tokens(&self) -> usize {
355 self.n_context_tokens
356 }
357
358 fn bot_token_id(&self) -> Option<TokenId> {
359 None
360 }
361
362 fn eot_token_id(&self) -> TokenId {
363 2
364 }
365
366 fn inference_parameters(&self) -> &InferenceParameters {
367 &self.inference_parameters
368 }
369}
370#[cfg(test)]
371impl Llama {
372 fn new_empty() -> Self {
375 let context = ggml::Context::init(1024 * 1024, true);
376 let tok_embeddings = context.new_f32(0.0);
377 let norm = context.new_f32(0.0);
378 let output = context.new_f32(0.0);
379
380 Self {
381 hyperparameters: Default::default(),
382 n_context_tokens: 0,
383 vocabulary: Default::default(),
384 tok_embeddings,
385 norm,
386 output,
387 layers: Default::default(),
388 _mmap: Default::default(),
389 _context: context,
390 inference_parameters: Default::default(),
391 }
392 }
393}
394
395#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
397pub struct Hyperparameters {
398 pub n_vocab: usize,
400 pub n_embd: usize,
402 pub n_mult: usize,
404 pub n_head: usize,
406 pub n_layer: usize,
408 pub n_rot: usize,
410 pub file_type: FileType,
412}
413impl llm_base::Hyperparameters for Hyperparameters {
414 fn read_ggml(reader: &mut dyn std::io::BufRead) -> Result<Self, LoadError> {
415 Ok(Hyperparameters {
416 n_vocab: util::read_i32(reader)?.try_into()?,
417 n_embd: util::read_i32(reader)?.try_into()?,
418 n_mult: util::read_i32(reader)?.try_into()?,
419 n_head: util::read_i32(reader)?.try_into()?,
420 n_layer: util::read_i32(reader)?.try_into()?,
421 n_rot: util::read_i32(reader)?.try_into()?,
422 file_type: {
423 let ftype = util::read_i32(reader)?;
424 FileType::try_from(ftype).map_err(|_| LoadError::UnsupportedFileType(ftype))?
425 },
426 })
427 }
428
429 fn write_ggml(&self, writer: &mut dyn std::io::Write) -> Result<(), HyperparametersWriteError> {
430 util::write_i32(writer, self.n_vocab.try_into()?)?;
431 util::write_i32(writer, self.n_embd.try_into()?)?;
432 util::write_i32(writer, self.n_mult.try_into()?)?;
433 util::write_i32(writer, self.n_head.try_into()?)?;
434 util::write_i32(writer, self.n_layer.try_into()?)?;
435 util::write_i32(writer, self.n_rot.try_into()?)?;
436 util::write_i32(writer, self.file_type.into())?;
437 Ok(())
438 }
439
440 fn n_vocabulary(&self) -> usize {
441 self.n_vocab
442 }
443}
444
445struct Layer {
446 attention_norm: ggml::Tensor,
447
448 wq: ggml::Tensor,
449 wk: ggml::Tensor,
450 wv: ggml::Tensor,
451 wo: ggml::Tensor,
452
453 ffn_norm: ggml::Tensor,
455
456 w1: ggml::Tensor,
458 w2: ggml::Tensor,
459 w3: ggml::Tensor,
460}
461
462#[cfg(test)]
463mod tests {
464 use super::*;
465 use std::sync::Arc;
466
467 #[test]
468 fn can_share_model_between_threads() {
469 let model = Arc::new(Llama::new_empty());
470
471 for _ in 0..4 {
472 let model = model.clone();
473 std::thread::spawn(move || {
474 let _session = model.start_session(Default::default());
475 });
476 }
477
478 let session = model.start_session(Default::default());
479 std::thread::spawn(move || {
480 let _session = session;
481 });
482 }
483}