ai00_core/sampler/
typical.rs

1use std::collections::HashMap;
2
3use derivative::Derivative;
4use itertools::Itertools;
5use salvo::oapi::ToSchema;
6use serde::{Deserialize, Serialize};
7use voracious_radix_sort::RadixSort;
8
9use super::{radix, Sampler};
10
11#[derive(Debug, Clone, Derivative, Serialize, Deserialize, ToSchema)]
12#[derivative(Default)]
13#[serde(default)]
14pub struct TypicalParams {
15    #[derivative(Default(value = "0.5"))]
16    pub tau: f32,
17    #[derivative(Default(value = "128"))]
18    pub top_k: usize,
19    #[derivative(Default(value = "1.0"))]
20    pub temperature: f32,
21    #[derivative(Default(value = "0.3"))]
22    pub presence_penalty: f32,
23    #[derivative(Default(value = "0.3"))]
24    pub frequency_penalty: f32,
25    #[derivative(Default(value = "0.99654026"))]
26    pub penalty_decay: f32,
27}
28
29#[derive(Debug, Default, Clone)]
30pub struct TypicalState {
31    pub penalties: HashMap<u32, f32>,
32}
33
34#[derive(Debug, Default, Clone)]
35pub struct TypicalSampler {
36    pub params: TypicalParams,
37    pub state: TypicalState,
38}
39
40impl TypicalSampler {
41    pub fn new(params: TypicalParams) -> Self {
42        Self {
43            params,
44            state: Default::default(),
45        }
46    }
47}
48
49impl Sampler for TypicalSampler {
50    fn init(&mut self, model_tokens: &[u32]) {
51        let TypicalSampler { params, state } = self;
52        for (index, token) in model_tokens.iter().rev().enumerate() {
53            let ap = params.presence_penalty;
54            let af = params.frequency_penalty;
55            let ad = params.penalty_decay;
56            let mut penalty = state.penalties.remove(token).unwrap_or(ap);
57            penalty += af * ad.powf(index as f32);
58            state.penalties.insert(*token, penalty);
59        }
60    }
61
62    fn transform(&self, output: &mut [f32]) {
63        self.state
64            .penalties
65            .iter()
66            // .filter(|(token, _)| !penalty_free_tokens.contains(token))
67            .for_each(|(token, penalty)| output[*token as usize] -= penalty)
68    }
69
70    fn sample(&mut self, probs: &[f32]) -> u32 {
71        let TypicalSampler { params, state } = self;
72
73        let probs = probs
74            .iter()
75            .enumerate()
76            .filter(|(_, &x)| x > 0.0)
77            .map(|(id, &x)| (id, x, -x.ln()))
78            .collect_vec();
79        let entropy = probs.iter().map(|(_, x, y)| x * y).sum::<f32>();
80        let mut sorted = probs
81            .into_iter()
82            .map(|(id, x, y)| radix::DoubleF32WithIndex(id, x, (y - entropy).abs()))
83            .collect_vec();
84        sorted.voracious_sort();
85        let sorted = sorted
86            .into_iter()
87            .map(|radix::DoubleF32WithIndex(id, x, _)| (id, x))
88            .take(params.top_k)
89            .scan((0, 0.0, 0.0), |(_, cum, _), (id, x)| {
90                if *cum > params.tau {
91                    None
92                } else {
93                    *cum += x;
94                    Some((id, *cum, x))
95                }
96            })
97            .map(|(id, _, x)| (id, x.powf(1.0 / params.temperature)))
98            .collect_vec();
99
100        let sum: f32 = sorted.iter().map(|(_, x)| x).sum();
101        let sorted = sorted
102            .into_iter()
103            .map(|(id, x)| (id, x / sum))
104            .scan((0, 0.0), |(_, cum), (id, x)| {
105                *cum += x;
106                Some((id, *cum))
107            })
108            .collect_vec();
109
110        let rand = fastrand::f32();
111        let token = sorted
112            .into_iter()
113            .find_or_first(|&(_, cum)| rand <= cum)
114            .map(|(id, _)| id)
115            .unwrap_or_default();
116        let token = token as u32;
117
118        state
119            .penalties
120            .iter_mut()
121            .for_each(|(_, penalty)| *penalty *= params.penalty_decay);
122
123        let penalty = match state.penalties.get(&token) {
124            Some(penalty) => penalty + params.frequency_penalty,
125            None => params.presence_penalty,
126        };
127        state.penalties.insert(token, penalty);
128
129        token
130    }
131}