metal_candle/inference/
sampling.rs

1//! Sampling strategies for text generation.
2
3use crate::error::Result;
4use crate::inference::StreamToken;
5use candle_core::Tensor;
6use rand::Rng;
7
8/// Sampling strategy for token selection.
9#[derive(Debug, Clone, Default)]
10pub enum SamplingStrategy {
11    /// Greedy sampling (argmax)
12    #[default]
13    Greedy,
14
15    /// Top-k sampling
16    TopK {
17        /// Number of top tokens to consider
18        k: usize,
19    },
20
21    /// Top-p (nucleus) sampling
22    TopP {
23        /// Cumulative probability threshold
24        p: f64,
25    },
26
27    /// Temperature sampling
28    Temperature {
29        /// Temperature value (higher = more random)
30        temperature: f64,
31    },
32}
33
34/// Applies repetition penalty to logits.
35///
36/// Penalizes previously generated tokens by dividing their logits by the penalty factor.
37/// This reduces the likelihood of repetitive text generation.
38///
39/// # Arguments
40///
41/// * `logits` - Mutable logits tensor to modify, shape: `(vocab_size,)`
42/// * `generated_ids` - Previously generated token IDs to penalize
43/// * `penalty` - Penalty factor (> 1.0 = penalize, 1.0 = no penalty)
44///
45/// # Errors
46///
47/// Returns an error if tensor operations fail.
48///
49/// # Examples
50///
51/// ```no_run
52/// use metal_candle::inference::sampling::apply_repetition_penalty;
53/// use candle_core::{Device, Tensor};
54///
55/// let device = Device::Cpu;
56/// let mut logits = Tensor::new(&[1.0f32, 2.0, 3.0, 4.0], &device)?;
57/// let generated = vec![1, 3]; // Penalize tokens 1 and 3
58///
59/// apply_repetition_penalty(&mut logits, &generated, 1.2)?;
60/// # Ok::<(), Box<dyn std::error::Error>>(())
61/// ```
62pub fn apply_repetition_penalty(
63    logits: &mut Tensor,
64    generated_ids: &[u32],
65    penalty: f32,
66) -> Result<()> {
67    if generated_ids.is_empty() || (penalty - 1.0).abs() < 1e-7 {
68        return Ok(()); // No penalty needed
69    }
70
71    let mut logits_vec = logits.to_vec1::<f32>()?;
72
73    // Apply penalty to previously generated tokens
74    for &token_id in generated_ids {
75        let idx = token_id as usize;
76        if idx < logits_vec.len() {
77            logits_vec[idx] /= penalty;
78        }
79    }
80
81    // Replace logits with penalized version
82    *logits = Tensor::new(&logits_vec[..], logits.device())?;
83
84    Ok(())
85}
86
87/// Samples a token from logits using the specified strategy.
88///
89/// # Arguments
90///
91/// * `logits` - Logits tensor, shape: `(vocab_size,)`
92/// * `strategy` - Sampling strategy to use
93/// * `generated_ids` - Previously generated token IDs (for repetition penalty)
94/// * `repetition_penalty` - Penalty factor for repeated tokens (1.0 = no penalty)
95///
96/// # Returns
97///
98/// Returns the sampled token ID.
99///
100/// # Errors
101///
102/// Returns an error if sampling fails or tensor operations fail.
103///
104/// # Examples
105///
106/// ```no_run
107/// use metal_candle::inference::sampling::{sample_token, SamplingStrategy};
108/// use candle_core::{Device, Tensor};
109///
110/// let device = Device::Cpu;
111/// let logits = Tensor::new(&[1.0f32, 2.0, 3.0, 4.0], &device)?;
112/// let strategy = SamplingStrategy::Greedy;
113/// let generated = vec![1, 2];
114///
115/// let token = sample_token(&logits, &strategy, &generated, 1.2)?;
116/// # Ok::<(), Box<dyn std::error::Error>>(())
117/// ```
118pub fn sample_token(
119    logits: &Tensor,
120    strategy: &SamplingStrategy,
121    generated_ids: &[u32],
122    repetition_penalty: f32,
123) -> Result<u32> {
124    // Apply repetition penalty if needed
125    let mut logits = logits.clone();
126    apply_repetition_penalty(&mut logits, generated_ids, repetition_penalty)?;
127
128    match strategy {
129        SamplingStrategy::Greedy => sample_greedy(&logits),
130        SamplingStrategy::TopK { k } => sample_top_k(&logits, *k),
131        SamplingStrategy::TopP { p } => sample_top_p(&logits, *p),
132        SamplingStrategy::Temperature { temperature } => sample_temperature(&logits, *temperature),
133    }
134}
135
136/// Samples a token and returns rich metadata for streaming.
137///
138/// # Arguments
139///
140/// * `logits` - Logits tensor, shape: `(vocab_size,)`
141/// * `strategy` - Sampling strategy to use
142/// * `generated_ids` - Previously generated token IDs (for repetition penalty)
143/// * `repetition_penalty` - Penalty factor for repeated tokens (1.0 = no penalty)
144/// * `eos_token_id` - Optional EOS token ID to mark in the result
145///
146/// # Returns
147///
148/// Returns a `StreamToken` with the sampled token and metadata.
149///
150/// # Errors
151///
152/// Returns an error if sampling fails or tensor operations fail.
153///
154/// # Panics
155///
156/// This function does not panic under normal circumstances.
157///
158/// # Examples
159///
160/// ```no_run
161/// use metal_candle::inference::sampling::{sample_token_with_metadata, SamplingStrategy};
162/// use candle_core::{Device, Tensor};
163///
164/// let device = Device::Cpu;
165/// let logits = Tensor::new(&[1.0f32, 2.0, 3.0, 4.0], &device)?;
166/// let strategy = SamplingStrategy::Greedy;
167/// let generated = vec![1, 2];
168///
169/// let stream_token = sample_token_with_metadata(&logits, &strategy, &generated, 1.2, Some(3))?;
170/// println!("Token {}: prob={:.2}", stream_token.token_id, stream_token.probability);
171/// # Ok::<(), Box<dyn std::error::Error>>(())
172/// ```
173pub fn sample_token_with_metadata(
174    logits: &Tensor,
175    strategy: &SamplingStrategy,
176    generated_ids: &[u32],
177    repetition_penalty: f32,
178    eos_token_id: Option<u32>,
179) -> Result<StreamToken> {
180    // Apply repetition penalty if needed
181    let mut penalized_logits = logits.clone();
182    apply_repetition_penalty(&mut penalized_logits, generated_ids, repetition_penalty)?;
183
184    // Sample the token
185    let token_id = match strategy {
186        SamplingStrategy::Greedy => sample_greedy(&penalized_logits)?,
187        SamplingStrategy::TopK { k } => sample_top_k(&penalized_logits, *k)?,
188        SamplingStrategy::TopP { p } => sample_top_p(&penalized_logits, *p)?,
189        SamplingStrategy::Temperature { temperature } => {
190            sample_temperature(&penalized_logits, *temperature)?
191        }
192    };
193
194    // Get logit and probability for the sampled token
195    let logits_vec = penalized_logits.to_vec1::<f32>()?;
196    let logit = logits_vec
197        .get(token_id as usize)
198        .copied()
199        .unwrap_or(f32::NEG_INFINITY);
200
201    // Compute softmax probability for the sampled token
202    let max_logit = logits_vec
203        .iter()
204        .copied()
205        .max_by(|a, b| a.partial_cmp(b).unwrap())
206        .unwrap_or(0.0);
207    let exp_sum: f32 = logits_vec.iter().map(|l| (l - max_logit).exp()).sum();
208    let probability = if exp_sum > 0.0 {
209        (logit - max_logit).exp() / exp_sum
210    } else {
211        0.0
212    };
213
214    // Check if this is an EOS token
215    let is_eos = eos_token_id == Some(token_id);
216
217    Ok(StreamToken {
218        token_id,
219        text: None, // Text decoding happens in generator if tokenizer available
220        logit,
221        probability,
222        is_eos,
223    })
224}
225
226/// Greedy sampling (argmax).
227fn sample_greedy(logits: &Tensor) -> Result<u32> {
228    let logits_vec = logits.to_vec1::<f32>()?;
229    let token = logits_vec
230        .iter()
231        .enumerate()
232        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
233        .map(|(idx, _)| u32::try_from(idx).unwrap_or(u32::MAX))
234        .ok_or_else(|| crate::error::InferenceError::SamplingError {
235            reason: "Empty logits".to_string(),
236        })?;
237    Ok(token)
238}
239
240/// Top-k sampling.
241fn sample_top_k(logits: &Tensor, k: usize) -> Result<u32> {
242    let logits_vec = logits.to_vec1::<f32>()?;
243
244    // Get top-k indices
245    let mut indexed: Vec<(usize, f32)> = logits_vec.iter().copied().enumerate().collect();
246    indexed.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
247    indexed.truncate(k);
248
249    // Apply softmax to top-k
250    let max_logit = indexed[0].1;
251    let exp_sum: f32 = indexed.iter().map(|(_, l)| (l - max_logit).exp()).sum();
252    let probs: Vec<f64> = indexed
253        .iter()
254        .map(|(_, l)| f64::from((l - max_logit).exp() / exp_sum))
255        .collect();
256
257    // Sample from top-k
258    let mut rng = rand::thread_rng();
259    let r: f64 = rng.gen();
260    let mut cumsum = 0.0;
261    for (i, &p) in probs.iter().enumerate() {
262        cumsum += p;
263        if r <= cumsum {
264            return Ok(u32::try_from(indexed[i].0).unwrap_or(u32::MAX));
265        }
266    }
267
268    Ok(u32::try_from(indexed[0].0).unwrap_or(u32::MAX))
269}
270
271/// Top-p (nucleus) sampling.
272fn sample_top_p(logits: &Tensor, p: f64) -> Result<u32> {
273    let logits_vec = logits.to_vec1::<f32>()?;
274
275    // Sort by probability (descending)
276    let mut indexed: Vec<(usize, f32)> = logits_vec.iter().copied().enumerate().collect();
277    indexed.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
278
279    // Apply softmax
280    let max_logit = indexed[0].1;
281    let exp_sum: f32 = indexed.iter().map(|(_, l)| (l - max_logit).exp()).sum();
282    let probs: Vec<(usize, f64)> = indexed
283        .iter()
284        .map(|(idx, l)| (*idx, f64::from((l - max_logit).exp() / exp_sum)))
285        .collect();
286
287    // Find nucleus (top-p)
288    let mut cumsum = 0.0;
289    let mut nucleus = Vec::new();
290    for (idx, prob) in probs {
291        nucleus.push((idx, prob));
292        cumsum += prob;
293        if cumsum >= p {
294            break;
295        }
296    }
297
298    // Sample from nucleus
299    let mut rng = rand::thread_rng();
300    let r: f64 = rng.gen();
301    let nucleus_sum: f64 = nucleus.iter().map(|(_, p)| p).sum();
302    let mut cumsum = 0.0;
303    for (idx, prob) in &nucleus {
304        cumsum += prob / nucleus_sum;
305        if r <= cumsum {
306            return Ok(u32::try_from(*idx).unwrap_or(u32::MAX));
307        }
308    }
309
310    Ok(u32::try_from(nucleus[0].0).unwrap_or(u32::MAX))
311}
312
313/// Temperature sampling.
314fn sample_temperature(logits: &Tensor, temperature: f64) -> Result<u32> {
315    let logits_vec = logits.to_vec1::<f32>()?;
316
317    // Apply temperature
318    #[allow(clippy::cast_possible_truncation)]
319    // temperature is user-controlled, truncation acceptable
320    let scaled: Vec<f32> = logits_vec.iter().map(|l| l / temperature as f32).collect();
321
322    // Apply softmax
323    let max_logit = scaled.iter().copied().fold(f32::NEG_INFINITY, f32::max);
324    let exp_sum: f32 = scaled.iter().map(|l| (l - max_logit).exp()).sum();
325    let probs: Vec<f64> = scaled
326        .iter()
327        .map(|l| f64::from((l - max_logit).exp() / exp_sum))
328        .collect();
329
330    // Sample
331    let mut rng = rand::thread_rng();
332    let r: f64 = rng.gen();
333    let mut cumsum = 0.0;
334    for (idx, &p) in probs.iter().enumerate() {
335        cumsum += p;
336        if r <= cumsum {
337            return Ok(u32::try_from(idx).unwrap_or(u32::MAX));
338        }
339    }
340
341    Ok(u32::try_from(probs.len() - 1).unwrap_or(u32::MAX))
342}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347    use candle_core::Device;
348
349    #[test]
350    fn test_greedy_sampling() {
351        let device = Device::Cpu;
352        let logits = Tensor::new(&[1.0f32, 3.0, 2.0, 0.5], &device).unwrap();
353
354        let token = sample_greedy(&logits).unwrap();
355        assert_eq!(token, 1); // Index of max value (3.0)
356    }
357
358    #[test]
359    fn test_top_k_sampling() {
360        let device = Device::Cpu;
361        let logits = Tensor::new(&[1.0f32, 3.0, 2.0, 0.5], &device).unwrap();
362
363        // Top-2: should sample from indices 1 (3.0) or 2 (2.0)
364        let token = sample_top_k(&logits, 2).unwrap();
365        assert!(token == 1 || token == 2);
366    }
367
368    #[test]
369    fn test_sampling_strategy_default() {
370        let strategy = SamplingStrategy::default();
371        assert!(matches!(strategy, SamplingStrategy::Greedy));
372    }
373
374    #[test]
375    fn test_apply_repetition_penalty() {
376        let device = Device::Cpu;
377        let mut logits = Tensor::new(&[1.0f32, 2.0, 3.0, 4.0], &device).unwrap();
378        let generated = vec![1, 3]; // Penalize tokens 1 and 3
379
380        apply_repetition_penalty(&mut logits, &generated, 2.0).unwrap();
381
382        let result = logits.to_vec1::<f32>().unwrap();
383        assert!((result[0] - 1.0).abs() < 1e-6); // Unchanged
384        assert!((result[1] - 1.0).abs() < 1e-6); // 2.0 / 2.0 = 1.0
385        assert!((result[2] - 3.0).abs() < 1e-6); // Unchanged
386        assert!((result[3] - 2.0).abs() < 1e-6); // 4.0 / 2.0 = 2.0
387    }
388
389    #[test]
390    fn test_apply_repetition_penalty_empty() {
391        let device = Device::Cpu;
392        let mut logits = Tensor::new(&[1.0f32, 2.0, 3.0], &device).unwrap();
393        let original = logits.to_vec1::<f32>().unwrap();
394
395        apply_repetition_penalty(&mut logits, &[], 2.0).unwrap();
396
397        let result = logits.to_vec1::<f32>().unwrap();
398        assert_eq!(result, original); // Unchanged with empty generated_ids
399    }
400
401    #[test]
402    fn test_apply_repetition_penalty_no_penalty() {
403        let device = Device::Cpu;
404        let mut logits = Tensor::new(&[1.0f32, 2.0, 3.0], &device).unwrap();
405        let original = logits.to_vec1::<f32>().unwrap();
406
407        apply_repetition_penalty(&mut logits, &[0, 1], 1.0).unwrap();
408
409        let result = logits.to_vec1::<f32>().unwrap();
410        assert_eq!(result, original); // Unchanged with penalty = 1.0
411    }
412
413    #[test]
414    fn test_sample_token_with_penalty() {
415        let device = Device::Cpu;
416        let logits = Tensor::new(&[1.0f32, 5.0, 2.0, 0.5], &device).unwrap();
417
418        // Without penalty, greedy should pick token 1 (highest logit 5.0)
419        let token = sample_token(&logits, &SamplingStrategy::Greedy, &[], 1.0).unwrap();
420        assert_eq!(token, 1);
421
422        // With high penalty on token 1, should pick token 2 (next highest)
423        let token = sample_token(&logits, &SamplingStrategy::Greedy, &[1], 10.0).unwrap();
424        assert_eq!(token, 2);
425    }
426}