pub fn dispatch_scalar_mul_bf16_with_encoder(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
n_elements: u32,
scalar: f32,
) -> Result<()>Expand description
Scale bf16 values by a scalar using an externally-provided encoder (no commit).
Encodes output[i] = input[i] * scalar (bf16) into the given encoder
without committing or waiting. Use this to chain the scale into a
mega-encoder alongside other GPU work, avoiding CPU round-trips.
§Arguments
encoder- Command encoder to record the dispatch into.registry- Kernel registry (mutable for lazy pipeline compilation).device- Metal device for pipeline compilation.input- Input buffer (bf16).output- Output buffer (bf16, same size as input).n_elements- Number of elements to process.scalar- The f32 scalar to multiply by (e.g.sqrt(hidden_size)).
§Errors
Returns MlxError::InvalidArgument if n_elements is zero or buffers are
too small.