llama_cpp_4/model.rs
1//! A safe wrapper around `llama_model`.
2use std::ffi::CStr;
3use std::ffi::CString;
4use std::fmt;
5use std::num::NonZeroU16;
6use std::os::raw::{c_char, c_int};
7use std::path::Path;
8use std::ptr::NonNull;
9
10use llama_cpp_sys_4::{
11 llama_adapter_lora, llama_adapter_lora_init, llama_add_bos_token, llama_add_eos_token,
12 llama_chat_apply_template, llama_chat_builtin_templates, llama_chat_message,
13 llama_detokenize, llama_free_model, llama_load_model_from_file, llama_model,
14 llama_model_cls_label, llama_model_decoder_start_token, llama_model_desc,
15 llama_model_get_vocab, llama_model_has_decoder, llama_model_has_encoder,
16 llama_model_is_diffusion, llama_model_is_hybrid, llama_model_is_recurrent,
17 llama_model_load_from_splits, llama_model_meta_count, llama_model_meta_key_by_index,
18 llama_model_meta_val_str, llama_model_meta_val_str_by_index, llama_model_n_cls_out,
19 llama_model_n_embd_inp, llama_model_n_embd_out, llama_model_n_head_kv, llama_model_n_params,
20 llama_model_n_swa, llama_model_rope_freq_scale_train, llama_model_rope_type,
21 llama_model_save_to_file, llama_model_size, llama_n_ctx_train, llama_n_embd, llama_n_head,
22 llama_n_layer, llama_n_vocab, llama_new_context_with_model, llama_split_path,
23 llama_split_prefix, llama_token_bos, llama_token_cls, llama_token_eos, llama_token_eot,
24 llama_token_fim_mid, llama_token_fim_pad, llama_token_fim_pre, llama_token_fim_rep,
25 llama_token_fim_sep, llama_token_fim_suf, llama_token_get_attr, llama_token_get_score,
26 llama_token_get_text, llama_token_is_control, llama_token_is_eog, llama_token_nl,
27 llama_token_pad, llama_token_sep, llama_token_to_piece, llama_tokenize, llama_vocab,
28 llama_vocab_type, LLAMA_VOCAB_TYPE_BPE, LLAMA_VOCAB_TYPE_SPM,
29};
30
31use crate::context::params::LlamaContextParams;
32use crate::context::LlamaContext;
33use crate::llama_backend::LlamaBackend;
34use crate::model::params::LlamaModelParams;
35use crate::token::LlamaToken;
36use crate::token_type::{LlamaTokenAttr, LlamaTokenAttrs};
37use crate::{
38 ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaLoraAdapterInitError,
39 LlamaModelLoadError, NewLlamaChatMessageError, StringFromModelError, StringToTokenError,
40 TokenToStringError,
41};
42
43pub mod params;
44
45/// A safe wrapper around `llama_model`.
46#[derive(Debug)]
47#[repr(transparent)]
48#[allow(clippy::module_name_repetitions)]
49pub struct LlamaModel {
50 pub(crate) model: NonNull<llama_model>,
51}
52
53/// A safe wrapper around `llama_vocab`.
54#[derive(Debug)]
55#[repr(transparent)]
56#[allow(clippy::module_name_repetitions)]
57pub struct LlamaVocab {
58 pub(crate) vocab: NonNull<llama_vocab>,
59}
60
61impl LlamaVocab {
62 /// Get the number of tokens in the vocabulary.
63 #[must_use]
64 pub fn n_tokens(&self) -> i32 {
65 unsafe { llama_cpp_sys_4::llama_vocab_n_tokens(self.vocab.as_ref()) }
66 }
67
68 /// Get the vocabulary type.
69 #[must_use]
70 pub fn vocab_type(&self) -> u32 {
71 unsafe { llama_cpp_sys_4::llama_vocab_type(self.vocab.as_ref()).try_into().unwrap() }
72 }
73
74 /// Get the BOS token.
75 #[must_use]
76 pub fn bos(&self) -> LlamaToken {
77 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_bos(self.vocab.as_ref()) })
78 }
79
80 /// Get the EOS token.
81 #[must_use]
82 pub fn eos(&self) -> LlamaToken {
83 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_eos(self.vocab.as_ref()) })
84 }
85
86 /// Get the EOT (end of turn) token.
87 #[must_use]
88 pub fn eot(&self) -> LlamaToken {
89 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_eot(self.vocab.as_ref()) })
90 }
91
92 /// Get the CLS (classification) token.
93 #[must_use]
94 pub fn cls(&self) -> LlamaToken {
95 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_cls(self.vocab.as_ref()) })
96 }
97
98 /// Get the SEP (separator) token.
99 #[must_use]
100 pub fn sep(&self) -> LlamaToken {
101 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_sep(self.vocab.as_ref()) })
102 }
103
104 /// Get the NL (newline) token.
105 #[must_use]
106 pub fn nl(&self) -> LlamaToken {
107 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_nl(self.vocab.as_ref()) })
108 }
109
110 /// Get the PAD (padding) token.
111 #[must_use]
112 pub fn pad(&self) -> LlamaToken {
113 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_pad(self.vocab.as_ref()) })
114 }
115
116 /// Get the FIM prefix token.
117 #[must_use]
118 pub fn fim_pre(&self) -> LlamaToken {
119 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_pre(self.vocab.as_ref()) })
120 }
121
122 /// Get the FIM suffix token.
123 #[must_use]
124 pub fn fim_suf(&self) -> LlamaToken {
125 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_suf(self.vocab.as_ref()) })
126 }
127
128 /// Get the FIM middle token.
129 #[must_use]
130 pub fn fim_mid(&self) -> LlamaToken {
131 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_mid(self.vocab.as_ref()) })
132 }
133
134 /// Get the FIM padding token.
135 #[must_use]
136 pub fn fim_pad(&self) -> LlamaToken {
137 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_pad(self.vocab.as_ref()) })
138 }
139
140 /// Get the FIM repository token.
141 #[must_use]
142 pub fn fim_rep(&self) -> LlamaToken {
143 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_rep(self.vocab.as_ref()) })
144 }
145
146 /// Get the FIM separator token.
147 #[must_use]
148 pub fn fim_sep(&self) -> LlamaToken {
149 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_sep(self.vocab.as_ref()) })
150 }
151
152 /// Check whether BOS should be added.
153 #[must_use]
154 pub fn get_add_bos(&self) -> bool {
155 unsafe { llama_cpp_sys_4::llama_vocab_get_add_bos(self.vocab.as_ref()) }
156 }
157
158 /// Check whether EOS should be added.
159 #[must_use]
160 pub fn get_add_eos(&self) -> bool {
161 unsafe { llama_cpp_sys_4::llama_vocab_get_add_eos(self.vocab.as_ref()) }
162 }
163
164 /// Check whether SEP should be added.
165 #[must_use]
166 pub fn get_add_sep(&self) -> bool {
167 unsafe { llama_cpp_sys_4::llama_vocab_get_add_sep(self.vocab.as_ref()) }
168 }
169
170 /// Get the text representation of a token.
171 ///
172 /// # Errors
173 ///
174 /// Returns an error if the text pointer is null or not valid UTF-8.
175 pub fn get_text(&self, token: LlamaToken) -> Result<&str, StringFromModelError> {
176 let ptr = unsafe { llama_cpp_sys_4::llama_vocab_get_text(self.vocab.as_ref(), token.0) };
177 if ptr.is_null() {
178 return Err(StringFromModelError::ReturnedError(-1));
179 }
180 let cstr = unsafe { CStr::from_ptr(ptr) };
181 cstr.to_str().map_err(StringFromModelError::Utf8Error)
182 }
183
184 /// Get the score of a token.
185 #[must_use]
186 pub fn get_score(&self, token: LlamaToken) -> f32 {
187 unsafe { llama_cpp_sys_4::llama_vocab_get_score(self.vocab.as_ref(), token.0) }
188 }
189
190 /// Get the attributes of a token.
191 #[must_use]
192 pub fn get_attr(&self, token: LlamaToken) -> u32 {
193 unsafe { llama_cpp_sys_4::llama_vocab_get_attr(self.vocab.as_ref(), token.0).try_into().unwrap() }
194 }
195
196 /// Check if a token is a control token.
197 #[must_use]
198 pub fn is_control(&self, token: LlamaToken) -> bool {
199 unsafe { llama_cpp_sys_4::llama_vocab_is_control(self.vocab.as_ref(), token.0) }
200 }
201
202 /// Check if a token is an end-of-generation token.
203 #[must_use]
204 pub fn is_eog(&self, token: LlamaToken) -> bool {
205 unsafe { llama_cpp_sys_4::llama_vocab_is_eog(self.vocab.as_ref(), token.0) }
206 }
207
208 /// Get the token mask value for the vocabulary.
209 #[must_use]
210 pub fn mask(&self) -> LlamaToken {
211 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_mask(self.vocab.as_ref()) })
212 }
213}
214
215/// A safe wrapper around `llama_adapter_lora`.
216#[derive(Debug)]
217#[repr(transparent)]
218#[allow(clippy::module_name_repetitions)]
219pub struct LlamaLoraAdapter {
220 pub(crate) lora_adapter: NonNull<llama_adapter_lora>,
221}
222
223impl LlamaLoraAdapter {
224 /// Get the number of metadata key-value pairs in the adapter.
225 #[must_use]
226 pub fn meta_count(&self) -> i32 {
227 unsafe { llama_cpp_sys_4::llama_adapter_meta_count(self.lora_adapter.as_ptr()) }
228 }
229
230 /// Get a metadata key by index.
231 ///
232 /// # Errors
233 ///
234 /// Returns an error if the index is out of range or the key is not valid UTF-8.
235 #[allow(clippy::cast_sign_loss)]
236 pub fn meta_key_by_index(
237 &self,
238 index: i32,
239 buf_size: usize,
240 ) -> Result<String, StringFromModelError> {
241 let mut buf = vec![0u8; buf_size];
242 let ret = unsafe {
243 llama_cpp_sys_4::llama_adapter_meta_key_by_index(
244 self.lora_adapter.as_ptr(),
245 index,
246 buf.as_mut_ptr().cast::<c_char>(),
247 buf_size,
248 )
249 };
250 if ret < 0 {
251 return Err(StringFromModelError::ReturnedError(ret));
252 }
253 let len = ret as usize;
254 let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
255 Ok(s.to_owned())
256 }
257
258 /// Get a metadata value by key name.
259 ///
260 /// # Errors
261 ///
262 /// Returns an error if the key is not found or the value is not valid UTF-8.
263 #[allow(clippy::cast_sign_loss)]
264 pub fn meta_val_str(
265 &self,
266 key: &str,
267 buf_size: usize,
268 ) -> Result<String, StringFromModelError> {
269 let c_key =
270 CString::new(key).map_err(|_| StringFromModelError::ReturnedError(-1))?;
271 let mut buf = vec![0u8; buf_size];
272 let ret = unsafe {
273 llama_cpp_sys_4::llama_adapter_meta_val_str(
274 self.lora_adapter.as_ptr(),
275 c_key.as_ptr(),
276 buf.as_mut_ptr().cast::<c_char>(),
277 buf_size,
278 )
279 };
280 if ret < 0 {
281 return Err(StringFromModelError::ReturnedError(ret));
282 }
283 let len = ret as usize;
284 let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
285 Ok(s.to_owned())
286 }
287
288 /// Get a metadata value by index.
289 ///
290 /// # Errors
291 ///
292 /// Returns an error if the index is out of range or the value is not valid UTF-8.
293 #[allow(clippy::cast_sign_loss)]
294 pub fn meta_val_str_by_index(
295 &self,
296 index: i32,
297 buf_size: usize,
298 ) -> Result<String, StringFromModelError> {
299 let mut buf = vec![0u8; buf_size];
300 let ret = unsafe {
301 llama_cpp_sys_4::llama_adapter_meta_val_str_by_index(
302 self.lora_adapter.as_ptr(),
303 index,
304 buf.as_mut_ptr().cast::<c_char>(),
305 buf_size,
306 )
307 };
308 if ret < 0 {
309 return Err(StringFromModelError::ReturnedError(ret));
310 }
311 let len = ret as usize;
312 let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
313 Ok(s.to_owned())
314 }
315
316 /// Get all metadata as a list of `(key, value)` pairs.
317 ///
318 /// # Errors
319 ///
320 /// Returns an error if any key or value cannot be read or is not valid UTF-8.
321 #[allow(clippy::cast_sign_loss)]
322 pub fn metadata(&self) -> Result<Vec<(String, String)>, StringFromModelError> {
323 let count = self.meta_count();
324 let mut result = Vec::with_capacity(count as usize);
325 for i in 0..count {
326 let key = self.meta_key_by_index(i, 256)?;
327 let val = self.meta_val_str_by_index(i, 4096)?;
328 result.push((key, val));
329 }
330 Ok(result)
331 }
332
333 /// Get the number of invocation tokens for this adapter.
334 #[must_use]
335 pub fn n_invocation_tokens(&self) -> u64 {
336 unsafe {
337 llama_cpp_sys_4::llama_adapter_get_alora_n_invocation_tokens(
338 self.lora_adapter.as_ptr(),
339 )
340 }
341 }
342
343 /// Get the invocation tokens for this adapter.
344 ///
345 /// Returns an empty slice if there are no invocation tokens.
346 #[must_use]
347 #[allow(clippy::cast_possible_truncation)]
348 pub fn invocation_tokens(&self) -> &[LlamaToken] {
349 let n = self.n_invocation_tokens() as usize;
350 if n == 0 {
351 return &[];
352 }
353 let ptr = unsafe {
354 llama_cpp_sys_4::llama_adapter_get_alora_invocation_tokens(
355 self.lora_adapter.as_ptr(),
356 )
357 };
358 if ptr.is_null() {
359 return &[];
360 }
361 // LlamaToken is repr(transparent) over llama_token (i32), so this cast is safe
362 unsafe { std::slice::from_raw_parts(ptr.cast::<LlamaToken>(), n) }
363 }
364}
365
366impl Drop for LlamaLoraAdapter {
367 fn drop(&mut self) {
368 unsafe {
369 llama_cpp_sys_4::llama_adapter_lora_free(self.lora_adapter.as_ptr());
370 }
371 }
372}
373
374/// A Safe wrapper around `llama_chat_message`
375#[derive(Debug, Eq, PartialEq, Clone)]
376pub struct LlamaChatMessage {
377 role: CString,
378 content: CString,
379}
380
381impl LlamaChatMessage {
382 /// Create a new `LlamaChatMessage`.
383 ///
384 /// # Errors
385 ///
386 /// Returns [`NewLlamaChatMessageError`] if the role or content contains a null byte.
387 pub fn new(role: String, content: String) -> Result<Self, NewLlamaChatMessageError> {
388 Ok(Self {
389 role: CString::new(role)?,
390 content: CString::new(content)?,
391 })
392 }
393}
394
395/// How to determine if we should prepend a bos token to tokens
396#[derive(Debug, Clone, Copy, PartialEq, Eq)]
397pub enum AddBos {
398 /// Add the beginning of stream token to the start of the string.
399 Always,
400 /// Do not add the beginning of stream token to the start of the string.
401 Never,
402}
403
404/// How to determine if we should tokenize special tokens
405#[derive(Debug, Clone, Copy, PartialEq, Eq)]
406pub enum Special {
407 /// Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space.
408 Tokenize,
409 /// Treat special and/or control tokens as plaintext.
410 Plaintext,
411}
412
413unsafe impl Send for LlamaModel {}
414
415unsafe impl Sync for LlamaModel {}
416
417impl LlamaModel {
418 /// Retrieves the vocabulary associated with the current Llama model.
419 ///
420 /// This method fetches the vocabulary from the underlying model using an unsafe
421 /// FFI call. The returned `LlamaVocab` struct contains a non-null pointer to
422 /// the vocabulary data, which is wrapped in a `NonNull` for safety.
423 ///
424 /// # Safety
425 /// This method uses an unsafe block to call a C function (`llama_model_get_vocab`),
426 /// which is assumed to return a valid pointer to the vocabulary. The caller should
427 /// ensure that the model object is properly initialized and valid before calling
428 /// this method, as dereferencing invalid pointers can lead to undefined behavior.
429 ///
430 /// # Returns
431 /// A `LlamaVocab` struct containing the vocabulary of the model.
432 ///
433 /// # Panics
434 ///
435 /// Panics if the underlying C function returns a null pointer.
436 ///
437 /// # Example
438 /// ```rust,ignore
439 /// let vocab = model.get_vocab();
440 /// ```
441 #[must_use]
442 pub fn get_vocab(&self) -> LlamaVocab {
443 let llama_vocab = unsafe { llama_model_get_vocab(self.model.as_ptr()) }.cast_mut();
444
445 LlamaVocab {
446 vocab: NonNull::new(llama_vocab).unwrap(),
447 }
448 }
449 /// Get the number of tokens the model was trained on.
450 ///
451 /// This function returns the number of tokens that the model was trained on, represented as a `u32`.
452 ///
453 /// # Panics
454 ///
455 /// This function will panic if the number of tokens the model was trained on does not fit into a `u32`.
456 /// This should be impossible on most platforms since llama.cpp returns a `c_int` (i32 on most platforms),
457 /// which is almost certainly positive.
458 #[must_use]
459 pub fn n_ctx_train(&self) -> u32 {
460 let n_ctx_train = unsafe { llama_n_ctx_train(self.model.as_ptr()) };
461 u32::try_from(n_ctx_train).expect("n_ctx_train fits into an u32")
462 }
463
464 /// Get all tokens in the model.
465 ///
466 /// This function returns an iterator over all the tokens in the model. Each item in the iterator is a tuple
467 /// containing a `LlamaToken` and its corresponding string representation (or an error if the conversion fails).
468 ///
469 /// # Parameters
470 ///
471 /// - `special`: The `Special` value that determines how special tokens (like BOS, EOS, etc.) are handled.
472 pub fn tokens(
473 &self,
474 special: Special,
475 ) -> impl Iterator<Item = (LlamaToken, Result<String, TokenToStringError>)> + '_ {
476 (0..self.n_vocab())
477 .map(LlamaToken::new)
478 .map(move |llama_token| (llama_token, self.token_to_str(llama_token, special)))
479 }
480
481 /// Get the beginning of stream token.
482 ///
483 /// This function returns the token that represents the beginning of a stream (BOS token).
484 #[must_use]
485 pub fn token_bos(&self) -> LlamaToken {
486 let token = unsafe { llama_token_bos(self.get_vocab().vocab.as_ref()) };
487 LlamaToken(token)
488 }
489
490 /// Get the end of stream token.
491 ///
492 /// This function returns the token that represents the end of a stream (EOS token).
493 #[must_use]
494 pub fn token_eos(&self) -> LlamaToken {
495 let token = unsafe { llama_token_eos(self.get_vocab().vocab.as_ref()) };
496 LlamaToken(token)
497 }
498
499 /// Get the newline token.
500 ///
501 /// This function returns the token that represents a newline character.
502 #[must_use]
503 pub fn token_nl(&self) -> LlamaToken {
504 let token = unsafe { llama_token_nl(self.get_vocab().vocab.as_ref()) };
505 LlamaToken(token)
506 }
507
508 /// Check if a token represents the end of generation (end of turn, end of sequence, etc.).
509 ///
510 /// This function returns `true` if the provided token signifies the end of generation or end of sequence,
511 /// such as EOS or other special tokens.
512 ///
513 /// # Parameters
514 ///
515 /// - `token`: The `LlamaToken` to check.
516 ///
517 /// # Returns
518 ///
519 /// - `true` if the token is an end-of-generation token, otherwise `false`.
520 #[must_use]
521 pub fn is_eog_token(&self, token: LlamaToken) -> bool {
522 unsafe { llama_token_is_eog(self.get_vocab().vocab.as_ref(), token.0) }
523 }
524
525 /// Get the classification token.
526 #[must_use]
527 pub fn token_cls(&self) -> LlamaToken {
528 let token = unsafe { llama_token_cls(self.get_vocab().vocab.as_ref()) };
529 LlamaToken(token)
530 }
531
532 /// Get the end-of-turn token.
533 #[must_use]
534 pub fn token_eot(&self) -> LlamaToken {
535 let token = unsafe { llama_token_eot(self.get_vocab().vocab.as_ref()) };
536 LlamaToken(token)
537 }
538
539 /// Get the padding token.
540 #[must_use]
541 pub fn token_pad(&self) -> LlamaToken {
542 let token = unsafe { llama_token_pad(self.get_vocab().vocab.as_ref()) };
543 LlamaToken(token)
544 }
545
546 /// Get the separator token.
547 #[must_use]
548 pub fn token_sep(&self) -> LlamaToken {
549 let token = unsafe { llama_token_sep(self.get_vocab().vocab.as_ref()) };
550 LlamaToken(token)
551 }
552
553 /// Get the fill-in-the-middle prefix token.
554 #[must_use]
555 pub fn token_fim_pre(&self) -> LlamaToken {
556 let token = unsafe { llama_token_fim_pre(self.get_vocab().vocab.as_ref()) };
557 LlamaToken(token)
558 }
559
560 /// Get the fill-in-the-middle suffix token.
561 #[must_use]
562 pub fn token_fim_suf(&self) -> LlamaToken {
563 let token = unsafe { llama_token_fim_suf(self.get_vocab().vocab.as_ref()) };
564 LlamaToken(token)
565 }
566
567 /// Get the fill-in-the-middle middle token.
568 #[must_use]
569 pub fn token_fim_mid(&self) -> LlamaToken {
570 let token = unsafe { llama_token_fim_mid(self.get_vocab().vocab.as_ref()) };
571 LlamaToken(token)
572 }
573
574 /// Get the fill-in-the-middle padding token.
575 #[must_use]
576 pub fn token_fim_pad(&self) -> LlamaToken {
577 let token = unsafe { llama_token_fim_pad(self.get_vocab().vocab.as_ref()) };
578 LlamaToken(token)
579 }
580
581 /// Get the fill-in-the-middle repository token.
582 #[must_use]
583 pub fn token_fim_rep(&self) -> LlamaToken {
584 let token = unsafe { llama_token_fim_rep(self.get_vocab().vocab.as_ref()) };
585 LlamaToken(token)
586 }
587
588 /// Get the fill-in-the-middle separator token.
589 #[must_use]
590 pub fn token_fim_sep(&self) -> LlamaToken {
591 let token = unsafe { llama_token_fim_sep(self.get_vocab().vocab.as_ref()) };
592 LlamaToken(token)
593 }
594
595 /// Check if a token is a control token.
596 #[must_use]
597 pub fn token_is_control(&self, token: LlamaToken) -> bool {
598 unsafe { llama_token_is_control(self.get_vocab().vocab.as_ref(), token.0) }
599 }
600
601 /// Get the score of a token.
602 #[must_use]
603 pub fn token_get_score(&self, token: LlamaToken) -> f32 {
604 unsafe { llama_token_get_score(self.get_vocab().vocab.as_ref(), token.0) }
605 }
606
607 /// Get the raw text of a token.
608 ///
609 /// # Errors
610 ///
611 /// Returns an error if the token text is null or not valid UTF-8.
612 pub fn token_get_text(&self, token: LlamaToken) -> Result<&str, StringFromModelError> {
613 let ptr =
614 unsafe { llama_token_get_text(self.get_vocab().vocab.as_ref(), token.0) };
615 if ptr.is_null() {
616 return Err(StringFromModelError::ReturnedError(-1));
617 }
618 let cstr = unsafe { CStr::from_ptr(ptr) };
619 cstr.to_str()
620 .map_err(StringFromModelError::Utf8Error)
621 }
622
623 /// Check if a BOS token should be added when tokenizing.
624 #[must_use]
625 pub fn add_bos_token(&self) -> bool {
626 unsafe { llama_add_bos_token(self.get_vocab().vocab.as_ref()) }
627 }
628
629 /// Check if an EOS token should be added when tokenizing.
630 #[must_use]
631 pub fn add_eos_token(&self) -> bool {
632 unsafe { llama_add_eos_token(self.get_vocab().vocab.as_ref()) }
633 }
634
635 /// Get the decoder start token.
636 ///
637 /// This function returns the token used to signal the start of decoding (i.e., the token used at the start
638 /// of a sequence generation).
639 #[must_use]
640 pub fn decode_start_token(&self) -> LlamaToken {
641 let token = unsafe { llama_model_decoder_start_token(self.model.as_ptr()) };
642 LlamaToken(token)
643 }
644
645 /// Convert a single token to a string.
646 ///
647 /// This function converts a `LlamaToken` into its string representation.
648 ///
649 /// # Errors
650 ///
651 /// This function returns an error if the token cannot be converted to a string. For more details, refer to
652 /// [`TokenToStringError`].
653 ///
654 /// # Parameters
655 ///
656 /// - `token`: The `LlamaToken` to convert.
657 /// - `special`: The `Special` value used to handle special tokens.
658 pub fn token_to_str(
659 &self,
660 token: LlamaToken,
661 special: Special,
662 ) -> Result<String, TokenToStringError> {
663 self.token_to_str_with_size(token, 32, special)
664 }
665
666 /// Convert a single token to bytes.
667 ///
668 /// This function converts a `LlamaToken` into a byte representation.
669 ///
670 /// # Errors
671 ///
672 /// This function returns an error if the token cannot be converted to bytes. For more details, refer to
673 /// [`TokenToStringError`].
674 ///
675 /// # Parameters
676 ///
677 /// - `token`: The `LlamaToken` to convert.
678 /// - `special`: The `Special` value used to handle special tokens.
679 pub fn token_to_bytes(
680 &self,
681 token: LlamaToken,
682 special: Special,
683 ) -> Result<Vec<u8>, TokenToStringError> {
684 self.token_to_bytes_with_size(token, 32, special, None)
685 }
686
687 /// Convert a vector of tokens to a single string.
688 ///
689 /// This function takes a slice of `LlamaToken`s and converts them into a single string, concatenating their
690 /// string representations.
691 ///
692 /// # Errors
693 ///
694 /// This function returns an error if any token cannot be converted to a string. For more details, refer to
695 /// [`TokenToStringError`].
696 ///
697 /// # Parameters
698 ///
699 /// - `tokens`: A slice of `LlamaToken`s to convert.
700 /// - `special`: The `Special` value used to handle special tokens.
701 pub fn tokens_to_str(
702 &self,
703 tokens: &[LlamaToken],
704 special: Special,
705 ) -> Result<String, TokenToStringError> {
706 let mut builder = String::with_capacity(tokens.len() * 4);
707 for str in tokens
708 .iter()
709 .copied()
710 .map(|t| self.token_to_str(t, special))
711 {
712 builder += &str?;
713 }
714 Ok(builder)
715 }
716
717 /// Convert a string to a vector of tokens.
718 ///
719 /// This function converts a string into a vector of `LlamaToken`s. The function will tokenize the string
720 /// and return the corresponding tokens.
721 ///
722 /// # Errors
723 ///
724 /// - This function will return an error if the input string contains a null byte.
725 ///
726 /// # Panics
727 ///
728 /// - This function will panic if the number of tokens exceeds `usize::MAX`.
729 ///
730 /// # Example
731 ///
732 /// ```no_run
733 /// use llama_cpp_4::model::LlamaModel;
734 ///
735 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
736 /// use std::path::Path;
737 /// use llama_cpp_4::model::AddBos;
738 /// let backend = llama_cpp_4::llama_backend::LlamaBackend::init()?;
739 /// let model = LlamaModel::load_from_file(&backend, Path::new("path/to/model"), &Default::default())?;
740 /// let tokens = model.str_to_token("Hello, World!", AddBos::Always)?;
741 /// # Ok(())
742 /// # }
743 /// ```
744 pub fn str_to_token(
745 &self,
746 str: &str,
747 add_bos: AddBos,
748 ) -> Result<Vec<LlamaToken>, StringToTokenError> {
749 let add_bos = match add_bos {
750 AddBos::Always => true,
751 AddBos::Never => false,
752 };
753
754 let tokens_estimation = std::cmp::max(8, (str.len() / 2) + usize::from(add_bos));
755 let mut buffer = Vec::with_capacity(tokens_estimation);
756
757 let c_string = CString::new(str)?;
758 let buffer_capacity =
759 c_int::try_from(buffer.capacity()).expect("buffer capacity should fit into a c_int");
760
761 let size = unsafe {
762 llama_tokenize(
763 self.get_vocab().vocab.as_ref(),
764 c_string.as_ptr(),
765 c_int::try_from(c_string.as_bytes().len())?,
766 buffer.as_mut_ptr(),
767 buffer_capacity,
768 add_bos,
769 true,
770 )
771 };
772
773 // if we fail the first time we can resize the vector to the correct size and try again. This should never fail.
774 // as a result - size is guaranteed to be positive here.
775 let size = if size.is_negative() {
776 buffer.reserve_exact(usize::try_from(-size).expect("usize's are larger "));
777 unsafe {
778 llama_tokenize(
779 self.get_vocab().vocab.as_ref(),
780 c_string.as_ptr(),
781 c_int::try_from(c_string.as_bytes().len())?,
782 buffer.as_mut_ptr(),
783 -size,
784 add_bos,
785 true,
786 )
787 }
788 } else {
789 size
790 };
791
792 let size = usize::try_from(size).expect("size is positive and usize ");
793
794 // Safety: `size` < `capacity` and llama-cpp has initialized elements up to `size`
795 unsafe { buffer.set_len(size) }
796 Ok(buffer.into_iter().map(LlamaToken).collect())
797 }
798
799 /// Get the type of a token.
800 ///
801 /// This function retrieves the attributes associated with a given token. The attributes are typically used to
802 /// understand whether the token represents a special type of token (e.g., beginning-of-sequence (BOS), end-of-sequence (EOS),
803 /// control tokens, etc.).
804 ///
805 /// # Panics
806 ///
807 /// - This function will panic if the token type is unknown or cannot be converted to a valid `LlamaTokenAttrs`.
808 ///
809 /// # Example
810 ///
811 /// ```no_run
812 /// use llama_cpp_4::model::{LlamaModel, LlamaToken};
813 ///
814 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
815 /// let model = LlamaModel::load_from_file("path/to/model")?;
816 /// let token = LlamaToken(42);
817 /// let token_attrs = model.token_attr(token);
818 /// # Ok(())
819 /// # }
820 /// ```
821 #[must_use]
822 pub fn token_attr(&self, LlamaToken(id): LlamaToken) -> LlamaTokenAttrs {
823 let token_type = unsafe { llama_token_get_attr(self.get_vocab().vocab.as_ref(), id) };
824 LlamaTokenAttrs::try_from(token_type).expect("token type is valid")
825 }
826
827 /// Detokenize a slice of tokens into a string.
828 ///
829 /// This is the inverse of [`str_to_token`](Self::str_to_token).
830 ///
831 /// # Parameters
832 ///
833 /// - `tokens`: The tokens to detokenize.
834 /// - `remove_special`: If `true`, special tokens are removed from the output.
835 /// - `unparse_special`: If `true`, special tokens are rendered as their text representation.
836 ///
837 /// # Errors
838 ///
839 /// Returns an error if the detokenized text is not valid UTF-8.
840 #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap, clippy::cast_sign_loss)]
841 pub fn detokenize(
842 &self,
843 tokens: &[LlamaToken],
844 remove_special: bool,
845 unparse_special: bool,
846 ) -> Result<String, StringFromModelError> {
847 // First call with empty buffer to get required size
848 let n_tokens = tokens.len() as i32;
849 let token_ptr = tokens.as_ptr().cast::<llama_cpp_sys_4::llama_token>();
850 let needed = unsafe {
851 llama_detokenize(
852 self.get_vocab().vocab.as_ref(),
853 token_ptr,
854 n_tokens,
855 std::ptr::null_mut(),
856 0,
857 remove_special,
858 unparse_special,
859 )
860 };
861 // llama_detokenize returns negative required size when buffer is too small
862 let buf_size = if needed < 0 { (-needed) as usize } else { needed as usize };
863 let mut buf = vec![0u8; buf_size];
864 let ret = unsafe {
865 llama_detokenize(
866 self.get_vocab().vocab.as_ref(),
867 token_ptr,
868 n_tokens,
869 buf.as_mut_ptr().cast::<c_char>(),
870 buf_size as i32,
871 remove_special,
872 unparse_special,
873 )
874 };
875 if ret < 0 {
876 return Err(StringFromModelError::ReturnedError(ret));
877 }
878 let len = ret as usize;
879 let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
880 Ok(s.to_owned())
881 }
882
883 /// Convert a token to a string with a specified buffer size.
884 ///
885 /// This function allows you to convert a token into a string, with the ability to specify a buffer size for the operation.
886 /// It is generally recommended to use `LlamaModel::token_to_str` instead, as 8 bytes is typically sufficient for most tokens,
887 /// and the extra buffer size doesn't usually matter.
888 ///
889 /// # Errors
890 ///
891 /// - If the token type is unknown, an error will be returned.
892 /// - If the resultant token exceeds the provided `buffer_size`, an error will occur.
893 /// - If the token string returned by `llama-cpp` is not valid UTF-8, it will return an error.
894 ///
895 /// # Panics
896 ///
897 /// - This function will panic if the `buffer_size` does not fit into a `c_int`.
898 /// - It will also panic if the size returned from `llama-cpp` does not fit into a `usize`, which should typically never happen.
899 ///
900 /// # Example
901 ///
902 /// ```no_run
903 /// use llama_cpp_4::model::{LlamaModel, LlamaToken};
904 ///
905 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
906 /// let model = LlamaModel::load_from_file("path/to/model")?;
907 /// let token = LlamaToken(42);
908 /// let token_string = model.token_to_str_with_size(token, 32, Special::Plaintext)?;
909 /// # Ok(())
910 /// # }
911 /// ```
912 pub fn token_to_str_with_size(
913 &self,
914 token: LlamaToken,
915 buffer_size: usize,
916 special: Special,
917 ) -> Result<String, TokenToStringError> {
918 let bytes = self.token_to_bytes_with_size(token, buffer_size, special, None)?;
919 Ok(String::from_utf8(bytes)?)
920 }
921
922 /// Convert a token to bytes with a specified buffer size.
923 ///
924 /// Generally you should use [`LlamaModel::token_to_bytes`] instead as 8 bytes is enough for most words and
925 /// the extra bytes do not really matter.
926 ///
927 /// # Errors
928 ///
929 /// - if the token type is unknown
930 /// - the resultant token is larger than `buffer_size`.
931 ///
932 /// # Panics
933 ///
934 /// - This function will panic if `buffer_size` cannot fit into a `c_int`.
935 /// - It will also panic if the size returned from `llama-cpp` cannot be converted to `usize` (which should not happen).
936 ///
937 /// # Example
938 ///
939 /// ```no_run
940 /// use llama_cpp_4::model::{LlamaModel, LlamaToken};
941 ///
942 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
943 /// let model = LlamaModel::load_from_file("path/to/model")?;
944 /// let token = LlamaToken(42);
945 /// let token_bytes = model.token_to_bytes_with_size(token, 32, Special::Plaintext, None)?;
946 /// # Ok(())
947 /// # }
948 /// ```
949 pub fn token_to_bytes_with_size(
950 &self,
951 token: LlamaToken,
952 buffer_size: usize,
953 special: Special,
954 lstrip: Option<NonZeroU16>,
955 ) -> Result<Vec<u8>, TokenToStringError> {
956 if token == self.token_nl() {
957 return Ok(String::from("\n").into_bytes());
958 }
959
960 // unsure what to do with this in the face of the 'special' arg + attr changes
961 let attrs = self.token_attr(token);
962 if (attrs.contains(LlamaTokenAttr::Control)
963 && (token == self.token_bos() || token == self.token_eos()))
964 || attrs.is_empty()
965 || attrs
966 .intersects(LlamaTokenAttr::Unknown | LlamaTokenAttr::Byte | LlamaTokenAttr::Unused)
967 {
968 return Ok(Vec::new());
969 }
970
971 let special = match special {
972 Special::Tokenize => true,
973 Special::Plaintext => false,
974 };
975
976 let string = CString::new(vec![b'*'; buffer_size]).expect("no null");
977 let len = string.as_bytes().len();
978 let len = c_int::try_from(len).expect("length fits into c_int");
979 let buf = string.into_raw();
980 let lstrip = lstrip.map_or(0, |it| i32::from(it.get()));
981 let size = unsafe {
982 llama_token_to_piece(
983 self.get_vocab().vocab.as_ref(),
984 token.0,
985 buf,
986 len,
987 lstrip,
988 special,
989 )
990 };
991
992 match size {
993 0 => Err(TokenToStringError::UnknownTokenType),
994 i if i.is_negative() => Err(TokenToStringError::InsufficientBufferSpace(i)),
995 size => {
996 let string = unsafe { CString::from_raw(buf) };
997 let mut bytes = string.into_bytes();
998 let len = usize::try_from(size).expect("size is positive and fits into usize");
999 bytes.truncate(len);
1000 Ok(bytes)
1001 }
1002 }
1003 }
1004 /// The number of tokens the model was trained on.
1005 ///
1006 /// This function returns the number of tokens the model was trained on. It is returned as a `c_int` for maximum
1007 /// compatibility with the underlying llama-cpp library, though it can typically be cast to an `i32` without issue.
1008 ///
1009 /// # Example
1010 ///
1011 /// ```no_run
1012 /// use llama_cpp_4::model::LlamaModel;
1013 ///
1014 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1015 /// let model = LlamaModel::load_from_file("path/to/model")?;
1016 /// let n_vocab = model.n_vocab();
1017 /// # Ok(())
1018 /// # }
1019 /// ```
1020 #[must_use]
1021 pub fn n_vocab(&self) -> i32 {
1022 unsafe { llama_n_vocab(self.get_vocab().vocab.as_ref()) }
1023 }
1024
1025 /// The type of vocab the model was trained on.
1026 ///
1027 /// This function returns the type of vocabulary used by the model, such as whether it is based on byte-pair encoding (BPE),
1028 /// word-level tokens, or another tokenization scheme.
1029 ///
1030 /// # Panics
1031 ///
1032 /// - This function will panic if `llama-cpp` emits a vocab type that is not recognized or is invalid for this library.
1033 ///
1034 /// # Example
1035 ///
1036 /// ```no_run
1037 /// use llama_cpp_4::model::LlamaModel;
1038 ///
1039 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1040 /// let model = LlamaModel::load_from_file("path/to/model")?;
1041 /// let vocab_type = model.vocab_type();
1042 /// # Ok(())
1043 /// # }
1044 /// ```
1045 #[must_use]
1046 pub fn vocab_type(&self) -> VocabType {
1047 let vocab_type = unsafe { llama_vocab_type(self.get_vocab().vocab.as_ref()) };
1048 VocabType::try_from(vocab_type).expect("invalid vocab type")
1049 }
1050
1051 /// Returns the number of embedding dimensions for the model.
1052 ///
1053 /// This function retrieves the number of embeddings (or embedding dimensions) used by the model. It is typically
1054 /// used for analyzing model architecture and setting up context parameters or other model configuration aspects.
1055 ///
1056 /// # Panics
1057 ///
1058 /// - This function may panic if the underlying `llama-cpp` library returns an invalid embedding dimension value.
1059 ///
1060 /// # Example
1061 ///
1062 /// ```no_run
1063 /// use llama_cpp_4::model::LlamaModel;
1064 ///
1065 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1066 /// let model = LlamaModel::load_from_file("path/to/model")?;
1067 /// let n_embd = model.n_embd();
1068 /// # Ok(())
1069 /// # }
1070 /// ```
1071 #[must_use]
1072 pub fn n_embd(&self) -> c_int {
1073 unsafe { llama_n_embd(self.model.as_ptr()) }
1074 }
1075
1076 /// Get the number of transformer layers in the model.
1077 #[must_use]
1078 pub fn n_layer(&self) -> c_int {
1079 unsafe { llama_n_layer(self.model.as_ptr()) }
1080 }
1081
1082 /// Get the number of attention heads in the model.
1083 #[must_use]
1084 pub fn n_head(&self) -> c_int {
1085 unsafe { llama_n_head(self.model.as_ptr()) }
1086 }
1087
1088 /// Get the number of key-value attention heads in the model.
1089 #[must_use]
1090 pub fn n_head_kv(&self) -> c_int {
1091 unsafe { llama_model_n_head_kv(self.model.as_ptr()) }
1092 }
1093
1094 /// Get the input embedding size of the model.
1095 #[must_use]
1096 pub fn n_embd_inp(&self) -> c_int {
1097 unsafe { llama_model_n_embd_inp(self.model.as_ptr()) }
1098 }
1099
1100 /// Get the output embedding size of the model.
1101 #[must_use]
1102 pub fn n_embd_out(&self) -> c_int {
1103 unsafe { llama_model_n_embd_out(self.model.as_ptr()) }
1104 }
1105
1106 /// Get the sliding window attention size of the model.
1107 /// Returns 0 if the model does not use sliding window attention.
1108 #[must_use]
1109 pub fn n_swa(&self) -> c_int {
1110 unsafe { llama_model_n_swa(self.model.as_ptr()) }
1111 }
1112
1113 /// Get the `RoPE` type used by the model.
1114 #[must_use]
1115 pub fn rope_type(&self) -> i32 {
1116 unsafe { llama_model_rope_type(self.model.as_ptr()) }
1117 }
1118
1119 /// Get the `RoPE` frequency scale used during training.
1120 #[must_use]
1121 pub fn rope_freq_scale_train(&self) -> f32 {
1122 unsafe { llama_model_rope_freq_scale_train(self.model.as_ptr()) }
1123 }
1124
1125 /// Get the model size in bytes.
1126 #[must_use]
1127 pub fn model_size(&self) -> u64 {
1128 unsafe { llama_model_size(self.model.as_ptr()) }
1129 }
1130
1131 /// Get the number of parameters in the model.
1132 #[must_use]
1133 pub fn n_params(&self) -> u64 {
1134 unsafe { llama_model_n_params(self.model.as_ptr()) }
1135 }
1136
1137 /// Get the number of classification outputs.
1138 #[must_use]
1139 pub fn n_cls_out(&self) -> u32 {
1140 unsafe { llama_model_n_cls_out(self.model.as_ptr()) }
1141 }
1142
1143 /// Get the classification label for the given index.
1144 ///
1145 /// # Errors
1146 ///
1147 /// Returns an error if the label is null or not valid UTF-8.
1148 pub fn cls_label(&self, index: u32) -> Result<&str, StringFromModelError> {
1149 let ptr = unsafe { llama_model_cls_label(self.model.as_ptr(), index) };
1150 if ptr.is_null() {
1151 return Err(StringFromModelError::ReturnedError(-1));
1152 }
1153 let cstr = unsafe { CStr::from_ptr(ptr) };
1154 cstr.to_str().map_err(StringFromModelError::Utf8Error)
1155 }
1156
1157 /// Get the number of metadata key-value pairs.
1158 #[must_use]
1159 pub fn meta_count(&self) -> c_int {
1160 unsafe { llama_model_meta_count(self.model.as_ptr()) }
1161 }
1162
1163 /// Get a model description string.
1164 ///
1165 /// The `buf_size` parameter specifies the maximum buffer size for the description.
1166 /// A default of 256 bytes is usually sufficient.
1167 ///
1168 /// # Errors
1169 ///
1170 /// Returns an error if the description could not be retrieved or is not valid UTF-8.
1171 #[allow(clippy::cast_sign_loss)]
1172 pub fn desc(&self, buf_size: usize) -> Result<String, StringFromModelError> {
1173 let mut buf = vec![0u8; buf_size];
1174 let ret = unsafe {
1175 llama_model_desc(
1176 self.model.as_ptr(),
1177 buf.as_mut_ptr().cast::<c_char>(),
1178 buf_size,
1179 )
1180 };
1181 if ret < 0 {
1182 return Err(StringFromModelError::ReturnedError(ret));
1183 }
1184 let len = ret as usize;
1185 let s = std::str::from_utf8(&buf[..len])
1186 .map_err(StringFromModelError::Utf8Error)?;
1187 Ok(s.to_owned())
1188 }
1189
1190 /// Get a metadata key by index.
1191 ///
1192 /// The `buf_size` parameter specifies the maximum buffer size for the key.
1193 /// A default of 256 bytes is usually sufficient.
1194 ///
1195 /// # Errors
1196 ///
1197 /// Returns an error if the index is out of range or the key is not valid UTF-8.
1198 #[allow(clippy::cast_sign_loss)]
1199 pub fn meta_key_by_index(&self, index: i32, buf_size: usize) -> Result<String, StringFromModelError> {
1200 let mut buf = vec![0u8; buf_size];
1201 let ret = unsafe {
1202 llama_model_meta_key_by_index(
1203 self.model.as_ptr(),
1204 index,
1205 buf.as_mut_ptr().cast::<c_char>(),
1206 buf_size,
1207 )
1208 };
1209 if ret < 0 {
1210 return Err(StringFromModelError::ReturnedError(ret));
1211 }
1212 let len = ret as usize;
1213 let s = std::str::from_utf8(&buf[..len])
1214 .map_err(StringFromModelError::Utf8Error)?;
1215 Ok(s.to_owned())
1216 }
1217
1218 /// Get a metadata value string by index.
1219 ///
1220 /// The `buf_size` parameter specifies the maximum buffer size for the value.
1221 /// Values can be large (e.g. chat templates, token lists), so 4096+ may be needed.
1222 ///
1223 /// # Errors
1224 ///
1225 /// Returns an error if the index is out of range or the value is not valid UTF-8.
1226 #[allow(clippy::cast_sign_loss)]
1227 pub fn meta_val_str_by_index(&self, index: i32, buf_size: usize) -> Result<String, StringFromModelError> {
1228 let mut buf = vec![0u8; buf_size];
1229 let ret = unsafe {
1230 llama_model_meta_val_str_by_index(
1231 self.model.as_ptr(),
1232 index,
1233 buf.as_mut_ptr().cast::<c_char>(),
1234 buf_size,
1235 )
1236 };
1237 if ret < 0 {
1238 return Err(StringFromModelError::ReturnedError(ret));
1239 }
1240 let len = ret as usize;
1241 let s = std::str::from_utf8(&buf[..len])
1242 .map_err(StringFromModelError::Utf8Error)?;
1243 Ok(s.to_owned())
1244 }
1245
1246 /// Get a metadata value by key name.
1247 ///
1248 /// This is more convenient than iterating metadata by index when you know the key.
1249 /// The `buf_size` parameter specifies the maximum buffer size for the value.
1250 ///
1251 /// # Errors
1252 ///
1253 /// Returns an error if the key is not found, contains a null byte, or the value is not valid UTF-8.
1254 #[allow(clippy::cast_sign_loss)]
1255 pub fn meta_val_str(&self, key: &str, buf_size: usize) -> Result<String, StringFromModelError> {
1256 let c_key = CString::new(key)
1257 .map_err(|_| StringFromModelError::ReturnedError(-1))?;
1258 let mut buf = vec![0u8; buf_size];
1259 let ret = unsafe {
1260 llama_model_meta_val_str(
1261 self.model.as_ptr(),
1262 c_key.as_ptr(),
1263 buf.as_mut_ptr().cast::<c_char>(),
1264 buf_size,
1265 )
1266 };
1267 if ret < 0 {
1268 return Err(StringFromModelError::ReturnedError(ret));
1269 }
1270 let len = ret as usize;
1271 let s = std::str::from_utf8(&buf[..len])
1272 .map_err(StringFromModelError::Utf8Error)?;
1273 Ok(s.to_owned())
1274 }
1275
1276 /// Get all metadata as a list of `(key, value)` pairs.
1277 ///
1278 /// This is a convenience method that iterates over all metadata entries.
1279 /// Keys use a buffer of 256 bytes and values use 4096 bytes.
1280 /// For values that may be larger (e.g. token lists), use
1281 /// [`meta_val_str_by_index`](Self::meta_val_str_by_index) directly with a larger buffer.
1282 ///
1283 /// # Errors
1284 ///
1285 /// Returns an error if any key or value cannot be read or is not valid UTF-8.
1286 #[allow(clippy::cast_sign_loss)]
1287 pub fn metadata(&self) -> Result<Vec<(String, String)>, StringFromModelError> {
1288 let count = self.meta_count();
1289 let mut result = Vec::with_capacity(count as usize);
1290 for i in 0..count {
1291 let key = self.meta_key_by_index(i, 256)?;
1292 let val = self.meta_val_str_by_index(i, 4096)?;
1293 result.push((key, val));
1294 }
1295 Ok(result)
1296 }
1297
1298 /// Check if the model has an encoder.
1299 #[must_use]
1300 pub fn has_encoder(&self) -> bool {
1301 unsafe { llama_model_has_encoder(self.model.as_ptr()) }
1302 }
1303
1304 /// Check if the model has a decoder.
1305 #[must_use]
1306 pub fn has_decoder(&self) -> bool {
1307 unsafe { llama_model_has_decoder(self.model.as_ptr()) }
1308 }
1309
1310 /// Check if the model is recurrent (e.g. Mamba, RWKV).
1311 #[must_use]
1312 pub fn is_recurrent(&self) -> bool {
1313 unsafe { llama_model_is_recurrent(self.model.as_ptr()) }
1314 }
1315
1316 /// Check if the model is a hybrid model.
1317 #[must_use]
1318 pub fn is_hybrid(&self) -> bool {
1319 unsafe { llama_model_is_hybrid(self.model.as_ptr()) }
1320 }
1321
1322 /// Check if the model is a diffusion model.
1323 #[must_use]
1324 pub fn is_diffusion(&self) -> bool {
1325 unsafe { llama_model_is_diffusion(self.model.as_ptr()) }
1326 }
1327
1328 /// Get chat template from model.
1329 ///
1330 /// # Errors
1331 ///
1332 /// - If the model does not have a chat template, it will return an error.
1333 /// - If the chat template is not a valid `CString`, it will return an error.
1334 ///
1335 /// # Example
1336 ///
1337 /// ```no_run
1338 /// use llama_cpp_4::model::LlamaModel;
1339 ///
1340 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1341 /// let model = LlamaModel::load_from_file("path/to/model")?;
1342 /// let chat_template = model.get_chat_template(1024)?;
1343 /// # Ok(())
1344 /// # }
1345 /// ```
1346 #[allow(clippy::missing_panics_doc)] // We statically know this will not panic as long as the buffer size is sufficient
1347 pub fn get_chat_template(&self, buf_size: usize) -> Result<String, ChatTemplateError> {
1348 // longest known template is about 1200 bytes from llama.cpp
1349 let chat_temp = CString::new(vec![b'*'; buf_size]).expect("no null");
1350 let chat_ptr = chat_temp.into_raw();
1351 let chat_name = CString::new("tokenizer.chat_template").expect("no null bytes");
1352
1353 let ret = unsafe {
1354 llama_model_meta_val_str(self.model.as_ptr(), chat_name.as_ptr(), chat_ptr, buf_size)
1355 };
1356
1357 if ret < 0 {
1358 return Err(ChatTemplateError::MissingTemplate(ret));
1359 }
1360
1361 let template_c = unsafe { CString::from_raw(chat_ptr) };
1362 let template = template_c.to_str()?;
1363
1364 let ret: usize = ret.try_into().unwrap();
1365 if template.len() < ret {
1366 return Err(ChatTemplateError::BuffSizeError(ret + 1));
1367 }
1368
1369 Ok(template.to_owned())
1370 }
1371
1372 /// Loads a model from a file.
1373 ///
1374 /// This function loads a model from a specified file path and returns the corresponding `LlamaModel` instance.
1375 ///
1376 /// # Errors
1377 ///
1378 /// - If the path cannot be converted to a string or if the model file does not exist, it will return an error.
1379 /// - If the model cannot be loaded (e.g., due to an invalid or corrupted model file), it will return a `LlamaModelLoadError`.
1380 ///
1381 /// # Example
1382 ///
1383 /// ```no_run
1384 /// use llama_cpp_4::model::LlamaModel;
1385 /// use std::path::Path;
1386 ///
1387 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1388 /// let model = LlamaModel::load_from_file("path/to/model", &LlamaModelParams::default())?;
1389 /// # Ok(())
1390 /// # }
1391 /// ```
1392 #[tracing::instrument(skip_all, fields(params))]
1393 pub fn load_from_file(
1394 _: &LlamaBackend,
1395 path: impl AsRef<Path>,
1396 params: &LlamaModelParams,
1397 ) -> Result<Self, LlamaModelLoadError> {
1398 let path = path.as_ref();
1399 debug_assert!(
1400 Path::new(path).exists(),
1401 "{} does not exist",
1402 path.display()
1403 );
1404 let path = path
1405 .to_str()
1406 .ok_or(LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
1407
1408 let cstr = CString::new(path)?;
1409 let llama_model = unsafe { llama_load_model_from_file(cstr.as_ptr(), params.params) };
1410
1411 let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?;
1412
1413 tracing::debug!(?path, "Loaded model");
1414 Ok(LlamaModel { model })
1415 }
1416
1417 /// Load a model from multiple split files.
1418 ///
1419 /// This function loads a model that has been split across multiple files. This is useful for
1420 /// very large models that exceed filesystem limitations or need to be distributed across
1421 /// multiple storage devices.
1422 ///
1423 /// # Arguments
1424 ///
1425 /// * `paths` - A slice of paths to the split model files
1426 /// * `params` - The model parameters
1427 ///
1428 /// # Errors
1429 ///
1430 /// Returns an error if:
1431 /// - Any of the paths cannot be converted to a C string
1432 /// - The model fails to load from the splits
1433 /// - Any path doesn't exist or isn't accessible
1434 ///
1435 /// # Example
1436 ///
1437 /// ```no_run
1438 /// use llama_cpp_4::model::{LlamaModel, params::LlamaModelParams};
1439 /// use llama_cpp_4::llama_backend::LlamaBackend;
1440 ///
1441 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1442 /// let backend = LlamaBackend::init()?;
1443 /// let params = LlamaModelParams::default();
1444 ///
1445 /// let paths = vec![
1446 /// "model-00001-of-00003.gguf",
1447 /// "model-00002-of-00003.gguf",
1448 /// "model-00003-of-00003.gguf",
1449 /// ];
1450 ///
1451 /// let model = LlamaModel::load_from_splits(&backend, &paths, ¶ms)?;
1452 /// # Ok(())
1453 /// # }
1454 /// ```
1455 #[tracing::instrument(skip_all)]
1456 pub fn load_from_splits(
1457 _: &LlamaBackend,
1458 paths: &[impl AsRef<Path>],
1459 params: &LlamaModelParams,
1460 ) -> Result<Self, LlamaModelLoadError> {
1461 // Convert paths to C strings
1462 let c_strings: Vec<CString> = paths
1463 .iter()
1464 .map(|p| {
1465 let path = p.as_ref();
1466 debug_assert!(path.exists(), "{} does not exist", path.display());
1467 let path_str = path
1468 .to_str()
1469 .ok_or(LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
1470 CString::new(path_str).map_err(LlamaModelLoadError::from)
1471 })
1472 .collect::<Result<Vec<_>, _>>()?;
1473
1474 // Create array of pointers to C strings
1475 let c_ptrs: Vec<*const c_char> = c_strings.iter().map(|s| s.as_ptr()).collect();
1476
1477 // Load the model from splits
1478 let llama_model = unsafe {
1479 llama_model_load_from_splits(c_ptrs.as_ptr().cast_mut(), c_ptrs.len(), params.params)
1480 };
1481
1482 let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?;
1483
1484 tracing::debug!("Loaded model from {} splits", paths.len());
1485 Ok(LlamaModel { model })
1486 }
1487
1488 /// Load a model from a `FILE` pointer.
1489 ///
1490 /// # Safety
1491 ///
1492 /// The `file` pointer must be a valid, open `FILE*`.
1493 ///
1494 /// # Errors
1495 ///
1496 /// Returns an error if the model cannot be loaded.
1497 pub unsafe fn load_from_file_ptr(
1498 file: *mut llama_cpp_sys_4::FILE,
1499 params: &LlamaModelParams,
1500 ) -> Result<Self, LlamaModelLoadError> {
1501 let model = llama_cpp_sys_4::llama_model_load_from_file_ptr(file, params.params);
1502 let model = NonNull::new(model).ok_or(LlamaModelLoadError::NullResult)?;
1503 Ok(LlamaModel { model })
1504 }
1505
1506 /// Initialize a model from user-provided data.
1507 ///
1508 /// # Safety
1509 ///
1510 /// The metadata, callback, and user data must be valid.
1511 ///
1512 /// # Errors
1513 ///
1514 /// Returns an error if the model cannot be initialized.
1515 pub unsafe fn init_from_user(
1516 metadata: *mut llama_cpp_sys_4::gguf_context,
1517 set_tensor_data: llama_cpp_sys_4::llama_model_set_tensor_data_t,
1518 set_tensor_data_ud: *mut std::ffi::c_void,
1519 params: &LlamaModelParams,
1520 ) -> Result<Self, LlamaModelLoadError> {
1521 let model = llama_cpp_sys_4::llama_model_init_from_user(
1522 metadata,
1523 set_tensor_data,
1524 set_tensor_data_ud,
1525 params.params,
1526 );
1527 let model = NonNull::new(model).ok_or(LlamaModelLoadError::NullResult)?;
1528 Ok(LlamaModel { model })
1529 }
1530
1531 /// Save the model to a file.
1532 ///
1533 /// # Panics
1534 ///
1535 /// Panics if the path contains null bytes.
1536 pub fn save_to_file(&self, path: impl AsRef<Path>) {
1537 let path = path.as_ref();
1538 let path_str = path.to_str().expect("path is not valid UTF-8");
1539 let c_path = CString::new(path_str).expect("path contains null bytes");
1540 unsafe {
1541 llama_model_save_to_file(self.model.as_ptr(), c_path.as_ptr());
1542 }
1543 }
1544
1545 /// Get the list of built-in chat templates.
1546 ///
1547 /// Returns the names of all chat templates that are built into llama.cpp.
1548 ///
1549 /// # Panics
1550 ///
1551 /// Panics if any template name is not valid UTF-8.
1552 #[allow(clippy::cast_sign_loss)]
1553 #[must_use]
1554 pub fn chat_builtin_templates() -> Vec<String> {
1555 // First call to get count
1556 let count = unsafe { llama_chat_builtin_templates(std::ptr::null_mut(), 0) };
1557 if count <= 0 {
1558 return Vec::new();
1559 }
1560 let count = count as usize;
1561 let mut ptrs: Vec<*const c_char> = vec![std::ptr::null(); count];
1562 unsafe {
1563 llama_chat_builtin_templates(ptrs.as_mut_ptr(), count);
1564 }
1565 ptrs.iter()
1566 .map(|&p| {
1567 let cstr = unsafe { CStr::from_ptr(p) };
1568 cstr.to_str()
1569 .expect("template name is not valid UTF-8")
1570 .to_owned()
1571 })
1572 .collect()
1573 }
1574
1575 /// Initializes a lora adapter from a file.
1576 ///
1577 /// This function initializes a Lora adapter, which is a model extension used to adapt or fine-tune the existing model
1578 /// to a specific domain or task. The adapter file is typically in the form of a binary or serialized file that can be applied
1579 /// to the model for improved performance on specialized tasks.
1580 ///
1581 /// # Errors
1582 ///
1583 /// - If the adapter file path cannot be converted to a string or if the adapter cannot be initialized, it will return an error.
1584 ///
1585 /// # Example
1586 ///
1587 /// ```no_run
1588 /// use llama_cpp_4::model::{LlamaModel, LlamaLoraAdapter};
1589 /// use std::path::Path;
1590 ///
1591 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1592 /// let model = LlamaModel::load_from_file("path/to/model", &LlamaModelParams::default())?;
1593 /// let adapter = model.lora_adapter_init("path/to/lora/adapter")?;
1594 /// # Ok(())
1595 /// # }
1596 /// ```
1597 pub fn lora_adapter_init(
1598 &self,
1599 path: impl AsRef<Path>,
1600 ) -> Result<LlamaLoraAdapter, LlamaLoraAdapterInitError> {
1601 let path = path.as_ref();
1602 debug_assert!(
1603 Path::new(path).exists(),
1604 "{} does not exist",
1605 path.display()
1606 );
1607
1608 let path = path
1609 .to_str()
1610 .ok_or(LlamaLoraAdapterInitError::PathToStrError(
1611 path.to_path_buf(),
1612 ))?;
1613
1614 let cstr = CString::new(path)?;
1615 let adapter = unsafe { llama_adapter_lora_init(self.model.as_ptr(), cstr.as_ptr()) };
1616
1617 let adapter = NonNull::new(adapter).ok_or(LlamaLoraAdapterInitError::NullResult)?;
1618
1619 tracing::debug!(?path, "Initialized lora adapter");
1620 Ok(LlamaLoraAdapter {
1621 lora_adapter: adapter,
1622 })
1623 }
1624
1625 /// Create a new context from this model.
1626 ///
1627 /// This function creates a new context for the model, which is used to manage and perform computations for inference,
1628 /// including token generation, embeddings, and other tasks that the model can perform. The context allows fine-grained
1629 /// control over model parameters for a specific task.
1630 ///
1631 /// # Errors
1632 ///
1633 /// - There are various potential failures such as invalid parameters or a failure to allocate the context. See [`LlamaContextLoadError`]
1634 /// for more detailed error descriptions.
1635 ///
1636 /// # Example
1637 ///
1638 /// ```no_run
1639 /// use llama_cpp_4::model::{LlamaModel, LlamaContext};
1640 /// use llama_cpp_4::LlamaContextParams;
1641 ///
1642 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1643 /// let model = LlamaModel::load_from_file("path/to/model", &LlamaModelParams::default())?;
1644 /// let context = model.new_context(&LlamaBackend::init()?, LlamaContextParams::default())?;
1645 /// # Ok(())
1646 /// # }
1647 /// ```
1648 #[allow(clippy::needless_pass_by_value)]
1649 pub fn new_context(
1650 &self,
1651 _: &LlamaBackend,
1652 params: LlamaContextParams,
1653 ) -> Result<LlamaContext<'_>, LlamaContextLoadError> {
1654 // Apply TurboQuant attn-rotation preference before the KV cache is
1655 // initialised inside llama_new_context_with_model.
1656 let prev_rot_var = std::env::var("LLAMA_ATTN_ROT_DISABLE").ok();
1657 if params.attn_rot_disabled {
1658 // SAFETY: we restore the value right after the call.
1659 #[allow(unused_unsafe)]
1660 unsafe {
1661 std::env::set_var("LLAMA_ATTN_ROT_DISABLE", "1");
1662 }
1663 } else if std::env::var("LLAMA_ATTN_ROT_DISABLE").is_ok() {
1664 // params say "enabled" – only clear if it was previously unset
1665 // (respect explicit user env var).
1666 }
1667
1668 let context_params = params.context_params;
1669 let context = unsafe { llama_new_context_with_model(self.model.as_ptr(), context_params) };
1670
1671 // Restore the env-var to its previous state.
1672 #[allow(unused_unsafe)]
1673 match prev_rot_var {
1674 Some(v) => unsafe { std::env::set_var("LLAMA_ATTN_ROT_DISABLE", v) },
1675 None if params.attn_rot_disabled => unsafe {
1676 std::env::remove_var("LLAMA_ATTN_ROT_DISABLE");
1677 },
1678 None => {}
1679 }
1680
1681 let context = NonNull::new(context).ok_or(LlamaContextLoadError::NullReturn)?;
1682 Ok(LlamaContext::new(self, context, params.embeddings()))
1683 }
1684
1685 /// Apply the model's chat template to a sequence of messages.
1686 ///
1687 /// This function applies the model's chat template to the provided chat messages, formatting them accordingly. The chat
1688 /// template determines the structure or style of conversation between the system and user, such as token formatting,
1689 /// role separation, and more. The template can be customized by providing an optional template string, or if `None`
1690 /// is provided, the default template used by `llama.cpp` will be applied.
1691 ///
1692 /// For more information on supported templates, visit:
1693 /// <https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template>
1694 ///
1695 /// # Arguments
1696 ///
1697 /// - `tmpl`: An optional custom template string. If `None`, the default template will be used.
1698 /// - `chat`: A vector of `LlamaChatMessage` instances, which represent the conversation between the system and user.
1699 /// - `add_ass`: A boolean flag indicating whether additional system-specific instructions (like "assistant") should be included.
1700 ///
1701 /// # Errors
1702 ///
1703 /// There are several possible points of failure when applying the chat template:
1704 /// - Insufficient buffer size to hold the formatted chat (this will return `ApplyChatTemplateError::BuffSizeError`).
1705 /// - If the template or messages cannot be processed properly, various errors from `ApplyChatTemplateError` may occur.
1706 ///
1707 /// # Example
1708 ///
1709 /// ```no_run
1710 /// use llama_cpp_4::model::{LlamaModel, LlamaChatMessage};
1711 ///
1712 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1713 /// let model = LlamaModel::load_from_file("path/to/model", &LlamaModelParams::default())?;
1714 /// let chat = vec![
1715 /// LlamaChatMessage::new("user", "Hello!"),
1716 /// LlamaChatMessage::new("assistant", "Hi! How can I assist you today?"),
1717 /// ];
1718 /// let formatted_chat = model.apply_chat_template(None, chat, true)?;
1719 /// # Ok(())
1720 /// # }
1721 /// ```
1722 ///
1723 /// # Notes
1724 ///
1725 /// The provided buffer is twice the length of the messages by default, which is recommended by the `llama.cpp` documentation.
1726 /// # Panics
1727 ///
1728 /// Panics if the buffer length exceeds `i32::MAX`.
1729 #[tracing::instrument(skip_all)]
1730 pub fn apply_chat_template(
1731 &self,
1732 tmpl: Option<&str>,
1733 chat: &[LlamaChatMessage],
1734 add_ass: bool,
1735 ) -> Result<String, ApplyChatTemplateError> {
1736 // Compute raw message byte total from the original LlamaChatMessage vec
1737 // *before* we shadow `chat` with the sys-type vec below.
1738 let message_length = chat.iter().fold(0usize, |acc, c| {
1739 acc + c.role.to_bytes().len() + c.content.to_bytes().len()
1740 });
1741
1742 // Build our llama_cpp_sys chat messages (raw pointers into CStrings).
1743 let chat_sys: Vec<llama_chat_message> = chat
1744 .iter()
1745 .map(|c| llama_chat_message {
1746 role: c.role.as_ptr(),
1747 content: c.content.as_ptr(),
1748 })
1749 .collect();
1750
1751 // Set the tmpl pointer.
1752 let tmpl_cstring = tmpl.map(CString::new).transpose()?;
1753 let tmpl_ptr = tmpl_cstring
1754 .as_ref()
1755 .map_or(std::ptr::null(), |s| s.as_ptr());
1756
1757 // `message_length * 4` is far too small for models whose built-in chat
1758 // template adds a long default system prompt (e.g. Qwen3.5 prepends
1759 // ~80+ chars of markup even for a one-word user message). Start with
1760 // at least 4 KiB so short inputs like "hi" always have room.
1761 //
1762 // `llama_chat_apply_template` returns the number of bytes it *actually*
1763 // needed when the buffer was too small, so we retry exactly once with
1764 // that precise size rather than giving up immediately.
1765 let mut buf_size = message_length.saturating_mul(4).max(4096);
1766
1767 for _ in 0..2 {
1768 // Use u8 so that as_mut_ptr()/as_ptr() match the binding (*mut u8 / *const u8).
1769 let mut buff = vec![0u8; buf_size];
1770 let res = unsafe {
1771 llama_chat_apply_template(
1772 tmpl_ptr,
1773 chat_sys.as_ptr(),
1774 chat_sys.len(),
1775 add_ass,
1776 buff.as_mut_ptr().cast(),
1777 i32::try_from(buff.len()).expect("buffer length fits in i32"),
1778 )
1779 };
1780
1781 if res < 0 {
1782 return Err(ApplyChatTemplateError::BuffSizeError);
1783 }
1784
1785 #[allow(clippy::cast_sign_loss)]
1786 let needed = res as usize;
1787 if needed > buf_size {
1788 // Buffer was too small — retry with the exact size llama.cpp reported.
1789 buf_size = needed + 1; // +1 for null terminator
1790 continue;
1791 }
1792
1793 // SAFETY: llama_chat_apply_template wrote a NUL-terminated string
1794 // into `buff`; `needed` bytes were used.
1795 let formatted = unsafe {
1796 CStr::from_ptr(buff.as_ptr().cast())
1797 .to_string_lossy()
1798 .into_owned()
1799 };
1800 return Ok(formatted);
1801 }
1802
1803 Err(ApplyChatTemplateError::BuffSizeError)
1804 }
1805
1806 /// Build a split GGUF file path for a specific chunk.
1807 ///
1808 /// This utility function creates the standardized filename for a split model chunk
1809 /// following the pattern: `{prefix}-{split_no:05d}-of-{split_count:05d}.gguf`
1810 ///
1811 /// # Arguments
1812 ///
1813 /// * `path_prefix` - The base path and filename prefix
1814 /// * `split_no` - The split number (1-indexed)
1815 /// * `split_count` - The total number of splits
1816 ///
1817 /// # Returns
1818 ///
1819 /// Returns the formatted split path as a String
1820 ///
1821 /// # Example
1822 ///
1823 /// ```
1824 /// use llama_cpp_4::model::LlamaModel;
1825 ///
1826 /// let path = LlamaModel::split_path("/models/llama", 2, 4);
1827 /// assert_eq!(path, "/models/llama-00002-of-00004.gguf");
1828 /// ```
1829 ///
1830 /// # Panics
1831 ///
1832 /// Panics if the path prefix contains a null byte.
1833 #[must_use]
1834 pub fn split_path(path_prefix: &str, split_no: i32, split_count: i32) -> String {
1835 let mut buffer = vec![0u8; 1024];
1836 let len = unsafe {
1837 llama_split_path(
1838 buffer.as_mut_ptr().cast::<c_char>(),
1839 buffer.len(),
1840 CString::new(path_prefix).unwrap().as_ptr(),
1841 split_no,
1842 split_count,
1843 )
1844 };
1845
1846 let len = usize::try_from(len).expect("split_path length fits in usize");
1847 buffer.truncate(len);
1848 String::from_utf8(buffer).unwrap_or_default()
1849 }
1850
1851 /// Extract the path prefix from a split filename.
1852 ///
1853 /// This function extracts the base path prefix from a split model filename,
1854 /// but only if the `split_no` and `split_count` match the pattern in the filename.
1855 ///
1856 /// # Arguments
1857 ///
1858 /// * `split_path` - The full path to the split file
1859 /// * `split_no` - The expected split number
1860 /// * `split_count` - The expected total number of splits
1861 ///
1862 /// # Returns
1863 ///
1864 /// Returns the path prefix if the pattern matches, or None if it doesn't
1865 ///
1866 /// # Example
1867 ///
1868 /// ```
1869 /// use llama_cpp_4::model::LlamaModel;
1870 ///
1871 /// let prefix = LlamaModel::split_prefix("/models/llama-00002-of-00004.gguf", 2, 4);
1872 /// assert_eq!(prefix, Some("/models/llama".to_string()));
1873 /// ```
1874 ///
1875 /// # Panics
1876 ///
1877 /// Panics if the split path contains a null byte.
1878 #[must_use]
1879 pub fn split_prefix(split_path: &str, split_no: i32, split_count: i32) -> Option<String> {
1880 let mut buffer = vec![0u8; 1024];
1881 let len = unsafe {
1882 llama_split_prefix(
1883 buffer.as_mut_ptr().cast::<c_char>(),
1884 buffer.len(),
1885 CString::new(split_path).unwrap().as_ptr(),
1886 split_no,
1887 split_count,
1888 )
1889 };
1890
1891 if len > 0 {
1892 let len = usize::try_from(len).expect("split_prefix length fits in usize");
1893 buffer.truncate(len);
1894 String::from_utf8(buffer).ok()
1895 } else {
1896 None
1897 }
1898 }
1899}
1900
1901#[allow(clippy::cast_precision_loss)]
1902impl fmt::Display for LlamaModel {
1903 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1904 let desc = self.desc(256).unwrap_or_else(|_| "unknown".to_string());
1905 write!(
1906 f,
1907 "{desc} | {layers}L {heads}H {embd}E | {params} params | {size:.1} MiB",
1908 layers = self.n_layer(),
1909 heads = self.n_head(),
1910 embd = self.n_embd(),
1911 params = self.n_params(),
1912 size = self.model_size() as f64 / (1024.0 * 1024.0),
1913 )
1914 }
1915}
1916
1917impl Drop for LlamaModel {
1918 fn drop(&mut self) {
1919 unsafe { llama_free_model(self.model.as_ptr()) }
1920 }
1921}
1922
1923/// Defines the possible types of vocabulary used by the model.
1924///
1925/// The model may use different types of vocabulary depending on the tokenization method chosen during training.
1926/// This enum represents these types, specifically `BPE` (Byte Pair Encoding) and `SPM` (`SentencePiece`).
1927///
1928/// # Variants
1929///
1930/// - `BPE`: Byte Pair Encoding, a common tokenization method used in NLP tasks.
1931/// - `SPM`: `SentencePiece`, another popular tokenization method for NLP models.
1932///
1933/// # Example
1934///
1935/// ```no_run
1936/// use llama_cpp_4::model::VocabType;
1937///
1938/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1939/// let vocab_type = VocabType::BPE;
1940/// match vocab_type {
1941/// VocabType::BPE => println!("The model uses Byte Pair Encoding (BPE)"),
1942/// VocabType::SPM => println!("The model uses SentencePiece (SPM)"),
1943/// }
1944/// # Ok(())
1945/// # }
1946/// ```
1947#[repr(u32)]
1948#[derive(Debug, Eq, Copy, Clone, PartialEq)]
1949pub enum VocabType {
1950 /// Byte Pair Encoding
1951 BPE = LLAMA_VOCAB_TYPE_BPE as _,
1952 /// Sentence Piece Tokenizer
1953 SPM = LLAMA_VOCAB_TYPE_SPM as _,
1954}
1955
1956/// Error that occurs when trying to convert a `llama_vocab_type` to a `VocabType`.
1957///
1958/// This error is raised when the integer value returned by the system does not correspond to a known vocabulary type.
1959///
1960/// # Variants
1961///
1962/// - `UnknownValue`: The error is raised when the value is not a valid `llama_vocab_type`. The invalid value is returned with the error.
1963///
1964/// # Example
1965///
1966/// ```no_run
1967/// use llama_cpp_4::model::LlamaTokenTypeFromIntError;
1968///
1969/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1970/// let invalid_value = 999; // Not a valid vocabulary type
1971/// let error = LlamaTokenTypeFromIntError::UnknownValue(invalid_value);
1972/// println!("Error: {}", error);
1973/// # Ok(())
1974/// # }
1975/// ```
1976#[derive(thiserror::Error, Debug, Eq, PartialEq)]
1977pub enum LlamaTokenTypeFromIntError {
1978 /// The value is not a valid `llama_token_type`. Contains the int value that was invalid.
1979 #[error("Unknown Value {0}")]
1980 UnknownValue(llama_vocab_type),
1981}
1982
1983impl TryFrom<llama_vocab_type> for VocabType {
1984 type Error = LlamaTokenTypeFromIntError;
1985
1986 fn try_from(value: llama_vocab_type) -> Result<Self, Self::Error> {
1987 match value {
1988 LLAMA_VOCAB_TYPE_BPE => Ok(VocabType::BPE),
1989 LLAMA_VOCAB_TYPE_SPM => Ok(VocabType::SPM),
1990 unknown => Err(LlamaTokenTypeFromIntError::UnknownValue(unknown)),
1991 }
1992 }
1993}