Skip to main content

llama_cpp_bindings/token/
data_array.rs

1//! an rusty equivalent of `llama_token_data_array`.
2use std::ptr;
3
4use crate::error::TokenSamplingError;
5use crate::sampling::LlamaSampler;
6use crate::token::data::LlamaTokenData;
7
8use super::LlamaToken;
9
10/// a safe wrapper around `llama_token_data_array`.
11#[derive(Debug, Clone, PartialEq)]
12pub struct LlamaTokenDataArray {
13    /// the underlying data
14    pub data: Vec<LlamaTokenData>,
15    /// the index of the selected token in ``data``
16    pub selected: Option<usize>,
17    /// is the data sorted?
18    pub sorted: bool,
19}
20
21impl LlamaTokenDataArray {
22    /// Create a new `LlamaTokenDataArray` from a vector and whether or not the data is sorted.
23    ///
24    /// ```
25    /// # use llama_cpp_bindings::token::data::LlamaTokenData;
26    /// # use llama_cpp_bindings::token::data_array::LlamaTokenDataArray;
27    /// # use llama_cpp_bindings::token::LlamaToken;
28    /// let array = LlamaTokenDataArray::new(vec![
29    ///         LlamaTokenData::new(LlamaToken(0), 0.0, 0.0),
30    ///         LlamaTokenData::new(LlamaToken(1), 0.1, 0.1)
31    ///    ], false);
32    /// assert_eq!(array.data.len(), 2);
33    /// assert_eq!(array.sorted, false);
34    /// ```
35    #[must_use]
36    pub const fn new(data: Vec<LlamaTokenData>, sorted: bool) -> Self {
37        Self {
38            data,
39            selected: None,
40            sorted,
41        }
42    }
43
44    /// Create a new `LlamaTokenDataArray` from an iterator and whether or not the data is sorted.
45    /// ```
46    /// # use llama_cpp_bindings::token::data::LlamaTokenData;
47    /// # use llama_cpp_bindings::token::data_array::LlamaTokenDataArray;
48    /// # use llama_cpp_bindings::token::LlamaToken;
49    /// let array = LlamaTokenDataArray::from_iter([
50    ///     LlamaTokenData::new(LlamaToken(0), 0.0, 0.0),
51    ///     LlamaTokenData::new(LlamaToken(1), 0.1, 0.1)
52    /// ], false);
53    /// assert_eq!(array.data.len(), 2);
54    /// assert_eq!(array.sorted, false);
55    pub fn from_iter<TIterator>(data: TIterator, sorted: bool) -> Self
56    where
57        TIterator: IntoIterator<Item = LlamaTokenData>,
58    {
59        Self::new(data.into_iter().collect(), sorted)
60    }
61
62    /// Returns the current selected token, if one exists.
63    #[must_use]
64    pub fn selected_token(&self) -> Option<LlamaToken> {
65        self.data.get(self.selected?).map(LlamaTokenData::id)
66    }
67}
68
69impl LlamaTokenDataArray {
70    /// Modify the underlying data as a `llama_token_data_array`. and reconstruct the `LlamaTokenDataArray`.
71    ///
72    /// # Panics
73    ///
74    /// Panics if some of the safety conditions are not met. (we cannot check all of them at
75    /// runtime so breaking them is UB)
76    ///
77    /// # Safety
78    ///
79    /// The returned array formed by the data pointer and the length must entirely consist of
80    /// initialized token data and the length must be less than the capacity of this array's data
81    /// buffer.
82    /// If the data is not sorted, sorted must be false.
83    pub unsafe fn modify_as_c_llama_token_data_array<TResult>(
84        &mut self,
85        modify: impl FnOnce(&mut llama_cpp_bindings_sys::llama_token_data_array) -> TResult,
86    ) -> TResult {
87        let size = self.data.len();
88        let data = self
89            .data
90            .as_mut_ptr()
91            .cast::<llama_cpp_bindings_sys::llama_token_data>();
92
93        let mut c_llama_token_data_array = llama_cpp_bindings_sys::llama_token_data_array {
94            data,
95            size,
96            selected: self
97                .selected
98                .and_then(|selected_index| selected_index.try_into().ok())
99                .unwrap_or(-1),
100            sorted: self.sorted,
101        };
102
103        let result = modify(&mut c_llama_token_data_array);
104
105        assert!(c_llama_token_data_array.size <= self.data.capacity());
106        // SAFETY: caller guarantees the returned data and size are valid.
107        unsafe {
108            if !ptr::eq(c_llama_token_data_array.data, data) {
109                ptr::copy(
110                    c_llama_token_data_array.data,
111                    data,
112                    c_llama_token_data_array.size,
113                );
114            }
115            self.data.set_len(c_llama_token_data_array.size);
116        }
117
118        self.sorted = c_llama_token_data_array.sorted;
119        self.selected = c_llama_token_data_array
120            .selected
121            .try_into()
122            .ok()
123            .filter(|&s| s < self.data.len());
124
125        result
126    }
127
128    /// Modifies the data array by applying a sampler to it.
129    ///
130    /// # Panics
131    ///
132    /// Panics if the vendored sampler throws a C++ exception. `llama_sampler_apply` is
133    /// documented to be a pure logit transform and is not expected to throw; if it does
134    /// the failure is propagated as a panic per the crash-fast invariant.
135    pub fn apply_sampler(&mut self, sampler: &LlamaSampler) {
136        unsafe {
137            self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| {
138                let mut out_error: *mut std::os::raw::c_char = ptr::null_mut();
139                let status = llama_cpp_bindings_sys::llama_rs_sampler_apply(
140                    sampler.sampler,
141                    c_llama_token_data_array,
142                    &raw mut out_error,
143                );
144                if status != llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_APPLY_OK {
145                    let message = crate::ffi_error_reader::read_and_free_cpp_error(out_error);
146                    panic!("llama_rs_sampler_apply returned status {status}: {message}");
147                }
148            });
149        }
150    }
151
152    /// Modifies the data array by applying a sampler to it
153    #[must_use]
154    pub fn with_sampler(mut self, sampler: &mut LlamaSampler) -> Self {
155        self.apply_sampler(sampler);
156        self
157    }
158
159    /// Randomly selects a token from the candidates based on their probabilities.
160    ///
161    /// # Errors
162    /// Returns [`TokenSamplingError::NoTokenSelected`] if the sampler fails to select a token.
163    pub fn sample_token(&mut self, seed: u32) -> Result<LlamaToken, TokenSamplingError> {
164        self.apply_sampler(&LlamaSampler::dist(seed));
165        self.selected_token()
166            .ok_or(TokenSamplingError::NoTokenSelected)
167    }
168
169    /// Selects the token with the highest probability.
170    ///
171    /// # Errors
172    /// Returns [`TokenSamplingError::NoTokenSelected`] if the sampler fails to select a token.
173    pub fn sample_token_greedy(&mut self) -> Result<LlamaToken, TokenSamplingError> {
174        self.apply_sampler(&LlamaSampler::greedy());
175        self.selected_token()
176            .ok_or(TokenSamplingError::NoTokenSelected)
177    }
178}
179
180#[cfg(test)]
181mod tests {
182    use crate::token::LlamaToken;
183    use crate::token::data::LlamaTokenData;
184
185    use super::LlamaTokenDataArray;
186
187    #[test]
188    fn apply_greedy_sampler_selects_highest_logit() {
189        use crate::sampling::LlamaSampler;
190
191        let mut array = LlamaTokenDataArray::new(
192            vec![
193                LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0),
194                LlamaTokenData::new(LlamaToken::new(1), 5.0, 0.0),
195                LlamaTokenData::new(LlamaToken::new(2), 3.0, 0.0),
196            ],
197            false,
198        );
199
200        array.apply_sampler(&LlamaSampler::greedy());
201
202        assert_eq!(array.selected_token(), Some(LlamaToken::new(1)));
203    }
204
205    #[test]
206    fn with_sampler_builder_pattern() {
207        use crate::sampling::LlamaSampler;
208
209        let array = LlamaTokenDataArray::new(
210            vec![
211                LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0),
212                LlamaTokenData::new(LlamaToken::new(1), 5.0, 0.0),
213            ],
214            false,
215        )
216        .with_sampler(&mut LlamaSampler::greedy());
217
218        assert_eq!(array.selected_token(), Some(LlamaToken::new(1)));
219    }
220
221    #[test]
222    fn sample_token_greedy_returns_highest() {
223        let mut array = LlamaTokenDataArray::new(
224            vec![
225                LlamaTokenData::new(LlamaToken::new(10), 0.1, 0.0),
226                LlamaTokenData::new(LlamaToken::new(20), 9.9, 0.0),
227            ],
228            false,
229        );
230
231        let token = array
232            .sample_token_greedy()
233            .expect("test: greedy sampler should select a token");
234
235        assert_eq!(token, LlamaToken::new(20));
236    }
237
238    #[test]
239    fn from_iter_creates_array_from_iterator() {
240        let array = LlamaTokenDataArray::from_iter(
241            [
242                LlamaTokenData::new(LlamaToken::new(0), 0.0, 0.0),
243                LlamaTokenData::new(LlamaToken::new(1), 1.0, 0.0),
244                LlamaTokenData::new(LlamaToken::new(2), 2.0, 0.0),
245            ],
246            false,
247        );
248
249        assert_eq!(array.data.len(), 3);
250        assert!(!array.sorted);
251        assert!(array.selected.is_none());
252    }
253
254    #[test]
255    fn sample_token_with_seed_selects_a_token() {
256        let mut array = LlamaTokenDataArray::new(
257            vec![
258                LlamaTokenData::new(LlamaToken::new(10), 1.0, 0.0),
259                LlamaTokenData::new(LlamaToken::new(20), 1.0, 0.0),
260            ],
261            false,
262        );
263
264        let token = array
265            .sample_token(42)
266            .expect("test: dist sampler should select a token");
267
268        assert!(token == LlamaToken::new(10) || token == LlamaToken::new(20));
269    }
270
271    #[test]
272    fn selected_token_returns_none_when_no_selection() {
273        let array = LlamaTokenDataArray::new(
274            vec![LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0)],
275            false,
276        );
277
278        assert!(array.selected_token().is_none());
279    }
280
281    #[test]
282    fn selected_token_returns_none_when_index_out_of_bounds() {
283        let array = LlamaTokenDataArray {
284            data: vec![LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0)],
285            selected: Some(5),
286            sorted: false,
287        };
288
289        assert!(array.selected_token().is_none());
290    }
291
292    #[test]
293    fn modify_as_c_llama_token_data_array_copies_when_data_pointer_changes() {
294        let mut array = LlamaTokenDataArray::new(
295            vec![
296                LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0),
297                LlamaTokenData::new(LlamaToken::new(1), 2.0, 0.0),
298                LlamaTokenData::new(LlamaToken::new(2), 3.0, 0.0),
299            ],
300            false,
301        );
302
303        let replacement = [
304            llama_cpp_bindings_sys::llama_token_data {
305                id: 10,
306                logit: 5.0,
307                p: 0.0,
308            },
309            llama_cpp_bindings_sys::llama_token_data {
310                id: 20,
311                logit: 6.0,
312                p: 0.0,
313            },
314        ];
315
316        unsafe {
317            array.modify_as_c_llama_token_data_array(|c_array| {
318                c_array.data = replacement.as_ptr().cast_mut();
319                c_array.size = replacement.len();
320                c_array.selected = 0;
321            });
322        }
323
324        assert_eq!(array.data.len(), 2);
325        assert_eq!(array.data[0].id(), LlamaToken::new(10));
326        assert_eq!(array.data[1].id(), LlamaToken::new(20));
327        assert_eq!(array.selected, Some(0));
328    }
329
330    #[test]
331    fn selected_overflow_uses_negative_one() {
332        let mut array = LlamaTokenDataArray {
333            data: vec![LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0)],
334            selected: Some(usize::MAX),
335            sorted: false,
336        };
337
338        unsafe {
339            array.modify_as_c_llama_token_data_array(|c_array| {
340                assert_eq!(c_array.selected, -1);
341            });
342        }
343    }
344}