use ggml::{Context, Tensor};
use crate::{EvaluateOutputRequest, InferenceSession, TokenId};
pub fn prepare_for_evaluate(
n_layer: usize,
session: &mut InferenceSession,
input_tokens: &[TokenId],
) -> (Context, Tensor) {
let mut buf_size = {
let buf_size_mb = if n_layer >= 80 {
1536
} else if n_layer >= 60 {
1280
} else {
1024
};
buf_size_mb * 1024 * 1024
};
let n = input_tokens.len();
if session.mem_per_token > 0 && session.mem_per_token * n > buf_size {
buf_size = (1.1f64 * session.mem_per_token as f64 * n as f64) as usize;
};
let ctx0 = ggml::Context::init(buf_size, true);
let mut embd = ctx0.new_tensor_1d(ggml::Type::I32, n);
unsafe { embd.write_data(bytemuck::cast_slice(input_tokens)) };
(ctx0, embd)
}
pub fn read_last_token(
session: &mut InferenceSession,
input_layer: &Tensor,
n_vocab: usize,
n: usize,
) {
assert_eq!(session.last_logits.len(), n_vocab);
unsafe {
input_layer.read_data(
n_vocab * (n - 1) * std::mem::size_of::<f32>(),
bytemuck::cast_slice_mut(&mut session.last_logits),
)
};
}
pub fn extract_logits(
output_request: &mut EvaluateOutputRequest,
input_layer: &Tensor,
n_vocab: usize,
n: usize,
) {
if let Some(all_logits) = &mut output_request.all_logits {
all_logits.resize(n_vocab * n, 0.0);
assert_eq!(input_layer.nelements(), n_vocab * n);
unsafe {
input_layer.read_data(0, bytemuck::cast_slice_mut(all_logits));
}
}
}
pub fn extract_embeddings(
output_request: &mut EvaluateOutputRequest,
embd: &Tensor,
n_embd: usize,
n: usize,
) {
if let Some(embeddings) = &mut output_request.embeddings {
embeddings.resize(n_embd * n, 0.0);
assert_eq!(embd.nelements(), n_embd * n);
unsafe {
embd.read_data(0, bytemuck::cast_slice_mut(embeddings));
}
}
}
pub fn update_session(session: &mut InferenceSession, ctx0: &Context, n_input: usize, n: usize) {
if session.mem_per_token == 0 {
session.mem_per_token = ctx0.used_mem() / n;
}
session.n_past += n_input;
}