Skip to main content

llama_cpp_bindings/token/
data.rs

1//! Safe wrapper around `llama_token_data`.
2use crate::token::LlamaToken;
3
4/// A transparent wrapper around `llama_token_data`.
5///
6/// Do not rely on `repr(transparent)` for this type. It should be considered an implementation
7/// detail and may change across minor versions.
8#[derive(Clone, Copy, Debug, PartialEq)]
9#[repr(transparent)]
10pub struct LlamaTokenData {
11    data: llama_cpp_bindings_sys::llama_token_data,
12}
13
14impl LlamaTokenData {
15    /// Create a new token data from a token, logit, and probability.
16    /// ```
17    /// # use llama_cpp_bindings::token::LlamaToken;
18    /// # use llama_cpp_bindings::token::data::LlamaTokenData;
19    /// let token = LlamaToken::new(1);
20    /// let token_data = LlamaTokenData::new(token, 1.0, 1.0);
21    #[must_use]
22    pub const fn new(LlamaToken(id): LlamaToken, logit: f32, p: f32) -> Self {
23        Self {
24            data: llama_cpp_bindings_sys::llama_token_data { id, logit, p },
25        }
26    }
27    /// Get the token's id
28    /// ```
29    /// # use llama_cpp_bindings::token::LlamaToken;
30    /// # use llama_cpp_bindings::token::data::LlamaTokenData;
31    /// let token = LlamaToken::new(1);
32    /// let token_data = LlamaTokenData::new(token, 1.0, 1.0);
33    /// assert_eq!(token_data.id(), token);
34    /// ```
35    #[must_use]
36    pub const fn id(&self) -> LlamaToken {
37        LlamaToken(self.data.id)
38    }
39
40    /// Get the token's logit
41    /// ```
42    /// # use llama_cpp_bindings::token::LlamaToken;
43    /// # use llama_cpp_bindings::token::data::LlamaTokenData;
44    /// let token = LlamaToken::new(1);
45    /// let token_data = LlamaTokenData::new(token, 1.0, 1.0);
46    /// assert_eq!(token_data.logit(), 1.0);
47    /// ```
48    #[must_use]
49    pub const fn logit(&self) -> f32 {
50        self.data.logit
51    }
52
53    /// Get the token's probability
54    /// ```
55    /// # use llama_cpp_bindings::token::LlamaToken;
56    /// # use llama_cpp_bindings::token::data::LlamaTokenData;
57    /// let token = LlamaToken::new(1);
58    /// let token_data = LlamaTokenData::new(token, 1.0, 1.0);
59    /// assert_eq!(token_data.p(), 1.0);
60    /// ```
61    #[must_use]
62    pub const fn p(&self) -> f32 {
63        self.data.p
64    }
65
66    /// Set the token's id
67    /// ```
68    /// # use llama_cpp_bindings::token::LlamaToken;
69    /// # use llama_cpp_bindings::token::data::LlamaTokenData;
70    /// let token = LlamaToken::new(1);
71    /// let mut token_data = LlamaTokenData::new(token, 1.0, 1.0);
72    /// token_data.set_id(LlamaToken::new(2));
73    /// assert_eq!(token_data.id(), LlamaToken::new(2));
74    /// ```
75    pub const fn set_id(&mut self, id: LlamaToken) {
76        self.data.id = id.0;
77    }
78
79    /// Set the token's logit
80    /// ```
81    /// # use llama_cpp_bindings::token::LlamaToken;
82    /// # use llama_cpp_bindings::token::data::LlamaTokenData;
83    /// let token = LlamaToken::new(1);
84    /// let mut token_data = LlamaTokenData::new(token, 1.0, 1.0);
85    /// token_data.set_logit(2.0);
86    /// assert_eq!(token_data.logit(), 2.0);
87    /// ```
88    pub const fn set_logit(&mut self, logit: f32) {
89        self.data.logit = logit;
90    }
91
92    /// Set the token's probability
93    /// ```
94    /// # use llama_cpp_bindings::token::LlamaToken;
95    /// # use llama_cpp_bindings::token::data::LlamaTokenData;
96    /// let token = LlamaToken::new(1);
97    /// let mut token_data = LlamaTokenData::new(token, 1.0, 1.0);
98    /// token_data.set_p(2.0);
99    /// assert_eq!(token_data.p(), 2.0);
100    /// ```
101    pub const fn set_p(&mut self, p: f32) {
102        self.data.p = p;
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use super::LlamaTokenData;
109    use crate::token::LlamaToken;
110
111    #[test]
112    fn new_stores_all_fields() {
113        let token = LlamaToken::new(7);
114        let data = LlamaTokenData::new(token, 2.5, 0.8);
115        assert_eq!(data.id(), token);
116        assert!((data.logit() - 2.5_f32).abs() < f32::EPSILON);
117        assert!((data.p() - 0.8_f32).abs() < f32::EPSILON);
118    }
119
120    #[test]
121    fn set_id_updates_token() {
122        let mut data = LlamaTokenData::new(LlamaToken::new(1), 0.0, 0.0);
123        data.set_id(LlamaToken::new(42));
124        assert_eq!(data.id(), LlamaToken::new(42));
125    }
126
127    #[test]
128    fn set_logit_updates_logit() {
129        let mut data = LlamaTokenData::new(LlamaToken::new(1), 0.0, 0.0);
130        data.set_logit(-1.5);
131        assert!((data.logit() - (-1.5_f32)).abs() < f32::EPSILON);
132    }
133
134    #[test]
135    fn set_p_updates_probability() {
136        let mut data = LlamaTokenData::new(LlamaToken::new(1), 0.0, 0.0);
137        data.set_p(0.95);
138        assert!((data.p() - 0.95_f32).abs() < f32::EPSILON);
139    }
140}