1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{
4 collections::{HashMap, HashSet},
5 sync::{Arc, LazyLock, Mutex},
6};
7
8use candle_core::{DType, Device, Error, Result, Tensor, D};
9use mistralrs_quant::{CumSumOp, SortOp};
10#[cfg(feature = "pyo3_macros")]
11use pyo3::pyclass;
12
13use rand::distr::{weighted::WeightedIndex, Distribution};
14use rand_isaac::Isaac64Rng;
15use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
16use serde::{Deserialize, Serialize};
17use tokenizers::Tokenizer;
18
19static DRY_SEQUENCE_BREAKERS: LazyLock<Vec<String>> =
20 LazyLock::new(|| ["\n", ":", "\"", "*"].map(String::from).to_vec());
21
22#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
23pub struct ModelGenerationDefaults {
28 pub do_sample: Option<bool>,
29 pub temperature: Option<f64>,
30 pub top_k: Option<usize>,
31 pub top_p: Option<f64>,
32 pub min_p: Option<f64>,
33 pub repetition_penalty: Option<f32>,
34 pub max_new_tokens: Option<usize>,
35 pub max_length: Option<usize>,
36}
37
38impl ModelGenerationDefaults {
39 pub fn is_empty(&self) -> bool {
40 self.do_sample.is_none()
41 && self.temperature.is_none()
42 && self.top_k.is_none()
43 && self.top_p.is_none()
44 && self.min_p.is_none()
45 && self.repetition_penalty.is_none()
46 && self.max_new_tokens.is_none()
47 && self.max_length.is_none()
48 }
49}
50
51#[derive(Clone, Debug, Serialize, Deserialize)]
52pub enum StopTokens {
54 Seqs(Vec<String>),
55 Ids(Vec<u32>),
56}
57
58#[derive(Clone, Debug, Serialize, Deserialize)]
59pub struct SamplingParams {
61 pub temperature: Option<f64>,
62 pub top_k: Option<usize>,
63 pub top_p: Option<f64>,
64 pub min_p: Option<f64>,
65 pub top_n_logprobs: usize,
66 pub frequency_penalty: Option<f32>,
67 pub presence_penalty: Option<f32>,
68 pub repetition_penalty: Option<f32>,
69 pub stop_toks: Option<StopTokens>,
70 pub max_len: Option<usize>,
71 pub logits_bias: Option<HashMap<u32, f32>>,
72 pub n_choices: usize,
73 pub dry_params: Option<DrySamplingParams>,
74}
75
76impl SamplingParams {
77 pub fn neutral() -> Self {
84 Self {
85 temperature: None,
86 top_k: None,
87 top_p: None,
88 min_p: None,
89 top_n_logprobs: 0,
90 frequency_penalty: None,
91 presence_penalty: None,
92 repetition_penalty: None,
93 stop_toks: None,
94 max_len: None,
95 logits_bias: None,
96 n_choices: 1,
97 dry_params: None,
98 }
99 }
100
101 pub fn deterministic() -> Self {
106 Self {
107 temperature: None,
108 top_k: Some(1),
109 top_p: None,
110 min_p: None,
111 top_n_logprobs: 0,
112 frequency_penalty: None,
113 presence_penalty: None,
114 repetition_penalty: None,
115 stop_toks: None,
116 max_len: None,
117 logits_bias: None,
118 n_choices: 1,
119 dry_params: None,
120 }
121 }
122
123 pub fn apply_model_defaults(&mut self, defaults: &ModelGenerationDefaults) {
127 if defaults.do_sample == Some(false) {
128 self.temperature = None;
129 self.top_k = Some(1);
130 self.top_p = None;
131 self.min_p = None;
132 }
133
134 if let Some(temperature) = defaults.temperature {
135 self.temperature = if temperature == 0.0 {
136 None
137 } else {
138 Some(temperature)
139 };
140 }
141 if let Some(top_k) = defaults.top_k {
142 self.top_k = if top_k == 0 { None } else { Some(top_k) };
143 }
144 if let Some(top_p) = defaults.top_p {
145 self.top_p = Some(top_p);
146 }
147 if let Some(min_p) = defaults.min_p {
148 self.min_p = Some(min_p);
149 }
150 if let Some(repetition_penalty) = defaults.repetition_penalty {
151 self.repetition_penalty = Some(repetition_penalty);
152 }
153 if let Some(max_new_tokens) = defaults.max_new_tokens {
154 self.max_len = Some(max_new_tokens);
155 }
156 }
157}
158
159#[derive(Clone, Debug, Serialize, Deserialize)]
161pub struct DrySamplingParams {
162 pub sequence_breakers: Vec<String>,
163 pub multiplier: f32,
164 pub base: f32,
165 pub allowed_length: usize,
166}
167
168impl DrySamplingParams {
169 pub fn new_with_defaults(
170 multiplier: f32,
171 sequence_breakers: Option<Vec<String>>,
172 base: Option<f32>,
173 allowed_length: Option<usize>,
174 ) -> anyhow::Result<Self> {
175 Ok(Self {
176 base: base.unwrap_or(1.75),
177 allowed_length: allowed_length.unwrap_or(2),
178 sequence_breakers: sequence_breakers.unwrap_or(DRY_SEQUENCE_BREAKERS.clone()),
179 multiplier,
180 })
181 }
182}
183
184impl Default for DrySamplingParams {
185 fn default() -> Self {
186 Self {
187 multiplier: 0.0,
188 base: 1.75,
189 allowed_length: 2,
190 sequence_breakers: DRY_SEQUENCE_BREAKERS.clone(),
191 }
192 }
193}
194
195#[derive(Clone, Debug)]
196struct DrySamplingParamsInner {
197 pub sequence_breakers: HashSet<u32>,
198 pub multiplier: f32,
199 pub base: f32,
200 pub allowed_length: usize,
201}
202
203impl DrySamplingParamsInner {
204 pub fn from(other: DrySamplingParams, tokenizer: &Tokenizer) -> anyhow::Result<Self> {
205 Ok(Self {
206 base: other.base,
207 allowed_length: other.allowed_length,
208 sequence_breakers: HashSet::from_iter(
209 other
210 .sequence_breakers
211 .into_iter()
212 .map(|breaker| {
213 tokenizer
214 .encode_fast(["a", &breaker].concat(), true)
220 .map_err(anyhow::Error::msg)
221 .map(|enc| {
222 let ids = enc.get_ids();
223 if !ids.is_empty() {
224 Some(ids[ids.len() - 1])
225 } else {
226 None
227 }
228 })
229 })
230 .collect::<anyhow::Result<Vec<_>>>()?
231 .into_iter()
232 .flatten()
233 .collect::<Vec<_>>(),
234 ),
235 multiplier: other.multiplier,
236 })
237 }
238}
239
240pub trait CustomLogitsProcessor: Send + Sync {
260 fn apply(&self, logits: &Tensor, context: &[u32]) -> Result<Tensor>;
262}
263
264impl<T: Fn(&Tensor, &[u32]) -> Result<Tensor> + Send + Sync> CustomLogitsProcessor for T {
265 fn apply(&self, logits: &Tensor, context: &[u32]) -> Result<Tensor> {
266 self(logits, context)
267 }
268}
269
270#[derive(Clone)]
272pub struct Sampler {
273 temperature: Option<f64>,
274 top_n_logprobs: usize,
275 tokenizer: Option<Arc<Tokenizer>>,
276 frequency_penalty: Option<f32>,
277 presence_penalty: Option<f32>,
278 repetition_penalty: Option<f32>,
279 dry_params: Option<DrySamplingParamsInner>,
280 top_k: i64,
281 top_p: f64,
282 min_p: f64,
283 logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
284 gumbel_cache: Arc<Mutex<Option<Tensor>>>,
286}
287
288#[cfg_attr(feature = "pyo3_macros", pyclass)]
289#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
290#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
291pub struct TopLogprob {
293 pub token: u32,
294 pub logprob: f32,
295 pub bytes: Option<String>,
296}
297
298#[derive(Debug, Clone, Serialize, Deserialize)]
299pub struct Logprobs {
300 pub token: u32,
301 pub logprob: f32,
302 pub bytes: Option<String>,
303 pub top_logprobs: Option<Vec<TopLogprob>>,
304}
305
306#[inline]
308fn cmp_desc_by_prob(a: &(u32, f32), b: &(u32, f32)) -> std::cmp::Ordering {
309 b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
310}
311
312fn partial_sort_top_k(probs: &mut [f32], k: usize, zero_rest: bool) -> Vec<(u32, f32)> {
318 let n = probs.len();
319 if n == 0 || k == 0 {
320 return Vec::new();
321 }
322
323 let mut idx_probs: Vec<(u32, f32)> = (0..n as u32).map(|i| (i, probs[i as usize])).collect();
325
326 let k = k.min(n);
327
328 if k < n {
329 idx_probs.select_nth_unstable_by(k - 1, cmp_desc_by_prob);
333
334 if zero_rest {
335 for (idx, _) in idx_probs[k..].iter() {
337 probs[*idx as usize] = 0.0;
338 }
339 }
340
341 idx_probs.truncate(k);
343 }
344
345 idx_probs.sort_unstable_by(cmp_desc_by_prob);
347
348 idx_probs
349}
350
351#[inline]
353fn argmax_f32(values: &[f32]) -> u32 {
354 values
355 .iter()
356 .enumerate()
357 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
358 .map(|(i, _)| i as u32)
359 .unwrap_or(0)
360}
361
362impl Sampler {
363 #[allow(clippy::too_many_arguments)]
364 pub fn new(
365 temperature: Option<f64>,
366 top_n_logprobs: usize,
367 tokenizer: Option<Arc<Tokenizer>>,
368 frequency_penalty: Option<f32>,
369 presence_penalty: Option<f32>,
370 repetition_penalty: Option<f32>,
371 dry_params: Option<DrySamplingParams>,
372 top_k: i64,
373 top_p: f64,
374 min_p: f64,
375 logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
376 ) -> anyhow::Result<Self> {
377 let temperature = if temperature.is_none_or(|v| v < 1e-7) {
378 None
379 } else {
380 temperature
381 };
382 let dry_params = if let Some(ref tokenizer) = tokenizer {
383 dry_params.map(|params| DrySamplingParamsInner::from(params, tokenizer))
384 } else {
385 None
386 };
387 let dry_params = match dry_params {
388 Some(fallible) => Some(fallible?),
389 None => None,
390 };
391 Ok(Self {
392 temperature,
393 top_n_logprobs,
394 tokenizer,
395 frequency_penalty,
396 presence_penalty,
397 repetition_penalty,
398 dry_params,
399 top_k,
400 top_p,
401 min_p,
402 logits_processors,
403 gumbel_cache: Arc::new(Mutex::new(None)),
404 })
405 }
406
407 fn get_top_logprobs(&self, probs: &[f32]) -> Result<Vec<TopLogprob>> {
408 let k = self.top_n_logprobs.min(probs.len());
409 if k == 0 {
410 return Ok(Vec::new());
411 }
412
413 let mut probs_copy = probs.to_vec();
415 let top_k = partial_sort_top_k(&mut probs_copy, k, false);
416
417 let mut result = Vec::with_capacity(k);
419 if let Some(tokenizer) = &self.tokenizer {
420 for (token, prob) in top_k {
421 let decoded = tokenizer
422 .decode(&[token], false)
423 .map_err(|e| Error::Msg(e.to_string()))?;
424 result.push(TopLogprob {
425 token,
426 logprob: prob.log(10.0),
427 bytes: Some(decoded),
428 });
429 }
430 } else {
431 for (token, prob) in top_k {
432 result.push(TopLogprob {
433 token,
434 logprob: prob.log(10.0),
435 bytes: None,
436 });
437 }
438 }
439 Ok(result)
440 }
441
442 fn sample_argmax(&self, logits: Tensor, return_logprobs: bool) -> Result<Logprobs> {
443 let probs: Vec<f32> = logits.to_vec1()?;
444 let next_token = argmax_f32(&probs);
445 let logprob = probs[next_token as usize].log(10.0);
446
447 let top_logprobs = if return_logprobs {
448 Some(self.get_top_logprobs(&probs)?)
449 } else {
450 None
451 };
452
453 let bytes = if let Some(tokenizer) = &self.tokenizer {
454 Some(
455 tokenizer
456 .decode(&[next_token], false)
457 .map_err(|x| Error::Msg(x.to_string()))?,
458 )
459 } else {
460 None
461 };
462
463 Ok(Logprobs {
464 token: next_token,
465 logprob,
466 top_logprobs,
467 bytes,
468 })
469 }
470
471 #[allow(unused)]
472 fn sample_fast(
473 &self,
474 logits: Tensor,
475 context: &[u32],
476 return_logprobs: bool,
477 top_k: i64,
478 top_p: f64,
479 min_p: f64,
480 ) -> Result<Logprobs> {
481 let mut probs = logits.to_dtype(DType::F32)?;
482
483 for processor in &self.logits_processors {
484 probs = processor.apply(&probs, context)?;
485 }
486
487 let context = Tensor::new(context, logits.device())?;
488 let mut counts = logits.zeros_like()?;
489 counts = counts.scatter_add(
490 &context,
491 &context.ones_like()?.to_dtype(counts.dtype())?,
492 D::Minus1,
493 )?;
494
495 let presence = counts
496 .gt(0.)?
497 .where_cond(&counts.ones_like()?, &counts.zeros_like()?)?;
498
499 match self.frequency_penalty {
500 Some(freq_penalty) if freq_penalty != 0. => {
501 probs = (probs - (freq_penalty as f64 * counts)?)?;
502 }
503 _ => (),
504 }
505
506 match self.presence_penalty {
507 Some(pres_penalty) if pres_penalty != 0. => {
508 probs = (probs - (pres_penalty as f64 * &presence)?)?;
509 }
510 _ => (),
511 }
512
513 match self.repetition_penalty {
514 Some(rep_penalty) if rep_penalty != 1. => {
515 let pos_mask = probs.gt(0.)?;
516 let scaled_pos = (&probs / (rep_penalty as f64))?;
517 let scaled_neg = (&probs * (rep_penalty as f64))?;
518 let modified = pos_mask.where_cond(&scaled_pos, &scaled_neg)?;
519
520 let pres_mask = presence.gt(0.)?;
521 probs = pres_mask.where_cond(&modified, &probs)?;
522 }
523 _ => (),
524 }
525
526 probs = candle_nn::ops::softmax_last_dim(&(probs / self.temperature.unwrap_or(1.))?)?;
527
528 if top_k > 0 {
530 let sorted_values = probs.fast_sort_asc(D::Minus1)?;
531 let topk_values = sorted_values.narrow(
532 D::Minus1,
533 sorted_values.dim(D::Minus1)? - top_k as usize,
534 top_k as usize,
535 )?;
536
537 let threshold = topk_values.get_on_dim(D::Minus1, 0)?.unsqueeze(0)?;
539 let mask_topk = probs.broadcast_ge(&threshold)?;
540 probs = mask_topk.where_cond(&probs, &Tensor::zeros_like(&probs)?)?;
541 }
542
543 if top_p > 0.0 && top_p < 1.0 {
545 let sorted_probs = probs.fast_sort_asc(D::Minus1)?;
546
547 let cumsum = sorted_probs.fast_cumsum(D::Minus1)?;
548
549 let mask_topp = cumsum.le(top_p)?;
550
551 let masked_sorted =
552 mask_topp.where_cond(&sorted_probs, &Tensor::zeros_like(&sorted_probs)?)?;
553
554 let threshold = masked_sorted.max(D::Minus1)?;
555 let threshold = threshold.unsqueeze(D::Minus1)?;
556 let mask_full = probs.broadcast_ge(&threshold)?;
557 probs = mask_full.where_cond(&probs, &Tensor::zeros_like(&probs)?)?;
558 }
559
560 if min_p > 0.0 && min_p < 1.0 {
562 let max_vals = probs.max(D::Minus1)?;
563 let threshold_min = (max_vals.unsqueeze(D::Minus1)? * min_p)?;
564 let mask_minp = probs.broadcast_gt(&threshold_min)?;
565 probs = mask_minp.where_cond(&probs, &Tensor::zeros_like(&probs)?)?;
566 }
567
568 let log_probs = probs.log()?;
570 let gumbel = {
572 let mut guard = self.gumbel_cache.lock().unwrap();
573 if guard.is_none() {
574 let uniform = Tensor::rand(0f32, 1f32, log_probs.shape(), log_probs.device())?;
575 let noise = uniform
576 .clamp(1e-20, 1.0)?
577 .log()? .neg()? .log()? .neg()?; *guard = Some(noise);
582 }
583 guard.as_ref().unwrap().clone()
584 };
585
586 let gumbel_logits = (&log_probs + &gumbel)?;
587 let next_token = gumbel_logits.argmax(D::Minus1)?.to_scalar::<u32>()?;
588
589 let (top_logprobs, logprob) = if return_logprobs {
591 let k = self.top_n_logprobs;
592
593 let sorted_values = probs.fast_sort_asc(D::Minus1)?;
594 let topk_values = sorted_values
595 .narrow(
596 D::Minus1,
597 sorted_values.dim(D::Minus1)? - top_k as usize,
598 top_k as usize,
599 )?
600 .to_vec1::<f32>()?;
601
602 let sorted_idxs = probs.fast_argsort_asc(D::Minus1)?;
603 let topk_idxs = sorted_idxs
604 .narrow(
605 D::Minus1,
606 sorted_values.dim(D::Minus1)? - top_k as usize,
607 top_k as usize,
608 )?
609 .to_vec1::<u32>()?;
610
611 let mut result = Vec::with_capacity(k);
612 if let Some(tokenizer) = &self.tokenizer {
613 for (prob, token) in topk_values.iter().zip(topk_idxs) {
614 let decoded = tokenizer
615 .decode(&[token], false)
616 .map_err(|e| Error::Msg(e.to_string()))?;
617 result.push(TopLogprob {
618 token,
619 logprob: prob.log(10.0),
620 bytes: Some(decoded),
621 });
622 }
623 } else {
624 for (prob, token) in topk_values.iter().zip(topk_idxs) {
625 result.push(TopLogprob {
626 token,
627 logprob: prob.log(10.0),
628 bytes: None,
629 });
630 }
631 }
632
633 let logprob = result.last().map(|res| res.logprob).unwrap_or(1.);
634
635 (Some(result), logprob)
636 } else {
637 (None, 1.)
638 };
639
640 let bytes = if let Some(tokenizer) = &self.tokenizer {
641 Some(
642 tokenizer
643 .decode(&[next_token], false)
644 .map_err(|x| Error::Msg(x.to_string()))?,
645 )
646 } else {
647 None
648 };
649
650 Ok(Logprobs {
651 token: next_token,
652 logprob,
653 top_logprobs,
654 bytes,
655 })
656 }
657 fn sample_speculative_top_kp_min_p(
658 &self,
659 logits: Tensor,
660 return_logprobs: bool,
661 top_k: i64,
662 top_p: f32,
663 min_p: f32,
664 ) -> Result<Logprobs> {
665 let mut probs: Vec<f32> = logits.to_vec1()?;
666
667 let k = if top_k > 0 {
669 top_k as usize
670 } else {
671 probs.len()
672 };
673
674 let idx_probs = partial_sort_top_k(&mut probs, k, true);
676
677 let mut cumsum = 0.;
684 for (index, prob) in &idx_probs {
685 if cumsum >= top_p {
686 probs[*index as usize] = 0.0;
687 } else {
688 cumsum += prob;
689 }
690 }
691
692 let max_p = idx_probs.first().map(|(_, p)| *p).unwrap_or(0.0);
694
695 let min_p_threshold = max_p * min_p;
701 for (index, prob) in &idx_probs {
702 if min_p_threshold >= *prob {
703 probs[*index as usize] = 0.0;
704 }
705 }
706
707 let next_token = argmax_f32(&probs);
709 let logprob = probs[next_token as usize].log(10.0);
710
711 let top_logprobs = if return_logprobs {
712 Some(self.get_top_logprobs(&probs)?)
713 } else {
714 None
715 };
716
717 let bytes = if let Some(tokenizer) = &self.tokenizer {
718 Some(
719 tokenizer
720 .decode(&[next_token], false)
721 .map_err(|x| Error::Msg(x.to_string()))?,
722 )
723 } else {
724 None
725 };
726
727 Ok(Logprobs {
728 token: next_token,
729 logprob,
730 top_logprobs,
731 bytes,
732 })
733 }
734
735 fn sample_multinomial(
736 &self,
737 probs: &[f32],
738 return_logprobs: bool,
739 rng: Arc<Mutex<Isaac64Rng>>,
740 ) -> Result<Logprobs> {
741 let distr = match WeightedIndex::new(probs) {
742 Ok(distr) => distr,
743 Err(e) => {
744 if let Some((idx, prob)) = probs
745 .iter()
746 .enumerate()
747 .find(|(_, prob)| !prob.is_finite() || **prob < 0.0)
748 {
749 return Err(Error::Msg(format!(
750 "Invalid sampling probability at index {idx}: {prob}. The model likely produced NaN/Inf logits."
751 )));
752 }
753
754 let positive_weight_sum: f64 = probs
755 .iter()
756 .copied()
757 .filter(|prob| prob.is_finite() && *prob > 0.0)
758 .map(f64::from)
759 .sum();
760
761 if positive_weight_sum == 0.0 {
762 return Err(Error::Msg(
763 "All sampling probabilities are zero after filtering (top-k/top-p/min-p)."
764 .to_string(),
765 ));
766 }
767
768 return Err(Error::Msg(format!(
769 "Failed to construct multinomial sampler: {e}"
770 )));
771 }
772 };
773
774 let mut mut_ref_rng = &mut *rng.lock().expect("could not lock rng mutex");
775 let next_token = distr.sample(&mut mut_ref_rng); let logprob = probs[next_token].log(10.0);
777
778 let top_logprobs = if return_logprobs {
779 Some(self.get_top_logprobs(probs)?)
780 } else {
781 None
782 };
783
784 let bytes = if let Some(tokenizer) = &self.tokenizer {
785 Some(
786 tokenizer
787 .decode(&[next_token.try_into().unwrap()], false)
788 .map_err(|x| Error::Msg(x.to_string()))?,
789 )
790 } else {
791 None
792 };
793
794 Ok(Logprobs {
795 token: next_token as u32,
796 logprob,
797 top_logprobs,
798 bytes,
799 })
800 }
801
802 #[allow(clippy::too_many_arguments)]
803 fn sample_top_kp_min_p(
804 &self,
805 probs: &mut [f32],
806 top_k: i64,
807 top_p: f32,
808 min_p: f32,
809 return_logprobs: bool,
810 rng: Arc<Mutex<Isaac64Rng>>,
811 ) -> Result<Logprobs> {
812 let k = if top_k > 0 {
814 top_k as usize
815 } else {
816 probs.len()
817 };
818
819 let idx_probs = partial_sort_top_k(probs, k, true);
821
822 if top_p <= 0.0 || top_p >= 1.0 {
823 return self.sample_multinomial(probs, return_logprobs, rng);
824 }
825
826 let mut cumsum = 0.;
834 for (index, prob) in &idx_probs {
835 if cumsum >= top_p {
836 probs[*index as usize] = 0.0;
837 } else {
838 cumsum += prob;
839 }
840 }
841
842 if min_p <= 0.0 || min_p >= 1.0 {
843 return self.sample_multinomial(probs, return_logprobs, rng);
844 }
845
846 let max_p = idx_probs.first().map(|(_, p)| *p).unwrap_or(0.0);
848
849 let min_p_threshold = max_p * min_p;
856 for (index, prob) in &idx_probs {
857 if min_p_threshold >= *prob {
858 probs[*index as usize] = 0.0;
859 }
860 }
861
862 self.sample_multinomial(probs, return_logprobs, rng)
864 }
865
866 fn apply_penalties(&self, mut logits: Vec<f32>, context: &[u32]) -> Result<Tensor> {
867 if context.is_empty() {
868 candle_core::bail!("Penalty context is empty, this should not happen.");
869 }
870
871 self.apply_dry_penalty(&mut logits, context)?;
873
874 self.apply_freq_pres_rep_penalty(&mut logits, context)?;
876
877 let vocab_size = logits.len();
878 Tensor::from_vec(logits, vocab_size, &Device::Cpu)
879 }
880
881 fn apply_freq_pres_rep_penalty(&self, logits: &mut [f32], context: &[u32]) -> Result<()> {
882 if self.frequency_penalty.is_some()
883 || self.presence_penalty.is_some()
884 || self.repetition_penalty.is_some()
885 {
886 let frequency_penalty = self.frequency_penalty.unwrap_or(0.);
887 let presence_penalty = self.presence_penalty.unwrap_or(0.);
888 let repetition_penalty = self.repetition_penalty.unwrap_or(1.);
889
890 let mut counts = vec![0.0f32; logits.len()];
893 for ctx in context.iter() {
894 if *ctx as usize >= logits.len() {
896 continue;
897 }
898 counts[*ctx as usize] += 1.0;
899 }
900
901 for (token_id, logit) in logits.iter_mut().enumerate() {
902 let count = counts[token_id];
903 *logit = *logit
904 - count * frequency_penalty
905 - if count > 0.0 { 1. } else { 0. } * presence_penalty;
906
907 if repetition_penalty != 1.0 && count > 0.0 {
908 if *logit > 0.0 {
909 *logit /= repetition_penalty;
910 } else {
911 *logit *= repetition_penalty;
912 }
913 }
914 }
915 }
916 Ok(())
917 }
918
919 const DRY_PENALTY_PAR_THRESHOLD: usize = 1024;
922
923 fn apply_dry_penalty(&self, logits: &mut [f32], context: &[u32]) -> Result<()> {
924 if let Some(ref params) = self.dry_params {
925 if params.multiplier == 0. {
926 return Ok(());
927 }
928
929 let last_token = *context.last().unwrap();
930
931 let match_indices: Vec<usize> = if context.len() > Self::DRY_PENALTY_PAR_THRESHOLD {
933 context
934 .par_iter()
935 .enumerate()
936 .take(context.len() - 1)
937 .filter(|(_i, x)| last_token == **x)
938 .map(|(i, _)| i)
939 .collect()
940 } else {
941 context
942 .iter()
943 .enumerate()
944 .take(context.len() - 1)
945 .filter(|(_i, x)| last_token == **x)
946 .map(|(i, _)| i)
947 .collect()
948 };
949
950 let mut match_lengths = HashMap::new();
951
952 for i in match_indices {
953 let next_token = context[i + 1];
954
955 if params.sequence_breakers.contains(&next_token) {
956 continue;
957 }
958
959 let mut match_length = 1;
960
961 while match_length < 50 {
963 if match_length > i {
964 break;
966 }
967
968 let j = i - match_length;
969
970 let prev_tok = context[context.len() - (match_length + 1)];
971 if context[j] != prev_tok {
972 break;
974 }
975
976 if params.sequence_breakers.contains(&prev_tok) {
977 break;
979 }
980
981 match_length += 1;
982 }
983
984 #[allow(clippy::map_entry)]
985 if match_lengths.contains_key(&next_token) {
986 match_lengths.insert(next_token, match_length.max(match_lengths[&next_token]));
987 } else {
988 match_lengths.insert(next_token, match_length);
989 }
990 }
991
992 for (tok, match_len) in match_lengths {
994 if match_len >= params.allowed_length {
995 if tok as usize >= logits.len() {
997 continue;
998 }
999 let penalty = params.multiplier
1000 * params.base.powf((match_len - params.allowed_length) as f32);
1001 logits[tok as usize] -= penalty;
1002 }
1003 }
1004 }
1005 Ok(())
1006 }
1007
1008 #[allow(unused)]
1009 pub fn sample(
1014 &self,
1015 logits: Tensor,
1016 context: &[u32],
1017 return_logprobs: bool,
1018 rng: Arc<Mutex<Isaac64Rng>>,
1019 sample_speculative: bool,
1020 multiple_sequences: bool,
1021 ) -> Result<Logprobs> {
1022 let logits = logits.to_vec1()?;
1034 let mut logits = self.apply_penalties(logits, context)?;
1035 for processor in &self.logits_processors {
1036 logits = processor.apply(&logits, context)?;
1037 }
1038 let next_token = if sample_speculative {
1039 match self.temperature {
1040 None => self.sample_speculative_top_kp_min_p(
1041 logits,
1042 return_logprobs,
1043 self.top_k,
1044 self.top_p as f32,
1045 self.min_p as f32,
1046 )?,
1047 Some(temperature) => {
1048 let logits = (&logits / temperature)?;
1049 let probs = candle_nn::ops::softmax_last_dim(&logits)?;
1050
1051 self.sample_speculative_top_kp_min_p(
1052 probs,
1053 return_logprobs,
1054 self.top_k,
1055 self.top_p as f32,
1056 self.min_p as f32,
1057 )?
1058 }
1059 }
1060 } else {
1061 match self.temperature {
1062 None => self.sample_argmax(logits, return_logprobs)?,
1063 Some(temperature) => {
1064 let logits = (&logits / temperature)?;
1065 let probs = candle_nn::ops::softmax_last_dim(&logits)?;
1066 let mut probs: Vec<f32> = probs.to_vec1()?;
1067
1068 self.sample_top_kp_min_p(
1069 &mut probs,
1070 self.top_k,
1071 self.top_p as f32,
1072 self.min_p as f32,
1073 return_logprobs,
1074 rng,
1075 )?
1076 }
1077 }
1078 };
1079 Ok(next_token)
1080 }
1081}
1082
1083#[cfg(test)]
1084mod tests {
1085 use super::{ModelGenerationDefaults, SamplingParams};
1086
1087 #[test]
1088 fn test_argmax() {
1089 use super::Sampler;
1090 use candle_core::{Device, Tensor};
1091 use rand::SeedableRng;
1092 use rand_isaac::Isaac64Rng;
1093 use std::sync::Arc;
1094 use std::sync::Mutex;
1095
1096 let sampler = Sampler::new(
1097 None,
1098 10,
1099 None,
1100 None,
1101 None,
1102 None,
1103 None,
1104 32,
1105 0.1,
1106 0.05,
1107 vec![],
1108 )
1109 .unwrap();
1110 let logits = Tensor::arange(0f32, 1024f32, &Device::Cpu).unwrap();
1111 let rng = Arc::new(Mutex::new(Isaac64Rng::seed_from_u64(42)));
1112 let res = sampler
1113 .sample(
1114 logits,
1115 &(0..1024).collect::<Vec<_>>(),
1116 false,
1117 rng,
1118 false,
1119 false,
1120 )
1121 .unwrap();
1122 assert_eq!(res.token, 1023);
1123 assert_eq!(res.top_logprobs, None);
1124 assert_eq!(res.logprob, 1023f64.log(10.) as f32)
1125 }
1126
1127 #[test]
1128 fn test_gumbel_speculative() {
1129 use super::Sampler;
1130 use candle_core::{Device, Tensor};
1131 use rand::SeedableRng;
1132 use rand_isaac::Isaac64Rng;
1133 use std::sync::Arc;
1134 use std::sync::Mutex;
1135
1136 let sampler = Sampler::new(
1137 None,
1138 10,
1139 None,
1140 None,
1141 None,
1142 None,
1143 None,
1144 32,
1145 0.1,
1146 0.05,
1147 vec![],
1148 )
1149 .unwrap();
1150 let logits = Tensor::arange(0f32, 1024f32, &Device::Cpu).unwrap();
1151 let rng = Arc::new(Mutex::new(Isaac64Rng::seed_from_u64(42)));
1152 let res = sampler
1153 .sample(
1154 logits,
1155 &(0..1024).collect::<Vec<_>>(),
1156 false,
1157 rng,
1158 true,
1159 false,
1160 )
1161 .unwrap();
1162 assert_eq!(res.token, 1023);
1163 assert_eq!(res.top_logprobs, None);
1164 assert_eq!(res.logprob, 1023f64.log(10.) as f32)
1165 }
1166
1167 #[test]
1168 fn test_apply_model_defaults() {
1169 let mut params = SamplingParams::neutral();
1170 params.apply_model_defaults(&ModelGenerationDefaults {
1171 do_sample: Some(true),
1172 temperature: Some(1.0),
1173 top_k: Some(32),
1174 top_p: Some(0.9),
1175 min_p: Some(0.05),
1176 repetition_penalty: Some(1.1),
1177 max_new_tokens: Some(256),
1178 max_length: None,
1179 });
1180
1181 assert_eq!(params.temperature, Some(1.0));
1182 assert_eq!(params.top_k, Some(32));
1183 assert_eq!(params.top_p, Some(0.9));
1184 assert_eq!(params.min_p, Some(0.05));
1185 assert_eq!(params.repetition_penalty, Some(1.1));
1186 assert_eq!(params.max_len, Some(256));
1187 }
1188
1189 #[test]
1190 fn test_apply_model_defaults_disables_sampling_when_requested() {
1191 let mut params = SamplingParams {
1192 temperature: Some(0.7),
1193 top_k: Some(40),
1194 top_p: Some(0.9),
1195 min_p: Some(0.1),
1196 ..SamplingParams::neutral()
1197 };
1198 params.apply_model_defaults(&ModelGenerationDefaults {
1199 do_sample: Some(false),
1200 ..Default::default()
1201 });
1202
1203 assert_eq!(params.temperature, None);
1204 assert_eq!(params.top_k, Some(1));
1205 assert_eq!(params.top_p, None);
1206 assert_eq!(params.min_p, None);
1207 }
1208}