Skip to main content

llama_cpp_bindings/
llama_token_attrs.rs

1use std::ops::{Deref, DerefMut};
2
3use enumflags2::BitFlags;
4
5use crate::llama_token_attr::LlamaTokenAttr;
6use crate::llama_token_attrs_from_int_error::LlamaTokenAttrsFromIntError;
7
8#[cfg(target_env = "msvc")]
9const fn llama_token_type_to_u32(value: llama_cpp_bindings_sys::llama_token_type) -> u32 {
10    value.cast_unsigned()
11}
12
13#[cfg(not(target_env = "msvc"))]
14const fn llama_token_type_to_u32(value: llama_cpp_bindings_sys::llama_token_type) -> u32 {
15    value
16}
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub struct LlamaTokenAttrs(pub BitFlags<LlamaTokenAttr>);
20
21impl Deref for LlamaTokenAttrs {
22    type Target = BitFlags<LlamaTokenAttr>;
23
24    fn deref(&self) -> &Self::Target {
25        &self.0
26    }
27}
28
29impl DerefMut for LlamaTokenAttrs {
30    fn deref_mut(&mut self) -> &mut Self::Target {
31        &mut self.0
32    }
33}
34
35impl TryFrom<llama_cpp_bindings_sys::llama_token_type> for LlamaTokenAttrs {
36    type Error = LlamaTokenAttrsFromIntError;
37
38    fn try_from(value: llama_cpp_bindings_sys::llama_vocab_type) -> Result<Self, Self::Error> {
39        Ok(Self(
40            BitFlags::from_bits(llama_token_type_to_u32(value)).map_err(|bit_flag_error| {
41                LlamaTokenAttrsFromIntError::UnknownValue(bit_flag_error.invalid_bits())
42            })?,
43        ))
44    }
45}
46
47#[cfg(test)]
48mod tests {
49    use enumflags2::BitFlags;
50
51    use super::{LlamaTokenAttr, LlamaTokenAttrs, LlamaTokenAttrsFromIntError};
52
53    #[test]
54    fn try_from_valid_single_attribute() {
55        let attrs = LlamaTokenAttrs::try_from(llama_cpp_bindings_sys::LLAMA_TOKEN_ATTR_NORMAL);
56
57        assert!(attrs.is_ok());
58        assert!(
59            attrs
60                .expect("valid attribute")
61                .contains(LlamaTokenAttr::Normal)
62        );
63    }
64
65    #[test]
66    fn try_from_zero_produces_empty_flags() {
67        let attrs = LlamaTokenAttrs::try_from(0);
68
69        assert!(attrs.is_ok());
70        assert!(attrs.expect("valid attribute").is_empty());
71    }
72
73    #[test]
74    fn try_from_invalid_bits_returns_error() {
75        let err = LlamaTokenAttrs::try_from(!0).unwrap_err();
76        let LlamaTokenAttrsFromIntError::UnknownValue(invalid_bits) = err;
77
78        assert!(
79            invalid_bits > 0,
80            "passing !0 must produce at least one unknown bit"
81        );
82    }
83
84    #[test]
85    fn deref_exposes_bitflags_methods() {
86        let attrs = LlamaTokenAttrs(BitFlags::from_flag(LlamaTokenAttr::Control));
87
88        assert!(attrs.contains(LlamaTokenAttr::Control));
89        assert!(!attrs.contains(LlamaTokenAttr::Normal));
90    }
91
92    #[test]
93    fn deref_mut_allows_modification() {
94        let mut attrs = LlamaTokenAttrs(BitFlags::empty());
95
96        attrs.insert(LlamaTokenAttr::Byte);
97
98        assert!(attrs.contains(LlamaTokenAttr::Byte));
99    }
100}