pub mod embeddings;
pub mod kv_cache;
pub mod params;
pub mod sampling_state;
pub mod session;
use std::ptr::NonNull;
use llama_crab_sys as sys;
use crate::batch::LlamaBatch;
use crate::error::{LlamaError, Result};
use crate::model::LlamaModel;
#[derive(Debug)]
pub struct LlamaContext<'a> {
pub(crate) handle: NonNull<sys::llama_context>,
pub(crate) model: &'a LlamaModel,
}
impl<'a> LlamaContext<'a> {
pub(crate) fn from_raw(handle: NonNull<sys::llama_context>, model: &'a LlamaModel) -> Self {
Self { handle, model }
}
#[must_use]
pub fn n_ctx(&self) -> u32 {
unsafe { sys::llama_n_ctx(self.handle.as_ptr()) as u32 }
}
#[must_use]
pub fn n_batch(&self) -> u32 {
unsafe { sys::llama_n_batch(self.handle.as_ptr()) as u32 }
}
#[must_use]
pub fn n_ubatch(&self) -> u32 {
unsafe { sys::llama_n_ubatch(self.handle.as_ptr()) as u32 }
}
#[must_use]
pub fn n_seq_max(&self) -> u32 {
unsafe { sys::llama_n_seq_max(self.handle.as_ptr()) as u32 }
}
#[must_use]
pub fn raw_handle(&self) -> *mut sys::llama_context {
self.handle.as_ptr()
}
pub fn decode(&mut self, batch: &LlamaBatch) -> Result<()> {
let rc = unsafe { sys::llama_decode(self.handle.as_ptr(), *batch.raw()) };
if rc != 0 {
return Err(LlamaError::Decode(rc));
}
Ok(())
}
pub fn encode(&mut self, batch: &LlamaBatch) -> Result<()> {
let rc = unsafe { sys::llama_encode(self.handle.as_ptr(), *batch.raw()) };
if rc != 0 {
return Err(LlamaError::Encode(rc));
}
Ok(())
}
#[must_use]
pub const fn model(&self) -> &'a LlamaModel {
self.model
}
pub(crate) fn raw(&self) -> *mut sys::llama_context {
self.handle.as_ptr()
}
}
unsafe impl Send for LlamaContext<'_> {}
unsafe impl Sync for LlamaContext<'_> {}
impl Drop for LlamaContext<'_> {
fn drop(&mut self) {
unsafe { sys::llama_free(self.handle.as_ptr()) };
}
}
pub use self::params::LlamaContextParams;