pub fn elementwise_mul(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &DeviceRef,
a: &MlxBuffer,
b: &MlxBuffer,
output: &MlxBuffer,
n_elements: usize,
dtype: DType,
) -> Result<()>Expand description
Encode elementwise multiplication: output = a * b.
Both inputs and output must have the same dtype (F32 or F16).