llama_cpp_bindings/token/
data_array.rs1use std::ptr;
3
4use crate::error::TokenSamplingError;
5use crate::sampling::LlamaSampler;
6use crate::token::data::LlamaTokenData;
7
8use super::LlamaToken;
9
10#[derive(Debug, Clone, PartialEq)]
12pub struct LlamaTokenDataArray {
13 pub data: Vec<LlamaTokenData>,
15 pub selected: Option<usize>,
17 pub sorted: bool,
19}
20
21impl LlamaTokenDataArray {
22 #[must_use]
36 pub const fn new(data: Vec<LlamaTokenData>, sorted: bool) -> Self {
37 Self {
38 data,
39 selected: None,
40 sorted,
41 }
42 }
43
44 pub fn from_iter<TIterator>(data: TIterator, sorted: bool) -> Self
56 where
57 TIterator: IntoIterator<Item = LlamaTokenData>,
58 {
59 Self::new(data.into_iter().collect(), sorted)
60 }
61
62 #[must_use]
64 pub fn selected_token(&self) -> Option<LlamaToken> {
65 self.data.get(self.selected?).map(LlamaTokenData::id)
66 }
67}
68
69impl LlamaTokenDataArray {
70 pub unsafe fn modify_as_c_llama_token_data_array<TResult>(
84 &mut self,
85 modify: impl FnOnce(&mut llama_cpp_bindings_sys::llama_token_data_array) -> TResult,
86 ) -> TResult {
87 let size = self.data.len();
88 let data = self
89 .data
90 .as_mut_ptr()
91 .cast::<llama_cpp_bindings_sys::llama_token_data>();
92
93 let mut c_llama_token_data_array = llama_cpp_bindings_sys::llama_token_data_array {
94 data,
95 size,
96 selected: self
97 .selected
98 .and_then(|selected_index| selected_index.try_into().ok())
99 .unwrap_or(-1),
100 sorted: self.sorted,
101 };
102
103 let result = modify(&mut c_llama_token_data_array);
104
105 assert!(c_llama_token_data_array.size <= self.data.capacity());
106 unsafe {
108 if !ptr::eq(c_llama_token_data_array.data, data) {
109 ptr::copy(
110 c_llama_token_data_array.data,
111 data,
112 c_llama_token_data_array.size,
113 );
114 }
115 self.data.set_len(c_llama_token_data_array.size);
116 }
117
118 self.sorted = c_llama_token_data_array.sorted;
119 self.selected = c_llama_token_data_array
120 .selected
121 .try_into()
122 .ok()
123 .filter(|&s| s < self.data.len());
124
125 result
126 }
127
128 pub fn apply_sampler(&mut self, sampler: &LlamaSampler) {
136 unsafe {
137 self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| {
138 let mut out_error: *mut std::os::raw::c_char = ptr::null_mut();
139 let status = llama_cpp_bindings_sys::llama_rs_sampler_apply(
140 sampler.sampler,
141 c_llama_token_data_array,
142 &raw mut out_error,
143 );
144 if status != llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_APPLY_OK {
145 let message = crate::ffi_error_reader::read_and_free_cpp_error(out_error);
146 panic!("llama_rs_sampler_apply returned status {status}: {message}");
147 }
148 });
149 }
150 }
151
152 #[must_use]
154 pub fn with_sampler(mut self, sampler: &mut LlamaSampler) -> Self {
155 self.apply_sampler(sampler);
156 self
157 }
158
159 pub fn sample_token(&mut self, seed: u32) -> Result<LlamaToken, TokenSamplingError> {
164 self.apply_sampler(&LlamaSampler::dist(seed));
165 self.selected_token()
166 .ok_or(TokenSamplingError::NoTokenSelected)
167 }
168
169 pub fn sample_token_greedy(&mut self) -> Result<LlamaToken, TokenSamplingError> {
174 self.apply_sampler(&LlamaSampler::greedy());
175 self.selected_token()
176 .ok_or(TokenSamplingError::NoTokenSelected)
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use crate::token::LlamaToken;
183 use crate::token::data::LlamaTokenData;
184
185 use super::LlamaTokenDataArray;
186
187 #[test]
188 fn apply_greedy_sampler_selects_highest_logit() {
189 use crate::sampling::LlamaSampler;
190
191 let mut array = LlamaTokenDataArray::new(
192 vec![
193 LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0),
194 LlamaTokenData::new(LlamaToken::new(1), 5.0, 0.0),
195 LlamaTokenData::new(LlamaToken::new(2), 3.0, 0.0),
196 ],
197 false,
198 );
199
200 array.apply_sampler(&LlamaSampler::greedy());
201
202 assert_eq!(array.selected_token(), Some(LlamaToken::new(1)));
203 }
204
205 #[test]
206 fn with_sampler_builder_pattern() {
207 use crate::sampling::LlamaSampler;
208
209 let array = LlamaTokenDataArray::new(
210 vec![
211 LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0),
212 LlamaTokenData::new(LlamaToken::new(1), 5.0, 0.0),
213 ],
214 false,
215 )
216 .with_sampler(&mut LlamaSampler::greedy());
217
218 assert_eq!(array.selected_token(), Some(LlamaToken::new(1)));
219 }
220
221 #[test]
222 fn sample_token_greedy_returns_highest() {
223 let mut array = LlamaTokenDataArray::new(
224 vec![
225 LlamaTokenData::new(LlamaToken::new(10), 0.1, 0.0),
226 LlamaTokenData::new(LlamaToken::new(20), 9.9, 0.0),
227 ],
228 false,
229 );
230
231 let token = array
232 .sample_token_greedy()
233 .expect("test: greedy sampler should select a token");
234
235 assert_eq!(token, LlamaToken::new(20));
236 }
237
238 #[test]
239 fn from_iter_creates_array_from_iterator() {
240 let array = LlamaTokenDataArray::from_iter(
241 [
242 LlamaTokenData::new(LlamaToken::new(0), 0.0, 0.0),
243 LlamaTokenData::new(LlamaToken::new(1), 1.0, 0.0),
244 LlamaTokenData::new(LlamaToken::new(2), 2.0, 0.0),
245 ],
246 false,
247 );
248
249 assert_eq!(array.data.len(), 3);
250 assert!(!array.sorted);
251 assert!(array.selected.is_none());
252 }
253
254 #[test]
255 fn sample_token_with_seed_selects_a_token() {
256 let mut array = LlamaTokenDataArray::new(
257 vec![
258 LlamaTokenData::new(LlamaToken::new(10), 1.0, 0.0),
259 LlamaTokenData::new(LlamaToken::new(20), 1.0, 0.0),
260 ],
261 false,
262 );
263
264 let token = array
265 .sample_token(42)
266 .expect("test: dist sampler should select a token");
267
268 assert!(token == LlamaToken::new(10) || token == LlamaToken::new(20));
269 }
270
271 #[test]
272 fn selected_token_returns_none_when_no_selection() {
273 let array = LlamaTokenDataArray::new(
274 vec![LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0)],
275 false,
276 );
277
278 assert!(array.selected_token().is_none());
279 }
280
281 #[test]
282 fn selected_token_returns_none_when_index_out_of_bounds() {
283 let array = LlamaTokenDataArray {
284 data: vec![LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0)],
285 selected: Some(5),
286 sorted: false,
287 };
288
289 assert!(array.selected_token().is_none());
290 }
291
292 #[test]
293 fn modify_as_c_llama_token_data_array_copies_when_data_pointer_changes() {
294 let mut array = LlamaTokenDataArray::new(
295 vec![
296 LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0),
297 LlamaTokenData::new(LlamaToken::new(1), 2.0, 0.0),
298 LlamaTokenData::new(LlamaToken::new(2), 3.0, 0.0),
299 ],
300 false,
301 );
302
303 let replacement = [
304 llama_cpp_bindings_sys::llama_token_data {
305 id: 10,
306 logit: 5.0,
307 p: 0.0,
308 },
309 llama_cpp_bindings_sys::llama_token_data {
310 id: 20,
311 logit: 6.0,
312 p: 0.0,
313 },
314 ];
315
316 unsafe {
317 array.modify_as_c_llama_token_data_array(|c_array| {
318 c_array.data = replacement.as_ptr().cast_mut();
319 c_array.size = replacement.len();
320 c_array.selected = 0;
321 });
322 }
323
324 assert_eq!(array.data.len(), 2);
325 assert_eq!(array.data[0].id(), LlamaToken::new(10));
326 assert_eq!(array.data[1].id(), LlamaToken::new(20));
327 assert_eq!(array.selected, Some(0));
328 }
329
330 #[test]
331 fn selected_overflow_uses_negative_one() {
332 let mut array = LlamaTokenDataArray {
333 data: vec![LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0)],
334 selected: Some(usize::MAX),
335 sorted: false,
336 };
337
338 unsafe {
339 array.modify_as_c_llama_token_data_array(|c_array| {
340 assert_eq!(c_array.selected, -1);
341 });
342 }
343 }
344}