#![cfg(feature = "lm")]
use mlxrs::{Array, Dtype, lm::sample};
fn logprobs() -> Array {
Array::from_slice::<f32>(&[-3.0, -2.0, -1.0, 0.0], &[1, 4]).unwrap()
}
#[test]
fn argmax_sample_picks_max_index() {
let lp = logprobs();
let mut tok = sample::argmax_sample(&lp).unwrap();
assert_eq!(tok.shape(), vec![1]);
assert_eq!(tok.to_vec::<u32>().unwrap(), vec![3]);
}
#[test]
fn top_k_1_keeps_only_the_max() {
let lp = logprobs();
let mut out = sample::apply_top_k(&lp, 1).unwrap();
let v = out.to_vec::<f32>().unwrap();
assert_eq!(v[3], 0.0);
assert!(v[0].is_infinite() && v[0] < 0.0);
assert!(v[1].is_infinite() && v[1] < 0.0);
assert!(v[2].is_infinite() && v[2] < 0.0);
}
#[test]
fn top_k_out_of_range_errors() {
let lp = logprobs();
assert!(sample::apply_top_k(&lp, 0).is_err());
assert!(sample::apply_top_k(&lp, 4).is_err());
}
#[test]
fn top_p_full_mass_keeps_all() {
let lp = logprobs();
let mut out = sample::apply_top_p(&lp, 0.999).unwrap();
assert_eq!(out.shape(), vec![1, 4]);
let v = out.to_vec::<f32>().unwrap();
assert!(v.iter().all(|x| x.is_finite()));
}
#[test]
fn top_p_aggressive_keeps_at_least_the_top() {
let lp = logprobs();
let mut out = sample::apply_top_p(&lp, 0.05).unwrap();
let v = out.to_vec::<f32>().unwrap();
assert!(v[3].is_finite());
}
#[test]
fn min_p_keeps_top_and_prunes_tail() {
let lp = logprobs();
let mut out = sample::apply_min_p(&lp, 0.5, 1).unwrap();
let v = out.to_vec::<f32>().unwrap();
assert!(v[3].is_finite());
assert!(v[0].is_infinite() && v[0] < 0.0);
}
#[test]
fn min_p_invalid_params_error() {
let lp = logprobs();
assert!(sample::apply_min_p(&lp, 1.5, 1).is_err());
assert!(sample::apply_min_p(&lp, 0.1, 0).is_err());
assert!(sample::apply_min_p(&lp, 0.9, 5).is_err());
}
#[test]
fn categorical_sampling_shape_and_bounds() {
let lp = logprobs();
let key = mlxrs::ops::random::key(0).unwrap();
let mut tok = sample::categorical_sampling(&lp, 0.8, &key).unwrap();
assert_eq!(tok.shape(), vec![1]);
let idx = tok.to_vec::<u32>().unwrap();
assert_eq!(idx.len(), 1);
assert!(idx[0] < 4);
}
#[test]
fn top_p_inverse_permutation_preserves_original_order() {
let lp = Array::from_slice::<f32>(&[-1.0, -3.0, -2.0, 0.0], &[1, 4]).unwrap();
let mut out = sample::apply_top_p(&lp, 0.7).unwrap();
let v = out.to_vec::<f32>().unwrap();
assert_eq!(v[0], -1.0, "orig idx 0 kept with its logprob");
assert!(v[1].is_infinite() && v[1] < 0.0, "orig idx 1 pruned");
assert!(v[2].is_infinite() && v[2] < 0.0, "orig idx 2 pruned");
assert_eq!(v[3], 0.0, "orig idx 3 kept with its logprob");
}
#[test]
fn half_and_bfloat_preserve_dtype_and_mask() {
for dt in [Dtype::F16, Dtype::BF16] {
let lp = Array::from_slice::<f32>(&[-3.0, -2.0, -1.0, 0.0], &[1, 4])
.unwrap()
.astype(dt)
.unwrap();
let tk = sample::apply_top_k(&lp, 1).unwrap();
assert_eq!(tk.dtype().unwrap(), dt, "top_k preserves {dt:?}");
let mut tkf = tk.astype(Dtype::F32).unwrap();
let v = tkf.to_vec::<f32>().unwrap();
assert_eq!(v[3], 0.0, "{dt:?} top_k keeps argmax with its logprob");
assert!(v[0].is_infinite() && v[0] < 0.0, "{dt:?} top_k prunes tail");
let tp = sample::apply_top_p(&lp, 0.7).unwrap();
assert_eq!(tp.dtype().unwrap(), dt, "top_p preserves {dt:?}");
let mut tpf = tp.astype(Dtype::F32).unwrap();
let vp = tpf.to_vec::<f32>().unwrap();
assert!(vp[3].is_finite(), "{dt:?} top_p keeps argmax");
let mp = sample::apply_min_p(&lp, 0.5, 1).unwrap();
assert_eq!(mp.dtype().unwrap(), dt, "min_p preserves {dt:?}");
let mut mpf = mp.astype(Dtype::F32).unwrap();
let vm = mpf.to_vec::<f32>().unwrap();
assert!(vm[3].is_finite(), "{dt:?} min_p keeps top");
assert!(
vm[0].is_infinite() && vm[0] < 0.0,
"{dt:?} min_p prunes tail"
);
let key = mlxrs::ops::random::key(0).unwrap();
let mut tok = sample::categorical_sampling(&lp, 0.8, &key).unwrap();
let idx = tok.to_vec::<u32>().unwrap();
assert_eq!(idx.len(), 1);
assert!(idx[0] < 4, "{dt:?} categorical draws an in-range index");
}
}
#[test]
fn top_p_out_of_range_errors() {
let lp = logprobs();
assert!(sample::apply_top_p(&lp, 0.0).is_err());
assert!(sample::apply_top_p(&lp, -0.1).is_err());
assert!(sample::apply_top_p(&lp, 1.5).is_err());
assert!(sample::apply_top_p(&lp, f32::NAN).is_err());
assert!(sample::apply_top_p(&lp, f32::INFINITY).is_err());
assert!(
sample::apply_top_p(&lp, 1.0).is_ok(),
"top_p == 1.0 is a valid no-op"
);
}
#[test]
fn categorical_sampling_invalid_temp_errors() {
let lp = logprobs();
let key = mlxrs::ops::random::key(0).unwrap();
assert!(sample::categorical_sampling(&lp, 0.0, &key).is_err());
assert!(sample::categorical_sampling(&lp, -1.0, &key).is_err());
assert!(sample::categorical_sampling(&lp, f32::NAN, &key).is_err());
assert!(sample::categorical_sampling(&lp, f32::INFINITY, &key).is_err());
}
#[test]
fn categorical_sampling_tiny_and_subnormal_temp_stays_finite() {
let key = mlxrs::ops::random::key(0).unwrap();
let lp_f16 = Array::from_slice::<f32>(&[-3.0, -2.0, -1.0, 0.0], &[1, 4])
.unwrap()
.astype(Dtype::F16)
.unwrap();
let mut tok = sample::categorical_sampling(&lp_f16, 1e-7_f32, &key).unwrap();
let idx = tok.to_vec::<u32>().unwrap();
assert_eq!(idx.len(), 1, "f16 tiny-temp shape");
assert!(
idx[0] < 4,
"f16 tiny-temp must draw within [0, vocab); got {}",
idx[0]
);
let subnormal: f32 = 1e-40; assert!(subnormal.is_finite() && subnormal > 0.0);
assert!(
(1.0_f32 / subnormal).is_infinite(),
"test premise: 1/temp overflows f32 reciprocal"
);
for dt in [Dtype::F32, Dtype::F16, Dtype::BF16] {
let lp = Array::from_slice::<f32>(&[-3.0, -2.0, -1.0, 0.0], &[1, 4])
.unwrap()
.astype(dt)
.unwrap();
let mut tok = sample::categorical_sampling(&lp, subnormal, &key).unwrap();
let idx = tok.to_vec::<u32>().unwrap();
assert_eq!(idx.len(), 1, "{dt:?} subnormal-temp shape");
assert!(
idx[0] < 4,
"{dt:?} subnormal-temp must draw within [0, vocab); got {}",
idx[0]
);
}
}
#[test]
fn categorical_sampling_tiny_temp_produces_finite_scaled_logits() {
let temp: f32 = 1e-40;
assert!(
temp.is_finite() && temp > 0.0,
"test premise: temp is finite +ve"
);
for dt in [Dtype::F32, Dtype::F16, Dtype::BF16] {
let lp = Array::from_slice::<f32>(
&[0.0, 0.0, 0.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
&[1, 10],
)
.unwrap()
.astype(dt)
.unwrap();
let scaled = sample::scale_logits_by_temp(&lp, temp).unwrap();
assert_eq!(scaled.dtype().unwrap(), dt, "{dt:?} preserves dtype");
let mut sf = scaled.astype(Dtype::F32).unwrap();
let sv = sf.to_vec::<f32>().unwrap();
assert_eq!(sv.len(), 10, "{dt:?} scaled shape preserved");
for (i, x) in sv.iter().enumerate() {
assert!(
!x.is_nan(),
"{dt:?} scaled[{i}] must NOT be NaN under sub-min-subnormal temp; got {x} in {sv:?}"
);
}
for (i, x) in sv.iter().enumerate() {
if i == 3 {
assert!(!x.is_nan(), "{dt:?} max position must NOT be NaN");
} else {
assert!(
x.is_finite() && *x == 0.0,
"{dt:?} zero position {i} must stay finite zero, got {x}"
);
}
}
let key = mlxrs::ops::random::key(0).unwrap();
let mut tok = sample::categorical_sampling(&lp, temp, &key).unwrap();
let idx = tok.to_vec::<u32>().unwrap();
assert_eq!(idx.len(), 1, "{dt:?} categorical shape");
assert!(
idx[0] < 10,
"{dt:?} categorical draws in-range index; got {}",
idx[0]
);
}
}
fn penalty_logits() -> Array {
Array::from_slice::<f32>(&[2.0, -4.0, 1.0, -1.0], &[1, 4]).unwrap()
}
fn vals(a: &Array, dt: Dtype) -> Vec<f32> {
a.astype(dt)
.unwrap()
.astype(Dtype::F32)
.unwrap()
.to_vec::<f32>()
.unwrap()
}
#[test]
fn xtc_excludes_top_choices_above_cutoff() {
let lp = Array::from_slice::<f32>(
&[0.5f32.ln(), 0.3f32.ln(), 0.15f32.ln(), 0.05f32.ln()],
&[1, 4],
)
.unwrap();
let key = mlxrs::ops::random::key(0).unwrap();
let mut out = sample::apply_xtc(&lp, 1.0, 0.1, &[], &key).unwrap();
let v = out.to_vec::<f32>().unwrap();
assert!(v[0].is_infinite() && v[0] < 0.0, "top prob 0.5 excluded");
assert!(v[1].is_infinite() && v[1] < 0.0, "prob 0.3 excluded");
assert_eq!(v[2], 0.15f32.ln(), "boundary prob 0.15 kept (strict >)");
assert_eq!(v[3], 0.05f32.ln(), "tail prob 0.05 kept");
}
#[test]
fn xtc_special_tokens_are_preserved() {
let lp = Array::from_slice::<f32>(
&[0.5f32.ln(), 0.3f32.ln(), 0.15f32.ln(), 0.05f32.ln()],
&[1, 4],
)
.unwrap();
let key = mlxrs::ops::random::key(0).unwrap();
let mut out = sample::apply_xtc(&lp, 1.0, 0.1, &[0], &key).unwrap();
let v = out.to_vec::<f32>().unwrap();
assert_eq!(v[0], 0.5f32.ln(), "special idx0 kept despite mask");
assert!(
v[1].is_infinite() && v[1] < 0.0,
"non-special idx1 excluded"
);
}
#[test]
fn xtc_no_token_above_threshold_is_identity() {
let lp = Array::from_slice::<f32>(&[0.0, 0.0, 0.0, 0.0], &[1, 4]).unwrap();
let key = mlxrs::ops::random::key(0).unwrap();
let mut out = sample::apply_xtc(&lp, 1.0, 0.4, &[], &key).unwrap();
assert_eq!(out.to_vec::<f32>().unwrap(), vec![0.0, 0.0, 0.0, 0.0]);
}
#[test]
fn xtc_invalid_params_error() {
let lp = logprobs();
let key = mlxrs::ops::random::key(0).unwrap();
assert!(
sample::apply_xtc(&lp, 0.5, 0.6, &[], &key).is_err(),
"thr>0.5"
);
assert!(
sample::apply_xtc(&lp, 0.5, -0.1, &[], &key).is_err(),
"thr<0"
);
assert!(
sample::apply_xtc(&lp, 1.5, 0.3, &[], &key).is_err(),
"prob>1"
);
assert!(
sample::apply_xtc(&lp, f32::NAN, 0.3, &[], &key).is_err(),
"prob NaN"
);
assert!(
sample::apply_xtc(&lp, 0.5, f32::NAN, &[], &key).is_err(),
"thr NaN"
);
assert!(
sample::apply_xtc(&lp, 0.5, 0.5, &[], &key).is_ok(),
"thr==0.5 / prob==0.5 valid"
);
}
#[test]
fn repetition_penalty_sign_aware() {
let lg = penalty_logits();
let mut out = sample::apply_repetition_penalty(&lg, &[0, 1], 2.0).unwrap();
assert_eq!(out.to_vec::<f32>().unwrap(), vec![1.0, -8.0, 1.0, -1.0]);
}
#[test]
fn repetition_penalty_empty_tokens_is_identity_and_bad_penalty_errors() {
let lg = penalty_logits();
let mut out = sample::apply_repetition_penalty(&lg, &[], 2.0).unwrap();
assert_eq!(out.to_vec::<f32>().unwrap(), vec![2.0, -4.0, 1.0, -1.0]);
assert!(sample::apply_repetition_penalty(&lg, &[0], -1.0).is_err());
assert!(sample::apply_repetition_penalty(&lg, &[0], f32::NAN).is_err());
}
#[test]
fn presence_penalty_subtracts_once() {
let lg = penalty_logits();
let mut out = sample::apply_presence_penalty(&lg, &[0, 2], 1.5).unwrap();
assert_eq!(out.to_vec::<f32>().unwrap(), vec![0.5, -4.0, -0.5, -1.0]);
}
#[test]
fn presence_penalty_duplicate_ids_penalized_once() {
let lg = penalty_logits();
let mut out = sample::apply_presence_penalty(&lg, &[1, 1], 1.5).unwrap();
assert_eq!(out.to_vec::<f32>().unwrap(), vec![2.0, -5.5, 1.0, -1.0]);
}
#[test]
fn frequency_penalty_scales_with_count() {
let lg = penalty_logits();
let mut out = sample::apply_frequency_penalty(&lg, &[1, 1, 2], 0.5).unwrap();
assert_eq!(out.to_vec::<f32>().unwrap(), vec![2.0, -5.0, 0.5, -1.0]);
}
#[test]
fn frequency_penalty_empty_tokens_is_identity() {
let lg = penalty_logits();
let mut out = sample::apply_frequency_penalty(&lg, &[], 0.5).unwrap();
assert_eq!(out.to_vec::<f32>().unwrap(), vec![2.0, -4.0, 1.0, -1.0]);
}
#[test]
fn logit_bias_adds_at_indices() {
let lg = penalty_logits();
let bv = Array::from_slice::<f32>(&[1.0, -2.0], &[2]).unwrap();
let mut out = sample::apply_logit_bias(&lg, &[0, 3], &bv).unwrap();
assert_eq!(out.to_vec::<f32>().unwrap(), vec![3.0, -4.0, 1.0, -3.0]);
}
#[test]
fn logit_bias_duplicate_indices_accumulate() {
let lg = penalty_logits();
let bv = Array::from_slice::<f32>(&[1.0, 0.5], &[2]).unwrap();
let mut out = sample::apply_logit_bias(&lg, &[2, 2], &bv).unwrap();
assert_eq!(out.to_vec::<f32>().unwrap(), vec![2.0, -4.0, 2.5, -1.0]);
}
#[test]
fn logit_bias_length_mismatch_errors_and_empty_is_identity() {
let lg = penalty_logits();
let bv = Array::from_slice::<f32>(&[1.0, -2.0], &[2]).unwrap();
assert!(
sample::apply_logit_bias(&lg, &[0, 1, 2], &bv).is_err(),
"3 indices vs 2 values"
);
assert!(
sample::apply_logit_bias(&lg, &[], &bv).is_err(),
"0 indices vs 2 values must error, not silently drop the bias"
);
let empty = Array::from_slice::<f32>(&[], &[0i32]).unwrap();
let mut out = sample::apply_logit_bias(&lg, &[], &empty).unwrap();
assert_eq!(out.to_vec::<f32>().unwrap(), vec![2.0, -4.0, 1.0, -1.0]);
}
#[test]
fn new_transforms_preserve_half_and_bfloat_dtype() {
for dt in [Dtype::F16, Dtype::BF16] {
let lg = penalty_logits().astype(dt).unwrap();
let rep = sample::apply_repetition_penalty(&lg, &[0, 1], 2.0).unwrap();
assert_eq!(rep.dtype().unwrap(), dt, "rep preserves {dt:?}");
assert_eq!(vals(&rep, dt), vec![1.0, -8.0, 1.0, -1.0], "{dt:?} rep");
let pres = sample::apply_presence_penalty(&lg, &[0, 2], 1.5).unwrap();
assert_eq!(pres.dtype().unwrap(), dt, "pres preserves {dt:?}");
assert_eq!(vals(&pres, dt), vec![0.5, -4.0, -0.5, -1.0], "{dt:?} pres");
let freq = sample::apply_frequency_penalty(&lg, &[1, 1, 2], 0.5).unwrap();
assert_eq!(freq.dtype().unwrap(), dt, "freq preserves {dt:?}");
assert_eq!(vals(&freq, dt), vec![2.0, -5.0, 0.5, -1.0], "{dt:?} freq");
let bv = Array::from_slice::<f32>(&[1.0, -2.0], &[2])
.unwrap()
.astype(dt)
.unwrap();
let bias = sample::apply_logit_bias(&lg, &[0, 3], &bv).unwrap();
assert_eq!(bias.dtype().unwrap(), dt, "bias preserves {dt:?}");
assert_eq!(vals(&bias, dt), vec![3.0, -4.0, 1.0, -3.0], "{dt:?} bias");
let xl = Array::from_slice::<f32>(
&[0.5f32.ln(), 0.3f32.ln(), 0.15f32.ln(), 0.05f32.ln()],
&[1, 4],
)
.unwrap()
.astype(dt)
.unwrap();
let key = mlxrs::ops::random::key(0).unwrap();
let xtc = sample::apply_xtc(&xl, 1.0, 0.1, &[], &key).unwrap();
assert_eq!(xtc.dtype().unwrap(), dt, "xtc preserves {dt:?}");
let xv = vals(&xtc, dt);
assert!(xv[0].is_infinite() && xv[0] < 0.0, "{dt:?} xtc excl idx0");
assert!(xv[1].is_infinite() && xv[1] < 0.0, "{dt:?} xtc excl idx1");
assert!(
xv[2].is_finite() && xv[3].is_finite(),
"{dt:?} xtc keeps tail"
);
}
}
#[test]
fn apply_frequency_penalty_f16_large_penalty_no_nan_bleed() {
let lp = Array::from_slice::<f32>(&[2.0, -4.0, 1.0, -1.0], &(1, 4))
.unwrap()
.astype(Dtype::F16)
.unwrap();
let out = sample::apply_frequency_penalty(&lp, &[1], 70000.0).unwrap();
assert_eq!(out.dtype().unwrap(), Dtype::F16, "dtype preserved");
let v = vals(&out, Dtype::F16);
assert_eq!(v[0], 2.0, "untouched col 0 bit-exact, not NaN: {v:?}");
assert!(
v[1].is_infinite() && v[1] < 0.0,
"selected col 1 suppressed: {v:?}"
);
assert_eq!(v[2], 1.0, "untouched col 2 bit-exact, not NaN: {v:?}");
assert_eq!(v[3], -1.0, "untouched col 3 bit-exact, not NaN: {v:?}");
let f = Array::from_slice::<f32>(&[10.0, 20.0, 30.0], &(1, 3)).unwrap();
let of = sample::apply_frequency_penalty(&f, &[2, 2], 5.0).unwrap();
assert_eq!(
vals(&of, Dtype::F32),
vec![10.0, 20.0, 20.0],
"id 2 twice → -10; others exact"
);
let sz = Array::from_slice::<f32>(&[-0.0, 5.0], &(1, 2))
.unwrap()
.astype(Dtype::F16)
.unwrap();
let so = sample::apply_frequency_penalty(&sz, &[1], 70000.0).unwrap();
let sv = vals(&so, Dtype::F16);
assert_eq!(
sv[0].to_bits(),
(-0.0_f32).to_bits(),
"untouched -0.0 must stay raw -0.0 (no signed-zero canonicalization): {sv:?}"
);
assert!(
sv[1].is_infinite() && sv[1] < 0.0,
"selected col suppressed: {sv:?}"
);
}
#[test]
fn scale_logits_by_temp_ensures_handler_installed() {
let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("src/lm/sample.rs");
let src = std::fs::read_to_string(&path).expect("read sample.rs");
let sig = "pub fn scale_logits_by_temp(logits: &Array, temp: f32) -> Result<Array> {";
let sig_idx = src
.find(sig)
.expect("scale_logits_by_temp signature must exist in sample.rs");
let body_start = sig_idx + sig.len();
let body = &src[body_start..];
let mut executable: Vec<&str> = Vec::new();
for raw in body.lines() {
if executable.len() >= 3 {
break;
}
let trimmed = raw.trim();
if trimmed.is_empty() {
continue;
}
if trimmed.starts_with("//") {
continue;
}
executable.push(trimmed);
}
assert!(
!executable.is_empty(),
"scale_logits_by_temp body must contain at least one executable line"
);
assert_eq!(
executable[0], "crate::error::ensure_handler_installed();",
"scale_logits_by_temp MUST call ensure_handler_installed() as the FIRST \
executable statement; first 3 executable lines were: {executable:?}"
);
}
#[test]
fn scale_logits_by_temp_rejects_f64() {
let base: f64 = 1.0;
let step: f64 = 1e-9;
let input: Vec<f64> = (0..10).map(|i| base + (i as f64) * step).collect();
for w in input.windows(2) {
assert!(w[0] < w[1], "test premise: input strictly monotonic in f64");
}
let f32_roundtripped: Vec<f64> = input.iter().map(|x| *x as f32 as f64).collect();
assert!(
f32_roundtripped.windows(2).any(|w| w[0] == w[1]),
"test premise: f32 roundtrip collapses some adjacent f64 values to equal — \
this is the silent precision loss the F64 rejection protects against; \
got {f32_roundtripped:?}"
);
let lp = Array::from_slice::<f64>(&input, &[1, 10]).unwrap();
assert_eq!(lp.dtype().unwrap(), Dtype::F64);
let err = sample::scale_logits_by_temp(&lp, 0.5).unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("F64"),
"F64 rejection message must mention F64 explicitly: {msg}"
);
assert!(
msg.contains("astype") || msg.contains("Cast"),
"F64 rejection message must tell the caller to cast: {msg}"
);
let key = mlxrs::ops::random::key(0).unwrap();
assert!(
sample::categorical_sampling(&lp, 0.5, &key).is_err(),
"categorical_sampling on F64 must error via scale_logits_by_temp"
);
}
#[test]
fn scale_logits_by_temp_rejects_integer_dtype() {
let lp = Array::from_slice::<i32>(&[1, 2, 3, 4], &[1, 4]).unwrap();
assert_eq!(lp.dtype().unwrap(), Dtype::I32);
let err = sample::scale_logits_by_temp(&lp, 0.8).unwrap_err();
match err {
mlxrs::Error::UnsupportedDtype(p) => {
assert_eq!(
p.dtype(),
Dtype::I32,
"rejection payload names rejected dtype"
);
assert_eq!(p.supported(), &[Dtype::F32, Dtype::F16, Dtype::BF16]);
assert!(
p.context().contains("logits dtype"),
"context names the logits-dtype site: {}",
p.context()
);
}
other => panic!("expected Error::UnsupportedDtype for i32 logits, got {other:?}"),
}
let key = mlxrs::ops::random::key(0).unwrap();
assert!(
sample::categorical_sampling(&lp, 0.8, &key).is_err(),
"categorical_sampling on i32 logits must error via scale_logits_by_temp"
);
}