mlx_native/ops/encode_helpers.rs
1//! Helper utilities for encoding compute dispatches with inline constant
2//! parameters (bytes) alongside buffer bindings.
3//!
4//! The base [`CommandEncoder::encode`] method only supports buffer bindings.
5//! These helpers extend encoding to support Metal `set_bytes` for small
6//! constant parameter structs, which avoids allocating a full Metal buffer
7//! for a few bytes of configuration data.
8//!
9//! `KernelArg` and `as_bytes` are defined in `crate::encoder` and re-exported
10//! here for backward compatibility.
11
12use metal::{ComputePipelineStateRef, MTLSize};
13
14use crate::encoder::CommandEncoder;
15
16// Re-export from encoder module where KernelArg now lives.
17pub use crate::encoder::{KernelArg, as_bytes};
18
19/// Encode a compute pass with mixed buffer and bytes bindings.
20///
21/// This is an extension of [`CommandEncoder::encode`] that additionally
22/// supports `set_bytes` for small constant parameter structs.
23///
24/// # Arguments
25///
26/// * `encoder` — The command encoder to record into.
27/// * `pipeline` — The compiled compute pipeline.
28/// * `bindings` — Slice of `(index, KernelArg)` pairs.
29/// * `grid_size` — Total threads to launch.
30/// * `threadgroup_size` — Threads per threadgroup.
31pub fn encode_with_args(
32 encoder: &mut CommandEncoder,
33 pipeline: &ComputePipelineStateRef,
34 bindings: &[(u64, KernelArg<'_>)],
35 grid_size: MTLSize,
36 threadgroup_size: MTLSize,
37) {
38 // Use the encoder's persistent compute encoder via encode_with_args_dispatch.
39 // This delegates to CommandEncoder's own dispatch methods that reuse the
40 // same compute encoder across calls.
41 encoder.encode_with_args(pipeline, bindings, grid_size, threadgroup_size);
42}
43
44/// Encode a compute pass with threadgroups and mixed buffer/bytes bindings.
45pub fn encode_threadgroups_with_args(
46 encoder: &mut CommandEncoder,
47 pipeline: &ComputePipelineStateRef,
48 bindings: &[(u64, KernelArg<'_>)],
49 threadgroups: MTLSize,
50 threadgroup_size: MTLSize,
51) {
52 encoder.encode_threadgroups_with_args(pipeline, bindings, threadgroups, threadgroup_size);
53}
54
55/// Encode a compute pass with threadgroups, mixed buffer/bytes bindings, and
56/// threadgroup shared memory allocations.
57///
58/// Combines the capabilities of [`encode_threadgroups_with_args`] (inline bytes
59/// via `set_bytes`) and the encoder's `encode_threadgroups_with_shared` (shared
60/// memory allocation). Required by fused kernels that need both constant-struct
61/// parameters and a threadgroup scratch buffer for reduction.
62///
63/// # Arguments
64///
65/// * `encoder` — The command encoder to record into.
66/// * `pipeline` — The compiled compute pipeline.
67/// * `bindings` — Slice of `(index, KernelArg)` pairs.
68/// * `threadgroup_mem` — Slice of `(index, byte_length)` pairs for threadgroup memory.
69/// * `threadgroups` — Number of threadgroups to dispatch.
70/// * `threadgroup_size` — Threads per threadgroup.
71pub fn encode_threadgroups_with_args_and_shared(
72 encoder: &mut CommandEncoder,
73 pipeline: &ComputePipelineStateRef,
74 bindings: &[(u64, KernelArg<'_>)],
75 threadgroup_mem: &[(u64, u64)],
76 threadgroups: MTLSize,
77 threadgroup_size: MTLSize,
78) {
79 encoder.encode_threadgroups_with_args_and_shared(
80 pipeline,
81 bindings,
82 threadgroup_mem,
83 threadgroups,
84 threadgroup_size,
85 );
86}
87
88// as_bytes is re-exported from crate::encoder above.