llama_cpp_bindings/token/data.rs
1//! Safe wrapper around `llama_token_data`.
2use crate::token::LlamaToken;
3
4/// A transparent wrapper around `llama_token_data`.
5///
6/// Do not rely on `repr(transparent)` for this type. It should be considered an implementation
7/// detail and may change across minor versions.
8#[derive(Clone, Copy, Debug, PartialEq)]
9#[repr(transparent)]
10pub struct LlamaTokenData {
11 data: llama_cpp_bindings_sys::llama_token_data,
12}
13
14impl LlamaTokenData {
15 /// Create a new token data from a token, logit, and probability.
16 /// ```
17 /// # use llama_cpp_bindings::token::LlamaToken;
18 /// # use llama_cpp_bindings::token::data::LlamaTokenData;
19 /// let token = LlamaToken::new(1);
20 /// let token_data = LlamaTokenData::new(token, 1.0, 1.0);
21 #[must_use]
22 pub const fn new(LlamaToken(id): LlamaToken, logit: f32, p: f32) -> Self {
23 Self {
24 data: llama_cpp_bindings_sys::llama_token_data { id, logit, p },
25 }
26 }
27 /// Get the token's id
28 /// ```
29 /// # use llama_cpp_bindings::token::LlamaToken;
30 /// # use llama_cpp_bindings::token::data::LlamaTokenData;
31 /// let token = LlamaToken::new(1);
32 /// let token_data = LlamaTokenData::new(token, 1.0, 1.0);
33 /// assert_eq!(token_data.id(), token);
34 /// ```
35 #[must_use]
36 pub const fn id(&self) -> LlamaToken {
37 LlamaToken(self.data.id)
38 }
39
40 /// Get the token's logit
41 /// ```
42 /// # use llama_cpp_bindings::token::LlamaToken;
43 /// # use llama_cpp_bindings::token::data::LlamaTokenData;
44 /// let token = LlamaToken::new(1);
45 /// let token_data = LlamaTokenData::new(token, 1.0, 1.0);
46 /// assert_eq!(token_data.logit(), 1.0);
47 /// ```
48 #[must_use]
49 pub const fn logit(&self) -> f32 {
50 self.data.logit
51 }
52
53 /// Get the token's probability
54 /// ```
55 /// # use llama_cpp_bindings::token::LlamaToken;
56 /// # use llama_cpp_bindings::token::data::LlamaTokenData;
57 /// let token = LlamaToken::new(1);
58 /// let token_data = LlamaTokenData::new(token, 1.0, 1.0);
59 /// assert_eq!(token_data.p(), 1.0);
60 /// ```
61 #[must_use]
62 pub const fn p(&self) -> f32 {
63 self.data.p
64 }
65
66 /// Set the token's id
67 /// ```
68 /// # use llama_cpp_bindings::token::LlamaToken;
69 /// # use llama_cpp_bindings::token::data::LlamaTokenData;
70 /// let token = LlamaToken::new(1);
71 /// let mut token_data = LlamaTokenData::new(token, 1.0, 1.0);
72 /// token_data.set_id(LlamaToken::new(2));
73 /// assert_eq!(token_data.id(), LlamaToken::new(2));
74 /// ```
75 pub const fn set_id(&mut self, id: LlamaToken) {
76 self.data.id = id.0;
77 }
78
79 /// Set the token's logit
80 /// ```
81 /// # use llama_cpp_bindings::token::LlamaToken;
82 /// # use llama_cpp_bindings::token::data::LlamaTokenData;
83 /// let token = LlamaToken::new(1);
84 /// let mut token_data = LlamaTokenData::new(token, 1.0, 1.0);
85 /// token_data.set_logit(2.0);
86 /// assert_eq!(token_data.logit(), 2.0);
87 /// ```
88 pub const fn set_logit(&mut self, logit: f32) {
89 self.data.logit = logit;
90 }
91
92 /// Set the token's probability
93 /// ```
94 /// # use llama_cpp_bindings::token::LlamaToken;
95 /// # use llama_cpp_bindings::token::data::LlamaTokenData;
96 /// let token = LlamaToken::new(1);
97 /// let mut token_data = LlamaTokenData::new(token, 1.0, 1.0);
98 /// token_data.set_p(2.0);
99 /// assert_eq!(token_data.p(), 2.0);
100 /// ```
101 pub const fn set_p(&mut self, p: f32) {
102 self.data.p = p;
103 }
104}
105
106#[cfg(test)]
107mod tests {
108 use super::LlamaTokenData;
109 use crate::token::LlamaToken;
110
111 #[test]
112 fn new_stores_all_fields() {
113 let token = LlamaToken::new(7);
114 let data = LlamaTokenData::new(token, 2.5, 0.8);
115 assert_eq!(data.id(), token);
116 assert!((data.logit() - 2.5_f32).abs() < f32::EPSILON);
117 assert!((data.p() - 0.8_f32).abs() < f32::EPSILON);
118 }
119
120 #[test]
121 fn set_id_updates_token() {
122 let mut data = LlamaTokenData::new(LlamaToken::new(1), 0.0, 0.0);
123 data.set_id(LlamaToken::new(42));
124 assert_eq!(data.id(), LlamaToken::new(42));
125 }
126
127 #[test]
128 fn set_logit_updates_logit() {
129 let mut data = LlamaTokenData::new(LlamaToken::new(1), 0.0, 0.0);
130 data.set_logit(-1.5);
131 assert!((data.logit() - (-1.5_f32)).abs() < f32::EPSILON);
132 }
133
134 #[test]
135 fn set_p_updates_probability() {
136 let mut data = LlamaTokenData::new(LlamaToken::new(1), 0.0, 0.0);
137 data.set_p(0.95);
138 assert!((data.p() - 0.95_f32).abs() < f32::EPSILON);
139 }
140}