use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::dtypes::DType;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
pub static SOFTCAP_SHADER_SOURCE: &str = include_str!("../shaders/softcap.metal");
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("softcap_f32", SOFTCAP_SHADER_SOURCE);
registry.register_source("softcap_f16", SOFTCAP_SHADER_SOURCE);
registry.register_source("softcap_bf16", SOFTCAP_SHADER_SOURCE);
}
pub fn dispatch_softcap(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
params_buf: &MlxBuffer,
cap: f32,
) -> Result<()> {
if cap <= 0.0 {
return Err(MlxError::InvalidArgument(format!(
"Softcap cap must be positive, got {}",
cap
)));
}
let n = input.element_count();
if n == 0 {
return Err(MlxError::InvalidArgument(
"Softcap input must have at least one element".into(),
));
}
if output.element_count() != n {
return Err(MlxError::InvalidArgument(format!(
"Softcap output element count {} != input element count {}",
output.element_count(),
n
)));
}
let _ = cap;
let kernel_name = match input.dtype() {
DType::F32 => "softcap_f32",
DType::F16 => "softcap_f16",
DType::BF16 => "softcap_bf16",
_ => {
return Err(MlxError::InvalidArgument(format!(
"Softcap unsupported dtype: {}",
input.dtype()
)));
}
};
let pipeline = registry.get_pipeline(kernel_name, device)?;
let threadgroup_size: u64 = std::cmp::min(256, n as u64);
let threadgroup_count = (n as u64 + threadgroup_size - 1) / threadgroup_size;
encoder.encode_threadgroups(
pipeline,
&[(0, input), (1, output), (2, params_buf)],
MTLSize::new(threadgroup_count, 1, 1),
MTLSize::new(threadgroup_size, 1, 1),
);
Ok(())
}