1use crate::error::Result;
4use crate::inference::StreamToken;
5use candle_core::Tensor;
6use rand::Rng;
7
8#[derive(Debug, Clone, Default)]
10pub enum SamplingStrategy {
11 #[default]
13 Greedy,
14
15 TopK {
17 k: usize,
19 },
20
21 TopP {
23 p: f64,
25 },
26
27 Temperature {
29 temperature: f64,
31 },
32}
33
34pub fn apply_repetition_penalty(
63 logits: &mut Tensor,
64 generated_ids: &[u32],
65 penalty: f32,
66) -> Result<()> {
67 if generated_ids.is_empty() || (penalty - 1.0).abs() < 1e-7 {
68 return Ok(()); }
70
71 let mut logits_vec = logits.to_vec1::<f32>()?;
72
73 for &token_id in generated_ids {
75 let idx = token_id as usize;
76 if idx < logits_vec.len() {
77 logits_vec[idx] /= penalty;
78 }
79 }
80
81 *logits = Tensor::new(&logits_vec[..], logits.device())?;
83
84 Ok(())
85}
86
87pub fn sample_token(
119 logits: &Tensor,
120 strategy: &SamplingStrategy,
121 generated_ids: &[u32],
122 repetition_penalty: f32,
123) -> Result<u32> {
124 let mut logits = logits.clone();
126 apply_repetition_penalty(&mut logits, generated_ids, repetition_penalty)?;
127
128 match strategy {
129 SamplingStrategy::Greedy => sample_greedy(&logits),
130 SamplingStrategy::TopK { k } => sample_top_k(&logits, *k),
131 SamplingStrategy::TopP { p } => sample_top_p(&logits, *p),
132 SamplingStrategy::Temperature { temperature } => sample_temperature(&logits, *temperature),
133 }
134}
135
136pub fn sample_token_with_metadata(
174 logits: &Tensor,
175 strategy: &SamplingStrategy,
176 generated_ids: &[u32],
177 repetition_penalty: f32,
178 eos_token_id: Option<u32>,
179) -> Result<StreamToken> {
180 let mut penalized_logits = logits.clone();
182 apply_repetition_penalty(&mut penalized_logits, generated_ids, repetition_penalty)?;
183
184 let token_id = match strategy {
186 SamplingStrategy::Greedy => sample_greedy(&penalized_logits)?,
187 SamplingStrategy::TopK { k } => sample_top_k(&penalized_logits, *k)?,
188 SamplingStrategy::TopP { p } => sample_top_p(&penalized_logits, *p)?,
189 SamplingStrategy::Temperature { temperature } => {
190 sample_temperature(&penalized_logits, *temperature)?
191 }
192 };
193
194 let logits_vec = penalized_logits.to_vec1::<f32>()?;
196 let logit = logits_vec
197 .get(token_id as usize)
198 .copied()
199 .unwrap_or(f32::NEG_INFINITY);
200
201 let max_logit = logits_vec
203 .iter()
204 .copied()
205 .max_by(|a, b| a.partial_cmp(b).unwrap())
206 .unwrap_or(0.0);
207 let exp_sum: f32 = logits_vec.iter().map(|l| (l - max_logit).exp()).sum();
208 let probability = if exp_sum > 0.0 {
209 (logit - max_logit).exp() / exp_sum
210 } else {
211 0.0
212 };
213
214 let is_eos = eos_token_id == Some(token_id);
216
217 Ok(StreamToken {
218 token_id,
219 text: None, logit,
221 probability,
222 is_eos,
223 })
224}
225
226fn sample_greedy(logits: &Tensor) -> Result<u32> {
228 let logits_vec = logits.to_vec1::<f32>()?;
229 let token = logits_vec
230 .iter()
231 .enumerate()
232 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
233 .map(|(idx, _)| u32::try_from(idx).unwrap_or(u32::MAX))
234 .ok_or_else(|| crate::error::InferenceError::SamplingError {
235 reason: "Empty logits".to_string(),
236 })?;
237 Ok(token)
238}
239
240fn sample_top_k(logits: &Tensor, k: usize) -> Result<u32> {
242 let logits_vec = logits.to_vec1::<f32>()?;
243
244 let mut indexed: Vec<(usize, f32)> = logits_vec.iter().copied().enumerate().collect();
246 indexed.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
247 indexed.truncate(k);
248
249 let max_logit = indexed[0].1;
251 let exp_sum: f32 = indexed.iter().map(|(_, l)| (l - max_logit).exp()).sum();
252 let probs: Vec<f64> = indexed
253 .iter()
254 .map(|(_, l)| f64::from((l - max_logit).exp() / exp_sum))
255 .collect();
256
257 let mut rng = rand::thread_rng();
259 let r: f64 = rng.gen();
260 let mut cumsum = 0.0;
261 for (i, &p) in probs.iter().enumerate() {
262 cumsum += p;
263 if r <= cumsum {
264 return Ok(u32::try_from(indexed[i].0).unwrap_or(u32::MAX));
265 }
266 }
267
268 Ok(u32::try_from(indexed[0].0).unwrap_or(u32::MAX))
269}
270
271fn sample_top_p(logits: &Tensor, p: f64) -> Result<u32> {
273 let logits_vec = logits.to_vec1::<f32>()?;
274
275 let mut indexed: Vec<(usize, f32)> = logits_vec.iter().copied().enumerate().collect();
277 indexed.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
278
279 let max_logit = indexed[0].1;
281 let exp_sum: f32 = indexed.iter().map(|(_, l)| (l - max_logit).exp()).sum();
282 let probs: Vec<(usize, f64)> = indexed
283 .iter()
284 .map(|(idx, l)| (*idx, f64::from((l - max_logit).exp() / exp_sum)))
285 .collect();
286
287 let mut cumsum = 0.0;
289 let mut nucleus = Vec::new();
290 for (idx, prob) in probs {
291 nucleus.push((idx, prob));
292 cumsum += prob;
293 if cumsum >= p {
294 break;
295 }
296 }
297
298 let mut rng = rand::thread_rng();
300 let r: f64 = rng.gen();
301 let nucleus_sum: f64 = nucleus.iter().map(|(_, p)| p).sum();
302 let mut cumsum = 0.0;
303 for (idx, prob) in &nucleus {
304 cumsum += prob / nucleus_sum;
305 if r <= cumsum {
306 return Ok(u32::try_from(*idx).unwrap_or(u32::MAX));
307 }
308 }
309
310 Ok(u32::try_from(nucleus[0].0).unwrap_or(u32::MAX))
311}
312
313fn sample_temperature(logits: &Tensor, temperature: f64) -> Result<u32> {
315 let logits_vec = logits.to_vec1::<f32>()?;
316
317 #[allow(clippy::cast_possible_truncation)]
319 let scaled: Vec<f32> = logits_vec.iter().map(|l| l / temperature as f32).collect();
321
322 let max_logit = scaled.iter().copied().fold(f32::NEG_INFINITY, f32::max);
324 let exp_sum: f32 = scaled.iter().map(|l| (l - max_logit).exp()).sum();
325 let probs: Vec<f64> = scaled
326 .iter()
327 .map(|l| f64::from((l - max_logit).exp() / exp_sum))
328 .collect();
329
330 let mut rng = rand::thread_rng();
332 let r: f64 = rng.gen();
333 let mut cumsum = 0.0;
334 for (idx, &p) in probs.iter().enumerate() {
335 cumsum += p;
336 if r <= cumsum {
337 return Ok(u32::try_from(idx).unwrap_or(u32::MAX));
338 }
339 }
340
341 Ok(u32::try_from(probs.len() - 1).unwrap_or(u32::MAX))
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347 use candle_core::Device;
348
349 #[test]
350 fn test_greedy_sampling() {
351 let device = Device::Cpu;
352 let logits = Tensor::new(&[1.0f32, 3.0, 2.0, 0.5], &device).unwrap();
353
354 let token = sample_greedy(&logits).unwrap();
355 assert_eq!(token, 1); }
357
358 #[test]
359 fn test_top_k_sampling() {
360 let device = Device::Cpu;
361 let logits = Tensor::new(&[1.0f32, 3.0, 2.0, 0.5], &device).unwrap();
362
363 let token = sample_top_k(&logits, 2).unwrap();
365 assert!(token == 1 || token == 2);
366 }
367
368 #[test]
369 fn test_sampling_strategy_default() {
370 let strategy = SamplingStrategy::default();
371 assert!(matches!(strategy, SamplingStrategy::Greedy));
372 }
373
374 #[test]
375 fn test_apply_repetition_penalty() {
376 let device = Device::Cpu;
377 let mut logits = Tensor::new(&[1.0f32, 2.0, 3.0, 4.0], &device).unwrap();
378 let generated = vec![1, 3]; apply_repetition_penalty(&mut logits, &generated, 2.0).unwrap();
381
382 let result = logits.to_vec1::<f32>().unwrap();
383 assert!((result[0] - 1.0).abs() < 1e-6); assert!((result[1] - 1.0).abs() < 1e-6); assert!((result[2] - 3.0).abs() < 1e-6); assert!((result[3] - 2.0).abs() < 1e-6); }
388
389 #[test]
390 fn test_apply_repetition_penalty_empty() {
391 let device = Device::Cpu;
392 let mut logits = Tensor::new(&[1.0f32, 2.0, 3.0], &device).unwrap();
393 let original = logits.to_vec1::<f32>().unwrap();
394
395 apply_repetition_penalty(&mut logits, &[], 2.0).unwrap();
396
397 let result = logits.to_vec1::<f32>().unwrap();
398 assert_eq!(result, original); }
400
401 #[test]
402 fn test_apply_repetition_penalty_no_penalty() {
403 let device = Device::Cpu;
404 let mut logits = Tensor::new(&[1.0f32, 2.0, 3.0], &device).unwrap();
405 let original = logits.to_vec1::<f32>().unwrap();
406
407 apply_repetition_penalty(&mut logits, &[0, 1], 1.0).unwrap();
408
409 let result = logits.to_vec1::<f32>().unwrap();
410 assert_eq!(result, original); }
412
413 #[test]
414 fn test_sample_token_with_penalty() {
415 let device = Device::Cpu;
416 let logits = Tensor::new(&[1.0f32, 5.0, 2.0, 0.5], &device).unwrap();
417
418 let token = sample_token(&logits, &SamplingStrategy::Greedy, &[], 1.0).unwrap();
420 assert_eq!(token, 1);
421
422 let token = sample_token(&logits, &SamplingStrategy::Greedy, &[1], 10.0).unwrap();
424 assert_eq!(token, 2);
425 }
426}