1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
//! Numerically stable softmax GPU dispatch.
//!
//! Computes softmax along the last dimension of a 2D tensor using the
//! subtract-max trick for numerical stability. All accumulations use f32
//! even when inputs are f16 to prevent overflow.
use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::dtypes::DType;
use crate::encoder::{CapturedOpKind, CommandEncoder};
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
/// MSL source for the softmax kernels (embedded at compile time).
pub static SOFTMAX_SHADER_SOURCE: &str = include_str!("../shaders/softmax.metal");
/// Register softmax shader sources with the given kernel registry.
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("softmax_f32", SOFTMAX_SHADER_SOURCE);
registry.register_source("softmax_f16", SOFTMAX_SHADER_SOURCE);
registry.register_source("softmax_bf16", SOFTMAX_SHADER_SOURCE);
}
/// Dispatch a softmax operation on the GPU.
///
/// # Arguments
///
/// * `encoder` - Command encoder to record the dispatch into.
/// * `registry` - Kernel registry (must have softmax sources registered).
/// * `device` - Metal device for pipeline compilation.
/// * `input` - Input buffer of shape `[rows, cols]` (f32, f16, or bf16).
/// * `output` - Output buffer (same dtype and shape as input).
/// * `params_buf` - Params buffer containing `[cols, 0]` as f32.
/// * `rows` - Number of rows.
/// * `cols` - Number of columns (softmax dimension).
///
/// # Errors
///
/// Returns `MlxError::InvalidArgument` if:
/// - Input dtype is not f32, f16, or bf16.
/// - Input element count does not match rows * cols.
pub fn dispatch_softmax(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
params_buf: &MlxBuffer,
rows: u32,
cols: u32,
) -> Result<()> {
if rows == 0 || cols == 0 {
return Err(MlxError::InvalidArgument(
"Softmax rows and cols must be > 0".into(),
));
}
let expected = (rows as usize) * (cols as usize);
if input.element_count() != expected {
return Err(MlxError::InvalidArgument(format!(
"Softmax input element count {} != rows({}) * cols({})",
input.element_count(),
rows,
cols
)));
}
if output.element_count() != expected {
return Err(MlxError::InvalidArgument(format!(
"Softmax output element count {} != rows({}) * cols({})",
output.element_count(),
rows,
cols
)));
}
// hf2q ADR-030 iter-113 — defense-in-depth: kernel selected by
// input.dtype() writes output at the same stride. Mismatched output
// dtype would mis-stride writes. See ADR-030 iter-110/111/112.
if input.dtype() != output.dtype() {
return Err(MlxError::InvalidArgument(format!(
"Softmax dtype mismatch: input={} != output={}",
input.dtype(), output.dtype(),
)));
}
let kernel_name = match input.dtype() {
DType::F32 => "softmax_f32",
DType::F16 => "softmax_f16",
DType::BF16 => "softmax_bf16",
_ => {
return Err(MlxError::InvalidArgument(format!(
"Softmax unsupported dtype: {}",
input.dtype()
)));
}
};
let pipeline = registry.get_pipeline(kernel_name, device)?;
// One threadgroup per row. Threadgroup size must be a power of 2
// for the tree reduction to work correctly.
let tg_size = std::cmp::min(256, cols.next_power_of_two()) as u64;
// Threadgroup shared memory: tg_size floats for the reduction.
let shared_mem_bytes = tg_size * 4; // sizeof(float) = 4
// Tag for the reorder pass (Phase 4e.3): Softmax is NOT reorderable.
encoder.set_op_kind(CapturedOpKind::Softmax);
encoder.encode_threadgroups_with_shared(
pipeline,
&[(0, input), (1, output), (2, params_buf)],
&[(0, shared_mem_bytes)],
MTLSize::new(rows as u64, 1, 1),
MTLSize::new(tg_size, 1, 1),
);
Ok(())
}