llama_cpp_bindings/
llama_token_attrs.rs1use std::ops::{Deref, DerefMut};
2
3use enumflags2::BitFlags;
4
5use crate::llama_token_attr::LlamaTokenAttr;
6use crate::llama_token_attrs_from_int_error::LlamaTokenAttrsFromIntError;
7
8#[cfg(target_env = "msvc")]
9const fn llama_token_type_to_u32(value: llama_cpp_bindings_sys::llama_token_type) -> u32 {
10 value.cast_unsigned()
11}
12
13#[cfg(not(target_env = "msvc"))]
14const fn llama_token_type_to_u32(value: llama_cpp_bindings_sys::llama_token_type) -> u32 {
15 value
16}
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub struct LlamaTokenAttrs(pub BitFlags<LlamaTokenAttr>);
20
21impl Deref for LlamaTokenAttrs {
22 type Target = BitFlags<LlamaTokenAttr>;
23
24 fn deref(&self) -> &Self::Target {
25 &self.0
26 }
27}
28
29impl DerefMut for LlamaTokenAttrs {
30 fn deref_mut(&mut self) -> &mut Self::Target {
31 &mut self.0
32 }
33}
34
35impl TryFrom<llama_cpp_bindings_sys::llama_token_type> for LlamaTokenAttrs {
36 type Error = LlamaTokenAttrsFromIntError;
37
38 fn try_from(value: llama_cpp_bindings_sys::llama_vocab_type) -> Result<Self, Self::Error> {
39 Ok(Self(
40 BitFlags::from_bits(llama_token_type_to_u32(value)).map_err(|bit_flag_error| {
41 LlamaTokenAttrsFromIntError::UnknownValue(bit_flag_error.invalid_bits())
42 })?,
43 ))
44 }
45}
46
47#[cfg(test)]
48mod tests {
49 use enumflags2::BitFlags;
50
51 use super::{LlamaTokenAttr, LlamaTokenAttrs, LlamaTokenAttrsFromIntError};
52
53 #[test]
54 fn try_from_valid_single_attribute() {
55 let attrs = LlamaTokenAttrs::try_from(llama_cpp_bindings_sys::LLAMA_TOKEN_ATTR_NORMAL);
56
57 assert!(attrs.is_ok());
58 assert!(
59 attrs
60 .expect("valid attribute")
61 .contains(LlamaTokenAttr::Normal)
62 );
63 }
64
65 #[test]
66 fn try_from_zero_produces_empty_flags() {
67 let attrs = LlamaTokenAttrs::try_from(0);
68
69 assert!(attrs.is_ok());
70 assert!(attrs.expect("valid attribute").is_empty());
71 }
72
73 #[test]
74 fn try_from_invalid_bits_returns_error() {
75 let err = LlamaTokenAttrs::try_from(!0).unwrap_err();
76 let LlamaTokenAttrsFromIntError::UnknownValue(invalid_bits) = err;
77
78 assert!(
79 invalid_bits > 0,
80 "passing !0 must produce at least one unknown bit"
81 );
82 }
83
84 #[test]
85 fn deref_exposes_bitflags_methods() {
86 let attrs = LlamaTokenAttrs(BitFlags::from_flag(LlamaTokenAttr::Control));
87
88 assert!(attrs.contains(LlamaTokenAttr::Control));
89 assert!(!attrs.contains(LlamaTokenAttr::Normal));
90 }
91
92 #[test]
93 fn deref_mut_allows_modification() {
94 let mut attrs = LlamaTokenAttrs(BitFlags::empty());
95
96 attrs.insert(LlamaTokenAttr::Byte);
97
98 assert!(attrs.contains(LlamaTokenAttr::Byte));
99 }
100}