use std::fmt::{Debug, Formatter};
use std::num::NonZeroI32;
use crate::llama_backend::LlamaBackend;
use crate::llama_batch::LlamaBatch;
use crate::model::LlamaModel;
use crate::timing::LlamaTimings;
use crate::token::data::LlamaTokenData;
use crate::token::LlamaToken;
use crate::{DecodeError, LlamaContextLoadError};
use params::LlamaContextParams;
use std::os::raw::c_int;
use std::ptr::NonNull;
use std::slice;
pub mod kv_cache;
pub mod params;
pub mod sample;
#[allow(clippy::module_name_repetitions)]
pub struct LlamaContext<'a> {
pub(crate) context: NonNull<llama_cpp_sys_2::llama_context>,
pub model: &'a LlamaModel,
initialized_logits: Vec<i32>,
}
impl Debug for LlamaContext<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LlamaContext")
.field("context", &self.context)
.finish()
}
}
impl<'model> LlamaContext<'model> {
pub(crate) fn new(
llama_model: &'model LlamaModel,
llama_context: NonNull<llama_cpp_sys_2::llama_context>,
) -> Self {
Self {
context: llama_context,
model: llama_model,
initialized_logits: Vec::new(),
}
}
#[must_use]
pub fn n_ctx(&self) -> c_int {
unsafe { llama_cpp_sys_2::llama_n_ctx(self.context.as_ptr()) }
}
pub fn decode(&mut self, batch: &mut LlamaBatch) -> Result<(), DecodeError> {
let result =
unsafe { llama_cpp_sys_2::llama_decode(self.context.as_ptr(), batch.llama_batch) };
match NonZeroI32::new(result as i32) {
None => {
self.initialized_logits = batch.initialized_logits.clone();
Ok(())
}
Some(error) => Err(DecodeError::from(error)),
}
}
pub fn candidates_ith(&self, i: i32) -> impl Iterator<Item = LlamaTokenData> + '_ {
assert!(
self.initialized_logits.contains(&i),
"logit {i} is not initialized. only {:?} is",
self.initialized_logits
);
(0_i32..).zip(self.get_logits_ith(i)).map(|(i, logit)| {
let token = LlamaToken::new(i);
LlamaTokenData::new(token, *logit, 0_f32)
})
}
#[must_use]
pub fn get_logits_ith(&self, i: i32) -> &[f32] {
assert!(
self.n_ctx() > i,
"n_ctx ({}) must be greater than i ({})",
self.n_ctx(),
i
);
let data = unsafe { llama_cpp_sys_2::llama_get_logits_ith(self.context.as_ptr(), i) };
let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize");
unsafe { slice::from_raw_parts(data, len) }
}
pub fn reset_timings(&mut self) {
unsafe { llama_cpp_sys_2::llama_reset_timings(self.context.as_ptr()) }
}
pub fn logits_mut(&mut self, n_tokens: usize) -> &mut [f32] {
let logits_ptr = unsafe { llama_cpp_sys_2::llama_get_logits(self.context.as_ptr()) };
let n_vocab = usize::try_from(self.model.n_vocab()).expect("n_vocab should be positive");
unsafe { slice::from_raw_parts_mut(logits_ptr, n_vocab * n_tokens) }
}
#[deprecated]
#[must_use]
pub fn logits(&self, n_tokens: usize) -> &[f32] {
let n_vocab = usize::try_from(self.model.n_vocab()).expect("n_vocab should be positive");
let logits_ptr = unsafe { llama_cpp_sys_2::llama_get_logits(self.context.as_ptr()) };
unsafe { slice::from_raw_parts(logits_ptr, n_vocab * n_tokens) }
}
pub fn timings(&mut self) -> LlamaTimings {
let timings = unsafe { llama_cpp_sys_2::llama_get_timings(self.context.as_ptr()) };
LlamaTimings { timings }
}
#[deprecated(note = "use `Model::new_context` instead")]
#[tracing::instrument(skip_all)]
pub fn new_with_model(
backend: &LlamaBackend,
model: &'model mut LlamaModel,
context_params: &LlamaContextParams,
) -> Result<Self, LlamaContextLoadError> {
model.new_context(backend, context_params)
}
}
impl Drop for LlamaContext<'_> {
fn drop(&mut self) {
unsafe { llama_cpp_sys_2::llama_free(self.context.as_ptr()) }
}
}