llama_cpp_4/model.rs
1//! A safe wrapper around `llama_model`.
2use std::ffi::CStr;
3use std::ffi::CString;
4use std::fmt;
5use std::num::NonZeroU16;
6use std::os::raw::{c_char, c_int};
7use std::path::Path;
8use std::ptr::NonNull;
9
10use llama_cpp_sys_4::{
11 llama_adapter_lora, llama_adapter_lora_init, llama_chat_apply_template,
12 llama_chat_builtin_templates, llama_chat_message, llama_detokenize, llama_init_from_model,
13 llama_model, llama_model_cls_label, llama_model_decoder_start_token, llama_model_desc,
14 llama_model_free, llama_model_get_vocab, llama_model_has_decoder, llama_model_has_encoder,
15 llama_model_is_diffusion, llama_model_is_hybrid, llama_model_is_recurrent,
16 llama_model_load_from_file, llama_model_load_from_splits, llama_model_meta_count,
17 llama_model_meta_key_by_index, llama_model_meta_val_str, llama_model_meta_val_str_by_index,
18 llama_model_n_cls_out, llama_model_n_ctx_train, llama_model_n_embd, llama_model_n_embd_inp,
19 llama_model_n_embd_out, llama_model_n_head, llama_model_n_head_kv, llama_model_n_layer,
20 llama_model_n_params, llama_model_n_swa, llama_model_rope_freq_scale_train,
21 llama_model_rope_type, llama_model_save_to_file, llama_model_size, llama_split_path,
22 llama_split_prefix, llama_token_to_piece, llama_tokenize, llama_vocab, llama_vocab_type,
23 LLAMA_VOCAB_TYPE_BPE, LLAMA_VOCAB_TYPE_SPM,
24};
25
26use crate::context::params::LlamaContextParams;
27use crate::context::LlamaContext;
28use crate::llama_backend::LlamaBackend;
29use crate::model::params::LlamaModelParams;
30use crate::token::LlamaToken;
31use crate::token_type::{LlamaTokenAttr, LlamaTokenAttrs};
32use crate::{
33 ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaLoraAdapterInitError,
34 LlamaModelLoadError, NewLlamaChatMessageError, StringFromModelError, StringToTokenError,
35 TokenToStringError,
36};
37
38pub mod params;
39
40/// A safe wrapper around `llama_model`.
41#[derive(Debug)]
42#[repr(transparent)]
43#[allow(clippy::module_name_repetitions)]
44pub struct LlamaModel {
45 pub(crate) model: NonNull<llama_model>,
46}
47
48/// A safe wrapper around `llama_vocab`.
49#[derive(Debug)]
50#[repr(transparent)]
51#[allow(clippy::module_name_repetitions)]
52pub struct LlamaVocab {
53 pub(crate) vocab: NonNull<llama_vocab>,
54}
55
56impl LlamaVocab {
57 /// Get the number of tokens in the vocabulary.
58 #[must_use]
59 pub fn n_tokens(&self) -> i32 {
60 unsafe { llama_cpp_sys_4::llama_vocab_n_tokens(self.vocab.as_ref()) }
61 }
62
63 /// Get the vocabulary type.
64 #[must_use]
65 pub fn vocab_type(&self) -> u32 {
66 unsafe { llama_cpp_sys_4::llama_vocab_type(self.vocab.as_ref()).try_into().unwrap() }
67 }
68
69 /// Get the BOS token.
70 #[must_use]
71 pub fn bos(&self) -> LlamaToken {
72 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_bos(self.vocab.as_ref()) })
73 }
74
75 /// Get the EOS token.
76 #[must_use]
77 pub fn eos(&self) -> LlamaToken {
78 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_eos(self.vocab.as_ref()) })
79 }
80
81 /// Get the EOT (end of turn) token.
82 #[must_use]
83 pub fn eot(&self) -> LlamaToken {
84 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_eot(self.vocab.as_ref()) })
85 }
86
87 /// Get the CLS (classification) token.
88 #[must_use]
89 pub fn cls(&self) -> LlamaToken {
90 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_cls(self.vocab.as_ref()) })
91 }
92
93 /// Get the SEP (separator) token.
94 #[must_use]
95 pub fn sep(&self) -> LlamaToken {
96 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_sep(self.vocab.as_ref()) })
97 }
98
99 /// Get the NL (newline) token.
100 #[must_use]
101 pub fn nl(&self) -> LlamaToken {
102 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_nl(self.vocab.as_ref()) })
103 }
104
105 /// Get the PAD (padding) token.
106 #[must_use]
107 pub fn pad(&self) -> LlamaToken {
108 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_pad(self.vocab.as_ref()) })
109 }
110
111 /// Get the FIM prefix token.
112 #[must_use]
113 pub fn fim_pre(&self) -> LlamaToken {
114 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_pre(self.vocab.as_ref()) })
115 }
116
117 /// Get the FIM suffix token.
118 #[must_use]
119 pub fn fim_suf(&self) -> LlamaToken {
120 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_suf(self.vocab.as_ref()) })
121 }
122
123 /// Get the FIM middle token.
124 #[must_use]
125 pub fn fim_mid(&self) -> LlamaToken {
126 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_mid(self.vocab.as_ref()) })
127 }
128
129 /// Get the FIM padding token.
130 #[must_use]
131 pub fn fim_pad(&self) -> LlamaToken {
132 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_pad(self.vocab.as_ref()) })
133 }
134
135 /// Get the FIM repository token.
136 #[must_use]
137 pub fn fim_rep(&self) -> LlamaToken {
138 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_rep(self.vocab.as_ref()) })
139 }
140
141 /// Get the FIM separator token.
142 #[must_use]
143 pub fn fim_sep(&self) -> LlamaToken {
144 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_sep(self.vocab.as_ref()) })
145 }
146
147 /// Check whether BOS should be added.
148 #[must_use]
149 pub fn get_add_bos(&self) -> bool {
150 unsafe { llama_cpp_sys_4::llama_vocab_get_add_bos(self.vocab.as_ref()) }
151 }
152
153 /// Check whether EOS should be added.
154 #[must_use]
155 pub fn get_add_eos(&self) -> bool {
156 unsafe { llama_cpp_sys_4::llama_vocab_get_add_eos(self.vocab.as_ref()) }
157 }
158
159 /// Check whether SEP should be added.
160 #[must_use]
161 pub fn get_add_sep(&self) -> bool {
162 unsafe { llama_cpp_sys_4::llama_vocab_get_add_sep(self.vocab.as_ref()) }
163 }
164
165 /// Get the text representation of a token.
166 ///
167 /// # Errors
168 ///
169 /// Returns an error if the text pointer is null or not valid UTF-8.
170 pub fn get_text(&self, token: LlamaToken) -> Result<&str, StringFromModelError> {
171 let ptr = unsafe { llama_cpp_sys_4::llama_vocab_get_text(self.vocab.as_ref(), token.0) };
172 if ptr.is_null() {
173 return Err(StringFromModelError::ReturnedError(-1));
174 }
175 let cstr = unsafe { CStr::from_ptr(ptr) };
176 cstr.to_str().map_err(StringFromModelError::Utf8Error)
177 }
178
179 /// Get the score of a token.
180 #[must_use]
181 pub fn get_score(&self, token: LlamaToken) -> f32 {
182 unsafe { llama_cpp_sys_4::llama_vocab_get_score(self.vocab.as_ref(), token.0) }
183 }
184
185 /// Get the attributes of a token.
186 #[must_use]
187 pub fn get_attr(&self, token: LlamaToken) -> u32 {
188 unsafe { llama_cpp_sys_4::llama_vocab_get_attr(self.vocab.as_ref(), token.0).try_into().unwrap() }
189 }
190
191 /// Check if a token is a control token.
192 #[must_use]
193 pub fn is_control(&self, token: LlamaToken) -> bool {
194 unsafe { llama_cpp_sys_4::llama_vocab_is_control(self.vocab.as_ref(), token.0) }
195 }
196
197 /// Check if a token is an end-of-generation token.
198 #[must_use]
199 pub fn is_eog(&self, token: LlamaToken) -> bool {
200 unsafe { llama_cpp_sys_4::llama_vocab_is_eog(self.vocab.as_ref(), token.0) }
201 }
202
203 /// Get the token mask value for the vocabulary.
204 #[must_use]
205 pub fn mask(&self) -> LlamaToken {
206 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_mask(self.vocab.as_ref()) })
207 }
208}
209
210/// A safe wrapper around `llama_adapter_lora`.
211#[derive(Debug)]
212#[repr(transparent)]
213#[allow(clippy::module_name_repetitions)]
214pub struct LlamaLoraAdapter {
215 pub(crate) lora_adapter: NonNull<llama_adapter_lora>,
216}
217
218impl LlamaLoraAdapter {
219 /// Get the number of metadata key-value pairs in the adapter.
220 #[must_use]
221 pub fn meta_count(&self) -> i32 {
222 unsafe { llama_cpp_sys_4::llama_adapter_meta_count(self.lora_adapter.as_ptr()) }
223 }
224
225 /// Get a metadata key by index.
226 ///
227 /// # Errors
228 ///
229 /// Returns an error if the index is out of range or the key is not valid UTF-8.
230 #[allow(clippy::cast_sign_loss)]
231 pub fn meta_key_by_index(
232 &self,
233 index: i32,
234 buf_size: usize,
235 ) -> Result<String, StringFromModelError> {
236 let mut buf = vec![0u8; buf_size];
237 let ret = unsafe {
238 llama_cpp_sys_4::llama_adapter_meta_key_by_index(
239 self.lora_adapter.as_ptr(),
240 index,
241 buf.as_mut_ptr().cast::<c_char>(),
242 buf_size,
243 )
244 };
245 if ret < 0 {
246 return Err(StringFromModelError::ReturnedError(ret));
247 }
248 let len = ret as usize;
249 let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
250 Ok(s.to_owned())
251 }
252
253 /// Get a metadata value by key name.
254 ///
255 /// # Errors
256 ///
257 /// Returns an error if the key is not found or the value is not valid UTF-8.
258 #[allow(clippy::cast_sign_loss)]
259 pub fn meta_val_str(
260 &self,
261 key: &str,
262 buf_size: usize,
263 ) -> Result<String, StringFromModelError> {
264 let c_key =
265 CString::new(key).map_err(|_| StringFromModelError::ReturnedError(-1))?;
266 let mut buf = vec![0u8; buf_size];
267 let ret = unsafe {
268 llama_cpp_sys_4::llama_adapter_meta_val_str(
269 self.lora_adapter.as_ptr(),
270 c_key.as_ptr(),
271 buf.as_mut_ptr().cast::<c_char>(),
272 buf_size,
273 )
274 };
275 if ret < 0 {
276 return Err(StringFromModelError::ReturnedError(ret));
277 }
278 let len = ret as usize;
279 let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
280 Ok(s.to_owned())
281 }
282
283 /// Get a metadata value by index.
284 ///
285 /// # Errors
286 ///
287 /// Returns an error if the index is out of range or the value is not valid UTF-8.
288 #[allow(clippy::cast_sign_loss)]
289 pub fn meta_val_str_by_index(
290 &self,
291 index: i32,
292 buf_size: usize,
293 ) -> Result<String, StringFromModelError> {
294 let mut buf = vec![0u8; buf_size];
295 let ret = unsafe {
296 llama_cpp_sys_4::llama_adapter_meta_val_str_by_index(
297 self.lora_adapter.as_ptr(),
298 index,
299 buf.as_mut_ptr().cast::<c_char>(),
300 buf_size,
301 )
302 };
303 if ret < 0 {
304 return Err(StringFromModelError::ReturnedError(ret));
305 }
306 let len = ret as usize;
307 let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
308 Ok(s.to_owned())
309 }
310
311 /// Get all metadata as a list of `(key, value)` pairs.
312 ///
313 /// # Errors
314 ///
315 /// Returns an error if any key or value cannot be read or is not valid UTF-8.
316 #[allow(clippy::cast_sign_loss)]
317 pub fn metadata(&self) -> Result<Vec<(String, String)>, StringFromModelError> {
318 let count = self.meta_count();
319 let mut result = Vec::with_capacity(count as usize);
320 for i in 0..count {
321 let key = self.meta_key_by_index(i, 256)?;
322 let val = self.meta_val_str_by_index(i, 4096)?;
323 result.push((key, val));
324 }
325 Ok(result)
326 }
327
328 /// Get the number of invocation tokens for this adapter.
329 #[must_use]
330 pub fn n_invocation_tokens(&self) -> u64 {
331 unsafe {
332 llama_cpp_sys_4::llama_adapter_get_alora_n_invocation_tokens(
333 self.lora_adapter.as_ptr(),
334 )
335 }
336 }
337
338 /// Get the invocation tokens for this adapter.
339 ///
340 /// Returns an empty slice if there are no invocation tokens.
341 #[must_use]
342 #[allow(clippy::cast_possible_truncation)]
343 pub fn invocation_tokens(&self) -> &[LlamaToken] {
344 let n = self.n_invocation_tokens() as usize;
345 if n == 0 {
346 return &[];
347 }
348 let ptr = unsafe {
349 llama_cpp_sys_4::llama_adapter_get_alora_invocation_tokens(
350 self.lora_adapter.as_ptr(),
351 )
352 };
353 if ptr.is_null() {
354 return &[];
355 }
356 // LlamaToken is repr(transparent) over llama_token (i32), so this cast is safe
357 unsafe { std::slice::from_raw_parts(ptr.cast::<LlamaToken>(), n) }
358 }
359}
360
361impl Drop for LlamaLoraAdapter {
362 fn drop(&mut self) {
363 unsafe {
364 llama_cpp_sys_4::llama_adapter_lora_free(self.lora_adapter.as_ptr());
365 }
366 }
367}
368
369/// A Safe wrapper around `llama_chat_message`
370#[derive(Debug, Eq, PartialEq, Clone)]
371pub struct LlamaChatMessage {
372 role: CString,
373 content: CString,
374}
375
376impl LlamaChatMessage {
377 /// Create a new `LlamaChatMessage`.
378 ///
379 /// # Errors
380 ///
381 /// Returns [`NewLlamaChatMessageError`] if the role or content contains a null byte.
382 pub fn new(role: String, content: String) -> Result<Self, NewLlamaChatMessageError> {
383 Ok(Self {
384 role: CString::new(role)?,
385 content: CString::new(content)?,
386 })
387 }
388}
389
390/// How to determine if we should prepend a bos token to tokens
391#[derive(Debug, Clone, Copy, PartialEq, Eq)]
392pub enum AddBos {
393 /// Add the beginning of stream token to the start of the string.
394 Always,
395 /// Do not add the beginning of stream token to the start of the string.
396 Never,
397}
398
399/// How to determine if we should tokenize special tokens
400#[derive(Debug, Clone, Copy, PartialEq, Eq)]
401pub enum Special {
402 /// Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space.
403 Tokenize,
404 /// Treat special and/or control tokens as plaintext.
405 Plaintext,
406}
407
408unsafe impl Send for LlamaModel {}
409
410unsafe impl Sync for LlamaModel {}
411
412impl LlamaModel {
413 /// Retrieves the vocabulary associated with the current Llama model.
414 ///
415 /// This method fetches the vocabulary from the underlying model using an unsafe
416 /// FFI call. The returned `LlamaVocab` struct contains a non-null pointer to
417 /// the vocabulary data, which is wrapped in a `NonNull` for safety.
418 ///
419 /// # Safety
420 /// This method uses an unsafe block to call a C function (`llama_model_get_vocab`),
421 /// which is assumed to return a valid pointer to the vocabulary. The caller should
422 /// ensure that the model object is properly initialized and valid before calling
423 /// this method, as dereferencing invalid pointers can lead to undefined behavior.
424 ///
425 /// # Returns
426 /// A `LlamaVocab` struct containing the vocabulary of the model.
427 ///
428 /// # Panics
429 ///
430 /// Panics if the underlying C function returns a null pointer.
431 ///
432 /// # Example
433 /// ```rust,ignore
434 /// let vocab = model.get_vocab();
435 /// ```
436 #[must_use]
437 pub fn get_vocab(&self) -> LlamaVocab {
438 let llama_vocab = unsafe { llama_model_get_vocab(self.model.as_ptr()) }.cast_mut();
439
440 LlamaVocab {
441 vocab: NonNull::new(llama_vocab).unwrap(),
442 }
443 }
444 /// Get the number of tokens the model was trained on.
445 ///
446 /// This function returns the number of tokens that the model was trained on, represented as a `u32`.
447 ///
448 /// # Panics
449 ///
450 /// This function will panic if the number of tokens the model was trained on does not fit into a `u32`.
451 /// This should be impossible on most platforms since llama.cpp returns a `c_int` (i32 on most platforms),
452 /// which is almost certainly positive.
453 #[must_use]
454 pub fn n_ctx_train(&self) -> u32 {
455 let n_ctx_train = unsafe { llama_model_n_ctx_train(self.model.as_ptr()) };
456 u32::try_from(n_ctx_train).expect("n_ctx_train fits into an u32")
457 }
458
459 /// Get all tokens in the model.
460 ///
461 /// This function returns an iterator over all the tokens in the model. Each item in the iterator is a tuple
462 /// containing a `LlamaToken` and its corresponding string representation (or an error if the conversion fails).
463 ///
464 /// # Parameters
465 ///
466 /// - `special`: The `Special` value that determines how special tokens (like BOS, EOS, etc.) are handled.
467 pub fn tokens(
468 &self,
469 special: Special,
470 ) -> impl Iterator<Item = (LlamaToken, Result<String, TokenToStringError>)> + '_ {
471 (0..self.n_vocab())
472 .map(LlamaToken::new)
473 .map(move |llama_token| (llama_token, self.token_to_str(llama_token, special)))
474 }
475
476 /// Get the beginning of stream token.
477 ///
478 /// This function returns the token that represents the beginning of a stream (BOS token).
479 #[must_use]
480 pub fn token_bos(&self) -> LlamaToken {
481 self.get_vocab().bos()
482 }
483
484 /// Get the end of stream token.
485 ///
486 /// This function returns the token that represents the end of a stream (EOS token).
487 #[must_use]
488 pub fn token_eos(&self) -> LlamaToken {
489 self.get_vocab().eos()
490 }
491
492 /// Get the newline token.
493 ///
494 /// This function returns the token that represents a newline character.
495 #[must_use]
496 pub fn token_nl(&self) -> LlamaToken {
497 self.get_vocab().nl()
498 }
499
500 /// Check if a token represents the end of generation (end of turn, end of sequence, etc.).
501 ///
502 /// This function returns `true` if the provided token signifies the end of generation or end of sequence,
503 /// such as EOS or other special tokens.
504 ///
505 /// # Parameters
506 ///
507 /// - `token`: The `LlamaToken` to check.
508 ///
509 /// # Returns
510 ///
511 /// - `true` if the token is an end-of-generation token, otherwise `false`.
512 #[must_use]
513 pub fn is_eog_token(&self, token: LlamaToken) -> bool {
514 self.get_vocab().is_eog(token)
515 }
516
517 /// Get the classification token.
518 #[must_use]
519 pub fn token_cls(&self) -> LlamaToken {
520 self.get_vocab().cls()
521 }
522
523 /// Get the end-of-turn token.
524 #[must_use]
525 pub fn token_eot(&self) -> LlamaToken {
526 self.get_vocab().eot()
527 }
528
529 /// Get the padding token.
530 #[must_use]
531 pub fn token_pad(&self) -> LlamaToken {
532 self.get_vocab().pad()
533 }
534
535 /// Get the separator token.
536 #[must_use]
537 pub fn token_sep(&self) -> LlamaToken {
538 self.get_vocab().sep()
539 }
540
541 /// Get the fill-in-the-middle prefix token.
542 #[must_use]
543 pub fn token_fim_pre(&self) -> LlamaToken {
544 self.get_vocab().fim_pre()
545 }
546
547 /// Get the fill-in-the-middle suffix token.
548 #[must_use]
549 pub fn token_fim_suf(&self) -> LlamaToken {
550 self.get_vocab().fim_suf()
551 }
552
553 /// Get the fill-in-the-middle middle token.
554 #[must_use]
555 pub fn token_fim_mid(&self) -> LlamaToken {
556 self.get_vocab().fim_mid()
557 }
558
559 /// Get the fill-in-the-middle padding token.
560 #[must_use]
561 pub fn token_fim_pad(&self) -> LlamaToken {
562 self.get_vocab().fim_pad()
563 }
564
565 /// Get the fill-in-the-middle repository token.
566 #[must_use]
567 pub fn token_fim_rep(&self) -> LlamaToken {
568 self.get_vocab().fim_rep()
569 }
570
571 /// Get the fill-in-the-middle separator token.
572 #[must_use]
573 pub fn token_fim_sep(&self) -> LlamaToken {
574 self.get_vocab().fim_sep()
575 }
576
577 /// Check if a token is a control token.
578 #[must_use]
579 pub fn token_is_control(&self, token: LlamaToken) -> bool {
580 self.get_vocab().is_control(token)
581 }
582
583 /// Get the score of a token.
584 #[must_use]
585 pub fn token_get_score(&self, token: LlamaToken) -> f32 {
586 self.get_vocab().get_score(token)
587 }
588
589 /// Get the raw text of a token.
590 ///
591 /// # Errors
592 ///
593 /// Returns an error if the token text is null or not valid UTF-8.
594 pub fn token_get_text(&self, token: LlamaToken) -> Result<&str, StringFromModelError> {
595 let ptr = unsafe {
596 llama_cpp_sys_4::llama_vocab_get_text(self.get_vocab().vocab.as_ref(), token.0)
597 };
598 if ptr.is_null() {
599 return Err(StringFromModelError::ReturnedError(-1));
600 }
601 let cstr = unsafe { CStr::from_ptr(ptr) };
602 cstr.to_str().map_err(StringFromModelError::Utf8Error)
603 }
604
605 /// Check if a BOS token should be added when tokenizing.
606 #[must_use]
607 pub fn add_bos_token(&self) -> bool {
608 self.get_vocab().get_add_bos()
609 }
610
611 /// Check if an EOS token should be added when tokenizing.
612 #[must_use]
613 pub fn add_eos_token(&self) -> bool {
614 self.get_vocab().get_add_eos()
615 }
616
617 /// Get the decoder start token.
618 ///
619 /// This function returns the token used to signal the start of decoding (i.e., the token used at the start
620 /// of a sequence generation).
621 #[must_use]
622 pub fn decode_start_token(&self) -> LlamaToken {
623 let token = unsafe { llama_model_decoder_start_token(self.model.as_ptr()) };
624 LlamaToken(token)
625 }
626
627 /// Convert a single token to a string.
628 ///
629 /// This function converts a `LlamaToken` into its string representation.
630 ///
631 /// # Errors
632 ///
633 /// This function returns an error if the token cannot be converted to a string. For more details, refer to
634 /// [`TokenToStringError`].
635 ///
636 /// # Parameters
637 ///
638 /// - `token`: The `LlamaToken` to convert.
639 /// - `special`: The `Special` value used to handle special tokens.
640 pub fn token_to_str(
641 &self,
642 token: LlamaToken,
643 special: Special,
644 ) -> Result<String, TokenToStringError> {
645 self.token_to_str_with_size(token, 32, special)
646 }
647
648 /// Convert a single token to bytes.
649 ///
650 /// This function converts a `LlamaToken` into a byte representation.
651 ///
652 /// # Errors
653 ///
654 /// This function returns an error if the token cannot be converted to bytes. For more details, refer to
655 /// [`TokenToStringError`].
656 ///
657 /// # Parameters
658 ///
659 /// - `token`: The `LlamaToken` to convert.
660 /// - `special`: The `Special` value used to handle special tokens.
661 pub fn token_to_bytes(
662 &self,
663 token: LlamaToken,
664 special: Special,
665 ) -> Result<Vec<u8>, TokenToStringError> {
666 self.token_to_bytes_with_size(token, 32, special, None)
667 }
668
669 /// Convert a vector of tokens to a single string.
670 ///
671 /// This function takes a slice of `LlamaToken`s and converts them into a single string, concatenating their
672 /// string representations.
673 ///
674 /// # Errors
675 ///
676 /// This function returns an error if any token cannot be converted to a string. For more details, refer to
677 /// [`TokenToStringError`].
678 ///
679 /// # Parameters
680 ///
681 /// - `tokens`: A slice of `LlamaToken`s to convert.
682 /// - `special`: The `Special` value used to handle special tokens.
683 pub fn tokens_to_str(
684 &self,
685 tokens: &[LlamaToken],
686 special: Special,
687 ) -> Result<String, TokenToStringError> {
688 let mut builder = String::with_capacity(tokens.len() * 4);
689 for str in tokens
690 .iter()
691 .copied()
692 .map(|t| self.token_to_str(t, special))
693 {
694 builder += &str?;
695 }
696 Ok(builder)
697 }
698
699 /// Convert a string to a vector of tokens.
700 ///
701 /// This function converts a string into a vector of `LlamaToken`s. The function will tokenize the string
702 /// and return the corresponding tokens.
703 ///
704 /// # Errors
705 ///
706 /// - This function will return an error if the input string contains a null byte.
707 ///
708 /// # Panics
709 ///
710 /// - This function will panic if the number of tokens exceeds `usize::MAX`.
711 ///
712 /// # Example
713 ///
714 /// ```no_run
715 /// use llama_cpp_4::model::LlamaModel;
716 ///
717 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
718 /// use std::path::Path;
719 /// use llama_cpp_4::model::AddBos;
720 /// let backend = llama_cpp_4::llama_backend::LlamaBackend::init()?;
721 /// let model = LlamaModel::load_from_file(&backend, Path::new("path/to/model"), &Default::default())?;
722 /// let tokens = model.str_to_token("Hello, World!", AddBos::Always)?;
723 /// # Ok(())
724 /// # }
725 /// ```
726 pub fn str_to_token(
727 &self,
728 str: &str,
729 add_bos: AddBos,
730 ) -> Result<Vec<LlamaToken>, StringToTokenError> {
731 let add_bos = match add_bos {
732 AddBos::Always => true,
733 AddBos::Never => false,
734 };
735
736 let tokens_estimation = std::cmp::max(8, (str.len() / 2) + usize::from(add_bos));
737 let mut buffer = Vec::with_capacity(tokens_estimation);
738
739 let c_string = CString::new(str)?;
740 let buffer_capacity =
741 c_int::try_from(buffer.capacity()).expect("buffer capacity should fit into a c_int");
742
743 let size = unsafe {
744 llama_tokenize(
745 self.get_vocab().vocab.as_ref(),
746 c_string.as_ptr(),
747 c_int::try_from(c_string.as_bytes().len())?,
748 buffer.as_mut_ptr(),
749 buffer_capacity,
750 add_bos,
751 true,
752 )
753 };
754
755 // if we fail the first time we can resize the vector to the correct size and try again. This should never fail.
756 // as a result - size is guaranteed to be positive here.
757 let size = if size.is_negative() {
758 buffer.reserve_exact(usize::try_from(-size).expect("usize's are larger "));
759 unsafe {
760 llama_tokenize(
761 self.get_vocab().vocab.as_ref(),
762 c_string.as_ptr(),
763 c_int::try_from(c_string.as_bytes().len())?,
764 buffer.as_mut_ptr(),
765 -size,
766 add_bos,
767 true,
768 )
769 }
770 } else {
771 size
772 };
773
774 let size = usize::try_from(size).expect("size is positive and usize ");
775
776 // Safety: `size` < `capacity` and llama-cpp has initialized elements up to `size`
777 unsafe { buffer.set_len(size) }
778 Ok(buffer.into_iter().map(LlamaToken).collect())
779 }
780
781 /// Get the type of a token.
782 ///
783 /// This function retrieves the attributes associated with a given token. The attributes are typically used to
784 /// understand whether the token represents a special type of token (e.g., beginning-of-sequence (BOS), end-of-sequence (EOS),
785 /// control tokens, etc.).
786 ///
787 /// # Panics
788 ///
789 /// - This function will panic if the token type is unknown or cannot be converted to a valid `LlamaTokenAttrs`.
790 ///
791 /// # Example
792 ///
793 /// ```no_run
794 /// use llama_cpp_4::model::LlamaModel;
795 /// use llama_cpp_4::model::params::LlamaModelParams;
796 /// use llama_cpp_4::llama_backend::LlamaBackend;
797 /// use llama_cpp_4::token::LlamaToken;
798 ///
799 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
800 /// let backend = LlamaBackend::init()?;
801 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
802 /// let token = LlamaToken::new(42);
803 /// let token_attrs = model.token_attr(token);
804 /// # Ok(())
805 /// # }
806 /// ```
807 #[must_use]
808 pub fn token_attr(&self, LlamaToken(id): LlamaToken) -> LlamaTokenAttrs {
809 let token_type =
810 unsafe { llama_cpp_sys_4::llama_vocab_get_attr(self.get_vocab().vocab.as_ref(), id) };
811 LlamaTokenAttrs::try_from(token_type).expect("token type is valid")
812 }
813
814 /// Detokenize a slice of tokens into a string.
815 ///
816 /// This is the inverse of [`str_to_token`](Self::str_to_token).
817 ///
818 /// # Parameters
819 ///
820 /// - `tokens`: The tokens to detokenize.
821 /// - `remove_special`: If `true`, special tokens are removed from the output.
822 /// - `unparse_special`: If `true`, special tokens are rendered as their text representation.
823 ///
824 /// # Errors
825 ///
826 /// Returns an error if the detokenized text is not valid UTF-8.
827 #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap, clippy::cast_sign_loss)]
828 pub fn detokenize(
829 &self,
830 tokens: &[LlamaToken],
831 remove_special: bool,
832 unparse_special: bool,
833 ) -> Result<String, StringFromModelError> {
834 // First call with empty buffer to get required size
835 let n_tokens = tokens.len() as i32;
836 let token_ptr = tokens.as_ptr().cast::<llama_cpp_sys_4::llama_token>();
837 let needed = unsafe {
838 llama_detokenize(
839 self.get_vocab().vocab.as_ref(),
840 token_ptr,
841 n_tokens,
842 std::ptr::null_mut(),
843 0,
844 remove_special,
845 unparse_special,
846 )
847 };
848 // llama_detokenize returns negative required size when buffer is too small
849 let buf_size = if needed < 0 { (-needed) as usize } else { needed as usize };
850 let mut buf = vec![0u8; buf_size];
851 let ret = unsafe {
852 llama_detokenize(
853 self.get_vocab().vocab.as_ref(),
854 token_ptr,
855 n_tokens,
856 buf.as_mut_ptr().cast::<c_char>(),
857 buf_size as i32,
858 remove_special,
859 unparse_special,
860 )
861 };
862 if ret < 0 {
863 return Err(StringFromModelError::ReturnedError(ret));
864 }
865 let len = ret as usize;
866 let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
867 Ok(s.to_owned())
868 }
869
870 /// Convert a token to a string with a specified buffer size.
871 ///
872 /// This function allows you to convert a token into a string, with the ability to specify a buffer size for the operation.
873 /// It is generally recommended to use `LlamaModel::token_to_str` instead, as 8 bytes is typically sufficient for most tokens,
874 /// and the extra buffer size doesn't usually matter.
875 ///
876 /// # Errors
877 ///
878 /// - If the token type is unknown, an error will be returned.
879 /// - If the resultant token exceeds the provided `buffer_size`, an error will occur.
880 /// - If the token string returned by `llama-cpp` is not valid UTF-8, it will return an error.
881 ///
882 /// # Panics
883 ///
884 /// - This function will panic if the `buffer_size` does not fit into a `c_int`.
885 /// - It will also panic if the size returned from `llama-cpp` does not fit into a `usize`, which should typically never happen.
886 ///
887 /// # Example
888 ///
889 /// ```no_run
890 /// use llama_cpp_4::model::{LlamaModel, Special};
891 /// use llama_cpp_4::model::params::LlamaModelParams;
892 /// use llama_cpp_4::llama_backend::LlamaBackend;
893 /// use llama_cpp_4::token::LlamaToken;
894 ///
895 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
896 /// let backend = LlamaBackend::init()?;
897 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
898 /// let token = LlamaToken::new(42);
899 /// let token_string = model.token_to_str_with_size(token, 32, Special::Plaintext)?;
900 /// # Ok(())
901 /// # }
902 /// ```
903 pub fn token_to_str_with_size(
904 &self,
905 token: LlamaToken,
906 buffer_size: usize,
907 special: Special,
908 ) -> Result<String, TokenToStringError> {
909 let bytes = self.token_to_bytes_with_size(token, buffer_size, special, None)?;
910 Ok(String::from_utf8(bytes)?)
911 }
912
913 /// Convert a token to bytes with a specified buffer size.
914 ///
915 /// Generally you should use [`LlamaModel::token_to_bytes`] instead as 8 bytes is enough for most words and
916 /// the extra bytes do not really matter.
917 ///
918 /// # Errors
919 ///
920 /// - if the token type is unknown
921 /// - the resultant token is larger than `buffer_size`.
922 ///
923 /// # Panics
924 ///
925 /// - This function will panic if `buffer_size` cannot fit into a `c_int`.
926 /// - It will also panic if the size returned from `llama-cpp` cannot be converted to `usize` (which should not happen).
927 ///
928 /// # Example
929 ///
930 /// ```no_run
931 /// use llama_cpp_4::model::{LlamaModel, Special};
932 /// use llama_cpp_4::model::params::LlamaModelParams;
933 /// use llama_cpp_4::llama_backend::LlamaBackend;
934 /// use llama_cpp_4::token::LlamaToken;
935 ///
936 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
937 /// let backend = LlamaBackend::init()?;
938 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
939 /// let token = LlamaToken::new(42);
940 /// let token_bytes = model.token_to_bytes_with_size(token, 32, Special::Plaintext, None)?;
941 /// # Ok(())
942 /// # }
943 /// ```
944 pub fn token_to_bytes_with_size(
945 &self,
946 token: LlamaToken,
947 buffer_size: usize,
948 special: Special,
949 lstrip: Option<NonZeroU16>,
950 ) -> Result<Vec<u8>, TokenToStringError> {
951 if token == self.token_nl() {
952 return Ok(String::from("\n").into_bytes());
953 }
954
955 // unsure what to do with this in the face of the 'special' arg + attr changes
956 let attrs = self.token_attr(token);
957 if (attrs.contains(LlamaTokenAttr::Control)
958 && (token == self.token_bos() || token == self.token_eos()))
959 || attrs.is_empty()
960 || attrs
961 .intersects(LlamaTokenAttr::Unknown | LlamaTokenAttr::Byte | LlamaTokenAttr::Unused)
962 {
963 return Ok(Vec::new());
964 }
965
966 let special = match special {
967 Special::Tokenize => true,
968 Special::Plaintext => false,
969 };
970
971 let string = CString::new(vec![b'*'; buffer_size]).expect("no null");
972 let len = string.as_bytes().len();
973 let len = c_int::try_from(len).expect("length fits into c_int");
974 let buf = string.into_raw();
975 let lstrip = lstrip.map_or(0, |it| i32::from(it.get()));
976 let size = unsafe {
977 llama_token_to_piece(
978 self.get_vocab().vocab.as_ref(),
979 token.0,
980 buf,
981 len,
982 lstrip,
983 special,
984 )
985 };
986
987 match size {
988 0 => Err(TokenToStringError::UnknownTokenType),
989 i if i.is_negative() => Err(TokenToStringError::InsufficientBufferSpace(i)),
990 size => {
991 let string = unsafe { CString::from_raw(buf) };
992 let mut bytes = string.into_bytes();
993 let len = usize::try_from(size).expect("size is positive and fits into usize");
994 bytes.truncate(len);
995 Ok(bytes)
996 }
997 }
998 }
999 /// The number of tokens the model was trained on.
1000 ///
1001 /// This function returns the number of tokens the model was trained on. It is returned as a `c_int` for maximum
1002 /// compatibility with the underlying llama-cpp library, though it can typically be cast to an `i32` without issue.
1003 ///
1004 /// # Example
1005 ///
1006 /// ```no_run
1007 /// use llama_cpp_4::model::LlamaModel;
1008 /// use llama_cpp_4::model::params::LlamaModelParams;
1009 /// use llama_cpp_4::llama_backend::LlamaBackend;
1010 ///
1011 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1012 /// let backend = LlamaBackend::init()?;
1013 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1014 /// let n_vocab = model.n_vocab();
1015 /// # Ok(())
1016 /// # }
1017 /// ```
1018 #[must_use]
1019 pub fn n_vocab(&self) -> i32 {
1020 self.get_vocab().n_tokens()
1021 }
1022
1023 /// The type of vocab the model was trained on.
1024 ///
1025 /// This function returns the type of vocabulary used by the model, such as whether it is based on byte-pair encoding (BPE),
1026 /// word-level tokens, or another tokenization scheme.
1027 ///
1028 /// # Panics
1029 ///
1030 /// - This function will panic if `llama-cpp` emits a vocab type that is not recognized or is invalid for this library.
1031 ///
1032 /// # Example
1033 ///
1034 /// ```no_run
1035 /// use llama_cpp_4::model::LlamaModel;
1036 /// use llama_cpp_4::model::params::LlamaModelParams;
1037 /// use llama_cpp_4::llama_backend::LlamaBackend;
1038 ///
1039 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1040 /// let backend = LlamaBackend::init()?;
1041 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1042 /// let vocab_type = model.vocab_type();
1043 /// # Ok(())
1044 /// # }
1045 /// ```
1046 #[must_use]
1047 pub fn vocab_type(&self) -> VocabType {
1048 let vocab_type = unsafe { llama_vocab_type(self.get_vocab().vocab.as_ref()) };
1049 VocabType::try_from(vocab_type).expect("invalid vocab type")
1050 }
1051
1052 /// Returns the number of embedding dimensions for the model.
1053 ///
1054 /// This function retrieves the number of embeddings (or embedding dimensions) used by the model. It is typically
1055 /// used for analyzing model architecture and setting up context parameters or other model configuration aspects.
1056 ///
1057 /// # Panics
1058 ///
1059 /// - This function may panic if the underlying `llama-cpp` library returns an invalid embedding dimension value.
1060 ///
1061 /// # Example
1062 ///
1063 /// ```no_run
1064 /// use llama_cpp_4::model::LlamaModel;
1065 /// use llama_cpp_4::model::params::LlamaModelParams;
1066 /// use llama_cpp_4::llama_backend::LlamaBackend;
1067 ///
1068 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1069 /// let backend = LlamaBackend::init()?;
1070 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1071 /// let n_embd = model.n_embd();
1072 /// # Ok(())
1073 /// # }
1074 /// ```
1075 #[must_use]
1076 pub fn n_embd(&self) -> c_int {
1077 unsafe { llama_model_n_embd(self.model.as_ptr()) }
1078 }
1079
1080 /// Get the number of transformer layers in the model.
1081 #[must_use]
1082 pub fn n_layer(&self) -> c_int {
1083 unsafe { llama_model_n_layer(self.model.as_ptr()) }
1084 }
1085
1086 /// Get the number of attention heads in the model.
1087 #[must_use]
1088 pub fn n_head(&self) -> c_int {
1089 unsafe { llama_model_n_head(self.model.as_ptr()) }
1090 }
1091
1092 /// Get the number of key-value attention heads in the model.
1093 #[must_use]
1094 pub fn n_head_kv(&self) -> c_int {
1095 unsafe { llama_model_n_head_kv(self.model.as_ptr()) }
1096 }
1097
1098 /// Get the input embedding size of the model.
1099 #[must_use]
1100 pub fn n_embd_inp(&self) -> c_int {
1101 unsafe { llama_model_n_embd_inp(self.model.as_ptr()) }
1102 }
1103
1104 /// Get the output embedding size of the model.
1105 #[must_use]
1106 pub fn n_embd_out(&self) -> c_int {
1107 unsafe { llama_model_n_embd_out(self.model.as_ptr()) }
1108 }
1109
1110 /// Get the sliding window attention size of the model.
1111 /// Returns 0 if the model does not use sliding window attention.
1112 #[must_use]
1113 pub fn n_swa(&self) -> c_int {
1114 unsafe { llama_model_n_swa(self.model.as_ptr()) }
1115 }
1116
1117 /// Get the `RoPE` type used by the model.
1118 #[must_use]
1119 pub fn rope_type(&self) -> i32 {
1120 unsafe { llama_model_rope_type(self.model.as_ptr()) }
1121 }
1122
1123 /// Get the `RoPE` frequency scale used during training.
1124 #[must_use]
1125 pub fn rope_freq_scale_train(&self) -> f32 {
1126 unsafe { llama_model_rope_freq_scale_train(self.model.as_ptr()) }
1127 }
1128
1129 /// Get the model size in bytes.
1130 #[must_use]
1131 pub fn model_size(&self) -> u64 {
1132 unsafe { llama_model_size(self.model.as_ptr()) }
1133 }
1134
1135 /// Get the number of parameters in the model.
1136 #[must_use]
1137 pub fn n_params(&self) -> u64 {
1138 unsafe { llama_model_n_params(self.model.as_ptr()) }
1139 }
1140
1141 /// Get the number of classification outputs.
1142 #[must_use]
1143 pub fn n_cls_out(&self) -> u32 {
1144 unsafe { llama_model_n_cls_out(self.model.as_ptr()) }
1145 }
1146
1147 /// Get the classification label for the given index.
1148 ///
1149 /// # Errors
1150 ///
1151 /// Returns an error if the label is null or not valid UTF-8.
1152 pub fn cls_label(&self, index: u32) -> Result<&str, StringFromModelError> {
1153 let ptr = unsafe { llama_model_cls_label(self.model.as_ptr(), index) };
1154 if ptr.is_null() {
1155 return Err(StringFromModelError::ReturnedError(-1));
1156 }
1157 let cstr = unsafe { CStr::from_ptr(ptr) };
1158 cstr.to_str().map_err(StringFromModelError::Utf8Error)
1159 }
1160
1161 /// Get the number of metadata key-value pairs.
1162 #[must_use]
1163 pub fn meta_count(&self) -> c_int {
1164 unsafe { llama_model_meta_count(self.model.as_ptr()) }
1165 }
1166
1167 /// Get a model description string.
1168 ///
1169 /// The `buf_size` parameter specifies the maximum buffer size for the description.
1170 /// A default of 256 bytes is usually sufficient.
1171 ///
1172 /// # Errors
1173 ///
1174 /// Returns an error if the description could not be retrieved or is not valid UTF-8.
1175 #[allow(clippy::cast_sign_loss)]
1176 pub fn desc(&self, buf_size: usize) -> Result<String, StringFromModelError> {
1177 let mut buf = vec![0u8; buf_size];
1178 let ret = unsafe {
1179 llama_model_desc(
1180 self.model.as_ptr(),
1181 buf.as_mut_ptr().cast::<c_char>(),
1182 buf_size,
1183 )
1184 };
1185 if ret < 0 {
1186 return Err(StringFromModelError::ReturnedError(ret));
1187 }
1188 let len = ret as usize;
1189 let s = std::str::from_utf8(&buf[..len])
1190 .map_err(StringFromModelError::Utf8Error)?;
1191 Ok(s.to_owned())
1192 }
1193
1194 /// Get a metadata key by index.
1195 ///
1196 /// The `buf_size` parameter specifies the maximum buffer size for the key.
1197 /// A default of 256 bytes is usually sufficient.
1198 ///
1199 /// # Errors
1200 ///
1201 /// Returns an error if the index is out of range or the key is not valid UTF-8.
1202 #[allow(clippy::cast_sign_loss)]
1203 pub fn meta_key_by_index(&self, index: i32, buf_size: usize) -> Result<String, StringFromModelError> {
1204 let mut buf = vec![0u8; buf_size];
1205 let ret = unsafe {
1206 llama_model_meta_key_by_index(
1207 self.model.as_ptr(),
1208 index,
1209 buf.as_mut_ptr().cast::<c_char>(),
1210 buf_size,
1211 )
1212 };
1213 if ret < 0 {
1214 return Err(StringFromModelError::ReturnedError(ret));
1215 }
1216 let len = ret as usize;
1217 let s = std::str::from_utf8(&buf[..len])
1218 .map_err(StringFromModelError::Utf8Error)?;
1219 Ok(s.to_owned())
1220 }
1221
1222 /// Get a metadata value string by index.
1223 ///
1224 /// The `buf_size` parameter specifies the maximum buffer size for the value.
1225 /// Values can be large (e.g. chat templates, token lists), so 4096+ may be needed.
1226 ///
1227 /// # Errors
1228 ///
1229 /// Returns an error if the index is out of range or the value is not valid UTF-8.
1230 #[allow(clippy::cast_sign_loss)]
1231 pub fn meta_val_str_by_index(&self, index: i32, buf_size: usize) -> Result<String, StringFromModelError> {
1232 let mut buf = vec![0u8; buf_size];
1233 let ret = unsafe {
1234 llama_model_meta_val_str_by_index(
1235 self.model.as_ptr(),
1236 index,
1237 buf.as_mut_ptr().cast::<c_char>(),
1238 buf_size,
1239 )
1240 };
1241 if ret < 0 {
1242 return Err(StringFromModelError::ReturnedError(ret));
1243 }
1244 let len = ret as usize;
1245 let s = std::str::from_utf8(&buf[..len])
1246 .map_err(StringFromModelError::Utf8Error)?;
1247 Ok(s.to_owned())
1248 }
1249
1250 /// Get a metadata value by key name.
1251 ///
1252 /// This is more convenient than iterating metadata by index when you know the key.
1253 /// The `buf_size` parameter specifies the maximum buffer size for the value.
1254 ///
1255 /// # Errors
1256 ///
1257 /// Returns an error if the key is not found, contains a null byte, or the value is not valid UTF-8.
1258 #[allow(clippy::cast_sign_loss)]
1259 pub fn meta_val_str(&self, key: &str, buf_size: usize) -> Result<String, StringFromModelError> {
1260 let c_key = CString::new(key)
1261 .map_err(|_| StringFromModelError::ReturnedError(-1))?;
1262 let mut buf = vec![0u8; buf_size];
1263 let ret = unsafe {
1264 llama_model_meta_val_str(
1265 self.model.as_ptr(),
1266 c_key.as_ptr(),
1267 buf.as_mut_ptr().cast::<c_char>(),
1268 buf_size,
1269 )
1270 };
1271 if ret < 0 {
1272 return Err(StringFromModelError::ReturnedError(ret));
1273 }
1274 let len = ret as usize;
1275 let s = std::str::from_utf8(&buf[..len])
1276 .map_err(StringFromModelError::Utf8Error)?;
1277 Ok(s.to_owned())
1278 }
1279
1280 /// Get all metadata as a list of `(key, value)` pairs.
1281 ///
1282 /// This is a convenience method that iterates over all metadata entries.
1283 /// Keys use a buffer of 256 bytes and values use 4096 bytes.
1284 /// For values that may be larger (e.g. token lists), use
1285 /// [`meta_val_str_by_index`](Self::meta_val_str_by_index) directly with a larger buffer.
1286 ///
1287 /// # Errors
1288 ///
1289 /// Returns an error if any key or value cannot be read or is not valid UTF-8.
1290 #[allow(clippy::cast_sign_loss)]
1291 pub fn metadata(&self) -> Result<Vec<(String, String)>, StringFromModelError> {
1292 let count = self.meta_count();
1293 let mut result = Vec::with_capacity(count as usize);
1294 for i in 0..count {
1295 let key = self.meta_key_by_index(i, 256)?;
1296 let val = self.meta_val_str_by_index(i, 4096)?;
1297 result.push((key, val));
1298 }
1299 Ok(result)
1300 }
1301
1302 /// Check if the model has an encoder.
1303 #[must_use]
1304 pub fn has_encoder(&self) -> bool {
1305 unsafe { llama_model_has_encoder(self.model.as_ptr()) }
1306 }
1307
1308 /// Check if the model has a decoder.
1309 #[must_use]
1310 pub fn has_decoder(&self) -> bool {
1311 unsafe { llama_model_has_decoder(self.model.as_ptr()) }
1312 }
1313
1314 /// Check if the model is recurrent (e.g. Mamba, RWKV).
1315 #[must_use]
1316 pub fn is_recurrent(&self) -> bool {
1317 unsafe { llama_model_is_recurrent(self.model.as_ptr()) }
1318 }
1319
1320 /// Check if the model is a hybrid model.
1321 #[must_use]
1322 pub fn is_hybrid(&self) -> bool {
1323 unsafe { llama_model_is_hybrid(self.model.as_ptr()) }
1324 }
1325
1326 /// Check if the model is a diffusion model.
1327 #[must_use]
1328 pub fn is_diffusion(&self) -> bool {
1329 unsafe { llama_model_is_diffusion(self.model.as_ptr()) }
1330 }
1331
1332 /// Get chat template from model.
1333 ///
1334 /// # Errors
1335 ///
1336 /// - If the model does not have a chat template, it will return an error.
1337 /// - If the chat template is not a valid `CString`, it will return an error.
1338 ///
1339 /// # Example
1340 ///
1341 /// ```no_run
1342 /// use llama_cpp_4::model::LlamaModel;
1343 /// use llama_cpp_4::model::params::LlamaModelParams;
1344 /// use llama_cpp_4::llama_backend::LlamaBackend;
1345 ///
1346 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1347 /// let backend = LlamaBackend::init()?;
1348 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1349 /// let chat_template = model.get_chat_template(1024)?;
1350 /// # Ok(())
1351 /// # }
1352 /// ```
1353 #[allow(clippy::missing_panics_doc)] // We statically know this will not panic as long as the buffer size is sufficient
1354 pub fn get_chat_template(&self, buf_size: usize) -> Result<String, ChatTemplateError> {
1355 // longest known template is about 1200 bytes from llama.cpp
1356 let chat_temp = CString::new(vec![b'*'; buf_size]).expect("no null");
1357 let chat_ptr = chat_temp.into_raw();
1358 let chat_name = CString::new("tokenizer.chat_template").expect("no null bytes");
1359
1360 let ret = unsafe {
1361 llama_model_meta_val_str(self.model.as_ptr(), chat_name.as_ptr(), chat_ptr, buf_size)
1362 };
1363
1364 if ret < 0 {
1365 return Err(ChatTemplateError::MissingTemplate(ret));
1366 }
1367
1368 let template_c = unsafe { CString::from_raw(chat_ptr) };
1369 let template = template_c.to_str()?;
1370
1371 let ret: usize = ret.try_into().unwrap();
1372 if template.len() < ret {
1373 return Err(ChatTemplateError::BuffSizeError(ret + 1));
1374 }
1375
1376 Ok(template.to_owned())
1377 }
1378
1379 /// Loads a model from a file.
1380 ///
1381 /// This function loads a model from a specified file path and returns the corresponding `LlamaModel` instance.
1382 ///
1383 /// # Errors
1384 ///
1385 /// - If the path cannot be converted to a string or if the model file does not exist, it will return an error.
1386 /// - If the model cannot be loaded (e.g., due to an invalid or corrupted model file), it will return a `LlamaModelLoadError`.
1387 ///
1388 /// # Example
1389 ///
1390 /// ```no_run
1391 /// use llama_cpp_4::model::LlamaModel;
1392 /// use llama_cpp_4::model::params::LlamaModelParams;
1393 /// use llama_cpp_4::llama_backend::LlamaBackend;
1394 ///
1395 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1396 /// let backend = LlamaBackend::init()?;
1397 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1398 /// # Ok(())
1399 /// # }
1400 /// ```
1401 #[tracing::instrument(skip_all, fields(params))]
1402 pub fn load_from_file(
1403 _: &LlamaBackend,
1404 path: impl AsRef<Path>,
1405 params: &LlamaModelParams,
1406 ) -> Result<Self, LlamaModelLoadError> {
1407 let path = path.as_ref();
1408 debug_assert!(
1409 Path::new(path).exists(),
1410 "{} does not exist",
1411 path.display()
1412 );
1413 let path = path
1414 .to_str()
1415 .ok_or(LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
1416
1417 let cstr = CString::new(path)?;
1418 let llama_model = unsafe { llama_model_load_from_file(cstr.as_ptr(), params.params) };
1419
1420 let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?;
1421
1422 tracing::debug!(?path, "Loaded model");
1423 Ok(LlamaModel { model })
1424 }
1425
1426 /// Load a model from multiple split files.
1427 ///
1428 /// This function loads a model that has been split across multiple files. This is useful for
1429 /// very large models that exceed filesystem limitations or need to be distributed across
1430 /// multiple storage devices.
1431 ///
1432 /// # Arguments
1433 ///
1434 /// * `paths` - A slice of paths to the split model files
1435 /// * `params` - The model parameters
1436 ///
1437 /// # Errors
1438 ///
1439 /// Returns an error if:
1440 /// - Any of the paths cannot be converted to a C string
1441 /// - The model fails to load from the splits
1442 /// - Any path doesn't exist or isn't accessible
1443 ///
1444 /// # Example
1445 ///
1446 /// ```no_run
1447 /// use llama_cpp_4::model::{LlamaModel, params::LlamaModelParams};
1448 /// use llama_cpp_4::llama_backend::LlamaBackend;
1449 ///
1450 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1451 /// let backend = LlamaBackend::init()?;
1452 /// let params = LlamaModelParams::default();
1453 ///
1454 /// let paths = vec![
1455 /// "model-00001-of-00003.gguf",
1456 /// "model-00002-of-00003.gguf",
1457 /// "model-00003-of-00003.gguf",
1458 /// ];
1459 ///
1460 /// let model = LlamaModel::load_from_splits(&backend, &paths, ¶ms)?;
1461 /// # Ok(())
1462 /// # }
1463 /// ```
1464 #[tracing::instrument(skip_all)]
1465 pub fn load_from_splits(
1466 _: &LlamaBackend,
1467 paths: &[impl AsRef<Path>],
1468 params: &LlamaModelParams,
1469 ) -> Result<Self, LlamaModelLoadError> {
1470 // Convert paths to C strings
1471 let c_strings: Vec<CString> = paths
1472 .iter()
1473 .map(|p| {
1474 let path = p.as_ref();
1475 debug_assert!(path.exists(), "{} does not exist", path.display());
1476 let path_str = path
1477 .to_str()
1478 .ok_or(LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
1479 CString::new(path_str).map_err(LlamaModelLoadError::from)
1480 })
1481 .collect::<Result<Vec<_>, _>>()?;
1482
1483 // Create array of pointers to C strings
1484 let c_ptrs: Vec<*const c_char> = c_strings.iter().map(|s| s.as_ptr()).collect();
1485
1486 // Load the model from splits
1487 let llama_model = unsafe {
1488 llama_model_load_from_splits(c_ptrs.as_ptr().cast_mut(), c_ptrs.len(), params.params)
1489 };
1490
1491 let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?;
1492
1493 tracing::debug!("Loaded model from {} splits", paths.len());
1494 Ok(LlamaModel { model })
1495 }
1496
1497 /// Load a model from a `FILE` pointer.
1498 ///
1499 /// # Safety
1500 ///
1501 /// The `file` pointer must be a valid, open `FILE*`.
1502 ///
1503 /// # Errors
1504 ///
1505 /// Returns an error if the model cannot be loaded.
1506 pub unsafe fn load_from_file_ptr(
1507 file: *mut llama_cpp_sys_4::FILE,
1508 params: &LlamaModelParams,
1509 ) -> Result<Self, LlamaModelLoadError> {
1510 let model = llama_cpp_sys_4::llama_model_load_from_file_ptr(file, params.params);
1511 let model = NonNull::new(model).ok_or(LlamaModelLoadError::NullResult)?;
1512 Ok(LlamaModel { model })
1513 }
1514
1515 /// Initialize a model from user-provided data.
1516 ///
1517 /// # Safety
1518 ///
1519 /// The metadata, callback, and user data must be valid.
1520 ///
1521 /// # Errors
1522 ///
1523 /// Returns an error if the model cannot be initialized.
1524 pub unsafe fn init_from_user(
1525 metadata: *mut llama_cpp_sys_4::gguf_context,
1526 set_tensor_data: llama_cpp_sys_4::llama_model_set_tensor_data_t,
1527 set_tensor_data_ud: *mut std::ffi::c_void,
1528 params: &LlamaModelParams,
1529 ) -> Result<Self, LlamaModelLoadError> {
1530 let model = llama_cpp_sys_4::llama_model_init_from_user(
1531 metadata,
1532 set_tensor_data,
1533 set_tensor_data_ud,
1534 params.params,
1535 );
1536 let model = NonNull::new(model).ok_or(LlamaModelLoadError::NullResult)?;
1537 Ok(LlamaModel { model })
1538 }
1539
1540 /// Save the model to a file.
1541 ///
1542 /// # Panics
1543 ///
1544 /// Panics if the path contains null bytes.
1545 pub fn save_to_file(&self, path: impl AsRef<Path>) {
1546 let path = path.as_ref();
1547 let path_str = path.to_str().expect("path is not valid UTF-8");
1548 let c_path = CString::new(path_str).expect("path contains null bytes");
1549 unsafe {
1550 llama_model_save_to_file(self.model.as_ptr(), c_path.as_ptr());
1551 }
1552 }
1553
1554 /// Get the list of built-in chat templates.
1555 ///
1556 /// Returns the names of all chat templates that are built into llama.cpp.
1557 ///
1558 /// # Panics
1559 ///
1560 /// Panics if any template name is not valid UTF-8.
1561 #[allow(clippy::cast_sign_loss)]
1562 #[must_use]
1563 pub fn chat_builtin_templates() -> Vec<String> {
1564 // First call to get count
1565 let count = unsafe { llama_chat_builtin_templates(std::ptr::null_mut(), 0) };
1566 if count <= 0 {
1567 return Vec::new();
1568 }
1569 let count = count as usize;
1570 let mut ptrs: Vec<*const c_char> = vec![std::ptr::null(); count];
1571 unsafe {
1572 llama_chat_builtin_templates(ptrs.as_mut_ptr(), count);
1573 }
1574 ptrs.iter()
1575 .map(|&p| {
1576 let cstr = unsafe { CStr::from_ptr(p) };
1577 cstr.to_str()
1578 .expect("template name is not valid UTF-8")
1579 .to_owned()
1580 })
1581 .collect()
1582 }
1583
1584 /// Initializes a lora adapter from a file.
1585 ///
1586 /// This function initializes a Lora adapter, which is a model extension used to adapt or fine-tune the existing model
1587 /// to a specific domain or task. The adapter file is typically in the form of a binary or serialized file that can be applied
1588 /// to the model for improved performance on specialized tasks.
1589 ///
1590 /// # Errors
1591 ///
1592 /// - If the adapter file path cannot be converted to a string or if the adapter cannot be initialized, it will return an error.
1593 ///
1594 /// # Example
1595 ///
1596 /// ```no_run
1597 /// use llama_cpp_4::model::LlamaModel;
1598 /// use llama_cpp_4::model::params::LlamaModelParams;
1599 /// use llama_cpp_4::llama_backend::LlamaBackend;
1600 ///
1601 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1602 /// let backend = LlamaBackend::init()?;
1603 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1604 /// let adapter = model.lora_adapter_init("path/to/lora/adapter")?;
1605 /// # Ok(())
1606 /// # }
1607 /// ```
1608 pub fn lora_adapter_init(
1609 &self,
1610 path: impl AsRef<Path>,
1611 ) -> Result<LlamaLoraAdapter, LlamaLoraAdapterInitError> {
1612 let path = path.as_ref();
1613 debug_assert!(
1614 Path::new(path).exists(),
1615 "{} does not exist",
1616 path.display()
1617 );
1618
1619 let path = path
1620 .to_str()
1621 .ok_or(LlamaLoraAdapterInitError::PathToStrError(
1622 path.to_path_buf(),
1623 ))?;
1624
1625 let cstr = CString::new(path)?;
1626 let adapter = unsafe { llama_adapter_lora_init(self.model.as_ptr(), cstr.as_ptr()) };
1627
1628 let adapter = NonNull::new(adapter).ok_or(LlamaLoraAdapterInitError::NullResult)?;
1629
1630 tracing::debug!(?path, "Initialized lora adapter");
1631 Ok(LlamaLoraAdapter {
1632 lora_adapter: adapter,
1633 })
1634 }
1635
1636 /// Create a new context from this model.
1637 ///
1638 /// This function creates a new context for the model, which is used to manage and perform computations for inference,
1639 /// including token generation, embeddings, and other tasks that the model can perform. The context allows fine-grained
1640 /// control over model parameters for a specific task.
1641 ///
1642 /// # Errors
1643 ///
1644 /// - There are various potential failures such as invalid parameters or a failure to allocate the context. See [`LlamaContextLoadError`]
1645 /// for more detailed error descriptions.
1646 ///
1647 /// # Example
1648 ///
1649 /// ```no_run
1650 /// use llama_cpp_4::model::LlamaModel;
1651 /// use llama_cpp_4::model::params::LlamaModelParams;
1652 /// use llama_cpp_4::context::params::LlamaContextParams;
1653 /// use llama_cpp_4::llama_backend::LlamaBackend;
1654 ///
1655 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1656 /// let backend = LlamaBackend::init()?;
1657 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1658 /// let context = model.new_context(&backend, LlamaContextParams::default())?;
1659 /// # Ok(())
1660 /// # }
1661 /// ```
1662 #[allow(clippy::needless_pass_by_value)]
1663 pub fn new_context(
1664 &self,
1665 _: &LlamaBackend,
1666 params: LlamaContextParams,
1667 ) -> Result<LlamaContext<'_>, LlamaContextLoadError> {
1668 // Apply TurboQuant attn-rotation preference before the KV cache is
1669 // initialised inside llama_init_from_model.
1670 let prev_rot_var = std::env::var("LLAMA_ATTN_ROT_DISABLE").ok();
1671 if params.attn_rot_disabled {
1672 // SAFETY: we restore the value right after the call.
1673 #[allow(unused_unsafe)]
1674 unsafe {
1675 std::env::set_var("LLAMA_ATTN_ROT_DISABLE", "1");
1676 }
1677 } else if std::env::var("LLAMA_ATTN_ROT_DISABLE").is_ok() {
1678 // params say "enabled" – only clear if it was previously unset
1679 // (respect explicit user env var).
1680 }
1681
1682 let context_params = params.context_params;
1683 let context = unsafe { llama_init_from_model(self.model.as_ptr(), context_params) };
1684
1685 // Restore the env-var to its previous state.
1686 #[allow(unused_unsafe)]
1687 match prev_rot_var {
1688 Some(v) => unsafe { std::env::set_var("LLAMA_ATTN_ROT_DISABLE", v) },
1689 None if params.attn_rot_disabled => unsafe {
1690 std::env::remove_var("LLAMA_ATTN_ROT_DISABLE");
1691 },
1692 None => {}
1693 }
1694
1695 let context = NonNull::new(context).ok_or(LlamaContextLoadError::NullReturn)?;
1696 Ok(LlamaContext::new(self, context, params.embeddings()))
1697 }
1698
1699 /// Apply the model's chat template to a sequence of messages.
1700 ///
1701 /// This function applies the model's chat template to the provided chat messages, formatting them accordingly. The chat
1702 /// template determines the structure or style of conversation between the system and user, such as token formatting,
1703 /// role separation, and more. The template can be customized by providing an optional template string, or if `None`
1704 /// is provided, the default template used by `llama.cpp` will be applied.
1705 ///
1706 /// For more information on supported templates, visit:
1707 /// <https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template>
1708 ///
1709 /// # Arguments
1710 ///
1711 /// - `tmpl`: An optional custom template string. If `None`, the default template will be used.
1712 /// - `chat`: A vector of `LlamaChatMessage` instances, which represent the conversation between the system and user.
1713 /// - `add_ass`: A boolean flag indicating whether additional system-specific instructions (like "assistant") should be included.
1714 ///
1715 /// # Errors
1716 ///
1717 /// There are several possible points of failure when applying the chat template:
1718 /// - Insufficient buffer size to hold the formatted chat (this will return `ApplyChatTemplateError::BuffSizeError`).
1719 /// - If the template or messages cannot be processed properly, various errors from `ApplyChatTemplateError` may occur.
1720 ///
1721 /// # Example
1722 ///
1723 /// ```no_run
1724 /// use llama_cpp_4::model::{LlamaModel, LlamaChatMessage};
1725 /// use llama_cpp_4::model::params::LlamaModelParams;
1726 /// use llama_cpp_4::llama_backend::LlamaBackend;
1727 ///
1728 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1729 /// let backend = LlamaBackend::init()?;
1730 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1731 /// let chat = vec![
1732 /// LlamaChatMessage::new("user".to_string(), "Hello!".to_string())?,
1733 /// LlamaChatMessage::new("assistant".to_string(), "Hi! How can I assist you today?".to_string())?,
1734 /// ];
1735 /// let formatted_chat = model.apply_chat_template(None, &chat, true)?;
1736 /// # Ok(())
1737 /// # }
1738 /// ```
1739 ///
1740 /// # Notes
1741 ///
1742 /// The provided buffer is twice the length of the messages by default, which is recommended by the `llama.cpp` documentation.
1743 /// # Panics
1744 ///
1745 /// Panics if the buffer length exceeds `i32::MAX`.
1746 #[tracing::instrument(skip_all)]
1747 pub fn apply_chat_template(
1748 &self,
1749 tmpl: Option<&str>,
1750 chat: &[LlamaChatMessage],
1751 add_ass: bool,
1752 ) -> Result<String, ApplyChatTemplateError> {
1753 // Compute raw message byte total from the original LlamaChatMessage vec
1754 // *before* we shadow `chat` with the sys-type vec below.
1755 let message_length = chat.iter().fold(0usize, |acc, c| {
1756 acc + c.role.to_bytes().len() + c.content.to_bytes().len()
1757 });
1758
1759 // Build our llama_cpp_sys chat messages (raw pointers into CStrings).
1760 let chat_sys: Vec<llama_chat_message> = chat
1761 .iter()
1762 .map(|c| llama_chat_message {
1763 role: c.role.as_ptr(),
1764 content: c.content.as_ptr(),
1765 })
1766 .collect();
1767
1768 // Set the tmpl pointer.
1769 let tmpl_cstring = tmpl.map(CString::new).transpose()?;
1770 let tmpl_ptr = tmpl_cstring
1771 .as_ref()
1772 .map_or(std::ptr::null(), |s| s.as_ptr());
1773
1774 // `message_length * 4` is far too small for models whose built-in chat
1775 // template adds a long default system prompt (e.g. Qwen3.5 prepends
1776 // ~80+ chars of markup even for a one-word user message). Start with
1777 // at least 4 KiB so short inputs like "hi" always have room.
1778 //
1779 // `llama_chat_apply_template` returns the number of bytes it *actually*
1780 // needed when the buffer was too small, so we retry exactly once with
1781 // that precise size rather than giving up immediately.
1782 let mut buf_size = message_length.saturating_mul(4).max(4096);
1783
1784 for _ in 0..2 {
1785 // Use u8 so that as_mut_ptr()/as_ptr() match the binding (*mut u8 / *const u8).
1786 let mut buff = vec![0u8; buf_size];
1787 let res = unsafe {
1788 llama_chat_apply_template(
1789 tmpl_ptr,
1790 chat_sys.as_ptr(),
1791 chat_sys.len(),
1792 add_ass,
1793 buff.as_mut_ptr().cast(),
1794 i32::try_from(buff.len()).expect("buffer length fits in i32"),
1795 )
1796 };
1797
1798 if res < 0 {
1799 return Err(ApplyChatTemplateError::BuffSizeError);
1800 }
1801
1802 #[allow(clippy::cast_sign_loss)]
1803 let needed = res as usize;
1804 if needed > buf_size {
1805 // Buffer was too small — retry with the exact size llama.cpp reported.
1806 buf_size = needed + 1; // +1 for null terminator
1807 continue;
1808 }
1809
1810 // SAFETY: llama_chat_apply_template wrote a NUL-terminated string
1811 // into `buff`; `needed` bytes were used.
1812 let formatted = unsafe {
1813 CStr::from_ptr(buff.as_ptr().cast())
1814 .to_string_lossy()
1815 .into_owned()
1816 };
1817 return Ok(formatted);
1818 }
1819
1820 Err(ApplyChatTemplateError::BuffSizeError)
1821 }
1822
1823 /// Build a split GGUF file path for a specific chunk.
1824 ///
1825 /// This utility function creates the standardized filename for a split model chunk
1826 /// following the pattern: `{prefix}-{split_no:05d}-of-{split_count:05d}.gguf`
1827 ///
1828 /// # Arguments
1829 ///
1830 /// * `path_prefix` - The base path and filename prefix
1831 /// * `split_no` - The split number (1-indexed)
1832 /// * `split_count` - The total number of splits
1833 ///
1834 /// # Returns
1835 ///
1836 /// Returns the formatted split path as a String
1837 ///
1838 /// # Example
1839 ///
1840 /// ```
1841 /// use llama_cpp_4::model::LlamaModel;
1842 ///
1843 /// let path = LlamaModel::split_path("/models/llama", 1, 4);
1844 /// assert_eq!(path, "/models/llama-00002-of-00004.gguf");
1845 /// ```
1846 ///
1847 /// # Panics
1848 ///
1849 /// Panics if the path prefix contains a null byte.
1850 #[must_use]
1851 pub fn split_path(path_prefix: &str, split_no: i32, split_count: i32) -> String {
1852 let mut buffer = vec![0u8; 1024];
1853 let len = unsafe {
1854 llama_split_path(
1855 buffer.as_mut_ptr().cast::<c_char>(),
1856 buffer.len(),
1857 CString::new(path_prefix).unwrap().as_ptr(),
1858 split_no,
1859 split_count,
1860 )
1861 };
1862
1863 let len = usize::try_from(len).expect("split_path length fits in usize");
1864 buffer.truncate(len);
1865 String::from_utf8(buffer).unwrap_or_default()
1866 }
1867
1868 /// Extract the path prefix from a split filename.
1869 ///
1870 /// This function extracts the base path prefix from a split model filename,
1871 /// but only if the `split_no` and `split_count` match the pattern in the filename.
1872 ///
1873 /// # Arguments
1874 ///
1875 /// * `split_path` - The full path to the split file
1876 /// * `split_no` - The expected split number
1877 /// * `split_count` - The expected total number of splits
1878 ///
1879 /// # Returns
1880 ///
1881 /// Returns the path prefix if the pattern matches, or None if it doesn't
1882 ///
1883 /// # Example
1884 ///
1885 /// ```
1886 /// use llama_cpp_4::model::LlamaModel;
1887 ///
1888 /// let prefix = LlamaModel::split_prefix("/models/llama-00002-of-00004.gguf", 1, 4);
1889 /// assert_eq!(prefix, Some("/models/llama".to_string()));
1890 /// ```
1891 ///
1892 /// # Panics
1893 ///
1894 /// Panics if the split path contains a null byte.
1895 #[must_use]
1896 pub fn split_prefix(split_path: &str, split_no: i32, split_count: i32) -> Option<String> {
1897 let mut buffer = vec![0u8; 1024];
1898 let len = unsafe {
1899 llama_split_prefix(
1900 buffer.as_mut_ptr().cast::<c_char>(),
1901 buffer.len(),
1902 CString::new(split_path).unwrap().as_ptr(),
1903 split_no,
1904 split_count,
1905 )
1906 };
1907
1908 if len > 0 {
1909 let len = usize::try_from(len).expect("split_prefix length fits in usize");
1910 buffer.truncate(len);
1911 String::from_utf8(buffer).ok()
1912 } else {
1913 None
1914 }
1915 }
1916}
1917
1918#[allow(clippy::cast_precision_loss)]
1919impl fmt::Display for LlamaModel {
1920 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1921 let desc = self.desc(256).unwrap_or_else(|_| "unknown".to_string());
1922 write!(
1923 f,
1924 "{desc} | {layers}L {heads}H {embd}E | {params} params | {size:.1} MiB",
1925 layers = self.n_layer(),
1926 heads = self.n_head(),
1927 embd = self.n_embd(),
1928 params = self.n_params(),
1929 size = self.model_size() as f64 / (1024.0 * 1024.0),
1930 )
1931 }
1932}
1933
1934impl Drop for LlamaModel {
1935 fn drop(&mut self) {
1936 unsafe { llama_model_free(self.model.as_ptr()) }
1937 }
1938}
1939
1940/// Defines the possible types of vocabulary used by the model.
1941///
1942/// The model may use different types of vocabulary depending on the tokenization method chosen during training.
1943/// This enum represents these types, specifically `BPE` (Byte Pair Encoding) and `SPM` (`SentencePiece`).
1944///
1945/// # Variants
1946///
1947/// - `BPE`: Byte Pair Encoding, a common tokenization method used in NLP tasks.
1948/// - `SPM`: `SentencePiece`, another popular tokenization method for NLP models.
1949///
1950/// # Example
1951///
1952/// ```no_run
1953/// use llama_cpp_4::model::VocabType;
1954///
1955/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1956/// let vocab_type = VocabType::BPE;
1957/// match vocab_type {
1958/// VocabType::BPE => println!("The model uses Byte Pair Encoding (BPE)"),
1959/// VocabType::SPM => println!("The model uses SentencePiece (SPM)"),
1960/// }
1961/// # Ok(())
1962/// # }
1963/// ```
1964#[repr(u32)]
1965#[derive(Debug, Eq, Copy, Clone, PartialEq)]
1966pub enum VocabType {
1967 /// Byte Pair Encoding
1968 BPE = LLAMA_VOCAB_TYPE_BPE as _,
1969 /// Sentence Piece Tokenizer
1970 SPM = LLAMA_VOCAB_TYPE_SPM as _,
1971}
1972
1973/// Error that occurs when trying to convert a `llama_vocab_type` to a `VocabType`.
1974///
1975/// This error is raised when the integer value returned by the system does not correspond to a known vocabulary type.
1976///
1977/// # Variants
1978///
1979/// - `UnknownValue`: The error is raised when the value is not a valid `llama_vocab_type`. The invalid value is returned with the error.
1980///
1981/// # Example
1982///
1983/// ```no_run
1984/// use llama_cpp_4::model::LlamaTokenTypeFromIntError;
1985///
1986/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1987/// let invalid_value = 999; // Not a valid vocabulary type
1988/// let error = LlamaTokenTypeFromIntError::UnknownValue(invalid_value);
1989/// println!("Error: {}", error);
1990/// # Ok(())
1991/// # }
1992/// ```
1993#[derive(thiserror::Error, Debug, Eq, PartialEq)]
1994pub enum LlamaTokenTypeFromIntError {
1995 /// The value is not a valid `llama_token_type`. Contains the int value that was invalid.
1996 #[error("Unknown Value {0}")]
1997 UnknownValue(llama_vocab_type),
1998}
1999
2000impl TryFrom<llama_vocab_type> for VocabType {
2001 type Error = LlamaTokenTypeFromIntError;
2002
2003 fn try_from(value: llama_vocab_type) -> Result<Self, Self::Error> {
2004 match value {
2005 LLAMA_VOCAB_TYPE_BPE => Ok(VocabType::BPE),
2006 LLAMA_VOCAB_TYPE_SPM => Ok(VocabType::SPM),
2007 unknown => Err(LlamaTokenTypeFromIntError::UnknownValue(unknown)),
2008 }
2009 }
2010}