gpt_sovits/
logits_sampler.rs1use {
2 rand::{
3 SeedableRng,
4 distr::{Distribution, weighted::WeightedIndex},
5 rngs::StdRng,
6 },
7 std::{cmp::Ordering, collections::HashSet},
8};
9
10pub fn argmax(logits: &[f32]) -> i64 {
12 let mut max_logit = f32::NEG_INFINITY;
13 let mut max_idx = 0;
14
15 for (idx, &logit) in logits.iter().enumerate() {
16 if logit > max_logit {
17 max_logit = logit;
18 max_idx = idx;
19 }
20 }
21 max_idx as i64
22}
23
24#[derive(Clone, Copy, Debug)]
26pub struct SamplingParams {
27 pub temperature: f32,
28 pub top_k: Option<usize>,
29 pub top_p: Option<f32>,
30 pub repetition_penalty: f32,
31}
32
33pub struct SamplingParamsBuilder {
35 temperature: f32,
36 top_k: Option<usize>,
37 top_p: Option<f32>,
38 repetition_penalty: f32,
39}
40
41impl SamplingParams {
42 pub fn builder() -> SamplingParamsBuilder {
43 SamplingParamsBuilder::new()
44 }
45}
46
47impl SamplingParamsBuilder {
48 fn new() -> Self {
49 SamplingParamsBuilder {
50 temperature: 1.0,
51 top_k: None,
52 top_p: None,
53 repetition_penalty: 1.0,
54 }
55 }
56
57 pub fn temperature(mut self, temperature: f32) -> Self {
58 self.temperature = if temperature >= 0.0 { temperature } else { 1.0 };
59 self
60 }
61
62 pub fn top_k(mut self, top_k: usize) -> Self {
63 self.top_k = Some(top_k);
64 self
65 }
66
67 pub fn top_p(mut self, top_p: f32) -> Self {
68 self.top_p = Some(top_p);
69 self
70 }
71
72 pub fn repetition_penalty(mut self, repetition_penalty: f32) -> Self {
73 self.repetition_penalty = if repetition_penalty > 0.0 {
74 repetition_penalty
75 } else {
76 1.0
77 };
78 self
79 }
80
81 pub fn build(self) -> SamplingParams {
82 SamplingParams {
83 temperature: self.temperature,
84 top_k: self.top_k,
85 top_p: self.top_p,
86 repetition_penalty: self.repetition_penalty,
87 }
88 }
89}
90
91pub struct Sampler {
94 rng: StdRng,
95 probs: Vec<f32>,
97}
98
99unsafe impl Send for Sampler {}
100
101impl Sampler {
102 pub fn new(vocab_size: usize) -> Self {
107 Self {
108 rng: StdRng::from_os_rng(),
109 probs: Vec::with_capacity(vocab_size),
110 }
111 }
112
113 fn apply_repetition_penalty(logits: &mut [f32], prev_tokens: &[i64], penalty: f32) {
115 if penalty == 1.0 {
116 return;
117 }
118 let prev_tokens_set: HashSet<_> = prev_tokens.iter().copied().collect();
119 for (token_id, logit) in logits.iter_mut().enumerate() {
120 if prev_tokens_set.contains(&(token_id as i64)) {
121 if *logit >= 0.0 {
122 *logit /= penalty;
123 } else {
124 *logit *= penalty;
125 }
126 }
127 }
128 }
129
130 fn apply_temperature(logits: &mut [f32], temperature: f32) {
132 if temperature > 0.0 {
133 let inv_temp = 1.0 / temperature;
134 for logit in logits.iter_mut() {
135 *logit *= inv_temp;
136 }
137 }
138 }
139
140 fn softmax(&mut self, logits: &[f32]) {
142 self.probs.clear();
143 if logits.is_empty() {
144 return;
145 }
146
147 let max_logit = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
148
149 let mut sum_exp = 0.0;
150 self.probs.extend(logits.iter().map(|&logit| {
151 let exp_val = (logit - max_logit).exp();
152 sum_exp += exp_val;
153 exp_val
154 }));
155
156 if sum_exp > 0.0 {
157 let inv_sum_exp = 1.0 / sum_exp;
158 for prob in self.probs.iter_mut() {
159 *prob *= inv_sum_exp;
160 }
161 }
162 }
163
164 pub fn sample(
166 &mut self,
167 logits: &mut [f32],
168 prev_tokens: &[i64],
169 params: &SamplingParams,
170 ) -> i64 {
171 Self::apply_repetition_penalty(logits, prev_tokens, params.repetition_penalty);
172
173 if params.temperature == 0.0 {
175 return argmax(logits);
176 }
177
178 Self::apply_temperature(logits, params.temperature);
179 self.softmax(logits);
180
181 let mut candidates: Vec<(usize, f32)> = self.probs.iter().copied().enumerate().collect();
182
183 if candidates.is_empty() {
184 return argmax(logits);
185 }
186
187 if let Some(k) = params.top_k {
189 if k > 0 && k < candidates.len() {
190 candidates.select_nth_unstable_by(k - 1, |a, b| {
191 b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal)
192 });
193 candidates.truncate(k);
194 }
195 }
196
197 if let Some(p) = params.top_p {
199 if p < 1.0 {
200 candidates
201 .sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
202 let mut cum_prob = 0.0;
203 let mut cutoff = candidates.len();
204 for (i, &(_, prob)) in candidates.iter().enumerate() {
205 cum_prob += prob;
206 if cum_prob >= p {
207 cutoff = i + 1;
208 break;
209 }
210 }
211 candidates.truncate(cutoff);
212 }
213 }
214
215 let weights = candidates.iter().map(|&(_, p)| p);
217 let dist = match WeightedIndex::new(weights) {
218 Ok(d) => d,
219 Err(_) => {
220 return candidates
223 .first()
224 .map_or_else(|| argmax(logits), |&(idx, _)| idx as i64);
225 }
226 };
227
228 let sampled_candidate_index = dist.sample(&mut self.rng);
229 candidates[sampled_candidate_index].0 as i64
230 }
231}