llama_cpp_bindings/token/
data_array.rs1use std::ptr;
3
4use crate::error::TokenSamplingError;
5use crate::sampling::LlamaSampler;
6use crate::token::data::LlamaTokenData;
7
8use super::LlamaToken;
9
10#[derive(Debug, Clone, PartialEq)]
12pub struct LlamaTokenDataArray {
13 pub data: Vec<LlamaTokenData>,
15 pub selected: Option<usize>,
17 pub sorted: bool,
19}
20
21impl LlamaTokenDataArray {
22 #[must_use]
36 pub const fn new(data: Vec<LlamaTokenData>, sorted: bool) -> Self {
37 Self {
38 data,
39 selected: None,
40 sorted,
41 }
42 }
43
44 pub fn from_iter<TIterator>(data: TIterator, sorted: bool) -> Self
56 where
57 TIterator: IntoIterator<Item = LlamaTokenData>,
58 {
59 Self::new(data.into_iter().collect(), sorted)
60 }
61
62 #[must_use]
64 pub fn selected_token(&self) -> Option<LlamaToken> {
65 self.data.get(self.selected?).map(LlamaTokenData::id)
66 }
67}
68
69impl LlamaTokenDataArray {
70 pub unsafe fn modify_as_c_llama_token_data_array<TResult>(
84 &mut self,
85 modify: impl FnOnce(&mut llama_cpp_bindings_sys::llama_token_data_array) -> TResult,
86 ) -> TResult {
87 let size = self.data.len();
88 let data = self
89 .data
90 .as_mut_ptr()
91 .cast::<llama_cpp_bindings_sys::llama_token_data>();
92
93 let mut c_llama_token_data_array = llama_cpp_bindings_sys::llama_token_data_array {
94 data,
95 size,
96 selected: self
97 .selected
98 .and_then(|selected_index| selected_index.try_into().ok())
99 .unwrap_or(-1),
100 sorted: self.sorted,
101 };
102
103 let result = modify(&mut c_llama_token_data_array);
104
105 assert!(c_llama_token_data_array.size <= self.data.capacity());
106 unsafe {
108 if !ptr::eq(c_llama_token_data_array.data, data) {
109 ptr::copy(
110 c_llama_token_data_array.data,
111 data,
112 c_llama_token_data_array.size,
113 );
114 }
115 self.data.set_len(c_llama_token_data_array.size);
116 }
117
118 self.sorted = c_llama_token_data_array.sorted;
119 self.selected = c_llama_token_data_array
120 .selected
121 .try_into()
122 .ok()
123 .filter(|&s| s < self.data.len());
124
125 result
126 }
127
128 pub fn apply_sampler(&mut self, sampler: &LlamaSampler) {
130 unsafe {
131 self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| {
132 llama_cpp_bindings_sys::llama_sampler_apply(
133 sampler.sampler,
134 c_llama_token_data_array,
135 );
136 });
137 }
138 }
139
140 #[must_use]
142 pub fn with_sampler(mut self, sampler: &mut LlamaSampler) -> Self {
143 self.apply_sampler(sampler);
144 self
145 }
146
147 pub fn sample_token(&mut self, seed: u32) -> Result<LlamaToken, TokenSamplingError> {
152 self.apply_sampler(&LlamaSampler::dist(seed));
153 self.selected_token()
154 .ok_or(TokenSamplingError::NoTokenSelected)
155 }
156
157 pub fn sample_token_greedy(&mut self) -> Result<LlamaToken, TokenSamplingError> {
162 self.apply_sampler(&LlamaSampler::greedy());
163 self.selected_token()
164 .ok_or(TokenSamplingError::NoTokenSelected)
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use crate::token::LlamaToken;
171 use crate::token::data::LlamaTokenData;
172
173 use super::LlamaTokenDataArray;
174
175 #[test]
176 fn apply_greedy_sampler_selects_highest_logit() {
177 use crate::sampling::LlamaSampler;
178
179 let mut array = LlamaTokenDataArray::new(
180 vec![
181 LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0),
182 LlamaTokenData::new(LlamaToken::new(1), 5.0, 0.0),
183 LlamaTokenData::new(LlamaToken::new(2), 3.0, 0.0),
184 ],
185 false,
186 );
187
188 array.apply_sampler(&LlamaSampler::greedy());
189
190 assert_eq!(array.selected_token(), Some(LlamaToken::new(1)));
191 }
192
193 #[test]
194 fn with_sampler_builder_pattern() {
195 use crate::sampling::LlamaSampler;
196
197 let array = LlamaTokenDataArray::new(
198 vec![
199 LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0),
200 LlamaTokenData::new(LlamaToken::new(1), 5.0, 0.0),
201 ],
202 false,
203 )
204 .with_sampler(&mut LlamaSampler::greedy());
205
206 assert_eq!(array.selected_token(), Some(LlamaToken::new(1)));
207 }
208
209 #[test]
210 fn sample_token_greedy_returns_highest() {
211 let mut array = LlamaTokenDataArray::new(
212 vec![
213 LlamaTokenData::new(LlamaToken::new(10), 0.1, 0.0),
214 LlamaTokenData::new(LlamaToken::new(20), 9.9, 0.0),
215 ],
216 false,
217 );
218
219 let token = array
220 .sample_token_greedy()
221 .expect("test: greedy sampler should select a token");
222
223 assert_eq!(token, LlamaToken::new(20));
224 }
225
226 #[test]
227 fn from_iter_creates_array_from_iterator() {
228 let array = LlamaTokenDataArray::from_iter(
229 [
230 LlamaTokenData::new(LlamaToken::new(0), 0.0, 0.0),
231 LlamaTokenData::new(LlamaToken::new(1), 1.0, 0.0),
232 LlamaTokenData::new(LlamaToken::new(2), 2.0, 0.0),
233 ],
234 false,
235 );
236
237 assert_eq!(array.data.len(), 3);
238 assert!(!array.sorted);
239 assert!(array.selected.is_none());
240 }
241
242 #[test]
243 fn sample_token_with_seed_selects_a_token() {
244 let mut array = LlamaTokenDataArray::new(
245 vec![
246 LlamaTokenData::new(LlamaToken::new(10), 1.0, 0.0),
247 LlamaTokenData::new(LlamaToken::new(20), 1.0, 0.0),
248 ],
249 false,
250 );
251
252 let token = array
253 .sample_token(42)
254 .expect("test: dist sampler should select a token");
255
256 assert!(token == LlamaToken::new(10) || token == LlamaToken::new(20));
257 }
258
259 #[test]
260 fn selected_token_returns_none_when_no_selection() {
261 let array = LlamaTokenDataArray::new(
262 vec![LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0)],
263 false,
264 );
265
266 assert!(array.selected_token().is_none());
267 }
268
269 #[test]
270 fn selected_token_returns_none_when_index_out_of_bounds() {
271 let array = LlamaTokenDataArray {
272 data: vec![LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0)],
273 selected: Some(5),
274 sorted: false,
275 };
276
277 assert!(array.selected_token().is_none());
278 }
279
280 #[test]
281 fn modify_as_c_llama_token_data_array_copies_when_data_pointer_changes() {
282 let mut array = LlamaTokenDataArray::new(
283 vec![
284 LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0),
285 LlamaTokenData::new(LlamaToken::new(1), 2.0, 0.0),
286 LlamaTokenData::new(LlamaToken::new(2), 3.0, 0.0),
287 ],
288 false,
289 );
290
291 let replacement = [
292 llama_cpp_bindings_sys::llama_token_data {
293 id: 10,
294 logit: 5.0,
295 p: 0.0,
296 },
297 llama_cpp_bindings_sys::llama_token_data {
298 id: 20,
299 logit: 6.0,
300 p: 0.0,
301 },
302 ];
303
304 unsafe {
305 array.modify_as_c_llama_token_data_array(|c_array| {
306 c_array.data = replacement.as_ptr().cast_mut();
307 c_array.size = replacement.len();
308 c_array.selected = 0;
309 });
310 }
311
312 assert_eq!(array.data.len(), 2);
313 assert_eq!(array.data[0].id(), LlamaToken::new(10));
314 assert_eq!(array.data[1].id(), LlamaToken::new(20));
315 assert_eq!(array.selected, Some(0));
316 }
317
318 #[test]
319 fn selected_overflow_uses_negative_one() {
320 let mut array = LlamaTokenDataArray {
321 data: vec![LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0)],
322 selected: Some(usize::MAX),
323 sorted: false,
324 };
325
326 unsafe {
327 array.modify_as_c_llama_token_data_array(|c_array| {
328 assert_eq!(c_array.selected, -1);
329 });
330 }
331 }
332}