use std::fmt::{Debug, Formatter};
use std::num::NonZeroI32;
use crate::llama_batch::LlamaBatch;
use crate::model::LlamaModel;
use crate::timing::LlamaTimings;
use crate::token::data::LlamaTokenData;
use crate::token::LlamaToken;
use crate::{DecodeError};
use std::ptr::NonNull;
use std::slice;
pub mod kv_cache;
pub mod params;
pub mod sample;
#[allow(clippy::module_name_repetitions)]
pub struct LlamaContext<'a> {
pub(crate) context: NonNull<llama_cpp_sys_2::llama_context>,
pub model: &'a LlamaModel,
initialized_logits: Vec<i32>,
}
impl Debug for LlamaContext<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LlamaContext")
.field("context", &self.context)
.finish()
}
}
impl<'model> LlamaContext<'model> {
pub(crate) fn new(
llama_model: &'model LlamaModel,
llama_context: NonNull<llama_cpp_sys_2::llama_context>,
) -> Self {
Self {
context: llama_context,
model: llama_model,
initialized_logits: Vec::new(),
}
}
#[must_use]
pub fn n_ctx(&self) -> u32 {
unsafe { llama_cpp_sys_2::llama_n_ctx(self.context.as_ptr()) }
}
pub fn decode(&mut self, batch: &mut LlamaBatch) -> Result<(), DecodeError> {
let result =
unsafe { llama_cpp_sys_2::llama_decode(self.context.as_ptr(), batch.llama_batch) };
match NonZeroI32::new(result) {
None => {
self.initialized_logits = batch.initialized_logits.clone();
Ok(())
}
Some(error) => Err(DecodeError::from(error)),
}
}
pub fn candidates_ith(&self, i: i32) -> impl Iterator<Item = LlamaTokenData> + '_ {
(0_i32..).zip(self.get_logits_ith(i)).map(|(i, logit)| {
let token = LlamaToken::new(i);
LlamaTokenData::new(token, *logit, 0_f32)
})
}
#[must_use]
pub fn get_logits_ith(&self, i: i32) -> &[f32] {
assert!(
self.initialized_logits.contains(&i),
"logit {i} is not initialized. only {:?} is",
self.initialized_logits
);
assert!(
self.n_ctx() > u32::try_from(i).expect("i does not fit into a u32"),
"n_ctx ({}) must be greater than i ({})",
self.n_ctx(),
i
);
let data = unsafe { llama_cpp_sys_2::llama_get_logits_ith(self.context.as_ptr(), i) };
let len = usize::try_from(self.model.n_vocab()).expect("n_vocab does not fit into a usize");
unsafe { slice::from_raw_parts(data, len) }
}
pub fn reset_timings(&mut self) {
unsafe { llama_cpp_sys_2::llama_reset_timings(self.context.as_ptr()) }
}
pub fn timings(&mut self) -> LlamaTimings {
let timings = unsafe { llama_cpp_sys_2::llama_get_timings(self.context.as_ptr()) };
LlamaTimings { timings }
}
}
impl Drop for LlamaContext<'_> {
fn drop(&mut self) {
unsafe { llama_cpp_sys_2::llama_free(self.context.as_ptr()) }
}
}