1use std::ffi::{c_char, CStr, CString};
3use std::num::NonZeroU16;
4use std::os::raw::c_int;
5use std::path::Path;
6use std::ptr::NonNull;
7use std::str::Utf8Error;
8
9use crate::context::params::LlamaContextParams;
10use crate::context::LlamaContext;
11use crate::llama_backend::LlamaBackend;
12use crate::model::params::LlamaModelParams;
13use crate::token::LlamaToken;
14use crate::token_type::{LlamaTokenAttr, LlamaTokenAttrs};
15use crate::{
16 ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaLoraAdapterInitError,
17 LlamaModelLoadError, MetaValError, NewLlamaChatMessageError, StringToTokenError,
18 TokenToStringError,
19};
20
21pub mod params;
22
23#[derive(Debug)]
25#[repr(transparent)]
26#[allow(clippy::module_name_repetitions)]
27pub struct LlamaModel {
28 pub(crate) model: NonNull<llama_cpp_sys_2::llama_model>,
29}
30
31#[derive(Debug)]
33#[repr(transparent)]
34#[allow(clippy::module_name_repetitions)]
35pub struct LlamaLoraAdapter {
36 pub(crate) lora_adapter: NonNull<llama_cpp_sys_2::llama_adapter_lora>,
37}
38
39#[derive(Eq, PartialEq, Clone, PartialOrd, Ord, Hash)]
44pub struct LlamaChatTemplate(CString);
45
46impl LlamaChatTemplate {
47 pub fn new(template: &str) -> Result<Self, std::ffi::NulError> {
50 Ok(Self(CString::new(template)?))
51 }
52
53 pub fn as_c_str(&self) -> &CStr {
55 &self.0
56 }
57
58 pub fn to_str(&self) -> Result<&str, Utf8Error> {
60 self.0.to_str()
61 }
62
63 pub fn to_string(&self) -> Result<String, Utf8Error> {
65 self.to_str().map(str::to_string)
66 }
67}
68
69impl std::fmt::Debug for LlamaChatTemplate {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 self.0.fmt(f)
72 }
73}
74
75#[derive(Debug, Eq, PartialEq, Clone)]
77pub struct LlamaChatMessage {
78 role: CString,
79 content: CString,
80}
81
82impl LlamaChatMessage {
83 pub fn new(role: String, content: String) -> Result<Self, NewLlamaChatMessageError> {
88 Ok(Self {
89 role: CString::new(role)?,
90 content: CString::new(content)?,
91 })
92 }
93}
94
95#[derive(Debug, Clone, Copy, PartialEq, Eq)]
97pub enum RopeType {
98 Norm,
99 NeoX,
100 MRope,
101 Vision,
102}
103
104#[derive(Debug, Clone, Copy, PartialEq, Eq)]
106pub enum AddBos {
107 Always,
109 Never,
111}
112
113#[derive(Debug, Clone, Copy, PartialEq, Eq)]
115pub enum Special {
116 Tokenize,
118 Plaintext,
120}
121
122unsafe impl Send for LlamaModel {}
123
124unsafe impl Sync for LlamaModel {}
125
126impl LlamaModel {
127 pub(crate) fn vocab_ptr(&self) -> *const llama_cpp_sys_2::llama_vocab {
128 unsafe { llama_cpp_sys_2::llama_model_get_vocab(self.model.as_ptr()) }
129 }
130
131 #[must_use]
138 pub fn n_ctx_train(&self) -> u32 {
139 let n_ctx_train = unsafe { llama_cpp_sys_2::llama_n_ctx_train(self.model.as_ptr()) };
140 u32::try_from(n_ctx_train).expect("n_ctx_train fits into an u32")
141 }
142
143 pub fn tokens(
145 &self,
146 special: Special,
147 ) -> impl Iterator<Item = (LlamaToken, Result<String, TokenToStringError>)> + '_ {
148 (0..self.n_vocab())
149 .map(LlamaToken::new)
150 .map(move |llama_token| (llama_token, self.token_to_str(llama_token, special)))
151 }
152
153 #[must_use]
155 pub fn token_bos(&self) -> LlamaToken {
156 let token = unsafe { llama_cpp_sys_2::llama_token_bos(self.vocab_ptr()) };
157 LlamaToken(token)
158 }
159
160 #[must_use]
162 pub fn token_eos(&self) -> LlamaToken {
163 let token = unsafe { llama_cpp_sys_2::llama_token_eos(self.vocab_ptr()) };
164 LlamaToken(token)
165 }
166
167 #[must_use]
169 pub fn token_nl(&self) -> LlamaToken {
170 let token = unsafe { llama_cpp_sys_2::llama_token_nl(self.vocab_ptr()) };
171 LlamaToken(token)
172 }
173
174 #[must_use]
176 pub fn is_eog_token(&self, token: LlamaToken) -> bool {
177 unsafe { llama_cpp_sys_2::llama_token_is_eog(self.vocab_ptr(), token.0) }
178 }
179
180 #[must_use]
182 pub fn decode_start_token(&self) -> LlamaToken {
183 let token =
184 unsafe { llama_cpp_sys_2::llama_model_decoder_start_token(self.model.as_ptr()) };
185 LlamaToken(token)
186 }
187
188 #[must_use]
190 pub fn token_sep(&self) -> LlamaToken {
191 let token = unsafe { llama_cpp_sys_2::llama_vocab_sep(self.vocab_ptr()) };
192 LlamaToken(token)
193 }
194
195 pub fn token_to_str(
201 &self,
202 token: LlamaToken,
203 special: Special,
204 ) -> Result<String, TokenToStringError> {
205 let bytes = self.token_to_bytes(token, special)?;
206 Ok(String::from_utf8(bytes)?)
207 }
208
209 pub fn token_to_bytes(
219 &self,
220 token: LlamaToken,
221 special: Special,
222 ) -> Result<Vec<u8>, TokenToStringError> {
223 match self.token_to_bytes_with_size(token, 8, special, None) {
224 Err(TokenToStringError::InsufficientBufferSpace(i)) => self.token_to_bytes_with_size(
225 token,
226 (-i).try_into().expect("Error buffer size is positive"),
227 special,
228 None,
229 ),
230 x => x,
231 }
232 }
233
234 pub fn tokens_to_str(
240 &self,
241 tokens: &[LlamaToken],
242 special: Special,
243 ) -> Result<String, TokenToStringError> {
244 let mut builder: Vec<u8> = Vec::with_capacity(tokens.len() * 4);
245 for piece in tokens
246 .iter()
247 .copied()
248 .map(|t| self.token_to_bytes(t, special))
249 {
250 builder.extend_from_slice(&piece?);
251 }
252 Ok(String::from_utf8(builder)?)
253 }
254
255 pub fn str_to_token(
278 &self,
279 str: &str,
280 add_bos: AddBos,
281 ) -> Result<Vec<LlamaToken>, StringToTokenError> {
282 let add_bos = match add_bos {
283 AddBos::Always => true,
284 AddBos::Never => false,
285 };
286
287 let tokens_estimation = std::cmp::max(8, (str.len() / 2) + usize::from(add_bos));
288 let mut buffer: Vec<LlamaToken> = Vec::with_capacity(tokens_estimation);
289
290 let c_string = CString::new(str)?;
291 let buffer_capacity =
292 c_int::try_from(buffer.capacity()).expect("buffer capacity should fit into a c_int");
293
294 let size = unsafe {
295 llama_cpp_sys_2::llama_tokenize(
296 self.vocab_ptr(),
297 c_string.as_ptr(),
298 c_int::try_from(c_string.as_bytes().len())?,
299 buffer.as_mut_ptr().cast::<llama_cpp_sys_2::llama_token>(),
300 buffer_capacity,
301 add_bos,
302 true,
303 )
304 };
305
306 let size = if size.is_negative() {
309 buffer.reserve_exact(usize::try_from(-size).expect("usize's are larger "));
310 unsafe {
311 llama_cpp_sys_2::llama_tokenize(
312 self.vocab_ptr(),
313 c_string.as_ptr(),
314 c_int::try_from(c_string.as_bytes().len())?,
315 buffer.as_mut_ptr().cast::<llama_cpp_sys_2::llama_token>(),
316 -size,
317 add_bos,
318 true,
319 )
320 }
321 } else {
322 size
323 };
324
325 let size = usize::try_from(size).expect("size is positive and usize ");
326
327 unsafe { buffer.set_len(size) }
329 Ok(buffer)
330 }
331
332 #[must_use]
338 pub fn token_attr(&self, LlamaToken(id): LlamaToken) -> LlamaTokenAttrs {
339 let token_type = unsafe { llama_cpp_sys_2::llama_token_get_attr(self.vocab_ptr(), id) };
340 LlamaTokenAttrs::try_from(token_type).expect("token type is valid")
341 }
342
343 pub fn token_to_str_with_size(
359 &self,
360 token: LlamaToken,
361 buffer_size: usize,
362 special: Special,
363 ) -> Result<String, TokenToStringError> {
364 let bytes = self.token_to_bytes_with_size(token, buffer_size, special, None)?;
365 Ok(String::from_utf8(bytes)?)
366 }
367
368 pub fn token_to_bytes_with_size(
383 &self,
384 token: LlamaToken,
385 buffer_size: usize,
386 special: Special,
387 lstrip: Option<NonZeroU16>,
388 ) -> Result<Vec<u8>, TokenToStringError> {
389 if token == self.token_nl() {
390 return Ok(b"\n".to_vec());
391 }
392
393 let attrs = self.token_attr(token);
395 if attrs.is_empty()
396 || attrs
397 .intersects(LlamaTokenAttr::Unknown | LlamaTokenAttr::Byte | LlamaTokenAttr::Unused)
398 || attrs.contains(LlamaTokenAttr::Control)
399 && (token == self.token_bos() || token == self.token_eos())
400 {
401 return Ok(Vec::new());
402 }
403
404 let special = match special {
405 Special::Tokenize => true,
406 Special::Plaintext => false,
407 };
408
409 let string = CString::new(vec![b'*'; buffer_size]).expect("no null");
410 let len = string.as_bytes().len();
411 let len = c_int::try_from(len).expect("length fits into c_int");
412 let buf = string.into_raw();
413 let lstrip = lstrip.map_or(0, |it| i32::from(it.get()));
414 let size = unsafe {
415 llama_cpp_sys_2::llama_token_to_piece(
416 self.vocab_ptr(),
417 token.0,
418 buf,
419 len,
420 lstrip,
421 special,
422 )
423 };
424
425 match size {
426 0 => Err(TokenToStringError::UnknownTokenType),
427 i if i.is_negative() => Err(TokenToStringError::InsufficientBufferSpace(i)),
428 size => {
429 let string = unsafe { CString::from_raw(buf) };
430 let mut bytes = string.into_bytes();
431 let len = usize::try_from(size).expect("size is positive and fits into usize");
432 bytes.truncate(len);
433 Ok(bytes)
434 }
435 }
436 }
437 #[must_use]
442 pub fn n_vocab(&self) -> i32 {
443 unsafe { llama_cpp_sys_2::llama_n_vocab(self.vocab_ptr()) }
444 }
445
446 #[must_use]
452 pub fn vocab_type(&self) -> VocabType {
453 let vocab_type = unsafe { llama_cpp_sys_2::llama_vocab_type(self.vocab_ptr()) };
455 VocabType::try_from(vocab_type).expect("invalid vocab type")
456 }
457
458 #[must_use]
461 pub fn n_embd(&self) -> c_int {
462 unsafe { llama_cpp_sys_2::llama_n_embd(self.model.as_ptr()) }
463 }
464
465 pub fn size(&self) -> u64 {
467 unsafe { llama_cpp_sys_2::llama_model_size(self.model.as_ptr()) }
468 }
469
470 pub fn n_params(&self) -> u64 {
472 unsafe { llama_cpp_sys_2::llama_model_n_params(self.model.as_ptr()) }
473 }
474
475 pub fn is_recurrent(&self) -> bool {
477 unsafe { llama_cpp_sys_2::llama_model_is_recurrent(self.model.as_ptr()) }
478 }
479
480 pub fn n_layer(&self) -> u32 {
482 u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_layer(self.model.as_ptr()) }).unwrap()
485 }
486
487 pub fn n_head(&self) -> u32 {
489 u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_head(self.model.as_ptr()) }).unwrap()
492 }
493
494 pub fn n_head_kv(&self) -> u32 {
496 u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_head_kv(self.model.as_ptr()) })
499 .unwrap()
500 }
501
502 pub fn meta_val_str(&self, key: &str) -> Result<String, MetaValError> {
504 let key_cstring = CString::new(key)?;
505 let key_ptr = key_cstring.as_ptr();
506
507 extract_meta_string(
508 |buf_ptr, buf_len| unsafe {
509 llama_cpp_sys_2::llama_model_meta_val_str(
510 self.model.as_ptr(),
511 key_ptr,
512 buf_ptr,
513 buf_len,
514 )
515 },
516 256,
517 )
518 }
519
520 pub fn meta_count(&self) -> i32 {
522 unsafe { llama_cpp_sys_2::llama_model_meta_count(self.model.as_ptr()) }
523 }
524
525 pub fn meta_key_by_index(&self, index: i32) -> Result<String, MetaValError> {
527 extract_meta_string(
528 |buf_ptr, buf_len| unsafe {
529 llama_cpp_sys_2::llama_model_meta_key_by_index(
530 self.model.as_ptr(),
531 index,
532 buf_ptr,
533 buf_len,
534 )
535 },
536 256,
537 )
538 }
539
540 pub fn meta_val_str_by_index(&self, index: i32) -> Result<String, MetaValError> {
542 extract_meta_string(
543 |buf_ptr, buf_len| unsafe {
544 llama_cpp_sys_2::llama_model_meta_val_str_by_index(
545 self.model.as_ptr(),
546 index,
547 buf_ptr,
548 buf_len,
549 )
550 },
551 256,
552 )
553 }
554
555 pub fn rope_type(&self) -> Option<RopeType> {
557 match unsafe { llama_cpp_sys_2::llama_model_rope_type(self.model.as_ptr()) } {
558 llama_cpp_sys_2::LLAMA_ROPE_TYPE_NONE => None,
559 llama_cpp_sys_2::LLAMA_ROPE_TYPE_NORM => Some(RopeType::Norm),
560 llama_cpp_sys_2::LLAMA_ROPE_TYPE_NEOX => Some(RopeType::NeoX),
561 llama_cpp_sys_2::LLAMA_ROPE_TYPE_MROPE => Some(RopeType::MRope),
562 llama_cpp_sys_2::LLAMA_ROPE_TYPE_VISION => Some(RopeType::Vision),
563 rope_type => {
564 tracing::error!(rope_type = rope_type, "Unexpected rope type from llama.cpp");
565 None
566 }
567 }
568 }
569
570 pub fn chat_template(
584 &self,
585 name: Option<&str>,
586 ) -> Result<LlamaChatTemplate, ChatTemplateError> {
587 let name_cstr = name.map(CString::new);
588 let name_ptr = match name_cstr {
589 Some(Ok(name)) => name.as_ptr(),
590 _ => std::ptr::null(),
591 };
592 let result =
593 unsafe { llama_cpp_sys_2::llama_model_chat_template(self.model.as_ptr(), name_ptr) };
594
595 if result.is_null() {
597 Err(ChatTemplateError::MissingTemplate)
598 } else {
599 let chat_template_cstr = unsafe { CStr::from_ptr(result) };
600 let chat_template = CString::new(chat_template_cstr.to_bytes())?;
601 Ok(LlamaChatTemplate(chat_template))
602 }
603 }
604
605 #[tracing::instrument(skip_all, fields(params))]
611 pub fn load_from_file(
612 _: &LlamaBackend,
613 path: impl AsRef<Path>,
614 params: &LlamaModelParams,
615 ) -> Result<Self, LlamaModelLoadError> {
616 let path = path.as_ref();
617 debug_assert!(Path::new(path).exists(), "{path:?} does not exist");
618 let path = path
619 .to_str()
620 .ok_or(LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
621
622 let cstr = CString::new(path)?;
623 let llama_model =
624 unsafe { llama_cpp_sys_2::llama_load_model_from_file(cstr.as_ptr(), params.params) };
625
626 let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?;
627
628 tracing::debug!(?path, "Loaded model");
629 Ok(LlamaModel { model })
630 }
631
632 pub fn lora_adapter_init(
638 &self,
639 path: impl AsRef<Path>,
640 ) -> Result<LlamaLoraAdapter, LlamaLoraAdapterInitError> {
641 let path = path.as_ref();
642 debug_assert!(Path::new(path).exists(), "{path:?} does not exist");
643
644 let path = path
645 .to_str()
646 .ok_or(LlamaLoraAdapterInitError::PathToStrError(
647 path.to_path_buf(),
648 ))?;
649
650 let cstr = CString::new(path)?;
651 let adapter =
652 unsafe { llama_cpp_sys_2::llama_adapter_lora_init(self.model.as_ptr(), cstr.as_ptr()) };
653
654 let adapter = NonNull::new(adapter).ok_or(LlamaLoraAdapterInitError::NullResult)?;
655
656 tracing::debug!(?path, "Initialized lora adapter");
657 Ok(LlamaLoraAdapter {
658 lora_adapter: adapter,
659 })
660 }
661
662 #[allow(clippy::needless_pass_by_value)]
669 pub fn new_context(
670 &self,
671 _: &LlamaBackend,
672 params: LlamaContextParams,
673 ) -> Result<LlamaContext, LlamaContextLoadError> {
674 let context_params = params.context_params;
675 let context = unsafe {
676 llama_cpp_sys_2::llama_new_context_with_model(self.model.as_ptr(), context_params)
677 };
678 let context = NonNull::new(context).ok_or(LlamaContextLoadError::NullReturn)?;
679
680 Ok(LlamaContext::new(self, context, params.embeddings()))
681 }
682
683 #[tracing::instrument(skip_all)]
701 pub fn apply_chat_template(
702 &self,
703 tmpl: &LlamaChatTemplate,
704 chat: &[LlamaChatMessage],
705 add_ass: bool,
706 ) -> Result<String, ApplyChatTemplateError> {
707 let message_length = chat.iter().fold(0, |acc, c| {
709 acc + c.role.to_bytes().len() + c.content.to_bytes().len()
710 });
711 let mut buff: Vec<u8> = vec![0; message_length * 2];
712
713 let chat: Vec<llama_cpp_sys_2::llama_chat_message> = chat
715 .iter()
716 .map(|c| llama_cpp_sys_2::llama_chat_message {
717 role: c.role.as_ptr(),
718 content: c.content.as_ptr(),
719 })
720 .collect();
721
722 let tmpl_ptr = tmpl.0.as_ptr();
723
724 let res = unsafe {
725 llama_cpp_sys_2::llama_chat_apply_template(
726 tmpl_ptr,
727 chat.as_ptr(),
728 chat.len(),
729 add_ass,
730 buff.as_mut_ptr().cast::<c_char>(),
731 buff.len().try_into().expect("Buffer size exceeds i32::MAX"),
732 )
733 };
734
735 if res > buff.len().try_into().expect("Buffer size exceeds i32::MAX") {
736 buff.resize(res.try_into().expect("res is negative"), 0);
737
738 let res = unsafe {
739 llama_cpp_sys_2::llama_chat_apply_template(
740 tmpl_ptr,
741 chat.as_ptr(),
742 chat.len(),
743 add_ass,
744 buff.as_mut_ptr().cast::<c_char>(),
745 buff.len().try_into().expect("Buffer size exceeds i32::MAX"),
746 )
747 };
748 assert_eq!(Ok(res), buff.len().try_into());
749 }
750 buff.truncate(res.try_into().expect("res is negative"));
751 Ok(String::from_utf8(buff)?)
752 }
753}
754
755fn extract_meta_string<F>(c_function: F, capacity: usize) -> Result<String, MetaValError>
761where
762 F: Fn(*mut c_char, usize) -> i32,
763{
764 let mut buffer = vec![0u8; capacity];
765
766 let result = c_function(buffer.as_mut_ptr() as *mut c_char, buffer.len());
768 if result < 0 {
769 return Err(MetaValError::NegativeReturn(result));
770 }
771
772 let returned_len = result as usize;
774 if returned_len >= capacity {
775 return extract_meta_string(c_function, returned_len + 1);
777 }
778
779 debug_assert_eq!(
781 buffer.get(returned_len),
782 Some(&0),
783 "should end with null byte"
784 );
785
786 buffer.truncate(returned_len);
788 Ok(String::from_utf8(buffer)?)
789}
790
791impl Drop for LlamaModel {
792 fn drop(&mut self) {
793 unsafe { llama_cpp_sys_2::llama_free_model(self.model.as_ptr()) }
794 }
795}
796
797#[repr(u32)]
799#[derive(Debug, Eq, Copy, Clone, PartialEq)]
800pub enum VocabType {
801 BPE = llama_cpp_sys_2::LLAMA_VOCAB_TYPE_BPE as _,
803 SPM = llama_cpp_sys_2::LLAMA_VOCAB_TYPE_SPM as _,
805}
806
807#[derive(thiserror::Error, Debug, Eq, PartialEq)]
809pub enum LlamaTokenTypeFromIntError {
810 #[error("Unknown Value {0}")]
812 UnknownValue(llama_cpp_sys_2::llama_vocab_type),
813}
814
815impl TryFrom<llama_cpp_sys_2::llama_vocab_type> for VocabType {
816 type Error = LlamaTokenTypeFromIntError;
817
818 fn try_from(value: llama_cpp_sys_2::llama_vocab_type) -> Result<Self, Self::Error> {
819 match value {
820 llama_cpp_sys_2::LLAMA_VOCAB_TYPE_BPE => Ok(VocabType::BPE),
821 llama_cpp_sys_2::LLAMA_VOCAB_TYPE_SPM => Ok(VocabType::SPM),
822 unknown => Err(LlamaTokenTypeFromIntError::UnknownValue(unknown)),
823 }
824 }
825}