mlx_native/ops/
scale_mask_softmax.rs1use crate::buffer::MlxBuffer;
10use crate::device::MlxDevice;
11use crate::dtypes::DType;
12use crate::encoder::{as_bytes, CommandEncoder, KernelArg};
13use crate::error::{MlxError, Result};
14use crate::kernel_registry::KernelRegistry;
15
16#[derive(Debug, Clone, Copy)]
18pub struct ScaleMaskSoftmaxParams {
19 pub rows: u32,
21 pub cols: u32,
24 pub seq_q: u32,
27 pub scale: f32,
29}
30
31#[repr(C)]
32#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
33struct ScaleMaskSoftmaxGpuParams {
34 cols: u32,
35 seq_q: u32,
36 scale: f32,
37 _pad: u32,
38}
39
40const THREADS_PER_TG: u64 = 256;
45
46pub fn dispatch_scale_mask_softmax_f32(
52 encoder: &mut CommandEncoder,
53 registry: &mut KernelRegistry,
54 device: &MlxDevice,
55 input: &MlxBuffer,
56 output: &MlxBuffer,
57 mask_bf16: &MlxBuffer,
58 params: &ScaleMaskSoftmaxParams,
59) -> Result<()> {
60 if params.rows == 0 || params.cols == 0 || params.seq_q == 0 {
61 return Err(MlxError::InvalidArgument(
62 "scale_mask_softmax_f32: rows/cols/seq_q must be > 0".into(),
63 ));
64 }
65 if params.rows % params.seq_q != 0 {
66 return Err(MlxError::InvalidArgument(format!(
67 "scale_mask_softmax_f32: rows ({}) must be a multiple of seq_q ({})",
68 params.rows, params.seq_q
69 )));
70 }
71
72 let f32_sz = DType::F32.size_of();
73 let bf16_sz = DType::BF16.size_of();
74
75 let expected_input_bytes = (params.rows as usize) * (params.cols as usize) * f32_sz;
76 if input.byte_len() < expected_input_bytes {
77 return Err(MlxError::InvalidArgument(format!(
78 "scale_mask_softmax_f32: input too small: expected {} bytes, got {}",
79 expected_input_bytes, input.byte_len()
80 )));
81 }
82 if output.byte_len() < expected_input_bytes {
83 return Err(MlxError::InvalidArgument(format!(
84 "scale_mask_softmax_f32: output too small: expected {} bytes, got {}",
85 expected_input_bytes, output.byte_len()
86 )));
87 }
88 let expected_mask_bytes = (params.seq_q as usize) * (params.cols as usize) * bf16_sz;
89 if mask_bf16.byte_len() < expected_mask_bytes {
90 return Err(MlxError::InvalidArgument(format!(
91 "scale_mask_softmax_f32: mask too small: expected {} bytes for [{}x{}] bf16, got {}",
92 expected_mask_bytes, params.seq_q, params.cols, mask_bf16.byte_len()
93 )));
94 }
95
96 let use_v4 = (params.cols % 4 == 0) && match std::env::var("HF2Q_SOFTMAX_V4").as_deref() {
100 Ok("1") | Ok("true") | Ok("True") | Ok("TRUE") | Ok("yes") | Ok("YES") => true,
101 _ => false,
102 };
103 let kernel_name = if use_v4 {
104 "scale_mask_softmax_f32_v4"
105 } else {
106 "scale_mask_softmax_f32"
107 };
108 let pipeline = registry
109 .get_pipeline(kernel_name, device.metal_device())?;
110
111 let gpu_params = ScaleMaskSoftmaxGpuParams {
112 cols: params.cols,
113 seq_q: params.seq_q,
114 scale: params.scale,
115 _pad: 0,
116 };
117
118 let threadgroups = metal::MTLSize::new(params.rows as u64, 1, 1);
119 let tg_size = std::cmp::min(THREADS_PER_TG, params.cols as u64);
120 let threads_per_tg = metal::MTLSize::new(tg_size, 1, 1);
121 let shmem_bytes = tg_size * (f32_sz as u64);
122
123 encoder.encode_threadgroups_with_args_and_shared(
124 pipeline,
125 &[
126 (0, KernelArg::Buffer(input)),
127 (1, KernelArg::Buffer(output)),
128 (2, KernelArg::Buffer(mask_bf16)),
129 (3, KernelArg::Bytes(as_bytes(&gpu_params))),
130 ],
131 &[(0, shmem_bytes)],
132 threadgroups,
133 threads_per_tg,
134 );
135
136 Ok(())
137}