ai00_core/sampler/
mirostat.rs

1use super::{radix, Sampler};
2use derivative::Derivative;
3use itertools::Itertools;
4use salvo::oapi::ToSchema;
5use serde::{Deserialize, Serialize};
6use voracious_radix_sort::RadixSort;
7
8#[derive(Debug, Clone, Derivative, Serialize, Deserialize, ToSchema)]
9#[derivative(Default)]
10#[serde(default)]
11pub struct MirostatParams {
12    #[derivative(Default(value = "3.0"))]
13    pub tau: f32,
14    #[derivative(Default(value = "0.1"))]
15    #[serde(alias = "learning_rate")]
16    pub rate: f32,
17}
18
19#[derive(Debug, Clone, Default)]
20pub struct MirostatState {
21    pub max_surprise: f32,
22}
23
24#[derive(Debug, Clone, Default)]
25pub struct MirostatSampler {
26    pub params: MirostatParams,
27    pub state: MirostatState,
28}
29
30impl MirostatSampler {
31    pub fn new(params: MirostatParams) -> Self {
32        let state = MirostatState {
33            max_surprise: params.tau * 2.0,
34        };
35        Self { params, state }
36    }
37}
38
39impl Sampler for MirostatSampler {
40    fn init(&mut self, _model_tokens: &[u32]) {}
41
42    fn transform(&self, _output: &mut [f32]) {}
43
44    fn sample(&mut self, probs: &[f32]) -> u32 {
45        let MirostatSampler { params, state } = self;
46
47        // sort the surprise values and truncate
48        let mut sorted = probs
49            .iter()
50            .copied()
51            .enumerate()
52            .map(|(id, x)| radix::F32WithIndex(id, x))
53            .collect_vec();
54        sorted.voracious_sort();
55        let sorted = sorted
56            .into_iter()
57            .rev()
58            .scan((0, 0.0, 0.0), |(_, cum, _), radix::F32WithIndex(id, x)| {
59                // if *cum > params.top_p {
60                //     None
61                // } else {
62                //     *cum += x;
63                //     Some((id, *cum, *x))
64                // }
65                *cum += x;
66                Some((id, *cum, x))
67            })
68            .collect_vec();
69        let k = sorted
70            .iter()
71            .find_position(|&(_, _, x)| -x.log2() > state.max_surprise)
72            .map(|(k, _)| k + 1)
73            .unwrap_or(sorted.len());
74        let sorted = sorted.into_iter().take(k).collect_vec();
75
76        // normalize the probs
77        let sum = sorted.last().map(|(_, x, _)| *x).unwrap();
78        let rand = fastrand::f32() * sum;
79        let (token, _, prob) = sorted
80            .into_iter()
81            .find_or_first(|&(_, cum, _)| rand <= cum)
82            .unwrap();
83
84        let token_surprise = sum.log2() - prob.log2();
85        let error_surprise = token_surprise - params.tau;
86        state.max_surprise -= params.rate * error_surprise;
87        state.max_surprise = state.max_surprise.min(4.0 * params.tau);
88
89        token as u32
90    }
91}