gpt_sovits/
logits_sampler.rs

1use {
2    rand::{
3        SeedableRng,
4        distr::{Distribution, weighted::WeightedIndex},
5        rngs::StdRng,
6    },
7    std::{cmp::Ordering, collections::HashSet},
8};
9
10/// Finds the token with the highest logit value (argmax).
11pub fn argmax(logits: &[f32]) -> i64 {
12    let mut max_logit = f32::NEG_INFINITY;
13    let mut max_idx = 0;
14
15    for (idx, &logit) in logits.iter().enumerate() {
16        if logit > max_logit {
17            max_logit = logit;
18            max_idx = idx;
19        }
20    }
21    max_idx as i64
22}
23
24// Sampling parameters (unchanged)
25#[derive(Clone, Copy, Debug)]
26pub struct SamplingParams {
27    pub temperature: f32,
28    pub top_k: Option<usize>,
29    pub top_p: Option<f32>,
30    pub repetition_penalty: f32,
31}
32
33// Builder for SamplingParams (unchanged)
34pub struct SamplingParamsBuilder {
35    temperature: f32,
36    top_k: Option<usize>,
37    top_p: Option<f32>,
38    repetition_penalty: f32,
39}
40
41impl SamplingParams {
42    pub fn builder() -> SamplingParamsBuilder {
43        SamplingParamsBuilder::new()
44    }
45}
46
47impl SamplingParamsBuilder {
48    fn new() -> Self {
49        SamplingParamsBuilder {
50            temperature: 1.0,
51            top_k: None,
52            top_p: None,
53            repetition_penalty: 1.0,
54        }
55    }
56
57    pub fn temperature(mut self, temperature: f32) -> Self {
58        self.temperature = if temperature >= 0.0 { temperature } else { 1.0 };
59        self
60    }
61
62    pub fn top_k(mut self, top_k: usize) -> Self {
63        self.top_k = Some(top_k);
64        self
65    }
66
67    pub fn top_p(mut self, top_p: f32) -> Self {
68        self.top_p = Some(top_p);
69        self
70    }
71
72    pub fn repetition_penalty(mut self, repetition_penalty: f32) -> Self {
73        self.repetition_penalty = if repetition_penalty > 0.0 {
74            repetition_penalty
75        } else {
76            1.0
77        };
78        self
79    }
80
81    pub fn build(self) -> SamplingParams {
82        SamplingParams {
83            temperature: self.temperature,
84            top_k: self.top_k,
85            top_p: self.top_p,
86            repetition_penalty: self.repetition_penalty,
87        }
88    }
89}
90
91/// Processes logits to sample a token ID, applying various strategies
92/// like temperature, repetition penalty, and Top-K/Top-P sampling.
93pub struct Sampler {
94    rng: StdRng,
95    /// Reusable buffer for probabilities to avoid re-allocation in the sampling loop.
96    probs: Vec<f32>,
97}
98
99unsafe impl Send for Sampler {}
100
101impl Sampler {
102    /// Creates a new Sampler.
103    ///
104    /// # Arguments
105    /// * `vocab_size`: The size of the vocabulary, used to pre-allocate buffers for efficiency.
106    pub fn new(vocab_size: usize) -> Self {
107        Self {
108            rng: StdRng::from_os_rng(),
109            probs: Vec::with_capacity(vocab_size),
110        }
111    }
112
113    /// Applies a penalty to the logits of repeated tokens.
114    fn apply_repetition_penalty(logits: &mut [f32], prev_tokens: &[i64], penalty: f32) {
115        if penalty == 1.0 {
116            return;
117        }
118        let prev_tokens_set: HashSet<_> = prev_tokens.iter().copied().collect();
119        for (token_id, logit) in logits.iter_mut().enumerate() {
120            if prev_tokens_set.contains(&(token_id as i64)) {
121                if *logit >= 0.0 {
122                    *logit /= penalty;
123                } else {
124                    *logit *= penalty;
125                }
126            }
127        }
128    }
129
130    /// Applies temperature scaling to the logits.
131    fn apply_temperature(logits: &mut [f32], temperature: f32) {
132        if temperature > 0.0 {
133            let inv_temp = 1.0 / temperature;
134            for logit in logits.iter_mut() {
135                *logit *= inv_temp;
136            }
137        }
138    }
139
140    /// Computes the softmax of logits and stores the result in the internal `probs` buffer.
141    fn softmax(&mut self, logits: &[f32]) {
142        self.probs.clear();
143        if logits.is_empty() {
144            return;
145        }
146
147        let max_logit = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
148
149        let mut sum_exp = 0.0;
150        self.probs.extend(logits.iter().map(|&logit| {
151            let exp_val = (logit - max_logit).exp();
152            sum_exp += exp_val;
153            exp_val
154        }));
155
156        if sum_exp > 0.0 {
157            let inv_sum_exp = 1.0 / sum_exp;
158            for prob in self.probs.iter_mut() {
159                *prob *= inv_sum_exp;
160            }
161        }
162    }
163
164    /// Main sampling method with performance optimizations.
165    pub fn sample(
166        &mut self,
167        logits: &mut [f32],
168        prev_tokens: &[i64],
169        params: &SamplingParams,
170    ) -> i64 {
171        Self::apply_repetition_penalty(logits, prev_tokens, params.repetition_penalty);
172
173        // Optimized path for greedy decoding (argmax).
174        if params.temperature == 0.0 {
175            return argmax(logits);
176        }
177
178        Self::apply_temperature(logits, params.temperature);
179        self.softmax(logits);
180
181        let mut candidates: Vec<(usize, f32)> = self.probs.iter().copied().enumerate().collect();
182
183        if candidates.is_empty() {
184            return argmax(logits);
185        }
186
187        // --- Top-K Filtering (Optimized O(V) selection) ---
188        if let Some(k) = params.top_k {
189            if k > 0 && k < candidates.len() {
190                candidates.select_nth_unstable_by(k - 1, |a, b| {
191                    b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal)
192                });
193                candidates.truncate(k);
194            }
195        }
196
197        // --- Top-P (Nucleus) Filtering (on at most K candidates) ---
198        if let Some(p) = params.top_p {
199            if p < 1.0 {
200                candidates
201                    .sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
202                let mut cum_prob = 0.0;
203                let mut cutoff = candidates.len();
204                for (i, &(_, prob)) in candidates.iter().enumerate() {
205                    cum_prob += prob;
206                    if cum_prob >= p {
207                        cutoff = i + 1;
208                        break;
209                    }
210                }
211                candidates.truncate(cutoff);
212            }
213        }
214
215        // --- Final Sampling ---
216        let weights = candidates.iter().map(|&(_, p)| p);
217        let dist = match WeightedIndex::new(weights) {
218            Ok(d) => d,
219            Err(_) => {
220                // Fallback if distribution fails (e.g., all probs are 0 after filtering).
221                // Return the highest probability candidate before this step.
222                return candidates
223                    .first()
224                    .map_or_else(|| argmax(logits), |&(idx, _)| idx as i64);
225            }
226        };
227
228        let sampled_candidate_index = dist.sample(&mut self.rng);
229        candidates[sampled_candidate_index].0 as i64
230    }
231}