llama_cpp_4/model.rs
1//! A safe wrapper around `llama_model`.
2use std::ffi::CStr;
3use std::ffi::CString;
4use std::fmt;
5use std::num::NonZeroU16;
6use std::os::raw::{c_char, c_int};
7use std::path::Path;
8use std::ptr::NonNull;
9use std::slice;
10
11use llama_cpp_sys_4::{
12 llama_adapter_lora, llama_adapter_lora_init, llama_chat_apply_template,
13 llama_chat_builtin_templates, llama_chat_message, llama_detokenize, llama_init_from_model,
14 llama_model, llama_model_cls_label, llama_model_decoder_start_token, llama_model_desc,
15 llama_model_free, llama_model_get_device, llama_model_get_vocab, llama_model_has_decoder,
16 llama_model_has_encoder, llama_model_is_diffusion, llama_model_is_hybrid,
17 llama_model_is_recurrent, llama_model_load_from_file, llama_model_load_from_splits,
18 llama_model_meta_count, llama_model_meta_key_by_index, llama_model_meta_val_str,
19 llama_model_meta_val_str_by_index, llama_model_n_cls_out, llama_model_n_ctx_train,
20 llama_model_n_devices, llama_model_n_embd, llama_model_n_embd_inp, llama_model_n_embd_out,
21 llama_model_n_expert, llama_model_n_head, llama_model_n_head_kv, llama_model_n_layer,
22 llama_model_n_layer_nextn, llama_model_n_params, llama_model_n_swa,
23 llama_model_rope_freq_scale_train, llama_model_rope_type, llama_model_save_to_file,
24 llama_model_size, llama_model_target_layer_ids, llama_model_target_layer_ids_n,
25 llama_split_path, llama_split_prefix, llama_token_to_piece, llama_tokenize, llama_vocab,
26 llama_vocab_type, LLAMA_VOCAB_TYPE_BPE, LLAMA_VOCAB_TYPE_SPM,
27};
28
29use crate::context::params::LlamaContextParams;
30use crate::context::LlamaContext;
31use crate::llama_backend::LlamaBackend;
32use crate::model::params::LlamaModelParams;
33use crate::token::LlamaToken;
34use crate::token_type::{LlamaTokenAttr, LlamaTokenAttrs};
35use crate::{
36 ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaLoraAdapterInitError,
37 LlamaModelLoadError, NewLlamaChatMessageError, StringFromModelError, StringToTokenError,
38 TokenToStringError,
39};
40
41pub mod params;
42
43/// Opaque ggml backend device handle returned by [`LlamaModel::get_device`].
44///
45/// Use [`Self::name`], [`Self::description`], [`Self::device_type`], and
46/// [`Self::memory`] to inspect the device. The handle is valid for the lifetime
47/// of the parent [`LlamaModel`].
48#[derive(Debug, Copy, Clone, PartialEq, Eq)]
49pub struct LlamaBackendDevice {
50 pub(crate) dev: llama_cpp_sys_4::ggml_backend_dev_t,
51}
52
53/// Backend device class (CPU, discrete GPU, integrated GPU, …).
54#[repr(i32)]
55#[derive(Copy, Clone, Debug, PartialEq, Eq)]
56pub enum LlamaBackendDeviceType {
57 /// Host CPU backend.
58 Cpu = llama_cpp_sys_4::GGML_BACKEND_DEVICE_TYPE_CPU.cast_signed(),
59 /// Discrete GPU.
60 Gpu = llama_cpp_sys_4::GGML_BACKEND_DEVICE_TYPE_GPU.cast_signed(),
61 /// Integrated GPU.
62 IntegratedGpu = llama_cpp_sys_4::GGML_BACKEND_DEVICE_TYPE_IGPU.cast_signed(),
63 /// Accelerator device (e.g. BLAS / Hexagon).
64 Accel = llama_cpp_sys_4::GGML_BACKEND_DEVICE_TYPE_ACCEL.cast_signed(),
65 /// Meta / placeholder device entry.
66 Meta = llama_cpp_sys_4::GGML_BACKEND_DEVICE_TYPE_META.cast_signed(),
67}
68
69impl From<llama_cpp_sys_4::ggml_backend_dev_type> for LlamaBackendDeviceType {
70 fn from(value: llama_cpp_sys_4::ggml_backend_dev_type) -> Self {
71 match value {
72 llama_cpp_sys_4::GGML_BACKEND_DEVICE_TYPE_CPU => Self::Cpu,
73 llama_cpp_sys_4::GGML_BACKEND_DEVICE_TYPE_GPU => Self::Gpu,
74 llama_cpp_sys_4::GGML_BACKEND_DEVICE_TYPE_IGPU => Self::IntegratedGpu,
75 llama_cpp_sys_4::GGML_BACKEND_DEVICE_TYPE_ACCEL => Self::Accel,
76 _ => Self::Meta,
77 }
78 }
79}
80
81impl LlamaBackendDevice {
82 /// Human-readable device name (e.g. `CUDA0`, `Metal`).
83 ///
84 /// # Errors
85 ///
86 /// Returns an error when the name pointer is null or not valid UTF-8.
87 pub fn name(&self) -> Result<&str, StringFromModelError> {
88 let ptr = unsafe { llama_cpp_sys_4::ggml_backend_dev_name(self.dev) };
89 if ptr.is_null() {
90 return Err(StringFromModelError::ReturnedError(-1));
91 }
92 let cstr = unsafe { CStr::from_ptr(ptr) };
93 cstr.to_str().map_err(StringFromModelError::Utf8Error)
94 }
95
96 /// Longer device description (often includes hardware name).
97 ///
98 /// # Errors
99 ///
100 /// Returns an error when the description pointer is null or not valid UTF-8.
101 pub fn description(&self) -> Result<&str, StringFromModelError> {
102 let ptr = unsafe { llama_cpp_sys_4::ggml_backend_dev_description(self.dev) };
103 if ptr.is_null() {
104 return Err(StringFromModelError::ReturnedError(-1));
105 }
106 let cstr = unsafe { CStr::from_ptr(ptr) };
107 cstr.to_str().map_err(StringFromModelError::Utf8Error)
108 }
109
110 /// Device class (CPU, GPU, integrated GPU, …).
111 #[must_use]
112 pub fn device_type(&self) -> LlamaBackendDeviceType {
113 unsafe { llama_cpp_sys_4::ggml_backend_dev_type(self.dev).into() }
114 }
115
116 /// Device memory `(free_bytes, total_bytes)`.
117 #[must_use]
118 pub fn memory(&self) -> (usize, usize) {
119 let mut free = 0usize;
120 let mut total = 0usize;
121 unsafe {
122 llama_cpp_sys_4::ggml_backend_dev_memory(self.dev, &raw mut free, &raw mut total);
123 }
124 (free, total)
125 }
126}
127
128/// Iterator over [`LlamaBackendDevice`] handles for a loaded model.
129///
130/// # Examples
131///
132/// ```no_run
133/// use llama_cpp_4::prelude::*;
134///
135/// fn main() {
136/// let backend = LlamaBackend::init().unwrap();
137/// let model = LlamaModel::load_from_file(&backend, "model.gguf", &LlamaModelParams::default()).unwrap();
138/// for dev in model.devices() {
139/// let (free, total) = dev.memory();
140/// println!("{}: {} / {} bytes free", dev.name().unwrap(), free, total);
141/// }
142/// }
143/// ```
144#[derive(Debug, Clone, Copy)]
145pub struct LlamaBackendDevices<'a> {
146 model: &'a LlamaModel,
147 next: i32,
148}
149
150#[allow(clippy::copy_iterator)]
151impl Iterator for LlamaBackendDevices<'_> {
152 type Item = LlamaBackendDevice;
153
154 fn next(&mut self) -> Option<Self::Item> {
155 let dev = self.model.get_device(self.next)?;
156 self.next += 1;
157 Some(dev)
158 }
159
160 fn size_hint(&self) -> (usize, Option<usize>) {
161 let remaining = usize::try_from((self.model.n_devices() - self.next).max(0)).unwrap_or(0);
162 (remaining, Some(remaining))
163 }
164}
165
166impl ExactSizeIterator for LlamaBackendDevices<'_> {}
167
168/// A safe wrapper around `llama_model`.
169#[derive(Debug)]
170#[repr(transparent)]
171#[allow(clippy::module_name_repetitions)]
172pub struct LlamaModel {
173 pub(crate) model: NonNull<llama_model>,
174}
175
176/// A safe wrapper around `llama_vocab`.
177#[derive(Debug)]
178#[repr(transparent)]
179#[allow(clippy::module_name_repetitions)]
180pub struct LlamaVocab {
181 pub(crate) vocab: NonNull<llama_vocab>,
182}
183
184impl LlamaVocab {
185 /// Get the number of tokens in the vocabulary.
186 #[must_use]
187 pub fn n_tokens(&self) -> i32 {
188 unsafe { llama_cpp_sys_4::llama_vocab_n_tokens(self.vocab.as_ref()) }
189 }
190
191 /// Get the vocabulary type.
192 ///
193 /// # Panics
194 ///
195 /// Panics if the C API returns a vocabulary type that does not fit in `u32`.
196 #[must_use]
197 pub fn vocab_type(&self) -> u32 {
198 unsafe { llama_cpp_sys_4::llama_vocab_type(self.vocab.as_ref()) as u32 }
199 }
200
201 /// Get the BOS token.
202 #[must_use]
203 pub fn bos(&self) -> LlamaToken {
204 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_bos(self.vocab.as_ref()) })
205 }
206
207 /// Get the EOS token.
208 #[must_use]
209 pub fn eos(&self) -> LlamaToken {
210 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_eos(self.vocab.as_ref()) })
211 }
212
213 /// Get the EOT (end of turn) token.
214 #[must_use]
215 pub fn eot(&self) -> LlamaToken {
216 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_eot(self.vocab.as_ref()) })
217 }
218
219 /// Get the CLS (classification) token.
220 #[must_use]
221 pub fn cls(&self) -> LlamaToken {
222 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_cls(self.vocab.as_ref()) })
223 }
224
225 /// Get the SEP (separator) token.
226 #[must_use]
227 pub fn sep(&self) -> LlamaToken {
228 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_sep(self.vocab.as_ref()) })
229 }
230
231 /// Get the NL (newline) token.
232 #[must_use]
233 pub fn nl(&self) -> LlamaToken {
234 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_nl(self.vocab.as_ref()) })
235 }
236
237 /// Get the PAD (padding) token.
238 #[must_use]
239 pub fn pad(&self) -> LlamaToken {
240 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_pad(self.vocab.as_ref()) })
241 }
242
243 /// Get the FIM prefix token.
244 #[must_use]
245 pub fn fim_pre(&self) -> LlamaToken {
246 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_pre(self.vocab.as_ref()) })
247 }
248
249 /// Get the FIM suffix token.
250 #[must_use]
251 pub fn fim_suf(&self) -> LlamaToken {
252 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_suf(self.vocab.as_ref()) })
253 }
254
255 /// Get the FIM middle token.
256 #[must_use]
257 pub fn fim_mid(&self) -> LlamaToken {
258 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_mid(self.vocab.as_ref()) })
259 }
260
261 /// Get the FIM padding token.
262 #[must_use]
263 pub fn fim_pad(&self) -> LlamaToken {
264 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_pad(self.vocab.as_ref()) })
265 }
266
267 /// Get the FIM repository token.
268 #[must_use]
269 pub fn fim_rep(&self) -> LlamaToken {
270 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_rep(self.vocab.as_ref()) })
271 }
272
273 /// Get the FIM separator token.
274 #[must_use]
275 pub fn fim_sep(&self) -> LlamaToken {
276 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_fim_sep(self.vocab.as_ref()) })
277 }
278
279 /// Check whether BOS should be added.
280 #[must_use]
281 pub fn get_add_bos(&self) -> bool {
282 unsafe { llama_cpp_sys_4::llama_vocab_get_add_bos(self.vocab.as_ref()) }
283 }
284
285 /// Check whether EOS should be added.
286 #[must_use]
287 pub fn get_add_eos(&self) -> bool {
288 unsafe { llama_cpp_sys_4::llama_vocab_get_add_eos(self.vocab.as_ref()) }
289 }
290
291 /// Check whether SEP should be added.
292 #[must_use]
293 pub fn get_add_sep(&self) -> bool {
294 unsafe { llama_cpp_sys_4::llama_vocab_get_add_sep(self.vocab.as_ref()) }
295 }
296
297 /// Get the text representation of a token.
298 ///
299 /// # Errors
300 ///
301 /// Returns an error if the text pointer is null or not valid UTF-8.
302 pub fn get_text(&self, token: LlamaToken) -> Result<&str, StringFromModelError> {
303 let ptr = unsafe { llama_cpp_sys_4::llama_vocab_get_text(self.vocab.as_ref(), token.0) };
304 if ptr.is_null() {
305 return Err(StringFromModelError::ReturnedError(-1));
306 }
307 let cstr = unsafe { CStr::from_ptr(ptr) };
308 cstr.to_str().map_err(StringFromModelError::Utf8Error)
309 }
310
311 /// Get the score of a token.
312 #[must_use]
313 pub fn get_score(&self, token: LlamaToken) -> f32 {
314 unsafe { llama_cpp_sys_4::llama_vocab_get_score(self.vocab.as_ref(), token.0) }
315 }
316
317 /// Get the attributes of a token.
318 ///
319 /// # Panics
320 ///
321 /// Panics if the C API returns attributes that do not fit in `u32`.
322 #[must_use]
323 pub fn get_attr(&self, token: LlamaToken) -> u32 {
324 unsafe { llama_cpp_sys_4::llama_vocab_get_attr(self.vocab.as_ref(), token.0) as u32 }
325 }
326
327 /// Check if a token is a control token.
328 #[must_use]
329 pub fn is_control(&self, token: LlamaToken) -> bool {
330 unsafe { llama_cpp_sys_4::llama_vocab_is_control(self.vocab.as_ref(), token.0) }
331 }
332
333 /// Check if a token is an end-of-generation token.
334 #[must_use]
335 pub fn is_eog(&self, token: LlamaToken) -> bool {
336 unsafe { llama_cpp_sys_4::llama_vocab_is_eog(self.vocab.as_ref(), token.0) }
337 }
338
339 /// Get the token mask value for the vocabulary.
340 #[must_use]
341 pub fn mask(&self) -> LlamaToken {
342 LlamaToken(unsafe { llama_cpp_sys_4::llama_vocab_mask(self.vocab.as_ref()) })
343 }
344}
345
346/// A safe wrapper around `llama_adapter_lora`.
347#[derive(Debug)]
348#[repr(transparent)]
349#[allow(clippy::module_name_repetitions)]
350pub struct LlamaLoraAdapter {
351 pub(crate) lora_adapter: NonNull<llama_adapter_lora>,
352}
353
354impl LlamaLoraAdapter {
355 /// Get the number of metadata key-value pairs in the adapter.
356 #[must_use]
357 pub fn meta_count(&self) -> i32 {
358 unsafe { llama_cpp_sys_4::llama_adapter_meta_count(self.lora_adapter.as_ptr()) }
359 }
360
361 /// Get a metadata key by index.
362 ///
363 /// # Errors
364 ///
365 /// Returns an error if the index is out of range or the key is not valid UTF-8.
366 #[allow(clippy::cast_sign_loss)]
367 pub fn meta_key_by_index(
368 &self,
369 index: i32,
370 buf_size: usize,
371 ) -> Result<String, StringFromModelError> {
372 let mut buf = vec![0u8; buf_size];
373 let ret = unsafe {
374 llama_cpp_sys_4::llama_adapter_meta_key_by_index(
375 self.lora_adapter.as_ptr(),
376 index,
377 buf.as_mut_ptr().cast::<c_char>(),
378 buf_size,
379 )
380 };
381 if ret < 0 {
382 return Err(StringFromModelError::ReturnedError(ret));
383 }
384 let len = ret as usize;
385 let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
386 Ok(s.to_owned())
387 }
388
389 /// Get a metadata value by key name.
390 ///
391 /// # Errors
392 ///
393 /// Returns an error if the key is not found or the value is not valid UTF-8.
394 #[allow(clippy::cast_sign_loss)]
395 pub fn meta_val_str(&self, key: &str, buf_size: usize) -> Result<String, StringFromModelError> {
396 let c_key = CString::new(key).map_err(|_| StringFromModelError::ReturnedError(-1))?;
397 let mut buf = vec![0u8; buf_size];
398 let ret = unsafe {
399 llama_cpp_sys_4::llama_adapter_meta_val_str(
400 self.lora_adapter.as_ptr(),
401 c_key.as_ptr(),
402 buf.as_mut_ptr().cast::<c_char>(),
403 buf_size,
404 )
405 };
406 if ret < 0 {
407 return Err(StringFromModelError::ReturnedError(ret));
408 }
409 let len = ret as usize;
410 let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
411 Ok(s.to_owned())
412 }
413
414 /// Get a metadata value by index.
415 ///
416 /// # Errors
417 ///
418 /// Returns an error if the index is out of range or the value is not valid UTF-8.
419 #[allow(clippy::cast_sign_loss)]
420 pub fn meta_val_str_by_index(
421 &self,
422 index: i32,
423 buf_size: usize,
424 ) -> Result<String, StringFromModelError> {
425 let mut buf = vec![0u8; buf_size];
426 let ret = unsafe {
427 llama_cpp_sys_4::llama_adapter_meta_val_str_by_index(
428 self.lora_adapter.as_ptr(),
429 index,
430 buf.as_mut_ptr().cast::<c_char>(),
431 buf_size,
432 )
433 };
434 if ret < 0 {
435 return Err(StringFromModelError::ReturnedError(ret));
436 }
437 let len = ret as usize;
438 let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
439 Ok(s.to_owned())
440 }
441
442 /// Get all metadata as a list of `(key, value)` pairs.
443 ///
444 /// # Errors
445 ///
446 /// Returns an error if any key or value cannot be read or is not valid UTF-8.
447 #[allow(clippy::cast_sign_loss)]
448 pub fn metadata(&self) -> Result<Vec<(String, String)>, StringFromModelError> {
449 let count = self.meta_count();
450 let mut result = Vec::with_capacity(count as usize);
451 for i in 0..count {
452 let key = self.meta_key_by_index(i, 256)?;
453 let val = self.meta_val_str_by_index(i, 4096)?;
454 result.push((key, val));
455 }
456 Ok(result)
457 }
458
459 /// Get the number of invocation tokens for this adapter.
460 #[must_use]
461 pub fn n_invocation_tokens(&self) -> u64 {
462 unsafe {
463 llama_cpp_sys_4::llama_adapter_get_alora_n_invocation_tokens(self.lora_adapter.as_ptr())
464 }
465 }
466
467 /// Get the invocation tokens for this adapter.
468 ///
469 /// Returns an empty slice if there are no invocation tokens.
470 #[must_use]
471 #[allow(clippy::cast_possible_truncation)]
472 pub fn invocation_tokens(&self) -> &[LlamaToken] {
473 let n = self.n_invocation_tokens() as usize;
474 if n == 0 {
475 return &[];
476 }
477 let ptr = unsafe {
478 llama_cpp_sys_4::llama_adapter_get_alora_invocation_tokens(self.lora_adapter.as_ptr())
479 };
480 if ptr.is_null() {
481 return &[];
482 }
483 // LlamaToken is repr(transparent) over llama_token (i32), so this cast is safe
484 unsafe { std::slice::from_raw_parts(ptr.cast::<LlamaToken>(), n) }
485 }
486}
487
488impl Drop for LlamaLoraAdapter {
489 fn drop(&mut self) {
490 unsafe {
491 llama_cpp_sys_4::llama_adapter_lora_free(self.lora_adapter.as_ptr());
492 }
493 }
494}
495
496/// A Safe wrapper around `llama_chat_message`
497#[derive(Debug, Eq, PartialEq, Clone)]
498pub struct LlamaChatMessage {
499 role: CString,
500 content: CString,
501}
502
503impl LlamaChatMessage {
504 /// Create a new `LlamaChatMessage`.
505 ///
506 /// # Errors
507 ///
508 /// Returns [`NewLlamaChatMessageError`] if the role or content contains a null byte.
509 pub fn new(role: String, content: String) -> Result<Self, NewLlamaChatMessageError> {
510 Ok(Self {
511 role: CString::new(role)?,
512 content: CString::new(content)?,
513 })
514 }
515}
516
517/// How to determine if we should prepend a bos token to tokens
518#[derive(Debug, Clone, Copy, PartialEq, Eq)]
519pub enum AddBos {
520 /// Add the beginning of stream token to the start of the string.
521 Always,
522 /// Do not add the beginning of stream token to the start of the string.
523 Never,
524}
525
526/// How to determine if we should tokenize special tokens
527#[derive(Debug, Clone, Copy, PartialEq, Eq)]
528pub enum Special {
529 /// Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space.
530 Tokenize,
531 /// Treat special and/or control tokens as plaintext.
532 Plaintext,
533}
534
535unsafe impl Send for LlamaModel {}
536
537unsafe impl Sync for LlamaModel {}
538
539impl LlamaModel {
540 /// Retrieves the vocabulary associated with the current Llama model.
541 ///
542 /// This method fetches the vocabulary from the underlying model using an unsafe
543 /// FFI call. The returned `LlamaVocab` struct contains a non-null pointer to
544 /// the vocabulary data, which is wrapped in a `NonNull` for safety.
545 ///
546 /// # Safety
547 /// This method uses an unsafe block to call a C function (`llama_model_get_vocab`),
548 /// which is assumed to return a valid pointer to the vocabulary. The caller should
549 /// ensure that the model object is properly initialized and valid before calling
550 /// this method, as dereferencing invalid pointers can lead to undefined behavior.
551 ///
552 /// # Returns
553 /// A `LlamaVocab` struct containing the vocabulary of the model.
554 ///
555 /// # Panics
556 ///
557 /// Panics if the underlying C function returns a null pointer.
558 ///
559 /// # Example
560 /// ```rust,ignore
561 /// let vocab = model.get_vocab();
562 /// ```
563 #[must_use]
564 pub fn get_vocab(&self) -> LlamaVocab {
565 let llama_vocab = unsafe { llama_model_get_vocab(self.model.as_ptr()) }.cast_mut();
566
567 LlamaVocab {
568 vocab: NonNull::new(llama_vocab).unwrap(),
569 }
570 }
571 /// Get the number of tokens the model was trained on.
572 ///
573 /// This function returns the number of tokens that the model was trained on, represented as a `u32`.
574 ///
575 /// # Panics
576 ///
577 /// This function will panic if the number of tokens the model was trained on does not fit into a `u32`.
578 /// This should be impossible on most platforms since llama.cpp returns a `c_int` (i32 on most platforms),
579 /// which is almost certainly positive.
580 #[must_use]
581 pub fn n_ctx_train(&self) -> u32 {
582 let n_ctx_train = unsafe { llama_model_n_ctx_train(self.model.as_ptr()) };
583 u32::try_from(n_ctx_train).expect("n_ctx_train fits into an u32")
584 }
585
586 /// Get all tokens in the model.
587 ///
588 /// This function returns an iterator over all the tokens in the model. Each item in the iterator is a tuple
589 /// containing a `LlamaToken` and its corresponding string representation (or an error if the conversion fails).
590 ///
591 /// # Parameters
592 ///
593 /// - `special`: The `Special` value that determines how special tokens (like BOS, EOS, etc.) are handled.
594 pub fn tokens(
595 &self,
596 special: Special,
597 ) -> impl Iterator<Item = (LlamaToken, Result<String, TokenToStringError>)> + '_ {
598 (0..self.n_vocab())
599 .map(LlamaToken::new)
600 .map(move |llama_token| (llama_token, self.token_to_str(llama_token, special)))
601 }
602
603 /// Get the beginning of stream token.
604 ///
605 /// This function returns the token that represents the beginning of a stream (BOS token).
606 #[must_use]
607 pub fn token_bos(&self) -> LlamaToken {
608 self.get_vocab().bos()
609 }
610
611 /// Get the end of stream token.
612 ///
613 /// This function returns the token that represents the end of a stream (EOS token).
614 #[must_use]
615 pub fn token_eos(&self) -> LlamaToken {
616 self.get_vocab().eos()
617 }
618
619 /// Get the newline token.
620 ///
621 /// This function returns the token that represents a newline character.
622 #[must_use]
623 pub fn token_nl(&self) -> LlamaToken {
624 self.get_vocab().nl()
625 }
626
627 /// Check if a token represents the end of generation (end of turn, end of sequence, etc.).
628 ///
629 /// This function returns `true` if the provided token signifies the end of generation or end of sequence,
630 /// such as EOS or other special tokens.
631 ///
632 /// # Parameters
633 ///
634 /// - `token`: The `LlamaToken` to check.
635 ///
636 /// # Returns
637 ///
638 /// - `true` if the token is an end-of-generation token, otherwise `false`.
639 #[must_use]
640 pub fn is_eog_token(&self, token: LlamaToken) -> bool {
641 self.get_vocab().is_eog(token)
642 }
643
644 /// Get the classification token.
645 #[must_use]
646 pub fn token_cls(&self) -> LlamaToken {
647 self.get_vocab().cls()
648 }
649
650 /// Get the end-of-turn token.
651 #[must_use]
652 pub fn token_eot(&self) -> LlamaToken {
653 self.get_vocab().eot()
654 }
655
656 /// Get the padding token.
657 #[must_use]
658 pub fn token_pad(&self) -> LlamaToken {
659 self.get_vocab().pad()
660 }
661
662 /// Get the separator token.
663 #[must_use]
664 pub fn token_sep(&self) -> LlamaToken {
665 self.get_vocab().sep()
666 }
667
668 /// Get the fill-in-the-middle prefix token.
669 #[must_use]
670 pub fn token_fim_pre(&self) -> LlamaToken {
671 self.get_vocab().fim_pre()
672 }
673
674 /// Get the fill-in-the-middle suffix token.
675 #[must_use]
676 pub fn token_fim_suf(&self) -> LlamaToken {
677 self.get_vocab().fim_suf()
678 }
679
680 /// Get the fill-in-the-middle middle token.
681 #[must_use]
682 pub fn token_fim_mid(&self) -> LlamaToken {
683 self.get_vocab().fim_mid()
684 }
685
686 /// Get the fill-in-the-middle padding token.
687 #[must_use]
688 pub fn token_fim_pad(&self) -> LlamaToken {
689 self.get_vocab().fim_pad()
690 }
691
692 /// Get the fill-in-the-middle repository token.
693 #[must_use]
694 pub fn token_fim_rep(&self) -> LlamaToken {
695 self.get_vocab().fim_rep()
696 }
697
698 /// Get the fill-in-the-middle separator token.
699 #[must_use]
700 pub fn token_fim_sep(&self) -> LlamaToken {
701 self.get_vocab().fim_sep()
702 }
703
704 /// Check if a token is a control token.
705 #[must_use]
706 pub fn token_is_control(&self, token: LlamaToken) -> bool {
707 self.get_vocab().is_control(token)
708 }
709
710 /// Get the score of a token.
711 #[must_use]
712 pub fn token_get_score(&self, token: LlamaToken) -> f32 {
713 self.get_vocab().get_score(token)
714 }
715
716 /// Get the raw text of a token.
717 ///
718 /// # Errors
719 ///
720 /// Returns an error if the token text is null or not valid UTF-8.
721 pub fn token_get_text(&self, token: LlamaToken) -> Result<&str, StringFromModelError> {
722 let ptr = unsafe {
723 llama_cpp_sys_4::llama_vocab_get_text(self.get_vocab().vocab.as_ref(), token.0)
724 };
725 if ptr.is_null() {
726 return Err(StringFromModelError::ReturnedError(-1));
727 }
728 let cstr = unsafe { CStr::from_ptr(ptr) };
729 cstr.to_str().map_err(StringFromModelError::Utf8Error)
730 }
731
732 /// Check if a BOS token should be added when tokenizing.
733 #[must_use]
734 pub fn add_bos_token(&self) -> bool {
735 self.get_vocab().get_add_bos()
736 }
737
738 /// Check if an EOS token should be added when tokenizing.
739 #[must_use]
740 pub fn add_eos_token(&self) -> bool {
741 self.get_vocab().get_add_eos()
742 }
743
744 /// Get the decoder start token.
745 ///
746 /// This function returns the token used to signal the start of decoding (i.e., the token used at the start
747 /// of a sequence generation).
748 #[must_use]
749 pub fn decode_start_token(&self) -> LlamaToken {
750 let token = unsafe { llama_model_decoder_start_token(self.model.as_ptr()) };
751 LlamaToken(token)
752 }
753
754 /// Convert a single token to a string.
755 ///
756 /// This function converts a `LlamaToken` into its string representation.
757 ///
758 /// # Errors
759 ///
760 /// This function returns an error if the token cannot be converted to a string. For more details, refer to
761 /// [`TokenToStringError`].
762 ///
763 /// # Parameters
764 ///
765 /// - `token`: The `LlamaToken` to convert.
766 /// - `special`: The `Special` value used to handle special tokens.
767 pub fn token_to_str(
768 &self,
769 token: LlamaToken,
770 special: Special,
771 ) -> Result<String, TokenToStringError> {
772 self.token_to_str_with_size(token, 32, special)
773 }
774
775 /// Convert a single token to bytes.
776 ///
777 /// This function converts a `LlamaToken` into a byte representation.
778 ///
779 /// # Errors
780 ///
781 /// This function returns an error if the token cannot be converted to bytes. For more details, refer to
782 /// [`TokenToStringError`].
783 ///
784 /// # Parameters
785 ///
786 /// - `token`: The `LlamaToken` to convert.
787 /// - `special`: The `Special` value used to handle special tokens.
788 pub fn token_to_bytes(
789 &self,
790 token: LlamaToken,
791 special: Special,
792 ) -> Result<Vec<u8>, TokenToStringError> {
793 self.token_to_bytes_with_size(token, 32, special, None)
794 }
795
796 /// Convert a vector of tokens to a single string.
797 ///
798 /// This function takes a slice of `LlamaToken`s and converts them into a single string, concatenating their
799 /// string representations.
800 ///
801 /// # Errors
802 ///
803 /// This function returns an error if any token cannot be converted to a string. For more details, refer to
804 /// [`TokenToStringError`].
805 ///
806 /// # Parameters
807 ///
808 /// - `tokens`: A slice of `LlamaToken`s to convert.
809 /// - `special`: The `Special` value used to handle special tokens.
810 pub fn tokens_to_str(
811 &self,
812 tokens: &[LlamaToken],
813 special: Special,
814 ) -> Result<String, TokenToStringError> {
815 let mut builder = String::with_capacity(tokens.len() * 4);
816 for str in tokens
817 .iter()
818 .copied()
819 .map(|t| self.token_to_str(t, special))
820 {
821 builder += &str?;
822 }
823 Ok(builder)
824 }
825
826 /// Convert a string to a vector of tokens.
827 ///
828 /// This function converts a string into a vector of `LlamaToken`s. The function will tokenize the string
829 /// and return the corresponding tokens.
830 ///
831 /// # Errors
832 ///
833 /// - This function will return an error if the input string contains a null byte.
834 ///
835 /// # Panics
836 ///
837 /// - This function will panic if the number of tokens exceeds `usize::MAX`.
838 ///
839 /// # Example
840 ///
841 /// ```no_run
842 /// use llama_cpp_4::model::LlamaModel;
843 ///
844 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
845 /// use std::path::Path;
846 /// use llama_cpp_4::model::AddBos;
847 /// let backend = llama_cpp_4::llama_backend::LlamaBackend::init()?;
848 /// let model = LlamaModel::load_from_file(&backend, Path::new("path/to/model"), &Default::default())?;
849 /// let tokens = model.str_to_token("Hello, World!", AddBos::Always)?;
850 /// # Ok(())
851 /// # }
852 /// ```
853 pub fn str_to_token(
854 &self,
855 str: &str,
856 add_bos: AddBos,
857 ) -> Result<Vec<LlamaToken>, StringToTokenError> {
858 let add_bos = match add_bos {
859 AddBos::Always => true,
860 AddBos::Never => false,
861 };
862
863 let tokens_estimation = std::cmp::max(8, (str.len() / 2) + usize::from(add_bos));
864 let mut buffer = Vec::with_capacity(tokens_estimation);
865
866 let c_string = CString::new(str)?;
867 let buffer_capacity =
868 c_int::try_from(buffer.capacity()).expect("buffer capacity should fit into a c_int");
869
870 let size = unsafe {
871 llama_tokenize(
872 self.get_vocab().vocab.as_ref(),
873 c_string.as_ptr(),
874 c_int::try_from(c_string.as_bytes().len())?,
875 buffer.as_mut_ptr(),
876 buffer_capacity,
877 add_bos,
878 true,
879 )
880 };
881
882 // if we fail the first time we can resize the vector to the correct size and try again. This should never fail.
883 // as a result - size is guaranteed to be positive here.
884 let size = if size.is_negative() {
885 buffer.reserve_exact(usize::try_from(-size).expect("usize's are larger "));
886 unsafe {
887 llama_tokenize(
888 self.get_vocab().vocab.as_ref(),
889 c_string.as_ptr(),
890 c_int::try_from(c_string.as_bytes().len())?,
891 buffer.as_mut_ptr(),
892 -size,
893 add_bos,
894 true,
895 )
896 }
897 } else {
898 size
899 };
900
901 let size = usize::try_from(size).expect("size is positive and usize ");
902
903 // Safety: `size` < `capacity` and llama-cpp has initialized elements up to `size`
904 unsafe { buffer.set_len(size) }
905 Ok(buffer.into_iter().map(LlamaToken).collect())
906 }
907
908 /// Get the type of a token.
909 ///
910 /// This function retrieves the attributes associated with a given token. The attributes are typically used to
911 /// understand whether the token represents a special type of token (e.g., beginning-of-sequence (BOS), end-of-sequence (EOS),
912 /// control tokens, etc.).
913 ///
914 /// # Panics
915 ///
916 /// - This function will panic if the token type is unknown or cannot be converted to a valid `LlamaTokenAttrs`.
917 ///
918 /// # Example
919 ///
920 /// ```no_run
921 /// use llama_cpp_4::model::LlamaModel;
922 /// use llama_cpp_4::model::params::LlamaModelParams;
923 /// use llama_cpp_4::llama_backend::LlamaBackend;
924 /// use llama_cpp_4::token::LlamaToken;
925 ///
926 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
927 /// let backend = LlamaBackend::init()?;
928 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
929 /// let token = LlamaToken::new(42);
930 /// let token_attrs = model.token_attr(token);
931 /// # Ok(())
932 /// # }
933 /// ```
934 #[must_use]
935 pub fn token_attr(&self, LlamaToken(id): LlamaToken) -> LlamaTokenAttrs {
936 let token_type =
937 unsafe { llama_cpp_sys_4::llama_vocab_get_attr(self.get_vocab().vocab.as_ref(), id) };
938 LlamaTokenAttrs::try_from(token_type).expect("token type is valid")
939 }
940
941 /// Detokenize a slice of tokens into a string.
942 ///
943 /// This is the inverse of [`str_to_token`](Self::str_to_token).
944 ///
945 /// # Parameters
946 ///
947 /// - `tokens`: The tokens to detokenize.
948 /// - `remove_special`: If `true`, special tokens are removed from the output.
949 /// - `unparse_special`: If `true`, special tokens are rendered as their text representation.
950 ///
951 /// # Errors
952 ///
953 /// Returns an error if the detokenized text is not valid UTF-8.
954 #[allow(
955 clippy::cast_possible_truncation,
956 clippy::cast_possible_wrap,
957 clippy::cast_sign_loss
958 )]
959 pub fn detokenize(
960 &self,
961 tokens: &[LlamaToken],
962 remove_special: bool,
963 unparse_special: bool,
964 ) -> Result<String, StringFromModelError> {
965 // First call with empty buffer to get required size
966 let n_tokens = tokens.len() as i32;
967 let token_ptr = tokens.as_ptr().cast::<llama_cpp_sys_4::llama_token>();
968 let needed = unsafe {
969 llama_detokenize(
970 self.get_vocab().vocab.as_ref(),
971 token_ptr,
972 n_tokens,
973 std::ptr::null_mut(),
974 0,
975 remove_special,
976 unparse_special,
977 )
978 };
979 // llama_detokenize returns negative required size when buffer is too small
980 let buf_size = if needed < 0 {
981 (-needed) as usize
982 } else {
983 needed as usize
984 };
985 let mut buf = vec![0u8; buf_size];
986 let ret = unsafe {
987 llama_detokenize(
988 self.get_vocab().vocab.as_ref(),
989 token_ptr,
990 n_tokens,
991 buf.as_mut_ptr().cast::<c_char>(),
992 buf_size as i32,
993 remove_special,
994 unparse_special,
995 )
996 };
997 if ret < 0 {
998 return Err(StringFromModelError::ReturnedError(ret));
999 }
1000 let len = ret as usize;
1001 let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
1002 Ok(s.to_owned())
1003 }
1004
1005 /// Convert a token to a string with a specified buffer size.
1006 ///
1007 /// This function allows you to convert a token into a string, with the ability to specify a buffer size for the operation.
1008 /// It is generally recommended to use `LlamaModel::token_to_str` instead, as 8 bytes is typically sufficient for most tokens,
1009 /// and the extra buffer size doesn't usually matter.
1010 ///
1011 /// # Errors
1012 ///
1013 /// - If the token type is unknown, an error will be returned.
1014 /// - If the resultant token exceeds the provided `buffer_size`, an error will occur.
1015 /// - If the token string returned by `llama-cpp` is not valid UTF-8, it will return an error.
1016 ///
1017 /// # Panics
1018 ///
1019 /// - This function will panic if the `buffer_size` does not fit into a `c_int`.
1020 /// - It will also panic if the size returned from `llama-cpp` does not fit into a `usize`, which should typically never happen.
1021 ///
1022 /// # Example
1023 ///
1024 /// ```no_run
1025 /// use llama_cpp_4::model::{LlamaModel, Special};
1026 /// use llama_cpp_4::model::params::LlamaModelParams;
1027 /// use llama_cpp_4::llama_backend::LlamaBackend;
1028 /// use llama_cpp_4::token::LlamaToken;
1029 ///
1030 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1031 /// let backend = LlamaBackend::init()?;
1032 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1033 /// let token = LlamaToken::new(42);
1034 /// let token_string = model.token_to_str_with_size(token, 32, Special::Plaintext)?;
1035 /// # Ok(())
1036 /// # }
1037 /// ```
1038 pub fn token_to_str_with_size(
1039 &self,
1040 token: LlamaToken,
1041 buffer_size: usize,
1042 special: Special,
1043 ) -> Result<String, TokenToStringError> {
1044 let bytes = self.token_to_bytes_with_size(token, buffer_size, special, None)?;
1045 Ok(String::from_utf8(bytes)?)
1046 }
1047
1048 /// Convert a token to bytes with a specified buffer size.
1049 ///
1050 /// Generally you should use [`LlamaModel::token_to_bytes`] instead as 8 bytes is enough for most words and
1051 /// the extra bytes do not really matter.
1052 ///
1053 /// # Errors
1054 ///
1055 /// - if the token type is unknown
1056 /// - the resultant token is larger than `buffer_size`.
1057 ///
1058 /// # Panics
1059 ///
1060 /// - This function will panic if `buffer_size` cannot fit into a `c_int`.
1061 /// - It will also panic if the size returned from `llama-cpp` cannot be converted to `usize` (which should not happen).
1062 ///
1063 /// # Example
1064 ///
1065 /// ```no_run
1066 /// use llama_cpp_4::model::{LlamaModel, Special};
1067 /// use llama_cpp_4::model::params::LlamaModelParams;
1068 /// use llama_cpp_4::llama_backend::LlamaBackend;
1069 /// use llama_cpp_4::token::LlamaToken;
1070 ///
1071 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1072 /// let backend = LlamaBackend::init()?;
1073 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1074 /// let token = LlamaToken::new(42);
1075 /// let token_bytes = model.token_to_bytes_with_size(token, 32, Special::Plaintext, None)?;
1076 /// # Ok(())
1077 /// # }
1078 /// ```
1079 pub fn token_to_bytes_with_size(
1080 &self,
1081 token: LlamaToken,
1082 buffer_size: usize,
1083 special: Special,
1084 lstrip: Option<NonZeroU16>,
1085 ) -> Result<Vec<u8>, TokenToStringError> {
1086 if token == self.token_nl() {
1087 return Ok(String::from("\n").into_bytes());
1088 }
1089
1090 // unsure what to do with this in the face of the 'special' arg + attr changes
1091 let attrs = self.token_attr(token);
1092 if (attrs.contains(LlamaTokenAttr::Control)
1093 && (token == self.token_bos() || token == self.token_eos()))
1094 || attrs.is_empty()
1095 || attrs
1096 .intersects(LlamaTokenAttr::Unknown | LlamaTokenAttr::Byte | LlamaTokenAttr::Unused)
1097 {
1098 return Ok(Vec::new());
1099 }
1100
1101 let special = match special {
1102 Special::Tokenize => true,
1103 Special::Plaintext => false,
1104 };
1105
1106 let string = CString::new(vec![b'*'; buffer_size]).expect("no null");
1107 let len = string.as_bytes().len();
1108 let len = c_int::try_from(len).expect("length fits into c_int");
1109 let buf = string.into_raw();
1110 let lstrip = lstrip.map_or(0, |it| i32::from(it.get()));
1111 let size = unsafe {
1112 llama_token_to_piece(
1113 self.get_vocab().vocab.as_ref(),
1114 token.0,
1115 buf,
1116 len,
1117 lstrip,
1118 special,
1119 )
1120 };
1121
1122 match size {
1123 0 => Err(TokenToStringError::UnknownTokenType),
1124 i if i.is_negative() => Err(TokenToStringError::InsufficientBufferSpace(i)),
1125 size => {
1126 let string = unsafe { CString::from_raw(buf) };
1127 let mut bytes = string.into_bytes();
1128 let len = usize::try_from(size).expect("size is positive and fits into usize");
1129 bytes.truncate(len);
1130 Ok(bytes)
1131 }
1132 }
1133 }
1134 /// The number of tokens the model was trained on.
1135 ///
1136 /// This function returns the number of tokens the model was trained on. It is returned as a `c_int` for maximum
1137 /// compatibility with the underlying llama-cpp library, though it can typically be cast to an `i32` without issue.
1138 ///
1139 /// # Example
1140 ///
1141 /// ```no_run
1142 /// use llama_cpp_4::model::LlamaModel;
1143 /// use llama_cpp_4::model::params::LlamaModelParams;
1144 /// use llama_cpp_4::llama_backend::LlamaBackend;
1145 ///
1146 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1147 /// let backend = LlamaBackend::init()?;
1148 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1149 /// let n_vocab = model.n_vocab();
1150 /// # Ok(())
1151 /// # }
1152 /// ```
1153 #[must_use]
1154 pub fn n_vocab(&self) -> i32 {
1155 self.get_vocab().n_tokens()
1156 }
1157
1158 /// The type of vocab the model was trained on.
1159 ///
1160 /// This function returns the type of vocabulary used by the model, such as whether it is based on byte-pair encoding (BPE),
1161 /// word-level tokens, or another tokenization scheme.
1162 ///
1163 /// # Panics
1164 ///
1165 /// - This function will panic if `llama-cpp` emits a vocab type that is not recognized or is invalid for this library.
1166 ///
1167 /// # Example
1168 ///
1169 /// ```no_run
1170 /// use llama_cpp_4::model::LlamaModel;
1171 /// use llama_cpp_4::model::params::LlamaModelParams;
1172 /// use llama_cpp_4::llama_backend::LlamaBackend;
1173 ///
1174 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1175 /// let backend = LlamaBackend::init()?;
1176 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1177 /// let vocab_type = model.vocab_type();
1178 /// # Ok(())
1179 /// # }
1180 /// ```
1181 #[must_use]
1182 pub fn vocab_type(&self) -> VocabType {
1183 let vocab_type = unsafe { llama_vocab_type(self.get_vocab().vocab.as_ref()) };
1184 VocabType::try_from(vocab_type).expect("invalid vocab type")
1185 }
1186
1187 /// Returns the number of embedding dimensions for the model.
1188 ///
1189 /// This function retrieves the number of embeddings (or embedding dimensions) used by the model. It is typically
1190 /// used for analyzing model architecture and setting up context parameters or other model configuration aspects.
1191 ///
1192 /// # Panics
1193 ///
1194 /// - This function may panic if the underlying `llama-cpp` library returns an invalid embedding dimension value.
1195 ///
1196 /// # Example
1197 ///
1198 /// ```no_run
1199 /// use llama_cpp_4::model::LlamaModel;
1200 /// use llama_cpp_4::model::params::LlamaModelParams;
1201 /// use llama_cpp_4::llama_backend::LlamaBackend;
1202 ///
1203 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1204 /// let backend = LlamaBackend::init()?;
1205 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1206 /// let n_embd = model.n_embd();
1207 /// # Ok(())
1208 /// # }
1209 /// ```
1210 #[must_use]
1211 pub fn n_embd(&self) -> c_int {
1212 unsafe { llama_model_n_embd(self.model.as_ptr()) }
1213 }
1214
1215 /// Get the number of transformer layers in the model.
1216 #[must_use]
1217 pub fn n_layer(&self) -> c_int {
1218 unsafe { llama_model_n_layer(self.model.as_ptr()) }
1219 }
1220
1221 /// Get the number of `NextN` / MTP prediction heads bundled with the model.
1222 ///
1223 /// Returns `0` when the checkpoint has no `NextN` blocks. Multi-head models
1224 /// (e.g. Step3.5) return values greater than `1`; pair with
1225 /// [`crate::context::LlamaContext::set_nextn_layer_offset`] on the draft
1226 /// context. See [`crate::mtp`] for the speculative-decoding workflow.
1227 ///
1228 /// # Examples
1229 ///
1230 /// ```no_run
1231 /// # use llama_cpp_4::llama_backend::LlamaBackend;
1232 /// # use llama_cpp_4::model::{LlamaModel, params::LlamaModelParams};
1233 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1234 /// # let backend = LlamaBackend::init()?;
1235 /// # let model = LlamaModel::load_from_file(&backend, "model.gguf", &LlamaModelParams::default())?;
1236 /// if model.n_layer_nextn() > 0 {
1237 /// println!("MTP model with {} NextN heads", model.n_layer_nextn());
1238 /// }
1239 /// # Ok(())
1240 /// # }
1241 /// ```
1242 #[must_use]
1243 pub fn n_layer_nextn(&self) -> c_int {
1244 unsafe { llama_model_n_layer_nextn(self.model.as_ptr()) }
1245 }
1246
1247 /// Get the number of mixture-of-experts (`MoE`) layers in the model.
1248 ///
1249 /// Returns `0` for dense (non-MoE) checkpoints.
1250 #[must_use]
1251 pub fn n_expert(&self) -> c_int {
1252 unsafe { llama_model_n_expert(self.model.as_ptr()) }
1253 }
1254
1255 /// Number of backend devices the model tensors are spread across.
1256 ///
1257 /// Use with [`Self::get_device`] to inspect each device. Returns `0` when
1258 /// the model is not yet loaded onto any device.
1259 #[must_use]
1260 pub fn n_devices(&self) -> c_int {
1261 unsafe { llama_model_n_devices(self.model.as_ptr()) }
1262 }
1263
1264 /// Get the backend device at `index`.
1265 ///
1266 /// Valid indices satisfy `0 <= index < n_devices()`. Returns `None` for
1267 /// out-of-range indices or when the device pointer is null.
1268 #[must_use]
1269 pub fn get_device(&self, index: i32) -> Option<LlamaBackendDevice> {
1270 if index < 0 || index >= self.n_devices() {
1271 return None;
1272 }
1273 let dev = unsafe { llama_model_get_device(self.model.as_ptr(), index) };
1274 if dev.is_null() {
1275 None
1276 } else {
1277 Some(LlamaBackendDevice { dev })
1278 }
1279 }
1280
1281 /// Iterate backend devices the model tensors are spread across.
1282 ///
1283 /// Equivalent to calling [`Self::get_device`] for `0..self.n_devices()`.
1284 /// Use [`LlamaBackendDevice::memory`] to inspect free/total bytes per device.
1285 #[must_use]
1286 pub fn devices(&self) -> LlamaBackendDevices<'_> {
1287 LlamaBackendDevices {
1288 model: self,
1289 next: 0,
1290 }
1291 }
1292
1293 /// Target-model layer indices stored in this checkpoint.
1294 ///
1295 /// Populated for EAGLE / distillation draft models that record which target
1296 /// layers they were trained against. Returns an empty slice when the
1297 /// metadata is absent.
1298 #[must_use]
1299 pub fn target_layer_ids(&self) -> &[i32] {
1300 let n = unsafe { llama_model_target_layer_ids_n(self.model.as_ptr()) };
1301 if n == 0 {
1302 return &[];
1303 }
1304 let ptr = unsafe { llama_model_target_layer_ids(self.model.as_ptr()) };
1305 if ptr.is_null() {
1306 &[]
1307 } else {
1308 unsafe { slice::from_raw_parts(ptr, n as usize) }
1309 }
1310 }
1311
1312 /// Get the number of attention heads in the model.
1313 #[must_use]
1314 pub fn n_head(&self) -> c_int {
1315 unsafe { llama_model_n_head(self.model.as_ptr()) }
1316 }
1317
1318 /// Get the number of key-value attention heads in the model.
1319 #[must_use]
1320 pub fn n_head_kv(&self) -> c_int {
1321 unsafe { llama_model_n_head_kv(self.model.as_ptr()) }
1322 }
1323
1324 /// Get the input embedding size of the model.
1325 #[must_use]
1326 pub fn n_embd_inp(&self) -> c_int {
1327 unsafe { llama_model_n_embd_inp(self.model.as_ptr()) }
1328 }
1329
1330 /// Get the output embedding size of the model.
1331 #[must_use]
1332 pub fn n_embd_out(&self) -> c_int {
1333 unsafe { llama_model_n_embd_out(self.model.as_ptr()) }
1334 }
1335
1336 /// Get the sliding window attention size of the model.
1337 /// Returns 0 if the model does not use sliding window attention.
1338 #[must_use]
1339 pub fn n_swa(&self) -> c_int {
1340 unsafe { llama_model_n_swa(self.model.as_ptr()) }
1341 }
1342
1343 /// Get the `RoPE` type used by the model.
1344 #[must_use]
1345 pub fn rope_type(&self) -> i32 {
1346 unsafe { llama_model_rope_type(self.model.as_ptr()) }
1347 }
1348
1349 /// Get the `RoPE` frequency scale used during training.
1350 #[must_use]
1351 pub fn rope_freq_scale_train(&self) -> f32 {
1352 unsafe { llama_model_rope_freq_scale_train(self.model.as_ptr()) }
1353 }
1354
1355 /// Get the model size in bytes.
1356 #[must_use]
1357 pub fn model_size(&self) -> u64 {
1358 unsafe { llama_model_size(self.model.as_ptr()) }
1359 }
1360
1361 /// Get the number of parameters in the model.
1362 #[must_use]
1363 pub fn n_params(&self) -> u64 {
1364 unsafe { llama_model_n_params(self.model.as_ptr()) }
1365 }
1366
1367 /// Get the number of classification outputs.
1368 #[must_use]
1369 pub fn n_cls_out(&self) -> u32 {
1370 unsafe { llama_model_n_cls_out(self.model.as_ptr()) }
1371 }
1372
1373 /// Get the classification label for the given index.
1374 ///
1375 /// # Errors
1376 ///
1377 /// Returns an error if the label is null or not valid UTF-8.
1378 pub fn cls_label(&self, index: u32) -> Result<&str, StringFromModelError> {
1379 let ptr = unsafe { llama_model_cls_label(self.model.as_ptr(), index) };
1380 if ptr.is_null() {
1381 return Err(StringFromModelError::ReturnedError(-1));
1382 }
1383 let cstr = unsafe { CStr::from_ptr(ptr) };
1384 cstr.to_str().map_err(StringFromModelError::Utf8Error)
1385 }
1386
1387 /// Get the number of metadata key-value pairs.
1388 #[must_use]
1389 pub fn meta_count(&self) -> c_int {
1390 unsafe { llama_model_meta_count(self.model.as_ptr()) }
1391 }
1392
1393 /// Get a model description string.
1394 ///
1395 /// The `buf_size` parameter specifies the maximum buffer size for the description.
1396 /// A default of 256 bytes is usually sufficient.
1397 ///
1398 /// # Errors
1399 ///
1400 /// Returns an error if the description could not be retrieved or is not valid UTF-8.
1401 #[allow(clippy::cast_sign_loss)]
1402 pub fn desc(&self, buf_size: usize) -> Result<String, StringFromModelError> {
1403 let mut buf = vec![0u8; buf_size];
1404 let ret = unsafe {
1405 llama_model_desc(
1406 self.model.as_ptr(),
1407 buf.as_mut_ptr().cast::<c_char>(),
1408 buf_size,
1409 )
1410 };
1411 if ret < 0 {
1412 return Err(StringFromModelError::ReturnedError(ret));
1413 }
1414 let len = ret as usize;
1415 let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
1416 Ok(s.to_owned())
1417 }
1418
1419 /// Get a metadata key by index.
1420 ///
1421 /// The `buf_size` parameter specifies the maximum buffer size for the key.
1422 /// A default of 256 bytes is usually sufficient.
1423 ///
1424 /// # Errors
1425 ///
1426 /// Returns an error if the index is out of range or the key is not valid UTF-8.
1427 #[allow(clippy::cast_sign_loss)]
1428 pub fn meta_key_by_index(
1429 &self,
1430 index: i32,
1431 buf_size: usize,
1432 ) -> Result<String, StringFromModelError> {
1433 let mut buf = vec![0u8; buf_size];
1434 let ret = unsafe {
1435 llama_model_meta_key_by_index(
1436 self.model.as_ptr(),
1437 index,
1438 buf.as_mut_ptr().cast::<c_char>(),
1439 buf_size,
1440 )
1441 };
1442 if ret < 0 {
1443 return Err(StringFromModelError::ReturnedError(ret));
1444 }
1445 let len = ret as usize;
1446 let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
1447 Ok(s.to_owned())
1448 }
1449
1450 /// Get a metadata value string by index.
1451 ///
1452 /// The `buf_size` parameter specifies the maximum buffer size for the value.
1453 /// Values can be large (e.g. chat templates, token lists), so 4096+ may be needed.
1454 ///
1455 /// # Errors
1456 ///
1457 /// Returns an error if the index is out of range or the value is not valid UTF-8.
1458 #[allow(clippy::cast_sign_loss)]
1459 pub fn meta_val_str_by_index(
1460 &self,
1461 index: i32,
1462 buf_size: usize,
1463 ) -> Result<String, StringFromModelError> {
1464 let mut buf = vec![0u8; buf_size];
1465 let ret = unsafe {
1466 llama_model_meta_val_str_by_index(
1467 self.model.as_ptr(),
1468 index,
1469 buf.as_mut_ptr().cast::<c_char>(),
1470 buf_size,
1471 )
1472 };
1473 if ret < 0 {
1474 return Err(StringFromModelError::ReturnedError(ret));
1475 }
1476 let len = ret as usize;
1477 let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
1478 Ok(s.to_owned())
1479 }
1480
1481 /// Get a metadata value by key name.
1482 ///
1483 /// This is more convenient than iterating metadata by index when you know the key.
1484 /// The `buf_size` parameter specifies the maximum buffer size for the value.
1485 ///
1486 /// # Errors
1487 ///
1488 /// Returns an error if the key is not found, contains a null byte, or the value is not valid UTF-8.
1489 #[allow(clippy::cast_sign_loss)]
1490 pub fn meta_val_str(&self, key: &str, buf_size: usize) -> Result<String, StringFromModelError> {
1491 let c_key = CString::new(key).map_err(|_| StringFromModelError::ReturnedError(-1))?;
1492 let mut buf = vec![0u8; buf_size];
1493 let ret = unsafe {
1494 llama_model_meta_val_str(
1495 self.model.as_ptr(),
1496 c_key.as_ptr(),
1497 buf.as_mut_ptr().cast::<c_char>(),
1498 buf_size,
1499 )
1500 };
1501 if ret < 0 {
1502 return Err(StringFromModelError::ReturnedError(ret));
1503 }
1504 let len = ret as usize;
1505 let s = std::str::from_utf8(&buf[..len]).map_err(StringFromModelError::Utf8Error)?;
1506 Ok(s.to_owned())
1507 }
1508
1509 /// Get all metadata as a list of `(key, value)` pairs.
1510 ///
1511 /// This is a convenience method that iterates over all metadata entries.
1512 /// Keys use a buffer of 256 bytes and values use 4096 bytes.
1513 /// For values that may be larger (e.g. token lists), use
1514 /// [`meta_val_str_by_index`](Self::meta_val_str_by_index) directly with a larger buffer.
1515 ///
1516 /// # Errors
1517 ///
1518 /// Returns an error if any key or value cannot be read or is not valid UTF-8.
1519 #[allow(clippy::cast_sign_loss)]
1520 pub fn metadata(&self) -> Result<Vec<(String, String)>, StringFromModelError> {
1521 let count = self.meta_count();
1522 let mut result = Vec::with_capacity(count as usize);
1523 for i in 0..count {
1524 let key = self.meta_key_by_index(i, 256)?;
1525 let val = self.meta_val_str_by_index(i, 4096)?;
1526 result.push((key, val));
1527 }
1528 Ok(result)
1529 }
1530
1531 /// Check if the model has an encoder.
1532 #[must_use]
1533 pub fn has_encoder(&self) -> bool {
1534 unsafe { llama_model_has_encoder(self.model.as_ptr()) }
1535 }
1536
1537 /// Check if the model has a decoder.
1538 #[must_use]
1539 pub fn has_decoder(&self) -> bool {
1540 unsafe { llama_model_has_decoder(self.model.as_ptr()) }
1541 }
1542
1543 /// Check if the model is recurrent (e.g. Mamba, RWKV).
1544 #[must_use]
1545 pub fn is_recurrent(&self) -> bool {
1546 unsafe { llama_model_is_recurrent(self.model.as_ptr()) }
1547 }
1548
1549 /// Check if the model is a hybrid model.
1550 #[must_use]
1551 pub fn is_hybrid(&self) -> bool {
1552 unsafe { llama_model_is_hybrid(self.model.as_ptr()) }
1553 }
1554
1555 /// Check if the model is a diffusion model.
1556 #[must_use]
1557 pub fn is_diffusion(&self) -> bool {
1558 unsafe { llama_model_is_diffusion(self.model.as_ptr()) }
1559 }
1560
1561 /// Get chat template from model.
1562 ///
1563 /// # Errors
1564 ///
1565 /// - If the model does not have a chat template, it will return an error.
1566 /// - If the chat template is not a valid `CString`, it will return an error.
1567 ///
1568 /// # Example
1569 ///
1570 /// ```no_run
1571 /// use llama_cpp_4::model::LlamaModel;
1572 /// use llama_cpp_4::model::params::LlamaModelParams;
1573 /// use llama_cpp_4::llama_backend::LlamaBackend;
1574 ///
1575 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1576 /// let backend = LlamaBackend::init()?;
1577 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1578 /// let chat_template = model.get_chat_template(1024)?;
1579 /// # Ok(())
1580 /// # }
1581 /// ```
1582 #[allow(clippy::missing_panics_doc)] // We statically know this will not panic as long as the buffer size is sufficient
1583 pub fn get_chat_template(&self, buf_size: usize) -> Result<String, ChatTemplateError> {
1584 // longest known template is about 1200 bytes from llama.cpp
1585 let chat_temp = CString::new(vec![b'*'; buf_size]).expect("no null");
1586 let chat_ptr = chat_temp.into_raw();
1587 let chat_name = CString::new("tokenizer.chat_template").expect("no null bytes");
1588
1589 let ret = unsafe {
1590 llama_model_meta_val_str(self.model.as_ptr(), chat_name.as_ptr(), chat_ptr, buf_size)
1591 };
1592
1593 if ret < 0 {
1594 return Err(ChatTemplateError::MissingTemplate(ret));
1595 }
1596
1597 let template_c = unsafe { CString::from_raw(chat_ptr) };
1598 let template = template_c.to_str()?;
1599
1600 let ret: usize = ret.try_into().unwrap();
1601 if template.len() < ret {
1602 return Err(ChatTemplateError::BuffSizeError(ret + 1));
1603 }
1604
1605 Ok(template.to_owned())
1606 }
1607
1608 /// Loads a model from a file.
1609 ///
1610 /// This function loads a model from a specified file path and returns the corresponding `LlamaModel` instance.
1611 ///
1612 /// # Errors
1613 ///
1614 /// - If the path cannot be converted to a string or if the model file does not exist, it will return an error.
1615 /// - If the model cannot be loaded (e.g., due to an invalid or corrupted model file), it will return a `LlamaModelLoadError`.
1616 ///
1617 /// # Example
1618 ///
1619 /// ```no_run
1620 /// use llama_cpp_4::model::LlamaModel;
1621 /// use llama_cpp_4::model::params::LlamaModelParams;
1622 /// use llama_cpp_4::llama_backend::LlamaBackend;
1623 ///
1624 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1625 /// let backend = LlamaBackend::init()?;
1626 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1627 /// # Ok(())
1628 /// # }
1629 /// ```
1630 #[tracing::instrument(skip_all, fields(params))]
1631 pub fn load_from_file(
1632 _: &LlamaBackend,
1633 path: impl AsRef<Path>,
1634 params: &LlamaModelParams,
1635 ) -> Result<Self, LlamaModelLoadError> {
1636 let path = path.as_ref();
1637 debug_assert!(
1638 Path::new(path).exists(),
1639 "{} does not exist",
1640 path.display()
1641 );
1642 let path = path
1643 .to_str()
1644 .ok_or(LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
1645
1646 let cstr = CString::new(path)?;
1647 let llama_model = unsafe { llama_model_load_from_file(cstr.as_ptr(), params.params) };
1648
1649 let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?;
1650
1651 tracing::debug!(?path, "Loaded model");
1652 Ok(LlamaModel { model })
1653 }
1654
1655 /// Load a model from multiple split files.
1656 ///
1657 /// This function loads a model that has been split across multiple files. This is useful for
1658 /// very large models that exceed filesystem limitations or need to be distributed across
1659 /// multiple storage devices.
1660 ///
1661 /// # Arguments
1662 ///
1663 /// * `paths` - A slice of paths to the split model files
1664 /// * `params` - The model parameters
1665 ///
1666 /// # Errors
1667 ///
1668 /// Returns an error if:
1669 /// - Any of the paths cannot be converted to a C string
1670 /// - The model fails to load from the splits
1671 /// - Any path doesn't exist or isn't accessible
1672 ///
1673 /// # Example
1674 ///
1675 /// ```no_run
1676 /// use llama_cpp_4::model::{LlamaModel, params::LlamaModelParams};
1677 /// use llama_cpp_4::llama_backend::LlamaBackend;
1678 ///
1679 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1680 /// let backend = LlamaBackend::init()?;
1681 /// let params = LlamaModelParams::default();
1682 ///
1683 /// let paths = vec![
1684 /// "model-00001-of-00003.gguf",
1685 /// "model-00002-of-00003.gguf",
1686 /// "model-00003-of-00003.gguf",
1687 /// ];
1688 ///
1689 /// let model = LlamaModel::load_from_splits(&backend, &paths, ¶ms)?;
1690 /// # Ok(())
1691 /// # }
1692 /// ```
1693 #[tracing::instrument(skip_all)]
1694 pub fn load_from_splits(
1695 _: &LlamaBackend,
1696 paths: &[impl AsRef<Path>],
1697 params: &LlamaModelParams,
1698 ) -> Result<Self, LlamaModelLoadError> {
1699 // Convert paths to C strings
1700 let c_strings: Vec<CString> = paths
1701 .iter()
1702 .map(|p| {
1703 let path = p.as_ref();
1704 debug_assert!(path.exists(), "{} does not exist", path.display());
1705 let path_str = path
1706 .to_str()
1707 .ok_or(LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
1708 CString::new(path_str).map_err(LlamaModelLoadError::from)
1709 })
1710 .collect::<Result<Vec<_>, _>>()?;
1711
1712 // Create array of pointers to C strings
1713 let c_ptrs: Vec<*const c_char> = c_strings.iter().map(|s| s.as_ptr()).collect();
1714
1715 // Load the model from splits
1716 let llama_model = unsafe {
1717 llama_model_load_from_splits(c_ptrs.as_ptr().cast_mut(), c_ptrs.len(), params.params)
1718 };
1719
1720 let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?;
1721
1722 tracing::debug!("Loaded model from {} splits", paths.len());
1723 Ok(LlamaModel { model })
1724 }
1725
1726 /// Load a model from a `FILE` pointer.
1727 ///
1728 /// # Safety
1729 ///
1730 /// The `file` pointer must be a valid, open `FILE*`.
1731 ///
1732 /// # Errors
1733 ///
1734 /// Returns an error if the model cannot be loaded.
1735 pub unsafe fn load_from_file_ptr(
1736 file: *mut llama_cpp_sys_4::FILE,
1737 params: &LlamaModelParams,
1738 ) -> Result<Self, LlamaModelLoadError> {
1739 let model = llama_cpp_sys_4::llama_model_load_from_file_ptr(file, params.params);
1740 let model = NonNull::new(model).ok_or(LlamaModelLoadError::NullResult)?;
1741 Ok(LlamaModel { model })
1742 }
1743
1744 /// Initialize a model from user-provided data.
1745 ///
1746 /// # Safety
1747 ///
1748 /// The metadata, callback, and user data must be valid.
1749 ///
1750 /// # Errors
1751 ///
1752 /// Returns an error if the model cannot be initialized.
1753 pub unsafe fn init_from_user(
1754 metadata: *mut llama_cpp_sys_4::gguf_context,
1755 set_tensor_data: llama_cpp_sys_4::llama_model_set_tensor_data_t,
1756 set_tensor_data_ud: *mut std::ffi::c_void,
1757 params: &LlamaModelParams,
1758 ) -> Result<Self, LlamaModelLoadError> {
1759 let model = llama_cpp_sys_4::llama_model_init_from_user(
1760 metadata,
1761 set_tensor_data,
1762 set_tensor_data_ud,
1763 params.params,
1764 );
1765 let model = NonNull::new(model).ok_or(LlamaModelLoadError::NullResult)?;
1766 Ok(LlamaModel { model })
1767 }
1768
1769 /// Save the model to a file.
1770 ///
1771 /// # Panics
1772 ///
1773 /// Panics if the path contains null bytes.
1774 pub fn save_to_file(&self, path: impl AsRef<Path>) {
1775 let path = path.as_ref();
1776 let path_str = path.to_str().expect("path is not valid UTF-8");
1777 let c_path = CString::new(path_str).expect("path contains null bytes");
1778 unsafe {
1779 llama_model_save_to_file(self.model.as_ptr(), c_path.as_ptr());
1780 }
1781 }
1782
1783 /// Get the list of built-in chat templates.
1784 ///
1785 /// Returns the names of all chat templates that are built into llama.cpp.
1786 ///
1787 /// # Panics
1788 ///
1789 /// Panics if any template name is not valid UTF-8.
1790 #[allow(clippy::cast_sign_loss)]
1791 #[must_use]
1792 pub fn chat_builtin_templates() -> Vec<String> {
1793 // First call to get count
1794 let count = unsafe { llama_chat_builtin_templates(std::ptr::null_mut(), 0) };
1795 if count <= 0 {
1796 return Vec::new();
1797 }
1798 let count = count as usize;
1799 let mut ptrs: Vec<*const c_char> = vec![std::ptr::null(); count];
1800 unsafe {
1801 llama_chat_builtin_templates(ptrs.as_mut_ptr(), count);
1802 }
1803 ptrs.iter()
1804 .map(|&p| {
1805 let cstr = unsafe { CStr::from_ptr(p) };
1806 cstr.to_str()
1807 .expect("template name is not valid UTF-8")
1808 .to_owned()
1809 })
1810 .collect()
1811 }
1812
1813 /// Initializes a lora adapter from a file.
1814 ///
1815 /// This function initializes a Lora adapter, which is a model extension used to adapt or fine-tune the existing model
1816 /// to a specific domain or task. The adapter file is typically in the form of a binary or serialized file that can be applied
1817 /// to the model for improved performance on specialized tasks.
1818 ///
1819 /// # Errors
1820 ///
1821 /// - If the adapter file path cannot be converted to a string or if the adapter cannot be initialized, it will return an error.
1822 ///
1823 /// # Example
1824 ///
1825 /// ```no_run
1826 /// use llama_cpp_4::model::LlamaModel;
1827 /// use llama_cpp_4::model::params::LlamaModelParams;
1828 /// use llama_cpp_4::llama_backend::LlamaBackend;
1829 ///
1830 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1831 /// let backend = LlamaBackend::init()?;
1832 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1833 /// let adapter = model.lora_adapter_init("path/to/lora/adapter")?;
1834 /// # Ok(())
1835 /// # }
1836 /// ```
1837 pub fn lora_adapter_init(
1838 &self,
1839 path: impl AsRef<Path>,
1840 ) -> Result<LlamaLoraAdapter, LlamaLoraAdapterInitError> {
1841 let path = path.as_ref();
1842 debug_assert!(
1843 Path::new(path).exists(),
1844 "{} does not exist",
1845 path.display()
1846 );
1847
1848 let path = path
1849 .to_str()
1850 .ok_or(LlamaLoraAdapterInitError::PathToStrError(
1851 path.to_path_buf(),
1852 ))?;
1853
1854 let cstr = CString::new(path)?;
1855 let adapter = unsafe { llama_adapter_lora_init(self.model.as_ptr(), cstr.as_ptr()) };
1856
1857 let adapter = NonNull::new(adapter).ok_or(LlamaLoraAdapterInitError::NullResult)?;
1858
1859 tracing::debug!(?path, "Initialized lora adapter");
1860 Ok(LlamaLoraAdapter {
1861 lora_adapter: adapter,
1862 })
1863 }
1864
1865 /// Create a new context from this model.
1866 ///
1867 /// This function creates a new context for the model, which is used to manage and perform computations for inference,
1868 /// including token generation, embeddings, and other tasks that the model can perform. The context allows fine-grained
1869 /// control over model parameters for a specific task.
1870 ///
1871 /// # Errors
1872 ///
1873 /// - There are various potential failures such as invalid parameters or a failure to allocate the context. See [`LlamaContextLoadError`]
1874 /// for more detailed error descriptions.
1875 ///
1876 /// # Example
1877 ///
1878 /// ```no_run
1879 /// use llama_cpp_4::model::LlamaModel;
1880 /// use llama_cpp_4::model::params::LlamaModelParams;
1881 /// use llama_cpp_4::context::params::LlamaContextParams;
1882 /// use llama_cpp_4::llama_backend::LlamaBackend;
1883 ///
1884 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1885 /// let backend = LlamaBackend::init()?;
1886 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1887 /// let context = model.new_context(&backend, LlamaContextParams::default())?;
1888 /// # Ok(())
1889 /// # }
1890 /// ```
1891 #[allow(clippy::needless_pass_by_value)]
1892 pub fn new_context(
1893 &self,
1894 _: &LlamaBackend,
1895 params: LlamaContextParams,
1896 ) -> Result<LlamaContext<'_>, LlamaContextLoadError> {
1897 // Apply TurboQuant attn-rotation preference before the KV cache is
1898 // initialised inside llama_init_from_model.
1899 let prev_rot_var = std::env::var("LLAMA_ATTN_ROT_DISABLE").ok();
1900 if params.attn_rot_disabled {
1901 // SAFETY: we restore the value right after the call.
1902 #[allow(unused_unsafe)]
1903 unsafe {
1904 std::env::set_var("LLAMA_ATTN_ROT_DISABLE", "1");
1905 }
1906 } else if std::env::var("LLAMA_ATTN_ROT_DISABLE").is_ok() {
1907 // params say "enabled" – only clear if it was previously unset
1908 // (respect explicit user env var).
1909 }
1910
1911 let context_params = params.context_params;
1912 let context = unsafe { llama_init_from_model(self.model.as_ptr(), context_params) };
1913
1914 // Restore the env-var to its previous state.
1915 #[allow(unused_unsafe)]
1916 match prev_rot_var {
1917 Some(v) => unsafe { std::env::set_var("LLAMA_ATTN_ROT_DISABLE", v) },
1918 None if params.attn_rot_disabled => unsafe {
1919 std::env::remove_var("LLAMA_ATTN_ROT_DISABLE");
1920 },
1921 None => {}
1922 }
1923
1924 let context = NonNull::new(context).ok_or(LlamaContextLoadError::NullReturn)?;
1925 Ok(LlamaContext::new(self, context, params.embeddings()))
1926 }
1927
1928 /// Apply the model's chat template to a sequence of messages.
1929 ///
1930 /// This function applies the model's chat template to the provided chat messages, formatting them accordingly. The chat
1931 /// template determines the structure or style of conversation between the system and user, such as token formatting,
1932 /// role separation, and more. The template can be customized by providing an optional template string, or if `None`
1933 /// is provided, the default template used by `llama.cpp` will be applied.
1934 ///
1935 /// For more information on supported templates, visit:
1936 /// <https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template>
1937 ///
1938 /// # Arguments
1939 ///
1940 /// - `tmpl`: An optional custom template string. If `None`, the default template will be used.
1941 /// - `chat`: A vector of `LlamaChatMessage` instances, which represent the conversation between the system and user.
1942 /// - `add_ass`: A boolean flag indicating whether additional system-specific instructions (like "assistant") should be included.
1943 ///
1944 /// # Errors
1945 ///
1946 /// There are several possible points of failure when applying the chat template:
1947 /// - Insufficient buffer size to hold the formatted chat (this will return `ApplyChatTemplateError::BuffSizeError`).
1948 /// - If the template or messages cannot be processed properly, various errors from `ApplyChatTemplateError` may occur.
1949 ///
1950 /// # Example
1951 ///
1952 /// ```no_run
1953 /// use llama_cpp_4::model::{LlamaModel, LlamaChatMessage};
1954 /// use llama_cpp_4::model::params::LlamaModelParams;
1955 /// use llama_cpp_4::llama_backend::LlamaBackend;
1956 ///
1957 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1958 /// let backend = LlamaBackend::init()?;
1959 /// let model = LlamaModel::load_from_file(&backend, "path/to/model", &LlamaModelParams::default())?;
1960 /// let chat = vec![
1961 /// LlamaChatMessage::new("user".to_string(), "Hello!".to_string())?,
1962 /// LlamaChatMessage::new("assistant".to_string(), "Hi! How can I assist you today?".to_string())?,
1963 /// ];
1964 /// let formatted_chat = model.apply_chat_template(None, &chat, true)?;
1965 /// # Ok(())
1966 /// # }
1967 /// ```
1968 ///
1969 /// # Notes
1970 ///
1971 /// The provided buffer is twice the length of the messages by default, which is recommended by the `llama.cpp` documentation.
1972 /// # Panics
1973 ///
1974 /// Panics if the buffer length exceeds `i32::MAX`.
1975 #[tracing::instrument(skip_all)]
1976 pub fn apply_chat_template(
1977 &self,
1978 tmpl: Option<&str>,
1979 chat: &[LlamaChatMessage],
1980 add_ass: bool,
1981 ) -> Result<String, ApplyChatTemplateError> {
1982 // Compute raw message byte total from the original LlamaChatMessage vec
1983 // *before* we shadow `chat` with the sys-type vec below.
1984 let message_length = chat.iter().fold(0usize, |acc, c| {
1985 acc + c.role.to_bytes().len() + c.content.to_bytes().len()
1986 });
1987
1988 // Build our llama_cpp_sys chat messages (raw pointers into CStrings).
1989 let chat_sys: Vec<llama_chat_message> = chat
1990 .iter()
1991 .map(|c| llama_chat_message {
1992 role: c.role.as_ptr(),
1993 content: c.content.as_ptr(),
1994 })
1995 .collect();
1996
1997 // Set the tmpl pointer.
1998 let tmpl_cstring = tmpl.map(CString::new).transpose()?;
1999 let tmpl_ptr = tmpl_cstring
2000 .as_ref()
2001 .map_or(std::ptr::null(), |s| s.as_ptr());
2002
2003 // `message_length * 4` is far too small for models whose built-in chat
2004 // template adds a long default system prompt (e.g. Qwen3.5 prepends
2005 // ~80+ chars of markup even for a one-word user message). Start with
2006 // at least 4 KiB so short inputs like "hi" always have room.
2007 //
2008 // `llama_chat_apply_template` returns the number of bytes it *actually*
2009 // needed when the buffer was too small, so we retry exactly once with
2010 // that precise size rather than giving up immediately.
2011 let mut buf_size = message_length.saturating_mul(4).max(4096);
2012
2013 for _ in 0..2 {
2014 // Use u8 so that as_mut_ptr()/as_ptr() match the binding (*mut u8 / *const u8).
2015 let mut buff = vec![0u8; buf_size];
2016 let res = unsafe {
2017 llama_chat_apply_template(
2018 tmpl_ptr,
2019 chat_sys.as_ptr(),
2020 chat_sys.len(),
2021 add_ass,
2022 buff.as_mut_ptr().cast(),
2023 i32::try_from(buff.len()).expect("buffer length fits in i32"),
2024 )
2025 };
2026
2027 if res < 0 {
2028 return Err(ApplyChatTemplateError::BuffSizeError);
2029 }
2030
2031 #[allow(clippy::cast_sign_loss)]
2032 let needed = res as usize;
2033 if needed > buf_size {
2034 // Buffer was too small — retry with the exact size llama.cpp reported.
2035 buf_size = needed + 1; // +1 for null terminator
2036 continue;
2037 }
2038
2039 // SAFETY: llama_chat_apply_template wrote a NUL-terminated string
2040 // into `buff`; `needed` bytes were used.
2041 let formatted = unsafe {
2042 CStr::from_ptr(buff.as_ptr().cast())
2043 .to_string_lossy()
2044 .into_owned()
2045 };
2046 return Ok(formatted);
2047 }
2048
2049 Err(ApplyChatTemplateError::BuffSizeError)
2050 }
2051
2052 /// Build a split GGUF file path for a specific chunk.
2053 ///
2054 /// This utility function creates the standardized filename for a split model chunk
2055 /// following the pattern: `{prefix}-{split_no:05d}-of-{split_count:05d}.gguf`
2056 ///
2057 /// # Arguments
2058 ///
2059 /// * `path_prefix` - The base path and filename prefix
2060 /// * `split_no` - The split number (1-indexed)
2061 /// * `split_count` - The total number of splits
2062 ///
2063 /// # Returns
2064 ///
2065 /// Returns the formatted split path as a String
2066 ///
2067 /// # Example
2068 ///
2069 /// ```
2070 /// use llama_cpp_4::model::LlamaModel;
2071 ///
2072 /// let path = LlamaModel::split_path("/models/llama", 1, 4);
2073 /// assert_eq!(path, "/models/llama-00002-of-00004.gguf");
2074 /// ```
2075 ///
2076 /// # Panics
2077 ///
2078 /// Panics if the path prefix contains a null byte.
2079 #[must_use]
2080 pub fn split_path(path_prefix: &str, split_no: i32, split_count: i32) -> String {
2081 let mut buffer = vec![0u8; 1024];
2082 let len = unsafe {
2083 llama_split_path(
2084 buffer.as_mut_ptr().cast::<c_char>(),
2085 buffer.len(),
2086 CString::new(path_prefix).unwrap().as_ptr(),
2087 split_no,
2088 split_count,
2089 )
2090 };
2091
2092 let len = usize::try_from(len).expect("split_path length fits in usize");
2093 buffer.truncate(len);
2094 String::from_utf8(buffer).unwrap_or_default()
2095 }
2096
2097 /// Extract the path prefix from a split filename.
2098 ///
2099 /// This function extracts the base path prefix from a split model filename,
2100 /// but only if the `split_no` and `split_count` match the pattern in the filename.
2101 ///
2102 /// # Arguments
2103 ///
2104 /// * `split_path` - The full path to the split file
2105 /// * `split_no` - The expected split number
2106 /// * `split_count` - The expected total number of splits
2107 ///
2108 /// # Returns
2109 ///
2110 /// Returns the path prefix if the pattern matches, or None if it doesn't
2111 ///
2112 /// # Example
2113 ///
2114 /// ```
2115 /// use llama_cpp_4::model::LlamaModel;
2116 ///
2117 /// let prefix = LlamaModel::split_prefix("/models/llama-00002-of-00004.gguf", 1, 4);
2118 /// assert_eq!(prefix, Some("/models/llama".to_string()));
2119 /// ```
2120 ///
2121 /// # Panics
2122 ///
2123 /// Panics if the split path contains a null byte.
2124 #[must_use]
2125 pub fn split_prefix(split_path: &str, split_no: i32, split_count: i32) -> Option<String> {
2126 let mut buffer = vec![0u8; 1024];
2127 let len = unsafe {
2128 llama_split_prefix(
2129 buffer.as_mut_ptr().cast::<c_char>(),
2130 buffer.len(),
2131 CString::new(split_path).unwrap().as_ptr(),
2132 split_no,
2133 split_count,
2134 )
2135 };
2136
2137 if len > 0 {
2138 let len = usize::try_from(len).expect("split_prefix length fits in usize");
2139 buffer.truncate(len);
2140 String::from_utf8(buffer).ok()
2141 } else {
2142 None
2143 }
2144 }
2145}
2146
2147#[allow(clippy::cast_precision_loss)]
2148impl fmt::Display for LlamaModel {
2149 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2150 let desc = self.desc(256).unwrap_or_else(|_| "unknown".to_string());
2151 write!(
2152 f,
2153 "{desc} | {layers}L {heads}H {embd}E | {params} params | {size:.1} MiB",
2154 layers = self.n_layer(),
2155 heads = self.n_head(),
2156 embd = self.n_embd(),
2157 params = self.n_params(),
2158 size = self.model_size() as f64 / (1024.0 * 1024.0),
2159 )
2160 }
2161}
2162
2163impl Drop for LlamaModel {
2164 fn drop(&mut self) {
2165 unsafe { llama_model_free(self.model.as_ptr()) }
2166 }
2167}
2168
2169/// Defines the possible types of vocabulary used by the model.
2170///
2171/// The model may use different types of vocabulary depending on the tokenization method chosen during training.
2172/// This enum represents these types, specifically `BPE` (Byte Pair Encoding) and `SPM` (`SentencePiece`).
2173///
2174/// # Variants
2175///
2176/// - `BPE`: Byte Pair Encoding, a common tokenization method used in NLP tasks.
2177/// - `SPM`: `SentencePiece`, another popular tokenization method for NLP models.
2178///
2179/// # Example
2180///
2181/// ```no_run
2182/// use llama_cpp_4::model::VocabType;
2183///
2184/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
2185/// let vocab_type = VocabType::BPE;
2186/// match vocab_type {
2187/// VocabType::BPE => println!("The model uses Byte Pair Encoding (BPE)"),
2188/// VocabType::SPM => println!("The model uses SentencePiece (SPM)"),
2189/// }
2190/// # Ok(())
2191/// # }
2192/// ```
2193#[repr(u32)]
2194#[derive(Debug, Eq, Copy, Clone, PartialEq)]
2195pub enum VocabType {
2196 /// Byte Pair Encoding
2197 BPE = LLAMA_VOCAB_TYPE_BPE as _,
2198 /// Sentence Piece Tokenizer
2199 SPM = LLAMA_VOCAB_TYPE_SPM as _,
2200}
2201
2202/// Error that occurs when trying to convert a `llama_vocab_type` to a `VocabType`.
2203///
2204/// This error is raised when the integer value returned by the system does not correspond to a known vocabulary type.
2205///
2206/// # Variants
2207///
2208/// - `UnknownValue`: The error is raised when the value is not a valid `llama_vocab_type`. The invalid value is returned with the error.
2209///
2210/// # Example
2211///
2212/// ```no_run
2213/// use llama_cpp_4::model::LlamaTokenTypeFromIntError;
2214///
2215/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
2216/// let invalid_value = 999; // Not a valid vocabulary type
2217/// let error = LlamaTokenTypeFromIntError::UnknownValue(invalid_value);
2218/// println!("Error: {}", error);
2219/// # Ok(())
2220/// # }
2221/// ```
2222#[derive(thiserror::Error, Debug, Eq, PartialEq)]
2223pub enum LlamaTokenTypeFromIntError {
2224 /// The value is not a valid `llama_token_type`. Contains the int value that was invalid.
2225 #[error("Unknown Value {0}")]
2226 UnknownValue(llama_vocab_type),
2227}
2228
2229impl TryFrom<llama_vocab_type> for VocabType {
2230 type Error = LlamaTokenTypeFromIntError;
2231
2232 fn try_from(value: llama_vocab_type) -> Result<Self, Self::Error> {
2233 match value {
2234 LLAMA_VOCAB_TYPE_BPE => Ok(VocabType::BPE),
2235 LLAMA_VOCAB_TYPE_SPM => Ok(VocabType::SPM),
2236 unknown => Err(LlamaTokenTypeFromIntError::UnknownValue(unknown)),
2237 }
2238 }
2239}