Skip to main content

mlx_native/ops/
softmax_backward.rs

1//! Backward pass for row-wise softmax.
2//!
3//! Given `y = softmax(x)` along the last dim and upstream gradient
4//! `dy` of the same shape, computes
5//!
6//!   `dx[b, i] = y[b, i] · (dy[b, i] − Σ_j y[b, j] · dy[b, j])`
7//!
8//! Companion to [`crate::ops::softmax::dispatch_softmax`].  Used by
9//! reverse-mode autograd in hf2q's calibrate module (ADR-020 Track 1).
10//!
11//! Threadgroup-per-row layout matches softmax forward: one threadgroup
12//! processes one full row, doing a tree reduction over the columns.
13
14use crate::buffer::MlxBuffer;
15use crate::dtypes::DType;
16use crate::encoder::CommandEncoder;
17use crate::error::{MlxError, Result};
18use crate::kernel_registry::KernelRegistry;
19
20use metal::MTLSize;
21
22pub fn register(registry: &mut KernelRegistry) {
23    registry.register_source(
24        "softmax_backward_f32",
25        include_str!("../shaders/softmax_backward.metal"),
26    );
27}
28
29/// Encode the softmax backward kernel.
30///
31/// # Arguments
32///
33/// * `encoder`    — Command encoder.
34/// * `registry`   — Kernel registry (must have softmax_backward source registered).
35/// * `device`     — Metal device.
36/// * `y`          — Forward softmax output `[rows, cols]`, f32.
37/// * `dy`         — Upstream gradient `[rows, cols]`, f32.
38/// * `dx`         — Output gradient `[rows, cols]`, f32 (must be pre-allocated).
39/// * `params_buf` — Params buffer containing `[cols, 0]` as f32.
40/// * `rows`       — Row count (one threadgroup per row).
41/// * `cols`       — Column count.
42///
43/// # Errors
44///
45/// Returns `MlxError::InvalidArgument` if shapes are inconsistent or
46/// any buffer is too small.
47#[allow(clippy::too_many_arguments)]
48pub fn dispatch_softmax_backward(
49    encoder: &mut CommandEncoder,
50    registry: &mut KernelRegistry,
51    device: &metal::DeviceRef,
52    y: &MlxBuffer,
53    dy: &MlxBuffer,
54    dx: &MlxBuffer,
55    params_buf: &MlxBuffer,
56    rows: u32,
57    cols: u32,
58) -> Result<()> {
59    if rows == 0 || cols == 0 {
60        return Err(MlxError::InvalidArgument(
61            "softmax_backward: rows and cols must be > 0".into(),
62        ));
63    }
64    let expected = (rows as usize) * (cols as usize);
65    for (label, buf) in [("y", y), ("dy", dy), ("dx", dx)] {
66        if buf.element_count() != expected {
67            return Err(MlxError::InvalidArgument(format!(
68                "softmax_backward: {label} element count {} != rows({}) * cols({})",
69                buf.element_count(),
70                rows,
71                cols
72            )));
73        }
74        if buf.dtype() != DType::F32 {
75            return Err(MlxError::InvalidArgument(format!(
76                "softmax_backward: {label} dtype {} not f32",
77                buf.dtype()
78            )));
79        }
80    }
81    if params_buf.byte_len() < 8 {
82        return Err(MlxError::InvalidArgument(format!(
83            "softmax_backward: params_buf too small (need 8 bytes for float2, got {})",
84            params_buf.byte_len()
85        )));
86    }
87
88    let pipeline = registry.get_pipeline("softmax_backward_f32", device)?;
89
90    // One threadgroup per row.  Threadgroup size must be a power of 2
91    // for the tree reduction (matches softmax forward convention).
92    let tg_size = std::cmp::min(256, cols.next_power_of_two()) as u64;
93    let shared_mem_bytes = tg_size * 4; // sizeof(float) = 4
94
95    encoder.encode_threadgroups_with_shared(
96        pipeline,
97        &[(0, y), (1, dy), (2, dx), (3, params_buf)],
98        &[(0, shared_mem_bytes)],
99        MTLSize::new(rows as u64, 1, 1),
100        MTLSize::new(tg_size, 1, 1),
101    );
102
103    Ok(())
104}