ai00_core/sampler/
typical.rs1use std::collections::HashMap;
2
3use derivative::Derivative;
4use itertools::Itertools;
5use salvo::oapi::ToSchema;
6use serde::{Deserialize, Serialize};
7use voracious_radix_sort::RadixSort;
8
9use super::{radix, Sampler};
10
11#[derive(Debug, Clone, Derivative, Serialize, Deserialize, ToSchema)]
12#[derivative(Default)]
13#[serde(default)]
14pub struct TypicalParams {
15 #[derivative(Default(value = "0.5"))]
16 pub tau: f32,
17 #[derivative(Default(value = "128"))]
18 pub top_k: usize,
19 #[derivative(Default(value = "1.0"))]
20 pub temperature: f32,
21 #[derivative(Default(value = "0.3"))]
22 pub presence_penalty: f32,
23 #[derivative(Default(value = "0.3"))]
24 pub frequency_penalty: f32,
25 #[derivative(Default(value = "0.99654026"))]
26 pub penalty_decay: f32,
27}
28
29#[derive(Debug, Default, Clone)]
30pub struct TypicalState {
31 pub penalties: HashMap<u32, f32>,
32}
33
34#[derive(Debug, Default, Clone)]
35pub struct TypicalSampler {
36 pub params: TypicalParams,
37 pub state: TypicalState,
38}
39
40impl TypicalSampler {
41 pub fn new(params: TypicalParams) -> Self {
42 Self {
43 params,
44 state: Default::default(),
45 }
46 }
47}
48
49impl Sampler for TypicalSampler {
50 fn init(&mut self, model_tokens: &[u32]) {
51 let TypicalSampler { params, state } = self;
52 for (index, token) in model_tokens.iter().rev().enumerate() {
53 let ap = params.presence_penalty;
54 let af = params.frequency_penalty;
55 let ad = params.penalty_decay;
56 let mut penalty = state.penalties.remove(token).unwrap_or(ap);
57 penalty += af * ad.powf(index as f32);
58 state.penalties.insert(*token, penalty);
59 }
60 }
61
62 fn transform(&self, output: &mut [f32]) {
63 self.state
64 .penalties
65 .iter()
66 .for_each(|(token, penalty)| output[*token as usize] -= penalty)
68 }
69
70 fn sample(&mut self, probs: &[f32]) -> u32 {
71 let TypicalSampler { params, state } = self;
72
73 let probs = probs
74 .iter()
75 .enumerate()
76 .filter(|(_, &x)| x > 0.0)
77 .map(|(id, &x)| (id, x, -x.ln()))
78 .collect_vec();
79 let entropy = probs.iter().map(|(_, x, y)| x * y).sum::<f32>();
80 let mut sorted = probs
81 .into_iter()
82 .map(|(id, x, y)| radix::DoubleF32WithIndex(id, x, (y - entropy).abs()))
83 .collect_vec();
84 sorted.voracious_sort();
85 let sorted = sorted
86 .into_iter()
87 .map(|radix::DoubleF32WithIndex(id, x, _)| (id, x))
88 .take(params.top_k)
89 .scan((0, 0.0, 0.0), |(_, cum, _), (id, x)| {
90 if *cum > params.tau {
91 None
92 } else {
93 *cum += x;
94 Some((id, *cum, x))
95 }
96 })
97 .map(|(id, _, x)| (id, x.powf(1.0 / params.temperature)))
98 .collect_vec();
99
100 let sum: f32 = sorted.iter().map(|(_, x)| x).sum();
101 let sorted = sorted
102 .into_iter()
103 .map(|(id, x)| (id, x / sum))
104 .scan((0, 0.0), |(_, cum), (id, x)| {
105 *cum += x;
106 Some((id, *cum))
107 })
108 .collect_vec();
109
110 let rand = fastrand::f32();
111 let token = sorted
112 .into_iter()
113 .find_or_first(|&(_, cum)| rand <= cum)
114 .map(|(id, _)| id)
115 .unwrap_or_default();
116 let token = token as u32;
117
118 state
119 .penalties
120 .iter_mut()
121 .for_each(|(_, penalty)| *penalty *= params.penalty_decay);
122
123 let penalty = match state.penalties.get(&token) {
124 Some(penalty) => penalty + params.frequency_penalty,
125 None => params.presence_penalty,
126 };
127 state.penalties.insert(token, penalty);
128
129 token
130 }
131}