mlx_native/ops/
softmax_backward.rs1use crate::buffer::MlxBuffer;
15use crate::dtypes::DType;
16use crate::encoder::CommandEncoder;
17use crate::error::{MlxError, Result};
18use crate::kernel_registry::KernelRegistry;
19
20use metal::MTLSize;
21
22pub fn register(registry: &mut KernelRegistry) {
23 registry.register_source(
24 "softmax_backward_f32",
25 include_str!("../shaders/softmax_backward.metal"),
26 );
27}
28
29#[allow(clippy::too_many_arguments)]
48pub fn dispatch_softmax_backward(
49 encoder: &mut CommandEncoder,
50 registry: &mut KernelRegistry,
51 device: &metal::DeviceRef,
52 y: &MlxBuffer,
53 dy: &MlxBuffer,
54 dx: &MlxBuffer,
55 params_buf: &MlxBuffer,
56 rows: u32,
57 cols: u32,
58) -> Result<()> {
59 if rows == 0 || cols == 0 {
60 return Err(MlxError::InvalidArgument(
61 "softmax_backward: rows and cols must be > 0".into(),
62 ));
63 }
64 let expected = (rows as usize) * (cols as usize);
65 for (label, buf) in [("y", y), ("dy", dy), ("dx", dx)] {
66 if buf.element_count() != expected {
67 return Err(MlxError::InvalidArgument(format!(
68 "softmax_backward: {label} element count {} != rows({}) * cols({})",
69 buf.element_count(),
70 rows,
71 cols
72 )));
73 }
74 if buf.dtype() != DType::F32 {
75 return Err(MlxError::InvalidArgument(format!(
76 "softmax_backward: {label} dtype {} not f32",
77 buf.dtype()
78 )));
79 }
80 }
81 if params_buf.byte_len() < 8 {
82 return Err(MlxError::InvalidArgument(format!(
83 "softmax_backward: params_buf too small (need 8 bytes for float2, got {})",
84 params_buf.byte_len()
85 )));
86 }
87
88 let pipeline = registry.get_pipeline("softmax_backward_f32", device)?;
89
90 let tg_size = std::cmp::min(256, cols.next_power_of_two()) as u64;
93 let shared_mem_bytes = tg_size * 4; encoder.encode_threadgroups_with_shared(
96 pipeline,
97 &[(0, y), (1, dy), (2, dx), (3, params_buf)],
98 &[(0, shared_mem_bytes)],
99 MTLSize::new(rows as u64, 1, 1),
100 MTLSize::new(tg_size, 1, 1),
101 );
102
103 Ok(())
104}