rten_generate/
filter.rs

1//! Filters for processing model outputs ("logits") prior to sampling.
2//!
3//! This module defines the [`LogitsFilter`] trait implemented by all filters,
4//! plus convenience functions to simplify implementing filters.
5
6use rten_simd::ops::{MaskOps, NumOps};
7use rten_simd::{Isa, Simd, SimdIterable, SimdOp};
8use rten_vecmath::Softmax;
9
10use crate::Logits;
11use crate::generator::TokenId;
12
13/// Filter which modifies the output logits from a model.
14///
15/// Filters can remove tokens or alter their scores. Filters are stateless and
16/// at each step they receive logits from the model or a previous filter, plus
17/// the previously generated token IDs.
18///
19/// Filters can be chained together using [`Chain`].
20pub trait LogitsFilter {
21    /// Filter the model's output and return the modified logits.
22    ///
23    /// `prev_tokens` contains the previously sampled tokens, including the prompt.
24    fn filter(&self, logits: Logits, prev_tokens: &[TokenId]) -> Logits;
25}
26
27struct TokenIdFilter<F: Fn(TokenId) -> bool> {
28    predicate: F,
29}
30
31impl<F: Fn(TokenId) -> bool> LogitsFilter for TokenIdFilter<F> {
32    fn filter(&self, logits: Logits, _prev_tokens: &[TokenId]) -> Logits {
33        let (logits, indices) = logits.into_logits_indices();
34        let (new_logits, new_indices) = logits
35            .into_iter()
36            .zip(indices)
37            .filter(|(_logit, token_id)| (self.predicate)(*token_id))
38            .unzip();
39        Logits::sparse(new_logits, new_indices)
40    }
41}
42
43/// Create a filter which suppresses all tokens that do not match a predicate by
44/// setting the value to `f32::NEG_INFINITY`.
45pub fn token_id_filter<F: Fn(TokenId) -> bool>(predicate: F) -> impl LogitsFilter {
46    TokenIdFilter { predicate }
47}
48
49/// Filter which scales logits uniformly.
50///
51/// This updates the value of each input logit using the formula `logit /
52/// temperature`.
53pub struct Temperature {
54    temperature: f32,
55}
56
57impl Temperature {
58    /// Create a temperature filter which updates each logit by dividing by
59    /// `temperature`.
60    pub fn new(temperature: f32) -> Self {
61        assert!(temperature >= 0.);
62        Self { temperature }
63    }
64}
65
66impl LogitsFilter for Temperature {
67    fn filter(&self, logits: Logits, _prev_tokens: &[TokenId]) -> Logits {
68        if self.temperature == 1.0 {
69            return logits;
70        }
71        let (mut logits, indices) = logits.into_logits_indices();
72        let inv_temp = 1. / self.temperature;
73        for x in &mut logits {
74            *x *= inv_temp;
75        }
76        Logits::sparse(logits, indices)
77    }
78}
79
80/// Applies a sequence of logit filters in series.
81pub struct Chain {
82    filters: Vec<Box<dyn LogitsFilter>>,
83}
84
85impl Default for Chain {
86    fn default() -> Self {
87        Self::new()
88    }
89}
90
91impl Chain {
92    /// Create an empty logits filter chain.
93    ///
94    /// An empty chain returns input logits unmodified.
95    pub fn new() -> Self {
96        Self {
97            filters: Vec::new(),
98        }
99    }
100
101    /// Add a new filter to the chain.
102    pub fn append<F: LogitsFilter + 'static>(mut self, filter: F) -> Self {
103        self.filters.push(Box::new(filter));
104        self
105    }
106
107    /// Add a temperature filter to the chain. See [`Temperature`].
108    pub fn temperature(self, temp: f32) -> Self {
109        self.append(Temperature::new(temp))
110    }
111
112    /// Add a top-P (nucleus sampling) filter to the chain. See [`TopP`].
113    pub fn top_p(self, p: f32) -> Self {
114        self.append(TopP::new(p))
115    }
116
117    /// Add a top-K filter to the chain. See [`TopK`].
118    pub fn top_k(self, k: usize) -> Self {
119        self.append(TopK::new(k))
120    }
121}
122
123impl LogitsFilter for Chain {
124    fn filter(&self, logits: Logits, prev_tokens: &[TokenId]) -> Logits {
125        self.filters
126            .iter()
127            .fold(logits, |logits, f| f.filter(logits, prev_tokens))
128    }
129}
130
131/// Filter which retains K logits with the highest values.
132pub struct TopK {
133    k: usize,
134}
135
136impl TopK {
137    pub fn new(k: usize) -> Self {
138        Self { k }
139    }
140}
141
142impl LogitsFilter for TopK {
143    fn filter(&self, logits: Logits, _prev_tokens: &[TokenId]) -> Logits {
144        if logits.is_empty() {
145            return logits;
146        }
147
148        let (logits, indices) = logits.into_logits_indices();
149
150        let topk = SimdTopK {
151            k: self.k,
152            indices: &indices,
153            logits: &logits,
154        }
155        .dispatch();
156
157        let (indices, logits) = topk.into_iter().unzip();
158        Logits::sparse(logits, indices)
159    }
160}
161
162/// Vectorized top-K implementation which assumes that K is small relative to
163/// the input length, and that values are distributed such that updates to the
164/// running top-K are infrequent.
165struct SimdTopK<'a> {
166    k: usize,
167    logits: &'a [f32],
168    indices: &'a [u32],
169}
170
171impl<'a> SimdOp for SimdTopK<'a> {
172    type Output = Vec<(u32, f32)>;
173
174    #[inline(always)]
175    fn eval<I: Isa>(self, isa: I) -> Self::Output {
176        let SimdTopK { logits, indices, k } = self;
177
178        let ops = isa.f32();
179        let mask_ops = isa.m32();
180        let compare_gt = |a: f32, b: f32| a.total_cmp(&b).reverse();
181
182        // Create an initial sorted top-K list from the first K entries.
183        let mut topk: Vec<(u32, f32)> = indices
184            .iter()
185            .zip(logits)
186            .take(k)
187            .map(|(i, logit)| (*i, *logit))
188            .collect();
189        topk.sort_by(|a, b| compare_gt(a.1, b.1));
190
191        if k == 0 || logits.len() == k {
192            return topk;
193        }
194
195        let mut kth_logit = topk.last().unwrap().1;
196        let mut kth_logit_vec = ops.splat(kth_logit);
197
198        let mut update_topk = |kth_logit: &mut f32, index: u32, logit: f32| {
199            if logit > *kth_logit {
200                *topk.last_mut().unwrap() = (index, logit);
201                topk.sort_by(|a, b| compare_gt(a.1, b.1));
202                *kth_logit = topk.last().unwrap().1;
203            }
204        };
205
206        let indices = &indices[k..];
207        let logits = &logits[k..];
208
209        // Iterate over SIMD-sized chunks of remaining logits and update running
210        // top-K.
211        let mut indices_iter = indices.chunks_exact(ops.len());
212        let mut logits_iter = logits.simd_iter(ops);
213        for (index_chunk, logits_vec) in indices_iter.by_ref().zip(logits_iter.by_ref()) {
214            if mask_ops.any(ops.gt(logits_vec, kth_logit_vec)) {
215                for (&index, logit) in index_chunk.iter().zip(logits_vec.to_array()) {
216                    update_topk(&mut kth_logit, index, logit);
217                }
218                kth_logit_vec = ops.splat(kth_logit);
219            }
220        }
221
222        // Handle tail.
223        if let Some((logits_tail, _mask)) = logits_iter.tail() {
224            let indices_tail = indices_iter.remainder();
225            for (&index, logit) in indices_tail.iter().zip(logits_tail.to_array()) {
226                update_topk(&mut kth_logit, index, logit);
227            }
228        }
229
230        topk
231    }
232}
233
234/// Filter which retains the logits whose cumulative probability exceeds a
235/// threshold _p_.
236///
237/// See <https://en.wikipedia.org/wiki/Top-p_sampling>.
238pub struct TopP {
239    cumulative_prob: f32,
240    normalize: bool,
241}
242
243impl TopP {
244    pub fn new(cumulative_prob: f32) -> Self {
245        Self {
246            cumulative_prob,
247            normalize: false,
248        }
249    }
250
251    /// Set whether input logits are normalized to probabilities using softmax
252    /// before the top-P subset is computed.
253    ///
254    /// This is true by default.
255    pub fn normalize(mut self, normalize: bool) -> Self {
256        self.normalize = normalize;
257        self
258    }
259}
260
261impl LogitsFilter for TopP {
262    fn filter(&self, logits: Logits, _prev_tokens: &[TokenId]) -> Logits {
263        if self.cumulative_prob == 1.0 {
264            return logits;
265        }
266
267        let (mut logits, indices) = logits.into_logits_indices();
268
269        // Normalize logits to probabilities.
270        if self.normalize {
271            Softmax::new_mut(&mut logits).dispatch();
272        }
273
274        // Combine into (logit, token_id) tuples and sort by probability
275        // descending.
276        let mut pairs: Vec<(f32, TokenId)> = logits.into_iter().zip(indices).collect();
277        pairs.sort_by(|a, b| {
278            let (a_prob, _a_id) = a;
279            let (b_prob, _b_id) = b;
280            a_prob.total_cmp(b_prob).reverse()
281        });
282
283        // Find k such that the top-K logits have a cumulative probability >= self.p.
284        //
285        // The threshold is set to be > 0 so the sampled set is non-empty.
286        let mut cum_prob = 0.;
287        let mut k = 0;
288        let threshold = self.cumulative_prob.max(f32::MIN_POSITIVE);
289        while cum_prob < threshold && k < pairs.len() {
290            cum_prob += pairs[k].0;
291            k += 1;
292        }
293        pairs.truncate(k);
294
295        // Return the top-K logits.
296        let (logits, indices) = pairs.into_iter().unzip();
297        Logits::sparse(logits, indices)
298    }
299}
300
301/// Filter which sorts logits in descending order of their scores.
302#[derive(Default)]
303pub struct Sort {
304    _private: (),
305}
306
307impl Sort {
308    pub fn new() -> Self {
309        Sort { _private: () }
310    }
311}
312
313impl LogitsFilter for Sort {
314    fn filter(&self, logits: Logits, _prev_tokens: &[TokenId]) -> Logits {
315        let (logits, indices) = logits.into_logits_indices();
316
317        let mut pairs: Vec<(f32, TokenId)> = logits.into_iter().zip(indices).collect();
318        pairs.sort_by(|(a_val, _), (b_val, _)| a_val.total_cmp(b_val).reverse());
319
320        let (logits, indices) = pairs.into_iter().unzip();
321        Logits::sparse(logits, indices)
322    }
323}
324
325#[cfg(test)]
326mod tests {
327    use super::{Chain, Logits, LogitsFilter, Sort, Temperature, TopK, TopP, token_id_filter};
328
329    #[test]
330    fn test_token_id_filter() {
331        let logits = Logits::dense(vec![0., 1., 2., 3., 4.]);
332        let filter = token_id_filter(|id| id % 2 == 0);
333        let output = filter.filter(logits, &[]);
334        assert_eq!(output.logits(), &[0., 2., 4.]);
335        assert_eq!(output.indices(), &[0, 2, 4]);
336    }
337
338    #[test]
339    fn test_temperature() {
340        let logits = Logits::dense(vec![0., 1., 2., 3., 4.]);
341        let filter = Temperature::new(2.0);
342        let output = filter.filter(logits, &[]);
343        assert_eq!(output.logits(), &[0., 0.5, 1., 1.5, 2.0]);
344        assert_eq!(output.indices(), &[0, 1, 2, 3, 4]);
345    }
346
347    #[test]
348    fn test_chain() {
349        let logits = Logits::dense(vec![0., 1., 2., 3., 4.]);
350        let chain = Chain::new()
351            .append(token_id_filter(|id| id % 2 == 0))
352            .append(token_id_filter(|id| id > 0));
353        let output = chain.filter(logits, &[]);
354        assert_eq!(output.logits(), &[2., 4.]);
355        assert_eq!(output.indices(), &[2, 4]);
356    }
357
358    fn reference_topk(logits: &Logits, k: usize) -> Logits {
359        let mut pairs: Vec<(u32, f32)> = logits
360            .indices()
361            .iter()
362            .zip(logits.logits())
363            .map(|(idx, val)| (*idx, *val))
364            .collect();
365        pairs.sort_by(|a, b| a.1.total_cmp(&b.1).reverse());
366        pairs.truncate(k);
367        let (indices, logits) = pairs.into_iter().unzip();
368        Logits::sparse(logits, indices)
369    }
370
371    #[test]
372    fn test_top_k() {
373        let sort = |logits| Sort::new().filter(logits, &[]);
374
375        let logits = Logits::dense(vec![
376            -1., 1., 0., 2., -2., 10., -3., 2., 1., 0., 20., -5., 5., 0.1, -0.2, 0.2, 0.1,
377        ]);
378        // Exceeds max common SIMD vector width (16 x f32 = 512 bits) and has
379        // a tail.
380        assert_eq!(logits.len(), 17);
381
382        // Test cases where K <= logits length.
383        for k in 0..=logits.len() {
384            let topk = TopK::new(k).filter(logits.clone(), &[]);
385            let sorted_topk = sort(topk);
386            let expected_topk = reference_topk(&logits, k);
387            assert_eq!(sorted_topk.logits(), expected_topk.logits());
388            assert_eq!(sorted_topk.indices(), expected_topk.indices());
389        }
390
391        // Test empty logits
392        let logits = Logits::dense(vec![]);
393        let topk = TopK::new(1).filter(logits, &[]);
394        assert!(topk.is_empty());
395    }
396
397    #[test]
398    fn test_top_p() {
399        // These tests disable normalization so the input logits are treated
400        // directly as probabilities.
401
402        let logits = Logits::dense(vec![0.1, 0.25, 0.15, 0.5]);
403        let all_logits = TopP::new(1.0).normalize(false).filter(logits.clone(), &[]);
404        assert_eq!(logits, all_logits);
405
406        let top_p_logits = TopP::new(0.5).normalize(false).filter(logits.clone(), &[]);
407        assert_eq!(top_p_logits.logits(), &[0.5]);
408        assert_eq!(top_p_logits.indices(), &[3]);
409
410        let top_p_logits = TopP::new(0.75).normalize(false).filter(logits.clone(), &[]);
411        assert_eq!(top_p_logits.logits(), &[0.5, 0.25]);
412        assert_eq!(top_p_logits.indices(), &[3, 1]);
413
414        // As a special case, the probability is clamped to be > 0 so that at
415        // least one token will be sampled, if the input is non-empty.
416        let top_p_logits = TopP::new(0.).normalize(false).filter(logits.clone(), &[]);
417        assert_eq!(top_p_logits.logits(), &[0.5]);
418        assert_eq!(top_p_logits.indices(), &[3]);
419    }
420}