Skip to main content

mlx_native/ops/
scale_mask_softmax.rs

1//! Fused scale-mask-softmax for non-flash-attention prefill.
2//!
3//! Replaces three sequential dispatches (scale by 1/sqrt(hd), add bf16
4//! mask, row-softmax) with one kernel.  Intended for hf2q's HF2Q_NO_FA
5//! path; not used elsewhere.
6//!
7//! See `src/shaders/scale_mask_softmax.metal` for the kernel contract.
8
9use crate::buffer::MlxBuffer;
10use crate::device::MlxDevice;
11use crate::dtypes::DType;
12use crate::encoder::{as_bytes, CommandEncoder, KernelArg};
13use crate::error::{MlxError, Result};
14use crate::kernel_registry::KernelRegistry;
15
16/// Host-side parameters for `scale_mask_softmax_f32`.
17#[derive(Debug, Clone, Copy)]
18pub struct ScaleMaskSoftmaxParams {
19    /// Number of rows = `n_heads * seq_q` (one threadgroup per row).
20    pub rows: u32,
21    /// Length of the reduction axis = `seq_k`.  Must match mask row
22    /// length.
23    pub cols: u32,
24    /// Number of query rows per head (= `seq_q`).  Lets the kernel
25    /// derive `q = row_idx % seq_q` for the shared mask index.
26    pub seq_q: u32,
27    /// Pre-softmax multiplicative scale (e.g. `1.0 / sqrt(head_dim)`).
28    pub scale: f32,
29}
30
31#[repr(C)]
32#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
33struct ScaleMaskSoftmaxGpuParams {
34    cols: u32,
35    seq_q: u32,
36    scale: f32,
37    _pad: u32,
38}
39
40/// Threadgroup size for the softmax reduction.  Matches
41/// `softmax::dispatch_softmax_f32`'s convention — 256 threads per
42/// threadgroup yields enough parallelism at cols = seq_k = 2455 without
43/// overallocating shmem.
44const THREADS_PER_TG: u64 = 256;
45
46/// Dispatches `scale_mask_softmax_f32`.
47///
48/// # Errors
49///
50/// `MlxError::InvalidArgument` on buffer-size or shape inconsistencies.
51pub fn dispatch_scale_mask_softmax_f32(
52    encoder: &mut CommandEncoder,
53    registry: &mut KernelRegistry,
54    device: &MlxDevice,
55    input: &MlxBuffer,
56    output: &MlxBuffer,
57    mask_bf16: &MlxBuffer,
58    params: &ScaleMaskSoftmaxParams,
59) -> Result<()> {
60    if params.rows == 0 || params.cols == 0 || params.seq_q == 0 {
61        return Err(MlxError::InvalidArgument(
62            "scale_mask_softmax_f32: rows/cols/seq_q must be > 0".into(),
63        ));
64    }
65    if params.rows % params.seq_q != 0 {
66        return Err(MlxError::InvalidArgument(format!(
67            "scale_mask_softmax_f32: rows ({}) must be a multiple of seq_q ({})",
68            params.rows, params.seq_q
69        )));
70    }
71
72    let f32_sz = DType::F32.size_of();
73    let bf16_sz = DType::BF16.size_of();
74
75    let expected_input_bytes = (params.rows as usize) * (params.cols as usize) * f32_sz;
76    if input.byte_len() < expected_input_bytes {
77        return Err(MlxError::InvalidArgument(format!(
78            "scale_mask_softmax_f32: input too small: expected {} bytes, got {}",
79            expected_input_bytes, input.byte_len()
80        )));
81    }
82    if output.byte_len() < expected_input_bytes {
83        return Err(MlxError::InvalidArgument(format!(
84            "scale_mask_softmax_f32: output too small: expected {} bytes, got {}",
85            expected_input_bytes, output.byte_len()
86        )));
87    }
88    let expected_mask_bytes = (params.seq_q as usize) * (params.cols as usize) * bf16_sz;
89    if mask_bf16.byte_len() < expected_mask_bytes {
90        return Err(MlxError::InvalidArgument(format!(
91            "scale_mask_softmax_f32: mask too small: expected {} bytes for [{}x{}] bf16, got {}",
92            expected_mask_bytes, params.seq_q, params.cols, mask_bf16.byte_len()
93        )));
94    }
95
96    // ADR-029 iter-93 H71: route to float4-vectorized v4 kernel when env
97    // HF2Q_SOFTMAX_V4=1 AND cols % 4 == 0 (vector alignment requirement).
98    // Default OFF until coherence + thermal-fair bench parity proven.
99    let use_v4 = (params.cols % 4 == 0) && match std::env::var("HF2Q_SOFTMAX_V4").as_deref() {
100        Ok("1") | Ok("true") | Ok("True") | Ok("TRUE") | Ok("yes") | Ok("YES") => true,
101        _ => false,
102    };
103    let kernel_name = if use_v4 {
104        "scale_mask_softmax_f32_v4"
105    } else {
106        "scale_mask_softmax_f32"
107    };
108    let pipeline = registry
109        .get_pipeline(kernel_name, device.metal_device())?;
110
111    let gpu_params = ScaleMaskSoftmaxGpuParams {
112        cols: params.cols,
113        seq_q: params.seq_q,
114        scale: params.scale,
115        _pad: 0,
116    };
117
118    let threadgroups = metal::MTLSize::new(params.rows as u64, 1, 1);
119    let tg_size = std::cmp::min(THREADS_PER_TG, params.cols as u64);
120    let threads_per_tg = metal::MTLSize::new(tg_size, 1, 1);
121    let shmem_bytes = tg_size * (f32_sz as u64);
122
123    encoder.encode_threadgroups_with_args_and_shared(
124        pipeline,
125        &[
126            (0, KernelArg::Buffer(input)),
127            (1, KernelArg::Buffer(output)),
128            (2, KernelArg::Buffer(mask_bf16)),
129            (3, KernelArg::Bytes(as_bytes(&gpu_params))),
130        ],
131        &[(0, shmem_bytes)],
132        threadgroups,
133        threads_per_tg,
134    );
135
136    Ok(())
137}