1use ggml::{Context, Tensor};
2
3use crate::{InferenceSession, OutputRequest, TokenId};
4
5pub fn prepare_for_evaluate(
7 n_layer: usize,
8 session: &mut InferenceSession,
9 input_tokens: &[TokenId],
10) -> (Context, Tensor) {
11 let mut buf_size = {
16 let buf_size_mb = if n_layer >= 80 {
17 1536
18 } else if n_layer >= 60 {
19 1280
20 } else {
21 1024
22 };
23 buf_size_mb * 1024 * 1024
24 };
25
26 let n = input_tokens.len();
27 if session.mem_per_token > 0 && session.mem_per_token * n > buf_size {
28 buf_size = (1.1f64 * session.mem_per_token as f64 * n as f64) as usize;
30 };
31 let ctx0 = ggml::Context::init(buf_size, true);
32
33 let mut embd = ctx0.new_tensor_1d(ggml::Type::I32, n);
34 unsafe { embd.write_data(bytemuck::cast_slice(input_tokens)) };
35
36 (ctx0, embd)
37}
38
39pub fn read_last_token(
41 session: &mut InferenceSession,
42 input_layer: &Tensor,
43 n_vocab: usize,
44 n: usize,
45) {
46 assert_eq!(session.last_logits.len(), n_vocab);
47 unsafe {
48 input_layer.read_data(
49 n_vocab * (n - 1) * std::mem::size_of::<f32>(),
50 bytemuck::cast_slice_mut(&mut session.last_logits),
51 )
52 };
53}
54
55pub fn extract_logits(
57 output_request: &mut OutputRequest,
58 input_layer: &Tensor,
59 n_vocab: usize,
60 n: usize,
61) {
62 if let Some(all_logits) = &mut output_request.all_logits {
63 all_logits.resize(n_vocab * n, 0.0);
64 assert_eq!(input_layer.nelements(), n_vocab * n);
68 unsafe {
69 input_layer.read_data(0, bytemuck::cast_slice_mut(all_logits));
70 }
71 }
72}
73
74pub fn extract_embeddings(
76 output_request: &mut OutputRequest,
77 embd: &Tensor,
78 n_embd: usize,
79 n: usize,
80) {
81 if let Some(embeddings) = &mut output_request.embeddings {
83 embeddings.resize(n_embd * n, 0.0);
84 assert_eq!(embd.nelements(), n_embd * n);
86 unsafe {
87 embd.read_data(0, bytemuck::cast_slice_mut(embeddings));
88 }
89 }
90}
91
92pub fn update_session(session: &mut InferenceSession, ctx0: &Context, n_input: usize, n: usize) {
94 if session.mem_per_token == 0 {
96 session.mem_per_token = ctx0.used_mem() / n;
97 }
98
99 session.n_past += n_input;
101}