1use std::ffi::{CStr, CString, c_char};
3use std::num::NonZeroU16;
4use std::os::raw::c_int;
5use std::path::Path;
6use std::ptr::{self, NonNull};
7
8use crate::context::LlamaContext;
9use crate::context::params::LlamaContextParams;
10use crate::llama_backend::LlamaBackend;
11use crate::openai::OpenAIChatTemplateParams;
12use crate::token::LlamaToken;
13use crate::token_type::LlamaTokenAttrs;
14use crate::{
15 ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaLoraAdapterInitError,
16 LlamaModelLoadError, MetaValError, StringToTokenError, TokenToStringError,
17};
18
19pub mod add_bos;
20pub mod chat_template_result;
21pub mod grammar_trigger;
22pub mod llama_chat_message;
23pub mod llama_chat_template;
24pub mod llama_lora_adapter;
25pub mod params;
26pub mod rope_type;
27pub mod split_mode;
28pub mod vocab_type;
29
30pub use add_bos::AddBos;
31pub use chat_template_result::ChatTemplateResult;
32pub use grammar_trigger::{GrammarTrigger, GrammarTriggerType};
33pub use llama_chat_message::LlamaChatMessage;
34pub use llama_chat_template::LlamaChatTemplate;
35pub use llama_lora_adapter::LlamaLoraAdapter;
36pub use rope_type::RopeType;
37pub use vocab_type::{LlamaTokenTypeFromIntError, VocabType};
38
39use chat_template_result::{new_empty_chat_template_raw_result, parse_chat_template_raw_result};
40use params::LlamaModelParams;
41
42#[derive(Debug)]
44#[repr(transparent)]
45pub struct LlamaModel {
46 pub model: NonNull<llama_cpp_bindings_sys::llama_model>,
48}
49
50unsafe impl Send for LlamaModel {}
51
52unsafe impl Sync for LlamaModel {}
53
54impl LlamaModel {
55 #[must_use]
57 pub fn vocab_ptr(&self) -> *const llama_cpp_bindings_sys::llama_vocab {
58 unsafe { llama_cpp_bindings_sys::llama_model_get_vocab(self.model.as_ptr()) }
59 }
60
61 #[must_use]
68 pub fn n_ctx_train(&self) -> u32 {
69 let n_ctx_train = unsafe { llama_cpp_bindings_sys::llama_n_ctx_train(self.model.as_ptr()) };
70 u32::try_from(n_ctx_train).expect("n_ctx_train fits into an u32")
71 }
72
73 pub fn tokens(
75 &self,
76 decode_special: bool,
77 ) -> impl Iterator<Item = (LlamaToken, Result<String, TokenToStringError>)> + '_ {
78 (0..self.n_vocab())
79 .map(LlamaToken::new)
80 .map(move |llama_token| {
81 let mut decoder = encoding_rs::UTF_8.new_decoder();
82 (
83 llama_token,
84 self.token_to_piece(llama_token, &mut decoder, decode_special, None),
85 )
86 })
87 }
88
89 #[must_use]
91 pub fn token_bos(&self) -> LlamaToken {
92 let token = unsafe { llama_cpp_bindings_sys::llama_token_bos(self.vocab_ptr()) };
93 LlamaToken(token)
94 }
95
96 #[must_use]
98 pub fn token_eos(&self) -> LlamaToken {
99 let token = unsafe { llama_cpp_bindings_sys::llama_token_eos(self.vocab_ptr()) };
100 LlamaToken(token)
101 }
102
103 #[must_use]
105 pub fn token_nl(&self) -> LlamaToken {
106 let token = unsafe { llama_cpp_bindings_sys::llama_token_nl(self.vocab_ptr()) };
107 LlamaToken(token)
108 }
109
110 #[must_use]
112 pub fn is_eog_token(&self, token: LlamaToken) -> bool {
113 unsafe { llama_cpp_bindings_sys::llama_token_is_eog(self.vocab_ptr(), token.0) }
114 }
115
116 #[must_use]
118 pub fn decode_start_token(&self) -> LlamaToken {
119 let token =
120 unsafe { llama_cpp_bindings_sys::llama_model_decoder_start_token(self.model.as_ptr()) };
121 LlamaToken(token)
122 }
123
124 #[must_use]
126 pub fn token_sep(&self) -> LlamaToken {
127 let token = unsafe { llama_cpp_bindings_sys::llama_vocab_sep(self.vocab_ptr()) };
128 LlamaToken(token)
129 }
130
131 pub fn str_to_token(
154 &self,
155 str: &str,
156 add_bos: AddBos,
157 ) -> Result<Vec<LlamaToken>, StringToTokenError> {
158 let add_bos = match add_bos {
159 AddBos::Always => true,
160 AddBos::Never => false,
161 };
162
163 let tokens_estimation = std::cmp::max(8, (str.len() / 2) + usize::from(add_bos));
164 let mut buffer: Vec<LlamaToken> = Vec::with_capacity(tokens_estimation);
165
166 let c_string = CString::new(str)?;
167 let buffer_capacity =
168 c_int::try_from(buffer.capacity()).expect("buffer capacity should fit into a c_int");
169
170 let size = unsafe {
171 llama_cpp_bindings_sys::llama_tokenize(
172 self.vocab_ptr(),
173 c_string.as_ptr(),
174 c_int::try_from(c_string.as_bytes().len())?,
175 buffer
176 .as_mut_ptr()
177 .cast::<llama_cpp_bindings_sys::llama_token>(),
178 buffer_capacity,
179 add_bos,
180 true,
181 )
182 };
183
184 let size = if size.is_negative() {
187 buffer.reserve_exact(usize::try_from(-size).expect("usize's are larger "));
188 unsafe {
189 llama_cpp_bindings_sys::llama_tokenize(
190 self.vocab_ptr(),
191 c_string.as_ptr(),
192 c_int::try_from(c_string.as_bytes().len())?,
193 buffer
194 .as_mut_ptr()
195 .cast::<llama_cpp_bindings_sys::llama_token>(),
196 -size,
197 add_bos,
198 true,
199 )
200 }
201 } else {
202 size
203 };
204
205 let size = usize::try_from(size).expect("size is positive and usize ");
206
207 unsafe { buffer.set_len(size) }
209
210 Ok(buffer)
211 }
212
213 #[must_use]
219 pub fn token_attr(&self, LlamaToken(id): LlamaToken) -> LlamaTokenAttrs {
220 let token_type =
221 unsafe { llama_cpp_bindings_sys::llama_token_get_attr(self.vocab_ptr(), id) };
222 LlamaTokenAttrs::try_from(token_type).expect("token type is valid")
223 }
224
225 pub fn token_to_piece(
243 &self,
244 token: LlamaToken,
245 decoder: &mut encoding_rs::Decoder,
246 special: bool,
247 lstrip: Option<NonZeroU16>,
248 ) -> Result<String, TokenToStringError> {
249 let bytes = match self.token_to_piece_bytes(token, 8, special, lstrip) {
250 Err(TokenToStringError::InsufficientBufferSpace(i)) => self.token_to_piece_bytes(
253 token,
254 (-i).try_into().expect("Error buffer size is positive"),
255 special,
256 lstrip,
257 ),
258 x => x,
259 }?;
260 let mut output_piece = String::with_capacity(bytes.len());
262 let (_result, _somesize, _truthy) =
265 decoder.decode_to_string(&bytes, &mut output_piece, false);
266
267 Ok(output_piece)
268 }
269
270 pub fn token_to_piece_bytes(
286 &self,
287 token: LlamaToken,
288 buffer_size: usize,
289 special: bool,
290 lstrip: Option<NonZeroU16>,
291 ) -> Result<Vec<u8>, TokenToStringError> {
292 let string = CString::new(vec![b'*'; buffer_size]).expect("no null");
294 let len = string.as_bytes().len();
295 let len = c_int::try_from(len).expect("length fits into c_int");
296 let buf = string.into_raw();
297 let lstrip = lstrip.map_or(0, |it| i32::from(it.get()));
298 let size = unsafe {
299 llama_cpp_bindings_sys::llama_token_to_piece(
300 self.vocab_ptr(),
301 token.0,
302 buf,
303 len,
304 lstrip,
305 special,
306 )
307 };
308
309 match size {
310 0 => Err(TokenToStringError::UnknownTokenType),
311 i if i.is_negative() => Err(TokenToStringError::InsufficientBufferSpace(i)),
312 size => {
313 let string = unsafe { CString::from_raw(buf) };
314 let mut bytes = string.into_bytes();
315 let len = usize::try_from(size).expect("size is positive and fits into usize");
316 bytes.truncate(len);
317
318 Ok(bytes)
319 }
320 }
321 }
322
323 #[must_use]
328 pub fn n_vocab(&self) -> i32 {
329 unsafe { llama_cpp_bindings_sys::llama_n_vocab(self.vocab_ptr()) }
330 }
331
332 #[must_use]
338 pub fn vocab_type(&self) -> VocabType {
339 let vocab_type = unsafe { llama_cpp_bindings_sys::llama_vocab_type(self.vocab_ptr()) };
340 VocabType::try_from(vocab_type).expect("invalid vocab type")
341 }
342
343 #[must_use]
346 pub fn n_embd(&self) -> c_int {
347 unsafe { llama_cpp_bindings_sys::llama_n_embd(self.model.as_ptr()) }
348 }
349
350 #[must_use]
352 pub fn size(&self) -> u64 {
353 unsafe { llama_cpp_bindings_sys::llama_model_size(self.model.as_ptr()) }
354 }
355
356 #[must_use]
358 pub fn n_params(&self) -> u64 {
359 unsafe { llama_cpp_bindings_sys::llama_model_n_params(self.model.as_ptr()) }
360 }
361
362 #[must_use]
364 pub fn is_recurrent(&self) -> bool {
365 unsafe { llama_cpp_bindings_sys::llama_model_is_recurrent(self.model.as_ptr()) }
366 }
367
368 #[must_use]
373 pub fn n_layer(&self) -> u32 {
374 u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_layer(self.model.as_ptr()) })
376 .expect("llama.cpp returns a positive value for n_layer")
377 }
378
379 #[must_use]
384 pub fn n_head(&self) -> u32 {
385 u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_head(self.model.as_ptr()) })
387 .expect("llama.cpp returns a positive value for n_head")
388 }
389
390 #[must_use]
395 pub fn n_head_kv(&self) -> u32 {
396 u32::try_from(unsafe { llama_cpp_bindings_sys::llama_model_n_head_kv(self.model.as_ptr()) })
398 .expect("llama.cpp returns a positive value for n_head_kv")
399 }
400
401 pub fn meta_val_str(&self, key: &str) -> Result<String, MetaValError> {
406 let key_cstring = CString::new(key)?;
407 let key_ptr = key_cstring.as_ptr();
408
409 extract_meta_string(
410 |buf_ptr, buf_len| unsafe {
411 llama_cpp_bindings_sys::llama_model_meta_val_str(
412 self.model.as_ptr(),
413 key_ptr,
414 buf_ptr,
415 buf_len,
416 )
417 },
418 256,
419 )
420 }
421
422 #[must_use]
424 pub fn meta_count(&self) -> i32 {
425 unsafe { llama_cpp_bindings_sys::llama_model_meta_count(self.model.as_ptr()) }
426 }
427
428 pub fn meta_key_by_index(&self, index: i32) -> Result<String, MetaValError> {
433 extract_meta_string(
434 |buf_ptr, buf_len| unsafe {
435 llama_cpp_bindings_sys::llama_model_meta_key_by_index(
436 self.model.as_ptr(),
437 index,
438 buf_ptr,
439 buf_len,
440 )
441 },
442 256,
443 )
444 }
445
446 pub fn meta_val_str_by_index(&self, index: i32) -> Result<String, MetaValError> {
451 extract_meta_string(
452 |buf_ptr, buf_len| unsafe {
453 llama_cpp_bindings_sys::llama_model_meta_val_str_by_index(
454 self.model.as_ptr(),
455 index,
456 buf_ptr,
457 buf_len,
458 )
459 },
460 256,
461 )
462 }
463
464 pub fn rope_type(&self) -> Option<RopeType> {
466 match unsafe { llama_cpp_bindings_sys::llama_model_rope_type(self.model.as_ptr()) } {
467 llama_cpp_bindings_sys::LLAMA_ROPE_TYPE_NONE => None,
468 llama_cpp_bindings_sys::LLAMA_ROPE_TYPE_NORM => Some(RopeType::Norm),
469 llama_cpp_bindings_sys::LLAMA_ROPE_TYPE_NEOX => Some(RopeType::NeoX),
470 llama_cpp_bindings_sys::LLAMA_ROPE_TYPE_MROPE => Some(RopeType::MRope),
471 llama_cpp_bindings_sys::LLAMA_ROPE_TYPE_VISION => Some(RopeType::Vision),
472 rope_type => {
473 tracing::error!(rope_type = rope_type, "Unexpected rope type from llama.cpp");
474 None
475 }
476 }
477 }
478
479 pub fn chat_template(
493 &self,
494 name: Option<&str>,
495 ) -> Result<LlamaChatTemplate, ChatTemplateError> {
496 let name_cstr = name.map(CString::new);
497 let name_ptr = match name_cstr {
498 Some(Ok(name)) => name.as_ptr(),
499 _ => std::ptr::null(),
500 };
501 let result = unsafe {
502 llama_cpp_bindings_sys::llama_model_chat_template(self.model.as_ptr(), name_ptr)
503 };
504
505 if result.is_null() {
507 Err(ChatTemplateError::MissingTemplate)
508 } else {
509 let chat_template_cstr = unsafe { CStr::from_ptr(result) };
510 let chat_template = CString::new(chat_template_cstr.to_bytes())?;
511
512 Ok(LlamaChatTemplate(chat_template))
513 }
514 }
515
516 #[tracing::instrument(skip_all, fields(params))]
522 pub fn load_from_file(
523 _: &LlamaBackend,
524 path: impl AsRef<Path>,
525 params: &LlamaModelParams,
526 ) -> Result<Self, LlamaModelLoadError> {
527 let path = path.as_ref();
528 debug_assert!(
529 Path::new(path).exists(),
530 "{} does not exist",
531 path.display()
532 );
533 let path = path
534 .to_str()
535 .ok_or(LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
536
537 let cstr = CString::new(path)?;
538 let llama_model = unsafe {
539 llama_cpp_bindings_sys::llama_load_model_from_file(cstr.as_ptr(), params.params)
540 };
541
542 let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?;
543
544 tracing::debug!(?path, "Loaded model");
545
546 Ok(LlamaModel { model })
547 }
548
549 pub fn lora_adapter_init(
555 &self,
556 path: impl AsRef<Path>,
557 ) -> Result<LlamaLoraAdapter, LlamaLoraAdapterInitError> {
558 let path = path.as_ref();
559 debug_assert!(
560 Path::new(path).exists(),
561 "{} does not exist",
562 path.display()
563 );
564
565 let path = path
566 .to_str()
567 .ok_or(LlamaLoraAdapterInitError::PathToStrError(
568 path.to_path_buf(),
569 ))?;
570
571 let cstr = CString::new(path)?;
572 let adapter = unsafe {
573 llama_cpp_bindings_sys::llama_adapter_lora_init(self.model.as_ptr(), cstr.as_ptr())
574 };
575
576 let adapter = NonNull::new(adapter).ok_or(LlamaLoraAdapterInitError::NullResult)?;
577
578 tracing::debug!(?path, "Initialized lora adapter");
579
580 Ok(LlamaLoraAdapter {
581 lora_adapter: adapter,
582 })
583 }
584
585 pub fn new_context<'model>(
592 &'model self,
593 _: &LlamaBackend,
594 params: LlamaContextParams,
595 ) -> Result<LlamaContext<'model>, LlamaContextLoadError> {
596 let context_params = params.context_params;
597 let context = unsafe {
598 llama_cpp_bindings_sys::llama_new_context_with_model(
599 self.model.as_ptr(),
600 context_params,
601 )
602 };
603 let context = NonNull::new(context).ok_or(LlamaContextLoadError::NullReturn)?;
604
605 Ok(LlamaContext::new(self, context, params.embeddings()))
606 }
607
608 #[tracing::instrument(skip_all)]
629 pub fn apply_chat_template(
630 &self,
631 tmpl: &LlamaChatTemplate,
632 chat: &[LlamaChatMessage],
633 add_ass: bool,
634 ) -> Result<String, ApplyChatTemplateError> {
635 let message_length = chat.iter().fold(0, |acc, c| {
637 acc + c.role.to_bytes().len() + c.content.to_bytes().len()
638 });
639 let mut buff: Vec<u8> = vec![0; message_length * 2];
640
641 let chat: Vec<llama_cpp_bindings_sys::llama_chat_message> = chat
643 .iter()
644 .map(|c| llama_cpp_bindings_sys::llama_chat_message {
645 role: c.role.as_ptr(),
646 content: c.content.as_ptr(),
647 })
648 .collect();
649
650 let tmpl_ptr = tmpl.0.as_ptr();
651
652 let res = unsafe {
653 llama_cpp_bindings_sys::llama_chat_apply_template(
654 tmpl_ptr,
655 chat.as_ptr(),
656 chat.len(),
657 add_ass,
658 buff.as_mut_ptr().cast::<c_char>(),
659 buff.len().try_into().expect("Buffer size exceeds i32::MAX"),
660 )
661 };
662
663 if res > buff.len().try_into().expect("Buffer size exceeds i32::MAX") {
664 buff.resize(res.try_into().expect("res is negative"), 0);
665
666 let res = unsafe {
667 llama_cpp_bindings_sys::llama_chat_apply_template(
668 tmpl_ptr,
669 chat.as_ptr(),
670 chat.len(),
671 add_ass,
672 buff.as_mut_ptr().cast::<c_char>(),
673 buff.len().try_into().expect("Buffer size exceeds i32::MAX"),
674 )
675 };
676 assert_eq!(Ok(res), buff.len().try_into());
677 }
678 buff.truncate(res.try_into().expect("res is negative"));
679
680 Ok(String::from_utf8(buff)?)
681 }
682
683 #[tracing::instrument(skip_all)]
690 pub fn apply_chat_template_with_tools_oaicompat(
691 &self,
692 tmpl: &LlamaChatTemplate,
693 messages: &[LlamaChatMessage],
694 tools_json: Option<&str>,
695 json_schema: Option<&str>,
696 add_generation_prompt: bool,
697 ) -> Result<ChatTemplateResult, ApplyChatTemplateError> {
698 let chat: Vec<llama_cpp_bindings_sys::llama_chat_message> = messages
699 .iter()
700 .map(|c| llama_cpp_bindings_sys::llama_chat_message {
701 role: c.role.as_ptr(),
702 content: c.content.as_ptr(),
703 })
704 .collect();
705
706 let tools_cstr = tools_json.map(CString::new).transpose()?;
707 let json_schema_cstr = json_schema.map(CString::new).transpose()?;
708
709 let mut raw_result = new_empty_chat_template_raw_result();
710
711 let rc = unsafe {
712 llama_cpp_bindings_sys::llama_rs_apply_chat_template_with_tools_oaicompat(
713 self.model.as_ptr(),
714 tmpl.0.as_ptr(),
715 chat.as_ptr(),
716 chat.len(),
717 tools_cstr
718 .as_ref()
719 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
720 json_schema_cstr
721 .as_ref()
722 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
723 add_generation_prompt,
724 &raw mut raw_result,
725 )
726 };
727
728 let parse_tool_calls = tools_json.is_some_and(|tools| !tools.is_empty());
729
730 unsafe { parse_chat_template_raw_result(rc, &raw mut raw_result, parse_tool_calls) }
731 }
732
733 #[tracing::instrument(skip_all)]
738 pub fn apply_chat_template_oaicompat(
739 &self,
740 tmpl: &LlamaChatTemplate,
741 params: &OpenAIChatTemplateParams<'_>,
742 ) -> Result<ChatTemplateResult, ApplyChatTemplateError> {
743 let parse_tool_calls = params.parse_tool_calls;
744 let messages_cstr = CString::new(params.messages_json)?;
745 let tools_cstr = params.tools_json.map(CString::new).transpose()?;
746 let tool_choice_cstr = params.tool_choice.map(CString::new).transpose()?;
747 let json_schema_cstr = params.json_schema.map(CString::new).transpose()?;
748 let grammar_cstr = params.grammar.map(CString::new).transpose()?;
749 let reasoning_cstr = params.reasoning_format.map(CString::new).transpose()?;
750 let kwargs_cstr = params.chat_template_kwargs.map(CString::new).transpose()?;
751
752 let mut raw_result = new_empty_chat_template_raw_result();
753
754 let ffi_params = llama_cpp_bindings_sys::llama_rs_chat_template_oaicompat_params {
755 messages: messages_cstr.as_ptr(),
756 tools: tools_cstr
757 .as_ref()
758 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
759 tool_choice: tool_choice_cstr
760 .as_ref()
761 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
762 json_schema: json_schema_cstr
763 .as_ref()
764 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
765 grammar: grammar_cstr
766 .as_ref()
767 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
768 reasoning_format: reasoning_cstr
769 .as_ref()
770 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
771 chat_template_kwargs: kwargs_cstr
772 .as_ref()
773 .map_or(ptr::null(), |cstr| cstr.as_ptr()),
774 add_generation_prompt: params.add_generation_prompt,
775 use_jinja: params.use_jinja,
776 parallel_tool_calls: params.parallel_tool_calls,
777 enable_thinking: params.enable_thinking,
778 add_bos: params.add_bos,
779 add_eos: params.add_eos,
780 };
781
782 let rc = unsafe {
783 llama_cpp_bindings_sys::llama_rs_apply_chat_template_oaicompat(
784 self.model.as_ptr(),
785 tmpl.0.as_ptr(),
786 &raw const ffi_params,
787 &raw mut raw_result,
788 )
789 };
790
791 unsafe { parse_chat_template_raw_result(rc, &raw mut raw_result, parse_tool_calls) }
792 }
793}
794
795fn extract_meta_string<TCFunction>(
801 c_function: TCFunction,
802 capacity: usize,
803) -> Result<String, MetaValError>
804where
805 TCFunction: Fn(*mut c_char, usize) -> i32,
806{
807 let mut buffer = vec![0u8; capacity];
808
809 let result = c_function(buffer.as_mut_ptr().cast::<c_char>(), buffer.len());
811 if result < 0 {
812 return Err(MetaValError::NegativeReturn(result));
813 }
814
815 let returned_len = result as usize;
817 if returned_len >= capacity {
818 return extract_meta_string(c_function, returned_len + 1);
820 }
821
822 debug_assert_eq!(
824 buffer.get(returned_len),
825 Some(&0),
826 "should end with null byte"
827 );
828
829 buffer.truncate(returned_len);
831
832 Ok(String::from_utf8(buffer)?)
833}
834
835impl Drop for LlamaModel {
836 fn drop(&mut self) {
837 unsafe { llama_cpp_bindings_sys::llama_free_model(self.model.as_ptr()) }
838 }
839}