Skip to main content

llama_cpp_bindings/token/
data_array.rs

1use std::ptr;
2
3use crate::error::TokenSamplingError;
4use crate::sampling::LlamaSampler;
5use crate::token::data::LlamaTokenData;
6
7use super::LlamaToken;
8
9#[derive(Debug, Clone, PartialEq)]
10pub struct LlamaTokenDataArray {
11    pub data: Vec<LlamaTokenData>,
12    pub selected: Option<usize>,
13    pub sorted: bool,
14}
15
16impl LlamaTokenDataArray {
17    #[must_use]
18    pub const fn new(data: Vec<LlamaTokenData>, sorted: bool) -> Self {
19        Self {
20            data,
21            selected: None,
22            sorted,
23        }
24    }
25
26    pub fn from_iter<TIterator>(data: TIterator, sorted: bool) -> Self
27    where
28        TIterator: IntoIterator<Item = LlamaTokenData>,
29    {
30        Self::new(data.into_iter().collect(), sorted)
31    }
32
33    #[must_use]
34    pub fn selected_token(&self) -> Option<LlamaToken> {
35        self.data.get(self.selected?).map(LlamaTokenData::id)
36    }
37}
38
39impl LlamaTokenDataArray {
40    /// # Panics
41    ///
42    /// Panics if some of the safety conditions are not met. (we cannot check all of them at
43    /// runtime so breaking them is UB)
44    ///
45    /// # Safety
46    ///
47    /// The returned array formed by the data pointer and the length must entirely consist of
48    /// initialized token data and the length must be less than the capacity of this array's data
49    /// buffer.
50    /// If the data is not sorted, sorted must be false.
51    pub unsafe fn modify_as_c_llama_token_data_array<TResult>(
52        &mut self,
53        modify: impl FnOnce(&mut llama_cpp_bindings_sys::llama_token_data_array) -> TResult,
54    ) -> TResult {
55        let size = self.data.len();
56        let data = self
57            .data
58            .as_mut_ptr()
59            .cast::<llama_cpp_bindings_sys::llama_token_data>();
60
61        let mut c_llama_token_data_array = llama_cpp_bindings_sys::llama_token_data_array {
62            data,
63            size,
64            selected: self
65                .selected
66                .and_then(|selected_index| selected_index.try_into().ok())
67                .unwrap_or(-1),
68            sorted: self.sorted,
69        };
70
71        let result = modify(&mut c_llama_token_data_array);
72
73        assert!(c_llama_token_data_array.size <= self.data.capacity());
74        // SAFETY: caller guarantees the returned data and size are valid.
75        unsafe {
76            if !ptr::eq(c_llama_token_data_array.data, data) {
77                ptr::copy(
78                    c_llama_token_data_array.data,
79                    data,
80                    c_llama_token_data_array.size,
81                );
82            }
83            self.data.set_len(c_llama_token_data_array.size);
84        }
85
86        self.sorted = c_llama_token_data_array.sorted;
87        self.selected = c_llama_token_data_array
88            .selected
89            .try_into()
90            .ok()
91            .filter(|&s| s < self.data.len());
92
93        result
94    }
95
96    /// # Panics
97    ///
98    /// Panics if the vendored sampler throws a C++ exception. `llama_sampler_apply` is
99    /// documented to be a pure logit transform and is not expected to throw; if it does
100    /// the failure is propagated as a panic per the crash-fast invariant.
101    pub fn apply_sampler(&mut self, sampler: &LlamaSampler) {
102        unsafe {
103            self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| {
104                let mut out_error: *mut std::os::raw::c_char = ptr::null_mut();
105                let status = llama_cpp_bindings_sys::llama_rs_sampler_apply(
106                    sampler.sampler,
107                    c_llama_token_data_array,
108                    &raw mut out_error,
109                );
110                if status != llama_cpp_bindings_sys::LLAMA_RS_SAMPLER_APPLY_OK {
111                    let message = crate::ffi_error_reader::read_and_free_cpp_error(out_error);
112                    panic!("llama_rs_sampler_apply returned status {status}: {message}");
113                }
114            });
115        }
116    }
117
118    #[must_use]
119    pub fn with_sampler(mut self, sampler: &mut LlamaSampler) -> Self {
120        self.apply_sampler(sampler);
121        self
122    }
123
124    /// # Errors
125    /// Returns [`TokenSamplingError::NoTokenSelected`] if the sampler fails to select a token.
126    pub fn sample_token(&mut self, seed: u32) -> Result<LlamaToken, TokenSamplingError> {
127        self.apply_sampler(&LlamaSampler::dist(seed));
128        self.selected_token()
129            .ok_or(TokenSamplingError::NoTokenSelected)
130    }
131
132    /// # Errors
133    /// Returns [`TokenSamplingError::NoTokenSelected`] if the sampler fails to select a token.
134    pub fn sample_token_greedy(&mut self) -> Result<LlamaToken, TokenSamplingError> {
135        self.apply_sampler(&LlamaSampler::greedy());
136        self.selected_token()
137            .ok_or(TokenSamplingError::NoTokenSelected)
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use crate::token::LlamaToken;
144    use crate::token::data::LlamaTokenData;
145
146    use super::LlamaTokenDataArray;
147
148    #[test]
149    fn apply_greedy_sampler_selects_highest_logit() {
150        use crate::sampling::LlamaSampler;
151
152        let mut array = LlamaTokenDataArray::new(
153            vec![
154                LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0),
155                LlamaTokenData::new(LlamaToken::new(1), 5.0, 0.0),
156                LlamaTokenData::new(LlamaToken::new(2), 3.0, 0.0),
157            ],
158            false,
159        );
160
161        array.apply_sampler(&LlamaSampler::greedy());
162
163        assert_eq!(array.selected_token(), Some(LlamaToken::new(1)));
164    }
165
166    #[test]
167    fn with_sampler_builder_pattern() {
168        use crate::sampling::LlamaSampler;
169
170        let array = LlamaTokenDataArray::new(
171            vec![
172                LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0),
173                LlamaTokenData::new(LlamaToken::new(1), 5.0, 0.0),
174            ],
175            false,
176        )
177        .with_sampler(&mut LlamaSampler::greedy());
178
179        assert_eq!(array.selected_token(), Some(LlamaToken::new(1)));
180    }
181
182    #[test]
183    fn sample_token_greedy_returns_highest() {
184        let mut array = LlamaTokenDataArray::new(
185            vec![
186                LlamaTokenData::new(LlamaToken::new(10), 0.1, 0.0),
187                LlamaTokenData::new(LlamaToken::new(20), 9.9, 0.0),
188            ],
189            false,
190        );
191
192        let token = array
193            .sample_token_greedy()
194            .expect("test: greedy sampler should select a token");
195
196        assert_eq!(token, LlamaToken::new(20));
197    }
198
199    #[test]
200    fn from_iter_creates_array_from_iterator() {
201        let array = LlamaTokenDataArray::from_iter(
202            [
203                LlamaTokenData::new(LlamaToken::new(0), 0.0, 0.0),
204                LlamaTokenData::new(LlamaToken::new(1), 1.0, 0.0),
205                LlamaTokenData::new(LlamaToken::new(2), 2.0, 0.0),
206            ],
207            false,
208        );
209
210        assert_eq!(array.data.len(), 3);
211        assert!(!array.sorted);
212        assert!(array.selected.is_none());
213    }
214
215    #[test]
216    fn sample_token_with_seed_selects_a_token() {
217        let mut array = LlamaTokenDataArray::new(
218            vec![
219                LlamaTokenData::new(LlamaToken::new(10), 1.0, 0.0),
220                LlamaTokenData::new(LlamaToken::new(20), 1.0, 0.0),
221            ],
222            false,
223        );
224
225        let token = array
226            .sample_token(42)
227            .expect("test: dist sampler should select a token");
228
229        assert!(token == LlamaToken::new(10) || token == LlamaToken::new(20));
230    }
231
232    #[test]
233    fn selected_token_returns_none_when_no_selection() {
234        let array = LlamaTokenDataArray::new(
235            vec![LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0)],
236            false,
237        );
238
239        assert!(array.selected_token().is_none());
240    }
241
242    #[test]
243    fn selected_token_returns_none_when_index_out_of_bounds() {
244        let array = LlamaTokenDataArray {
245            data: vec![LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0)],
246            selected: Some(5),
247            sorted: false,
248        };
249
250        assert!(array.selected_token().is_none());
251    }
252
253    #[test]
254    fn modify_as_c_llama_token_data_array_copies_when_data_pointer_changes() {
255        let mut array = LlamaTokenDataArray::new(
256            vec![
257                LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0),
258                LlamaTokenData::new(LlamaToken::new(1), 2.0, 0.0),
259                LlamaTokenData::new(LlamaToken::new(2), 3.0, 0.0),
260            ],
261            false,
262        );
263
264        let replacement = [
265            llama_cpp_bindings_sys::llama_token_data {
266                id: 10,
267                logit: 5.0,
268                p: 0.0,
269            },
270            llama_cpp_bindings_sys::llama_token_data {
271                id: 20,
272                logit: 6.0,
273                p: 0.0,
274            },
275        ];
276
277        unsafe {
278            array.modify_as_c_llama_token_data_array(|c_array| {
279                c_array.data = replacement.as_ptr().cast_mut();
280                c_array.size = replacement.len();
281                c_array.selected = 0;
282            });
283        }
284
285        assert_eq!(array.data.len(), 2);
286        assert_eq!(array.data[0].id(), LlamaToken::new(10));
287        assert_eq!(array.data[1].id(), LlamaToken::new(20));
288        assert_eq!(array.selected, Some(0));
289    }
290
291    #[test]
292    fn selected_overflow_uses_negative_one() {
293        let mut array = LlamaTokenDataArray {
294            data: vec![LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0)],
295            selected: Some(usize::MAX),
296            sorted: false,
297        };
298
299        unsafe {
300            array.modify_as_c_llama_token_data_array(|c_array| {
301                assert_eq!(c_array.selected, -1);
302            });
303        }
304    }
305}