ai00_core/sampler/
nucleus.rs

1use std::collections::HashMap;
2
3use super::{radix, Sampler};
4use derivative::Derivative;
5use itertools::Itertools;
6use salvo::oapi::ToSchema;
7use serde::{Deserialize, Serialize};
8use voracious_radix_sort::RadixSort;
9
10#[derive(Debug, Clone, Derivative, Serialize, Deserialize, ToSchema)]
11#[derivative(Default)]
12#[serde(default)]
13pub struct NucleusParams {
14    #[derivative(Default(value = "0.5"))]
15    pub top_p: f32,
16    #[derivative(Default(value = "128"))]
17    pub top_k: usize,
18    #[derivative(Default(value = "1.0"))]
19    pub temperature: f32,
20    #[derivative(Default(value = "0.3"))]
21    pub presence_penalty: f32,
22    #[derivative(Default(value = "0.3"))]
23    pub frequency_penalty: f32,
24    #[derivative(Default(value = "0.99654026"))]
25    pub penalty_decay: f32,
26}
27
28#[derive(Debug, Default, Clone)]
29pub struct NucleusState {
30    pub penalties: HashMap<u32, f32>,
31}
32
33#[derive(Debug, Default, Clone)]
34pub struct NucleusSampler {
35    pub params: NucleusParams,
36    pub state: NucleusState,
37}
38
39impl NucleusSampler {
40    pub fn new(params: NucleusParams) -> Self {
41        Self {
42            params,
43            state: Default::default(),
44        }
45    }
46}
47
48impl Sampler for NucleusSampler {
49    fn init(&mut self, model_tokens: &[u32]) {
50        let NucleusSampler { params, state } = self;
51        for (index, token) in model_tokens.iter().rev().enumerate() {
52            let ap = params.presence_penalty;
53            let af = params.frequency_penalty;
54            let ad = params.penalty_decay;
55            let mut penalty = state.penalties.remove(token).unwrap_or(ap);
56            penalty += af * ad.powf(index as f32);
57            state.penalties.insert(*token, penalty);
58        }
59    }
60
61    fn transform(&self, output: &mut [f32]) {
62        self.state
63            .penalties
64            .iter()
65            // .filter(|(token, _)| !penalty_free_tokens.contains(token))
66            .for_each(|(token, penalty)| output[*token as usize] -= penalty)
67    }
68
69    fn sample(&mut self, probs: &[f32]) -> u32 {
70        let NucleusSampler { params, state } = self;
71        let mut sorted = probs
72            .iter()
73            .copied()
74            .enumerate()
75            .map(|(id, x)| radix::F32WithIndex(id, x))
76            .collect_vec();
77        sorted.voracious_sort();
78        let sorted = sorted
79            .into_iter()
80            .rev()
81            .take(params.top_k)
82            .scan((0, 0.0, 0.0), |(_, cum, _), radix::F32WithIndex(id, x)| {
83                if *cum > params.top_p {
84                    None
85                } else {
86                    *cum += x;
87                    Some((id, *cum, x))
88                }
89            })
90            .map(|(id, _, x)| (id, x.powf(1.0 / params.temperature)))
91            .collect_vec();
92
93        let sum: f32 = sorted.iter().map(|(_, x)| x).sum();
94        let sorted = sorted
95            .into_iter()
96            .map(|(id, x)| (id, x / sum))
97            .scan((0, 0.0), |(_, cum), (id, x)| {
98                *cum += x;
99                Some((id, *cum))
100            })
101            .collect_vec();
102        let rand = fastrand::f32();
103        let token = sorted
104            .into_iter()
105            .find_or_first(|&(_, cum)| rand <= cum)
106            .map(|(id, _)| id)
107            .unwrap_or_default();
108        let token = token as u32;
109
110        state
111            .penalties
112            .iter_mut()
113            .for_each(|(_, penalty)| *penalty *= params.penalty_decay);
114
115        let penalty = match state.penalties.get(&token) {
116            Some(penalty) => penalty + params.frequency_penalty,
117            None => params.presence_penalty,
118        };
119        state.penalties.insert(token, penalty);
120
121        token
122    }
123}