llm_base/model/
common.rs

1use ggml::{Context, Tensor};
2
3use crate::{InferenceSession, OutputRequest, TokenId};
4
5/// Common code to prepare a model to evaluate input
6pub fn prepare_for_evaluate(
7    n_layer: usize,
8    session: &mut InferenceSession,
9    input_tokens: &[TokenId],
10) -> (Context, Tensor) {
11    // For the first run, we need to guess a maximum buffer size so we can measure
12    // the actual memory consumption of the temporary ggml context.
13    //
14    // These numbers are from `llama.cpp`, and could potentially be more efficient.
15    let mut buf_size = {
16        let buf_size_mb = if n_layer >= 80 {
17            1536
18        } else if n_layer >= 60 {
19            1280
20        } else {
21            1024
22        };
23        buf_size_mb * 1024 * 1024
24    };
25
26    let n = input_tokens.len();
27    if session.mem_per_token > 0 && session.mem_per_token * n > buf_size {
28        // add 10% to account for ggml object overhead
29        buf_size = (1.1f64 * session.mem_per_token as f64 * n as f64) as usize;
30    };
31    let ctx0 = ggml::Context::init(buf_size, true);
32
33    let mut embd = ctx0.new_tensor_1d(ggml::Type::I32, n);
34    unsafe { embd.write_data(bytemuck::cast_slice(input_tokens)) };
35
36    (ctx0, embd)
37}
38
39/// Return result for just the last token
40pub fn read_last_token(
41    session: &mut InferenceSession,
42    input_layer: &Tensor,
43    n_vocab: usize,
44    n: usize,
45) {
46    assert_eq!(session.last_logits.len(), n_vocab);
47    unsafe {
48        input_layer.read_data(
49            n_vocab * (n - 1) * std::mem::size_of::<f32>(),
50            bytemuck::cast_slice_mut(&mut session.last_logits),
51        )
52    };
53}
54
55/// Extract logits from [OutputRequest] evaluation
56pub fn extract_logits(
57    output_request: &mut OutputRequest,
58    input_layer: &Tensor,
59    n_vocab: usize,
60    n: usize,
61) {
62    if let Some(all_logits) = &mut output_request.all_logits {
63        all_logits.resize(n_vocab * n, 0.0);
64        // SAFETY: Tensor data can be read (properly aligned, initialized,
65        // data will not be mutated or otherwise aliased during the copy),
66        // and we're not reading past the end of the tensor data.
67        assert_eq!(input_layer.nelements(), n_vocab * n);
68        unsafe {
69            input_layer.read_data(0, bytemuck::cast_slice_mut(all_logits));
70        }
71    }
72}
73
74/// Extract embeddings from [OutputRequest] evaluation
75pub fn extract_embeddings(
76    output_request: &mut OutputRequest,
77    embd: &Tensor,
78    n_embd: usize,
79    n: usize,
80) {
81    // Extract embeddings
82    if let Some(embeddings) = &mut output_request.embeddings {
83        embeddings.resize(n_embd * n, 0.0);
84        // SAFETY: Same rationale as for the "Extract logits" section applies.
85        assert_eq!(embd.nelements(), n_embd * n);
86        unsafe {
87            embd.read_data(0, bytemuck::cast_slice_mut(embeddings));
88        }
89    }
90}
91
92/// Update an [InferenceSession] after evaluation
93pub fn update_session(session: &mut InferenceSession, ctx0: &Context, n_input: usize, n: usize) {
94    // Adjust the required memory per token if we didn't know that already
95    if session.mem_per_token == 0 {
96        session.mem_per_token = ctx0.used_mem() / n;
97    }
98
99    // Adjust n_past to new length.
100    session.n_past += n_input;
101}