1use std::ffi::{CStr, CString, c_char};
3use std::num::NonZeroU16;
4use std::os::raw::c_int;
5use std::path::Path;
6use std::sync::Arc;
7use std::sync::OnceLock;
8
9use toktrie::ApproximateTokEnv;
10use toktrie::TokRxInfo;
11use toktrie::TokTrie;
12
13fn truncated_buffer_to_string(
14 mut buffer: Vec<u8>,
15 length: usize,
16) -> Result<String, ApplyChatTemplateError> {
17 buffer.truncate(length);
18
19 Ok(String::from_utf8(buffer)?)
20}
21
22fn validate_string_length_for_tokenizer(length: usize) -> Result<c_int, StringToTokenError> {
23 Ok(c_int::try_from(length)?)
24}
25
26fn cstring_with_validated_len(str: &str) -> Result<(CString, c_int), StringToTokenError> {
27 let c_string = CString::new(str)?;
28 let len = validate_string_length_for_tokenizer(c_string.as_bytes().len())?;
29 Ok((c_string, len))
30}
31use std::ptr::{self, NonNull};
32
33use crate::chat_message_parse_outcome::ChatMessageParseOutcome;
34use crate::ffi_status_to_i32::status_to_i32;
35use crate::llama_backend::LlamaBackend;
36use crate::llama_token_attrs::LlamaTokenAttrs;
37use crate::llama_token_attrs_from_int_error::LlamaTokenAttrsFromIntError;
38use crate::raw_chat_message::RawChatMessage;
39use crate::resolved_tool_call_markers::ResolvedToolCallMarkers;
40use crate::sampled_token::SampledToken;
41use crate::sampled_token_classifier::SampledTokenClassifier;
42use crate::sampled_token_classifier::StreamingMarkers;
43use crate::token::LlamaToken;
44use crate::{
45 ApplyChatTemplateError, ChatTemplateError, LlamaLoraAdapterInitError, LlamaModelLoadError,
46 MarkerDetectionError, MetaValError, ParseChatMessageError, StringToTokenError,
47 TokenToStringError,
48};
49use llama_cpp_bindings_types::ParsedChatMessage;
50use llama_cpp_bindings_types::ParsedToolCall;
51use llama_cpp_bindings_types::ReasoningMarkers;
52use llama_cpp_bindings_types::ToolCallArguments;
53use llama_cpp_bindings_types::ToolCallMarkers;
54
55use crate::tool_call_format;
56use crate::tool_call_format::ToolCallFormatOutcome;
57use crate::tool_call_template_overrides;
58
59pub mod add_bos;
60pub mod llama_chat_message;
61pub mod llama_chat_template;
62pub mod llama_lora_adapter;
63pub mod params;
64pub mod rope_type;
65pub mod split_mode;
66pub mod vocab_type;
67pub mod vocab_type_from_int_error;
68
69pub use add_bos::AddBos;
70pub use llama_chat_message::LlamaChatMessage;
71pub use llama_chat_template::LlamaChatTemplate;
72pub use llama_lora_adapter::LlamaLoraAdapter;
73pub use rope_type::RopeType;
74pub use vocab_type::VocabType;
75pub use vocab_type_from_int_error::VocabTypeFromIntError;
76
77use params::LlamaModelParams;
78
79pub struct LlamaModel {
81 pub model: NonNull<llama_cpp_bindings_sys::llama_model>,
83 tok_env: OnceLock<Arc<ApproximateTokEnv>>,
84}
85
86impl std::fmt::Debug for LlamaModel {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 f.debug_struct("LlamaModel")
89 .field("model", &self.model)
90 .finish_non_exhaustive()
91 }
92}
93
94unsafe impl Send for LlamaModel {}
95
96unsafe impl Sync for LlamaModel {}
97
98impl LlamaModel {
99 #[must_use]
101 pub fn vocab_ptr(&self) -> *const llama_cpp_bindings_sys::llama_vocab {
102 unsafe { llama_cpp_bindings_sys::llama_model_get_vocab(self.model.as_ptr()) }
103 }
104
105 pub fn n_ctx_train(&self) -> Result<u32, std::num::TryFromIntError> {
111 let n_ctx_train = unsafe { llama_cpp_bindings_sys::llama_n_ctx_train(self.model.as_ptr()) };
112
113 u32::try_from(n_ctx_train)
114 }
115
116 pub fn tokens(
118 &self,
119 decode_special: bool,
120 ) -> impl Iterator<Item = (LlamaToken, Result<String, TokenToStringError>)> + '_ {
121 (0..self.n_vocab())
122 .map(LlamaToken::new)
123 .map(move |llama_token| {
124 let mut decoder = encoding_rs::UTF_8.new_decoder();
125 (
126 llama_token,
127 self.token_to_piece(
128 &SampledToken::Content(llama_token),
129 &mut decoder,
130 decode_special,
131 None,
132 ),
133 )
134 })
135 }
136
137 #[must_use]
139 pub fn token_bos(&self) -> LlamaToken {
140 let token = unsafe { llama_cpp_bindings_sys::llama_token_bos(self.vocab_ptr()) };
141 LlamaToken(token)
142 }
143
144 #[must_use]
146 pub fn token_eos(&self) -> LlamaToken {
147 let token = unsafe { llama_cpp_bindings_sys::llama_token_eos(self.vocab_ptr()) };
148 LlamaToken(token)
149 }
150
151 #[must_use]
153 pub fn token_nl(&self) -> LlamaToken {
154 let token = unsafe { llama_cpp_bindings_sys::llama_token_nl(self.vocab_ptr()) };
155 LlamaToken(token)
156 }
157
158 #[must_use]
160 pub fn is_eog_token(&self, token: &SampledToken) -> bool {
161 let (SampledToken::Content(LlamaToken(id))
162 | SampledToken::Reasoning(LlamaToken(id))
163 | SampledToken::ToolCall(LlamaToken(id))
164 | SampledToken::Undeterminable(LlamaToken(id))) = *token;
165
166 unsafe { llama_cpp_bindings_sys::llama_token_is_eog(self.vocab_ptr(), id) }
167 }
168
169 #[must_use]
171 pub fn decode_start_token(&self) -> LlamaToken {
172 let token =
173 unsafe { llama_cpp_bindings_sys::llama_model_decoder_start_token(self.model.as_ptr()) };
174 LlamaToken(token)
175 }
176
177 #[must_use]
179 pub fn token_sep(&self) -> LlamaToken {
180 let token = unsafe { llama_cpp_bindings_sys::llama_vocab_sep(self.vocab_ptr()) };
181 LlamaToken(token)
182 }
183
184 pub fn str_to_token(
204 &self,
205 str: &str,
206 add_bos: AddBos,
207 ) -> Result<Vec<LlamaToken>, StringToTokenError> {
208 let add_bos = match add_bos {
209 AddBos::Always => true,
210 AddBos::Never => false,
211 };
212
213 let tokens_estimation = std::cmp::max(8, (str.len() / 2) + usize::from(add_bos));
214 let mut buffer: Vec<LlamaToken> = Vec::with_capacity(tokens_estimation);
215
216 let (c_string, c_string_len) = cstring_with_validated_len(str)?;
217 let buffer_capacity = c_int::try_from(buffer.capacity())?;
218
219 let size = unsafe {
220 llama_cpp_bindings_sys::llama_tokenize(
221 self.vocab_ptr(),
222 c_string.as_ptr(),
223 c_string_len,
224 buffer
225 .as_mut_ptr()
226 .cast::<llama_cpp_bindings_sys::llama_token>(),
227 buffer_capacity,
228 add_bos,
229 true,
230 )
231 };
232
233 let size = if size.is_negative() {
234 buffer.reserve_exact(usize::try_from(-size)?);
235 unsafe {
236 llama_cpp_bindings_sys::llama_tokenize(
237 self.vocab_ptr(),
238 c_string.as_ptr(),
239 c_string_len,
240 buffer
241 .as_mut_ptr()
242 .cast::<llama_cpp_bindings_sys::llama_token>(),
243 -size,
244 add_bos,
245 true,
246 )
247 }
248 } else {
249 size
250 };
251
252 let size = usize::try_from(size)?;
253
254 unsafe { buffer.set_len(size) }
256
257 Ok(buffer)
258 }
259
260 pub fn token_attr(
266 &self,
267 LlamaToken(id): LlamaToken,
268 ) -> Result<LlamaTokenAttrs, LlamaTokenAttrsFromIntError> {
269 let token_type =
270 unsafe { llama_cpp_bindings_sys::llama_token_get_attr(self.vocab_ptr(), id) };
271
272 LlamaTokenAttrs::try_from(token_type)
273 }
274
275 pub fn token_to_piece(
291 &self,
292 token: &SampledToken,
293 decoder: &mut encoding_rs::Decoder,
294 special: bool,
295 lstrip: Option<NonZeroU16>,
296 ) -> Result<String, TokenToStringError> {
297 let (SampledToken::Content(inner)
298 | SampledToken::Reasoning(inner)
299 | SampledToken::ToolCall(inner)
300 | SampledToken::Undeterminable(inner)) = *token;
301 let bytes = match self.token_to_piece_bytes(inner, 8, special, lstrip) {
302 Err(TokenToStringError::InsufficientBufferSpace(required_size)) => {
303 let buffer_size: usize = (-required_size).try_into()?;
304
305 self.token_to_piece_bytes(inner, buffer_size, special, lstrip)
306 }
307 other => other,
308 }?;
309
310 let mut output_piece = String::with_capacity(bytes.len());
311 let (_result, _decoded_size, _had_replacements) =
312 decoder.decode_to_string(&bytes, &mut output_piece, false);
313
314 Ok(output_piece)
315 }
316
317 pub fn token_to_piece_bytes(
329 &self,
330 token: LlamaToken,
331 buffer_size: usize,
332 special: bool,
333 lstrip: Option<NonZeroU16>,
334 ) -> Result<Vec<u8>, TokenToStringError> {
335 let mut buffer: Vec<u8> = vec![0u8; buffer_size];
336 let buffer_len = c_int::try_from(buffer.len())?;
337 let lstrip = lstrip.map_or(0, |strip_count| i32::from(strip_count.get()));
338 let size = unsafe {
339 llama_cpp_bindings_sys::llama_token_to_piece(
340 self.vocab_ptr(),
341 token.0,
342 buffer.as_mut_ptr().cast::<c_char>(),
343 buffer_len,
344 lstrip,
345 special,
346 )
347 };
348
349 match size {
350 0 => Err(TokenToStringError::UnknownTokenType),
351 error_code if error_code.is_negative() => {
352 Err(TokenToStringError::InsufficientBufferSpace(error_code))
353 }
354 size => {
355 let written = usize::try_from(size)?;
356 buffer.truncate(written);
357
358 Ok(buffer)
359 }
360 }
361 }
362
363 #[must_use]
368 pub fn n_vocab(&self) -> i32 {
369 unsafe { llama_cpp_bindings_sys::llama_n_vocab(self.vocab_ptr()) }
370 }
371
372 pub fn vocab_type(&self) -> Result<VocabType, VocabTypeFromIntError> {
378 let vocab_type = unsafe { llama_cpp_bindings_sys::llama_vocab_type(self.vocab_ptr()) };
379
380 VocabType::try_from(vocab_type)
381 }
382
383 #[must_use]
386 pub fn n_embd(&self) -> c_int {
387 unsafe { llama_cpp_bindings_sys::llama_n_embd(self.model.as_ptr()) }
388 }
389
390 #[must_use]
392 pub fn size(&self) -> u64 {
393 unsafe { llama_cpp_bindings_sys::llama_model_size(self.model.as_ptr()) }
394 }
395
396 #[must_use]
398 pub fn n_params(&self) -> u64 {
399 unsafe { llama_cpp_bindings_sys::llama_model_n_params(self.model.as_ptr()) }
400 }
401
402 #[must_use]
404 pub fn is_recurrent(&self) -> bool {
405 unsafe { llama_cpp_bindings_sys::llama_model_is_recurrent(self.model.as_ptr()) }
406 }
407
408 pub fn n_layer(&self) -> Result<u32, std::num::TryFromIntError> {
414 u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_layer(self.model.as_ptr()) })
415 }
416
417 pub fn n_head(&self) -> Result<u32, std::num::TryFromIntError> {
423 u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_head(self.model.as_ptr()) })
424 }
425
426 pub fn n_head_kv(&self) -> Result<u32, std::num::TryFromIntError> {
432 u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_head_kv(self.model.as_ptr()) })
433 }
434
435 #[must_use]
439 pub fn is_hybrid(&self) -> bool {
440 unsafe { llama_cpp_bindings_sys::llama_model_is_hybrid(self.model.as_ptr()) }
441 }
442
443 pub fn meta_val_str(&self, key: &str) -> Result<String, MetaValError> {
448 let key_cstring = CString::new(key)?;
449 let key_ptr = key_cstring.as_ptr();
450
451 extract_meta_string(
452 |buf_ptr, buf_len| unsafe {
453 llama_cpp_bindings_sys::llama_model_meta_val_str(
454 self.model.as_ptr(),
455 key_ptr,
456 buf_ptr,
457 buf_len,
458 )
459 },
460 256,
461 )
462 }
463
464 #[must_use]
466 pub fn meta_count(&self) -> i32 {
467 unsafe { llama_cpp_bindings_sys::llama_model_meta_count(self.model.as_ptr()) }
468 }
469
470 pub fn meta_key_by_index(&self, index: i32) -> Result<String, MetaValError> {
475 extract_meta_string(
476 |buf_ptr, buf_len| unsafe {
477 llama_cpp_bindings_sys::llama_model_meta_key_by_index(
478 self.model.as_ptr(),
479 index,
480 buf_ptr,
481 buf_len,
482 )
483 },
484 256,
485 )
486 }
487
488 pub fn meta_val_str_by_index(&self, index: i32) -> Result<String, MetaValError> {
493 extract_meta_string(
494 |buf_ptr, buf_len| unsafe {
495 llama_cpp_bindings_sys::llama_model_meta_val_str_by_index(
496 self.model.as_ptr(),
497 index,
498 buf_ptr,
499 buf_len,
500 )
501 },
502 256,
503 )
504 }
505
506 #[must_use]
508 pub fn rope_type(&self) -> Option<RopeType> {
509 let raw = unsafe { llama_cpp_bindings_sys::llama_model_rope_type(self.model.as_ptr()) };
510
511 rope_type::rope_type_from_raw(raw)
512 }
513
514 pub fn chat_template(
532 &self,
533 name: Option<&str>,
534 ) -> Result<LlamaChatTemplate, ChatTemplateError> {
535 let name_cstr = name.map(CString::new);
536 let name_ptr = match name_cstr {
537 Some(Ok(name)) => name.as_ptr(),
538 _ => ptr::null(),
539 };
540 let result = unsafe {
541 llama_cpp_bindings_sys::llama_model_chat_template(self.model.as_ptr(), name_ptr)
542 };
543
544 if result.is_null() {
545 Err(ChatTemplateError::MissingTemplate)
546 } else {
547 let chat_template_cstr = unsafe { CStr::from_ptr(result) };
548
549 Ok(LlamaChatTemplate(chat_template_cstr.to_owned()))
550 }
551 }
552
553 #[tracing::instrument(skip_all, fields(params))]
563 pub fn load_from_file(
564 _: &LlamaBackend,
565 path: impl AsRef<Path>,
566 params: &LlamaModelParams,
567 ) -> Result<Self, LlamaModelLoadError> {
568 let path = path.as_ref();
569
570 let path_str = path
571 .to_str()
572 .ok_or_else(|| LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
573
574 if !path.exists() {
575 return Err(LlamaModelLoadError::FileNotFound(path.to_path_buf()));
576 }
577
578 let cstr = CString::new(path_str)?;
579 let llama_model = unsafe {
580 llama_cpp_bindings_sys::llama_load_model_from_file(cstr.as_ptr(), params.params)
581 };
582
583 let model = match NonNull::new(llama_model) {
584 Some(ptr) => ptr,
585 None if !path.exists() => {
586 return Err(LlamaModelLoadError::FileNotFound(path.to_path_buf()));
587 }
588 None => return Err(LlamaModelLoadError::NullResult),
589 };
590
591 Ok(Self {
592 model,
593 tok_env: OnceLock::new(),
594 })
595 }
596
597 pub fn lora_adapter_init(
603 &self,
604 path: impl AsRef<Path>,
605 ) -> Result<LlamaLoraAdapter, LlamaLoraAdapterInitError> {
606 let path = path.as_ref();
607
608 let path_str = path
609 .to_str()
610 .ok_or_else(|| LlamaLoraAdapterInitError::PathToStrError(path.to_path_buf()))?;
611
612 if !path.exists() {
613 return Err(LlamaLoraAdapterInitError::FileNotFound(path.to_path_buf()));
614 }
615
616 let cstr = CString::new(path_str)?;
617 let raw_adapter = unsafe {
618 llama_cpp_bindings_sys::llama_adapter_lora_init(self.model.as_ptr(), cstr.as_ptr())
619 };
620
621 let Some(adapter) = NonNull::new(raw_adapter) else {
622 return Err(LlamaLoraAdapterInitError::NullResult);
623 };
624
625 Ok(LlamaLoraAdapter {
626 lora_adapter: adapter,
627 })
628 }
629
630 #[tracing::instrument(skip_all)]
648 pub fn apply_chat_template(
649 &self,
650 tmpl: &LlamaChatTemplate,
651 chat: &[LlamaChatMessage],
652 add_ass: bool,
653 ) -> Result<String, ApplyChatTemplateError> {
654 let message_length = chat.iter().fold(0, |acc, chat_message| {
655 acc + chat_message.role.to_bytes().len() + chat_message.content.to_bytes().len()
656 });
657 let mut buff: Vec<u8> = vec![0; message_length * 2];
658
659 let chat: Vec<llama_cpp_bindings_sys::llama_chat_message> = chat
660 .iter()
661 .map(|chat_message| llama_cpp_bindings_sys::llama_chat_message {
662 role: chat_message.role.as_ptr(),
663 content: chat_message.content.as_ptr(),
664 })
665 .collect();
666
667 let tmpl_ptr = tmpl.0.as_ptr();
668
669 let buff_len: i32 = buff.len().try_into()?;
670
671 let res = unsafe {
672 llama_cpp_bindings_sys::llama_chat_apply_template(
673 tmpl_ptr,
674 chat.as_ptr(),
675 chat.len(),
676 add_ass,
677 buff.as_mut_ptr().cast::<c_char>(),
678 buff_len,
679 )
680 };
681
682 if res > buff_len {
683 let required_size: usize = res.try_into()?;
684 buff.resize(required_size, 0);
685
686 let new_buff_len: i32 = buff.len().try_into()?;
687
688 let res = unsafe {
689 llama_cpp_bindings_sys::llama_chat_apply_template(
690 tmpl_ptr,
691 chat.as_ptr(),
692 chat.len(),
693 add_ass,
694 buff.as_mut_ptr().cast::<c_char>(),
695 new_buff_len,
696 )
697 };
698 let final_size: usize = res.try_into()?;
699
700 return truncated_buffer_to_string(buff, final_size);
701 }
702
703 let final_size: usize = res.try_into()?;
704
705 truncated_buffer_to_string(buff, final_size)
706 }
707
708 pub fn sampled_token_classifier(&self) -> SampledTokenClassifier<'_> {
720 let markers = match self.streaming_markers() {
721 Ok(markers) => markers,
722 Err(detection_error) => {
723 tracing::warn!(
724 "streaming markers detection failed; classifier will run blind: {detection_error}"
725 );
726 StreamingMarkers::default()
727 }
728 };
729
730 SampledTokenClassifier::new(self, markers)
731 }
732
733 pub fn streaming_markers(&self) -> Result<StreamingMarkers, MarkerDetectionError> {
742 let (reasoning_open_str, reasoning_close_str) =
743 invoke_ffi_string_pair_detector(|first, second, error| unsafe {
744 llama_cpp_bindings_sys::llama_rs_detect_reasoning_markers(
745 self.model.as_ptr(),
746 first,
747 second,
748 error,
749 )
750 })?;
751
752 let tool_call_haystack = invoke_ffi_single_string_detector(|haystack, error| unsafe {
753 llama_cpp_bindings_sys::llama_rs_compute_tool_call_haystack(
754 self.model.as_ptr(),
755 haystack,
756 error,
757 )
758 })?;
759
760 let autoparser_pair = tool_call_haystack.as_deref().and_then(
761 crate::extract_tool_call_markers_from_haystack::extract_tool_call_markers_from_haystack,
762 );
763
764 let (autoparser_open, autoparser_close) = match autoparser_pair {
765 Some(crate::tool_call_marker_pair::ToolCallMarkerPair { open, close }) => {
766 (Some(open), Some(close))
767 }
768 None => (None, None),
769 };
770
771 let resolved_tool_call_markers =
772 self.resolve_tool_call_marker_strings(autoparser_open, autoparser_close);
773
774 Ok(StreamingMarkers {
775 reasoning_open: self.tokenize_marker(reasoning_open_str.as_deref()),
776 reasoning_close: self.tokenize_marker(reasoning_close_str.as_deref()),
777 tool_call_open: self.tokenize_marker(resolved_tool_call_markers.open.as_deref()),
778 tool_call_close: self.tokenize_marker(resolved_tool_call_markers.close.as_deref()),
779 })
780 }
781
782 fn resolve_tool_call_marker_strings(
786 &self,
787 autoparser_open: Option<String>,
788 autoparser_close: Option<String>,
789 ) -> ResolvedToolCallMarkers {
790 if autoparser_open
791 .as_deref()
792 .is_some_and(|raw| !raw.trim().is_empty())
793 {
794 return ResolvedToolCallMarkers {
795 open: autoparser_open,
796 close: autoparser_close,
797 };
798 }
799 let Some(markers) = self.tool_call_markers() else {
800 return ResolvedToolCallMarkers {
801 open: autoparser_open,
802 close: autoparser_close,
803 };
804 };
805 let close = if markers.close.is_empty() {
806 None
807 } else {
808 Some(markers.close)
809 };
810 ResolvedToolCallMarkers {
811 open: Some(markers.open),
812 close,
813 }
814 }
815
816 pub fn reasoning_markers(&self) -> Result<Option<ReasoningMarkers>, MarkerDetectionError> {
819 let (open, close) = invoke_ffi_string_pair_detector(|first, second, error| unsafe {
820 llama_cpp_bindings_sys::llama_rs_detect_reasoning_markers(
821 self.model.as_ptr(),
822 first,
823 second,
824 error,
825 )
826 })?;
827
828 match (open, close) {
829 (Some(open), Some(close)) if !open.is_empty() && !close.is_empty() => {
830 Ok(Some(ReasoningMarkers { open, close }))
831 }
832 _ => Ok(None),
833 }
834 }
835
836 #[must_use]
842 pub fn tool_call_markers(&self) -> Option<ToolCallMarkers> {
843 let template = match self.chat_template(None) {
844 Ok(template) => template,
845 Err(error) => {
846 tracing::debug!(
847 "tool-call markers unavailable: chat template missing or invalid: {error}"
848 );
849 return None;
850 }
851 };
852 let template_str = match template.to_str() {
853 Ok(template_str) => template_str,
854 Err(error) => {
855 tracing::debug!(
856 "tool-call markers unavailable: chat template is not valid UTF-8: {error}"
857 );
858 return None;
859 }
860 };
861 tool_call_template_overrides::detect(template_str)
862 }
863
864 fn tokenize_marker(&self, marker: Option<&str>) -> Option<Vec<LlamaToken>> {
865 let marker = marker?.trim();
866 if marker.is_empty() {
867 return None;
868 }
869 match self.str_to_token(marker, AddBos::Never) {
870 Ok(tokens) if !tokens.is_empty() => Some(tokens),
871 Ok(_) => None,
872 Err(tokenize_error) => {
873 tracing::debug!(
874 "marker {marker:?} failed to tokenise; classifier will ignore it: {tokenize_error}"
875 );
876 None
877 }
878 }
879 }
880
881 pub fn parse_chat_message(
908 &self,
909 tools_json: &str,
910 input: &str,
911 is_partial: bool,
912 ) -> Result<ChatMessageParseOutcome, ParseChatMessageError> {
913 let tools_value: serde_json::Value =
914 serde_json::from_str(tools_json).map_err(ParseChatMessageError::ToolsJsonInvalid)?;
915 if !tools_value.is_array() {
916 return Err(ParseChatMessageError::ToolsJsonNotArray);
917 }
918
919 let reasoning_markers = self.reasoning_markers().ok().flatten();
920
921 for candidate in tool_call_template_overrides::known_marker_candidates() {
922 if let ToolCallFormatOutcome::Parsed(calls) =
923 tool_call_format::try_parse(input, &candidate)
924 {
925 let split =
926 split_reasoning_prefix(input, reasoning_markers.as_ref(), &candidate.open);
927 let mut parsed = ParsedChatMessage::new(split.content, split.reasoning, calls);
928 synthesize_missing_tool_call_ids(&mut parsed.tool_calls);
929 return Ok(ChatMessageParseOutcome::Recognized(parsed));
930 }
931 }
932
933 match self.parse_chat_message_via_ffi(tools_json, input, is_partial) {
934 Ok(mut parsed) => {
935 synthesize_missing_tool_call_ids(&mut parsed.tool_calls);
936 Ok(ChatMessageParseOutcome::Recognized(parsed))
937 }
938 Err(ParseChatMessageError::ParseException(ffi_error_message)) => {
939 Ok(ChatMessageParseOutcome::Unrecognized(RawChatMessage {
940 tools_json: tools_json.to_owned(),
941 text: input.to_owned(),
942 is_partial,
943 ffi_error_message,
944 }))
945 }
946 Err(other) => Err(other),
947 }
948 }
949
950 fn parse_chat_message_via_ffi(
951 &self,
952 tools_json: &str,
953 input: &str,
954 is_partial: bool,
955 ) -> Result<ParsedChatMessage, ParseChatMessageError> {
956 let tools_cstring = CString::new(tools_json)
957 .map_err(|err| ParseChatMessageError::ToolsSerialization(err.to_string()))?;
958 let input_cstring = CString::new(input)
959 .map_err(|err| ParseChatMessageError::ToolsSerialization(err.to_string()))?;
960
961 let mut handle: *mut llama_cpp_bindings_sys::llama_rs_parsed_chat = ptr::null_mut();
962 let mut out_error: *mut c_char = ptr::null_mut();
963
964 let status = unsafe {
965 llama_cpp_bindings_sys::llama_rs_parse_chat_message(
966 self.model.as_ptr(),
967 tools_cstring.as_ptr(),
968 input_cstring.as_ptr(),
969 i32::from(is_partial),
970 &raw mut handle,
971 &raw mut out_error,
972 )
973 };
974
975 let parsed = match status {
976 llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK => collect_parsed_chat_message(handle),
977 llama_cpp_bindings_sys::LLAMA_RS_STATUS_EXCEPTION => {
978 let message = read_optional_owned_cstr_lossy(out_error);
979 Err(ParseChatMessageError::ParseException(message))
980 }
981 other => Err(ParseChatMessageError::FfiError(status_to_i32(other))),
982 };
983
984 unsafe { llama_cpp_bindings_sys::llama_rs_parsed_chat_free(handle) };
985 unsafe { llama_cpp_bindings_sys::llama_rs_string_free(out_error) };
986
987 parsed
988 }
989
990 pub fn diagnose_tool_call_synthetic_renders(
1000 &self,
1001 ) -> Result<(String, String), MarkerDetectionError> {
1002 let (no_tools, with_tools) =
1003 invoke_ffi_string_pair_detector(|first, second, error| unsafe {
1004 llama_cpp_bindings_sys::llama_rs_diagnose_tool_call_synthetic_renders(
1005 self.model.as_ptr(),
1006 first,
1007 second,
1008 error,
1009 )
1010 })?;
1011
1012 Ok((no_tools.unwrap_or_default(), with_tools.unwrap_or_default()))
1013 }
1014}
1015
1016impl LlamaModel {
1017 pub fn approximate_tok_env(&self) -> Arc<ApproximateTokEnv> {
1022 Arc::clone(self.tok_env.get_or_init(|| build_approximate_tok_env(self)))
1023 }
1024}
1025
1026fn build_approximate_tok_env(model: &LlamaModel) -> Arc<ApproximateTokEnv> {
1027 let n_vocab = model.n_vocab().cast_unsigned();
1028 let tok_eos = {
1029 let eot = unsafe { llama_cpp_bindings_sys::llama_vocab_eot(model.vocab_ptr()) };
1030 if eot == -1 {
1031 model.token_eos().0.cast_unsigned()
1032 } else {
1033 eot.cast_unsigned()
1034 }
1035 };
1036 let info = TokRxInfo::new(n_vocab, tok_eos);
1037
1038 let mut words = Vec::with_capacity(n_vocab as usize);
1039
1040 for token_id in 0..n_vocab.cast_signed() {
1041 let token = LlamaToken(token_id);
1042 let bytes = model
1043 .token_to_piece_bytes(token, 32, false, None)
1044 .unwrap_or_default();
1045 if bytes.is_empty() {
1046 let special_bytes = model
1047 .token_to_piece_bytes(token, 32, true, None)
1048 .unwrap_or_default();
1049 if special_bytes.is_empty() {
1050 words.push(vec![]);
1051 } else {
1052 let mut marked = Vec::with_capacity(special_bytes.len() + 1);
1053 marked.push(0xFF);
1054 marked.extend(special_bytes);
1055 words.push(marked);
1056 }
1057 } else {
1058 words.push(bytes);
1059 }
1060 }
1061
1062 let trie = TokTrie::from(&info, &words);
1063 Arc::new(ApproximateTokEnv::new(trie))
1064}
1065
1066fn collect_parsed_chat_message(
1067 handle: *mut llama_cpp_bindings_sys::llama_rs_parsed_chat,
1068) -> Result<ParsedChatMessage, ParseChatMessageError> {
1069 if handle.is_null() {
1070 return Ok(ParsedChatMessage::default());
1071 }
1072
1073 let content = read_owned_cstr_for_parse(unsafe {
1074 llama_cpp_bindings_sys::llama_rs_parsed_chat_content(handle)
1075 })?;
1076 let reasoning_content = read_owned_cstr_for_parse(unsafe {
1077 llama_cpp_bindings_sys::llama_rs_parsed_chat_reasoning_content(handle)
1078 })?;
1079
1080 let count = unsafe { llama_cpp_bindings_sys::llama_rs_parsed_chat_tool_call_count(handle) };
1081
1082 let mut tool_calls = Vec::with_capacity(count);
1083 for index in 0..count {
1084 let id = read_owned_cstr_for_parse(unsafe {
1085 llama_cpp_bindings_sys::llama_rs_parsed_chat_tool_call_id(handle, index)
1086 })?;
1087 let name = read_owned_cstr_for_parse(unsafe {
1088 llama_cpp_bindings_sys::llama_rs_parsed_chat_tool_call_name(handle, index)
1089 })?;
1090 let arguments_json = read_owned_cstr_for_parse(unsafe {
1091 llama_cpp_bindings_sys::llama_rs_parsed_chat_tool_call_arguments(handle, index)
1092 })?;
1093
1094 let arguments = ToolCallArguments::from_string(arguments_json);
1095 tool_calls.push(ParsedToolCall::new(id, name, arguments));
1096 }
1097
1098 Ok(ParsedChatMessage::new(
1099 content,
1100 reasoning_content,
1101 tool_calls,
1102 ))
1103}
1104
1105struct ReasoningSplit {
1106 reasoning: String,
1107 content: String,
1108}
1109
1110fn split_reasoning_prefix(
1111 input: &str,
1112 reasoning_markers: Option<&ReasoningMarkers>,
1113 tool_call_open: &str,
1114) -> ReasoningSplit {
1115 let content_only = || ReasoningSplit {
1116 reasoning: String::new(),
1117 content: prefix_before(input, tool_call_open),
1118 };
1119
1120 let Some(reasoning_markers) = reasoning_markers else {
1121 return content_only();
1122 };
1123 let Some(open_pos) = input.find(&reasoning_markers.open) else {
1124 return content_only();
1125 };
1126
1127 let after_open = &input[open_pos + reasoning_markers.open.len()..];
1128 let Some(close_offset) = after_open.find(&reasoning_markers.close) else {
1129 return content_only();
1130 };
1131
1132 let reasoning = after_open[..close_offset].to_owned();
1133 let after_close = &after_open[close_offset + reasoning_markers.close.len()..];
1134
1135 ReasoningSplit {
1136 reasoning,
1137 content: prefix_before(after_close, tool_call_open),
1138 }
1139}
1140
1141fn prefix_before(text: &str, marker: &str) -> String {
1142 text.find(marker)
1143 .map_or_else(|| text.to_owned(), |pos| text[..pos].to_owned())
1144}
1145
1146fn synthesize_missing_tool_call_ids(tool_calls: &mut [ParsedToolCall]) {
1147 for (index, call) in tool_calls.iter_mut().enumerate() {
1148 if call.id.is_empty() {
1149 call.id = format!("call_{index}");
1150 }
1151 }
1152}
1153
1154fn parse_single_string_status(
1155 status: llama_cpp_bindings_sys::llama_rs_status,
1156 out_value: *mut c_char,
1157 out_error: *mut c_char,
1158) -> Result<Option<String>, MarkerDetectionError> {
1159 match status {
1160 llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK => read_optional_owned_cstr(out_value),
1161 llama_cpp_bindings_sys::LLAMA_RS_STATUS_EXCEPTION => {
1162 let message = read_optional_owned_cstr_lossy(out_error);
1163
1164 Err(MarkerDetectionError::AnalyzeException(message))
1165 }
1166 other => Err(MarkerDetectionError::FfiError(status_to_i32(other))),
1167 }
1168}
1169
1170fn invoke_ffi_single_string_detector<TInvoke>(
1171 invoke: TInvoke,
1172) -> Result<Option<String>, MarkerDetectionError>
1173where
1174 TInvoke: FnOnce(*mut *mut c_char, *mut *mut c_char) -> llama_cpp_bindings_sys::llama_rs_status,
1175{
1176 let mut out_value: *mut c_char = ptr::null_mut();
1177 let mut out_error: *mut c_char = ptr::null_mut();
1178
1179 let status = invoke(&raw mut out_value, &raw mut out_error);
1180 let parsed = parse_single_string_status(status, out_value, out_error);
1181
1182 unsafe {
1183 if !out_value.is_null() {
1184 llama_cpp_bindings_sys::llama_rs_string_free(out_value);
1185 }
1186 if !out_error.is_null() {
1187 llama_cpp_bindings_sys::llama_rs_string_free(out_error);
1188 }
1189 }
1190
1191 parsed
1192}
1193
1194fn invoke_ffi_string_pair_detector<TInvoke>(
1195 invoke: TInvoke,
1196) -> Result<(Option<String>, Option<String>), MarkerDetectionError>
1197where
1198 TInvoke: FnOnce(
1199 *mut *mut c_char,
1200 *mut *mut c_char,
1201 *mut *mut c_char,
1202 ) -> llama_cpp_bindings_sys::llama_rs_status,
1203{
1204 let mut out_first: *mut c_char = ptr::null_mut();
1205 let mut out_second: *mut c_char = ptr::null_mut();
1206 let mut out_error: *mut c_char = ptr::null_mut();
1207
1208 let status = invoke(&raw mut out_first, &raw mut out_second, &raw mut out_error);
1209
1210 let parsed = (|| match status {
1211 llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK => {
1212 let first = read_optional_owned_cstr(out_first)?;
1213 let second = read_optional_owned_cstr(out_second)?;
1214
1215 Ok((first, second))
1216 }
1217 llama_cpp_bindings_sys::LLAMA_RS_STATUS_EXCEPTION => {
1218 let message = read_optional_owned_cstr_lossy(out_error);
1219
1220 Err(MarkerDetectionError::AnalyzeException(message))
1221 }
1222 other => Err(MarkerDetectionError::FfiError(status_to_i32(other))),
1223 })();
1224
1225 unsafe { llama_cpp_bindings_sys::llama_rs_string_free(out_first) };
1226 unsafe { llama_cpp_bindings_sys::llama_rs_string_free(out_second) };
1227 unsafe { llama_cpp_bindings_sys::llama_rs_string_free(out_error) };
1228
1229 parsed
1230}
1231
1232fn read_owned_cstr_for_parse(ptr: *mut c_char) -> Result<String, ParseChatMessageError> {
1233 if ptr.is_null() {
1234 return Ok(String::new());
1235 }
1236
1237 let bytes = unsafe { CStr::from_ptr(ptr) }.to_bytes().to_vec();
1238 unsafe { llama_cpp_bindings_sys::llama_rs_string_free(ptr) };
1239
1240 Ok(String::from_utf8(bytes)?)
1241}
1242
1243fn read_optional_owned_cstr(ptr: *const c_char) -> Result<Option<String>, MarkerDetectionError> {
1244 if ptr.is_null() {
1245 return Ok(None);
1246 }
1247
1248 let bytes = unsafe { CStr::from_ptr(ptr) }.to_bytes().to_vec();
1249
1250 Ok(Some(String::from_utf8(bytes)?))
1251}
1252
1253fn read_optional_owned_cstr_lossy(ptr: *const c_char) -> String {
1254 if ptr.is_null() {
1255 return String::new();
1256 }
1257
1258 unsafe { CStr::from_ptr(ptr) }
1259 .to_string_lossy()
1260 .into_owned()
1261}
1262
1263fn extract_meta_string<TCFunction>(
1264 c_function: TCFunction,
1265 capacity: usize,
1266) -> Result<String, MetaValError>
1267where
1268 TCFunction: Fn(*mut c_char, usize) -> i32,
1269{
1270 let mut buffer = vec![0u8; capacity];
1271 let result = c_function(buffer.as_mut_ptr().cast::<c_char>(), buffer.len());
1272
1273 if result < 0 {
1274 return Err(MetaValError::NegativeReturn(result));
1275 }
1276
1277 let returned_len = result.cast_unsigned() as usize;
1278
1279 if returned_len >= capacity {
1280 return extract_meta_string(c_function, returned_len + 1);
1281 }
1282
1283 if buffer.get(returned_len) != Some(&0) {
1284 return Err(MetaValError::NegativeReturn(-1));
1285 }
1286
1287 buffer.truncate(returned_len);
1288
1289 Ok(String::from_utf8(buffer)?)
1290}
1291
1292impl Drop for LlamaModel {
1293 fn drop(&mut self) {
1294 unsafe { llama_cpp_bindings_sys::llama_free_model(self.model.as_ptr()) }
1295 }
1296}
1297
1298#[cfg(test)]
1299mod extract_meta_string_tests {
1300 use super::extract_meta_string;
1301 use crate::MetaValError;
1302
1303 #[test]
1304 fn returns_error_when_null_terminator_missing() {
1305 let result = extract_meta_string(
1306 |buf_ptr, buf_len| {
1307 let buffer =
1308 unsafe { std::slice::from_raw_parts_mut(buf_ptr.cast::<u8>(), buf_len) };
1309 buffer[0] = b'a';
1310 buffer[1] = b'b';
1311 buffer[2] = b'c';
1312 2
1313 },
1314 4,
1315 );
1316
1317 assert_eq!(result.unwrap_err(), MetaValError::NegativeReturn(-1));
1318 }
1319
1320 #[test]
1321 fn returns_error_for_negative_return_value() {
1322 let result = extract_meta_string(|_buf_ptr, _buf_len| -5, 4);
1323
1324 assert_eq!(result.unwrap_err(), MetaValError::NegativeReturn(-5));
1325 }
1326
1327 #[test]
1328 fn returns_error_for_invalid_utf8_data() {
1329 let result = extract_meta_string(
1330 |buf_ptr, buf_len| {
1331 let buffer =
1332 unsafe { std::slice::from_raw_parts_mut(buf_ptr.cast::<u8>(), buf_len) };
1333 buffer[0] = 0xFF;
1334 buffer[1] = 0xFE;
1335 buffer[2] = 0;
1336 2
1337 },
1338 4,
1339 );
1340
1341 assert!(result.is_err());
1342 assert!(result.unwrap_err().to_string().contains("FromUtf8Error"));
1343 }
1344
1345 #[test]
1346 fn triggers_buffer_resize_when_returned_len_exceeds_capacity() {
1347 let initial_capacity: usize = 4;
1348 let length_exceeding_initial_capacity = 10;
1349 let written_length = 2;
1350 let call_count = std::cell::Cell::new(0);
1351 let result = extract_meta_string(
1352 |buf_ptr, buf_len| {
1353 let count = call_count.get();
1354 call_count.set(count + 1);
1355 if count == 0 {
1356 length_exceeding_initial_capacity
1357 } else {
1358 let buffer =
1359 unsafe { std::slice::from_raw_parts_mut(buf_ptr.cast::<u8>(), buf_len) };
1360 buffer[0] = b'h';
1361 buffer[1] = b'i';
1362 buffer[2] = 0;
1363 written_length
1364 }
1365 },
1366 initial_capacity,
1367 );
1368
1369 assert_eq!(result.unwrap(), "hi");
1370 }
1371
1372 #[test]
1373 fn cstring_with_validated_len_null_byte_returns_error() {
1374 let result = super::cstring_with_validated_len("null\0byte");
1375
1376 assert!(result.is_err());
1377 }
1378
1379 #[test]
1380 fn validate_string_length_overflow_returns_error() {
1381 let result = super::validate_string_length_for_tokenizer(usize::MAX);
1382
1383 assert!(result.is_err());
1384 }
1385
1386 #[test]
1387 fn truncated_buffer_to_string_with_invalid_utf8_returns_error() {
1388 let invalid_utf8 = vec![0xff, 0xfe, 0xfd];
1389 let result = super::truncated_buffer_to_string(invalid_utf8, 3);
1390
1391 assert!(result.is_err());
1392 }
1393}
1394
1395#[cfg(test)]
1396mod ffi_helper_tests {
1397 use std::ffi::CString;
1398 use std::ptr;
1399
1400 use super::invoke_ffi_single_string_detector;
1401 use super::invoke_ffi_string_pair_detector;
1402 use super::parse_single_string_status;
1403 use super::read_optional_owned_cstr_lossy;
1404 use crate::MarkerDetectionError;
1405
1406 #[test]
1407 fn read_optional_owned_cstr_lossy_returns_empty_for_null() {
1408 let result = read_optional_owned_cstr_lossy(ptr::null());
1409
1410 assert!(result.is_empty());
1411 }
1412
1413 #[test]
1414 fn read_optional_owned_cstr_lossy_returns_string_for_valid_pointer() {
1415 let owned = CString::new("hello").expect("static literal has no nuls");
1416 let result = read_optional_owned_cstr_lossy(owned.as_ptr());
1417
1418 assert_eq!(result, "hello");
1419 }
1420
1421 #[test]
1422 fn read_optional_owned_cstr_lossy_handles_invalid_utf8_via_replacement() {
1423 let owned = CString::new(vec![b'a', 0xFF, b'b']).expect("no interior nul");
1424 let result = read_optional_owned_cstr_lossy(owned.as_ptr());
1425
1426 assert!(result.starts_with('a'));
1427 assert!(result.ends_with('b'));
1428 }
1429
1430 #[test]
1431 fn parse_single_string_status_returns_none_for_ok_with_null() {
1432 let result = parse_single_string_status(
1433 llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK,
1434 ptr::null_mut(),
1435 ptr::null_mut(),
1436 );
1437
1438 assert_eq!(result.expect("OK + null returns Ok(None)"), None);
1439 }
1440
1441 #[test]
1442 fn parse_single_string_status_returns_some_for_ok_with_value() {
1443 let owned = CString::new("present").expect("no nul");
1444 let result = parse_single_string_status(
1445 llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK,
1446 owned.as_ptr().cast_mut(),
1447 ptr::null_mut(),
1448 );
1449
1450 assert_eq!(
1451 result.expect("OK + value returns Ok(Some)"),
1452 Some("present".to_owned())
1453 );
1454 }
1455
1456 #[test]
1457 fn parse_single_string_status_returns_analyze_exception() {
1458 let owned = CString::new("boom").expect("no nul");
1459 let result = parse_single_string_status(
1460 llama_cpp_bindings_sys::LLAMA_RS_STATUS_EXCEPTION,
1461 ptr::null_mut(),
1462 owned.as_ptr().cast_mut(),
1463 );
1464
1465 match result.expect_err("EXCEPTION must yield Err") {
1466 MarkerDetectionError::AnalyzeException(message) => assert_eq!(message, "boom"),
1467 other => panic!("expected AnalyzeException, got {other:?}"),
1468 }
1469 }
1470
1471 #[test]
1472 fn parse_single_string_status_returns_ffi_error_for_other_status() {
1473 let result = parse_single_string_status(
1474 llama_cpp_bindings_sys::LLAMA_RS_STATUS_INVALID_ARGUMENT,
1475 ptr::null_mut(),
1476 ptr::null_mut(),
1477 );
1478
1479 match result.expect_err("invalid status must yield Err") {
1480 MarkerDetectionError::FfiError(_) => {}
1481 other => panic!("expected FfiError, got {other:?}"),
1482 }
1483 }
1484
1485 #[test]
1486 fn invoke_ffi_single_string_detector_propagates_invalid_argument_status() {
1487 let result = invoke_ffi_single_string_detector(|_value, _error| {
1488 llama_cpp_bindings_sys::LLAMA_RS_STATUS_INVALID_ARGUMENT
1489 });
1490
1491 assert!(matches!(result, Err(MarkerDetectionError::FfiError(_))));
1492 }
1493
1494 #[test]
1495 fn invoke_ffi_single_string_detector_returns_none_for_ok_with_null() {
1496 let result = invoke_ffi_single_string_detector(|value, _error| {
1497 unsafe {
1498 *value = ptr::null_mut();
1499 }
1500 llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK
1501 });
1502
1503 assert_eq!(result.expect("OK + null returns Ok(None)"), None);
1504 }
1505
1506 #[test]
1507 fn invoke_ffi_string_pair_detector_propagates_invalid_argument_status() {
1508 let result = invoke_ffi_string_pair_detector(|_first, _second, _error| {
1509 llama_cpp_bindings_sys::LLAMA_RS_STATUS_INVALID_ARGUMENT
1510 });
1511
1512 assert!(matches!(result, Err(MarkerDetectionError::FfiError(_))));
1513 }
1514
1515 #[test]
1516 fn invoke_ffi_string_pair_detector_returns_pair_of_none_for_ok_with_nulls() {
1517 let result = invoke_ffi_string_pair_detector(|first, second, _error| {
1518 unsafe {
1519 *first = ptr::null_mut();
1520 *second = ptr::null_mut();
1521 }
1522 llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK
1523 });
1524
1525 assert_eq!(
1526 result.expect("OK with both null returns Ok((None, None))"),
1527 (None, None)
1528 );
1529 }
1530
1531 #[test]
1532 fn invoke_ffi_string_pair_detector_propagates_invalid_status_codes() {
1533 let result = invoke_ffi_string_pair_detector(|_first, _second, _error| {
1534 llama_cpp_bindings_sys::LLAMA_RS_STATUS_ALLOCATION_FAILED
1535 });
1536
1537 match result.expect_err("non-OK status yields Err") {
1538 MarkerDetectionError::FfiError(code) => assert!(code != 0),
1539 other => panic!("expected FfiError, got {other:?}"),
1540 }
1541 }
1542}