1use std::ffi::{c_char, CStr, CString};
3use std::num::NonZeroU16;
4use std::os::raw::c_int;
5use std::path::Path;
6use std::ptr::{self, NonNull};
7use std::slice;
8use std::str::Utf8Error;
9
10use crate::context::params::LlamaContextParams;
11use crate::context::LlamaContext;
12use crate::llama_backend::LlamaBackend;
13use crate::model::params::LlamaModelParams;
14use crate::token::LlamaToken;
15use crate::token_type::{LlamaTokenAttr, LlamaTokenAttrs};
16use crate::{
17 ApplyChatTemplateError, ChatTemplateError, LlamaContextLoadError, LlamaLoraAdapterInitError,
18 LlamaModelLoadError, MetaValError, NewLlamaChatMessageError, StringToTokenError,
19 TokenToStringError,
20};
21
22pub mod params;
23
24#[derive(Debug)]
26#[repr(transparent)]
27#[allow(clippy::module_name_repetitions)]
28pub struct LlamaModel {
29 pub(crate) model: NonNull<llama_cpp_sys_2::llama_model>,
30}
31
32#[derive(Debug)]
34#[repr(transparent)]
35#[allow(clippy::module_name_repetitions)]
36pub struct LlamaLoraAdapter {
37 pub(crate) lora_adapter: NonNull<llama_cpp_sys_2::llama_adapter_lora>,
38}
39
40#[derive(Eq, PartialEq, Clone, PartialOrd, Ord, Hash)]
45pub struct LlamaChatTemplate(CString);
46
47impl LlamaChatTemplate {
48 pub fn new(template: &str) -> Result<Self, std::ffi::NulError> {
51 Ok(Self(CString::new(template)?))
52 }
53
54 pub fn as_c_str(&self) -> &CStr {
56 &self.0
57 }
58
59 pub fn to_str(&self) -> Result<&str, Utf8Error> {
61 self.0.to_str()
62 }
63
64 pub fn to_string(&self) -> Result<String, Utf8Error> {
66 self.to_str().map(str::to_string)
67 }
68}
69
70impl std::fmt::Debug for LlamaChatTemplate {
71 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72 self.0.fmt(f)
73 }
74}
75
76#[derive(Debug, Eq, PartialEq, Clone)]
78pub struct LlamaChatMessage {
79 role: CString,
80 content: CString,
81}
82
83impl LlamaChatMessage {
84 pub fn new(role: String, content: String) -> Result<Self, NewLlamaChatMessageError> {
89 Ok(Self {
90 role: CString::new(role)?,
91 content: CString::new(content)?,
92 })
93 }
94}
95
96#[derive(Debug, Clone, Copy, PartialEq, Eq)]
98pub enum RopeType {
99 Norm,
100 NeoX,
101 MRope,
102 Vision,
103}
104
105#[derive(Debug, Clone, Copy, PartialEq, Eq)]
107pub enum AddBos {
108 Always,
110 Never,
112}
113
114#[deprecated(
116 since = "0.1.0",
117 note = "This enum is a mixture of options for llama cpp providing less flexibility it only used with deprecated methods and will be removed in the future."
118)]
119#[derive(Debug, Clone, Copy, PartialEq, Eq)]
120pub enum Special {
121 Tokenize,
123 Plaintext,
125}
126
127unsafe impl Send for LlamaModel {}
128
129unsafe impl Sync for LlamaModel {}
130
131impl LlamaModel {
132 pub(crate) fn vocab_ptr(&self) -> *const llama_cpp_sys_2::llama_vocab {
133 unsafe { llama_cpp_sys_2::llama_model_get_vocab(self.model.as_ptr()) }
134 }
135
136 #[must_use]
143 pub fn n_ctx_train(&self) -> u32 {
144 let n_ctx_train = unsafe { llama_cpp_sys_2::llama_n_ctx_train(self.model.as_ptr()) };
145 u32::try_from(n_ctx_train).expect("n_ctx_train fits into an u32")
146 }
147
148 pub fn tokens(
150 &self,
151 decode_special: bool,
152 ) -> impl Iterator<Item = (LlamaToken, Result<String, TokenToStringError>)> + '_ {
153 (0..self.n_vocab())
154 .map(LlamaToken::new)
155 .map(move |llama_token| {
156 let mut decoder = encoding_rs::UTF_8.new_decoder();
157 (
158 llama_token,
159 self.token_to_piece(llama_token, &mut decoder, decode_special, None),
160 )
161 })
162 }
163
164 #[must_use]
166 pub fn token_bos(&self) -> LlamaToken {
167 let token = unsafe { llama_cpp_sys_2::llama_token_bos(self.vocab_ptr()) };
168 LlamaToken(token)
169 }
170
171 #[must_use]
173 pub fn token_eos(&self) -> LlamaToken {
174 let token = unsafe { llama_cpp_sys_2::llama_token_eos(self.vocab_ptr()) };
175 LlamaToken(token)
176 }
177
178 #[must_use]
180 pub fn token_nl(&self) -> LlamaToken {
181 let token = unsafe { llama_cpp_sys_2::llama_token_nl(self.vocab_ptr()) };
182 LlamaToken(token)
183 }
184
185 #[must_use]
187 pub fn is_eog_token(&self, token: LlamaToken) -> bool {
188 unsafe { llama_cpp_sys_2::llama_token_is_eog(self.vocab_ptr(), token.0) }
189 }
190
191 #[must_use]
193 pub fn decode_start_token(&self) -> LlamaToken {
194 let token =
195 unsafe { llama_cpp_sys_2::llama_model_decoder_start_token(self.model.as_ptr()) };
196 LlamaToken(token)
197 }
198
199 #[must_use]
201 pub fn token_sep(&self) -> LlamaToken {
202 let token = unsafe { llama_cpp_sys_2::llama_vocab_sep(self.vocab_ptr()) };
203 LlamaToken(token)
204 }
205
206 #[deprecated(since = "0.1.0", note = "Use `token_to_piece` instead")]
212 pub fn token_to_str(
213 &self,
214 token: LlamaToken,
215 special: Special,
216 ) -> Result<String, TokenToStringError> {
217 let mut decoder = encoding_rs::UTF_8.new_decoder();
219 self.token_to_piece(
220 token,
221 &mut decoder,
222 matches!(special, Special::Tokenize),
223 None,
224 )
225 }
226
227 #[deprecated(since = "0.1.0", note = "Use `token_to_piece_bytes` instead")]
237 pub fn token_to_bytes(
238 &self,
239 token: LlamaToken,
240 special: Special,
241 ) -> Result<Vec<u8>, TokenToStringError> {
242 match self.token_to_piece_bytes(token, 8, matches!(special, Special::Tokenize), None) {
244 Err(TokenToStringError::InsufficientBufferSpace(i)) => self.token_to_piece_bytes(
245 token,
246 (-i).try_into().expect("Error buffer size is positive"),
247 matches!(special, Special::Tokenize),
248 None,
249 ),
250 x => x,
251 }
252 }
253
254 #[deprecated(
260 since = "0.1.0",
261 note = "Use `token_to_piece` for each token individually instead"
262 )]
263 pub fn tokens_to_str(
264 &self,
265 tokens: &[LlamaToken],
266 special: Special,
267 ) -> Result<String, TokenToStringError> {
268 let mut builder: Vec<u8> = Vec::with_capacity(tokens.len() * 4);
269 for piece in tokens
270 .iter()
271 .copied()
272 .map(|t| self.token_to_piece_bytes(t, 8, matches!(special, Special::Tokenize), None))
273 {
274 builder.extend_from_slice(&piece?);
275 }
276 Ok(String::from_utf8(builder)?)
277 }
278
279 pub fn str_to_token(
302 &self,
303 str: &str,
304 add_bos: AddBos,
305 ) -> Result<Vec<LlamaToken>, StringToTokenError> {
306 let add_bos = match add_bos {
307 AddBos::Always => true,
308 AddBos::Never => false,
309 };
310
311 let tokens_estimation = std::cmp::max(8, (str.len() / 2) + usize::from(add_bos));
312 let mut buffer: Vec<LlamaToken> = Vec::with_capacity(tokens_estimation);
313
314 let c_string = CString::new(str)?;
315 let buffer_capacity =
316 c_int::try_from(buffer.capacity()).expect("buffer capacity should fit into a c_int");
317
318 let size = unsafe {
319 llama_cpp_sys_2::llama_tokenize(
320 self.vocab_ptr(),
321 c_string.as_ptr(),
322 c_int::try_from(c_string.as_bytes().len())?,
323 buffer.as_mut_ptr().cast::<llama_cpp_sys_2::llama_token>(),
324 buffer_capacity,
325 add_bos,
326 true,
327 )
328 };
329
330 let size = if size.is_negative() {
333 buffer.reserve_exact(usize::try_from(-size).expect("usize's are larger "));
334 unsafe {
335 llama_cpp_sys_2::llama_tokenize(
336 self.vocab_ptr(),
337 c_string.as_ptr(),
338 c_int::try_from(c_string.as_bytes().len())?,
339 buffer.as_mut_ptr().cast::<llama_cpp_sys_2::llama_token>(),
340 -size,
341 add_bos,
342 true,
343 )
344 }
345 } else {
346 size
347 };
348
349 let size = usize::try_from(size).expect("size is positive and usize ");
350
351 unsafe { buffer.set_len(size) }
353 Ok(buffer)
354 }
355
356 #[must_use]
362 pub fn token_attr(&self, LlamaToken(id): LlamaToken) -> LlamaTokenAttrs {
363 let token_type = unsafe { llama_cpp_sys_2::llama_token_get_attr(self.vocab_ptr(), id) };
364 LlamaTokenAttrs::try_from(token_type).expect("token type is valid")
365 }
366
367 pub fn token_to_piece(
385 &self,
386 token: LlamaToken,
387 decoder: &mut encoding_rs::Decoder,
388 special: bool,
389 lstrip: Option<NonZeroU16>,
390 ) -> Result<String, TokenToStringError> {
391 let bytes = match self.token_to_piece_bytes(token, 8, special, lstrip) {
392 Err(TokenToStringError::InsufficientBufferSpace(i)) => self.token_to_piece_bytes(
395 token,
396 (-i).try_into().expect("Error buffer size is positive"),
397 special,
398 lstrip,
399 ),
400 x => x,
401 }?;
402 let mut output_piece = String::with_capacity(bytes.len());
404 let (_result, _somesize, _truthy) =
407 decoder.decode_to_string(&bytes, &mut output_piece, false);
408 Ok(output_piece)
409 }
410
411 pub fn token_to_piece_bytes(
427 &self,
428 token: LlamaToken,
429 buffer_size: usize,
430 special: bool,
431 lstrip: Option<NonZeroU16>,
432 ) -> Result<Vec<u8>, TokenToStringError> {
433 let string = CString::new(vec![b'*'; buffer_size]).expect("no null");
434 let len = string.as_bytes().len();
435 let len = c_int::try_from(len).expect("length fits into c_int");
436 let buf = string.into_raw();
437 let lstrip = lstrip.map_or(0, |it| i32::from(it.get()));
438 let size = unsafe {
439 llama_cpp_sys_2::llama_token_to_piece(
440 self.vocab_ptr(),
441 token.0,
442 buf,
443 len,
444 lstrip,
445 special,
446 )
447 };
448
449 match size {
450 0 => Err(TokenToStringError::UnknownTokenType),
451 i if i.is_negative() => Err(TokenToStringError::InsufficientBufferSpace(i)),
452 size => {
453 let string = unsafe { CString::from_raw(buf) };
454 let mut bytes = string.into_bytes();
455 let len = usize::try_from(size).expect("size is positive and fits into usize");
456 bytes.truncate(len);
457 Ok(bytes)
458 }
459 }
460 }
461
462 #[deprecated(since = "0.1.0", note = "Use `token_to_piece` instead")]
478 pub fn token_to_str_with_size(
479 &self,
480 token: LlamaToken,
481 buffer_size: usize,
482 special: Special,
483 ) -> Result<String, TokenToStringError> {
484 let bytes = self.token_to_piece_bytes(
485 token,
486 buffer_size,
487 matches!(special, Special::Tokenize),
488 None,
489 )?;
490 Ok(String::from_utf8(bytes)?)
491 }
492
493 #[deprecated(since = "0.1.0", note = "Use `token_to_piece_bytes` instead")]
508 pub fn token_to_bytes_with_size(
509 &self,
510 token: LlamaToken,
511 buffer_size: usize,
512 special: Special,
513 lstrip: Option<NonZeroU16>,
514 ) -> Result<Vec<u8>, TokenToStringError> {
515 if token == self.token_nl() {
516 return Ok(b"\n".to_vec());
517 }
518
519 let attrs = self.token_attr(token);
521 if attrs.is_empty()
522 || attrs
523 .intersects(LlamaTokenAttr::Unknown | LlamaTokenAttr::Byte | LlamaTokenAttr::Unused)
524 || attrs.contains(LlamaTokenAttr::Control)
525 && (token == self.token_bos() || token == self.token_eos())
526 {
527 return Ok(Vec::new());
528 }
529
530 let special = match special {
531 Special::Tokenize => true,
532 Special::Plaintext => false,
533 };
534
535 let string = CString::new(vec![b'*'; buffer_size]).expect("no null");
536 let len = string.as_bytes().len();
537 let len = c_int::try_from(len).expect("length fits into c_int");
538 let buf = string.into_raw();
539 let lstrip = lstrip.map_or(0, |it| i32::from(it.get()));
540 let size = unsafe {
541 llama_cpp_sys_2::llama_token_to_piece(
542 self.vocab_ptr(),
543 token.0,
544 buf,
545 len,
546 lstrip,
547 special,
548 )
549 };
550
551 match size {
552 0 => Err(TokenToStringError::UnknownTokenType),
553 i if i.is_negative() => Err(TokenToStringError::InsufficientBufferSpace(i)),
554 size => {
555 let string = unsafe { CString::from_raw(buf) };
556 let mut bytes = string.into_bytes();
557 let len = usize::try_from(size).expect("size is positive and fits into usize");
558 bytes.truncate(len);
559 Ok(bytes)
560 }
561 }
562 }
563 #[must_use]
568 pub fn n_vocab(&self) -> i32 {
569 unsafe { llama_cpp_sys_2::llama_n_vocab(self.vocab_ptr()) }
570 }
571
572 #[must_use]
578 pub fn vocab_type(&self) -> VocabType {
579 let vocab_type = unsafe { llama_cpp_sys_2::llama_vocab_type(self.vocab_ptr()) };
581 VocabType::try_from(vocab_type).expect("invalid vocab type")
582 }
583
584 #[must_use]
587 pub fn n_embd(&self) -> c_int {
588 unsafe { llama_cpp_sys_2::llama_n_embd(self.model.as_ptr()) }
589 }
590
591 pub fn size(&self) -> u64 {
593 unsafe { llama_cpp_sys_2::llama_model_size(self.model.as_ptr()) }
594 }
595
596 pub fn n_params(&self) -> u64 {
598 unsafe { llama_cpp_sys_2::llama_model_n_params(self.model.as_ptr()) }
599 }
600
601 pub fn is_recurrent(&self) -> bool {
603 unsafe { llama_cpp_sys_2::llama_model_is_recurrent(self.model.as_ptr()) }
604 }
605
606 pub fn is_hybrid(&self) -> bool {
611 unsafe { llama_cpp_sys_2::llama_model_is_hybrid(self.model.as_ptr()) }
612 }
613
614 pub fn n_layer(&self) -> u32 {
616 u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_layer(self.model.as_ptr()) }).unwrap()
619 }
620
621 pub fn n_head(&self) -> u32 {
623 u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_head(self.model.as_ptr()) }).unwrap()
626 }
627
628 pub fn n_head_kv(&self) -> u32 {
630 u32::try_from(unsafe { llama_cpp_sys_2::llama_model_n_head_kv(self.model.as_ptr()) })
633 .unwrap()
634 }
635
636 pub fn meta_val_str(&self, key: &str) -> Result<String, MetaValError> {
638 let key_cstring = CString::new(key)?;
639 let key_ptr = key_cstring.as_ptr();
640
641 extract_meta_string(
642 |buf_ptr, buf_len| unsafe {
643 llama_cpp_sys_2::llama_model_meta_val_str(
644 self.model.as_ptr(),
645 key_ptr,
646 buf_ptr,
647 buf_len,
648 )
649 },
650 256,
651 )
652 }
653
654 pub fn meta_count(&self) -> i32 {
656 unsafe { llama_cpp_sys_2::llama_model_meta_count(self.model.as_ptr()) }
657 }
658
659 pub fn meta_key_by_index(&self, index: i32) -> Result<String, MetaValError> {
661 extract_meta_string(
662 |buf_ptr, buf_len| unsafe {
663 llama_cpp_sys_2::llama_model_meta_key_by_index(
664 self.model.as_ptr(),
665 index,
666 buf_ptr,
667 buf_len,
668 )
669 },
670 256,
671 )
672 }
673
674 pub fn meta_val_str_by_index(&self, index: i32) -> Result<String, MetaValError> {
676 extract_meta_string(
677 |buf_ptr, buf_len| unsafe {
678 llama_cpp_sys_2::llama_model_meta_val_str_by_index(
679 self.model.as_ptr(),
680 index,
681 buf_ptr,
682 buf_len,
683 )
684 },
685 256,
686 )
687 }
688
689 pub fn rope_type(&self) -> Option<RopeType> {
691 match unsafe { llama_cpp_sys_2::llama_model_rope_type(self.model.as_ptr()) } {
692 llama_cpp_sys_2::LLAMA_ROPE_TYPE_NONE => None,
693 llama_cpp_sys_2::LLAMA_ROPE_TYPE_NORM => Some(RopeType::Norm),
694 llama_cpp_sys_2::LLAMA_ROPE_TYPE_NEOX => Some(RopeType::NeoX),
695 llama_cpp_sys_2::LLAMA_ROPE_TYPE_MROPE => Some(RopeType::MRope),
696 llama_cpp_sys_2::LLAMA_ROPE_TYPE_VISION => Some(RopeType::Vision),
697 rope_type => {
698 tracing::error!(rope_type = rope_type, "Unexpected rope type from llama.cpp");
699 None
700 }
701 }
702 }
703
704 pub fn chat_template(
718 &self,
719 name: Option<&str>,
720 ) -> Result<LlamaChatTemplate, ChatTemplateError> {
721 let name_cstr = name.map(CString::new);
722 let name_ptr = match name_cstr {
723 Some(Ok(name)) => name.as_ptr(),
724 _ => std::ptr::null(),
725 };
726 let result =
727 unsafe { llama_cpp_sys_2::llama_model_chat_template(self.model.as_ptr(), name_ptr) };
728
729 if result.is_null() {
731 Err(ChatTemplateError::MissingTemplate)
732 } else {
733 let chat_template_cstr = unsafe { CStr::from_ptr(result) };
734 let chat_template = CString::new(chat_template_cstr.to_bytes())?;
735 Ok(LlamaChatTemplate(chat_template))
736 }
737 }
738
739 #[tracing::instrument(skip_all, fields(params))]
745 pub fn load_from_file(
746 _: &LlamaBackend,
747 path: impl AsRef<Path>,
748 params: &LlamaModelParams,
749 ) -> Result<Self, LlamaModelLoadError> {
750 let path = path.as_ref();
751 debug_assert!(Path::new(path).exists(), "{path:?} does not exist");
752 let path = path
753 .to_str()
754 .ok_or(LlamaModelLoadError::PathToStrError(path.to_path_buf()))?;
755
756 let cstr = CString::new(path)?;
757 let llama_model =
758 unsafe { llama_cpp_sys_2::llama_load_model_from_file(cstr.as_ptr(), params.params) };
759
760 let model = NonNull::new(llama_model).ok_or(LlamaModelLoadError::NullResult)?;
761
762 tracing::debug!(?path, "Loaded model");
763 Ok(LlamaModel { model })
764 }
765
766 pub fn lora_adapter_init(
772 &self,
773 path: impl AsRef<Path>,
774 ) -> Result<LlamaLoraAdapter, LlamaLoraAdapterInitError> {
775 let path = path.as_ref();
776 debug_assert!(Path::new(path).exists(), "{path:?} does not exist");
777
778 let path = path
779 .to_str()
780 .ok_or(LlamaLoraAdapterInitError::PathToStrError(
781 path.to_path_buf(),
782 ))?;
783
784 let cstr = CString::new(path)?;
785 let adapter =
786 unsafe { llama_cpp_sys_2::llama_adapter_lora_init(self.model.as_ptr(), cstr.as_ptr()) };
787
788 let adapter = NonNull::new(adapter).ok_or(LlamaLoraAdapterInitError::NullResult)?;
789
790 tracing::debug!(?path, "Initialized lora adapter");
791 Ok(LlamaLoraAdapter {
792 lora_adapter: adapter,
793 })
794 }
795
796 #[allow(clippy::needless_pass_by_value)]
803 pub fn new_context<'a>(
804 &'a self,
805 _: &LlamaBackend,
806 params: LlamaContextParams,
807 ) -> Result<LlamaContext<'a>, LlamaContextLoadError> {
808 let context_params = params.context_params;
809 let context = unsafe {
810 llama_cpp_sys_2::llama_new_context_with_model(self.model.as_ptr(), context_params)
811 };
812 let context = NonNull::new(context).ok_or(LlamaContextLoadError::NullReturn)?;
813
814 Ok(LlamaContext::new(self, context, params.embeddings()))
815 }
816
817 #[tracing::instrument(skip_all)]
835 pub fn apply_chat_template(
836 &self,
837 tmpl: &LlamaChatTemplate,
838 chat: &[LlamaChatMessage],
839 add_ass: bool,
840 ) -> Result<String, ApplyChatTemplateError> {
841 let message_length = chat.iter().fold(0, |acc, c| {
843 acc + c.role.to_bytes().len() + c.content.to_bytes().len()
844 });
845 let mut buff: Vec<u8> = vec![0; message_length * 2];
846
847 let chat: Vec<llama_cpp_sys_2::llama_chat_message> = chat
849 .iter()
850 .map(|c| llama_cpp_sys_2::llama_chat_message {
851 role: c.role.as_ptr(),
852 content: c.content.as_ptr(),
853 })
854 .collect();
855
856 let tmpl_ptr = tmpl.0.as_ptr();
857
858 let res = unsafe {
859 llama_cpp_sys_2::llama_chat_apply_template(
860 tmpl_ptr,
861 chat.as_ptr(),
862 chat.len(),
863 add_ass,
864 buff.as_mut_ptr().cast::<c_char>(),
865 buff.len().try_into().expect("Buffer size exceeds i32::MAX"),
866 )
867 };
868
869 if res < 0 {
870 return Err(ApplyChatTemplateError::FfiError(res));
871 }
872
873 if res > buff.len().try_into().expect("Buffer size exceeds i32::MAX") {
874 buff.resize(res.try_into().expect("res is negative"), 0);
875
876 let res = unsafe {
877 llama_cpp_sys_2::llama_chat_apply_template(
878 tmpl_ptr,
879 chat.as_ptr(),
880 chat.len(),
881 add_ass,
882 buff.as_mut_ptr().cast::<c_char>(),
883 buff.len().try_into().expect("Buffer size exceeds i32::MAX"),
884 )
885 };
886 if res < 0 {
887 return Err(ApplyChatTemplateError::FfiError(res));
888 }
889 assert_eq!(Ok(res), buff.len().try_into());
890 }
891 buff.truncate(res.try_into().expect("res is negative"));
892 Ok(String::from_utf8(buff)?)
893 }
894}
895
896fn extract_meta_string<F>(c_function: F, capacity: usize) -> Result<String, MetaValError>
902where
903 F: Fn(*mut c_char, usize) -> i32,
904{
905 let mut buffer = vec![0u8; capacity];
906
907 let result = c_function(buffer.as_mut_ptr().cast::<c_char>(), buffer.len());
909 if result < 0 {
910 return Err(MetaValError::NegativeReturn(result));
911 }
912
913 let returned_len = result as usize;
915 if returned_len >= capacity {
916 return extract_meta_string(c_function, returned_len + 1);
918 }
919
920 debug_assert_eq!(
922 buffer.get(returned_len),
923 Some(&0),
924 "should end with null byte"
925 );
926
927 buffer.truncate(returned_len);
929 Ok(String::from_utf8(buffer)?)
930}
931
932impl Drop for LlamaModel {
933 fn drop(&mut self) {
934 unsafe { llama_cpp_sys_2::llama_free_model(self.model.as_ptr()) }
935 }
936}
937
938#[repr(u32)]
940#[derive(Debug, Eq, Copy, Clone, PartialEq)]
941pub enum VocabType {
942 BPE = llama_cpp_sys_2::LLAMA_VOCAB_TYPE_BPE as _,
944 SPM = llama_cpp_sys_2::LLAMA_VOCAB_TYPE_SPM as _,
946}
947
948#[derive(thiserror::Error, Debug, Eq, PartialEq)]
950pub enum LlamaTokenTypeFromIntError {
951 #[error("Unknown Value {0}")]
953 UnknownValue(llama_cpp_sys_2::llama_vocab_type),
954}
955
956impl TryFrom<llama_cpp_sys_2::llama_vocab_type> for VocabType {
957 type Error = LlamaTokenTypeFromIntError;
958
959 fn try_from(value: llama_cpp_sys_2::llama_vocab_type) -> Result<Self, Self::Error> {
960 match value {
961 llama_cpp_sys_2::LLAMA_VOCAB_TYPE_BPE => Ok(VocabType::BPE),
962 llama_cpp_sys_2::LLAMA_VOCAB_TYPE_SPM => Ok(VocabType::SPM),
963 unknown => Err(LlamaTokenTypeFromIntError::UnknownValue(unknown)),
964 }
965 }
966}