mlx_native/ops/
softmax.rs1use metal::MTLSize;
8
9use crate::buffer::MlxBuffer;
10use crate::dtypes::DType;
11use crate::encoder::{CapturedOpKind, CommandEncoder};
12use crate::error::{MlxError, Result};
13use crate::kernel_registry::KernelRegistry;
14
15pub static SOFTMAX_SHADER_SOURCE: &str = include_str!("../shaders/softmax.metal");
17
18pub fn register(registry: &mut KernelRegistry) {
20 registry.register_source("softmax_f32", SOFTMAX_SHADER_SOURCE);
21 registry.register_source("softmax_f16", SOFTMAX_SHADER_SOURCE);
22 registry.register_source("softmax_bf16", SOFTMAX_SHADER_SOURCE);
23}
24
25pub fn dispatch_softmax(
44 encoder: &mut CommandEncoder,
45 registry: &mut KernelRegistry,
46 device: &metal::DeviceRef,
47 input: &MlxBuffer,
48 output: &MlxBuffer,
49 params_buf: &MlxBuffer,
50 rows: u32,
51 cols: u32,
52) -> Result<()> {
53 if rows == 0 || cols == 0 {
54 return Err(MlxError::InvalidArgument(
55 "Softmax rows and cols must be > 0".into(),
56 ));
57 }
58
59 let expected = (rows as usize) * (cols as usize);
60 if input.element_count() != expected {
61 return Err(MlxError::InvalidArgument(format!(
62 "Softmax input element count {} != rows({}) * cols({})",
63 input.element_count(),
64 rows,
65 cols
66 )));
67 }
68 if output.element_count() != expected {
69 return Err(MlxError::InvalidArgument(format!(
70 "Softmax output element count {} != rows({}) * cols({})",
71 output.element_count(),
72 rows,
73 cols
74 )));
75 }
76
77 let kernel_name = match input.dtype() {
78 DType::F32 => "softmax_f32",
79 DType::F16 => "softmax_f16",
80 DType::BF16 => "softmax_bf16",
81 _ => {
82 return Err(MlxError::InvalidArgument(format!(
83 "Softmax unsupported dtype: {}",
84 input.dtype()
85 )));
86 }
87 };
88
89 let pipeline = registry.get_pipeline(kernel_name, device)?;
90
91 let tg_size = std::cmp::min(256, cols.next_power_of_two()) as u64;
94
95 let shared_mem_bytes = tg_size * 4; encoder.set_op_kind(CapturedOpKind::Softmax);
100
101 encoder.encode_threadgroups_with_shared(
102 pipeline,
103 &[(0, input), (1, output), (2, params_buf)],
104 &[(0, shared_mem_bytes)],
105 MTLSize::new(rows as u64, 1, 1),
106 MTLSize::new(tg_size, 1, 1),
107 );
108
109 Ok(())
110}