1use std::ffi::{c_char, CStr, CString};
3use std::num::NonZeroU16;
4use std::os::raw::c_int;
5use std::path::Path;
6use std::ptr::{self, NonNull};
7use std::slice;
8use std::str::Utf8Error;
9
10use crate::context::params::LlamaContextParams;
11use crate::context::LlamaContext;
12use crate::llama_backend::LlamaBackend;
13use crate::model::params::LlamaModelParams;
14use crate::openai::{ChatParseStateOaicompat, OpenAIChatTemplateParams};
15use crate::token::LlamaToken;
16use crate::token_type::{LlamaTokenAttr, LlamaTokenAttrs};
17use crate::{
18 status_is_ok, status_to_i32, ApplyChatTemplateError, ChatParseError, ChatTemplateError,
19 LlamaContextLoadError, LlamaLoraAdapterInitError, LlamaModelLoadError, MetaValError,
20 NewLlamaChatMessageError, StringToTokenError, TokenToStringError,
21};
22
23pub mod params;
24
25#[derive(Debug)]
27#[repr(transparent)]
28#[allow(clippy::module_name_repetitions)]
29pub struct LlamaModel {
30 pub(crate) model: NonNull<llama_cpp_sys_2::llama_model>,
31}
32
33#[derive(Debug)]
35#[repr(transparent)]
36#[allow(clippy::module_name_repetitions)]
37pub struct LlamaLoraAdapter {
38 pub(crate) lora_adapter: NonNull<llama_cpp_sys_2::llama_adapter_lora>,
39}
40
41#[derive(Eq, PartialEq, Clone, PartialOrd, Ord, Hash)]
46pub struct LlamaChatTemplate(CString);
47
48impl LlamaChatTemplate {
49 pub fn new(template: &str) -> Result<Self, std::ffi::NulError> {
52 Ok(Self(CString::new(template)?))
53 }
54
55 pub fn as_c_str(&self) -> &CStr {
57 &self.0
58 }
59
60 pub fn to_str(&self) -> Result<&str, Utf8Error> {
62 self.0.to_str()
63 }
64
65 pub fn to_string(&self) -> Result<String, Utf8Error> {
67 self.to_str().map(str::to_string)
68 }
69}
70
71impl std::fmt::Debug for LlamaChatTemplate {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 self.0.fmt(f)
74 }
75}
76
77#[derive(Debug, Eq, PartialEq, Clone)]
79pub struct LlamaChatMessage {
80 role: CString,
81 content: CString,
82}
83
84impl LlamaChatMessage {
85 pub fn new(role: String, content: String) -> Result<Self, NewLlamaChatMessageError> {
90 Ok(Self {
91 role: CString::new(role)?,
92 content: CString::new(content)?,
93 })
94 }
95}
96
97#[derive(Debug, Clone, Copy, PartialEq, Eq)]
99pub enum GrammarTriggerType {
100 Token = 0,
102 Word = 1,
104 Pattern = 2,
106 PatternFull = 3,
108}
109
110#[derive(Debug, Clone, PartialEq, Eq)]
112pub struct GrammarTrigger {
113 pub trigger_type: GrammarTriggerType,
115 pub value: String,
117 pub token: Option<LlamaToken>,
119}
120
121#[derive(Debug, Clone, PartialEq, Eq)]
123pub struct ChatTemplateResult {
124 pub prompt: String,
126 pub grammar: Option<String>,
128 pub grammar_lazy: bool,
130 pub grammar_triggers: Vec<GrammarTrigger>,
132 pub preserved_tokens: Vec<String>,
134 pub additional_stops: Vec<String>,
136 pub chat_format: i32,
138 pub parser: Option<String>,
140 pub thinking_forced_open: bool,
142 pub parse_tool_calls: bool,
144}
145
146#[derive(Debug, Clone, Copy, PartialEq, Eq)]
148pub enum RopeType {
149 Norm,
150 NeoX,
151 MRope,
152 Vision,
153}
154
155#[derive(Debug, Clone, Copy, PartialEq, Eq)]
157pub enum AddBos {
158 Always,
160 Never,
162}
163
164#[deprecated(
166 since = "0.1.0",
167 note = "This enum is a mixture of options for llama cpp providing less flexibility it only used with deprecated methods and will be removed in the future."
168)]
169#[derive(Debug, Clone, Copy, PartialEq, Eq)]
170pub enum Special {
171 Tokenize,
173 Plaintext,
175}
176
177unsafe impl Send for LlamaModel {}
178
179unsafe impl Sync for LlamaModel {}
180
181impl LlamaModel {
182 pub(crate) fn vocab_ptr(&self) -> *const llama_cpp_sys_2::llama_vocab {
183 unsafe { llama_cpp_sys_2::llama_model_get_vocab(self.model.as_ptr()) }
184 }
185
186 #[must_use]
193 pub fn n_ctx_train(&self) -> u32 {
194 let n_ctx_train = unsafe { llama_cpp_sys_2::llama_n_ctx_train(self.model.as_ptr()) };
195 u32::try_from(n_ctx_train).expect("n_ctx_train fits into an u32")
196 }
197
198 pub fn tokens(
200 &self,
201 decode_special: bool,
202 ) -> impl Iterator<Item = (LlamaToken, Result<String, TokenToStringError>)> + '_ {
203 (0..self.n_vocab())
204 .map(LlamaToken::new)
205 .map(move |llama_token| {
206 let mut decoder = encoding_rs::UTF_8.new_decoder();
207 (
208 llama_token,
209 self.token_to_piece(llama_token, &mut decoder, decode_special, None),
210 )
211 })
212 }
213
214 #[must_use]
216 pub fn token_bos(&self) -> LlamaToken {
217 let token = unsafe { llama_cpp_sys_2::llama_token_bos(self.vocab_ptr()) };
218 LlamaToken(token)
219 }
220
221 #[must_use]
223 pub fn token_eos(&self) -> LlamaToken {
224 let token = unsafe { llama_cpp_sys_2::llama_token_eos(self.vocab_ptr()) };
225 LlamaToken(token)
226 }
227
228 #[must_use]
230 pub fn token_nl(&self) -> LlamaToken {
231 let token = unsafe { llama_cpp_sys_2::llama_token_nl(self.vocab_ptr()) };
232 LlamaToken(token)
233 }
234
235 #[must_use]
237 pub fn is_eog_token(&self, token: LlamaToken) -> bool {
238 unsafe { llama_cpp_sys_2::llama_token_is_eog(self.vocab_ptr(), token.0) }
239 }
240
241 #[must_use]
243 pub fn decode_start_token(&self) -> LlamaToken {
244 let token =
245 unsafe { llama_cpp_sys_2::llama_model_decoder_start_token(self.model.as_ptr()) };
246 LlamaToken(token)
247 }
248
249 #[must_use]
251 pub fn token_sep(&self) -> LlamaToken {
252 let token = unsafe { llama_cpp_sys_2::llama_vocab_sep(self.vocab_ptr()) };
253 LlamaToken(token)
254 }
255
256 #[deprecated(since = "0.1.0", note = "Use `token_to_piece` instead")]
262 pub fn token_to_str(
263 &self,
264 token: LlamaToken,
265 special: Special,
266 ) -> Result<String, TokenToStringError> {
267 let mut decoder = encoding_rs::UTF_8.new_decoder();
269 Ok(self.token_to_piece(
270 token,
271 &mut decoder,
272 matches!(special, Special::Tokenize),
273 None,
274 )?)
275 }
276
277 #[deprecated(since = "0.1.0", note = "Use `token_to_piece_bytes` instead")]
287 pub fn token_to_bytes(
288 &self,
289 token: LlamaToken,
290 special: Special,
291 ) -> Result<Vec<u8>, TokenToStringError> {
292 match self.token_to_piece_bytes(token, 8, matches!(special, Special::Tokenize), None) {
294 Err(TokenToStringError::InsufficientBufferSpace(i)) => self.token_to_piece_bytes(
295 token,
296 (-i).try_into().expect("Error buffer size is positive"),
297 matches!(special, Special::Tokenize),
298 None,
299 ),
300 x => x,
301 }
302 }
303
304 #[deprecated(
310 since = "0.1.0",
311 note = "Use `token_to_piece` for each token individually instead"
312 )]
313 pub fn tokens_to_str(
314 &self,
315 tokens: &[LlamaToken],
316 special: Special,
317 ) -> Result<String, TokenToStringError> {
318 let mut builder: Vec<u8> = Vec::with_capacity(tokens.len() * 4);
319 for piece in tokens
320 .iter()
321 .copied()
322 .map(|t| self.token_to_piece_bytes(t, 8, matches!(special, Special::Tokenize), None))
323 {
324 builder.extend_from_slice(&piece?);
325 }
326 Ok(String::from_utf8(builder)?)
327 }
328
329 pub fn str_to_token(
352 &self,
353 str: &str,
354 add_bos: AddBos,
355 ) -> Result<Vec<LlamaToken>, StringToTokenError> {
356 let add_bos = match add_bos {
357 AddBos::Always => true,
358 AddBos::Never => false,
359 };
360
361 let tokens_estimation = std::cmp::max(8, (str.len() / 2) + usize::from(add_bos));
362 let mut buffer: Vec<LlamaToken> = Vec::with_capacity(tokens_estimation);
363
364 let c_string = CString::new(str)?;
365 let buffer_capacity =
366 c_int::try_from(buffer.capacity()).expect("buffer capacity should fit into a c_int");
367
368 let size = unsafe {
369 llama_cpp_sys_2::llama_tokenize(
370 self.vocab_ptr(),
371 c_string.as_ptr(),
372 c_int::try_from(c_string.as_bytes().len())?,
373 buffer.as_mut_ptr().cast::<llama_cpp_sys_2::llama_token>(),
374 buffer_capacity,
375 add_bos,
376 true,
377 )
378 };
379
380 let size = if size.is_negative() {
383 buffer.reserve_exact(usize::try_from(-size).expect("usize's are larger "));
384 unsafe {
385 llama_cpp_sys_2::llama_tokenize(
386 self.vocab_ptr(),
387 c_string.as_ptr(),
388 c_int::try_from(c_string.as_bytes().len())?,
389 buffer.as_mut_ptr().cast::<llama_cpp_sys_2::llama_token>(),
390 -size,
391 add_bos,
392 true,
393 )
394 }
395 } else {
396 size
397 };
398
399 let size = usize::try_from(size).expect("size is positive and usize ");
400
401 unsafe { buffer.set_len(size) }
403 Ok(buffer)
404 }
405
406 #[must_use]
412 pub fn token_attr(&self, LlamaToken(id): LlamaToken) -> LlamaTokenAttrs {
413 let token_type = unsafe { llama_cpp_sys_2::llama_token_get_attr(self.vocab_ptr(), id) };
414 LlamaTokenAttrs::try_from(token_type).expect("token type is valid")
415 }
416
417 pub fn token_to_piece(
435 &self,
436 token: LlamaToken,
437 decoder: &mut encoding_rs::Decoder,
438 special: bool,
439 lstrip: Option<NonZeroU16>,
440 ) -> Result<String, TokenToStringError> {
441 let bytes = match self.token_to_piece_bytes(token, 8, special, lstrip) {
442 Err(TokenToStringError::InsufficientBufferSpace(i)) => self.token_to_piece_bytes(
445 token,
446 (-i).try_into().expect("Error buffer size is positive"),
447 special,
448 lstrip,
449 ),
450 x => x,
451 }?;
452 let mut output_piece = String::with_capacity(bytes.len());
454 let (_result, _somesize, _truthy) =
457 decoder.decode_to_string(&bytes, &mut output_piece, false);
458 Ok(output_piece)
459 }
460
461 pub fn token_to_piece_bytes(
477 &self,
478 token: LlamaToken,
479 buffer_size: usize,
480 special: bool,
481 lstrip: Option<NonZeroU16>,
482 ) -> Result<Vec<u8>, TokenToStringError> {
483 let string = CString::new(vec![b'*'; buffer_size]).expect("no null");
484 let len = string.as_bytes().len();
485 let len = c_int::try_from(len).expect("length fits into c_int");
486 let buf = string.into_raw();
487 let lstrip = lstrip.map_or(0, |it| i32::from(it.get()));
488 let size = unsafe {
489 llama_cpp_sys_2::llama_token_to_piece(
490 self.vocab_ptr(),
491 token.0,
492 buf,
493 len,
494 lstrip,
495 special,
496 )
497 };
498
499 match size {
500 0 => Err(TokenToStringError::UnknownTokenType),
501 i if i.is_negative() => Err(TokenToStringError::InsufficientBufferSpace(i)),
502 size => {
503 let string = unsafe { CString::from_raw(buf) };
504 let mut bytes = string.into_bytes();
505 let len = usize::try_from(size).expect("size is positive and fits into usize");
506 bytes.truncate(len);
507 Ok(bytes)
508 }
509 }
510 }
511
512 #[deprecated(since = "0.1.0", note = "Use `token_to_piece` instead")]
528 pub fn token_to_str_with_size(
529 &self,
530 token: LlamaToken,
531 buffer_size: usize,
532 special: Special,
533 ) -> Result<String, TokenToStringError> {
534 let bytes = self.token_to_piece_bytes(
535 token,
536 buffer_size,
537 matches!(special, Special::Tokenize),
538 None,
539 )?;
540 Ok(String::from_utf8(bytes)?)
541 }
542
543 #[deprecated(since = "0.1.0", note = "Use `token_to_piece_bytes` instead")]
558 pub fn token_to_bytes_with_size(
559 &self,
560 token: LlamaToken,
561 buffer_size: usize,
562 special: Special,
563 lstrip: Option<NonZeroU16>,
564 ) -> Result<Vec<u8>, TokenToStringError> {
565 if token == self.token_nl() {
566 return Ok(b"\n".to_vec());
567 }
568
569 let attrs = self.token_attr(token);
571 if attrs.is_empty()
572 || attrs
573 .intersects(LlamaTokenAttr::Unknown | LlamaTokenAttr::Byte | LlamaTokenAttr::Unused)
574 || attrs.contains(LlamaTokenAttr::Control)
575 && (token == self.token_bos() || token == self.token_eos())
576 {
577 return Ok(Vec::new());
578 }
579
580 let special = match special {
581 Special::Tokenize => true,
582 Special::Plaintext => false,
583 };
584
585 let string = CString::new(vec![b'*'; buffer_size]).expect("no null");
586 let len = string.as_bytes().len();
587 let len = c_int::try_from(len).expect("length fits into c_int");
588 let buf = string.into_raw();
589 let lstrip = lstrip.map_or(0, |it| i32::from(it.get()));
590 let size = unsafe {
591 llama_cpp_sys_2::llama_token_to_piece(
592 self.vocab_ptr(),
593 token.0,
594 buf,
595 len,
596 lstrip,
597 special,
598 )
599 };
600
601 match size {
602 0 => Err(TokenToStringError::UnknownTokenType),
603 i if i.is_negative() => Err(TokenToStringError::InsufficientBufferSpace(i)),
604 size => {
605 let string = unsafe { CString::from_raw(buf) };
606 let mut bytes = string.into_bytes();
607 let len = usize::try_from(size).expect("size is positive and fits into usize");
608 bytes.truncate(len);
609 Ok(bytes)
610 }
611 }
612 }
613 #[must_use]
618 pub fn n_vocab(&self) -> i32 {
619 unsafe { llama_cpp_sys_2::llama_n_vocab(self.vocab_ptr()) }
620 }
621
622 #[must_use]
628 pub fn vocab_type(&self) -> VocabType {
629 let vocab_type = unsafe { llama_cpp_sys_2::llama_vocab_type(self.vocab_ptr()) };
631 VocabType::try_from(vocab_type).expect("invalid vocab type")
632 }
633
634 #[must_use]
637 pub fn n_embd(&self) -> c_int {
638 unsafe { llama_cpp_sys_2::llama_n_embd(self.model.as_ptr()) }
639 }
640
641 pub fn size(&self) -> u64 {
643 unsafe { llama_cpp_sys_2::llama_model_size(self.model.as_ptr()) }
644 }
645
646 pub fn n_params(&self) -> u64 {
648 unsafe { llama_cpp_sys_2::llama_model_n_params(self.model.as_ptr()) }
649 }
650
651 pub fn is_recurrent(&self) -> bool {
653 unsafe { llama_cpp_sys_2::llama_model_is_recurrent(self.model.as_ptr()) }
654 }
655
656 pub fn is_hybrid(&self) -> bool {
661 unsafe { llama_cpp_sys_2::llama_model_is_hybrid(self.model.as_ptr()) }
662 }
663
664 pub fn n_layer(&self) -> u32 {
666 u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_layer(self.model.as_ptr()) }).unwrap()
669 }
670
671 pub fn n_head(&self) -> u32 {
673 u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_head(self.model.as_ptr()) }).unwrap()
676 }
677
678 pub fn n_head_kv(&self) -> u32 {
680 u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_head_kv(self.model.as_ptr()) })
683 .unwrap()
684 }
685
686 pub fn meta_val_str(&self, key: &str) -> Result<String, MetaValError> {
688 let key_cstring = CString::new(key)?;
689 let key_ptr = key_cstring.as_ptr();
690
691 extract_meta_string(
692 |buf_ptr, buf_len| unsafe {
693 llama_cpp_sys_2::llama_model_meta_val_str(
694 self.model.as_ptr(),
695 key_ptr,
696 buf_ptr,
697 buf_len,
698 )
699 },
700 256,
701 )
702 }
703
704 pub fn meta_count(&self) -> i32 {
706 unsafe { llama_cpp_sys_2::llama_model_meta_count(self.model.as_ptr()) }
707 }
708
709 pub fn meta_key_by_index(&self, index: i32) -> Result<String, MetaValError> {
711 extract_meta_string(
712 |buf_ptr, buf_len| unsafe {
713 llama_cpp_sys_2::llama_model_meta_key_by_index(
714 self.model.as_ptr(),
715 index,
716 buf_ptr,
717 buf_len,
718 )
719 },
720 256,
721 )
722 }
723
724 pub fn meta_val_str_by_index(&self, index: i32) -> Result<String, MetaValError> {
726 extract_meta_string(
727 |buf_ptr, buf_len| unsafe {
728 llama_cpp_sys_2::llama_model_meta_val_str_by_index(
729 self.model.as_ptr(),
730 index,
731 buf_ptr,
732 buf_len,
733 )
734 },
735 256,
736 )
737 }
738
739 pub fn rope_type(&self) -> Option<RopeType> {
741 match unsafe { llama_cpp_sys_2::llama_model_rope_type(self.model.as_ptr()) } {
742 llama_cpp_sys_2::LLAMA_ROPE_TYPE_NONE => None,
743 llama_cpp_sys_2::LLAMA_ROPE_TYPE_NORM => Some(RopeType::Norm),
744 llama_cpp_sys_2::LLAMA_ROPE_TYPE_NEOX => Some(RopeType::NeoX),
745 llama_cpp_sys_2::LLAMA_ROPE_TYPE_MROPE => Some(RopeType::MRope),
746 llama_cpp_sys_2::LLAMA_ROPE_TYPE_VISION => Some(RopeType::Vision),
747 rope_type => {
748 tracing::error!(rope_type = rope_type, "Unexpected rope type from llama.cpp");
749 None
750 }
751 }
752 }
753
754 pub fn chat_template(
768 &self,
769 name: Option<&str>,
770 ) -> Result<LlamaChatTemplate, ChatTemplateError> {
771 let name_cstr = name.map(CString::new);
772 let name_ptr = match name_cstr {
773 Some(Ok(name)) => name.as_ptr(),
774 _ => std::ptr::null(),
775 };
776 let result =
777 unsafe { llama_cpp_sys_2::llama_model_chat_template(self.model.as_ptr(), name_ptr) };
778
779 if result.is_null() {
781 Err(ChatTemplateError::MissingTemplate)
782 } else {
783 let chat_template_cstr = unsafe { CStr::from_ptr(result) };
784 let chat_template = CString::new(chat_template_cstr.to_bytes())?;
785 Ok(LlamaChatTemplate(chat_template))
786 }
787 }
788
789 #[tracing::instrument(skip_all, fields(params))]
795 pub fn load_from_file(
796 _: &LlamaBackend,
797 path: impl AsRef<Path>,
798 params: &LlamaModelParams,
799 ) -> Result<Self, LlamaModelLoadError> {
800 let path = path.as_ref();
801 debug_assert!(Path::new(path).exists(), "{path:?} does not exist");
802 let path = path
803 .to_str()
804 .ok_or(LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
805
806 let cstr = CString::new(path)?;
807 let llama_model =
808 unsafe { llama_cpp_sys_2::llama_load_model_from_file(cstr.as_ptr(), params.params) };
809
810 let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?;
811
812 tracing::debug!(?path, "Loaded model");
813 Ok(LlamaModel { model })
814 }
815
816 pub fn lora_adapter_init(
822 &self,
823 path: impl AsRef<Path>,
824 ) -> Result<LlamaLoraAdapter, LlamaLoraAdapterInitError> {
825 let path = path.as_ref();
826 debug_assert!(Path::new(path).exists(), "{path:?} does not exist");
827
828 let path = path
829 .to_str()
830 .ok_or(LlamaLoraAdapterInitError::PathToStrError(
831 path.to_path_buf(),
832 ))?;
833
834 let cstr = CString::new(path)?;
835 let adapter =
836 unsafe { llama_cpp_sys_2::llama_adapter_lora_init(self.model.as_ptr(), cstr.as_ptr()) };
837
838 let adapter = NonNull::new(adapter).ok_or(LlamaLoraAdapterInitError::NullResult)?;
839
840 tracing::debug!(?path, "Initialized lora adapter");
841 Ok(LlamaLoraAdapter {
842 lora_adapter: adapter,
843 })
844 }
845
846 #[allow(clippy::needless_pass_by_value)]
853 pub fn new_context<'a>(
854 &'a self,
855 _: &LlamaBackend,
856 params: LlamaContextParams,
857 ) -> Result<LlamaContext<'a>, LlamaContextLoadError> {
858 let context_params = params.context_params;
859 let context = unsafe {
860 llama_cpp_sys_2::llama_new_context_with_model(self.model.as_ptr(), context_params)
861 };
862 let context = NonNull::new(context).ok_or(LlamaContextLoadError::NullReturn)?;
863
864 Ok(LlamaContext::new(self, context, params.embeddings()))
865 }
866
867 #[tracing::instrument(skip_all)]
885 pub fn apply_chat_template(
886 &self,
887 tmpl: &LlamaChatTemplate,
888 chat: &[LlamaChatMessage],
889 add_ass: bool,
890 ) -> Result<String, ApplyChatTemplateError> {
891 let message_length = chat.iter().fold(0, |acc, c| {
893 acc + c.role.to_bytes().len() + c.content.to_bytes().len()
894 });
895 let mut buff: Vec<u8> = vec![0; message_length * 2];
896
897 let chat: Vec<llama_cpp_sys_2::llama_chat_message> = chat
899 .iter()
900 .map(|c| llama_cpp_sys_2::llama_chat_message {
901 role: c.role.as_ptr(),
902 content: c.content.as_ptr(),
903 })
904 .collect();
905
906 let tmpl_ptr = tmpl.0.as_ptr();
907
908 let res = unsafe {
909 llama_cpp_sys_2::llama_chat_apply_template(
910 tmpl_ptr,
911 chat.as_ptr(),
912 chat.len(),
913 add_ass,
914 buff.as_mut_ptr().cast::<c_char>(),
915 buff.len().try_into().expect("Buffer size exceeds i32::MAX"),
916 )
917 };
918
919 if res > buff.len().try_into().expect("Buffer size exceeds i32::MAX") {
920 buff.resize(res.try_into().expect("res is negative"), 0);
921
922 let res = unsafe {
923 llama_cpp_sys_2::llama_chat_apply_template(
924 tmpl_ptr,
925 chat.as_ptr(),
926 chat.len(),
927 add_ass,
928 buff.as_mut_ptr().cast::<c_char>(),
929 buff.len().try_into().expect("Buffer size exceeds i32::MAX"),
930 )
931 };
932 assert_eq!(Ok(res), buff.len().try_into());
933 }
934 buff.truncate(res.try_into().expect("res is negative"));
935 Ok(String::from_utf8(buff)?)
936 }
937
938 #[tracing::instrument(skip_all)]
942 pub fn apply_chat_template_with_tools_oaicompat(
943 &self,
944 tmpl: &LlamaChatTemplate,
945 messages: &[LlamaChatMessage],
946 tools_json: Option<&str>,
947 json_schema: Option<&str>,
948 add_generation_prompt: bool,
949 ) -> Result<ChatTemplateResult, ApplyChatTemplateError> {
950 let chat: Vec<llama_cpp_sys_2::llama_chat_message> = messages
951 .iter()
952 .map(|c| llama_cpp_sys_2::llama_chat_message {
953 role: c.role.as_ptr(),
954 content: c.content.as_ptr(),
955 })
956 .collect();
957
958 let tools_cstr = tools_json.map(CString::new).transpose()?;
959 let json_schema_cstr = json_schema.map(CString::new).transpose()?;
960
961 let mut raw_result = llama_cpp_sys_2::llama_rs_chat_template_result {
962 prompt: ptr::null_mut(),
963 grammar: ptr::null_mut(),
964 parser: ptr::null_mut(),
965 chat_format: 0,
966 thinking_forced_open: false,
967 grammar_lazy: false,
968 grammar_triggers: ptr::null_mut(),
969 grammar_triggers_count: 0,
970 preserved_tokens: ptr::null_mut(),
971 preserved_tokens_count: 0,
972 additional_stops: ptr::null_mut(),
973 additional_stops_count: 0,
974 };
975
976 let rc = unsafe {
977 llama_cpp_sys_2::llama_rs_apply_chat_template_with_tools_oaicompat(
978 self.model.as_ptr(),
979 tmpl.0.as_ptr(),
980 chat.as_ptr(),
981 chat.len(),
982 tools_cstr
983 .as_ref()
984 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
985 json_schema_cstr
986 .as_ref()
987 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
988 add_generation_prompt,
989 &mut raw_result,
990 )
991 };
992
993 let result = (|| {
994 if !status_is_ok(rc) {
995 return Err(ApplyChatTemplateError::FfiError(status_to_i32(rc)));
996 }
997 if raw_result.prompt.is_null() {
998 return Err(ApplyChatTemplateError::NullResult);
999 }
1000 let prompt_bytes = unsafe { CStr::from_ptr(raw_result.prompt) }
1001 .to_bytes()
1002 .to_vec();
1003 let prompt = String::from_utf8(prompt_bytes)?;
1004 let grammar_lazy = raw_result.grammar_lazy;
1005 let grammar = if raw_result.grammar.is_null() {
1006 None
1007 } else {
1008 let grammar_bytes = unsafe { CStr::from_ptr(raw_result.grammar) }
1009 .to_bytes()
1010 .to_vec();
1011 Some(String::from_utf8(grammar_bytes)?)
1012 };
1013 let parser = if raw_result.parser.is_null() {
1014 None
1015 } else {
1016 let parser_bytes = unsafe { CStr::from_ptr(raw_result.parser) }
1017 .to_bytes()
1018 .to_vec();
1019 Some(String::from_utf8(parser_bytes)?)
1020 };
1021 let grammar_triggers = if raw_result.grammar_triggers_count == 0 {
1022 Vec::new()
1023 } else if raw_result.grammar_triggers.is_null() {
1024 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1025 } else {
1026 let triggers = unsafe {
1027 slice::from_raw_parts(
1028 raw_result.grammar_triggers,
1029 raw_result.grammar_triggers_count,
1030 )
1031 };
1032 let mut parsed = Vec::with_capacity(triggers.len());
1033 for trigger in triggers {
1034 let trigger_type = match trigger.type_ {
1035 0 => GrammarTriggerType::Token,
1036 1 => GrammarTriggerType::Word,
1037 2 => GrammarTriggerType::Pattern,
1038 3 => GrammarTriggerType::PatternFull,
1039 _ => return Err(ApplyChatTemplateError::InvalidGrammarTriggerType),
1040 };
1041 let value = if trigger.value.is_null() {
1042 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1043 } else {
1044 let bytes = unsafe { CStr::from_ptr(trigger.value) }.to_bytes().to_vec();
1045 String::from_utf8(bytes)?
1046 };
1047 let token = if trigger_type == GrammarTriggerType::Token {
1048 Some(LlamaToken(trigger.token))
1049 } else {
1050 None
1051 };
1052 parsed.push(GrammarTrigger {
1053 trigger_type,
1054 value,
1055 token,
1056 });
1057 }
1058 parsed
1059 };
1060 let preserved_tokens = if raw_result.preserved_tokens_count == 0 {
1061 Vec::new()
1062 } else if raw_result.preserved_tokens.is_null() {
1063 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1064 } else {
1065 let tokens = unsafe {
1066 slice::from_raw_parts(
1067 raw_result.preserved_tokens,
1068 raw_result.preserved_tokens_count,
1069 )
1070 };
1071 let mut parsed = Vec::with_capacity(tokens.len());
1072 for token in tokens {
1073 if token.is_null() {
1074 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1075 }
1076 let bytes = unsafe { CStr::from_ptr(*token) }.to_bytes().to_vec();
1077 parsed.push(String::from_utf8(bytes)?);
1078 }
1079 parsed
1080 };
1081 let additional_stops = if raw_result.additional_stops_count == 0 {
1082 Vec::new()
1083 } else if raw_result.additional_stops.is_null() {
1084 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1085 } else {
1086 let stops = unsafe {
1087 slice::from_raw_parts(
1088 raw_result.additional_stops,
1089 raw_result.additional_stops_count,
1090 )
1091 };
1092 let mut parsed = Vec::with_capacity(stops.len());
1093 for stop in stops {
1094 if stop.is_null() {
1095 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1096 }
1097 let bytes = unsafe { CStr::from_ptr(*stop) }.to_bytes().to_vec();
1098 parsed.push(String::from_utf8(bytes)?);
1099 }
1100 parsed
1101 };
1102 let parse_tool_calls = tools_json.map_or(false, |tools| !tools.is_empty());
1103 Ok(ChatTemplateResult {
1104 prompt,
1105 grammar,
1106 grammar_lazy,
1107 grammar_triggers,
1108 preserved_tokens,
1109 additional_stops,
1110 chat_format: raw_result.chat_format,
1111 parser,
1112 thinking_forced_open: raw_result.thinking_forced_open,
1113 parse_tool_calls,
1114 })
1115 })();
1116
1117 unsafe { llama_cpp_sys_2::llama_rs_chat_template_result_free(&mut raw_result) };
1118 result
1119 }
1120
1121 #[tracing::instrument(skip_all)]
1123 pub fn apply_chat_template_oaicompat(
1124 &self,
1125 tmpl: &LlamaChatTemplate,
1126 params: &OpenAIChatTemplateParams<'_>,
1127 ) -> Result<ChatTemplateResult, ApplyChatTemplateError> {
1128 let parse_tool_calls = params.parse_tool_calls;
1129 let messages_cstr = CString::new(params.messages_json)?;
1130 let tools_cstr = params.tools_json.map(CString::new).transpose()?;
1131 let tool_choice_cstr = params.tool_choice.map(CString::new).transpose()?;
1132 let json_schema_cstr = params.json_schema.map(CString::new).transpose()?;
1133 let grammar_cstr = params.grammar.map(CString::new).transpose()?;
1134 let reasoning_cstr = params.reasoning_format.map(CString::new).transpose()?;
1135 let kwargs_cstr = params.chat_template_kwargs.map(CString::new).transpose()?;
1136
1137 let mut raw_result = llama_cpp_sys_2::llama_rs_chat_template_result {
1138 prompt: ptr::null_mut(),
1139 grammar: ptr::null_mut(),
1140 parser: ptr::null_mut(),
1141 chat_format: 0,
1142 thinking_forced_open: false,
1143 grammar_lazy: false,
1144 grammar_triggers: ptr::null_mut(),
1145 grammar_triggers_count: 0,
1146 preserved_tokens: ptr::null_mut(),
1147 preserved_tokens_count: 0,
1148 additional_stops: ptr::null_mut(),
1149 additional_stops_count: 0,
1150 };
1151
1152 let ffi_params = llama_cpp_sys_2::llama_rs_chat_template_oaicompat_params {
1153 messages: messages_cstr.as_ptr(),
1154 tools: tools_cstr
1155 .as_ref()
1156 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
1157 tool_choice: tool_choice_cstr
1158 .as_ref()
1159 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
1160 json_schema: json_schema_cstr
1161 .as_ref()
1162 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
1163 grammar: grammar_cstr
1164 .as_ref()
1165 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
1166 reasoning_format: reasoning_cstr
1167 .as_ref()
1168 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
1169 chat_template_kwargs: kwargs_cstr
1170 .as_ref()
1171 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
1172 add_generation_prompt: params.add_generation_prompt,
1173 use_jinja: params.use_jinja,
1174 parallel_tool_calls: params.parallel_tool_calls,
1175 enable_thinking: params.enable_thinking,
1176 add_bos: params.add_bos,
1177 add_eos: params.add_eos,
1178 };
1179
1180 let rc = unsafe {
1181 llama_cpp_sys_2::llama_rs_apply_chat_template_oaicompat(
1182 self.model.as_ptr(),
1183 tmpl.0.as_ptr(),
1184 &ffi_params,
1185 &mut raw_result,
1186 )
1187 };
1188
1189 let result = (|| {
1190 if !status_is_ok(rc) {
1191 return Err(ApplyChatTemplateError::FfiError(status_to_i32(rc)));
1192 }
1193 if raw_result.prompt.is_null() {
1194 return Err(ApplyChatTemplateError::NullResult);
1195 }
1196 let prompt_bytes = unsafe { CStr::from_ptr(raw_result.prompt) }
1197 .to_bytes()
1198 .to_vec();
1199 let prompt = String::from_utf8(prompt_bytes)?;
1200 let grammar_lazy = raw_result.grammar_lazy;
1201 let grammar = if raw_result.grammar.is_null() {
1202 None
1203 } else {
1204 let grammar_bytes = unsafe { CStr::from_ptr(raw_result.grammar) }
1205 .to_bytes()
1206 .to_vec();
1207 Some(String::from_utf8(grammar_bytes)?)
1208 };
1209 let parser = if raw_result.parser.is_null() {
1210 None
1211 } else {
1212 let parser_bytes = unsafe { CStr::from_ptr(raw_result.parser) }
1213 .to_bytes()
1214 .to_vec();
1215 Some(String::from_utf8(parser_bytes)?)
1216 };
1217 let grammar_triggers = if raw_result.grammar_triggers_count == 0 {
1218 Vec::new()
1219 } else if raw_result.grammar_triggers.is_null() {
1220 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1221 } else {
1222 let triggers = unsafe {
1223 slice::from_raw_parts(
1224 raw_result.grammar_triggers,
1225 raw_result.grammar_triggers_count,
1226 )
1227 };
1228 let mut parsed = Vec::with_capacity(triggers.len());
1229 for trigger in triggers {
1230 let trigger_type = match trigger.type_ {
1231 0 => GrammarTriggerType::Token,
1232 1 => GrammarTriggerType::Word,
1233 2 => GrammarTriggerType::Pattern,
1234 3 => GrammarTriggerType::PatternFull,
1235 _ => return Err(ApplyChatTemplateError::InvalidGrammarTriggerType),
1236 };
1237 let value = if trigger.value.is_null() {
1238 String::new()
1239 } else {
1240 let bytes = unsafe { CStr::from_ptr(trigger.value) }.to_bytes().to_vec();
1241 String::from_utf8(bytes)?
1242 };
1243 let token = if trigger_type == GrammarTriggerType::Token {
1244 Some(LlamaToken(trigger.token))
1245 } else {
1246 None
1247 };
1248 parsed.push(GrammarTrigger {
1249 trigger_type,
1250 value,
1251 token,
1252 });
1253 }
1254 parsed
1255 };
1256 let preserved_tokens = if raw_result.preserved_tokens_count == 0 {
1257 Vec::new()
1258 } else if raw_result.preserved_tokens.is_null() {
1259 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1260 } else {
1261 let tokens = unsafe {
1262 slice::from_raw_parts(
1263 raw_result.preserved_tokens,
1264 raw_result.preserved_tokens_count,
1265 )
1266 };
1267 let mut parsed = Vec::with_capacity(tokens.len());
1268 for token in tokens {
1269 if token.is_null() {
1270 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1271 }
1272 let bytes = unsafe { CStr::from_ptr(*token) }.to_bytes().to_vec();
1273 parsed.push(String::from_utf8(bytes)?);
1274 }
1275 parsed
1276 };
1277 let additional_stops = if raw_result.additional_stops_count == 0 {
1278 Vec::new()
1279 } else if raw_result.additional_stops.is_null() {
1280 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1281 } else {
1282 let stops = unsafe {
1283 slice::from_raw_parts(
1284 raw_result.additional_stops,
1285 raw_result.additional_stops_count,
1286 )
1287 };
1288 let mut parsed = Vec::with_capacity(stops.len());
1289 for stop in stops {
1290 if stop.is_null() {
1291 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1292 }
1293 let bytes = unsafe { CStr::from_ptr(*stop) }.to_bytes().to_vec();
1294 parsed.push(String::from_utf8(bytes)?);
1295 }
1296 parsed
1297 };
1298
1299 Ok(ChatTemplateResult {
1300 prompt,
1301 grammar,
1302 grammar_lazy,
1303 grammar_triggers,
1304 preserved_tokens,
1305 additional_stops,
1306 chat_format: raw_result.chat_format,
1307 parser,
1308 thinking_forced_open: raw_result.thinking_forced_open,
1309 parse_tool_calls,
1310 })
1311 })();
1312
1313 unsafe { llama_cpp_sys_2::llama_rs_chat_template_result_free(&mut raw_result) };
1314 result
1315 }
1316}
1317
1318impl ChatTemplateResult {
1319 pub fn parse_response_oaicompat(
1321 &self,
1322 text: &str,
1323 is_partial: bool,
1324 ) -> Result<String, ChatParseError> {
1325 let text_cstr = CString::new(text)?;
1326 let parser_cstr = self.parser.as_deref().map(CString::new).transpose()?;
1327 let mut out_json: *mut c_char = ptr::null_mut();
1328 let rc = unsafe {
1329 llama_cpp_sys_2::llama_rs_chat_parse_to_oaicompat(
1330 text_cstr.as_ptr(),
1331 is_partial,
1332 self.chat_format,
1333 self.parse_tool_calls,
1334 parser_cstr
1335 .as_ref()
1336 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
1337 self.thinking_forced_open,
1338 &mut out_json,
1339 )
1340 };
1341
1342 let result = (|| {
1343 if !status_is_ok(rc) {
1344 return Err(ChatParseError::FfiError(status_to_i32(rc)));
1345 }
1346 if out_json.is_null() {
1347 return Err(ChatParseError::NullResult);
1348 }
1349 let bytes = unsafe { CStr::from_ptr(out_json) }.to_bytes().to_vec();
1350 Ok(String::from_utf8(bytes)?)
1351 })();
1352
1353 unsafe { llama_cpp_sys_2::llama_rs_string_free(out_json) };
1354 result
1355 }
1356
1357 pub fn streaming_state_oaicompat(&self) -> Result<ChatParseStateOaicompat, ChatParseError> {
1359 let parser_cstr = self.parser.as_deref().map(CString::new).transpose()?;
1360 let state = unsafe {
1361 llama_cpp_sys_2::llama_rs_chat_parse_state_init_oaicompat(
1362 self.chat_format,
1363 self.parse_tool_calls,
1364 parser_cstr
1365 .as_ref()
1366 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
1367 self.thinking_forced_open,
1368 )
1369 };
1370 let state = NonNull::new(state).ok_or(ChatParseError::NullResult)?;
1371 Ok(ChatParseStateOaicompat { state })
1372 }
1373}
1374
1375fn extract_meta_string<F>(c_function: F, capacity: usize) -> Result<String, MetaValError>
1381where
1382 F: Fn(*mut c_char, usize) -> i32,
1383{
1384 let mut buffer = vec![0u8; capacity];
1385
1386 let result = c_function(buffer.as_mut_ptr().cast::<c_char>(), buffer.len());
1388 if result < 0 {
1389 return Err(MetaValError::NegativeReturn(result));
1390 }
1391
1392 let returned_len = result as usize;
1394 if returned_len >= capacity {
1395 return extract_meta_string(c_function, returned_len + 1);
1397 }
1398
1399 debug_assert_eq!(
1401 buffer.get(returned_len),
1402 Some(&0),
1403 "should end with null byte"
1404 );
1405
1406 buffer.truncate(returned_len);
1408 Ok(String::from_utf8(buffer)?)
1409}
1410
1411impl Drop for LlamaModel {
1412 fn drop(&mut self) {
1413 unsafe { llama_cpp_sys_2::llama_free_model(self.model.as_ptr()) }
1414 }
1415}
1416
1417#[repr(u32)]
1419#[derive(Debug, Eq, Copy, Clone, PartialEq)]
1420pub enum VocabType {
1421 BPE = llama_cpp_sys_2::LLAMA_VOCAB_TYPE_BPE as _,
1423 SPM = llama_cpp_sys_2::LLAMA_VOCAB_TYPE_SPM as _,
1425}
1426
1427#[derive(thiserror::Error, Debug, Eq, PartialEq)]
1429pub enum LlamaTokenTypeFromIntError {
1430 #[error("Unknown Value {0}")]
1432 UnknownValue(llama_cpp_sys_2::llama_vocab_type),
1433}
1434
1435impl TryFrom<llama_cpp_sys_2::llama_vocab_type> for VocabType {
1436 type Error = LlamaTokenTypeFromIntError;
1437
1438 fn try_from(value: llama_cpp_sys_2::llama_vocab_type) -> Result<Self, Self::Error> {
1439 match value {
1440 llama_cpp_sys_2::LLAMA_VOCAB_TYPE_BPE => Ok(VocabType::BPE),
1441 llama_cpp_sys_2::LLAMA_VOCAB_TYPE_SPM => Ok(VocabType::SPM),
1442 unknown => Err(LlamaTokenTypeFromIntError::UnknownValue(unknown)),
1443 }
1444 }
1445}