llama_cpp_bindings/
token_type.rs1use enumflags2::{BitFlags, bitflags};
3use std::ops::{Deref, DerefMut};
4
5#[derive(Eq, PartialEq, Debug, Clone, Copy)]
7#[bitflags]
8#[repr(u32)]
9pub enum LlamaTokenAttr {
10 Unknown = llama_cpp_bindings_sys::LLAMA_TOKEN_ATTR_UNKNOWN as _,
12 Unused = llama_cpp_bindings_sys::LLAMA_TOKEN_ATTR_UNUSED as _,
14 Normal = llama_cpp_bindings_sys::LLAMA_TOKEN_ATTR_NORMAL as _,
16 Control = llama_cpp_bindings_sys::LLAMA_TOKEN_ATTR_CONTROL as _,
18 UserDefined = llama_cpp_bindings_sys::LLAMA_TOKEN_ATTR_USER_DEFINED as _,
20 Byte = llama_cpp_bindings_sys::LLAMA_TOKEN_ATTR_BYTE as _,
22 Normalized = llama_cpp_bindings_sys::LLAMA_TOKEN_ATTR_NORMALIZED as _,
24 LStrip = llama_cpp_bindings_sys::LLAMA_TOKEN_ATTR_LSTRIP as _,
26 RStrip = llama_cpp_bindings_sys::LLAMA_TOKEN_ATTR_RSTRIP as _,
28 SingleWord = llama_cpp_bindings_sys::LLAMA_TOKEN_ATTR_SINGLE_WORD as _,
30}
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub struct LlamaTokenAttrs(pub BitFlags<LlamaTokenAttr>);
35
36impl Deref for LlamaTokenAttrs {
37 type Target = BitFlags<LlamaTokenAttr>;
38
39 fn deref(&self) -> &Self::Target {
40 &self.0
41 }
42}
43
44impl DerefMut for LlamaTokenAttrs {
45 fn deref_mut(&mut self) -> &mut Self::Target {
46 &mut self.0
47 }
48}
49
50impl TryFrom<llama_cpp_bindings_sys::llama_token_type> for LlamaTokenAttrs {
51 type Error = LlamaTokenTypeFromIntError;
52
53 fn try_from(value: llama_cpp_bindings_sys::llama_vocab_type) -> Result<Self, Self::Error> {
54 Ok(Self(BitFlags::from_bits(value as _).map_err(|e| {
55 LlamaTokenTypeFromIntError::UnknownValue(e.invalid_bits())
56 })?))
57 }
58}
59
60#[derive(thiserror::Error, Debug, Eq, PartialEq)]
62pub enum LlamaTokenTypeFromIntError {
63 #[error("Unknown Value {0}")]
65 UnknownValue(std::ffi::c_uint),
66}
67
68#[cfg(test)]
69mod tests {
70 use enumflags2::BitFlags;
71
72 use super::{LlamaTokenAttr, LlamaTokenAttrs, LlamaTokenTypeFromIntError};
73
74 #[test]
75 fn try_from_valid_single_attribute() {
76 let attrs = LlamaTokenAttrs::try_from(llama_cpp_bindings_sys::LLAMA_TOKEN_ATTR_NORMAL);
77
78 assert!(attrs.is_ok());
79 assert!(
80 attrs
81 .expect("valid attribute")
82 .contains(LlamaTokenAttr::Normal)
83 );
84 }
85
86 #[test]
87 fn try_from_zero_produces_empty_flags() {
88 let attrs = LlamaTokenAttrs::try_from(0u32);
89
90 assert!(attrs.is_ok());
91 assert!(attrs.expect("valid attribute").is_empty());
92 }
93
94 #[test]
95 fn try_from_invalid_bits_returns_error() {
96 let invalid_value = 0xFFFF_FFFFu32;
97 let result = LlamaTokenAttrs::try_from(invalid_value);
98
99 assert!(result.is_err());
100 matches!(
101 result.expect_err("should fail"),
102 LlamaTokenTypeFromIntError::UnknownValue(_)
103 );
104 }
105
106 #[test]
107 fn deref_exposes_bitflags_methods() {
108 let attrs = LlamaTokenAttrs(BitFlags::from_flag(LlamaTokenAttr::Control));
109
110 assert!(attrs.contains(LlamaTokenAttr::Control));
111 assert!(!attrs.contains(LlamaTokenAttr::Normal));
112 }
113
114 #[test]
115 fn deref_mut_allows_modification() {
116 let mut attrs = LlamaTokenAttrs(BitFlags::empty());
117
118 attrs.insert(LlamaTokenAttr::Byte);
119
120 assert!(attrs.contains(LlamaTokenAttr::Byte));
121 }
122}