llama_cpp_bindings/token/
data_array.rs1use std::ptr;
3
4use crate::error::TokenSamplingError;
5use crate::sampling::LlamaSampler;
6use crate::token::data::LlamaTokenData;
7
8use super::LlamaToken;
9
10#[derive(Debug, Clone, PartialEq)]
12pub struct LlamaTokenDataArray {
13 pub data: Vec<LlamaTokenData>,
15 pub selected: Option<usize>,
17 pub sorted: bool,
19}
20
21impl LlamaTokenDataArray {
22 #[must_use]
36 pub fn new(data: Vec<LlamaTokenData>, sorted: bool) -> Self {
37 Self {
38 data,
39 selected: None,
40 sorted,
41 }
42 }
43
44 pub fn from_iter<TIterator>(data: TIterator, sorted: bool) -> LlamaTokenDataArray
56 where
57 TIterator: IntoIterator<Item = LlamaTokenData>,
58 {
59 Self::new(data.into_iter().collect(), sorted)
60 }
61
62 #[must_use]
64 pub fn selected_token(&self) -> Option<LlamaToken> {
65 self.data.get(self.selected?).map(LlamaTokenData::id)
66 }
67}
68
69impl LlamaTokenDataArray {
70 pub unsafe fn modify_as_c_llama_token_data_array<TResult>(
84 &mut self,
85 modify: impl FnOnce(&mut llama_cpp_bindings_sys::llama_token_data_array) -> TResult,
86 ) -> TResult {
87 let size = self.data.len();
88 let data = self
89 .data
90 .as_mut_ptr()
91 .cast::<llama_cpp_bindings_sys::llama_token_data>();
92
93 let mut c_llama_token_data_array = llama_cpp_bindings_sys::llama_token_data_array {
94 data,
95 size,
96 selected: self.selected.and_then(|s| s.try_into().ok()).unwrap_or(-1),
97 sorted: self.sorted,
98 };
99
100 let result = modify(&mut c_llama_token_data_array);
101
102 assert!(
103 c_llama_token_data_array.size <= self.data.capacity(),
104 "Size of the returned array exceeds the data buffer's capacity!"
105 );
106 unsafe {
108 if !ptr::eq(c_llama_token_data_array.data, data) {
109 ptr::copy(
110 c_llama_token_data_array.data,
111 data,
112 c_llama_token_data_array.size,
113 );
114 }
115 self.data.set_len(c_llama_token_data_array.size);
116 }
117
118 self.sorted = c_llama_token_data_array.sorted;
119 self.selected = c_llama_token_data_array
120 .selected
121 .try_into()
122 .ok()
123 .filter(|&s| s < self.data.len());
124
125 result
126 }
127
128 pub fn apply_sampler(&mut self, sampler: &LlamaSampler) {
130 unsafe {
131 self.modify_as_c_llama_token_data_array(|c_llama_token_data_array| {
132 llama_cpp_bindings_sys::llama_sampler_apply(
133 sampler.sampler,
134 c_llama_token_data_array,
135 );
136 });
137 }
138 }
139
140 #[must_use]
142 pub fn with_sampler(mut self, sampler: &mut LlamaSampler) -> Self {
143 self.apply_sampler(sampler);
144 self
145 }
146
147 pub fn sample_token(&mut self, seed: u32) -> Result<LlamaToken, TokenSamplingError> {
152 self.apply_sampler(&LlamaSampler::dist(seed));
153 self.selected_token()
154 .ok_or(TokenSamplingError::NoTokenSelected)
155 }
156
157 pub fn sample_token_greedy(&mut self) -> Result<LlamaToken, TokenSamplingError> {
162 self.apply_sampler(&LlamaSampler::greedy());
163 self.selected_token()
164 .ok_or(TokenSamplingError::NoTokenSelected)
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use crate::token::LlamaToken;
171 use crate::token::data::LlamaTokenData;
172
173 use super::LlamaTokenDataArray;
174
175 #[test]
176 fn apply_greedy_sampler_selects_highest_logit() {
177 use crate::sampling::LlamaSampler;
178
179 let mut array = LlamaTokenDataArray::new(
180 vec![
181 LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0),
182 LlamaTokenData::new(LlamaToken::new(1), 5.0, 0.0),
183 LlamaTokenData::new(LlamaToken::new(2), 3.0, 0.0),
184 ],
185 false,
186 );
187
188 array.apply_sampler(&LlamaSampler::greedy());
189
190 assert_eq!(array.selected_token(), Some(LlamaToken::new(1)));
191 }
192
193 #[test]
194 fn with_sampler_builder_pattern() {
195 use crate::sampling::LlamaSampler;
196
197 let array = LlamaTokenDataArray::new(
198 vec![
199 LlamaTokenData::new(LlamaToken::new(0), 1.0, 0.0),
200 LlamaTokenData::new(LlamaToken::new(1), 5.0, 0.0),
201 ],
202 false,
203 )
204 .with_sampler(&mut LlamaSampler::greedy());
205
206 assert_eq!(array.selected_token(), Some(LlamaToken::new(1)));
207 }
208
209 #[test]
210 fn sample_token_greedy_returns_highest() {
211 let mut array = LlamaTokenDataArray::new(
212 vec![
213 LlamaTokenData::new(LlamaToken::new(10), 0.1, 0.0),
214 LlamaTokenData::new(LlamaToken::new(20), 9.9, 0.0),
215 ],
216 false,
217 );
218
219 let token = array
220 .sample_token_greedy()
221 .expect("test: greedy sampler should select a token");
222
223 assert_eq!(token, LlamaToken::new(20));
224 }
225}