Skip to main content

llama_cpp_bindings/
model.rs

1//! A safe wrapper around `llama_model`.
2use std::ffi::{CStr, CString, c_char};
3use std::num::NonZeroU16;
4use std::os::raw::c_int;
5use std::path::Path;
6use std::sync::Arc;
7use std::sync::OnceLock;
8
9use toktrie::ApproximateTokEnv;
10use toktrie::TokRxInfo;
11use toktrie::TokTrie;
12
13fn truncated_buffer_to_string(
14    mut buffer: Vec<u8>,
15    length: usize,
16) -> Result<String, ApplyChatTemplateError> {
17    buffer.truncate(length);
18
19    Ok(String::from_utf8(buffer)?)
20}
21
22fn validate_string_length_for_tokenizer(length: usize) -> Result<c_int, StringToTokenError> {
23    Ok(c_int::try_from(length)?)
24}
25
26fn cstring_with_validated_len(str: &str) -> Result<(CString, c_int), StringToTokenError> {
27    let c_string = CString::new(str)?;
28    let len = validate_string_length_for_tokenizer(c_string.as_bytes().len())?;
29    Ok((c_string, len))
30}
31use std::ptr::{self, NonNull};
32
33use crate::chat_message_parse_outcome::ChatMessageParseOutcome;
34use crate::ffi_status_to_i32::status_to_i32;
35use crate::llama_backend::LlamaBackend;
36use crate::llama_token_attrs::LlamaTokenAttrs;
37use crate::llama_token_attrs_from_int_error::LlamaTokenAttrsFromIntError;
38use crate::raw_chat_message::RawChatMessage;
39use crate::resolved_tool_call_markers::ResolvedToolCallMarkers;
40use crate::sampled_token::SampledToken;
41use crate::sampled_token_classifier::SampledTokenClassifier;
42use crate::sampled_token_classifier::StreamingMarkers;
43use crate::token::LlamaToken;
44use crate::{
45    ApplyChatTemplateError, ChatTemplateError, LlamaLoraAdapterInitError, LlamaModelLoadError,
46    MarkerDetectionError, MetaValError, ParseChatMessageError, StringToTokenError,
47    TokenToStringError,
48};
49use llama_cpp_bindings_types::ParsedChatMessage;
50use llama_cpp_bindings_types::ParsedToolCall;
51use llama_cpp_bindings_types::ReasoningMarkers;
52use llama_cpp_bindings_types::ToolCallArguments;
53use llama_cpp_bindings_types::ToolCallMarkers;
54
55use crate::tool_call_format;
56use crate::tool_call_format::ToolCallFormatOutcome;
57use crate::tool_call_template_overrides;
58
59pub mod add_bos;
60pub mod llama_chat_message;
61pub mod llama_chat_template;
62pub mod llama_lora_adapter;
63pub mod params;
64pub mod rope_type;
65pub mod split_mode;
66pub mod vocab_type;
67pub mod vocab_type_from_int_error;
68
69pub use add_bos::AddBos;
70pub use llama_chat_message::LlamaChatMessage;
71pub use llama_chat_template::LlamaChatTemplate;
72pub use llama_lora_adapter::LlamaLoraAdapter;
73pub use rope_type::RopeType;
74pub use vocab_type::VocabType;
75pub use vocab_type_from_int_error::VocabTypeFromIntError;
76
77use params::LlamaModelParams;
78
79/// A safe wrapper around `llama_model`.
80pub struct LlamaModel {
81    /// Raw pointer to the underlying `llama_model`.
82    pub model: NonNull<llama_cpp_bindings_sys::llama_model>,
83    tok_env: OnceLock<Arc<ApproximateTokEnv>>,
84}
85
86impl std::fmt::Debug for LlamaModel {
87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88        f.debug_struct("LlamaModel")
89            .field("model", &self.model)
90            .finish_non_exhaustive()
91    }
92}
93
94unsafe impl Send for LlamaModel {}
95
96unsafe impl Sync for LlamaModel {}
97
98impl LlamaModel {
99    /// Returns a raw pointer to the model's vocabulary.
100    #[must_use]
101    pub fn vocab_ptr(&self) -> *const llama_cpp_bindings_sys::llama_vocab {
102        unsafe { llama_cpp_bindings_sys::llama_model_get_vocab(self.model.as_ptr()) }
103    }
104
105    /// Get the number of tokens the model was trained on.
106    ///
107    /// # Errors
108    ///
109    /// Returns an error if the value returned by llama.cpp does not fit into a `u32`.
110    pub fn n_ctx_train(&self) -> Result<u32, std::num::TryFromIntError> {
111        let n_ctx_train = unsafe { llama_cpp_bindings_sys::llama_n_ctx_train(self.model.as_ptr()) };
112
113        u32::try_from(n_ctx_train)
114    }
115
116    /// Get all tokens in the model.
117    pub fn tokens(
118        &self,
119        decode_special: bool,
120    ) -> impl Iterator<Item = (LlamaToken, Result<String, TokenToStringError>)> + '_ {
121        (0..self.n_vocab())
122            .map(LlamaToken::new)
123            .map(move |llama_token| {
124                let mut decoder = encoding_rs::UTF_8.new_decoder();
125                (
126                    llama_token,
127                    self.token_to_piece(
128                        &SampledToken::Content(llama_token),
129                        &mut decoder,
130                        decode_special,
131                        None,
132                    ),
133                )
134            })
135    }
136
137    /// Get the beginning of stream token.
138    #[must_use]
139    pub fn token_bos(&self) -> LlamaToken {
140        let token = unsafe { llama_cpp_bindings_sys::llama_token_bos(self.vocab_ptr()) };
141        LlamaToken(token)
142    }
143
144    /// Get the end of stream token.
145    #[must_use]
146    pub fn token_eos(&self) -> LlamaToken {
147        let token = unsafe { llama_cpp_bindings_sys::llama_token_eos(self.vocab_ptr()) };
148        LlamaToken(token)
149    }
150
151    /// Get the newline token.
152    #[must_use]
153    pub fn token_nl(&self) -> LlamaToken {
154        let token = unsafe { llama_cpp_bindings_sys::llama_token_nl(self.vocab_ptr()) };
155        LlamaToken(token)
156    }
157
158    /// Check if a token represents the end of generation (end of turn, end of sequence, etc.)
159    #[must_use]
160    pub fn is_eog_token(&self, token: &SampledToken) -> bool {
161        let (SampledToken::Content(LlamaToken(id))
162        | SampledToken::Reasoning(LlamaToken(id))
163        | SampledToken::ToolCall(LlamaToken(id))
164        | SampledToken::Undeterminable(LlamaToken(id))) = *token;
165
166        unsafe { llama_cpp_bindings_sys::llama_token_is_eog(self.vocab_ptr(), id) }
167    }
168
169    /// Get the decoder start token.
170    #[must_use]
171    pub fn decode_start_token(&self) -> LlamaToken {
172        let token =
173            unsafe { llama_cpp_bindings_sys::llama_model_decoder_start_token(self.model.as_ptr()) };
174        LlamaToken(token)
175    }
176
177    /// Get the separator token (SEP).
178    #[must_use]
179    pub fn token_sep(&self) -> LlamaToken {
180        let token = unsafe { llama_cpp_bindings_sys::llama_vocab_sep(self.vocab_ptr()) };
181        LlamaToken(token)
182    }
183
184    /// Convert a string to a Vector of tokens.
185    ///
186    /// # Errors
187    ///
188    /// - if [`str`] contains a null byte
189    /// - if an integer conversion fails during tokenization
190    ///
191    ///
192    /// ```no_run
193    /// use llama_cpp_bindings::model::LlamaModel;
194    ///
195    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
196    /// use std::path::Path;
197    /// use llama_cpp_bindings::model::AddBos;
198    /// let backend = llama_cpp_bindings::llama_backend::LlamaBackend::init()?;
199    /// let model = LlamaModel::load_from_file(&backend, Path::new("path/to/model"), &Default::default())?;
200    /// let tokens = model.str_to_token("Hello, World!", AddBos::Always)?;
201    /// # Ok(())
202    /// # }
203    pub fn str_to_token(
204        &self,
205        str: &str,
206        add_bos: AddBos,
207    ) -> Result<Vec<LlamaToken>, StringToTokenError> {
208        let add_bos = match add_bos {
209            AddBos::Always => true,
210            AddBos::Never => false,
211        };
212
213        let tokens_estimation = std::cmp::max(8, (str.len() / 2) + usize::from(add_bos));
214        let mut buffer: Vec<LlamaToken> = Vec::with_capacity(tokens_estimation);
215
216        let (c_string, c_string_len) = cstring_with_validated_len(str)?;
217        let buffer_capacity = c_int::try_from(buffer.capacity())?;
218
219        let size = unsafe {
220            llama_cpp_bindings_sys::llama_tokenize(
221                self.vocab_ptr(),
222                c_string.as_ptr(),
223                c_string_len,
224                buffer
225                    .as_mut_ptr()
226                    .cast::<llama_cpp_bindings_sys::llama_token>(),
227                buffer_capacity,
228                add_bos,
229                true,
230            )
231        };
232
233        let size = if size.is_negative() {
234            buffer.reserve_exact(usize::try_from(-size)?);
235            unsafe {
236                llama_cpp_bindings_sys::llama_tokenize(
237                    self.vocab_ptr(),
238                    c_string.as_ptr(),
239                    c_string_len,
240                    buffer
241                        .as_mut_ptr()
242                        .cast::<llama_cpp_bindings_sys::llama_token>(),
243                    -size,
244                    add_bos,
245                    true,
246                )
247            }
248        } else {
249            size
250        };
251
252        let size = usize::try_from(size)?;
253
254        // SAFETY: `size` < `capacity` and llama-cpp has initialized elements up to `size`
255        unsafe { buffer.set_len(size) }
256
257        Ok(buffer)
258    }
259
260    /// Get the type of a token.
261    ///
262    /// # Errors
263    ///
264    /// Returns an error if the token type is not known to this library.
265    pub fn token_attr(
266        &self,
267        LlamaToken(id): LlamaToken,
268    ) -> Result<LlamaTokenAttrs, LlamaTokenAttrsFromIntError> {
269        let token_type =
270            unsafe { llama_cpp_bindings_sys::llama_token_get_attr(self.vocab_ptr(), id) };
271
272        LlamaTokenAttrs::try_from(token_type)
273    }
274
275    /// Convert a token to a string using the underlying llama.cpp `llama_token_to_piece` function.
276    ///
277    /// This is the new default function for token decoding and provides direct access to
278    /// the llama.cpp token decoding functionality without any special logic or filtering.
279    ///
280    /// Decoding raw string requires using an decoder, tokens from language models may not always map
281    /// to full characters depending on the encoding so stateful decoding is required, otherwise partial strings may be lost!
282    /// Invalid characters are mapped to REPLACEMENT CHARACTER making the method safe to use even if the model inherently produces
283    /// garbage.
284    ///
285    /// # Errors
286    ///
287    /// - if the token type is unknown
288    ///
289    /// - if the returned size from llama.cpp does not fit into a `usize`
290    pub fn token_to_piece(
291        &self,
292        token: &SampledToken,
293        decoder: &mut encoding_rs::Decoder,
294        special: bool,
295        lstrip: Option<NonZeroU16>,
296    ) -> Result<String, TokenToStringError> {
297        let (SampledToken::Content(inner)
298        | SampledToken::Reasoning(inner)
299        | SampledToken::ToolCall(inner)
300        | SampledToken::Undeterminable(inner)) = *token;
301        let bytes = match self.token_to_piece_bytes(inner, 8, special, lstrip) {
302            Err(TokenToStringError::InsufficientBufferSpace(required_size)) => {
303                let buffer_size: usize = (-required_size).try_into()?;
304
305                self.token_to_piece_bytes(inner, buffer_size, special, lstrip)
306            }
307            other => other,
308        }?;
309
310        let mut output_piece = String::with_capacity(bytes.len());
311        let (_result, _decoded_size, _had_replacements) =
312            decoder.decode_to_string(&bytes, &mut output_piece, false);
313
314        Ok(output_piece)
315    }
316
317    /// Raw token decoding to bytes, use if you want to handle the decoding model output yourself
318    ///
319    /// Convert a token to bytes using the underlying llama.cpp `llama_token_to_piece` function. This is mostly
320    /// a thin wrapper around `llama_token_to_piece` function, that handles rust <-> c type conversions while
321    /// letting the caller handle errors. For a safer interface returning rust strings directly use `token_to_piece` instead!
322    ///
323    /// # Errors
324    ///
325    /// - if the token type is unknown
326    /// - the resultant token is larger than `buffer_size`.
327    /// - if an integer conversion fails
328    pub fn token_to_piece_bytes(
329        &self,
330        token: LlamaToken,
331        buffer_size: usize,
332        special: bool,
333        lstrip: Option<NonZeroU16>,
334    ) -> Result<Vec<u8>, TokenToStringError> {
335        let mut buffer: Vec<u8> = vec![0u8; buffer_size];
336        let buffer_len = c_int::try_from(buffer.len())?;
337        let lstrip = lstrip.map_or(0, |strip_count| i32::from(strip_count.get()));
338        let size = unsafe {
339            llama_cpp_bindings_sys::llama_token_to_piece(
340                self.vocab_ptr(),
341                token.0,
342                buffer.as_mut_ptr().cast::<c_char>(),
343                buffer_len,
344                lstrip,
345                special,
346            )
347        };
348
349        match size {
350            0 => Err(TokenToStringError::UnknownTokenType),
351            error_code if error_code.is_negative() => {
352                Err(TokenToStringError::InsufficientBufferSpace(error_code))
353            }
354            size => {
355                let written = usize::try_from(size)?;
356                buffer.truncate(written);
357
358                Ok(buffer)
359            }
360        }
361    }
362
363    /// The number of tokens the model was trained on.
364    ///
365    /// This returns a `c_int` for maximum compatibility. Most of the time it can be cast to an i32
366    /// without issue.
367    #[must_use]
368    pub fn n_vocab(&self) -> i32 {
369        unsafe { llama_cpp_bindings_sys::llama_n_vocab(self.vocab_ptr()) }
370    }
371
372    /// The type of vocab the model was trained on.
373    ///
374    /// # Errors
375    ///
376    /// Returns an error if llama.cpp emits a vocab type that is not known to this library.
377    pub fn vocab_type(&self) -> Result<VocabType, VocabTypeFromIntError> {
378        let vocab_type = unsafe { llama_cpp_bindings_sys::llama_vocab_type(self.vocab_ptr()) };
379
380        VocabType::try_from(vocab_type)
381    }
382
383    /// This returns a `c_int` for maximum compatibility. Most of the time it can be cast to an i32
384    /// without issue.
385    #[must_use]
386    pub fn n_embd(&self) -> c_int {
387        unsafe { llama_cpp_bindings_sys::llama_n_embd(self.model.as_ptr()) }
388    }
389
390    /// Returns the total size of all the tensors in the model in bytes.
391    #[must_use]
392    pub fn size(&self) -> u64 {
393        unsafe { llama_cpp_bindings_sys::llama_model_size(self.model.as_ptr()) }
394    }
395
396    /// Returns the number of parameters in the model.
397    #[must_use]
398    pub fn n_params(&self) -> u64 {
399        unsafe { llama_cpp_bindings_sys::llama_model_n_params(self.model.as_ptr()) }
400    }
401
402    /// Returns whether the model is a recurrent network (Mamba, RWKV, etc)
403    #[must_use]
404    pub fn is_recurrent(&self) -> bool {
405        unsafe { llama_cpp_bindings_sys::llama_model_is_recurrent(self.model.as_ptr()) }
406    }
407
408    /// Returns the number of layers within the model.
409    ///
410    /// # Errors
411    ///
412    /// Returns an error if the layer count returned by llama.cpp does not fit into a `u32`.
413    pub fn n_layer(&self) -> Result<u32, std::num::TryFromIntError> {
414        u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_layer(self.model.as_ptr()) })
415    }
416
417    /// Returns the number of attention heads within the model.
418    ///
419    /// # Errors
420    ///
421    /// Returns an error if the head count returned by llama.cpp does not fit into a `u32`.
422    pub fn n_head(&self) -> Result<u32, std::num::TryFromIntError> {
423        u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_head(self.model.as_ptr()) })
424    }
425
426    /// Returns the number of KV attention heads.
427    ///
428    /// # Errors
429    ///
430    /// Returns an error if the KV head count returned by llama.cpp does not fit into a `u32`.
431    pub fn n_head_kv(&self) -> Result<u32, std::num::TryFromIntError> {
432        u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_head_kv(self.model.as_ptr()) })
433    }
434
435    /// Returns whether the model is a hybrid network (Jamba, Granite, Qwen3xx, etc.)
436    ///
437    /// Hybrid models have both attention layers and recurrent/SSM layers.
438    #[must_use]
439    pub fn is_hybrid(&self) -> bool {
440        unsafe { llama_cpp_bindings_sys::llama_model_is_hybrid(self.model.as_ptr()) }
441    }
442
443    /// Get metadata value as a string by key name
444    ///
445    /// # Errors
446    /// Returns an error if the key is not found or the value is not valid UTF-8.
447    pub fn meta_val_str(&self, key: &str) -> Result<String, MetaValError> {
448        let key_cstring = CString::new(key)?;
449        let key_ptr = key_cstring.as_ptr();
450
451        extract_meta_string(
452            |buf_ptr, buf_len| unsafe {
453                llama_cpp_bindings_sys::llama_model_meta_val_str(
454                    self.model.as_ptr(),
455                    key_ptr,
456                    buf_ptr,
457                    buf_len,
458                )
459            },
460            256,
461        )
462    }
463
464    /// Get the number of metadata key/value pairs
465    #[must_use]
466    pub fn meta_count(&self) -> i32 {
467        unsafe { llama_cpp_bindings_sys::llama_model_meta_count(self.model.as_ptr()) }
468    }
469
470    /// Get metadata key name by index
471    ///
472    /// # Errors
473    /// Returns an error if the index is out of range or the key is not valid UTF-8.
474    pub fn meta_key_by_index(&self, index: i32) -> Result<String, MetaValError> {
475        extract_meta_string(
476            |buf_ptr, buf_len| unsafe {
477                llama_cpp_bindings_sys::llama_model_meta_key_by_index(
478                    self.model.as_ptr(),
479                    index,
480                    buf_ptr,
481                    buf_len,
482                )
483            },
484            256,
485        )
486    }
487
488    /// Get metadata value as a string by index
489    ///
490    /// # Errors
491    /// Returns an error if the index is out of range or the value is not valid UTF-8.
492    pub fn meta_val_str_by_index(&self, index: i32) -> Result<String, MetaValError> {
493        extract_meta_string(
494            |buf_ptr, buf_len| unsafe {
495                llama_cpp_bindings_sys::llama_model_meta_val_str_by_index(
496                    self.model.as_ptr(),
497                    index,
498                    buf_ptr,
499                    buf_len,
500                )
501            },
502            256,
503        )
504    }
505
506    /// Returns the rope type of the model.
507    #[must_use]
508    pub fn rope_type(&self) -> Option<RopeType> {
509        let raw = unsafe { llama_cpp_bindings_sys::llama_model_rope_type(self.model.as_ptr()) };
510
511        rope_type::rope_type_from_raw(raw)
512    }
513
514    /// Get chat template from model by name. If the name parameter is None, the default chat template will be returned.
515    ///
516    /// You supply this into [`Self::apply_chat_template`] to get back a string with the appropriate template
517    /// substitution applied to convert a list of messages into a prompt the LLM can use to complete
518    /// the chat.
519    ///
520    /// You could also use an external jinja parser, like [minijinja](https://github.com/mitsuhiko/minijinja),
521    /// to parse jinja templates not supported by the llama.cpp template engine.
522    ///
523    /// # Errors
524    ///
525    /// * If the model has no chat template by that name
526    ///
527    /// # Panics
528    ///
529    /// Panics if the C-returned chat template string contains interior null bytes
530    /// (should never happen with valid model data).
531    pub fn chat_template(
532        &self,
533        name: Option<&str>,
534    ) -> Result<LlamaChatTemplate, ChatTemplateError> {
535        let name_cstr = name.map(CString::new);
536        let name_ptr = match name_cstr {
537            Some(Ok(name)) => name.as_ptr(),
538            _ => ptr::null(),
539        };
540        let result = unsafe {
541            llama_cpp_bindings_sys::llama_model_chat_template(self.model.as_ptr(), name_ptr)
542        };
543
544        if result.is_null() {
545            Err(ChatTemplateError::MissingTemplate)
546        } else {
547            let chat_template_cstr = unsafe { CStr::from_ptr(result) };
548
549            Ok(LlamaChatTemplate(chat_template_cstr.to_owned()))
550        }
551    }
552
553    /// Loads a model from a file.
554    ///
555    /// # Errors
556    ///
557    /// See [`LlamaModelLoadError`] for more information.
558    ///
559    /// # Panics
560    ///
561    /// Panics if a valid UTF-8 path somehow contains interior null bytes (should never happen).
562    #[tracing::instrument(skip_all, fields(params))]
563    pub fn load_from_file(
564        _: &LlamaBackend,
565        path: impl AsRef<Path>,
566        params: &LlamaModelParams,
567    ) -> Result<Self, LlamaModelLoadError> {
568        let path = path.as_ref();
569
570        let path_str = path
571            .to_str()
572            .ok_or_else(|| LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
573
574        if !path.exists() {
575            return Err(LlamaModelLoadError::FileNotFound(path.to_path_buf()));
576        }
577
578        let cstr = CString::new(path_str)?;
579        let llama_model = unsafe {
580            llama_cpp_bindings_sys::llama_load_model_from_file(cstr.as_ptr(), params.params)
581        };
582
583        let model = match NonNull::new(llama_model) {
584            Some(ptr) => ptr,
585            None if !path.exists() => {
586                return Err(LlamaModelLoadError::FileNotFound(path.to_path_buf()));
587            }
588            None => return Err(LlamaModelLoadError::NullResult),
589        };
590
591        Ok(Self {
592            model,
593            tok_env: OnceLock::new(),
594        })
595    }
596
597    /// Initializes a lora adapter from a file.
598    ///
599    /// # Errors
600    ///
601    /// See [`LlamaLoraAdapterInitError`] for more information.
602    pub fn lora_adapter_init(
603        &self,
604        path: impl AsRef<Path>,
605    ) -> Result<LlamaLoraAdapter, LlamaLoraAdapterInitError> {
606        let path = path.as_ref();
607
608        let path_str = path
609            .to_str()
610            .ok_or_else(|| LlamaLoraAdapterInitError::PathToStrError(path.to_path_buf()))?;
611
612        if !path.exists() {
613            return Err(LlamaLoraAdapterInitError::FileNotFound(path.to_path_buf()));
614        }
615
616        let cstr = CString::new(path_str)?;
617        let raw_adapter = unsafe {
618            llama_cpp_bindings_sys::llama_adapter_lora_init(self.model.as_ptr(), cstr.as_ptr())
619        };
620
621        let Some(adapter) = NonNull::new(raw_adapter) else {
622            return Err(LlamaLoraAdapterInitError::NullResult);
623        };
624
625        Ok(LlamaLoraAdapter {
626            lora_adapter: adapter,
627        })
628    }
629
630    /// Apply the models chat template to some messages.
631    /// See <https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template>
632    ///
633    /// Unlike the llama.cpp `apply_chat_template` which just randomly uses the `ChatML` template when given
634    /// a null pointer for the template, this requires an explicit template to be specified. If you want to
635    /// use "chatml", then just do `LlamaChatTemplate::new("chatml")` or any other model name or template
636    /// string.
637    ///
638    /// Use [`Self::chat_template`] to retrieve the template baked into the model (this is the preferred
639    /// mechanism as using the wrong chat template can result in really unexpected responses from the LLM).
640    ///
641    /// You probably want to set `add_ass` to true so that the generated template string ends with a the
642    /// opening tag of the assistant. If you fail to leave a hanging chat tag, the model will likely generate
643    /// one into the output and the output may also have unexpected output aside from that.
644    ///
645    /// # Errors
646    /// There are many ways this can fail. See [`ApplyChatTemplateError`] for more information.
647    #[tracing::instrument(skip_all)]
648    pub fn apply_chat_template(
649        &self,
650        tmpl: &LlamaChatTemplate,
651        chat: &[LlamaChatMessage],
652        add_ass: bool,
653    ) -> Result<String, ApplyChatTemplateError> {
654        let message_length = chat.iter().fold(0, |acc, chat_message| {
655            acc + chat_message.role.to_bytes().len() + chat_message.content.to_bytes().len()
656        });
657        let mut buff: Vec<u8> = vec![0; message_length * 2];
658
659        let chat: Vec<llama_cpp_bindings_sys::llama_chat_message> = chat
660            .iter()
661            .map(|chat_message| llama_cpp_bindings_sys::llama_chat_message {
662                role: chat_message.role.as_ptr(),
663                content: chat_message.content.as_ptr(),
664            })
665            .collect();
666
667        let tmpl_ptr = tmpl.0.as_ptr();
668
669        let buff_len: i32 = buff.len().try_into()?;
670
671        let res = unsafe {
672            llama_cpp_bindings_sys::llama_chat_apply_template(
673                tmpl_ptr,
674                chat.as_ptr(),
675                chat.len(),
676                add_ass,
677                buff.as_mut_ptr().cast::<c_char>(),
678                buff_len,
679            )
680        };
681
682        if res > buff_len {
683            let required_size: usize = res.try_into()?;
684            buff.resize(required_size, 0);
685
686            let new_buff_len: i32 = buff.len().try_into()?;
687
688            let res = unsafe {
689                llama_cpp_bindings_sys::llama_chat_apply_template(
690                    tmpl_ptr,
691                    chat.as_ptr(),
692                    chat.len(),
693                    add_ass,
694                    buff.as_mut_ptr().cast::<c_char>(),
695                    new_buff_len,
696                )
697            };
698            let final_size: usize = res.try_into()?;
699
700            return truncated_buffer_to_string(buff, final_size);
701        }
702
703        let final_size: usize = res.try_into()?;
704
705        truncated_buffer_to_string(buff, final_size)
706    }
707
708    /// Build a streaming [`SampledTokenClassifier`] for this model.
709    ///
710    /// At construction the bindings detect reasoning markers (via the
711    /// autoparser, with a chunked-thinking fallback for templates that consume
712    /// thoughts via content blocks), tool-call markers, and the trailing
713    /// generation-prompt slice. The classifier then runs a state machine over
714    /// the decoded token stream — no per-model branches.
715    ///
716    /// If the model has no usable chat template the classifier is built in a
717    /// blind mode that classifies every token as
718    /// [`SampledToken::Undeterminable`].
719    pub fn sampled_token_classifier(&self) -> SampledTokenClassifier<'_> {
720        let markers = match self.streaming_markers() {
721            Ok(markers) => markers,
722            Err(detection_error) => {
723                tracing::warn!(
724                    "streaming markers detection failed; classifier will run blind: {detection_error}"
725                );
726                StreamingMarkers::default()
727            }
728        };
729
730        SampledTokenClassifier::new(self, markers)
731    }
732
733    /// Detect reasoning / tool-call markers (as token-ID sequences) and the
734    /// trailing generation-prompt slice for this model's chat template. The
735    /// returned `StreamingMarkers` carry tokenised markers — never raw strings
736    /// — so the classifier matches by `LlamaToken` equality rather than text
737    /// scanning.
738    ///
739    /// # Errors
740    /// Returns [`MarkerDetectionError`] when any underlying FFI call fails.
741    pub fn streaming_markers(&self) -> Result<StreamingMarkers, MarkerDetectionError> {
742        let (reasoning_open_str, reasoning_close_str) =
743            invoke_ffi_string_pair_detector(|first, second, error| unsafe {
744                llama_cpp_bindings_sys::llama_rs_detect_reasoning_markers(
745                    self.model.as_ptr(),
746                    first,
747                    second,
748                    error,
749                )
750            })?;
751
752        let tool_call_haystack = invoke_ffi_single_string_detector(|haystack, error| unsafe {
753            llama_cpp_bindings_sys::llama_rs_compute_tool_call_haystack(
754                self.model.as_ptr(),
755                haystack,
756                error,
757            )
758        })?;
759
760        let autoparser_pair = tool_call_haystack.as_deref().and_then(
761            crate::extract_tool_call_markers_from_haystack::extract_tool_call_markers_from_haystack,
762        );
763
764        let (autoparser_open, autoparser_close) = match autoparser_pair {
765            Some(crate::tool_call_marker_pair::ToolCallMarkerPair { open, close }) => {
766                (Some(open), Some(close))
767            }
768            None => (None, None),
769        };
770
771        let resolved_tool_call_markers =
772            self.resolve_tool_call_marker_strings(autoparser_open, autoparser_close);
773
774        Ok(StreamingMarkers {
775            reasoning_open: self.tokenize_marker(reasoning_open_str.as_deref()),
776            reasoning_close: self.tokenize_marker(reasoning_close_str.as_deref()),
777            tool_call_open: self.tokenize_marker(resolved_tool_call_markers.open.as_deref()),
778            tool_call_close: self.tokenize_marker(resolved_tool_call_markers.close.as_deref()),
779        })
780    }
781
782    /// When the autoparser-driven FFI returned no tool-call markers, consult the
783    /// per-template override registry so wrapper-known templates (Gemma 4,
784    /// Mistral 3, ...) still drive the classifier.
785    fn resolve_tool_call_marker_strings(
786        &self,
787        autoparser_open: Option<String>,
788        autoparser_close: Option<String>,
789    ) -> ResolvedToolCallMarkers {
790        if autoparser_open
791            .as_deref()
792            .is_some_and(|raw| !raw.trim().is_empty())
793        {
794            return ResolvedToolCallMarkers {
795                open: autoparser_open,
796                close: autoparser_close,
797            };
798        }
799        let Some(markers) = self.tool_call_markers() else {
800            return ResolvedToolCallMarkers {
801                open: autoparser_open,
802                close: autoparser_close,
803            };
804        };
805        let close = if markers.close.is_empty() {
806            None
807        } else {
808            Some(markers.close)
809        };
810        ResolvedToolCallMarkers {
811            open: Some(markers.open),
812            close,
813        }
814    }
815
816    /// # Errors
817    /// Returns [`MarkerDetectionError`] when the underlying FFI call fails.
818    pub fn reasoning_markers(&self) -> Result<Option<ReasoningMarkers>, MarkerDetectionError> {
819        let (open, close) = invoke_ffi_string_pair_detector(|first, second, error| unsafe {
820            llama_cpp_bindings_sys::llama_rs_detect_reasoning_markers(
821                self.model.as_ptr(),
822                first,
823                second,
824                error,
825            )
826        })?;
827
828        match (open, close) {
829            (Some(open), Some(close)) if !open.is_empty() && !close.is_empty() => {
830                Ok(Some(ReasoningMarkers { open, close }))
831            }
832            _ => Ok(None),
833        }
834    }
835
836    /// Returns the rich tool-call marker bundle (open / separator / close /
837    /// optional value-quote pair) for this model's chat template, sourced from
838    /// the wrapper's per-template override registry. Returns `None` when no
839    /// registered override matches — callers in that case fall back to
840    /// llama.cpp's autoparser via [`Self::parse_chat_message`].
841    #[must_use]
842    pub fn tool_call_markers(&self) -> Option<ToolCallMarkers> {
843        let template = match self.chat_template(None) {
844            Ok(template) => template,
845            Err(error) => {
846                tracing::debug!(
847                    "tool-call markers unavailable: chat template missing or invalid: {error}"
848                );
849                return None;
850            }
851        };
852        let template_str = match template.to_str() {
853            Ok(template_str) => template_str,
854            Err(error) => {
855                tracing::debug!(
856                    "tool-call markers unavailable: chat template is not valid UTF-8: {error}"
857                );
858                return None;
859            }
860        };
861        tool_call_template_overrides::detect(template_str)
862    }
863
864    fn tokenize_marker(&self, marker: Option<&str>) -> Option<Vec<LlamaToken>> {
865        let marker = marker?.trim();
866        if marker.is_empty() {
867            return None;
868        }
869        match self.str_to_token(marker, AddBos::Never) {
870            Ok(tokens) if !tokens.is_empty() => Some(tokens),
871            Ok(_) => None,
872            Err(tokenize_error) => {
873                tracing::debug!(
874                    "marker {marker:?} failed to tokenise; classifier will ignore it: {tokenize_error}"
875                );
876                None
877            }
878        }
879    }
880
881    /// Parse the assistant's output text into structured content, reasoning,
882    /// and tool calls.
883    ///
884    /// Two passes, in order:
885    /// 1. Duck-type the wrapper-side parsers across every known shape
886    ///    (Qwen XML, GLM key-value, Gemma paired-quote, Mistral bracketed-JSON).
887    ///    First match wins. The shapes are ordered so that more restrictive
888    ///    shapes run first, which keeps the duck-type pass safe for inputs
889    ///    that share an open marker but differ in inner structure.
890    /// 2. Delegate to llama.cpp's `common_chat_parse`. If it succeeds the
891    ///    result is `Recognized`; if it throws `ParseException` the result is
892    ///    `Unrecognized` with the raw input plus the FFI's diagnostic, so the
893    ///    caller can pass the unstructured tokens to the client.
894    ///
895    /// Empty tool-call `id` fields are filled with `call_{index}` before
896    /// returning, so callers always see well-formed identifiers.
897    ///
898    /// `tools_json` is a JSON-array string of OpenAI-style tool definitions
899    /// (use `"[]"` when no tools are in scope). `is_partial` switches between
900    /// mid-stream (lenient) and final (strict) parses for the FFI step.
901    ///
902    /// # Errors
903    ///
904    /// Returns [`ParseChatMessageError`] when `tools_json` is not valid JSON,
905    /// the FFI returns a non-OK status other than `ParseException`, or
906    /// accessor strings are not valid UTF-8.
907    pub fn parse_chat_message(
908        &self,
909        tools_json: &str,
910        input: &str,
911        is_partial: bool,
912    ) -> Result<ChatMessageParseOutcome, ParseChatMessageError> {
913        let tools_value: serde_json::Value =
914            serde_json::from_str(tools_json).map_err(ParseChatMessageError::ToolsJsonInvalid)?;
915        if !tools_value.is_array() {
916            return Err(ParseChatMessageError::ToolsJsonNotArray);
917        }
918
919        let reasoning_markers = self.reasoning_markers().ok().flatten();
920
921        for candidate in tool_call_template_overrides::known_marker_candidates() {
922            if let ToolCallFormatOutcome::Parsed(calls) =
923                tool_call_format::try_parse(input, &candidate)
924            {
925                let split =
926                    split_reasoning_prefix(input, reasoning_markers.as_ref(), &candidate.open);
927                let mut parsed = ParsedChatMessage::new(split.content, split.reasoning, calls);
928                synthesize_missing_tool_call_ids(&mut parsed.tool_calls);
929                return Ok(ChatMessageParseOutcome::Recognized(parsed));
930            }
931        }
932
933        match self.parse_chat_message_via_ffi(tools_json, input, is_partial) {
934            Ok(mut parsed) => {
935                synthesize_missing_tool_call_ids(&mut parsed.tool_calls);
936                Ok(ChatMessageParseOutcome::Recognized(parsed))
937            }
938            Err(ParseChatMessageError::ParseException(ffi_error_message)) => {
939                Ok(ChatMessageParseOutcome::Unrecognized(RawChatMessage {
940                    tools_json: tools_json.to_owned(),
941                    text: input.to_owned(),
942                    is_partial,
943                    ffi_error_message,
944                }))
945            }
946            Err(other) => Err(other),
947        }
948    }
949
950    fn parse_chat_message_via_ffi(
951        &self,
952        tools_json: &str,
953        input: &str,
954        is_partial: bool,
955    ) -> Result<ParsedChatMessage, ParseChatMessageError> {
956        let tools_cstring = CString::new(tools_json)
957            .map_err(|err| ParseChatMessageError::ToolsSerialization(err.to_string()))?;
958        let input_cstring = CString::new(input)
959            .map_err(|err| ParseChatMessageError::ToolsSerialization(err.to_string()))?;
960
961        let mut handle: *mut llama_cpp_bindings_sys::llama_rs_parsed_chat = ptr::null_mut();
962        let mut out_error: *mut c_char = ptr::null_mut();
963
964        let status = unsafe {
965            llama_cpp_bindings_sys::llama_rs_parse_chat_message(
966                self.model.as_ptr(),
967                tools_cstring.as_ptr(),
968                input_cstring.as_ptr(),
969                i32::from(is_partial),
970                &raw mut handle,
971                &raw mut out_error,
972            )
973        };
974
975        let parsed = match status {
976            llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK => collect_parsed_chat_message(handle),
977            llama_cpp_bindings_sys::LLAMA_RS_STATUS_EXCEPTION => {
978                let message = read_optional_owned_cstr_lossy(out_error);
979                Err(ParseChatMessageError::ParseException(message))
980            }
981            other => Err(ParseChatMessageError::FfiError(status_to_i32(other))),
982        };
983
984        unsafe { llama_cpp_bindings_sys::llama_rs_parsed_chat_free(handle) };
985        unsafe { llama_cpp_bindings_sys::llama_rs_string_free(out_error) };
986
987        parsed
988    }
989
990    /// Render the model's chat template with the autoparser's synthetic
991    /// no-tools and with-tools inputs. Returns `(output_no_tools,
992    /// output_with_tools)`. Either side can be empty when the template throws
993    /// during rendering. Useful for debugging tool-call marker detection.
994    ///
995    /// # Errors
996    ///
997    /// Returns [`MarkerDetectionError`] when the C++ analyzer throws or the FFI
998    /// returns a non-OK status.
999    pub fn diagnose_tool_call_synthetic_renders(
1000        &self,
1001    ) -> Result<(String, String), MarkerDetectionError> {
1002        let (no_tools, with_tools) =
1003            invoke_ffi_string_pair_detector(|first, second, error| unsafe {
1004                llama_cpp_bindings_sys::llama_rs_diagnose_tool_call_synthetic_renders(
1005                    self.model.as_ptr(),
1006                    first,
1007                    second,
1008                    error,
1009                )
1010            })?;
1011
1012        Ok((no_tools.unwrap_or_default(), with_tools.unwrap_or_default()))
1013    }
1014}
1015
1016impl LlamaModel {
1017    /// Returns a process-cached, approximate token environment built from this model's vocabulary.
1018    ///
1019    /// The first call iterates the full vocabulary and constructs the trie; subsequent calls
1020    /// return the cached `Arc` without further FFI work.
1021    pub fn approximate_tok_env(&self) -> Arc<ApproximateTokEnv> {
1022        Arc::clone(self.tok_env.get_or_init(|| build_approximate_tok_env(self)))
1023    }
1024}
1025
1026fn build_approximate_tok_env(model: &LlamaModel) -> Arc<ApproximateTokEnv> {
1027    let n_vocab = model.n_vocab().cast_unsigned();
1028    let tok_eos = {
1029        let eot = unsafe { llama_cpp_bindings_sys::llama_vocab_eot(model.vocab_ptr()) };
1030        if eot == -1 {
1031            model.token_eos().0.cast_unsigned()
1032        } else {
1033            eot.cast_unsigned()
1034        }
1035    };
1036    let info = TokRxInfo::new(n_vocab, tok_eos);
1037
1038    let mut words = Vec::with_capacity(n_vocab as usize);
1039
1040    for token_id in 0..n_vocab.cast_signed() {
1041        let token = LlamaToken(token_id);
1042        let bytes = model
1043            .token_to_piece_bytes(token, 32, false, None)
1044            .unwrap_or_default();
1045        if bytes.is_empty() {
1046            let special_bytes = model
1047                .token_to_piece_bytes(token, 32, true, None)
1048                .unwrap_or_default();
1049            if special_bytes.is_empty() {
1050                words.push(vec![]);
1051            } else {
1052                let mut marked = Vec::with_capacity(special_bytes.len() + 1);
1053                marked.push(0xFF);
1054                marked.extend(special_bytes);
1055                words.push(marked);
1056            }
1057        } else {
1058            words.push(bytes);
1059        }
1060    }
1061
1062    let trie = TokTrie::from(&info, &words);
1063    Arc::new(ApproximateTokEnv::new(trie))
1064}
1065
1066fn collect_parsed_chat_message(
1067    handle: *mut llama_cpp_bindings_sys::llama_rs_parsed_chat,
1068) -> Result<ParsedChatMessage, ParseChatMessageError> {
1069    if handle.is_null() {
1070        return Ok(ParsedChatMessage::default());
1071    }
1072
1073    let content = read_owned_cstr_for_parse(unsafe {
1074        llama_cpp_bindings_sys::llama_rs_parsed_chat_content(handle)
1075    })?;
1076    let reasoning_content = read_owned_cstr_for_parse(unsafe {
1077        llama_cpp_bindings_sys::llama_rs_parsed_chat_reasoning_content(handle)
1078    })?;
1079
1080    let count = unsafe { llama_cpp_bindings_sys::llama_rs_parsed_chat_tool_call_count(handle) };
1081
1082    let mut tool_calls = Vec::with_capacity(count);
1083    for index in 0..count {
1084        let id = read_owned_cstr_for_parse(unsafe {
1085            llama_cpp_bindings_sys::llama_rs_parsed_chat_tool_call_id(handle, index)
1086        })?;
1087        let name = read_owned_cstr_for_parse(unsafe {
1088            llama_cpp_bindings_sys::llama_rs_parsed_chat_tool_call_name(handle, index)
1089        })?;
1090        let arguments_json = read_owned_cstr_for_parse(unsafe {
1091            llama_cpp_bindings_sys::llama_rs_parsed_chat_tool_call_arguments(handle, index)
1092        })?;
1093
1094        let arguments = ToolCallArguments::from_string(arguments_json);
1095        tool_calls.push(ParsedToolCall::new(id, name, arguments));
1096    }
1097
1098    Ok(ParsedChatMessage::new(
1099        content,
1100        reasoning_content,
1101        tool_calls,
1102    ))
1103}
1104
1105struct ReasoningSplit {
1106    reasoning: String,
1107    content: String,
1108}
1109
1110fn split_reasoning_prefix(
1111    input: &str,
1112    reasoning_markers: Option<&ReasoningMarkers>,
1113    tool_call_open: &str,
1114) -> ReasoningSplit {
1115    let content_only = || ReasoningSplit {
1116        reasoning: String::new(),
1117        content: prefix_before(input, tool_call_open),
1118    };
1119
1120    let Some(reasoning_markers) = reasoning_markers else {
1121        return content_only();
1122    };
1123    let Some(open_pos) = input.find(&reasoning_markers.open) else {
1124        return content_only();
1125    };
1126
1127    let after_open = &input[open_pos + reasoning_markers.open.len()..];
1128    let Some(close_offset) = after_open.find(&reasoning_markers.close) else {
1129        return content_only();
1130    };
1131
1132    let reasoning = after_open[..close_offset].to_owned();
1133    let after_close = &after_open[close_offset + reasoning_markers.close.len()..];
1134
1135    ReasoningSplit {
1136        reasoning,
1137        content: prefix_before(after_close, tool_call_open),
1138    }
1139}
1140
1141fn prefix_before(text: &str, marker: &str) -> String {
1142    text.find(marker)
1143        .map_or_else(|| text.to_owned(), |pos| text[..pos].to_owned())
1144}
1145
1146fn synthesize_missing_tool_call_ids(tool_calls: &mut [ParsedToolCall]) {
1147    for (index, call) in tool_calls.iter_mut().enumerate() {
1148        if call.id.is_empty() {
1149            call.id = format!("call_{index}");
1150        }
1151    }
1152}
1153
1154fn parse_single_string_status(
1155    status: llama_cpp_bindings_sys::llama_rs_status,
1156    out_value: *mut c_char,
1157    out_error: *mut c_char,
1158) -> Result<Option<String>, MarkerDetectionError> {
1159    match status {
1160        llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK => read_optional_owned_cstr(out_value),
1161        llama_cpp_bindings_sys::LLAMA_RS_STATUS_EXCEPTION => {
1162            let message = read_optional_owned_cstr_lossy(out_error);
1163
1164            Err(MarkerDetectionError::AnalyzeException(message))
1165        }
1166        other => Err(MarkerDetectionError::FfiError(status_to_i32(other))),
1167    }
1168}
1169
1170fn invoke_ffi_single_string_detector<TInvoke>(
1171    invoke: TInvoke,
1172) -> Result<Option<String>, MarkerDetectionError>
1173where
1174    TInvoke: FnOnce(*mut *mut c_char, *mut *mut c_char) -> llama_cpp_bindings_sys::llama_rs_status,
1175{
1176    let mut out_value: *mut c_char = ptr::null_mut();
1177    let mut out_error: *mut c_char = ptr::null_mut();
1178
1179    let status = invoke(&raw mut out_value, &raw mut out_error);
1180    let parsed = parse_single_string_status(status, out_value, out_error);
1181
1182    unsafe {
1183        if !out_value.is_null() {
1184            llama_cpp_bindings_sys::llama_rs_string_free(out_value);
1185        }
1186        if !out_error.is_null() {
1187            llama_cpp_bindings_sys::llama_rs_string_free(out_error);
1188        }
1189    }
1190
1191    parsed
1192}
1193
1194fn invoke_ffi_string_pair_detector<TInvoke>(
1195    invoke: TInvoke,
1196) -> Result<(Option<String>, Option<String>), MarkerDetectionError>
1197where
1198    TInvoke: FnOnce(
1199        *mut *mut c_char,
1200        *mut *mut c_char,
1201        *mut *mut c_char,
1202    ) -> llama_cpp_bindings_sys::llama_rs_status,
1203{
1204    let mut out_first: *mut c_char = ptr::null_mut();
1205    let mut out_second: *mut c_char = ptr::null_mut();
1206    let mut out_error: *mut c_char = ptr::null_mut();
1207
1208    let status = invoke(&raw mut out_first, &raw mut out_second, &raw mut out_error);
1209
1210    let parsed = (|| match status {
1211        llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK => {
1212            let first = read_optional_owned_cstr(out_first)?;
1213            let second = read_optional_owned_cstr(out_second)?;
1214
1215            Ok((first, second))
1216        }
1217        llama_cpp_bindings_sys::LLAMA_RS_STATUS_EXCEPTION => {
1218            let message = read_optional_owned_cstr_lossy(out_error);
1219
1220            Err(MarkerDetectionError::AnalyzeException(message))
1221        }
1222        other => Err(MarkerDetectionError::FfiError(status_to_i32(other))),
1223    })();
1224
1225    unsafe { llama_cpp_bindings_sys::llama_rs_string_free(out_first) };
1226    unsafe { llama_cpp_bindings_sys::llama_rs_string_free(out_second) };
1227    unsafe { llama_cpp_bindings_sys::llama_rs_string_free(out_error) };
1228
1229    parsed
1230}
1231
1232fn read_owned_cstr_for_parse(ptr: *mut c_char) -> Result<String, ParseChatMessageError> {
1233    if ptr.is_null() {
1234        return Ok(String::new());
1235    }
1236
1237    let bytes = unsafe { CStr::from_ptr(ptr) }.to_bytes().to_vec();
1238    unsafe { llama_cpp_bindings_sys::llama_rs_string_free(ptr) };
1239
1240    Ok(String::from_utf8(bytes)?)
1241}
1242
1243fn read_optional_owned_cstr(ptr: *const c_char) -> Result<Option<String>, MarkerDetectionError> {
1244    if ptr.is_null() {
1245        return Ok(None);
1246    }
1247
1248    let bytes = unsafe { CStr::from_ptr(ptr) }.to_bytes().to_vec();
1249
1250    Ok(Some(String::from_utf8(bytes)?))
1251}
1252
1253fn read_optional_owned_cstr_lossy(ptr: *const c_char) -> String {
1254    if ptr.is_null() {
1255        return String::new();
1256    }
1257
1258    unsafe { CStr::from_ptr(ptr) }
1259        .to_string_lossy()
1260        .into_owned()
1261}
1262
1263fn extract_meta_string<TCFunction>(
1264    c_function: TCFunction,
1265    capacity: usize,
1266) -> Result<String, MetaValError>
1267where
1268    TCFunction: Fn(*mut c_char, usize) -> i32,
1269{
1270    let mut buffer = vec![0u8; capacity];
1271    let result = c_function(buffer.as_mut_ptr().cast::<c_char>(), buffer.len());
1272
1273    if result < 0 {
1274        return Err(MetaValError::NegativeReturn(result));
1275    }
1276
1277    let returned_len = result.cast_unsigned() as usize;
1278
1279    if returned_len >= capacity {
1280        return extract_meta_string(c_function, returned_len + 1);
1281    }
1282
1283    if buffer.get(returned_len) != Some(&0) {
1284        return Err(MetaValError::NegativeReturn(-1));
1285    }
1286
1287    buffer.truncate(returned_len);
1288
1289    Ok(String::from_utf8(buffer)?)
1290}
1291
1292impl Drop for LlamaModel {
1293    fn drop(&mut self) {
1294        unsafe { llama_cpp_bindings_sys::llama_free_model(self.model.as_ptr()) }
1295    }
1296}
1297
1298#[cfg(test)]
1299mod extract_meta_string_tests {
1300    use super::extract_meta_string;
1301    use crate::MetaValError;
1302
1303    #[test]
1304    fn returns_error_when_null_terminator_missing() {
1305        let result = extract_meta_string(
1306            |buf_ptr, buf_len| {
1307                let buffer =
1308                    unsafe { std::slice::from_raw_parts_mut(buf_ptr.cast::<u8>(), buf_len) };
1309                buffer[0] = b'a';
1310                buffer[1] = b'b';
1311                buffer[2] = b'c';
1312                2
1313            },
1314            4,
1315        );
1316
1317        assert_eq!(result.unwrap_err(), MetaValError::NegativeReturn(-1));
1318    }
1319
1320    #[test]
1321    fn returns_error_for_negative_return_value() {
1322        let result = extract_meta_string(|_buf_ptr, _buf_len| -5, 4);
1323
1324        assert_eq!(result.unwrap_err(), MetaValError::NegativeReturn(-5));
1325    }
1326
1327    #[test]
1328    fn returns_error_for_invalid_utf8_data() {
1329        let result = extract_meta_string(
1330            |buf_ptr, buf_len| {
1331                let buffer =
1332                    unsafe { std::slice::from_raw_parts_mut(buf_ptr.cast::<u8>(), buf_len) };
1333                buffer[0] = 0xFF;
1334                buffer[1] = 0xFE;
1335                buffer[2] = 0;
1336                2
1337            },
1338            4,
1339        );
1340
1341        assert!(result.is_err());
1342        assert!(result.unwrap_err().to_string().contains("FromUtf8Error"));
1343    }
1344
1345    #[test]
1346    fn triggers_buffer_resize_when_returned_len_exceeds_capacity() {
1347        let initial_capacity: usize = 4;
1348        let length_exceeding_initial_capacity = 10;
1349        let written_length = 2;
1350        let call_count = std::cell::Cell::new(0);
1351        let result = extract_meta_string(
1352            |buf_ptr, buf_len| {
1353                let count = call_count.get();
1354                call_count.set(count + 1);
1355                if count == 0 {
1356                    length_exceeding_initial_capacity
1357                } else {
1358                    let buffer =
1359                        unsafe { std::slice::from_raw_parts_mut(buf_ptr.cast::<u8>(), buf_len) };
1360                    buffer[0] = b'h';
1361                    buffer[1] = b'i';
1362                    buffer[2] = 0;
1363                    written_length
1364                }
1365            },
1366            initial_capacity,
1367        );
1368
1369        assert_eq!(result.unwrap(), "hi");
1370    }
1371
1372    #[test]
1373    fn cstring_with_validated_len_null_byte_returns_error() {
1374        let result = super::cstring_with_validated_len("null\0byte");
1375
1376        assert!(result.is_err());
1377    }
1378
1379    #[test]
1380    fn validate_string_length_overflow_returns_error() {
1381        let result = super::validate_string_length_for_tokenizer(usize::MAX);
1382
1383        assert!(result.is_err());
1384    }
1385
1386    #[test]
1387    fn truncated_buffer_to_string_with_invalid_utf8_returns_error() {
1388        let invalid_utf8 = vec![0xff, 0xfe, 0xfd];
1389        let result = super::truncated_buffer_to_string(invalid_utf8, 3);
1390
1391        assert!(result.is_err());
1392    }
1393}
1394
1395#[cfg(test)]
1396mod ffi_helper_tests {
1397    use std::ffi::CString;
1398    use std::ptr;
1399
1400    use super::invoke_ffi_single_string_detector;
1401    use super::invoke_ffi_string_pair_detector;
1402    use super::parse_single_string_status;
1403    use super::read_optional_owned_cstr_lossy;
1404    use crate::MarkerDetectionError;
1405
1406    #[test]
1407    fn read_optional_owned_cstr_lossy_returns_empty_for_null() {
1408        let result = read_optional_owned_cstr_lossy(ptr::null());
1409
1410        assert!(result.is_empty());
1411    }
1412
1413    #[test]
1414    fn read_optional_owned_cstr_lossy_returns_string_for_valid_pointer() {
1415        let owned = CString::new("hello").expect("static literal has no nuls");
1416        let result = read_optional_owned_cstr_lossy(owned.as_ptr());
1417
1418        assert_eq!(result, "hello");
1419    }
1420
1421    #[test]
1422    fn read_optional_owned_cstr_lossy_handles_invalid_utf8_via_replacement() {
1423        let owned = CString::new(vec![b'a', 0xFF, b'b']).expect("no interior nul");
1424        let result = read_optional_owned_cstr_lossy(owned.as_ptr());
1425
1426        assert!(result.starts_with('a'));
1427        assert!(result.ends_with('b'));
1428    }
1429
1430    #[test]
1431    fn parse_single_string_status_returns_none_for_ok_with_null() {
1432        let result = parse_single_string_status(
1433            llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK,
1434            ptr::null_mut(),
1435            ptr::null_mut(),
1436        );
1437
1438        assert_eq!(result.expect("OK + null returns Ok(None)"), None);
1439    }
1440
1441    #[test]
1442    fn parse_single_string_status_returns_some_for_ok_with_value() {
1443        let owned = CString::new("present").expect("no nul");
1444        let result = parse_single_string_status(
1445            llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK,
1446            owned.as_ptr().cast_mut(),
1447            ptr::null_mut(),
1448        );
1449
1450        assert_eq!(
1451            result.expect("OK + value returns Ok(Some)"),
1452            Some("present".to_owned())
1453        );
1454    }
1455
1456    #[test]
1457    fn parse_single_string_status_returns_analyze_exception() {
1458        let owned = CString::new("boom").expect("no nul");
1459        let result = parse_single_string_status(
1460            llama_cpp_bindings_sys::LLAMA_RS_STATUS_EXCEPTION,
1461            ptr::null_mut(),
1462            owned.as_ptr().cast_mut(),
1463        );
1464
1465        match result.expect_err("EXCEPTION must yield Err") {
1466            MarkerDetectionError::AnalyzeException(message) => assert_eq!(message, "boom"),
1467            other => panic!("expected AnalyzeException, got {other:?}"),
1468        }
1469    }
1470
1471    #[test]
1472    fn parse_single_string_status_returns_ffi_error_for_other_status() {
1473        let result = parse_single_string_status(
1474            llama_cpp_bindings_sys::LLAMA_RS_STATUS_INVALID_ARGUMENT,
1475            ptr::null_mut(),
1476            ptr::null_mut(),
1477        );
1478
1479        match result.expect_err("invalid status must yield Err") {
1480            MarkerDetectionError::FfiError(_) => {}
1481            other => panic!("expected FfiError, got {other:?}"),
1482        }
1483    }
1484
1485    #[test]
1486    fn invoke_ffi_single_string_detector_propagates_invalid_argument_status() {
1487        let result = invoke_ffi_single_string_detector(|_value, _error| {
1488            llama_cpp_bindings_sys::LLAMA_RS_STATUS_INVALID_ARGUMENT
1489        });
1490
1491        assert!(matches!(result, Err(MarkerDetectionError::FfiError(_))));
1492    }
1493
1494    #[test]
1495    fn invoke_ffi_single_string_detector_returns_none_for_ok_with_null() {
1496        let result = invoke_ffi_single_string_detector(|value, _error| {
1497            unsafe {
1498                *value = ptr::null_mut();
1499            }
1500            llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK
1501        });
1502
1503        assert_eq!(result.expect("OK + null returns Ok(None)"), None);
1504    }
1505
1506    #[test]
1507    fn invoke_ffi_string_pair_detector_propagates_invalid_argument_status() {
1508        let result = invoke_ffi_string_pair_detector(|_first, _second, _error| {
1509            llama_cpp_bindings_sys::LLAMA_RS_STATUS_INVALID_ARGUMENT
1510        });
1511
1512        assert!(matches!(result, Err(MarkerDetectionError::FfiError(_))));
1513    }
1514
1515    #[test]
1516    fn invoke_ffi_string_pair_detector_returns_pair_of_none_for_ok_with_nulls() {
1517        let result = invoke_ffi_string_pair_detector(|first, second, _error| {
1518            unsafe {
1519                *first = ptr::null_mut();
1520                *second = ptr::null_mut();
1521            }
1522            llama_cpp_bindings_sys::LLAMA_RS_STATUS_OK
1523        });
1524
1525        assert_eq!(
1526            result.expect("OK with both null returns Ok((None, None))"),
1527            (None, None)
1528        );
1529    }
1530
1531    #[test]
1532    fn invoke_ffi_string_pair_detector_propagates_invalid_status_codes() {
1533        let result = invoke_ffi_string_pair_detector(|_first, _second, _error| {
1534            llama_cpp_bindings_sys::LLAMA_RS_STATUS_ALLOCATION_FAILED
1535        });
1536
1537        match result.expect_err("non-OK status yields Err") {
1538            MarkerDetectionError::FfiError(code) => assert!(code != 0),
1539            other => panic!("expected FfiError, got {other:?}"),
1540        }
1541    }
1542}