llama_cpp_bindings/token/
data_array.rs1use std::ptr;
3
4use crate::error::TokenSamplingError;
5use crate::sampling::LlamaSampler;
6use crate::token::data::LlamaTokenData;
7
8use super::LlamaToken;
9
10#[derive(Debug, Clone, PartialEq)]
12pub struct LlamaTokenDataArray {
13 pub data: Vec<LlamaTokenData>,
15 pub selected: Option<usize>,
17 pub sorted: bool,
19}
20
21impl LlamaTokenDataArray {
22 #[must_use]
36 pub const fn new(data: Vec<LlamaTokenData>, sorted: bool) -> Self {
37 Self {
38 data,
39 selected: None,
40 sorted,
41 }
42 }
43
44 pub fn from_iter<TIterator>(data: TIterator, sorted: bool) -> Self
56 where
57 TIterator: IntoIterator<Item = LlamaTokenData>,
58 {
59 Self::new(data.into_iter().collect(), sorted)
60 }
61
62 #[must_use]
64 pub fn selected_token(&self) -> Option<LlamaToken> {
65 self.data.get(self.selected?).map(LlamaTokenData::id)
66 }
67}
68
69impl LlamaTokenDataArray {
70 pub unsafe fn modify_as_c_llama_token_data_array<TResult>(
84 &mut self,
85 modify: impl FnOnce(&mut llama_cpp_bindings_sys::llama_token_data_array) -> TResult,
86 ) -> TResult {
87 let size = self.data.len();
88 let data = self
89 .data
90 .as_mut_ptr()
91 .cast::<llama_cpp_bindings_sys::llama_token_data>();
92
93 let mut c_llama_token_data_array = llama_cpp_bindings_sys::llama_token_data_array {
94 data,
95 size,
96 selected: self.selected.and_then(|s| s.try_into().ok()).unwrap_or(-1),
97 sorted: self.sorted,
98 };
99
100 let result = modify(&mut c_llama_token_data_array);
101
102 assert!(c_llama_token_data_array.size <= self.data.capacity());
103 unsafe {
105 if !ptr::eq(c_llama_token_data_array.data, data) {
106 ptr::copy(
107 c_llama_token_data_array.data,
108 data,
109 c_llama_token_data_array.size,
110 );
111 }
112 self.data.set_len(c_llama_token_data_array.size);
113 }
114
115 self.sorted = c_llama_token_data_array.sorted;
116 self.selected = c_llama_token_data_array
117 .selected
118 .try_into()
119 .ok()
120 .filter(|&s| s < self.data.len());
121
122 result
123 }
124
125 pub fn apply_sampler(&mut self, sampler: &LlamaSampler) {
127 unsafe {
128 self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| {
129 llama_cpp_bindings_sys::llama_sampler_apply(
130 sampler.sampler,
131 c_llama_token_data_array,
132 );
133 });
134 }
135 }
136
137 #[must_use]
139 pub fn with_sampler(mut self, sampler: &mut LlamaSampler) -> Self {
140 self.apply_sampler(sampler);
141 self
142 }
143
144 pub fn sample_token(&mut self, seed: u32) -> Result<LlamaToken, TokenSamplingError> {
149 self.apply_sampler(&LlamaSampler::dist(seed));
150 self.selected_token()
151 .ok_or(TokenSamplingError::NoTokenSelected)
152 }
153
154 pub fn sample_token_greedy(&mut self) -> Result<LlamaToken, TokenSamplingError> {
159 self.apply_sampler(&LlamaSampler::greedy());
160 self.selected_token()
161 .ok_or(TokenSamplingError::NoTokenSelected)
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use crate::token::LlamaToken;
168 use crate::token::data::LlamaTokenData;
169
170 use super::LlamaTokenDataArray;
171
172 #[test]
173 fn apply_greedy_sampler_selects_highest_logit() {
174 use crate::sampling::LlamaSampler;
175
176 let mut array = LlamaTokenDataArray::new(
177 vec![
178 LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0),
179 LlamaTokenData::new(LlamaToken::new(1), 5.0, 0.0),
180 LlamaTokenData::new(LlamaToken::new(2), 3.0, 0.0),
181 ],
182 false,
183 );
184
185 array.apply_sampler(&LlamaSampler::greedy());
186
187 assert_eq!(array.selected_token(), Some(LlamaToken::new(1)));
188 }
189
190 #[test]
191 fn with_sampler_builder_pattern() {
192 use crate::sampling::LlamaSampler;
193
194 let array = LlamaTokenDataArray::new(
195 vec![
196 LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0),
197 LlamaTokenData::new(LlamaToken::new(1), 5.0, 0.0),
198 ],
199 false,
200 )
201 .with_sampler(&mut LlamaSampler::greedy());
202
203 assert_eq!(array.selected_token(), Some(LlamaToken::new(1)));
204 }
205
206 #[test]
207 fn sample_token_greedy_returns_highest() {
208 let mut array = LlamaTokenDataArray::new(
209 vec![
210 LlamaTokenData::new(LlamaToken::new(10), 0.1, 0.0),
211 LlamaTokenData::new(LlamaToken::new(20), 9.9, 0.0),
212 ],
213 false,
214 );
215
216 let token = array
217 .sample_token_greedy()
218 .expect("test: greedy sampler should select a token");
219
220 assert_eq!(token, LlamaToken::new(20));
221 }
222
223 #[test]
224 fn from_iter_creates_array_from_iterator() {
225 let array = LlamaTokenDataArray::from_iter(
226 [
227 LlamaTokenData::new(LlamaToken::new(0), 0.0, 0.0),
228 LlamaTokenData::new(LlamaToken::new(1), 1.0, 0.0),
229 LlamaTokenData::new(LlamaToken::new(2), 2.0, 0.0),
230 ],
231 false,
232 );
233
234 assert_eq!(array.data.len(), 3);
235 assert!(!array.sorted);
236 assert!(array.selected.is_none());
237 }
238
239 #[test]
240 fn sample_token_with_seed_selects_a_token() {
241 let mut array = LlamaTokenDataArray::new(
242 vec![
243 LlamaTokenData::new(LlamaToken::new(10), 1.0, 0.0),
244 LlamaTokenData::new(LlamaToken::new(20), 1.0, 0.0),
245 ],
246 false,
247 );
248
249 let token = array
250 .sample_token(42)
251 .expect("test: dist sampler should select a token");
252
253 assert!(token == LlamaToken::new(10) || token == LlamaToken::new(20));
254 }
255
256 #[test]
257 fn selected_token_returns_none_when_no_selection() {
258 let array = LlamaTokenDataArray::new(
259 vec![LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0)],
260 false,
261 );
262
263 assert!(array.selected_token().is_none());
264 }
265
266 #[test]
267 fn selected_token_returns_none_when_index_out_of_bounds() {
268 let array = LlamaTokenDataArray {
269 data: vec![LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0)],
270 selected: Some(5),
271 sorted: false,
272 };
273
274 assert!(array.selected_token().is_none());
275 }
276
277 #[test]
278 fn modify_as_c_llama_token_data_array_copies_when_data_pointer_changes() {
279 let mut array = LlamaTokenDataArray::new(
280 vec![
281 LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0),
282 LlamaTokenData::new(LlamaToken::new(1), 2.0, 0.0),
283 LlamaTokenData::new(LlamaToken::new(2), 3.0, 0.0),
284 ],
285 false,
286 );
287
288 let replacement = [
289 llama_cpp_bindings_sys::llama_token_data {
290 id: 10,
291 logit: 5.0,
292 p: 0.0,
293 },
294 llama_cpp_bindings_sys::llama_token_data {
295 id: 20,
296 logit: 6.0,
297 p: 0.0,
298 },
299 ];
300
301 unsafe {
302 array.modify_as_c_llama_token_data_array(|c_array| {
303 c_array.data = replacement.as_ptr().cast_mut();
304 c_array.size = replacement.len();
305 c_array.selected = 0;
306 });
307 }
308
309 assert_eq!(array.data.len(), 2);
310 assert_eq!(array.data[0].id(), LlamaToken::new(10));
311 assert_eq!(array.data[1].id(), LlamaToken::new(20));
312 assert_eq!(array.selected, Some(0));
313 }
314
315 #[test]
316 fn selected_overflow_uses_negative_one() {
317 let mut array = LlamaTokenDataArray {
318 data: vec![LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0)],
319 selected: Some(usize::MAX),
320 sorted: false,
321 };
322
323 unsafe {
324 array.modify_as_c_llama_token_data_array(|c_array| {
325 assert_eq!(c_array.selected, -1);
326 });
327 }
328 }
329}