1use std::ffi::{c_char, CStr, CString};
3use std::num::NonZeroU16;
4use std::os::raw::c_int;
5use std::path::Path;
6use std::ptr::{self, NonNull};
7use std::slice;
8use std::str::Utf8Error;
9
10use crate::context::params::LlamaContextParams;
11use crate::context::LlamaContext;
12use crate::llama_backend::LlamaBackend;
13use crate::model::params::LlamaModelParams;
14use crate::openai::{ChatParseStateOaicompat, OpenAIChatTemplateParams};
15use crate::token::LlamaToken;
16use crate::token_type::{LlamaTokenAttr, LlamaTokenAttrs};
17use crate::{
18 status_is_ok, status_to_i32, ApplyChatTemplateError, ChatParseError, ChatTemplateError,
19 LlamaContextLoadError, LlamaLoraAdapterInitError, LlamaModelLoadError, MetaValError,
20 NewLlamaChatMessageError, StringToTokenError, TokenToStringError,
21};
22
23pub mod params;
24
25#[derive(Debug)]
27#[repr(transparent)]
28#[allow(clippy::module_name_repetitions)]
29pub struct LlamaModel {
30 pub(crate) model: NonNull<llama_cpp_sys_2::llama_model>,
31}
32
33#[derive(Debug)]
35#[repr(transparent)]
36#[allow(clippy::module_name_repetitions)]
37pub struct LlamaLoraAdapter {
38 pub(crate) lora_adapter: NonNull<llama_cpp_sys_2::llama_adapter_lora>,
39}
40
41#[derive(Eq, PartialEq, Clone, PartialOrd, Ord, Hash)]
46pub struct LlamaChatTemplate(CString);
47
48impl LlamaChatTemplate {
49 pub fn new(template: &str) -> Result<Self, std::ffi::NulError> {
52 Ok(Self(CString::new(template)?))
53 }
54
55 pub fn as_c_str(&self) -> &CStr {
57 &self.0
58 }
59
60 pub fn to_str(&self) -> Result<&str, Utf8Error> {
62 self.0.to_str()
63 }
64
65 pub fn to_string(&self) -> Result<String, Utf8Error> {
67 self.to_str().map(str::to_string)
68 }
69}
70
71impl std::fmt::Debug for LlamaChatTemplate {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 self.0.fmt(f)
74 }
75}
76
77#[derive(Debug, Eq, PartialEq, Clone)]
79pub struct LlamaChatMessage {
80 role: CString,
81 content: CString,
82}
83
84impl LlamaChatMessage {
85 pub fn new(role: String, content: String) -> Result<Self, NewLlamaChatMessageError> {
90 Ok(Self {
91 role: CString::new(role)?,
92 content: CString::new(content)?,
93 })
94 }
95}
96
97#[derive(Debug, Clone, Copy, PartialEq, Eq)]
99pub enum GrammarTriggerType {
100 Token = 0,
102 Word = 1,
104 Pattern = 2,
106 PatternFull = 3,
108}
109
110#[derive(Debug, Clone, PartialEq, Eq)]
112pub struct GrammarTrigger {
113 pub trigger_type: GrammarTriggerType,
115 pub value: String,
117 pub token: Option<LlamaToken>,
119}
120
121#[derive(Debug, Clone, PartialEq, Eq)]
123pub struct ChatTemplateResult {
124 pub prompt: String,
126 pub grammar: Option<String>,
128 pub grammar_lazy: bool,
130 pub grammar_triggers: Vec<GrammarTrigger>,
132 pub preserved_tokens: Vec<String>,
134 pub additional_stops: Vec<String>,
136 pub chat_format: i32,
138 pub parser: Option<String>,
140 pub thinking_forced_open: bool,
142 pub parse_tool_calls: bool,
144}
145
146#[derive(Debug, Clone, Copy, PartialEq, Eq)]
148pub enum RopeType {
149 Norm,
150 NeoX,
151 MRope,
152 Vision,
153}
154
155#[derive(Debug, Clone, Copy, PartialEq, Eq)]
157pub enum AddBos {
158 Always,
160 Never,
162}
163
164#[deprecated(
166 since = "0.1.0",
167 note = "This enum is a mixture of options for llama cpp providing less flexibility it only used with deprecated methods and will be removed in the future."
168)]
169#[derive(Debug, Clone, Copy, PartialEq, Eq)]
170pub enum Special {
171 Tokenize,
173 Plaintext,
175}
176
177unsafe impl Send for LlamaModel {}
178
179unsafe impl Sync for LlamaModel {}
180
181impl LlamaModel {
182 pub(crate) fn vocab_ptr(&self) -> *const llama_cpp_sys_2::llama_vocab {
183 unsafe { llama_cpp_sys_2::llama_model_get_vocab(self.model.as_ptr()) }
184 }
185
186 #[must_use]
193 pub fn n_ctx_train(&self) -> u32 {
194 let n_ctx_train = unsafe { llama_cpp_sys_2::llama_n_ctx_train(self.model.as_ptr()) };
195 u32::try_from(n_ctx_train).expect("n_ctx_train fits into an u32")
196 }
197
198 pub fn tokens(
200 &self,
201 decode_special: bool,
202 ) -> impl Iterator<Item = (LlamaToken, Result<String, TokenToStringError>)> + '_ {
203 (0..self.n_vocab())
204 .map(LlamaToken::new)
205 .map(move |llama_token| {
206 let mut decoder = encoding_rs::UTF_8.new_decoder();
207 (
208 llama_token,
209 self.token_to_piece(llama_token, &mut decoder, decode_special, None),
210 )
211 })
212 }
213
214 #[must_use]
216 pub fn token_bos(&self) -> LlamaToken {
217 let token = unsafe { llama_cpp_sys_2::llama_token_bos(self.vocab_ptr()) };
218 LlamaToken(token)
219 }
220
221 #[must_use]
223 pub fn token_eos(&self) -> LlamaToken {
224 let token = unsafe { llama_cpp_sys_2::llama_token_eos(self.vocab_ptr()) };
225 LlamaToken(token)
226 }
227
228 #[must_use]
230 pub fn token_nl(&self) -> LlamaToken {
231 let token = unsafe { llama_cpp_sys_2::llama_token_nl(self.vocab_ptr()) };
232 LlamaToken(token)
233 }
234
235 #[must_use]
237 pub fn is_eog_token(&self, token: LlamaToken) -> bool {
238 unsafe { llama_cpp_sys_2::llama_token_is_eog(self.vocab_ptr(), token.0) }
239 }
240
241 #[must_use]
243 pub fn decode_start_token(&self) -> LlamaToken {
244 let token =
245 unsafe { llama_cpp_sys_2::llama_model_decoder_start_token(self.model.as_ptr()) };
246 LlamaToken(token)
247 }
248
249 #[must_use]
251 pub fn token_sep(&self) -> LlamaToken {
252 let token = unsafe { llama_cpp_sys_2::llama_vocab_sep(self.vocab_ptr()) };
253 LlamaToken(token)
254 }
255
256 #[deprecated(since = "0.1.0", note = "Use `token_to_piece` instead")]
262 pub fn token_to_str(
263 &self,
264 token: LlamaToken,
265 special: Special,
266 ) -> Result<String, TokenToStringError> {
267 let mut decoder = encoding_rs::UTF_8.new_decoder();
269 Ok(self.token_to_piece(
270 token,
271 &mut decoder,
272 matches!(special, Special::Tokenize),
273 None,
274 )?)
275 }
276
277 #[deprecated(since = "0.1.0", note = "Use `token_to_piece_bytes` instead")]
287 pub fn token_to_bytes(
288 &self,
289 token: LlamaToken,
290 special: Special,
291 ) -> Result<Vec<u8>, TokenToStringError> {
292 match self.token_to_piece_bytes(token, 8, matches!(special, Special::Tokenize), None) {
294 Err(TokenToStringError::InsufficientBufferSpace(i)) => self.token_to_piece_bytes(
295 token,
296 (-i).try_into().expect("Error buffer size is positive"),
297 matches!(special, Special::Tokenize),
298 None,
299 ),
300 x => x,
301 }
302 }
303
304 #[deprecated(
310 since = "0.1.0",
311 note = "Use `token_to_piece` for each token individually instead"
312 )]
313 pub fn tokens_to_str(
314 &self,
315 tokens: &[LlamaToken],
316 special: Special,
317 ) -> Result<String, TokenToStringError> {
318 let mut builder: Vec<u8> = Vec::with_capacity(tokens.len() * 4);
319 for piece in tokens
320 .iter()
321 .copied()
322 .map(|t| self.token_to_piece_bytes(t, 8, matches!(special, Special::Tokenize), None))
323 {
324 builder.extend_from_slice(&piece?);
325 }
326 Ok(String::from_utf8(builder)?)
327 }
328
329 pub fn str_to_token(
352 &self,
353 str: &str,
354 add_bos: AddBos,
355 ) -> Result<Vec<LlamaToken>, StringToTokenError> {
356 let add_bos = match add_bos {
357 AddBos::Always => true,
358 AddBos::Never => false,
359 };
360
361 let tokens_estimation = std::cmp::max(8, (str.len() / 2) + usize::from(add_bos));
362 let mut buffer: Vec<LlamaToken> = Vec::with_capacity(tokens_estimation);
363
364 let c_string = CString::new(str)?;
365 let buffer_capacity =
366 c_int::try_from(buffer.capacity()).expect("buffer capacity should fit into a c_int");
367
368 let size = unsafe {
369 llama_cpp_sys_2::llama_tokenize(
370 self.vocab_ptr(),
371 c_string.as_ptr(),
372 c_int::try_from(c_string.as_bytes().len())?,
373 buffer.as_mut_ptr().cast::<llama_cpp_sys_2::llama_token>(),
374 buffer_capacity,
375 add_bos,
376 true,
377 )
378 };
379
380 let size = if size.is_negative() {
383 buffer.reserve_exact(usize::try_from(-size).expect("usize's are larger "));
384 unsafe {
385 llama_cpp_sys_2::llama_tokenize(
386 self.vocab_ptr(),
387 c_string.as_ptr(),
388 c_int::try_from(c_string.as_bytes().len())?,
389 buffer.as_mut_ptr().cast::<llama_cpp_sys_2::llama_token>(),
390 -size,
391 add_bos,
392 true,
393 )
394 }
395 } else {
396 size
397 };
398
399 let size = usize::try_from(size).expect("size is positive and usize ");
400
401 unsafe { buffer.set_len(size) }
403 Ok(buffer)
404 }
405
406 #[must_use]
412 pub fn token_attr(&self, LlamaToken(id): LlamaToken) -> LlamaTokenAttrs {
413 let token_type = unsafe { llama_cpp_sys_2::llama_token_get_attr(self.vocab_ptr(), id) };
414 LlamaTokenAttrs::try_from(token_type).expect("token type is valid")
415 }
416
417 pub fn token_to_piece(
435 &self,
436 token: LlamaToken,
437 decoder: &mut encoding_rs::Decoder,
438 special: bool,
439 lstrip: Option<NonZeroU16>,
440 ) -> Result<String, TokenToStringError> {
441 let bytes = match self.token_to_piece_bytes(token, 8, special, lstrip) {
442 Err(TokenToStringError::InsufficientBufferSpace(i)) => self.token_to_piece_bytes(
445 token,
446 (-i).try_into().expect("Error buffer size is positive"),
447 special,
448 lstrip,
449 ),
450 x => x,
451 }?;
452 let mut output_piece = String::with_capacity(bytes.len());
454 let (_result, _somesize, _truthy) =
457 decoder.decode_to_string(&bytes, &mut output_piece, false);
458 Ok(output_piece)
459 }
460
461 pub fn token_to_piece_bytes(
477 &self,
478 token: LlamaToken,
479 buffer_size: usize,
480 special: bool,
481 lstrip: Option<NonZeroU16>,
482 ) -> Result<Vec<u8>, TokenToStringError> {
483 let string = CString::new(vec![b'*'; buffer_size]).expect("no null");
484 let len = string.as_bytes().len();
485 let len = c_int::try_from(len).expect("length fits into c_int");
486 let buf = string.into_raw();
487 let lstrip = lstrip.map_or(0, |it| i32::from(it.get()));
488 let size = unsafe {
489 llama_cpp_sys_2::llama_token_to_piece(
490 self.vocab_ptr(),
491 token.0,
492 buf,
493 len,
494 lstrip,
495 special,
496 )
497 };
498
499 match size {
500 0 => Err(TokenToStringError::UnknownTokenType),
501 i if i.is_negative() => Err(TokenToStringError::InsufficientBufferSpace(i)),
502 size => {
503 let string = unsafe { CString::from_raw(buf) };
504 let mut bytes = string.into_bytes();
505 let len = usize::try_from(size).expect("size is positive and fits into usize");
506 bytes.truncate(len);
507 Ok(bytes)
508 }
509 }
510 }
511
512 #[deprecated(since = "0.1.0", note = "Use `token_to_piece` instead")]
528 pub fn token_to_str_with_size(
529 &self,
530 token: LlamaToken,
531 buffer_size: usize,
532 special: Special,
533 ) -> Result<String, TokenToStringError> {
534 let bytes = self.token_to_piece_bytes(
535 token,
536 buffer_size,
537 matches!(special, Special::Tokenize),
538 None,
539 )?;
540 Ok(String::from_utf8(bytes)?)
541 }
542
543 #[deprecated(since = "0.1.0", note = "Use `token_to_piece_bytes` instead")]
558 pub fn token_to_bytes_with_size(
559 &self,
560 token: LlamaToken,
561 buffer_size: usize,
562 special: Special,
563 lstrip: Option<NonZeroU16>,
564 ) -> Result<Vec<u8>, TokenToStringError> {
565 if token == self.token_nl() {
566 return Ok(b"\n".to_vec());
567 }
568
569 let attrs = self.token_attr(token);
571 if attrs.is_empty()
572 || attrs
573 .intersects(LlamaTokenAttr::Unknown | LlamaTokenAttr::Byte | LlamaTokenAttr::Unused)
574 || attrs.contains(LlamaTokenAttr::Control)
575 && (token == self.token_bos() || token == self.token_eos())
576 {
577 return Ok(Vec::new());
578 }
579
580 let special = match special {
581 Special::Tokenize => true,
582 Special::Plaintext => false,
583 };
584
585 let string = CString::new(vec![b'*'; buffer_size]).expect("no null");
586 let len = string.as_bytes().len();
587 let len = c_int::try_from(len).expect("length fits into c_int");
588 let buf = string.into_raw();
589 let lstrip = lstrip.map_or(0, |it| i32::from(it.get()));
590 let size = unsafe {
591 llama_cpp_sys_2::llama_token_to_piece(
592 self.vocab_ptr(),
593 token.0,
594 buf,
595 len,
596 lstrip,
597 special,
598 )
599 };
600
601 match size {
602 0 => Err(TokenToStringError::UnknownTokenType),
603 i if i.is_negative() => Err(TokenToStringError::InsufficientBufferSpace(i)),
604 size => {
605 let string = unsafe { CString::from_raw(buf) };
606 let mut bytes = string.into_bytes();
607 let len = usize::try_from(size).expect("size is positive and fits into usize");
608 bytes.truncate(len);
609 Ok(bytes)
610 }
611 }
612 }
613 #[must_use]
618 pub fn n_vocab(&self) -> i32 {
619 unsafe { llama_cpp_sys_2::llama_n_vocab(self.vocab_ptr()) }
620 }
621
622 #[must_use]
628 pub fn vocab_type(&self) -> VocabType {
629 let vocab_type = unsafe { llama_cpp_sys_2::llama_vocab_type(self.vocab_ptr()) };
631 VocabType::try_from(vocab_type).expect("invalid vocab type")
632 }
633
634 #[must_use]
637 pub fn n_embd(&self) -> c_int {
638 unsafe { llama_cpp_sys_2::llama_n_embd(self.model.as_ptr()) }
639 }
640
641 pub fn size(&self) -> u64 {
643 unsafe { llama_cpp_sys_2::llama_model_size(self.model.as_ptr()) }
644 }
645
646 pub fn n_params(&self) -> u64 {
648 unsafe { llama_cpp_sys_2::llama_model_n_params(self.model.as_ptr()) }
649 }
650
651 pub fn is_recurrent(&self) -> bool {
653 unsafe { llama_cpp_sys_2::llama_model_is_recurrent(self.model.as_ptr()) }
654 }
655
656 pub fn n_layer(&self) -> u32 {
658 u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_layer(self.model.as_ptr()) }).unwrap()
661 }
662
663 pub fn n_head(&self) -> u32 {
665 u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_head(self.model.as_ptr()) }).unwrap()
668 }
669
670 pub fn n_head_kv(&self) -> u32 {
672 u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_head_kv(self.model.as_ptr()) })
675 .unwrap()
676 }
677
678 pub fn meta_val_str(&self, key: &str) -> Result<String, MetaValError> {
680 let key_cstring = CString::new(key)?;
681 let key_ptr = key_cstring.as_ptr();
682
683 extract_meta_string(
684 |buf_ptr, buf_len| unsafe {
685 llama_cpp_sys_2::llama_model_meta_val_str(
686 self.model.as_ptr(),
687 key_ptr,
688 buf_ptr,
689 buf_len,
690 )
691 },
692 256,
693 )
694 }
695
696 pub fn meta_count(&self) -> i32 {
698 unsafe { llama_cpp_sys_2::llama_model_meta_count(self.model.as_ptr()) }
699 }
700
701 pub fn meta_key_by_index(&self, index: i32) -> Result<String, MetaValError> {
703 extract_meta_string(
704 |buf_ptr, buf_len| unsafe {
705 llama_cpp_sys_2::llama_model_meta_key_by_index(
706 self.model.as_ptr(),
707 index,
708 buf_ptr,
709 buf_len,
710 )
711 },
712 256,
713 )
714 }
715
716 pub fn meta_val_str_by_index(&self, index: i32) -> Result<String, MetaValError> {
718 extract_meta_string(
719 |buf_ptr, buf_len| unsafe {
720 llama_cpp_sys_2::llama_model_meta_val_str_by_index(
721 self.model.as_ptr(),
722 index,
723 buf_ptr,
724 buf_len,
725 )
726 },
727 256,
728 )
729 }
730
731 pub fn rope_type(&self) -> Option<RopeType> {
733 match unsafe { llama_cpp_sys_2::llama_model_rope_type(self.model.as_ptr()) } {
734 llama_cpp_sys_2::LLAMA_ROPE_TYPE_NONE => None,
735 llama_cpp_sys_2::LLAMA_ROPE_TYPE_NORM => Some(RopeType::Norm),
736 llama_cpp_sys_2::LLAMA_ROPE_TYPE_NEOX => Some(RopeType::NeoX),
737 llama_cpp_sys_2::LLAMA_ROPE_TYPE_MROPE => Some(RopeType::MRope),
738 llama_cpp_sys_2::LLAMA_ROPE_TYPE_VISION => Some(RopeType::Vision),
739 rope_type => {
740 tracing::error!(rope_type = rope_type, "Unexpected rope type from llama.cpp");
741 None
742 }
743 }
744 }
745
746 pub fn chat_template(
760 &self,
761 name: Option<&str>,
762 ) -> Result<LlamaChatTemplate, ChatTemplateError> {
763 let name_cstr = name.map(CString::new);
764 let name_ptr = match name_cstr {
765 Some(Ok(name)) => name.as_ptr(),
766 _ => std::ptr::null(),
767 };
768 let result =
769 unsafe { llama_cpp_sys_2::llama_model_chat_template(self.model.as_ptr(), name_ptr) };
770
771 if result.is_null() {
773 Err(ChatTemplateError::MissingTemplate)
774 } else {
775 let chat_template_cstr = unsafe { CStr::from_ptr(result) };
776 let chat_template = CString::new(chat_template_cstr.to_bytes())?;
777 Ok(LlamaChatTemplate(chat_template))
778 }
779 }
780
781 #[tracing::instrument(skip_all, fields(params))]
787 pub fn load_from_file(
788 _: &LlamaBackend,
789 path: impl AsRef<Path>,
790 params: &LlamaModelParams,
791 ) -> Result<Self, LlamaModelLoadError> {
792 let path = path.as_ref();
793 debug_assert!(Path::new(path).exists(), "{path:?} does not exist");
794 let path = path
795 .to_str()
796 .ok_or(LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
797
798 let cstr = CString::new(path)?;
799 let llama_model =
800 unsafe { llama_cpp_sys_2::llama_load_model_from_file(cstr.as_ptr(), params.params) };
801
802 let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?;
803
804 tracing::debug!(?path, "Loaded model");
805 Ok(LlamaModel { model })
806 }
807
808 pub fn lora_adapter_init(
814 &self,
815 path: impl AsRef<Path>,
816 ) -> Result<LlamaLoraAdapter, LlamaLoraAdapterInitError> {
817 let path = path.as_ref();
818 debug_assert!(Path::new(path).exists(), "{path:?} does not exist");
819
820 let path = path
821 .to_str()
822 .ok_or(LlamaLoraAdapterInitError::PathToStrError(
823 path.to_path_buf(),
824 ))?;
825
826 let cstr = CString::new(path)?;
827 let adapter =
828 unsafe { llama_cpp_sys_2::llama_adapter_lora_init(self.model.as_ptr(), cstr.as_ptr()) };
829
830 let adapter = NonNull::new(adapter).ok_or(LlamaLoraAdapterInitError::NullResult)?;
831
832 tracing::debug!(?path, "Initialized lora adapter");
833 Ok(LlamaLoraAdapter {
834 lora_adapter: adapter,
835 })
836 }
837
838 #[allow(clippy::needless_pass_by_value)]
845 pub fn new_context<'a>(
846 &'a self,
847 _: &LlamaBackend,
848 params: LlamaContextParams,
849 ) -> Result<LlamaContext<'a>, LlamaContextLoadError> {
850 let context_params = params.context_params;
851 let context = unsafe {
852 llama_cpp_sys_2::llama_new_context_with_model(self.model.as_ptr(), context_params)
853 };
854 let context = NonNull::new(context).ok_or(LlamaContextLoadError::NullReturn)?;
855
856 Ok(LlamaContext::new(self, context, params.embeddings()))
857 }
858
859 #[tracing::instrument(skip_all)]
877 pub fn apply_chat_template(
878 &self,
879 tmpl: &LlamaChatTemplate,
880 chat: &[LlamaChatMessage],
881 add_ass: bool,
882 ) -> Result<String, ApplyChatTemplateError> {
883 let message_length = chat.iter().fold(0, |acc, c| {
885 acc + c.role.to_bytes().len() + c.content.to_bytes().len()
886 });
887 let mut buff: Vec<u8> = vec![0; message_length * 2];
888
889 let chat: Vec<llama_cpp_sys_2::llama_chat_message> = chat
891 .iter()
892 .map(|c| llama_cpp_sys_2::llama_chat_message {
893 role: c.role.as_ptr(),
894 content: c.content.as_ptr(),
895 })
896 .collect();
897
898 let tmpl_ptr = tmpl.0.as_ptr();
899
900 let res = unsafe {
901 llama_cpp_sys_2::llama_chat_apply_template(
902 tmpl_ptr,
903 chat.as_ptr(),
904 chat.len(),
905 add_ass,
906 buff.as_mut_ptr().cast::<c_char>(),
907 buff.len().try_into().expect("Buffer size exceeds i32::MAX"),
908 )
909 };
910
911 if res > buff.len().try_into().expect("Buffer size exceeds i32::MAX") {
912 buff.resize(res.try_into().expect("res is negative"), 0);
913
914 let res = unsafe {
915 llama_cpp_sys_2::llama_chat_apply_template(
916 tmpl_ptr,
917 chat.as_ptr(),
918 chat.len(),
919 add_ass,
920 buff.as_mut_ptr().cast::<c_char>(),
921 buff.len().try_into().expect("Buffer size exceeds i32::MAX"),
922 )
923 };
924 assert_eq!(Ok(res), buff.len().try_into());
925 }
926 buff.truncate(res.try_into().expect("res is negative"));
927 Ok(String::from_utf8(buff)?)
928 }
929
930 #[tracing::instrument(skip_all)]
934 pub fn apply_chat_template_with_tools_oaicompat(
935 &self,
936 tmpl: &LlamaChatTemplate,
937 messages: &[LlamaChatMessage],
938 tools_json: Option<&str>,
939 json_schema: Option<&str>,
940 add_generation_prompt: bool,
941 ) -> Result<ChatTemplateResult, ApplyChatTemplateError> {
942 let chat: Vec<llama_cpp_sys_2::llama_chat_message> = messages
943 .iter()
944 .map(|c| llama_cpp_sys_2::llama_chat_message {
945 role: c.role.as_ptr(),
946 content: c.content.as_ptr(),
947 })
948 .collect();
949
950 let tools_cstr = tools_json.map(CString::new).transpose()?;
951 let json_schema_cstr = json_schema.map(CString::new).transpose()?;
952
953 let mut raw_result = llama_cpp_sys_2::llama_rs_chat_template_result {
954 prompt: ptr::null_mut(),
955 grammar: ptr::null_mut(),
956 parser: ptr::null_mut(),
957 chat_format: 0,
958 thinking_forced_open: false,
959 grammar_lazy: false,
960 grammar_triggers: ptr::null_mut(),
961 grammar_triggers_count: 0,
962 preserved_tokens: ptr::null_mut(),
963 preserved_tokens_count: 0,
964 additional_stops: ptr::null_mut(),
965 additional_stops_count: 0,
966 };
967
968 let rc = unsafe {
969 llama_cpp_sys_2::llama_rs_apply_chat_template_with_tools_oaicompat(
970 self.model.as_ptr(),
971 tmpl.0.as_ptr(),
972 chat.as_ptr(),
973 chat.len(),
974 tools_cstr
975 .as_ref()
976 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
977 json_schema_cstr
978 .as_ref()
979 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
980 add_generation_prompt,
981 &mut raw_result,
982 )
983 };
984
985 let result = (|| {
986 if !status_is_ok(rc) {
987 return Err(ApplyChatTemplateError::FfiError(status_to_i32(rc)));
988 }
989 if raw_result.prompt.is_null() {
990 return Err(ApplyChatTemplateError::NullResult);
991 }
992 let prompt_bytes = unsafe { CStr::from_ptr(raw_result.prompt) }
993 .to_bytes()
994 .to_vec();
995 let prompt = String::from_utf8(prompt_bytes)?;
996 let grammar_lazy = raw_result.grammar_lazy;
997 let grammar = if raw_result.grammar.is_null() {
998 None
999 } else {
1000 let grammar_bytes = unsafe { CStr::from_ptr(raw_result.grammar) }
1001 .to_bytes()
1002 .to_vec();
1003 Some(String::from_utf8(grammar_bytes)?)
1004 };
1005 let parser = if raw_result.parser.is_null() {
1006 None
1007 } else {
1008 let parser_bytes = unsafe { CStr::from_ptr(raw_result.parser) }
1009 .to_bytes()
1010 .to_vec();
1011 Some(String::from_utf8(parser_bytes)?)
1012 };
1013 let grammar_triggers = if raw_result.grammar_triggers_count == 0 {
1014 Vec::new()
1015 } else if raw_result.grammar_triggers.is_null() {
1016 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1017 } else {
1018 let triggers = unsafe {
1019 slice::from_raw_parts(
1020 raw_result.grammar_triggers,
1021 raw_result.grammar_triggers_count,
1022 )
1023 };
1024 let mut parsed = Vec::with_capacity(triggers.len());
1025 for trigger in triggers {
1026 let trigger_type = match trigger.type_ {
1027 0 => GrammarTriggerType::Token,
1028 1 => GrammarTriggerType::Word,
1029 2 => GrammarTriggerType::Pattern,
1030 3 => GrammarTriggerType::PatternFull,
1031 _ => return Err(ApplyChatTemplateError::InvalidGrammarTriggerType),
1032 };
1033 let value = if trigger.value.is_null() {
1034 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1035 } else {
1036 let bytes = unsafe { CStr::from_ptr(trigger.value) }.to_bytes().to_vec();
1037 String::from_utf8(bytes)?
1038 };
1039 let token = if trigger_type == GrammarTriggerType::Token {
1040 Some(LlamaToken(trigger.token))
1041 } else {
1042 None
1043 };
1044 parsed.push(GrammarTrigger {
1045 trigger_type,
1046 value,
1047 token,
1048 });
1049 }
1050 parsed
1051 };
1052 let preserved_tokens = if raw_result.preserved_tokens_count == 0 {
1053 Vec::new()
1054 } else if raw_result.preserved_tokens.is_null() {
1055 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1056 } else {
1057 let tokens = unsafe {
1058 slice::from_raw_parts(
1059 raw_result.preserved_tokens,
1060 raw_result.preserved_tokens_count,
1061 )
1062 };
1063 let mut parsed = Vec::with_capacity(tokens.len());
1064 for token in tokens {
1065 if token.is_null() {
1066 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1067 }
1068 let bytes = unsafe { CStr::from_ptr(*token) }.to_bytes().to_vec();
1069 parsed.push(String::from_utf8(bytes)?);
1070 }
1071 parsed
1072 };
1073 let additional_stops = if raw_result.additional_stops_count == 0 {
1074 Vec::new()
1075 } else if raw_result.additional_stops.is_null() {
1076 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1077 } else {
1078 let stops = unsafe {
1079 slice::from_raw_parts(
1080 raw_result.additional_stops,
1081 raw_result.additional_stops_count,
1082 )
1083 };
1084 let mut parsed = Vec::with_capacity(stops.len());
1085 for stop in stops {
1086 if stop.is_null() {
1087 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1088 }
1089 let bytes = unsafe { CStr::from_ptr(*stop) }.to_bytes().to_vec();
1090 parsed.push(String::from_utf8(bytes)?);
1091 }
1092 parsed
1093 };
1094 let parse_tool_calls = tools_json.map_or(false, |tools| !tools.is_empty());
1095 Ok(ChatTemplateResult {
1096 prompt,
1097 grammar,
1098 grammar_lazy,
1099 grammar_triggers,
1100 preserved_tokens,
1101 additional_stops,
1102 chat_format: raw_result.chat_format,
1103 parser,
1104 thinking_forced_open: raw_result.thinking_forced_open,
1105 parse_tool_calls,
1106 })
1107 })();
1108
1109 unsafe { llama_cpp_sys_2::llama_rs_chat_template_result_free(&mut raw_result) };
1110 result
1111 }
1112
1113 #[tracing::instrument(skip_all)]
1115 pub fn apply_chat_template_oaicompat(
1116 &self,
1117 tmpl: &LlamaChatTemplate,
1118 params: &OpenAIChatTemplateParams<'_>,
1119 ) -> Result<ChatTemplateResult, ApplyChatTemplateError> {
1120 let parse_tool_calls = params.parse_tool_calls;
1121 let messages_cstr = CString::new(params.messages_json)?;
1122 let tools_cstr = params.tools_json.map(CString::new).transpose()?;
1123 let tool_choice_cstr = params.tool_choice.map(CString::new).transpose()?;
1124 let json_schema_cstr = params.json_schema.map(CString::new).transpose()?;
1125 let grammar_cstr = params.grammar.map(CString::new).transpose()?;
1126 let reasoning_cstr = params.reasoning_format.map(CString::new).transpose()?;
1127 let kwargs_cstr = params.chat_template_kwargs.map(CString::new).transpose()?;
1128
1129 let mut raw_result = llama_cpp_sys_2::llama_rs_chat_template_result {
1130 prompt: ptr::null_mut(),
1131 grammar: ptr::null_mut(),
1132 parser: ptr::null_mut(),
1133 chat_format: 0,
1134 thinking_forced_open: false,
1135 grammar_lazy: false,
1136 grammar_triggers: ptr::null_mut(),
1137 grammar_triggers_count: 0,
1138 preserved_tokens: ptr::null_mut(),
1139 preserved_tokens_count: 0,
1140 additional_stops: ptr::null_mut(),
1141 additional_stops_count: 0,
1142 };
1143
1144 let ffi_params = llama_cpp_sys_2::llama_rs_chat_template_oaicompat_params {
1145 messages: messages_cstr.as_ptr(),
1146 tools: tools_cstr
1147 .as_ref()
1148 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
1149 tool_choice: tool_choice_cstr
1150 .as_ref()
1151 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
1152 json_schema: json_schema_cstr
1153 .as_ref()
1154 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
1155 grammar: grammar_cstr
1156 .as_ref()
1157 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
1158 reasoning_format: reasoning_cstr
1159 .as_ref()
1160 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
1161 chat_template_kwargs: kwargs_cstr
1162 .as_ref()
1163 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
1164 add_generation_prompt: params.add_generation_prompt,
1165 use_jinja: params.use_jinja,
1166 parallel_tool_calls: params.parallel_tool_calls,
1167 enable_thinking: params.enable_thinking,
1168 add_bos: params.add_bos,
1169 add_eos: params.add_eos,
1170 };
1171
1172 let rc = unsafe {
1173 llama_cpp_sys_2::llama_rs_apply_chat_template_oaicompat(
1174 self.model.as_ptr(),
1175 tmpl.0.as_ptr(),
1176 &ffi_params,
1177 &mut raw_result,
1178 )
1179 };
1180
1181 let result = (|| {
1182 if !status_is_ok(rc) {
1183 return Err(ApplyChatTemplateError::FfiError(status_to_i32(rc)));
1184 }
1185 if raw_result.prompt.is_null() {
1186 return Err(ApplyChatTemplateError::NullResult);
1187 }
1188 let prompt_bytes = unsafe { CStr::from_ptr(raw_result.prompt) }
1189 .to_bytes()
1190 .to_vec();
1191 let prompt = String::from_utf8(prompt_bytes)?;
1192 let grammar_lazy = raw_result.grammar_lazy;
1193 let grammar = if raw_result.grammar.is_null() {
1194 None
1195 } else {
1196 let grammar_bytes = unsafe { CStr::from_ptr(raw_result.grammar) }
1197 .to_bytes()
1198 .to_vec();
1199 Some(String::from_utf8(grammar_bytes)?)
1200 };
1201 let parser = if raw_result.parser.is_null() {
1202 None
1203 } else {
1204 let parser_bytes = unsafe { CStr::from_ptr(raw_result.parser) }
1205 .to_bytes()
1206 .to_vec();
1207 Some(String::from_utf8(parser_bytes)?)
1208 };
1209 let grammar_triggers = if raw_result.grammar_triggers_count == 0 {
1210 Vec::new()
1211 } else if raw_result.grammar_triggers.is_null() {
1212 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1213 } else {
1214 let triggers = unsafe {
1215 slice::from_raw_parts(
1216 raw_result.grammar_triggers,
1217 raw_result.grammar_triggers_count,
1218 )
1219 };
1220 let mut parsed = Vec::with_capacity(triggers.len());
1221 for trigger in triggers {
1222 let trigger_type = match trigger.type_ {
1223 0 => GrammarTriggerType::Token,
1224 1 => GrammarTriggerType::Word,
1225 2 => GrammarTriggerType::Pattern,
1226 3 => GrammarTriggerType::PatternFull,
1227 _ => return Err(ApplyChatTemplateError::InvalidGrammarTriggerType),
1228 };
1229 let value = if trigger.value.is_null() {
1230 String::new()
1231 } else {
1232 let bytes = unsafe { CStr::from_ptr(trigger.value) }.to_bytes().to_vec();
1233 String::from_utf8(bytes)?
1234 };
1235 let token = if trigger_type == GrammarTriggerType::Token {
1236 Some(LlamaToken(trigger.token))
1237 } else {
1238 None
1239 };
1240 parsed.push(GrammarTrigger {
1241 trigger_type,
1242 value,
1243 token,
1244 });
1245 }
1246 parsed
1247 };
1248 let preserved_tokens = if raw_result.preserved_tokens_count == 0 {
1249 Vec::new()
1250 } else if raw_result.preserved_tokens.is_null() {
1251 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1252 } else {
1253 let tokens = unsafe {
1254 slice::from_raw_parts(
1255 raw_result.preserved_tokens,
1256 raw_result.preserved_tokens_count,
1257 )
1258 };
1259 let mut parsed = Vec::with_capacity(tokens.len());
1260 for token in tokens {
1261 if token.is_null() {
1262 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1263 }
1264 let bytes = unsafe { CStr::from_ptr(*token) }.to_bytes().to_vec();
1265 parsed.push(String::from_utf8(bytes)?);
1266 }
1267 parsed
1268 };
1269 let additional_stops = if raw_result.additional_stops_count == 0 {
1270 Vec::new()
1271 } else if raw_result.additional_stops.is_null() {
1272 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1273 } else {
1274 let stops = unsafe {
1275 slice::from_raw_parts(
1276 raw_result.additional_stops,
1277 raw_result.additional_stops_count,
1278 )
1279 };
1280 let mut parsed = Vec::with_capacity(stops.len());
1281 for stop in stops {
1282 if stop.is_null() {
1283 return Err(ApplyChatTemplateError::InvalidGrammarTriggerType);
1284 }
1285 let bytes = unsafe { CStr::from_ptr(*stop) }.to_bytes().to_vec();
1286 parsed.push(String::from_utf8(bytes)?);
1287 }
1288 parsed
1289 };
1290
1291 Ok(ChatTemplateResult {
1292 prompt,
1293 grammar,
1294 grammar_lazy,
1295 grammar_triggers,
1296 preserved_tokens,
1297 additional_stops,
1298 chat_format: raw_result.chat_format,
1299 parser,
1300 thinking_forced_open: raw_result.thinking_forced_open,
1301 parse_tool_calls,
1302 })
1303 })();
1304
1305 unsafe { llama_cpp_sys_2::llama_rs_chat_template_result_free(&mut raw_result) };
1306 result
1307 }
1308}
1309
1310impl ChatTemplateResult {
1311 pub fn parse_response_oaicompat(
1313 &self,
1314 text: &str,
1315 is_partial: bool,
1316 ) -> Result<String, ChatParseError> {
1317 let text_cstr = CString::new(text)?;
1318 let parser_cstr = self.parser.as_deref().map(CString::new).transpose()?;
1319 let mut out_json: *mut c_char = ptr::null_mut();
1320 let rc = unsafe {
1321 llama_cpp_sys_2::llama_rs_chat_parse_to_oaicompat(
1322 text_cstr.as_ptr(),
1323 is_partial,
1324 self.chat_format,
1325 self.parse_tool_calls,
1326 parser_cstr
1327 .as_ref()
1328 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
1329 self.thinking_forced_open,
1330 &mut out_json,
1331 )
1332 };
1333
1334 let result = (|| {
1335 if !status_is_ok(rc) {
1336 return Err(ChatParseError::FfiError(status_to_i32(rc)));
1337 }
1338 if out_json.is_null() {
1339 return Err(ChatParseError::NullResult);
1340 }
1341 let bytes = unsafe { CStr::from_ptr(out_json) }.to_bytes().to_vec();
1342 Ok(String::from_utf8(bytes)?)
1343 })();
1344
1345 unsafe { llama_cpp_sys_2::llama_rs_string_free(out_json) };
1346 result
1347 }
1348
1349 pub fn streaming_state_oaicompat(&self) -> Result<ChatParseStateOaicompat, ChatParseError> {
1351 let parser_cstr = self.parser.as_deref().map(CString::new).transpose()?;
1352 let state = unsafe {
1353 llama_cpp_sys_2::llama_rs_chat_parse_state_init_oaicompat(
1354 self.chat_format,
1355 self.parse_tool_calls,
1356 parser_cstr
1357 .as_ref()
1358 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
1359 self.thinking_forced_open,
1360 )
1361 };
1362 let state = NonNull::new(state).ok_or(ChatParseError::NullResult)?;
1363 Ok(ChatParseStateOaicompat { state })
1364 }
1365}
1366
1367fn extract_meta_string<F>(c_function: F, capacity: usize) -> Result<String, MetaValError>
1373where
1374 F: Fn(*mut c_char, usize) -> i32,
1375{
1376 let mut buffer = vec![0u8; capacity];
1377
1378 let result = c_function(buffer.as_mut_ptr().cast::<c_char>(), buffer.len());
1380 if result < 0 {
1381 return Err(MetaValError::NegativeReturn(result));
1382 }
1383
1384 let returned_len = result as usize;
1386 if returned_len >= capacity {
1387 return extract_meta_string(c_function, returned_len + 1);
1389 }
1390
1391 debug_assert_eq!(
1393 buffer.get(returned_len),
1394 Some(&0),
1395 "should end with null byte"
1396 );
1397
1398 buffer.truncate(returned_len);
1400 Ok(String::from_utf8(buffer)?)
1401}
1402
1403impl Drop for LlamaModel {
1404 fn drop(&mut self) {
1405 unsafe { llama_cpp_sys_2::llama_free_model(self.model.as_ptr()) }
1406 }
1407}
1408
1409#[repr(u32)]
1411#[derive(Debug, Eq, Copy, Clone, PartialEq)]
1412pub enum VocabType {
1413 BPE = llama_cpp_sys_2::LLAMA_VOCAB_TYPE_BPE as _,
1415 SPM = llama_cpp_sys_2::LLAMA_VOCAB_TYPE_SPM as _,
1417}
1418
1419#[derive(thiserror::Error, Debug, Eq, PartialEq)]
1421pub enum LlamaTokenTypeFromIntError {
1422 #[error("Unknown Value {0}")]
1424 UnknownValue(llama_cpp_sys_2::llama_vocab_type),
1425}
1426
1427impl TryFrom<llama_cpp_sys_2::llama_vocab_type> for VocabType {
1428 type Error = LlamaTokenTypeFromIntError;
1429
1430 fn try_from(value: llama_cpp_sys_2::llama_vocab_type) -> Result<Self, Self::Error> {
1431 match value {
1432 llama_cpp_sys_2::LLAMA_VOCAB_TYPE_BPE => Ok(VocabType::BPE),
1433 llama_cpp_sys_2::LLAMA_VOCAB_TYPE_SPM => Ok(VocabType::SPM),
1434 unknown => Err(LlamaTokenTypeFromIntError::UnknownValue(unknown)),
1435 }
1436 }
1437}