Skip to main content

llama_cpp_bindings/
token_type.rs

1//! Utilities for working with `llama_token_type` values.
2use enumflags2::{BitFlags, bitflags};
3use std::ops::{Deref, DerefMut};
4
5/// A rust flavored equivalent of `llama_token_type`.
6#[derive(Eq, PartialEq, Debug, Clone, Copy)]
7#[bitflags]
8#[repr(u32)]
9pub enum LlamaTokenAttr {
10    /// Unknown token attribute.
11    Unknown = llama_cpp_bindings_sys::LLAMA_TOKEN_ATTR_UNKNOWN as _,
12    /// Unused token attribute.
13    Unused = llama_cpp_bindings_sys::LLAMA_TOKEN_ATTR_UNUSED as _,
14    /// Normal text token.
15    Normal = llama_cpp_bindings_sys::LLAMA_TOKEN_ATTR_NORMAL as _,
16    /// Control token (e.g. BOS, EOS).
17    Control = llama_cpp_bindings_sys::LLAMA_TOKEN_ATTR_CONTROL as _,
18    /// User-defined token.
19    UserDefined = llama_cpp_bindings_sys::LLAMA_TOKEN_ATTR_USER_DEFINED as _,
20    /// Byte-level fallback token.
21    Byte = llama_cpp_bindings_sys::LLAMA_TOKEN_ATTR_BYTE as _,
22    /// Token with normalized text.
23    Normalized = llama_cpp_bindings_sys::LLAMA_TOKEN_ATTR_NORMALIZED as _,
24    /// Token with left-stripped whitespace.
25    LStrip = llama_cpp_bindings_sys::LLAMA_TOKEN_ATTR_LSTRIP as _,
26    /// Token with right-stripped whitespace.
27    RStrip = llama_cpp_bindings_sys::LLAMA_TOKEN_ATTR_RSTRIP as _,
28    /// Token representing a single word.
29    SingleWord = llama_cpp_bindings_sys::LLAMA_TOKEN_ATTR_SINGLE_WORD as _,
30}
31
32/// A set of `LlamaTokenAttrs`
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub struct LlamaTokenAttrs(pub BitFlags<LlamaTokenAttr>);
35
36impl Deref for LlamaTokenAttrs {
37    type Target = BitFlags<LlamaTokenAttr>;
38
39    fn deref(&self) -> &Self::Target {
40        &self.0
41    }
42}
43
44impl DerefMut for LlamaTokenAttrs {
45    fn deref_mut(&mut self) -> &mut Self::Target {
46        &mut self.0
47    }
48}
49
50impl TryFrom<llama_cpp_bindings_sys::llama_token_type> for LlamaTokenAttrs {
51    type Error = LlamaTokenTypeFromIntError;
52
53    fn try_from(value: llama_cpp_bindings_sys::llama_vocab_type) -> Result<Self, Self::Error> {
54        Ok(Self(BitFlags::from_bits(value as _).map_err(|e| {
55            LlamaTokenTypeFromIntError::UnknownValue(e.invalid_bits())
56        })?))
57    }
58}
59
60/// An error type for `LlamaTokenType::try_from`.
61#[derive(thiserror::Error, Debug, Eq, PartialEq)]
62pub enum LlamaTokenTypeFromIntError {
63    /// The value is not a valid `llama_token_type`.
64    #[error("Unknown Value {0}")]
65    UnknownValue(std::ffi::c_uint),
66}
67
68#[cfg(test)]
69mod tests {
70    use enumflags2::BitFlags;
71
72    use super::{LlamaTokenAttr, LlamaTokenAttrs, LlamaTokenTypeFromIntError};
73
74    #[test]
75    fn try_from_valid_single_attribute() {
76        let attrs =
77            LlamaTokenAttrs::try_from(llama_cpp_bindings_sys::LLAMA_TOKEN_ATTR_NORMAL as u32);
78
79        assert!(attrs.is_ok());
80        assert!(
81            attrs
82                .expect("valid attribute")
83                .contains(LlamaTokenAttr::Normal)
84        );
85    }
86
87    #[test]
88    fn try_from_zero_produces_empty_flags() {
89        let attrs = LlamaTokenAttrs::try_from(0u32);
90
91        assert!(attrs.is_ok());
92        assert!(attrs.expect("valid attribute").is_empty());
93    }
94
95    #[test]
96    fn try_from_invalid_bits_returns_error() {
97        let invalid_value = 0xFFFF_FFFFu32;
98        let result = LlamaTokenAttrs::try_from(invalid_value);
99
100        assert!(result.is_err());
101        matches!(
102            result.expect_err("should fail"),
103            LlamaTokenTypeFromIntError::UnknownValue(_)
104        );
105    }
106
107    #[test]
108    fn deref_exposes_bitflags_methods() {
109        let attrs = LlamaTokenAttrs(BitFlags::from_flag(LlamaTokenAttr::Control));
110
111        assert!(attrs.contains(LlamaTokenAttr::Control));
112        assert!(!attrs.contains(LlamaTokenAttr::Normal));
113    }
114
115    #[test]
116    fn deref_mut_allows_modification() {
117        let mut attrs = LlamaTokenAttrs(BitFlags::empty());
118
119        attrs.insert(LlamaTokenAttr::Byte);
120
121        assert!(attrs.contains(LlamaTokenAttr::Byte));
122    }
123}