Skip to main content

mlx_native/ops/
encode_helpers.rs

1//! Helper utilities for encoding compute dispatches with inline constant
2//! parameters (bytes) alongside buffer bindings.
3//!
4//! The base [`CommandEncoder::encode`] method only supports buffer bindings.
5//! These helpers extend encoding to support Metal `set_bytes` for small
6//! constant parameter structs, which avoids allocating a full Metal buffer
7//! for a few bytes of configuration data.
8//!
9//! `KernelArg` and `as_bytes` are defined in `crate::encoder` and re-exported
10//! here for backward compatibility.
11
12use metal::{ComputePipelineStateRef, MTLSize};
13
14use crate::encoder::CommandEncoder;
15
16// Re-export from encoder module where KernelArg now lives.
17pub use crate::encoder::{KernelArg, as_bytes};
18
19/// Encode a compute pass with mixed buffer and bytes bindings.
20///
21/// This is an extension of [`CommandEncoder::encode`] that additionally
22/// supports `set_bytes` for small constant parameter structs.
23///
24/// # Arguments
25///
26/// * `encoder`          — The command encoder to record into.
27/// * `pipeline`         — The compiled compute pipeline.
28/// * `bindings`         — Slice of `(index, KernelArg)` pairs.
29/// * `grid_size`        — Total threads to launch.
30/// * `threadgroup_size` — Threads per threadgroup.
31pub fn encode_with_args(
32    encoder: &mut CommandEncoder,
33    pipeline: &ComputePipelineStateRef,
34    bindings: &[(u64, KernelArg<'_>)],
35    grid_size: MTLSize,
36    threadgroup_size: MTLSize,
37) {
38    // Use the encoder's persistent compute encoder via encode_with_args_dispatch.
39    // This delegates to CommandEncoder's own dispatch methods that reuse the
40    // same compute encoder across calls.
41    encoder.encode_with_args(pipeline, bindings, grid_size, threadgroup_size);
42}
43
44/// Encode a compute pass with threadgroups and mixed buffer/bytes bindings.
45pub fn encode_threadgroups_with_args(
46    encoder: &mut CommandEncoder,
47    pipeline: &ComputePipelineStateRef,
48    bindings: &[(u64, KernelArg<'_>)],
49    threadgroups: MTLSize,
50    threadgroup_size: MTLSize,
51) {
52    encoder.encode_threadgroups_with_args(pipeline, bindings, threadgroups, threadgroup_size);
53}
54
55/// Encode a compute pass with threadgroups, mixed buffer/bytes bindings, and
56/// threadgroup shared memory allocations.
57///
58/// Combines the capabilities of [`encode_threadgroups_with_args`] (inline bytes
59/// via `set_bytes`) and the encoder's `encode_threadgroups_with_shared` (shared
60/// memory allocation).  Required by fused kernels that need both constant-struct
61/// parameters and a threadgroup scratch buffer for reduction.
62///
63/// # Arguments
64///
65/// * `encoder`          — The command encoder to record into.
66/// * `pipeline`         — The compiled compute pipeline.
67/// * `bindings`         — Slice of `(index, KernelArg)` pairs.
68/// * `threadgroup_mem`  — Slice of `(index, byte_length)` pairs for threadgroup memory.
69/// * `threadgroups`     — Number of threadgroups to dispatch.
70/// * `threadgroup_size` — Threads per threadgroup.
71pub fn encode_threadgroups_with_args_and_shared(
72    encoder: &mut CommandEncoder,
73    pipeline: &ComputePipelineStateRef,
74    bindings: &[(u64, KernelArg<'_>)],
75    threadgroup_mem: &[(u64, u64)],
76    threadgroups: MTLSize,
77    threadgroup_size: MTLSize,
78) {
79    encoder.encode_threadgroups_with_args_and_shared(
80        pipeline,
81        bindings,
82        threadgroup_mem,
83        threadgroups,
84        threadgroup_size,
85    );
86}
87
88// as_bytes is re-exported from crate::encoder above.