Skip to main content

mlx_native/ops/
dequant_to_f16.rs

1// ADR-029 iter-28 H29 — whole-tensor dequant from block-quantized formats
2// to F16, used at model load to materialize an F16 shadow of dense weights
3// so the runtime dispatch can use `kernel_mul_mm_f16_f32_*` (peer's
4// gemma4 pattern).
5//
6// See `src/shaders/dequant_to_f16.metal` for the kernel design and the
7// per-type instantiation list.  Public Rust API: one entry point
8// `dispatch_dequant_to_f16(...)` that picks the right kernel via the
9// caller-supplied `GgmlType`.
10
11use metal::MTLSize;
12
13use crate::encoder::CommandEncoder;
14use crate::error::{MlxError, Result};
15use crate::kernel_registry::KernelRegistry;
16use crate::buffer::MlxBuffer;
17use crate::ops::encode_helpers::KernelArg;
18use crate::ops::quantized_matmul_ggml::GgmlType;
19use crate::DType;
20
21/// Number of K-quants sub-groups per block_q5_K / block_q4_K / block_q6_K.
22/// Matches `QK_NL = 16` in `dequant_to_f16.metal`.
23const QK_NL_K: u32 = 16;
24
25/// Number of legacy-block sub-groups per block_q4_0 / block_q8_0 / etc.
26/// Matches `nl = 2` in `dequant_to_f16.metal`.
27const QK_NL_LEGACY: u32 = 2;
28
29/// Dispatch the whole-tensor dequant-to-F16 kernel.
30///
31/// `weight` is the source quantized buffer (caller-allocated, holds the
32/// GGUF-format bytes for `n_rows × n_cols` elements of `ggml_type`).
33/// `f16_shadow` is the destination buffer, must be at least
34/// `n_rows * n_cols * 2` bytes (F16 = 2 bytes/elem).
35///
36/// `n_rows` / `n_cols` are the logical tensor shape; the kernel writes
37/// `n_rows * n_cols` F16 values into `f16_shadow` in row-major order
38/// (matching the row-major dequant layout the matmul kernels expect).
39///
40/// Returns InvalidArgument if `ggml_type` is unsupported (F32 / F16 /
41/// I16 — no dequant needed) or if buffer sizes don't match.
42pub fn dispatch_dequant_to_f16(
43    encoder: &mut CommandEncoder,
44    registry: &mut KernelRegistry,
45    device: &metal::DeviceRef,
46    weight: &MlxBuffer,
47    f16_shadow: &MlxBuffer,
48    n_rows: u32,
49    n_cols: u32,
50    ggml_type: GgmlType,
51) -> Result<()> {
52    // Block size + sub-groups-per-block.
53    let (block_values, qk_nl, kernel_name) = match ggml_type {
54        GgmlType::Q4_0 => (32u32, QK_NL_LEGACY, "hf2q_dequant_q4_0_to_f16"),
55        GgmlType::Q8_0 => (32, QK_NL_LEGACY, "hf2q_dequant_q8_0_to_f16"),
56        GgmlType::Q5_1 => (32, QK_NL_LEGACY, "hf2q_dequant_q5_1_to_f16"),
57        GgmlType::IQ4_NL => (32, QK_NL_LEGACY, "hf2q_dequant_iq4_nl_to_f16"),
58        GgmlType::Q4_K => (256, QK_NL_K, "hf2q_dequant_q4_K_to_f16"),
59        GgmlType::Q5_K => (256, QK_NL_K, "hf2q_dequant_q5_K_to_f16"),
60        GgmlType::Q6_K => (256, QK_NL_K, "hf2q_dequant_q6_K_to_f16"),
61        other => {
62            return Err(MlxError::InvalidArgument(format!(
63                "dispatch_dequant_to_f16: unsupported ggml_type {:?} \
64                 (only Q4_0 / Q8_0 / Q5_1 / IQ4_NL / Q4_K / Q5_K / Q6_K)",
65                other
66            )));
67        }
68    };
69
70    // Validate shapes and buffer sizes.
71    if n_rows == 0 || n_cols == 0 {
72        return Err(MlxError::InvalidArgument(
73            "dispatch_dequant_to_f16: n_rows and n_cols must be > 0".into(),
74        ));
75    }
76    if n_cols % block_values != 0 {
77        return Err(MlxError::InvalidArgument(format!(
78            "dispatch_dequant_to_f16: n_cols ({}) must be divisible by block_values ({}) for {:?}",
79            n_cols, block_values, ggml_type
80        )));
81    }
82    if f16_shadow.dtype() != DType::F16 {
83        return Err(MlxError::InvalidArgument(format!(
84            "dispatch_dequant_to_f16: f16_shadow must be DType::F16, got {:?}",
85            f16_shadow.dtype()
86        )));
87    }
88
89    let n_elements = (n_rows as u64) * (n_cols as u64);
90    let needed_bytes = n_elements * 2;
91    if (f16_shadow.byte_len() as u64) < needed_bytes {
92        return Err(MlxError::InvalidArgument(format!(
93            "dispatch_dequant_to_f16: f16_shadow too small ({} bytes; need {})",
94            f16_shadow.byte_len(),
95            needed_bytes
96        )));
97    }
98
99    // Total threads = n_elements / 16.  Each thread dequants one 16-elem
100    // group (= one (block_idx, il) pair).
101    let n_groups: u32 = (n_elements / 16) as u32;
102    if (n_elements as u64) != (n_groups as u64) * 16 {
103        return Err(MlxError::InvalidArgument(format!(
104            "dispatch_dequant_to_f16: total elements ({}) must be a multiple of 16 \
105             (got n_rows={}, n_cols={})",
106            n_elements, n_rows, n_cols
107        )));
108    }
109
110    let pipeline = registry.get_pipeline(kernel_name, device)?;
111
112    // Threadgroup size: pick 256 for good occupancy.  Total grid = n_groups
113    // threadgroups of 256 threads, rounded up.
114    const TG_SIZE: u64 = 256;
115    let n_tg = ((n_groups as u64) + TG_SIZE - 1) / TG_SIZE;
116
117    let threadgroups = MTLSize::new(n_tg, 1, 1);
118    let threads_per_tg = MTLSize::new(TG_SIZE, 1, 1);
119
120    let n_groups_bytes = n_groups.to_ne_bytes();
121
122    encoder.encode_threadgroups_with_args(
123        pipeline,
124        &[
125            (0, KernelArg::Bytes(&n_groups_bytes)),
126            (1, KernelArg::Buffer(weight)),
127            (2, KernelArg::Buffer(f16_shadow)),
128        ],
129        threadgroups,
130        threads_per_tg,
131    );
132
133    // ABI sanity: silence dead-arg warning when block_values is only used
134    // in the validation arm above.
135    let _ = block_values;
136    let _ = qk_nl;
137    Ok(())
138}
139
140/// One-shot helper: allocate an F16 shadow buffer + dispatch + commit-and-wait.
141///
142/// Intended for use at model load — caller has the source quantized buffer
143/// already on GPU and wants a paired F16 shadow.  Returns the new F16
144/// buffer.  Performs a `commit_and_wait` so the buffer is ready for
145/// downstream use by the time this function returns.
146pub fn materialize_f16_shadow(
147    device: &crate::MlxDevice,
148    registry: &mut KernelRegistry,
149    weight: &MlxBuffer,
150    n_rows: u32,
151    n_cols: u32,
152    ggml_type: GgmlType,
153) -> Result<MlxBuffer> {
154    let n_elements = (n_rows as usize) * (n_cols as usize);
155    let f16_shadow = device
156        .alloc_buffer(n_elements * 2, DType::F16, vec![n_rows as usize, n_cols as usize])
157        .map_err(|e| MlxError::InvalidArgument(format!("materialize_f16_shadow alloc: {e}")))?;
158
159    let mut encoder = device
160        .command_encoder()
161        .map_err(|e| MlxError::InvalidArgument(format!("materialize_f16_shadow encoder: {e}")))?;
162
163    dispatch_dequant_to_f16(
164        &mut encoder,
165        registry,
166        device.metal_device(),
167        weight,
168        &f16_shadow,
169        n_rows,
170        n_cols,
171        ggml_type,
172    )?;
173
174    encoder
175        .commit_and_wait()
176        .map_err(|e| MlxError::InvalidArgument(format!("materialize_f16_shadow commit: {e}")))?;
177
178    Ok(f16_shadow)
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184    use crate::MlxDevice;
185
186    /// Round-trip: build a tiny Q8_0 tensor on CPU, dequant it via the
187    /// kernel, compare against a CPU-side dequant.  Confirms the kernel
188    /// produces correct F16 output for at least one type.
189    #[test]
190    fn dequant_q8_0_to_f16_roundtrip() {
191        // 1 block_q8_0 = 32 elements.  We use 2 blocks = 64 elements (1 row).
192        const N_BLOCKS: usize = 2;
193        const N_ELEMENTS: usize = N_BLOCKS * 32;
194
195        let device = MlxDevice::new().expect("new device");
196
197        // Build a block_q8_0 with d=0.5, qs = [0, 1, 2, ..., 31].
198        // dequant_q8_0_t at il produces 16 elements: qs[16*il + i] * d
199        // (i ∈ [0, 16))
200        let block_bytes = 2 + 32; // half (2) + 32 int8
201        let mut src: Vec<u8> = vec![0u8; N_BLOCKS * block_bytes];
202        for b in 0..N_BLOCKS {
203            // half(0.5) = 0x3800
204            src[b * block_bytes + 0] = 0x00;
205            src[b * block_bytes + 1] = 0x38;
206            for i in 0..32 {
207                src[b * block_bytes + 2 + i] = ((b * 32 + i) % 128) as u8;
208            }
209        }
210
211        let mut weight = device
212            .alloc_buffer(src.len(), DType::U8, vec![src.len()])
213            .expect("alloc src");
214        weight
215            .as_mut_slice::<u8>()
216            .expect("slice src")
217            .copy_from_slice(&src);
218
219        let f16_shadow = device
220            .alloc_buffer(N_ELEMENTS * 2, DType::F16, vec![N_ELEMENTS])
221            .expect("alloc f16");
222
223        let mut registry = KernelRegistry::new();
224        let mut encoder = device.command_encoder().expect("encoder");
225
226        // Note: this test uses raw row dimensions; the kernel only cares about
227        // total elements = n_rows * n_cols.  We test as a 1×64 tensor.
228        let res = dispatch_dequant_to_f16(
229            &mut encoder,
230            &mut registry,
231            device.metal_device(),
232            &weight,
233            &f16_shadow,
234            1,
235            N_ELEMENTS as u32,
236            GgmlType::Q8_0,
237        );
238        // dispatch_dequant_to_f16 returns Ok if the kernel was queued —
239        // actual GPU execution happens at commit.
240        res.expect("dispatch ok");
241
242        encoder.commit_and_wait().expect("commit");
243
244        // Read back and confirm element 0 = d * qs[0] = 0.5 * 0 = 0.
245        let out: &[u16] = f16_shadow.as_slice().expect("read f16");
246        // F16 0.0 = 0x0000.
247        assert_eq!(out[0], 0x0000, "out[0] should be F16 0.0, got 0x{:04X}", out[0]);
248        // F16(0.5 * 1) = F16(0.5) = 0x3800.
249        assert_eq!(out[1], 0x3800, "out[1] should be F16 0.5, got 0x{:04X}", out[1]);
250        // F16(0.5 * 2) = F16(1.0) = 0x3C00.
251        assert_eq!(out[2], 0x3C00, "out[2] should be F16 1.0, got 0x{:04X}", out[2]);
252    }
253}