Skip to main content

mlx_native/ops/
cumsum.rs

1//! Cumulative sum (inclusive prefix sum) along the last axis.
2//!
3//! Computes `out[r, i] = sum(x[r, 0..=i])` for every row `r`.
4//!
5//! Used by the chunked Gated DeltaNet path to produce the decay-mask base
6//! (ADR-013 Decision 4). Spec derived from the definition of an inclusive
7//! prefix scan.
8//!
9//! # Algorithm
10//!
11//! Hillis-Steele scan in threadgroup shared memory. Each thread owns a
12//! contiguous chunk of the row; it scans locally first, exchanges chunk
13//! totals via a shared-memory scan, then adds the exclusive prefix of
14//! preceding threads to its chunk.
15//!
16//! # Threadgroup shape
17//!
18//! One threadgroup per row. `tg_size = min(256, next_power_of_two(dim))`
19//! and each thread handles `ceil_div(dim, tg_size)` elements. The shader
20//! caps the per-thread chunk at `CUMSUM_MAX_CHUNK = 32`, so for `tg_size =
21//! 256` the maximum `dim` handled in a single dispatch is 8192.
22//!
23//! Reduction and prefix arithmetic are performed in f32 regardless of input
24//! dtype for numerical stability.
25use metal::MTLSize;
26
27use crate::buffer::MlxBuffer;
28use crate::dtypes::DType;
29use crate::encoder::CommandEncoder;
30use crate::error::{MlxError, Result};
31use crate::kernel_registry::KernelRegistry;
32
33pub static CUMSUM_SHADER_SOURCE: &str = include_str!("../shaders/cumsum.metal");
34
35/// Register cumsum shader sources with the given kernel registry.
36pub fn register(registry: &mut KernelRegistry) {
37    registry.register_source("cumsum_f32", CUMSUM_SHADER_SOURCE);
38    registry.register_source("cumsum_bf16", CUMSUM_SHADER_SOURCE);
39}
40
41/// Maximum per-thread chunk; keep in sync with `CUMSUM_MAX_CHUNK` in
42/// `cumsum.metal`. Threads must not exceed this chunk size or they'll
43/// overrun their private buffer.
44const SHADER_MAX_CHUNK: u32 = 32;
45
46/// Dispatch an inclusive prefix sum along the last axis of `[rows, dim]`.
47///
48/// # Arguments
49///
50/// * `input`      - shape `[rows, dim]`, f32 or bf16.
51/// * `output`     - same shape + dtype as `input`.
52/// * `params_buf` - one u32: `[dim]`. The kernel also reads `tg_size` from
53///                  the Metal dispatch so only `dim` needs to be in-buffer.
54/// * `rows`       - number of independent rows.
55/// * `dim`        - length of each row; must satisfy
56///                  `ceil_div(dim, tg_size) <= 32`.
57///
58/// # Errors
59///
60/// Returns `MlxError::InvalidArgument` if:
61/// - `rows == 0` or `dim == 0`.
62/// - input/output element counts disagree with `rows * dim`.
63/// - input and output dtypes differ, or are not f32/bf16.
64/// - `dim` exceeds `tg_size * SHADER_MAX_CHUNK` (8192 for the default
65///   `tg_size = 256`).
66pub fn dispatch_cumsum(
67    encoder: &mut CommandEncoder,
68    registry: &mut KernelRegistry,
69    device: &metal::DeviceRef,
70    input: &MlxBuffer,
71    output: &MlxBuffer,
72    params_buf: &MlxBuffer,
73    rows: u32,
74    dim: u32,
75) -> Result<()> {
76    if rows == 0 || dim == 0 {
77        return Err(MlxError::InvalidArgument(
78            "cumsum rows and dim must be > 0".into(),
79        ));
80    }
81
82    let expected = (rows as usize) * (dim as usize);
83    if input.element_count() != expected {
84        return Err(MlxError::InvalidArgument(format!(
85            "cumsum input element count {} != rows({}) * dim({})",
86            input.element_count(),
87            rows,
88            dim
89        )));
90    }
91    if output.element_count() != expected {
92        return Err(MlxError::InvalidArgument(format!(
93            "cumsum output element count {} != rows({}) * dim({})",
94            output.element_count(),
95            rows,
96            dim
97        )));
98    }
99    if input.dtype() != output.dtype() {
100        return Err(MlxError::InvalidArgument(format!(
101            "cumsum input/output dtype mismatch: {} vs {}",
102            input.dtype(),
103            output.dtype()
104        )));
105    }
106
107    let kernel_name = match input.dtype() {
108        DType::F32 => "cumsum_f32",
109        DType::BF16 => "cumsum_bf16",
110        _ => {
111            return Err(MlxError::InvalidArgument(format!(
112                "cumsum unsupported dtype: {}",
113                input.dtype()
114            )));
115        }
116    };
117
118    // Choose threadgroup size: smallest power of two >= dim, capped at 256.
119    // Must be a power of two for the Hillis-Steele offset loop. Must be >= 1
120    // (true because we rejected dim == 0 above).
121    let tg_size = std::cmp::min(256u32, dim.next_power_of_two());
122    let tg_size = std::cmp::max(tg_size, 1u32);
123
124    let chunk = dim.div_ceil(tg_size);
125    if chunk > SHADER_MAX_CHUNK {
126        return Err(MlxError::InvalidArgument(format!(
127            "cumsum dim {} exceeds supported limit: tg_size {} * chunk {} < dim",
128            dim, tg_size, SHADER_MAX_CHUNK
129        )));
130    }
131
132    let pipeline = registry.get_pipeline(kernel_name, device)?;
133
134    // Shared memory: tg_size floats for the cross-thread scan.
135    let shared_mem_bytes = (tg_size as u64) * 4;
136
137    encoder.encode_threadgroups_with_shared(
138        pipeline,
139        &[(0, input), (1, output), (2, params_buf)],
140        &[(0, shared_mem_bytes)],
141        MTLSize::new(rows as u64, 1, 1),
142        MTLSize::new(tg_size as u64, 1, 1),
143    );
144
145    Ok(())
146}