llama_cpp_4/model.rs
1//! A safe wrapper around `llama_model`.
2use std::ffi::CStr;
3use std::ffi::CString;
4use std::fmt;
5use std::num::NonZeroU16;
6use std::os::raw::{c_char, c_int};
7use std::path::Path;
8use std::ptr::NonNull;
9
10use llama_cpp_sys_4::{
11 llama_adapter_lora, llama_adapter_lora_init, llama_add_bos_token, llama_add_eos_token,
12 llama_chat_apply_template, llama_chat_builtin_templates, llama_chat_message,
13 llama_detokenize, llama_free_model, llama_load_model_from_file, llama_model,
14 llama_model_cls_label, llama_model_decoder_start_token, llama_model_desc,
15 llama_model_get_vocab, llama_model_has_decoder, llama_model_has_encoder,
16 llama_model_is_diffusion, llama_model_is_hybrid, llama_model_is_recurrent,
17 llama_model_load_from_splits, llama_model_meta_count, llama_model_meta_key_by_index,
18 llama_model_meta_val_str, llama_model_meta_val_str_by_index, llama_model_n_cls_out,
19 llama_model_n_embd_inp, llama_model_n_embd_out, llama_model_n_head_kv, llama_model_n_params,
20 llama_model_n_swa, llama_model_rope_freq_scale_train, llama_model_rope_type,
21 llama_model_save_to_file, llama_model_size, llama_n_ctx_train, llama_n_embd, llama_n_head,
22 llama_n_layer, llama_n_vocab, llama_new_context_with_model, llama_split_path,
23 llama_split_prefix, llama_token_bos, llama_token_cls, llama_token_eos, llama_token_eot,
24 llama_token_fim_mid, llama_token_fim_pad, llama_token_fim_pre, llama_token_fim_rep,
25 llama_token_fim_sep, llama_token_fim_suf, llama_token_get_attr, llama_token_get_score,
26 llama_token_get_text, llama_token_is_control, llama_token_is_eog, llama_token_nl,
27 llama_token_pad, llama_token_sep, llama_token_to_piece, llama_tokenize, llama_vocab,
28 llama_vocab_type, LLAMA_VOCAB_TYPE_BPE, LLAMA_VOCAB_TYPE_SPM,
29};
30
31use crate::context::params::LlamaContextParams;
32use crate::context::LlamaContext;
33use crate::llama_backend::LlamaBackend;
34use crate::model::params::LlamaModelParams;
35use crate::token::LlamaToken;
36use crate::token_type::{LlamaTokenAttr, LlamaTokenAttrs};
37use crate::{
38 ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaLoraAdapterInitError,
39 LlamaModelLoadError, NewLlamaChatMessageError, StringFromModelError, StringToTokenError,
40 TokenToStringError,
41};
42
43pub mod params;
44
45/// A safe wrapper around `llama_model`.
46#[derive(Debug)]
47#[repr(transparent)]
48#[allow(clippy::module_name_repetitions)]
49pub struct LlamaModel {
50 pub(crate) model: NonNull<llama_model>,
51}
52
53/// A safe wrapper around `llama_vocab`.
54#[derive(Debug)]
55#[repr(transparent)]
56#[allow(clippy::module_name_repetitions)]
57pub struct LlamaVocab {
58 pub(crate) vocab: NonNull<llama_vocab>,
59}
60
61impl LlamaVocab {
62 /// Get the number of tokens in the vocabulary.
63 #[must_use]
64 pub fn n_tokens(&self) -> i32 {
65 unsafe { llama_cpp_sys_4::llama_vocab_n_tokens(self.vocab.as_ref()) }
66 }
67
68 /// Get the vocabulary type.
69 #[must_use]
70 pub fn vocab_type(&self) -> u32 {
71 unsafe { llama_cpp_sys_4::llama_vocab_type(self.vocab.as_ref()).try_into().unwrap() }
72 }
73
74 /// Get the BOS token.
75 #[must_use]
76 pub fn bos(&self) -> LlamaToken {
77 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_bos(self.vocab.as_ref()) })
78 }
79
80 /// Get the EOS token.
81 #[must_use]
82 pub fn eos(&self) -> LlamaToken {
83 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_eos(self.vocab.as_ref()) })
84 }
85
86 /// Get the EOT (end of turn) token.
87 #[must_use]
88 pub fn eot(&self) -> LlamaToken {
89 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_eot(self.vocab.as_ref()) })
90 }
91
92 /// Get the CLS (classification) token.
93 #[must_use]
94 pub fn cls(&self) -> LlamaToken {
95 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_cls(self.vocab.as_ref()) })
96 }
97
98 /// Get the SEP (separator) token.
99 #[must_use]
100 pub fn sep(&self) -> LlamaToken {
101 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_sep(self.vocab.as_ref()) })
102 }
103
104 /// Get the NL (newline) token.
105 #[must_use]
106 pub fn nl(&self) -> LlamaToken {
107 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_nl(self.vocab.as_ref()) })
108 }
109
110 /// Get the PAD (padding) token.
111 #[must_use]
112 pub fn pad(&self) -> LlamaToken {
113 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_pad(self.vocab.as_ref()) })
114 }
115
116 /// Get the FIM prefix token.
117 #[must_use]
118 pub fn fim_pre(&self) -> LlamaToken {
119 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_pre(self.vocab.as_ref()) })
120 }
121
122 /// Get the FIM suffix token.
123 #[must_use]
124 pub fn fim_suf(&self) -> LlamaToken {
125 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_suf(self.vocab.as_ref()) })
126 }
127
128 /// Get the FIM middle token.
129 #[must_use]
130 pub fn fim_mid(&self) -> LlamaToken {
131 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_mid(self.vocab.as_ref()) })
132 }
133
134 /// Get the FIM padding token.
135 #[must_use]
136 pub fn fim_pad(&self) -> LlamaToken {
137 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_pad(self.vocab.as_ref()) })
138 }
139
140 /// Get the FIM repository token.
141 #[must_use]
142 pub fn fim_rep(&self) -> LlamaToken {
143 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_rep(self.vocab.as_ref()) })
144 }
145
146 /// Get the FIM separator token.
147 #[must_use]
148 pub fn fim_sep(&self) -> LlamaToken {
149 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_sep(self.vocab.as_ref()) })
150 }
151
152 /// Check whether BOS should be added.
153 #[must_use]
154 pub fn get_add_bos(&self) -> bool {
155 unsafe { llama_cpp_sys_4::llama_vocab_get_add_bos(self.vocab.as_ref()) }
156 }
157
158 /// Check whether EOS should be added.
159 #[must_use]
160 pub fn get_add_eos(&self) -> bool {
161 unsafe { llama_cpp_sys_4::llama_vocab_get_add_eos(self.vocab.as_ref()) }
162 }
163
164 /// Check whether SEP should be added.
165 #[must_use]
166 pub fn get_add_sep(&self) -> bool {
167 unsafe { llama_cpp_sys_4::llama_vocab_get_add_sep(self.vocab.as_ref()) }
168 }
169
170 /// Get the text representation of a token.
171 ///
172 /// # Errors
173 ///
174 /// Returns an error if the text pointer is null or not valid UTF-8.
175 pub fn get_text(&self, token: LlamaToken) -> Result<&str, StringFromModelError> {
176 let ptr = unsafe { llama_cpp_sys_4::llama_vocab_get_text(self.vocab.as_ref(), token.0) };
177 if ptr.is_null() {
178 return Err(StringFromModelError::ReturnedError(-1));
179 }
180 let cstr = unsafe { CStr::from_ptr(ptr) };
181 cstr.to_str().map_err(StringFromModelError::Utf8Error)
182 }
183
184 /// Get the score of a token.
185 #[must_use]
186 pub fn get_score(&self, token: LlamaToken) -> f32 {
187 unsafe { llama_cpp_sys_4::llama_vocab_get_score(self.vocab.as_ref(), token.0) }
188 }
189
190 /// Get the attributes of a token.
191 #[must_use]
192 pub fn get_attr(&self, token: LlamaToken) -> u32 {
193 unsafe { llama_cpp_sys_4::llama_vocab_get_attr(self.vocab.as_ref(), token.0).try_into().unwrap() }
194 }
195
196 /// Check if a token is a control token.
197 #[must_use]
198 pub fn is_control(&self, token: LlamaToken) -> bool {
199 unsafe { llama_cpp_sys_4::llama_vocab_is_control(self.vocab.as_ref(), token.0) }
200 }
201
202 /// Check if a token is an end-of-generation token.
203 #[must_use]
204 pub fn is_eog(&self, token: LlamaToken) -> bool {
205 unsafe { llama_cpp_sys_4::llama_vocab_is_eog(self.vocab.as_ref(), token.0) }
206 }
207
208 /// Get the token mask value for the vocabulary.
209 #[must_use]
210 pub fn mask(&self) -> LlamaToken {
211 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_mask(self.vocab.as_ref()) })
212 }
213}
214
215/// A safe wrapper around `llama_adapter_lora`.
216#[derive(Debug)]
217#[repr(transparent)]
218#[allow(clippy::module_name_repetitions)]
219pub struct LlamaLoraAdapter {
220 pub(crate) lora_adapter: NonNull<llama_adapter_lora>,
221}
222
223impl LlamaLoraAdapter {
224 /// Get the number of metadata key-value pairs in the adapter.
225 #[must_use]
226 pub fn meta_count(&self) -> i32 {
227 unsafe { llama_cpp_sys_4::llama_adapter_meta_count(self.lora_adapter.as_ptr()) }
228 }
229
230 /// Get a metadata key by index.
231 ///
232 /// # Errors
233 ///
234 /// Returns an error if the index is out of range or the key is not valid UTF-8.
235 #[allow(clippy::cast_sign_loss)]
236 pub fn meta_key_by_index(
237 &self,
238 index: i32,
239 buf_size: usize,
240 ) -> Result<String, StringFromModelError> {
241 let mut buf = vec![0u8; buf_size];
242 let ret = unsafe {
243 llama_cpp_sys_4::llama_adapter_meta_key_by_index(
244 self.lora_adapter.as_ptr(),
245 index,
246 buf.as_mut_ptr().cast::<c_char>(),
247 buf_size,
248 )
249 };
250 if ret < 0 {
251 return Err(StringFromModelError::ReturnedError(ret));
252 }
253 let len = ret as usize;
254 let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
255 Ok(s.to_owned())
256 }
257
258 /// Get a metadata value by key name.
259 ///
260 /// # Errors
261 ///
262 /// Returns an error if the key is not found or the value is not valid UTF-8.
263 #[allow(clippy::cast_sign_loss)]
264 pub fn meta_val_str(
265 &self,
266 key: &str,
267 buf_size: usize,
268 ) -> Result<String, StringFromModelError> {
269 let c_key =
270 CString::new(key).map_err(|_| StringFromModelError::ReturnedError(-1))?;
271 let mut buf = vec![0u8; buf_size];
272 let ret = unsafe {
273 llama_cpp_sys_4::llama_adapter_meta_val_str(
274 self.lora_adapter.as_ptr(),
275 c_key.as_ptr(),
276 buf.as_mut_ptr().cast::<c_char>(),
277 buf_size,
278 )
279 };
280 if ret < 0 {
281 return Err(StringFromModelError::ReturnedError(ret));
282 }
283 let len = ret as usize;
284 let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
285 Ok(s.to_owned())
286 }
287
288 /// Get a metadata value by index.
289 ///
290 /// # Errors
291 ///
292 /// Returns an error if the index is out of range or the value is not valid UTF-8.
293 #[allow(clippy::cast_sign_loss)]
294 pub fn meta_val_str_by_index(
295 &self,
296 index: i32,
297 buf_size: usize,
298 ) -> Result<String, StringFromModelError> {
299 let mut buf = vec![0u8; buf_size];
300 let ret = unsafe {
301 llama_cpp_sys_4::llama_adapter_meta_val_str_by_index(
302 self.lora_adapter.as_ptr(),
303 index,
304 buf.as_mut_ptr().cast::<c_char>(),
305 buf_size,
306 )
307 };
308 if ret < 0 {
309 return Err(StringFromModelError::ReturnedError(ret));
310 }
311 let len = ret as usize;
312 let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
313 Ok(s.to_owned())
314 }
315
316 /// Get all metadata as a list of `(key, value)` pairs.
317 ///
318 /// # Errors
319 ///
320 /// Returns an error if any key or value cannot be read or is not valid UTF-8.
321 #[allow(clippy::cast_sign_loss)]
322 pub fn metadata(&self) -> Result<Vec<(String, String)>, StringFromModelError> {
323 let count = self.meta_count();
324 let mut result = Vec::with_capacity(count as usize);
325 for i in 0..count {
326 let key = self.meta_key_by_index(i, 256)?;
327 let val = self.meta_val_str_by_index(i, 4096)?;
328 result.push((key, val));
329 }
330 Ok(result)
331 }
332
333 /// Get the number of invocation tokens for this adapter.
334 #[must_use]
335 pub fn n_invocation_tokens(&self) -> u64 {
336 unsafe {
337 llama_cpp_sys_4::llama_adapter_get_alora_n_invocation_tokens(
338 self.lora_adapter.as_ptr(),
339 )
340 }
341 }
342
343 /// Get the invocation tokens for this adapter.
344 ///
345 /// Returns an empty slice if there are no invocation tokens.
346 #[must_use]
347 #[allow(clippy::cast_possible_truncation)]
348 pub fn invocation_tokens(&self) -> &[LlamaToken] {
349 let n = self.n_invocation_tokens() as usize;
350 if n == 0 {
351 return &[];
352 }
353 let ptr = unsafe {
354 llama_cpp_sys_4::llama_adapter_get_alora_invocation_tokens(
355 self.lora_adapter.as_ptr(),
356 )
357 };
358 if ptr.is_null() {
359 return &[];
360 }
361 // LlamaToken is repr(transparent) over llama_token (i32), so this cast is safe
362 unsafe { std::slice::from_raw_parts(ptr.cast::<LlamaToken>(), n) }
363 }
364}
365
366impl Drop for LlamaLoraAdapter {
367 fn drop(&mut self) {
368 unsafe {
369 llama_cpp_sys_4::llama_adapter_lora_free(self.lora_adapter.as_ptr());
370 }
371 }
372}
373
374/// A Safe wrapper around `llama_chat_message`
375#[derive(Debug, Eq, PartialEq, Clone)]
376pub struct LlamaChatMessage {
377 role: CString,
378 content: CString,
379}
380
381impl LlamaChatMessage {
382 /// Create a new `LlamaChatMessage`.
383 ///
384 /// # Errors
385 ///
386 /// Returns [`NewLlamaChatMessageError`] if the role or content contains a null byte.
387 pub fn new(role: String, content: String) -> Result<Self, NewLlamaChatMessageError> {
388 Ok(Self {
389 role: CString::new(role)?,
390 content: CString::new(content)?,
391 })
392 }
393}
394
395/// How to determine if we should prepend a bos token to tokens
396#[derive(Debug, Clone, Copy, PartialEq, Eq)]
397pub enum AddBos {
398 /// Add the beginning of stream token to the start of the string.
399 Always,
400 /// Do not add the beginning of stream token to the start of the string.
401 Never,
402}
403
404/// How to determine if we should tokenize special tokens
405#[derive(Debug, Clone, Copy, PartialEq, Eq)]
406pub enum Special {
407 /// Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space.
408 Tokenize,
409 /// Treat special and/or control tokens as plaintext.
410 Plaintext,
411}
412
413unsafe impl Send for LlamaModel {}
414
415unsafe impl Sync for LlamaModel {}
416
417impl LlamaModel {
418 /// Retrieves the vocabulary associated with the current Llama model.
419 ///
420 /// This method fetches the vocabulary from the underlying model using an unsafe
421 /// FFI call. The returned `LlamaVocab` struct contains a non-null pointer to
422 /// the vocabulary data, which is wrapped in a `NonNull` for safety.
423 ///
424 /// # Safety
425 /// This method uses an unsafe block to call a C function (`llama_model_get_vocab`),
426 /// which is assumed to return a valid pointer to the vocabulary. The caller should
427 /// ensure that the model object is properly initialized and valid before calling
428 /// this method, as dereferencing invalid pointers can lead to undefined behavior.
429 ///
430 /// # Returns
431 /// A `LlamaVocab` struct containing the vocabulary of the model.
432 ///
433 /// # Panics
434 ///
435 /// Panics if the underlying C function returns a null pointer.
436 ///
437 /// # Example
438 /// ```rust,ignore
439 /// let vocab = model.get_vocab();
440 /// ```
441 #[must_use]
442 pub fn get_vocab(&self) -> LlamaVocab {
443 let llama_vocab = unsafe { llama_model_get_vocab(self.model.as_ptr()) }.cast_mut();
444
445 LlamaVocab {
446 vocab: NonNull::new(llama_vocab).unwrap(),
447 }
448 }
449 /// Get the number of tokens the model was trained on.
450 ///
451 /// This function returns the number of tokens that the model was trained on, represented as a `u32`.
452 ///
453 /// # Panics
454 ///
455 /// This function will panic if the number of tokens the model was trained on does not fit into a `u32`.
456 /// This should be impossible on most platforms since llama.cpp returns a `c_int` (i32 on most platforms),
457 /// which is almost certainly positive.
458 #[must_use]
459 pub fn n_ctx_train(&self) -> u32 {
460 let n_ctx_train = unsafe { llama_n_ctx_train(self.model.as_ptr()) };
461 u32::try_from(n_ctx_train).expect("n_ctx_train fits into an u32")
462 }
463
464 /// Get all tokens in the model.
465 ///
466 /// This function returns an iterator over all the tokens in the model. Each item in the iterator is a tuple
467 /// containing a `LlamaToken` and its corresponding string representation (or an error if the conversion fails).
468 ///
469 /// # Parameters
470 ///
471 /// - `special`: The `Special` value that determines how special tokens (like BOS, EOS, etc.) are handled.
472 pub fn tokens(
473 &self,
474 special: Special,
475 ) -> impl Iterator<Item = (LlamaToken, Result<String, TokenToStringError>)> + '_ {
476 (0..self.n_vocab())
477 .map(LlamaToken::new)
478 .map(move |llama_token| (llama_token, self.token_to_str(llama_token, special)))
479 }
480
481 /// Get the beginning of stream token.
482 ///
483 /// This function returns the token that represents the beginning of a stream (BOS token).
484 #[must_use]
485 pub fn token_bos(&self) -> LlamaToken {
486 let token = unsafe { llama_token_bos(self.get_vocab().vocab.as_ref()) };
487 LlamaToken(token)
488 }
489
490 /// Get the end of stream token.
491 ///
492 /// This function returns the token that represents the end of a stream (EOS token).
493 #[must_use]
494 pub fn token_eos(&self) -> LlamaToken {
495 let token = unsafe { llama_token_eos(self.get_vocab().vocab.as_ref()) };
496 LlamaToken(token)
497 }
498
499 /// Get the newline token.
500 ///
501 /// This function returns the token that represents a newline character.
502 #[must_use]
503 pub fn token_nl(&self) -> LlamaToken {
504 let token = unsafe { llama_token_nl(self.get_vocab().vocab.as_ref()) };
505 LlamaToken(token)
506 }
507
508 /// Check if a token represents the end of generation (end of turn, end of sequence, etc.).
509 ///
510 /// This function returns `true` if the provided token signifies the end of generation or end of sequence,
511 /// such as EOS or other special tokens.
512 ///
513 /// # Parameters
514 ///
515 /// - `token`: The `LlamaToken` to check.
516 ///
517 /// # Returns
518 ///
519 /// - `true` if the token is an end-of-generation token, otherwise `false`.
520 #[must_use]
521 pub fn is_eog_token(&self, token: LlamaToken) -> bool {
522 unsafe { llama_token_is_eog(self.get_vocab().vocab.as_ref(), token.0) }
523 }
524
525 /// Get the classification token.
526 #[must_use]
527 pub fn token_cls(&self) -> LlamaToken {
528 let token = unsafe { llama_token_cls(self.get_vocab().vocab.as_ref()) };
529 LlamaToken(token)
530 }
531
532 /// Get the end-of-turn token.
533 #[must_use]
534 pub fn token_eot(&self) -> LlamaToken {
535 let token = unsafe { llama_token_eot(self.get_vocab().vocab.as_ref()) };
536 LlamaToken(token)
537 }
538
539 /// Get the padding token.
540 #[must_use]
541 pub fn token_pad(&self) -> LlamaToken {
542 let token = unsafe { llama_token_pad(self.get_vocab().vocab.as_ref()) };
543 LlamaToken(token)
544 }
545
546 /// Get the separator token.
547 #[must_use]
548 pub fn token_sep(&self) -> LlamaToken {
549 let token = unsafe { llama_token_sep(self.get_vocab().vocab.as_ref()) };
550 LlamaToken(token)
551 }
552
553 /// Get the fill-in-the-middle prefix token.
554 #[must_use]
555 pub fn token_fim_pre(&self) -> LlamaToken {
556 let token = unsafe { llama_token_fim_pre(self.get_vocab().vocab.as_ref()) };
557 LlamaToken(token)
558 }
559
560 /// Get the fill-in-the-middle suffix token.
561 #[must_use]
562 pub fn token_fim_suf(&self) -> LlamaToken {
563 let token = unsafe { llama_token_fim_suf(self.get_vocab().vocab.as_ref()) };
564 LlamaToken(token)
565 }
566
567 /// Get the fill-in-the-middle middle token.
568 #[must_use]
569 pub fn token_fim_mid(&self) -> LlamaToken {
570 let token = unsafe { llama_token_fim_mid(self.get_vocab().vocab.as_ref()) };
571 LlamaToken(token)
572 }
573
574 /// Get the fill-in-the-middle padding token.
575 #[must_use]
576 pub fn token_fim_pad(&self) -> LlamaToken {
577 let token = unsafe { llama_token_fim_pad(self.get_vocab().vocab.as_ref()) };
578 LlamaToken(token)
579 }
580
581 /// Get the fill-in-the-middle repository token.
582 #[must_use]
583 pub fn token_fim_rep(&self) -> LlamaToken {
584 let token = unsafe { llama_token_fim_rep(self.get_vocab().vocab.as_ref()) };
585 LlamaToken(token)
586 }
587
588 /// Get the fill-in-the-middle separator token.
589 #[must_use]
590 pub fn token_fim_sep(&self) -> LlamaToken {
591 let token = unsafe { llama_token_fim_sep(self.get_vocab().vocab.as_ref()) };
592 LlamaToken(token)
593 }
594
595 /// Check if a token is a control token.
596 #[must_use]
597 pub fn token_is_control(&self, token: LlamaToken) -> bool {
598 unsafe { llama_token_is_control(self.get_vocab().vocab.as_ref(), token.0) }
599 }
600
601 /// Get the score of a token.
602 #[must_use]
603 pub fn token_get_score(&self, token: LlamaToken) -> f32 {
604 unsafe { llama_token_get_score(self.get_vocab().vocab.as_ref(), token.0) }
605 }
606
607 /// Get the raw text of a token.
608 ///
609 /// # Errors
610 ///
611 /// Returns an error if the token text is null or not valid UTF-8.
612 pub fn token_get_text(&self, token: LlamaToken) -> Result<&str, StringFromModelError> {
613 let ptr =
614 unsafe { llama_token_get_text(self.get_vocab().vocab.as_ref(), token.0) };
615 if ptr.is_null() {
616 return Err(StringFromModelError::ReturnedError(-1));
617 }
618 let cstr = unsafe { CStr::from_ptr(ptr) };
619 cstr.to_str()
620 .map_err(StringFromModelError::Utf8Error)
621 }
622
623 /// Check if a BOS token should be added when tokenizing.
624 #[must_use]
625 pub fn add_bos_token(&self) -> bool {
626 unsafe { llama_add_bos_token(self.get_vocab().vocab.as_ref()) }
627 }
628
629 /// Check if an EOS token should be added when tokenizing.
630 #[must_use]
631 pub fn add_eos_token(&self) -> bool {
632 unsafe { llama_add_eos_token(self.get_vocab().vocab.as_ref()) }
633 }
634
635 /// Get the decoder start token.
636 ///
637 /// This function returns the token used to signal the start of decoding (i.e., the token used at the start
638 /// of a sequence generation).
639 #[must_use]
640 pub fn decode_start_token(&self) -> LlamaToken {
641 let token = unsafe { llama_model_decoder_start_token(self.model.as_ptr()) };
642 LlamaToken(token)
643 }
644
645 /// Convert a single token to a string.
646 ///
647 /// This function converts a `LlamaToken` into its string representation.
648 ///
649 /// # Errors
650 ///
651 /// This function returns an error if the token cannot be converted to a string. For more details, refer to
652 /// [`TokenToStringError`].
653 ///
654 /// # Parameters
655 ///
656 /// - `token`: The `LlamaToken` to convert.
657 /// - `special`: The `Special` value used to handle special tokens.
658 pub fn token_to_str(
659 &self,
660 token: LlamaToken,
661 special: Special,
662 ) -> Result<String, TokenToStringError> {
663 self.token_to_str_with_size(token, 32, special)
664 }
665
666 /// Convert a single token to bytes.
667 ///
668 /// This function converts a `LlamaToken` into a byte representation.
669 ///
670 /// # Errors
671 ///
672 /// This function returns an error if the token cannot be converted to bytes. For more details, refer to
673 /// [`TokenToStringError`].
674 ///
675 /// # Parameters
676 ///
677 /// - `token`: The `LlamaToken` to convert.
678 /// - `special`: The `Special` value used to handle special tokens.
679 pub fn token_to_bytes(
680 &self,
681 token: LlamaToken,
682 special: Special,
683 ) -> Result<Vec<u8>, TokenToStringError> {
684 self.token_to_bytes_with_size(token, 32, special, None)
685 }
686
687 /// Convert a vector of tokens to a single string.
688 ///
689 /// This function takes a slice of `LlamaToken`s and converts them into a single string, concatenating their
690 /// string representations.
691 ///
692 /// # Errors
693 ///
694 /// This function returns an error if any token cannot be converted to a string. For more details, refer to
695 /// [`TokenToStringError`].
696 ///
697 /// # Parameters
698 ///
699 /// - `tokens`: A slice of `LlamaToken`s to convert.
700 /// - `special`: The `Special` value used to handle special tokens.
701 pub fn tokens_to_str(
702 &self,
703 tokens: &[LlamaToken],
704 special: Special,
705 ) -> Result<String, TokenToStringError> {
706 let mut builder = String::with_capacity(tokens.len() * 4);
707 for str in tokens
708 .iter()
709 .copied()
710 .map(|t| self.token_to_str(t, special))
711 {
712 builder += &str?;
713 }
714 Ok(builder)
715 }
716
717 /// Convert a string to a vector of tokens.
718 ///
719 /// This function converts a string into a vector of `LlamaToken`s. The function will tokenize the string
720 /// and return the corresponding tokens.
721 ///
722 /// # Errors
723 ///
724 /// - This function will return an error if the input string contains a null byte.
725 ///
726 /// # Panics
727 ///
728 /// - This function will panic if the number of tokens exceeds `usize::MAX`.
729 ///
730 /// # Example
731 ///
732 /// ```no_run
733 /// use llama_cpp_4::model::LlamaModel;
734 ///
735 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
736 /// use std::path::Path;
737 /// use llama_cpp_4::model::AddBos;
738 /// let backend = llama_cpp_4::llama_backend::LlamaBackend::init()?;
739 /// let model = LlamaModel::load_from_file(&backend, Path::new("path/to/model"), &Default::default())?;
740 /// let tokens = model.str_to_token("Hello, World!", AddBos::Always)?;
741 /// # Ok(())
742 /// # }
743 /// ```
744 pub fn str_to_token(
745 &self,
746 str: &str,
747 add_bos: AddBos,
748 ) -> Result<Vec<LlamaToken>, StringToTokenError> {
749 let add_bos = match add_bos {
750 AddBos::Always => true,
751 AddBos::Never => false,
752 };
753
754 let tokens_estimation = std::cmp::max(8, (str.len() / 2) + usize::from(add_bos));
755 let mut buffer = Vec::with_capacity(tokens_estimation);
756
757 let c_string = CString::new(str)?;
758 let buffer_capacity =
759 c_int::try_from(buffer.capacity()).expect("buffer capacity should fit into a c_int");
760
761 let size = unsafe {
762 llama_tokenize(
763 self.get_vocab().vocab.as_ref(),
764 c_string.as_ptr(),
765 c_int::try_from(c_string.as_bytes().len())?,
766 buffer.as_mut_ptr(),
767 buffer_capacity,
768 add_bos,
769 true,
770 )
771 };
772
773 // if we fail the first time we can resize the vector to the correct size and try again. This should never fail.
774 // as a result - size is guaranteed to be positive here.
775 let size = if size.is_negative() {
776 buffer.reserve_exact(usize::try_from(-size).expect("usize's are larger "));
777 unsafe {
778 llama_tokenize(
779 self.get_vocab().vocab.as_ref(),
780 c_string.as_ptr(),
781 c_int::try_from(c_string.as_bytes().len())?,
782 buffer.as_mut_ptr(),
783 -size,
784 add_bos,
785 true,
786 )
787 }
788 } else {
789 size
790 };
791
792 let size = usize::try_from(size).expect("size is positive and usize ");
793
794 // Safety: `size` < `capacity` and llama-cpp has initialized elements up to `size`
795 unsafe { buffer.set_len(size) }
796 Ok(buffer.into_iter().map(LlamaToken).collect())
797 }
798
799 /// Get the type of a token.
800 ///
801 /// This function retrieves the attributes associated with a given token. The attributes are typically used to
802 /// understand whether the token represents a special type of token (e.g., beginning-of-sequence (BOS), end-of-sequence (EOS),
803 /// control tokens, etc.).
804 ///
805 /// # Panics
806 ///
807 /// - This function will panic if the token type is unknown or cannot be converted to a valid `LlamaTokenAttrs`.
808 ///
809 /// # Example
810 ///
811 /// ```no_run
812 /// use llama_cpp_4::model::LlamaModel;
813 /// use llama_cpp_4::model::params::LlamaModelParams;
814 /// use llama_cpp_4::llama_backend::LlamaBackend;
815 /// use llama_cpp_4::token::LlamaToken;
816 ///
817 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
818 /// let backend = LlamaBackend::init()?;
819 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
820 /// let token = LlamaToken::new(42);
821 /// let token_attrs = model.token_attr(token);
822 /// # Ok(())
823 /// # }
824 /// ```
825 #[must_use]
826 pub fn token_attr(&self, LlamaToken(id): LlamaToken) -> LlamaTokenAttrs {
827 let token_type = unsafe { llama_token_get_attr(self.get_vocab().vocab.as_ref(), id) };
828 LlamaTokenAttrs::try_from(token_type).expect("token type is valid")
829 }
830
831 /// Detokenize a slice of tokens into a string.
832 ///
833 /// This is the inverse of [`str_to_token`](Self::str_to_token).
834 ///
835 /// # Parameters
836 ///
837 /// - `tokens`: The tokens to detokenize.
838 /// - `remove_special`: If `true`, special tokens are removed from the output.
839 /// - `unparse_special`: If `true`, special tokens are rendered as their text representation.
840 ///
841 /// # Errors
842 ///
843 /// Returns an error if the detokenized text is not valid UTF-8.
844 #[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap, clippy::cast_sign_loss)]
845 pub fn detokenize(
846 &self,
847 tokens: &[LlamaToken],
848 remove_special: bool,
849 unparse_special: bool,
850 ) -> Result<String, StringFromModelError> {
851 // First call with empty buffer to get required size
852 let n_tokens = tokens.len() as i32;
853 let token_ptr = tokens.as_ptr().cast::<llama_cpp_sys_4::llama_token>();
854 let needed = unsafe {
855 llama_detokenize(
856 self.get_vocab().vocab.as_ref(),
857 token_ptr,
858 n_tokens,
859 std::ptr::null_mut(),
860 0,
861 remove_special,
862 unparse_special,
863 )
864 };
865 // llama_detokenize returns negative required size when buffer is too small
866 let buf_size = if needed < 0 { (-needed) as usize } else { needed as usize };
867 let mut buf = vec![0u8; buf_size];
868 let ret = unsafe {
869 llama_detokenize(
870 self.get_vocab().vocab.as_ref(),
871 token_ptr,
872 n_tokens,
873 buf.as_mut_ptr().cast::<c_char>(),
874 buf_size as i32,
875 remove_special,
876 unparse_special,
877 )
878 };
879 if ret < 0 {
880 return Err(StringFromModelError::ReturnedError(ret));
881 }
882 let len = ret as usize;
883 let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
884 Ok(s.to_owned())
885 }
886
887 /// Convert a token to a string with a specified buffer size.
888 ///
889 /// This function allows you to convert a token into a string, with the ability to specify a buffer size for the operation.
890 /// It is generally recommended to use `LlamaModel::token_to_str` instead, as 8 bytes is typically sufficient for most tokens,
891 /// and the extra buffer size doesn't usually matter.
892 ///
893 /// # Errors
894 ///
895 /// - If the token type is unknown, an error will be returned.
896 /// - If the resultant token exceeds the provided `buffer_size`, an error will occur.
897 /// - If the token string returned by `llama-cpp` is not valid UTF-8, it will return an error.
898 ///
899 /// # Panics
900 ///
901 /// - This function will panic if the `buffer_size` does not fit into a `c_int`.
902 /// - It will also panic if the size returned from `llama-cpp` does not fit into a `usize`, which should typically never happen.
903 ///
904 /// # Example
905 ///
906 /// ```no_run
907 /// use llama_cpp_4::model::{LlamaModel, Special};
908 /// use llama_cpp_4::model::params::LlamaModelParams;
909 /// use llama_cpp_4::llama_backend::LlamaBackend;
910 /// use llama_cpp_4::token::LlamaToken;
911 ///
912 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
913 /// let backend = LlamaBackend::init()?;
914 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
915 /// let token = LlamaToken::new(42);
916 /// let token_string = model.token_to_str_with_size(token, 32, Special::Plaintext)?;
917 /// # Ok(())
918 /// # }
919 /// ```
920 pub fn token_to_str_with_size(
921 &self,
922 token: LlamaToken,
923 buffer_size: usize,
924 special: Special,
925 ) -> Result<String, TokenToStringError> {
926 let bytes = self.token_to_bytes_with_size(token, buffer_size, special, None)?;
927 Ok(String::from_utf8(bytes)?)
928 }
929
930 /// Convert a token to bytes with a specified buffer size.
931 ///
932 /// Generally you should use [`LlamaModel::token_to_bytes`] instead as 8 bytes is enough for most words and
933 /// the extra bytes do not really matter.
934 ///
935 /// # Errors
936 ///
937 /// - if the token type is unknown
938 /// - the resultant token is larger than `buffer_size`.
939 ///
940 /// # Panics
941 ///
942 /// - This function will panic if `buffer_size` cannot fit into a `c_int`.
943 /// - It will also panic if the size returned from `llama-cpp` cannot be converted to `usize` (which should not happen).
944 ///
945 /// # Example
946 ///
947 /// ```no_run
948 /// use llama_cpp_4::model::{LlamaModel, Special};
949 /// use llama_cpp_4::model::params::LlamaModelParams;
950 /// use llama_cpp_4::llama_backend::LlamaBackend;
951 /// use llama_cpp_4::token::LlamaToken;
952 ///
953 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
954 /// let backend = LlamaBackend::init()?;
955 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
956 /// let token = LlamaToken::new(42);
957 /// let token_bytes = model.token_to_bytes_with_size(token, 32, Special::Plaintext, None)?;
958 /// # Ok(())
959 /// # }
960 /// ```
961 pub fn token_to_bytes_with_size(
962 &self,
963 token: LlamaToken,
964 buffer_size: usize,
965 special: Special,
966 lstrip: Option<NonZeroU16>,
967 ) -> Result<Vec<u8>, TokenToStringError> {
968 if token == self.token_nl() {
969 return Ok(String::from("\n").into_bytes());
970 }
971
972 // unsure what to do with this in the face of the 'special' arg + attr changes
973 let attrs = self.token_attr(token);
974 if (attrs.contains(LlamaTokenAttr::Control)
975 && (token == self.token_bos() || token == self.token_eos()))
976 || attrs.is_empty()
977 || attrs
978 .intersects(LlamaTokenAttr::Unknown | LlamaTokenAttr::Byte | LlamaTokenAttr::Unused)
979 {
980 return Ok(Vec::new());
981 }
982
983 let special = match special {
984 Special::Tokenize => true,
985 Special::Plaintext => false,
986 };
987
988 let string = CString::new(vec![b'*'; buffer_size]).expect("no null");
989 let len = string.as_bytes().len();
990 let len = c_int::try_from(len).expect("length fits into c_int");
991 let buf = string.into_raw();
992 let lstrip = lstrip.map_or(0, |it| i32::from(it.get()));
993 let size = unsafe {
994 llama_token_to_piece(
995 self.get_vocab().vocab.as_ref(),
996 token.0,
997 buf,
998 len,
999 lstrip,
1000 special,
1001 )
1002 };
1003
1004 match size {
1005 0 => Err(TokenToStringError::UnknownTokenType),
1006 i if i.is_negative() => Err(TokenToStringError::InsufficientBufferSpace(i)),
1007 size => {
1008 let string = unsafe { CString::from_raw(buf) };
1009 let mut bytes = string.into_bytes();
1010 let len = usize::try_from(size).expect("size is positive and fits into usize");
1011 bytes.truncate(len);
1012 Ok(bytes)
1013 }
1014 }
1015 }
1016 /// The number of tokens the model was trained on.
1017 ///
1018 /// This function returns the number of tokens the model was trained on. It is returned as a `c_int` for maximum
1019 /// compatibility with the underlying llama-cpp library, though it can typically be cast to an `i32` without issue.
1020 ///
1021 /// # Example
1022 ///
1023 /// ```no_run
1024 /// use llama_cpp_4::model::LlamaModel;
1025 /// use llama_cpp_4::model::params::LlamaModelParams;
1026 /// use llama_cpp_4::llama_backend::LlamaBackend;
1027 ///
1028 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1029 /// let backend = LlamaBackend::init()?;
1030 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1031 /// let n_vocab = model.n_vocab();
1032 /// # Ok(())
1033 /// # }
1034 /// ```
1035 #[must_use]
1036 pub fn n_vocab(&self) -> i32 {
1037 unsafe { llama_n_vocab(self.get_vocab().vocab.as_ref()) }
1038 }
1039
1040 /// The type of vocab the model was trained on.
1041 ///
1042 /// This function returns the type of vocabulary used by the model, such as whether it is based on byte-pair encoding (BPE),
1043 /// word-level tokens, or another tokenization scheme.
1044 ///
1045 /// # Panics
1046 ///
1047 /// - This function will panic if `llama-cpp` emits a vocab type that is not recognized or is invalid for this library.
1048 ///
1049 /// # Example
1050 ///
1051 /// ```no_run
1052 /// use llama_cpp_4::model::LlamaModel;
1053 /// use llama_cpp_4::model::params::LlamaModelParams;
1054 /// use llama_cpp_4::llama_backend::LlamaBackend;
1055 ///
1056 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1057 /// let backend = LlamaBackend::init()?;
1058 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1059 /// let vocab_type = model.vocab_type();
1060 /// # Ok(())
1061 /// # }
1062 /// ```
1063 #[must_use]
1064 pub fn vocab_type(&self) -> VocabType {
1065 let vocab_type = unsafe { llama_vocab_type(self.get_vocab().vocab.as_ref()) };
1066 VocabType::try_from(vocab_type).expect("invalid vocab type")
1067 }
1068
1069 /// Returns the number of embedding dimensions for the model.
1070 ///
1071 /// This function retrieves the number of embeddings (or embedding dimensions) used by the model. It is typically
1072 /// used for analyzing model architecture and setting up context parameters or other model configuration aspects.
1073 ///
1074 /// # Panics
1075 ///
1076 /// - This function may panic if the underlying `llama-cpp` library returns an invalid embedding dimension value.
1077 ///
1078 /// # Example
1079 ///
1080 /// ```no_run
1081 /// use llama_cpp_4::model::LlamaModel;
1082 /// use llama_cpp_4::model::params::LlamaModelParams;
1083 /// use llama_cpp_4::llama_backend::LlamaBackend;
1084 ///
1085 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1086 /// let backend = LlamaBackend::init()?;
1087 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1088 /// let n_embd = model.n_embd();
1089 /// # Ok(())
1090 /// # }
1091 /// ```
1092 #[must_use]
1093 pub fn n_embd(&self) -> c_int {
1094 unsafe { llama_n_embd(self.model.as_ptr()) }
1095 }
1096
1097 /// Get the number of transformer layers in the model.
1098 #[must_use]
1099 pub fn n_layer(&self) -> c_int {
1100 unsafe { llama_n_layer(self.model.as_ptr()) }
1101 }
1102
1103 /// Get the number of attention heads in the model.
1104 #[must_use]
1105 pub fn n_head(&self) -> c_int {
1106 unsafe { llama_n_head(self.model.as_ptr()) }
1107 }
1108
1109 /// Get the number of key-value attention heads in the model.
1110 #[must_use]
1111 pub fn n_head_kv(&self) -> c_int {
1112 unsafe { llama_model_n_head_kv(self.model.as_ptr()) }
1113 }
1114
1115 /// Get the input embedding size of the model.
1116 #[must_use]
1117 pub fn n_embd_inp(&self) -> c_int {
1118 unsafe { llama_model_n_embd_inp(self.model.as_ptr()) }
1119 }
1120
1121 /// Get the output embedding size of the model.
1122 #[must_use]
1123 pub fn n_embd_out(&self) -> c_int {
1124 unsafe { llama_model_n_embd_out(self.model.as_ptr()) }
1125 }
1126
1127 /// Get the sliding window attention size of the model.
1128 /// Returns 0 if the model does not use sliding window attention.
1129 #[must_use]
1130 pub fn n_swa(&self) -> c_int {
1131 unsafe { llama_model_n_swa(self.model.as_ptr()) }
1132 }
1133
1134 /// Get the `RoPE` type used by the model.
1135 #[must_use]
1136 pub fn rope_type(&self) -> i32 {
1137 unsafe { llama_model_rope_type(self.model.as_ptr()) }
1138 }
1139
1140 /// Get the `RoPE` frequency scale used during training.
1141 #[must_use]
1142 pub fn rope_freq_scale_train(&self) -> f32 {
1143 unsafe { llama_model_rope_freq_scale_train(self.model.as_ptr()) }
1144 }
1145
1146 /// Get the model size in bytes.
1147 #[must_use]
1148 pub fn model_size(&self) -> u64 {
1149 unsafe { llama_model_size(self.model.as_ptr()) }
1150 }
1151
1152 /// Get the number of parameters in the model.
1153 #[must_use]
1154 pub fn n_params(&self) -> u64 {
1155 unsafe { llama_model_n_params(self.model.as_ptr()) }
1156 }
1157
1158 /// Get the number of classification outputs.
1159 #[must_use]
1160 pub fn n_cls_out(&self) -> u32 {
1161 unsafe { llama_model_n_cls_out(self.model.as_ptr()) }
1162 }
1163
1164 /// Get the classification label for the given index.
1165 ///
1166 /// # Errors
1167 ///
1168 /// Returns an error if the label is null or not valid UTF-8.
1169 pub fn cls_label(&self, index: u32) -> Result<&str, StringFromModelError> {
1170 let ptr = unsafe { llama_model_cls_label(self.model.as_ptr(), index) };
1171 if ptr.is_null() {
1172 return Err(StringFromModelError::ReturnedError(-1));
1173 }
1174 let cstr = unsafe { CStr::from_ptr(ptr) };
1175 cstr.to_str().map_err(StringFromModelError::Utf8Error)
1176 }
1177
1178 /// Get the number of metadata key-value pairs.
1179 #[must_use]
1180 pub fn meta_count(&self) -> c_int {
1181 unsafe { llama_model_meta_count(self.model.as_ptr()) }
1182 }
1183
1184 /// Get a model description string.
1185 ///
1186 /// The `buf_size` parameter specifies the maximum buffer size for the description.
1187 /// A default of 256 bytes is usually sufficient.
1188 ///
1189 /// # Errors
1190 ///
1191 /// Returns an error if the description could not be retrieved or is not valid UTF-8.
1192 #[allow(clippy::cast_sign_loss)]
1193 pub fn desc(&self, buf_size: usize) -> Result<String, StringFromModelError> {
1194 let mut buf = vec![0u8; buf_size];
1195 let ret = unsafe {
1196 llama_model_desc(
1197 self.model.as_ptr(),
1198 buf.as_mut_ptr().cast::<c_char>(),
1199 buf_size,
1200 )
1201 };
1202 if ret < 0 {
1203 return Err(StringFromModelError::ReturnedError(ret));
1204 }
1205 let len = ret as usize;
1206 let s = std::str::from_utf8(&buf[..len])
1207 .map_err(StringFromModelError::Utf8Error)?;
1208 Ok(s.to_owned())
1209 }
1210
1211 /// Get a metadata key by index.
1212 ///
1213 /// The `buf_size` parameter specifies the maximum buffer size for the key.
1214 /// A default of 256 bytes is usually sufficient.
1215 ///
1216 /// # Errors
1217 ///
1218 /// Returns an error if the index is out of range or the key is not valid UTF-8.
1219 #[allow(clippy::cast_sign_loss)]
1220 pub fn meta_key_by_index(&self, index: i32, buf_size: usize) -> Result<String, StringFromModelError> {
1221 let mut buf = vec![0u8; buf_size];
1222 let ret = unsafe {
1223 llama_model_meta_key_by_index(
1224 self.model.as_ptr(),
1225 index,
1226 buf.as_mut_ptr().cast::<c_char>(),
1227 buf_size,
1228 )
1229 };
1230 if ret < 0 {
1231 return Err(StringFromModelError::ReturnedError(ret));
1232 }
1233 let len = ret as usize;
1234 let s = std::str::from_utf8(&buf[..len])
1235 .map_err(StringFromModelError::Utf8Error)?;
1236 Ok(s.to_owned())
1237 }
1238
1239 /// Get a metadata value string by index.
1240 ///
1241 /// The `buf_size` parameter specifies the maximum buffer size for the value.
1242 /// Values can be large (e.g. chat templates, token lists), so 4096+ may be needed.
1243 ///
1244 /// # Errors
1245 ///
1246 /// Returns an error if the index is out of range or the value is not valid UTF-8.
1247 #[allow(clippy::cast_sign_loss)]
1248 pub fn meta_val_str_by_index(&self, index: i32, buf_size: usize) -> Result<String, StringFromModelError> {
1249 let mut buf = vec![0u8; buf_size];
1250 let ret = unsafe {
1251 llama_model_meta_val_str_by_index(
1252 self.model.as_ptr(),
1253 index,
1254 buf.as_mut_ptr().cast::<c_char>(),
1255 buf_size,
1256 )
1257 };
1258 if ret < 0 {
1259 return Err(StringFromModelError::ReturnedError(ret));
1260 }
1261 let len = ret as usize;
1262 let s = std::str::from_utf8(&buf[..len])
1263 .map_err(StringFromModelError::Utf8Error)?;
1264 Ok(s.to_owned())
1265 }
1266
1267 /// Get a metadata value by key name.
1268 ///
1269 /// This is more convenient than iterating metadata by index when you know the key.
1270 /// The `buf_size` parameter specifies the maximum buffer size for the value.
1271 ///
1272 /// # Errors
1273 ///
1274 /// Returns an error if the key is not found, contains a null byte, or the value is not valid UTF-8.
1275 #[allow(clippy::cast_sign_loss)]
1276 pub fn meta_val_str(&self, key: &str, buf_size: usize) -> Result<String, StringFromModelError> {
1277 let c_key = CString::new(key)
1278 .map_err(|_| StringFromModelError::ReturnedError(-1))?;
1279 let mut buf = vec![0u8; buf_size];
1280 let ret = unsafe {
1281 llama_model_meta_val_str(
1282 self.model.as_ptr(),
1283 c_key.as_ptr(),
1284 buf.as_mut_ptr().cast::<c_char>(),
1285 buf_size,
1286 )
1287 };
1288 if ret < 0 {
1289 return Err(StringFromModelError::ReturnedError(ret));
1290 }
1291 let len = ret as usize;
1292 let s = std::str::from_utf8(&buf[..len])
1293 .map_err(StringFromModelError::Utf8Error)?;
1294 Ok(s.to_owned())
1295 }
1296
1297 /// Get all metadata as a list of `(key, value)` pairs.
1298 ///
1299 /// This is a convenience method that iterates over all metadata entries.
1300 /// Keys use a buffer of 256 bytes and values use 4096 bytes.
1301 /// For values that may be larger (e.g. token lists), use
1302 /// [`meta_val_str_by_index`](Self::meta_val_str_by_index) directly with a larger buffer.
1303 ///
1304 /// # Errors
1305 ///
1306 /// Returns an error if any key or value cannot be read or is not valid UTF-8.
1307 #[allow(clippy::cast_sign_loss)]
1308 pub fn metadata(&self) -> Result<Vec<(String, String)>, StringFromModelError> {
1309 let count = self.meta_count();
1310 let mut result = Vec::with_capacity(count as usize);
1311 for i in 0..count {
1312 let key = self.meta_key_by_index(i, 256)?;
1313 let val = self.meta_val_str_by_index(i, 4096)?;
1314 result.push((key, val));
1315 }
1316 Ok(result)
1317 }
1318
1319 /// Check if the model has an encoder.
1320 #[must_use]
1321 pub fn has_encoder(&self) -> bool {
1322 unsafe { llama_model_has_encoder(self.model.as_ptr()) }
1323 }
1324
1325 /// Check if the model has a decoder.
1326 #[must_use]
1327 pub fn has_decoder(&self) -> bool {
1328 unsafe { llama_model_has_decoder(self.model.as_ptr()) }
1329 }
1330
1331 /// Check if the model is recurrent (e.g. Mamba, RWKV).
1332 #[must_use]
1333 pub fn is_recurrent(&self) -> bool {
1334 unsafe { llama_model_is_recurrent(self.model.as_ptr()) }
1335 }
1336
1337 /// Check if the model is a hybrid model.
1338 #[must_use]
1339 pub fn is_hybrid(&self) -> bool {
1340 unsafe { llama_model_is_hybrid(self.model.as_ptr()) }
1341 }
1342
1343 /// Check if the model is a diffusion model.
1344 #[must_use]
1345 pub fn is_diffusion(&self) -> bool {
1346 unsafe { llama_model_is_diffusion(self.model.as_ptr()) }
1347 }
1348
1349 /// Get chat template from model.
1350 ///
1351 /// # Errors
1352 ///
1353 /// - If the model does not have a chat template, it will return an error.
1354 /// - If the chat template is not a valid `CString`, it will return an error.
1355 ///
1356 /// # Example
1357 ///
1358 /// ```no_run
1359 /// use llama_cpp_4::model::LlamaModel;
1360 /// use llama_cpp_4::model::params::LlamaModelParams;
1361 /// use llama_cpp_4::llama_backend::LlamaBackend;
1362 ///
1363 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1364 /// let backend = LlamaBackend::init()?;
1365 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1366 /// let chat_template = model.get_chat_template(1024)?;
1367 /// # Ok(())
1368 /// # }
1369 /// ```
1370 #[allow(clippy::missing_panics_doc)] // We statically know this will not panic as long as the buffer size is sufficient
1371 pub fn get_chat_template(&self, buf_size: usize) -> Result<String, ChatTemplateError> {
1372 // longest known template is about 1200 bytes from llama.cpp
1373 let chat_temp = CString::new(vec![b'*'; buf_size]).expect("no null");
1374 let chat_ptr = chat_temp.into_raw();
1375 let chat_name = CString::new("tokenizer.chat_template").expect("no null bytes");
1376
1377 let ret = unsafe {
1378 llama_model_meta_val_str(self.model.as_ptr(), chat_name.as_ptr(), chat_ptr, buf_size)
1379 };
1380
1381 if ret < 0 {
1382 return Err(ChatTemplateError::MissingTemplate(ret));
1383 }
1384
1385 let template_c = unsafe { CString::from_raw(chat_ptr) };
1386 let template = template_c.to_str()?;
1387
1388 let ret: usize = ret.try_into().unwrap();
1389 if template.len() < ret {
1390 return Err(ChatTemplateError::BuffSizeError(ret + 1));
1391 }
1392
1393 Ok(template.to_owned())
1394 }
1395
1396 /// Loads a model from a file.
1397 ///
1398 /// This function loads a model from a specified file path and returns the corresponding `LlamaModel` instance.
1399 ///
1400 /// # Errors
1401 ///
1402 /// - If the path cannot be converted to a string or if the model file does not exist, it will return an error.
1403 /// - If the model cannot be loaded (e.g., due to an invalid or corrupted model file), it will return a `LlamaModelLoadError`.
1404 ///
1405 /// # Example
1406 ///
1407 /// ```no_run
1408 /// use llama_cpp_4::model::LlamaModel;
1409 /// use llama_cpp_4::model::params::LlamaModelParams;
1410 /// use llama_cpp_4::llama_backend::LlamaBackend;
1411 ///
1412 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1413 /// let backend = LlamaBackend::init()?;
1414 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1415 /// # Ok(())
1416 /// # }
1417 /// ```
1418 #[tracing::instrument(skip_all, fields(params))]
1419 pub fn load_from_file(
1420 _: &LlamaBackend,
1421 path: impl AsRef<Path>,
1422 params: &LlamaModelParams,
1423 ) -> Result<Self, LlamaModelLoadError> {
1424 let path = path.as_ref();
1425 debug_assert!(
1426 Path::new(path).exists(),
1427 "{} does not exist",
1428 path.display()
1429 );
1430 let path = path
1431 .to_str()
1432 .ok_or(LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
1433
1434 let cstr = CString::new(path)?;
1435 let llama_model = unsafe { llama_load_model_from_file(cstr.as_ptr(), params.params) };
1436
1437 let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?;
1438
1439 tracing::debug!(?path, "Loaded model");
1440 Ok(LlamaModel { model })
1441 }
1442
1443 /// Load a model from multiple split files.
1444 ///
1445 /// This function loads a model that has been split across multiple files. This is useful for
1446 /// very large models that exceed filesystem limitations or need to be distributed across
1447 /// multiple storage devices.
1448 ///
1449 /// # Arguments
1450 ///
1451 /// * `paths` - A slice of paths to the split model files
1452 /// * `params` - The model parameters
1453 ///
1454 /// # Errors
1455 ///
1456 /// Returns an error if:
1457 /// - Any of the paths cannot be converted to a C string
1458 /// - The model fails to load from the splits
1459 /// - Any path doesn't exist or isn't accessible
1460 ///
1461 /// # Example
1462 ///
1463 /// ```no_run
1464 /// use llama_cpp_4::model::{LlamaModel, params::LlamaModelParams};
1465 /// use llama_cpp_4::llama_backend::LlamaBackend;
1466 ///
1467 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1468 /// let backend = LlamaBackend::init()?;
1469 /// let params = LlamaModelParams::default();
1470 ///
1471 /// let paths = vec![
1472 /// "model-00001-of-00003.gguf",
1473 /// "model-00002-of-00003.gguf",
1474 /// "model-00003-of-00003.gguf",
1475 /// ];
1476 ///
1477 /// let model = LlamaModel::load_from_splits(&backend, &paths, ¶ms)?;
1478 /// # Ok(())
1479 /// # }
1480 /// ```
1481 #[tracing::instrument(skip_all)]
1482 pub fn load_from_splits(
1483 _: &LlamaBackend,
1484 paths: &[impl AsRef<Path>],
1485 params: &LlamaModelParams,
1486 ) -> Result<Self, LlamaModelLoadError> {
1487 // Convert paths to C strings
1488 let c_strings: Vec<CString> = paths
1489 .iter()
1490 .map(|p| {
1491 let path = p.as_ref();
1492 debug_assert!(path.exists(), "{} does not exist", path.display());
1493 let path_str = path
1494 .to_str()
1495 .ok_or(LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
1496 CString::new(path_str).map_err(LlamaModelLoadError::from)
1497 })
1498 .collect::<Result<Vec<_>, _>>()?;
1499
1500 // Create array of pointers to C strings
1501 let c_ptrs: Vec<*const c_char> = c_strings.iter().map(|s| s.as_ptr()).collect();
1502
1503 // Load the model from splits
1504 let llama_model = unsafe {
1505 llama_model_load_from_splits(c_ptrs.as_ptr().cast_mut(), c_ptrs.len(), params.params)
1506 };
1507
1508 let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?;
1509
1510 tracing::debug!("Loaded model from {} splits", paths.len());
1511 Ok(LlamaModel { model })
1512 }
1513
1514 /// Load a model from a `FILE` pointer.
1515 ///
1516 /// # Safety
1517 ///
1518 /// The `file` pointer must be a valid, open `FILE*`.
1519 ///
1520 /// # Errors
1521 ///
1522 /// Returns an error if the model cannot be loaded.
1523 pub unsafe fn load_from_file_ptr(
1524 file: *mut llama_cpp_sys_4::FILE,
1525 params: &LlamaModelParams,
1526 ) -> Result<Self, LlamaModelLoadError> {
1527 let model = llama_cpp_sys_4::llama_model_load_from_file_ptr(file, params.params);
1528 let model = NonNull::new(model).ok_or(LlamaModelLoadError::NullResult)?;
1529 Ok(LlamaModel { model })
1530 }
1531
1532 /// Initialize a model from user-provided data.
1533 ///
1534 /// # Safety
1535 ///
1536 /// The metadata, callback, and user data must be valid.
1537 ///
1538 /// # Errors
1539 ///
1540 /// Returns an error if the model cannot be initialized.
1541 pub unsafe fn init_from_user(
1542 metadata: *mut llama_cpp_sys_4::gguf_context,
1543 set_tensor_data: llama_cpp_sys_4::llama_model_set_tensor_data_t,
1544 set_tensor_data_ud: *mut std::ffi::c_void,
1545 params: &LlamaModelParams,
1546 ) -> Result<Self, LlamaModelLoadError> {
1547 let model = llama_cpp_sys_4::llama_model_init_from_user(
1548 metadata,
1549 set_tensor_data,
1550 set_tensor_data_ud,
1551 params.params,
1552 );
1553 let model = NonNull::new(model).ok_or(LlamaModelLoadError::NullResult)?;
1554 Ok(LlamaModel { model })
1555 }
1556
1557 /// Save the model to a file.
1558 ///
1559 /// # Panics
1560 ///
1561 /// Panics if the path contains null bytes.
1562 pub fn save_to_file(&self, path: impl AsRef<Path>) {
1563 let path = path.as_ref();
1564 let path_str = path.to_str().expect("path is not valid UTF-8");
1565 let c_path = CString::new(path_str).expect("path contains null bytes");
1566 unsafe {
1567 llama_model_save_to_file(self.model.as_ptr(), c_path.as_ptr());
1568 }
1569 }
1570
1571 /// Get the list of built-in chat templates.
1572 ///
1573 /// Returns the names of all chat templates that are built into llama.cpp.
1574 ///
1575 /// # Panics
1576 ///
1577 /// Panics if any template name is not valid UTF-8.
1578 #[allow(clippy::cast_sign_loss)]
1579 #[must_use]
1580 pub fn chat_builtin_templates() -> Vec<String> {
1581 // First call to get count
1582 let count = unsafe { llama_chat_builtin_templates(std::ptr::null_mut(), 0) };
1583 if count <= 0 {
1584 return Vec::new();
1585 }
1586 let count = count as usize;
1587 let mut ptrs: Vec<*const c_char> = vec![std::ptr::null(); count];
1588 unsafe {
1589 llama_chat_builtin_templates(ptrs.as_mut_ptr(), count);
1590 }
1591 ptrs.iter()
1592 .map(|&p| {
1593 let cstr = unsafe { CStr::from_ptr(p) };
1594 cstr.to_str()
1595 .expect("template name is not valid UTF-8")
1596 .to_owned()
1597 })
1598 .collect()
1599 }
1600
1601 /// Initializes a lora adapter from a file.
1602 ///
1603 /// This function initializes a Lora adapter, which is a model extension used to adapt or fine-tune the existing model
1604 /// to a specific domain or task. The adapter file is typically in the form of a binary or serialized file that can be applied
1605 /// to the model for improved performance on specialized tasks.
1606 ///
1607 /// # Errors
1608 ///
1609 /// - If the adapter file path cannot be converted to a string or if the adapter cannot be initialized, it will return an error.
1610 ///
1611 /// # Example
1612 ///
1613 /// ```no_run
1614 /// use llama_cpp_4::model::LlamaModel;
1615 /// use llama_cpp_4::model::params::LlamaModelParams;
1616 /// use llama_cpp_4::llama_backend::LlamaBackend;
1617 ///
1618 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1619 /// let backend = LlamaBackend::init()?;
1620 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1621 /// let adapter = model.lora_adapter_init("path/to/lora/adapter")?;
1622 /// # Ok(())
1623 /// # }
1624 /// ```
1625 pub fn lora_adapter_init(
1626 &self,
1627 path: impl AsRef<Path>,
1628 ) -> Result<LlamaLoraAdapter, LlamaLoraAdapterInitError> {
1629 let path = path.as_ref();
1630 debug_assert!(
1631 Path::new(path).exists(),
1632 "{} does not exist",
1633 path.display()
1634 );
1635
1636 let path = path
1637 .to_str()
1638 .ok_or(LlamaLoraAdapterInitError::PathToStrError(
1639 path.to_path_buf(),
1640 ))?;
1641
1642 let cstr = CString::new(path)?;
1643 let adapter = unsafe { llama_adapter_lora_init(self.model.as_ptr(), cstr.as_ptr()) };
1644
1645 let adapter = NonNull::new(adapter).ok_or(LlamaLoraAdapterInitError::NullResult)?;
1646
1647 tracing::debug!(?path, "Initialized lora adapter");
1648 Ok(LlamaLoraAdapter {
1649 lora_adapter: adapter,
1650 })
1651 }
1652
1653 /// Create a new context from this model.
1654 ///
1655 /// This function creates a new context for the model, which is used to manage and perform computations for inference,
1656 /// including token generation, embeddings, and other tasks that the model can perform. The context allows fine-grained
1657 /// control over model parameters for a specific task.
1658 ///
1659 /// # Errors
1660 ///
1661 /// - There are various potential failures such as invalid parameters or a failure to allocate the context. See [`LlamaContextLoadError`]
1662 /// for more detailed error descriptions.
1663 ///
1664 /// # Example
1665 ///
1666 /// ```no_run
1667 /// use llama_cpp_4::model::LlamaModel;
1668 /// use llama_cpp_4::model::params::LlamaModelParams;
1669 /// use llama_cpp_4::context::params::LlamaContextParams;
1670 /// use llama_cpp_4::llama_backend::LlamaBackend;
1671 ///
1672 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1673 /// let backend = LlamaBackend::init()?;
1674 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1675 /// let context = model.new_context(&backend, LlamaContextParams::default())?;
1676 /// # Ok(())
1677 /// # }
1678 /// ```
1679 #[allow(clippy::needless_pass_by_value)]
1680 pub fn new_context(
1681 &self,
1682 _: &LlamaBackend,
1683 params: LlamaContextParams,
1684 ) -> Result<LlamaContext<'_>, LlamaContextLoadError> {
1685 // Apply TurboQuant attn-rotation preference before the KV cache is
1686 // initialised inside llama_new_context_with_model.
1687 let prev_rot_var = std::env::var("LLAMA_ATTN_ROT_DISABLE").ok();
1688 if params.attn_rot_disabled {
1689 // SAFETY: we restore the value right after the call.
1690 #[allow(unused_unsafe)]
1691 unsafe {
1692 std::env::set_var("LLAMA_ATTN_ROT_DISABLE", "1");
1693 }
1694 } else if std::env::var("LLAMA_ATTN_ROT_DISABLE").is_ok() {
1695 // params say "enabled" – only clear if it was previously unset
1696 // (respect explicit user env var).
1697 }
1698
1699 let context_params = params.context_params;
1700 let context = unsafe { llama_new_context_with_model(self.model.as_ptr(), context_params) };
1701
1702 // Restore the env-var to its previous state.
1703 #[allow(unused_unsafe)]
1704 match prev_rot_var {
1705 Some(v) => unsafe { std::env::set_var("LLAMA_ATTN_ROT_DISABLE", v) },
1706 None if params.attn_rot_disabled => unsafe {
1707 std::env::remove_var("LLAMA_ATTN_ROT_DISABLE");
1708 },
1709 None => {}
1710 }
1711
1712 let context = NonNull::new(context).ok_or(LlamaContextLoadError::NullReturn)?;
1713 Ok(LlamaContext::new(self, context, params.embeddings()))
1714 }
1715
1716 /// Apply the model's chat template to a sequence of messages.
1717 ///
1718 /// This function applies the model's chat template to the provided chat messages, formatting them accordingly. The chat
1719 /// template determines the structure or style of conversation between the system and user, such as token formatting,
1720 /// role separation, and more. The template can be customized by providing an optional template string, or if `None`
1721 /// is provided, the default template used by `llama.cpp` will be applied.
1722 ///
1723 /// For more information on supported templates, visit:
1724 /// <https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template>
1725 ///
1726 /// # Arguments
1727 ///
1728 /// - `tmpl`: An optional custom template string. If `None`, the default template will be used.
1729 /// - `chat`: A vector of `LlamaChatMessage` instances, which represent the conversation between the system and user.
1730 /// - `add_ass`: A boolean flag indicating whether additional system-specific instructions (like "assistant") should be included.
1731 ///
1732 /// # Errors
1733 ///
1734 /// There are several possible points of failure when applying the chat template:
1735 /// - Insufficient buffer size to hold the formatted chat (this will return `ApplyChatTemplateError::BuffSizeError`).
1736 /// - If the template or messages cannot be processed properly, various errors from `ApplyChatTemplateError` may occur.
1737 ///
1738 /// # Example
1739 ///
1740 /// ```no_run
1741 /// use llama_cpp_4::model::{LlamaModel, LlamaChatMessage};
1742 /// use llama_cpp_4::model::params::LlamaModelParams;
1743 /// use llama_cpp_4::llama_backend::LlamaBackend;
1744 ///
1745 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1746 /// let backend = LlamaBackend::init()?;
1747 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1748 /// let chat = vec![
1749 /// LlamaChatMessage::new("user".to_string(), "Hello!".to_string())?,
1750 /// LlamaChatMessage::new("assistant".to_string(), "Hi! How can I assist you today?".to_string())?,
1751 /// ];
1752 /// let formatted_chat = model.apply_chat_template(None, &chat, true)?;
1753 /// # Ok(())
1754 /// # }
1755 /// ```
1756 ///
1757 /// # Notes
1758 ///
1759 /// The provided buffer is twice the length of the messages by default, which is recommended by the `llama.cpp` documentation.
1760 /// # Panics
1761 ///
1762 /// Panics if the buffer length exceeds `i32::MAX`.
1763 #[tracing::instrument(skip_all)]
1764 pub fn apply_chat_template(
1765 &self,
1766 tmpl: Option<&str>,
1767 chat: &[LlamaChatMessage],
1768 add_ass: bool,
1769 ) -> Result<String, ApplyChatTemplateError> {
1770 // Compute raw message byte total from the original LlamaChatMessage vec
1771 // *before* we shadow `chat` with the sys-type vec below.
1772 let message_length = chat.iter().fold(0usize, |acc, c| {
1773 acc + c.role.to_bytes().len() + c.content.to_bytes().len()
1774 });
1775
1776 // Build our llama_cpp_sys chat messages (raw pointers into CStrings).
1777 let chat_sys: Vec<llama_chat_message> = chat
1778 .iter()
1779 .map(|c| llama_chat_message {
1780 role: c.role.as_ptr(),
1781 content: c.content.as_ptr(),
1782 })
1783 .collect();
1784
1785 // Set the tmpl pointer.
1786 let tmpl_cstring = tmpl.map(CString::new).transpose()?;
1787 let tmpl_ptr = tmpl_cstring
1788 .as_ref()
1789 .map_or(std::ptr::null(), |s| s.as_ptr());
1790
1791 // `message_length * 4` is far too small for models whose built-in chat
1792 // template adds a long default system prompt (e.g. Qwen3.5 prepends
1793 // ~80+ chars of markup even for a one-word user message). Start with
1794 // at least 4 KiB so short inputs like "hi" always have room.
1795 //
1796 // `llama_chat_apply_template` returns the number of bytes it *actually*
1797 // needed when the buffer was too small, so we retry exactly once with
1798 // that precise size rather than giving up immediately.
1799 let mut buf_size = message_length.saturating_mul(4).max(4096);
1800
1801 for _ in 0..2 {
1802 // Use u8 so that as_mut_ptr()/as_ptr() match the binding (*mut u8 / *const u8).
1803 let mut buff = vec![0u8; buf_size];
1804 let res = unsafe {
1805 llama_chat_apply_template(
1806 tmpl_ptr,
1807 chat_sys.as_ptr(),
1808 chat_sys.len(),
1809 add_ass,
1810 buff.as_mut_ptr().cast(),
1811 i32::try_from(buff.len()).expect("buffer length fits in i32"),
1812 )
1813 };
1814
1815 if res < 0 {
1816 return Err(ApplyChatTemplateError::BuffSizeError);
1817 }
1818
1819 #[allow(clippy::cast_sign_loss)]
1820 let needed = res as usize;
1821 if needed > buf_size {
1822 // Buffer was too small — retry with the exact size llama.cpp reported.
1823 buf_size = needed + 1; // +1 for null terminator
1824 continue;
1825 }
1826
1827 // SAFETY: llama_chat_apply_template wrote a NUL-terminated string
1828 // into `buff`; `needed` bytes were used.
1829 let formatted = unsafe {
1830 CStr::from_ptr(buff.as_ptr().cast())
1831 .to_string_lossy()
1832 .into_owned()
1833 };
1834 return Ok(formatted);
1835 }
1836
1837 Err(ApplyChatTemplateError::BuffSizeError)
1838 }
1839
1840 /// Build a split GGUF file path for a specific chunk.
1841 ///
1842 /// This utility function creates the standardized filename for a split model chunk
1843 /// following the pattern: `{prefix}-{split_no:05d}-of-{split_count:05d}.gguf`
1844 ///
1845 /// # Arguments
1846 ///
1847 /// * `path_prefix` - The base path and filename prefix
1848 /// * `split_no` - The split number (1-indexed)
1849 /// * `split_count` - The total number of splits
1850 ///
1851 /// # Returns
1852 ///
1853 /// Returns the formatted split path as a String
1854 ///
1855 /// # Example
1856 ///
1857 /// ```
1858 /// use llama_cpp_4::model::LlamaModel;
1859 ///
1860 /// let path = LlamaModel::split_path("/models/llama", 1, 4);
1861 /// assert_eq!(path, "/models/llama-00002-of-00004.gguf");
1862 /// ```
1863 ///
1864 /// # Panics
1865 ///
1866 /// Panics if the path prefix contains a null byte.
1867 #[must_use]
1868 pub fn split_path(path_prefix: &str, split_no: i32, split_count: i32) -> String {
1869 let mut buffer = vec![0u8; 1024];
1870 let len = unsafe {
1871 llama_split_path(
1872 buffer.as_mut_ptr().cast::<c_char>(),
1873 buffer.len(),
1874 CString::new(path_prefix).unwrap().as_ptr(),
1875 split_no,
1876 split_count,
1877 )
1878 };
1879
1880 let len = usize::try_from(len).expect("split_path length fits in usize");
1881 buffer.truncate(len);
1882 String::from_utf8(buffer).unwrap_or_default()
1883 }
1884
1885 /// Extract the path prefix from a split filename.
1886 ///
1887 /// This function extracts the base path prefix from a split model filename,
1888 /// but only if the `split_no` and `split_count` match the pattern in the filename.
1889 ///
1890 /// # Arguments
1891 ///
1892 /// * `split_path` - The full path to the split file
1893 /// * `split_no` - The expected split number
1894 /// * `split_count` - The expected total number of splits
1895 ///
1896 /// # Returns
1897 ///
1898 /// Returns the path prefix if the pattern matches, or None if it doesn't
1899 ///
1900 /// # Example
1901 ///
1902 /// ```
1903 /// use llama_cpp_4::model::LlamaModel;
1904 ///
1905 /// let prefix = LlamaModel::split_prefix("/models/llama-00002-of-00004.gguf", 1, 4);
1906 /// assert_eq!(prefix, Some("/models/llama".to_string()));
1907 /// ```
1908 ///
1909 /// # Panics
1910 ///
1911 /// Panics if the split path contains a null byte.
1912 #[must_use]
1913 pub fn split_prefix(split_path: &str, split_no: i32, split_count: i32) -> Option<String> {
1914 let mut buffer = vec![0u8; 1024];
1915 let len = unsafe {
1916 llama_split_prefix(
1917 buffer.as_mut_ptr().cast::<c_char>(),
1918 buffer.len(),
1919 CString::new(split_path).unwrap().as_ptr(),
1920 split_no,
1921 split_count,
1922 )
1923 };
1924
1925 if len > 0 {
1926 let len = usize::try_from(len).expect("split_prefix length fits in usize");
1927 buffer.truncate(len);
1928 String::from_utf8(buffer).ok()
1929 } else {
1930 None
1931 }
1932 }
1933}
1934
1935#[allow(clippy::cast_precision_loss)]
1936impl fmt::Display for LlamaModel {
1937 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1938 let desc = self.desc(256).unwrap_or_else(|_| "unknown".to_string());
1939 write!(
1940 f,
1941 "{desc} | {layers}L {heads}H {embd}E | {params} params | {size:.1} MiB",
1942 layers = self.n_layer(),
1943 heads = self.n_head(),
1944 embd = self.n_embd(),
1945 params = self.n_params(),
1946 size = self.model_size() as f64 / (1024.0 * 1024.0),
1947 )
1948 }
1949}
1950
1951impl Drop for LlamaModel {
1952 fn drop(&mut self) {
1953 unsafe { llama_free_model(self.model.as_ptr()) }
1954 }
1955}
1956
1957/// Defines the possible types of vocabulary used by the model.
1958///
1959/// The model may use different types of vocabulary depending on the tokenization method chosen during training.
1960/// This enum represents these types, specifically `BPE` (Byte Pair Encoding) and `SPM` (`SentencePiece`).
1961///
1962/// # Variants
1963///
1964/// - `BPE`: Byte Pair Encoding, a common tokenization method used in NLP tasks.
1965/// - `SPM`: `SentencePiece`, another popular tokenization method for NLP models.
1966///
1967/// # Example
1968///
1969/// ```no_run
1970/// use llama_cpp_4::model::VocabType;
1971///
1972/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1973/// let vocab_type = VocabType::BPE;
1974/// match vocab_type {
1975/// VocabType::BPE => println!("The model uses Byte Pair Encoding (BPE)"),
1976/// VocabType::SPM => println!("The model uses SentencePiece (SPM)"),
1977/// }
1978/// # Ok(())
1979/// # }
1980/// ```
1981#[repr(u32)]
1982#[derive(Debug, Eq, Copy, Clone, PartialEq)]
1983pub enum VocabType {
1984 /// Byte Pair Encoding
1985 BPE = LLAMA_VOCAB_TYPE_BPE as _,
1986 /// Sentence Piece Tokenizer
1987 SPM = LLAMA_VOCAB_TYPE_SPM as _,
1988}
1989
1990/// Error that occurs when trying to convert a `llama_vocab_type` to a `VocabType`.
1991///
1992/// This error is raised when the integer value returned by the system does not correspond to a known vocabulary type.
1993///
1994/// # Variants
1995///
1996/// - `UnknownValue`: The error is raised when the value is not a valid `llama_vocab_type`. The invalid value is returned with the error.
1997///
1998/// # Example
1999///
2000/// ```no_run
2001/// use llama_cpp_4::model::LlamaTokenTypeFromIntError;
2002///
2003/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
2004/// let invalid_value = 999; // Not a valid vocabulary type
2005/// let error = LlamaTokenTypeFromIntError::UnknownValue(invalid_value);
2006/// println!("Error: {}", error);
2007/// # Ok(())
2008/// # }
2009/// ```
2010#[derive(thiserror::Error, Debug, Eq, PartialEq)]
2011pub enum LlamaTokenTypeFromIntError {
2012 /// The value is not a valid `llama_token_type`. Contains the int value that was invalid.
2013 #[error("Unknown Value {0}")]
2014 UnknownValue(llama_vocab_type),
2015}
2016
2017impl TryFrom<llama_vocab_type> for VocabType {
2018 type Error = LlamaTokenTypeFromIntError;
2019
2020 fn try_from(value: llama_vocab_type) -> Result<Self, Self::Error> {
2021 match value {
2022 LLAMA_VOCAB_TYPE_BPE => Ok(VocabType::BPE),
2023 LLAMA_VOCAB_TYPE_SPM => Ok(VocabType::SPM),
2024 unknown => Err(LlamaTokenTypeFromIntError::UnknownValue(unknown)),
2025 }
2026 }
2027}