Skip to main content

mlx_native/ops/
elementwise.rs

1//! GPU-accelerated elementwise operations: add, multiply, and dtype cast.
2//!
3//! These kernels are used for residual connections (add), scaling (multiply),
4//! and dtype conversion (cast) in the inference pipeline.
5
6use metal::MTLSize;
7
8use crate::buffer::MlxBuffer;
9use crate::dtypes::DType;
10use crate::encoder::{CapturedOpKind, CommandEncoder};
11use crate::error::{MlxError, Result};
12use crate::kernel_registry::KernelRegistry;
13
14use super::encode_helpers::{as_bytes, encode_with_args, KernelArg};
15
16/// MSL-compatible params struct for elementwise kernels.
17///
18/// Must match `ElementwiseParams` in `elementwise.metal`.
19#[repr(C)]
20#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
21struct GpuElementwiseParams {
22    n_elements: u32,
23}
24
25/// Threadgroup size for elementwise dispatches.
26const ELEMENTWISE_TG_SIZE: u64 = 256;
27
28/// Select the correct kernel name for a dtype-specific elementwise op.
29fn elementwise_kernel_name(op: &str, dtype: DType) -> Result<&'static str> {
30    match (op, dtype) {
31        ("add", DType::F32) => Ok("elementwise_add_f32"),
32        ("add", DType::F16) => Ok("elementwise_add_f16"),
33        ("add", DType::BF16) => Ok("elementwise_add_bf16"),
34        ("mul", DType::F32) => Ok("elementwise_mul_f32"),
35        ("mul", DType::F16) => Ok("elementwise_mul_f16"),
36        ("mul", DType::BF16) => Ok("elementwise_mul_bf16"),
37        _ => Err(MlxError::InvalidArgument(format!(
38            "elementwise_{op}: unsupported dtype {dtype}"
39        ))),
40    }
41}
42
43/// Encode an elementwise binary operation (add or multiply).
44///
45/// Both inputs and the output must have the same dtype and at least
46/// `n_elements * dtype.size_of()` bytes.
47///
48/// # Errors
49///
50/// Returns `MlxError::InvalidArgument` if:
51/// * `n_elements` is zero
52/// * dtype is not F32 or F16
53/// * Buffers are too small
54#[allow(clippy::too_many_arguments)]
55fn elementwise_binary(
56    encoder: &mut CommandEncoder,
57    registry: &mut KernelRegistry,
58    device: &metal::DeviceRef,
59    a: &MlxBuffer,
60    b: &MlxBuffer,
61    output: &MlxBuffer,
62    n_elements: usize,
63    op: &str,
64    dtype: DType,
65) -> Result<()> {
66    if n_elements == 0 {
67        return Err(MlxError::InvalidArgument(format!(
68            "elementwise_{op}: n_elements must be > 0"
69        )));
70    }
71
72    let elem_bytes = n_elements * dtype.size_of();
73    if a.byte_len() < elem_bytes {
74        return Err(MlxError::InvalidArgument(format!(
75            "elementwise_{op}: input 'a' buffer too small: need {} bytes, have {}",
76            elem_bytes,
77            a.byte_len()
78        )));
79    }
80    if b.byte_len() < elem_bytes {
81        return Err(MlxError::InvalidArgument(format!(
82            "elementwise_{op}: input 'b' buffer too small: need {} bytes, have {}",
83            elem_bytes,
84            b.byte_len()
85        )));
86    }
87    if output.byte_len() < elem_bytes {
88        return Err(MlxError::InvalidArgument(format!(
89            "elementwise_{op}: output buffer too small: need {} bytes, have {}",
90            elem_bytes,
91            output.byte_len()
92        )));
93    }
94
95    let kernel_name = elementwise_kernel_name(op, dtype)?;
96    let pipeline = registry.get_pipeline(kernel_name, device)?;
97
98    let gpu_params = GpuElementwiseParams {
99        n_elements: n_elements as u32,
100    };
101
102    let grid = MTLSize::new(n_elements as u64, 1, 1);
103    let tg = MTLSize::new(std::cmp::min(ELEMENTWISE_TG_SIZE, n_elements as u64), 1, 1);
104
105    // Tag for the fusion pass (Phase 4e.2): elementwise mul/add can fuse
106    // with a preceding RMS norm.
107    let op_tag = match op {
108        "mul" => CapturedOpKind::ElemMul,
109        "add" => CapturedOpKind::ElemAdd,
110        _ => CapturedOpKind::Other,
111    };
112    encoder.set_op_kind(op_tag);
113
114    encode_with_args(
115        encoder,
116        pipeline,
117        &[
118            (0, KernelArg::Buffer(a)),
119            (1, KernelArg::Buffer(b)),
120            (2, KernelArg::Buffer(output)),
121            (3, KernelArg::Bytes(as_bytes(&gpu_params))),
122        ],
123        grid,
124        tg,
125    );
126
127    Ok(())
128}
129
130/// Encode elementwise addition: `output = a + b`.
131///
132/// Both inputs and output must have the same dtype (F32 or F16).
133#[allow(clippy::too_many_arguments)]
134pub fn elementwise_add(
135    encoder: &mut CommandEncoder,
136    registry: &mut KernelRegistry,
137    device: &metal::DeviceRef,
138    a: &MlxBuffer,
139    b: &MlxBuffer,
140    output: &MlxBuffer,
141    n_elements: usize,
142    dtype: DType,
143) -> Result<()> {
144    elementwise_binary(encoder, registry, device, a, b, output, n_elements, "add", dtype)
145}
146
147/// Encode elementwise multiplication: `output = a * b`.
148///
149/// Both inputs and output must have the same dtype (F32 or F16).
150#[allow(clippy::too_many_arguments)]
151pub fn elementwise_mul(
152    encoder: &mut CommandEncoder,
153    registry: &mut KernelRegistry,
154    device: &metal::DeviceRef,
155    a: &MlxBuffer,
156    b: &MlxBuffer,
157    output: &MlxBuffer,
158    n_elements: usize,
159    dtype: DType,
160) -> Result<()> {
161    elementwise_binary(encoder, registry, device, a, b, output, n_elements, "mul", dtype)
162}
163
164/// MSL-compatible params struct for scalar multiplication.
165///
166/// Must match `ScalarMulParams` in `elementwise.metal`.
167#[repr(C)]
168#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
169struct GpuScalarMulParams {
170    scalar: f32,
171    count: u32,
172}
173
174/// Encode scalar multiplication: `output[i] = input[i] * scalar` (bf16).
175///
176/// # Arguments
177///
178/// * `encoder`    - Command encoder to record the dispatch into.
179/// * `registry`   - Kernel registry.
180/// * `device`     - Metal device for pipeline compilation.
181/// * `input`      - Input buffer (bf16).
182/// * `output`     - Output buffer (bf16, same size as input).
183/// * `n_elements` - Number of elements to process.
184/// * `scalar`     - The f32 scalar to multiply by.
185///
186/// # Errors
187///
188/// Returns `MlxError::InvalidArgument` if `n_elements` is zero or buffers are too small.
189pub fn scalar_mul_bf16(
190    encoder: &mut CommandEncoder,
191    registry: &mut KernelRegistry,
192    device: &metal::DeviceRef,
193    input: &MlxBuffer,
194    output: &MlxBuffer,
195    n_elements: usize,
196    scalar: f32,
197) -> Result<()> {
198    if n_elements == 0 {
199        return Err(MlxError::InvalidArgument(
200            "scalar_mul_bf16: n_elements must be > 0".into(),
201        ));
202    }
203
204    let elem_bytes = n_elements * DType::BF16.size_of();
205    if input.byte_len() < elem_bytes {
206        return Err(MlxError::InvalidArgument(format!(
207            "scalar_mul_bf16: input buffer too small: need {} bytes, have {}",
208            elem_bytes,
209            input.byte_len()
210        )));
211    }
212    if output.byte_len() < elem_bytes {
213        return Err(MlxError::InvalidArgument(format!(
214            "scalar_mul_bf16: output buffer too small: need {} bytes, have {}",
215            elem_bytes,
216            output.byte_len()
217        )));
218    }
219
220    let pipeline = registry.get_pipeline("scalar_mul_bf16", device)?;
221
222    let gpu_params = GpuScalarMulParams {
223        scalar,
224        count: n_elements as u32,
225    };
226
227    let grid = MTLSize::new(n_elements as u64, 1, 1);
228    let tg = MTLSize::new(std::cmp::min(ELEMENTWISE_TG_SIZE, n_elements as u64), 1, 1);
229
230    encode_with_args(
231        encoder,
232        pipeline,
233        &[
234            (0, KernelArg::Buffer(input)),
235            (1, KernelArg::Buffer(output)),
236            (2, KernelArg::Bytes(as_bytes(&gpu_params))),
237        ],
238        grid,
239        tg,
240    );
241
242    Ok(())
243}
244
245/// Encode scalar multiplication: `output[i] = input[i] * scalar` (f32).
246///
247/// # Arguments
248///
249/// * `encoder`    - Command encoder to record the dispatch into.
250/// * `registry`   - Kernel registry.
251/// * `device`     - Metal device for pipeline compilation.
252/// * `input`      - Input buffer (f32).
253/// * `output`     - Output buffer (f32, same size as input).
254/// * `n_elements` - Number of elements to process.
255/// * `scalar`     - The f32 scalar to multiply by.
256///
257/// # Errors
258///
259/// Returns `MlxError::InvalidArgument` if `n_elements` is zero or buffers are too small.
260pub fn scalar_mul_f32(
261    encoder: &mut CommandEncoder,
262    registry: &mut KernelRegistry,
263    device: &metal::DeviceRef,
264    input: &MlxBuffer,
265    output: &MlxBuffer,
266    n_elements: usize,
267    scalar: f32,
268) -> Result<()> {
269    if n_elements == 0 {
270        return Err(MlxError::InvalidArgument(
271            "scalar_mul_f32: n_elements must be > 0".into(),
272        ));
273    }
274
275    let elem_bytes = n_elements * DType::F32.size_of();
276    if input.byte_len() < elem_bytes {
277        return Err(MlxError::InvalidArgument(format!(
278            "scalar_mul_f32: input buffer too small: need {} bytes, have {}",
279            elem_bytes,
280            input.byte_len()
281        )));
282    }
283    if output.byte_len() < elem_bytes {
284        return Err(MlxError::InvalidArgument(format!(
285            "scalar_mul_f32: output buffer too small: need {} bytes, have {}",
286            elem_bytes,
287            output.byte_len()
288        )));
289    }
290
291    let pipeline = registry.get_pipeline("scalar_mul_f32", device)?;
292
293    let gpu_params = GpuScalarMulParams {
294        scalar,
295        count: n_elements as u32,
296    };
297
298    let grid = MTLSize::new(n_elements as u64, 1, 1);
299    let tg = MTLSize::new(std::cmp::min(ELEMENTWISE_TG_SIZE, n_elements as u64), 1, 1);
300
301    encode_with_args(
302        encoder,
303        pipeline,
304        &[
305            (0, KernelArg::Buffer(input)),
306            (1, KernelArg::Buffer(output)),
307            (2, KernelArg::Bytes(as_bytes(&gpu_params))),
308        ],
309        grid,
310        tg,
311    );
312
313    Ok(())
314}
315
316/// MSL-compatible params struct for embedding_gather_scale_f32 kernel.
317#[repr(C)]
318#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
319struct GpuEmbedGatherScaleParams {
320    scale: f32,
321    hidden_size: u32,
322    token_id: u32,
323}
324
325/// Encode an embedding gather + scale: `output[i] = embed[token_id * hs + i] * scale`.
326///
327/// # Arguments
328///
329/// * `encoder`     - Command encoder.
330/// * `registry`    - Kernel registry.
331/// * `device`      - Metal device.
332/// * `embed_table` - f32 `[vocab_size * hidden_size]`.
333/// * `output`      - f32 `[hidden_size]`.
334/// * `token_id`    - Token index into the embedding table.
335/// * `hidden_size` - Embedding dimension.
336/// * `scale`       - Scale factor (e.g. sqrt(hidden_size)).
337pub fn embedding_gather_scale_f32(
338    encoder: &mut CommandEncoder,
339    registry: &mut KernelRegistry,
340    device: &metal::DeviceRef,
341    embed_table: &MlxBuffer,
342    output: &MlxBuffer,
343    token_id: u32,
344    hidden_size: usize,
345    scale: f32,
346) -> Result<()> {
347    if hidden_size == 0 {
348        return Err(MlxError::InvalidArgument(
349            "embedding_gather_scale_f32: hidden_size must be > 0".into(),
350        ));
351    }
352    let out_bytes = hidden_size * std::mem::size_of::<f32>();
353    if output.byte_len() < out_bytes {
354        return Err(MlxError::InvalidArgument(format!(
355            "embedding_gather_scale_f32: output too small: need {} bytes, have {}",
356            out_bytes, output.byte_len()
357        )));
358    }
359
360    let pipeline = registry.get_pipeline("embedding_gather_scale_f32", device)?;
361
362    let gpu_params = GpuEmbedGatherScaleParams {
363        scale,
364        hidden_size: hidden_size as u32,
365        token_id,
366    };
367
368    let grid = MTLSize::new(hidden_size as u64, 1, 1);
369    let tg = MTLSize::new(std::cmp::min(ELEMENTWISE_TG_SIZE, hidden_size as u64), 1, 1);
370
371    encode_with_args(
372        encoder,
373        pipeline,
374        &[
375            (0, KernelArg::Buffer(embed_table)),
376            (1, KernelArg::Buffer(output)),
377            (2, KernelArg::Bytes(as_bytes(&gpu_params))),
378        ],
379        grid,
380        tg,
381    );
382
383    Ok(())
384}
385
386/// GPU params struct for batched embedding gather + scale.
387/// Must match `EmbedGatherScaleBatchParams` in the MSL shader.
388#[repr(C)]
389#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
390struct GpuEmbedGatherScaleBatchParams {
391    scale: f32,
392    hidden_size: u32,
393    n_tokens: u32,
394}
395
396/// Batched embedding gather + scale for prefill (f32).
397///
398/// Reads `token_ids[tok]` for each `tok in 0..n_tokens`, gathers the
399/// embedding row from `embed_table`, multiplies by `scale`, and writes to
400/// `output[tok * hidden_size + i]`.
401///
402/// * `embed_table` — f32 `[vocab_size * hidden_size]`
403/// * `token_ids`   — u32 `[n_tokens]`
404/// * `output`      — f32 `[n_tokens * hidden_size]`
405#[allow(clippy::too_many_arguments)]
406pub fn embedding_gather_scale_batch_f32(
407    encoder: &mut CommandEncoder,
408    registry: &mut KernelRegistry,
409    device: &metal::DeviceRef,
410    embed_table: &MlxBuffer,
411    token_ids: &MlxBuffer,
412    output: &MlxBuffer,
413    hidden_size: usize,
414    n_tokens: usize,
415    scale: f32,
416) -> Result<()> {
417    if hidden_size == 0 || n_tokens == 0 {
418        return Err(MlxError::InvalidArgument(
419            "embedding_gather_scale_batch_f32: hidden_size and n_tokens must be > 0".into(),
420        ));
421    }
422    let out_bytes = n_tokens * hidden_size * std::mem::size_of::<f32>();
423    if output.byte_len() < out_bytes {
424        return Err(MlxError::InvalidArgument(format!(
425            "embedding_gather_scale_batch_f32: output too small: need {} bytes, have {}",
426            out_bytes, output.byte_len()
427        )));
428    }
429    let ids_bytes = n_tokens * std::mem::size_of::<u32>();
430    if token_ids.byte_len() < ids_bytes {
431        return Err(MlxError::InvalidArgument(format!(
432            "embedding_gather_scale_batch_f32: token_ids too small: need {} bytes, have {}",
433            ids_bytes, token_ids.byte_len()
434        )));
435    }
436
437    let pipeline = registry.get_pipeline("embedding_gather_scale_batch_f32", device)?;
438
439    let gpu_params = GpuEmbedGatherScaleBatchParams {
440        scale,
441        hidden_size: hidden_size as u32,
442        n_tokens: n_tokens as u32,
443    };
444
445    let grid = MTLSize::new(hidden_size as u64, n_tokens as u64, 1);
446    let tg = MTLSize::new(
447        std::cmp::min(ELEMENTWISE_TG_SIZE, hidden_size as u64),
448        1, 1,
449    );
450
451    encode_with_args(
452        encoder,
453        pipeline,
454        &[
455            (0, KernelArg::Buffer(embed_table)),
456            (1, KernelArg::Buffer(token_ids)),
457            (2, KernelArg::Buffer(output)),
458            (3, KernelArg::Bytes(as_bytes(&gpu_params))),
459        ],
460        grid,
461        tg,
462    );
463
464    Ok(())
465}
466
467/// Cast f32 to bf16 using an externally-provided encoder (no commit).
468///
469/// Encodes the `cast_f32_to_bf16` kernel into the given encoder without
470/// committing or waiting.  Use this to chain the cast into a mega-encoder
471/// alongside other GPU work, avoiding CPU round-trips.
472///
473/// # Arguments
474///
475/// * `encoder`    - Command encoder to record the dispatch into.
476/// * `registry`   - Kernel registry (mutable for lazy pipeline compilation).
477/// * `device`     - Metal device for pipeline compilation.
478/// * `input`      - Input buffer (f32).
479/// * `output`     - Output buffer (bf16, pre-allocated with `n_elements * 2` bytes).
480/// * `n_elements` - Number of elements to cast.
481///
482/// # Errors
483///
484/// Returns `MlxError::InvalidArgument` if `n_elements` is zero or buffers are
485/// too small.
486pub fn dispatch_cast_f32_to_bf16_with_encoder(
487    encoder: &mut CommandEncoder,
488    registry: &mut KernelRegistry,
489    device: &metal::DeviceRef,
490    input: &MlxBuffer,
491    output: &MlxBuffer,
492    n_elements: u32,
493) -> Result<()> {
494    cast(
495        encoder,
496        registry,
497        device,
498        input,
499        output,
500        n_elements as usize,
501        CastDirection::F32ToBF16,
502    )
503}
504
505/// Cast bf16 to f32 using an externally-provided encoder (no commit).
506///
507/// Encodes the `cast_bf16_to_f32` kernel into the given encoder without
508/// committing or waiting.  Use this to chain the cast into a mega-encoder
509/// alongside other GPU work, avoiding CPU round-trips.
510///
511/// # Arguments
512///
513/// * `encoder`    - Command encoder to record the dispatch into.
514/// * `registry`   - Kernel registry (mutable for lazy pipeline compilation).
515/// * `device`     - Metal device for pipeline compilation.
516/// * `input`      - Input buffer (bf16).
517/// * `output`     - Output buffer (f32, pre-allocated with `n_elements * 4` bytes).
518/// * `n_elements` - Number of elements to cast.
519///
520/// # Errors
521///
522/// Returns `MlxError::InvalidArgument` if `n_elements` is zero or buffers are
523/// too small.
524pub fn dispatch_cast_bf16_to_f32_with_encoder(
525    encoder: &mut CommandEncoder,
526    registry: &mut KernelRegistry,
527    device: &metal::DeviceRef,
528    input: &MlxBuffer,
529    output: &MlxBuffer,
530    n_elements: u32,
531) -> Result<()> {
532    cast(
533        encoder,
534        registry,
535        device,
536        input,
537        output,
538        n_elements as usize,
539        CastDirection::BF16ToF32,
540    )
541}
542
543/// Scale bf16 values by a scalar using an externally-provided encoder (no commit).
544///
545/// Encodes `output[i] = input[i] * scalar` (bf16) into the given encoder
546/// without committing or waiting.  Use this to chain the scale into a
547/// mega-encoder alongside other GPU work, avoiding CPU round-trips.
548///
549/// # Arguments
550///
551/// * `encoder`    - Command encoder to record the dispatch into.
552/// * `registry`   - Kernel registry (mutable for lazy pipeline compilation).
553/// * `device`     - Metal device for pipeline compilation.
554/// * `input`      - Input buffer (bf16).
555/// * `output`     - Output buffer (bf16, same size as input).
556/// * `n_elements` - Number of elements to process.
557/// * `scalar`     - The f32 scalar to multiply by (e.g. `sqrt(hidden_size)`).
558///
559/// # Errors
560///
561/// Returns `MlxError::InvalidArgument` if `n_elements` is zero or buffers are
562/// too small.
563pub fn dispatch_scalar_mul_bf16_with_encoder(
564    encoder: &mut CommandEncoder,
565    registry: &mut KernelRegistry,
566    device: &metal::DeviceRef,
567    input: &MlxBuffer,
568    output: &MlxBuffer,
569    n_elements: u32,
570    scalar: f32,
571) -> Result<()> {
572    scalar_mul_bf16(
573        encoder,
574        registry,
575        device,
576        input,
577        output,
578        n_elements as usize,
579        scalar,
580    )
581}
582
583/// Cast direction for dtype conversion.
584pub enum CastDirection {
585    /// f16 -> f32
586    F16ToF32,
587    /// f32 -> f16
588    F32ToF16,
589    /// bf16 -> f32
590    BF16ToF32,
591    /// f32 -> bf16
592    F32ToBF16,
593}
594
595impl CastDirection {
596    fn kernel_name(&self) -> &'static str {
597        match self {
598            CastDirection::F16ToF32 => "cast_f16_to_f32",
599            CastDirection::F32ToF16 => "cast_f32_to_f16",
600            CastDirection::BF16ToF32 => "cast_bf16_to_f32",
601            CastDirection::F32ToBF16 => "cast_f32_to_bf16",
602        }
603    }
604
605    fn input_elem_size(&self) -> usize {
606        match self {
607            CastDirection::F16ToF32 | CastDirection::BF16ToF32 => 2,
608            CastDirection::F32ToF16 | CastDirection::F32ToBF16 => 4,
609        }
610    }
611
612    fn output_elem_size(&self) -> usize {
613        match self {
614            CastDirection::F16ToF32 | CastDirection::BF16ToF32 => 4,
615            CastDirection::F32ToF16 | CastDirection::F32ToBF16 => 2,
616        }
617    }
618}
619
620/// Encode a dtype cast operation.
621///
622/// # Errors
623///
624/// Returns `MlxError::InvalidArgument` if `n_elements` is zero or buffers
625/// are too small.
626pub fn cast(
627    encoder: &mut CommandEncoder,
628    registry: &mut KernelRegistry,
629    device: &metal::DeviceRef,
630    input: &MlxBuffer,
631    output: &MlxBuffer,
632    n_elements: usize,
633    direction: CastDirection,
634) -> Result<()> {
635    if n_elements == 0 {
636        return Err(MlxError::InvalidArgument(
637            "cast: n_elements must be > 0".into(),
638        ));
639    }
640
641    let input_bytes = n_elements * direction.input_elem_size();
642    if input.byte_len() < input_bytes {
643        return Err(MlxError::InvalidArgument(format!(
644            "cast: input buffer too small: need {} bytes, have {}",
645            input_bytes,
646            input.byte_len()
647        )));
648    }
649
650    let output_bytes = n_elements * direction.output_elem_size();
651    if output.byte_len() < output_bytes {
652        return Err(MlxError::InvalidArgument(format!(
653            "cast: output buffer too small: need {} bytes, have {}",
654            output_bytes,
655            output.byte_len()
656        )));
657    }
658
659    let pipeline = registry.get_pipeline(direction.kernel_name(), device)?;
660
661    let gpu_params = GpuElementwiseParams {
662        n_elements: n_elements as u32,
663    };
664
665    let grid = MTLSize::new(n_elements as u64, 1, 1);
666    let tg = MTLSize::new(std::cmp::min(ELEMENTWISE_TG_SIZE, n_elements as u64), 1, 1);
667
668    encode_with_args(
669        encoder,
670        pipeline,
671        &[
672            (0, KernelArg::Buffer(input)),
673            (1, KernelArg::Buffer(output)),
674            (2, KernelArg::Bytes(as_bytes(&gpu_params))),
675        ],
676        grid,
677        tg,
678    );
679
680    Ok(())
681}