ai00_core/sampler/
mirostat.rs1use super::{radix, Sampler};
2use derivative::Derivative;
3use itertools::Itertools;
4use salvo::oapi::ToSchema;
5use serde::{Deserialize, Serialize};
6use voracious_radix_sort::RadixSort;
7
8#[derive(Debug, Clone, Derivative, Serialize, Deserialize, ToSchema)]
9#[derivative(Default)]
10#[serde(default)]
11pub struct MirostatParams {
12 #[derivative(Default(value = "3.0"))]
13 pub tau: f32,
14 #[derivative(Default(value = "0.1"))]
15 #[serde(alias = "learning_rate")]
16 pub rate: f32,
17}
18
19#[derive(Debug, Clone, Default)]
20pub struct MirostatState {
21 pub max_surprise: f32,
22}
23
24#[derive(Debug, Clone, Default)]
25pub struct MirostatSampler {
26 pub params: MirostatParams,
27 pub state: MirostatState,
28}
29
30impl MirostatSampler {
31 pub fn new(params: MirostatParams) -> Self {
32 let state = MirostatState {
33 max_surprise: params.tau * 2.0,
34 };
35 Self { params, state }
36 }
37}
38
39impl Sampler for MirostatSampler {
40 fn init(&mut self, _model_tokens: &[u32]) {}
41
42 fn transform(&self, _output: &mut [f32]) {}
43
44 fn sample(&mut self, probs: &[f32]) -> u32 {
45 let MirostatSampler { params, state } = self;
46
47 let mut sorted = probs
49 .iter()
50 .copied()
51 .enumerate()
52 .map(|(id, x)| radix::F32WithIndex(id, x))
53 .collect_vec();
54 sorted.voracious_sort();
55 let sorted = sorted
56 .into_iter()
57 .rev()
58 .scan((0, 0.0, 0.0), |(_, cum, _), radix::F32WithIndex(id, x)| {
59 *cum += x;
66 Some((id, *cum, x))
67 })
68 .collect_vec();
69 let k = sorted
70 .iter()
71 .find_position(|&(_, _, x)| -x.log2() > state.max_surprise)
72 .map(|(k, _)| k + 1)
73 .unwrap_or(sorted.len());
74 let sorted = sorted.into_iter().take(k).collect_vec();
75
76 let sum = sorted.last().map(|(_, x, _)| *x).unwrap();
78 let rand = fastrand::f32() * sum;
79 let (token, _, prob) = sorted
80 .into_iter()
81 .find_or_first(|&(_, cum, _)| rand <= cum)
82 .unwrap();
83
84 let token_surprise = sum.log2() - prob.log2();
85 let error_surprise = token_surprise - params.tau;
86 state.max_surprise -= params.rate * error_surprise;
87 state.max_surprise = state.max_surprise.min(4.0 * params.tau);
88
89 token as u32
90 }
91}