1use std::ffi::{c_char, CStr, CString};
3use std::num::NonZeroU16;
4use std::os::raw::c_int;
5use std::path::Path;
6use std::ptr::{self, NonNull};
7use std::slice;
8use std::str::Utf8Error;
9
10use crate::context::params::LlamaContextParams;
11use crate::context::LlamaContext;
12use crate::llama_backend::LlamaBackend;
13use crate::model::params::LlamaModelParams;
14use crate::openai::{ChatParseStateOaicompat, OpenAIChatTemplateParams};
15use crate::token::LlamaToken;
16use crate::token_type::{LlamaTokenAttr, LlamaTokenAttrs};
17use crate::{
18 status_is_ok, status_to_i32, ApplyChatTemplateError, ChatParseError, ChatTemplateError,
19 LlamaContextLoadError, LlamaLoraAdapterInitError, LlamaModelLoadError, MetaValError,
20 NewLlamaChatMessageError, StringToTokenError, TokenToStringError,
21};
22
23pub mod params;
24
25#[derive(Debug)]
27#[repr(transparent)]
28#[allow(clippy::module_name_repetitions)]
29pub struct LlamaModel {
30 pub(crate) model: NonNull<llama_cpp_sys_2::llama_model>,
31}
32
33#[derive(Debug)]
35#[repr(transparent)]
36#[allow(clippy::module_name_repetitions)]
37pub struct LlamaLoraAdapter {
38 pub(crate) lora_adapter: NonNull<llama_cpp_sys_2::llama_adapter_lora>,
39}
40
41#[derive(Eq, PartialEq, Clone, PartialOrd, Ord, Hash)]
46pub struct LlamaChatTemplate(CString);
47
48impl LlamaChatTemplate {
49 pub fn new(template: &str) -> Result<Self, std::ffi::NulError> {
52 Ok(Self(CString::new(template)?))
53 }
54
55 pub fn as_c_str(&self) -> &CStr {
57 &self.0
58 }
59
60 pub fn to_str(&self) -> Result<&str, Utf8Error> {
62 self.0.to_str()
63 }
64
65 pub fn to_string(&self) -> Result<String, Utf8Error> {
67 self.to_str().map(str::to_string)
68 }
69}
70
71impl std::fmt::Debug for LlamaChatTemplate {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 self.0.fmt(f)
74 }
75}
76
77#[derive(Debug, Eq, PartialEq, Clone)]
79pub struct LlamaChatMessage {
80 role: CString,
81 content: CString,
82}
83
84impl LlamaChatMessage {
85 pub fn new(role: String, content: String) -> Result<Self, NewLlamaChatMessageError> {
90 Ok(Self {
91 role: CString::new(role)?,
92 content: CString::new(content)?,
93 })
94 }
95}
96
97#[derive(Debug, Clone, Copy, PartialEq, Eq)]
99pub enum GrammarTriggerType {
100 Token = 0,
102 Word = 1,
104 Pattern = 2,
106 PatternFull = 3,
108}
109
110#[derive(Debug, Clone, PartialEq, Eq)]
112pub struct GrammarTrigger {
113 pub trigger_type: GrammarTriggerType,
115 pub value: String,
117 pub token: Option<LlamaToken>,
119}
120
121#[derive(Debug, Clone, PartialEq, Eq)]
123pub struct ChatTemplateResult {
124 pub prompt: String,
126 pub grammar: Option<String>,
128 pub grammar_lazy: bool,
130 pub grammar_triggers: Vec<GrammarTrigger>,
132 pub preserved_tokens: Vec<String>,
134 pub additional_stops: Vec<String>,
136 pub chat_format: i32,
138 pub parser: Option<String>,
140 pub generation_prompt: String,
142 pub thinking_forced_open: bool,
144 pub parse_tool_calls: bool,
146}
147
148#[derive(Debug, Clone, Copy, PartialEq, Eq)]
150pub enum RopeType {
151 Norm,
152 NeoX,
153 MRope,
154 Vision,
155}
156
157#[derive(Debug, Clone, Copy, PartialEq, Eq)]
159pub enum AddBos {
160 Always,
162 Never,
164}
165
166#[deprecated(
168 since = "0.1.0",
169 note = "This enum is a mixture of options for llama cpp providing less flexibility it only used with deprecated methods and will be removed in the future."
170)]
171#[derive(Debug, Clone, Copy, PartialEq, Eq)]
172pub enum Special {
173 Tokenize,
175 Plaintext,
177}
178
179unsafe impl Send for LlamaModel {}
180
181unsafe impl Sync for LlamaModel {}
182
183impl LlamaModel {
184 pub(crate) fn vocab_ptr(&self) -> *const llama_cpp_sys_2::llama_vocab {
185 unsafe { llama_cpp_sys_2::llama_model_get_vocab(self.model.as_ptr()) }
186 }
187
188 #[must_use]
195 pub fn n_ctx_train(&self) -> u32 {
196 let n_ctx_train = unsafe { llama_cpp_sys_2::llama_n_ctx_train(self.model.as_ptr()) };
197 u32::try_from(n_ctx_train).expect("n_ctx_train fits into an u32")
198 }
199
200 pub fn tokens(
202 &self,
203 decode_special: bool,
204 ) -> impl Iterator<Item = (LlamaToken, Result<String, TokenToStringError>)> + '_ {
205 (0..self.n_vocab())
206 .map(LlamaToken::new)
207 .map(move |llama_token| {
208 let mut decoder = encoding_rs::UTF_8.new_decoder();
209 (
210 llama_token,
211 self.token_to_piece(llama_token, &mut decoder, decode_special, None),
212 )
213 })
214 }
215
216 #[must_use]
218 pub fn token_bos(&self) -> LlamaToken {
219 let token = unsafe { llama_cpp_sys_2::llama_token_bos(self.vocab_ptr()) };
220 LlamaToken(token)
221 }
222
223 #[must_use]
225 pub fn token_eos(&self) -> LlamaToken {
226 let token = unsafe { llama_cpp_sys_2::llama_token_eos(self.vocab_ptr()) };
227 LlamaToken(token)
228 }
229
230 #[must_use]
232 pub fn token_nl(&self) -> LlamaToken {
233 let token = unsafe { llama_cpp_sys_2::llama_token_nl(self.vocab_ptr()) };
234 LlamaToken(token)
235 }
236
237 #[must_use]
239 pub fn is_eog_token(&self, token: LlamaToken) -> bool {
240 unsafe { llama_cpp_sys_2::llama_token_is_eog(self.vocab_ptr(), token.0) }
241 }
242
243 #[must_use]
245 pub fn decode_start_token(&self) -> LlamaToken {
246 let token =
247 unsafe { llama_cpp_sys_2::llama_model_decoder_start_token(self.model.as_ptr()) };
248 LlamaToken(token)
249 }
250
251 #[must_use]
253 pub fn token_sep(&self) -> LlamaToken {
254 let token = unsafe { llama_cpp_sys_2::llama_vocab_sep(self.vocab_ptr()) };
255 LlamaToken(token)
256 }
257
258 #[deprecated(since = "0.1.0", note = "Use `token_to_piece` instead")]
264 pub fn token_to_str(
265 &self,
266 token: LlamaToken,
267 special: Special,
268 ) -> Result<String, TokenToStringError> {
269 let mut decoder = encoding_rs::UTF_8.new_decoder();
271 Ok(self.token_to_piece(
272 token,
273 &mut decoder,
274 matches!(special, Special::Tokenize),
275 None,
276 )?)
277 }
278
279 #[deprecated(since = "0.1.0", note = "Use `token_to_piece_bytes` instead")]
289 pub fn token_to_bytes(
290 &self,
291 token: LlamaToken,
292 special: Special,
293 ) -> Result<Vec<u8>, TokenToStringError> {
294 match self.token_to_piece_bytes(token, 8, matches!(special, Special::Tokenize), None) {
296 Err(TokenToStringError::InsufficientBufferSpace(i)) => self.token_to_piece_bytes(
297 token,
298 (-i).try_into().expect("Error buffer size is positive"),
299 matches!(special, Special::Tokenize),
300 None,
301 ),
302 x => x,
303 }
304 }
305
306 #[deprecated(
312 since = "0.1.0",
313 note = "Use `token_to_piece` for each token individually instead"
314 )]
315 pub fn tokens_to_str(
316 &self,
317 tokens: &[LlamaToken],
318 special: Special,
319 ) -> Result<String, TokenToStringError> {
320 let mut builder: Vec<u8> = Vec::with_capacity(tokens.len() * 4);
321 for piece in tokens
322 .iter()
323 .copied()
324 .map(|t| self.token_to_piece_bytes(t, 8, matches!(special, Special::Tokenize), None))
325 {
326 builder.extend_from_slice(&piece?);
327 }
328 Ok(String::from_utf8(builder)?)
329 }
330
331 pub fn str_to_token(
354 &self,
355 str: &str,
356 add_bos: AddBos,
357 ) -> Result<Vec<LlamaToken>, StringToTokenError> {
358 let add_bos = match add_bos {
359 AddBos::Always => true,
360 AddBos::Never => false,
361 };
362
363 let tokens_estimation = std::cmp::max(8, (str.len() / 2) + usize::from(add_bos));
364 let mut buffer: Vec<LlamaToken> = Vec::with_capacity(tokens_estimation);
365
366 let c_string = CString::new(str)?;
367 let buffer_capacity =
368 c_int::try_from(buffer.capacity()).expect("buffer capacity should fit into a c_int");
369
370 let size = unsafe {
371 llama_cpp_sys_2::llama_tokenize(
372 self.vocab_ptr(),
373 c_string.as_ptr(),
374 c_int::try_from(c_string.as_bytes().len())?,
375 buffer.as_mut_ptr().cast::<llama_cpp_sys_2::llama_token>(),
376 buffer_capacity,
377 add_bos,
378 true,
379 )
380 };
381
382 let size = if size.is_negative() {
385 buffer.reserve_exact(usize::try_from(-size).expect("usize's are larger "));
386 unsafe {
387 llama_cpp_sys_2::llama_tokenize(
388 self.vocab_ptr(),
389 c_string.as_ptr(),
390 c_int::try_from(c_string.as_bytes().len())?,
391 buffer.as_mut_ptr().cast::<llama_cpp_sys_2::llama_token>(),
392 -size,
393 add_bos,
394 true,
395 )
396 }
397 } else {
398 size
399 };
400
401 let size = usize::try_from(size).expect("size is positive and usize ");
402
403 unsafe { buffer.set_len(size) }
405 Ok(buffer)
406 }
407
408 #[must_use]
414 pub fn token_attr(&self, LlamaToken(id): LlamaToken) -> LlamaTokenAttrs {
415 let token_type = unsafe { llama_cpp_sys_2::llama_token_get_attr(self.vocab_ptr(), id) };
416 LlamaTokenAttrs::try_from(token_type).expect("token type is valid")
417 }
418
419 pub fn token_to_piece(
437 &self,
438 token: LlamaToken,
439 decoder: &mut encoding_rs::Decoder,
440 special: bool,
441 lstrip: Option<NonZeroU16>,
442 ) -> Result<String, TokenToStringError> {
443 let bytes = match self.token_to_piece_bytes(token, 8, special, lstrip) {
444 Err(TokenToStringError::InsufficientBufferSpace(i)) => self.token_to_piece_bytes(
447 token,
448 (-i).try_into().expect("Error buffer size is positive"),
449 special,
450 lstrip,
451 ),
452 x => x,
453 }?;
454 let mut output_piece = String::with_capacity(bytes.len());
456 let (_result, _somesize, _truthy) =
459 decoder.decode_to_string(&bytes, &mut output_piece, false);
460 Ok(output_piece)
461 }
462
463 pub fn token_to_piece_bytes(
479 &self,
480 token: LlamaToken,
481 buffer_size: usize,
482 special: bool,
483 lstrip: Option<NonZeroU16>,
484 ) -> Result<Vec<u8>, TokenToStringError> {
485 let string = CString::new(vec![b'*'; buffer_size]).expect("no null");
486 let len = string.as_bytes().len();
487 let len = c_int::try_from(len).expect("length fits into c_int");
488 let buf = string.into_raw();
489 let lstrip = lstrip.map_or(0, |it| i32::from(it.get()));
490 let size = unsafe {
491 llama_cpp_sys_2::llama_token_to_piece(
492 self.vocab_ptr(),
493 token.0,
494 buf,
495 len,
496 lstrip,
497 special,
498 )
499 };
500
501 match size {
502 0 => Err(TokenToStringError::UnknownTokenType),
503 i if i.is_negative() => Err(TokenToStringError::InsufficientBufferSpace(i)),
504 size => {
505 let string = unsafe { CString::from_raw(buf) };
506 let mut bytes = string.into_bytes();
507 let len = usize::try_from(size).expect("size is positive and fits into usize");
508 bytes.truncate(len);
509 Ok(bytes)
510 }
511 }
512 }
513
514 #[deprecated(since = "0.1.0", note = "Use `token_to_piece` instead")]
530 pub fn token_to_str_with_size(
531 &self,
532 token: LlamaToken,
533 buffer_size: usize,
534 special: Special,
535 ) -> Result<String, TokenToStringError> {
536 let bytes = self.token_to_piece_bytes(
537 token,
538 buffer_size,
539 matches!(special, Special::Tokenize),
540 None,
541 )?;
542 Ok(String::from_utf8(bytes)?)
543 }
544
545 #[deprecated(since = "0.1.0", note = "Use `token_to_piece_bytes` instead")]
560 pub fn token_to_bytes_with_size(
561 &self,
562 token: LlamaToken,
563 buffer_size: usize,
564 special: Special,
565 lstrip: Option<NonZeroU16>,
566 ) -> Result<Vec<u8>, TokenToStringError> {
567 if token == self.token_nl() {
568 return Ok(b"\n".to_vec());
569 }
570
571 let attrs = self.token_attr(token);
573 if attrs.is_empty()
574 || attrs
575 .intersects(LlamaTokenAttr::Unknown | LlamaTokenAttr::Byte | LlamaTokenAttr::Unused)
576 || attrs.contains(LlamaTokenAttr::Control)
577 && (token == self.token_bos() || token == self.token_eos())
578 {
579 return Ok(Vec::new());
580 }
581
582 let special = match special {
583 Special::Tokenize => true,
584 Special::Plaintext => false,
585 };
586
587 let string = CString::new(vec![b'*'; buffer_size]).expect("no null");
588 let len = string.as_bytes().len();
589 let len = c_int::try_from(len).expect("length fits into c_int");
590 let buf = string.into_raw();
591 let lstrip = lstrip.map_or(0, |it| i32::from(it.get()));
592 let size = unsafe {
593 llama_cpp_sys_2::llama_token_to_piece(
594 self.vocab_ptr(),
595 token.0,
596 buf,
597 len,
598 lstrip,
599 special,
600 )
601 };
602
603 match size {
604 0 => Err(TokenToStringError::UnknownTokenType),
605 i if i.is_negative() => Err(TokenToStringError::InsufficientBufferSpace(i)),
606 size => {
607 let string = unsafe { CString::from_raw(buf) };
608 let mut bytes = string.into_bytes();
609 let len = usize::try_from(size).expect("size is positive and fits into usize");
610 bytes.truncate(len);
611 Ok(bytes)
612 }
613 }
614 }
615 #[must_use]
620 pub fn n_vocab(&self) -> i32 {
621 unsafe { llama_cpp_sys_2::llama_n_vocab(self.vocab_ptr()) }
622 }
623
624 #[must_use]
630 pub fn vocab_type(&self) -> VocabType {
631 let vocab_type = unsafe { llama_cpp_sys_2::llama_vocab_type(self.vocab_ptr()) };
633 VocabType::try_from(vocab_type).expect("invalid vocab type")
634 }
635
636 #[must_use]
639 pub fn n_embd(&self) -> c_int {
640 unsafe { llama_cpp_sys_2::llama_n_embd(self.model.as_ptr()) }
641 }
642
643 pub fn size(&self) -> u64 {
645 unsafe { llama_cpp_sys_2::llama_model_size(self.model.as_ptr()) }
646 }
647
648 pub fn n_params(&self) -> u64 {
650 unsafe { llama_cpp_sys_2::llama_model_n_params(self.model.as_ptr()) }
651 }
652
653 pub fn is_recurrent(&self) -> bool {
655 unsafe { llama_cpp_sys_2::llama_model_is_recurrent(self.model.as_ptr()) }
656 }
657
658 pub fn is_hybrid(&self) -> bool {
663 unsafe { llama_cpp_sys_2::llama_model_is_hybrid(self.model.as_ptr()) }
664 }
665
666 pub fn n_layer(&self) -> u32 {
668 u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_layer(self.model.as_ptr()) }).unwrap()
671 }
672
673 pub fn n_head(&self) -> u32 {
675 u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_head(self.model.as_ptr()) }).unwrap()
678 }
679
680 pub fn n_head_kv(&self) -> u32 {
682 u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_head_kv(self.model.as_ptr()) })
685 .unwrap()
686 }
687
688 pub fn meta_val_str(&self, key: &str) -> Result<String, MetaValError> {
690 let key_cstring = CString::new(key)?;
691 let key_ptr = key_cstring.as_ptr();
692
693 extract_meta_string(
694 |buf_ptr, buf_len| unsafe {
695 llama_cpp_sys_2::llama_model_meta_val_str(
696 self.model.as_ptr(),
697 key_ptr,
698 buf_ptr,
699 buf_len,
700 )
701 },
702 256,
703 )
704 }
705
706 pub fn meta_count(&self) -> i32 {
708 unsafe { llama_cpp_sys_2::llama_model_meta_count(self.model.as_ptr()) }
709 }
710
711 pub fn meta_key_by_index(&self, index: i32) -> Result<String, MetaValError> {
713 extract_meta_string(
714 |buf_ptr, buf_len| unsafe {
715 llama_cpp_sys_2::llama_model_meta_key_by_index(
716 self.model.as_ptr(),
717 index,
718 buf_ptr,
719 buf_len,
720 )
721 },
722 256,
723 )
724 }
725
726 pub fn meta_val_str_by_index(&self, index: i32) -> Result<String, MetaValError> {
728 extract_meta_string(
729 |buf_ptr, buf_len| unsafe {
730 llama_cpp_sys_2::llama_model_meta_val_str_by_index(
731 self.model.as_ptr(),
732 index,
733 buf_ptr,
734 buf_len,
735 )
736 },
737 256,
738 )
739 }
740
741 pub fn rope_type(&self) -> Option<RopeType> {
743 match unsafe { llama_cpp_sys_2::llama_model_rope_type(self.model.as_ptr()) } {
744 llama_cpp_sys_2::LLAMA_ROPE_TYPE_NONE => None,
745 llama_cpp_sys_2::LLAMA_ROPE_TYPE_NORM => Some(RopeType::Norm),
746 llama_cpp_sys_2::LLAMA_ROPE_TYPE_NEOX => Some(RopeType::NeoX),
747 llama_cpp_sys_2::LLAMA_ROPE_TYPE_MROPE => Some(RopeType::MRope),
748 llama_cpp_sys_2::LLAMA_ROPE_TYPE_VISION => Some(RopeType::Vision),
749 rope_type => {
750 tracing::error!(rope_type = rope_type, "Unexpected rope type from llama.cpp");
751 None
752 }
753 }
754 }
755
756 pub fn chat_template(
770 &self,
771 name: Option<&str>,
772 ) -> Result<LlamaChatTemplate, ChatTemplateError> {
773 let name_cstr = name.map(CString::new);
774 let name_ptr = match name_cstr {
775 Some(Ok(name)) => name.as_ptr(),
776 _ => std::ptr::null(),
777 };
778 let result =
779 unsafe { llama_cpp_sys_2::llama_model_chat_template(self.model.as_ptr(), name_ptr) };
780
781 if result.is_null() {
783 Err(ChatTemplateError::MissingTemplate)
784 } else {
785 let chat_template_cstr = unsafe { CStr::from_ptr(result) };
786 let chat_template = CString::new(chat_template_cstr.to_bytes())?;
787 Ok(LlamaChatTemplate(chat_template))
788 }
789 }
790
791 #[tracing::instrument(skip_all, fields(params))]
797 pub fn load_from_file(
798 _: &LlamaBackend,
799 path: impl AsRef<Path>,
800 params: &LlamaModelParams,
801 ) -> Result<Self, LlamaModelLoadError> {
802 let path = path.as_ref();
803 debug_assert!(Path::new(path).exists(), "{path:?} does not exist");
804 let path = path
805 .to_str()
806 .ok_or(LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
807
808 let cstr = CString::new(path)?;
809 let llama_model =
810 unsafe { llama_cpp_sys_2::llama_load_model_from_file(cstr.as_ptr(), params.params) };
811
812 let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?;
813
814 tracing::debug!(?path, "Loaded model");
815 Ok(LlamaModel { model })
816 }
817
818 pub fn lora_adapter_init(
824 &self,
825 path: impl AsRef<Path>,
826 ) -> Result<LlamaLoraAdapter, LlamaLoraAdapterInitError> {
827 let path = path.as_ref();
828 debug_assert!(Path::new(path).exists(), "{path:?} does not exist");
829
830 let path = path
831 .to_str()
832 .ok_or(LlamaLoraAdapterInitError::PathToStrError(
833 path.to_path_buf(),
834 ))?;
835
836 let cstr = CString::new(path)?;
837 let adapter =
838 unsafe { llama_cpp_sys_2::llama_adapter_lora_init(self.model.as_ptr(), cstr.as_ptr()) };
839
840 let adapter = NonNull::new(adapter).ok_or(LlamaLoraAdapterInitError::NullResult)?;
841
842 tracing::debug!(?path, "Initialized lora adapter");
843 Ok(LlamaLoraAdapter {
844 lora_adapter: adapter,
845 })
846 }
847
848 #[allow(clippy::needless_pass_by_value)]
855 pub fn new_context<'a>(
856 &'a self,
857 _: &LlamaBackend,
858 params: LlamaContextParams,
859 ) -> Result<LlamaContext<'a>, LlamaContextLoadError> {
860 let context_params = params.context_params;
861 let context = unsafe {
862 llama_cpp_sys_2::llama_new_context_with_model(self.model.as_ptr(), context_params)
863 };
864 let context = NonNull::new(context).ok_or(LlamaContextLoadError::NullReturn)?;
865
866 Ok(LlamaContext::new(self, context, params.embeddings()))
867 }
868
869 #[tracing::instrument(skip_all)]
887 pub fn apply_chat_template(
888 &self,
889 tmpl: &LlamaChatTemplate,
890 chat: &[LlamaChatMessage],
891 add_ass: bool,
892 ) -> Result<String, ApplyChatTemplateError> {
893 let message_length = chat.iter().fold(0, |acc, c| {
895 acc + c.role.to_bytes().len() + c.content.to_bytes().len()
896 });
897 let mut buff: Vec<u8> = vec![0; message_length * 2];
898
899 let chat: Vec<llama_cpp_sys_2::llama_chat_message> = chat
901 .iter()
902 .map(|c| llama_cpp_sys_2::llama_chat_message {
903 role: c.role.as_ptr(),
904 content: c.content.as_ptr(),
905 })
906 .collect();
907
908 let tmpl_ptr = tmpl.0.as_ptr();
909
910 let res = unsafe {
911 llama_cpp_sys_2::llama_chat_apply_template(
912 tmpl_ptr,
913 chat.as_ptr(),
914 chat.len(),
915 add_ass,
916 buff.as_mut_ptr().cast::<c_char>(),
917 buff.len().try_into().expect("Buffer size exceeds i32::MAX"),
918 )
919 };
920
921 if res < 0 {
922 return Err(ApplyChatTemplateError::FfiError(res));
923 }
924
925 if res > buff.len().try_into().expect("Buffer size exceeds i32::MAX") {
926 buff.resize(res.try_into().expect("res is negative"), 0);
927
928 let res = unsafe {
929 llama_cpp_sys_2::llama_chat_apply_template(
930 tmpl_ptr,
931 chat.as_ptr(),
932 chat.len(),
933 add_ass,
934 buff.as_mut_ptr().cast::<c_char>(),
935 buff.len().try_into().expect("Buffer size exceeds i32::MAX"),
936 )
937 };
938 if res < 0 {
939 return Err(ApplyChatTemplateError::FfiError(res));
940 }
941 assert_eq!(Ok(res), buff.len().try_into());
942 }
943 buff.truncate(res.try_into().expect("res is negative"));
944 Ok(String::from_utf8(buff)?)
945 }
946
947 #[tracing::instrument(skip_all)]
951 pub fn apply_chat_template_with_tools_oaicompat(
952 &self,
953 tmpl: &LlamaChatTemplate,
954 messages: &[LlamaChatMessage],
955 tools_json: Option<&str>,
956 json_schema: Option<&str>,
957 add_generation_prompt: bool,
958 ) -> Result<ChatTemplateResult, ApplyChatTemplateError> {
959 let chat: Vec<llama_cpp_sys_2::llama_chat_message> = messages
960 .iter()
961 .map(|c| llama_cpp_sys_2::llama_chat_message {
962 role: c.role.as_ptr(),
963 content: c.content.as_ptr(),
964 })
965 .collect();
966
967 let tools_cstr = tools_json.map(CString::new).transpose()?;
968 let json_schema_cstr = json_schema.map(CString::new).transpose()?;
969
970 let mut raw_result = llama_cpp_sys_2::llama_rs_chat_template_result {
971 prompt: ptr::null_mut(),
972 grammar: ptr::null_mut(),
973 parser: ptr::null_mut(),
974 generation_prompt: ptr::null_mut(),
975 chat_format: 0,
976 thinking_forced_open: false,
977 grammar_lazy: false,
978 grammar_triggers: ptr::null_mut(),
979 grammar_triggers_count: 0,
980 preserved_tokens: ptr::null_mut(),
981 preserved_tokens_count: 0,
982 additional_stops: ptr::null_mut(),
983 additional_stops_count: 0,
984 };
985
986 let rc = unsafe {
987 llama_cpp_sys_2::llama_rs_apply_chat_template_with_tools_oaicompat(
988 self.model.as_ptr(),
989 tmpl.0.as_ptr(),
990 chat.as_ptr(),
991 chat.len(),
992 tools_cstr
993 .as_ref()
994 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
995 json_schema_cstr
996 .as_ref()
997 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
998 add_generation_prompt,
999 &mut raw_result,
1000 )
1001 };
1002
1003 let result = (|| {
1004 if !status_is_ok(rc) {
1005 return Err(ApplyChatTemplateError::FfiError(status_to_i32(rc)));
1006 }
1007 if raw_result.prompt.is_null() {
1008 return Err(ApplyChatTemplateError::NullResult);
1009 }
1010 let prompt_bytes = unsafe { CStr::from_ptr(raw_result.prompt) }
1011 .to_bytes()
1012 .to_vec();
1013 let prompt = String::from_utf8(prompt_bytes)?;
1014 let grammar_lazy = raw_result.grammar_lazy;
1015 let grammar = if raw_result.grammar.is_null() {
1016 None
1017 } else {
1018 let grammar_bytes = unsafe { CStr::from_ptr(raw_result.grammar) }
1019 .to_bytes()
1020 .to_vec();
1021 Some(String::from_utf8(grammar_bytes)?)
1022 };
1023 let parser = if raw_result.parser.is_null() {
1024 None
1025 } else {
1026 let parser_bytes = unsafe { CStr::from_ptr(raw_result.parser) }
1027 .to_bytes()
1028 .to_vec();
1029 Some(String::from_utf8(parser_bytes)?)
1030 };
1031 let generation_prompt = if raw_result.generation_prompt.is_null() {
1032 String::new()
1033 } else {
1034 let generation_prompt_bytes =
1035 unsafe { CStr::from_ptr(raw_result.generation_prompt) }
1036 .to_bytes()
1037 .to_vec();
1038 String::from_utf8(generation_prompt_bytes)?
1039 };
1040 let grammar_triggers = if raw_result.grammar_triggers_count == 0 {
1041 Vec::new()
1042 } else if raw_result.grammar_triggers.is_null() {
1043 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1044 } else {
1045 let triggers = unsafe {
1046 slice::from_raw_parts(
1047 raw_result.grammar_triggers,
1048 raw_result.grammar_triggers_count,
1049 )
1050 };
1051 let mut parsed = Vec::with_capacity(triggers.len());
1052 for trigger in triggers {
1053 let trigger_type = match trigger.type_ {
1054 0 => GrammarTriggerType::Token,
1055 1 => GrammarTriggerType::Word,
1056 2 => GrammarTriggerType::Pattern,
1057 3 => GrammarTriggerType::PatternFull,
1058 _ => return Err(ApplyChatTemplateError::InvalidGrammarTriggerType),
1059 };
1060 let value = if trigger.value.is_null() {
1061 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1062 } else {
1063 let bytes = unsafe { CStr::from_ptr(trigger.value) }.to_bytes().to_vec();
1064 String::from_utf8(bytes)?
1065 };
1066 let token = if trigger_type == GrammarTriggerType::Token {
1067 Some(LlamaToken(trigger.token))
1068 } else {
1069 None
1070 };
1071 parsed.push(GrammarTrigger {
1072 trigger_type,
1073 value,
1074 token,
1075 });
1076 }
1077 parsed
1078 };
1079 let preserved_tokens = if raw_result.preserved_tokens_count == 0 {
1080 Vec::new()
1081 } else if raw_result.preserved_tokens.is_null() {
1082 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1083 } else {
1084 let tokens = unsafe {
1085 slice::from_raw_parts(
1086 raw_result.preserved_tokens,
1087 raw_result.preserved_tokens_count,
1088 )
1089 };
1090 let mut parsed = Vec::with_capacity(tokens.len());
1091 for token in tokens {
1092 if token.is_null() {
1093 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1094 }
1095 let bytes = unsafe { CStr::from_ptr(*token) }.to_bytes().to_vec();
1096 parsed.push(String::from_utf8(bytes)?);
1097 }
1098 parsed
1099 };
1100 let additional_stops = if raw_result.additional_stops_count == 0 {
1101 Vec::new()
1102 } else if raw_result.additional_stops.is_null() {
1103 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1104 } else {
1105 let stops = unsafe {
1106 slice::from_raw_parts(
1107 raw_result.additional_stops,
1108 raw_result.additional_stops_count,
1109 )
1110 };
1111 let mut parsed = Vec::with_capacity(stops.len());
1112 for stop in stops {
1113 if stop.is_null() {
1114 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1115 }
1116 let bytes = unsafe { CStr::from_ptr(*stop) }.to_bytes().to_vec();
1117 parsed.push(String::from_utf8(bytes)?);
1118 }
1119 parsed
1120 };
1121 let parse_tool_calls = tools_json.map_or(false, |tools| !tools.is_empty());
1122 Ok(ChatTemplateResult {
1123 prompt,
1124 grammar,
1125 grammar_lazy,
1126 grammar_triggers,
1127 preserved_tokens,
1128 additional_stops,
1129 chat_format: raw_result.chat_format,
1130 parser,
1131 generation_prompt,
1132 thinking_forced_open: raw_result.thinking_forced_open,
1133 parse_tool_calls,
1134 })
1135 })();
1136
1137 unsafe { llama_cpp_sys_2::llama_rs_chat_template_result_free(&mut raw_result) };
1138 result
1139 }
1140
1141 #[tracing::instrument(skip_all)]
1143 pub fn apply_chat_template_oaicompat(
1144 &self,
1145 tmpl: &LlamaChatTemplate,
1146 params: &OpenAIChatTemplateParams<'_>,
1147 ) -> Result<ChatTemplateResult, ApplyChatTemplateError> {
1148 let parse_tool_calls = params.parse_tool_calls;
1149 let messages_cstr = CString::new(params.messages_json)?;
1150 let tools_cstr = params.tools_json.map(CString::new).transpose()?;
1151 let tool_choice_cstr = params.tool_choice.map(CString::new).transpose()?;
1152 let json_schema_cstr = params.json_schema.map(CString::new).transpose()?;
1153 let grammar_cstr = params.grammar.map(CString::new).transpose()?;
1154 let reasoning_cstr = params.reasoning_format.map(CString::new).transpose()?;
1155 let kwargs_cstr = params.chat_template_kwargs.map(CString::new).transpose()?;
1156
1157 let mut raw_result = llama_cpp_sys_2::llama_rs_chat_template_result {
1158 prompt: ptr::null_mut(),
1159 grammar: ptr::null_mut(),
1160 parser: ptr::null_mut(),
1161 generation_prompt: ptr::null_mut(),
1162 chat_format: 0,
1163 thinking_forced_open: false,
1164 grammar_lazy: false,
1165 grammar_triggers: ptr::null_mut(),
1166 grammar_triggers_count: 0,
1167 preserved_tokens: ptr::null_mut(),
1168 preserved_tokens_count: 0,
1169 additional_stops: ptr::null_mut(),
1170 additional_stops_count: 0,
1171 };
1172
1173 let ffi_params = llama_cpp_sys_2::llama_rs_chat_template_oaicompat_params {
1174 messages: messages_cstr.as_ptr(),
1175 tools: tools_cstr
1176 .as_ref()
1177 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
1178 tool_choice: tool_choice_cstr
1179 .as_ref()
1180 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
1181 json_schema: json_schema_cstr
1182 .as_ref()
1183 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
1184 grammar: grammar_cstr
1185 .as_ref()
1186 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
1187 reasoning_format: reasoning_cstr
1188 .as_ref()
1189 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
1190 chat_template_kwargs: kwargs_cstr
1191 .as_ref()
1192 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
1193 add_generation_prompt: params.add_generation_prompt,
1194 use_jinja: params.use_jinja,
1195 parallel_tool_calls: params.parallel_tool_calls,
1196 enable_thinking: params.enable_thinking,
1197 add_bos: params.add_bos,
1198 add_eos: params.add_eos,
1199 };
1200
1201 let rc = unsafe {
1202 llama_cpp_sys_2::llama_rs_apply_chat_template_oaicompat(
1203 self.model.as_ptr(),
1204 tmpl.0.as_ptr(),
1205 &ffi_params,
1206 &mut raw_result,
1207 )
1208 };
1209
1210 let result = (|| {
1211 if !status_is_ok(rc) {
1212 return Err(ApplyChatTemplateError::FfiError(status_to_i32(rc)));
1213 }
1214 if raw_result.prompt.is_null() {
1215 return Err(ApplyChatTemplateError::NullResult);
1216 }
1217 let prompt_bytes = unsafe { CStr::from_ptr(raw_result.prompt) }
1218 .to_bytes()
1219 .to_vec();
1220 let prompt = String::from_utf8(prompt_bytes)?;
1221 let grammar_lazy = raw_result.grammar_lazy;
1222 let grammar = if raw_result.grammar.is_null() {
1223 None
1224 } else {
1225 let grammar_bytes = unsafe { CStr::from_ptr(raw_result.grammar) }
1226 .to_bytes()
1227 .to_vec();
1228 Some(String::from_utf8(grammar_bytes)?)
1229 };
1230 let parser = if raw_result.parser.is_null() {
1231 None
1232 } else {
1233 let parser_bytes = unsafe { CStr::from_ptr(raw_result.parser) }
1234 .to_bytes()
1235 .to_vec();
1236 Some(String::from_utf8(parser_bytes)?)
1237 };
1238 let generation_prompt = if raw_result.generation_prompt.is_null() {
1239 String::new()
1240 } else {
1241 let generation_prompt_bytes =
1242 unsafe { CStr::from_ptr(raw_result.generation_prompt) }
1243 .to_bytes()
1244 .to_vec();
1245 String::from_utf8(generation_prompt_bytes)?
1246 };
1247 let grammar_triggers = if raw_result.grammar_triggers_count == 0 {
1248 Vec::new()
1249 } else if raw_result.grammar_triggers.is_null() {
1250 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1251 } else {
1252 let triggers = unsafe {
1253 slice::from_raw_parts(
1254 raw_result.grammar_triggers,
1255 raw_result.grammar_triggers_count,
1256 )
1257 };
1258 let mut parsed = Vec::with_capacity(triggers.len());
1259 for trigger in triggers {
1260 let trigger_type = match trigger.type_ {
1261 0 => GrammarTriggerType::Token,
1262 1 => GrammarTriggerType::Word,
1263 2 => GrammarTriggerType::Pattern,
1264 3 => GrammarTriggerType::PatternFull,
1265 _ => return Err(ApplyChatTemplateError::InvalidGrammarTriggerType),
1266 };
1267 let value = if trigger.value.is_null() {
1268 String::new()
1269 } else {
1270 let bytes = unsafe { CStr::from_ptr(trigger.value) }.to_bytes().to_vec();
1271 String::from_utf8(bytes)?
1272 };
1273 let token = if trigger_type == GrammarTriggerType::Token {
1274 Some(LlamaToken(trigger.token))
1275 } else {
1276 None
1277 };
1278 parsed.push(GrammarTrigger {
1279 trigger_type,
1280 value,
1281 token,
1282 });
1283 }
1284 parsed
1285 };
1286 let preserved_tokens = if raw_result.preserved_tokens_count == 0 {
1287 Vec::new()
1288 } else if raw_result.preserved_tokens.is_null() {
1289 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1290 } else {
1291 let tokens = unsafe {
1292 slice::from_raw_parts(
1293 raw_result.preserved_tokens,
1294 raw_result.preserved_tokens_count,
1295 )
1296 };
1297 let mut parsed = Vec::with_capacity(tokens.len());
1298 for token in tokens {
1299 if token.is_null() {
1300 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1301 }
1302 let bytes = unsafe { CStr::from_ptr(*token) }.to_bytes().to_vec();
1303 parsed.push(String::from_utf8(bytes)?);
1304 }
1305 parsed
1306 };
1307 let additional_stops = if raw_result.additional_stops_count == 0 {
1308 Vec::new()
1309 } else if raw_result.additional_stops.is_null() {
1310 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1311 } else {
1312 let stops = unsafe {
1313 slice::from_raw_parts(
1314 raw_result.additional_stops,
1315 raw_result.additional_stops_count,
1316 )
1317 };
1318 let mut parsed = Vec::with_capacity(stops.len());
1319 for stop in stops {
1320 if stop.is_null() {
1321 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1322 }
1323 let bytes = unsafe { CStr::from_ptr(*stop) }.to_bytes().to_vec();
1324 parsed.push(String::from_utf8(bytes)?);
1325 }
1326 parsed
1327 };
1328
1329 Ok(ChatTemplateResult {
1330 prompt,
1331 grammar,
1332 grammar_lazy,
1333 grammar_triggers,
1334 preserved_tokens,
1335 additional_stops,
1336 chat_format: raw_result.chat_format,
1337 parser,
1338 generation_prompt,
1339 thinking_forced_open: raw_result.thinking_forced_open,
1340 parse_tool_calls,
1341 })
1342 })();
1343
1344 unsafe { llama_cpp_sys_2::llama_rs_chat_template_result_free(&mut raw_result) };
1345 result
1346 }
1347}
1348
1349impl ChatTemplateResult {
1350 pub fn parse_response_oaicompat(
1352 &self,
1353 text: &str,
1354 is_partial: bool,
1355 ) -> Result<String, ChatParseError> {
1356 let text_cstr = CString::new(text)?;
1357 let parser_cstr = self.parser.as_deref().map(CString::new).transpose()?;
1358 let generation_prompt_cstr = if self.generation_prompt.is_empty() {
1359 None
1360 } else {
1361 Some(CString::new(self.generation_prompt.as_str())?)
1362 };
1363 let mut out_json: *mut c_char = ptr::null_mut();
1364 let rc = unsafe {
1365 llama_cpp_sys_2::llama_rs_chat_parse_to_oaicompat(
1366 text_cstr.as_ptr(),
1367 is_partial,
1368 self.chat_format,
1369 self.parse_tool_calls,
1370 parser_cstr
1371 .as_ref()
1372 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
1373 generation_prompt_cstr
1374 .as_ref()
1375 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
1376 self.thinking_forced_open,
1377 &mut out_json,
1378 )
1379 };
1380
1381 let result = (|| {
1382 if !status_is_ok(rc) {
1383 return Err(ChatParseError::FfiError(status_to_i32(rc)));
1384 }
1385 if out_json.is_null() {
1386 return Err(ChatParseError::NullResult);
1387 }
1388 let bytes = unsafe { CStr::from_ptr(out_json) }.to_bytes().to_vec();
1389 Ok(String::from_utf8(bytes)?)
1390 })();
1391
1392 unsafe { llama_cpp_sys_2::llama_rs_string_free(out_json) };
1393 result
1394 }
1395
1396 pub fn streaming_state_oaicompat(&self) -> Result<ChatParseStateOaicompat, ChatParseError> {
1398 let parser_cstr = self.parser.as_deref().map(CString::new).transpose()?;
1399 let generation_prompt_cstr = if self.generation_prompt.is_empty() {
1400 None
1401 } else {
1402 Some(CString::new(self.generation_prompt.as_str())?)
1403 };
1404 let state = unsafe {
1405 llama_cpp_sys_2::llama_rs_chat_parse_state_init_oaicompat(
1406 self.chat_format,
1407 self.parse_tool_calls,
1408 parser_cstr
1409 .as_ref()
1410 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
1411 generation_prompt_cstr
1412 .as_ref()
1413 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
1414 self.thinking_forced_open,
1415 )
1416 };
1417 let state = NonNull::new(state).ok_or(ChatParseError::NullResult)?;
1418 Ok(ChatParseStateOaicompat { state })
1419 }
1420}
1421
1422fn extract_meta_string<F>(c_function: F, capacity: usize) -> Result<String, MetaValError>
1428where
1429 F: Fn(*mut c_char, usize) -> i32,
1430{
1431 let mut buffer = vec![0u8; capacity];
1432
1433 let result = c_function(buffer.as_mut_ptr().cast::<c_char>(), buffer.len());
1435 if result < 0 {
1436 return Err(MetaValError::NegativeReturn(result));
1437 }
1438
1439 let returned_len = result as usize;
1441 if returned_len >= capacity {
1442 return extract_meta_string(c_function, returned_len + 1);
1444 }
1445
1446 debug_assert_eq!(
1448 buffer.get(returned_len),
1449 Some(&0),
1450 "should end with null byte"
1451 );
1452
1453 buffer.truncate(returned_len);
1455 Ok(String::from_utf8(buffer)?)
1456}
1457
1458impl Drop for LlamaModel {
1459 fn drop(&mut self) {
1460 unsafe { llama_cpp_sys_2::llama_free_model(self.model.as_ptr()) }
1461 }
1462}
1463
1464#[repr(u32)]
1466#[derive(Debug, Eq, Copy, Clone, PartialEq)]
1467pub enum VocabType {
1468 BPE = llama_cpp_sys_2::LLAMA_VOCAB_TYPE_BPE as _,
1470 SPM = llama_cpp_sys_2::LLAMA_VOCAB_TYPE_SPM as _,
1472}
1473
1474#[derive(thiserror::Error, Debug, Eq, PartialEq)]
1476pub enum LlamaTokenTypeFromIntError {
1477 #[error("Unknown Value {0}")]
1479 UnknownValue(llama_cpp_sys_2::llama_vocab_type),
1480}
1481
1482impl TryFrom<llama_cpp_sys_2::llama_vocab_type> for VocabType {
1483 type Error = LlamaTokenTypeFromIntError;
1484
1485 fn try_from(value: llama_cpp_sys_2::llama_vocab_type) -> Result<Self, Self::Error> {
1486 match value {
1487 llama_cpp_sys_2::LLAMA_VOCAB_TYPE_BPE => Ok(VocabType::BPE),
1488 llama_cpp_sys_2::LLAMA_VOCAB_TYPE_SPM => Ok(VocabType::SPM),
1489 unknown => Err(LlamaTokenTypeFromIntError::UnknownValue(unknown)),
1490 }
1491 }
1492}