Skip to main content

mlx_native/ops/
row_sum.rs

1//! Per-row sum reduction along the last dimension of a 2-D tensor +
2//! its broadcast-along-cols backward.
3//!
4//! Used by reverse-mode autograd in downstream crates (hf2q ADR-020
5//! Track 1: KL-divergence loss composition needs Σ_j p · (log_p − log_q)
6//! per row).
7
8use metal::MTLSize;
9
10use crate::buffer::MlxBuffer;
11use crate::dtypes::DType;
12use crate::encoder::CommandEncoder;
13use crate::error::{MlxError, Result};
14use crate::kernel_registry::KernelRegistry;
15
16pub static ROW_SUM_SHADER_SOURCE: &str = include_str!("../shaders/row_sum.metal");
17
18pub fn register(registry: &mut KernelRegistry) {
19    registry.register_source("row_sum_f32", ROW_SUM_SHADER_SOURCE);
20    registry.register_source("row_sum_backward_f32", ROW_SUM_SHADER_SOURCE);
21}
22
23/// Encode `output[b] = Σ_j input[b, j]` for a 2-D `[rows, cols]` f32 input.
24#[allow(clippy::too_many_arguments)]
25pub fn dispatch_row_sum_f32(
26    encoder: &mut CommandEncoder,
27    registry: &mut KernelRegistry,
28    device: &metal::DeviceRef,
29    input: &MlxBuffer,
30    output: &MlxBuffer,
31    params_buf: &MlxBuffer,
32    rows: u32,
33    cols: u32,
34) -> Result<()> {
35    if rows == 0 || cols == 0 {
36        return Err(MlxError::InvalidArgument(
37            "row_sum_f32: rows and cols must be > 0".into(),
38        ));
39    }
40    let in_expected = (rows as usize) * (cols as usize);
41    if input.element_count() != in_expected {
42        return Err(MlxError::InvalidArgument(format!(
43            "row_sum_f32: input element count {} != rows({}) * cols({})",
44            input.element_count(),
45            rows,
46            cols
47        )));
48    }
49    if output.element_count() != rows as usize {
50        return Err(MlxError::InvalidArgument(format!(
51            "row_sum_f32: output element count {} != rows({})",
52            output.element_count(),
53            rows
54        )));
55    }
56    if input.dtype() != DType::F32 || output.dtype() != DType::F32 {
57        return Err(MlxError::InvalidArgument(format!(
58            "row_sum_f32: only f32 supported; got input={} output={}",
59            input.dtype(),
60            output.dtype()
61        )));
62    }
63    if params_buf.byte_len() < 8 {
64        return Err(MlxError::InvalidArgument(format!(
65            "row_sum_f32: params_buf too small (need 8 bytes, got {})",
66            params_buf.byte_len()
67        )));
68    }
69
70    let pipeline = registry.get_pipeline("row_sum_f32", device)?;
71    let tg_size = std::cmp::min(256, cols.next_power_of_two()) as u64;
72    let shared_mem_bytes = tg_size * 4;
73
74    encoder.encode_threadgroups_with_shared(
75        pipeline,
76        &[(0, input), (1, output), (2, params_buf)],
77        &[(0, shared_mem_bytes)],
78        MTLSize::new(rows as u64, 1, 1),
79        MTLSize::new(tg_size, 1, 1),
80    );
81
82    Ok(())
83}
84
85/// Encode `dx[b, i] = d_out[b]` (broadcast along the cols dim).  This
86/// is the backward of [`dispatch_row_sum_f32`].
87#[allow(clippy::too_many_arguments)]
88pub fn dispatch_row_sum_backward_f32(
89    encoder: &mut CommandEncoder,
90    registry: &mut KernelRegistry,
91    device: &metal::DeviceRef,
92    d_out: &MlxBuffer,
93    dx: &MlxBuffer,
94    params_buf: &MlxBuffer,
95    rows: u32,
96    cols: u32,
97) -> Result<()> {
98    if rows == 0 || cols == 0 {
99        return Err(MlxError::InvalidArgument(
100            "row_sum_backward_f32: rows and cols must be > 0".into(),
101        ));
102    }
103    if d_out.element_count() != rows as usize {
104        return Err(MlxError::InvalidArgument(format!(
105            "row_sum_backward_f32: d_out element count {} != rows({})",
106            d_out.element_count(),
107            rows
108        )));
109    }
110    let dx_expected = (rows as usize) * (cols as usize);
111    if dx.element_count() != dx_expected {
112        return Err(MlxError::InvalidArgument(format!(
113            "row_sum_backward_f32: dx element count {} != rows({}) * cols({})",
114            dx.element_count(),
115            rows,
116            cols
117        )));
118    }
119    if d_out.dtype() != DType::F32 || dx.dtype() != DType::F32 {
120        return Err(MlxError::InvalidArgument(format!(
121            "row_sum_backward_f32: only f32; d_out={} dx={}",
122            d_out.dtype(),
123            dx.dtype()
124        )));
125    }
126    if params_buf.byte_len() < 8 {
127        return Err(MlxError::InvalidArgument(format!(
128            "row_sum_backward_f32: params_buf too small (need 8 bytes, got {})",
129            params_buf.byte_len()
130        )));
131    }
132
133    let pipeline = registry.get_pipeline("row_sum_backward_f32", device)?;
134    let tg_size = std::cmp::min(256, cols.next_power_of_two()) as u64;
135
136    encoder.encode_threadgroups(
137        pipeline,
138        &[(0, d_out), (1, dx), (2, params_buf)],
139        MTLSize::new(rows as u64, 1, 1),
140        MTLSize::new(tg_size, 1, 1),
141    );
142
143    Ok(())
144}