use crate::{
array::Array,
dtype::Dtype,
error::{Error, LengthMismatchPayload, OutOfRangePayload, Result, UnsupportedDtypePayload},
ops,
};
fn scalar_like(value: f32, like: &Array) -> Result<Array> {
crate::error::ensure_handler_installed();
ops::misc::astype(&Array::full::<f32>(&(1,), value)?, like.dtype()?)
}
fn slice_last_axis(a: &Array, start: i32, end: i32) -> Result<Array> {
let shape = a.shape();
let ndim = shape.len();
let mut starts = vec![0i32; ndim];
let mut stops: Vec<i32> = shape.iter().map(|&d| d as i32).collect();
let strides = vec![1i32; ndim];
if ndim > 0 {
starts[ndim - 1] = start;
stops[ndim - 1] = end;
}
ops::indexing::slice(a, &starts, &stops, &strides)
}
pub fn apply_top_k(logprobs: &Array, top_k: i32) -> Result<Array> {
let vocab_size = *logprobs.shape().last().unwrap_or(&0) as i32;
if top_k <= 0 || top_k >= vocab_size {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"top_k",
"must be an integer in the open interval (0, vocab_size)",
format!("{top_k} (vocab_size={vocab_size})"),
)));
}
let neg = ops::arithmetic::negative(logprobs)?;
let part = ops::misc::argpartition_axis(&neg, top_k - 1, -1)?;
let mask_idx = slice_last_axis(&part, top_k, vocab_size)?;
let neg_inf = scalar_like(f32::NEG_INFINITY, logprobs)?;
ops::indexing::put_along_axis(logprobs, &mask_idx, &neg_inf, -1)
}
pub fn apply_min_p(logprobs: &Array, min_p: f32, min_tokens_to_keep: i32) -> Result<Array> {
if !(0.0..=1.0).contains(&min_p) {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"min_p",
"must be a float in the closed interval [0, 1]",
format!("{min_p}"),
)));
}
if min_tokens_to_keep < 1 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"min_tokens_to_keep",
"must be a positive integer (>= 1)",
format!("{min_tokens_to_keep}"),
)));
}
let vocab_size = *logprobs.shape().last().unwrap_or(&0) as i32;
if min_tokens_to_keep > vocab_size {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"min_tokens_to_keep",
"must not exceed the vocabulary size",
format!("{min_tokens_to_keep} (vocab_size={vocab_size})"),
)));
}
let top_logprobs = ops::reduction::max_axes(logprobs, &[-1], true)?;
let scaled_min_p = ops::arithmetic::add(&top_logprobs, &scalar_like(min_p.ln(), &top_logprobs)?)?;
let mut tokens_to_remove = ops::comparison::less(logprobs, &scaled_min_p)?;
if min_tokens_to_keep > 1 {
let part = ops::misc::argpartition_axis(logprobs, vocab_size - min_tokens_to_keep, -1)?;
let top_indices = slice_last_axis(&part, vocab_size - min_tokens_to_keep, vocab_size)?;
let keep = Array::full::<bool>(&(1,), false)?;
tokens_to_remove = ops::indexing::put_along_axis(&tokens_to_remove, &top_indices, &keep, -1)?;
}
let neg_inf = scalar_like(f32::NEG_INFINITY, logprobs)?;
ops::logical::select(&tokens_to_remove, &neg_inf, logprobs)
}
pub fn apply_top_p(logprobs: &Array, top_p: f32) -> Result<Array> {
if !top_p.is_finite() || top_p <= 0.0 || top_p > 1.0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"top_p",
"must be a finite float in the half-open interval (0, 1]",
format!("{top_p}"),
)));
}
let probs = ops::arithmetic::exp(logprobs)?;
let sorted_indices = ops::misc::argsort_axis(logprobs, -1)?;
let sorted_probs = ops::indexing::take_along_axis(&probs, &sorted_indices, -1)?;
let cumulative_probs = ops::misc::cumsum(&sorted_probs, -1, false, true)?;
let inverse_indices = ops::misc::argsort_axis(&sorted_indices, -1)?;
let cumulative_probs = ops::indexing::take_along_axis(&cumulative_probs, &inverse_indices, -1)?;
let threshold = scalar_like(1.0 - top_p, &cumulative_probs)?;
let keep = ops::comparison::greater(&cumulative_probs, &threshold)?;
let neg_inf = scalar_like(f32::NEG_INFINITY, logprobs)?;
ops::logical::select(&keep, logprobs, &neg_inf)
}
pub fn scale_logits_by_temp(logits: &Array, temp: f32) -> Result<Array> {
crate::error::ensure_handler_installed();
if !temp.is_finite() || temp <= 0.0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"temp",
"must be a finite positive float (use argmax_sample for temperature-0 / greedy decoding)",
format!("{temp}"),
)));
}
let temp = temp.max(f32::MIN_POSITIVE);
let dtype = logits.dtype()?;
match dtype {
crate::Dtype::F32 => {
let divisor = Array::full::<f32>(&(1,), temp)?;
ops::arithmetic::divide(logits, &divisor)
}
crate::Dtype::F16 | crate::Dtype::BF16 => {
let logits_f32 = ops::misc::astype(logits, crate::Dtype::F32)?;
let divisor = Array::full::<f32>(&(1,), temp)?;
let scaled_f32 = ops::arithmetic::divide(&logits_f32, &divisor)?;
ops::misc::astype(&scaled_f32, dtype)
}
other => Err(Error::UnsupportedDtype(UnsupportedDtypePayload::new(
"categorical_sampling: logits dtype (MLX's GPU stream does not implement F64; cast with .astype(Dtype::F32) before sampling)",
other,
&[Dtype::F32, Dtype::F16, Dtype::BF16],
))),
}
}
pub fn categorical_sampling(logits: &Array, temp: f32, key: &Array) -> Result<Array> {
let scaled = scale_logits_by_temp(logits, temp)?;
ops::random::categorical(&scaled, -1, key)
}
pub fn argmax_sample(logits: &Array) -> Result<Array> {
ops::misc::argmax(logits, Some(-1), false)
}
pub fn apply_xtc(
logits: &Array,
xtc_probability: f32,
xtc_threshold: f32,
xtc_special_tokens: &[i32],
key: &Array,
) -> Result<Array> {
if !xtc_threshold.is_finite() || !(0.0..=0.5).contains(&xtc_threshold) {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"xtc_threshold",
"must be a finite float in the closed interval [0, 0.5]",
format!("{xtc_threshold}"),
)));
}
if !xtc_probability.is_finite() || !(0.0..=1.0).contains(&xtc_probability) {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"xtc_probability",
"must be a finite float in the closed interval [0, 1]",
format!("{xtc_probability}"),
)));
}
let probs = ops::misc::softmax_axis(logits, -1, false)?;
let thr = scalar_like(xtc_threshold, &probs)?;
let pos_inf = scalar_like(f32::INFINITY, &probs)?;
let above = ops::comparison::greater(&probs, &thr)?;
let candidates = ops::logical::select(&above, &probs, &pos_inf)?;
let cutoff = ops::reduction::min_axes(&candidates, &[-1], true)?;
let mut mask = ops::comparison::greater(&probs, &cutoff)?;
if !xtc_special_tokens.is_empty() {
let special = token_index(logits, xtc_special_tokens)?;
let off = Array::full::<bool>(&(1,), false)?;
mask = ops::indexing::put_along_axis(&mask, &special, &off, -1)?;
}
let zero = scalar_like(0.0, logits)?;
let one = scalar_like(1.0, logits)?;
let u = ops::random::uniform(&zero, &one, &[1i32], logits.dtype()?, key)?;
let prob = scalar_like(xtc_probability, logits)?;
let gate = ops::comparison::greater(&u, &prob)?;
let neg_inf = scalar_like(f32::NEG_INFINITY, logits)?;
let masked = ops::logical::select(&mask, &neg_inf, logits)?;
ops::logical::select(&gate, logits, &masked)
}
fn token_index(like: &Array, ids: &[i32]) -> Result<Array> {
let ndim = like.ndim().max(1);
let mut shape = vec![1i32; ndim];
let last = shape.len() - 1;
shape[last] = ids.len() as i32;
let dims: &[i32] = &shape;
Array::from_slice::<i32>(ids, &dims)
}
pub fn apply_repetition_penalty(logits: &Array, token_ids: &[i32], penalty: f32) -> Result<Array> {
if !penalty.is_finite() || penalty < 0.0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"penalty",
"must be a finite non-negative float",
format!("{penalty}"),
)));
}
if token_ids.is_empty() {
return logits.try_clone();
}
let idx = token_index(logits, token_ids)?;
let selected = ops::indexing::take_along_axis(logits, &idx, -1)?;
let p = scalar_like(penalty, &selected)?;
let scaled_down = ops::arithmetic::multiply(&selected, &p)?;
let scaled_up = ops::arithmetic::divide(&selected, &p)?;
let is_neg = ops::comparison::less(&selected, &scalar_like(0.0, &selected)?)?;
let new_selected = ops::logical::select(&is_neg, &scaled_down, &scaled_up)?;
ops::indexing::put_along_axis(logits, &idx, &new_selected, -1)
}
pub fn apply_presence_penalty(logits: &Array, token_ids: &[i32], penalty: f32) -> Result<Array> {
if token_ids.is_empty() {
return logits.try_clone();
}
let idx = token_index(logits, token_ids)?;
let selected = ops::indexing::take_along_axis(logits, &idx, -1)?;
let reduced = ops::arithmetic::subtract(&selected, &scalar_like(penalty, &selected)?)?;
ops::indexing::put_along_axis(logits, &idx, &reduced, -1)
}
pub fn apply_frequency_penalty(logits: &Array, token_ids: &[i32], penalty: f32) -> Result<Array> {
if token_ids.is_empty() {
return logits.try_clone();
}
let idx = token_index(logits, token_ids)?;
let ndim = logits.ndim().max(1);
let mut vshape = vec![1i32; ndim];
let last = vshape.len() - 1;
vshape[last] = token_ids.len() as i32;
let vdims: &[i32] = &vshape;
let neg_pen = ops::shape::reshape(
&ops::misc::astype(
&Array::full::<f32>(&(token_ids.len(),), -penalty)?,
logits.dtype()?,
)?,
&vdims,
)?;
ops::indexing::scatter_add_axis(logits, &idx, &neg_pen, -1)
}
pub fn apply_logit_bias(logits: &Array, indices: &[i32], values: &Array) -> Result<Array> {
if values.size() != indices.len() {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"apply_logit_bias: indices length vs values.size()",
indices.len(),
values.size(),
)));
}
if indices.is_empty() {
return logits.try_clone();
}
let idx = token_index(logits, indices)?;
let ndim = logits.ndim().max(1);
let mut vshape = vec![1i32; ndim];
let last = vshape.len() - 1;
vshape[last] = indices.len() as i32;
let vdims: &[i32] = &vshape;
let v = ops::shape::reshape(&ops::misc::astype(values, logits.dtype()?)?, &vdims)?;
ops::indexing::scatter_add_axis(logits, &idx, &v, -1)
}