1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::{
4 collections::{HashMap, HashSet},
5 sync::{Arc, LazyLock, Mutex},
6};
7
8use hanzo_ml::{Device, Error, Result, Tensor};
9#[cfg(feature = "pyo3_macros")]
10use pyo3::pyclass;
11
12use rand::distr::{weighted::WeightedIndex, Distribution};
13use rand_isaac::Isaac64Rng;
14use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
15use serde::{Deserialize, Serialize};
16use tokenizers::Tokenizer;
17
18static DRY_SEQUENCE_BREAKERS: LazyLock<Vec<String>> =
19 LazyLock::new(|| ["\n", ":", "\"", "*"].map(String::from).to_vec());
20
21#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
22pub struct ModelGenerationDefaults {
27 pub do_sample: Option<bool>,
28 pub temperature: Option<f64>,
29 pub top_k: Option<usize>,
30 pub top_p: Option<f64>,
31 pub min_p: Option<f64>,
32 pub repetition_penalty: Option<f32>,
33 pub max_new_tokens: Option<usize>,
34 pub max_length: Option<usize>,
35}
36
37impl ModelGenerationDefaults {
38 pub fn is_empty(&self) -> bool {
39 self.do_sample.is_none()
40 && self.temperature.is_none()
41 && self.top_k.is_none()
42 && self.top_p.is_none()
43 && self.min_p.is_none()
44 && self.repetition_penalty.is_none()
45 && self.max_new_tokens.is_none()
46 && self.max_length.is_none()
47 }
48}
49
50#[derive(Clone, Debug, Serialize, Deserialize)]
51pub enum StopTokens {
53 Seqs(Vec<String>),
54 Ids(Vec<u32>),
55}
56
57#[derive(Clone, Debug, Serialize, Deserialize)]
58pub struct SamplingParams {
60 pub temperature: Option<f64>,
61 pub top_k: Option<usize>,
62 pub top_p: Option<f64>,
63 pub min_p: Option<f64>,
64 pub top_n_logprobs: usize,
65 pub frequency_penalty: Option<f32>,
66 pub presence_penalty: Option<f32>,
67 pub repetition_penalty: Option<f32>,
68 pub stop_toks: Option<StopTokens>,
69 pub max_len: Option<usize>,
70 pub logits_bias: Option<HashMap<u32, f32>>,
71 pub n_choices: usize,
72 pub dry_params: Option<DrySamplingParams>,
73}
74
75impl SamplingParams {
76 pub fn neutral() -> Self {
83 Self {
84 temperature: None,
85 top_k: None,
86 top_p: None,
87 min_p: None,
88 top_n_logprobs: 0,
89 frequency_penalty: None,
90 presence_penalty: None,
91 repetition_penalty: None,
92 stop_toks: None,
93 max_len: None,
94 logits_bias: None,
95 n_choices: 1,
96 dry_params: None,
97 }
98 }
99
100 pub fn deterministic() -> Self {
105 Self {
106 temperature: None,
107 top_k: Some(1),
108 top_p: None,
109 min_p: None,
110 top_n_logprobs: 0,
111 frequency_penalty: None,
112 presence_penalty: None,
113 repetition_penalty: None,
114 stop_toks: None,
115 max_len: None,
116 logits_bias: None,
117 n_choices: 1,
118 dry_params: None,
119 }
120 }
121
122 pub fn apply_model_defaults(&mut self, defaults: &ModelGenerationDefaults) {
126 if defaults.do_sample == Some(false) {
127 self.temperature = None;
128 self.top_k = Some(1);
129 self.top_p = None;
130 self.min_p = None;
131 }
132
133 if let Some(temperature) = defaults.temperature {
134 self.temperature = if temperature == 0.0 {
135 None
136 } else {
137 Some(temperature)
138 };
139 }
140 if let Some(top_k) = defaults.top_k {
141 self.top_k = if top_k == 0 { None } else { Some(top_k) };
142 }
143 if let Some(top_p) = defaults.top_p {
144 self.top_p = Some(top_p);
145 }
146 if let Some(min_p) = defaults.min_p {
147 self.min_p = Some(min_p);
148 }
149 if let Some(repetition_penalty) = defaults.repetition_penalty {
150 self.repetition_penalty = Some(repetition_penalty);
151 }
152 if let Some(max_new_tokens) = defaults.max_new_tokens {
153 self.max_len = Some(max_new_tokens);
154 }
155 }
156}
157
158#[derive(Clone, Debug, Serialize, Deserialize)]
160pub struct DrySamplingParams {
161 pub sequence_breakers: Vec<String>,
162 pub multiplier: f32,
163 pub base: f32,
164 pub allowed_length: usize,
165}
166
167impl DrySamplingParams {
168 pub fn new_with_defaults(
169 multiplier: f32,
170 sequence_breakers: Option<Vec<String>>,
171 base: Option<f32>,
172 allowed_length: Option<usize>,
173 ) -> anyhow::Result<Self> {
174 Ok(Self {
175 base: base.unwrap_or(1.75),
176 allowed_length: allowed_length.unwrap_or(2),
177 sequence_breakers: sequence_breakers.unwrap_or(DRY_SEQUENCE_BREAKERS.clone()),
178 multiplier,
179 })
180 }
181}
182
183impl Default for DrySamplingParams {
184 fn default() -> Self {
185 Self {
186 multiplier: 0.0,
187 base: 1.75,
188 allowed_length: 2,
189 sequence_breakers: DRY_SEQUENCE_BREAKERS.clone(),
190 }
191 }
192}
193
194#[derive(Clone, Debug)]
195struct DrySamplingParamsInner {
196 pub sequence_breakers: HashSet<u32>,
197 pub multiplier: f32,
198 pub base: f32,
199 pub allowed_length: usize,
200}
201
202impl DrySamplingParamsInner {
203 pub fn from(other: DrySamplingParams, tokenizer: &Tokenizer) -> anyhow::Result<Self> {
204 Ok(Self {
205 base: other.base,
206 allowed_length: other.allowed_length,
207 sequence_breakers: HashSet::from_iter(
208 other
209 .sequence_breakers
210 .into_iter()
211 .map(|breaker| {
212 tokenizer
213 .encode_fast(["a", &breaker].concat(), true)
219 .map_err(anyhow::Error::msg)
220 .map(|enc| {
221 let ids = enc.get_ids();
222 if !ids.is_empty() {
223 Some(ids[ids.len() - 1])
224 } else {
225 None
226 }
227 })
228 })
229 .collect::<anyhow::Result<Vec<_>>>()?
230 .into_iter()
231 .flatten()
232 .collect::<Vec<_>>(),
233 ),
234 multiplier: other.multiplier,
235 })
236 }
237}
238
239pub trait CustomLogitsProcessor: Send + Sync {
259 fn apply(&self, logits: &Tensor, context: &[u32]) -> Result<Tensor>;
261}
262
263impl<T: Fn(&Tensor, &[u32]) -> Result<Tensor> + Send + Sync> CustomLogitsProcessor for T {
264 fn apply(&self, logits: &Tensor, context: &[u32]) -> Result<Tensor> {
265 self(logits, context)
266 }
267}
268
269#[derive(Clone)]
271pub struct Sampler {
272 temperature: Option<f64>,
273 top_n_logprobs: usize,
274 tokenizer: Option<Arc<Tokenizer>>,
275 frequency_penalty: Option<f32>,
276 presence_penalty: Option<f32>,
277 repetition_penalty: Option<f32>,
278 dry_params: Option<DrySamplingParamsInner>,
279 top_k: i64,
280 top_p: f64,
281 min_p: f64,
282 logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
283}
284
285#[cfg_attr(feature = "pyo3_macros", pyclass)]
286#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
287#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
288pub struct TopLogprob {
290 pub token: u32,
291 pub logprob: f32,
292 pub bytes: Option<String>,
293}
294
295#[derive(Debug, Clone, Serialize, Deserialize)]
296pub struct Logprobs {
297 pub token: u32,
298 pub logprob: f32,
299 pub bytes: Option<String>,
300 pub top_logprobs: Option<Vec<TopLogprob>>,
301}
302
303#[inline]
305fn cmp_desc_by_prob(a: &(u32, f32), b: &(u32, f32)) -> std::cmp::Ordering {
306 b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
307}
308
309fn partial_sort_top_k(probs: &mut [f32], k: usize, zero_rest: bool) -> Vec<(u32, f32)> {
315 let n = probs.len();
316 if n == 0 || k == 0 {
317 return Vec::new();
318 }
319
320 let mut idx_probs: Vec<(u32, f32)> = (0..n as u32).map(|i| (i, probs[i as usize])).collect();
322
323 let k = k.min(n);
324
325 if k < n {
326 idx_probs.select_nth_unstable_by(k - 1, cmp_desc_by_prob);
330
331 if zero_rest {
332 for (idx, _) in idx_probs[k..].iter() {
334 probs[*idx as usize] = 0.0;
335 }
336 }
337
338 idx_probs.truncate(k);
340 }
341
342 idx_probs.sort_unstable_by(cmp_desc_by_prob);
344
345 idx_probs
346}
347
348#[inline]
350fn argmax_f32(values: &[f32]) -> u32 {
351 values
352 .iter()
353 .enumerate()
354 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
355 .map(|(i, _)| i as u32)
356 .unwrap_or(0)
357}
358
359impl Sampler {
360 #[allow(clippy::too_many_arguments)]
361 pub fn new(
362 temperature: Option<f64>,
363 top_n_logprobs: usize,
364 tokenizer: Option<Arc<Tokenizer>>,
365 frequency_penalty: Option<f32>,
366 presence_penalty: Option<f32>,
367 repetition_penalty: Option<f32>,
368 dry_params: Option<DrySamplingParams>,
369 top_k: i64,
370 top_p: f64,
371 min_p: f64,
372 logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
373 ) -> anyhow::Result<Self> {
374 let temperature = if temperature.is_none_or(|v| v < 1e-7) {
375 None
376 } else {
377 temperature
378 };
379 let dry_params = if let Some(ref tokenizer) = tokenizer {
380 dry_params.map(|params| DrySamplingParamsInner::from(params, tokenizer))
381 } else {
382 None
383 };
384 let dry_params = match dry_params {
385 Some(fallible) => Some(fallible?),
386 None => None,
387 };
388 Ok(Self {
389 temperature,
390 top_n_logprobs,
391 tokenizer,
392 frequency_penalty,
393 presence_penalty,
394 repetition_penalty,
395 dry_params,
396 top_k,
397 top_p,
398 min_p,
399 logits_processors,
400 })
401 }
402
403 pub fn is_argmax(&self) -> bool {
404 self.temperature.is_none()
405 }
406
407 fn get_top_logprobs(&self, probs: &[f32]) -> Result<Vec<TopLogprob>> {
408 let k = self.top_n_logprobs.min(probs.len());
409 if k == 0 {
410 return Ok(Vec::new());
411 }
412
413 let mut probs_copy = probs.to_vec();
415 let top_k = partial_sort_top_k(&mut probs_copy, k, false);
416
417 let mut result = Vec::with_capacity(k);
419 if let Some(tokenizer) = &self.tokenizer {
420 for (token, prob) in top_k {
421 let decoded = tokenizer
422 .decode(&[token], false)
423 .map_err(|e| Error::Msg(e.to_string()))?;
424 result.push(TopLogprob {
425 token,
426 logprob: prob.log(10.0),
427 bytes: Some(decoded),
428 });
429 }
430 } else {
431 for (token, prob) in top_k {
432 result.push(TopLogprob {
433 token,
434 logprob: prob.log(10.0),
435 bytes: None,
436 });
437 }
438 }
439 Ok(result)
440 }
441
442 fn sample_argmax(&self, logits: Tensor, return_logprobs: bool) -> Result<Logprobs> {
443 let probs: Vec<f32> = logits.to_vec1()?;
444 let next_token = argmax_f32(&probs);
445 let logprob = probs[next_token as usize].log(10.0);
446
447 let top_logprobs = if return_logprobs {
448 Some(self.get_top_logprobs(&probs)?)
449 } else {
450 None
451 };
452
453 let bytes = if let Some(tokenizer) = &self.tokenizer {
454 Some(
455 tokenizer
456 .decode(&[next_token], false)
457 .map_err(|x| Error::Msg(x.to_string()))?,
458 )
459 } else {
460 None
461 };
462
463 Ok(Logprobs {
464 token: next_token,
465 logprob,
466 top_logprobs,
467 bytes,
468 })
469 }
470
471 fn sample_speculative_top_kp_min_p(
472 &self,
473 logits: Tensor,
474 return_logprobs: bool,
475 top_k: i64,
476 top_p: f32,
477 min_p: f32,
478 ) -> Result<Logprobs> {
479 let mut probs: Vec<f32> = logits.to_vec1()?;
480
481 let k = if top_k > 0 {
483 top_k as usize
484 } else {
485 probs.len()
486 };
487
488 let idx_probs = partial_sort_top_k(&mut probs, k, true);
490
491 let mut cumsum = 0.;
498 for (index, prob) in &idx_probs {
499 if cumsum >= top_p {
500 probs[*index as usize] = 0.0;
501 } else {
502 cumsum += prob;
503 }
504 }
505
506 let max_p = idx_probs.first().map(|(_, p)| *p).unwrap_or(0.0);
508
509 let min_p_threshold = max_p * min_p;
515 for (index, prob) in &idx_probs {
516 if min_p_threshold >= *prob {
517 probs[*index as usize] = 0.0;
518 }
519 }
520
521 let next_token = argmax_f32(&probs);
523 let logprob = probs[next_token as usize].log(10.0);
524
525 let top_logprobs = if return_logprobs {
526 Some(self.get_top_logprobs(&probs)?)
527 } else {
528 None
529 };
530
531 let bytes = if let Some(tokenizer) = &self.tokenizer {
532 Some(
533 tokenizer
534 .decode(&[next_token], false)
535 .map_err(|x| Error::Msg(x.to_string()))?,
536 )
537 } else {
538 None
539 };
540
541 Ok(Logprobs {
542 token: next_token,
543 logprob,
544 top_logprobs,
545 bytes,
546 })
547 }
548
549 fn sample_multinomial(
550 &self,
551 probs: &[f32],
552 return_logprobs: bool,
553 rng: Arc<Mutex<Isaac64Rng>>,
554 ) -> Result<Logprobs> {
555 let distr = match WeightedIndex::new(probs) {
556 Ok(distr) => distr,
557 Err(e) => {
558 if let Some((idx, prob)) = probs
559 .iter()
560 .enumerate()
561 .find(|(_, prob)| !prob.is_finite() || **prob < 0.0)
562 {
563 return Err(Error::Msg(format!(
564 "Invalid sampling probability at index {idx}: {prob}. The model likely produced NaN/Inf logits."
565 )));
566 }
567
568 let positive_weight_sum: f64 = probs
569 .iter()
570 .copied()
571 .filter(|prob| prob.is_finite() && *prob > 0.0)
572 .map(f64::from)
573 .sum();
574
575 if positive_weight_sum == 0.0 {
576 return Err(Error::Msg(
577 "All sampling probabilities are zero after filtering (top-k/top-p/min-p)."
578 .to_string(),
579 ));
580 }
581
582 return Err(Error::Msg(format!(
583 "Failed to construct multinomial sampler: {e}"
584 )));
585 }
586 };
587
588 let mut mut_ref_rng = &mut *rng.lock().expect("could not lock rng mutex");
589 let next_token = distr.sample(&mut mut_ref_rng); let logprob = probs[next_token].log(10.0);
591
592 let top_logprobs = if return_logprobs {
593 Some(self.get_top_logprobs(probs)?)
594 } else {
595 None
596 };
597
598 let bytes = if let Some(tokenizer) = &self.tokenizer {
599 Some(
600 tokenizer
601 .decode(&[next_token.try_into().unwrap()], false)
602 .map_err(|x| Error::Msg(x.to_string()))?,
603 )
604 } else {
605 None
606 };
607
608 Ok(Logprobs {
609 token: next_token as u32,
610 logprob,
611 top_logprobs,
612 bytes,
613 })
614 }
615
616 #[cfg(any(feature = "cuda", feature = "metal"))]
617 fn can_sample_topk_on_device(
618 &self,
619 return_logprobs: bool,
620 sample_speculative: bool,
621 multiple_sequences: bool,
622 ) -> bool {
623 const MAX_DEVICE_TOP_K: i64 = 128;
624
625 !return_logprobs
626 && !sample_speculative
627 && !multiple_sequences
628 && self.temperature.is_some()
629 && self.top_k > 0
630 && self.top_k <= MAX_DEVICE_TOP_K
631 && self.logits_processors.is_empty()
632 && self
633 .dry_params
634 .as_ref()
635 .is_none_or(|params| params.multiplier.abs() <= f32::EPSILON)
636 }
637
638 #[cfg(feature = "cuda")]
639 fn apply_device_sparse_penalties_if_needed(
640 &self,
641 logits: Tensor,
642 context: &[u32],
643 ) -> Result<Tensor> {
644 let frequency_penalty = self.frequency_penalty.unwrap_or(0.0);
645 let presence_penalty = self.presence_penalty.unwrap_or(0.0);
646 let repetition_penalty = self.repetition_penalty.unwrap_or(1.0);
647 let needs_penalty = frequency_penalty.abs() > f32::EPSILON
648 || presence_penalty.abs() > f32::EPSILON
649 || (repetition_penalty - 1.0).abs() > f32::EPSILON;
650
651 if !needs_penalty {
652 return Ok(logits);
653 }
654 if context.is_empty() {
655 hanzo_ml::bail!("Penalty context is empty, this should not happen.");
656 }
657
658 let vocab_size = logits.elem_count();
659 let mut counts = HashMap::<u32, f32>::with_capacity(context.len().min(vocab_size));
660 for &token_id in context {
661 if token_id as usize >= vocab_size {
662 continue;
663 }
664 *counts.entry(token_id).or_insert(0.0) += 1.0;
665 }
666
667 if counts.is_empty() {
668 return Ok(logits);
669 }
670
671 let n_tokens = counts.len();
672 let mut token_ids = Vec::with_capacity(n_tokens);
673 let mut token_counts = Vec::with_capacity(n_tokens);
674 for (token_id, count) in counts {
675 token_ids.push(token_id);
676 token_counts.push(count);
677 }
678
679 let device = logits.device();
680 let token_ids = Tensor::from_vec(token_ids, n_tokens, device)?;
681 let token_counts = Tensor::from_vec(token_counts, n_tokens, device)?;
682 crate::ops::cuda_apply_sparse_penalties_f32(
683 &logits,
684 &token_ids,
685 &token_counts,
686 frequency_penalty,
687 presence_penalty,
688 repetition_penalty,
689 )
690 }
691
692 #[cfg(feature = "cuda")]
693 fn sample_topk_on_device(
694 &self,
695 logits: Tensor,
696 temperature: f64,
697 rng: Arc<Mutex<Isaac64Rng>>,
698 ) -> Result<Logprobs> {
699 let topk =
700 crate::ops::cuda_topk_logits_f32_packed(&logits, self.top_k as usize, temperature)?;
701 let packed = topk.packed.to_vec1::<f32>()?;
702 let k = topk.k;
703 if packed.len() != 2 * k + 2 {
704 hanzo_ml::bail!(
705 "invalid CUDA top-k packed output length {}, expected {}",
706 packed.len(),
707 2 * k + 2
708 );
709 }
710 let top_values = &packed[..k];
711 let top_indices = packed[k..2 * k]
712 .iter()
713 .map(|idx| *idx as u32)
714 .collect::<Vec<_>>();
715 let softmax_info = &packed[2 * k..2 * k + 2];
716
717 let denom = softmax_info[0];
718 let global_max = softmax_info[1];
719 if denom <= 0.0 || !denom.is_finite() || !global_max.is_finite() {
720 hanzo_ml::bail!("invalid CUDA top-k softmax normalizer");
721 }
722
723 let inv_temperature = (1.0 / temperature) as f32;
724 let mut probs = top_values
725 .iter()
726 .map(|value| ((*value * inv_temperature - global_max).exp()) / denom)
727 .collect::<Vec<_>>();
728
729 if self.top_p > 0.0 && self.top_p < 1.0 {
730 let mut cumsum = 0.0f32;
731 for prob in &mut probs {
732 if cumsum >= self.top_p as f32 {
733 *prob = 0.0;
734 } else {
735 cumsum += *prob;
736 }
737 }
738
739 if self.min_p > 0.0 && self.min_p < 1.0 {
740 let max_p = probs.first().copied().unwrap_or(0.0);
741 let min_p_threshold = max_p * self.min_p as f32;
742 for prob in &mut probs {
743 if min_p_threshold >= *prob {
744 *prob = 0.0;
745 }
746 }
747 }
748 }
749
750 let distr = match WeightedIndex::new(&probs) {
751 Ok(distr) => distr,
752 Err(e) => {
753 let positive_weight_sum: f64 = probs
754 .iter()
755 .copied()
756 .filter(|prob| prob.is_finite() && *prob > 0.0)
757 .map(f64::from)
758 .sum();
759 if positive_weight_sum == 0.0 {
760 return Err(Error::Msg(
761 "All sampling probabilities are zero after CUDA top-k filtering."
762 .to_string(),
763 ));
764 }
765
766 return Err(Error::Msg(format!(
767 "Failed to construct CUDA top-k multinomial sampler: {e}"
768 )));
769 }
770 };
771
772 let mut mut_ref_rng = &mut *rng.lock().expect("could not lock rng mutex");
773 let selected = distr.sample(&mut mut_ref_rng);
774 let next_token = top_indices[selected];
775 let logprob = probs[selected].log(10.0);
776
777 let bytes = if let Some(tokenizer) = &self.tokenizer {
778 Some(
779 tokenizer
780 .decode(&[next_token], false)
781 .map_err(|x| Error::Msg(x.to_string()))?,
782 )
783 } else {
784 None
785 };
786
787 Ok(Logprobs {
788 token: next_token,
789 logprob,
790 top_logprobs: None,
791 bytes,
792 })
793 }
794
795 #[cfg(feature = "metal")]
796 fn apply_device_sparse_penalties_if_needed_metal(
797 &self,
798 logits: Tensor,
799 context: &[u32],
800 ) -> Result<Tensor> {
801 let frequency_penalty = self.frequency_penalty.unwrap_or(0.0);
802 let presence_penalty = self.presence_penalty.unwrap_or(0.0);
803 let repetition_penalty = self.repetition_penalty.unwrap_or(1.0);
804 let needs_penalty = frequency_penalty.abs() > f32::EPSILON
805 || presence_penalty.abs() > f32::EPSILON
806 || (repetition_penalty - 1.0).abs() > f32::EPSILON;
807 if !needs_penalty || context.is_empty() {
808 return Ok(logits);
809 }
810 let vocab_size = logits.elem_count();
811 let mut counts = HashMap::<u32, f32>::with_capacity(context.len().min(vocab_size));
812 for &tid in context {
813 if (tid as usize) >= vocab_size {
814 continue;
815 }
816 *counts.entry(tid).or_insert(0.0) += 1.0;
817 }
818 if counts.is_empty() {
819 return Ok(logits);
820 }
821 let n_tokens = counts.len();
822 let mut token_ids = Vec::with_capacity(n_tokens);
823 let mut token_counts = Vec::with_capacity(n_tokens);
824 for (tid, c) in counts {
825 token_ids.push(tid);
826 token_counts.push(c);
827 }
828 let device = logits.device();
829 let token_ids = Tensor::from_vec(token_ids, n_tokens, device)?;
830 let token_counts = Tensor::from_vec(token_counts, n_tokens, device)?;
831 crate::ops::metal_apply_sparse_penalties(
832 &logits,
833 &token_ids,
834 &token_counts,
835 frequency_penalty,
836 presence_penalty,
837 repetition_penalty,
838 )
839 }
840
841 #[cfg(feature = "metal")]
842 fn sample_topk_on_device_metal(
843 &self,
844 logits: Tensor,
845 temperature: f64,
846 rng: Arc<Mutex<Isaac64Rng>>,
847 ) -> Result<Logprobs> {
848 let topk = crate::ops::metal_topk_logits_packed(&logits, self.top_k as usize, temperature)?;
849 let packed = topk.packed.to_vec1::<f32>()?;
850 let k = topk.k;
851 if packed.len() != 2 * k + 2 {
852 hanzo_ml::bail!(
853 "invalid Metal top-k packed output length {}, expected {}",
854 packed.len(),
855 2 * k + 2
856 );
857 }
858 let top_values = &packed[..k];
859 let top_indices = packed[k..2 * k]
860 .iter()
861 .map(|idx| *idx as u32)
862 .collect::<Vec<_>>();
863 let softmax_info = &packed[2 * k..2 * k + 2];
864 let denom = softmax_info[0];
865 let global_max = softmax_info[1];
866 if denom <= 0.0 || !denom.is_finite() || !global_max.is_finite() {
867 hanzo_ml::bail!("invalid Metal top-k softmax normalizer");
868 }
869
870 let inv_temperature = (1.0 / temperature) as f32;
871 let mut probs = top_values
872 .iter()
873 .map(|value| ((*value * inv_temperature - global_max).exp()) / denom)
874 .collect::<Vec<_>>();
875
876 if self.top_p > 0.0 && self.top_p < 1.0 {
877 let mut cumsum = 0.0f32;
878 for prob in &mut probs {
879 if cumsum >= self.top_p as f32 {
880 *prob = 0.0;
881 } else {
882 cumsum += *prob;
883 }
884 }
885 if self.min_p > 0.0 && self.min_p < 1.0 {
886 let max_p = probs.first().copied().unwrap_or(0.0);
887 let min_p_threshold = max_p * self.min_p as f32;
888 for prob in &mut probs {
889 if min_p_threshold >= *prob {
890 *prob = 0.0;
891 }
892 }
893 }
894 }
895
896 let distr = match WeightedIndex::new(&probs) {
897 Ok(distr) => distr,
898 Err(e) => {
899 let positive_weight_sum: f64 = probs
900 .iter()
901 .copied()
902 .filter(|prob| prob.is_finite() && *prob > 0.0)
903 .map(f64::from)
904 .sum();
905 if positive_weight_sum == 0.0 {
906 return Err(Error::Msg(
907 "All sampling probabilities are zero after Metal top-k filtering."
908 .to_string(),
909 ));
910 }
911 return Err(Error::Msg(format!(
912 "Failed to construct Metal top-k multinomial sampler: {e}"
913 )));
914 }
915 };
916
917 let mut mut_ref_rng = &mut *rng.lock().expect("could not lock rng mutex");
918 let selected = distr.sample(&mut mut_ref_rng);
919 let next_token = top_indices[selected];
920 let logprob = probs[selected].log(10.0);
921 let bytes = if let Some(tokenizer) = &self.tokenizer {
922 Some(
923 tokenizer
924 .decode(&[next_token], false)
925 .map_err(|x| Error::Msg(x.to_string()))?,
926 )
927 } else {
928 None
929 };
930 Ok(Logprobs {
931 token: next_token,
932 logprob,
933 top_logprobs: None,
934 bytes,
935 })
936 }
937
938 fn filter_top_kp_min_p(&self, probs: &mut [f32]) {
939 let k = if self.top_k > 0 {
940 self.top_k as usize
941 } else {
942 probs.len()
943 };
944
945 let idx_probs = partial_sort_top_k(probs, k, true);
946
947 if self.top_p <= 0.0 || self.top_p >= 1.0 {
948 return;
949 }
950
951 let mut cumsum = 0.0f32;
952 for (index, prob) in &idx_probs {
953 if cumsum >= self.top_p as f32 {
954 probs[*index as usize] = 0.0;
955 } else {
956 cumsum += prob;
957 }
958 }
959
960 if self.min_p <= 0.0 || self.min_p >= 1.0 {
961 return;
962 }
963
964 let max_p = idx_probs.first().map(|(_, p)| *p).unwrap_or(0.0);
965 let min_p_threshold = max_p * self.min_p as f32;
966 for (index, prob) in &idx_probs {
967 if min_p_threshold >= *prob {
968 probs[*index as usize] = 0.0;
969 }
970 }
971 }
972
973 fn normalize_probs(probs: &mut [f32]) -> Result<()> {
974 let sum: f32 = probs
975 .iter()
976 .copied()
977 .filter(|prob| prob.is_finite() && *prob > 0.0)
978 .sum();
979 if sum <= 0.0 {
980 hanzo_ml::bail!("all probabilities are zero in speculative sampling");
981 }
982 for prob in probs.iter_mut() {
983 if prob.is_finite() && *prob > 0.0 {
984 *prob /= sum;
985 } else {
986 *prob = 0.0;
987 }
988 }
989 Ok(())
990 }
991
992 pub(crate) fn speculative_target_probs(
993 &self,
994 logits: Tensor,
995 context: &[u32],
996 ) -> Result<Vec<f32>> {
997 self.speculative_probs(logits, context)
998 }
999
1000 pub(crate) fn speculative_candidate_probs(
1001 &self,
1002 logits: Tensor,
1003 context: &[u32],
1004 ) -> Result<Vec<f32>> {
1005 self.speculative_probs(logits, context)
1006 }
1007
1008 fn speculative_probs(&self, logits: Tensor, context: &[u32]) -> Result<Vec<f32>> {
1009 let logits = logits.to_vec1()?;
1010 let mut logits = self.apply_penalties(logits, context)?;
1011 for processor in &self.logits_processors {
1012 logits = processor.apply(&logits, context)?;
1013 }
1014
1015 let mut probs = match self.temperature {
1016 None => {
1017 let logits = logits.to_vec1::<f32>()?;
1018 let mut probs = vec![0.0; logits.len()];
1019 probs[argmax_f32(&logits) as usize] = 1.0;
1020 probs
1021 }
1022 Some(temperature) => {
1023 let logits = (&logits / temperature)?;
1024 hanzo_nn::ops::softmax_last_dim(&logits)?.to_vec1::<f32>()?
1025 }
1026 };
1027 self.filter_top_kp_min_p(&mut probs);
1028 Self::normalize_probs(&mut probs)?;
1029 Ok(probs)
1030 }
1031
1032 pub(crate) fn logprobs_from_probs(
1033 &self,
1034 token: u32,
1035 probs: &[f32],
1036 return_logprobs: bool,
1037 ) -> Result<Logprobs> {
1038 let prob = probs.get(token as usize).copied().unwrap_or(0.0);
1039 let logprob = if prob > 0.0 {
1040 prob.log(10.0)
1041 } else {
1042 f32::NEG_INFINITY
1043 };
1044 let top_logprobs = if return_logprobs {
1045 Some(self.get_top_logprobs(probs)?)
1046 } else {
1047 None
1048 };
1049 let bytes = if let Some(tokenizer) = &self.tokenizer {
1050 Some(
1051 tokenizer
1052 .decode(&[token], false)
1053 .map_err(|x| Error::Msg(x.to_string()))?,
1054 )
1055 } else {
1056 None
1057 };
1058 Ok(Logprobs {
1059 token,
1060 logprob,
1061 top_logprobs,
1062 bytes,
1063 })
1064 }
1065
1066 pub(crate) fn sample_from_probs(
1067 &self,
1068 probs: &[f32],
1069 return_logprobs: bool,
1070 rng: Arc<Mutex<Isaac64Rng>>,
1071 ) -> Result<Logprobs> {
1072 self.sample_multinomial(probs, return_logprobs, rng)
1073 }
1074
1075 #[allow(clippy::too_many_arguments)]
1076 fn sample_top_kp_min_p(
1077 &self,
1078 probs: &mut [f32],
1079 top_k: i64,
1080 top_p: f32,
1081 min_p: f32,
1082 return_logprobs: bool,
1083 rng: Arc<Mutex<Isaac64Rng>>,
1084 ) -> Result<Logprobs> {
1085 let k = if top_k > 0 {
1087 top_k as usize
1088 } else {
1089 probs.len()
1090 };
1091
1092 let idx_probs = partial_sort_top_k(probs, k, true);
1094
1095 if top_p <= 0.0 || top_p >= 1.0 {
1096 return self.sample_multinomial(probs, return_logprobs, rng);
1097 }
1098
1099 let mut cumsum = 0.;
1107 for (index, prob) in &idx_probs {
1108 if cumsum >= top_p {
1109 probs[*index as usize] = 0.0;
1110 } else {
1111 cumsum += prob;
1112 }
1113 }
1114
1115 if min_p <= 0.0 || min_p >= 1.0 {
1116 return self.sample_multinomial(probs, return_logprobs, rng);
1117 }
1118
1119 let max_p = idx_probs.first().map(|(_, p)| *p).unwrap_or(0.0);
1121
1122 let min_p_threshold = max_p * min_p;
1129 for (index, prob) in &idx_probs {
1130 if min_p_threshold >= *prob {
1131 probs[*index as usize] = 0.0;
1132 }
1133 }
1134
1135 self.sample_multinomial(probs, return_logprobs, rng)
1137 }
1138
1139 fn apply_penalties(&self, mut logits: Vec<f32>, context: &[u32]) -> Result<Tensor> {
1140 if context.is_empty() {
1141 hanzo_ml::bail!("Penalty context is empty, this should not happen.");
1142 }
1143
1144 self.apply_dry_penalty(&mut logits, context)?;
1146
1147 self.apply_freq_pres_rep_penalty(&mut logits, context)?;
1149
1150 let vocab_size = logits.len();
1151 Tensor::from_vec(logits, vocab_size, &Device::Cpu)
1152 }
1153
1154 fn apply_freq_pres_rep_penalty(&self, logits: &mut [f32], context: &[u32]) -> Result<()> {
1155 if self.frequency_penalty.is_some()
1156 || self.presence_penalty.is_some()
1157 || self.repetition_penalty.is_some()
1158 {
1159 let frequency_penalty = self.frequency_penalty.unwrap_or(0.);
1160 let presence_penalty = self.presence_penalty.unwrap_or(0.);
1161 let repetition_penalty = self.repetition_penalty.unwrap_or(1.);
1162
1163 let mut counts = vec![0.0f32; logits.len()];
1166 for ctx in context.iter() {
1167 if *ctx as usize >= logits.len() {
1169 continue;
1170 }
1171 counts[*ctx as usize] += 1.0;
1172 }
1173
1174 for (token_id, logit) in logits.iter_mut().enumerate() {
1175 let count = counts[token_id];
1176 *logit = *logit
1177 - count * frequency_penalty
1178 - if count > 0.0 { 1. } else { 0. } * presence_penalty;
1179
1180 if repetition_penalty != 1.0 && count > 0.0 {
1181 if *logit > 0.0 {
1182 *logit /= repetition_penalty;
1183 } else {
1184 *logit *= repetition_penalty;
1185 }
1186 }
1187 }
1188 }
1189 Ok(())
1190 }
1191
1192 const DRY_PENALTY_PAR_THRESHOLD: usize = 1024;
1195
1196 fn apply_dry_penalty(&self, logits: &mut [f32], context: &[u32]) -> Result<()> {
1197 if let Some(ref params) = self.dry_params {
1198 if params.multiplier == 0. {
1199 return Ok(());
1200 }
1201
1202 let last_token = *context.last().unwrap();
1203
1204 let match_indices: Vec<usize> = if context.len() > Self::DRY_PENALTY_PAR_THRESHOLD {
1206 context
1207 .par_iter()
1208 .enumerate()
1209 .take(context.len() - 1)
1210 .filter(|(_i, x)| last_token == **x)
1211 .map(|(i, _)| i)
1212 .collect()
1213 } else {
1214 context
1215 .iter()
1216 .enumerate()
1217 .take(context.len() - 1)
1218 .filter(|(_i, x)| last_token == **x)
1219 .map(|(i, _)| i)
1220 .collect()
1221 };
1222
1223 let mut match_lengths = HashMap::new();
1224
1225 for i in match_indices {
1226 let next_token = context[i + 1];
1227
1228 if params.sequence_breakers.contains(&next_token) {
1229 continue;
1230 }
1231
1232 let mut match_length = 1;
1233
1234 while match_length < 50 {
1236 if match_length > i {
1237 break;
1239 }
1240
1241 let j = i - match_length;
1242
1243 let prev_tok = context[context.len() - (match_length + 1)];
1244 if context[j] != prev_tok {
1245 break;
1247 }
1248
1249 if params.sequence_breakers.contains(&prev_tok) {
1250 break;
1252 }
1253
1254 match_length += 1;
1255 }
1256
1257 #[allow(clippy::map_entry)]
1258 if match_lengths.contains_key(&next_token) {
1259 match_lengths.insert(next_token, match_length.max(match_lengths[&next_token]));
1260 } else {
1261 match_lengths.insert(next_token, match_length);
1262 }
1263 }
1264
1265 for (tok, match_len) in match_lengths {
1267 if match_len >= params.allowed_length {
1268 if tok as usize >= logits.len() {
1270 continue;
1271 }
1272 let penalty = params.multiplier
1273 * params.base.powf((match_len - params.allowed_length) as f32);
1274 logits[tok as usize] -= penalty;
1275 }
1276 }
1277 }
1278 Ok(())
1279 }
1280
1281 #[allow(unused)]
1282 pub fn sample(
1287 &self,
1288 logits: Tensor,
1289 context: &[u32],
1290 return_logprobs: bool,
1291 rng: Arc<Mutex<Isaac64Rng>>,
1292 sample_speculative: bool,
1293 multiple_sequences: bool,
1294 ) -> Result<Logprobs> {
1295 #[cfg(feature = "cuda")]
1296 if logits.device().is_cuda()
1297 && self.can_sample_topk_on_device(
1298 return_logprobs,
1299 sample_speculative,
1300 multiple_sequences,
1301 )
1302 {
1303 if let Some(temperature) = self.temperature {
1304 let logits = self.apply_device_sparse_penalties_if_needed(logits, context)?;
1305 return self.sample_topk_on_device(logits, temperature, rng);
1306 }
1307 }
1308
1309 #[cfg(feature = "metal")]
1310 if logits.device().is_metal()
1311 && self.can_sample_topk_on_device(
1312 return_logprobs,
1313 sample_speculative,
1314 multiple_sequences,
1315 )
1316 {
1317 if let Some(temperature) = self.temperature {
1318 let logits = self.apply_device_sparse_penalties_if_needed_metal(logits, context)?;
1319 return self.sample_topk_on_device_metal(logits, temperature, rng);
1320 }
1321 }
1322
1323 let logits = logits.to_vec1()?;
1324 let mut logits = self.apply_penalties(logits, context)?;
1325 for processor in &self.logits_processors {
1326 logits = processor.apply(&logits, context)?;
1327 }
1328 let next_token = if sample_speculative {
1329 match self.temperature {
1330 None => self.sample_speculative_top_kp_min_p(
1331 logits,
1332 return_logprobs,
1333 self.top_k,
1334 self.top_p as f32,
1335 self.min_p as f32,
1336 )?,
1337 Some(temperature) => {
1338 let logits = (&logits / temperature)?;
1339 let probs = hanzo_nn::ops::softmax_last_dim(&logits)?;
1340
1341 self.sample_speculative_top_kp_min_p(
1342 probs,
1343 return_logprobs,
1344 self.top_k,
1345 self.top_p as f32,
1346 self.min_p as f32,
1347 )?
1348 }
1349 }
1350 } else {
1351 match self.temperature {
1352 None => self.sample_argmax(logits, return_logprobs)?,
1353 Some(temperature) => {
1354 let logits = (&logits / temperature)?;
1355 let probs = hanzo_nn::ops::softmax_last_dim(&logits)?;
1356 let mut probs: Vec<f32> = probs.to_vec1()?;
1357
1358 self.sample_top_kp_min_p(
1359 &mut probs,
1360 self.top_k,
1361 self.top_p as f32,
1362 self.min_p as f32,
1363 return_logprobs,
1364 rng,
1365 )?
1366 }
1367 }
1368 };
1369 Ok(next_token)
1370 }
1371}
1372
1373#[cfg(test)]
1374mod tests {
1375 use super::{ModelGenerationDefaults, SamplingParams};
1376
1377 #[test]
1378 fn test_argmax() {
1379 use super::Sampler;
1380 use hanzo_ml::{Device, Tensor};
1381 use rand::SeedableRng;
1382 use rand_isaac::Isaac64Rng;
1383 use std::sync::Arc;
1384 use std::sync::Mutex;
1385
1386 let sampler = Sampler::new(
1387 None,
1388 10,
1389 None,
1390 None,
1391 None,
1392 None,
1393 None,
1394 32,
1395 0.1,
1396 0.05,
1397 vec![],
1398 )
1399 .unwrap();
1400 let logits = Tensor::arange(0f32, 1024f32, &Device::Cpu).unwrap();
1401 let rng = Arc::new(Mutex::new(Isaac64Rng::seed_from_u64(42)));
1402 let res = sampler
1403 .sample(
1404 logits,
1405 &(0..1024).collect::<Vec<_>>(),
1406 false,
1407 rng,
1408 false,
1409 false,
1410 )
1411 .unwrap();
1412 assert_eq!(res.token, 1023);
1413 assert_eq!(res.top_logprobs, None);
1414 assert_eq!(res.logprob, 1023f64.log(10.) as f32)
1415 }
1416
1417 #[test]
1418 fn test_gumbel_speculative() {
1419 use super::Sampler;
1420 use hanzo_ml::{Device, Tensor};
1421 use rand::SeedableRng;
1422 use rand_isaac::Isaac64Rng;
1423 use std::sync::Arc;
1424 use std::sync::Mutex;
1425
1426 let sampler = Sampler::new(
1427 None,
1428 10,
1429 None,
1430 None,
1431 None,
1432 None,
1433 None,
1434 32,
1435 0.1,
1436 0.05,
1437 vec![],
1438 )
1439 .unwrap();
1440 let logits = Tensor::arange(0f32, 1024f32, &Device::Cpu).unwrap();
1441 let rng = Arc::new(Mutex::new(Isaac64Rng::seed_from_u64(42)));
1442 let res = sampler
1443 .sample(
1444 logits,
1445 &(0..1024).collect::<Vec<_>>(),
1446 false,
1447 rng,
1448 true,
1449 false,
1450 )
1451 .unwrap();
1452 assert_eq!(res.token, 1023);
1453 assert_eq!(res.top_logprobs, None);
1454 assert_eq!(res.logprob, 1023f64.log(10.) as f32)
1455 }
1456
1457 #[test]
1458 fn test_speculative_candidate_probs_use_sampling_filters() {
1459 use super::Sampler;
1460 use hanzo_ml::{Device, Tensor};
1461
1462 let sampler = Sampler::new(
1463 Some(1.0),
1464 10,
1465 None,
1466 None,
1467 None,
1468 None,
1469 None,
1470 1,
1471 1.0,
1472 0.0,
1473 vec![],
1474 )
1475 .unwrap();
1476 let logits = Tensor::from_vec(vec![0.0f32, 1.0, 2.0], 3, &Device::Cpu).unwrap();
1477 let context = [0u32];
1478 let target_probs = sampler
1479 .speculative_target_probs(logits.clone(), &context)
1480 .unwrap();
1481 let candidate_probs = sampler
1482 .speculative_candidate_probs(logits, &context)
1483 .unwrap();
1484
1485 assert_eq!(candidate_probs, target_probs);
1486 assert_eq!(candidate_probs, vec![0.0, 0.0, 1.0]);
1487 }
1488
1489 #[test]
1490 fn test_apply_model_defaults() {
1491 let mut params = SamplingParams::neutral();
1492 params.apply_model_defaults(&ModelGenerationDefaults {
1493 do_sample: Some(true),
1494 temperature: Some(1.0),
1495 top_k: Some(32),
1496 top_p: Some(0.9),
1497 min_p: Some(0.05),
1498 repetition_penalty: Some(1.1),
1499 max_new_tokens: Some(256),
1500 max_length: None,
1501 });
1502
1503 assert_eq!(params.temperature, Some(1.0));
1504 assert_eq!(params.top_k, Some(32));
1505 assert_eq!(params.top_p, Some(0.9));
1506 assert_eq!(params.min_p, Some(0.05));
1507 assert_eq!(params.repetition_penalty, Some(1.1));
1508 assert_eq!(params.max_len, Some(256));
1509 }
1510
1511 #[test]
1512 fn test_apply_model_defaults_disables_sampling_when_requested() {
1513 let mut params = SamplingParams {
1514 temperature: Some(0.7),
1515 top_k: Some(40),
1516 top_p: Some(0.9),
1517 min_p: Some(0.1),
1518 ..SamplingParams::neutral()
1519 };
1520 params.apply_model_defaults(&ModelGenerationDefaults {
1521 do_sample: Some(false),
1522 ..Default::default()
1523 });
1524
1525 assert_eq!(params.temperature, None);
1526 assert_eq!(params.top_k, Some(1));
1527 assert_eq!(params.top_p, None);
1528 assert_eq!(params.min_p, None);
1529 }
1530}