Skip to main content

hanzo_engine/
sampler.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{
4    collections::{HashMap, HashSet},
5    sync::{Arc, LazyLock, Mutex},
6};
7
8use hanzo_ml::{Device, Error, Result, Tensor};
9#[cfg(feature = "pyo3_macros")]
10use pyo3::pyclass;
11
12use rand::distr::{weighted::WeightedIndex, Distribution};
13use rand_isaac::Isaac64Rng;
14use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
15use serde::{Deserialize, Serialize};
16use tokenizers::Tokenizer;
17
18static DRY_SEQUENCE_BREAKERS: LazyLock<Vec<String>> =
19    LazyLock::new(|| ["\n", ":", "\"", "*"].map(String::from).to_vec());
20
21#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
22/// Optional generation defaults parsed from a model's `generation_config.json`.
23///
24/// These defaults are descriptive and opt-in: consumers may choose to apply them,
25/// partially apply them, or ignore them entirely.
26pub struct ModelGenerationDefaults {
27    pub do_sample: Option<bool>,
28    pub temperature: Option<f64>,
29    pub top_k: Option<usize>,
30    pub top_p: Option<f64>,
31    pub min_p: Option<f64>,
32    pub repetition_penalty: Option<f32>,
33    pub max_new_tokens: Option<usize>,
34    pub max_length: Option<usize>,
35}
36
37impl ModelGenerationDefaults {
38    pub fn is_empty(&self) -> bool {
39        self.do_sample.is_none()
40            && self.temperature.is_none()
41            && self.top_k.is_none()
42            && self.top_p.is_none()
43            && self.min_p.is_none()
44            && self.repetition_penalty.is_none()
45            && self.max_new_tokens.is_none()
46            && self.max_length.is_none()
47    }
48}
49
50#[derive(Clone, Debug, Serialize, Deserialize)]
51/// Stop sequences or ids.
52pub enum StopTokens {
53    Seqs(Vec<String>),
54    Ids(Vec<u32>),
55}
56
57#[derive(Clone, Debug, Serialize, Deserialize)]
58/// Sampling params are used to control sampling.
59pub struct SamplingParams {
60    pub temperature: Option<f64>,
61    pub top_k: Option<usize>,
62    pub top_p: Option<f64>,
63    pub min_p: Option<f64>,
64    pub top_n_logprobs: usize,
65    pub frequency_penalty: Option<f32>,
66    pub presence_penalty: Option<f32>,
67    pub repetition_penalty: Option<f32>,
68    pub stop_toks: Option<StopTokens>,
69    pub max_len: Option<usize>,
70    pub logits_bias: Option<HashMap<u32, f32>>,
71    pub n_choices: usize,
72    pub dry_params: Option<DrySamplingParams>,
73}
74
75impl SamplingParams {
76    /// This sets up the parameters so that there is:
77    /// - No temperature, topk, topp, minp
78    /// - No penalties, stop tokens, or logit bias
79    /// - No maximum length
80    ///
81    /// Unlike [`Self::deterministic`], this does not force `top_k = 1`.
82    pub fn neutral() -> Self {
83        Self {
84            temperature: None,
85            top_k: None,
86            top_p: None,
87            min_p: None,
88            top_n_logprobs: 0,
89            frequency_penalty: None,
90            presence_penalty: None,
91            repetition_penalty: None,
92            stop_toks: None,
93            max_len: None,
94            logits_bias: None,
95            n_choices: 1,
96            dry_params: None,
97        }
98    }
99
100    /// This sets up the parameters so that there is:
101    /// - No temperature, topk, topp, minp
102    /// - No penalties, stop tokens, or logit bias
103    /// - No maximum length
104    pub fn deterministic() -> Self {
105        Self {
106            temperature: None,
107            top_k: Some(1),
108            top_p: None,
109            min_p: None,
110            top_n_logprobs: 0,
111            frequency_penalty: None,
112            presence_penalty: None,
113            repetition_penalty: None,
114            stop_toks: None,
115            max_len: None,
116            logits_bias: None,
117            n_choices: 1,
118            dry_params: None,
119        }
120    }
121
122    /// Applies model-level generation defaults onto this request-local sampler config.
123    ///
124    /// This is opt-in and only updates fields that the model default explicitly provides.
125    pub fn apply_model_defaults(&mut self, defaults: &ModelGenerationDefaults) {
126        if defaults.do_sample == Some(false) {
127            self.temperature = None;
128            self.top_k = Some(1);
129            self.top_p = None;
130            self.min_p = None;
131        }
132
133        if let Some(temperature) = defaults.temperature {
134            self.temperature = if temperature == 0.0 {
135                None
136            } else {
137                Some(temperature)
138            };
139        }
140        if let Some(top_k) = defaults.top_k {
141            self.top_k = if top_k == 0 { None } else { Some(top_k) };
142        }
143        if let Some(top_p) = defaults.top_p {
144            self.top_p = Some(top_p);
145        }
146        if let Some(min_p) = defaults.min_p {
147            self.min_p = Some(min_p);
148        }
149        if let Some(repetition_penalty) = defaults.repetition_penalty {
150            self.repetition_penalty = Some(repetition_penalty);
151        }
152        if let Some(max_new_tokens) = defaults.max_new_tokens {
153            self.max_len = Some(max_new_tokens);
154        }
155    }
156}
157
158/// Parameters for DRY (Don't Repeat Yourself) sampling to reduce repetition.
159#[derive(Clone, Debug, Serialize, Deserialize)]
160pub struct DrySamplingParams {
161    pub sequence_breakers: Vec<String>,
162    pub multiplier: f32,
163    pub base: f32,
164    pub allowed_length: usize,
165}
166
167impl DrySamplingParams {
168    pub fn new_with_defaults(
169        multiplier: f32,
170        sequence_breakers: Option<Vec<String>>,
171        base: Option<f32>,
172        allowed_length: Option<usize>,
173    ) -> anyhow::Result<Self> {
174        Ok(Self {
175            base: base.unwrap_or(1.75),
176            allowed_length: allowed_length.unwrap_or(2),
177            sequence_breakers: sequence_breakers.unwrap_or(DRY_SEQUENCE_BREAKERS.clone()),
178            multiplier,
179        })
180    }
181}
182
183impl Default for DrySamplingParams {
184    fn default() -> Self {
185        Self {
186            multiplier: 0.0,
187            base: 1.75,
188            allowed_length: 2,
189            sequence_breakers: DRY_SEQUENCE_BREAKERS.clone(),
190        }
191    }
192}
193
194#[derive(Clone, Debug)]
195struct DrySamplingParamsInner {
196    pub sequence_breakers: HashSet<u32>,
197    pub multiplier: f32,
198    pub base: f32,
199    pub allowed_length: usize,
200}
201
202impl DrySamplingParamsInner {
203    pub fn from(other: DrySamplingParams, tokenizer: &Tokenizer) -> anyhow::Result<Self> {
204        Ok(Self {
205            base: other.base,
206            allowed_length: other.allowed_length,
207            sequence_breakers: HashSet::from_iter(
208                other
209                    .sequence_breakers
210                    .into_iter()
211                    .map(|breaker| {
212                        tokenizer
213                            // Prefix with 'a' to get the correct encoding of the token at the end of a text.
214                            //
215                            // FIXME: This is a hack. See https://github.com/LostRuins/koboldcpp/pull/982
216                            //        for the correct solution which covers multi-token sequence breakers
217                            //        and ambiguous encodings.
218                            .encode_fast(["a", &breaker].concat(), true)
219                            .map_err(anyhow::Error::msg)
220                            .map(|enc| {
221                                let ids = enc.get_ids();
222                                if !ids.is_empty() {
223                                    Some(ids[ids.len() - 1])
224                                } else {
225                                    None
226                                }
227                            })
228                    })
229                    .collect::<anyhow::Result<Vec<_>>>()?
230                    .into_iter()
231                    .flatten()
232                    .collect::<Vec<_>>(),
233            ),
234            multiplier: other.multiplier,
235        })
236    }
237}
238
239/// Customizable logits processor.
240///
241/// # Example
242/// ```rust
243/// use std::{sync::Arc, ops::Mul};
244/// use hanzo_engine::CustomLogitsProcessor;
245/// use hanzo_ml::{Result, Tensor};
246///
247/// struct ThresholdLogitsProcessor;
248/// impl CustomLogitsProcessor for ThresholdLogitsProcessor {
249///     fn apply(&self, logits: &Tensor, _context: &[u32]) -> Result<Tensor> {
250///         // Mask is 1 for true, 0 for false.
251///         let mask = logits.ge(0.5)?;
252///         logits.broadcast_mul(&mask.to_dtype(logits.dtype())?)
253///     }
254/// }
255/// let processor1: Arc<dyn CustomLogitsProcessor> = Arc::new(|logits: &Tensor, _context: &[u32]| logits * 1.23);
256/// let processor2: Arc<dyn CustomLogitsProcessor> = Arc::new(ThresholdLogitsProcessor);
257/// ```
258pub trait CustomLogitsProcessor: Send + Sync {
259    /// Logits and sequence context (prompt and generated tokens), returning modified tokens.
260    fn apply(&self, logits: &Tensor, context: &[u32]) -> Result<Tensor>;
261}
262
263impl<T: Fn(&Tensor, &[u32]) -> Result<Tensor> + Send + Sync> CustomLogitsProcessor for T {
264    fn apply(&self, logits: &Tensor, context: &[u32]) -> Result<Tensor> {
265        self(logits, context)
266    }
267}
268
269/// Sampler for sampling.
270#[derive(Clone)]
271pub struct Sampler {
272    temperature: Option<f64>,
273    top_n_logprobs: usize,
274    tokenizer: Option<Arc<Tokenizer>>,
275    frequency_penalty: Option<f32>,
276    presence_penalty: Option<f32>,
277    repetition_penalty: Option<f32>,
278    dry_params: Option<DrySamplingParamsInner>,
279    top_k: i64,
280    top_p: f64,
281    min_p: f64,
282    logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
283}
284
285#[cfg_attr(feature = "pyo3_macros", pyclass)]
286#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
287#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
288/// Top-n logprobs element
289pub struct TopLogprob {
290    pub token: u32,
291    pub logprob: f32,
292    pub bytes: Option<String>,
293}
294
295#[derive(Debug, Clone, Serialize, Deserialize)]
296pub struct Logprobs {
297    pub token: u32,
298    pub logprob: f32,
299    pub bytes: Option<String>,
300    pub top_logprobs: Option<Vec<TopLogprob>>,
301}
302
303/// Comparator for descending order by probability (second element of tuple).
304#[inline]
305fn cmp_desc_by_prob(a: &(u32, f32), b: &(u32, f32)) -> std::cmp::Ordering {
306    b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
307}
308
309/// Returns the top-k (index, probability) pairs from `probs`, sorted in descending order.
310/// Uses partial sort (O(n) + O(k log k)) instead of full sort (O(n log n)).
311///
312/// If `k >= probs.len()`, returns all elements sorted.
313/// Also zeros out elements in `probs` beyond top-k if `zero_rest` is true.
314fn partial_sort_top_k(probs: &mut [f32], k: usize, zero_rest: bool) -> Vec<(u32, f32)> {
315    let n = probs.len();
316    if n == 0 || k == 0 {
317        return Vec::new();
318    }
319
320    // Build (index, probability) pairs
321    let mut idx_probs: Vec<(u32, f32)> = (0..n as u32).map(|i| (i, probs[i as usize])).collect();
322
323    let k = k.min(n);
324
325    if k < n {
326        // Partial sort: partition so top k elements are in first k positions
327        // select_nth_unstable_by places the k-1th largest at position k-1,
328        // with all larger elements before it (unsorted) and smaller after
329        idx_probs.select_nth_unstable_by(k - 1, cmp_desc_by_prob);
330
331        if zero_rest {
332            // Zero out elements beyond top-k
333            for (idx, _) in idx_probs[k..].iter() {
334                probs[*idx as usize] = 0.0;
335            }
336        }
337
338        // Truncate to top k
339        idx_probs.truncate(k);
340    }
341
342    // Sort just the top k elements (descending by probability)
343    idx_probs.sort_unstable_by(cmp_desc_by_prob);
344
345    idx_probs
346}
347
348/// Find the index of the maximum element in a slice. O(n) scan.
349#[inline]
350fn argmax_f32(values: &[f32]) -> u32 {
351    values
352        .iter()
353        .enumerate()
354        .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
355        .map(|(i, _)| i as u32)
356        .unwrap_or(0)
357}
358
359impl Sampler {
360    #[allow(clippy::too_many_arguments)]
361    pub fn new(
362        temperature: Option<f64>,
363        top_n_logprobs: usize,
364        tokenizer: Option<Arc<Tokenizer>>,
365        frequency_penalty: Option<f32>,
366        presence_penalty: Option<f32>,
367        repetition_penalty: Option<f32>,
368        dry_params: Option<DrySamplingParams>,
369        top_k: i64,
370        top_p: f64,
371        min_p: f64,
372        logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
373    ) -> anyhow::Result<Self> {
374        let temperature = if temperature.is_none_or(|v| v < 1e-7) {
375            None
376        } else {
377            temperature
378        };
379        let dry_params = if let Some(ref tokenizer) = tokenizer {
380            dry_params.map(|params| DrySamplingParamsInner::from(params, tokenizer))
381        } else {
382            None
383        };
384        let dry_params = match dry_params {
385            Some(fallible) => Some(fallible?),
386            None => None,
387        };
388        Ok(Self {
389            temperature,
390            top_n_logprobs,
391            tokenizer,
392            frequency_penalty,
393            presence_penalty,
394            repetition_penalty,
395            dry_params,
396            top_k,
397            top_p,
398            min_p,
399            logits_processors,
400        })
401    }
402
403    pub fn is_argmax(&self) -> bool {
404        self.temperature.is_none()
405    }
406
407    fn get_top_logprobs(&self, probs: &[f32]) -> Result<Vec<TopLogprob>> {
408        let k = self.top_n_logprobs.min(probs.len());
409        if k == 0 {
410            return Ok(Vec::new());
411        }
412
413        // Use partial sort helper (doesn't modify probs since we pass a copy)
414        let mut probs_copy = probs.to_vec();
415        let top_k = partial_sort_top_k(&mut probs_copy, k, false);
416
417        // Build the result vector with log10 of probabilities and optional decoding
418        let mut result = Vec::with_capacity(k);
419        if let Some(tokenizer) = &self.tokenizer {
420            for (token, prob) in top_k {
421                let decoded = tokenizer
422                    .decode(&[token], false)
423                    .map_err(|e| Error::Msg(e.to_string()))?;
424                result.push(TopLogprob {
425                    token,
426                    logprob: prob.log(10.0),
427                    bytes: Some(decoded),
428                });
429            }
430        } else {
431            for (token, prob) in top_k {
432                result.push(TopLogprob {
433                    token,
434                    logprob: prob.log(10.0),
435                    bytes: None,
436                });
437            }
438        }
439        Ok(result)
440    }
441
442    fn sample_argmax(&self, logits: Tensor, return_logprobs: bool) -> Result<Logprobs> {
443        let probs: Vec<f32> = logits.to_vec1()?;
444        let next_token = argmax_f32(&probs);
445        let logprob = probs[next_token as usize].log(10.0);
446
447        let top_logprobs = if return_logprobs {
448            Some(self.get_top_logprobs(&probs)?)
449        } else {
450            None
451        };
452
453        let bytes = if let Some(tokenizer) = &self.tokenizer {
454            Some(
455                tokenizer
456                    .decode(&[next_token], false)
457                    .map_err(|x| Error::Msg(x.to_string()))?,
458            )
459        } else {
460            None
461        };
462
463        Ok(Logprobs {
464            token: next_token,
465            logprob,
466            top_logprobs,
467            bytes,
468        })
469    }
470
471    fn sample_speculative_top_kp_min_p(
472        &self,
473        logits: Tensor,
474        return_logprobs: bool,
475        top_k: i64,
476        top_p: f32,
477        min_p: f32,
478    ) -> Result<Logprobs> {
479        let mut probs: Vec<f32> = logits.to_vec1()?;
480
481        // Determine how many elements we need for partial sort
482        let k = if top_k > 0 {
483            top_k as usize
484        } else {
485            probs.len()
486        };
487
488        // Get sorted top-k indices with partial sort, zeroing out rest
489        let idx_probs = partial_sort_top_k(&mut probs, k, true);
490
491        // TOP P
492        // top-p sampling (or "nucleus sampling") samples from the smallest set of
493        // tokens that exceed probability top_p. This way we never sample tokens that
494        // have very low probabilities and are less likely to go "off the rails".
495
496        // Clamp smaller probabilities to zero.
497        let mut cumsum = 0.;
498        for (index, prob) in &idx_probs {
499            if cumsum >= top_p {
500                probs[*index as usize] = 0.0;
501            } else {
502                cumsum += prob;
503            }
504        }
505
506        // Get max_p from first sorted element
507        let max_p = idx_probs.first().map(|(_, p)| *p).unwrap_or(0.0);
508
509        // MIN P
510        // min-p sampling samples from the tokens whose prob are greater than
511        // (max prob of token in dist) * min_p
512
513        // Clamp smaller probabilities to zero.
514        let min_p_threshold = max_p * min_p;
515        for (index, prob) in &idx_probs {
516            if min_p_threshold >= *prob {
517                probs[*index as usize] = 0.0;
518            }
519        }
520
521        // Find argmax directly on the Vec (O(n) scan, no Tensor creation)
522        let next_token = argmax_f32(&probs);
523        let logprob = probs[next_token as usize].log(10.0);
524
525        let top_logprobs = if return_logprobs {
526            Some(self.get_top_logprobs(&probs)?)
527        } else {
528            None
529        };
530
531        let bytes = if let Some(tokenizer) = &self.tokenizer {
532            Some(
533                tokenizer
534                    .decode(&[next_token], false)
535                    .map_err(|x| Error::Msg(x.to_string()))?,
536            )
537        } else {
538            None
539        };
540
541        Ok(Logprobs {
542            token: next_token,
543            logprob,
544            top_logprobs,
545            bytes,
546        })
547    }
548
549    fn sample_multinomial(
550        &self,
551        probs: &[f32],
552        return_logprobs: bool,
553        rng: Arc<Mutex<Isaac64Rng>>,
554    ) -> Result<Logprobs> {
555        let distr = match WeightedIndex::new(probs) {
556            Ok(distr) => distr,
557            Err(e) => {
558                if let Some((idx, prob)) = probs
559                    .iter()
560                    .enumerate()
561                    .find(|(_, prob)| !prob.is_finite() || **prob < 0.0)
562                {
563                    return Err(Error::Msg(format!(
564                        "Invalid sampling probability at index {idx}: {prob}. The model likely produced NaN/Inf logits."
565                    )));
566                }
567
568                let positive_weight_sum: f64 = probs
569                    .iter()
570                    .copied()
571                    .filter(|prob| prob.is_finite() && *prob > 0.0)
572                    .map(f64::from)
573                    .sum();
574
575                if positive_weight_sum == 0.0 {
576                    return Err(Error::Msg(
577                        "All sampling probabilities are zero after filtering (top-k/top-p/min-p)."
578                            .to_string(),
579                    ));
580                }
581
582                return Err(Error::Msg(format!(
583                    "Failed to construct multinomial sampler: {e}"
584                )));
585            }
586        };
587
588        let mut mut_ref_rng = &mut *rng.lock().expect("could not lock rng mutex");
589        let next_token = distr.sample(&mut mut_ref_rng); // "Find the first item which has a weight *higher* than the chosen weight."
590        let logprob = probs[next_token].log(10.0);
591
592        let top_logprobs = if return_logprobs {
593            Some(self.get_top_logprobs(probs)?)
594        } else {
595            None
596        };
597
598        let bytes = if let Some(tokenizer) = &self.tokenizer {
599            Some(
600                tokenizer
601                    .decode(&[next_token.try_into().unwrap()], false)
602                    .map_err(|x| Error::Msg(x.to_string()))?,
603            )
604        } else {
605            None
606        };
607
608        Ok(Logprobs {
609            token: next_token as u32,
610            logprob,
611            top_logprobs,
612            bytes,
613        })
614    }
615
616    #[cfg(any(feature = "cuda", feature = "metal"))]
617    fn can_sample_topk_on_device(
618        &self,
619        return_logprobs: bool,
620        sample_speculative: bool,
621        multiple_sequences: bool,
622    ) -> bool {
623        const MAX_DEVICE_TOP_K: i64 = 128;
624
625        !return_logprobs
626            && !sample_speculative
627            && !multiple_sequences
628            && self.temperature.is_some()
629            && self.top_k > 0
630            && self.top_k <= MAX_DEVICE_TOP_K
631            && self.logits_processors.is_empty()
632            && self
633                .dry_params
634                .as_ref()
635                .is_none_or(|params| params.multiplier.abs() <= f32::EPSILON)
636    }
637
638    #[cfg(feature = "cuda")]
639    fn apply_device_sparse_penalties_if_needed(
640        &self,
641        logits: Tensor,
642        context: &[u32],
643    ) -> Result<Tensor> {
644        let frequency_penalty = self.frequency_penalty.unwrap_or(0.0);
645        let presence_penalty = self.presence_penalty.unwrap_or(0.0);
646        let repetition_penalty = self.repetition_penalty.unwrap_or(1.0);
647        let needs_penalty = frequency_penalty.abs() > f32::EPSILON
648            || presence_penalty.abs() > f32::EPSILON
649            || (repetition_penalty - 1.0).abs() > f32::EPSILON;
650
651        if !needs_penalty {
652            return Ok(logits);
653        }
654        if context.is_empty() {
655            hanzo_ml::bail!("Penalty context is empty, this should not happen.");
656        }
657
658        let vocab_size = logits.elem_count();
659        let mut counts = HashMap::<u32, f32>::with_capacity(context.len().min(vocab_size));
660        for &token_id in context {
661            if token_id as usize >= vocab_size {
662                continue;
663            }
664            *counts.entry(token_id).or_insert(0.0) += 1.0;
665        }
666
667        if counts.is_empty() {
668            return Ok(logits);
669        }
670
671        let n_tokens = counts.len();
672        let mut token_ids = Vec::with_capacity(n_tokens);
673        let mut token_counts = Vec::with_capacity(n_tokens);
674        for (token_id, count) in counts {
675            token_ids.push(token_id);
676            token_counts.push(count);
677        }
678
679        let device = logits.device();
680        let token_ids = Tensor::from_vec(token_ids, n_tokens, device)?;
681        let token_counts = Tensor::from_vec(token_counts, n_tokens, device)?;
682        crate::ops::cuda_apply_sparse_penalties_f32(
683            &logits,
684            &token_ids,
685            &token_counts,
686            frequency_penalty,
687            presence_penalty,
688            repetition_penalty,
689        )
690    }
691
692    #[cfg(feature = "cuda")]
693    fn sample_topk_on_device(
694        &self,
695        logits: Tensor,
696        temperature: f64,
697        rng: Arc<Mutex<Isaac64Rng>>,
698    ) -> Result<Logprobs> {
699        let topk =
700            crate::ops::cuda_topk_logits_f32_packed(&logits, self.top_k as usize, temperature)?;
701        let packed = topk.packed.to_vec1::<f32>()?;
702        let k = topk.k;
703        if packed.len() != 2 * k + 2 {
704            hanzo_ml::bail!(
705                "invalid CUDA top-k packed output length {}, expected {}",
706                packed.len(),
707                2 * k + 2
708            );
709        }
710        let top_values = &packed[..k];
711        let top_indices = packed[k..2 * k]
712            .iter()
713            .map(|idx| *idx as u32)
714            .collect::<Vec<_>>();
715        let softmax_info = &packed[2 * k..2 * k + 2];
716
717        let denom = softmax_info[0];
718        let global_max = softmax_info[1];
719        if denom <= 0.0 || !denom.is_finite() || !global_max.is_finite() {
720            hanzo_ml::bail!("invalid CUDA top-k softmax normalizer");
721        }
722
723        let inv_temperature = (1.0 / temperature) as f32;
724        let mut probs = top_values
725            .iter()
726            .map(|value| ((*value * inv_temperature - global_max).exp()) / denom)
727            .collect::<Vec<_>>();
728
729        if self.top_p > 0.0 && self.top_p < 1.0 {
730            let mut cumsum = 0.0f32;
731            for prob in &mut probs {
732                if cumsum >= self.top_p as f32 {
733                    *prob = 0.0;
734                } else {
735                    cumsum += *prob;
736                }
737            }
738
739            if self.min_p > 0.0 && self.min_p < 1.0 {
740                let max_p = probs.first().copied().unwrap_or(0.0);
741                let min_p_threshold = max_p * self.min_p as f32;
742                for prob in &mut probs {
743                    if min_p_threshold >= *prob {
744                        *prob = 0.0;
745                    }
746                }
747            }
748        }
749
750        let distr = match WeightedIndex::new(&probs) {
751            Ok(distr) => distr,
752            Err(e) => {
753                let positive_weight_sum: f64 = probs
754                    .iter()
755                    .copied()
756                    .filter(|prob| prob.is_finite() && *prob > 0.0)
757                    .map(f64::from)
758                    .sum();
759                if positive_weight_sum == 0.0 {
760                    return Err(Error::Msg(
761                        "All sampling probabilities are zero after CUDA top-k filtering."
762                            .to_string(),
763                    ));
764                }
765
766                return Err(Error::Msg(format!(
767                    "Failed to construct CUDA top-k multinomial sampler: {e}"
768                )));
769            }
770        };
771
772        let mut mut_ref_rng = &mut *rng.lock().expect("could not lock rng mutex");
773        let selected = distr.sample(&mut mut_ref_rng);
774        let next_token = top_indices[selected];
775        let logprob = probs[selected].log(10.0);
776
777        let bytes = if let Some(tokenizer) = &self.tokenizer {
778            Some(
779                tokenizer
780                    .decode(&[next_token], false)
781                    .map_err(|x| Error::Msg(x.to_string()))?,
782            )
783        } else {
784            None
785        };
786
787        Ok(Logprobs {
788            token: next_token,
789            logprob,
790            top_logprobs: None,
791            bytes,
792        })
793    }
794
795    #[cfg(feature = "metal")]
796    fn apply_device_sparse_penalties_if_needed_metal(
797        &self,
798        logits: Tensor,
799        context: &[u32],
800    ) -> Result<Tensor> {
801        let frequency_penalty = self.frequency_penalty.unwrap_or(0.0);
802        let presence_penalty = self.presence_penalty.unwrap_or(0.0);
803        let repetition_penalty = self.repetition_penalty.unwrap_or(1.0);
804        let needs_penalty = frequency_penalty.abs() > f32::EPSILON
805            || presence_penalty.abs() > f32::EPSILON
806            || (repetition_penalty - 1.0).abs() > f32::EPSILON;
807        if !needs_penalty || context.is_empty() {
808            return Ok(logits);
809        }
810        let vocab_size = logits.elem_count();
811        let mut counts = HashMap::<u32, f32>::with_capacity(context.len().min(vocab_size));
812        for &tid in context {
813            if (tid as usize) >= vocab_size {
814                continue;
815            }
816            *counts.entry(tid).or_insert(0.0) += 1.0;
817        }
818        if counts.is_empty() {
819            return Ok(logits);
820        }
821        let n_tokens = counts.len();
822        let mut token_ids = Vec::with_capacity(n_tokens);
823        let mut token_counts = Vec::with_capacity(n_tokens);
824        for (tid, c) in counts {
825            token_ids.push(tid);
826            token_counts.push(c);
827        }
828        let device = logits.device();
829        let token_ids = Tensor::from_vec(token_ids, n_tokens, device)?;
830        let token_counts = Tensor::from_vec(token_counts, n_tokens, device)?;
831        crate::ops::metal_apply_sparse_penalties(
832            &logits,
833            &token_ids,
834            &token_counts,
835            frequency_penalty,
836            presence_penalty,
837            repetition_penalty,
838        )
839    }
840
841    #[cfg(feature = "metal")]
842    fn sample_topk_on_device_metal(
843        &self,
844        logits: Tensor,
845        temperature: f64,
846        rng: Arc<Mutex<Isaac64Rng>>,
847    ) -> Result<Logprobs> {
848        let topk = crate::ops::metal_topk_logits_packed(&logits, self.top_k as usize, temperature)?;
849        let packed = topk.packed.to_vec1::<f32>()?;
850        let k = topk.k;
851        if packed.len() != 2 * k + 2 {
852            hanzo_ml::bail!(
853                "invalid Metal top-k packed output length {}, expected {}",
854                packed.len(),
855                2 * k + 2
856            );
857        }
858        let top_values = &packed[..k];
859        let top_indices = packed[k..2 * k]
860            .iter()
861            .map(|idx| *idx as u32)
862            .collect::<Vec<_>>();
863        let softmax_info = &packed[2 * k..2 * k + 2];
864        let denom = softmax_info[0];
865        let global_max = softmax_info[1];
866        if denom <= 0.0 || !denom.is_finite() || !global_max.is_finite() {
867            hanzo_ml::bail!("invalid Metal top-k softmax normalizer");
868        }
869
870        let inv_temperature = (1.0 / temperature) as f32;
871        let mut probs = top_values
872            .iter()
873            .map(|value| ((*value * inv_temperature - global_max).exp()) / denom)
874            .collect::<Vec<_>>();
875
876        if self.top_p > 0.0 && self.top_p < 1.0 {
877            let mut cumsum = 0.0f32;
878            for prob in &mut probs {
879                if cumsum >= self.top_p as f32 {
880                    *prob = 0.0;
881                } else {
882                    cumsum += *prob;
883                }
884            }
885            if self.min_p > 0.0 && self.min_p < 1.0 {
886                let max_p = probs.first().copied().unwrap_or(0.0);
887                let min_p_threshold = max_p * self.min_p as f32;
888                for prob in &mut probs {
889                    if min_p_threshold >= *prob {
890                        *prob = 0.0;
891                    }
892                }
893            }
894        }
895
896        let distr = match WeightedIndex::new(&probs) {
897            Ok(distr) => distr,
898            Err(e) => {
899                let positive_weight_sum: f64 = probs
900                    .iter()
901                    .copied()
902                    .filter(|prob| prob.is_finite() && *prob > 0.0)
903                    .map(f64::from)
904                    .sum();
905                if positive_weight_sum == 0.0 {
906                    return Err(Error::Msg(
907                        "All sampling probabilities are zero after Metal top-k filtering."
908                            .to_string(),
909                    ));
910                }
911                return Err(Error::Msg(format!(
912                    "Failed to construct Metal top-k multinomial sampler: {e}"
913                )));
914            }
915        };
916
917        let mut mut_ref_rng = &mut *rng.lock().expect("could not lock rng mutex");
918        let selected = distr.sample(&mut mut_ref_rng);
919        let next_token = top_indices[selected];
920        let logprob = probs[selected].log(10.0);
921        let bytes = if let Some(tokenizer) = &self.tokenizer {
922            Some(
923                tokenizer
924                    .decode(&[next_token], false)
925                    .map_err(|x| Error::Msg(x.to_string()))?,
926            )
927        } else {
928            None
929        };
930        Ok(Logprobs {
931            token: next_token,
932            logprob,
933            top_logprobs: None,
934            bytes,
935        })
936    }
937
938    fn filter_top_kp_min_p(&self, probs: &mut [f32]) {
939        let k = if self.top_k > 0 {
940            self.top_k as usize
941        } else {
942            probs.len()
943        };
944
945        let idx_probs = partial_sort_top_k(probs, k, true);
946
947        if self.top_p <= 0.0 || self.top_p >= 1.0 {
948            return;
949        }
950
951        let mut cumsum = 0.0f32;
952        for (index, prob) in &idx_probs {
953            if cumsum >= self.top_p as f32 {
954                probs[*index as usize] = 0.0;
955            } else {
956                cumsum += prob;
957            }
958        }
959
960        if self.min_p <= 0.0 || self.min_p >= 1.0 {
961            return;
962        }
963
964        let max_p = idx_probs.first().map(|(_, p)| *p).unwrap_or(0.0);
965        let min_p_threshold = max_p * self.min_p as f32;
966        for (index, prob) in &idx_probs {
967            if min_p_threshold >= *prob {
968                probs[*index as usize] = 0.0;
969            }
970        }
971    }
972
973    fn normalize_probs(probs: &mut [f32]) -> Result<()> {
974        let sum: f32 = probs
975            .iter()
976            .copied()
977            .filter(|prob| prob.is_finite() && *prob > 0.0)
978            .sum();
979        if sum <= 0.0 {
980            hanzo_ml::bail!("all probabilities are zero in speculative sampling");
981        }
982        for prob in probs.iter_mut() {
983            if prob.is_finite() && *prob > 0.0 {
984                *prob /= sum;
985            } else {
986                *prob = 0.0;
987            }
988        }
989        Ok(())
990    }
991
992    pub(crate) fn speculative_target_probs(
993        &self,
994        logits: Tensor,
995        context: &[u32],
996    ) -> Result<Vec<f32>> {
997        self.speculative_probs(logits, context)
998    }
999
1000    pub(crate) fn speculative_candidate_probs(
1001        &self,
1002        logits: Tensor,
1003        context: &[u32],
1004    ) -> Result<Vec<f32>> {
1005        self.speculative_probs(logits, context)
1006    }
1007
1008    fn speculative_probs(&self, logits: Tensor, context: &[u32]) -> Result<Vec<f32>> {
1009        let logits = logits.to_vec1()?;
1010        let mut logits = self.apply_penalties(logits, context)?;
1011        for processor in &self.logits_processors {
1012            logits = processor.apply(&logits, context)?;
1013        }
1014
1015        let mut probs = match self.temperature {
1016            None => {
1017                let logits = logits.to_vec1::<f32>()?;
1018                let mut probs = vec![0.0; logits.len()];
1019                probs[argmax_f32(&logits) as usize] = 1.0;
1020                probs
1021            }
1022            Some(temperature) => {
1023                let logits = (&logits / temperature)?;
1024                hanzo_nn::ops::softmax_last_dim(&logits)?.to_vec1::<f32>()?
1025            }
1026        };
1027        self.filter_top_kp_min_p(&mut probs);
1028        Self::normalize_probs(&mut probs)?;
1029        Ok(probs)
1030    }
1031
1032    pub(crate) fn logprobs_from_probs(
1033        &self,
1034        token: u32,
1035        probs: &[f32],
1036        return_logprobs: bool,
1037    ) -> Result<Logprobs> {
1038        let prob = probs.get(token as usize).copied().unwrap_or(0.0);
1039        let logprob = if prob > 0.0 {
1040            prob.log(10.0)
1041        } else {
1042            f32::NEG_INFINITY
1043        };
1044        let top_logprobs = if return_logprobs {
1045            Some(self.get_top_logprobs(probs)?)
1046        } else {
1047            None
1048        };
1049        let bytes = if let Some(tokenizer) = &self.tokenizer {
1050            Some(
1051                tokenizer
1052                    .decode(&[token], false)
1053                    .map_err(|x| Error::Msg(x.to_string()))?,
1054            )
1055        } else {
1056            None
1057        };
1058        Ok(Logprobs {
1059            token,
1060            logprob,
1061            top_logprobs,
1062            bytes,
1063        })
1064    }
1065
1066    pub(crate) fn sample_from_probs(
1067        &self,
1068        probs: &[f32],
1069        return_logprobs: bool,
1070        rng: Arc<Mutex<Isaac64Rng>>,
1071    ) -> Result<Logprobs> {
1072        self.sample_multinomial(probs, return_logprobs, rng)
1073    }
1074
1075    #[allow(clippy::too_many_arguments)]
1076    fn sample_top_kp_min_p(
1077        &self,
1078        probs: &mut [f32],
1079        top_k: i64,
1080        top_p: f32,
1081        min_p: f32,
1082        return_logprobs: bool,
1083        rng: Arc<Mutex<Isaac64Rng>>,
1084    ) -> Result<Logprobs> {
1085        // Determine how many elements we need for partial sort
1086        let k = if top_k > 0 {
1087            top_k as usize
1088        } else {
1089            probs.len()
1090        };
1091
1092        // Get sorted top-k indices with partial sort, zeroing out rest
1093        let idx_probs = partial_sort_top_k(probs, k, true);
1094
1095        if top_p <= 0.0 || top_p >= 1.0 {
1096            return self.sample_multinomial(probs, return_logprobs, rng);
1097        }
1098
1099        // TOP P
1100
1101        // top-p sampling (or "nucleus sampling") samples from the smallest set of
1102        // tokens that exceed probability top_p. This way we never sample tokens that
1103        // have very low probabilities and are less likely to go "off the rails".
1104
1105        // Clamp smaller probabilities to zero.
1106        let mut cumsum = 0.;
1107        for (index, prob) in &idx_probs {
1108            if cumsum >= top_p {
1109                probs[*index as usize] = 0.0;
1110            } else {
1111                cumsum += prob;
1112            }
1113        }
1114
1115        if min_p <= 0.0 || min_p >= 1.0 {
1116            return self.sample_multinomial(probs, return_logprobs, rng);
1117        }
1118
1119        // Get max_p from first sorted element
1120        let max_p = idx_probs.first().map(|(_, p)| *p).unwrap_or(0.0);
1121
1122        // MIN P
1123
1124        // min-p sampling samples from the tokens whose prob are greater than
1125        // (max prob of token in dist) * min_p
1126
1127        // Clamp smaller probabilities to zero.
1128        let min_p_threshold = max_p * min_p;
1129        for (index, prob) in &idx_probs {
1130            if min_p_threshold >= *prob {
1131                probs[*index as usize] = 0.0;
1132            }
1133        }
1134
1135        // Sample with clamped probabilities.
1136        self.sample_multinomial(probs, return_logprobs, rng)
1137    }
1138
1139    fn apply_penalties(&self, mut logits: Vec<f32>, context: &[u32]) -> Result<Tensor> {
1140        if context.is_empty() {
1141            hanzo_ml::bail!("Penalty context is empty, this should not happen.");
1142        }
1143
1144        // Dry penalty
1145        self.apply_dry_penalty(&mut logits, context)?;
1146
1147        // Frequency, presence, repetition penalty
1148        self.apply_freq_pres_rep_penalty(&mut logits, context)?;
1149
1150        let vocab_size = logits.len();
1151        Tensor::from_vec(logits, vocab_size, &Device::Cpu)
1152    }
1153
1154    fn apply_freq_pres_rep_penalty(&self, logits: &mut [f32], context: &[u32]) -> Result<()> {
1155        if self.frequency_penalty.is_some()
1156            || self.presence_penalty.is_some()
1157            || self.repetition_penalty.is_some()
1158        {
1159            let frequency_penalty = self.frequency_penalty.unwrap_or(0.);
1160            let presence_penalty = self.presence_penalty.unwrap_or(0.);
1161            let repetition_penalty = self.repetition_penalty.unwrap_or(1.);
1162
1163            //mu[j] -> mu[j] - c[j] * alpha_frequency - float(c[j] > 0) * alpha_presence
1164
1165            let mut counts = vec![0.0f32; logits.len()];
1166            for ctx in context.iter() {
1167                // Llama 3.2 uses a hack triggering this error... we wouldn't want a weight on it anyway
1168                if *ctx as usize >= logits.len() {
1169                    continue;
1170                }
1171                counts[*ctx as usize] += 1.0;
1172            }
1173
1174            for (token_id, logit) in logits.iter_mut().enumerate() {
1175                let count = counts[token_id];
1176                *logit = *logit
1177                    - count * frequency_penalty
1178                    - if count > 0.0 { 1. } else { 0. } * presence_penalty;
1179
1180                if repetition_penalty != 1.0 && count > 0.0 {
1181                    if *logit > 0.0 {
1182                        *logit /= repetition_penalty;
1183                    } else {
1184                        *logit *= repetition_penalty;
1185                    }
1186                }
1187            }
1188        }
1189        Ok(())
1190    }
1191
1192    /// Threshold for using parallel iteration in dry penalty.
1193    /// Below this, sequential is faster due to parallel overhead.
1194    const DRY_PENALTY_PAR_THRESHOLD: usize = 1024;
1195
1196    fn apply_dry_penalty(&self, logits: &mut [f32], context: &[u32]) -> Result<()> {
1197        if let Some(ref params) = self.dry_params {
1198            if params.multiplier == 0. {
1199                return Ok(());
1200            }
1201
1202            let last_token = *context.last().unwrap();
1203
1204            // Use parallel iteration only for large contexts
1205            let match_indices: Vec<usize> = if context.len() > Self::DRY_PENALTY_PAR_THRESHOLD {
1206                context
1207                    .par_iter()
1208                    .enumerate()
1209                    .take(context.len() - 1)
1210                    .filter(|(_i, x)| last_token == **x)
1211                    .map(|(i, _)| i)
1212                    .collect()
1213            } else {
1214                context
1215                    .iter()
1216                    .enumerate()
1217                    .take(context.len() - 1)
1218                    .filter(|(_i, x)| last_token == **x)
1219                    .map(|(i, _)| i)
1220                    .collect()
1221            };
1222
1223            let mut match_lengths = HashMap::new();
1224
1225            for i in match_indices {
1226                let next_token = context[i + 1];
1227
1228                if params.sequence_breakers.contains(&next_token) {
1229                    continue;
1230                }
1231
1232                let mut match_length = 1;
1233
1234                // Limit match length to avoid quadratic runtime and potential DoS with adversarial inputs.
1235                while match_length < 50 {
1236                    if match_length > i {
1237                        // Start of input
1238                        break;
1239                    }
1240
1241                    let j = i - match_length;
1242
1243                    let prev_tok = context[context.len() - (match_length + 1)];
1244                    if context[j] != prev_tok {
1245                        // Start of match reached
1246                        break;
1247                    }
1248
1249                    if params.sequence_breakers.contains(&prev_tok) {
1250                        // Seq breaking tok reached
1251                        break;
1252                    }
1253
1254                    match_length += 1;
1255                }
1256
1257                #[allow(clippy::map_entry)]
1258                if match_lengths.contains_key(&next_token) {
1259                    match_lengths.insert(next_token, match_length.max(match_lengths[&next_token]));
1260                } else {
1261                    match_lengths.insert(next_token, match_length);
1262                }
1263            }
1264
1265            // Actually apply penalties
1266            for (tok, match_len) in match_lengths {
1267                if match_len >= params.allowed_length {
1268                    // Llama 3.2 uses a hack triggering this error... we wouldn't want a weight on it anyway
1269                    if tok as usize >= logits.len() {
1270                        continue;
1271                    }
1272                    let penalty = params.multiplier
1273                        * params.base.powf((match_len - params.allowed_length) as f32);
1274                    logits[tok as usize] -= penalty;
1275                }
1276            }
1277        }
1278        Ok(())
1279    }
1280
1281    #[allow(unused)]
1282    /// Sample the provided tokens.
1283    ///
1284    /// If the temperature is `None`, argmax sampling is used. Otherwise, the selected sampling is used.
1285    /// With `top-p` sampling, if the `top-p` value is `<= 0.0` or `>= 1.0`, multinomial sampling is used.
1286    pub fn sample(
1287        &self,
1288        logits: Tensor,
1289        context: &[u32],
1290        return_logprobs: bool,
1291        rng: Arc<Mutex<Isaac64Rng>>,
1292        sample_speculative: bool,
1293        multiple_sequences: bool,
1294    ) -> Result<Logprobs> {
1295        #[cfg(feature = "cuda")]
1296        if logits.device().is_cuda()
1297            && self.can_sample_topk_on_device(
1298                return_logprobs,
1299                sample_speculative,
1300                multiple_sequences,
1301            )
1302        {
1303            if let Some(temperature) = self.temperature {
1304                let logits = self.apply_device_sparse_penalties_if_needed(logits, context)?;
1305                return self.sample_topk_on_device(logits, temperature, rng);
1306            }
1307        }
1308
1309        #[cfg(feature = "metal")]
1310        if logits.device().is_metal()
1311            && self.can_sample_topk_on_device(
1312                return_logprobs,
1313                sample_speculative,
1314                multiple_sequences,
1315            )
1316        {
1317            if let Some(temperature) = self.temperature {
1318                let logits = self.apply_device_sparse_penalties_if_needed_metal(logits, context)?;
1319                return self.sample_topk_on_device_metal(logits, temperature, rng);
1320            }
1321        }
1322
1323        let logits = logits.to_vec1()?;
1324        let mut logits = self.apply_penalties(logits, context)?;
1325        for processor in &self.logits_processors {
1326            logits = processor.apply(&logits, context)?;
1327        }
1328        let next_token = if sample_speculative {
1329            match self.temperature {
1330                None => self.sample_speculative_top_kp_min_p(
1331                    logits,
1332                    return_logprobs,
1333                    self.top_k,
1334                    self.top_p as f32,
1335                    self.min_p as f32,
1336                )?,
1337                Some(temperature) => {
1338                    let logits = (&logits / temperature)?;
1339                    let probs = hanzo_nn::ops::softmax_last_dim(&logits)?;
1340
1341                    self.sample_speculative_top_kp_min_p(
1342                        probs,
1343                        return_logprobs,
1344                        self.top_k,
1345                        self.top_p as f32,
1346                        self.min_p as f32,
1347                    )?
1348                }
1349            }
1350        } else {
1351            match self.temperature {
1352                None => self.sample_argmax(logits, return_logprobs)?,
1353                Some(temperature) => {
1354                    let logits = (&logits / temperature)?;
1355                    let probs = hanzo_nn::ops::softmax_last_dim(&logits)?;
1356                    let mut probs: Vec<f32> = probs.to_vec1()?;
1357
1358                    self.sample_top_kp_min_p(
1359                        &mut probs,
1360                        self.top_k,
1361                        self.top_p as f32,
1362                        self.min_p as f32,
1363                        return_logprobs,
1364                        rng,
1365                    )?
1366                }
1367            }
1368        };
1369        Ok(next_token)
1370    }
1371}
1372
1373#[cfg(test)]
1374mod tests {
1375    use super::{ModelGenerationDefaults, SamplingParams};
1376
1377    #[test]
1378    fn test_argmax() {
1379        use super::Sampler;
1380        use hanzo_ml::{Device, Tensor};
1381        use rand::SeedableRng;
1382        use rand_isaac::Isaac64Rng;
1383        use std::sync::Arc;
1384        use std::sync::Mutex;
1385
1386        let sampler = Sampler::new(
1387            None,
1388            10,
1389            None,
1390            None,
1391            None,
1392            None,
1393            None,
1394            32,
1395            0.1,
1396            0.05,
1397            vec![],
1398        )
1399        .unwrap();
1400        let logits = Tensor::arange(0f32, 1024f32, &Device::Cpu).unwrap();
1401        let rng = Arc::new(Mutex::new(Isaac64Rng::seed_from_u64(42)));
1402        let res = sampler
1403            .sample(
1404                logits,
1405                &(0..1024).collect::<Vec<_>>(),
1406                false,
1407                rng,
1408                false,
1409                false,
1410            )
1411            .unwrap();
1412        assert_eq!(res.token, 1023);
1413        assert_eq!(res.top_logprobs, None);
1414        assert_eq!(res.logprob, 1023f64.log(10.) as f32)
1415    }
1416
1417    #[test]
1418    fn test_gumbel_speculative() {
1419        use super::Sampler;
1420        use hanzo_ml::{Device, Tensor};
1421        use rand::SeedableRng;
1422        use rand_isaac::Isaac64Rng;
1423        use std::sync::Arc;
1424        use std::sync::Mutex;
1425
1426        let sampler = Sampler::new(
1427            None,
1428            10,
1429            None,
1430            None,
1431            None,
1432            None,
1433            None,
1434            32,
1435            0.1,
1436            0.05,
1437            vec![],
1438        )
1439        .unwrap();
1440        let logits = Tensor::arange(0f32, 1024f32, &Device::Cpu).unwrap();
1441        let rng = Arc::new(Mutex::new(Isaac64Rng::seed_from_u64(42)));
1442        let res = sampler
1443            .sample(
1444                logits,
1445                &(0..1024).collect::<Vec<_>>(),
1446                false,
1447                rng,
1448                true,
1449                false,
1450            )
1451            .unwrap();
1452        assert_eq!(res.token, 1023);
1453        assert_eq!(res.top_logprobs, None);
1454        assert_eq!(res.logprob, 1023f64.log(10.) as f32)
1455    }
1456
1457    #[test]
1458    fn test_speculative_candidate_probs_use_sampling_filters() {
1459        use super::Sampler;
1460        use hanzo_ml::{Device, Tensor};
1461
1462        let sampler = Sampler::new(
1463            Some(1.0),
1464            10,
1465            None,
1466            None,
1467            None,
1468            None,
1469            None,
1470            1,
1471            1.0,
1472            0.0,
1473            vec![],
1474        )
1475        .unwrap();
1476        let logits = Tensor::from_vec(vec![0.0f32, 1.0, 2.0], 3, &Device::Cpu).unwrap();
1477        let context = [0u32];
1478        let target_probs = sampler
1479            .speculative_target_probs(logits.clone(), &context)
1480            .unwrap();
1481        let candidate_probs = sampler
1482            .speculative_candidate_probs(logits, &context)
1483            .unwrap();
1484
1485        assert_eq!(candidate_probs, target_probs);
1486        assert_eq!(candidate_probs, vec![0.0, 0.0, 1.0]);
1487    }
1488
1489    #[test]
1490    fn test_apply_model_defaults() {
1491        let mut params = SamplingParams::neutral();
1492        params.apply_model_defaults(&ModelGenerationDefaults {
1493            do_sample: Some(true),
1494            temperature: Some(1.0),
1495            top_k: Some(32),
1496            top_p: Some(0.9),
1497            min_p: Some(0.05),
1498            repetition_penalty: Some(1.1),
1499            max_new_tokens: Some(256),
1500            max_length: None,
1501        });
1502
1503        assert_eq!(params.temperature, Some(1.0));
1504        assert_eq!(params.top_k, Some(32));
1505        assert_eq!(params.top_p, Some(0.9));
1506        assert_eq!(params.min_p, Some(0.05));
1507        assert_eq!(params.repetition_penalty, Some(1.1));
1508        assert_eq!(params.max_len, Some(256));
1509    }
1510
1511    #[test]
1512    fn test_apply_model_defaults_disables_sampling_when_requested() {
1513        let mut params = SamplingParams {
1514            temperature: Some(0.7),
1515            top_k: Some(40),
1516            top_p: Some(0.9),
1517            min_p: Some(0.1),
1518            ..SamplingParams::neutral()
1519        };
1520        params.apply_model_defaults(&ModelGenerationDefaults {
1521            do_sample: Some(false),
1522            ..Default::default()
1523        });
1524
1525        assert_eq!(params.temperature, None);
1526        assert_eq!(params.top_k, Some(1));
1527        assert_eq!(params.top_p, None);
1528        assert_eq!(params.min_p, None);
1529    }
1530}