mlx_native/ops/
row_sum.rs1use metal::MTLSize;
9
10use crate::buffer::MlxBuffer;
11use crate::dtypes::DType;
12use crate::encoder::CommandEncoder;
13use crate::error::{MlxError, Result};
14use crate::kernel_registry::KernelRegistry;
15
16pub static ROW_SUM_SHADER_SOURCE: &str = include_str!("../shaders/row_sum.metal");
17
18pub fn register(registry: &mut KernelRegistry) {
19 registry.register_source("row_sum_f32", ROW_SUM_SHADER_SOURCE);
20 registry.register_source("row_sum_backward_f32", ROW_SUM_SHADER_SOURCE);
21}
22
23#[allow(clippy::too_many_arguments)]
25pub fn dispatch_row_sum_f32(
26 encoder: &mut CommandEncoder,
27 registry: &mut KernelRegistry,
28 device: &metal::DeviceRef,
29 input: &MlxBuffer,
30 output: &MlxBuffer,
31 params_buf: &MlxBuffer,
32 rows: u32,
33 cols: u32,
34) -> Result<()> {
35 if rows == 0 || cols == 0 {
36 return Err(MlxError::InvalidArgument(
37 "row_sum_f32: rows and cols must be > 0".into(),
38 ));
39 }
40 let in_expected = (rows as usize) * (cols as usize);
41 if input.element_count() != in_expected {
42 return Err(MlxError::InvalidArgument(format!(
43 "row_sum_f32: input element count {} != rows({}) * cols({})",
44 input.element_count(),
45 rows,
46 cols
47 )));
48 }
49 if output.element_count() != rows as usize {
50 return Err(MlxError::InvalidArgument(format!(
51 "row_sum_f32: output element count {} != rows({})",
52 output.element_count(),
53 rows
54 )));
55 }
56 if input.dtype() != DType::F32 || output.dtype() != DType::F32 {
57 return Err(MlxError::InvalidArgument(format!(
58 "row_sum_f32: only f32 supported; got input={} output={}",
59 input.dtype(),
60 output.dtype()
61 )));
62 }
63 if params_buf.byte_len() < 8 {
64 return Err(MlxError::InvalidArgument(format!(
65 "row_sum_f32: params_buf too small (need 8 bytes, got {})",
66 params_buf.byte_len()
67 )));
68 }
69
70 let pipeline = registry.get_pipeline("row_sum_f32", device)?;
71 let tg_size = std::cmp::min(256, cols.next_power_of_two()) as u64;
72 let shared_mem_bytes = tg_size * 4;
73
74 encoder.encode_threadgroups_with_shared(
75 pipeline,
76 &[(0, input), (1, output), (2, params_buf)],
77 &[(0, shared_mem_bytes)],
78 MTLSize::new(rows as u64, 1, 1),
79 MTLSize::new(tg_size, 1, 1),
80 );
81
82 Ok(())
83}
84
85#[allow(clippy::too_many_arguments)]
88pub fn dispatch_row_sum_backward_f32(
89 encoder: &mut CommandEncoder,
90 registry: &mut KernelRegistry,
91 device: &metal::DeviceRef,
92 d_out: &MlxBuffer,
93 dx: &MlxBuffer,
94 params_buf: &MlxBuffer,
95 rows: u32,
96 cols: u32,
97) -> Result<()> {
98 if rows == 0 || cols == 0 {
99 return Err(MlxError::InvalidArgument(
100 "row_sum_backward_f32: rows and cols must be > 0".into(),
101 ));
102 }
103 if d_out.element_count() != rows as usize {
104 return Err(MlxError::InvalidArgument(format!(
105 "row_sum_backward_f32: d_out element count {} != rows({})",
106 d_out.element_count(),
107 rows
108 )));
109 }
110 let dx_expected = (rows as usize) * (cols as usize);
111 if dx.element_count() != dx_expected {
112 return Err(MlxError::InvalidArgument(format!(
113 "row_sum_backward_f32: dx element count {} != rows({}) * cols({})",
114 dx.element_count(),
115 rows,
116 cols
117 )));
118 }
119 if d_out.dtype() != DType::F32 || dx.dtype() != DType::F32 {
120 return Err(MlxError::InvalidArgument(format!(
121 "row_sum_backward_f32: only f32; d_out={} dx={}",
122 d_out.dtype(),
123 dx.dtype()
124 )));
125 }
126 if params_buf.byte_len() < 8 {
127 return Err(MlxError::InvalidArgument(format!(
128 "row_sum_backward_f32: params_buf too small (need 8 bytes, got {})",
129 params_buf.byte_len()
130 )));
131 }
132
133 let pipeline = registry.get_pipeline("row_sum_backward_f32", device)?;
134 let tg_size = std::cmp::min(256, cols.next_power_of_two()) as u64;
135
136 encoder.encode_threadgroups(
137 pipeline,
138 &[(0, d_out), (1, dx), (2, params_buf)],
139 MTLSize::new(rows as u64, 1, 1),
140 MTLSize::new(tg_size, 1, 1),
141 );
142
143 Ok(())
144}