1use metal::MTLSize;
12
13use crate::encoder::CommandEncoder;
14use crate::error::{MlxError, Result};
15use crate::kernel_registry::KernelRegistry;
16use crate::buffer::MlxBuffer;
17use crate::ops::encode_helpers::KernelArg;
18use crate::ops::quantized_matmul_ggml::GgmlType;
19use crate::DType;
20
21const QK_NL_K: u32 = 16;
24
25const QK_NL_LEGACY: u32 = 2;
28
29pub fn dispatch_dequant_to_f16(
43 encoder: &mut CommandEncoder,
44 registry: &mut KernelRegistry,
45 device: &metal::DeviceRef,
46 weight: &MlxBuffer,
47 f16_shadow: &MlxBuffer,
48 n_rows: u32,
49 n_cols: u32,
50 ggml_type: GgmlType,
51) -> Result<()> {
52 let (block_values, qk_nl, kernel_name) = match ggml_type {
54 GgmlType::Q4_0 => (32u32, QK_NL_LEGACY, "hf2q_dequant_q4_0_to_f16"),
55 GgmlType::Q8_0 => (32, QK_NL_LEGACY, "hf2q_dequant_q8_0_to_f16"),
56 GgmlType::Q5_1 => (32, QK_NL_LEGACY, "hf2q_dequant_q5_1_to_f16"),
57 GgmlType::IQ4_NL => (32, QK_NL_LEGACY, "hf2q_dequant_iq4_nl_to_f16"),
58 GgmlType::Q4_K => (256, QK_NL_K, "hf2q_dequant_q4_K_to_f16"),
59 GgmlType::Q5_K => (256, QK_NL_K, "hf2q_dequant_q5_K_to_f16"),
60 GgmlType::Q6_K => (256, QK_NL_K, "hf2q_dequant_q6_K_to_f16"),
61 other => {
62 return Err(MlxError::InvalidArgument(format!(
63 "dispatch_dequant_to_f16: unsupported ggml_type {:?} \
64 (only Q4_0 / Q8_0 / Q5_1 / IQ4_NL / Q4_K / Q5_K / Q6_K)",
65 other
66 )));
67 }
68 };
69
70 if n_rows == 0 || n_cols == 0 {
72 return Err(MlxError::InvalidArgument(
73 "dispatch_dequant_to_f16: n_rows and n_cols must be > 0".into(),
74 ));
75 }
76 if n_cols % block_values != 0 {
77 return Err(MlxError::InvalidArgument(format!(
78 "dispatch_dequant_to_f16: n_cols ({}) must be divisible by block_values ({}) for {:?}",
79 n_cols, block_values, ggml_type
80 )));
81 }
82 if f16_shadow.dtype() != DType::F16 {
83 return Err(MlxError::InvalidArgument(format!(
84 "dispatch_dequant_to_f16: f16_shadow must be DType::F16, got {:?}",
85 f16_shadow.dtype()
86 )));
87 }
88
89 let n_elements = (n_rows as u64) * (n_cols as u64);
90 let needed_bytes = n_elements * 2;
91 if (f16_shadow.byte_len() as u64) < needed_bytes {
92 return Err(MlxError::InvalidArgument(format!(
93 "dispatch_dequant_to_f16: f16_shadow too small ({} bytes; need {})",
94 f16_shadow.byte_len(),
95 needed_bytes
96 )));
97 }
98
99 let n_groups: u32 = (n_elements / 16) as u32;
102 if (n_elements as u64) != (n_groups as u64) * 16 {
103 return Err(MlxError::InvalidArgument(format!(
104 "dispatch_dequant_to_f16: total elements ({}) must be a multiple of 16 \
105 (got n_rows={}, n_cols={})",
106 n_elements, n_rows, n_cols
107 )));
108 }
109
110 let pipeline = registry.get_pipeline(kernel_name, device)?;
111
112 const TG_SIZE: u64 = 256;
115 let n_tg = ((n_groups as u64) + TG_SIZE - 1) / TG_SIZE;
116
117 let threadgroups = MTLSize::new(n_tg, 1, 1);
118 let threads_per_tg = MTLSize::new(TG_SIZE, 1, 1);
119
120 let n_groups_bytes = n_groups.to_ne_bytes();
121
122 encoder.encode_threadgroups_with_args(
123 pipeline,
124 &[
125 (0, KernelArg::Bytes(&n_groups_bytes)),
126 (1, KernelArg::Buffer(weight)),
127 (2, KernelArg::Buffer(f16_shadow)),
128 ],
129 threadgroups,
130 threads_per_tg,
131 );
132
133 let _ = block_values;
136 let _ = qk_nl;
137 Ok(())
138}
139
140pub fn materialize_f16_shadow(
147 device: &crate::MlxDevice,
148 registry: &mut KernelRegistry,
149 weight: &MlxBuffer,
150 n_rows: u32,
151 n_cols: u32,
152 ggml_type: GgmlType,
153) -> Result<MlxBuffer> {
154 let n_elements = (n_rows as usize) * (n_cols as usize);
155 let f16_shadow = device
156 .alloc_buffer(n_elements * 2, DType::F16, vec![n_rows as usize, n_cols as usize])
157 .map_err(|e| MlxError::InvalidArgument(format!("materialize_f16_shadow alloc: {e}")))?;
158
159 let mut encoder = device
160 .command_encoder()
161 .map_err(|e| MlxError::InvalidArgument(format!("materialize_f16_shadow encoder: {e}")))?;
162
163 dispatch_dequant_to_f16(
164 &mut encoder,
165 registry,
166 device.metal_device(),
167 weight,
168 &f16_shadow,
169 n_rows,
170 n_cols,
171 ggml_type,
172 )?;
173
174 encoder
175 .commit_and_wait()
176 .map_err(|e| MlxError::InvalidArgument(format!("materialize_f16_shadow commit: {e}")))?;
177
178 Ok(f16_shadow)
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184 use crate::MlxDevice;
185
186 #[test]
190 fn dequant_q8_0_to_f16_roundtrip() {
191 const N_BLOCKS: usize = 2;
193 const N_ELEMENTS: usize = N_BLOCKS * 32;
194
195 let device = MlxDevice::new().expect("new device");
196
197 let block_bytes = 2 + 32; let mut src: Vec<u8> = vec![0u8; N_BLOCKS * block_bytes];
202 for b in 0..N_BLOCKS {
203 src[b * block_bytes + 0] = 0x00;
205 src[b * block_bytes + 1] = 0x38;
206 for i in 0..32 {
207 src[b * block_bytes + 2 + i] = ((b * 32 + i) % 128) as u8;
208 }
209 }
210
211 let mut weight = device
212 .alloc_buffer(src.len(), DType::U8, vec![src.len()])
213 .expect("alloc src");
214 weight
215 .as_mut_slice::<u8>()
216 .expect("slice src")
217 .copy_from_slice(&src);
218
219 let f16_shadow = device
220 .alloc_buffer(N_ELEMENTS * 2, DType::F16, vec![N_ELEMENTS])
221 .expect("alloc f16");
222
223 let mut registry = KernelRegistry::new();
224 let mut encoder = device.command_encoder().expect("encoder");
225
226 let res = dispatch_dequant_to_f16(
229 &mut encoder,
230 &mut registry,
231 device.metal_device(),
232 &weight,
233 &f16_shadow,
234 1,
235 N_ELEMENTS as u32,
236 GgmlType::Q8_0,
237 );
238 res.expect("dispatch ok");
241
242 encoder.commit_and_wait().expect("commit");
243
244 let out: &[u16] = f16_shadow.as_slice().expect("read f16");
246 assert_eq!(out[0], 0x0000, "out[0] should be F16 0.0, got 0x{:04X}", out[0]);
248 assert_eq!(out[1], 0x3800, "out[1] should be F16 0.5, got 0x{:04X}", out[1]);
250 assert_eq!(out[2], 0x3C00, "out[2] should be F16 1.0, got 0x{:04X}", out[2]);
252 }
253}