llama_cpp_4/model.rs
1//! A safe wrapper around `llama_model`.
2use std::ffi::CStr;
3use std::ffi::CString;
4use std::num::NonZeroU16;
5use std::os::raw::{c_char, c_int};
6use std::path::Path;
7use std::ptr::NonNull;
8
9use llama_cpp_sys_4::{
10 llama_adapter_lora, llama_adapter_lora_init, llama_chat_apply_template, llama_chat_message,
11 llama_free_model, llama_load_model_from_file, llama_model, llama_model_decoder_start_token,
12 llama_model_get_vocab, llama_model_load_from_splits, llama_model_meta_val_str,
13 llama_n_ctx_train, llama_n_embd, llama_n_vocab, llama_new_context_with_model, llama_split_path,
14 llama_split_prefix, llama_token_bos, llama_token_eos, llama_token_get_attr, llama_token_is_eog,
15 llama_token_nl, llama_token_to_piece, llama_tokenize, llama_vocab, llama_vocab_type,
16 LLAMA_VOCAB_TYPE_BPE, LLAMA_VOCAB_TYPE_SPM,
17};
18
19use crate::context::params::LlamaContextParams;
20use crate::context::LlamaContext;
21use crate::llama_backend::LlamaBackend;
22use crate::model::params::LlamaModelParams;
23use crate::token::LlamaToken;
24use crate::token_type::{LlamaTokenAttr, LlamaTokenAttrs};
25use crate::{
26 ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaLoraAdapterInitError,
27 LlamaModelLoadError, NewLlamaChatMessageError, StringToTokenError, TokenToStringError,
28};
29
30pub mod params;
31
32/// A safe wrapper around `llama_model`.
33#[derive(Debug)]
34#[repr(transparent)]
35#[allow(clippy::module_name_repetitions)]
36pub struct LlamaModel {
37 pub(crate) model: NonNull<llama_model>,
38}
39
40/// A safe wrapper around `llama_vocab`.
41#[derive(Debug)]
42#[repr(transparent)]
43#[allow(clippy::module_name_repetitions)]
44pub struct LlamaVocab {
45 pub(crate) vocab: NonNull<llama_vocab>,
46}
47
48/// A safe wrapper around `llama_adapter_lora`.
49#[derive(Debug)]
50#[repr(transparent)]
51#[allow(clippy::module_name_repetitions)]
52pub struct LlamaLoraAdapter {
53 pub(crate) lora_adapter: NonNull<llama_adapter_lora>,
54}
55
56/// A Safe wrapper around `llama_chat_message`
57#[derive(Debug, Eq, PartialEq, Clone)]
58pub struct LlamaChatMessage {
59 role: CString,
60 content: CString,
61}
62
63impl LlamaChatMessage {
64 /// Create a new `LlamaChatMessage`.
65 ///
66 /// # Errors
67 ///
68 /// Returns [`NewLlamaChatMessageError`] if the role or content contains a null byte.
69 pub fn new(role: String, content: String) -> Result<Self, NewLlamaChatMessageError> {
70 Ok(Self {
71 role: CString::new(role)?,
72 content: CString::new(content)?,
73 })
74 }
75}
76
77/// How to determine if we should prepend a bos token to tokens
78#[derive(Debug, Clone, Copy, PartialEq, Eq)]
79pub enum AddBos {
80 /// Add the beginning of stream token to the start of the string.
81 Always,
82 /// Do not add the beginning of stream token to the start of the string.
83 Never,
84}
85
86/// How to determine if we should tokenize special tokens
87#[derive(Debug, Clone, Copy, PartialEq, Eq)]
88pub enum Special {
89 /// Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space.
90 Tokenize,
91 /// Treat special and/or control tokens as plaintext.
92 Plaintext,
93}
94
95unsafe impl Send for LlamaModel {}
96
97unsafe impl Sync for LlamaModel {}
98
99impl LlamaModel {
100 /// Retrieves the vocabulary associated with the current Llama model.
101 ///
102 /// This method fetches the vocabulary from the underlying model using an unsafe
103 /// FFI call. The returned `LlamaVocab` struct contains a non-null pointer to
104 /// the vocabulary data, which is wrapped in a `NonNull` for safety.
105 ///
106 /// # Safety
107 /// This method uses an unsafe block to call a C function (`llama_model_get_vocab`),
108 /// which is assumed to return a valid pointer to the vocabulary. The caller should
109 /// ensure that the model object is properly initialized and valid before calling
110 /// this method, as dereferencing invalid pointers can lead to undefined behavior.
111 ///
112 /// # Returns
113 /// A `LlamaVocab` struct containing the vocabulary of the model.
114 ///
115 /// # Panics
116 ///
117 /// Panics if the underlying C function returns a null pointer.
118 ///
119 /// # Example
120 /// ```rust,ignore
121 /// let vocab = model.get_vocab();
122 /// ```
123 #[must_use]
124 pub fn get_vocab(&self) -> LlamaVocab {
125 let llama_vocab = unsafe { llama_model_get_vocab(self.model.as_ptr()) }.cast_mut();
126
127 LlamaVocab {
128 vocab: NonNull::new(llama_vocab).unwrap(),
129 }
130 }
131 /// Get the number of tokens the model was trained on.
132 ///
133 /// This function returns the number of tokens that the model was trained on, represented as a `u32`.
134 ///
135 /// # Panics
136 ///
137 /// This function will panic if the number of tokens the model was trained on does not fit into a `u32`.
138 /// This should be impossible on most platforms since llama.cpp returns a `c_int` (i32 on most platforms),
139 /// which is almost certainly positive.
140 #[must_use]
141 pub fn n_ctx_train(&self) -> u32 {
142 let n_ctx_train = unsafe { llama_n_ctx_train(self.model.as_ptr()) };
143 u32::try_from(n_ctx_train).expect("n_ctx_train fits into an u32")
144 }
145
146 /// Get all tokens in the model.
147 ///
148 /// This function returns an iterator over all the tokens in the model. Each item in the iterator is a tuple
149 /// containing a `LlamaToken` and its corresponding string representation (or an error if the conversion fails).
150 ///
151 /// # Parameters
152 ///
153 /// - `special`: The `Special` value that determines how special tokens (like BOS, EOS, etc.) are handled.
154 pub fn tokens(
155 &self,
156 special: Special,
157 ) -> impl Iterator<Item = (LlamaToken, Result<String, TokenToStringError>)> + '_ {
158 (0..self.n_vocab())
159 .map(LlamaToken::new)
160 .map(move |llama_token| (llama_token, self.token_to_str(llama_token, special)))
161 }
162
163 /// Get the beginning of stream token.
164 ///
165 /// This function returns the token that represents the beginning of a stream (BOS token).
166 #[must_use]
167 pub fn token_bos(&self) -> LlamaToken {
168 let token = unsafe { llama_token_bos(self.get_vocab().vocab.as_ref()) };
169 LlamaToken(token)
170 }
171
172 /// Get the end of stream token.
173 ///
174 /// This function returns the token that represents the end of a stream (EOS token).
175 #[must_use]
176 pub fn token_eos(&self) -> LlamaToken {
177 let token = unsafe { llama_token_eos(self.get_vocab().vocab.as_ref()) };
178 LlamaToken(token)
179 }
180
181 /// Get the newline token.
182 ///
183 /// This function returns the token that represents a newline character.
184 #[must_use]
185 pub fn token_nl(&self) -> LlamaToken {
186 let token = unsafe { llama_token_nl(self.get_vocab().vocab.as_ref()) };
187 LlamaToken(token)
188 }
189
190 /// Check if a token represents the end of generation (end of turn, end of sequence, etc.).
191 ///
192 /// This function returns `true` if the provided token signifies the end of generation or end of sequence,
193 /// such as EOS or other special tokens.
194 ///
195 /// # Parameters
196 ///
197 /// - `token`: The `LlamaToken` to check.
198 ///
199 /// # Returns
200 ///
201 /// - `true` if the token is an end-of-generation token, otherwise `false`.
202 #[must_use]
203 pub fn is_eog_token(&self, token: LlamaToken) -> bool {
204 unsafe { llama_token_is_eog(self.get_vocab().vocab.as_ref(), token.0) }
205 }
206
207 /// Get the decoder start token.
208 ///
209 /// This function returns the token used to signal the start of decoding (i.e., the token used at the start
210 /// of a sequence generation).
211 #[must_use]
212 pub fn decode_start_token(&self) -> LlamaToken {
213 let token = unsafe { llama_model_decoder_start_token(self.model.as_ptr()) };
214 LlamaToken(token)
215 }
216
217 /// Convert a single token to a string.
218 ///
219 /// This function converts a `LlamaToken` into its string representation.
220 ///
221 /// # Errors
222 ///
223 /// This function returns an error if the token cannot be converted to a string. For more details, refer to
224 /// [`TokenToStringError`].
225 ///
226 /// # Parameters
227 ///
228 /// - `token`: The `LlamaToken` to convert.
229 /// - `special`: The `Special` value used to handle special tokens.
230 pub fn token_to_str(
231 &self,
232 token: LlamaToken,
233 special: Special,
234 ) -> Result<String, TokenToStringError> {
235 self.token_to_str_with_size(token, 32, special)
236 }
237
238 /// Convert a single token to bytes.
239 ///
240 /// This function converts a `LlamaToken` into a byte representation.
241 ///
242 /// # Errors
243 ///
244 /// This function returns an error if the token cannot be converted to bytes. For more details, refer to
245 /// [`TokenToStringError`].
246 ///
247 /// # Parameters
248 ///
249 /// - `token`: The `LlamaToken` to convert.
250 /// - `special`: The `Special` value used to handle special tokens.
251 pub fn token_to_bytes(
252 &self,
253 token: LlamaToken,
254 special: Special,
255 ) -> Result<Vec<u8>, TokenToStringError> {
256 self.token_to_bytes_with_size(token, 32, special, None)
257 }
258
259 /// Convert a vector of tokens to a single string.
260 ///
261 /// This function takes a slice of `LlamaToken`s and converts them into a single string, concatenating their
262 /// string representations.
263 ///
264 /// # Errors
265 ///
266 /// This function returns an error if any token cannot be converted to a string. For more details, refer to
267 /// [`TokenToStringError`].
268 ///
269 /// # Parameters
270 ///
271 /// - `tokens`: A slice of `LlamaToken`s to convert.
272 /// - `special`: The `Special` value used to handle special tokens.
273 pub fn tokens_to_str(
274 &self,
275 tokens: &[LlamaToken],
276 special: Special,
277 ) -> Result<String, TokenToStringError> {
278 let mut builder = String::with_capacity(tokens.len() * 4);
279 for str in tokens
280 .iter()
281 .copied()
282 .map(|t| self.token_to_str(t, special))
283 {
284 builder += &str?;
285 }
286 Ok(builder)
287 }
288
289 /// Convert a string to a vector of tokens.
290 ///
291 /// This function converts a string into a vector of `LlamaToken`s. The function will tokenize the string
292 /// and return the corresponding tokens.
293 ///
294 /// # Errors
295 ///
296 /// - This function will return an error if the input string contains a null byte.
297 ///
298 /// # Panics
299 ///
300 /// - This function will panic if the number of tokens exceeds `usize::MAX`.
301 ///
302 /// # Example
303 ///
304 /// ```no_run
305 /// use llama_cpp_4::model::LlamaModel;
306 ///
307 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
308 /// use std::path::Path;
309 /// use llama_cpp_4::model::AddBos;
310 /// let backend = llama_cpp_4::llama_backend::LlamaBackend::init()?;
311 /// let model = LlamaModel::load_from_file(&backend, Path::new("path/to/model"), &Default::default())?;
312 /// let tokens = model.str_to_token("Hello, World!", AddBos::Always)?;
313 /// # Ok(())
314 /// # }
315 /// ```
316 pub fn str_to_token(
317 &self,
318 str: &str,
319 add_bos: AddBos,
320 ) -> Result<Vec<LlamaToken>, StringToTokenError> {
321 let add_bos = match add_bos {
322 AddBos::Always => true,
323 AddBos::Never => false,
324 };
325
326 let tokens_estimation = std::cmp::max(8, (str.len() / 2) + usize::from(add_bos));
327 let mut buffer = Vec::with_capacity(tokens_estimation);
328
329 let c_string = CString::new(str)?;
330 let buffer_capacity =
331 c_int::try_from(buffer.capacity()).expect("buffer capacity should fit into a c_int");
332
333 let size = unsafe {
334 llama_tokenize(
335 self.get_vocab().vocab.as_ref(),
336 c_string.as_ptr(),
337 c_int::try_from(c_string.as_bytes().len())?,
338 buffer.as_mut_ptr(),
339 buffer_capacity,
340 add_bos,
341 true,
342 )
343 };
344
345 // if we fail the first time we can resize the vector to the correct size and try again. This should never fail.
346 // as a result - size is guaranteed to be positive here.
347 let size = if size.is_negative() {
348 buffer.reserve_exact(usize::try_from(-size).expect("usize's are larger "));
349 unsafe {
350 llama_tokenize(
351 self.get_vocab().vocab.as_ref(),
352 c_string.as_ptr(),
353 c_int::try_from(c_string.as_bytes().len())?,
354 buffer.as_mut_ptr(),
355 -size,
356 add_bos,
357 true,
358 )
359 }
360 } else {
361 size
362 };
363
364 let size = usize::try_from(size).expect("size is positive and usize ");
365
366 // Safety: `size` < `capacity` and llama-cpp has initialized elements up to `size`
367 unsafe { buffer.set_len(size) }
368 Ok(buffer.into_iter().map(LlamaToken).collect())
369 }
370
371 /// Get the type of a token.
372 ///
373 /// This function retrieves the attributes associated with a given token. The attributes are typically used to
374 /// understand whether the token represents a special type of token (e.g., beginning-of-sequence (BOS), end-of-sequence (EOS),
375 /// control tokens, etc.).
376 ///
377 /// # Panics
378 ///
379 /// - This function will panic if the token type is unknown or cannot be converted to a valid `LlamaTokenAttrs`.
380 ///
381 /// # Example
382 ///
383 /// ```no_run
384 /// use llama_cpp_4::model::{LlamaModel, LlamaToken};
385 ///
386 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
387 /// let model = LlamaModel::load_from_file("path/to/model")?;
388 /// let token = LlamaToken(42);
389 /// let token_attrs = model.token_attr(token);
390 /// # Ok(())
391 /// # }
392 /// ```
393 #[must_use]
394 pub fn token_attr(&self, LlamaToken(id): LlamaToken) -> LlamaTokenAttrs {
395 let token_type = unsafe { llama_token_get_attr(self.get_vocab().vocab.as_ref(), id) };
396 LlamaTokenAttrs::try_from(token_type).expect("token type is valid")
397 }
398
399 /// Convert a token to a string with a specified buffer size.
400 ///
401 /// This function allows you to convert a token into a string, with the ability to specify a buffer size for the operation.
402 /// It is generally recommended to use `LlamaModel::token_to_str` instead, as 8 bytes is typically sufficient for most tokens,
403 /// and the extra buffer size doesn't usually matter.
404 ///
405 /// # Errors
406 ///
407 /// - If the token type is unknown, an error will be returned.
408 /// - If the resultant token exceeds the provided `buffer_size`, an error will occur.
409 /// - If the token string returned by `llama-cpp` is not valid UTF-8, it will return an error.
410 ///
411 /// # Panics
412 ///
413 /// - This function will panic if the `buffer_size` does not fit into a `c_int`.
414 /// - It will also panic if the size returned from `llama-cpp` does not fit into a `usize`, which should typically never happen.
415 ///
416 /// # Example
417 ///
418 /// ```no_run
419 /// use llama_cpp_4::model::{LlamaModel, LlamaToken};
420 ///
421 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
422 /// let model = LlamaModel::load_from_file("path/to/model")?;
423 /// let token = LlamaToken(42);
424 /// let token_string = model.token_to_str_with_size(token, 32, Special::Plaintext)?;
425 /// # Ok(())
426 /// # }
427 /// ```
428 pub fn token_to_str_with_size(
429 &self,
430 token: LlamaToken,
431 buffer_size: usize,
432 special: Special,
433 ) -> Result<String, TokenToStringError> {
434 let bytes = self.token_to_bytes_with_size(token, buffer_size, special, None)?;
435 Ok(String::from_utf8(bytes)?)
436 }
437
438 /// Convert a token to bytes with a specified buffer size.
439 ///
440 /// Generally you should use [`LlamaModel::token_to_bytes`] instead as 8 bytes is enough for most words and
441 /// the extra bytes do not really matter.
442 ///
443 /// # Errors
444 ///
445 /// - if the token type is unknown
446 /// - the resultant token is larger than `buffer_size`.
447 ///
448 /// # Panics
449 ///
450 /// - This function will panic if `buffer_size` cannot fit into a `c_int`.
451 /// - It will also panic if the size returned from `llama-cpp` cannot be converted to `usize` (which should not happen).
452 ///
453 /// # Example
454 ///
455 /// ```no_run
456 /// use llama_cpp_4::model::{LlamaModel, LlamaToken};
457 ///
458 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
459 /// let model = LlamaModel::load_from_file("path/to/model")?;
460 /// let token = LlamaToken(42);
461 /// let token_bytes = model.token_to_bytes_with_size(token, 32, Special::Plaintext, None)?;
462 /// # Ok(())
463 /// # }
464 /// ```
465 pub fn token_to_bytes_with_size(
466 &self,
467 token: LlamaToken,
468 buffer_size: usize,
469 special: Special,
470 lstrip: Option<NonZeroU16>,
471 ) -> Result<Vec<u8>, TokenToStringError> {
472 if token == self.token_nl() {
473 return Ok(String::from("\n").into_bytes());
474 }
475
476 // unsure what to do with this in the face of the 'special' arg + attr changes
477 let attrs = self.token_attr(token);
478 if (attrs.contains(LlamaTokenAttr::Control)
479 && (token == self.token_bos() || token == self.token_eos()))
480 || attrs.is_empty()
481 || attrs
482 .intersects(LlamaTokenAttr::Unknown | LlamaTokenAttr::Byte | LlamaTokenAttr::Unused)
483 {
484 return Ok(Vec::new());
485 }
486
487 let special = match special {
488 Special::Tokenize => true,
489 Special::Plaintext => false,
490 };
491
492 let string = CString::new(vec![b'*'; buffer_size]).expect("no null");
493 let len = string.as_bytes().len();
494 let len = c_int::try_from(len).expect("length fits into c_int");
495 let buf = string.into_raw();
496 let lstrip = lstrip.map_or(0, |it| i32::from(it.get()));
497 let size = unsafe {
498 llama_token_to_piece(
499 self.get_vocab().vocab.as_ref(),
500 token.0,
501 buf,
502 len,
503 lstrip,
504 special,
505 )
506 };
507
508 match size {
509 0 => Err(TokenToStringError::UnknownTokenType),
510 i if i.is_negative() => Err(TokenToStringError::InsufficientBufferSpace(i)),
511 size => {
512 let string = unsafe { CString::from_raw(buf) };
513 let mut bytes = string.into_bytes();
514 let len = usize::try_from(size).expect("size is positive and fits into usize");
515 bytes.truncate(len);
516 Ok(bytes)
517 }
518 }
519 }
520 /// The number of tokens the model was trained on.
521 ///
522 /// This function returns the number of tokens the model was trained on. It is returned as a `c_int` for maximum
523 /// compatibility with the underlying llama-cpp library, though it can typically be cast to an `i32` without issue.
524 ///
525 /// # Example
526 ///
527 /// ```no_run
528 /// use llama_cpp_4::model::LlamaModel;
529 ///
530 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
531 /// let model = LlamaModel::load_from_file("path/to/model")?;
532 /// let n_vocab = model.n_vocab();
533 /// # Ok(())
534 /// # }
535 /// ```
536 #[must_use]
537 pub fn n_vocab(&self) -> i32 {
538 unsafe { llama_n_vocab(self.get_vocab().vocab.as_ref()) }
539 }
540
541 /// The type of vocab the model was trained on.
542 ///
543 /// This function returns the type of vocabulary used by the model, such as whether it is based on byte-pair encoding (BPE),
544 /// word-level tokens, or another tokenization scheme.
545 ///
546 /// # Panics
547 ///
548 /// - This function will panic if `llama-cpp` emits a vocab type that is not recognized or is invalid for this library.
549 ///
550 /// # Example
551 ///
552 /// ```no_run
553 /// use llama_cpp_4::model::LlamaModel;
554 ///
555 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
556 /// let model = LlamaModel::load_from_file("path/to/model")?;
557 /// let vocab_type = model.vocab_type();
558 /// # Ok(())
559 /// # }
560 /// ```
561 #[must_use]
562 pub fn vocab_type(&self) -> VocabType {
563 let vocab_type = unsafe { llama_vocab_type(self.get_vocab().vocab.as_ref()) };
564 VocabType::try_from(vocab_type).expect("invalid vocab type")
565 }
566
567 /// Returns the number of embedding dimensions for the model.
568 ///
569 /// This function retrieves the number of embeddings (or embedding dimensions) used by the model. It is typically
570 /// used for analyzing model architecture and setting up context parameters or other model configuration aspects.
571 ///
572 /// # Panics
573 ///
574 /// - This function may panic if the underlying `llama-cpp` library returns an invalid embedding dimension value.
575 ///
576 /// # Example
577 ///
578 /// ```no_run
579 /// use llama_cpp_4::model::LlamaModel;
580 ///
581 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
582 /// let model = LlamaModel::load_from_file("path/to/model")?;
583 /// let n_embd = model.n_embd();
584 /// # Ok(())
585 /// # }
586 /// ```
587 #[must_use]
588 pub fn n_embd(&self) -> c_int {
589 unsafe { llama_n_embd(self.model.as_ptr()) }
590 }
591
592 /// Get chat template from model.
593 ///
594 /// # Errors
595 ///
596 /// - If the model does not have a chat template, it will return an error.
597 /// - If the chat template is not a valid `CString`, it will return an error.
598 ///
599 /// # Example
600 ///
601 /// ```no_run
602 /// use llama_cpp_4::model::LlamaModel;
603 ///
604 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
605 /// let model = LlamaModel::load_from_file("path/to/model")?;
606 /// let chat_template = model.get_chat_template(1024)?;
607 /// # Ok(())
608 /// # }
609 /// ```
610 #[allow(clippy::missing_panics_doc)] // We statically know this will not panic as long as the buffer size is sufficient
611 pub fn get_chat_template(&self, buf_size: usize) -> Result<String, ChatTemplateError> {
612 // longest known template is about 1200 bytes from llama.cpp
613 let chat_temp = CString::new(vec![b'*'; buf_size]).expect("no null");
614 let chat_ptr = chat_temp.into_raw();
615 let chat_name = CString::new("tokenizer.chat_template").expect("no null bytes");
616
617 let ret = unsafe {
618 llama_model_meta_val_str(self.model.as_ptr(), chat_name.as_ptr(), chat_ptr, buf_size)
619 };
620
621 if ret < 0 {
622 return Err(ChatTemplateError::MissingTemplate(ret));
623 }
624
625 let template_c = unsafe { CString::from_raw(chat_ptr) };
626 let template = template_c.to_str()?;
627
628 let ret: usize = ret.try_into().unwrap();
629 if template.len() < ret {
630 return Err(ChatTemplateError::BuffSizeError(ret + 1));
631 }
632
633 Ok(template.to_owned())
634 }
635
636 /// Loads a model from a file.
637 ///
638 /// This function loads a model from a specified file path and returns the corresponding `LlamaModel` instance.
639 ///
640 /// # Errors
641 ///
642 /// - If the path cannot be converted to a string or if the model file does not exist, it will return an error.
643 /// - If the model cannot be loaded (e.g., due to an invalid or corrupted model file), it will return a `LlamaModelLoadError`.
644 ///
645 /// # Example
646 ///
647 /// ```no_run
648 /// use llama_cpp_4::model::LlamaModel;
649 /// use std::path::Path;
650 ///
651 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
652 /// let model = LlamaModel::load_from_file("path/to/model", &LlamaModelParams::default())?;
653 /// # Ok(())
654 /// # }
655 /// ```
656 #[tracing::instrument(skip_all, fields(params))]
657 pub fn load_from_file(
658 _: &LlamaBackend,
659 path: impl AsRef<Path>,
660 params: &LlamaModelParams,
661 ) -> Result<Self, LlamaModelLoadError> {
662 let path = path.as_ref();
663 debug_assert!(
664 Path::new(path).exists(),
665 "{} does not exist",
666 path.display()
667 );
668 let path = path
669 .to_str()
670 .ok_or(LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
671
672 let cstr = CString::new(path)?;
673 let llama_model = unsafe { llama_load_model_from_file(cstr.as_ptr(), params.params) };
674
675 let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?;
676
677 tracing::debug!(?path, "Loaded model");
678 Ok(LlamaModel { model })
679 }
680
681 /// Load a model from multiple split files.
682 ///
683 /// This function loads a model that has been split across multiple files. This is useful for
684 /// very large models that exceed filesystem limitations or need to be distributed across
685 /// multiple storage devices.
686 ///
687 /// # Arguments
688 ///
689 /// * `paths` - A slice of paths to the split model files
690 /// * `params` - The model parameters
691 ///
692 /// # Errors
693 ///
694 /// Returns an error if:
695 /// - Any of the paths cannot be converted to a C string
696 /// - The model fails to load from the splits
697 /// - Any path doesn't exist or isn't accessible
698 ///
699 /// # Example
700 ///
701 /// ```no_run
702 /// use llama_cpp_4::model::{LlamaModel, params::LlamaModelParams};
703 /// use llama_cpp_4::llama_backend::LlamaBackend;
704 ///
705 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
706 /// let backend = LlamaBackend::init()?;
707 /// let params = LlamaModelParams::default();
708 ///
709 /// let paths = vec![
710 /// "model-00001-of-00003.gguf",
711 /// "model-00002-of-00003.gguf",
712 /// "model-00003-of-00003.gguf",
713 /// ];
714 ///
715 /// let model = LlamaModel::load_from_splits(&backend, &paths, ¶ms)?;
716 /// # Ok(())
717 /// # }
718 /// ```
719 #[tracing::instrument(skip_all)]
720 pub fn load_from_splits(
721 _: &LlamaBackend,
722 paths: &[impl AsRef<Path>],
723 params: &LlamaModelParams,
724 ) -> Result<Self, LlamaModelLoadError> {
725 // Convert paths to C strings
726 let c_strings: Vec<CString> = paths
727 .iter()
728 .map(|p| {
729 let path = p.as_ref();
730 debug_assert!(path.exists(), "{} does not exist", path.display());
731 let path_str = path
732 .to_str()
733 .ok_or(LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
734 CString::new(path_str).map_err(LlamaModelLoadError::from)
735 })
736 .collect::<Result<Vec<_>, _>>()?;
737
738 // Create array of pointers to C strings
739 let c_ptrs: Vec<*const c_char> = c_strings.iter().map(|s| s.as_ptr()).collect();
740
741 // Load the model from splits
742 let llama_model = unsafe {
743 llama_model_load_from_splits(c_ptrs.as_ptr().cast_mut(), c_ptrs.len(), params.params)
744 };
745
746 let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?;
747
748 tracing::debug!("Loaded model from {} splits", paths.len());
749 Ok(LlamaModel { model })
750 }
751
752 /// Initializes a lora adapter from a file.
753 ///
754 /// This function initializes a Lora adapter, which is a model extension used to adapt or fine-tune the existing model
755 /// to a specific domain or task. The adapter file is typically in the form of a binary or serialized file that can be applied
756 /// to the model for improved performance on specialized tasks.
757 ///
758 /// # Errors
759 ///
760 /// - If the adapter file path cannot be converted to a string or if the adapter cannot be initialized, it will return an error.
761 ///
762 /// # Example
763 ///
764 /// ```no_run
765 /// use llama_cpp_4::model::{LlamaModel, LlamaLoraAdapter};
766 /// use std::path::Path;
767 ///
768 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
769 /// let model = LlamaModel::load_from_file("path/to/model", &LlamaModelParams::default())?;
770 /// let adapter = model.lora_adapter_init("path/to/lora/adapter")?;
771 /// # Ok(())
772 /// # }
773 /// ```
774 pub fn lora_adapter_init(
775 &self,
776 path: impl AsRef<Path>,
777 ) -> Result<LlamaLoraAdapter, LlamaLoraAdapterInitError> {
778 let path = path.as_ref();
779 debug_assert!(
780 Path::new(path).exists(),
781 "{} does not exist",
782 path.display()
783 );
784
785 let path = path
786 .to_str()
787 .ok_or(LlamaLoraAdapterInitError::PathToStrError(
788 path.to_path_buf(),
789 ))?;
790
791 let cstr = CString::new(path)?;
792 let adapter = unsafe { llama_adapter_lora_init(self.model.as_ptr(), cstr.as_ptr()) };
793
794 let adapter = NonNull::new(adapter).ok_or(LlamaLoraAdapterInitError::NullResult)?;
795
796 tracing::debug!(?path, "Initialized lora adapter");
797 Ok(LlamaLoraAdapter {
798 lora_adapter: adapter,
799 })
800 }
801
802 /// Create a new context from this model.
803 ///
804 /// This function creates a new context for the model, which is used to manage and perform computations for inference,
805 /// including token generation, embeddings, and other tasks that the model can perform. The context allows fine-grained
806 /// control over model parameters for a specific task.
807 ///
808 /// # Errors
809 ///
810 /// - There are various potential failures such as invalid parameters or a failure to allocate the context. See [`LlamaContextLoadError`]
811 /// for more detailed error descriptions.
812 ///
813 /// # Example
814 ///
815 /// ```no_run
816 /// use llama_cpp_4::model::{LlamaModel, LlamaContext};
817 /// use llama_cpp_4::LlamaContextParams;
818 ///
819 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
820 /// let model = LlamaModel::load_from_file("path/to/model", &LlamaModelParams::default())?;
821 /// let context = model.new_context(&LlamaBackend::init()?, LlamaContextParams::default())?;
822 /// # Ok(())
823 /// # }
824 /// ```
825 #[allow(clippy::needless_pass_by_value)]
826 pub fn new_context(
827 &self,
828 _: &LlamaBackend,
829 params: LlamaContextParams,
830 ) -> Result<LlamaContext<'_>, LlamaContextLoadError> {
831 let context_params = params.context_params;
832 let context = unsafe { llama_new_context_with_model(self.model.as_ptr(), context_params) };
833 let context = NonNull::new(context).ok_or(LlamaContextLoadError::NullReturn)?;
834
835 Ok(LlamaContext::new(self, context, params.embeddings()))
836 }
837
838 /// Apply the model's chat template to a sequence of messages.
839 ///
840 /// This function applies the model's chat template to the provided chat messages, formatting them accordingly. The chat
841 /// template determines the structure or style of conversation between the system and user, such as token formatting,
842 /// role separation, and more. The template can be customized by providing an optional template string, or if `None`
843 /// is provided, the default template used by `llama.cpp` will be applied.
844 ///
845 /// For more information on supported templates, visit:
846 /// <https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template>
847 ///
848 /// # Arguments
849 ///
850 /// - `tmpl`: An optional custom template string. If `None`, the default template will be used.
851 /// - `chat`: A vector of `LlamaChatMessage` instances, which represent the conversation between the system and user.
852 /// - `add_ass`: A boolean flag indicating whether additional system-specific instructions (like "assistant") should be included.
853 ///
854 /// # Errors
855 ///
856 /// There are several possible points of failure when applying the chat template:
857 /// - Insufficient buffer size to hold the formatted chat (this will return `ApplyChatTemplateError::BuffSizeError`).
858 /// - If the template or messages cannot be processed properly, various errors from `ApplyChatTemplateError` may occur.
859 ///
860 /// # Example
861 ///
862 /// ```no_run
863 /// use llama_cpp_4::model::{LlamaModel, LlamaChatMessage};
864 ///
865 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
866 /// let model = LlamaModel::load_from_file("path/to/model", &LlamaModelParams::default())?;
867 /// let chat = vec![
868 /// LlamaChatMessage::new("user", "Hello!"),
869 /// LlamaChatMessage::new("assistant", "Hi! How can I assist you today?"),
870 /// ];
871 /// let formatted_chat = model.apply_chat_template(None, chat, true)?;
872 /// # Ok(())
873 /// # }
874 /// ```
875 ///
876 /// # Notes
877 ///
878 /// The provided buffer is twice the length of the messages by default, which is recommended by the `llama.cpp` documentation.
879 /// # Panics
880 ///
881 /// Panics if the buffer length exceeds `i32::MAX`.
882 #[tracing::instrument(skip_all)]
883 pub fn apply_chat_template(
884 &self,
885 tmpl: Option<&str>,
886 chat: &[LlamaChatMessage],
887 add_ass: bool,
888 ) -> Result<String, ApplyChatTemplateError> {
889 // Compute raw message byte total from the original LlamaChatMessage vec
890 // *before* we shadow `chat` with the sys-type vec below.
891 let message_length = chat.iter().fold(0usize, |acc, c| {
892 acc + c.role.to_bytes().len() + c.content.to_bytes().len()
893 });
894
895 // Build our llama_cpp_sys chat messages (raw pointers into CStrings).
896 let chat_sys: Vec<llama_chat_message> = chat
897 .iter()
898 .map(|c| llama_chat_message {
899 role: c.role.as_ptr(),
900 content: c.content.as_ptr(),
901 })
902 .collect();
903
904 // Set the tmpl pointer.
905 let tmpl_cstring = tmpl.map(CString::new).transpose()?;
906 let tmpl_ptr = tmpl_cstring
907 .as_ref()
908 .map_or(std::ptr::null(), |s| s.as_ptr());
909
910 // `message_length * 4` is far too small for models whose built-in chat
911 // template adds a long default system prompt (e.g. Qwen3.5 prepends
912 // ~80+ chars of markup even for a one-word user message). Start with
913 // at least 4 KiB so short inputs like "hi" always have room.
914 //
915 // `llama_chat_apply_template` returns the number of bytes it *actually*
916 // needed when the buffer was too small, so we retry exactly once with
917 // that precise size rather than giving up immediately.
918 let mut buf_size = message_length.saturating_mul(4).max(4096);
919
920 for _ in 0..2 {
921 // Use u8 so that as_mut_ptr()/as_ptr() match the binding (*mut u8 / *const u8).
922 let mut buff = vec![0u8; buf_size];
923 let res = unsafe {
924 llama_chat_apply_template(
925 tmpl_ptr,
926 chat_sys.as_ptr(),
927 chat_sys.len(),
928 add_ass,
929 buff.as_mut_ptr().cast(),
930 i32::try_from(buff.len()).expect("buffer length fits in i32"),
931 )
932 };
933
934 if res < 0 {
935 return Err(ApplyChatTemplateError::BuffSizeError);
936 }
937
938 #[allow(clippy::cast_sign_loss)]
939 let needed = res as usize;
940 if needed > buf_size {
941 // Buffer was too small — retry with the exact size llama.cpp reported.
942 buf_size = needed + 1; // +1 for null terminator
943 continue;
944 }
945
946 // SAFETY: llama_chat_apply_template wrote a NUL-terminated string
947 // into `buff`; `needed` bytes were used.
948 let formatted = unsafe {
949 CStr::from_ptr(buff.as_ptr().cast())
950 .to_string_lossy()
951 .into_owned()
952 };
953 return Ok(formatted);
954 }
955
956 Err(ApplyChatTemplateError::BuffSizeError)
957 }
958
959 /// Build a split GGUF file path for a specific chunk.
960 ///
961 /// This utility function creates the standardized filename for a split model chunk
962 /// following the pattern: `{prefix}-{split_no:05d}-of-{split_count:05d}.gguf`
963 ///
964 /// # Arguments
965 ///
966 /// * `path_prefix` - The base path and filename prefix
967 /// * `split_no` - The split number (1-indexed)
968 /// * `split_count` - The total number of splits
969 ///
970 /// # Returns
971 ///
972 /// Returns the formatted split path as a String
973 ///
974 /// # Example
975 ///
976 /// ```
977 /// use llama_cpp_4::model::LlamaModel;
978 ///
979 /// let path = LlamaModel::split_path("/models/llama", 2, 4);
980 /// assert_eq!(path, "/models/llama-00002-of-00004.gguf");
981 /// ```
982 ///
983 /// # Panics
984 ///
985 /// Panics if the path prefix contains a null byte.
986 #[must_use]
987 pub fn split_path(path_prefix: &str, split_no: i32, split_count: i32) -> String {
988 let mut buffer = vec![0u8; 1024];
989 let len = unsafe {
990 llama_split_path(
991 buffer.as_mut_ptr().cast::<c_char>(),
992 buffer.len(),
993 CString::new(path_prefix).unwrap().as_ptr(),
994 split_no,
995 split_count,
996 )
997 };
998
999 let len = usize::try_from(len).expect("split_path length fits in usize");
1000 buffer.truncate(len);
1001 String::from_utf8(buffer).unwrap_or_default()
1002 }
1003
1004 /// Extract the path prefix from a split filename.
1005 ///
1006 /// This function extracts the base path prefix from a split model filename,
1007 /// but only if the `split_no` and `split_count` match the pattern in the filename.
1008 ///
1009 /// # Arguments
1010 ///
1011 /// * `split_path` - The full path to the split file
1012 /// * `split_no` - The expected split number
1013 /// * `split_count` - The expected total number of splits
1014 ///
1015 /// # Returns
1016 ///
1017 /// Returns the path prefix if the pattern matches, or None if it doesn't
1018 ///
1019 /// # Example
1020 ///
1021 /// ```
1022 /// use llama_cpp_4::model::LlamaModel;
1023 ///
1024 /// let prefix = LlamaModel::split_prefix("/models/llama-00002-of-00004.gguf", 2, 4);
1025 /// assert_eq!(prefix, Some("/models/llama".to_string()));
1026 /// ```
1027 ///
1028 /// # Panics
1029 ///
1030 /// Panics if the split path contains a null byte.
1031 #[must_use]
1032 pub fn split_prefix(split_path: &str, split_no: i32, split_count: i32) -> Option<String> {
1033 let mut buffer = vec![0u8; 1024];
1034 let len = unsafe {
1035 llama_split_prefix(
1036 buffer.as_mut_ptr().cast::<c_char>(),
1037 buffer.len(),
1038 CString::new(split_path).unwrap().as_ptr(),
1039 split_no,
1040 split_count,
1041 )
1042 };
1043
1044 if len > 0 {
1045 let len = usize::try_from(len).expect("split_prefix length fits in usize");
1046 buffer.truncate(len);
1047 String::from_utf8(buffer).ok()
1048 } else {
1049 None
1050 }
1051 }
1052}
1053
1054impl Drop for LlamaModel {
1055 fn drop(&mut self) {
1056 unsafe { llama_free_model(self.model.as_ptr()) }
1057 }
1058}
1059
1060/// Defines the possible types of vocabulary used by the model.
1061///
1062/// The model may use different types of vocabulary depending on the tokenization method chosen during training.
1063/// This enum represents these types, specifically `BPE` (Byte Pair Encoding) and `SPM` (`SentencePiece`).
1064///
1065/// # Variants
1066///
1067/// - `BPE`: Byte Pair Encoding, a common tokenization method used in NLP tasks.
1068/// - `SPM`: `SentencePiece`, another popular tokenization method for NLP models.
1069///
1070/// # Example
1071///
1072/// ```no_run
1073/// use llama_cpp_4::model::VocabType;
1074///
1075/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1076/// let vocab_type = VocabType::BPE;
1077/// match vocab_type {
1078/// VocabType::BPE => println!("The model uses Byte Pair Encoding (BPE)"),
1079/// VocabType::SPM => println!("The model uses SentencePiece (SPM)"),
1080/// }
1081/// # Ok(())
1082/// # }
1083/// ```
1084#[repr(u32)]
1085#[derive(Debug, Eq, Copy, Clone, PartialEq)]
1086pub enum VocabType {
1087 /// Byte Pair Encoding
1088 BPE = LLAMA_VOCAB_TYPE_BPE as _,
1089 /// Sentence Piece Tokenizer
1090 SPM = LLAMA_VOCAB_TYPE_SPM as _,
1091}
1092
1093/// Error that occurs when trying to convert a `llama_vocab_type` to a `VocabType`.
1094///
1095/// This error is raised when the integer value returned by the system does not correspond to a known vocabulary type.
1096///
1097/// # Variants
1098///
1099/// - `UnknownValue`: The error is raised when the value is not a valid `llama_vocab_type`. The invalid value is returned with the error.
1100///
1101/// # Example
1102///
1103/// ```no_run
1104/// use llama_cpp_4::model::LlamaTokenTypeFromIntError;
1105///
1106/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
1107/// let invalid_value = 999; // Not a valid vocabulary type
1108/// let error = LlamaTokenTypeFromIntError::UnknownValue(invalid_value);
1109/// println!("Error: {}", error);
1110/// # Ok(())
1111/// # }
1112/// ```
1113#[derive(thiserror::Error, Debug, Eq, PartialEq)]
1114pub enum LlamaTokenTypeFromIntError {
1115 /// The value is not a valid `llama_token_type`. Contains the int value that was invalid.
1116 #[error("Unknown Value {0}")]
1117 UnknownValue(llama_vocab_type),
1118}
1119
1120impl TryFrom<llama_vocab_type> for VocabType {
1121 type Error = LlamaTokenTypeFromIntError;
1122
1123 fn try_from(value: llama_vocab_type) -> Result<Self, Self::Error> {
1124 match value {
1125 LLAMA_VOCAB_TYPE_BPE => Ok(VocabType::BPE),
1126 LLAMA_VOCAB_TYPE_SPM => Ok(VocabType::SPM),
1127 unknown => Err(LlamaTokenTypeFromIntError::UnknownValue(unknown)),
1128 }
1129 }
1130}