use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::dtypes::DType;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
pub static L2_NORM_SHADER_SOURCE: &str = include_str!("../shaders/l2_norm.metal");
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("l2_norm_f32", L2_NORM_SHADER_SOURCE);
registry.register_source("l2_norm_f16", L2_NORM_SHADER_SOURCE);
registry.register_source("l2_norm_bf16", L2_NORM_SHADER_SOURCE);
registry.register_source("l2_norm_scale_f32", L2_NORM_SHADER_SOURCE);
}
pub fn dispatch_l2_norm(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
params_buf: &MlxBuffer,
rows: u32,
dim: u32,
) -> Result<()> {
if rows == 0 || dim == 0 {
return Err(MlxError::InvalidArgument(
"L2 norm rows and dim must be > 0".into(),
));
}
let expected = (rows as usize) * (dim as usize);
if input.element_count() != expected {
return Err(MlxError::InvalidArgument(format!(
"L2 norm input element count {} != rows({}) * dim({})",
input.element_count(),
rows,
dim
)));
}
if output.element_count() != expected {
return Err(MlxError::InvalidArgument(format!(
"L2 norm output element count {} != rows({}) * dim({})",
output.element_count(),
rows,
dim
)));
}
if input.dtype() != output.dtype() {
return Err(MlxError::InvalidArgument(format!(
"L2 norm input/output dtype mismatch: {} vs {}",
input.dtype(),
output.dtype()
)));
}
let kernel_name = match input.dtype() {
DType::F32 => "l2_norm_f32",
DType::F16 => "l2_norm_f16",
DType::BF16 => "l2_norm_bf16",
_ => {
return Err(MlxError::InvalidArgument(format!(
"L2 norm unsupported dtype: {}",
input.dtype()
)));
}
};
let pipeline = registry.get_pipeline(kernel_name, device)?;
let tg_size = std::cmp::min(256, dim.next_power_of_two()) as u64;
let shared_mem_bytes = tg_size * 4;
encoder.encode_threadgroups_with_shared(
pipeline,
&[(0, input), (1, output), (2, params_buf)],
&[(0, shared_mem_bytes)],
MTLSize::new(rows as u64, 1, 1),
MTLSize::new(tg_size, 1, 1),
);
Ok(())
}
pub fn dispatch_l2_norm_scale_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
params_buf: &MlxBuffer,
rows: u32,
dim: u32,
) -> Result<()> {
if rows == 0 || dim == 0 {
return Err(MlxError::InvalidArgument(
"L2 norm scale rows and dim must be > 0".into(),
));
}
let expected = (rows as usize) * (dim as usize);
if input.element_count() != expected {
return Err(MlxError::InvalidArgument(format!(
"L2 norm scale input element count {} != rows({}) * dim({})",
input.element_count(),
rows,
dim
)));
}
if output.element_count() != expected {
return Err(MlxError::InvalidArgument(format!(
"L2 norm scale output element count {} != rows({}) * dim({})",
output.element_count(),
rows,
dim
)));
}
if input.dtype() != output.dtype() {
return Err(MlxError::InvalidArgument(format!(
"L2 norm scale input/output dtype mismatch: {} vs {}",
input.dtype(),
output.dtype()
)));
}
if input.dtype() != DType::F32 {
return Err(MlxError::InvalidArgument(format!(
"L2 norm scale only supports f32 (got {})",
input.dtype()
)));
}
let pipeline = registry.get_pipeline("l2_norm_scale_f32", device)?;
let tg_size = std::cmp::min(256, dim.next_power_of_two()) as u64;
let shared_mem_bytes = tg_size * 4;
encoder.encode_threadgroups_with_shared(
pipeline,
&[(0, input), (1, output), (2, params_buf)],
&[(0, shared_mem_bytes)],
MTLSize::new(rows as u64, 1, 1),
MTLSize::new(tg_size, 1, 1),
);
Ok(())
}