use std::ffi::c_void;
use std::fmt::{Debug, Formatter};
use std::num::NonZeroI32;
use std::ptr::NonNull;
use std::slice;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use crate::context::params::LlamaContextParams;
use crate::llama_backend::LlamaBackend;
use crate::llama_batch::LlamaBatch;
use crate::model::{LlamaLoraAdapter, LlamaModel};
use crate::timing::LlamaTimings;
use crate::token::LlamaToken;
use crate::token::data::LlamaTokenData;
use crate::token::data_array::LlamaTokenDataArray;
use crate::{
DecodeError, EmbeddingsError, EncodeError, LlamaContextLoadError, LlamaLoraAdapterRemoveError,
LlamaLoraAdapterSetError, LogitsError,
};
const fn check_lora_set_result(err_code: i32) -> Result<(), LlamaLoraAdapterSetError> {
if err_code != 0 {
return Err(LlamaLoraAdapterSetError::ErrorResult(err_code));
}
Ok(())
}
const fn check_lora_remove_result(err_code: i32) -> Result<(), LlamaLoraAdapterRemoveError> {
if err_code != 0 {
return Err(LlamaLoraAdapterRemoveError::ErrorResult(err_code));
}
Ok(())
}
pub mod kv_cache;
pub mod kv_cache_type;
pub mod llama_attention_type;
pub mod llama_pooling_type;
pub mod llama_state_seq_flags;
pub mod load_seq_state_error;
pub mod load_session_error;
pub mod params;
pub mod rope_scaling_type;
pub mod save_seq_state_error;
pub mod save_session_error;
pub mod session;
unsafe extern "C" fn abort_callback_trampoline(data: *mut c_void) -> bool {
let flag = unsafe { &*(data as *const AtomicBool) };
flag.load(Ordering::Relaxed)
}
pub struct LlamaContext<'model> {
pub context: NonNull<llama_cpp_bindings_sys::llama_context>,
pub model: &'model LlamaModel,
abort_flag: Option<Arc<AtomicBool>>,
initialized_logits: Vec<i32>,
embeddings_enabled: bool,
}
impl Debug for LlamaContext<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LlamaContext")
.field("context", &self.context)
.finish()
}
}
impl<'model> LlamaContext<'model> {
#[must_use]
pub const fn new(
llama_model: &'model LlamaModel,
llama_context: NonNull<llama_cpp_bindings_sys::llama_context>,
embeddings_enabled: bool,
) -> Self {
Self {
context: llama_context,
model: llama_model,
abort_flag: None,
initialized_logits: Vec::new(),
embeddings_enabled,
}
}
#[expect(
clippy::needless_pass_by_value,
reason = "LlamaContextParams may become non-trivially copyable upstream"
)]
pub fn from_model(
model: &'model LlamaModel,
_backend: &LlamaBackend,
params: LlamaContextParams,
) -> Result<Self, LlamaContextLoadError> {
let context_params = params.context_params;
let mut out_ctx: *mut llama_cpp_bindings_sys::llama_context = std::ptr::null_mut();
let mut out_error: *mut std::os::raw::c_char = std::ptr::null_mut();
let status = unsafe {
llama_cpp_bindings_sys::llama_rs_new_context_with_model(
model.model.as_ptr(),
context_params,
&raw mut out_ctx,
&raw mut out_error,
)
};
match status {
llama_cpp_bindings_sys::LLAMA_RS_NEW_CONTEXT_WITH_MODEL_OK => {
let context = NonNull::new(out_ctx)
.ok_or(LlamaContextLoadError::Unconstructible)?;
Ok(Self::new(model, context, params.embeddings()))
}
llama_cpp_bindings_sys::LLAMA_RS_NEW_CONTEXT_WITH_MODEL_VENDORED_RETURNED_NULL => {
Err(LlamaContextLoadError::Unconstructible)
}
llama_cpp_bindings_sys::LLAMA_RS_NEW_CONTEXT_WITH_MODEL_ERROR_STRING_ALLOCATION_FAILED => {
Err(LlamaContextLoadError::NotEnoughMemory)
}
llama_cpp_bindings_sys::LLAMA_RS_NEW_CONTEXT_WITH_MODEL_VENDORED_THREW_CXX_EXCEPTION => {
let message = unsafe { crate::ffi_error_reader::read_and_free_cpp_error(out_error) };
Err(LlamaContextLoadError::Reported { message })
}
other => unreachable!(
"llama_rs_new_context_with_model returned unrecognized status {other}"
),
}
}
#[must_use]
pub fn n_batch(&self) -> u32 {
unsafe { llama_cpp_bindings_sys::llama_n_batch(self.context.as_ptr()) }
}
#[must_use]
pub fn n_ubatch(&self) -> u32 {
unsafe { llama_cpp_bindings_sys::llama_n_ubatch(self.context.as_ptr()) }
}
#[must_use]
pub fn n_ctx(&self) -> u32 {
unsafe { llama_cpp_bindings_sys::llama_n_ctx(self.context.as_ptr()) }
}
#[expect(unsafe_code, reason = "required for FFI abort callback registration")]
pub fn set_abort_flag(&mut self, flag: Arc<AtomicBool>) {
let raw_ptr = Arc::as_ptr(&flag) as *mut c_void;
self.abort_flag = Some(flag);
unsafe {
llama_cpp_bindings_sys::llama_set_abort_callback(
self.context.as_ptr(),
Some(abort_callback_trampoline),
raw_ptr,
);
}
}
#[expect(unsafe_code, reason = "required for FFI abort callback deregistration")]
pub fn clear_abort_callback(&mut self) {
self.abort_flag = None;
unsafe {
llama_cpp_bindings_sys::llama_set_abort_callback(
self.context.as_ptr(),
None,
std::ptr::null_mut(),
);
}
}
#[expect(unsafe_code, reason = "required for FFI synchronization call")]
pub fn synchronize(&self) {
unsafe { llama_cpp_bindings_sys::llama_synchronize(self.context.as_ptr()) }
}
#[expect(unsafe_code, reason = "required for FFI threadpool detachment")]
pub fn detach_threadpool(&self) {
unsafe { llama_cpp_bindings_sys::llama_detach_threadpool(self.context.as_ptr()) }
}
pub fn mark_logits_initialized(&mut self, token_index: i32) {
self.initialized_logits = vec![token_index];
}
pub fn decode(&mut self, batch: &mut LlamaBatch) -> Result<(), DecodeError> {
let mut out_vendored_return_code: i32 = 0;
let mut out_error: *mut std::os::raw::c_char = std::ptr::null_mut();
let status = unsafe {
llama_cpp_bindings_sys::llama_rs_decode(
self.context.as_ptr(),
batch.llama_batch,
&raw mut out_vendored_return_code,
&raw mut out_error,
)
};
match status {
llama_cpp_bindings_sys::LLAMA_RS_DECODE_OK => {
self.initialized_logits
.clone_from(&batch.initialized_logits);
Ok(())
}
llama_cpp_bindings_sys::LLAMA_RS_DECODE_VENDORED_RETURNED_NONZERO_CODE => {
let code = NonZeroI32::new(out_vendored_return_code).unwrap_or_else(|| {
unreachable!(
"llama_rs_decode reported a nonzero return code but the value was zero"
)
});
Err(DecodeError::from(code))
}
llama_cpp_bindings_sys::LLAMA_RS_DECODE_OUT_OF_MEMORY => {
Err(DecodeError::DecodeOutOfMemory)
}
llama_cpp_bindings_sys::LLAMA_RS_DECODE_COMPUTE_FAILED => {
Err(DecodeError::ComputeFailed)
}
llama_cpp_bindings_sys::LLAMA_RS_DECODE_ERROR_STRING_ALLOCATION_FAILED => {
Err(DecodeError::NotEnoughMemory)
}
llama_cpp_bindings_sys::LLAMA_RS_DECODE_VENDORED_THREW_CXX_EXCEPTION => {
let message =
unsafe { crate::ffi_error_reader::read_and_free_cpp_error(out_error) };
Err(DecodeError::Reported { message })
}
other => unreachable!("llama_rs_decode returned unrecognized status {other}"),
}
}
pub fn encode(&mut self, batch: &mut LlamaBatch) -> Result<(), EncodeError> {
let mut out_vendored_return_code: i32 = 0;
let mut out_error: *mut std::os::raw::c_char = std::ptr::null_mut();
let status = unsafe {
llama_cpp_bindings_sys::llama_rs_encode(
self.context.as_ptr(),
batch.llama_batch,
&raw mut out_vendored_return_code,
&raw mut out_error,
)
};
match status {
llama_cpp_bindings_sys::LLAMA_RS_ENCODE_OK => {
self.initialized_logits
.clone_from(&batch.initialized_logits);
Ok(())
}
llama_cpp_bindings_sys::LLAMA_RS_ENCODE_MODEL_HAS_NO_ENCODER => {
Err(EncodeError::ModelHasNoEncoder)
}
llama_cpp_bindings_sys::LLAMA_RS_ENCODE_VENDORED_RETURNED_NONZERO_CODE => {
let code = NonZeroI32::new(out_vendored_return_code).unwrap_or_else(|| {
unreachable!(
"llama_rs_encode reported a nonzero return code but the value was zero"
)
});
Err(EncodeError::from(code))
}
llama_cpp_bindings_sys::LLAMA_RS_ENCODE_OUT_OF_MEMORY => {
Err(EncodeError::EncodeOutOfMemory)
}
llama_cpp_bindings_sys::LLAMA_RS_ENCODE_COMPUTE_FAILED => {
Err(EncodeError::ComputeFailed)
}
llama_cpp_bindings_sys::LLAMA_RS_ENCODE_ERROR_STRING_ALLOCATION_FAILED => {
Err(EncodeError::NotEnoughMemory)
}
llama_cpp_bindings_sys::LLAMA_RS_ENCODE_VENDORED_THREW_CXX_EXCEPTION => {
let message =
unsafe { crate::ffi_error_reader::read_and_free_cpp_error(out_error) };
Err(EncodeError::Reported { message })
}
other => unreachable!("llama_rs_encode returned unrecognized status {other}"),
}
}
pub fn embeddings_seq_ith(&self, sequence_index: i32) -> Result<&[f32], EmbeddingsError> {
if !self.embeddings_enabled {
return Err(EmbeddingsError::NotEnabled);
}
let n_embd = usize::try_from(self.model.n_embd())
.map_err(EmbeddingsError::InvalidEmbeddingDimension)?;
unsafe {
let embedding = llama_cpp_bindings_sys::llama_get_embeddings_seq(
self.context.as_ptr(),
sequence_index,
);
if embedding.is_null() {
Err(EmbeddingsError::NonePoolType)
} else {
Ok(slice::from_raw_parts(embedding, n_embd))
}
}
}
pub fn embeddings_ith(&self, token_index: i32) -> Result<&[f32], EmbeddingsError> {
if !self.embeddings_enabled {
return Err(EmbeddingsError::NotEnabled);
}
let n_embd = usize::try_from(self.model.n_embd())
.map_err(EmbeddingsError::InvalidEmbeddingDimension)?;
unsafe {
let embedding = llama_cpp_bindings_sys::llama_get_embeddings_ith(
self.context.as_ptr(),
token_index,
);
if embedding.is_null() {
Err(EmbeddingsError::LogitsNotEnabled)
} else {
Ok(slice::from_raw_parts(embedding, n_embd))
}
}
}
pub fn candidates(&self) -> Result<impl Iterator<Item = LlamaTokenData> + '_, LogitsError> {
let logits = self.get_logits()?;
Ok((0_i32..).zip(logits).map(|(token_id, logit)| {
let token = LlamaToken::new(token_id);
LlamaTokenData::new(token, *logit, 0_f32)
}))
}
pub fn token_data_array(&self) -> Result<LlamaTokenDataArray, LogitsError> {
Ok(LlamaTokenDataArray::from_iter(self.candidates()?, false))
}
pub fn get_logits(&self) -> Result<&[f32], LogitsError> {
let data = unsafe { llama_cpp_bindings_sys::llama_get_logits(self.context.as_ptr()) };
if data.is_null() {
return Err(LogitsError::NullLogits);
}
let len = usize::try_from(self.model.n_vocab()).map_err(LogitsError::VocabSizeOverflow)?;
Ok(unsafe { slice::from_raw_parts(data, len) })
}
pub fn candidates_ith(
&self,
token_index: i32,
) -> Result<impl Iterator<Item = LlamaTokenData> + '_, LogitsError> {
let logits = self.get_logits_ith(token_index)?;
Ok((0_i32..).zip(logits).map(|(token_id, logit)| {
let token = LlamaToken::new(token_id);
LlamaTokenData::new(token, *logit, 0_f32)
}))
}
pub fn token_data_array_ith(
&self,
token_index: i32,
) -> Result<LlamaTokenDataArray, LogitsError> {
Ok(LlamaTokenDataArray::from_iter(
self.candidates_ith(token_index)?,
false,
))
}
pub fn get_logits_ith(&self, token_index: i32) -> Result<&[f32], LogitsError> {
if !self.initialized_logits.contains(&token_index) {
return Err(LogitsError::TokenNotInitialized(token_index));
}
if token_index >= 0 {
let token_index_u32 =
u32::try_from(token_index).map_err(LogitsError::TokenIndexOverflow)?;
if self.n_ctx() <= token_index_u32 {
return Err(LogitsError::TokenIndexExceedsContext {
token_index: token_index_u32,
context_size: self.n_ctx(),
});
}
}
let data = unsafe {
llama_cpp_bindings_sys::llama_get_logits_ith(self.context.as_ptr(), token_index)
};
let len = usize::try_from(self.model.n_vocab()).map_err(LogitsError::VocabSizeOverflow)?;
Ok(unsafe { slice::from_raw_parts(data, len) })
}
pub fn reset_timings(&mut self) {
unsafe { llama_cpp_bindings_sys::llama_perf_context_reset(self.context.as_ptr()) }
}
pub fn timings(&mut self) -> LlamaTimings {
let timings = unsafe { llama_cpp_bindings_sys::llama_perf_context(self.context.as_ptr()) };
LlamaTimings { timings }
}
pub fn lora_adapter_set(
&self,
adapter: &mut LlamaLoraAdapter,
scale: f32,
) -> Result<(), LlamaLoraAdapterSetError> {
let mut adapters = [adapter.lora_adapter.as_ptr()];
let mut scales = [scale];
let err_code = unsafe {
llama_cpp_bindings_sys::llama_set_adapters_lora(
self.context.as_ptr(),
adapters.as_mut_ptr(),
1,
scales.as_mut_ptr(),
)
};
check_lora_set_result(err_code)?;
log::debug!("Set lora adapter");
Ok(())
}
pub fn lora_adapter_remove(
&self,
_adapter: &mut LlamaLoraAdapter,
) -> Result<(), LlamaLoraAdapterRemoveError> {
let err_code = unsafe {
llama_cpp_bindings_sys::llama_set_adapters_lora(
self.context.as_ptr(),
std::ptr::null_mut(),
0,
std::ptr::null_mut(),
)
};
check_lora_remove_result(err_code)?;
log::debug!("Remove lora adapter");
Ok(())
}
}
impl Drop for LlamaContext<'_> {
fn drop(&mut self) {
unsafe { llama_cpp_bindings_sys::llama_free(self.context.as_ptr()) }
}
}
#[cfg(test)]
mod unit_tests {
use crate::LlamaLoraAdapterRemoveError;
use crate::LlamaLoraAdapterSetError;
use super::{check_lora_remove_result, check_lora_set_result};
#[test]
fn check_lora_set_result_ok_for_zero() {
assert!(check_lora_set_result(0).is_ok());
}
#[test]
fn check_lora_set_result_error_for_nonzero() {
let result = check_lora_set_result(-1);
assert_eq!(result, Err(LlamaLoraAdapterSetError::ErrorResult(-1)));
}
#[test]
fn check_lora_remove_result_ok_for_zero() {
assert!(check_lora_remove_result(0).is_ok());
}
#[test]
fn check_lora_remove_result_error_for_nonzero() {
let result = check_lora_remove_result(-1);
assert_eq!(result, Err(LlamaLoraAdapterRemoveError::ErrorResult(-1)));
}
}