use crate::context::LlamaModelContext;
use crate::token::LlamaToken;
use std::ffi::{CString, NulError};
use std::path::{Path, PathBuf};
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
pub enum SaveSessionError {
#[error("Failed to save session file")]
FailedToSave,
#[error("null byte in string {0}")]
NullError(#[from] NulError),
#[error("failed to convert path {0} to str")]
PathToStrError(PathBuf),
}
#[derive(Debug, Eq, PartialEq, thiserror::Error)]
pub enum LoadSessionError {
#[error("Failed to load session file")]
FailedToLoad,
#[error("null byte in string {0}")]
NullError(#[from] NulError),
#[error("failed to convert path {0} to str")]
PathToStrError(PathBuf),
#[error("max_length is not large enough to hold {n_out} (was {max_tokens})")]
InsufficientMaxLength {
n_out: usize,
max_tokens: usize,
},
}
impl LlamaModelContext<'_> {
pub fn save_session_file(
&self,
path_session: impl AsRef<Path>,
tokens: &[LlamaToken],
) -> Result<(), SaveSessionError> {
let path = path_session.as_ref();
let path = path
.to_str()
.ok_or_else(|| SaveSessionError::PathToStrError(path.to_path_buf()))?;
let cstr = CString::new(path)?;
if unsafe {
infrastructure_llama_bindings::llama_save_session_file(
self.context.as_ptr(),
cstr.as_ptr(),
tokens
.as_ptr()
.cast::<infrastructure_llama_bindings::llama_token>(),
tokens.len(),
)
} {
Ok(())
} else {
Err(SaveSessionError::FailedToSave)
}
}
pub fn load_session_file(
&mut self,
path_session: impl AsRef<Path>,
max_tokens: usize,
) -> Result<Vec<LlamaToken>, LoadSessionError> {
let path = path_session.as_ref();
let path = path
.to_str()
.ok_or(LoadSessionError::PathToStrError(path.to_path_buf()))?;
let cstr = CString::new(path)?;
let mut tokens: Vec<LlamaToken> = Vec::with_capacity(max_tokens);
let mut n_out = 0;
let tokens_out = tokens
.as_mut_ptr()
.cast::<infrastructure_llama_bindings::llama_token>();
let load_session_success = unsafe {
infrastructure_llama_bindings::llama_load_session_file(
self.context.as_ptr(),
cstr.as_ptr(),
tokens_out,
max_tokens,
&raw mut n_out,
)
};
if load_session_success {
if n_out > max_tokens {
return Err(LoadSessionError::InsufficientMaxLength { n_out, max_tokens });
}
unsafe {
tokens.set_len(n_out);
}
Ok(tokens)
} else {
Err(LoadSessionError::FailedToLoad)
}
}
#[must_use]
pub fn get_state_size(&self) -> usize {
unsafe { infrastructure_llama_bindings::llama_get_state_size(self.context.as_ptr()) }
}
pub unsafe fn copy_state_data(&self, dest: *mut u8) -> usize {
unsafe { infrastructure_llama_bindings::llama_copy_state_data(self.context.as_ptr(), dest) }
}
pub unsafe fn set_state_data(&mut self, src: &[u8]) -> usize {
unsafe {
infrastructure_llama_bindings::llama_set_state_data(self.context.as_ptr(), src.as_ptr())
}
}
}