Skip to main content

mlx_native/ops/
softmax.rs

1//! Numerically stable softmax GPU dispatch.
2//!
3//! Computes softmax along the last dimension of a 2D tensor using the
4//! subtract-max trick for numerical stability.  All accumulations use f32
5//! even when inputs are f16 to prevent overflow.
6
7use metal::MTLSize;
8
9use crate::buffer::MlxBuffer;
10use crate::dtypes::DType;
11use crate::encoder::{CapturedOpKind, CommandEncoder};
12use crate::error::{MlxError, Result};
13use crate::kernel_registry::KernelRegistry;
14
15/// MSL source for the softmax kernels (embedded at compile time).
16pub static SOFTMAX_SHADER_SOURCE: &str = include_str!("../shaders/softmax.metal");
17
18/// Register softmax shader sources with the given kernel registry.
19pub fn register(registry: &mut KernelRegistry) {
20    registry.register_source("softmax_f32", SOFTMAX_SHADER_SOURCE);
21    registry.register_source("softmax_f16", SOFTMAX_SHADER_SOURCE);
22    registry.register_source("softmax_bf16", SOFTMAX_SHADER_SOURCE);
23}
24
25/// Dispatch a softmax operation on the GPU.
26///
27/// # Arguments
28///
29/// * `encoder`    - Command encoder to record the dispatch into.
30/// * `registry`   - Kernel registry (must have softmax sources registered).
31/// * `device`     - Metal device for pipeline compilation.
32/// * `input`      - Input buffer of shape `[rows, cols]` (f32, f16, or bf16).
33/// * `output`     - Output buffer (same dtype and shape as input).
34/// * `params_buf` - Params buffer containing `[cols, 0]` as f32.
35/// * `rows`       - Number of rows.
36/// * `cols`       - Number of columns (softmax dimension).
37///
38/// # Errors
39///
40/// Returns `MlxError::InvalidArgument` if:
41/// - Input dtype is not f32, f16, or bf16.
42/// - Input element count does not match rows * cols.
43pub fn dispatch_softmax(
44    encoder: &mut CommandEncoder,
45    registry: &mut KernelRegistry,
46    device: &metal::DeviceRef,
47    input: &MlxBuffer,
48    output: &MlxBuffer,
49    params_buf: &MlxBuffer,
50    rows: u32,
51    cols: u32,
52) -> Result<()> {
53    if rows == 0 || cols == 0 {
54        return Err(MlxError::InvalidArgument(
55            "Softmax rows and cols must be > 0".into(),
56        ));
57    }
58
59    let expected = (rows as usize) * (cols as usize);
60    if input.element_count() != expected {
61        return Err(MlxError::InvalidArgument(format!(
62            "Softmax input element count {} != rows({}) * cols({})",
63            input.element_count(),
64            rows,
65            cols
66        )));
67    }
68    if output.element_count() != expected {
69        return Err(MlxError::InvalidArgument(format!(
70            "Softmax output element count {} != rows({}) * cols({})",
71            output.element_count(),
72            rows,
73            cols
74        )));
75    }
76
77    let kernel_name = match input.dtype() {
78        DType::F32 => "softmax_f32",
79        DType::F16 => "softmax_f16",
80        DType::BF16 => "softmax_bf16",
81        _ => {
82            return Err(MlxError::InvalidArgument(format!(
83                "Softmax unsupported dtype: {}",
84                input.dtype()
85            )));
86        }
87    };
88
89    let pipeline = registry.get_pipeline(kernel_name, device)?;
90
91    // One threadgroup per row.  Threadgroup size must be a power of 2
92    // for the tree reduction to work correctly.
93    let tg_size = std::cmp::min(256, cols.next_power_of_two()) as u64;
94
95    // Threadgroup shared memory: tg_size floats for the reduction.
96    let shared_mem_bytes = tg_size * 4; // sizeof(float) = 4
97
98    // Tag for the reorder pass (Phase 4e.3): Softmax is NOT reorderable.
99    encoder.set_op_kind(CapturedOpKind::Softmax);
100
101    encoder.encode_threadgroups_with_shared(
102        pipeline,
103        &[(0, input), (1, output), (2, params_buf)],
104        &[(0, shared_mem_bytes)],
105        MTLSize::new(rows as u64, 1, 1),
106        MTLSize::new(tg_size, 1, 1),
107    );
108
109    Ok(())
110}