llama_cpp_bindings/token/
data_array.rs1use std::ptr;
2
3use crate::error::TokenSamplingError;
4use crate::sampling::LlamaSampler;
5use crate::token::data::LlamaTokenData;
6
7use super::LlamaToken;
8
9#[derive(Debug, Clone, PartialEq)]
10pub struct LlamaTokenDataArray {
11 pub data: Vec<LlamaTokenData>,
12 pub selected: Option<usize>,
13 pub sorted: bool,
14}
15
16impl LlamaTokenDataArray {
17 #[must_use]
18 pub const fn new(data: Vec<LlamaTokenData>, sorted: bool) -> Self {
19 Self {
20 data,
21 selected: None,
22 sorted,
23 }
24 }
25
26 pub fn from_iter<TIterator>(data: TIterator, sorted: bool) -> Self
27 where
28 TIterator: IntoIterator<Item = LlamaTokenData>,
29 {
30 Self::new(data.into_iter().collect(), sorted)
31 }
32
33 #[must_use]
34 pub fn selected_token(&self) -> Option<LlamaToken> {
35 self.data.get(self.selected?).map(LlamaTokenData::id)
36 }
37}
38
39impl LlamaTokenDataArray {
40 pub unsafe fn modify_as_c_llama_token_data_array<TResult>(
52 &mut self,
53 modify: impl FnOnce(&mut llama_cpp_bindings_sys::llama_token_data_array) -> TResult,
54 ) -> TResult {
55 let size = self.data.len();
56 let data = self
57 .data
58 .as_mut_ptr()
59 .cast::<llama_cpp_bindings_sys::llama_token_data>();
60
61 let mut c_llama_token_data_array = llama_cpp_bindings_sys::llama_token_data_array {
62 data,
63 size,
64 selected: self
65 .selected
66 .and_then(|selected_index| selected_index.try_into().ok())
67 .unwrap_or(-1),
68 sorted: self.sorted,
69 };
70
71 let result = modify(&mut c_llama_token_data_array);
72
73 assert!(c_llama_token_data_array.size <= self.data.capacity());
74 unsafe {
76 if !ptr::eq(c_llama_token_data_array.data, data) {
77 ptr::copy(
78 c_llama_token_data_array.data,
79 data,
80 c_llama_token_data_array.size,
81 );
82 }
83 self.data.set_len(c_llama_token_data_array.size);
84 }
85
86 self.sorted = c_llama_token_data_array.sorted;
87 self.selected = c_llama_token_data_array
88 .selected
89 .try_into()
90 .ok()
91 .filter(|&s| s < self.data.len());
92
93 result
94 }
95
96 pub fn apply_sampler(&mut self, sampler: &LlamaSampler) {
102 unsafe {
103 self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| {
104 let mut out_error: *mut std::os::raw::c_char = ptr::null_mut();
105 let status = llama_cpp_bindings_sys::llama_rs_sampler_apply(
106 sampler.sampler,
107 c_llama_token_data_array,
108 &raw mut out_error,
109 );
110 if status != llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_APPLY_OK {
111 let message = crate::ffi_error_reader::read_and_free_cpp_error(out_error);
112 panic!("llama_rs_sampler_apply returned status {status}: {message}");
113 }
114 });
115 }
116 }
117
118 #[must_use]
119 pub fn with_sampler(mut self, sampler: &mut LlamaSampler) -> Self {
120 self.apply_sampler(sampler);
121 self
122 }
123
124 pub fn sample_token(&mut self, seed: u32) -> Result<LlamaToken, TokenSamplingError> {
127 self.apply_sampler(&LlamaSampler::dist(seed));
128 self.selected_token()
129 .ok_or(TokenSamplingError::NoTokenSelected)
130 }
131
132 pub fn sample_token_greedy(&mut self) -> Result<LlamaToken, TokenSamplingError> {
135 self.apply_sampler(&LlamaSampler::greedy());
136 self.selected_token()
137 .ok_or(TokenSamplingError::NoTokenSelected)
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use crate::token::LlamaToken;
144 use crate::token::data::LlamaTokenData;
145
146 use super::LlamaTokenDataArray;
147
148 #[test]
149 fn apply_greedy_sampler_selects_highest_logit() {
150 use crate::sampling::LlamaSampler;
151
152 let mut array = LlamaTokenDataArray::new(
153 vec![
154 LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0),
155 LlamaTokenData::new(LlamaToken::new(1), 5.0, 0.0),
156 LlamaTokenData::new(LlamaToken::new(2), 3.0, 0.0),
157 ],
158 false,
159 );
160
161 array.apply_sampler(&LlamaSampler::greedy());
162
163 assert_eq!(array.selected_token(), Some(LlamaToken::new(1)));
164 }
165
166 #[test]
167 fn with_sampler_builder_pattern() {
168 use crate::sampling::LlamaSampler;
169
170 let array = LlamaTokenDataArray::new(
171 vec![
172 LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0),
173 LlamaTokenData::new(LlamaToken::new(1), 5.0, 0.0),
174 ],
175 false,
176 )
177 .with_sampler(&mut LlamaSampler::greedy());
178
179 assert_eq!(array.selected_token(), Some(LlamaToken::new(1)));
180 }
181
182 #[test]
183 fn sample_token_greedy_returns_highest() {
184 let mut array = LlamaTokenDataArray::new(
185 vec![
186 LlamaTokenData::new(LlamaToken::new(10), 0.1, 0.0),
187 LlamaTokenData::new(LlamaToken::new(20), 9.9, 0.0),
188 ],
189 false,
190 );
191
192 let token = array
193 .sample_token_greedy()
194 .expect("test: greedy sampler should select a token");
195
196 assert_eq!(token, LlamaToken::new(20));
197 }
198
199 #[test]
200 fn from_iter_creates_array_from_iterator() {
201 let array = LlamaTokenDataArray::from_iter(
202 [
203 LlamaTokenData::new(LlamaToken::new(0), 0.0, 0.0),
204 LlamaTokenData::new(LlamaToken::new(1), 1.0, 0.0),
205 LlamaTokenData::new(LlamaToken::new(2), 2.0, 0.0),
206 ],
207 false,
208 );
209
210 assert_eq!(array.data.len(), 3);
211 assert!(!array.sorted);
212 assert!(array.selected.is_none());
213 }
214
215 #[test]
216 fn sample_token_with_seed_selects_a_token() {
217 let mut array = LlamaTokenDataArray::new(
218 vec![
219 LlamaTokenData::new(LlamaToken::new(10), 1.0, 0.0),
220 LlamaTokenData::new(LlamaToken::new(20), 1.0, 0.0),
221 ],
222 false,
223 );
224
225 let token = array
226 .sample_token(42)
227 .expect("test: dist sampler should select a token");
228
229 assert!(token == LlamaToken::new(10) || token == LlamaToken::new(20));
230 }
231
232 #[test]
233 fn selected_token_returns_none_when_no_selection() {
234 let array = LlamaTokenDataArray::new(
235 vec![LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0)],
236 false,
237 );
238
239 assert!(array.selected_token().is_none());
240 }
241
242 #[test]
243 fn selected_token_returns_none_when_index_out_of_bounds() {
244 let array = LlamaTokenDataArray {
245 data: vec![LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0)],
246 selected: Some(5),
247 sorted: false,
248 };
249
250 assert!(array.selected_token().is_none());
251 }
252
253 #[test]
254 fn modify_as_c_llama_token_data_array_copies_when_data_pointer_changes() {
255 let mut array = LlamaTokenDataArray::new(
256 vec![
257 LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0),
258 LlamaTokenData::new(LlamaToken::new(1), 2.0, 0.0),
259 LlamaTokenData::new(LlamaToken::new(2), 3.0, 0.0),
260 ],
261 false,
262 );
263
264 let replacement = [
265 llama_cpp_bindings_sys::llama_token_data {
266 id: 10,
267 logit: 5.0,
268 p: 0.0,
269 },
270 llama_cpp_bindings_sys::llama_token_data {
271 id: 20,
272 logit: 6.0,
273 p: 0.0,
274 },
275 ];
276
277 unsafe {
278 array.modify_as_c_llama_token_data_array(|c_array| {
279 c_array.data = replacement.as_ptr().cast_mut();
280 c_array.size = replacement.len();
281 c_array.selected = 0;
282 });
283 }
284
285 assert_eq!(array.data.len(), 2);
286 assert_eq!(array.data[0].id(), LlamaToken::new(10));
287 assert_eq!(array.data[1].id(), LlamaToken::new(20));
288 assert_eq!(array.selected, Some(0));
289 }
290
291 #[test]
292 fn selected_overflow_uses_negative_one() {
293 let mut array = LlamaTokenDataArray {
294 data: vec![LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0)],
295 selected: Some(usize::MAX),
296 sorted: false,
297 };
298
299 unsafe {
300 array.modify_as_c_llama_token_data_array(|c_array| {
301 assert_eq!(c_array.selected, -1);
302 });
303 }
304 }
305}