use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
pub static SOFTMAX_SAMPLE_SHADER_SOURCE: &str =
include_str!("../shaders/softmax_sample.metal");
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("softmax_sample_f32", SOFTMAX_SAMPLE_SHADER_SOURCE);
}
pub fn dispatch_softmax_sample_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
logits: &MlxBuffer,
scratch: &MlxBuffer,
out_token: &MlxBuffer,
out_logprob: &MlxBuffer,
params_buf: &MlxBuffer,
n_elements: u32,
temperature: f32,
random_val: f32,
) -> Result<()> {
if n_elements == 0 {
return Err(MlxError::InvalidArgument(
"softmax_sample_f32: n_elements must be > 0".into(),
));
}
if temperature <= 0.0 {
return Err(MlxError::InvalidArgument(format!(
"softmax_sample_f32: temperature must be > 0, got {}",
temperature
)));
}
if !(0.0..1.0).contains(&random_val) {
return Err(MlxError::InvalidArgument(format!(
"softmax_sample_f32: random_val must be in [0, 1), got {}",
random_val
)));
}
if logits.element_count() != n_elements as usize {
return Err(MlxError::InvalidArgument(format!(
"softmax_sample_f32: logits element count {} != n_elements {}",
logits.element_count(),
n_elements
)));
}
if scratch.element_count() != n_elements as usize {
return Err(MlxError::InvalidArgument(format!(
"softmax_sample_f32: scratch element count {} != n_elements {}",
scratch.element_count(),
n_elements
)));
}
if out_token.element_count() < 1 {
return Err(MlxError::InvalidArgument(
"softmax_sample_f32: out_token must have at least 1 element".into(),
));
}
if out_logprob.element_count() < 1 {
return Err(MlxError::InvalidArgument(
"softmax_sample_f32: out_logprob must have at least 1 element".into(),
));
}
let pipeline = registry.get_pipeline("softmax_sample_f32", device)?;
let tg_size = std::cmp::min(1024, n_elements.next_power_of_two()) as u64;
let shared_mem_bytes = tg_size * 4;
encoder.encode_threadgroups_with_shared(
pipeline,
&[
(0, logits),
(1, scratch),
(2, out_token),
(3, out_logprob),
(4, params_buf),
],
&[(0, shared_mem_bytes)],
MTLSize::new(1, 1, 1), MTLSize::new(tg_size, 1, 1),
);
Ok(())
}