1use std::ffi::{c_char, CStr, CString};
3use std::num::NonZeroU16;
4use std::os::raw::c_int;
5use std::path::Path;
6use std::ptr::{self, NonNull};
7use std::slice;
8use std::str::Utf8Error;
9
10use crate::context::params::LlamaContextParams;
11use crate::context::LlamaContext;
12use crate::llama_backend::LlamaBackend;
13use crate::model::params::LlamaModelParams;
14use crate::openai::{ChatParseStateOaicompat, OpenAIChatTemplateParams};
15use crate::token::LlamaToken;
16use crate::token_type::{LlamaTokenAttr, LlamaTokenAttrs};
17use crate::{
18 status_is_ok, status_to_i32, ApplyChatTemplateError, ChatParseError, ChatTemplateError,
19 LlamaContextLoadError, LlamaLoraAdapterInitError, LlamaModelLoadError, MetaValError,
20 NewLlamaChatMessageError, StringToTokenError, TokenToStringError,
21};
22
23pub mod params;
24
25#[derive(Debug)]
27#[repr(transparent)]
28#[allow(clippy::module_name_repetitions)]
29pub struct LlamaModel {
30 pub(crate) model: NonNull<llama_cpp_sys_2::llama_model>,
31}
32
33#[derive(Debug)]
35#[repr(transparent)]
36#[allow(clippy::module_name_repetitions)]
37pub struct LlamaLoraAdapter {
38 pub(crate) lora_adapter: NonNull<llama_cpp_sys_2::llama_adapter_lora>,
39}
40
41#[derive(Eq, PartialEq, Clone, PartialOrd, Ord, Hash)]
46pub struct LlamaChatTemplate(CString);
47
48impl LlamaChatTemplate {
49 pub fn new(template: &str) -> Result<Self, std::ffi::NulError> {
52 Ok(Self(CString::new(template)?))
53 }
54
55 pub fn as_c_str(&self) -> &CStr {
57 &self.0
58 }
59
60 pub fn to_str(&self) -> Result<&str, Utf8Error> {
62 self.0.to_str()
63 }
64
65 pub fn to_string(&self) -> Result<String, Utf8Error> {
67 self.to_str().map(str::to_string)
68 }
69}
70
71impl std::fmt::Debug for LlamaChatTemplate {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 self.0.fmt(f)
74 }
75}
76
77#[derive(Debug, Eq, PartialEq, Clone)]
79pub struct LlamaChatMessage {
80 role: CString,
81 content: CString,
82}
83
84impl LlamaChatMessage {
85 pub fn new(role: String, content: String) -> Result<Self, NewLlamaChatMessageError> {
90 Ok(Self {
91 role: CString::new(role)?,
92 content: CString::new(content)?,
93 })
94 }
95}
96
97#[derive(Debug, Clone, Copy, PartialEq, Eq)]
99pub enum GrammarTriggerType {
100 Token = 0,
102 Word = 1,
104 Pattern = 2,
106 PatternFull = 3,
108}
109
110#[derive(Debug, Clone, PartialEq, Eq)]
112pub struct GrammarTrigger {
113 pub trigger_type: GrammarTriggerType,
115 pub value: String,
117 pub token: Option<LlamaToken>,
119}
120
121#[derive(Debug, Clone, PartialEq, Eq)]
123pub struct ChatTemplateResult {
124 pub prompt: String,
126 pub grammar: Option<String>,
128 pub grammar_lazy: bool,
130 pub grammar_triggers: Vec<GrammarTrigger>,
132 pub preserved_tokens: Vec<String>,
134 pub additional_stops: Vec<String>,
136 pub chat_format: i32,
138 pub parser: Option<String>,
140 pub generation_prompt: String,
142 pub parse_tool_calls: bool,
144}
145
146#[derive(Debug, Clone, Copy, PartialEq, Eq)]
148pub enum RopeType {
149 Norm,
150 NeoX,
151 MRope,
152 Vision,
153}
154
155#[derive(Debug, Clone, Copy, PartialEq, Eq)]
157pub enum AddBos {
158 Always,
160 Never,
162}
163
164#[deprecated(
166 since = "0.1.0",
167 note = "This enum is a mixture of options for llama cpp providing less flexibility it only used with deprecated methods and will be removed in the future."
168)]
169#[derive(Debug, Clone, Copy, PartialEq, Eq)]
170pub enum Special {
171 Tokenize,
173 Plaintext,
175}
176
177unsafe impl Send for LlamaModel {}
178
179unsafe impl Sync for LlamaModel {}
180
181impl LlamaModel {
182 pub(crate) fn vocab_ptr(&self) -> *const llama_cpp_sys_2::llama_vocab {
183 unsafe { llama_cpp_sys_2::llama_model_get_vocab(self.model.as_ptr()) }
184 }
185
186 #[must_use]
193 pub fn n_ctx_train(&self) -> u32 {
194 let n_ctx_train = unsafe { llama_cpp_sys_2::llama_n_ctx_train(self.model.as_ptr()) };
195 u32::try_from(n_ctx_train).expect("n_ctx_train fits into an u32")
196 }
197
198 pub fn tokens(
200 &self,
201 decode_special: bool,
202 ) -> impl Iterator<Item = (LlamaToken, Result<String, TokenToStringError>)> + '_ {
203 (0..self.n_vocab())
204 .map(LlamaToken::new)
205 .map(move |llama_token| {
206 let mut decoder = encoding_rs::UTF_8.new_decoder();
207 (
208 llama_token,
209 self.token_to_piece(llama_token, &mut decoder, decode_special, None),
210 )
211 })
212 }
213
214 #[must_use]
216 pub fn token_bos(&self) -> LlamaToken {
217 let token = unsafe { llama_cpp_sys_2::llama_token_bos(self.vocab_ptr()) };
218 LlamaToken(token)
219 }
220
221 #[must_use]
223 pub fn token_eos(&self) -> LlamaToken {
224 let token = unsafe { llama_cpp_sys_2::llama_token_eos(self.vocab_ptr()) };
225 LlamaToken(token)
226 }
227
228 #[must_use]
230 pub fn token_nl(&self) -> LlamaToken {
231 let token = unsafe { llama_cpp_sys_2::llama_token_nl(self.vocab_ptr()) };
232 LlamaToken(token)
233 }
234
235 #[must_use]
237 pub fn is_eog_token(&self, token: LlamaToken) -> bool {
238 unsafe { llama_cpp_sys_2::llama_token_is_eog(self.vocab_ptr(), token.0) }
239 }
240
241 #[must_use]
243 pub fn decode_start_token(&self) -> LlamaToken {
244 let token =
245 unsafe { llama_cpp_sys_2::llama_model_decoder_start_token(self.model.as_ptr()) };
246 LlamaToken(token)
247 }
248
249 #[must_use]
251 pub fn token_sep(&self) -> LlamaToken {
252 let token = unsafe { llama_cpp_sys_2::llama_vocab_sep(self.vocab_ptr()) };
253 LlamaToken(token)
254 }
255
256 #[deprecated(since = "0.1.0", note = "Use `token_to_piece` instead")]
262 pub fn token_to_str(
263 &self,
264 token: LlamaToken,
265 special: Special,
266 ) -> Result<String, TokenToStringError> {
267 let mut decoder = encoding_rs::UTF_8.new_decoder();
269 Ok(self.token_to_piece(
270 token,
271 &mut decoder,
272 matches!(special, Special::Tokenize),
273 None,
274 )?)
275 }
276
277 #[deprecated(since = "0.1.0", note = "Use `token_to_piece_bytes` instead")]
287 pub fn token_to_bytes(
288 &self,
289 token: LlamaToken,
290 special: Special,
291 ) -> Result<Vec<u8>, TokenToStringError> {
292 match self.token_to_piece_bytes(token, 8, matches!(special, Special::Tokenize), None) {
294 Err(TokenToStringError::InsufficientBufferSpace(i)) => self.token_to_piece_bytes(
295 token,
296 (-i).try_into().expect("Error buffer size is positive"),
297 matches!(special, Special::Tokenize),
298 None,
299 ),
300 x => x,
301 }
302 }
303
304 #[deprecated(
310 since = "0.1.0",
311 note = "Use `token_to_piece` for each token individually instead"
312 )]
313 pub fn tokens_to_str(
314 &self,
315 tokens: &[LlamaToken],
316 special: Special,
317 ) -> Result<String, TokenToStringError> {
318 let mut builder: Vec<u8> = Vec::with_capacity(tokens.len() * 4);
319 for piece in tokens
320 .iter()
321 .copied()
322 .map(|t| self.token_to_piece_bytes(t, 8, matches!(special, Special::Tokenize), None))
323 {
324 builder.extend_from_slice(&piece?);
325 }
326 Ok(String::from_utf8(builder)?)
327 }
328
329 pub fn str_to_token(
352 &self,
353 str: &str,
354 add_bos: AddBos,
355 ) -> Result<Vec<LlamaToken>, StringToTokenError> {
356 let add_bos = match add_bos {
357 AddBos::Always => true,
358 AddBos::Never => false,
359 };
360
361 let tokens_estimation = std::cmp::max(8, (str.len() / 2) + usize::from(add_bos));
362 let mut buffer: Vec<LlamaToken> = Vec::with_capacity(tokens_estimation);
363
364 let c_string = CString::new(str)?;
365 let buffer_capacity =
366 c_int::try_from(buffer.capacity()).expect("buffer capacity should fit into a c_int");
367
368 let size = unsafe {
369 llama_cpp_sys_2::llama_tokenize(
370 self.vocab_ptr(),
371 c_string.as_ptr(),
372 c_int::try_from(c_string.as_bytes().len())?,
373 buffer.as_mut_ptr().cast::<llama_cpp_sys_2::llama_token>(),
374 buffer_capacity,
375 add_bos,
376 true,
377 )
378 };
379
380 let size = if size.is_negative() {
383 buffer.reserve_exact(usize::try_from(-size).expect("usize's are larger "));
384 unsafe {
385 llama_cpp_sys_2::llama_tokenize(
386 self.vocab_ptr(),
387 c_string.as_ptr(),
388 c_int::try_from(c_string.as_bytes().len())?,
389 buffer.as_mut_ptr().cast::<llama_cpp_sys_2::llama_token>(),
390 -size,
391 add_bos,
392 true,
393 )
394 }
395 } else {
396 size
397 };
398
399 let size = usize::try_from(size).expect("size is positive and usize ");
400
401 unsafe { buffer.set_len(size) }
403 Ok(buffer)
404 }
405
406 #[must_use]
412 pub fn token_attr(&self, LlamaToken(id): LlamaToken) -> LlamaTokenAttrs {
413 let token_type = unsafe { llama_cpp_sys_2::llama_token_get_attr(self.vocab_ptr(), id) };
414 LlamaTokenAttrs::try_from(token_type).expect("token type is valid")
415 }
416
417 pub fn token_to_piece(
435 &self,
436 token: LlamaToken,
437 decoder: &mut encoding_rs::Decoder,
438 special: bool,
439 lstrip: Option<NonZeroU16>,
440 ) -> Result<String, TokenToStringError> {
441 let bytes = match self.token_to_piece_bytes(token, 8, special, lstrip) {
442 Err(TokenToStringError::InsufficientBufferSpace(i)) => self.token_to_piece_bytes(
445 token,
446 (-i).try_into().expect("Error buffer size is positive"),
447 special,
448 lstrip,
449 ),
450 x => x,
451 }?;
452 let mut output_piece = String::with_capacity(bytes.len());
454 let (_result, _somesize, _truthy) =
457 decoder.decode_to_string(&bytes, &mut output_piece, false);
458 Ok(output_piece)
459 }
460
461 pub fn token_to_piece_bytes(
477 &self,
478 token: LlamaToken,
479 buffer_size: usize,
480 special: bool,
481 lstrip: Option<NonZeroU16>,
482 ) -> Result<Vec<u8>, TokenToStringError> {
483 let string = CString::new(vec![b'*'; buffer_size]).expect("no null");
484 let len = string.as_bytes().len();
485 let len = c_int::try_from(len).expect("length fits into c_int");
486 let buf = string.into_raw();
487 let lstrip = lstrip.map_or(0, |it| i32::from(it.get()));
488 let size = unsafe {
489 llama_cpp_sys_2::llama_token_to_piece(
490 self.vocab_ptr(),
491 token.0,
492 buf,
493 len,
494 lstrip,
495 special,
496 )
497 };
498
499 match size {
500 0 => Err(TokenToStringError::UnknownTokenType),
501 i if i.is_negative() => Err(TokenToStringError::InsufficientBufferSpace(i)),
502 size => {
503 let string = unsafe { CString::from_raw(buf) };
504 let mut bytes = string.into_bytes();
505 let len = usize::try_from(size).expect("size is positive and fits into usize");
506 bytes.truncate(len);
507 Ok(bytes)
508 }
509 }
510 }
511
512 #[deprecated(since = "0.1.0", note = "Use `token_to_piece` instead")]
528 pub fn token_to_str_with_size(
529 &self,
530 token: LlamaToken,
531 buffer_size: usize,
532 special: Special,
533 ) -> Result<String, TokenToStringError> {
534 let bytes = self.token_to_piece_bytes(
535 token,
536 buffer_size,
537 matches!(special, Special::Tokenize),
538 None,
539 )?;
540 Ok(String::from_utf8(bytes)?)
541 }
542
543 #[deprecated(since = "0.1.0", note = "Use `token_to_piece_bytes` instead")]
558 pub fn token_to_bytes_with_size(
559 &self,
560 token: LlamaToken,
561 buffer_size: usize,
562 special: Special,
563 lstrip: Option<NonZeroU16>,
564 ) -> Result<Vec<u8>, TokenToStringError> {
565 if token == self.token_nl() {
566 return Ok(b"\n".to_vec());
567 }
568
569 let attrs = self.token_attr(token);
571 if attrs.is_empty()
572 || attrs
573 .intersects(LlamaTokenAttr::Unknown | LlamaTokenAttr::Byte | LlamaTokenAttr::Unused)
574 || attrs.contains(LlamaTokenAttr::Control)
575 && (token == self.token_bos() || token == self.token_eos())
576 {
577 return Ok(Vec::new());
578 }
579
580 let special = match special {
581 Special::Tokenize => true,
582 Special::Plaintext => false,
583 };
584
585 let string = CString::new(vec![b'*'; buffer_size]).expect("no null");
586 let len = string.as_bytes().len();
587 let len = c_int::try_from(len).expect("length fits into c_int");
588 let buf = string.into_raw();
589 let lstrip = lstrip.map_or(0, |it| i32::from(it.get()));
590 let size = unsafe {
591 llama_cpp_sys_2::llama_token_to_piece(
592 self.vocab_ptr(),
593 token.0,
594 buf,
595 len,
596 lstrip,
597 special,
598 )
599 };
600
601 match size {
602 0 => Err(TokenToStringError::UnknownTokenType),
603 i if i.is_negative() => Err(TokenToStringError::InsufficientBufferSpace(i)),
604 size => {
605 let string = unsafe { CString::from_raw(buf) };
606 let mut bytes = string.into_bytes();
607 let len = usize::try_from(size).expect("size is positive and fits into usize");
608 bytes.truncate(len);
609 Ok(bytes)
610 }
611 }
612 }
613 #[must_use]
618 pub fn n_vocab(&self) -> i32 {
619 unsafe { llama_cpp_sys_2::llama_n_vocab(self.vocab_ptr()) }
620 }
621
622 #[must_use]
628 pub fn vocab_type(&self) -> VocabType {
629 let vocab_type = unsafe { llama_cpp_sys_2::llama_vocab_type(self.vocab_ptr()) };
631 VocabType::try_from(vocab_type).expect("invalid vocab type")
632 }
633
634 #[must_use]
637 pub fn n_embd(&self) -> c_int {
638 unsafe { llama_cpp_sys_2::llama_n_embd(self.model.as_ptr()) }
639 }
640
641 pub fn size(&self) -> u64 {
643 unsafe { llama_cpp_sys_2::llama_model_size(self.model.as_ptr()) }
644 }
645
646 pub fn n_params(&self) -> u64 {
648 unsafe { llama_cpp_sys_2::llama_model_n_params(self.model.as_ptr()) }
649 }
650
651 pub fn is_recurrent(&self) -> bool {
653 unsafe { llama_cpp_sys_2::llama_model_is_recurrent(self.model.as_ptr()) }
654 }
655
656 pub fn is_hybrid(&self) -> bool {
661 unsafe { llama_cpp_sys_2::llama_model_is_hybrid(self.model.as_ptr()) }
662 }
663
664 pub fn n_layer(&self) -> u32 {
666 u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_layer(self.model.as_ptr()) }).unwrap()
669 }
670
671 pub fn n_head(&self) -> u32 {
673 u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_head(self.model.as_ptr()) }).unwrap()
676 }
677
678 pub fn n_head_kv(&self) -> u32 {
680 u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_head_kv(self.model.as_ptr()) })
683 .unwrap()
684 }
685
686 pub fn meta_val_str(&self, key: &str) -> Result<String, MetaValError> {
688 let key_cstring = CString::new(key)?;
689 let key_ptr = key_cstring.as_ptr();
690
691 extract_meta_string(
692 |buf_ptr, buf_len| unsafe {
693 llama_cpp_sys_2::llama_model_meta_val_str(
694 self.model.as_ptr(),
695 key_ptr,
696 buf_ptr,
697 buf_len,
698 )
699 },
700 256,
701 )
702 }
703
704 pub fn meta_count(&self) -> i32 {
706 unsafe { llama_cpp_sys_2::llama_model_meta_count(self.model.as_ptr()) }
707 }
708
709 pub fn meta_key_by_index(&self, index: i32) -> Result<String, MetaValError> {
711 extract_meta_string(
712 |buf_ptr, buf_len| unsafe {
713 llama_cpp_sys_2::llama_model_meta_key_by_index(
714 self.model.as_ptr(),
715 index,
716 buf_ptr,
717 buf_len,
718 )
719 },
720 256,
721 )
722 }
723
724 pub fn meta_val_str_by_index(&self, index: i32) -> Result<String, MetaValError> {
726 extract_meta_string(
727 |buf_ptr, buf_len| unsafe {
728 llama_cpp_sys_2::llama_model_meta_val_str_by_index(
729 self.model.as_ptr(),
730 index,
731 buf_ptr,
732 buf_len,
733 )
734 },
735 256,
736 )
737 }
738
739 pub fn rope_type(&self) -> Option<RopeType> {
741 match unsafe { llama_cpp_sys_2::llama_model_rope_type(self.model.as_ptr()) } {
742 llama_cpp_sys_2::LLAMA_ROPE_TYPE_NONE => None,
743 llama_cpp_sys_2::LLAMA_ROPE_TYPE_NORM => Some(RopeType::Norm),
744 llama_cpp_sys_2::LLAMA_ROPE_TYPE_NEOX => Some(RopeType::NeoX),
745 llama_cpp_sys_2::LLAMA_ROPE_TYPE_MROPE => Some(RopeType::MRope),
746 llama_cpp_sys_2::LLAMA_ROPE_TYPE_VISION => Some(RopeType::Vision),
747 rope_type => {
748 tracing::error!(rope_type = rope_type, "Unexpected rope type from llama.cpp");
749 None
750 }
751 }
752 }
753
754 pub fn chat_template(
768 &self,
769 name: Option<&str>,
770 ) -> Result<LlamaChatTemplate, ChatTemplateError> {
771 let name_cstr = name.map(CString::new);
772 let name_ptr = match name_cstr {
773 Some(Ok(name)) => name.as_ptr(),
774 _ => std::ptr::null(),
775 };
776 let result =
777 unsafe { llama_cpp_sys_2::llama_model_chat_template(self.model.as_ptr(), name_ptr) };
778
779 if result.is_null() {
781 Err(ChatTemplateError::MissingTemplate)
782 } else {
783 let chat_template_cstr = unsafe { CStr::from_ptr(result) };
784 let chat_template = CString::new(chat_template_cstr.to_bytes())?;
785 Ok(LlamaChatTemplate(chat_template))
786 }
787 }
788
789 #[tracing::instrument(skip_all, fields(params))]
795 pub fn load_from_file(
796 _: &LlamaBackend,
797 path: impl AsRef<Path>,
798 params: &LlamaModelParams,
799 ) -> Result<Self, LlamaModelLoadError> {
800 let path = path.as_ref();
801 debug_assert!(Path::new(path).exists(), "{path:?} does not exist");
802 let path = path
803 .to_str()
804 .ok_or(LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
805
806 let cstr = CString::new(path)?;
807 let llama_model =
808 unsafe { llama_cpp_sys_2::llama_load_model_from_file(cstr.as_ptr(), params.params) };
809
810 let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?;
811
812 tracing::debug!(?path, "Loaded model");
813 Ok(LlamaModel { model })
814 }
815
816 pub fn lora_adapter_init(
822 &self,
823 path: impl AsRef<Path>,
824 ) -> Result<LlamaLoraAdapter, LlamaLoraAdapterInitError> {
825 let path = path.as_ref();
826 debug_assert!(Path::new(path).exists(), "{path:?} does not exist");
827
828 let path = path
829 .to_str()
830 .ok_or(LlamaLoraAdapterInitError::PathToStrError(
831 path.to_path_buf(),
832 ))?;
833
834 let cstr = CString::new(path)?;
835 let adapter =
836 unsafe { llama_cpp_sys_2::llama_adapter_lora_init(self.model.as_ptr(), cstr.as_ptr()) };
837
838 let adapter = NonNull::new(adapter).ok_or(LlamaLoraAdapterInitError::NullResult)?;
839
840 tracing::debug!(?path, "Initialized lora adapter");
841 Ok(LlamaLoraAdapter {
842 lora_adapter: adapter,
843 })
844 }
845
846 #[allow(clippy::needless_pass_by_value)]
853 pub fn new_context<'a>(
854 &'a self,
855 _: &LlamaBackend,
856 params: LlamaContextParams,
857 ) -> Result<LlamaContext<'a>, LlamaContextLoadError> {
858 let context_params = params.context_params;
859 let context = unsafe {
860 llama_cpp_sys_2::llama_new_context_with_model(self.model.as_ptr(), context_params)
861 };
862 let context = NonNull::new(context).ok_or(LlamaContextLoadError::NullReturn)?;
863
864 Ok(LlamaContext::new(self, context, params.embeddings()))
865 }
866
867 #[tracing::instrument(skip_all)]
885 pub fn apply_chat_template(
886 &self,
887 tmpl: &LlamaChatTemplate,
888 chat: &[LlamaChatMessage],
889 add_ass: bool,
890 ) -> Result<String, ApplyChatTemplateError> {
891 let message_length = chat.iter().fold(0, |acc, c| {
893 acc + c.role.to_bytes().len() + c.content.to_bytes().len()
894 });
895 let mut buff: Vec<u8> = vec![0; message_length * 2];
896
897 let chat: Vec<llama_cpp_sys_2::llama_chat_message> = chat
899 .iter()
900 .map(|c| llama_cpp_sys_2::llama_chat_message {
901 role: c.role.as_ptr(),
902 content: c.content.as_ptr(),
903 })
904 .collect();
905
906 let tmpl_ptr = tmpl.0.as_ptr();
907
908 let res = unsafe {
909 llama_cpp_sys_2::llama_chat_apply_template(
910 tmpl_ptr,
911 chat.as_ptr(),
912 chat.len(),
913 add_ass,
914 buff.as_mut_ptr().cast::<c_char>(),
915 buff.len().try_into().expect("Buffer size exceeds i32::MAX"),
916 )
917 };
918
919 if res < 0 {
920 return Err(ApplyChatTemplateError::FfiError(res));
921 }
922
923 if res > buff.len().try_into().expect("Buffer size exceeds i32::MAX") {
924 buff.resize(res.try_into().expect("res is negative"), 0);
925
926 let res = unsafe {
927 llama_cpp_sys_2::llama_chat_apply_template(
928 tmpl_ptr,
929 chat.as_ptr(),
930 chat.len(),
931 add_ass,
932 buff.as_mut_ptr().cast::<c_char>(),
933 buff.len().try_into().expect("Buffer size exceeds i32::MAX"),
934 )
935 };
936 if res < 0 {
937 return Err(ApplyChatTemplateError::FfiError(res));
938 }
939 assert_eq!(Ok(res), buff.len().try_into());
940 }
941 buff.truncate(res.try_into().expect("res is negative"));
942 Ok(String::from_utf8(buff)?)
943 }
944
945 #[tracing::instrument(skip_all)]
949 pub fn apply_chat_template_with_tools_oaicompat(
950 &self,
951 tmpl: &LlamaChatTemplate,
952 messages: &[LlamaChatMessage],
953 tools_json: Option<&str>,
954 json_schema: Option<&str>,
955 add_generation_prompt: bool,
956 ) -> Result<ChatTemplateResult, ApplyChatTemplateError> {
957 let chat: Vec<llama_cpp_sys_2::llama_chat_message> = messages
958 .iter()
959 .map(|c| llama_cpp_sys_2::llama_chat_message {
960 role: c.role.as_ptr(),
961 content: c.content.as_ptr(),
962 })
963 .collect();
964
965 let tools_cstr = tools_json.map(CString::new).transpose()?;
966 let json_schema_cstr = json_schema.map(CString::new).transpose()?;
967
968 let mut raw_result = llama_cpp_sys_2::llama_rs_chat_template_result {
969 prompt: ptr::null_mut(),
970 grammar: ptr::null_mut(),
971 parser: ptr::null_mut(),
972 generation_prompt: ptr::null_mut(),
973 chat_format: 0,
974 grammar_lazy: false,
975 grammar_triggers: ptr::null_mut(),
976 grammar_triggers_count: 0,
977 preserved_tokens: ptr::null_mut(),
978 preserved_tokens_count: 0,
979 additional_stops: ptr::null_mut(),
980 additional_stops_count: 0,
981 };
982
983 let rc = unsafe {
984 llama_cpp_sys_2::llama_rs_apply_chat_template_with_tools_oaicompat(
985 self.model.as_ptr(),
986 tmpl.0.as_ptr(),
987 chat.as_ptr(),
988 chat.len(),
989 tools_cstr
990 .as_ref()
991 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
992 json_schema_cstr
993 .as_ref()
994 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
995 add_generation_prompt,
996 &mut raw_result,
997 )
998 };
999
1000 let result = (|| {
1001 if !status_is_ok(rc) {
1002 return Err(ApplyChatTemplateError::FfiError(status_to_i32(rc)));
1003 }
1004 if raw_result.prompt.is_null() {
1005 return Err(ApplyChatTemplateError::NullResult);
1006 }
1007 let prompt_bytes = unsafe { CStr::from_ptr(raw_result.prompt) }
1008 .to_bytes()
1009 .to_vec();
1010 let prompt = String::from_utf8(prompt_bytes)?;
1011 let grammar_lazy = raw_result.grammar_lazy;
1012 let grammar = if raw_result.grammar.is_null() {
1013 None
1014 } else {
1015 let grammar_bytes = unsafe { CStr::from_ptr(raw_result.grammar) }
1016 .to_bytes()
1017 .to_vec();
1018 Some(String::from_utf8(grammar_bytes)?)
1019 };
1020 let parser = if raw_result.parser.is_null() {
1021 None
1022 } else {
1023 let parser_bytes = unsafe { CStr::from_ptr(raw_result.parser) }
1024 .to_bytes()
1025 .to_vec();
1026 Some(String::from_utf8(parser_bytes)?)
1027 };
1028 let generation_prompt = if raw_result.generation_prompt.is_null() {
1029 String::new()
1030 } else {
1031 let generation_prompt_bytes =
1032 unsafe { CStr::from_ptr(raw_result.generation_prompt) }
1033 .to_bytes()
1034 .to_vec();
1035 String::from_utf8(generation_prompt_bytes)?
1036 };
1037 let grammar_triggers = if raw_result.grammar_triggers_count == 0 {
1038 Vec::new()
1039 } else if raw_result.grammar_triggers.is_null() {
1040 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1041 } else {
1042 let triggers = unsafe {
1043 slice::from_raw_parts(
1044 raw_result.grammar_triggers,
1045 raw_result.grammar_triggers_count,
1046 )
1047 };
1048 let mut parsed = Vec::with_capacity(triggers.len());
1049 for trigger in triggers {
1050 let trigger_type = match trigger.type_ {
1051 0 => GrammarTriggerType::Token,
1052 1 => GrammarTriggerType::Word,
1053 2 => GrammarTriggerType::Pattern,
1054 3 => GrammarTriggerType::PatternFull,
1055 _ => return Err(ApplyChatTemplateError::InvalidGrammarTriggerType),
1056 };
1057 let value = if trigger.value.is_null() {
1058 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1059 } else {
1060 let bytes = unsafe { CStr::from_ptr(trigger.value) }.to_bytes().to_vec();
1061 String::from_utf8(bytes)?
1062 };
1063 let token = if trigger_type == GrammarTriggerType::Token {
1064 Some(LlamaToken(trigger.token))
1065 } else {
1066 None
1067 };
1068 parsed.push(GrammarTrigger {
1069 trigger_type,
1070 value,
1071 token,
1072 });
1073 }
1074 parsed
1075 };
1076 let preserved_tokens = if raw_result.preserved_tokens_count == 0 {
1077 Vec::new()
1078 } else if raw_result.preserved_tokens.is_null() {
1079 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1080 } else {
1081 let tokens = unsafe {
1082 slice::from_raw_parts(
1083 raw_result.preserved_tokens,
1084 raw_result.preserved_tokens_count,
1085 )
1086 };
1087 let mut parsed = Vec::with_capacity(tokens.len());
1088 for token in tokens {
1089 if token.is_null() {
1090 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1091 }
1092 let bytes = unsafe { CStr::from_ptr(*token) }.to_bytes().to_vec();
1093 parsed.push(String::from_utf8(bytes)?);
1094 }
1095 parsed
1096 };
1097 let additional_stops = if raw_result.additional_stops_count == 0 {
1098 Vec::new()
1099 } else if raw_result.additional_stops.is_null() {
1100 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1101 } else {
1102 let stops = unsafe {
1103 slice::from_raw_parts(
1104 raw_result.additional_stops,
1105 raw_result.additional_stops_count,
1106 )
1107 };
1108 let mut parsed = Vec::with_capacity(stops.len());
1109 for stop in stops {
1110 if stop.is_null() {
1111 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1112 }
1113 let bytes = unsafe { CStr::from_ptr(*stop) }.to_bytes().to_vec();
1114 parsed.push(String::from_utf8(bytes)?);
1115 }
1116 parsed
1117 };
1118 let parse_tool_calls = tools_json.map_or(false, |tools| !tools.is_empty());
1119 Ok(ChatTemplateResult {
1120 prompt,
1121 grammar,
1122 grammar_lazy,
1123 grammar_triggers,
1124 preserved_tokens,
1125 additional_stops,
1126 chat_format: raw_result.chat_format,
1127 parser,
1128 generation_prompt,
1129 parse_tool_calls,
1130 })
1131 })();
1132
1133 unsafe { llama_cpp_sys_2::llama_rs_chat_template_result_free(&mut raw_result) };
1134 result
1135 }
1136
1137 #[tracing::instrument(skip_all)]
1139 pub fn apply_chat_template_oaicompat(
1140 &self,
1141 tmpl: &LlamaChatTemplate,
1142 params: &OpenAIChatTemplateParams<'_>,
1143 ) -> Result<ChatTemplateResult, ApplyChatTemplateError> {
1144 let parse_tool_calls = params.parse_tool_calls;
1145 let messages_cstr = CString::new(params.messages_json)?;
1146 let tools_cstr = params.tools_json.map(CString::new).transpose()?;
1147 let tool_choice_cstr = params.tool_choice.map(CString::new).transpose()?;
1148 let json_schema_cstr = params.json_schema.map(CString::new).transpose()?;
1149 let grammar_cstr = params.grammar.map(CString::new).transpose()?;
1150 let reasoning_cstr = params.reasoning_format.map(CString::new).transpose()?;
1151 let kwargs_cstr = params.chat_template_kwargs.map(CString::new).transpose()?;
1152
1153 let mut raw_result = llama_cpp_sys_2::llama_rs_chat_template_result {
1154 prompt: ptr::null_mut(),
1155 grammar: ptr::null_mut(),
1156 parser: ptr::null_mut(),
1157 generation_prompt: ptr::null_mut(),
1158 chat_format: 0,
1159 grammar_lazy: false,
1160 grammar_triggers: ptr::null_mut(),
1161 grammar_triggers_count: 0,
1162 preserved_tokens: ptr::null_mut(),
1163 preserved_tokens_count: 0,
1164 additional_stops: ptr::null_mut(),
1165 additional_stops_count: 0,
1166 };
1167
1168 let ffi_params = llama_cpp_sys_2::llama_rs_chat_template_oaicompat_params {
1169 messages: messages_cstr.as_ptr(),
1170 tools: tools_cstr
1171 .as_ref()
1172 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
1173 tool_choice: tool_choice_cstr
1174 .as_ref()
1175 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
1176 json_schema: json_schema_cstr
1177 .as_ref()
1178 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
1179 grammar: grammar_cstr
1180 .as_ref()
1181 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
1182 reasoning_format: reasoning_cstr
1183 .as_ref()
1184 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
1185 chat_template_kwargs: kwargs_cstr
1186 .as_ref()
1187 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
1188 add_generation_prompt: params.add_generation_prompt,
1189 use_jinja: params.use_jinja,
1190 parallel_tool_calls: params.parallel_tool_calls,
1191 enable_thinking: params.enable_thinking,
1192 add_bos: params.add_bos,
1193 add_eos: params.add_eos,
1194 };
1195
1196 let rc = unsafe {
1197 llama_cpp_sys_2::llama_rs_apply_chat_template_oaicompat(
1198 self.model.as_ptr(),
1199 tmpl.0.as_ptr(),
1200 &ffi_params,
1201 &mut raw_result,
1202 )
1203 };
1204
1205 let result = (|| {
1206 if !status_is_ok(rc) {
1207 return Err(ApplyChatTemplateError::FfiError(status_to_i32(rc)));
1208 }
1209 if raw_result.prompt.is_null() {
1210 return Err(ApplyChatTemplateError::NullResult);
1211 }
1212 let prompt_bytes = unsafe { CStr::from_ptr(raw_result.prompt) }
1213 .to_bytes()
1214 .to_vec();
1215 let prompt = String::from_utf8(prompt_bytes)?;
1216 let grammar_lazy = raw_result.grammar_lazy;
1217 let grammar = if raw_result.grammar.is_null() {
1218 None
1219 } else {
1220 let grammar_bytes = unsafe { CStr::from_ptr(raw_result.grammar) }
1221 .to_bytes()
1222 .to_vec();
1223 Some(String::from_utf8(grammar_bytes)?)
1224 };
1225 let parser = if raw_result.parser.is_null() {
1226 None
1227 } else {
1228 let parser_bytes = unsafe { CStr::from_ptr(raw_result.parser) }
1229 .to_bytes()
1230 .to_vec();
1231 Some(String::from_utf8(parser_bytes)?)
1232 };
1233 let generation_prompt = if raw_result.generation_prompt.is_null() {
1234 String::new()
1235 } else {
1236 let generation_prompt_bytes =
1237 unsafe { CStr::from_ptr(raw_result.generation_prompt) }
1238 .to_bytes()
1239 .to_vec();
1240 String::from_utf8(generation_prompt_bytes)?
1241 };
1242 let grammar_triggers = if raw_result.grammar_triggers_count == 0 {
1243 Vec::new()
1244 } else if raw_result.grammar_triggers.is_null() {
1245 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1246 } else {
1247 let triggers = unsafe {
1248 slice::from_raw_parts(
1249 raw_result.grammar_triggers,
1250 raw_result.grammar_triggers_count,
1251 )
1252 };
1253 let mut parsed = Vec::with_capacity(triggers.len());
1254 for trigger in triggers {
1255 let trigger_type = match trigger.type_ {
1256 0 => GrammarTriggerType::Token,
1257 1 => GrammarTriggerType::Word,
1258 2 => GrammarTriggerType::Pattern,
1259 3 => GrammarTriggerType::PatternFull,
1260 _ => return Err(ApplyChatTemplateError::InvalidGrammarTriggerType),
1261 };
1262 let value = if trigger.value.is_null() {
1263 String::new()
1264 } else {
1265 let bytes = unsafe { CStr::from_ptr(trigger.value) }.to_bytes().to_vec();
1266 String::from_utf8(bytes)?
1267 };
1268 let token = if trigger_type == GrammarTriggerType::Token {
1269 Some(LlamaToken(trigger.token))
1270 } else {
1271 None
1272 };
1273 parsed.push(GrammarTrigger {
1274 trigger_type,
1275 value,
1276 token,
1277 });
1278 }
1279 parsed
1280 };
1281 let preserved_tokens = if raw_result.preserved_tokens_count == 0 {
1282 Vec::new()
1283 } else if raw_result.preserved_tokens.is_null() {
1284 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1285 } else {
1286 let tokens = unsafe {
1287 slice::from_raw_parts(
1288 raw_result.preserved_tokens,
1289 raw_result.preserved_tokens_count,
1290 )
1291 };
1292 let mut parsed = Vec::with_capacity(tokens.len());
1293 for token in tokens {
1294 if token.is_null() {
1295 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1296 }
1297 let bytes = unsafe { CStr::from_ptr(*token) }.to_bytes().to_vec();
1298 parsed.push(String::from_utf8(bytes)?);
1299 }
1300 parsed
1301 };
1302 let additional_stops = if raw_result.additional_stops_count == 0 {
1303 Vec::new()
1304 } else if raw_result.additional_stops.is_null() {
1305 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1306 } else {
1307 let stops = unsafe {
1308 slice::from_raw_parts(
1309 raw_result.additional_stops,
1310 raw_result.additional_stops_count,
1311 )
1312 };
1313 let mut parsed = Vec::with_capacity(stops.len());
1314 for stop in stops {
1315 if stop.is_null() {
1316 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1317 }
1318 let bytes = unsafe { CStr::from_ptr(*stop) }.to_bytes().to_vec();
1319 parsed.push(String::from_utf8(bytes)?);
1320 }
1321 parsed
1322 };
1323
1324 Ok(ChatTemplateResult {
1325 prompt,
1326 grammar,
1327 grammar_lazy,
1328 grammar_triggers,
1329 preserved_tokens,
1330 additional_stops,
1331 chat_format: raw_result.chat_format,
1332 parser,
1333 generation_prompt,
1334 parse_tool_calls,
1335 })
1336 })();
1337
1338 unsafe { llama_cpp_sys_2::llama_rs_chat_template_result_free(&mut raw_result) };
1339 result
1340 }
1341}
1342
1343impl ChatTemplateResult {
1344 pub fn parse_response_oaicompat(
1346 &self,
1347 text: &str,
1348 is_partial: bool,
1349 ) -> Result<String, ChatParseError> {
1350 let text_cstr = CString::new(text)?;
1351 let parser_cstr = self.parser.as_deref().map(CString::new).transpose()?;
1352 let generation_prompt_cstr = if self.generation_prompt.is_empty() {
1353 None
1354 } else {
1355 Some(CString::new(self.generation_prompt.as_str())?)
1356 };
1357 let mut out_json: *mut c_char = ptr::null_mut();
1358 let rc = unsafe {
1359 llama_cpp_sys_2::llama_rs_chat_parse_to_oaicompat(
1360 text_cstr.as_ptr(),
1361 is_partial,
1362 self.chat_format,
1363 self.parse_tool_calls,
1364 parser_cstr
1365 .as_ref()
1366 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
1367 generation_prompt_cstr
1368 .as_ref()
1369 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
1370 &mut out_json,
1371 )
1372 };
1373
1374 let result = (|| {
1375 if !status_is_ok(rc) {
1376 return Err(ChatParseError::FfiError(status_to_i32(rc)));
1377 }
1378 if out_json.is_null() {
1379 return Err(ChatParseError::NullResult);
1380 }
1381 let bytes = unsafe { CStr::from_ptr(out_json) }.to_bytes().to_vec();
1382 Ok(String::from_utf8(bytes)?)
1383 })();
1384
1385 unsafe { llama_cpp_sys_2::llama_rs_string_free(out_json) };
1386 result
1387 }
1388
1389 pub fn streaming_state_oaicompat(&self) -> Result<ChatParseStateOaicompat, ChatParseError> {
1391 let parser_cstr = self.parser.as_deref().map(CString::new).transpose()?;
1392 let generation_prompt_cstr = if self.generation_prompt.is_empty() {
1393 None
1394 } else {
1395 Some(CString::new(self.generation_prompt.as_str())?)
1396 };
1397 let state = unsafe {
1398 llama_cpp_sys_2::llama_rs_chat_parse_state_init_oaicompat(
1399 self.chat_format,
1400 self.parse_tool_calls,
1401 parser_cstr
1402 .as_ref()
1403 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
1404 generation_prompt_cstr
1405 .as_ref()
1406 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
1407 )
1408 };
1409 let state = NonNull::new(state).ok_or(ChatParseError::NullResult)?;
1410 Ok(ChatParseStateOaicompat { state })
1411 }
1412}
1413
1414fn extract_meta_string<F>(c_function: F, capacity: usize) -> Result<String, MetaValError>
1420where
1421 F: Fn(*mut c_char, usize) -> i32,
1422{
1423 let mut buffer = vec![0u8; capacity];
1424
1425 let result = c_function(buffer.as_mut_ptr().cast::<c_char>(), buffer.len());
1427 if result < 0 {
1428 return Err(MetaValError::NegativeReturn(result));
1429 }
1430
1431 let returned_len = result as usize;
1433 if returned_len >= capacity {
1434 return extract_meta_string(c_function, returned_len + 1);
1436 }
1437
1438 debug_assert_eq!(
1440 buffer.get(returned_len),
1441 Some(&0),
1442 "should end with null byte"
1443 );
1444
1445 buffer.truncate(returned_len);
1447 Ok(String::from_utf8(buffer)?)
1448}
1449
1450impl Drop for LlamaModel {
1451 fn drop(&mut self) {
1452 unsafe { llama_cpp_sys_2::llama_free_model(self.model.as_ptr()) }
1453 }
1454}
1455
1456#[repr(u32)]
1458#[derive(Debug, Eq, Copy, Clone, PartialEq)]
1459pub enum VocabType {
1460 BPE = llama_cpp_sys_2::LLAMA_VOCAB_TYPE_BPE as _,
1462 SPM = llama_cpp_sys_2::LLAMA_VOCAB_TYPE_SPM as _,
1464}
1465
1466#[derive(thiserror::Error, Debug, Eq, PartialEq)]
1468pub enum LlamaTokenTypeFromIntError {
1469 #[error("Unknown Value {0}")]
1471 UnknownValue(llama_cpp_sys_2::llama_vocab_type),
1472}
1473
1474impl TryFrom<llama_cpp_sys_2::llama_vocab_type> for VocabType {
1475 type Error = LlamaTokenTypeFromIntError;
1476
1477 fn try_from(value: llama_cpp_sys_2::llama_vocab_type) -> Result<Self, Self::Error> {
1478 match value {
1479 llama_cpp_sys_2::LLAMA_VOCAB_TYPE_BPE => Ok(VocabType::BPE),
1480 llama_cpp_sys_2::LLAMA_VOCAB_TYPE_SPM => Ok(VocabType::SPM),
1481 unknown => Err(LlamaTokenTypeFromIntError::UnknownValue(unknown)),
1482 }
1483 }
1484}