ai00_core/sampler/
nucleus.rs1use std::collections::HashMap;
2
3use super::{radix, Sampler};
4use derivative::Derivative;
5use itertools::Itertools;
6use salvo::oapi::ToSchema;
7use serde::{Deserialize, Serialize};
8use voracious_radix_sort::RadixSort;
9
10#[derive(Debug, Clone, Derivative, Serialize, Deserialize, ToSchema)]
11#[derivative(Default)]
12#[serde(default)]
13pub struct NucleusParams {
14 #[derivative(Default(value = "0.5"))]
15 pub top_p: f32,
16 #[derivative(Default(value = "128"))]
17 pub top_k: usize,
18 #[derivative(Default(value = "1.0"))]
19 pub temperature: f32,
20 #[derivative(Default(value = "0.3"))]
21 pub presence_penalty: f32,
22 #[derivative(Default(value = "0.3"))]
23 pub frequency_penalty: f32,
24 #[derivative(Default(value = "0.99654026"))]
25 pub penalty_decay: f32,
26}
27
28#[derive(Debug, Default, Clone)]
29pub struct NucleusState {
30 pub penalties: HashMap<u32, f32>,
31}
32
33#[derive(Debug, Default, Clone)]
34pub struct NucleusSampler {
35 pub params: NucleusParams,
36 pub state: NucleusState,
37}
38
39impl NucleusSampler {
40 pub fn new(params: NucleusParams) -> Self {
41 Self {
42 params,
43 state: Default::default(),
44 }
45 }
46}
47
48impl Sampler for NucleusSampler {
49 fn init(&mut self, model_tokens: &[u32]) {
50 let NucleusSampler { params, state } = self;
51 for (index, token) in model_tokens.iter().rev().enumerate() {
52 let ap = params.presence_penalty;
53 let af = params.frequency_penalty;
54 let ad = params.penalty_decay;
55 let mut penalty = state.penalties.remove(token).unwrap_or(ap);
56 penalty += af * ad.powf(index as f32);
57 state.penalties.insert(*token, penalty);
58 }
59 }
60
61 fn transform(&self, output: &mut [f32]) {
62 self.state
63 .penalties
64 .iter()
65 .for_each(|(token, penalty)| output[*token as usize] -= penalty)
67 }
68
69 fn sample(&mut self, probs: &[f32]) -> u32 {
70 let NucleusSampler { params, state } = self;
71 let mut sorted = probs
72 .iter()
73 .copied()
74 .enumerate()
75 .map(|(id, x)| radix::F32WithIndex(id, x))
76 .collect_vec();
77 sorted.voracious_sort();
78 let sorted = sorted
79 .into_iter()
80 .rev()
81 .take(params.top_k)
82 .scan((0, 0.0, 0.0), |(_, cum, _), radix::F32WithIndex(id, x)| {
83 if *cum > params.top_p {
84 None
85 } else {
86 *cum += x;
87 Some((id, *cum, x))
88 }
89 })
90 .map(|(id, _, x)| (id, x.powf(1.0 / params.temperature)))
91 .collect_vec();
92
93 let sum: f32 = sorted.iter().map(|(_, x)| x).sum();
94 let sorted = sorted
95 .into_iter()
96 .map(|(id, x)| (id, x / sum))
97 .scan((0, 0.0), |(_, cum), (id, x)| {
98 *cum += x;
99 Some((id, *cum))
100 })
101 .collect_vec();
102 let rand = fastrand::f32();
103 let token = sorted
104 .into_iter()
105 .find_or_first(|&(_, cum)| rand <= cum)
106 .map(|(id, _)| id)
107 .unwrap_or_default();
108 let token = token as u32;
109
110 state
111 .penalties
112 .iter_mut()
113 .for_each(|(_, penalty)| *penalty *= params.penalty_decay);
114
115 let penalty = match state.penalties.get(&token) {
116 Some(penalty) => penalty + params.frequency_penalty,
117 None => params.presence_penalty,
118 };
119 state.penalties.insert(token, penalty);
120
121 token
122 }
123}