1use std::ffi::{c_char, CStr, CString};
3use std::num::NonZeroU16;
4use std::os::raw::c_int;
5use std::path::Path;
6use std::ptr::NonNull;
7use std::str::Utf8Error;
8
9use crate::context::params::LlamaContextParams;
10use crate::context::LlamaContext;
11use crate::llama_backend::LlamaBackend;
12use crate::model::params::LlamaModelParams;
13use crate::token::LlamaToken;
14use crate::token_type::{LlamaTokenAttr, LlamaTokenAttrs};
15use crate::{
16 ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaLoraAdapterInitError,
17 LlamaModelLoadError, MetaValError, NewLlamaChatMessageError, StringToTokenError,
18 TokenToStringError,
19};
20
21pub mod params;
22
23#[derive(Debug)]
25#[repr(transparent)]
26#[allow(clippy::module_name_repetitions)]
27pub struct LlamaModel {
28 pub(crate) model: NonNull<llama_cpp_sys_2::llama_model>,
29}
30
31#[derive(Debug)]
33#[repr(transparent)]
34#[allow(clippy::module_name_repetitions)]
35pub struct LlamaLoraAdapter {
36 pub(crate) lora_adapter: NonNull<llama_cpp_sys_2::llama_adapter_lora>,
37}
38
39#[derive(Eq, PartialEq, Clone, PartialOrd, Ord, Hash)]
44pub struct LlamaChatTemplate(CString);
45
46impl LlamaChatTemplate {
47 pub fn new(template: &str) -> Result<Self, std::ffi::NulError> {
50 Ok(Self(CString::new(template)?))
51 }
52
53 pub fn as_c_str(&self) -> &CStr {
55 &self.0
56 }
57
58 pub fn to_str(&self) -> Result<&str, Utf8Error> {
60 self.0.to_str()
61 }
62
63 pub fn to_string(&self) -> Result<String, Utf8Error> {
65 self.to_str().map(str::to_string)
66 }
67}
68
69impl std::fmt::Debug for LlamaChatTemplate {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 self.0.fmt(f)
72 }
73}
74
75#[derive(Debug, Eq, PartialEq, Clone)]
77pub struct LlamaChatMessage {
78 role: CString,
79 content: CString,
80}
81
82impl LlamaChatMessage {
83 pub fn new(role: String, content: String) -> Result<Self, NewLlamaChatMessageError> {
88 Ok(Self {
89 role: CString::new(role)?,
90 content: CString::new(content)?,
91 })
92 }
93}
94
95#[derive(Debug, Clone, Copy, PartialEq, Eq)]
97pub enum RopeType {
98 Norm,
99 NeoX,
100 MRope,
101 Vision,
102}
103
104#[derive(Debug, Clone, Copy, PartialEq, Eq)]
106pub enum AddBos {
107 Always,
109 Never,
111}
112
113#[derive(Debug, Clone, Copy, PartialEq, Eq)]
115pub enum Special {
116 Tokenize,
118 Plaintext,
120}
121
122unsafe impl Send for LlamaModel {}
123
124unsafe impl Sync for LlamaModel {}
125
126impl LlamaModel {
127 pub(crate) fn vocab_ptr(&self) -> *const llama_cpp_sys_2::llama_vocab {
128 unsafe { llama_cpp_sys_2::llama_model_get_vocab(self.model.as_ptr()) }
129 }
130
131 #[must_use]
138 pub fn n_ctx_train(&self) -> u32 {
139 let n_ctx_train = unsafe { llama_cpp_sys_2::llama_n_ctx_train(self.model.as_ptr()) };
140 u32::try_from(n_ctx_train).expect("n_ctx_train fits into an u32")
141 }
142
143 pub fn tokens(
145 &self,
146 special: Special,
147 ) -> impl Iterator<Item = (LlamaToken, Result<String, TokenToStringError>)> + '_ {
148 (0..self.n_vocab())
149 .map(LlamaToken::new)
150 .map(move |llama_token| (llama_token, self.token_to_str(llama_token, special)))
151 }
152
153 #[must_use]
155 pub fn token_bos(&self) -> LlamaToken {
156 let token = unsafe { llama_cpp_sys_2::llama_token_bos(self.vocab_ptr()) };
157 LlamaToken(token)
158 }
159
160 #[must_use]
162 pub fn token_eos(&self) -> LlamaToken {
163 let token = unsafe { llama_cpp_sys_2::llama_token_eos(self.vocab_ptr()) };
164 LlamaToken(token)
165 }
166
167 #[must_use]
169 pub fn token_nl(&self) -> LlamaToken {
170 let token = unsafe { llama_cpp_sys_2::llama_token_nl(self.vocab_ptr()) };
171 LlamaToken(token)
172 }
173
174 #[must_use]
176 pub fn is_eog_token(&self, token: LlamaToken) -> bool {
177 unsafe { llama_cpp_sys_2::llama_token_is_eog(self.vocab_ptr(), token.0) }
178 }
179
180 #[must_use]
182 pub fn decode_start_token(&self) -> LlamaToken {
183 let token =
184 unsafe { llama_cpp_sys_2::llama_model_decoder_start_token(self.model.as_ptr()) };
185 LlamaToken(token)
186 }
187
188 pub fn token_to_str(
194 &self,
195 token: LlamaToken,
196 special: Special,
197 ) -> Result<String, TokenToStringError> {
198 let bytes = self.token_to_bytes(token, special)?;
199 Ok(String::from_utf8(bytes)?)
200 }
201
202 pub fn token_to_bytes(
212 &self,
213 token: LlamaToken,
214 special: Special,
215 ) -> Result<Vec<u8>, TokenToStringError> {
216 match self.token_to_bytes_with_size(token, 8, special, None) {
217 Err(TokenToStringError::InsufficientBufferSpace(i)) => self.token_to_bytes_with_size(
218 token,
219 (-i).try_into().expect("Error buffer size is positive"),
220 special,
221 None,
222 ),
223 x => x,
224 }
225 }
226
227 pub fn tokens_to_str(
233 &self,
234 tokens: &[LlamaToken],
235 special: Special,
236 ) -> Result<String, TokenToStringError> {
237 let mut builder: Vec<u8> = Vec::with_capacity(tokens.len() * 4);
238 for piece in tokens
239 .iter()
240 .copied()
241 .map(|t| self.token_to_bytes(t, special))
242 {
243 builder.extend_from_slice(&piece?);
244 }
245 Ok(String::from_utf8(builder)?)
246 }
247
248 pub fn str_to_token(
271 &self,
272 str: &str,
273 add_bos: AddBos,
274 ) -> Result<Vec<LlamaToken>, StringToTokenError> {
275 let add_bos = match add_bos {
276 AddBos::Always => true,
277 AddBos::Never => false,
278 };
279
280 let tokens_estimation = std::cmp::max(8, (str.len() / 2) + usize::from(add_bos));
281 let mut buffer: Vec<LlamaToken> = Vec::with_capacity(tokens_estimation);
282
283 let c_string = CString::new(str)?;
284 let buffer_capacity =
285 c_int::try_from(buffer.capacity()).expect("buffer capacity should fit into a c_int");
286
287 let size = unsafe {
288 llama_cpp_sys_2::llama_tokenize(
289 self.vocab_ptr(),
290 c_string.as_ptr(),
291 c_int::try_from(c_string.as_bytes().len())?,
292 buffer.as_mut_ptr().cast::<llama_cpp_sys_2::llama_token>(),
293 buffer_capacity,
294 add_bos,
295 true,
296 )
297 };
298
299 let size = if size.is_negative() {
302 buffer.reserve_exact(usize::try_from(-size).expect("usize's are larger "));
303 unsafe {
304 llama_cpp_sys_2::llama_tokenize(
305 self.vocab_ptr(),
306 c_string.as_ptr(),
307 c_int::try_from(c_string.as_bytes().len())?,
308 buffer.as_mut_ptr().cast::<llama_cpp_sys_2::llama_token>(),
309 -size,
310 add_bos,
311 true,
312 )
313 }
314 } else {
315 size
316 };
317
318 let size = usize::try_from(size).expect("size is positive and usize ");
319
320 unsafe { buffer.set_len(size) }
322 Ok(buffer)
323 }
324
325 #[must_use]
331 pub fn token_attr(&self, LlamaToken(id): LlamaToken) -> LlamaTokenAttrs {
332 let token_type = unsafe { llama_cpp_sys_2::llama_token_get_attr(self.vocab_ptr(), id) };
333 LlamaTokenAttrs::try_from(token_type).expect("token type is valid")
334 }
335
336 pub fn token_to_str_with_size(
352 &self,
353 token: LlamaToken,
354 buffer_size: usize,
355 special: Special,
356 ) -> Result<String, TokenToStringError> {
357 let bytes = self.token_to_bytes_with_size(token, buffer_size, special, None)?;
358 Ok(String::from_utf8(bytes)?)
359 }
360
361 pub fn token_to_bytes_with_size(
376 &self,
377 token: LlamaToken,
378 buffer_size: usize,
379 special: Special,
380 lstrip: Option<NonZeroU16>,
381 ) -> Result<Vec<u8>, TokenToStringError> {
382 if token == self.token_nl() {
383 return Ok(b"\n".to_vec());
384 }
385
386 let attrs = self.token_attr(token);
388 if attrs.is_empty()
389 || attrs
390 .intersects(LlamaTokenAttr::Unknown | LlamaTokenAttr::Byte | LlamaTokenAttr::Unused)
391 || attrs.contains(LlamaTokenAttr::Control)
392 && (token == self.token_bos() || token == self.token_eos())
393 {
394 return Ok(Vec::new());
395 }
396
397 let special = match special {
398 Special::Tokenize => true,
399 Special::Plaintext => false,
400 };
401
402 let string = CString::new(vec![b'*'; buffer_size]).expect("no null");
403 let len = string.as_bytes().len();
404 let len = c_int::try_from(len).expect("length fits into c_int");
405 let buf = string.into_raw();
406 let lstrip = lstrip.map_or(0, |it| i32::from(it.get()));
407 let size = unsafe {
408 llama_cpp_sys_2::llama_token_to_piece(
409 self.vocab_ptr(),
410 token.0,
411 buf,
412 len,
413 lstrip,
414 special,
415 )
416 };
417
418 match size {
419 0 => Err(TokenToStringError::UnknownTokenType),
420 i if i.is_negative() => Err(TokenToStringError::InsufficientBufferSpace(i)),
421 size => {
422 let string = unsafe { CString::from_raw(buf) };
423 let mut bytes = string.into_bytes();
424 let len = usize::try_from(size).expect("size is positive and fits into usize");
425 bytes.truncate(len);
426 Ok(bytes)
427 }
428 }
429 }
430 #[must_use]
435 pub fn n_vocab(&self) -> i32 {
436 unsafe { llama_cpp_sys_2::llama_n_vocab(self.vocab_ptr()) }
437 }
438
439 #[must_use]
445 pub fn vocab_type(&self) -> VocabType {
446 let vocab_type = unsafe { llama_cpp_sys_2::llama_vocab_type(self.vocab_ptr()) };
448 VocabType::try_from(vocab_type).expect("invalid vocab type")
449 }
450
451 #[must_use]
454 pub fn n_embd(&self) -> c_int {
455 unsafe { llama_cpp_sys_2::llama_n_embd(self.model.as_ptr()) }
456 }
457
458 pub fn size(&self) -> u64 {
460 unsafe { llama_cpp_sys_2::llama_model_size(self.model.as_ptr()) }
461 }
462
463 pub fn n_params(&self) -> u64 {
465 unsafe { llama_cpp_sys_2::llama_model_n_params(self.model.as_ptr()) }
466 }
467
468 pub fn is_recurrent(&self) -> bool {
470 unsafe { llama_cpp_sys_2::llama_model_is_recurrent(self.model.as_ptr()) }
471 }
472
473 pub fn n_layer(&self) -> u32 {
475 u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_layer(self.model.as_ptr()) }).unwrap()
478 }
479
480 pub fn n_head(&self) -> u32 {
482 u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_head(self.model.as_ptr()) }).unwrap()
485 }
486
487 pub fn n_head_kv(&self) -> u32 {
489 u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_head_kv(self.model.as_ptr()) })
492 .unwrap()
493 }
494
495 pub fn meta_val_str(&self, key: &str) -> Result<String, MetaValError> {
497 let key_cstring = CString::new(key)?;
498 let key_ptr = key_cstring.as_ptr();
499
500 extract_meta_string(
501 |buf_ptr, buf_len| unsafe {
502 llama_cpp_sys_2::llama_model_meta_val_str(
503 self.model.as_ptr(),
504 key_ptr,
505 buf_ptr,
506 buf_len,
507 )
508 },
509 256,
510 )
511 }
512
513 pub fn meta_count(&self) -> i32 {
515 unsafe { llama_cpp_sys_2::llama_model_meta_count(self.model.as_ptr()) }
516 }
517
518 pub fn meta_key_by_index(&self, index: i32) -> Result<String, MetaValError> {
520 extract_meta_string(
521 |buf_ptr, buf_len| unsafe {
522 llama_cpp_sys_2::llama_model_meta_key_by_index(
523 self.model.as_ptr(),
524 index,
525 buf_ptr,
526 buf_len,
527 )
528 },
529 256,
530 )
531 }
532
533 pub fn meta_val_str_by_index(&self, index: i32) -> Result<String, MetaValError> {
535 extract_meta_string(
536 |buf_ptr, buf_len| unsafe {
537 llama_cpp_sys_2::llama_model_meta_val_str_by_index(
538 self.model.as_ptr(),
539 index,
540 buf_ptr,
541 buf_len,
542 )
543 },
544 256,
545 )
546 }
547
548 pub fn rope_type(&self) -> Option<RopeType> {
550 match unsafe { llama_cpp_sys_2::llama_model_rope_type(self.model.as_ptr()) } {
551 llama_cpp_sys_2::LLAMA_ROPE_TYPE_NONE => None,
552 llama_cpp_sys_2::LLAMA_ROPE_TYPE_NORM => Some(RopeType::Norm),
553 llama_cpp_sys_2::LLAMA_ROPE_TYPE_NEOX => Some(RopeType::NeoX),
554 llama_cpp_sys_2::LLAMA_ROPE_TYPE_MROPE => Some(RopeType::MRope),
555 llama_cpp_sys_2::LLAMA_ROPE_TYPE_VISION => Some(RopeType::Vision),
556 rope_type => {
557 tracing::error!(rope_type = rope_type, "Unexpected rope type from llama.cpp");
558 None
559 }
560 }
561 }
562
563 pub fn chat_template(
577 &self,
578 name: Option<&str>,
579 ) -> Result<LlamaChatTemplate, ChatTemplateError> {
580 let name_cstr = name.map(CString::new);
581 let name_ptr = match name_cstr {
582 Some(Ok(name)) => name.as_ptr(),
583 _ => std::ptr::null(),
584 };
585 let result =
586 unsafe { llama_cpp_sys_2::llama_model_chat_template(self.model.as_ptr(), name_ptr) };
587
588 if result.is_null() {
590 Err(ChatTemplateError::MissingTemplate)
591 } else {
592 let chat_template_cstr = unsafe { CStr::from_ptr(result) };
593 let chat_template = CString::new(chat_template_cstr.to_bytes())?;
594 Ok(LlamaChatTemplate(chat_template))
595 }
596 }
597
598 #[tracing::instrument(skip_all, fields(params))]
604 pub fn load_from_file(
605 _: &LlamaBackend,
606 path: impl AsRef<Path>,
607 params: &LlamaModelParams,
608 ) -> Result<Self, LlamaModelLoadError> {
609 let path = path.as_ref();
610 debug_assert!(Path::new(path).exists(), "{path:?} does not exist");
611 let path = path
612 .to_str()
613 .ok_or(LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
614
615 let cstr = CString::new(path)?;
616 let llama_model =
617 unsafe { llama_cpp_sys_2::llama_load_model_from_file(cstr.as_ptr(), params.params) };
618
619 let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?;
620
621 tracing::debug!(?path, "Loaded model");
622 Ok(LlamaModel { model })
623 }
624
625 pub fn lora_adapter_init(
631 &self,
632 path: impl AsRef<Path>,
633 ) -> Result<LlamaLoraAdapter, LlamaLoraAdapterInitError> {
634 let path = path.as_ref();
635 debug_assert!(Path::new(path).exists(), "{path:?} does not exist");
636
637 let path = path
638 .to_str()
639 .ok_or(LlamaLoraAdapterInitError::PathToStrError(
640 path.to_path_buf(),
641 ))?;
642
643 let cstr = CString::new(path)?;
644 let adapter =
645 unsafe { llama_cpp_sys_2::llama_adapter_lora_init(self.model.as_ptr(), cstr.as_ptr()) };
646
647 let adapter = NonNull::new(adapter).ok_or(LlamaLoraAdapterInitError::NullResult)?;
648
649 tracing::debug!(?path, "Initialized lora adapter");
650 Ok(LlamaLoraAdapter {
651 lora_adapter: adapter,
652 })
653 }
654
655 #[allow(clippy::needless_pass_by_value)]
662 pub fn new_context(
663 &self,
664 _: &LlamaBackend,
665 params: LlamaContextParams,
666 ) -> Result<LlamaContext, LlamaContextLoadError> {
667 let context_params = params.context_params;
668 let context = unsafe {
669 llama_cpp_sys_2::llama_new_context_with_model(self.model.as_ptr(), context_params)
670 };
671 let context = NonNull::new(context).ok_or(LlamaContextLoadError::NullReturn)?;
672
673 Ok(LlamaContext::new(self, context, params.embeddings()))
674 }
675
676 #[tracing::instrument(skip_all)]
694 pub fn apply_chat_template(
695 &self,
696 tmpl: &LlamaChatTemplate,
697 chat: &[LlamaChatMessage],
698 add_ass: bool,
699 ) -> Result<String, ApplyChatTemplateError> {
700 let message_length = chat.iter().fold(0, |acc, c| {
702 acc + c.role.to_bytes().len() + c.content.to_bytes().len()
703 });
704 let mut buff: Vec<u8> = vec![0; message_length * 2];
705
706 let chat: Vec<llama_cpp_sys_2::llama_chat_message> = chat
708 .iter()
709 .map(|c| llama_cpp_sys_2::llama_chat_message {
710 role: c.role.as_ptr(),
711 content: c.content.as_ptr(),
712 })
713 .collect();
714
715 let tmpl_ptr = tmpl.0.as_ptr();
716
717 let res = unsafe {
718 llama_cpp_sys_2::llama_chat_apply_template(
719 tmpl_ptr,
720 chat.as_ptr(),
721 chat.len(),
722 add_ass,
723 buff.as_mut_ptr().cast::<c_char>(),
724 buff.len().try_into().expect("Buffer size exceeds i32::MAX"),
725 )
726 };
727
728 if res > buff.len().try_into().expect("Buffer size exceeds i32::MAX") {
729 buff.resize(res.try_into().expect("res is negative"), 0);
730
731 let res = unsafe {
732 llama_cpp_sys_2::llama_chat_apply_template(
733 tmpl_ptr,
734 chat.as_ptr(),
735 chat.len(),
736 add_ass,
737 buff.as_mut_ptr().cast::<c_char>(),
738 buff.len().try_into().expect("Buffer size exceeds i32::MAX"),
739 )
740 };
741 assert_eq!(Ok(res), buff.len().try_into());
742 }
743 buff.truncate(res.try_into().expect("res is negative"));
744 Ok(String::from_utf8(buff)?)
745 }
746}
747
748fn extract_meta_string<F>(c_function: F, capacity: usize) -> Result<String, MetaValError>
754where
755 F: Fn(*mut c_char, usize) -> i32,
756{
757 let mut buffer = vec![0u8; capacity];
758
759 let result = c_function(buffer.as_mut_ptr() as *mut c_char, buffer.len());
761 if result < 0 {
762 return Err(MetaValError::NegativeReturn(result));
763 }
764
765 let returned_len = result as usize;
767 if returned_len >= capacity {
768 return extract_meta_string(c_function, returned_len + 1);
770 }
771
772 debug_assert_eq!(
774 buffer.get(returned_len),
775 Some(&0),
776 "should end with null byte"
777 );
778
779 buffer.truncate(returned_len);
781 Ok(String::from_utf8(buffer)?)
782}
783
784impl Drop for LlamaModel {
785 fn drop(&mut self) {
786 unsafe { llama_cpp_sys_2::llama_free_model(self.model.as_ptr()) }
787 }
788}
789
790#[repr(u32)]
792#[derive(Debug, Eq, Copy, Clone, PartialEq)]
793pub enum VocabType {
794 BPE = llama_cpp_sys_2::LLAMA_VOCAB_TYPE_BPE as _,
796 SPM = llama_cpp_sys_2::LLAMA_VOCAB_TYPE_SPM as _,
798}
799
800#[derive(thiserror::Error, Debug, Eq, PartialEq)]
802pub enum LlamaTokenTypeFromIntError {
803 #[error("Unknown Value {0}")]
805 UnknownValue(llama_cpp_sys_2::llama_vocab_type),
806}
807
808impl TryFrom<llama_cpp_sys_2::llama_vocab_type> for VocabType {
809 type Error = LlamaTokenTypeFromIntError;
810
811 fn try_from(value: llama_cpp_sys_2::llama_vocab_type) -> Result<Self, Self::Error> {
812 match value {
813 llama_cpp_sys_2::LLAMA_VOCAB_TYPE_BPE => Ok(VocabType::BPE),
814 llama_cpp_sys_2::LLAMA_VOCAB_TYPE_SPM => Ok(VocabType::SPM),
815 unknown => Err(LlamaTokenTypeFromIntError::UnknownValue(unknown)),
816 }
817 }
818}