1pub mod add_bos;
2pub mod llama_chat_message;
3pub mod llama_chat_template;
4pub mod llama_lora_adapter;
5pub mod llama_split_mode_parse_error;
6pub mod params;
7pub mod rope_type;
8pub mod split_mode;
9pub mod vocab_type;
10pub mod vocab_type_from_int_error;
11
12use std::ffi::{CStr, CString, c_char};
13use std::num::NonZeroU16;
14use std::os::raw::c_int;
15use std::path::Path;
16use std::ptr;
17use std::ptr::NonNull;
18use std::sync::Arc;
19use std::sync::OnceLock;
20
21use toktrie::ApproximateTokEnv;
22use toktrie::TokRxInfo;
23use toktrie::TokTrie;
24
25use llama_cpp_bindings_types::ParsedChatMessage;
26use llama_cpp_bindings_types::ParsedToolCall;
27use llama_cpp_bindings_types::ReasoningMarkers;
28use llama_cpp_bindings_types::ToolCallArguments;
29use llama_cpp_bindings_types::ToolCallMarkers;
30
31use crate::chat_message_parse_outcome::ChatMessageParseOutcome;
32use crate::llama_backend::LlamaBackend;
33use crate::llama_token_attrs::LlamaTokenAttrs;
34use crate::llama_token_attrs_from_int_error::LlamaTokenAttrsFromIntError;
35use crate::raw_chat_message::RawChatMessage;
36use crate::resolved_tool_call_markers::ResolvedToolCallMarkers;
37use crate::sampled_token::SampledToken;
38use crate::sampled_token_classifier::SampledTokenClassifier;
39use crate::streaming_markers::StreamingMarkers;
40use crate::token::LlamaToken;
41use crate::tool_call_format;
42use crate::tool_call_format::ToolCallFormatOutcome;
43use crate::tool_call_template_overrides;
44use crate::{
45 ApplyChatTemplateError, ChatTemplateError, LlamaLoraAdapterInitError, LlamaModelLoadError,
46 MarkerDetectionError, MetaValError, ParseChatMessageError, StringToTokenError,
47 TokenToStringError,
48};
49
50pub use add_bos::AddBos;
51pub use llama_chat_message::LlamaChatMessage;
52pub use llama_chat_template::LlamaChatTemplate;
53pub use llama_lora_adapter::LlamaLoraAdapter;
54pub use rope_type::RopeType;
55pub use vocab_type::VocabType;
56pub use vocab_type_from_int_error::VocabTypeFromIntError;
57
58use params::LlamaModelParams;
59
60fn truncated_buffer_to_string(
61 mut buffer: Vec<u8>,
62 length: usize,
63) -> Result<String, ApplyChatTemplateError> {
64 buffer.truncate(length);
65
66 Ok(String::from_utf8(buffer)?)
67}
68
69fn validate_string_length_for_tokenizer(length: usize) -> Result<c_int, StringToTokenError> {
70 Ok(c_int::try_from(length)?)
71}
72
73fn cstring_with_validated_len(str: &str) -> Result<(CString, c_int), StringToTokenError> {
74 let c_string = CString::new(str)?;
75 let len = validate_string_length_for_tokenizer(c_string.as_bytes().len())?;
76 Ok((c_string, len))
77}
78
79pub struct LlamaModel {
80 pub model: NonNull<llama_cpp_bindings_sys::llama_model>,
81 tok_env: OnceLock<Arc<ApproximateTokEnv>>,
82}
83
84impl std::fmt::Debug for LlamaModel {
85 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86 f.debug_struct("LlamaModel")
87 .field("model", &self.model)
88 .finish_non_exhaustive()
89 }
90}
91
92unsafe impl Send for LlamaModel {}
93
94unsafe impl Sync for LlamaModel {}
95
96impl LlamaModel {
97 #[must_use]
98 pub fn vocab_ptr(&self) -> *const llama_cpp_bindings_sys::llama_vocab {
99 unsafe { llama_cpp_bindings_sys::llama_model_get_vocab(self.model.as_ptr()) }
100 }
101
102 pub fn n_ctx_train(&self) -> Result<u32, std::num::TryFromIntError> {
106 let n_ctx_train = unsafe { llama_cpp_bindings_sys::llama_n_ctx_train(self.model.as_ptr()) };
107
108 u32::try_from(n_ctx_train)
109 }
110
111 pub fn tokens(
112 &self,
113 decode_special: bool,
114 ) -> impl Iterator<Item = (LlamaToken, Result<String, TokenToStringError>)> + '_ {
115 (0..self.n_vocab())
116 .map(LlamaToken::new)
117 .map(move |llama_token| {
118 let mut decoder = encoding_rs::UTF_8.new_decoder();
119 (
120 llama_token,
121 self.token_to_piece(
122 &SampledToken::Content(llama_token),
123 &mut decoder,
124 decode_special,
125 None,
126 ),
127 )
128 })
129 }
130
131 #[must_use]
132 pub fn token_bos(&self) -> LlamaToken {
133 let token = unsafe { llama_cpp_bindings_sys::llama_token_bos(self.vocab_ptr()) };
134 LlamaToken(token)
135 }
136
137 #[must_use]
138 pub fn token_eos(&self) -> LlamaToken {
139 let token = unsafe { llama_cpp_bindings_sys::llama_token_eos(self.vocab_ptr()) };
140 LlamaToken(token)
141 }
142
143 #[must_use]
144 pub fn token_nl(&self) -> LlamaToken {
145 let token = unsafe { llama_cpp_bindings_sys::llama_token_nl(self.vocab_ptr()) };
146 LlamaToken(token)
147 }
148
149 #[must_use]
150 pub fn is_eog_token(&self, token: &SampledToken) -> bool {
151 let (SampledToken::Content(LlamaToken(id))
152 | SampledToken::Reasoning(LlamaToken(id))
153 | SampledToken::ToolCall(LlamaToken(id))
154 | SampledToken::Undeterminable(LlamaToken(id))) = *token;
155
156 unsafe { llama_cpp_bindings_sys::llama_token_is_eog(self.vocab_ptr(), id) }
157 }
158
159 #[must_use]
160 pub fn decode_start_token(&self) -> LlamaToken {
161 let token =
162 unsafe { llama_cpp_bindings_sys::llama_model_decoder_start_token(self.model.as_ptr()) };
163 LlamaToken(token)
164 }
165
166 #[must_use]
167 pub fn token_sep(&self) -> LlamaToken {
168 let token = unsafe { llama_cpp_bindings_sys::llama_vocab_sep(self.vocab_ptr()) };
169 LlamaToken(token)
170 }
171
172 pub fn str_to_token(
182 &self,
183 str: &str,
184 add_bos: AddBos,
185 ) -> Result<Vec<LlamaToken>, StringToTokenError> {
186 let add_bos = match add_bos {
187 AddBos::Always => true,
188 AddBos::Never => false,
189 };
190
191 let tokens_estimation = std::cmp::max(8, (str.len() / 2) + usize::from(add_bos));
192 let mut buffer: Vec<LlamaToken> = Vec::with_capacity(tokens_estimation);
193
194 let (c_string, c_string_len) = cstring_with_validated_len(str)?;
195 let buffer_capacity = c_int::try_from(buffer.capacity())?;
196
197 let size = invoke_rs_tokenize(
198 self.vocab_ptr(),
199 c_string.as_ptr(),
200 c_string_len,
201 buffer
202 .as_mut_ptr()
203 .cast::<llama_cpp_bindings_sys::llama_token>(),
204 buffer_capacity,
205 add_bos,
206 )?;
207
208 let size = if size.is_negative() {
209 buffer.reserve_exact(usize::try_from(-size)?);
210 invoke_rs_tokenize(
211 self.vocab_ptr(),
212 c_string.as_ptr(),
213 c_string_len,
214 buffer
215 .as_mut_ptr()
216 .cast::<llama_cpp_bindings_sys::llama_token>(),
217 -size,
218 add_bos,
219 )?
220 } else {
221 size
222 };
223
224 let size = usize::try_from(size)?;
225
226 unsafe { buffer.set_len(size) }
228
229 Ok(buffer)
230 }
231
232 pub fn token_attr(
236 &self,
237 LlamaToken(id): LlamaToken,
238 ) -> Result<LlamaTokenAttrs, LlamaTokenAttrsFromIntError> {
239 let token_type =
240 unsafe { llama_cpp_bindings_sys::llama_token_get_attr(self.vocab_ptr(), id) };
241
242 LlamaTokenAttrs::try_from(token_type)
243 }
244
245 pub fn token_to_piece(
251 &self,
252 token: &SampledToken,
253 decoder: &mut encoding_rs::Decoder,
254 special: bool,
255 lstrip: Option<NonZeroU16>,
256 ) -> Result<String, TokenToStringError> {
257 let (SampledToken::Content(inner)
258 | SampledToken::Reasoning(inner)
259 | SampledToken::ToolCall(inner)
260 | SampledToken::Undeterminable(inner)) = *token;
261 let bytes = match self.token_to_piece_bytes(inner, 8, special, lstrip) {
262 Err(TokenToStringError::InsufficientBufferSpace(required_size)) => {
263 let buffer_size: usize = (-required_size).try_into()?;
264
265 self.token_to_piece_bytes(inner, buffer_size, special, lstrip)
266 }
267 other => other,
268 }?;
269
270 let mut output_piece = String::with_capacity(bytes.len());
271 let (_result, _decoded_size, _had_replacements) =
272 decoder.decode_to_string(&bytes, &mut output_piece, false);
273
274 Ok(output_piece)
275 }
276
277 pub fn token_to_piece_bytes(
283 &self,
284 token: LlamaToken,
285 buffer_size: usize,
286 special: bool,
287 lstrip: Option<NonZeroU16>,
288 ) -> Result<Vec<u8>, TokenToStringError> {
289 let mut buffer: Vec<u8> = vec![0u8; buffer_size];
290 let buffer_len = c_int::try_from(buffer.len())?;
291 let lstrip = lstrip.map_or(0, |strip_count| i32::from(strip_count.get()));
292 let size = unsafe {
293 llama_cpp_bindings_sys::llama_token_to_piece(
294 self.vocab_ptr(),
295 token.0,
296 buffer.as_mut_ptr().cast::<c_char>(),
297 buffer_len,
298 lstrip,
299 special,
300 )
301 };
302
303 match size {
304 0 => Err(TokenToStringError::UnknownTokenType),
305 error_code if error_code.is_negative() => {
306 Err(TokenToStringError::InsufficientBufferSpace(error_code))
307 }
308 size => {
309 let written = usize::try_from(size)?;
310 buffer.truncate(written);
311
312 Ok(buffer)
313 }
314 }
315 }
316
317 #[must_use]
318 pub fn n_vocab(&self) -> i32 {
319 unsafe { llama_cpp_bindings_sys::llama_n_vocab(self.vocab_ptr()) }
320 }
321
322 pub fn vocab_type(&self) -> Result<VocabType, VocabTypeFromIntError> {
326 let vocab_type = unsafe { llama_cpp_bindings_sys::llama_vocab_type(self.vocab_ptr()) };
327
328 VocabType::try_from(vocab_type)
329 }
330
331 #[must_use]
332 pub fn n_embd(&self) -> c_int {
333 unsafe { llama_cpp_bindings_sys::llama_n_embd(self.model.as_ptr()) }
334 }
335
336 #[must_use]
337 pub fn size(&self) -> u64 {
338 unsafe { llama_cpp_bindings_sys::llama_model_size(self.model.as_ptr()) }
339 }
340
341 #[must_use]
342 pub fn n_params(&self) -> u64 {
343 unsafe { llama_cpp_bindings_sys::llama_model_n_params(self.model.as_ptr()) }
344 }
345
346 #[must_use]
347 pub fn is_recurrent(&self) -> bool {
348 unsafe { llama_cpp_bindings_sys::llama_model_is_recurrent(self.model.as_ptr()) }
349 }
350
351 pub fn n_layer(&self) -> Result<u32, std::num::TryFromIntError> {
355 u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_layer(self.model.as_ptr()) })
356 }
357
358 pub fn n_head(&self) -> Result<u32, std::num::TryFromIntError> {
362 u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_head(self.model.as_ptr()) })
363 }
364
365 pub fn n_head_kv(&self) -> Result<u32, std::num::TryFromIntError> {
369 u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_head_kv(self.model.as_ptr()) })
370 }
371
372 #[must_use]
373 pub fn is_hybrid(&self) -> bool {
374 unsafe { llama_cpp_bindings_sys::llama_model_is_hybrid(self.model.as_ptr()) }
375 }
376
377 pub fn meta_val_str(&self, key: &str) -> Result<String, MetaValError> {
380 let key_cstring = CString::new(key)?;
381 let key_ptr = key_cstring.as_ptr();
382
383 extract_meta_string(
384 |buf_ptr, buf_len| unsafe {
385 llama_cpp_bindings_sys::llama_model_meta_val_str(
386 self.model.as_ptr(),
387 key_ptr,
388 buf_ptr,
389 buf_len,
390 )
391 },
392 256,
393 )
394 }
395
396 #[must_use]
397 pub fn meta_count(&self) -> i32 {
398 unsafe { llama_cpp_bindings_sys::llama_model_meta_count(self.model.as_ptr()) }
399 }
400
401 pub fn meta_key_by_index(&self, index: i32) -> Result<String, MetaValError> {
404 extract_meta_string(
405 |buf_ptr, buf_len| unsafe {
406 llama_cpp_bindings_sys::llama_model_meta_key_by_index(
407 self.model.as_ptr(),
408 index,
409 buf_ptr,
410 buf_len,
411 )
412 },
413 256,
414 )
415 }
416
417 pub fn meta_val_str_by_index(&self, index: i32) -> Result<String, MetaValError> {
420 extract_meta_string(
421 |buf_ptr, buf_len| unsafe {
422 llama_cpp_bindings_sys::llama_model_meta_val_str_by_index(
423 self.model.as_ptr(),
424 index,
425 buf_ptr,
426 buf_len,
427 )
428 },
429 256,
430 )
431 }
432
433 #[must_use]
434 pub fn rope_type(&self) -> Option<RopeType> {
435 let raw = unsafe { llama_cpp_bindings_sys::llama_model_rope_type(self.model.as_ptr()) };
436
437 rope_type::rope_type_from_raw(raw)
438 }
439
440 pub fn chat_template(
449 &self,
450 name: Option<&str>,
451 ) -> Result<LlamaChatTemplate, ChatTemplateError> {
452 let name_cstr = name.map(CString::new);
453 let name_ptr = match name_cstr {
454 Some(Ok(name)) => name.as_ptr(),
455 _ => ptr::null(),
456 };
457 let result = unsafe {
458 llama_cpp_bindings_sys::llama_model_chat_template(self.model.as_ptr(), name_ptr)
459 };
460
461 if result.is_null() {
462 Err(ChatTemplateError::MissingTemplate)
463 } else {
464 let chat_template_cstr = unsafe { CStr::from_ptr(result) };
465
466 Ok(LlamaChatTemplate(chat_template_cstr.to_owned()))
467 }
468 }
469
470 pub fn load_from_file(
478 _: &LlamaBackend,
479 path: impl AsRef<Path>,
480 params: &LlamaModelParams,
481 ) -> Result<Self, LlamaModelLoadError> {
482 let path = path.as_ref();
483
484 let path_str = path
485 .to_str()
486 .ok_or_else(|| LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
487
488 if !path.exists() {
489 return Err(LlamaModelLoadError::FileNotFound(path.to_path_buf()));
490 }
491
492 let cstr = CString::new(path_str)?;
493 let mut out_model: *mut llama_cpp_bindings_sys::llama_model = ptr::null_mut();
494 let mut out_error: *mut c_char = ptr::null_mut();
495 let status = unsafe {
496 llama_cpp_bindings_sys::llama_rs_load_model_from_file(
497 cstr.as_ptr(),
498 params.params,
499 &raw mut out_model,
500 &raw mut out_error,
501 )
502 };
503 match status {
504 llama_cpp_bindings_sys::LLAMA_RS_LOAD_MODEL_FROM_FILE_OK => {
505 let model = NonNull::new(out_model)
506 .ok_or(LlamaModelLoadError::Unloadable)?;
507 Ok(Self {
508 model,
509 tok_env: OnceLock::new(),
510 })
511 }
512 llama_cpp_bindings_sys::LLAMA_RS_LOAD_MODEL_FROM_FILE_VENDORED_RETURNED_NULL => {
513 if path.exists() {
514 Err(LlamaModelLoadError::Unloadable)
515 } else {
516 Err(LlamaModelLoadError::FileNotFound(path.to_path_buf()))
517 }
518 }
519 llama_cpp_bindings_sys::LLAMA_RS_LOAD_MODEL_FROM_FILE_ERROR_STRING_ALLOCATION_FAILED => {
520 Err(LlamaModelLoadError::NotEnoughMemory)
521 }
522 llama_cpp_bindings_sys::LLAMA_RS_LOAD_MODEL_FROM_FILE_VENDORED_THREW_CXX_EXCEPTION => {
523 let message = unsafe { crate::ffi_error_reader::read_and_free_cpp_error(out_error) };
524 Err(LlamaModelLoadError::Reported { message })
525 }
526 other => unreachable!(
527 "llama_rs_load_model_from_file returned unrecognized status {other}"
528 ),
529 }
530 }
531
532 pub fn lora_adapter_init(
536 &self,
537 path: impl AsRef<Path>,
538 ) -> Result<LlamaLoraAdapter, LlamaLoraAdapterInitError> {
539 let path = path.as_ref();
540
541 let path_str = path
542 .to_str()
543 .ok_or_else(|| LlamaLoraAdapterInitError::PathToStrError(path.to_path_buf()))?;
544
545 if !path.exists() {
546 return Err(LlamaLoraAdapterInitError::FileNotFound(path.to_path_buf()));
547 }
548
549 let cstr = CString::new(path_str)?;
550 let raw_adapter = unsafe {
551 llama_cpp_bindings_sys::llama_adapter_lora_init(self.model.as_ptr(), cstr.as_ptr())
552 };
553
554 let Some(adapter) = NonNull::new(raw_adapter) else {
555 return Err(LlamaLoraAdapterInitError::Unloadable);
556 };
557
558 Ok(LlamaLoraAdapter {
559 lora_adapter: adapter,
560 })
561 }
562
563 pub fn apply_chat_template(
566 &self,
567 tmpl: &LlamaChatTemplate,
568 chat: &[LlamaChatMessage],
569 add_ass: bool,
570 ) -> Result<String, ApplyChatTemplateError> {
571 let message_length = chat.iter().fold(0, |acc, chat_message| {
572 acc + chat_message.role.to_bytes().len() + chat_message.content.to_bytes().len()
573 });
574 let mut buff: Vec<u8> = vec![0; message_length * 2];
575
576 let chat: Vec<llama_cpp_bindings_sys::llama_chat_message> = chat
577 .iter()
578 .map(|chat_message| llama_cpp_bindings_sys::llama_chat_message {
579 role: chat_message.role.as_ptr(),
580 content: chat_message.content.as_ptr(),
581 })
582 .collect();
583
584 let tmpl_ptr = tmpl.0.as_ptr();
585
586 let buff_len: i32 = buff.len().try_into()?;
587
588 let res = unsafe {
589 llama_cpp_bindings_sys::llama_chat_apply_template(
590 tmpl_ptr,
591 chat.as_ptr(),
592 chat.len(),
593 add_ass,
594 buff.as_mut_ptr().cast::<c_char>(),
595 buff_len,
596 )
597 };
598
599 if res > buff_len {
600 let required_size: usize = res.try_into()?;
601 buff.resize(required_size, 0);
602
603 let new_buff_len: i32 = buff.len().try_into()?;
604
605 let res = unsafe {
606 llama_cpp_bindings_sys::llama_chat_apply_template(
607 tmpl_ptr,
608 chat.as_ptr(),
609 chat.len(),
610 add_ass,
611 buff.as_mut_ptr().cast::<c_char>(),
612 new_buff_len,
613 )
614 };
615 let final_size: usize = res.try_into()?;
616
617 return truncated_buffer_to_string(buff, final_size);
618 }
619
620 let final_size: usize = res.try_into()?;
621
622 truncated_buffer_to_string(buff, final_size)
623 }
624
625 pub fn sampled_token_classifier(&self) -> SampledTokenClassifier<'_> {
626 let markers = match self.streaming_markers() {
627 Ok(markers) => markers,
628 Err(detection_error) => {
629 log::warn!(
630 "streaming markers detection failed; classifier will run blind: {detection_error}",
631 );
632 StreamingMarkers::default()
633 }
634 };
635
636 SampledTokenClassifier::new(self, markers)
637 }
638
639 pub fn streaming_markers(&self) -> Result<StreamingMarkers, MarkerDetectionError> {
642 let (reasoning_open_str, reasoning_close_str) =
643 invoke_detect_reasoning_markers(self.model.as_ptr())?;
644
645 let tool_call_haystack = invoke_compute_tool_call_haystack(self.model.as_ptr())?;
646
647 let autoparser_pair = tool_call_haystack.as_deref().and_then(
648 crate::extract_tool_call_markers_from_haystack::extract_tool_call_markers_from_haystack,
649 );
650
651 let (autoparser_open, autoparser_close) = match autoparser_pair {
652 Some(crate::tool_call_marker_pair::ToolCallMarkerPair { open, close }) => {
653 (Some(open), Some(close))
654 }
655 None => (None, None),
656 };
657
658 let resolved_tool_call_markers =
659 self.resolve_tool_call_marker_strings(autoparser_open, autoparser_close);
660
661 Ok(StreamingMarkers {
662 reasoning_open: self.tokenize_marker(reasoning_open_str.as_deref()),
663 reasoning_close: self.tokenize_marker(reasoning_close_str.as_deref()),
664 tool_call_open: self.tokenize_marker(resolved_tool_call_markers.open.as_deref()),
665 tool_call_close: self.tokenize_marker(resolved_tool_call_markers.close.as_deref()),
666 })
667 }
668
669 fn resolve_tool_call_marker_strings(
670 &self,
671 autoparser_open: Option<String>,
672 autoparser_close: Option<String>,
673 ) -> ResolvedToolCallMarkers {
674 if autoparser_open
675 .as_deref()
676 .is_some_and(|raw| !raw.trim().is_empty())
677 {
678 return ResolvedToolCallMarkers {
679 open: autoparser_open,
680 close: autoparser_close,
681 };
682 }
683 let Some(markers) = self.tool_call_markers() else {
684 return ResolvedToolCallMarkers {
685 open: autoparser_open,
686 close: autoparser_close,
687 };
688 };
689 let close = if markers.close.is_empty() {
690 None
691 } else {
692 Some(markers.close)
693 };
694 ResolvedToolCallMarkers {
695 open: Some(markers.open),
696 close,
697 }
698 }
699
700 pub fn reasoning_markers(&self) -> Result<Option<ReasoningMarkers>, MarkerDetectionError> {
703 let (open, close) = invoke_detect_reasoning_markers(self.model.as_ptr())?;
704
705 match (open, close) {
706 (Some(open), Some(close)) if !open.is_empty() && !close.is_empty() => {
707 Ok(Some(ReasoningMarkers { open, close }))
708 }
709 _ => Ok(None),
710 }
711 }
712
713 #[must_use]
714 pub fn tool_call_markers(&self) -> Option<ToolCallMarkers> {
715 let template = match self.chat_template(None) {
716 Ok(template) => template,
717 Err(error) => {
718 log::debug!(
719 "tool-call markers unavailable: chat template missing or invalid: {error}",
720 );
721 return None;
722 }
723 };
724 let template_str = match template.to_str() {
725 Ok(template_str) => template_str,
726 Err(error) => {
727 log::debug!(
728 "tool-call markers unavailable: chat template is not valid UTF-8: {error}",
729 );
730 return None;
731 }
732 };
733 tool_call_template_overrides::detect(template_str)
734 }
735
736 fn tokenize_marker(&self, marker: Option<&str>) -> Option<Vec<LlamaToken>> {
737 let marker = marker?.trim();
738 if marker.is_empty() {
739 return None;
740 }
741 match self.str_to_token(marker, AddBos::Never) {
742 Ok(tokens) if !tokens.is_empty() => Some(tokens),
743 Ok(_) => None,
744 Err(tokenize_error) => {
745 log::debug!(
746 "marker {marker:?} failed to tokenise; classifier will ignore it: {tokenize_error}",
747 );
748 None
749 }
750 }
751 }
752
753 pub fn parse_chat_message(
759 &self,
760 tools_json: &str,
761 input: &str,
762 is_partial: bool,
763 ) -> Result<ChatMessageParseOutcome, ParseChatMessageError> {
764 let tools_value: serde_json::Value =
765 serde_json::from_str(tools_json).map_err(ParseChatMessageError::ToolsJsonInvalid)?;
766 if !tools_value.is_array() {
767 return Err(ParseChatMessageError::ToolsJsonNotArray);
768 }
769
770 let reasoning_markers = self.reasoning_markers().ok().flatten();
771
772 for candidate in tool_call_template_overrides::known_marker_candidates() {
773 if let ToolCallFormatOutcome::Parsed(calls) =
774 tool_call_format::try_parse(input, &candidate)
775 {
776 let split =
777 split_reasoning_prefix(input, reasoning_markers.as_ref(), &candidate.open);
778 let mut parsed = ParsedChatMessage::new(split.content, split.reasoning, calls);
779 synthesize_missing_tool_call_ids(&mut parsed.tool_calls);
780 return Ok(ChatMessageParseOutcome::Recognized(parsed));
781 }
782 }
783
784 match self.parse_chat_message_via_ffi(tools_json, input, is_partial) {
785 Ok(mut parsed) => {
786 synthesize_missing_tool_call_ids(&mut parsed.tool_calls);
787 Ok(ChatMessageParseOutcome::Recognized(parsed))
788 }
789 Err(ParseChatMessageError::ParseFailed { message }) => {
790 Ok(ChatMessageParseOutcome::Unrecognized(RawChatMessage {
791 tools_json: tools_json.to_owned(),
792 text: input.to_owned(),
793 is_partial,
794 ffi_error_message: message,
795 }))
796 }
797 Err(other) => Err(other),
798 }
799 }
800
801 fn parse_chat_message_via_ffi(
802 &self,
803 tools_json: &str,
804 input: &str,
805 is_partial: bool,
806 ) -> Result<ParsedChatMessage, ParseChatMessageError> {
807 let tools_cstring = CString::new(tools_json)
808 .map_err(|err| ParseChatMessageError::ToolsSerialization(err.to_string()))?;
809 let input_cstring = CString::new(input)
810 .map_err(|err| ParseChatMessageError::ToolsSerialization(err.to_string()))?;
811
812 let mut handle: *mut llama_cpp_bindings_sys::llama_rs_parsed_chat = ptr::null_mut();
813 let mut out_error: *mut c_char = ptr::null_mut();
814
815 let status = unsafe {
816 llama_cpp_bindings_sys::llama_rs_parse_chat_message(
817 self.model.as_ptr(),
818 tools_cstring.as_ptr(),
819 input_cstring.as_ptr(),
820 i32::from(is_partial),
821 &raw mut handle,
822 &raw mut out_error,
823 )
824 };
825
826 let parsed = match status {
827 llama_cpp_bindings_sys::LLAMA_RS_PARSE_CHAT_MESSAGE_OK => {
828 collect_parsed_chat_message(handle)
829 }
830 llama_cpp_bindings_sys::LLAMA_RS_PARSE_CHAT_MESSAGE_MODEL_HAS_NO_CHAT_TEMPLATE => {
831 Err(ParseChatMessageError::NoChatTemplate)
832 }
833 llama_cpp_bindings_sys::LLAMA_RS_PARSE_CHAT_MESSAGE_MODEL_HAS_NO_VOCAB => {
834 Err(ParseChatMessageError::NoVocab)
835 }
836 llama_cpp_bindings_sys::LLAMA_RS_PARSE_CHAT_MESSAGE_ERROR_STRING_ALLOCATION_FAILED => {
837 Err(ParseChatMessageError::NotEnoughMemory)
838 }
839 llama_cpp_bindings_sys::LLAMA_RS_PARSE_CHAT_MESSAGE_VENDORED_THREW_CXX_EXCEPTION => {
840 let message =
841 unsafe { crate::ffi_error_reader::read_and_free_cpp_error(out_error) };
842 out_error = ptr::null_mut();
843 Err(ParseChatMessageError::ParseFailed { message })
844 }
845 other => {
846 unreachable!("llama_rs_parse_chat_message returned unrecognized status {other}")
847 }
848 };
849
850 let mut free_error: *mut c_char = ptr::null_mut();
851 let free_status = unsafe {
852 llama_cpp_bindings_sys::llama_rs_parsed_chat_free(handle, &raw mut free_error)
853 };
854 match (parsed, free_status) {
855 (Ok(value), llama_cpp_bindings_sys::LLAMA_RS_PARSED_CHAT_FREE_OK) => {
856 unsafe { llama_cpp_bindings_sys::llama_rs_string_free(out_error) };
857 Ok(value)
858 }
859 (
860 Ok(_),
861 llama_cpp_bindings_sys::LLAMA_RS_PARSED_CHAT_FREE_DESTRUCTOR_THREW_CXX_EXCEPTION,
862 ) => {
863 unsafe { llama_cpp_bindings_sys::llama_rs_string_free(out_error) };
864 let message =
865 unsafe { crate::ffi_error_reader::read_and_free_cpp_error(free_error) };
866 Err(ParseChatMessageError::DestructorFailed { message })
867 }
868 (
869 Ok(_),
870 llama_cpp_bindings_sys::LLAMA_RS_PARSED_CHAT_FREE_ERROR_STRING_ALLOCATION_FAILED,
871 ) => {
872 unsafe { llama_cpp_bindings_sys::llama_rs_string_free(out_error) };
873 Err(ParseChatMessageError::NotEnoughMemory)
874 }
875 (Ok(_), other) => {
876 unsafe { llama_cpp_bindings_sys::llama_rs_string_free(out_error) };
877 unsafe { llama_cpp_bindings_sys::llama_rs_string_free(free_error) };
878 unreachable!("llama_rs_parsed_chat_free returned unrecognized status {other}")
879 }
880 (Err(parse_err), _) => {
881 unsafe { llama_cpp_bindings_sys::llama_rs_string_free(out_error) };
882 unsafe { llama_cpp_bindings_sys::llama_rs_string_free(free_error) };
883 Err(parse_err)
884 }
885 }
886 }
887
888 pub fn diagnose_tool_call_synthetic_renders(
893 &self,
894 ) -> Result<(String, String), MarkerDetectionError> {
895 let (no_tools, with_tools) =
896 invoke_diagnose_tool_call_synthetic_renders(self.model.as_ptr())?;
897
898 Ok((no_tools.unwrap_or_default(), with_tools.unwrap_or_default()))
899 }
900}
901
902impl LlamaModel {
903 pub fn approximate_tok_env(&self) -> Arc<ApproximateTokEnv> {
904 Arc::clone(self.tok_env.get_or_init(|| build_approximate_tok_env(self)))
905 }
906}
907
908fn build_approximate_tok_env(model: &LlamaModel) -> Arc<ApproximateTokEnv> {
909 let n_vocab = model.n_vocab().cast_unsigned();
910 let tok_eos = {
911 let eot = unsafe { llama_cpp_bindings_sys::llama_vocab_eot(model.vocab_ptr()) };
912 if eot == -1 {
913 model.token_eos().0.cast_unsigned()
914 } else {
915 eot.cast_unsigned()
916 }
917 };
918 let info = TokRxInfo::new(n_vocab, tok_eos);
919
920 let mut words = Vec::with_capacity(n_vocab as usize);
921
922 for token_id in 0..n_vocab.cast_signed() {
923 let token = LlamaToken(token_id);
924 let bytes = model
925 .token_to_piece_bytes(token, 32, false, None)
926 .unwrap_or_default();
927 if bytes.is_empty() {
928 let special_bytes = model
929 .token_to_piece_bytes(token, 32, true, None)
930 .unwrap_or_default();
931 if special_bytes.is_empty() {
932 words.push(vec![]);
933 } else {
934 let mut marked = Vec::with_capacity(special_bytes.len() + 1);
935 marked.push(0xFF);
936 marked.extend(special_bytes);
937 words.push(marked);
938 }
939 } else {
940 words.push(bytes);
941 }
942 }
943
944 let trie = TokTrie::from(&info, &words);
945 Arc::new(ApproximateTokEnv::new(trie))
946}
947
948fn collect_parsed_chat_message(
949 handle: *mut llama_cpp_bindings_sys::llama_rs_parsed_chat,
950) -> Result<ParsedChatMessage, ParseChatMessageError> {
951 if handle.is_null() {
952 return Ok(ParsedChatMessage::default());
953 }
954
955 let content = read_parsed_chat_content(handle)?;
956 let reasoning_content = read_parsed_chat_reasoning_content(handle)?;
957 let count = read_parsed_chat_tool_call_count(handle)?;
958
959 let mut tool_calls = Vec::with_capacity(count);
960 for index in 0..count {
961 let id = read_parsed_chat_tool_call_id(handle, index)?;
962 let name = read_parsed_chat_tool_call_name(handle, index)?;
963 let arguments_json = read_parsed_chat_tool_call_arguments(handle, index)?;
964
965 let arguments = ToolCallArguments::from_string(arguments_json);
966 tool_calls.push(ParsedToolCall::new(id, name, arguments));
967 }
968
969 Ok(ParsedChatMessage::new(
970 content,
971 reasoning_content,
972 tool_calls,
973 ))
974}
975
976fn read_parsed_chat_content(
977 handle: *mut llama_cpp_bindings_sys::llama_rs_parsed_chat,
978) -> Result<String, ParseChatMessageError> {
979 let mut out_string: *mut c_char = ptr::null_mut();
980 let mut out_error: *mut c_char = ptr::null_mut();
981 let status = unsafe {
982 llama_cpp_bindings_sys::llama_rs_parsed_chat_content(
983 handle,
984 &raw mut out_string,
985 &raw mut out_error,
986 )
987 };
988 match status {
989 llama_cpp_bindings_sys::LLAMA_RS_PARSED_CHAT_CONTENT_OK => {
990 consume_accessor_string(out_string)
991 }
992 llama_cpp_bindings_sys::LLAMA_RS_PARSED_CHAT_CONTENT_ERROR_STRING_ALLOCATION_FAILED => {
993 unsafe { llama_cpp_bindings_sys::llama_rs_string_free(out_error) };
994 Err(ParseChatMessageError::NotEnoughMemory)
995 }
996 llama_cpp_bindings_sys::LLAMA_RS_PARSED_CHAT_CONTENT_VENDORED_THREW_CXX_EXCEPTION => {
997 let message = unsafe { crate::ffi_error_reader::read_and_free_cpp_error(out_error) };
998 Err(ParseChatMessageError::Reported { message })
999 }
1000 other => unreachable!("llama_rs_parsed_chat_content returned unrecognized status {other}"),
1001 }
1002}
1003
1004fn read_parsed_chat_reasoning_content(
1005 handle: *mut llama_cpp_bindings_sys::llama_rs_parsed_chat,
1006) -> Result<String, ParseChatMessageError> {
1007 let mut out_string: *mut c_char = ptr::null_mut();
1008 let mut out_error: *mut c_char = ptr::null_mut();
1009 let status = unsafe {
1010 llama_cpp_bindings_sys::llama_rs_parsed_chat_reasoning_content(
1011 handle,
1012 &raw mut out_string,
1013 &raw mut out_error,
1014 )
1015 };
1016 match status {
1017 llama_cpp_bindings_sys::LLAMA_RS_PARSED_CHAT_REASONING_CONTENT_OK => {
1018 consume_accessor_string(out_string)
1019 }
1020 llama_cpp_bindings_sys::LLAMA_RS_PARSED_CHAT_REASONING_CONTENT_ERROR_STRING_ALLOCATION_FAILED => {
1021 unsafe { llama_cpp_bindings_sys::llama_rs_string_free(out_error) };
1022 Err(ParseChatMessageError::NotEnoughMemory)
1023 }
1024 llama_cpp_bindings_sys::LLAMA_RS_PARSED_CHAT_REASONING_CONTENT_VENDORED_THREW_CXX_EXCEPTION => {
1025 let message =
1026 unsafe { crate::ffi_error_reader::read_and_free_cpp_error(out_error) };
1027 Err(ParseChatMessageError::Reported { message })
1028 }
1029 other => unreachable!(
1030 "llama_rs_parsed_chat_reasoning_content returned unrecognized status {other}"
1031 ),
1032 }
1033}
1034
1035fn read_parsed_chat_tool_call_count(
1036 handle: *mut llama_cpp_bindings_sys::llama_rs_parsed_chat,
1037) -> Result<usize, ParseChatMessageError> {
1038 let mut out_count: usize = 0;
1039 let mut out_error: *mut c_char = ptr::null_mut();
1040 let status = unsafe {
1041 llama_cpp_bindings_sys::llama_rs_parsed_chat_tool_call_count(
1042 handle,
1043 &raw mut out_count,
1044 &raw mut out_error,
1045 )
1046 };
1047 match status {
1048 llama_cpp_bindings_sys::LLAMA_RS_PARSED_CHAT_TOOL_CALL_COUNT_OK => Ok(out_count),
1049 llama_cpp_bindings_sys::LLAMA_RS_PARSED_CHAT_TOOL_CALL_COUNT_ERROR_STRING_ALLOCATION_FAILED => {
1050 unsafe { llama_cpp_bindings_sys::llama_rs_string_free(out_error) };
1051 Err(ParseChatMessageError::NotEnoughMemory)
1052 }
1053 llama_cpp_bindings_sys::LLAMA_RS_PARSED_CHAT_TOOL_CALL_COUNT_VENDORED_THREW_CXX_EXCEPTION => {
1054 let message =
1055 unsafe { crate::ffi_error_reader::read_and_free_cpp_error(out_error) };
1056 Err(ParseChatMessageError::Reported { message })
1057 }
1058 other => unreachable!(
1059 "llama_rs_parsed_chat_tool_call_count returned unrecognized status {other}"
1060 ),
1061 }
1062}
1063
1064fn read_parsed_chat_tool_call_id(
1065 handle: *mut llama_cpp_bindings_sys::llama_rs_parsed_chat,
1066 index: usize,
1067) -> Result<String, ParseChatMessageError> {
1068 let mut out_string: *mut c_char = ptr::null_mut();
1069 let mut out_error: *mut c_char = ptr::null_mut();
1070 let status = unsafe {
1071 llama_cpp_bindings_sys::llama_rs_parsed_chat_tool_call_id(
1072 handle,
1073 index,
1074 &raw mut out_string,
1075 &raw mut out_error,
1076 )
1077 };
1078 match status {
1079 llama_cpp_bindings_sys::LLAMA_RS_PARSED_CHAT_TOOL_CALL_ID_OK => {
1080 consume_accessor_string(out_string)
1081 }
1082 llama_cpp_bindings_sys::LLAMA_RS_PARSED_CHAT_TOOL_CALL_ID_INDEX_OUT_OF_BOUNDS => {
1083 Err(ParseChatMessageError::ToolCallIdIndexOutOfBounds { index })
1084 }
1085 llama_cpp_bindings_sys::LLAMA_RS_PARSED_CHAT_TOOL_CALL_ID_ERROR_STRING_ALLOCATION_FAILED => {
1086 unsafe { llama_cpp_bindings_sys::llama_rs_string_free(out_error) };
1087 Err(ParseChatMessageError::NotEnoughMemory)
1088 }
1089 llama_cpp_bindings_sys::LLAMA_RS_PARSED_CHAT_TOOL_CALL_ID_VENDORED_THREW_CXX_EXCEPTION => {
1090 let message =
1091 unsafe { crate::ffi_error_reader::read_and_free_cpp_error(out_error) };
1092 Err(ParseChatMessageError::Reported { message })
1093 }
1094 other => unreachable!(
1095 "llama_rs_parsed_chat_tool_call_id returned unrecognized status {other}"
1096 ),
1097 }
1098}
1099
1100fn read_parsed_chat_tool_call_name(
1101 handle: *mut llama_cpp_bindings_sys::llama_rs_parsed_chat,
1102 index: usize,
1103) -> Result<String, ParseChatMessageError> {
1104 let mut out_string: *mut c_char = ptr::null_mut();
1105 let mut out_error: *mut c_char = ptr::null_mut();
1106 let status = unsafe {
1107 llama_cpp_bindings_sys::llama_rs_parsed_chat_tool_call_name(
1108 handle,
1109 index,
1110 &raw mut out_string,
1111 &raw mut out_error,
1112 )
1113 };
1114 match status {
1115 llama_cpp_bindings_sys::LLAMA_RS_PARSED_CHAT_TOOL_CALL_NAME_OK => {
1116 consume_accessor_string(out_string)
1117 }
1118 llama_cpp_bindings_sys::LLAMA_RS_PARSED_CHAT_TOOL_CALL_NAME_INDEX_OUT_OF_BOUNDS => {
1119 Err(ParseChatMessageError::ToolCallNameIndexOutOfBounds { index })
1120 }
1121 llama_cpp_bindings_sys::LLAMA_RS_PARSED_CHAT_TOOL_CALL_NAME_ERROR_STRING_ALLOCATION_FAILED => {
1122 unsafe { llama_cpp_bindings_sys::llama_rs_string_free(out_error) };
1123 Err(ParseChatMessageError::NotEnoughMemory)
1124 }
1125 llama_cpp_bindings_sys::LLAMA_RS_PARSED_CHAT_TOOL_CALL_NAME_VENDORED_THREW_CXX_EXCEPTION => {
1126 let message =
1127 unsafe { crate::ffi_error_reader::read_and_free_cpp_error(out_error) };
1128 Err(ParseChatMessageError::Reported { message })
1129 }
1130 other => unreachable!(
1131 "llama_rs_parsed_chat_tool_call_name returned unrecognized status {other}"
1132 ),
1133 }
1134}
1135
1136fn read_parsed_chat_tool_call_arguments(
1137 handle: *mut llama_cpp_bindings_sys::llama_rs_parsed_chat,
1138 index: usize,
1139) -> Result<String, ParseChatMessageError> {
1140 let mut out_string: *mut c_char = ptr::null_mut();
1141 let mut out_error: *mut c_char = ptr::null_mut();
1142 let status = unsafe {
1143 llama_cpp_bindings_sys::llama_rs_parsed_chat_tool_call_arguments(
1144 handle,
1145 index,
1146 &raw mut out_string,
1147 &raw mut out_error,
1148 )
1149 };
1150 match status {
1151 llama_cpp_bindings_sys::LLAMA_RS_PARSED_CHAT_TOOL_CALL_ARGUMENTS_OK => {
1152 consume_accessor_string(out_string)
1153 }
1154 llama_cpp_bindings_sys::LLAMA_RS_PARSED_CHAT_TOOL_CALL_ARGUMENTS_INDEX_OUT_OF_BOUNDS => {
1155 Err(ParseChatMessageError::ToolCallArgumentsIndexOutOfBounds { index })
1156 }
1157 llama_cpp_bindings_sys::LLAMA_RS_PARSED_CHAT_TOOL_CALL_ARGUMENTS_ERROR_STRING_ALLOCATION_FAILED => {
1158 unsafe { llama_cpp_bindings_sys::llama_rs_string_free(out_error) };
1159 Err(ParseChatMessageError::NotEnoughMemory)
1160 }
1161 llama_cpp_bindings_sys::LLAMA_RS_PARSED_CHAT_TOOL_CALL_ARGUMENTS_VENDORED_THREW_CXX_EXCEPTION => {
1162 let message =
1163 unsafe { crate::ffi_error_reader::read_and_free_cpp_error(out_error) };
1164 Err(ParseChatMessageError::Reported { message })
1165 }
1166 other => unreachable!(
1167 "llama_rs_parsed_chat_tool_call_arguments returned unrecognized status {other}"
1168 ),
1169 }
1170}
1171
1172fn consume_accessor_string(ptr: *mut c_char) -> Result<String, ParseChatMessageError> {
1173 if ptr.is_null() {
1174 return Ok(String::new());
1175 }
1176 let bytes = unsafe { CStr::from_ptr(ptr) }.to_bytes().to_vec();
1177 unsafe { llama_cpp_bindings_sys::llama_rs_string_free(ptr) };
1178 Ok(String::from_utf8(bytes)?)
1179}
1180
1181struct ReasoningSplit {
1182 reasoning: String,
1183 content: String,
1184}
1185
1186fn split_reasoning_prefix(
1187 input: &str,
1188 reasoning_markers: Option<&ReasoningMarkers>,
1189 tool_call_open: &str,
1190) -> ReasoningSplit {
1191 let content_only = || ReasoningSplit {
1192 reasoning: String::new(),
1193 content: prefix_before(input, tool_call_open),
1194 };
1195
1196 let Some(reasoning_markers) = reasoning_markers else {
1197 return content_only();
1198 };
1199 let Some(open_pos) = input.find(&reasoning_markers.open) else {
1200 return content_only();
1201 };
1202
1203 let after_open = &input[open_pos + reasoning_markers.open.len()..];
1204 let Some(close_offset) = after_open.find(&reasoning_markers.close) else {
1205 return content_only();
1206 };
1207
1208 let reasoning = after_open[..close_offset].to_owned();
1209 let after_close = &after_open[close_offset + reasoning_markers.close.len()..];
1210
1211 ReasoningSplit {
1212 reasoning,
1213 content: prefix_before(after_close, tool_call_open),
1214 }
1215}
1216
1217fn prefix_before(text: &str, marker: &str) -> String {
1218 text.find(marker)
1219 .map_or_else(|| text.to_owned(), |pos| text[..pos].to_owned())
1220}
1221
1222fn synthesize_missing_tool_call_ids(tool_calls: &mut [ParsedToolCall]) {
1223 for (index, call) in tool_calls.iter_mut().enumerate() {
1224 if call.id.is_empty() {
1225 call.id = format!("call_{index}");
1226 }
1227 }
1228}
1229
1230fn invoke_detect_reasoning_markers(
1231 model: *const llama_cpp_bindings_sys::llama_model,
1232) -> Result<(Option<String>, Option<String>), MarkerDetectionError> {
1233 let mut out_open: *mut c_char = ptr::null_mut();
1234 let mut out_close: *mut c_char = ptr::null_mut();
1235 let mut out_error: *mut c_char = ptr::null_mut();
1236
1237 let status = unsafe {
1238 llama_cpp_bindings_sys::llama_rs_detect_reasoning_markers(
1239 model,
1240 &raw mut out_open,
1241 &raw mut out_close,
1242 &raw mut out_error,
1243 )
1244 };
1245
1246 let parsed = match status {
1247 llama_cpp_bindings_sys::LLAMA_RS_DETECT_REASONING_MARKERS_OK => {
1248 collect_optional_cstr_pair(out_open, out_close)
1249 }
1250 llama_cpp_bindings_sys::LLAMA_RS_DETECT_REASONING_MARKERS_ERROR_STRING_ALLOCATION_FAILED => {
1251 Err(MarkerDetectionError::NotEnoughMemory)
1252 }
1253 llama_cpp_bindings_sys::LLAMA_RS_DETECT_REASONING_MARKERS_VENDORED_THREW_CXX_EXCEPTION => {
1254 let message = unsafe { crate::ffi_error_reader::read_and_free_cpp_error(out_error) };
1255 Err(MarkerDetectionError::ReasoningMarkerDetectionFailed { message })
1256 }
1257 other => unreachable!(
1258 "llama_rs_detect_reasoning_markers returned unrecognized status {other}"
1259 ),
1260 };
1261
1262 unsafe { llama_cpp_bindings_sys::llama_rs_string_free(out_open) };
1263 unsafe { llama_cpp_bindings_sys::llama_rs_string_free(out_close) };
1264 if !matches!(
1265 parsed,
1266 Err(MarkerDetectionError::ReasoningMarkerDetectionFailed { .. })
1267 ) {
1268 unsafe { llama_cpp_bindings_sys::llama_rs_string_free(out_error) };
1269 }
1270
1271 parsed
1272}
1273
1274fn invoke_compute_tool_call_haystack(
1275 model: *const llama_cpp_bindings_sys::llama_model,
1276) -> Result<Option<String>, MarkerDetectionError> {
1277 let mut out_haystack: *mut c_char = ptr::null_mut();
1278 let mut out_error: *mut c_char = ptr::null_mut();
1279
1280 let status = unsafe {
1281 llama_cpp_bindings_sys::llama_rs_compute_tool_call_haystack(
1282 model,
1283 &raw mut out_haystack,
1284 &raw mut out_error,
1285 )
1286 };
1287
1288 let parsed = match status {
1289 llama_cpp_bindings_sys::LLAMA_RS_COMPUTE_TOOL_CALL_HAYSTACK_OK => {
1290 read_optional_owned_cstr(out_haystack)
1291 }
1292 llama_cpp_bindings_sys::LLAMA_RS_COMPUTE_TOOL_CALL_HAYSTACK_ERROR_STRING_ALLOCATION_FAILED => {
1293 Err(MarkerDetectionError::NotEnoughMemory)
1294 }
1295 llama_cpp_bindings_sys::LLAMA_RS_COMPUTE_TOOL_CALL_HAYSTACK_VENDORED_THREW_CXX_EXCEPTION => {
1296 let message = unsafe { crate::ffi_error_reader::read_and_free_cpp_error(out_error) };
1297 Err(MarkerDetectionError::ToolCallHaystackComputationFailed { message })
1298 }
1299 other => unreachable!(
1300 "llama_rs_compute_tool_call_haystack returned unrecognized status {other}"
1301 ),
1302 };
1303
1304 unsafe { llama_cpp_bindings_sys::llama_rs_string_free(out_haystack) };
1305 if !matches!(
1306 parsed,
1307 Err(MarkerDetectionError::ToolCallHaystackComputationFailed { .. })
1308 ) {
1309 unsafe { llama_cpp_bindings_sys::llama_rs_string_free(out_error) };
1310 }
1311
1312 parsed
1313}
1314
1315fn invoke_diagnose_tool_call_synthetic_renders(
1316 model: *const llama_cpp_bindings_sys::llama_model,
1317) -> Result<(Option<String>, Option<String>), MarkerDetectionError> {
1318 let mut out_no_tools: *mut c_char = ptr::null_mut();
1319 let mut out_with_tools: *mut c_char = ptr::null_mut();
1320 let mut out_error: *mut c_char = ptr::null_mut();
1321
1322 let status = unsafe {
1323 llama_cpp_bindings_sys::llama_rs_diagnose_tool_call_synthetic_renders(
1324 model,
1325 &raw mut out_no_tools,
1326 &raw mut out_with_tools,
1327 &raw mut out_error,
1328 )
1329 };
1330
1331 let parsed = match status {
1332 llama_cpp_bindings_sys::LLAMA_RS_DIAGNOSE_TOOL_CALL_SYNTHETIC_RENDERS_OK => {
1333 collect_optional_cstr_pair(out_no_tools, out_with_tools)
1334 }
1335 llama_cpp_bindings_sys::LLAMA_RS_DIAGNOSE_TOOL_CALL_SYNTHETIC_RENDERS_ERROR_STRING_ALLOCATION_FAILED => {
1336 Err(MarkerDetectionError::NotEnoughMemory)
1337 }
1338 llama_cpp_bindings_sys::LLAMA_RS_DIAGNOSE_TOOL_CALL_SYNTHETIC_RENDERS_VENDORED_THREW_CXX_EXCEPTION => {
1339 let message = unsafe { crate::ffi_error_reader::read_and_free_cpp_error(out_error) };
1340 Err(MarkerDetectionError::ToolCallSyntheticRenderDiagnosisFailed { message })
1341 }
1342 other => unreachable!(
1343 "llama_rs_diagnose_tool_call_synthetic_renders returned unrecognized status {other}"
1344 ),
1345 };
1346
1347 unsafe { llama_cpp_bindings_sys::llama_rs_string_free(out_no_tools) };
1348 unsafe { llama_cpp_bindings_sys::llama_rs_string_free(out_with_tools) };
1349 if !matches!(
1350 parsed,
1351 Err(MarkerDetectionError::ToolCallSyntheticRenderDiagnosisFailed { .. })
1352 ) {
1353 unsafe { llama_cpp_bindings_sys::llama_rs_string_free(out_error) };
1354 }
1355
1356 parsed
1357}
1358
1359fn read_optional_owned_cstr(ptr: *const c_char) -> Result<Option<String>, MarkerDetectionError> {
1360 if ptr.is_null() {
1361 return Ok(None);
1362 }
1363
1364 let bytes = unsafe { CStr::from_ptr(ptr) }.to_bytes().to_vec();
1365
1366 Ok(Some(String::from_utf8(bytes)?))
1367}
1368
1369fn invoke_rs_tokenize(
1370 vocab: *const llama_cpp_bindings_sys::llama_vocab,
1371 text: *const c_char,
1372 text_len: c_int,
1373 tokens: *mut llama_cpp_bindings_sys::llama_token,
1374 n_tokens_max: c_int,
1375 add_bos: bool,
1376) -> Result<c_int, StringToTokenError> {
1377 let mut out_count: i32 = 0;
1378 let mut out_error: *mut c_char = ptr::null_mut();
1379 let status = unsafe {
1380 llama_cpp_bindings_sys::llama_rs_tokenize(
1381 vocab,
1382 text,
1383 text_len,
1384 tokens,
1385 n_tokens_max,
1386 add_bos,
1387 true,
1388 &raw mut out_count,
1389 &raw mut out_error,
1390 )
1391 };
1392 match status {
1393 llama_cpp_bindings_sys::LLAMA_RS_TOKENIZE_OK => Ok(out_count),
1394 llama_cpp_bindings_sys::LLAMA_RS_TOKENIZE_ERROR_STRING_ALLOCATION_FAILED => {
1395 Err(StringToTokenError::NotEnoughMemory)
1396 }
1397 llama_cpp_bindings_sys::LLAMA_RS_TOKENIZE_VENDORED_THREW_CXX_EXCEPTION => {
1398 let message = unsafe { crate::ffi_error_reader::read_and_free_cpp_error(out_error) };
1399 Err(StringToTokenError::Reported { message })
1400 }
1401 other => unreachable!("llama_rs_tokenize returned unrecognized status {other}"),
1402 }
1403}
1404
1405fn collect_optional_cstr_pair(
1406 first_ptr: *const c_char,
1407 second_ptr: *const c_char,
1408) -> Result<(Option<String>, Option<String>), MarkerDetectionError> {
1409 let first = read_optional_owned_cstr(first_ptr)?;
1410 let second = read_optional_owned_cstr(second_ptr)?;
1411 Ok((first, second))
1412}
1413
1414fn extract_meta_string<TCFunction>(
1415 c_function: TCFunction,
1416 capacity: usize,
1417) -> Result<String, MetaValError>
1418where
1419 TCFunction: Fn(*mut c_char, usize) -> i32,
1420{
1421 let mut buffer = vec![0u8; capacity];
1422 let result = c_function(buffer.as_mut_ptr().cast::<c_char>(), buffer.len());
1423
1424 if result < 0 {
1425 return Err(MetaValError::NegativeReturn(result));
1426 }
1427
1428 let returned_len = result.cast_unsigned() as usize;
1429
1430 if returned_len >= capacity {
1431 return extract_meta_string(c_function, returned_len + 1);
1432 }
1433
1434 if buffer.get(returned_len) != Some(&0) {
1435 return Err(MetaValError::NegativeReturn(-1));
1436 }
1437
1438 buffer.truncate(returned_len);
1439
1440 Ok(String::from_utf8(buffer)?)
1441}
1442
1443impl Drop for LlamaModel {
1444 fn drop(&mut self) {
1445 unsafe { llama_cpp_bindings_sys::llama_free_model(self.model.as_ptr()) }
1446 }
1447}
1448
1449#[cfg(test)]
1450mod extract_meta_string_tests {
1451 use super::extract_meta_string;
1452 use crate::MetaValError;
1453
1454 #[test]
1455 fn returns_error_when_null_terminator_missing() {
1456 let result = extract_meta_string(
1457 |buf_ptr, buf_len| {
1458 let buffer =
1459 unsafe { std::slice::from_raw_parts_mut(buf_ptr.cast::<u8>(), buf_len) };
1460 buffer[0] = b'a';
1461 buffer[1] = b'b';
1462 buffer[2] = b'c';
1463 2
1464 },
1465 4,
1466 );
1467
1468 assert_eq!(result.unwrap_err(), MetaValError::NegativeReturn(-1));
1469 }
1470
1471 #[test]
1472 fn returns_error_for_negative_return_value() {
1473 let result = extract_meta_string(|_buf_ptr, _buf_len| -5, 4);
1474
1475 assert_eq!(result.unwrap_err(), MetaValError::NegativeReturn(-5));
1476 }
1477
1478 #[test]
1479 fn returns_error_for_invalid_utf8_data() {
1480 let result = extract_meta_string(
1481 |buf_ptr, buf_len| {
1482 let buffer =
1483 unsafe { std::slice::from_raw_parts_mut(buf_ptr.cast::<u8>(), buf_len) };
1484 buffer[0] = 0xFF;
1485 buffer[1] = 0xFE;
1486 buffer[2] = 0;
1487 2
1488 },
1489 4,
1490 );
1491
1492 assert!(result.is_err());
1493 assert!(result.unwrap_err().to_string().contains("FromUtf8Error"));
1494 }
1495
1496 #[test]
1497 fn triggers_buffer_resize_when_returned_len_exceeds_capacity() {
1498 let initial_capacity: usize = 4;
1499 let length_exceeding_initial_capacity = 10;
1500 let written_length = 2;
1501 let call_count = std::cell::Cell::new(0);
1502 let result = extract_meta_string(
1503 |buf_ptr, buf_len| {
1504 let count = call_count.get();
1505 call_count.set(count + 1);
1506 if count == 0 {
1507 length_exceeding_initial_capacity
1508 } else {
1509 let buffer =
1510 unsafe { std::slice::from_raw_parts_mut(buf_ptr.cast::<u8>(), buf_len) };
1511 buffer[0] = b'h';
1512 buffer[1] = b'i';
1513 buffer[2] = 0;
1514 written_length
1515 }
1516 },
1517 initial_capacity,
1518 );
1519
1520 assert_eq!(result.unwrap(), "hi");
1521 }
1522
1523 #[test]
1524 fn cstring_with_validated_len_null_byte_returns_error() {
1525 let result = super::cstring_with_validated_len("null\0byte");
1526
1527 assert!(result.is_err());
1528 }
1529
1530 #[test]
1531 fn validate_string_length_overflow_returns_error() {
1532 let result = super::validate_string_length_for_tokenizer(usize::MAX);
1533
1534 assert!(result.is_err());
1535 }
1536
1537 #[test]
1538 fn truncated_buffer_to_string_with_invalid_utf8_returns_error() {
1539 let invalid_utf8 = vec![0xff, 0xfe, 0xfd];
1540 let result = super::truncated_buffer_to_string(invalid_utf8, 3);
1541
1542 assert!(result.is_err());
1543 }
1544}