1use rten_simd::ops::{MaskOps, NumOps};
7use rten_simd::{Isa, Simd, SimdIterable, SimdOp};
8use rten_vecmath::Softmax;
9
10use crate::Logits;
11use crate::generator::TokenId;
12
13pub trait LogitsFilter {
21 fn filter(&self, logits: Logits, prev_tokens: &[TokenId]) -> Logits;
25}
26
27struct TokenIdFilter<F: Fn(TokenId) -> bool> {
28 predicate: F,
29}
30
31impl<F: Fn(TokenId) -> bool> LogitsFilter for TokenIdFilter<F> {
32 fn filter(&self, logits: Logits, _prev_tokens: &[TokenId]) -> Logits {
33 let (logits, indices) = logits.into_logits_indices();
34 let (new_logits, new_indices) = logits
35 .into_iter()
36 .zip(indices)
37 .filter(|(_logit, token_id)| (self.predicate)(*token_id))
38 .unzip();
39 Logits::sparse(new_logits, new_indices)
40 }
41}
42
43pub fn token_id_filter<F: Fn(TokenId) -> bool>(predicate: F) -> impl LogitsFilter {
46 TokenIdFilter { predicate }
47}
48
49pub struct Temperature {
54 temperature: f32,
55}
56
57impl Temperature {
58 pub fn new(temperature: f32) -> Self {
61 assert!(temperature >= 0.);
62 Self { temperature }
63 }
64}
65
66impl LogitsFilter for Temperature {
67 fn filter(&self, logits: Logits, _prev_tokens: &[TokenId]) -> Logits {
68 if self.temperature == 1.0 {
69 return logits;
70 }
71 let (mut logits, indices) = logits.into_logits_indices();
72 let inv_temp = 1. / self.temperature;
73 for x in &mut logits {
74 *x *= inv_temp;
75 }
76 Logits::sparse(logits, indices)
77 }
78}
79
80pub struct Chain {
82 filters: Vec<Box<dyn LogitsFilter>>,
83}
84
85impl Default for Chain {
86 fn default() -> Self {
87 Self::new()
88 }
89}
90
91impl Chain {
92 pub fn new() -> Self {
96 Self {
97 filters: Vec::new(),
98 }
99 }
100
101 pub fn append<F: LogitsFilter + 'static>(mut self, filter: F) -> Self {
103 self.filters.push(Box::new(filter));
104 self
105 }
106
107 pub fn temperature(self, temp: f32) -> Self {
109 self.append(Temperature::new(temp))
110 }
111
112 pub fn top_p(self, p: f32) -> Self {
114 self.append(TopP::new(p))
115 }
116
117 pub fn top_k(self, k: usize) -> Self {
119 self.append(TopK::new(k))
120 }
121}
122
123impl LogitsFilter for Chain {
124 fn filter(&self, logits: Logits, prev_tokens: &[TokenId]) -> Logits {
125 self.filters
126 .iter()
127 .fold(logits, |logits, f| f.filter(logits, prev_tokens))
128 }
129}
130
131pub struct TopK {
133 k: usize,
134}
135
136impl TopK {
137 pub fn new(k: usize) -> Self {
138 Self { k }
139 }
140}
141
142impl LogitsFilter for TopK {
143 fn filter(&self, logits: Logits, _prev_tokens: &[TokenId]) -> Logits {
144 if logits.is_empty() {
145 return logits;
146 }
147
148 let (logits, indices) = logits.into_logits_indices();
149
150 let topk = SimdTopK {
151 k: self.k,
152 indices: &indices,
153 logits: &logits,
154 }
155 .dispatch();
156
157 let (indices, logits) = topk.into_iter().unzip();
158 Logits::sparse(logits, indices)
159 }
160}
161
162struct SimdTopK<'a> {
166 k: usize,
167 logits: &'a [f32],
168 indices: &'a [u32],
169}
170
171impl<'a> SimdOp for SimdTopK<'a> {
172 type Output = Vec<(u32, f32)>;
173
174 #[inline(always)]
175 fn eval<I: Isa>(self, isa: I) -> Self::Output {
176 let SimdTopK { logits, indices, k } = self;
177
178 let ops = isa.f32();
179 let mask_ops = isa.m32();
180 let compare_gt = |a: f32, b: f32| a.total_cmp(&b).reverse();
181
182 let mut topk: Vec<(u32, f32)> = indices
184 .iter()
185 .zip(logits)
186 .take(k)
187 .map(|(i, logit)| (*i, *logit))
188 .collect();
189 topk.sort_by(|a, b| compare_gt(a.1, b.1));
190
191 if k == 0 || logits.len() == k {
192 return topk;
193 }
194
195 let mut kth_logit = topk.last().unwrap().1;
196 let mut kth_logit_vec = ops.splat(kth_logit);
197
198 let mut update_topk = |kth_logit: &mut f32, index: u32, logit: f32| {
199 if logit > *kth_logit {
200 *topk.last_mut().unwrap() = (index, logit);
201 topk.sort_by(|a, b| compare_gt(a.1, b.1));
202 *kth_logit = topk.last().unwrap().1;
203 }
204 };
205
206 let indices = &indices[k..];
207 let logits = &logits[k..];
208
209 let mut indices_iter = indices.chunks_exact(ops.len());
212 let mut logits_iter = logits.simd_iter(ops);
213 for (index_chunk, logits_vec) in indices_iter.by_ref().zip(logits_iter.by_ref()) {
214 if mask_ops.any(ops.gt(logits_vec, kth_logit_vec)) {
215 for (&index, logit) in index_chunk.iter().zip(logits_vec.to_array()) {
216 update_topk(&mut kth_logit, index, logit);
217 }
218 kth_logit_vec = ops.splat(kth_logit);
219 }
220 }
221
222 if let Some((logits_tail, _mask)) = logits_iter.tail() {
224 let indices_tail = indices_iter.remainder();
225 for (&index, logit) in indices_tail.iter().zip(logits_tail.to_array()) {
226 update_topk(&mut kth_logit, index, logit);
227 }
228 }
229
230 topk
231 }
232}
233
234pub struct TopP {
239 cumulative_prob: f32,
240 normalize: bool,
241}
242
243impl TopP {
244 pub fn new(cumulative_prob: f32) -> Self {
245 Self {
246 cumulative_prob,
247 normalize: false,
248 }
249 }
250
251 pub fn normalize(mut self, normalize: bool) -> Self {
256 self.normalize = normalize;
257 self
258 }
259}
260
261impl LogitsFilter for TopP {
262 fn filter(&self, logits: Logits, _prev_tokens: &[TokenId]) -> Logits {
263 if self.cumulative_prob == 1.0 {
264 return logits;
265 }
266
267 let (mut logits, indices) = logits.into_logits_indices();
268
269 if self.normalize {
271 Softmax::new_mut(&mut logits).dispatch();
272 }
273
274 let mut pairs: Vec<(f32, TokenId)> = logits.into_iter().zip(indices).collect();
277 pairs.sort_by(|a, b| {
278 let (a_prob, _a_id) = a;
279 let (b_prob, _b_id) = b;
280 a_prob.total_cmp(b_prob).reverse()
281 });
282
283 let mut cum_prob = 0.;
287 let mut k = 0;
288 let threshold = self.cumulative_prob.max(f32::MIN_POSITIVE);
289 while cum_prob < threshold && k < pairs.len() {
290 cum_prob += pairs[k].0;
291 k += 1;
292 }
293 pairs.truncate(k);
294
295 let (logits, indices) = pairs.into_iter().unzip();
297 Logits::sparse(logits, indices)
298 }
299}
300
301#[derive(Default)]
303pub struct Sort {
304 _private: (),
305}
306
307impl Sort {
308 pub fn new() -> Self {
309 Sort { _private: () }
310 }
311}
312
313impl LogitsFilter for Sort {
314 fn filter(&self, logits: Logits, _prev_tokens: &[TokenId]) -> Logits {
315 let (logits, indices) = logits.into_logits_indices();
316
317 let mut pairs: Vec<(f32, TokenId)> = logits.into_iter().zip(indices).collect();
318 pairs.sort_by(|(a_val, _), (b_val, _)| a_val.total_cmp(b_val).reverse());
319
320 let (logits, indices) = pairs.into_iter().unzip();
321 Logits::sparse(logits, indices)
322 }
323}
324
325#[cfg(test)]
326mod tests {
327 use super::{Chain, Logits, LogitsFilter, Sort, Temperature, TopK, TopP, token_id_filter};
328
329 #[test]
330 fn test_token_id_filter() {
331 let logits = Logits::dense(vec![0., 1., 2., 3., 4.]);
332 let filter = token_id_filter(|id| id % 2 == 0);
333 let output = filter.filter(logits, &[]);
334 assert_eq!(output.logits(), &[0., 2., 4.]);
335 assert_eq!(output.indices(), &[0, 2, 4]);
336 }
337
338 #[test]
339 fn test_temperature() {
340 let logits = Logits::dense(vec![0., 1., 2., 3., 4.]);
341 let filter = Temperature::new(2.0);
342 let output = filter.filter(logits, &[]);
343 assert_eq!(output.logits(), &[0., 0.5, 1., 1.5, 2.0]);
344 assert_eq!(output.indices(), &[0, 1, 2, 3, 4]);
345 }
346
347 #[test]
348 fn test_chain() {
349 let logits = Logits::dense(vec![0., 1., 2., 3., 4.]);
350 let chain = Chain::new()
351 .append(token_id_filter(|id| id % 2 == 0))
352 .append(token_id_filter(|id| id > 0));
353 let output = chain.filter(logits, &[]);
354 assert_eq!(output.logits(), &[2., 4.]);
355 assert_eq!(output.indices(), &[2, 4]);
356 }
357
358 fn reference_topk(logits: &Logits, k: usize) -> Logits {
359 let mut pairs: Vec<(u32, f32)> = logits
360 .indices()
361 .iter()
362 .zip(logits.logits())
363 .map(|(idx, val)| (*idx, *val))
364 .collect();
365 pairs.sort_by(|a, b| a.1.total_cmp(&b.1).reverse());
366 pairs.truncate(k);
367 let (indices, logits) = pairs.into_iter().unzip();
368 Logits::sparse(logits, indices)
369 }
370
371 #[test]
372 fn test_top_k() {
373 let sort = |logits| Sort::new().filter(logits, &[]);
374
375 let logits = Logits::dense(vec![
376 -1., 1., 0., 2., -2., 10., -3., 2., 1., 0., 20., -5., 5., 0.1, -0.2, 0.2, 0.1,
377 ]);
378 assert_eq!(logits.len(), 17);
381
382 for k in 0..=logits.len() {
384 let topk = TopK::new(k).filter(logits.clone(), &[]);
385 let sorted_topk = sort(topk);
386 let expected_topk = reference_topk(&logits, k);
387 assert_eq!(sorted_topk.logits(), expected_topk.logits());
388 assert_eq!(sorted_topk.indices(), expected_topk.indices());
389 }
390
391 let logits = Logits::dense(vec![]);
393 let topk = TopK::new(1).filter(logits, &[]);
394 assert!(topk.is_empty());
395 }
396
397 #[test]
398 fn test_top_p() {
399 let logits = Logits::dense(vec![0.1, 0.25, 0.15, 0.5]);
403 let all_logits = TopP::new(1.0).normalize(false).filter(logits.clone(), &[]);
404 assert_eq!(logits, all_logits);
405
406 let top_p_logits = TopP::new(0.5).normalize(false).filter(logits.clone(), &[]);
407 assert_eq!(top_p_logits.logits(), &[0.5]);
408 assert_eq!(top_p_logits.indices(), &[3]);
409
410 let top_p_logits = TopP::new(0.75).normalize(false).filter(logits.clone(), &[]);
411 assert_eq!(top_p_logits.logits(), &[0.5, 0.25]);
412 assert_eq!(top_p_logits.indices(), &[3, 1]);
413
414 let top_p_logits = TopP::new(0.).normalize(false).filter(logits.clone(), &[]);
417 assert_eq!(top_p_logits.logits(), &[0.5]);
418 assert_eq!(top_p_logits.indices(), &[3]);
419 }
420}