mlx_native/ops/
scale_mask_softmax.rs1use crate::buffer::MlxBuffer;
10use crate::device::MlxDevice;
11use crate::dtypes::DType;
12use crate::encoder::{as_bytes, CommandEncoder, KernelArg};
13use crate::error::{MlxError, Result};
14use crate::kernel_registry::KernelRegistry;
15
16#[derive(Debug, Clone, Copy)]
18pub struct ScaleMaskSoftmaxParams {
19 pub rows: u32,
21 pub cols: u32,
24 pub seq_q: u32,
27 pub scale: f32,
29}
30
31#[repr(C)]
32#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
33struct ScaleMaskSoftmaxGpuParams {
34 cols: u32,
35 seq_q: u32,
36 scale: f32,
37 _pad: u32,
38}
39
40const THREADS_PER_TG: u64 = 256;
45
46pub fn dispatch_scale_mask_softmax_f32(
52 encoder: &mut CommandEncoder,
53 registry: &mut KernelRegistry,
54 device: &MlxDevice,
55 input: &MlxBuffer,
56 output: &MlxBuffer,
57 mask_bf16: &MlxBuffer,
58 params: &ScaleMaskSoftmaxParams,
59) -> Result<()> {
60 if params.rows == 0 || params.cols == 0 || params.seq_q == 0 {
61 return Err(MlxError::InvalidArgument(
62 "scale_mask_softmax_f32: rows/cols/seq_q must be > 0".into(),
63 ));
64 }
65 if params.rows % params.seq_q != 0 {
66 return Err(MlxError::InvalidArgument(format!(
67 "scale_mask_softmax_f32: rows ({}) must be a multiple of seq_q ({})",
68 params.rows, params.seq_q
69 )));
70 }
71
72 let f32_sz = DType::F32.size_of();
73 let bf16_sz = DType::BF16.size_of();
74
75 let expected_input_bytes = (params.rows as usize) * (params.cols as usize) * f32_sz;
76 if input.byte_len() < expected_input_bytes {
77 return Err(MlxError::InvalidArgument(format!(
78 "scale_mask_softmax_f32: input too small: expected {} bytes, got {}",
79 expected_input_bytes, input.byte_len()
80 )));
81 }
82 if output.byte_len() < expected_input_bytes {
83 return Err(MlxError::InvalidArgument(format!(
84 "scale_mask_softmax_f32: output too small: expected {} bytes, got {}",
85 expected_input_bytes, output.byte_len()
86 )));
87 }
88 let expected_mask_bytes = (params.seq_q as usize) * (params.cols as usize) * bf16_sz;
89 if mask_bf16.byte_len() < expected_mask_bytes {
90 return Err(MlxError::InvalidArgument(format!(
91 "scale_mask_softmax_f32: mask too small: expected {} bytes for [{}x{}] bf16, got {}",
92 expected_mask_bytes, params.seq_q, params.cols, mask_bf16.byte_len()
93 )));
94 }
95
96 let pipeline = registry
97 .get_pipeline("scale_mask_softmax_f32", device.metal_device())?;
98
99 let gpu_params = ScaleMaskSoftmaxGpuParams {
100 cols: params.cols,
101 seq_q: params.seq_q,
102 scale: params.scale,
103 _pad: 0,
104 };
105
106 let threadgroups = metal::MTLSize::new(params.rows as u64, 1, 1);
107 let tg_size = std::cmp::min(THREADS_PER_TG, params.cols as u64);
108 let threads_per_tg = metal::MTLSize::new(tg_size, 1, 1);
109 let shmem_bytes = tg_size * (f32_sz as u64);
110
111 encoder.encode_threadgroups_with_args_and_shared(
112 pipeline,
113 &[
114 (0, KernelArg::Buffer(input)),
115 (1, KernelArg::Buffer(output)),
116 (2, KernelArg::Buffer(mask_bf16)),
117 (3, KernelArg::Bytes(as_bytes(&gpu_params))),
118 ],
119 &[(0, shmem_bytes)],
120 threadgroups,
121 threads_per_tg,
122 );
123
124 Ok(())
125}