use std::path::Path;
use crate::context::LlamaContext;
use crate::error::{LlamaError, Result};
impl LlamaContext<'_> {
#[must_use]
pub fn state_size(&self) -> usize {
unsafe { llama_crab_sys::llama_state_get_size(self.raw()) }
}
pub fn state_to_bytes(&self) -> Result<Vec<u8>> {
let size = self.state_size();
let mut buf = vec![0_u8; size];
let written =
unsafe { llama_crab_sys::llama_state_get_data(self.raw(), buf.as_mut_ptr(), size) };
if written != size {
return Err(LlamaError::Ffi(-1));
}
Ok(buf)
}
pub fn load_state(&mut self, bytes: &[u8]) -> Result<()> {
let n = unsafe {
llama_crab_sys::llama_state_set_data(self.raw(), bytes.as_ptr(), bytes.len())
};
if n != bytes.len() {
return Err(LlamaError::Ffi(-1));
}
Ok(())
}
pub fn state_save_file(&self, path: impl AsRef<Path>) -> Result<()> {
let cstr = std::ffi::CString::new(path.as_ref().display().to_string())?;
let tokens_to_copy = self.n_ctx() as usize;
let ok = unsafe {
llama_crab_sys::llama_state_save_file(
self.raw(),
cstr.as_ptr(),
std::ptr::null(), tokens_to_copy,
)
};
if !ok {
return Err(LlamaError::Io(std::io::Error::other("state_save_file")));
}
Ok(())
}
pub fn state_load_file(&mut self, path: impl AsRef<Path>) -> Result<()> {
let cstr = std::ffi::CString::new(path.as_ref().display().to_string())?;
let mut token_count: usize = 0;
let ok = unsafe {
llama_crab_sys::llama_state_load_file(
self.raw(),
cstr.as_ptr(),
std::ptr::null_mut(),
0,
&mut token_count,
)
};
if !ok {
return Err(LlamaError::Io(std::io::Error::other("state_load_file")));
}
Ok(())
}
}