mlx_native/ops/
softmax.rs1use metal::MTLSize;
8
9use crate::buffer::MlxBuffer;
10use crate::dtypes::DType;
11use crate::encoder::{CapturedOpKind, CommandEncoder};
12use crate::error::{MlxError, Result};
13use crate::kernel_registry::KernelRegistry;
14
15pub static SOFTMAX_SHADER_SOURCE: &str = include_str!("../shaders/softmax.metal");
17
18pub fn register(registry: &mut KernelRegistry) {
20 registry.register_source("softmax_f32", SOFTMAX_SHADER_SOURCE);
21 registry.register_source("softmax_f16", SOFTMAX_SHADER_SOURCE);
22 registry.register_source("softmax_bf16", SOFTMAX_SHADER_SOURCE);
23}
24
25pub fn dispatch_softmax(
44 encoder: &mut CommandEncoder,
45 registry: &mut KernelRegistry,
46 device: &metal::DeviceRef,
47 input: &MlxBuffer,
48 output: &MlxBuffer,
49 params_buf: &MlxBuffer,
50 rows: u32,
51 cols: u32,
52) -> Result<()> {
53 if rows == 0 || cols == 0 {
54 return Err(MlxError::InvalidArgument(
55 "Softmax rows and cols must be > 0".into(),
56 ));
57 }
58
59 let expected = (rows as usize) * (cols as usize);
60 if input.element_count() != expected {
61 return Err(MlxError::InvalidArgument(format!(
62 "Softmax input element count {} != rows({}) * cols({})",
63 input.element_count(),
64 rows,
65 cols
66 )));
67 }
68 if output.element_count() != expected {
69 return Err(MlxError::InvalidArgument(format!(
70 "Softmax output element count {} != rows({}) * cols({})",
71 output.element_count(),
72 rows,
73 cols
74 )));
75 }
76 if input.dtype() != output.dtype() {
80 return Err(MlxError::InvalidArgument(format!(
81 "Softmax dtype mismatch: input={} != output={}",
82 input.dtype(), output.dtype(),
83 )));
84 }
85
86 let kernel_name = match input.dtype() {
87 DType::F32 => "softmax_f32",
88 DType::F16 => "softmax_f16",
89 DType::BF16 => "softmax_bf16",
90 _ => {
91 return Err(MlxError::InvalidArgument(format!(
92 "Softmax unsupported dtype: {}",
93 input.dtype()
94 )));
95 }
96 };
97
98 let pipeline = registry.get_pipeline(kernel_name, device)?;
99
100 let tg_size = std::cmp::min(256, cols.next_power_of_two()) as u64;
103
104 let shared_mem_bytes = tg_size * 4; encoder.set_op_kind(CapturedOpKind::Softmax);
109
110 encoder.encode_threadgroups_with_shared(
111 pipeline,
112 &[(0, input), (1, output), (2, params_buf)],
113 &[(0, shared_mem_bytes)],
114 MTLSize::new(rows as u64, 1, 1),
115 MTLSize::new(tg_size, 1, 1),
116 );
117
118 Ok(())
119}