#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
use std::{
collections::{HashMap, HashSet},
sync::{Arc, LazyLock, Mutex},
};
use hanzo_ml::{Device, Error, Result, Tensor};
#[cfg(feature = "pyo3_macros")]
use pyo3::pyclass;
use rand::distr::{weighted::WeightedIndex, Distribution};
use rand_isaac::Isaac64Rng;
use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
use serde::{Deserialize, Serialize};
use tokenizers::Tokenizer;
static DRY_SEQUENCE_BREAKERS: LazyLock<Vec<String>> =
LazyLock::new(|| ["\n", ":", "\"", "*"].map(String::from).to_vec());
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
pub struct ModelGenerationDefaults {
pub do_sample: Option<bool>,
pub temperature: Option<f64>,
pub top_k: Option<usize>,
pub top_p: Option<f64>,
pub min_p: Option<f64>,
pub repetition_penalty: Option<f32>,
pub max_new_tokens: Option<usize>,
pub max_length: Option<usize>,
}
impl ModelGenerationDefaults {
pub fn is_empty(&self) -> bool {
self.do_sample.is_none()
&& self.temperature.is_none()
&& self.top_k.is_none()
&& self.top_p.is_none()
&& self.min_p.is_none()
&& self.repetition_penalty.is_none()
&& self.max_new_tokens.is_none()
&& self.max_length.is_none()
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum StopTokens {
Seqs(Vec<String>),
Ids(Vec<u32>),
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SamplingParams {
pub temperature: Option<f64>,
pub top_k: Option<usize>,
pub top_p: Option<f64>,
pub min_p: Option<f64>,
pub top_n_logprobs: usize,
pub frequency_penalty: Option<f32>,
pub presence_penalty: Option<f32>,
pub repetition_penalty: Option<f32>,
pub stop_toks: Option<StopTokens>,
pub max_len: Option<usize>,
pub logits_bias: Option<HashMap<u32, f32>>,
pub n_choices: usize,
pub dry_params: Option<DrySamplingParams>,
}
impl SamplingParams {
pub fn neutral() -> Self {
Self {
temperature: None,
top_k: None,
top_p: None,
min_p: None,
top_n_logprobs: 0,
frequency_penalty: None,
presence_penalty: None,
repetition_penalty: None,
stop_toks: None,
max_len: None,
logits_bias: None,
n_choices: 1,
dry_params: None,
}
}
pub fn deterministic() -> Self {
Self {
temperature: None,
top_k: Some(1),
top_p: None,
min_p: None,
top_n_logprobs: 0,
frequency_penalty: None,
presence_penalty: None,
repetition_penalty: None,
stop_toks: None,
max_len: None,
logits_bias: None,
n_choices: 1,
dry_params: None,
}
}
pub fn apply_model_defaults(&mut self, defaults: &ModelGenerationDefaults) {
if defaults.do_sample == Some(false) {
self.temperature = None;
self.top_k = Some(1);
self.top_p = None;
self.min_p = None;
}
if let Some(temperature) = defaults.temperature {
self.temperature = if temperature == 0.0 {
None
} else {
Some(temperature)
};
}
if let Some(top_k) = defaults.top_k {
self.top_k = if top_k == 0 { None } else { Some(top_k) };
}
if let Some(top_p) = defaults.top_p {
self.top_p = Some(top_p);
}
if let Some(min_p) = defaults.min_p {
self.min_p = Some(min_p);
}
if let Some(repetition_penalty) = defaults.repetition_penalty {
self.repetition_penalty = Some(repetition_penalty);
}
if let Some(max_new_tokens) = defaults.max_new_tokens {
self.max_len = Some(max_new_tokens);
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct DrySamplingParams {
pub sequence_breakers: Vec<String>,
pub multiplier: f32,
pub base: f32,
pub allowed_length: usize,
}
impl DrySamplingParams {
pub fn new_with_defaults(
multiplier: f32,
sequence_breakers: Option<Vec<String>>,
base: Option<f32>,
allowed_length: Option<usize>,
) -> anyhow::Result<Self> {
Ok(Self {
base: base.unwrap_or(1.75),
allowed_length: allowed_length.unwrap_or(2),
sequence_breakers: sequence_breakers.unwrap_or(DRY_SEQUENCE_BREAKERS.clone()),
multiplier,
})
}
}
impl Default for DrySamplingParams {
fn default() -> Self {
Self {
multiplier: 0.0,
base: 1.75,
allowed_length: 2,
sequence_breakers: DRY_SEQUENCE_BREAKERS.clone(),
}
}
}
#[derive(Clone, Debug)]
struct DrySamplingParamsInner {
pub sequence_breakers: HashSet<u32>,
pub multiplier: f32,
pub base: f32,
pub allowed_length: usize,
}
impl DrySamplingParamsInner {
pub fn from(other: DrySamplingParams, tokenizer: &Tokenizer) -> anyhow::Result<Self> {
Ok(Self {
base: other.base,
allowed_length: other.allowed_length,
sequence_breakers: HashSet::from_iter(
other
.sequence_breakers
.into_iter()
.map(|breaker| {
tokenizer
.encode_fast(["a", &breaker].concat(), true)
.map_err(anyhow::Error::msg)
.map(|enc| {
let ids = enc.get_ids();
if !ids.is_empty() {
Some(ids[ids.len() - 1])
} else {
None
}
})
})
.collect::<anyhow::Result<Vec<_>>>()?
.into_iter()
.flatten()
.collect::<Vec<_>>(),
),
multiplier: other.multiplier,
})
}
}
pub trait CustomLogitsProcessor: Send + Sync {
fn apply(&self, logits: &Tensor, context: &[u32]) -> Result<Tensor>;
}
impl<T: Fn(&Tensor, &[u32]) -> Result<Tensor> + Send + Sync> CustomLogitsProcessor for T {
fn apply(&self, logits: &Tensor, context: &[u32]) -> Result<Tensor> {
self(logits, context)
}
}
#[derive(Clone)]
pub struct Sampler {
temperature: Option<f64>,
top_n_logprobs: usize,
tokenizer: Option<Arc<Tokenizer>>,
frequency_penalty: Option<f32>,
presence_penalty: Option<f32>,
repetition_penalty: Option<f32>,
dry_params: Option<DrySamplingParamsInner>,
top_k: i64,
top_p: f64,
min_p: f64,
logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
}
#[cfg_attr(feature = "pyo3_macros", pyclass)]
#[cfg_attr(feature = "pyo3_macros", pyo3(get_all))]
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct TopLogprob {
pub token: u32,
pub logprob: f32,
pub bytes: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Logprobs {
pub token: u32,
pub logprob: f32,
pub bytes: Option<String>,
pub top_logprobs: Option<Vec<TopLogprob>>,
}
#[inline]
fn cmp_desc_by_prob(a: &(u32, f32), b: &(u32, f32)) -> std::cmp::Ordering {
b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
}
fn partial_sort_top_k(probs: &mut [f32], k: usize, zero_rest: bool) -> Vec<(u32, f32)> {
let n = probs.len();
if n == 0 || k == 0 {
return Vec::new();
}
let mut idx_probs: Vec<(u32, f32)> = (0..n as u32).map(|i| (i, probs[i as usize])).collect();
let k = k.min(n);
if k < n {
idx_probs.select_nth_unstable_by(k - 1, cmp_desc_by_prob);
if zero_rest {
for (idx, _) in idx_probs[k..].iter() {
probs[*idx as usize] = 0.0;
}
}
idx_probs.truncate(k);
}
idx_probs.sort_unstable_by(cmp_desc_by_prob);
idx_probs
}
#[inline]
fn argmax_f32(values: &[f32]) -> u32 {
values
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i as u32)
.unwrap_or(0)
}
impl Sampler {
#[allow(clippy::too_many_arguments)]
pub fn new(
temperature: Option<f64>,
top_n_logprobs: usize,
tokenizer: Option<Arc<Tokenizer>>,
frequency_penalty: Option<f32>,
presence_penalty: Option<f32>,
repetition_penalty: Option<f32>,
dry_params: Option<DrySamplingParams>,
top_k: i64,
top_p: f64,
min_p: f64,
logits_processors: Vec<Arc<dyn CustomLogitsProcessor>>,
) -> anyhow::Result<Self> {
let temperature = if temperature.is_none_or(|v| v < 1e-7) {
None
} else {
temperature
};
let dry_params = if let Some(ref tokenizer) = tokenizer {
dry_params.map(|params| DrySamplingParamsInner::from(params, tokenizer))
} else {
None
};
let dry_params = match dry_params {
Some(fallible) => Some(fallible?),
None => None,
};
Ok(Self {
temperature,
top_n_logprobs,
tokenizer,
frequency_penalty,
presence_penalty,
repetition_penalty,
dry_params,
top_k,
top_p,
min_p,
logits_processors,
})
}
pub fn is_argmax(&self) -> bool {
self.temperature.is_none()
}
fn get_top_logprobs(&self, probs: &[f32]) -> Result<Vec<TopLogprob>> {
let k = self.top_n_logprobs.min(probs.len());
if k == 0 {
return Ok(Vec::new());
}
let mut probs_copy = probs.to_vec();
let top_k = partial_sort_top_k(&mut probs_copy, k, false);
let mut result = Vec::with_capacity(k);
if let Some(tokenizer) = &self.tokenizer {
for (token, prob) in top_k {
let decoded = tokenizer
.decode(&[token], false)
.map_err(|e| Error::Msg(e.to_string()))?;
result.push(TopLogprob {
token,
logprob: prob.log(10.0),
bytes: Some(decoded),
});
}
} else {
for (token, prob) in top_k {
result.push(TopLogprob {
token,
logprob: prob.log(10.0),
bytes: None,
});
}
}
Ok(result)
}
fn sample_argmax(&self, logits: Tensor, return_logprobs: bool) -> Result<Logprobs> {
let probs: Vec<f32> = logits.to_vec1()?;
let next_token = argmax_f32(&probs);
let logprob = probs[next_token as usize].log(10.0);
let top_logprobs = if return_logprobs {
Some(self.get_top_logprobs(&probs)?)
} else {
None
};
let bytes = if let Some(tokenizer) = &self.tokenizer {
Some(
tokenizer
.decode(&[next_token], false)
.map_err(|x| Error::Msg(x.to_string()))?,
)
} else {
None
};
Ok(Logprobs {
token: next_token,
logprob,
top_logprobs,
bytes,
})
}
fn sample_speculative_top_kp_min_p(
&self,
logits: Tensor,
return_logprobs: bool,
top_k: i64,
top_p: f32,
min_p: f32,
) -> Result<Logprobs> {
let mut probs: Vec<f32> = logits.to_vec1()?;
let k = if top_k > 0 {
top_k as usize
} else {
probs.len()
};
let idx_probs = partial_sort_top_k(&mut probs, k, true);
let mut cumsum = 0.;
for (index, prob) in &idx_probs {
if cumsum >= top_p {
probs[*index as usize] = 0.0;
} else {
cumsum += prob;
}
}
let max_p = idx_probs.first().map(|(_, p)| *p).unwrap_or(0.0);
let min_p_threshold = max_p * min_p;
for (index, prob) in &idx_probs {
if min_p_threshold >= *prob {
probs[*index as usize] = 0.0;
}
}
let next_token = argmax_f32(&probs);
let logprob = probs[next_token as usize].log(10.0);
let top_logprobs = if return_logprobs {
Some(self.get_top_logprobs(&probs)?)
} else {
None
};
let bytes = if let Some(tokenizer) = &self.tokenizer {
Some(
tokenizer
.decode(&[next_token], false)
.map_err(|x| Error::Msg(x.to_string()))?,
)
} else {
None
};
Ok(Logprobs {
token: next_token,
logprob,
top_logprobs,
bytes,
})
}
fn sample_multinomial(
&self,
probs: &[f32],
return_logprobs: bool,
rng: Arc<Mutex<Isaac64Rng>>,
) -> Result<Logprobs> {
let distr = match WeightedIndex::new(probs) {
Ok(distr) => distr,
Err(e) => {
if let Some((idx, prob)) = probs
.iter()
.enumerate()
.find(|(_, prob)| !prob.is_finite() || **prob < 0.0)
{
return Err(Error::Msg(format!(
"Invalid sampling probability at index {idx}: {prob}. The model likely produced NaN/Inf logits."
)));
}
let positive_weight_sum: f64 = probs
.iter()
.copied()
.filter(|prob| prob.is_finite() && *prob > 0.0)
.map(f64::from)
.sum();
if positive_weight_sum == 0.0 {
return Err(Error::Msg(
"All sampling probabilities are zero after filtering (top-k/top-p/min-p)."
.to_string(),
));
}
return Err(Error::Msg(format!(
"Failed to construct multinomial sampler: {e}"
)));
}
};
let mut mut_ref_rng = &mut *rng.lock().expect("could not lock rng mutex");
let next_token = distr.sample(&mut mut_ref_rng); let logprob = probs[next_token].log(10.0);
let top_logprobs = if return_logprobs {
Some(self.get_top_logprobs(probs)?)
} else {
None
};
let bytes = if let Some(tokenizer) = &self.tokenizer {
Some(
tokenizer
.decode(&[next_token.try_into().unwrap()], false)
.map_err(|x| Error::Msg(x.to_string()))?,
)
} else {
None
};
Ok(Logprobs {
token: next_token as u32,
logprob,
top_logprobs,
bytes,
})
}
#[cfg(any(feature = "cuda", feature = "metal"))]
fn can_sample_topk_on_device(
&self,
return_logprobs: bool,
sample_speculative: bool,
multiple_sequences: bool,
) -> bool {
const MAX_DEVICE_TOP_K: i64 = 128;
!return_logprobs
&& !sample_speculative
&& !multiple_sequences
&& self.temperature.is_some()
&& self.top_k > 0
&& self.top_k <= MAX_DEVICE_TOP_K
&& self.logits_processors.is_empty()
&& self
.dry_params
.as_ref()
.is_none_or(|params| params.multiplier.abs() <= f32::EPSILON)
}
#[cfg(feature = "cuda")]
fn apply_device_sparse_penalties_if_needed(
&self,
logits: Tensor,
context: &[u32],
) -> Result<Tensor> {
let frequency_penalty = self.frequency_penalty.unwrap_or(0.0);
let presence_penalty = self.presence_penalty.unwrap_or(0.0);
let repetition_penalty = self.repetition_penalty.unwrap_or(1.0);
let needs_penalty = frequency_penalty.abs() > f32::EPSILON
|| presence_penalty.abs() > f32::EPSILON
|| (repetition_penalty - 1.0).abs() > f32::EPSILON;
if !needs_penalty {
return Ok(logits);
}
if context.is_empty() {
hanzo_ml::bail!("Penalty context is empty, this should not happen.");
}
let vocab_size = logits.elem_count();
let mut counts = HashMap::<u32, f32>::with_capacity(context.len().min(vocab_size));
for &token_id in context {
if token_id as usize >= vocab_size {
continue;
}
*counts.entry(token_id).or_insert(0.0) += 1.0;
}
if counts.is_empty() {
return Ok(logits);
}
let n_tokens = counts.len();
let mut token_ids = Vec::with_capacity(n_tokens);
let mut token_counts = Vec::with_capacity(n_tokens);
for (token_id, count) in counts {
token_ids.push(token_id);
token_counts.push(count);
}
let device = logits.device();
let token_ids = Tensor::from_vec(token_ids, n_tokens, device)?;
let token_counts = Tensor::from_vec(token_counts, n_tokens, device)?;
crate::ops::cuda_apply_sparse_penalties_f32(
&logits,
&token_ids,
&token_counts,
frequency_penalty,
presence_penalty,
repetition_penalty,
)
}
#[cfg(feature = "cuda")]
fn sample_topk_on_device(
&self,
logits: Tensor,
temperature: f64,
rng: Arc<Mutex<Isaac64Rng>>,
) -> Result<Logprobs> {
let topk =
crate::ops::cuda_topk_logits_f32_packed(&logits, self.top_k as usize, temperature)?;
let packed = topk.packed.to_vec1::<f32>()?;
let k = topk.k;
if packed.len() != 2 * k + 2 {
hanzo_ml::bail!(
"invalid CUDA top-k packed output length {}, expected {}",
packed.len(),
2 * k + 2
);
}
let top_values = &packed[..k];
let top_indices = packed[k..2 * k]
.iter()
.map(|idx| *idx as u32)
.collect::<Vec<_>>();
let softmax_info = &packed[2 * k..2 * k + 2];
let denom = softmax_info[0];
let global_max = softmax_info[1];
if denom <= 0.0 || !denom.is_finite() || !global_max.is_finite() {
hanzo_ml::bail!("invalid CUDA top-k softmax normalizer");
}
let inv_temperature = (1.0 / temperature) as f32;
let mut probs = top_values
.iter()
.map(|value| ((*value * inv_temperature - global_max).exp()) / denom)
.collect::<Vec<_>>();
if self.top_p > 0.0 && self.top_p < 1.0 {
let mut cumsum = 0.0f32;
for prob in &mut probs {
if cumsum >= self.top_p as f32 {
*prob = 0.0;
} else {
cumsum += *prob;
}
}
if self.min_p > 0.0 && self.min_p < 1.0 {
let max_p = probs.first().copied().unwrap_or(0.0);
let min_p_threshold = max_p * self.min_p as f32;
for prob in &mut probs {
if min_p_threshold >= *prob {
*prob = 0.0;
}
}
}
}
let distr = match WeightedIndex::new(&probs) {
Ok(distr) => distr,
Err(e) => {
let positive_weight_sum: f64 = probs
.iter()
.copied()
.filter(|prob| prob.is_finite() && *prob > 0.0)
.map(f64::from)
.sum();
if positive_weight_sum == 0.0 {
return Err(Error::Msg(
"All sampling probabilities are zero after CUDA top-k filtering."
.to_string(),
));
}
return Err(Error::Msg(format!(
"Failed to construct CUDA top-k multinomial sampler: {e}"
)));
}
};
let mut mut_ref_rng = &mut *rng.lock().expect("could not lock rng mutex");
let selected = distr.sample(&mut mut_ref_rng);
let next_token = top_indices[selected];
let logprob = probs[selected].log(10.0);
let bytes = if let Some(tokenizer) = &self.tokenizer {
Some(
tokenizer
.decode(&[next_token], false)
.map_err(|x| Error::Msg(x.to_string()))?,
)
} else {
None
};
Ok(Logprobs {
token: next_token,
logprob,
top_logprobs: None,
bytes,
})
}
#[cfg(feature = "metal")]
fn apply_device_sparse_penalties_if_needed_metal(
&self,
logits: Tensor,
context: &[u32],
) -> Result<Tensor> {
let frequency_penalty = self.frequency_penalty.unwrap_or(0.0);
let presence_penalty = self.presence_penalty.unwrap_or(0.0);
let repetition_penalty = self.repetition_penalty.unwrap_or(1.0);
let needs_penalty = frequency_penalty.abs() > f32::EPSILON
|| presence_penalty.abs() > f32::EPSILON
|| (repetition_penalty - 1.0).abs() > f32::EPSILON;
if !needs_penalty || context.is_empty() {
return Ok(logits);
}
let vocab_size = logits.elem_count();
let mut counts = HashMap::<u32, f32>::with_capacity(context.len().min(vocab_size));
for &tid in context {
if (tid as usize) >= vocab_size {
continue;
}
*counts.entry(tid).or_insert(0.0) += 1.0;
}
if counts.is_empty() {
return Ok(logits);
}
let n_tokens = counts.len();
let mut token_ids = Vec::with_capacity(n_tokens);
let mut token_counts = Vec::with_capacity(n_tokens);
for (tid, c) in counts {
token_ids.push(tid);
token_counts.push(c);
}
let device = logits.device();
let token_ids = Tensor::from_vec(token_ids, n_tokens, device)?;
let token_counts = Tensor::from_vec(token_counts, n_tokens, device)?;
crate::ops::metal_apply_sparse_penalties(
&logits,
&token_ids,
&token_counts,
frequency_penalty,
presence_penalty,
repetition_penalty,
)
}
#[cfg(feature = "metal")]
fn sample_topk_on_device_metal(
&self,
logits: Tensor,
temperature: f64,
rng: Arc<Mutex<Isaac64Rng>>,
) -> Result<Logprobs> {
let topk = crate::ops::metal_topk_logits_packed(&logits, self.top_k as usize, temperature)?;
let packed = topk.packed.to_vec1::<f32>()?;
let k = topk.k;
if packed.len() != 2 * k + 2 {
hanzo_ml::bail!(
"invalid Metal top-k packed output length {}, expected {}",
packed.len(),
2 * k + 2
);
}
let top_values = &packed[..k];
let top_indices = packed[k..2 * k]
.iter()
.map(|idx| *idx as u32)
.collect::<Vec<_>>();
let softmax_info = &packed[2 * k..2 * k + 2];
let denom = softmax_info[0];
let global_max = softmax_info[1];
if denom <= 0.0 || !denom.is_finite() || !global_max.is_finite() {
hanzo_ml::bail!("invalid Metal top-k softmax normalizer");
}
let inv_temperature = (1.0 / temperature) as f32;
let mut probs = top_values
.iter()
.map(|value| ((*value * inv_temperature - global_max).exp()) / denom)
.collect::<Vec<_>>();
if self.top_p > 0.0 && self.top_p < 1.0 {
let mut cumsum = 0.0f32;
for prob in &mut probs {
if cumsum >= self.top_p as f32 {
*prob = 0.0;
} else {
cumsum += *prob;
}
}
if self.min_p > 0.0 && self.min_p < 1.0 {
let max_p = probs.first().copied().unwrap_or(0.0);
let min_p_threshold = max_p * self.min_p as f32;
for prob in &mut probs {
if min_p_threshold >= *prob {
*prob = 0.0;
}
}
}
}
let distr = match WeightedIndex::new(&probs) {
Ok(distr) => distr,
Err(e) => {
let positive_weight_sum: f64 = probs
.iter()
.copied()
.filter(|prob| prob.is_finite() && *prob > 0.0)
.map(f64::from)
.sum();
if positive_weight_sum == 0.0 {
return Err(Error::Msg(
"All sampling probabilities are zero after Metal top-k filtering."
.to_string(),
));
}
return Err(Error::Msg(format!(
"Failed to construct Metal top-k multinomial sampler: {e}"
)));
}
};
let mut mut_ref_rng = &mut *rng.lock().expect("could not lock rng mutex");
let selected = distr.sample(&mut mut_ref_rng);
let next_token = top_indices[selected];
let logprob = probs[selected].log(10.0);
let bytes = if let Some(tokenizer) = &self.tokenizer {
Some(
tokenizer
.decode(&[next_token], false)
.map_err(|x| Error::Msg(x.to_string()))?,
)
} else {
None
};
Ok(Logprobs {
token: next_token,
logprob,
top_logprobs: None,
bytes,
})
}
fn filter_top_kp_min_p(&self, probs: &mut [f32]) {
let k = if self.top_k > 0 {
self.top_k as usize
} else {
probs.len()
};
let idx_probs = partial_sort_top_k(probs, k, true);
if self.top_p <= 0.0 || self.top_p >= 1.0 {
return;
}
let mut cumsum = 0.0f32;
for (index, prob) in &idx_probs {
if cumsum >= self.top_p as f32 {
probs[*index as usize] = 0.0;
} else {
cumsum += prob;
}
}
if self.min_p <= 0.0 || self.min_p >= 1.0 {
return;
}
let max_p = idx_probs.first().map(|(_, p)| *p).unwrap_or(0.0);
let min_p_threshold = max_p * self.min_p as f32;
for (index, prob) in &idx_probs {
if min_p_threshold >= *prob {
probs[*index as usize] = 0.0;
}
}
}
fn normalize_probs(probs: &mut [f32]) -> Result<()> {
let sum: f32 = probs
.iter()
.copied()
.filter(|prob| prob.is_finite() && *prob > 0.0)
.sum();
if sum <= 0.0 {
hanzo_ml::bail!("all probabilities are zero in speculative sampling");
}
for prob in probs.iter_mut() {
if prob.is_finite() && *prob > 0.0 {
*prob /= sum;
} else {
*prob = 0.0;
}
}
Ok(())
}
pub(crate) fn speculative_target_probs(
&self,
logits: Tensor,
context: &[u32],
) -> Result<Vec<f32>> {
self.speculative_probs(logits, context)
}
pub(crate) fn speculative_candidate_probs(
&self,
logits: Tensor,
context: &[u32],
) -> Result<Vec<f32>> {
self.speculative_probs(logits, context)
}
fn speculative_probs(&self, logits: Tensor, context: &[u32]) -> Result<Vec<f32>> {
let logits = logits.to_vec1()?;
let mut logits = self.apply_penalties(logits, context)?;
for processor in &self.logits_processors {
logits = processor.apply(&logits, context)?;
}
let mut probs = match self.temperature {
None => {
let logits = logits.to_vec1::<f32>()?;
let mut probs = vec![0.0; logits.len()];
probs[argmax_f32(&logits) as usize] = 1.0;
probs
}
Some(temperature) => {
let logits = (&logits / temperature)?;
hanzo_nn::ops::softmax_last_dim(&logits)?.to_vec1::<f32>()?
}
};
self.filter_top_kp_min_p(&mut probs);
Self::normalize_probs(&mut probs)?;
Ok(probs)
}
pub(crate) fn logprobs_from_probs(
&self,
token: u32,
probs: &[f32],
return_logprobs: bool,
) -> Result<Logprobs> {
let prob = probs.get(token as usize).copied().unwrap_or(0.0);
let logprob = if prob > 0.0 {
prob.log(10.0)
} else {
f32::NEG_INFINITY
};
let top_logprobs = if return_logprobs {
Some(self.get_top_logprobs(probs)?)
} else {
None
};
let bytes = if let Some(tokenizer) = &self.tokenizer {
Some(
tokenizer
.decode(&[token], false)
.map_err(|x| Error::Msg(x.to_string()))?,
)
} else {
None
};
Ok(Logprobs {
token,
logprob,
top_logprobs,
bytes,
})
}
pub(crate) fn sample_from_probs(
&self,
probs: &[f32],
return_logprobs: bool,
rng: Arc<Mutex<Isaac64Rng>>,
) -> Result<Logprobs> {
self.sample_multinomial(probs, return_logprobs, rng)
}
#[allow(clippy::too_many_arguments)]
fn sample_top_kp_min_p(
&self,
probs: &mut [f32],
top_k: i64,
top_p: f32,
min_p: f32,
return_logprobs: bool,
rng: Arc<Mutex<Isaac64Rng>>,
) -> Result<Logprobs> {
let k = if top_k > 0 {
top_k as usize
} else {
probs.len()
};
let idx_probs = partial_sort_top_k(probs, k, true);
if top_p <= 0.0 || top_p >= 1.0 {
return self.sample_multinomial(probs, return_logprobs, rng);
}
let mut cumsum = 0.;
for (index, prob) in &idx_probs {
if cumsum >= top_p {
probs[*index as usize] = 0.0;
} else {
cumsum += prob;
}
}
if min_p <= 0.0 || min_p >= 1.0 {
return self.sample_multinomial(probs, return_logprobs, rng);
}
let max_p = idx_probs.first().map(|(_, p)| *p).unwrap_or(0.0);
let min_p_threshold = max_p * min_p;
for (index, prob) in &idx_probs {
if min_p_threshold >= *prob {
probs[*index as usize] = 0.0;
}
}
self.sample_multinomial(probs, return_logprobs, rng)
}
fn apply_penalties(&self, mut logits: Vec<f32>, context: &[u32]) -> Result<Tensor> {
if context.is_empty() {
hanzo_ml::bail!("Penalty context is empty, this should not happen.");
}
self.apply_dry_penalty(&mut logits, context)?;
self.apply_freq_pres_rep_penalty(&mut logits, context)?;
let vocab_size = logits.len();
Tensor::from_vec(logits, vocab_size, &Device::Cpu)
}
fn apply_freq_pres_rep_penalty(&self, logits: &mut [f32], context: &[u32]) -> Result<()> {
if self.frequency_penalty.is_some()
|| self.presence_penalty.is_some()
|| self.repetition_penalty.is_some()
{
let frequency_penalty = self.frequency_penalty.unwrap_or(0.);
let presence_penalty = self.presence_penalty.unwrap_or(0.);
let repetition_penalty = self.repetition_penalty.unwrap_or(1.);
let mut counts = vec![0.0f32; logits.len()];
for ctx in context.iter() {
if *ctx as usize >= logits.len() {
continue;
}
counts[*ctx as usize] += 1.0;
}
for (token_id, logit) in logits.iter_mut().enumerate() {
let count = counts[token_id];
*logit = *logit
- count * frequency_penalty
- if count > 0.0 { 1. } else { 0. } * presence_penalty;
if repetition_penalty != 1.0 && count > 0.0 {
if *logit > 0.0 {
*logit /= repetition_penalty;
} else {
*logit *= repetition_penalty;
}
}
}
}
Ok(())
}
const DRY_PENALTY_PAR_THRESHOLD: usize = 1024;
fn apply_dry_penalty(&self, logits: &mut [f32], context: &[u32]) -> Result<()> {
if let Some(ref params) = self.dry_params {
if params.multiplier == 0. {
return Ok(());
}
let last_token = *context.last().unwrap();
let match_indices: Vec<usize> = if context.len() > Self::DRY_PENALTY_PAR_THRESHOLD {
context
.par_iter()
.enumerate()
.take(context.len() - 1)
.filter(|(_i, x)| last_token == **x)
.map(|(i, _)| i)
.collect()
} else {
context
.iter()
.enumerate()
.take(context.len() - 1)
.filter(|(_i, x)| last_token == **x)
.map(|(i, _)| i)
.collect()
};
let mut match_lengths = HashMap::new();
for i in match_indices {
let next_token = context[i + 1];
if params.sequence_breakers.contains(&next_token) {
continue;
}
let mut match_length = 1;
while match_length < 50 {
if match_length > i {
break;
}
let j = i - match_length;
let prev_tok = context[context.len() - (match_length + 1)];
if context[j] != prev_tok {
break;
}
if params.sequence_breakers.contains(&prev_tok) {
break;
}
match_length += 1;
}
#[allow(clippy::map_entry)]
if match_lengths.contains_key(&next_token) {
match_lengths.insert(next_token, match_length.max(match_lengths[&next_token]));
} else {
match_lengths.insert(next_token, match_length);
}
}
for (tok, match_len) in match_lengths {
if match_len >= params.allowed_length {
if tok as usize >= logits.len() {
continue;
}
let penalty = params.multiplier
* params.base.powf((match_len - params.allowed_length) as f32);
logits[tok as usize] -= penalty;
}
}
}
Ok(())
}
#[allow(unused)]
pub fn sample(
&self,
logits: Tensor,
context: &[u32],
return_logprobs: bool,
rng: Arc<Mutex<Isaac64Rng>>,
sample_speculative: bool,
multiple_sequences: bool,
) -> Result<Logprobs> {
#[cfg(feature = "cuda")]
if logits.device().is_cuda()
&& self.can_sample_topk_on_device(
return_logprobs,
sample_speculative,
multiple_sequences,
)
{
if let Some(temperature) = self.temperature {
let logits = self.apply_device_sparse_penalties_if_needed(logits, context)?;
return self.sample_topk_on_device(logits, temperature, rng);
}
}
#[cfg(feature = "metal")]
if logits.device().is_metal()
&& self.can_sample_topk_on_device(
return_logprobs,
sample_speculative,
multiple_sequences,
)
{
if let Some(temperature) = self.temperature {
let logits = self.apply_device_sparse_penalties_if_needed_metal(logits, context)?;
return self.sample_topk_on_device_metal(logits, temperature, rng);
}
}
let logits = logits.to_vec1()?;
let mut logits = self.apply_penalties(logits, context)?;
for processor in &self.logits_processors {
logits = processor.apply(&logits, context)?;
}
let next_token = if sample_speculative {
match self.temperature {
None => self.sample_speculative_top_kp_min_p(
logits,
return_logprobs,
self.top_k,
self.top_p as f32,
self.min_p as f32,
)?,
Some(temperature) => {
let logits = (&logits / temperature)?;
let probs = hanzo_nn::ops::softmax_last_dim(&logits)?;
self.sample_speculative_top_kp_min_p(
probs,
return_logprobs,
self.top_k,
self.top_p as f32,
self.min_p as f32,
)?
}
}
} else {
match self.temperature {
None => self.sample_argmax(logits, return_logprobs)?,
Some(temperature) => {
let logits = (&logits / temperature)?;
let probs = hanzo_nn::ops::softmax_last_dim(&logits)?;
let mut probs: Vec<f32> = probs.to_vec1()?;
self.sample_top_kp_min_p(
&mut probs,
self.top_k,
self.top_p as f32,
self.min_p as f32,
return_logprobs,
rng,
)?
}
}
};
Ok(next_token)
}
}
#[cfg(test)]
mod tests {
use super::{ModelGenerationDefaults, SamplingParams};
#[test]
fn test_argmax() {
use super::Sampler;
use hanzo_ml::{Device, Tensor};
use rand::SeedableRng;
use rand_isaac::Isaac64Rng;
use std::sync::Arc;
use std::sync::Mutex;
let sampler = Sampler::new(
None,
10,
None,
None,
None,
None,
None,
32,
0.1,
0.05,
vec![],
)
.unwrap();
let logits = Tensor::arange(0f32, 1024f32, &Device::Cpu).unwrap();
let rng = Arc::new(Mutex::new(Isaac64Rng::seed_from_u64(42)));
let res = sampler
.sample(
logits,
&(0..1024).collect::<Vec<_>>(),
false,
rng,
false,
false,
)
.unwrap();
assert_eq!(res.token, 1023);
assert_eq!(res.top_logprobs, None);
assert_eq!(res.logprob, 1023f64.log(10.) as f32)
}
#[test]
fn test_gumbel_speculative() {
use super::Sampler;
use hanzo_ml::{Device, Tensor};
use rand::SeedableRng;
use rand_isaac::Isaac64Rng;
use std::sync::Arc;
use std::sync::Mutex;
let sampler = Sampler::new(
None,
10,
None,
None,
None,
None,
None,
32,
0.1,
0.05,
vec![],
)
.unwrap();
let logits = Tensor::arange(0f32, 1024f32, &Device::Cpu).unwrap();
let rng = Arc::new(Mutex::new(Isaac64Rng::seed_from_u64(42)));
let res = sampler
.sample(
logits,
&(0..1024).collect::<Vec<_>>(),
false,
rng,
true,
false,
)
.unwrap();
assert_eq!(res.token, 1023);
assert_eq!(res.top_logprobs, None);
assert_eq!(res.logprob, 1023f64.log(10.) as f32)
}
#[test]
fn test_speculative_candidate_probs_use_sampling_filters() {
use super::Sampler;
use hanzo_ml::{Device, Tensor};
let sampler = Sampler::new(
Some(1.0),
10,
None,
None,
None,
None,
None,
1,
1.0,
0.0,
vec![],
)
.unwrap();
let logits = Tensor::from_vec(vec![0.0f32, 1.0, 2.0], 3, &Device::Cpu).unwrap();
let context = [0u32];
let target_probs = sampler
.speculative_target_probs(logits.clone(), &context)
.unwrap();
let candidate_probs = sampler
.speculative_candidate_probs(logits, &context)
.unwrap();
assert_eq!(candidate_probs, target_probs);
assert_eq!(candidate_probs, vec![0.0, 0.0, 1.0]);
}
#[test]
fn test_apply_model_defaults() {
let mut params = SamplingParams::neutral();
params.apply_model_defaults(&ModelGenerationDefaults {
do_sample: Some(true),
temperature: Some(1.0),
top_k: Some(32),
top_p: Some(0.9),
min_p: Some(0.05),
repetition_penalty: Some(1.1),
max_new_tokens: Some(256),
max_length: None,
});
assert_eq!(params.temperature, Some(1.0));
assert_eq!(params.top_k, Some(32));
assert_eq!(params.top_p, Some(0.9));
assert_eq!(params.min_p, Some(0.05));
assert_eq!(params.repetition_penalty, Some(1.1));
assert_eq!(params.max_len, Some(256));
}
#[test]
fn test_apply_model_defaults_disables_sampling_when_requested() {
let mut params = SamplingParams {
temperature: Some(0.7),
top_k: Some(40),
top_p: Some(0.9),
min_p: Some(0.1),
..SamplingParams::neutral()
};
params.apply_model_defaults(&ModelGenerationDefaults {
do_sample: Some(false),
..Default::default()
});
assert_eq!(params.temperature, None);
assert_eq!(params.top_k, Some(1));
assert_eq!(params.top_p, None);
assert_eq!(params.min_p, None);
}
}