use std::ffi::{CStr, CString, c_char};
use std::num::NonZeroU16;
use std::os::raw::c_int;
use std::path::Path;
use std::sync::Arc;
use std::sync::OnceLock;
use toktrie::ApproximateTokEnv;
use toktrie::TokRxInfo;
use toktrie::TokTrie;
fn truncated_buffer_to_string(
mut buffer: Vec<u8>,
length: usize,
) -> Result<String, ApplyChatTemplateError> {
buffer.truncate(length);
Ok(String::from_utf8(buffer)?)
}
fn validate_string_length_for_tokenizer(length: usize) -> Result<c_int, StringToTokenError> {
Ok(c_int::try_from(length)?)
}
fn cstring_with_validated_len(str: &str) -> Result<(CString, c_int), StringToTokenError> {
let c_string = CString::new(str)?;
let len = validate_string_length_for_tokenizer(c_string.as_bytes().len())?;
Ok((c_string, len))
}
use std::ptr::{self, NonNull};
use crate::chat_message_parse_outcome::ChatMessageParseOutcome;
use crate::ffi_status_to_i32::status_to_i32;
use crate::llama_backend::LlamaBackend;
use crate::llama_token_attrs::LlamaTokenAttrs;
use crate::llama_token_attrs_from_int_error::LlamaTokenAttrsFromIntError;
use crate::raw_chat_message::RawChatMessage;
use crate::resolved_tool_call_markers::ResolvedToolCallMarkers;
use crate::sampled_token::SampledToken;
use crate::sampled_token_classifier::SampledTokenClassifier;
use crate::sampled_token_classifier::StreamingMarkers;
use crate::token::LlamaToken;
use crate::{
ApplyChatTemplateError, ChatTemplateError, LlamaLoraAdapterInitError, LlamaModelLoadError,
MarkerDetectionError, MetaValError, ParseChatMessageError, StringToTokenError,
TokenToStringError,
};
use llama_cpp_bindings_types::ParsedChatMessage;
use llama_cpp_bindings_types::ParsedToolCall;
use llama_cpp_bindings_types::ReasoningMarkers;
use llama_cpp_bindings_types::ToolCallArguments;
use llama_cpp_bindings_types::ToolCallMarkers;
use crate::tool_call_format;
use crate::tool_call_format::ToolCallFormatOutcome;
use crate::tool_call_template_overrides;
pub mod add_bos;
pub mod llama_chat_message;
pub mod llama_chat_template;
pub mod llama_lora_adapter;
pub mod params;
pub mod rope_type;
pub mod split_mode;
pub mod vocab_type;
pub mod vocab_type_from_int_error;
pub use add_bos::AddBos;
pub use llama_chat_message::LlamaChatMessage;
pub use llama_chat_template::LlamaChatTemplate;
pub use llama_lora_adapter::LlamaLoraAdapter;
pub use rope_type::RopeType;
pub use vocab_type::VocabType;
pub use vocab_type_from_int_error::VocabTypeFromIntError;
use params::LlamaModelParams;
pub struct LlamaModel {
pub model: NonNull<llama_cpp_bindings_sys::llama_model>,
tok_env: OnceLock<Arc<ApproximateTokEnv>>,
}
impl std::fmt::Debug for LlamaModel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LlamaModel")
.field("model", &self.model)
.finish_non_exhaustive()
}
}
unsafe impl Send for LlamaModel {}
unsafe impl Sync for LlamaModel {}
impl LlamaModel {
#[must_use]
pub fn vocab_ptr(&self) -> *const llama_cpp_bindings_sys::llama_vocab {
unsafe { llama_cpp_bindings_sys::llama_model_get_vocab(self.model.as_ptr()) }
}
pub fn n_ctx_train(&self) -> Result<u32, std::num::TryFromIntError> {
let n_ctx_train = unsafe { llama_cpp_bindings_sys::llama_n_ctx_train(self.model.as_ptr()) };
u32::try_from(n_ctx_train)
}
pub fn tokens(
&self,
decode_special: bool,
) -> impl Iterator<Item = (LlamaToken, Result<String, TokenToStringError>)> + '_ {
(0..self.n_vocab())
.map(LlamaToken::new)
.map(move |llama_token| {
let mut decoder = encoding_rs::UTF_8.new_decoder();
(
llama_token,
self.token_to_piece(
&SampledToken::Content(llama_token),
&mut decoder,
decode_special,
None,
),
)
})
}
#[must_use]
pub fn token_bos(&self) -> LlamaToken {
let token = unsafe { llama_cpp_bindings_sys::llama_token_bos(self.vocab_ptr()) };
LlamaToken(token)
}
#[must_use]
pub fn token_eos(&self) -> LlamaToken {
let token = unsafe { llama_cpp_bindings_sys::llama_token_eos(self.vocab_ptr()) };
LlamaToken(token)
}
#[must_use]
pub fn token_nl(&self) -> LlamaToken {
let token = unsafe { llama_cpp_bindings_sys::llama_token_nl(self.vocab_ptr()) };
LlamaToken(token)
}
#[must_use]
pub fn is_eog_token(&self, token: &SampledToken) -> bool {
let (SampledToken::Content(LlamaToken(id))
| SampledToken::Reasoning(LlamaToken(id))
| SampledToken::ToolCall(LlamaToken(id))
| SampledToken::Undeterminable(LlamaToken(id))) = *token;
unsafe { llama_cpp_bindings_sys::llama_token_is_eog(self.vocab_ptr(), id) }
}
#[must_use]
pub fn decode_start_token(&self) -> LlamaToken {
let token =
unsafe { llama_cpp_bindings_sys::llama_model_decoder_start_token(self.model.as_ptr()) };
LlamaToken(token)
}
#[must_use]
pub fn token_sep(&self) -> LlamaToken {
let token = unsafe { llama_cpp_bindings_sys::llama_vocab_sep(self.vocab_ptr()) };
LlamaToken(token)
}
pub fn str_to_token(
&self,
str: &str,
add_bos: AddBos,
) -> Result<Vec<LlamaToken>, StringToTokenError> {
let add_bos = match add_bos {
AddBos::Always => true,
AddBos::Never => false,
};
let tokens_estimation = std::cmp::max(8, (str.len() / 2) + usize::from(add_bos));
let mut buffer: Vec<LlamaToken> = Vec::with_capacity(tokens_estimation);
let (c_string, c_string_len) = cstring_with_validated_len(str)?;
let buffer_capacity = c_int::try_from(buffer.capacity())?;
let size = unsafe {
llama_cpp_bindings_sys::llama_tokenize(
self.vocab_ptr(),
c_string.as_ptr(),
c_string_len,
buffer
.as_mut_ptr()
.cast::<llama_cpp_bindings_sys::llama_token>(),
buffer_capacity,
add_bos,
true,
)
};
let size = if size.is_negative() {
buffer.reserve_exact(usize::try_from(-size)?);
unsafe {
llama_cpp_bindings_sys::llama_tokenize(
self.vocab_ptr(),
c_string.as_ptr(),
c_string_len,
buffer
.as_mut_ptr()
.cast::<llama_cpp_bindings_sys::llama_token>(),
-size,
add_bos,
true,
)
}
} else {
size
};
let size = usize::try_from(size)?;
unsafe { buffer.set_len(size) }
Ok(buffer)
}
pub fn token_attr(
&self,
LlamaToken(id): LlamaToken,
) -> Result<LlamaTokenAttrs, LlamaTokenAttrsFromIntError> {
let token_type =
unsafe { llama_cpp_bindings_sys::llama_token_get_attr(self.vocab_ptr(), id) };
LlamaTokenAttrs::try_from(token_type)
}
pub fn token_to_piece(
&self,
token: &SampledToken,
decoder: &mut encoding_rs::Decoder,
special: bool,
lstrip: Option<NonZeroU16>,
) -> Result<String, TokenToStringError> {
let (SampledToken::Content(inner)
| SampledToken::Reasoning(inner)
| SampledToken::ToolCall(inner)
| SampledToken::Undeterminable(inner)) = *token;
let bytes = match self.token_to_piece_bytes(inner, 8, special, lstrip) {
Err(TokenToStringError::InsufficientBufferSpace(required_size)) => {
let buffer_size: usize = (-required_size).try_into()?;
self.token_to_piece_bytes(inner, buffer_size, special, lstrip)
}
other => other,
}?;
let mut output_piece = String::with_capacity(bytes.len());
let (_result, _decoded_size, _had_replacements) =
decoder.decode_to_string(&bytes, &mut output_piece, false);
Ok(output_piece)
}
pub fn token_to_piece_bytes(
&self,
token: LlamaToken,
buffer_size: usize,
special: bool,
lstrip: Option<NonZeroU16>,
) -> Result<Vec<u8>, TokenToStringError> {
let mut buffer: Vec<u8> = vec![0u8; buffer_size];
let buffer_len = c_int::try_from(buffer.len())?;
let lstrip = lstrip.map_or(0, |strip_count| i32::from(strip_count.get()));
let size = unsafe {
llama_cpp_bindings_sys::llama_token_to_piece(
self.vocab_ptr(),
token.0,
buffer.as_mut_ptr().cast::<c_char>(),
buffer_len,
lstrip,
special,
)
};
match size {
0 => Err(TokenToStringError::UnknownTokenType),
error_code if error_code.is_negative() => {
Err(TokenToStringError::InsufficientBufferSpace(error_code))
}
size => {
let written = usize::try_from(size)?;
buffer.truncate(written);
Ok(buffer)
}
}
}
#[must_use]
pub fn n_vocab(&self) -> i32 {
unsafe { llama_cpp_bindings_sys::llama_n_vocab(self.vocab_ptr()) }
}
pub fn vocab_type(&self) -> Result<VocabType, VocabTypeFromIntError> {
let vocab_type = unsafe { llama_cpp_bindings_sys::llama_vocab_type(self.vocab_ptr()) };
VocabType::try_from(vocab_type)
}
#[must_use]
pub fn n_embd(&self) -> c_int {
unsafe { llama_cpp_bindings_sys::llama_n_embd(self.model.as_ptr()) }
}
#[must_use]
pub fn size(&self) -> u64 {
unsafe { llama_cpp_bindings_sys::llama_model_size(self.model.as_ptr()) }
}
#[must_use]
pub fn n_params(&self) -> u64 {
unsafe { llama_cpp_bindings_sys::llama_model_n_params(self.model.as_ptr()) }
}
#[must_use]
pub fn is_recurrent(&self) -> bool {
unsafe { llama_cpp_bindings_sys::llama_model_is_recurrent(self.model.as_ptr()) }
}
pub fn n_layer(&self) -> Result<u32, std::num::TryFromIntError> {
u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_layer(self.model.as_ptr()) })
}
pub fn n_head(&self) -> Result<u32, std::num::TryFromIntError> {
u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_head(self.model.as_ptr()) })
}
pub fn n_head_kv(&self) -> Result<u32, std::num::TryFromIntError> {
u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_head_kv(self.model.as_ptr()) })
}
#[must_use]
pub fn is_hybrid(&self) -> bool {
unsafe { llama_cpp_bindings_sys::llama_model_is_hybrid(self.model.as_ptr()) }
}
pub fn meta_val_str(&self, key: &str) -> Result<String, MetaValError> {
let key_cstring = CString::new(key)?;
let key_ptr = key_cstring.as_ptr();
extract_meta_string(
|buf_ptr, buf_len| unsafe {
llama_cpp_bindings_sys::llama_model_meta_val_str(
self.model.as_ptr(),
key_ptr,
buf_ptr,
buf_len,
)
},
256,
)
}
#[must_use]
pub fn meta_count(&self) -> i32 {
unsafe { llama_cpp_bindings_sys::llama_model_meta_count(self.model.as_ptr()) }
}
pub fn meta_key_by_index(&self, index: i32) -> Result<String, MetaValError> {
extract_meta_string(
|buf_ptr, buf_len| unsafe {
llama_cpp_bindings_sys::llama_model_meta_key_by_index(
self.model.as_ptr(),
index,
buf_ptr,
buf_len,
)
},
256,
)
}
pub fn meta_val_str_by_index(&self, index: i32) -> Result<String, MetaValError> {
extract_meta_string(
|buf_ptr, buf_len| unsafe {
llama_cpp_bindings_sys::llama_model_meta_val_str_by_index(
self.model.as_ptr(),
index,
buf_ptr,
buf_len,
)
},
256,
)
}
#[must_use]
pub fn rope_type(&self) -> Option<RopeType> {
let raw = unsafe { llama_cpp_bindings_sys::llama_model_rope_type(self.model.as_ptr()) };
rope_type::rope_type_from_raw(raw)
}
pub fn chat_template(
&self,
name: Option<&str>,
) -> Result<LlamaChatTemplate, ChatTemplateError> {
let name_cstr = name.map(CString::new);
let name_ptr = match name_cstr {
Some(Ok(name)) => name.as_ptr(),
_ => ptr::null(),
};
let result = unsafe {
llama_cpp_bindings_sys::llama_model_chat_template(self.model.as_ptr(), name_ptr)
};
if result.is_null() {
Err(ChatTemplateError::MissingTemplate)
} else {
let chat_template_cstr = unsafe { CStr::from_ptr(result) };
Ok(LlamaChatTemplate(chat_template_cstr.to_owned()))
}
}
#[tracing::instrument(skip_all, fields(params))]
pub fn load_from_file(
_: &LlamaBackend,
path: impl AsRef<Path>,
params: &LlamaModelParams,
) -> Result<Self, LlamaModelLoadError> {
let path = path.as_ref();
let path_str = path
.to_str()
.ok_or_else(|| LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
if !path.exists() {
return Err(LlamaModelLoadError::FileNotFound(path.to_path_buf()));
}
let cstr = CString::new(path_str)?;
let llama_model = unsafe {
llama_cpp_bindings_sys::llama_load_model_from_file(cstr.as_ptr(), params.params)
};
let model = match NonNull::new(llama_model) {
Some(ptr) => ptr,
None if !path.exists() => {
return Err(LlamaModelLoadError::FileNotFound(path.to_path_buf()));
}
None => return Err(LlamaModelLoadError::NullResult),
};
Ok(Self {
model,
tok_env: OnceLock::new(),
})
}
pub fn lora_adapter_init(
&self,
path: impl AsRef<Path>,
) -> Result<LlamaLoraAdapter, LlamaLoraAdapterInitError> {
let path = path.as_ref();
let path_str = path
.to_str()
.ok_or_else(|| LlamaLoraAdapterInitError::PathToStrError(path.to_path_buf()))?;
if !path.exists() {
return Err(LlamaLoraAdapterInitError::FileNotFound(path.to_path_buf()));
}
let cstr = CString::new(path_str)?;
let raw_adapter = unsafe {
llama_cpp_bindings_sys::llama_adapter_lora_init(self.model.as_ptr(), cstr.as_ptr())
};
let Some(adapter) = NonNull::new(raw_adapter) else {
return Err(LlamaLoraAdapterInitError::NullResult);
};
Ok(LlamaLoraAdapter {
lora_adapter: adapter,
})
}
#[tracing::instrument(skip_all)]
pub fn apply_chat_template(
&self,
tmpl: &LlamaChatTemplate,
chat: &[LlamaChatMessage],
add_ass: bool,
) -> Result<String, ApplyChatTemplateError> {
let message_length = chat.iter().fold(0, |acc, chat_message| {
acc + chat_message.role.to_bytes().len() + chat_message.content.to_bytes().len()
});
let mut buff: Vec<u8> = vec![0; message_length * 2];
let chat: Vec<llama_cpp_bindings_sys::llama_chat_message> = chat
.iter()
.map(|chat_message| llama_cpp_bindings_sys::llama_chat_message {
role: chat_message.role.as_ptr(),
content: chat_message.content.as_ptr(),
})
.collect();
let tmpl_ptr = tmpl.0.as_ptr();
let buff_len: i32 = buff.len().try_into()?;
let res = unsafe {
llama_cpp_bindings_sys::llama_chat_apply_template(
tmpl_ptr,
chat.as_ptr(),
chat.len(),
add_ass,
buff.as_mut_ptr().cast::<c_char>(),
buff_len,
)
};
if res > buff_len {
let required_size: usize = res.try_into()?;
buff.resize(required_size, 0);
let new_buff_len: i32 = buff.len().try_into()?;
let res = unsafe {
llama_cpp_bindings_sys::llama_chat_apply_template(
tmpl_ptr,
chat.as_ptr(),
chat.len(),
add_ass,
buff.as_mut_ptr().cast::<c_char>(),
new_buff_len,
)
};
let final_size: usize = res.try_into()?;
return truncated_buffer_to_string(buff, final_size);
}
let final_size: usize = res.try_into()?;
truncated_buffer_to_string(buff, final_size)
}
pub fn sampled_token_classifier(&self) -> SampledTokenClassifier<'_> {
let markers = match self.streaming_markers() {
Ok(markers) => markers,
Err(detection_error) => {
tracing::warn!(
"streaming markers detection failed; classifier will run blind: {detection_error}"
);
StreamingMarkers::default()
}
};
SampledTokenClassifier::new(self, markers)
}
pub fn streaming_markers(&self) -> Result<StreamingMarkers, MarkerDetectionError> {
let (reasoning_open_str, reasoning_close_str) =
invoke_ffi_string_pair_detector(|first, second, error| unsafe {
llama_cpp_bindings_sys::llama_rs_detect_reasoning_markers(
self.model.as_ptr(),
first,
second,
error,
)
})?;
let tool_call_haystack = invoke_ffi_single_string_detector(|haystack, error| unsafe {
llama_cpp_bindings_sys::llama_rs_compute_tool_call_haystack(
self.model.as_ptr(),
haystack,
error,
)
})?;
let autoparser_pair = tool_call_haystack.as_deref().and_then(
crate::extract_tool_call_markers_from_haystack::extract_tool_call_markers_from_haystack,
);
let (autoparser_open, autoparser_close) = match autoparser_pair {
Some(crate::tool_call_marker_pair::ToolCallMarkerPair { open, close }) => {
(Some(open), Some(close))
}
None => (None, None),
};
let resolved_tool_call_markers =
self.resolve_tool_call_marker_strings(autoparser_open, autoparser_close);
Ok(StreamingMarkers {
reasoning_open: self.tokenize_marker(reasoning_open_str.as_deref()),
reasoning_close: self.tokenize_marker(reasoning_close_str.as_deref()),
tool_call_open: self.tokenize_marker(resolved_tool_call_markers.open.as_deref()),
tool_call_close: self.tokenize_marker(resolved_tool_call_markers.close.as_deref()),
})
}
fn resolve_tool_call_marker_strings(
&self,
autoparser_open: Option<String>,
autoparser_close: Option<String>,
) -> ResolvedToolCallMarkers {
if autoparser_open
.as_deref()
.is_some_and(|raw| !raw.trim().is_empty())
{
return ResolvedToolCallMarkers {
open: autoparser_open,
close: autoparser_close,
};
}
let Some(markers) = self.tool_call_markers() else {
return ResolvedToolCallMarkers {
open: autoparser_open,
close: autoparser_close,
};
};
let close = if markers.close.is_empty() {
None
} else {
Some(markers.close)
};
ResolvedToolCallMarkers {
open: Some(markers.open),
close,
}
}
pub fn reasoning_markers(&self) -> Result<Option<ReasoningMarkers>, MarkerDetectionError> {
let (open, close) = invoke_ffi_string_pair_detector(|first, second, error| unsafe {
llama_cpp_bindings_sys::llama_rs_detect_reasoning_markers(
self.model.as_ptr(),
first,
second,
error,
)
})?;
match (open, close) {
(Some(open), Some(close)) if !open.is_empty() && !close.is_empty() => {
Ok(Some(ReasoningMarkers { open, close }))
}
_ => Ok(None),
}
}
#[must_use]
pub fn tool_call_markers(&self) -> Option<ToolCallMarkers> {
let template = match self.chat_template(None) {
Ok(template) => template,
Err(error) => {
tracing::debug!(
"tool-call markers unavailable: chat template missing or invalid: {error}"
);
return None;
}
};
let template_str = match template.to_str() {
Ok(template_str) => template_str,
Err(error) => {
tracing::debug!(
"tool-call markers unavailable: chat template is not valid UTF-8: {error}"
);
return None;
}
};
tool_call_template_overrides::detect(template_str)
}
fn tokenize_marker(&self, marker: Option<&str>) -> Option<Vec<LlamaToken>> {
let marker = marker?.trim();
if marker.is_empty() {
return None;
}
match self.str_to_token(marker, AddBos::Never) {
Ok(tokens) if !tokens.is_empty() => Some(tokens),
Ok(_) => None,
Err(tokenize_error) => {
tracing::debug!(
"marker {marker:?} failed to tokenise; classifier will ignore it: {tokenize_error}"
);
None
}
}
}
pub fn parse_chat_message(
&self,
tools_json: &str,
input: &str,
is_partial: bool,
) -> Result<ChatMessageParseOutcome, ParseChatMessageError> {
let tools_value: serde_json::Value =
serde_json::from_str(tools_json).map_err(ParseChatMessageError::ToolsJsonInvalid)?;
if !tools_value.is_array() {
return Err(ParseChatMessageError::ToolsJsonNotArray);
}
let reasoning_markers = self.reasoning_markers().ok().flatten();
for candidate in tool_call_template_overrides::known_marker_candidates() {
if let ToolCallFormatOutcome::Parsed(calls) =
tool_call_format::try_parse(input, &candidate)
{
let split =
split_reasoning_prefix(input, reasoning_markers.as_ref(), &candidate.open);
let mut parsed = ParsedChatMessage::new(split.content, split.reasoning, calls);
synthesize_missing_tool_call_ids(&mut parsed.tool_calls);
return Ok(ChatMessageParseOutcome::Recognized(parsed));
}
}
match self.parse_chat_message_via_ffi(tools_json, input, is_partial) {
Ok(mut parsed) => {
synthesize_missing_tool_call_ids(&mut parsed.tool_calls);
Ok(ChatMessageParseOutcome::Recognized(parsed))
}
Err(ParseChatMessageError::ParseException(ffi_error_message)) => {
Ok(ChatMessageParseOutcome::Unrecognized(RawChatMessage {
tools_json: tools_json.to_owned(),
text: input.to_owned(),
is_partial,
ffi_error_message,
}))
}
Err(other) => Err(other),
}
}
fn parse_chat_message_via_ffi(
&self,
tools_json: &str,
input: &str,
is_partial: bool,
) -> Result<ParsedChatMessage, ParseChatMessageError> {
let tools_cstring = CString::new(tools_json)
.map_err(|err| ParseChatMessageError::ToolsSerialization(err.to_string()))?;
let input_cstring = CString::new(input)
.map_err(|err| ParseChatMessageError::ToolsSerialization(err.to_string()))?;
let mut handle: *mut llama_cpp_bindings_sys::llama_rs_parsed_chat = ptr::null_mut();
let mut out_error: *mut c_char = ptr::null_mut();
let status = unsafe {
llama_cpp_bindings_sys::llama_rs_parse_chat_message(
self.model.as_ptr(),
tools_cstring.as_ptr(),
input_cstring.as_ptr(),
i32::from(is_partial),
&raw mut handle,
&raw mut out_error,
)
};
let parsed = match status {
llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK => collect_parsed_chat_message(handle),
llama_cpp_bindings_sys::LLAMA_RS_STATUS_EXCEPTION => {
let message = read_optional_owned_cstr_lossy(out_error);
Err(ParseChatMessageError::ParseException(message))
}
other => Err(ParseChatMessageError::FfiError(status_to_i32(other))),
};
unsafe { llama_cpp_bindings_sys::llama_rs_parsed_chat_free(handle) };
unsafe { llama_cpp_bindings_sys::llama_rs_string_free(out_error) };
parsed
}
pub fn diagnose_tool_call_synthetic_renders(
&self,
) -> Result<(String, String), MarkerDetectionError> {
let (no_tools, with_tools) =
invoke_ffi_string_pair_detector(|first, second, error| unsafe {
llama_cpp_bindings_sys::llama_rs_diagnose_tool_call_synthetic_renders(
self.model.as_ptr(),
first,
second,
error,
)
})?;
Ok((no_tools.unwrap_or_default(), with_tools.unwrap_or_default()))
}
}
impl LlamaModel {
pub fn approximate_tok_env(&self) -> Arc<ApproximateTokEnv> {
Arc::clone(self.tok_env.get_or_init(|| build_approximate_tok_env(self)))
}
}
fn build_approximate_tok_env(model: &LlamaModel) -> Arc<ApproximateTokEnv> {
let n_vocab = model.n_vocab().cast_unsigned();
let tok_eos = {
let eot = unsafe { llama_cpp_bindings_sys::llama_vocab_eot(model.vocab_ptr()) };
if eot == -1 {
model.token_eos().0.cast_unsigned()
} else {
eot.cast_unsigned()
}
};
let info = TokRxInfo::new(n_vocab, tok_eos);
let mut words = Vec::with_capacity(n_vocab as usize);
for token_id in 0..n_vocab.cast_signed() {
let token = LlamaToken(token_id);
let bytes = model
.token_to_piece_bytes(token, 32, false, None)
.unwrap_or_default();
if bytes.is_empty() {
let special_bytes = model
.token_to_piece_bytes(token, 32, true, None)
.unwrap_or_default();
if special_bytes.is_empty() {
words.push(vec![]);
} else {
let mut marked = Vec::with_capacity(special_bytes.len() + 1);
marked.push(0xFF);
marked.extend(special_bytes);
words.push(marked);
}
} else {
words.push(bytes);
}
}
let trie = TokTrie::from(&info, &words);
Arc::new(ApproximateTokEnv::new(trie))
}
fn collect_parsed_chat_message(
handle: *mut llama_cpp_bindings_sys::llama_rs_parsed_chat,
) -> Result<ParsedChatMessage, ParseChatMessageError> {
if handle.is_null() {
return Ok(ParsedChatMessage::default());
}
let content = read_owned_cstr_for_parse(unsafe {
llama_cpp_bindings_sys::llama_rs_parsed_chat_content(handle)
})?;
let reasoning_content = read_owned_cstr_for_parse(unsafe {
llama_cpp_bindings_sys::llama_rs_parsed_chat_reasoning_content(handle)
})?;
let count = unsafe { llama_cpp_bindings_sys::llama_rs_parsed_chat_tool_call_count(handle) };
let mut tool_calls = Vec::with_capacity(count);
for index in 0..count {
let id = read_owned_cstr_for_parse(unsafe {
llama_cpp_bindings_sys::llama_rs_parsed_chat_tool_call_id(handle, index)
})?;
let name = read_owned_cstr_for_parse(unsafe {
llama_cpp_bindings_sys::llama_rs_parsed_chat_tool_call_name(handle, index)
})?;
let arguments_json = read_owned_cstr_for_parse(unsafe {
llama_cpp_bindings_sys::llama_rs_parsed_chat_tool_call_arguments(handle, index)
})?;
let arguments = ToolCallArguments::from_string(arguments_json);
tool_calls.push(ParsedToolCall::new(id, name, arguments));
}
Ok(ParsedChatMessage::new(
content,
reasoning_content,
tool_calls,
))
}
struct ReasoningSplit {
reasoning: String,
content: String,
}
fn split_reasoning_prefix(
input: &str,
reasoning_markers: Option<&ReasoningMarkers>,
tool_call_open: &str,
) -> ReasoningSplit {
let content_only = || ReasoningSplit {
reasoning: String::new(),
content: prefix_before(input, tool_call_open),
};
let Some(reasoning_markers) = reasoning_markers else {
return content_only();
};
let Some(open_pos) = input.find(&reasoning_markers.open) else {
return content_only();
};
let after_open = &input[open_pos + reasoning_markers.open.len()..];
let Some(close_offset) = after_open.find(&reasoning_markers.close) else {
return content_only();
};
let reasoning = after_open[..close_offset].to_owned();
let after_close = &after_open[close_offset + reasoning_markers.close.len()..];
ReasoningSplit {
reasoning,
content: prefix_before(after_close, tool_call_open),
}
}
fn prefix_before(text: &str, marker: &str) -> String {
text.find(marker)
.map_or_else(|| text.to_owned(), |pos| text[..pos].to_owned())
}
fn synthesize_missing_tool_call_ids(tool_calls: &mut [ParsedToolCall]) {
for (index, call) in tool_calls.iter_mut().enumerate() {
if call.id.is_empty() {
call.id = format!("call_{index}");
}
}
}
fn parse_single_string_status(
status: llama_cpp_bindings_sys::llama_rs_status,
out_value: *mut c_char,
out_error: *mut c_char,
) -> Result<Option<String>, MarkerDetectionError> {
match status {
llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK => read_optional_owned_cstr(out_value),
llama_cpp_bindings_sys::LLAMA_RS_STATUS_EXCEPTION => {
let message = read_optional_owned_cstr_lossy(out_error);
Err(MarkerDetectionError::AnalyzeException(message))
}
other => Err(MarkerDetectionError::FfiError(status_to_i32(other))),
}
}
fn invoke_ffi_single_string_detector<TInvoke>(
invoke: TInvoke,
) -> Result<Option<String>, MarkerDetectionError>
where
TInvoke: FnOnce(*mut *mut c_char, *mut *mut c_char) -> llama_cpp_bindings_sys::llama_rs_status,
{
let mut out_value: *mut c_char = ptr::null_mut();
let mut out_error: *mut c_char = ptr::null_mut();
let status = invoke(&raw mut out_value, &raw mut out_error);
let parsed = parse_single_string_status(status, out_value, out_error);
unsafe {
if !out_value.is_null() {
llama_cpp_bindings_sys::llama_rs_string_free(out_value);
}
if !out_error.is_null() {
llama_cpp_bindings_sys::llama_rs_string_free(out_error);
}
}
parsed
}
fn invoke_ffi_string_pair_detector<TInvoke>(
invoke: TInvoke,
) -> Result<(Option<String>, Option<String>), MarkerDetectionError>
where
TInvoke: FnOnce(
*mut *mut c_char,
*mut *mut c_char,
*mut *mut c_char,
) -> llama_cpp_bindings_sys::llama_rs_status,
{
let mut out_first: *mut c_char = ptr::null_mut();
let mut out_second: *mut c_char = ptr::null_mut();
let mut out_error: *mut c_char = ptr::null_mut();
let status = invoke(&raw mut out_first, &raw mut out_second, &raw mut out_error);
let parsed = (|| match status {
llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK => {
let first = read_optional_owned_cstr(out_first)?;
let second = read_optional_owned_cstr(out_second)?;
Ok((first, second))
}
llama_cpp_bindings_sys::LLAMA_RS_STATUS_EXCEPTION => {
let message = read_optional_owned_cstr_lossy(out_error);
Err(MarkerDetectionError::AnalyzeException(message))
}
other => Err(MarkerDetectionError::FfiError(status_to_i32(other))),
})();
unsafe { llama_cpp_bindings_sys::llama_rs_string_free(out_first) };
unsafe { llama_cpp_bindings_sys::llama_rs_string_free(out_second) };
unsafe { llama_cpp_bindings_sys::llama_rs_string_free(out_error) };
parsed
}
fn read_owned_cstr_for_parse(ptr: *mut c_char) -> Result<String, ParseChatMessageError> {
if ptr.is_null() {
return Ok(String::new());
}
let bytes = unsafe { CStr::from_ptr(ptr) }.to_bytes().to_vec();
unsafe { llama_cpp_bindings_sys::llama_rs_string_free(ptr) };
Ok(String::from_utf8(bytes)?)
}
fn read_optional_owned_cstr(ptr: *const c_char) -> Result<Option<String>, MarkerDetectionError> {
if ptr.is_null() {
return Ok(None);
}
let bytes = unsafe { CStr::from_ptr(ptr) }.to_bytes().to_vec();
Ok(Some(String::from_utf8(bytes)?))
}
fn read_optional_owned_cstr_lossy(ptr: *const c_char) -> String {
if ptr.is_null() {
return String::new();
}
unsafe { CStr::from_ptr(ptr) }
.to_string_lossy()
.into_owned()
}
fn extract_meta_string<TCFunction>(
c_function: TCFunction,
capacity: usize,
) -> Result<String, MetaValError>
where
TCFunction: Fn(*mut c_char, usize) -> i32,
{
let mut buffer = vec![0u8; capacity];
let result = c_function(buffer.as_mut_ptr().cast::<c_char>(), buffer.len());
if result < 0 {
return Err(MetaValError::NegativeReturn(result));
}
let returned_len = result.cast_unsigned() as usize;
if returned_len >= capacity {
return extract_meta_string(c_function, returned_len + 1);
}
if buffer.get(returned_len) != Some(&0) {
return Err(MetaValError::NegativeReturn(-1));
}
buffer.truncate(returned_len);
Ok(String::from_utf8(buffer)?)
}
impl Drop for LlamaModel {
fn drop(&mut self) {
unsafe { llama_cpp_bindings_sys::llama_free_model(self.model.as_ptr()) }
}
}
#[cfg(test)]
mod extract_meta_string_tests {
use super::extract_meta_string;
use crate::MetaValError;
#[test]
fn returns_error_when_null_terminator_missing() {
let result = extract_meta_string(
|buf_ptr, buf_len| {
let buffer =
unsafe { std::slice::from_raw_parts_mut(buf_ptr.cast::<u8>(), buf_len) };
buffer[0] = b'a';
buffer[1] = b'b';
buffer[2] = b'c';
2
},
4,
);
assert_eq!(result.unwrap_err(), MetaValError::NegativeReturn(-1));
}
#[test]
fn returns_error_for_negative_return_value() {
let result = extract_meta_string(|_buf_ptr, _buf_len| -5, 4);
assert_eq!(result.unwrap_err(), MetaValError::NegativeReturn(-5));
}
#[test]
fn returns_error_for_invalid_utf8_data() {
let result = extract_meta_string(
|buf_ptr, buf_len| {
let buffer =
unsafe { std::slice::from_raw_parts_mut(buf_ptr.cast::<u8>(), buf_len) };
buffer[0] = 0xFF;
buffer[1] = 0xFE;
buffer[2] = 0;
2
},
4,
);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("FromUtf8Error"));
}
#[test]
fn triggers_buffer_resize_when_returned_len_exceeds_capacity() {
let initial_capacity: usize = 4;
let length_exceeding_initial_capacity = 10;
let written_length = 2;
let call_count = std::cell::Cell::new(0);
let result = extract_meta_string(
|buf_ptr, buf_len| {
let count = call_count.get();
call_count.set(count + 1);
if count == 0 {
length_exceeding_initial_capacity
} else {
let buffer =
unsafe { std::slice::from_raw_parts_mut(buf_ptr.cast::<u8>(), buf_len) };
buffer[0] = b'h';
buffer[1] = b'i';
buffer[2] = 0;
written_length
}
},
initial_capacity,
);
assert_eq!(result.unwrap(), "hi");
}
#[test]
fn cstring_with_validated_len_null_byte_returns_error() {
let result = super::cstring_with_validated_len("null\0byte");
assert!(result.is_err());
}
#[test]
fn validate_string_length_overflow_returns_error() {
let result = super::validate_string_length_for_tokenizer(usize::MAX);
assert!(result.is_err());
}
#[test]
fn truncated_buffer_to_string_with_invalid_utf8_returns_error() {
let invalid_utf8 = vec![0xff, 0xfe, 0xfd];
let result = super::truncated_buffer_to_string(invalid_utf8, 3);
assert!(result.is_err());
}
}
#[cfg(test)]
mod ffi_helper_tests {
use std::ffi::CString;
use std::ptr;
use super::invoke_ffi_single_string_detector;
use super::invoke_ffi_string_pair_detector;
use super::parse_single_string_status;
use super::read_optional_owned_cstr_lossy;
use crate::MarkerDetectionError;
#[test]
fn read_optional_owned_cstr_lossy_returns_empty_for_null() {
let result = read_optional_owned_cstr_lossy(ptr::null());
assert!(result.is_empty());
}
#[test]
fn read_optional_owned_cstr_lossy_returns_string_for_valid_pointer() {
let owned = CString::new("hello").expect("static literal has no nuls");
let result = read_optional_owned_cstr_lossy(owned.as_ptr());
assert_eq!(result, "hello");
}
#[test]
fn read_optional_owned_cstr_lossy_handles_invalid_utf8_via_replacement() {
let owned = CString::new(vec![b'a', 0xFF, b'b']).expect("no interior nul");
let result = read_optional_owned_cstr_lossy(owned.as_ptr());
assert!(result.starts_with('a'));
assert!(result.ends_with('b'));
}
#[test]
fn parse_single_string_status_returns_none_for_ok_with_null() {
let result = parse_single_string_status(
llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK,
ptr::null_mut(),
ptr::null_mut(),
);
assert_eq!(result.expect("OK + null returns Ok(None)"), None);
}
#[test]
fn parse_single_string_status_returns_some_for_ok_with_value() {
let owned = CString::new("present").expect("no nul");
let result = parse_single_string_status(
llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK,
owned.as_ptr().cast_mut(),
ptr::null_mut(),
);
assert_eq!(
result.expect("OK + value returns Ok(Some)"),
Some("present".to_owned())
);
}
#[test]
fn parse_single_string_status_returns_analyze_exception() {
let owned = CString::new("boom").expect("no nul");
let result = parse_single_string_status(
llama_cpp_bindings_sys::LLAMA_RS_STATUS_EXCEPTION,
ptr::null_mut(),
owned.as_ptr().cast_mut(),
);
match result.expect_err("EXCEPTION must yield Err") {
MarkerDetectionError::AnalyzeException(message) => assert_eq!(message, "boom"),
other => panic!("expected AnalyzeException, got {other:?}"),
}
}
#[test]
fn parse_single_string_status_returns_ffi_error_for_other_status() {
let result = parse_single_string_status(
llama_cpp_bindings_sys::LLAMA_RS_STATUS_INVALID_ARGUMENT,
ptr::null_mut(),
ptr::null_mut(),
);
match result.expect_err("invalid status must yield Err") {
MarkerDetectionError::FfiError(_) => {}
other => panic!("expected FfiError, got {other:?}"),
}
}
#[test]
fn invoke_ffi_single_string_detector_propagates_invalid_argument_status() {
let result = invoke_ffi_single_string_detector(|_value, _error| {
llama_cpp_bindings_sys::LLAMA_RS_STATUS_INVALID_ARGUMENT
});
assert!(matches!(result, Err(MarkerDetectionError::FfiError(_))));
}
#[test]
fn invoke_ffi_single_string_detector_returns_none_for_ok_with_null() {
let result = invoke_ffi_single_string_detector(|value, _error| {
unsafe {
*value = ptr::null_mut();
}
llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK
});
assert_eq!(result.expect("OK + null returns Ok(None)"), None);
}
#[test]
fn invoke_ffi_string_pair_detector_propagates_invalid_argument_status() {
let result = invoke_ffi_string_pair_detector(|_first, _second, _error| {
llama_cpp_bindings_sys::LLAMA_RS_STATUS_INVALID_ARGUMENT
});
assert!(matches!(result, Err(MarkerDetectionError::FfiError(_))));
}
#[test]
fn invoke_ffi_string_pair_detector_returns_pair_of_none_for_ok_with_nulls() {
let result = invoke_ffi_string_pair_detector(|first, second, _error| {
unsafe {
*first = ptr::null_mut();
*second = ptr::null_mut();
}
llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK
});
assert_eq!(
result.expect("OK with both null returns Ok((None, None))"),
(None, None)
);
}
#[test]
fn invoke_ffi_string_pair_detector_propagates_invalid_status_codes() {
let result = invoke_ffi_string_pair_detector(|_first, _second, _error| {
llama_cpp_bindings_sys::LLAMA_RS_STATUS_ALLOCATION_FAILED
});
match result.expect_err("non-OK status yields Err") {
MarkerDetectionError::FfiError(code) => assert!(code != 0),
other => panic!("expected FfiError, got {other:?}"),
}
}
}