Skip to main content

dispatch_scalar_mul_bf16_with_encoder

Function dispatch_scalar_mul_bf16_with_encoder 

Source
pub fn dispatch_scalar_mul_bf16_with_encoder(
    encoder: &mut CommandEncoder,
    registry: &mut KernelRegistry,
    device: &DeviceRef,
    input: &MlxBuffer,
    output: &MlxBuffer,
    n_elements: u32,
    scalar: f32,
) -> Result<()>
Expand description

Scale bf16 values by a scalar using an externally-provided encoder (no commit).

Encodes output[i] = input[i] * scalar (bf16) into the given encoder without committing or waiting. Use this to chain the scale into a mega-encoder alongside other GPU work, avoiding CPU round-trips.

§Arguments

  • encoder - Command encoder to record the dispatch into.
  • registry - Kernel registry (mutable for lazy pipeline compilation).
  • device - Metal device for pipeline compilation.
  • input - Input buffer (bf16).
  • output - Output buffer (bf16, same size as input).
  • n_elements - Number of elements to process.
  • scalar - The f32 scalar to multiply by (e.g. sqrt(hidden_size)).

§Errors

Returns MlxError::InvalidArgument if n_elements is zero or buffers are too small.