llama_cpp_4/model.rs
1//! A safe wrapper around `llama_model`.
2use std::ffi::CStr;
3use std::ffi::CString;
4use std::fmt;
5use std::num::NonZeroU16;
6use std::os::raw::{c_char, c_int};
7use std::path::Path;
8use std::ptr::NonNull;
9
10use llama_cpp_sys_4::{
11 llama_adapter_lora, llama_adapter_lora_init, llama_chat_apply_template,
12 llama_chat_builtin_templates, llama_chat_message, llama_detokenize, llama_init_from_model,
13 llama_model, llama_model_cls_label, llama_model_decoder_start_token, llama_model_desc,
14 llama_model_free, llama_model_get_vocab, llama_model_has_decoder, llama_model_has_encoder,
15 llama_model_is_diffusion, llama_model_is_hybrid, llama_model_is_recurrent,
16 llama_model_load_from_file, llama_model_load_from_splits, llama_model_meta_count,
17 llama_model_meta_key_by_index, llama_model_meta_val_str, llama_model_meta_val_str_by_index,
18 llama_model_n_cls_out, llama_model_n_ctx_train, llama_model_n_embd, llama_model_n_embd_inp,
19 llama_model_n_embd_out, llama_model_n_head, llama_model_n_head_kv, llama_model_n_layer,
20 llama_model_n_params, llama_model_n_swa, llama_model_rope_freq_scale_train,
21 llama_model_rope_type, llama_model_save_to_file, llama_model_size, llama_split_path,
22 llama_split_prefix, llama_token_to_piece, llama_tokenize, llama_vocab, llama_vocab_type,
23 LLAMA_VOCAB_TYPE_BPE, LLAMA_VOCAB_TYPE_SPM,
24};
25
26use crate::context::params::LlamaContextParams;
27use crate::context::LlamaContext;
28use crate::llama_backend::LlamaBackend;
29use crate::model::params::LlamaModelParams;
30use crate::token::LlamaToken;
31use crate::token_type::{LlamaTokenAttr, LlamaTokenAttrs};
32use crate::{
33 ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaLoraAdapterInitError,
34 LlamaModelLoadError, NewLlamaChatMessageError, StringFromModelError, StringToTokenError,
35 TokenToStringError,
36};
37
38pub mod params;
39
40/// A safe wrapper around `llama_model`.
41#[derive(Debug)]
42#[repr(transparent)]
43#[allow(clippy::module_name_repetitions)]
44pub struct LlamaModel {
45 pub(crate) model: NonNull<llama_model>,
46}
47
48/// A safe wrapper around `llama_vocab`.
49#[derive(Debug)]
50#[repr(transparent)]
51#[allow(clippy::module_name_repetitions)]
52pub struct LlamaVocab {
53 pub(crate) vocab: NonNull<llama_vocab>,
54}
55
56impl LlamaVocab {
57 /// Get the number of tokens in the vocabulary.
58 #[must_use]
59 pub fn n_tokens(&self) -> i32 {
60 unsafe { llama_cpp_sys_4::llama_vocab_n_tokens(self.vocab.as_ref()) }
61 }
62
63 /// Get the vocabulary type.
64 #[must_use]
65 pub fn vocab_type(&self) -> u32 {
66 unsafe {
67 llama_cpp_sys_4::llama_vocab_type(self.vocab.as_ref())
68 .try_into()
69 .unwrap()
70 }
71 }
72
73 /// Get the BOS token.
74 #[must_use]
75 pub fn bos(&self) -> LlamaToken {
76 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_bos(self.vocab.as_ref()) })
77 }
78
79 /// Get the EOS token.
80 #[must_use]
81 pub fn eos(&self) -> LlamaToken {
82 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_eos(self.vocab.as_ref()) })
83 }
84
85 /// Get the EOT (end of turn) token.
86 #[must_use]
87 pub fn eot(&self) -> LlamaToken {
88 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_eot(self.vocab.as_ref()) })
89 }
90
91 /// Get the CLS (classification) token.
92 #[must_use]
93 pub fn cls(&self) -> LlamaToken {
94 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_cls(self.vocab.as_ref()) })
95 }
96
97 /// Get the SEP (separator) token.
98 #[must_use]
99 pub fn sep(&self) -> LlamaToken {
100 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_sep(self.vocab.as_ref()) })
101 }
102
103 /// Get the NL (newline) token.
104 #[must_use]
105 pub fn nl(&self) -> LlamaToken {
106 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_nl(self.vocab.as_ref()) })
107 }
108
109 /// Get the PAD (padding) token.
110 #[must_use]
111 pub fn pad(&self) -> LlamaToken {
112 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_pad(self.vocab.as_ref()) })
113 }
114
115 /// Get the FIM prefix token.
116 #[must_use]
117 pub fn fim_pre(&self) -> LlamaToken {
118 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_pre(self.vocab.as_ref()) })
119 }
120
121 /// Get the FIM suffix token.
122 #[must_use]
123 pub fn fim_suf(&self) -> LlamaToken {
124 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_suf(self.vocab.as_ref()) })
125 }
126
127 /// Get the FIM middle token.
128 #[must_use]
129 pub fn fim_mid(&self) -> LlamaToken {
130 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_mid(self.vocab.as_ref()) })
131 }
132
133 /// Get the FIM padding token.
134 #[must_use]
135 pub fn fim_pad(&self) -> LlamaToken {
136 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_pad(self.vocab.as_ref()) })
137 }
138
139 /// Get the FIM repository token.
140 #[must_use]
141 pub fn fim_rep(&self) -> LlamaToken {
142 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_rep(self.vocab.as_ref()) })
143 }
144
145 /// Get the FIM separator token.
146 #[must_use]
147 pub fn fim_sep(&self) -> LlamaToken {
148 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_sep(self.vocab.as_ref()) })
149 }
150
151 /// Check whether BOS should be added.
152 #[must_use]
153 pub fn get_add_bos(&self) -> bool {
154 unsafe { llama_cpp_sys_4::llama_vocab_get_add_bos(self.vocab.as_ref()) }
155 }
156
157 /// Check whether EOS should be added.
158 #[must_use]
159 pub fn get_add_eos(&self) -> bool {
160 unsafe { llama_cpp_sys_4::llama_vocab_get_add_eos(self.vocab.as_ref()) }
161 }
162
163 /// Check whether SEP should be added.
164 #[must_use]
165 pub fn get_add_sep(&self) -> bool {
166 unsafe { llama_cpp_sys_4::llama_vocab_get_add_sep(self.vocab.as_ref()) }
167 }
168
169 /// Get the text representation of a token.
170 ///
171 /// # Errors
172 ///
173 /// Returns an error if the text pointer is null or not valid UTF-8.
174 pub fn get_text(&self, token: LlamaToken) -> Result<&str, StringFromModelError> {
175 let ptr = unsafe { llama_cpp_sys_4::llama_vocab_get_text(self.vocab.as_ref(), token.0) };
176 if ptr.is_null() {
177 return Err(StringFromModelError::ReturnedError(-1));
178 }
179 let cstr = unsafe { CStr::from_ptr(ptr) };
180 cstr.to_str().map_err(StringFromModelError::Utf8Error)
181 }
182
183 /// Get the score of a token.
184 #[must_use]
185 pub fn get_score(&self, token: LlamaToken) -> f32 {
186 unsafe { llama_cpp_sys_4::llama_vocab_get_score(self.vocab.as_ref(), token.0) }
187 }
188
189 /// Get the attributes of a token.
190 #[must_use]
191 pub fn get_attr(&self, token: LlamaToken) -> u32 {
192 unsafe {
193 llama_cpp_sys_4::llama_vocab_get_attr(self.vocab.as_ref(), token.0)
194 .try_into()
195 .unwrap()
196 }
197 }
198
199 /// Check if a token is a control token.
200 #[must_use]
201 pub fn is_control(&self, token: LlamaToken) -> bool {
202 unsafe { llama_cpp_sys_4::llama_vocab_is_control(self.vocab.as_ref(), token.0) }
203 }
204
205 /// Check if a token is an end-of-generation token.
206 #[must_use]
207 pub fn is_eog(&self, token: LlamaToken) -> bool {
208 unsafe { llama_cpp_sys_4::llama_vocab_is_eog(self.vocab.as_ref(), token.0) }
209 }
210
211 /// Get the token mask value for the vocabulary.
212 #[must_use]
213 pub fn mask(&self) -> LlamaToken {
214 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_mask(self.vocab.as_ref()) })
215 }
216}
217
218/// A safe wrapper around `llama_adapter_lora`.
219#[derive(Debug)]
220#[repr(transparent)]
221#[allow(clippy::module_name_repetitions)]
222pub struct LlamaLoraAdapter {
223 pub(crate) lora_adapter: NonNull<llama_adapter_lora>,
224}
225
226impl LlamaLoraAdapter {
227 /// Get the number of metadata key-value pairs in the adapter.
228 #[must_use]
229 pub fn meta_count(&self) -> i32 {
230 unsafe { llama_cpp_sys_4::llama_adapter_meta_count(self.lora_adapter.as_ptr()) }
231 }
232
233 /// Get a metadata key by index.
234 ///
235 /// # Errors
236 ///
237 /// Returns an error if the index is out of range or the key is not valid UTF-8.
238 #[allow(clippy::cast_sign_loss)]
239 pub fn meta_key_by_index(
240 &self,
241 index: i32,
242 buf_size: usize,
243 ) -> Result<String, StringFromModelError> {
244 let mut buf = vec![0u8; buf_size];
245 let ret = unsafe {
246 llama_cpp_sys_4::llama_adapter_meta_key_by_index(
247 self.lora_adapter.as_ptr(),
248 index,
249 buf.as_mut_ptr().cast::<c_char>(),
250 buf_size,
251 )
252 };
253 if ret < 0 {
254 return Err(StringFromModelError::ReturnedError(ret));
255 }
256 let len = ret as usize;
257 let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
258 Ok(s.to_owned())
259 }
260
261 /// Get a metadata value by key name.
262 ///
263 /// # Errors
264 ///
265 /// Returns an error if the key is not found or the value is not valid UTF-8.
266 #[allow(clippy::cast_sign_loss)]
267 pub fn meta_val_str(&self, key: &str, buf_size: usize) -> Result<String, StringFromModelError> {
268 let c_key = CString::new(key).map_err(|_| StringFromModelError::ReturnedError(-1))?;
269 let mut buf = vec![0u8; buf_size];
270 let ret = unsafe {
271 llama_cpp_sys_4::llama_adapter_meta_val_str(
272 self.lora_adapter.as_ptr(),
273 c_key.as_ptr(),
274 buf.as_mut_ptr().cast::<c_char>(),
275 buf_size,
276 )
277 };
278 if ret < 0 {
279 return Err(StringFromModelError::ReturnedError(ret));
280 }
281 let len = ret as usize;
282 let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
283 Ok(s.to_owned())
284 }
285
286 /// Get a metadata value by index.
287 ///
288 /// # Errors
289 ///
290 /// Returns an error if the index is out of range or the value is not valid UTF-8.
291 #[allow(clippy::cast_sign_loss)]
292 pub fn meta_val_str_by_index(
293 &self,
294 index: i32,
295 buf_size: usize,
296 ) -> Result<String, StringFromModelError> {
297 let mut buf = vec![0u8; buf_size];
298 let ret = unsafe {
299 llama_cpp_sys_4::llama_adapter_meta_val_str_by_index(
300 self.lora_adapter.as_ptr(),
301 index,
302 buf.as_mut_ptr().cast::<c_char>(),
303 buf_size,
304 )
305 };
306 if ret < 0 {
307 return Err(StringFromModelError::ReturnedError(ret));
308 }
309 let len = ret as usize;
310 let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
311 Ok(s.to_owned())
312 }
313
314 /// Get all metadata as a list of `(key, value)` pairs.
315 ///
316 /// # Errors
317 ///
318 /// Returns an error if any key or value cannot be read or is not valid UTF-8.
319 #[allow(clippy::cast_sign_loss)]
320 pub fn metadata(&self) -> Result<Vec<(String, String)>, StringFromModelError> {
321 let count = self.meta_count();
322 let mut result = Vec::with_capacity(count as usize);
323 for i in 0..count {
324 let key = self.meta_key_by_index(i, 256)?;
325 let val = self.meta_val_str_by_index(i, 4096)?;
326 result.push((key, val));
327 }
328 Ok(result)
329 }
330
331 /// Get the number of invocation tokens for this adapter.
332 #[must_use]
333 pub fn n_invocation_tokens(&self) -> u64 {
334 unsafe {
335 llama_cpp_sys_4::llama_adapter_get_alora_n_invocation_tokens(self.lora_adapter.as_ptr())
336 }
337 }
338
339 /// Get the invocation tokens for this adapter.
340 ///
341 /// Returns an empty slice if there are no invocation tokens.
342 #[must_use]
343 #[allow(clippy::cast_possible_truncation)]
344 pub fn invocation_tokens(&self) -> &[LlamaToken] {
345 let n = self.n_invocation_tokens() as usize;
346 if n == 0 {
347 return &[];
348 }
349 let ptr = unsafe {
350 llama_cpp_sys_4::llama_adapter_get_alora_invocation_tokens(self.lora_adapter.as_ptr())
351 };
352 if ptr.is_null() {
353 return &[];
354 }
355 // LlamaToken is repr(transparent) over llama_token (i32), so this cast is safe
356 unsafe { std::slice::from_raw_parts(ptr.cast::<LlamaToken>(), n) }
357 }
358}
359
360impl Drop for LlamaLoraAdapter {
361 fn drop(&mut self) {
362 unsafe {
363 llama_cpp_sys_4::llama_adapter_lora_free(self.lora_adapter.as_ptr());
364 }
365 }
366}
367
368/// A Safe wrapper around `llama_chat_message`
369#[derive(Debug, Eq, PartialEq, Clone)]
370pub struct LlamaChatMessage {
371 role: CString,
372 content: CString,
373}
374
375impl LlamaChatMessage {
376 /// Create a new `LlamaChatMessage`.
377 ///
378 /// # Errors
379 ///
380 /// Returns [`NewLlamaChatMessageError`] if the role or content contains a null byte.
381 pub fn new(role: String, content: String) -> Result<Self, NewLlamaChatMessageError> {
382 Ok(Self {
383 role: CString::new(role)?,
384 content: CString::new(content)?,
385 })
386 }
387}
388
389/// How to determine if we should prepend a bos token to tokens
390#[derive(Debug, Clone, Copy, PartialEq, Eq)]
391pub enum AddBos {
392 /// Add the beginning of stream token to the start of the string.
393 Always,
394 /// Do not add the beginning of stream token to the start of the string.
395 Never,
396}
397
398/// How to determine if we should tokenize special tokens
399#[derive(Debug, Clone, Copy, PartialEq, Eq)]
400pub enum Special {
401 /// Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space.
402 Tokenize,
403 /// Treat special and/or control tokens as plaintext.
404 Plaintext,
405}
406
407unsafe impl Send for LlamaModel {}
408
409unsafe impl Sync for LlamaModel {}
410
411impl LlamaModel {
412 /// Retrieves the vocabulary associated with the current Llama model.
413 ///
414 /// This method fetches the vocabulary from the underlying model using an unsafe
415 /// FFI call. The returned `LlamaVocab` struct contains a non-null pointer to
416 /// the vocabulary data, which is wrapped in a `NonNull` for safety.
417 ///
418 /// # Safety
419 /// This method uses an unsafe block to call a C function (`llama_model_get_vocab`),
420 /// which is assumed to return a valid pointer to the vocabulary. The caller should
421 /// ensure that the model object is properly initialized and valid before calling
422 /// this method, as dereferencing invalid pointers can lead to undefined behavior.
423 ///
424 /// # Returns
425 /// A `LlamaVocab` struct containing the vocabulary of the model.
426 ///
427 /// # Panics
428 ///
429 /// Panics if the underlying C function returns a null pointer.
430 ///
431 /// # Example
432 /// ```rust,ignore
433 /// let vocab = model.get_vocab();
434 /// ```
435 #[must_use]
436 pub fn get_vocab(&self) -> LlamaVocab {
437 let llama_vocab = unsafe { llama_model_get_vocab(self.model.as_ptr()) }.cast_mut();
438
439 LlamaVocab {
440 vocab: NonNull::new(llama_vocab).unwrap(),
441 }
442 }
443 /// Get the number of tokens the model was trained on.
444 ///
445 /// This function returns the number of tokens that the model was trained on, represented as a `u32`.
446 ///
447 /// # Panics
448 ///
449 /// This function will panic if the number of tokens the model was trained on does not fit into a `u32`.
450 /// This should be impossible on most platforms since llama.cpp returns a `c_int` (i32 on most platforms),
451 /// which is almost certainly positive.
452 #[must_use]
453 pub fn n_ctx_train(&self) -> u32 {
454 let n_ctx_train = unsafe { llama_model_n_ctx_train(self.model.as_ptr()) };
455 u32::try_from(n_ctx_train).expect("n_ctx_train fits into an u32")
456 }
457
458 /// Get all tokens in the model.
459 ///
460 /// This function returns an iterator over all the tokens in the model. Each item in the iterator is a tuple
461 /// containing a `LlamaToken` and its corresponding string representation (or an error if the conversion fails).
462 ///
463 /// # Parameters
464 ///
465 /// - `special`: The `Special` value that determines how special tokens (like BOS, EOS, etc.) are handled.
466 pub fn tokens(
467 &self,
468 special: Special,
469 ) -> impl Iterator<Item = (LlamaToken, Result<String, TokenToStringError>)> + '_ {
470 (0..self.n_vocab())
471 .map(LlamaToken::new)
472 .map(move |llama_token| (llama_token, self.token_to_str(llama_token, special)))
473 }
474
475 /// Get the beginning of stream token.
476 ///
477 /// This function returns the token that represents the beginning of a stream (BOS token).
478 #[must_use]
479 pub fn token_bos(&self) -> LlamaToken {
480 self.get_vocab().bos()
481 }
482
483 /// Get the end of stream token.
484 ///
485 /// This function returns the token that represents the end of a stream (EOS token).
486 #[must_use]
487 pub fn token_eos(&self) -> LlamaToken {
488 self.get_vocab().eos()
489 }
490
491 /// Get the newline token.
492 ///
493 /// This function returns the token that represents a newline character.
494 #[must_use]
495 pub fn token_nl(&self) -> LlamaToken {
496 self.get_vocab().nl()
497 }
498
499 /// Check if a token represents the end of generation (end of turn, end of sequence, etc.).
500 ///
501 /// This function returns `true` if the provided token signifies the end of generation or end of sequence,
502 /// such as EOS or other special tokens.
503 ///
504 /// # Parameters
505 ///
506 /// - `token`: The `LlamaToken` to check.
507 ///
508 /// # Returns
509 ///
510 /// - `true` if the token is an end-of-generation token, otherwise `false`.
511 #[must_use]
512 pub fn is_eog_token(&self, token: LlamaToken) -> bool {
513 self.get_vocab().is_eog(token)
514 }
515
516 /// Get the classification token.
517 #[must_use]
518 pub fn token_cls(&self) -> LlamaToken {
519 self.get_vocab().cls()
520 }
521
522 /// Get the end-of-turn token.
523 #[must_use]
524 pub fn token_eot(&self) -> LlamaToken {
525 self.get_vocab().eot()
526 }
527
528 /// Get the padding token.
529 #[must_use]
530 pub fn token_pad(&self) -> LlamaToken {
531 self.get_vocab().pad()
532 }
533
534 /// Get the separator token.
535 #[must_use]
536 pub fn token_sep(&self) -> LlamaToken {
537 self.get_vocab().sep()
538 }
539
540 /// Get the fill-in-the-middle prefix token.
541 #[must_use]
542 pub fn token_fim_pre(&self) -> LlamaToken {
543 self.get_vocab().fim_pre()
544 }
545
546 /// Get the fill-in-the-middle suffix token.
547 #[must_use]
548 pub fn token_fim_suf(&self) -> LlamaToken {
549 self.get_vocab().fim_suf()
550 }
551
552 /// Get the fill-in-the-middle middle token.
553 #[must_use]
554 pub fn token_fim_mid(&self) -> LlamaToken {
555 self.get_vocab().fim_mid()
556 }
557
558 /// Get the fill-in-the-middle padding token.
559 #[must_use]
560 pub fn token_fim_pad(&self) -> LlamaToken {
561 self.get_vocab().fim_pad()
562 }
563
564 /// Get the fill-in-the-middle repository token.
565 #[must_use]
566 pub fn token_fim_rep(&self) -> LlamaToken {
567 self.get_vocab().fim_rep()
568 }
569
570 /// Get the fill-in-the-middle separator token.
571 #[must_use]
572 pub fn token_fim_sep(&self) -> LlamaToken {
573 self.get_vocab().fim_sep()
574 }
575
576 /// Check if a token is a control token.
577 #[must_use]
578 pub fn token_is_control(&self, token: LlamaToken) -> bool {
579 self.get_vocab().is_control(token)
580 }
581
582 /// Get the score of a token.
583 #[must_use]
584 pub fn token_get_score(&self, token: LlamaToken) -> f32 {
585 self.get_vocab().get_score(token)
586 }
587
588 /// Get the raw text of a token.
589 ///
590 /// # Errors
591 ///
592 /// Returns an error if the token text is null or not valid UTF-8.
593 pub fn token_get_text(&self, token: LlamaToken) -> Result<&str, StringFromModelError> {
594 let ptr = unsafe {
595 llama_cpp_sys_4::llama_vocab_get_text(self.get_vocab().vocab.as_ref(), token.0)
596 };
597 if ptr.is_null() {
598 return Err(StringFromModelError::ReturnedError(-1));
599 }
600 let cstr = unsafe { CStr::from_ptr(ptr) };
601 cstr.to_str().map_err(StringFromModelError::Utf8Error)
602 }
603
604 /// Check if a BOS token should be added when tokenizing.
605 #[must_use]
606 pub fn add_bos_token(&self) -> bool {
607 self.get_vocab().get_add_bos()
608 }
609
610 /// Check if an EOS token should be added when tokenizing.
611 #[must_use]
612 pub fn add_eos_token(&self) -> bool {
613 self.get_vocab().get_add_eos()
614 }
615
616 /// Get the decoder start token.
617 ///
618 /// This function returns the token used to signal the start of decoding (i.e., the token used at the start
619 /// of a sequence generation).
620 #[must_use]
621 pub fn decode_start_token(&self) -> LlamaToken {
622 let token = unsafe { llama_model_decoder_start_token(self.model.as_ptr()) };
623 LlamaToken(token)
624 }
625
626 /// Convert a single token to a string.
627 ///
628 /// This function converts a `LlamaToken` into its string representation.
629 ///
630 /// # Errors
631 ///
632 /// This function returns an error if the token cannot be converted to a string. For more details, refer to
633 /// [`TokenToStringError`].
634 ///
635 /// # Parameters
636 ///
637 /// - `token`: The `LlamaToken` to convert.
638 /// - `special`: The `Special` value used to handle special tokens.
639 pub fn token_to_str(
640 &self,
641 token: LlamaToken,
642 special: Special,
643 ) -> Result<String, TokenToStringError> {
644 self.token_to_str_with_size(token, 32, special)
645 }
646
647 /// Convert a single token to bytes.
648 ///
649 /// This function converts a `LlamaToken` into a byte representation.
650 ///
651 /// # Errors
652 ///
653 /// This function returns an error if the token cannot be converted to bytes. For more details, refer to
654 /// [`TokenToStringError`].
655 ///
656 /// # Parameters
657 ///
658 /// - `token`: The `LlamaToken` to convert.
659 /// - `special`: The `Special` value used to handle special tokens.
660 pub fn token_to_bytes(
661 &self,
662 token: LlamaToken,
663 special: Special,
664 ) -> Result<Vec<u8>, TokenToStringError> {
665 self.token_to_bytes_with_size(token, 32, special, None)
666 }
667
668 /// Convert a vector of tokens to a single string.
669 ///
670 /// This function takes a slice of `LlamaToken`s and converts them into a single string, concatenating their
671 /// string representations.
672 ///
673 /// # Errors
674 ///
675 /// This function returns an error if any token cannot be converted to a string. For more details, refer to
676 /// [`TokenToStringError`].
677 ///
678 /// # Parameters
679 ///
680 /// - `tokens`: A slice of `LlamaToken`s to convert.
681 /// - `special`: The `Special` value used to handle special tokens.
682 pub fn tokens_to_str(
683 &self,
684 tokens: &[LlamaToken],
685 special: Special,
686 ) -> Result<String, TokenToStringError> {
687 let mut builder = String::with_capacity(tokens.len() * 4);
688 for str in tokens
689 .iter()
690 .copied()
691 .map(|t| self.token_to_str(t, special))
692 {
693 builder += &str?;
694 }
695 Ok(builder)
696 }
697
698 /// Convert a string to a vector of tokens.
699 ///
700 /// This function converts a string into a vector of `LlamaToken`s. The function will tokenize the string
701 /// and return the corresponding tokens.
702 ///
703 /// # Errors
704 ///
705 /// - This function will return an error if the input string contains a null byte.
706 ///
707 /// # Panics
708 ///
709 /// - This function will panic if the number of tokens exceeds `usize::MAX`.
710 ///
711 /// # Example
712 ///
713 /// ```no_run
714 /// use llama_cpp_4::model::LlamaModel;
715 ///
716 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
717 /// use std::path::Path;
718 /// use llama_cpp_4::model::AddBos;
719 /// let backend = llama_cpp_4::llama_backend::LlamaBackend::init()?;
720 /// let model = LlamaModel::load_from_file(&backend, Path::new("path/to/model"), &Default::default())?;
721 /// let tokens = model.str_to_token("Hello, World!", AddBos::Always)?;
722 /// # Ok(())
723 /// # }
724 /// ```
725 pub fn str_to_token(
726 &self,
727 str: &str,
728 add_bos: AddBos,
729 ) -> Result<Vec<LlamaToken>, StringToTokenError> {
730 let add_bos = match add_bos {
731 AddBos::Always => true,
732 AddBos::Never => false,
733 };
734
735 let tokens_estimation = std::cmp::max(8, (str.len() / 2) + usize::from(add_bos));
736 let mut buffer = Vec::with_capacity(tokens_estimation);
737
738 let c_string = CString::new(str)?;
739 let buffer_capacity =
740 c_int::try_from(buffer.capacity()).expect("buffer capacity should fit into a c_int");
741
742 let size = unsafe {
743 llama_tokenize(
744 self.get_vocab().vocab.as_ref(),
745 c_string.as_ptr(),
746 c_int::try_from(c_string.as_bytes().len())?,
747 buffer.as_mut_ptr(),
748 buffer_capacity,
749 add_bos,
750 true,
751 )
752 };
753
754 // if we fail the first time we can resize the vector to the correct size and try again. This should never fail.
755 // as a result - size is guaranteed to be positive here.
756 let size = if size.is_negative() {
757 buffer.reserve_exact(usize::try_from(-size).expect("usize's are larger "));
758 unsafe {
759 llama_tokenize(
760 self.get_vocab().vocab.as_ref(),
761 c_string.as_ptr(),
762 c_int::try_from(c_string.as_bytes().len())?,
763 buffer.as_mut_ptr(),
764 -size,
765 add_bos,
766 true,
767 )
768 }
769 } else {
770 size
771 };
772
773 let size = usize::try_from(size).expect("size is positive and usize ");
774
775 // Safety: `size` < `capacity` and llama-cpp has initialized elements up to `size`
776 unsafe { buffer.set_len(size) }
777 Ok(buffer.into_iter().map(LlamaToken).collect())
778 }
779
780 /// Get the type of a token.
781 ///
782 /// This function retrieves the attributes associated with a given token. The attributes are typically used to
783 /// understand whether the token represents a special type of token (e.g., beginning-of-sequence (BOS), end-of-sequence (EOS),
784 /// control tokens, etc.).
785 ///
786 /// # Panics
787 ///
788 /// - This function will panic if the token type is unknown or cannot be converted to a valid `LlamaTokenAttrs`.
789 ///
790 /// # Example
791 ///
792 /// ```no_run
793 /// use llama_cpp_4::model::LlamaModel;
794 /// use llama_cpp_4::model::params::LlamaModelParams;
795 /// use llama_cpp_4::llama_backend::LlamaBackend;
796 /// use llama_cpp_4::token::LlamaToken;
797 ///
798 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
799 /// let backend = LlamaBackend::init()?;
800 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
801 /// let token = LlamaToken::new(42);
802 /// let token_attrs = model.token_attr(token);
803 /// # Ok(())
804 /// # }
805 /// ```
806 #[must_use]
807 pub fn token_attr(&self, LlamaToken(id): LlamaToken) -> LlamaTokenAttrs {
808 let token_type =
809 unsafe { llama_cpp_sys_4::llama_vocab_get_attr(self.get_vocab().vocab.as_ref(), id) };
810 LlamaTokenAttrs::try_from(token_type).expect("token type is valid")
811 }
812
813 /// Detokenize a slice of tokens into a string.
814 ///
815 /// This is the inverse of [`str_to_token`](Self::str_to_token).
816 ///
817 /// # Parameters
818 ///
819 /// - `tokens`: The tokens to detokenize.
820 /// - `remove_special`: If `true`, special tokens are removed from the output.
821 /// - `unparse_special`: If `true`, special tokens are rendered as their text representation.
822 ///
823 /// # Errors
824 ///
825 /// Returns an error if the detokenized text is not valid UTF-8.
826 #[allow(
827 clippy::cast_possible_truncation,
828 clippy::cast_possible_wrap,
829 clippy::cast_sign_loss
830 )]
831 pub fn detokenize(
832 &self,
833 tokens: &[LlamaToken],
834 remove_special: bool,
835 unparse_special: bool,
836 ) -> Result<String, StringFromModelError> {
837 // First call with empty buffer to get required size
838 let n_tokens = tokens.len() as i32;
839 let token_ptr = tokens.as_ptr().cast::<llama_cpp_sys_4::llama_token>();
840 let needed = unsafe {
841 llama_detokenize(
842 self.get_vocab().vocab.as_ref(),
843 token_ptr,
844 n_tokens,
845 std::ptr::null_mut(),
846 0,
847 remove_special,
848 unparse_special,
849 )
850 };
851 // llama_detokenize returns negative required size when buffer is too small
852 let buf_size = if needed < 0 {
853 (-needed) as usize
854 } else {
855 needed as usize
856 };
857 let mut buf = vec![0u8; buf_size];
858 let ret = unsafe {
859 llama_detokenize(
860 self.get_vocab().vocab.as_ref(),
861 token_ptr,
862 n_tokens,
863 buf.as_mut_ptr().cast::<c_char>(),
864 buf_size as i32,
865 remove_special,
866 unparse_special,
867 )
868 };
869 if ret < 0 {
870 return Err(StringFromModelError::ReturnedError(ret));
871 }
872 let len = ret as usize;
873 let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
874 Ok(s.to_owned())
875 }
876
877 /// Convert a token to a string with a specified buffer size.
878 ///
879 /// This function allows you to convert a token into a string, with the ability to specify a buffer size for the operation.
880 /// It is generally recommended to use `LlamaModel::token_to_str` instead, as 8 bytes is typically sufficient for most tokens,
881 /// and the extra buffer size doesn't usually matter.
882 ///
883 /// # Errors
884 ///
885 /// - If the token type is unknown, an error will be returned.
886 /// - If the resultant token exceeds the provided `buffer_size`, an error will occur.
887 /// - If the token string returned by `llama-cpp` is not valid UTF-8, it will return an error.
888 ///
889 /// # Panics
890 ///
891 /// - This function will panic if the `buffer_size` does not fit into a `c_int`.
892 /// - It will also panic if the size returned from `llama-cpp` does not fit into a `usize`, which should typically never happen.
893 ///
894 /// # Example
895 ///
896 /// ```no_run
897 /// use llama_cpp_4::model::{LlamaModel, Special};
898 /// use llama_cpp_4::model::params::LlamaModelParams;
899 /// use llama_cpp_4::llama_backend::LlamaBackend;
900 /// use llama_cpp_4::token::LlamaToken;
901 ///
902 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
903 /// let backend = LlamaBackend::init()?;
904 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
905 /// let token = LlamaToken::new(42);
906 /// let token_string = model.token_to_str_with_size(token, 32, Special::Plaintext)?;
907 /// # Ok(())
908 /// # }
909 /// ```
910 pub fn token_to_str_with_size(
911 &self,
912 token: LlamaToken,
913 buffer_size: usize,
914 special: Special,
915 ) -> Result<String, TokenToStringError> {
916 let bytes = self.token_to_bytes_with_size(token, buffer_size, special, None)?;
917 Ok(String::from_utf8(bytes)?)
918 }
919
920 /// Convert a token to bytes with a specified buffer size.
921 ///
922 /// Generally you should use [`LlamaModel::token_to_bytes`] instead as 8 bytes is enough for most words and
923 /// the extra bytes do not really matter.
924 ///
925 /// # Errors
926 ///
927 /// - if the token type is unknown
928 /// - the resultant token is larger than `buffer_size`.
929 ///
930 /// # Panics
931 ///
932 /// - This function will panic if `buffer_size` cannot fit into a `c_int`.
933 /// - It will also panic if the size returned from `llama-cpp` cannot be converted to `usize` (which should not happen).
934 ///
935 /// # Example
936 ///
937 /// ```no_run
938 /// use llama_cpp_4::model::{LlamaModel, Special};
939 /// use llama_cpp_4::model::params::LlamaModelParams;
940 /// use llama_cpp_4::llama_backend::LlamaBackend;
941 /// use llama_cpp_4::token::LlamaToken;
942 ///
943 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
944 /// let backend = LlamaBackend::init()?;
945 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
946 /// let token = LlamaToken::new(42);
947 /// let token_bytes = model.token_to_bytes_with_size(token, 32, Special::Plaintext, None)?;
948 /// # Ok(())
949 /// # }
950 /// ```
951 pub fn token_to_bytes_with_size(
952 &self,
953 token: LlamaToken,
954 buffer_size: usize,
955 special: Special,
956 lstrip: Option<NonZeroU16>,
957 ) -> Result<Vec<u8>, TokenToStringError> {
958 if token == self.token_nl() {
959 return Ok(String::from("\n").into_bytes());
960 }
961
962 // unsure what to do with this in the face of the 'special' arg + attr changes
963 let attrs = self.token_attr(token);
964 if (attrs.contains(LlamaTokenAttr::Control)
965 && (token == self.token_bos() || token == self.token_eos()))
966 || attrs.is_empty()
967 || attrs
968 .intersects(LlamaTokenAttr::Unknown | LlamaTokenAttr::Byte | LlamaTokenAttr::Unused)
969 {
970 return Ok(Vec::new());
971 }
972
973 let special = match special {
974 Special::Tokenize => true,
975 Special::Plaintext => false,
976 };
977
978 let string = CString::new(vec![b'*'; buffer_size]).expect("no null");
979 let len = string.as_bytes().len();
980 let len = c_int::try_from(len).expect("length fits into c_int");
981 let buf = string.into_raw();
982 let lstrip = lstrip.map_or(0, |it| i32::from(it.get()));
983 let size = unsafe {
984 llama_token_to_piece(
985 self.get_vocab().vocab.as_ref(),
986 token.0,
987 buf,
988 len,
989 lstrip,
990 special,
991 )
992 };
993
994 match size {
995 0 => Err(TokenToStringError::UnknownTokenType),
996 i if i.is_negative() => Err(TokenToStringError::InsufficientBufferSpace(i)),
997 size => {
998 let string = unsafe { CString::from_raw(buf) };
999 let mut bytes = string.into_bytes();
1000 let len = usize::try_from(size).expect("size is positive and fits into usize");
1001 bytes.truncate(len);
1002 Ok(bytes)
1003 }
1004 }
1005 }
1006 /// The number of tokens the model was trained on.
1007 ///
1008 /// This function returns the number of tokens the model was trained on. It is returned as a `c_int` for maximum
1009 /// compatibility with the underlying llama-cpp library, though it can typically be cast to an `i32` without issue.
1010 ///
1011 /// # Example
1012 ///
1013 /// ```no_run
1014 /// use llama_cpp_4::model::LlamaModel;
1015 /// use llama_cpp_4::model::params::LlamaModelParams;
1016 /// use llama_cpp_4::llama_backend::LlamaBackend;
1017 ///
1018 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1019 /// let backend = LlamaBackend::init()?;
1020 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1021 /// let n_vocab = model.n_vocab();
1022 /// # Ok(())
1023 /// # }
1024 /// ```
1025 #[must_use]
1026 pub fn n_vocab(&self) -> i32 {
1027 self.get_vocab().n_tokens()
1028 }
1029
1030 /// The type of vocab the model was trained on.
1031 ///
1032 /// This function returns the type of vocabulary used by the model, such as whether it is based on byte-pair encoding (BPE),
1033 /// word-level tokens, or another tokenization scheme.
1034 ///
1035 /// # Panics
1036 ///
1037 /// - This function will panic if `llama-cpp` emits a vocab type that is not recognized or is invalid for this library.
1038 ///
1039 /// # Example
1040 ///
1041 /// ```no_run
1042 /// use llama_cpp_4::model::LlamaModel;
1043 /// use llama_cpp_4::model::params::LlamaModelParams;
1044 /// use llama_cpp_4::llama_backend::LlamaBackend;
1045 ///
1046 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1047 /// let backend = LlamaBackend::init()?;
1048 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1049 /// let vocab_type = model.vocab_type();
1050 /// # Ok(())
1051 /// # }
1052 /// ```
1053 #[must_use]
1054 pub fn vocab_type(&self) -> VocabType {
1055 let vocab_type = unsafe { llama_vocab_type(self.get_vocab().vocab.as_ref()) };
1056 VocabType::try_from(vocab_type).expect("invalid vocab type")
1057 }
1058
1059 /// Returns the number of embedding dimensions for the model.
1060 ///
1061 /// This function retrieves the number of embeddings (or embedding dimensions) used by the model. It is typically
1062 /// used for analyzing model architecture and setting up context parameters or other model configuration aspects.
1063 ///
1064 /// # Panics
1065 ///
1066 /// - This function may panic if the underlying `llama-cpp` library returns an invalid embedding dimension value.
1067 ///
1068 /// # Example
1069 ///
1070 /// ```no_run
1071 /// use llama_cpp_4::model::LlamaModel;
1072 /// use llama_cpp_4::model::params::LlamaModelParams;
1073 /// use llama_cpp_4::llama_backend::LlamaBackend;
1074 ///
1075 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1076 /// let backend = LlamaBackend::init()?;
1077 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1078 /// let n_embd = model.n_embd();
1079 /// # Ok(())
1080 /// # }
1081 /// ```
1082 #[must_use]
1083 pub fn n_embd(&self) -> c_int {
1084 unsafe { llama_model_n_embd(self.model.as_ptr()) }
1085 }
1086
1087 /// Get the number of transformer layers in the model.
1088 #[must_use]
1089 pub fn n_layer(&self) -> c_int {
1090 unsafe { llama_model_n_layer(self.model.as_ptr()) }
1091 }
1092
1093 /// Get the number of attention heads in the model.
1094 #[must_use]
1095 pub fn n_head(&self) -> c_int {
1096 unsafe { llama_model_n_head(self.model.as_ptr()) }
1097 }
1098
1099 /// Get the number of key-value attention heads in the model.
1100 #[must_use]
1101 pub fn n_head_kv(&self) -> c_int {
1102 unsafe { llama_model_n_head_kv(self.model.as_ptr()) }
1103 }
1104
1105 /// Get the input embedding size of the model.
1106 #[must_use]
1107 pub fn n_embd_inp(&self) -> c_int {
1108 unsafe { llama_model_n_embd_inp(self.model.as_ptr()) }
1109 }
1110
1111 /// Get the output embedding size of the model.
1112 #[must_use]
1113 pub fn n_embd_out(&self) -> c_int {
1114 unsafe { llama_model_n_embd_out(self.model.as_ptr()) }
1115 }
1116
1117 /// Get the sliding window attention size of the model.
1118 /// Returns 0 if the model does not use sliding window attention.
1119 #[must_use]
1120 pub fn n_swa(&self) -> c_int {
1121 unsafe { llama_model_n_swa(self.model.as_ptr()) }
1122 }
1123
1124 /// Get the `RoPE` type used by the model.
1125 #[must_use]
1126 pub fn rope_type(&self) -> i32 {
1127 unsafe { llama_model_rope_type(self.model.as_ptr()) }
1128 }
1129
1130 /// Get the `RoPE` frequency scale used during training.
1131 #[must_use]
1132 pub fn rope_freq_scale_train(&self) -> f32 {
1133 unsafe { llama_model_rope_freq_scale_train(self.model.as_ptr()) }
1134 }
1135
1136 /// Get the model size in bytes.
1137 #[must_use]
1138 pub fn model_size(&self) -> u64 {
1139 unsafe { llama_model_size(self.model.as_ptr()) }
1140 }
1141
1142 /// Get the number of parameters in the model.
1143 #[must_use]
1144 pub fn n_params(&self) -> u64 {
1145 unsafe { llama_model_n_params(self.model.as_ptr()) }
1146 }
1147
1148 /// Get the number of classification outputs.
1149 #[must_use]
1150 pub fn n_cls_out(&self) -> u32 {
1151 unsafe { llama_model_n_cls_out(self.model.as_ptr()) }
1152 }
1153
1154 /// Get the classification label for the given index.
1155 ///
1156 /// # Errors
1157 ///
1158 /// Returns an error if the label is null or not valid UTF-8.
1159 pub fn cls_label(&self, index: u32) -> Result<&str, StringFromModelError> {
1160 let ptr = unsafe { llama_model_cls_label(self.model.as_ptr(), index) };
1161 if ptr.is_null() {
1162 return Err(StringFromModelError::ReturnedError(-1));
1163 }
1164 let cstr = unsafe { CStr::from_ptr(ptr) };
1165 cstr.to_str().map_err(StringFromModelError::Utf8Error)
1166 }
1167
1168 /// Get the number of metadata key-value pairs.
1169 #[must_use]
1170 pub fn meta_count(&self) -> c_int {
1171 unsafe { llama_model_meta_count(self.model.as_ptr()) }
1172 }
1173
1174 /// Get a model description string.
1175 ///
1176 /// The `buf_size` parameter specifies the maximum buffer size for the description.
1177 /// A default of 256 bytes is usually sufficient.
1178 ///
1179 /// # Errors
1180 ///
1181 /// Returns an error if the description could not be retrieved or is not valid UTF-8.
1182 #[allow(clippy::cast_sign_loss)]
1183 pub fn desc(&self, buf_size: usize) -> Result<String, StringFromModelError> {
1184 let mut buf = vec![0u8; buf_size];
1185 let ret = unsafe {
1186 llama_model_desc(
1187 self.model.as_ptr(),
1188 buf.as_mut_ptr().cast::<c_char>(),
1189 buf_size,
1190 )
1191 };
1192 if ret < 0 {
1193 return Err(StringFromModelError::ReturnedError(ret));
1194 }
1195 let len = ret as usize;
1196 let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
1197 Ok(s.to_owned())
1198 }
1199
1200 /// Get a metadata key by index.
1201 ///
1202 /// The `buf_size` parameter specifies the maximum buffer size for the key.
1203 /// A default of 256 bytes is usually sufficient.
1204 ///
1205 /// # Errors
1206 ///
1207 /// Returns an error if the index is out of range or the key is not valid UTF-8.
1208 #[allow(clippy::cast_sign_loss)]
1209 pub fn meta_key_by_index(
1210 &self,
1211 index: i32,
1212 buf_size: usize,
1213 ) -> Result<String, StringFromModelError> {
1214 let mut buf = vec![0u8; buf_size];
1215 let ret = unsafe {
1216 llama_model_meta_key_by_index(
1217 self.model.as_ptr(),
1218 index,
1219 buf.as_mut_ptr().cast::<c_char>(),
1220 buf_size,
1221 )
1222 };
1223 if ret < 0 {
1224 return Err(StringFromModelError::ReturnedError(ret));
1225 }
1226 let len = ret as usize;
1227 let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
1228 Ok(s.to_owned())
1229 }
1230
1231 /// Get a metadata value string by index.
1232 ///
1233 /// The `buf_size` parameter specifies the maximum buffer size for the value.
1234 /// Values can be large (e.g. chat templates, token lists), so 4096+ may be needed.
1235 ///
1236 /// # Errors
1237 ///
1238 /// Returns an error if the index is out of range or the value is not valid UTF-8.
1239 #[allow(clippy::cast_sign_loss)]
1240 pub fn meta_val_str_by_index(
1241 &self,
1242 index: i32,
1243 buf_size: usize,
1244 ) -> Result<String, StringFromModelError> {
1245 let mut buf = vec![0u8; buf_size];
1246 let ret = unsafe {
1247 llama_model_meta_val_str_by_index(
1248 self.model.as_ptr(),
1249 index,
1250 buf.as_mut_ptr().cast::<c_char>(),
1251 buf_size,
1252 )
1253 };
1254 if ret < 0 {
1255 return Err(StringFromModelError::ReturnedError(ret));
1256 }
1257 let len = ret as usize;
1258 let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
1259 Ok(s.to_owned())
1260 }
1261
1262 /// Get a metadata value by key name.
1263 ///
1264 /// This is more convenient than iterating metadata by index when you know the key.
1265 /// The `buf_size` parameter specifies the maximum buffer size for the value.
1266 ///
1267 /// # Errors
1268 ///
1269 /// Returns an error if the key is not found, contains a null byte, or the value is not valid UTF-8.
1270 #[allow(clippy::cast_sign_loss)]
1271 pub fn meta_val_str(&self, key: &str, buf_size: usize) -> Result<String, StringFromModelError> {
1272 let c_key = CString::new(key).map_err(|_| StringFromModelError::ReturnedError(-1))?;
1273 let mut buf = vec![0u8; buf_size];
1274 let ret = unsafe {
1275 llama_model_meta_val_str(
1276 self.model.as_ptr(),
1277 c_key.as_ptr(),
1278 buf.as_mut_ptr().cast::<c_char>(),
1279 buf_size,
1280 )
1281 };
1282 if ret < 0 {
1283 return Err(StringFromModelError::ReturnedError(ret));
1284 }
1285 let len = ret as usize;
1286 let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
1287 Ok(s.to_owned())
1288 }
1289
1290 /// Get all metadata as a list of `(key, value)` pairs.
1291 ///
1292 /// This is a convenience method that iterates over all metadata entries.
1293 /// Keys use a buffer of 256 bytes and values use 4096 bytes.
1294 /// For values that may be larger (e.g. token lists), use
1295 /// [`meta_val_str_by_index`](Self::meta_val_str_by_index) directly with a larger buffer.
1296 ///
1297 /// # Errors
1298 ///
1299 /// Returns an error if any key or value cannot be read or is not valid UTF-8.
1300 #[allow(clippy::cast_sign_loss)]
1301 pub fn metadata(&self) -> Result<Vec<(String, String)>, StringFromModelError> {
1302 let count = self.meta_count();
1303 let mut result = Vec::with_capacity(count as usize);
1304 for i in 0..count {
1305 let key = self.meta_key_by_index(i, 256)?;
1306 let val = self.meta_val_str_by_index(i, 4096)?;
1307 result.push((key, val));
1308 }
1309 Ok(result)
1310 }
1311
1312 /// Check if the model has an encoder.
1313 #[must_use]
1314 pub fn has_encoder(&self) -> bool {
1315 unsafe { llama_model_has_encoder(self.model.as_ptr()) }
1316 }
1317
1318 /// Check if the model has a decoder.
1319 #[must_use]
1320 pub fn has_decoder(&self) -> bool {
1321 unsafe { llama_model_has_decoder(self.model.as_ptr()) }
1322 }
1323
1324 /// Check if the model is recurrent (e.g. Mamba, RWKV).
1325 #[must_use]
1326 pub fn is_recurrent(&self) -> bool {
1327 unsafe { llama_model_is_recurrent(self.model.as_ptr()) }
1328 }
1329
1330 /// Check if the model is a hybrid model.
1331 #[must_use]
1332 pub fn is_hybrid(&self) -> bool {
1333 unsafe { llama_model_is_hybrid(self.model.as_ptr()) }
1334 }
1335
1336 /// Check if the model is a diffusion model.
1337 #[must_use]
1338 pub fn is_diffusion(&self) -> bool {
1339 unsafe { llama_model_is_diffusion(self.model.as_ptr()) }
1340 }
1341
1342 /// Get chat template from model.
1343 ///
1344 /// # Errors
1345 ///
1346 /// - If the model does not have a chat template, it will return an error.
1347 /// - If the chat template is not a valid `CString`, it will return an error.
1348 ///
1349 /// # Example
1350 ///
1351 /// ```no_run
1352 /// use llama_cpp_4::model::LlamaModel;
1353 /// use llama_cpp_4::model::params::LlamaModelParams;
1354 /// use llama_cpp_4::llama_backend::LlamaBackend;
1355 ///
1356 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1357 /// let backend = LlamaBackend::init()?;
1358 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1359 /// let chat_template = model.get_chat_template(1024)?;
1360 /// # Ok(())
1361 /// # }
1362 /// ```
1363 #[allow(clippy::missing_panics_doc)] // We statically know this will not panic as long as the buffer size is sufficient
1364 pub fn get_chat_template(&self, buf_size: usize) -> Result<String, ChatTemplateError> {
1365 // longest known template is about 1200 bytes from llama.cpp
1366 let chat_temp = CString::new(vec![b'*'; buf_size]).expect("no null");
1367 let chat_ptr = chat_temp.into_raw();
1368 let chat_name = CString::new("tokenizer.chat_template").expect("no null bytes");
1369
1370 let ret = unsafe {
1371 llama_model_meta_val_str(self.model.as_ptr(), chat_name.as_ptr(), chat_ptr, buf_size)
1372 };
1373
1374 if ret < 0 {
1375 return Err(ChatTemplateError::MissingTemplate(ret));
1376 }
1377
1378 let template_c = unsafe { CString::from_raw(chat_ptr) };
1379 let template = template_c.to_str()?;
1380
1381 let ret: usize = ret.try_into().unwrap();
1382 if template.len() < ret {
1383 return Err(ChatTemplateError::BuffSizeError(ret + 1));
1384 }
1385
1386 Ok(template.to_owned())
1387 }
1388
1389 /// Loads a model from a file.
1390 ///
1391 /// This function loads a model from a specified file path and returns the corresponding `LlamaModel` instance.
1392 ///
1393 /// # Errors
1394 ///
1395 /// - If the path cannot be converted to a string or if the model file does not exist, it will return an error.
1396 /// - If the model cannot be loaded (e.g., due to an invalid or corrupted model file), it will return a `LlamaModelLoadError`.
1397 ///
1398 /// # Example
1399 ///
1400 /// ```no_run
1401 /// use llama_cpp_4::model::LlamaModel;
1402 /// use llama_cpp_4::model::params::LlamaModelParams;
1403 /// use llama_cpp_4::llama_backend::LlamaBackend;
1404 ///
1405 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1406 /// let backend = LlamaBackend::init()?;
1407 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1408 /// # Ok(())
1409 /// # }
1410 /// ```
1411 #[tracing::instrument(skip_all, fields(params))]
1412 pub fn load_from_file(
1413 _: &LlamaBackend,
1414 path: impl AsRef<Path>,
1415 params: &LlamaModelParams,
1416 ) -> Result<Self, LlamaModelLoadError> {
1417 let path = path.as_ref();
1418 debug_assert!(
1419 Path::new(path).exists(),
1420 "{} does not exist",
1421 path.display()
1422 );
1423 let path = path
1424 .to_str()
1425 .ok_or(LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
1426
1427 let cstr = CString::new(path)?;
1428 let llama_model = unsafe { llama_model_load_from_file(cstr.as_ptr(), params.params) };
1429
1430 let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?;
1431
1432 tracing::debug!(?path, "Loaded model");
1433 Ok(LlamaModel { model })
1434 }
1435
1436 /// Load a model from multiple split files.
1437 ///
1438 /// This function loads a model that has been split across multiple files. This is useful for
1439 /// very large models that exceed filesystem limitations or need to be distributed across
1440 /// multiple storage devices.
1441 ///
1442 /// # Arguments
1443 ///
1444 /// * `paths` - A slice of paths to the split model files
1445 /// * `params` - The model parameters
1446 ///
1447 /// # Errors
1448 ///
1449 /// Returns an error if:
1450 /// - Any of the paths cannot be converted to a C string
1451 /// - The model fails to load from the splits
1452 /// - Any path doesn't exist or isn't accessible
1453 ///
1454 /// # Example
1455 ///
1456 /// ```no_run
1457 /// use llama_cpp_4::model::{LlamaModel, params::LlamaModelParams};
1458 /// use llama_cpp_4::llama_backend::LlamaBackend;
1459 ///
1460 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1461 /// let backend = LlamaBackend::init()?;
1462 /// let params = LlamaModelParams::default();
1463 ///
1464 /// let paths = vec![
1465 /// "model-00001-of-00003.gguf",
1466 /// "model-00002-of-00003.gguf",
1467 /// "model-00003-of-00003.gguf",
1468 /// ];
1469 ///
1470 /// let model = LlamaModel::load_from_splits(&backend, &paths, ¶ms)?;
1471 /// # Ok(())
1472 /// # }
1473 /// ```
1474 #[tracing::instrument(skip_all)]
1475 pub fn load_from_splits(
1476 _: &LlamaBackend,
1477 paths: &[impl AsRef<Path>],
1478 params: &LlamaModelParams,
1479 ) -> Result<Self, LlamaModelLoadError> {
1480 // Convert paths to C strings
1481 let c_strings: Vec<CString> = paths
1482 .iter()
1483 .map(|p| {
1484 let path = p.as_ref();
1485 debug_assert!(path.exists(), "{} does not exist", path.display());
1486 let path_str = path
1487 .to_str()
1488 .ok_or(LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
1489 CString::new(path_str).map_err(LlamaModelLoadError::from)
1490 })
1491 .collect::<Result<Vec<_>, _>>()?;
1492
1493 // Create array of pointers to C strings
1494 let c_ptrs: Vec<*const c_char> = c_strings.iter().map(|s| s.as_ptr()).collect();
1495
1496 // Load the model from splits
1497 let llama_model = unsafe {
1498 llama_model_load_from_splits(c_ptrs.as_ptr().cast_mut(), c_ptrs.len(), params.params)
1499 };
1500
1501 let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?;
1502
1503 tracing::debug!("Loaded model from {} splits", paths.len());
1504 Ok(LlamaModel { model })
1505 }
1506
1507 /// Load a model from a `FILE` pointer.
1508 ///
1509 /// # Safety
1510 ///
1511 /// The `file` pointer must be a valid, open `FILE*`.
1512 ///
1513 /// # Errors
1514 ///
1515 /// Returns an error if the model cannot be loaded.
1516 pub unsafe fn load_from_file_ptr(
1517 file: *mut llama_cpp_sys_4::FILE,
1518 params: &LlamaModelParams,
1519 ) -> Result<Self, LlamaModelLoadError> {
1520 let model = llama_cpp_sys_4::llama_model_load_from_file_ptr(file, params.params);
1521 let model = NonNull::new(model).ok_or(LlamaModelLoadError::NullResult)?;
1522 Ok(LlamaModel { model })
1523 }
1524
1525 /// Initialize a model from user-provided data.
1526 ///
1527 /// # Safety
1528 ///
1529 /// The metadata, callback, and user data must be valid.
1530 ///
1531 /// # Errors
1532 ///
1533 /// Returns an error if the model cannot be initialized.
1534 pub unsafe fn init_from_user(
1535 metadata: *mut llama_cpp_sys_4::gguf_context,
1536 set_tensor_data: llama_cpp_sys_4::llama_model_set_tensor_data_t,
1537 set_tensor_data_ud: *mut std::ffi::c_void,
1538 params: &LlamaModelParams,
1539 ) -> Result<Self, LlamaModelLoadError> {
1540 let model = llama_cpp_sys_4::llama_model_init_from_user(
1541 metadata,
1542 set_tensor_data,
1543 set_tensor_data_ud,
1544 params.params,
1545 );
1546 let model = NonNull::new(model).ok_or(LlamaModelLoadError::NullResult)?;
1547 Ok(LlamaModel { model })
1548 }
1549
1550 /// Save the model to a file.
1551 ///
1552 /// # Panics
1553 ///
1554 /// Panics if the path contains null bytes.
1555 pub fn save_to_file(&self, path: impl AsRef<Path>) {
1556 let path = path.as_ref();
1557 let path_str = path.to_str().expect("path is not valid UTF-8");
1558 let c_path = CString::new(path_str).expect("path contains null bytes");
1559 unsafe {
1560 llama_model_save_to_file(self.model.as_ptr(), c_path.as_ptr());
1561 }
1562 }
1563
1564 /// Get the list of built-in chat templates.
1565 ///
1566 /// Returns the names of all chat templates that are built into llama.cpp.
1567 ///
1568 /// # Panics
1569 ///
1570 /// Panics if any template name is not valid UTF-8.
1571 #[allow(clippy::cast_sign_loss)]
1572 #[must_use]
1573 pub fn chat_builtin_templates() -> Vec<String> {
1574 // First call to get count
1575 let count = unsafe { llama_chat_builtin_templates(std::ptr::null_mut(), 0) };
1576 if count <= 0 {
1577 return Vec::new();
1578 }
1579 let count = count as usize;
1580 let mut ptrs: Vec<*const c_char> = vec![std::ptr::null(); count];
1581 unsafe {
1582 llama_chat_builtin_templates(ptrs.as_mut_ptr(), count);
1583 }
1584 ptrs.iter()
1585 .map(|&p| {
1586 let cstr = unsafe { CStr::from_ptr(p) };
1587 cstr.to_str()
1588 .expect("template name is not valid UTF-8")
1589 .to_owned()
1590 })
1591 .collect()
1592 }
1593
1594 /// Initializes a lora adapter from a file.
1595 ///
1596 /// This function initializes a Lora adapter, which is a model extension used to adapt or fine-tune the existing model
1597 /// to a specific domain or task. The adapter file is typically in the form of a binary or serialized file that can be applied
1598 /// to the model for improved performance on specialized tasks.
1599 ///
1600 /// # Errors
1601 ///
1602 /// - If the adapter file path cannot be converted to a string or if the adapter cannot be initialized, it will return an error.
1603 ///
1604 /// # Example
1605 ///
1606 /// ```no_run
1607 /// use llama_cpp_4::model::LlamaModel;
1608 /// use llama_cpp_4::model::params::LlamaModelParams;
1609 /// use llama_cpp_4::llama_backend::LlamaBackend;
1610 ///
1611 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1612 /// let backend = LlamaBackend::init()?;
1613 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1614 /// let adapter = model.lora_adapter_init("path/to/lora/adapter")?;
1615 /// # Ok(())
1616 /// # }
1617 /// ```
1618 pub fn lora_adapter_init(
1619 &self,
1620 path: impl AsRef<Path>,
1621 ) -> Result<LlamaLoraAdapter, LlamaLoraAdapterInitError> {
1622 let path = path.as_ref();
1623 debug_assert!(
1624 Path::new(path).exists(),
1625 "{} does not exist",
1626 path.display()
1627 );
1628
1629 let path = path
1630 .to_str()
1631 .ok_or(LlamaLoraAdapterInitError::PathToStrError(
1632 path.to_path_buf(),
1633 ))?;
1634
1635 let cstr = CString::new(path)?;
1636 let adapter = unsafe { llama_adapter_lora_init(self.model.as_ptr(), cstr.as_ptr()) };
1637
1638 let adapter = NonNull::new(adapter).ok_or(LlamaLoraAdapterInitError::NullResult)?;
1639
1640 tracing::debug!(?path, "Initialized lora adapter");
1641 Ok(LlamaLoraAdapter {
1642 lora_adapter: adapter,
1643 })
1644 }
1645
1646 /// Create a new context from this model.
1647 ///
1648 /// This function creates a new context for the model, which is used to manage and perform computations for inference,
1649 /// including token generation, embeddings, and other tasks that the model can perform. The context allows fine-grained
1650 /// control over model parameters for a specific task.
1651 ///
1652 /// # Errors
1653 ///
1654 /// - There are various potential failures such as invalid parameters or a failure to allocate the context. See [`LlamaContextLoadError`]
1655 /// for more detailed error descriptions.
1656 ///
1657 /// # Example
1658 ///
1659 /// ```no_run
1660 /// use llama_cpp_4::model::LlamaModel;
1661 /// use llama_cpp_4::model::params::LlamaModelParams;
1662 /// use llama_cpp_4::context::params::LlamaContextParams;
1663 /// use llama_cpp_4::llama_backend::LlamaBackend;
1664 ///
1665 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1666 /// let backend = LlamaBackend::init()?;
1667 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1668 /// let context = model.new_context(&backend, LlamaContextParams::default())?;
1669 /// # Ok(())
1670 /// # }
1671 /// ```
1672 #[allow(clippy::needless_pass_by_value)]
1673 pub fn new_context(
1674 &self,
1675 _: &LlamaBackend,
1676 params: LlamaContextParams,
1677 ) -> Result<LlamaContext<'_>, LlamaContextLoadError> {
1678 // Apply TurboQuant attn-rotation preference before the KV cache is
1679 // initialised inside llama_init_from_model.
1680 let prev_rot_var = std::env::var("LLAMA_ATTN_ROT_DISABLE").ok();
1681 if params.attn_rot_disabled {
1682 // SAFETY: we restore the value right after the call.
1683 #[allow(unused_unsafe)]
1684 unsafe {
1685 std::env::set_var("LLAMA_ATTN_ROT_DISABLE", "1");
1686 }
1687 } else if std::env::var("LLAMA_ATTN_ROT_DISABLE").is_ok() {
1688 // params say "enabled" – only clear if it was previously unset
1689 // (respect explicit user env var).
1690 }
1691
1692 let context_params = params.context_params;
1693 let context = unsafe { llama_init_from_model(self.model.as_ptr(), context_params) };
1694
1695 // Restore the env-var to its previous state.
1696 #[allow(unused_unsafe)]
1697 match prev_rot_var {
1698 Some(v) => unsafe { std::env::set_var("LLAMA_ATTN_ROT_DISABLE", v) },
1699 None if params.attn_rot_disabled => unsafe {
1700 std::env::remove_var("LLAMA_ATTN_ROT_DISABLE");
1701 },
1702 None => {}
1703 }
1704
1705 let context = NonNull::new(context).ok_or(LlamaContextLoadError::NullReturn)?;
1706 Ok(LlamaContext::new(self, context, params.embeddings()))
1707 }
1708
1709 /// Apply the model's chat template to a sequence of messages.
1710 ///
1711 /// This function applies the model's chat template to the provided chat messages, formatting them accordingly. The chat
1712 /// template determines the structure or style of conversation between the system and user, such as token formatting,
1713 /// role separation, and more. The template can be customized by providing an optional template string, or if `None`
1714 /// is provided, the default template used by `llama.cpp` will be applied.
1715 ///
1716 /// For more information on supported templates, visit:
1717 /// <https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template>
1718 ///
1719 /// # Arguments
1720 ///
1721 /// - `tmpl`: An optional custom template string. If `None`, the default template will be used.
1722 /// - `chat`: A vector of `LlamaChatMessage` instances, which represent the conversation between the system and user.
1723 /// - `add_ass`: A boolean flag indicating whether additional system-specific instructions (like "assistant") should be included.
1724 ///
1725 /// # Errors
1726 ///
1727 /// There are several possible points of failure when applying the chat template:
1728 /// - Insufficient buffer size to hold the formatted chat (this will return `ApplyChatTemplateError::BuffSizeError`).
1729 /// - If the template or messages cannot be processed properly, various errors from `ApplyChatTemplateError` may occur.
1730 ///
1731 /// # Example
1732 ///
1733 /// ```no_run
1734 /// use llama_cpp_4::model::{LlamaModel, LlamaChatMessage};
1735 /// use llama_cpp_4::model::params::LlamaModelParams;
1736 /// use llama_cpp_4::llama_backend::LlamaBackend;
1737 ///
1738 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1739 /// let backend = LlamaBackend::init()?;
1740 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1741 /// let chat = vec![
1742 /// LlamaChatMessage::new("user".to_string(), "Hello!".to_string())?,
1743 /// LlamaChatMessage::new("assistant".to_string(), "Hi! How can I assist you today?".to_string())?,
1744 /// ];
1745 /// let formatted_chat = model.apply_chat_template(None, &chat, true)?;
1746 /// # Ok(())
1747 /// # }
1748 /// ```
1749 ///
1750 /// # Notes
1751 ///
1752 /// The provided buffer is twice the length of the messages by default, which is recommended by the `llama.cpp` documentation.
1753 /// # Panics
1754 ///
1755 /// Panics if the buffer length exceeds `i32::MAX`.
1756 #[tracing::instrument(skip_all)]
1757 pub fn apply_chat_template(
1758 &self,
1759 tmpl: Option<&str>,
1760 chat: &[LlamaChatMessage],
1761 add_ass: bool,
1762 ) -> Result<String, ApplyChatTemplateError> {
1763 // Compute raw message byte total from the original LlamaChatMessage vec
1764 // *before* we shadow `chat` with the sys-type vec below.
1765 let message_length = chat.iter().fold(0usize, |acc, c| {
1766 acc + c.role.to_bytes().len() + c.content.to_bytes().len()
1767 });
1768
1769 // Build our llama_cpp_sys chat messages (raw pointers into CStrings).
1770 let chat_sys: Vec<llama_chat_message> = chat
1771 .iter()
1772 .map(|c| llama_chat_message {
1773 role: c.role.as_ptr(),
1774 content: c.content.as_ptr(),
1775 })
1776 .collect();
1777
1778 // Set the tmpl pointer.
1779 let tmpl_cstring = tmpl.map(CString::new).transpose()?;
1780 let tmpl_ptr = tmpl_cstring
1781 .as_ref()
1782 .map_or(std::ptr::null(), |s| s.as_ptr());
1783
1784 // `message_length * 4` is far too small for models whose built-in chat
1785 // template adds a long default system prompt (e.g. Qwen3.5 prepends
1786 // ~80+ chars of markup even for a one-word user message). Start with
1787 // at least 4 KiB so short inputs like "hi" always have room.
1788 //
1789 // `llama_chat_apply_template` returns the number of bytes it *actually*
1790 // needed when the buffer was too small, so we retry exactly once with
1791 // that precise size rather than giving up immediately.
1792 let mut buf_size = message_length.saturating_mul(4).max(4096);
1793
1794 for _ in 0..2 {
1795 // Use u8 so that as_mut_ptr()/as_ptr() match the binding (*mut u8 / *const u8).
1796 let mut buff = vec![0u8; buf_size];
1797 let res = unsafe {
1798 llama_chat_apply_template(
1799 tmpl_ptr,
1800 chat_sys.as_ptr(),
1801 chat_sys.len(),
1802 add_ass,
1803 buff.as_mut_ptr().cast(),
1804 i32::try_from(buff.len()).expect("buffer length fits in i32"),
1805 )
1806 };
1807
1808 if res < 0 {
1809 return Err(ApplyChatTemplateError::BuffSizeError);
1810 }
1811
1812 #[allow(clippy::cast_sign_loss)]
1813 let needed = res as usize;
1814 if needed > buf_size {
1815 // Buffer was too small — retry with the exact size llama.cpp reported.
1816 buf_size = needed + 1; // +1 for null terminator
1817 continue;
1818 }
1819
1820 // SAFETY: llama_chat_apply_template wrote a NUL-terminated string
1821 // into `buff`; `needed` bytes were used.
1822 let formatted = unsafe {
1823 CStr::from_ptr(buff.as_ptr().cast())
1824 .to_string_lossy()
1825 .into_owned()
1826 };
1827 return Ok(formatted);
1828 }
1829
1830 Err(ApplyChatTemplateError::BuffSizeError)
1831 }
1832
1833 /// Build a split GGUF file path for a specific chunk.
1834 ///
1835 /// This utility function creates the standardized filename for a split model chunk
1836 /// following the pattern: `{prefix}-{split_no:05d}-of-{split_count:05d}.gguf`
1837 ///
1838 /// # Arguments
1839 ///
1840 /// * `path_prefix` - The base path and filename prefix
1841 /// * `split_no` - The split number (1-indexed)
1842 /// * `split_count` - The total number of splits
1843 ///
1844 /// # Returns
1845 ///
1846 /// Returns the formatted split path as a String
1847 ///
1848 /// # Example
1849 ///
1850 /// ```
1851 /// use llama_cpp_4::model::LlamaModel;
1852 ///
1853 /// let path = LlamaModel::split_path("/models/llama", 1, 4);
1854 /// assert_eq!(path, "/models/llama-00002-of-00004.gguf");
1855 /// ```
1856 ///
1857 /// # Panics
1858 ///
1859 /// Panics if the path prefix contains a null byte.
1860 #[must_use]
1861 pub fn split_path(path_prefix: &str, split_no: i32, split_count: i32) -> String {
1862 let mut buffer = vec![0u8; 1024];
1863 let len = unsafe {
1864 llama_split_path(
1865 buffer.as_mut_ptr().cast::<c_char>(),
1866 buffer.len(),
1867 CString::new(path_prefix).unwrap().as_ptr(),
1868 split_no,
1869 split_count,
1870 )
1871 };
1872
1873 let len = usize::try_from(len).expect("split_path length fits in usize");
1874 buffer.truncate(len);
1875 String::from_utf8(buffer).unwrap_or_default()
1876 }
1877
1878 /// Extract the path prefix from a split filename.
1879 ///
1880 /// This function extracts the base path prefix from a split model filename,
1881 /// but only if the `split_no` and `split_count` match the pattern in the filename.
1882 ///
1883 /// # Arguments
1884 ///
1885 /// * `split_path` - The full path to the split file
1886 /// * `split_no` - The expected split number
1887 /// * `split_count` - The expected total number of splits
1888 ///
1889 /// # Returns
1890 ///
1891 /// Returns the path prefix if the pattern matches, or None if it doesn't
1892 ///
1893 /// # Example
1894 ///
1895 /// ```
1896 /// use llama_cpp_4::model::LlamaModel;
1897 ///
1898 /// let prefix = LlamaModel::split_prefix("/models/llama-00002-of-00004.gguf", 1, 4);
1899 /// assert_eq!(prefix, Some("/models/llama".to_string()));
1900 /// ```
1901 ///
1902 /// # Panics
1903 ///
1904 /// Panics if the split path contains a null byte.
1905 #[must_use]
1906 pub fn split_prefix(split_path: &str, split_no: i32, split_count: i32) -> Option<String> {
1907 let mut buffer = vec![0u8; 1024];
1908 let len = unsafe {
1909 llama_split_prefix(
1910 buffer.as_mut_ptr().cast::<c_char>(),
1911 buffer.len(),
1912 CString::new(split_path).unwrap().as_ptr(),
1913 split_no,
1914 split_count,
1915 )
1916 };
1917
1918 if len > 0 {
1919 let len = usize::try_from(len).expect("split_prefix length fits in usize");
1920 buffer.truncate(len);
1921 String::from_utf8(buffer).ok()
1922 } else {
1923 None
1924 }
1925 }
1926}
1927
1928#[allow(clippy::cast_precision_loss)]
1929impl fmt::Display for LlamaModel {
1930 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1931 let desc = self.desc(256).unwrap_or_else(|_| "unknown".to_string());
1932 write!(
1933 f,
1934 "{desc} | {layers}L {heads}H {embd}E | {params} params | {size:.1} MiB",
1935 layers = self.n_layer(),
1936 heads = self.n_head(),
1937 embd = self.n_embd(),
1938 params = self.n_params(),
1939 size = self.model_size() as f64 / (1024.0 * 1024.0),
1940 )
1941 }
1942}
1943
1944impl Drop for LlamaModel {
1945 fn drop(&mut self) {
1946 unsafe { llama_model_free(self.model.as_ptr()) }
1947 }
1948}
1949
1950/// Defines the possible types of vocabulary used by the model.
1951///
1952/// The model may use different types of vocabulary depending on the tokenization method chosen during training.
1953/// This enum represents these types, specifically `BPE` (Byte Pair Encoding) and `SPM` (`SentencePiece`).
1954///
1955/// # Variants
1956///
1957/// - `BPE`: Byte Pair Encoding, a common tokenization method used in NLP tasks.
1958/// - `SPM`: `SentencePiece`, another popular tokenization method for NLP models.
1959///
1960/// # Example
1961///
1962/// ```no_run
1963/// use llama_cpp_4::model::VocabType;
1964///
1965/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1966/// let vocab_type = VocabType::BPE;
1967/// match vocab_type {
1968/// VocabType::BPE => println!("The model uses Byte Pair Encoding (BPE)"),
1969/// VocabType::SPM => println!("The model uses SentencePiece (SPM)"),
1970/// }
1971/// # Ok(())
1972/// # }
1973/// ```
1974#[repr(u32)]
1975#[derive(Debug, Eq, Copy, Clone, PartialEq)]
1976pub enum VocabType {
1977 /// Byte Pair Encoding
1978 BPE = LLAMA_VOCAB_TYPE_BPE as _,
1979 /// Sentence Piece Tokenizer
1980 SPM = LLAMA_VOCAB_TYPE_SPM as _,
1981}
1982
1983/// Error that occurs when trying to convert a `llama_vocab_type` to a `VocabType`.
1984///
1985/// This error is raised when the integer value returned by the system does not correspond to a known vocabulary type.
1986///
1987/// # Variants
1988///
1989/// - `UnknownValue`: The error is raised when the value is not a valid `llama_vocab_type`. The invalid value is returned with the error.
1990///
1991/// # Example
1992///
1993/// ```no_run
1994/// use llama_cpp_4::model::LlamaTokenTypeFromIntError;
1995///
1996/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1997/// let invalid_value = 999; // Not a valid vocabulary type
1998/// let error = LlamaTokenTypeFromIntError::UnknownValue(invalid_value);
1999/// println!("Error: {}", error);
2000/// # Ok(())
2001/// # }
2002/// ```
2003#[derive(thiserror::Error, Debug, Eq, PartialEq)]
2004pub enum LlamaTokenTypeFromIntError {
2005 /// The value is not a valid `llama_token_type`. Contains the int value that was invalid.
2006 #[error("Unknown Value {0}")]
2007 UnknownValue(llama_vocab_type),
2008}
2009
2010impl TryFrom<llama_vocab_type> for VocabType {
2011 type Error = LlamaTokenTypeFromIntError;
2012
2013 fn try_from(value: llama_vocab_type) -> Result<Self, Self::Error> {
2014 match value {
2015 LLAMA_VOCAB_TYPE_BPE => Ok(VocabType::BPE),
2016 LLAMA_VOCAB_TYPE_SPM => Ok(VocabType::SPM),
2017 unknown => Err(LlamaTokenTypeFromIntError::UnknownValue(unknown)),
2018 }
2019 }
2020}