mlx_native/ops/cumsum.rs
1//! Cumulative sum (inclusive prefix sum) along the last axis.
2//!
3//! Computes `out[r, i] = sum(x[r, 0..=i])` for every row `r`.
4//!
5//! Used by the chunked Gated DeltaNet path to produce the decay-mask base
6//! (ADR-013 Decision 4). Spec derived from the definition of an inclusive
7//! prefix scan.
8//!
9//! # Algorithm
10//!
11//! Hillis-Steele scan in threadgroup shared memory. Each thread owns a
12//! contiguous chunk of the row; it scans locally first, exchanges chunk
13//! totals via a shared-memory scan, then adds the exclusive prefix of
14//! preceding threads to its chunk.
15//!
16//! # Threadgroup shape
17//!
18//! One threadgroup per row. `tg_size = min(256, next_power_of_two(dim))`
19//! and each thread handles `ceil_div(dim, tg_size)` elements. The shader
20//! caps the per-thread chunk at `CUMSUM_MAX_CHUNK = 32`, so for `tg_size =
21//! 256` the maximum `dim` handled in a single dispatch is 8192.
22//!
23//! Reduction and prefix arithmetic are performed in f32 regardless of input
24//! dtype for numerical stability.
25use metal::MTLSize;
26
27use crate::buffer::MlxBuffer;
28use crate::dtypes::DType;
29use crate::encoder::CommandEncoder;
30use crate::error::{MlxError, Result};
31use crate::kernel_registry::KernelRegistry;
32
33pub static CUMSUM_SHADER_SOURCE: &str = include_str!("../shaders/cumsum.metal");
34
35/// Register cumsum shader sources with the given kernel registry.
36pub fn register(registry: &mut KernelRegistry) {
37 registry.register_source("cumsum_f32", CUMSUM_SHADER_SOURCE);
38 registry.register_source("cumsum_bf16", CUMSUM_SHADER_SOURCE);
39}
40
41/// Maximum per-thread chunk; keep in sync with `CUMSUM_MAX_CHUNK` in
42/// `cumsum.metal`. Threads must not exceed this chunk size or they'll
43/// overrun their private buffer.
44const SHADER_MAX_CHUNK: u32 = 32;
45
46/// Dispatch an inclusive prefix sum along the last axis of `[rows, dim]`.
47///
48/// # Arguments
49///
50/// * `input` - shape `[rows, dim]`, f32 or bf16.
51/// * `output` - same shape + dtype as `input`.
52/// * `params_buf` - one u32: `[dim]`. The kernel also reads `tg_size` from
53/// the Metal dispatch so only `dim` needs to be in-buffer.
54/// * `rows` - number of independent rows.
55/// * `dim` - length of each row; must satisfy
56/// `ceil_div(dim, tg_size) <= 32`.
57///
58/// # Errors
59///
60/// Returns `MlxError::InvalidArgument` if:
61/// - `rows == 0` or `dim == 0`.
62/// - input/output element counts disagree with `rows * dim`.
63/// - input and output dtypes differ, or are not f32/bf16.
64/// - `dim` exceeds `tg_size * SHADER_MAX_CHUNK` (8192 for the default
65/// `tg_size = 256`).
66pub fn dispatch_cumsum(
67 encoder: &mut CommandEncoder,
68 registry: &mut KernelRegistry,
69 device: &metal::DeviceRef,
70 input: &MlxBuffer,
71 output: &MlxBuffer,
72 params_buf: &MlxBuffer,
73 rows: u32,
74 dim: u32,
75) -> Result<()> {
76 if rows == 0 || dim == 0 {
77 return Err(MlxError::InvalidArgument(
78 "cumsum rows and dim must be > 0".into(),
79 ));
80 }
81
82 let expected = (rows as usize) * (dim as usize);
83 if input.element_count() != expected {
84 return Err(MlxError::InvalidArgument(format!(
85 "cumsum input element count {} != rows({}) * dim({})",
86 input.element_count(),
87 rows,
88 dim
89 )));
90 }
91 if output.element_count() != expected {
92 return Err(MlxError::InvalidArgument(format!(
93 "cumsum output element count {} != rows({}) * dim({})",
94 output.element_count(),
95 rows,
96 dim
97 )));
98 }
99 if input.dtype() != output.dtype() {
100 return Err(MlxError::InvalidArgument(format!(
101 "cumsum input/output dtype mismatch: {} vs {}",
102 input.dtype(),
103 output.dtype()
104 )));
105 }
106
107 let kernel_name = match input.dtype() {
108 DType::F32 => "cumsum_f32",
109 DType::BF16 => "cumsum_bf16",
110 _ => {
111 return Err(MlxError::InvalidArgument(format!(
112 "cumsum unsupported dtype: {}",
113 input.dtype()
114 )));
115 }
116 };
117
118 // Choose threadgroup size: smallest power of two >= dim, capped at 256.
119 // Must be a power of two for the Hillis-Steele offset loop. Must be >= 1
120 // (true because we rejected dim == 0 above).
121 let tg_size = std::cmp::min(256u32, dim.next_power_of_two());
122 let tg_size = std::cmp::max(tg_size, 1u32);
123
124 let chunk = dim.div_ceil(tg_size);
125 if chunk > SHADER_MAX_CHUNK {
126 return Err(MlxError::InvalidArgument(format!(
127 "cumsum dim {} exceeds supported limit: tg_size {} * chunk {} < dim",
128 dim, tg_size, SHADER_MAX_CHUNK
129 )));
130 }
131
132 let pipeline = registry.get_pipeline(kernel_name, device)?;
133
134 // Shared memory: tg_size floats for the cross-thread scan.
135 let shared_mem_bytes = (tg_size as u64) * 4;
136
137 encoder.encode_threadgroups_with_shared(
138 pipeline,
139 &[(0, input), (1, output), (2, params_buf)],
140 &[(0, shared_mem_bytes)],
141 MTLSize::new(rows as u64, 1, 1),
142 MTLSize::new(tg_size as u64, 1, 1),
143 );
144
145 Ok(())
146}