1use std::ffi::{CStr, CString, c_char};
3use std::num::NonZeroU16;
4use std::os::raw::c_int;
5use std::path::Path;
6use std::ptr::{self, NonNull};
7
8use crate::context::LlamaContext;
9use crate::context::params::LlamaContextParams;
10use crate::llama_backend::LlamaBackend;
11use crate::openai::OpenAIChatTemplateParams;
12use crate::token::LlamaToken;
13use crate::token_type::LlamaTokenAttrs;
14use crate::{
15 ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaLoraAdapterInitError,
16 LlamaModelLoadError, MetaValError, StringToTokenError, TokenToStringError,
17};
18
19pub mod add_bos;
20pub mod chat_template_result;
21pub mod grammar_trigger;
22pub mod llama_chat_message;
23pub mod llama_chat_template;
24pub mod llama_lora_adapter;
25pub mod params;
26pub mod rope_type;
27pub mod split_mode;
28pub mod vocab_type;
29
30pub use add_bos::AddBos;
31pub use chat_template_result::ChatTemplateResult;
32pub use grammar_trigger::{GrammarTrigger, GrammarTriggerType};
33pub use llama_chat_message::LlamaChatMessage;
34pub use llama_chat_template::LlamaChatTemplate;
35pub use llama_lora_adapter::LlamaLoraAdapter;
36pub use rope_type::RopeType;
37pub use vocab_type::{LlamaTokenTypeFromIntError, VocabType};
38
39use chat_template_result::{new_empty_chat_template_raw_result, parse_chat_template_raw_result};
40use params::LlamaModelParams;
41
42#[derive(Debug)]
44#[repr(transparent)]
45pub struct LlamaModel {
46 pub model: NonNull<llama_cpp_bindings_sys::llama_model>,
48}
49
50unsafe impl Send for LlamaModel {}
51
52unsafe impl Sync for LlamaModel {}
53
54impl LlamaModel {
55 #[must_use]
57 pub fn vocab_ptr(&self) -> *const llama_cpp_bindings_sys::llama_vocab {
58 unsafe { llama_cpp_bindings_sys::llama_model_get_vocab(self.model.as_ptr()) }
59 }
60
61 #[must_use]
68 pub fn n_ctx_train(&self) -> u32 {
69 let n_ctx_train = unsafe { llama_cpp_bindings_sys::llama_n_ctx_train(self.model.as_ptr()) };
70 u32::try_from(n_ctx_train).expect("n_ctx_train fits into an u32")
71 }
72
73 pub fn tokens(
75 &self,
76 decode_special: bool,
77 ) -> impl Iterator<Item = (LlamaToken, Result<String, TokenToStringError>)> + '_ {
78 (0..self.n_vocab())
79 .map(LlamaToken::new)
80 .map(move |llama_token| {
81 let mut decoder = encoding_rs::UTF_8.new_decoder();
82 (
83 llama_token,
84 self.token_to_piece(llama_token, &mut decoder, decode_special, None),
85 )
86 })
87 }
88
89 #[must_use]
91 pub fn token_bos(&self) -> LlamaToken {
92 let token = unsafe { llama_cpp_bindings_sys::llama_token_bos(self.vocab_ptr()) };
93 LlamaToken(token)
94 }
95
96 #[must_use]
98 pub fn token_eos(&self) -> LlamaToken {
99 let token = unsafe { llama_cpp_bindings_sys::llama_token_eos(self.vocab_ptr()) };
100 LlamaToken(token)
101 }
102
103 #[must_use]
105 pub fn token_nl(&self) -> LlamaToken {
106 let token = unsafe { llama_cpp_bindings_sys::llama_token_nl(self.vocab_ptr()) };
107 LlamaToken(token)
108 }
109
110 #[must_use]
112 pub fn is_eog_token(&self, token: LlamaToken) -> bool {
113 unsafe { llama_cpp_bindings_sys::llama_token_is_eog(self.vocab_ptr(), token.0) }
114 }
115
116 #[must_use]
118 pub fn decode_start_token(&self) -> LlamaToken {
119 let token =
120 unsafe { llama_cpp_bindings_sys::llama_model_decoder_start_token(self.model.as_ptr()) };
121 LlamaToken(token)
122 }
123
124 #[must_use]
126 pub fn token_sep(&self) -> LlamaToken {
127 let token = unsafe { llama_cpp_bindings_sys::llama_vocab_sep(self.vocab_ptr()) };
128 LlamaToken(token)
129 }
130
131 pub fn str_to_token(
154 &self,
155 str: &str,
156 add_bos: AddBos,
157 ) -> Result<Vec<LlamaToken>, StringToTokenError> {
158 let add_bos = match add_bos {
159 AddBos::Always => true,
160 AddBos::Never => false,
161 };
162
163 let tokens_estimation = std::cmp::max(8, (str.len() / 2) + usize::from(add_bos));
164 let mut buffer: Vec<LlamaToken> = Vec::with_capacity(tokens_estimation);
165
166 let c_string = CString::new(str)?;
167 let buffer_capacity =
168 c_int::try_from(buffer.capacity()).expect("buffer capacity should fit into a c_int");
169
170 let size = unsafe {
171 llama_cpp_bindings_sys::llama_tokenize(
172 self.vocab_ptr(),
173 c_string.as_ptr(),
174 c_int::try_from(c_string.as_bytes().len())?,
175 buffer
176 .as_mut_ptr()
177 .cast::<llama_cpp_bindings_sys::llama_token>(),
178 buffer_capacity,
179 add_bos,
180 true,
181 )
182 };
183
184 let size = if size.is_negative() {
185 buffer.reserve_exact(usize::try_from(-size).expect("negated size fits into usize"));
186 unsafe {
187 llama_cpp_bindings_sys::llama_tokenize(
188 self.vocab_ptr(),
189 c_string.as_ptr(),
190 c_int::try_from(c_string.as_bytes().len())?,
191 buffer
192 .as_mut_ptr()
193 .cast::<llama_cpp_bindings_sys::llama_token>(),
194 -size,
195 add_bos,
196 true,
197 )
198 }
199 } else {
200 size
201 };
202
203 let size = usize::try_from(size).expect("size is positive and fits into usize");
204
205 unsafe { buffer.set_len(size) }
207
208 Ok(buffer)
209 }
210
211 #[must_use]
217 pub fn token_attr(&self, LlamaToken(id): LlamaToken) -> LlamaTokenAttrs {
218 let token_type =
219 unsafe { llama_cpp_bindings_sys::llama_token_get_attr(self.vocab_ptr(), id) };
220 LlamaTokenAttrs::try_from(token_type).expect("token type is valid")
221 }
222
223 pub fn token_to_piece(
241 &self,
242 token: LlamaToken,
243 decoder: &mut encoding_rs::Decoder,
244 special: bool,
245 lstrip: Option<NonZeroU16>,
246 ) -> Result<String, TokenToStringError> {
247 let bytes = match self.token_to_piece_bytes(token, 8, special, lstrip) {
248 Err(TokenToStringError::InsufficientBufferSpace(required_size)) => self
249 .token_to_piece_bytes(
250 token,
251 (-required_size)
252 .try_into()
253 .expect("Error buffer size is positive"),
254 special,
255 lstrip,
256 ),
257 other => other,
258 }?;
259
260 let mut output_piece = String::with_capacity(bytes.len());
261 let (_result, _decoded_size, _had_replacements) =
262 decoder.decode_to_string(&bytes, &mut output_piece, false);
263
264 Ok(output_piece)
265 }
266
267 pub fn token_to_piece_bytes(
283 &self,
284 token: LlamaToken,
285 buffer_size: usize,
286 special: bool,
287 lstrip: Option<NonZeroU16>,
288 ) -> Result<Vec<u8>, TokenToStringError> {
289 let string = CString::new(vec![b'*'; buffer_size]).expect("no null");
291 let len = string.as_bytes().len();
292 let len = c_int::try_from(len).expect("length fits into c_int");
293 let buf = string.into_raw();
294 let lstrip = lstrip.map_or(0, |strip_count| i32::from(strip_count.get()));
295 let size = unsafe {
296 llama_cpp_bindings_sys::llama_token_to_piece(
297 self.vocab_ptr(),
298 token.0,
299 buf,
300 len,
301 lstrip,
302 special,
303 )
304 };
305
306 match size {
307 0 => Err(TokenToStringError::UnknownTokenType),
308 error_code if error_code.is_negative() => {
309 Err(TokenToStringError::InsufficientBufferSpace(error_code))
310 }
311 size => {
312 let string = unsafe { CString::from_raw(buf) };
313 let mut bytes = string.into_bytes();
314 let len = usize::try_from(size).expect("size is positive and fits into usize");
315 bytes.truncate(len);
316
317 Ok(bytes)
318 }
319 }
320 }
321
322 #[must_use]
327 pub fn n_vocab(&self) -> i32 {
328 unsafe { llama_cpp_bindings_sys::llama_n_vocab(self.vocab_ptr()) }
329 }
330
331 #[must_use]
337 pub fn vocab_type(&self) -> VocabType {
338 let vocab_type = unsafe { llama_cpp_bindings_sys::llama_vocab_type(self.vocab_ptr()) };
339 VocabType::try_from(vocab_type).expect("invalid vocab type")
340 }
341
342 #[must_use]
345 pub fn n_embd(&self) -> c_int {
346 unsafe { llama_cpp_bindings_sys::llama_n_embd(self.model.as_ptr()) }
347 }
348
349 #[must_use]
351 pub fn size(&self) -> u64 {
352 unsafe { llama_cpp_bindings_sys::llama_model_size(self.model.as_ptr()) }
353 }
354
355 #[must_use]
357 pub fn n_params(&self) -> u64 {
358 unsafe { llama_cpp_bindings_sys::llama_model_n_params(self.model.as_ptr()) }
359 }
360
361 #[must_use]
363 pub fn is_recurrent(&self) -> bool {
364 unsafe { llama_cpp_bindings_sys::llama_model_is_recurrent(self.model.as_ptr()) }
365 }
366
367 #[must_use]
372 pub fn n_layer(&self) -> u32 {
373 u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_layer(self.model.as_ptr()) })
375 .expect("llama.cpp returns a positive value for n_layer")
376 }
377
378 #[must_use]
383 pub fn n_head(&self) -> u32 {
384 u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_head(self.model.as_ptr()) })
386 .expect("llama.cpp returns a positive value for n_head")
387 }
388
389 #[must_use]
394 pub fn n_head_kv(&self) -> u32 {
395 u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_head_kv(self.model.as_ptr()) })
397 .expect("llama.cpp returns a positive value for n_head_kv")
398 }
399
400 pub fn meta_val_str(&self, key: &str) -> Result<String, MetaValError> {
405 let key_cstring = CString::new(key)?;
406 let key_ptr = key_cstring.as_ptr();
407
408 extract_meta_string(
409 |buf_ptr, buf_len| unsafe {
410 llama_cpp_bindings_sys::llama_model_meta_val_str(
411 self.model.as_ptr(),
412 key_ptr,
413 buf_ptr,
414 buf_len,
415 )
416 },
417 256,
418 )
419 }
420
421 #[must_use]
423 pub fn meta_count(&self) -> i32 {
424 unsafe { llama_cpp_bindings_sys::llama_model_meta_count(self.model.as_ptr()) }
425 }
426
427 pub fn meta_key_by_index(&self, index: i32) -> Result<String, MetaValError> {
432 extract_meta_string(
433 |buf_ptr, buf_len| unsafe {
434 llama_cpp_bindings_sys::llama_model_meta_key_by_index(
435 self.model.as_ptr(),
436 index,
437 buf_ptr,
438 buf_len,
439 )
440 },
441 256,
442 )
443 }
444
445 pub fn meta_val_str_by_index(&self, index: i32) -> Result<String, MetaValError> {
450 extract_meta_string(
451 |buf_ptr, buf_len| unsafe {
452 llama_cpp_bindings_sys::llama_model_meta_val_str_by_index(
453 self.model.as_ptr(),
454 index,
455 buf_ptr,
456 buf_len,
457 )
458 },
459 256,
460 )
461 }
462
463 pub fn rope_type(&self) -> Option<RopeType> {
465 match unsafe { llama_cpp_bindings_sys::llama_model_rope_type(self.model.as_ptr()) } {
466 llama_cpp_bindings_sys::LLAMA_ROPE_TYPE_NONE => None,
467 llama_cpp_bindings_sys::LLAMA_ROPE_TYPE_NORM => Some(RopeType::Norm),
468 llama_cpp_bindings_sys::LLAMA_ROPE_TYPE_NEOX => Some(RopeType::NeoX),
469 llama_cpp_bindings_sys::LLAMA_ROPE_TYPE_MROPE => Some(RopeType::MRope),
470 llama_cpp_bindings_sys::LLAMA_ROPE_TYPE_VISION => Some(RopeType::Vision),
471 rope_type => {
472 tracing::error!(rope_type = rope_type, "Unexpected rope type from llama.cpp");
473 None
474 }
475 }
476 }
477
478 pub fn chat_template(
492 &self,
493 name: Option<&str>,
494 ) -> Result<LlamaChatTemplate, ChatTemplateError> {
495 let name_cstr = name.map(CString::new);
496 let name_ptr = match name_cstr {
497 Some(Ok(name)) => name.as_ptr(),
498 _ => std::ptr::null(),
499 };
500 let result = unsafe {
501 llama_cpp_bindings_sys::llama_model_chat_template(self.model.as_ptr(), name_ptr)
502 };
503
504 if result.is_null() {
505 Err(ChatTemplateError::MissingTemplate)
506 } else {
507 let chat_template_cstr = unsafe { CStr::from_ptr(result) };
508 let chat_template = CString::new(chat_template_cstr.to_bytes())?;
509
510 Ok(LlamaChatTemplate(chat_template))
511 }
512 }
513
514 #[tracing::instrument(skip_all, fields(params))]
520 pub fn load_from_file(
521 _: &LlamaBackend,
522 path: impl AsRef<Path>,
523 params: &LlamaModelParams,
524 ) -> Result<Self, LlamaModelLoadError> {
525 let path = path.as_ref();
526 debug_assert!(
527 Path::new(path).exists(),
528 "{} does not exist",
529 path.display()
530 );
531 let path = path
532 .to_str()
533 .ok_or(LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
534
535 let cstr = CString::new(path)?;
536 let llama_model = unsafe {
537 llama_cpp_bindings_sys::llama_load_model_from_file(cstr.as_ptr(), params.params)
538 };
539
540 let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?;
541
542 tracing::debug!(?path, "Loaded model");
543
544 Ok(LlamaModel { model })
545 }
546
547 pub fn lora_adapter_init(
553 &self,
554 path: impl AsRef<Path>,
555 ) -> Result<LlamaLoraAdapter, LlamaLoraAdapterInitError> {
556 let path = path.as_ref();
557 debug_assert!(
558 Path::new(path).exists(),
559 "{} does not exist",
560 path.display()
561 );
562
563 let path = path
564 .to_str()
565 .ok_or(LlamaLoraAdapterInitError::PathToStrError(
566 path.to_path_buf(),
567 ))?;
568
569 let cstr = CString::new(path)?;
570 let adapter = unsafe {
571 llama_cpp_bindings_sys::llama_adapter_lora_init(self.model.as_ptr(), cstr.as_ptr())
572 };
573
574 let adapter = NonNull::new(adapter).ok_or(LlamaLoraAdapterInitError::NullResult)?;
575
576 tracing::debug!(?path, "Initialized lora adapter");
577
578 Ok(LlamaLoraAdapter {
579 lora_adapter: adapter,
580 })
581 }
582
583 pub fn new_context<'model>(
590 &'model self,
591 _: &LlamaBackend,
592 params: LlamaContextParams,
593 ) -> Result<LlamaContext<'model>, LlamaContextLoadError> {
594 let context_params = params.context_params;
595 let context = unsafe {
596 llama_cpp_bindings_sys::llama_new_context_with_model(
597 self.model.as_ptr(),
598 context_params,
599 )
600 };
601 let context = NonNull::new(context).ok_or(LlamaContextLoadError::NullReturn)?;
602
603 Ok(LlamaContext::new(self, context, params.embeddings()))
604 }
605
606 #[tracing::instrument(skip_all)]
627 pub fn apply_chat_template(
628 &self,
629 tmpl: &LlamaChatTemplate,
630 chat: &[LlamaChatMessage],
631 add_ass: bool,
632 ) -> Result<String, ApplyChatTemplateError> {
633 let message_length = chat.iter().fold(0, |acc, chat_message| {
634 acc + chat_message.role.to_bytes().len() + chat_message.content.to_bytes().len()
635 });
636 let mut buff: Vec<u8> = vec![0; message_length * 2];
637
638 let chat: Vec<llama_cpp_bindings_sys::llama_chat_message> = chat
639 .iter()
640 .map(|chat_message| llama_cpp_bindings_sys::llama_chat_message {
641 role: chat_message.role.as_ptr(),
642 content: chat_message.content.as_ptr(),
643 })
644 .collect();
645
646 let tmpl_ptr = tmpl.0.as_ptr();
647
648 let res = unsafe {
649 llama_cpp_bindings_sys::llama_chat_apply_template(
650 tmpl_ptr,
651 chat.as_ptr(),
652 chat.len(),
653 add_ass,
654 buff.as_mut_ptr().cast::<c_char>(),
655 buff.len().try_into().expect("Buffer size exceeds i32::MAX"),
656 )
657 };
658
659 if res > buff.len().try_into().expect("Buffer size exceeds i32::MAX") {
660 buff.resize(res.try_into().expect("res is negative"), 0);
661
662 let res = unsafe {
663 llama_cpp_bindings_sys::llama_chat_apply_template(
664 tmpl_ptr,
665 chat.as_ptr(),
666 chat.len(),
667 add_ass,
668 buff.as_mut_ptr().cast::<c_char>(),
669 buff.len().try_into().expect("Buffer size exceeds i32::MAX"),
670 )
671 };
672 assert_eq!(Ok(res), buff.len().try_into());
673 }
674 buff.truncate(res.try_into().expect("res is negative"));
675
676 Ok(String::from_utf8(buff)?)
677 }
678
679 #[tracing::instrument(skip_all)]
686 pub fn apply_chat_template_with_tools_oaicompat(
687 &self,
688 tmpl: &LlamaChatTemplate,
689 messages: &[LlamaChatMessage],
690 tools_json: Option<&str>,
691 json_schema: Option<&str>,
692 add_generation_prompt: bool,
693 ) -> Result<ChatTemplateResult, ApplyChatTemplateError> {
694 let chat: Vec<llama_cpp_bindings_sys::llama_chat_message> = messages
695 .iter()
696 .map(|chat_message| llama_cpp_bindings_sys::llama_chat_message {
697 role: chat_message.role.as_ptr(),
698 content: chat_message.content.as_ptr(),
699 })
700 .collect();
701
702 let tools_cstr = tools_json.map(CString::new).transpose()?;
703 let json_schema_cstr = json_schema.map(CString::new).transpose()?;
704
705 let mut raw_result = new_empty_chat_template_raw_result();
706
707 let rc = unsafe {
708 llama_cpp_bindings_sys::llama_rs_apply_chat_template_with_tools_oaicompat(
709 self.model.as_ptr(),
710 tmpl.0.as_ptr(),
711 chat.as_ptr(),
712 chat.len(),
713 tools_cstr
714 .as_ref()
715 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
716 json_schema_cstr
717 .as_ref()
718 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
719 add_generation_prompt,
720 &raw mut raw_result,
721 )
722 };
723
724 let parse_tool_calls = tools_json.is_some_and(|tools| !tools.is_empty());
725
726 unsafe { parse_chat_template_raw_result(rc, &raw mut raw_result, parse_tool_calls) }
727 }
728
729 #[tracing::instrument(skip_all)]
734 pub fn apply_chat_template_oaicompat(
735 &self,
736 tmpl: &LlamaChatTemplate,
737 params: &OpenAIChatTemplateParams<'_>,
738 ) -> Result<ChatTemplateResult, ApplyChatTemplateError> {
739 let parse_tool_calls = params.parse_tool_calls;
740 let messages_cstr = CString::new(params.messages_json)?;
741 let tools_cstr = params.tools_json.map(CString::new).transpose()?;
742 let tool_choice_cstr = params.tool_choice.map(CString::new).transpose()?;
743 let json_schema_cstr = params.json_schema.map(CString::new).transpose()?;
744 let grammar_cstr = params.grammar.map(CString::new).transpose()?;
745 let reasoning_cstr = params.reasoning_format.map(CString::new).transpose()?;
746 let kwargs_cstr = params.chat_template_kwargs.map(CString::new).transpose()?;
747
748 let mut raw_result = new_empty_chat_template_raw_result();
749
750 let ffi_params = llama_cpp_bindings_sys::llama_rs_chat_template_oaicompat_params {
751 messages: messages_cstr.as_ptr(),
752 tools: tools_cstr
753 .as_ref()
754 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
755 tool_choice: tool_choice_cstr
756 .as_ref()
757 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
758 json_schema: json_schema_cstr
759 .as_ref()
760 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
761 grammar: grammar_cstr
762 .as_ref()
763 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
764 reasoning_format: reasoning_cstr
765 .as_ref()
766 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
767 chat_template_kwargs: kwargs_cstr
768 .as_ref()
769 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
770 add_generation_prompt: params.add_generation_prompt,
771 use_jinja: params.use_jinja,
772 parallel_tool_calls: params.parallel_tool_calls,
773 enable_thinking: params.enable_thinking,
774 add_bos: params.add_bos,
775 add_eos: params.add_eos,
776 };
777
778 let rc = unsafe {
779 llama_cpp_bindings_sys::llama_rs_apply_chat_template_oaicompat(
780 self.model.as_ptr(),
781 tmpl.0.as_ptr(),
782 &raw const ffi_params,
783 &raw mut raw_result,
784 )
785 };
786
787 unsafe { parse_chat_template_raw_result(rc, &raw mut raw_result, parse_tool_calls) }
788 }
789}
790
791fn extract_meta_string<TCFunction>(
792 c_function: TCFunction,
793 capacity: usize,
794) -> Result<String, MetaValError>
795where
796 TCFunction: Fn(*mut c_char, usize) -> i32,
797{
798 let mut buffer = vec![0u8; capacity];
799 let result = c_function(buffer.as_mut_ptr().cast::<c_char>(), buffer.len());
800
801 if result < 0 {
802 return Err(MetaValError::NegativeReturn(result));
803 }
804
805 let returned_len = result.cast_unsigned() as usize;
806
807 if returned_len >= capacity {
808 return extract_meta_string(c_function, returned_len + 1);
809 }
810
811 debug_assert_eq!(
812 buffer.get(returned_len),
813 Some(&0),
814 "should end with null byte"
815 );
816
817 buffer.truncate(returned_len);
818
819 Ok(String::from_utf8(buffer)?)
820}
821
822impl Drop for LlamaModel {
823 fn drop(&mut self) {
824 unsafe { llama_cpp_bindings_sys::llama_free_model(self.model.as_ptr()) }
825 }
826}