Skip to main content

mistralrs_core/
sampler.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{
4    collections::{HashMap, HashSet},
5    sync::{Arc, LazyLock, Mutex},
6};
7
8use candle_core::{DType, Device, Error, Result, Tensor, D};
9use mistralrs_quant::{CumSumOp, SortOp};
10#[cfg(feature = "pyo3_macros")]
11use pyo3::pyclass;
12
13use rand::distr::{weighted::WeightedIndex, Distribution};
14use rand_isaac::Isaac64Rng;
15use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
16use serde::{Deserialize, Serialize};
17use tokenizers::Tokenizer;
18
19static DRY_SEQUENCE_BREAKERS: LazyLock<Vec<String>> =
20    LazyLock::new(|| ["\n", ":", "\"", "*"].map(String::from).to_vec());
21
22#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
23/// Optional generation defaults parsed from a model's `generation_config.json`.
24///
25/// These defaults are descriptive and opt-in: consumers may choose to apply them,
26/// partially apply them, or ignore them entirely.
27pub struct ModelGenerationDefaults {
28    pub do_sample: Option<bool>,
29    pub temperature: Option<f64>,
30    pub top_k: Option<usize>,
31    pub top_p: Option<f64>,
32    pub min_p: Option<f64>,
33    pub repetition_penalty: Option<f32>,
34    pub max_new_tokens: Option<usize>,
35    pub max_length: Option<usize>,
36}
37
38impl ModelGenerationDefaults {
39    pub fn is_empty(&self) -> bool {
40        self.do_sample.is_none()
41            && self.temperature.is_none()
42            && self.top_k.is_none()
43            && self.top_p.is_none()
44            && self.min_p.is_none()
45            && self.repetition_penalty.is_none()
46            && self.max_new_tokens.is_none()
47            && self.max_length.is_none()
48    }
49}
50
51#[derive(Clone, Debug, Serialize, Deserialize)]
52/// Stop sequences or ids.
53pub enum StopTokens {
54    Seqs(Vec<String>),
55    Ids(Vec<u32>),
56}
57
58#[derive(Clone, Debug, Serialize, Deserialize)]
59/// Sampling params are used to control sampling.
60pub struct SamplingParams {
61    pub temperature: Option<f64>,
62    pub top_k: Option<usize>,
63    pub top_p: Option<f64>,
64    pub min_p: Option<f64>,
65    pub top_n_logprobs: usize,
66    pub frequency_penalty: Option<f32>,
67    pub presence_penalty: Option<f32>,
68    pub repetition_penalty: Option<f32>,
69    pub stop_toks: Option<StopTokens>,
70    pub max_len: Option<usize>,
71    pub logits_bias: Option<HashMap<u32, f32>>,
72    pub n_choices: usize,
73    pub dry_params: Option<DrySamplingParams>,
74}
75
76impl SamplingParams {
77    /// This sets up the parameters so that there is:
78    /// - No temperature, topk, topp, minp
79    /// - No penalties, stop tokens, or logit bias
80    /// - No maximum length
81    ///
82    /// Unlike [`Self::deterministic`], this does not force `top_k = 1`.
83    pub fn neutral() -> Self {
84        Self {
85            temperature: None,
86            top_k: None,
87            top_p: None,
88            min_p: None,
89            top_n_logprobs: 0,
90            frequency_penalty: None,
91            presence_penalty: None,
92            repetition_penalty: None,
93            stop_toks: None,
94            max_len: None,
95            logits_bias: None,
96            n_choices: 1,
97            dry_params: None,
98        }
99    }
100
101    /// This sets up the parameters so that there is:
102    /// - No temperature, topk, topp, minp
103    /// - No penalties, stop tokens, or logit bias
104    /// - No maximum length
105    pub fn deterministic() -> Self {
106        Self {
107            temperature: None,
108            top_k: Some(1),
109            top_p: None,
110            min_p: None,
111            top_n_logprobs: 0,
112            frequency_penalty: None,
113            presence_penalty: None,
114            repetition_penalty: None,
115            stop_toks: None,
116            max_len: None,
117            logits_bias: None,
118            n_choices: 1,
119            dry_params: None,
120        }
121    }
122
123    /// Applies model-level generation defaults onto this request-local sampler config.
124    ///
125    /// This is opt-in and only updates fields that the model default explicitly provides.
126    pub fn apply_model_defaults(&mut self, defaults: &ModelGenerationDefaults) {
127        if defaults.do_sample == Some(false) {
128            self.temperature = None;
129            self.top_k = Some(1);
130            self.top_p = None;
131            self.min_p = None;
132        }
133
134        if let Some(temperature) = defaults.temperature {
135            self.temperature = if temperature == 0.0 {
136                None
137            } else {
138                Some(temperature)
139            };
140        }
141        if let Some(top_k) = defaults.top_k {
142            self.top_k = if top_k == 0 { None } else { Some(top_k) };
143        }
144        if let Some(top_p) = defaults.top_p {
145            self.top_p = Some(top_p);
146        }
147        if let Some(min_p) = defaults.min_p {
148            self.min_p = Some(min_p);
149        }
150        if let Some(repetition_penalty) = defaults.repetition_penalty {
151            self.repetition_penalty = Some(repetition_penalty);
152        }
153        if let Some(max_new_tokens) = defaults.max_new_tokens {
154            self.max_len = Some(max_new_tokens);
155        }
156    }
157}
158
159/// Parameters for DRY (Don't Repeat Yourself) sampling to reduce repetition.
160#[derive(Clone, Debug, Serialize, Deserialize)]
161pub struct DrySamplingParams {
162    pub sequence_breakers: Vec<String>,
163    pub multiplier: f32,
164    pub base: f32,
165    pub allowed_length: usize,
166}
167
168impl DrySamplingParams {
169    pub fn new_with_defaults(
170        multiplier: f32,
171        sequence_breakers: Option<Vec<String>>,
172        base: Option<f32>,
173        allowed_length: Option<usize>,
174    ) -> anyhow::Result<Self> {
175        Ok(Self {
176            base: base.unwrap_or(1.75),
177            allowed_length: allowed_length.unwrap_or(2),
178            sequence_breakers: sequence_breakers.unwrap_or(DRY_SEQUENCE_BREAKERS.clone()),
179            multiplier,
180        })
181    }
182}
183
184impl Default for DrySamplingParams {
185    fn default() -> Self {
186        Self {
187            multiplier: 0.0,
188            base: 1.75,
189            allowed_length: 2,
190            sequence_breakers: DRY_SEQUENCE_BREAKERS.clone(),
191        }
192    }
193}
194
195#[derive(Clone, Debug)]
196struct DrySamplingParamsInner {
197    pub sequence_breakers: HashSet<u32>,
198    pub multiplier: f32,
199    pub base: f32,
200    pub allowed_length: usize,
201}
202
203impl DrySamplingParamsInner {
204    pub fn from(other: DrySamplingParams, tokenizer: &Tokenizer) -> anyhow::Result<Self> {
205        Ok(Self {
206            base: other.base,
207            allowed_length: other.allowed_length,
208            sequence_breakers: HashSet::from_iter(
209                other
210                    .sequence_breakers
211                    .into_iter()
212                    .map(|breaker| {
213                        tokenizer
214                            // Prefix with 'a' to get the correct encoding of the token at the end of a text.
215                            //
216                            // FIXME: This is a hack. See https://github.com/LostRuins/koboldcpp/pull/982
217                            //        for the correct solution which covers multi-token sequence breakers
218                            //        and ambiguous encodings.
219                            .encode_fast(["a", &breaker].concat(), true)
220                            .map_err(anyhow::Error::msg)
221                            .map(|enc| {
222                                let ids = enc.get_ids();
223                                if !ids.is_empty() {
224                                    Some(ids[ids.len() - 1])
225                                } else {
226                                    None
227                                }
228                            })
229                    })
230                    .collect::<anyhow::Result<Vec<_>>>()?
231                    .into_iter()
232                    .flatten()
233                    .collect::<Vec<_>>(),
234            ),
235            multiplier: other.multiplier,
236        })
237    }
238}
239
240/// Customizable logits processor.
241///
242/// # Example
243/// ```rust
244/// use std::{sync::Arc, ops::Mul};
245/// use mistralrs_core::CustomLogitsProcessor;
246/// use candle_core::{Result, Tensor};
247///
248/// struct ThresholdLogitsProcessor;
249/// impl CustomLogitsProcessor for ThresholdLogitsProcessor {
250///     fn apply(&self, logits: &Tensor, _context: &[u32]) -> Result<Tensor> {
251///         // Mask is 1 for true, 0 for false.
252///         let mask = logits.ge(0.5)?;
253///         logits.broadcast_mul(&mask.to_dtype(logits.dtype())?)
254///     }
255/// }
256/// let processor1: Arc<dyn CustomLogitsProcessor> = Arc::new(|logits: &Tensor, _context: &[u32]| logits * 1.23);
257/// let processor2: Arc<dyn CustomLogitsProcessor> = Arc::new(ThresholdLogitsProcessor);
258/// ```
259pub trait CustomLogitsProcessor: Send + Sync {
260    /// Logits and sequence context (prompt and generated tokens), returning modified tokens.
261    fn apply(&self, logits: &Tensor, context: &[u32]) -> Result<Tensor>;
262}
263
264impl<T: Fn(&Tensor, &[u32]) -> Result<Tensor> + Send + Sync> CustomLogitsProcessor for T {
265    fn apply(&self, logits: &Tensor, context: &[u32]) -> Result<Tensor> {
266        self(logits, context)
267    }
268}
269
270/// Sampler for sampling.
271#[derive(Clone)]
272pub struct Sampler {
273    temperature: Option<f64>,
274    top_n_logprobs: usize,
275    tokenizer: Option<Arc<Tokenizer>>,
276    frequency_penalty: Option<f32>,
277    presence_penalty: Option<f32>,
278    repetition_penalty: Option<f32>,
279    dry_params: Option<DrySamplingParamsInner>,
280    top_k: i64,
281    top_p: f64,
282    min_p: f64,
283    logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
284    /// Cached Gumbel noise tensor to avoid reallocating it.
285    gumbel_cache: Arc<Mutex<Option<Tensor>>>,
286}
287
288#[cfg_attr(feature = "pyo3_macros", pyclass)]
289#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
290#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
291/// Top-n logprobs element
292pub struct TopLogprob {
293    pub token: u32,
294    pub logprob: f32,
295    pub bytes: Option<String>,
296}
297
298#[derive(Debug, Clone, Serialize, Deserialize)]
299pub struct Logprobs {
300    pub token: u32,
301    pub logprob: f32,
302    pub bytes: Option<String>,
303    pub top_logprobs: Option<Vec<TopLogprob>>,
304}
305
306/// Comparator for descending order by probability (second element of tuple).
307#[inline]
308fn cmp_desc_by_prob(a: &(u32, f32), b: &(u32, f32)) -> std::cmp::Ordering {
309    b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
310}
311
312/// Returns the top-k (index, probability) pairs from `probs`, sorted in descending order.
313/// Uses partial sort (O(n) + O(k log k)) instead of full sort (O(n log n)).
314///
315/// If `k >= probs.len()`, returns all elements sorted.
316/// Also zeros out elements in `probs` beyond top-k if `zero_rest` is true.
317fn partial_sort_top_k(probs: &mut [f32], k: usize, zero_rest: bool) -> Vec<(u32, f32)> {
318    let n = probs.len();
319    if n == 0 || k == 0 {
320        return Vec::new();
321    }
322
323    // Build (index, probability) pairs
324    let mut idx_probs: Vec<(u32, f32)> = (0..n as u32).map(|i| (i, probs[i as usize])).collect();
325
326    let k = k.min(n);
327
328    if k < n {
329        // Partial sort: partition so top k elements are in first k positions
330        // select_nth_unstable_by places the k-1th largest at position k-1,
331        // with all larger elements before it (unsorted) and smaller after
332        idx_probs.select_nth_unstable_by(k - 1, cmp_desc_by_prob);
333
334        if zero_rest {
335            // Zero out elements beyond top-k
336            for (idx, _) in idx_probs[k..].iter() {
337                probs[*idx as usize] = 0.0;
338            }
339        }
340
341        // Truncate to top k
342        idx_probs.truncate(k);
343    }
344
345    // Sort just the top k elements (descending by probability)
346    idx_probs.sort_unstable_by(cmp_desc_by_prob);
347
348    idx_probs
349}
350
351/// Find the index of the maximum element in a slice. O(n) scan.
352#[inline]
353fn argmax_f32(values: &[f32]) -> u32 {
354    values
355        .iter()
356        .enumerate()
357        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
358        .map(|(i, _)| i as u32)
359        .unwrap_or(0)
360}
361
362impl Sampler {
363    #[allow(clippy::too_many_arguments)]
364    pub fn new(
365        temperature: Option<f64>,
366        top_n_logprobs: usize,
367        tokenizer: Option<Arc<Tokenizer>>,
368        frequency_penalty: Option<f32>,
369        presence_penalty: Option<f32>,
370        repetition_penalty: Option<f32>,
371        dry_params: Option<DrySamplingParams>,
372        top_k: i64,
373        top_p: f64,
374        min_p: f64,
375        logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
376    ) -> anyhow::Result<Self> {
377        let temperature = if temperature.is_none_or(|v| v < 1e-7) {
378            None
379        } else {
380            temperature
381        };
382        let dry_params = if let Some(ref tokenizer) = tokenizer {
383            dry_params.map(|params| DrySamplingParamsInner::from(params, tokenizer))
384        } else {
385            None
386        };
387        let dry_params = match dry_params {
388            Some(fallible) => Some(fallible?),
389            None => None,
390        };
391        Ok(Self {
392            temperature,
393            top_n_logprobs,
394            tokenizer,
395            frequency_penalty,
396            presence_penalty,
397            repetition_penalty,
398            dry_params,
399            top_k,
400            top_p,
401            min_p,
402            logits_processors,
403            gumbel_cache: Arc::new(Mutex::new(None)),
404        })
405    }
406
407    fn get_top_logprobs(&self, probs: &[f32]) -> Result<Vec<TopLogprob>> {
408        let k = self.top_n_logprobs.min(probs.len());
409        if k == 0 {
410            return Ok(Vec::new());
411        }
412
413        // Use partial sort helper (doesn't modify probs since we pass a copy)
414        let mut probs_copy = probs.to_vec();
415        let top_k = partial_sort_top_k(&mut probs_copy, k, false);
416
417        // Build the result vector with log10 of probabilities and optional decoding
418        let mut result = Vec::with_capacity(k);
419        if let Some(tokenizer) = &self.tokenizer {
420            for (token, prob) in top_k {
421                let decoded = tokenizer
422                    .decode(&[token], false)
423                    .map_err(|e| Error::Msg(e.to_string()))?;
424                result.push(TopLogprob {
425                    token,
426                    logprob: prob.log(10.0),
427                    bytes: Some(decoded),
428                });
429            }
430        } else {
431            for (token, prob) in top_k {
432                result.push(TopLogprob {
433                    token,
434                    logprob: prob.log(10.0),
435                    bytes: None,
436                });
437            }
438        }
439        Ok(result)
440    }
441
442    fn sample_argmax(&self, logits: Tensor, return_logprobs: bool) -> Result<Logprobs> {
443        let probs: Vec<f32> = logits.to_vec1()?;
444        let next_token = argmax_f32(&probs);
445        let logprob = probs[next_token as usize].log(10.0);
446
447        let top_logprobs = if return_logprobs {
448            Some(self.get_top_logprobs(&probs)?)
449        } else {
450            None
451        };
452
453        let bytes = if let Some(tokenizer) = &self.tokenizer {
454            Some(
455                tokenizer
456                    .decode(&[next_token], false)
457                    .map_err(|x| Error::Msg(x.to_string()))?,
458            )
459        } else {
460            None
461        };
462
463        Ok(Logprobs {
464            token: next_token,
465            logprob,
466            top_logprobs,
467            bytes,
468        })
469    }
470
471    #[allow(unused)]
472    fn sample_fast(
473        &self,
474        logits: Tensor,
475        context: &[u32],
476        return_logprobs: bool,
477        top_k: i64,
478        top_p: f64,
479        min_p: f64,
480    ) -> Result<Logprobs> {
481        let mut probs = logits.to_dtype(DType::F32)?;
482
483        for processor in &self.logits_processors {
484            probs = processor.apply(&probs, context)?;
485        }
486
487        let context = Tensor::new(context, logits.device())?;
488        let mut counts = logits.zeros_like()?;
489        counts = counts.scatter_add(
490            &context,
491            &context.ones_like()?.to_dtype(counts.dtype())?,
492            D::Minus1,
493        )?;
494
495        let presence = counts
496            .gt(0.)?
497            .where_cond(&counts.ones_like()?, &counts.zeros_like()?)?;
498
499        match self.frequency_penalty {
500            Some(freq_penalty) if freq_penalty != 0. => {
501                probs = (probs - (freq_penalty as f64 * counts)?)?;
502            }
503            _ => (),
504        }
505
506        match self.presence_penalty {
507            Some(pres_penalty) if pres_penalty != 0. => {
508                probs = (probs - (pres_penalty as f64 * &presence)?)?;
509            }
510            _ => (),
511        }
512
513        match self.repetition_penalty {
514            Some(rep_penalty) if rep_penalty != 1. => {
515                let pos_mask = probs.gt(0.)?;
516                let scaled_pos = (&probs / (rep_penalty as f64))?;
517                let scaled_neg = (&probs * (rep_penalty as f64))?;
518                let modified = pos_mask.where_cond(&scaled_pos, &scaled_neg)?;
519
520                let pres_mask = presence.gt(0.)?;
521                probs = pres_mask.where_cond(&modified, &probs)?;
522            }
523            _ => (),
524        }
525
526        probs = candle_nn::ops::softmax_last_dim(&(probs / self.temperature.unwrap_or(1.))?)?;
527
528        // Top-K
529        if top_k > 0 {
530            let sorted_values = probs.fast_sort_asc(D::Minus1)?;
531            let topk_values = sorted_values.narrow(
532                D::Minus1,
533                sorted_values.dim(D::Minus1)? - top_k as usize,
534                top_k as usize,
535            )?;
536
537            // select the kth largest value as threshold
538            let threshold = topk_values.get_on_dim(D::Minus1, 0)?.unsqueeze(0)?;
539            let mask_topk = probs.broadcast_ge(&threshold)?;
540            probs = mask_topk.where_cond(&probs, &Tensor::zeros_like(&probs)?)?;
541        }
542
543        // Top-P (nucleus)
544        if top_p > 0.0 && top_p < 1.0 {
545            let sorted_probs = probs.fast_sort_asc(D::Minus1)?;
546
547            let cumsum = sorted_probs.fast_cumsum(D::Minus1)?;
548
549            let mask_topp = cumsum.le(top_p)?;
550
551            let masked_sorted =
552                mask_topp.where_cond(&sorted_probs, &Tensor::zeros_like(&sorted_probs)?)?;
553
554            let threshold = masked_sorted.max(D::Minus1)?;
555            let threshold = threshold.unsqueeze(D::Minus1)?;
556            let mask_full = probs.broadcast_ge(&threshold)?;
557            probs = mask_full.where_cond(&probs, &Tensor::zeros_like(&probs)?)?;
558        }
559
560        // Min-P
561        if min_p > 0.0 && min_p < 1.0 {
562            let max_vals = probs.max(D::Minus1)?;
563            let threshold_min = (max_vals.unsqueeze(D::Minus1)? * min_p)?;
564            let mask_minp = probs.broadcast_gt(&threshold_min)?;
565            probs = mask_minp.where_cond(&probs, &Tensor::zeros_like(&probs)?)?;
566        }
567
568        // Sample using the Gumbel-max trick fully on-device.
569        let log_probs = probs.log()?;
570        // Generate cached Gumbel noise (-log(-log(u))) once.
571        let gumbel = {
572            let mut guard = self.gumbel_cache.lock().unwrap();
573            if guard.is_none() {
574                let uniform = Tensor::rand(0f32, 1f32, log_probs.shape(), log_probs.device())?;
575                let noise = uniform
576                    .clamp(1e-20, 1.0)?
577                    .log()? // ln(u)
578                    .neg()? // -ln(u)
579                    .log()? // ln(-ln(u))
580                    .neg()?; // -ln(-ln(u))
581                *guard = Some(noise);
582            }
583            guard.as_ref().unwrap().clone()
584        };
585
586        let gumbel_logits = (&log_probs + &gumbel)?;
587        let next_token = gumbel_logits.argmax(D::Minus1)?.to_scalar::<u32>()?;
588
589        // Extract the top‑n log‑probs if the caller asked for them.
590        let (top_logprobs, logprob) = if return_logprobs {
591            let k = self.top_n_logprobs;
592
593            let sorted_values = probs.fast_sort_asc(D::Minus1)?;
594            let topk_values = sorted_values
595                .narrow(
596                    D::Minus1,
597                    sorted_values.dim(D::Minus1)? - top_k as usize,
598                    top_k as usize,
599                )?
600                .to_vec1::<f32>()?;
601
602            let sorted_idxs = probs.fast_argsort_asc(D::Minus1)?;
603            let topk_idxs = sorted_idxs
604                .narrow(
605                    D::Minus1,
606                    sorted_values.dim(D::Minus1)? - top_k as usize,
607                    top_k as usize,
608                )?
609                .to_vec1::<u32>()?;
610
611            let mut result = Vec::with_capacity(k);
612            if let Some(tokenizer) = &self.tokenizer {
613                for (prob, token) in topk_values.iter().zip(topk_idxs) {
614                    let decoded = tokenizer
615                        .decode(&[token], false)
616                        .map_err(|e| Error::Msg(e.to_string()))?;
617                    result.push(TopLogprob {
618                        token,
619                        logprob: prob.log(10.0),
620                        bytes: Some(decoded),
621                    });
622                }
623            } else {
624                for (prob, token) in topk_values.iter().zip(topk_idxs) {
625                    result.push(TopLogprob {
626                        token,
627                        logprob: prob.log(10.0),
628                        bytes: None,
629                    });
630                }
631            }
632
633            let logprob = result.last().map(|res| res.logprob).unwrap_or(1.);
634
635            (Some(result), logprob)
636        } else {
637            (None, 1.)
638        };
639
640        let bytes = if let Some(tokenizer) = &self.tokenizer {
641            Some(
642                tokenizer
643                    .decode(&[next_token], false)
644                    .map_err(|x| Error::Msg(x.to_string()))?,
645            )
646        } else {
647            None
648        };
649
650        Ok(Logprobs {
651            token: next_token,
652            logprob,
653            top_logprobs,
654            bytes,
655        })
656    }
657    fn sample_speculative_top_kp_min_p(
658        &self,
659        logits: Tensor,
660        return_logprobs: bool,
661        top_k: i64,
662        top_p: f32,
663        min_p: f32,
664    ) -> Result<Logprobs> {
665        let mut probs: Vec<f32> = logits.to_vec1()?;
666
667        // Determine how many elements we need for partial sort
668        let k = if top_k > 0 {
669            top_k as usize
670        } else {
671            probs.len()
672        };
673
674        // Get sorted top-k indices with partial sort, zeroing out rest
675        let idx_probs = partial_sort_top_k(&mut probs, k, true);
676
677        // TOP P
678        // top-p sampling (or "nucleus sampling") samples from the smallest set of
679        // tokens that exceed probability top_p. This way we never sample tokens that
680        // have very low probabilities and are less likely to go "off the rails".
681
682        // Clamp smaller probabilities to zero.
683        let mut cumsum = 0.;
684        for (index, prob) in &idx_probs {
685            if cumsum >= top_p {
686                probs[*index as usize] = 0.0;
687            } else {
688                cumsum += prob;
689            }
690        }
691
692        // Get max_p from first sorted element
693        let max_p = idx_probs.first().map(|(_, p)| *p).unwrap_or(0.0);
694
695        // MIN P
696        // min-p sampling samples from the tokens whose prob are greater than
697        // (max prob of token in dist) * min_p
698
699        // Clamp smaller probabilities to zero.
700        let min_p_threshold = max_p * min_p;
701        for (index, prob) in &idx_probs {
702            if min_p_threshold >= *prob {
703                probs[*index as usize] = 0.0;
704            }
705        }
706
707        // Find argmax directly on the Vec (O(n) scan, no Tensor creation)
708        let next_token = argmax_f32(&probs);
709        let logprob = probs[next_token as usize].log(10.0);
710
711        let top_logprobs = if return_logprobs {
712            Some(self.get_top_logprobs(&probs)?)
713        } else {
714            None
715        };
716
717        let bytes = if let Some(tokenizer) = &self.tokenizer {
718            Some(
719                tokenizer
720                    .decode(&[next_token], false)
721                    .map_err(|x| Error::Msg(x.to_string()))?,
722            )
723        } else {
724            None
725        };
726
727        Ok(Logprobs {
728            token: next_token,
729            logprob,
730            top_logprobs,
731            bytes,
732        })
733    }
734
735    fn sample_multinomial(
736        &self,
737        probs: &[f32],
738        return_logprobs: bool,
739        rng: Arc<Mutex<Isaac64Rng>>,
740    ) -> Result<Logprobs> {
741        let distr = match WeightedIndex::new(probs) {
742            Ok(distr) => distr,
743            Err(e) => {
744                if let Some((idx, prob)) = probs
745                    .iter()
746                    .enumerate()
747                    .find(|(_, prob)| !prob.is_finite() || **prob < 0.0)
748                {
749                    return Err(Error::Msg(format!(
750                        "Invalid sampling probability at index {idx}: {prob}. The model likely produced NaN/Inf logits."
751                    )));
752                }
753
754                let positive_weight_sum: f64 = probs
755                    .iter()
756                    .copied()
757                    .filter(|prob| prob.is_finite() && *prob > 0.0)
758                    .map(f64::from)
759                    .sum();
760
761                if positive_weight_sum == 0.0 {
762                    return Err(Error::Msg(
763                        "All sampling probabilities are zero after filtering (top-k/top-p/min-p)."
764                            .to_string(),
765                    ));
766                }
767
768                return Err(Error::Msg(format!(
769                    "Failed to construct multinomial sampler: {e}"
770                )));
771            }
772        };
773
774        let mut mut_ref_rng = &mut *rng.lock().expect("could not lock rng mutex");
775        let next_token = distr.sample(&mut mut_ref_rng); // "Find the first item which has a weight *higher* than the chosen weight."
776        let logprob = probs[next_token].log(10.0);
777
778        let top_logprobs = if return_logprobs {
779            Some(self.get_top_logprobs(probs)?)
780        } else {
781            None
782        };
783
784        let bytes = if let Some(tokenizer) = &self.tokenizer {
785            Some(
786                tokenizer
787                    .decode(&[next_token.try_into().unwrap()], false)
788                    .map_err(|x| Error::Msg(x.to_string()))?,
789            )
790        } else {
791            None
792        };
793
794        Ok(Logprobs {
795            token: next_token as u32,
796            logprob,
797            top_logprobs,
798            bytes,
799        })
800    }
801
802    #[allow(clippy::too_many_arguments)]
803    fn sample_top_kp_min_p(
804        &self,
805        probs: &mut [f32],
806        top_k: i64,
807        top_p: f32,
808        min_p: f32,
809        return_logprobs: bool,
810        rng: Arc<Mutex<Isaac64Rng>>,
811    ) -> Result<Logprobs> {
812        // Determine how many elements we need for partial sort
813        let k = if top_k > 0 {
814            top_k as usize
815        } else {
816            probs.len()
817        };
818
819        // Get sorted top-k indices with partial sort, zeroing out rest
820        let idx_probs = partial_sort_top_k(probs, k, true);
821
822        if top_p <= 0.0 || top_p >= 1.0 {
823            return self.sample_multinomial(probs, return_logprobs, rng);
824        }
825
826        // TOP P
827
828        // top-p sampling (or "nucleus sampling") samples from the smallest set of
829        // tokens that exceed probability top_p. This way we never sample tokens that
830        // have very low probabilities and are less likely to go "off the rails".
831
832        // Clamp smaller probabilities to zero.
833        let mut cumsum = 0.;
834        for (index, prob) in &idx_probs {
835            if cumsum >= top_p {
836                probs[*index as usize] = 0.0;
837            } else {
838                cumsum += prob;
839            }
840        }
841
842        if min_p <= 0.0 || min_p >= 1.0 {
843            return self.sample_multinomial(probs, return_logprobs, rng);
844        }
845
846        // Get max_p from first sorted element
847        let max_p = idx_probs.first().map(|(_, p)| *p).unwrap_or(0.0);
848
849        // MIN P
850
851        // min-p sampling samples from the tokens whose prob are greater than
852        // (max prob of token in dist) * min_p
853
854        // Clamp smaller probabilities to zero.
855        let min_p_threshold = max_p * min_p;
856        for (index, prob) in &idx_probs {
857            if min_p_threshold >= *prob {
858                probs[*index as usize] = 0.0;
859            }
860        }
861
862        // Sample with clamped probabilities.
863        self.sample_multinomial(probs, return_logprobs, rng)
864    }
865
866    fn apply_penalties(&self, mut logits: Vec<f32>, context: &[u32]) -> Result<Tensor> {
867        if context.is_empty() {
868            candle_core::bail!("Penalty context is empty, this should not happen.");
869        }
870
871        // Dry penalty
872        self.apply_dry_penalty(&mut logits, context)?;
873
874        // Frequency, presence, repetition penalty
875        self.apply_freq_pres_rep_penalty(&mut logits, context)?;
876
877        let vocab_size = logits.len();
878        Tensor::from_vec(logits, vocab_size, &Device::Cpu)
879    }
880
881    fn apply_freq_pres_rep_penalty(&self, logits: &mut [f32], context: &[u32]) -> Result<()> {
882        if self.frequency_penalty.is_some()
883            || self.presence_penalty.is_some()
884            || self.repetition_penalty.is_some()
885        {
886            let frequency_penalty = self.frequency_penalty.unwrap_or(0.);
887            let presence_penalty = self.presence_penalty.unwrap_or(0.);
888            let repetition_penalty = self.repetition_penalty.unwrap_or(1.);
889
890            //mu[j] -> mu[j] - c[j] * alpha_frequency - float(c[j] > 0) * alpha_presence
891
892            let mut counts = vec![0.0f32; logits.len()];
893            for ctx in context.iter() {
894                // Llama 3.2 uses a hack triggering this error... we wouldn't want a weight on it anyway
895                if *ctx as usize >= logits.len() {
896                    continue;
897                }
898                counts[*ctx as usize] += 1.0;
899            }
900
901            for (token_id, logit) in logits.iter_mut().enumerate() {
902                let count = counts[token_id];
903                *logit = *logit
904                    - count * frequency_penalty
905                    - if count > 0.0 { 1. } else { 0. } * presence_penalty;
906
907                if repetition_penalty != 1.0 && count > 0.0 {
908                    if *logit > 0.0 {
909                        *logit /= repetition_penalty;
910                    } else {
911                        *logit *= repetition_penalty;
912                    }
913                }
914            }
915        }
916        Ok(())
917    }
918
919    /// Threshold for using parallel iteration in dry penalty.
920    /// Below this, sequential is faster due to parallel overhead.
921    const DRY_PENALTY_PAR_THRESHOLD: usize = 1024;
922
923    fn apply_dry_penalty(&self, logits: &mut [f32], context: &[u32]) -> Result<()> {
924        if let Some(ref params) = self.dry_params {
925            if params.multiplier == 0. {
926                return Ok(());
927            }
928
929            let last_token = *context.last().unwrap();
930
931            // Use parallel iteration only for large contexts
932            let match_indices: Vec<usize> = if context.len() > Self::DRY_PENALTY_PAR_THRESHOLD {
933                context
934                    .par_iter()
935                    .enumerate()
936                    .take(context.len() - 1)
937                    .filter(|(_i, x)| last_token == **x)
938                    .map(|(i, _)| i)
939                    .collect()
940            } else {
941                context
942                    .iter()
943                    .enumerate()
944                    .take(context.len() - 1)
945                    .filter(|(_i, x)| last_token == **x)
946                    .map(|(i, _)| i)
947                    .collect()
948            };
949
950            let mut match_lengths = HashMap::new();
951
952            for i in match_indices {
953                let next_token = context[i + 1];
954
955                if params.sequence_breakers.contains(&next_token) {
956                    continue;
957                }
958
959                let mut match_length = 1;
960
961                // Limit match length to avoid quadratic runtime and potential DoS with adversarial inputs.
962                while match_length < 50 {
963                    if match_length > i {
964                        // Start of input
965                        break;
966                    }
967
968                    let j = i - match_length;
969
970                    let prev_tok = context[context.len() - (match_length + 1)];
971                    if context[j] != prev_tok {
972                        // Start of match reached
973                        break;
974                    }
975
976                    if params.sequence_breakers.contains(&prev_tok) {
977                        // Seq breaking tok reached
978                        break;
979                    }
980
981                    match_length += 1;
982                }
983
984                #[allow(clippy::map_entry)]
985                if match_lengths.contains_key(&next_token) {
986                    match_lengths.insert(next_token, match_length.max(match_lengths[&next_token]));
987                } else {
988                    match_lengths.insert(next_token, match_length);
989                }
990            }
991
992            // Actually apply penalties
993            for (tok, match_len) in match_lengths {
994                if match_len >= params.allowed_length {
995                    // Llama 3.2 uses a hack triggering this error... we wouldn't want a weight on it anyway
996                    if tok as usize >= logits.len() {
997                        continue;
998                    }
999                    let penalty = params.multiplier
1000                        * params.base.powf((match_len - params.allowed_length) as f32);
1001                    logits[tok as usize] -= penalty;
1002                }
1003            }
1004        }
1005        Ok(())
1006    }
1007
1008    #[allow(unused)]
1009    /// Sample the provided tokens.
1010    ///
1011    /// If the temperature is `None`, argmax sampling is used. Otherwise, the selected sampling is used.
1012    /// With `top-p` sampling, if the `top-p` value is `<= 0.0` or `>= 1.0`, multinomial sampling is used.
1013    pub fn sample(
1014        &self,
1015        logits: Tensor,
1016        context: &[u32],
1017        return_logprobs: bool,
1018        rng: Arc<Mutex<Isaac64Rng>>,
1019        sample_speculative: bool,
1020        multiple_sequences: bool,
1021    ) -> Result<Logprobs> {
1022        // if cfg!(feature = "metal") && !multiple_sequences {
1023        //     return self.sample_fast(
1024        //         logits,
1025        //         context,
1026        //         return_logprobs,
1027        //         self.top_k,
1028        //         self.top_p,
1029        //         self.min_p,
1030        //     );
1031        // }
1032
1033        let logits = logits.to_vec1()?;
1034        let mut logits = self.apply_penalties(logits, context)?;
1035        for processor in &self.logits_processors {
1036            logits = processor.apply(&logits, context)?;
1037        }
1038        let next_token = if sample_speculative {
1039            match self.temperature {
1040                None => self.sample_speculative_top_kp_min_p(
1041                    logits,
1042                    return_logprobs,
1043                    self.top_k,
1044                    self.top_p as f32,
1045                    self.min_p as f32,
1046                )?,
1047                Some(temperature) => {
1048                    let logits = (&logits / temperature)?;
1049                    let probs = candle_nn::ops::softmax_last_dim(&logits)?;
1050
1051                    self.sample_speculative_top_kp_min_p(
1052                        probs,
1053                        return_logprobs,
1054                        self.top_k,
1055                        self.top_p as f32,
1056                        self.min_p as f32,
1057                    )?
1058                }
1059            }
1060        } else {
1061            match self.temperature {
1062                None => self.sample_argmax(logits, return_logprobs)?,
1063                Some(temperature) => {
1064                    let logits = (&logits / temperature)?;
1065                    let probs = candle_nn::ops::softmax_last_dim(&logits)?;
1066                    let mut probs: Vec<f32> = probs.to_vec1()?;
1067
1068                    self.sample_top_kp_min_p(
1069                        &mut probs,
1070                        self.top_k,
1071                        self.top_p as f32,
1072                        self.min_p as f32,
1073                        return_logprobs,
1074                        rng,
1075                    )?
1076                }
1077            }
1078        };
1079        Ok(next_token)
1080    }
1081}
1082
1083#[cfg(test)]
1084mod tests {
1085    use super::{ModelGenerationDefaults, SamplingParams};
1086
1087    #[test]
1088    fn test_argmax() {
1089        use super::Sampler;
1090        use candle_core::{Device, Tensor};
1091        use rand::SeedableRng;
1092        use rand_isaac::Isaac64Rng;
1093        use std::sync::Arc;
1094        use std::sync::Mutex;
1095
1096        let sampler = Sampler::new(
1097            None,
1098            10,
1099            None,
1100            None,
1101            None,
1102            None,
1103            None,
1104            32,
1105            0.1,
1106            0.05,
1107            vec![],
1108        )
1109        .unwrap();
1110        let logits = Tensor::arange(0f32, 1024f32, &Device::Cpu).unwrap();
1111        let rng = Arc::new(Mutex::new(Isaac64Rng::seed_from_u64(42)));
1112        let res = sampler
1113            .sample(
1114                logits,
1115                &(0..1024).collect::<Vec<_>>(),
1116                false,
1117                rng,
1118                false,
1119                false,
1120            )
1121            .unwrap();
1122        assert_eq!(res.token, 1023);
1123        assert_eq!(res.top_logprobs, None);
1124        assert_eq!(res.logprob, 1023f64.log(10.) as f32)
1125    }
1126
1127    #[test]
1128    fn test_gumbel_speculative() {
1129        use super::Sampler;
1130        use candle_core::{Device, Tensor};
1131        use rand::SeedableRng;
1132        use rand_isaac::Isaac64Rng;
1133        use std::sync::Arc;
1134        use std::sync::Mutex;
1135
1136        let sampler = Sampler::new(
1137            None,
1138            10,
1139            None,
1140            None,
1141            None,
1142            None,
1143            None,
1144            32,
1145            0.1,
1146            0.05,
1147            vec![],
1148        )
1149        .unwrap();
1150        let logits = Tensor::arange(0f32, 1024f32, &Device::Cpu).unwrap();
1151        let rng = Arc::new(Mutex::new(Isaac64Rng::seed_from_u64(42)));
1152        let res = sampler
1153            .sample(
1154                logits,
1155                &(0..1024).collect::<Vec<_>>(),
1156                false,
1157                rng,
1158                true,
1159                false,
1160            )
1161            .unwrap();
1162        assert_eq!(res.token, 1023);
1163        assert_eq!(res.top_logprobs, None);
1164        assert_eq!(res.logprob, 1023f64.log(10.) as f32)
1165    }
1166
1167    #[test]
1168    fn test_apply_model_defaults() {
1169        let mut params = SamplingParams::neutral();
1170        params.apply_model_defaults(&ModelGenerationDefaults {
1171            do_sample: Some(true),
1172            temperature: Some(1.0),
1173            top_k: Some(32),
1174            top_p: Some(0.9),
1175            min_p: Some(0.05),
1176            repetition_penalty: Some(1.1),
1177            max_new_tokens: Some(256),
1178            max_length: None,
1179        });
1180
1181        assert_eq!(params.temperature, Some(1.0));
1182        assert_eq!(params.top_k, Some(32));
1183        assert_eq!(params.top_p, Some(0.9));
1184        assert_eq!(params.min_p, Some(0.05));
1185        assert_eq!(params.repetition_penalty, Some(1.1));
1186        assert_eq!(params.max_len, Some(256));
1187    }
1188
1189    #[test]
1190    fn test_apply_model_defaults_disables_sampling_when_requested() {
1191        let mut params = SamplingParams {
1192            temperature: Some(0.7),
1193            top_k: Some(40),
1194            top_p: Some(0.9),
1195            min_p: Some(0.1),
1196            ..SamplingParams::neutral()
1197        };
1198        params.apply_model_defaults(&ModelGenerationDefaults {
1199            do_sample: Some(false),
1200            ..Default::default()
1201        });
1202
1203        assert_eq!(params.temperature, None);
1204        assert_eq!(params.top_k, Some(1));
1205        assert_eq!(params.top_p, None);
1206        assert_eq!(params.min_p, None);
1207    }
1208}