Skip to main content

llama_cpp_bindings/
token_type.rs

1//! Utilities for working with `llama_token_type` values.
2use enumflags2::{BitFlags, bitflags};
3use std::ops::{Deref, DerefMut};
4
5/// A rust flavored equivalent of `llama_token_type`.
6#[derive(Eq, PartialEq, Debug, Clone, Copy)]
7#[bitflags]
8#[repr(u32)]
9pub enum LlamaTokenAttr {
10    /// Unknown token attribute.
11    Unknown = llama_cpp_bindings_sys::LLAMA_TOKEN_ATTR_UNKNOWN as _,
12    /// Unused token attribute.
13    Unused = llama_cpp_bindings_sys::LLAMA_TOKEN_ATTR_UNUSED as _,
14    /// Normal text token.
15    Normal = llama_cpp_bindings_sys::LLAMA_TOKEN_ATTR_NORMAL as _,
16    /// Control token (e.g. BOS, EOS).
17    Control = llama_cpp_bindings_sys::LLAMA_TOKEN_ATTR_CONTROL as _,
18    /// User-defined token.
19    UserDefined = llama_cpp_bindings_sys::LLAMA_TOKEN_ATTR_USER_DEFINED as _,
20    /// Byte-level fallback token.
21    Byte = llama_cpp_bindings_sys::LLAMA_TOKEN_ATTR_BYTE as _,
22    /// Token with normalized text.
23    Normalized = llama_cpp_bindings_sys::LLAMA_TOKEN_ATTR_NORMALIZED as _,
24    /// Token with left-stripped whitespace.
25    LStrip = llama_cpp_bindings_sys::LLAMA_TOKEN_ATTR_LSTRIP as _,
26    /// Token with right-stripped whitespace.
27    RStrip = llama_cpp_bindings_sys::LLAMA_TOKEN_ATTR_RSTRIP as _,
28    /// Token representing a single word.
29    SingleWord = llama_cpp_bindings_sys::LLAMA_TOKEN_ATTR_SINGLE_WORD as _,
30}
31
32/// A set of `LlamaTokenAttrs`
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub struct LlamaTokenAttrs(pub BitFlags<LlamaTokenAttr>);
35
36impl Deref for LlamaTokenAttrs {
37    type Target = BitFlags<LlamaTokenAttr>;
38
39    fn deref(&self) -> &Self::Target {
40        &self.0
41    }
42}
43
44impl DerefMut for LlamaTokenAttrs {
45    fn deref_mut(&mut self) -> &mut Self::Target {
46        &mut self.0
47    }
48}
49
50impl TryFrom<llama_cpp_bindings_sys::llama_token_type> for LlamaTokenAttrs {
51    type Error = LlamaTokenTypeFromIntError;
52
53    fn try_from(value: llama_cpp_bindings_sys::llama_vocab_type) -> Result<Self, Self::Error> {
54        Ok(Self(BitFlags::from_bits(value as _).map_err(|e| {
55            LlamaTokenTypeFromIntError::UnknownValue(e.invalid_bits())
56        })?))
57    }
58}
59
60/// An error type for `LlamaTokenType::try_from`.
61#[derive(thiserror::Error, Debug, Eq, PartialEq)]
62pub enum LlamaTokenTypeFromIntError {
63    /// The value is not a valid `llama_token_type`.
64    #[error("Unknown Value {0}")]
65    UnknownValue(std::ffi::c_uint),
66}
67
68#[cfg(test)]
69mod tests {
70    use enumflags2::BitFlags;
71
72    use super::{LlamaTokenAttr, LlamaTokenAttrs, LlamaTokenTypeFromIntError};
73
74    #[test]
75    fn try_from_valid_single_attribute() {
76        let attrs = LlamaTokenAttrs::try_from(llama_cpp_bindings_sys::LLAMA_TOKEN_ATTR_NORMAL);
77
78        assert!(attrs.is_ok());
79        assert!(
80            attrs
81                .expect("valid attribute")
82                .contains(LlamaTokenAttr::Normal)
83        );
84    }
85
86    #[test]
87    fn try_from_zero_produces_empty_flags() {
88        let attrs = LlamaTokenAttrs::try_from(0u32);
89
90        assert!(attrs.is_ok());
91        assert!(attrs.expect("valid attribute").is_empty());
92    }
93
94    #[test]
95    fn try_from_invalid_bits_returns_error() {
96        let invalid_value = 0xFFFF_FFFFu32;
97        let result = LlamaTokenAttrs::try_from(invalid_value);
98
99        assert!(result.is_err());
100        matches!(
101            result.expect_err("should fail"),
102            LlamaTokenTypeFromIntError::UnknownValue(_)
103        );
104    }
105
106    #[test]
107    fn deref_exposes_bitflags_methods() {
108        let attrs = LlamaTokenAttrs(BitFlags::from_flag(LlamaTokenAttr::Control));
109
110        assert!(attrs.contains(LlamaTokenAttr::Control));
111        assert!(!attrs.contains(LlamaTokenAttr::Normal));
112    }
113
114    #[test]
115    fn deref_mut_allows_modification() {
116        let mut attrs = LlamaTokenAttrs(BitFlags::empty());
117
118        attrs.insert(LlamaTokenAttr::Byte);
119
120        assert!(attrs.contains(LlamaTokenAttr::Byte));
121    }
122}