1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
//! Cumulative sum (inclusive prefix sum) along the last axis.
//!
//! Computes `out[r, i] = sum(x[r, 0..=i])` for every row `r`.
//!
//! Used by the chunked Gated DeltaNet path to produce the decay-mask base
//! (ADR-013 Decision 4). Spec derived from the definition of an inclusive
//! prefix scan.
//!
//! # Algorithm
//!
//! Hillis-Steele scan in threadgroup shared memory. Each thread owns a
//! contiguous chunk of the row; it scans locally first, exchanges chunk
//! totals via a shared-memory scan, then adds the exclusive prefix of
//! preceding threads to its chunk.
//!
//! # Threadgroup shape
//!
//! One threadgroup per row. `tg_size = min(256, next_power_of_two(dim))`
//! and each thread handles `ceil_div(dim, tg_size)` elements. The shader
//! caps the per-thread chunk at `CUMSUM_MAX_CHUNK = 32`, so for `tg_size =
//! 256` the maximum `dim` handled in a single dispatch is 8192.
//!
//! Reduction and prefix arithmetic are performed in f32 regardless of input
//! dtype for numerical stability.
use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::dtypes::DType;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
pub static CUMSUM_SHADER_SOURCE: &str = include_str!("../shaders/cumsum.metal");
/// Register cumsum shader sources with the given kernel registry.
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("cumsum_f32", CUMSUM_SHADER_SOURCE);
registry.register_source("cumsum_bf16", CUMSUM_SHADER_SOURCE);
}
/// Maximum per-thread chunk; keep in sync with `CUMSUM_MAX_CHUNK` in
/// `cumsum.metal`. Threads must not exceed this chunk size or they'll
/// overrun their private buffer.
const SHADER_MAX_CHUNK: u32 = 32;
/// Dispatch an inclusive prefix sum along the last axis of `[rows, dim]`.
///
/// # Arguments
///
/// * `input` - shape `[rows, dim]`, f32 or bf16.
/// * `output` - same shape + dtype as `input`.
/// * `params_buf` - one u32: `[dim]`. The kernel also reads `tg_size` from
/// the Metal dispatch so only `dim` needs to be in-buffer.
/// * `rows` - number of independent rows.
/// * `dim` - length of each row; must satisfy
/// `ceil_div(dim, tg_size) <= 32`.
///
/// # Errors
///
/// Returns `MlxError::InvalidArgument` if:
/// - `rows == 0` or `dim == 0`.
/// - input/output element counts disagree with `rows * dim`.
/// - input and output dtypes differ, or are not f32/bf16.
/// - `dim` exceeds `tg_size * SHADER_MAX_CHUNK` (8192 for the default
/// `tg_size = 256`).
pub fn dispatch_cumsum(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
params_buf: &MlxBuffer,
rows: u32,
dim: u32,
) -> Result<()> {
if rows == 0 || dim == 0 {
return Err(MlxError::InvalidArgument(
"cumsum rows and dim must be > 0".into(),
));
}
let expected = (rows as usize) * (dim as usize);
if input.element_count() != expected {
return Err(MlxError::InvalidArgument(format!(
"cumsum input element count {} != rows({}) * dim({})",
input.element_count(),
rows,
dim
)));
}
if output.element_count() != expected {
return Err(MlxError::InvalidArgument(format!(
"cumsum output element count {} != rows({}) * dim({})",
output.element_count(),
rows,
dim
)));
}
if input.dtype() != output.dtype() {
return Err(MlxError::InvalidArgument(format!(
"cumsum input/output dtype mismatch: {} vs {}",
input.dtype(),
output.dtype()
)));
}
let kernel_name = match input.dtype() {
DType::F32 => "cumsum_f32",
DType::BF16 => "cumsum_bf16",
_ => {
return Err(MlxError::InvalidArgument(format!(
"cumsum unsupported dtype: {}",
input.dtype()
)));
}
};
// Choose threadgroup size: smallest power of two >= dim, capped at 256.
// Must be a power of two for the Hillis-Steele offset loop. Must be >= 1
// (true because we rejected dim == 0 above).
let tg_size = std::cmp::min(256u32, dim.next_power_of_two());
let tg_size = std::cmp::max(tg_size, 1u32);
let chunk = dim.div_ceil(tg_size);
if chunk > SHADER_MAX_CHUNK {
return Err(MlxError::InvalidArgument(format!(
"cumsum dim {} exceeds supported limit: tg_size {} * chunk {} < dim",
dim, tg_size, SHADER_MAX_CHUNK
)));
}
let pipeline = registry.get_pipeline(kernel_name, device)?;
// Shared memory: tg_size floats for the cross-thread scan.
let shared_mem_bytes = (tg_size as u64) * 4;
encoder.encode_threadgroups_with_shared(
pipeline,
&[(0, input), (1, output), (2, params_buf)],
&[(0, shared_mem_bytes)],
MTLSize::new(rows as u64, 1, 1),
MTLSize::new(tg_size as u64, 1, 1),
);
Ok(())
}