1use anyhow::Result;
2use half::f16;
3
4const QK8_0: usize = 32;
5const QK_K: usize = 256;
6const K_SCALE_SIZE: usize = 12;
7
8pub fn decode_f16(bytes: [u8; 2]) -> f32 {
9 f16::from_bits(u16::from_le_bytes(bytes)).to_f32()
10}
11
12pub fn dequantize_q8_0_block(block: &[u8]) -> Result<Vec<f32>> {
13 if block.len() != 2 + QK8_0 {
14 anyhow::bail!("Bloque Q8_0 invalido: {} bytes", block.len());
15 }
16 let d = decode_f16([block[0], block[1]]);
17 let mut out = Vec::with_capacity(QK8_0);
18 for quant in &block[2..] {
19 out.push(d * (*quant as i8) as f32);
20 }
21 Ok(out)
22}
23
24fn get_scale_min_k4(j: usize, scales: &[u8]) -> (u8, u8) {
25 if j < 4 {
26 (scales[j] & 63, scales[j + 4] & 63)
27 } else {
28 let d = (scales[j + 4] & 0x0F) | ((scales[j - 4] >> 6) << 4);
29 let m = (scales[j + 4] >> 4) | ((scales[j] >> 6) << 4);
30 (d, m)
31 }
32}
33
34pub fn dequantize_q4_k_block(block: &[u8]) -> Result<Vec<f32>> {
35 let expected = 2 * 2 + K_SCALE_SIZE + QK_K / 2;
36 if block.len() != expected {
37 anyhow::bail!("Bloque Q4_K invalido: {} bytes", block.len());
38 }
39
40 let d = decode_f16([block[0], block[1]]);
41 let dmin = decode_f16([block[2], block[3]]);
42 let scales = &block[4..4 + K_SCALE_SIZE];
43 let quants = &block[4 + K_SCALE_SIZE..];
44
45 let mut out = Vec::with_capacity(QK_K);
46 let mut is = 0usize;
47 let mut q_offset = 0usize;
48
49 for _ in (0..QK_K).step_by(64) {
50 let (sc1, m1) = get_scale_min_k4(is, scales);
51 let d1 = d * f32::from(sc1);
52 let min1 = dmin * f32::from(m1);
53 let (sc2, m2) = get_scale_min_k4(is + 1, scales);
54 let d2 = d * f32::from(sc2);
55 let min2 = dmin * f32::from(m2);
56
57 for l in 0..32 {
58 out.push(d1 * f32::from(quants[q_offset + l] & 0x0F) - min1);
59 }
60 for l in 0..32 {
61 out.push(d2 * f32::from(quants[q_offset + l] >> 4) - min2);
62 }
63
64 q_offset += 32;
65 is += 2;
66 }
67
68 Ok(out)
69}
70
71pub fn dequantize_q6_k_block(block: &[u8]) -> Result<Vec<f32>> {
72 let expected = 2 + QK_K / 16 + (3 * QK_K) / 4;
73 if block.len() != expected {
74 anyhow::bail!("Bloque Q6_K invalido: {} bytes", block.len());
75 }
76
77 let ql_len = QK_K / 2;
78 let qh_len = QK_K / 4;
79 let scales_len = QK_K / 16;
80
81 let ql = &block[..ql_len];
82 let qh = &block[ql_len..ql_len + qh_len];
83 let scales = &block[ql_len + qh_len..ql_len + qh_len + scales_len];
84 let d = decode_f16([
85 block[ql_len + qh_len + scales_len],
86 block[ql_len + qh_len + scales_len + 1],
87 ]);
88
89 let mut out = vec![0.0f32; QK_K];
90 let mut ql_offset = 0usize;
91 let mut qh_offset = 0usize;
92 let mut scales_offset = 0usize;
93 let mut y_offset = 0usize;
94
95 for _ in (0..QK_K).step_by(128) {
96 for l in 0..32 {
97 let is = l / 16;
98 let qh_byte = qh[qh_offset + l];
99 let q1 = (((ql[ql_offset + l] & 0x0F) | (((qh_byte >> 0) & 3) << 4)) as i8) - 32;
100 let q2 = (((ql[ql_offset + 32 + l] & 0x0F) | (((qh_byte >> 2) & 3) << 4)) as i8) - 32;
101 let q3 = (((ql[ql_offset + l] >> 4) | (((qh_byte >> 4) & 3) << 4)) as i8) - 32;
102 let q4 = (((ql[ql_offset + 32 + l] >> 4) | (((qh_byte >> 6) & 3) << 4)) as i8) - 32;
103
104 out[y_offset + l] = d * f32::from(scales[scales_offset + is] as i8) * f32::from(q1);
105 out[y_offset + 32 + l] = d * f32::from(scales[scales_offset + 2 + is] as i8) * f32::from(q2);
106 out[y_offset + 64 + l] = d * f32::from(scales[scales_offset + 4 + is] as i8) * f32::from(q3);
107 out[y_offset + 96 + l] = d * f32::from(scales[scales_offset + 6 + is] as i8) * f32::from(q4);
108 }
109 ql_offset += 64;
110 qh_offset += 32;
111 scales_offset += 8;
112 y_offset += 128;
113 }
114
115 Ok(out)
116}
117
118pub fn dot_q8_0_f32(qdata: &[u8], vec_f32: &[f32], num_elements: usize) -> f32 {
121 let block_size = QK8_0;
122 let type_size = 2 + QK8_0;
123 let num_blocks = num_elements / block_size;
124
125 #[cfg(target_arch = "x86_64")]
126 {
127 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
128 return unsafe { dot_q8_0_f32_avx2(qdata, vec_f32, num_blocks, type_size, block_size) };
130 }
131 }
132
133 dot_q8_0_f32_scalar(qdata, vec_f32, num_blocks, type_size, block_size)
134}
135
136fn dot_q8_0_f32_scalar(qdata: &[u8], vec_f32: &[f32], num_blocks: usize, type_size: usize, block_size: usize) -> f32 {
137 let mut acc = 0.0f32;
138 for block_idx in 0..num_blocks {
139 let block_start = block_idx * type_size;
140 let block = &qdata[block_start..block_start + type_size];
141 let d = decode_f16([block[0], block[1]]);
142 let quants = &block[2..];
143 let vec_offset = block_idx * block_size;
144
145 let mut block_acc = 0.0f32;
146 for i in 0..block_size {
147 block_acc += (quants[i] as i8) as f32 * vec_f32[vec_offset + i];
148 }
149 acc += d * block_acc;
150 }
151 acc
152}
153
154#[cfg(target_arch = "x86_64")]
155#[target_feature(enable = "avx2,fma")]
156unsafe fn dot_q8_0_f32_avx2(
157 qdata: &[u8],
158 vec_f32: &[f32],
159 num_blocks: usize,
160 type_size: usize,
161 block_size: usize,
162) -> f32 {
163 use std::arch::x86_64::*;
164
165 let mut acc = _mm256_setzero_ps();
166
167 for block_idx in 0..num_blocks {
168 let block_start = block_idx * type_size;
169 let block = &qdata[block_start..block_start + type_size];
170 let d = decode_f16([block[0], block[1]]);
171 let quants = block.as_ptr().add(2);
172 let vec_offset = block_idx * block_size;
173 let vec_ptr = vec_f32.as_ptr().add(vec_offset);
174
175 let d_vec = _mm256_set1_ps(d);
176
177 let q_bytes = _mm_loadl_epi64(quants as *const __m128i);
180 let q_i32 = _mm256_cvtepi8_epi32(q_bytes);
181 let q_f32 = _mm256_cvtepi32_ps(q_i32);
182 let x_f32 = _mm256_loadu_ps(vec_ptr);
183 acc = _mm256_fmadd_ps(_mm256_mul_ps(d_vec, q_f32), x_f32, acc);
184
185 let q_bytes = _mm_loadl_epi64(quants.add(8) as *const __m128i);
187 let q_i32 = _mm256_cvtepi8_epi32(q_bytes);
188 let q_f32 = _mm256_cvtepi32_ps(q_i32);
189 let x_f32 = _mm256_loadu_ps(vec_ptr.add(8));
190 acc = _mm256_fmadd_ps(_mm256_mul_ps(d_vec, q_f32), x_f32, acc);
191
192 let q_bytes = _mm_loadl_epi64(quants.add(16) as *const __m128i);
194 let q_i32 = _mm256_cvtepi8_epi32(q_bytes);
195 let q_f32 = _mm256_cvtepi32_ps(q_i32);
196 let x_f32 = _mm256_loadu_ps(vec_ptr.add(16));
197 acc = _mm256_fmadd_ps(_mm256_mul_ps(d_vec, q_f32), x_f32, acc);
198
199 let q_bytes = _mm_loadl_epi64(quants.add(24) as *const __m128i);
201 let q_i32 = _mm256_cvtepi8_epi32(q_bytes);
202 let q_f32 = _mm256_cvtepi32_ps(q_i32);
203 let x_f32 = _mm256_loadu_ps(vec_ptr.add(24));
204 acc = _mm256_fmadd_ps(_mm256_mul_ps(d_vec, q_f32), x_f32, acc);
205 }
206
207 let hi = _mm256_extractf128_ps(acc, 1);
209 let lo = _mm256_castps256_ps128(acc);
210 let sum128 = _mm_add_ps(lo, hi);
211 let shuf = _mm_movehdup_ps(sum128);
212 let sums = _mm_add_ps(sum128, shuf);
213 let shuf2 = _mm_movehl_ps(sums, sums);
214 let result = _mm_add_ss(sums, shuf2);
215 _mm_cvtss_f32(result)
216}
217
218pub fn dot_q6_k_f32(qdata: &[u8], vec_f32: &[f32], num_elements: usize) -> f32 {
220 let ql_per_block = QK_K / 2;
221 let qh_per_block = QK_K / 4;
222 let scales_per_block = QK_K / 16;
223 let type_size = ql_per_block + qh_per_block + scales_per_block + 2;
224 let num_blocks = num_elements / QK_K;
225 let mut acc = 0.0f32;
226
227 for block_idx in 0..num_blocks {
228 let block_start = block_idx * type_size;
229 let block = &qdata[block_start..block_start + type_size];
230
231 let ql = &block[..ql_per_block];
232 let qh = &block[ql_per_block..ql_per_block + qh_per_block];
233 let scales = &block[ql_per_block + qh_per_block..ql_per_block + qh_per_block + scales_per_block];
234 let d = decode_f16([
235 block[ql_per_block + qh_per_block + scales_per_block],
236 block[ql_per_block + qh_per_block + scales_per_block + 1],
237 ]);
238
239 let vec_offset = block_idx * QK_K;
240 let mut ql_off = 0usize;
241 let mut qh_off = 0usize;
242 let mut sc_off = 0usize;
243 let mut y_off = 0usize;
244
245 for _ in (0..QK_K).step_by(128) {
246 for l in 0..32 {
247 let is = l / 16;
248 let qh_byte = qh[qh_off + l];
249 let q1 = (((ql[ql_off + l] & 0x0F) | (((qh_byte >> 0) & 3) << 4)) as i8) - 32;
250 let q2 = (((ql[ql_off + 32 + l] & 0x0F) | (((qh_byte >> 2) & 3) << 4)) as i8) - 32;
251 let q3 = (((ql[ql_off + l] >> 4) | (((qh_byte >> 4) & 3) << 4)) as i8) - 32;
252 let q4 = (((ql[ql_off + 32 + l] >> 4) | (((qh_byte >> 6) & 3) << 4)) as i8) - 32;
253
254 let s1 = d * (scales[sc_off + is] as i8) as f32;
255 let s2 = d * (scales[sc_off + 2 + is] as i8) as f32;
256 let s3 = d * (scales[sc_off + 4 + is] as i8) as f32;
257 let s4 = d * (scales[sc_off + 6 + is] as i8) as f32;
258
259 acc += s1 * q1 as f32 * vec_f32[vec_offset + y_off + l];
260 acc += s2 * q2 as f32 * vec_f32[vec_offset + y_off + 32 + l];
261 acc += s3 * q3 as f32 * vec_f32[vec_offset + y_off + 64 + l];
262 acc += s4 * q4 as f32 * vec_f32[vec_offset + y_off + 96 + l];
263 }
264 ql_off += 64;
265 qh_off += 32;
266 sc_off += 8;
267 y_off += 128;
268 }
269 }
270 acc
271}
272
273pub fn dot_q4_k_f32(qdata: &[u8], vec_f32: &[f32], num_elements: usize) -> f32 {
275 let type_size = 2 * 2 + K_SCALE_SIZE + QK_K / 2;
276 let num_blocks = num_elements / QK_K;
277 let mut acc = 0.0f32;
278
279 for block_idx in 0..num_blocks {
280 let block_start = block_idx * type_size;
281 let block = &qdata[block_start..block_start + type_size];
282
283 let d = decode_f16([block[0], block[1]]);
284 let dmin = decode_f16([block[2], block[3]]);
285 let scales = &block[4..4 + K_SCALE_SIZE];
286 let quants = &block[4 + K_SCALE_SIZE..];
287
288 let vec_offset = block_idx * QK_K;
289 let mut is = 0usize;
290 let mut q_off = 0usize;
291 let mut y_off = 0usize;
292
293 for _ in (0..QK_K).step_by(64) {
294 let (sc1, m1) = get_scale_min_k4(is, scales);
295 let d1 = d * f32::from(sc1);
296 let min1 = dmin * f32::from(m1);
297 let (sc2, m2) = get_scale_min_k4(is + 1, scales);
298 let d2 = d * f32::from(sc2);
299 let min2 = dmin * f32::from(m2);
300
301 for l in 0..32 {
302 let val = d1 * f32::from(quants[q_off + l] & 0x0F) - min1;
303 acc += val * vec_f32[vec_offset + y_off + l];
304 }
305 for l in 0..32 {
306 let val = d2 * f32::from(quants[q_off + l] >> 4) - min2;
307 acc += val * vec_f32[vec_offset + y_off + 32 + l];
308 }
309
310 q_off += 32;
311 is += 2;
312 y_off += 64;
313 }
314 }
315 acc
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321
322 #[test]
323 fn q8_0_block_dequantizes_expected_values() {
324 let mut block = vec![0u8; 34];
325 block[0..2].copy_from_slice(&f16::from_f32(0.5).to_bits().to_le_bytes());
326 for (index, value) in [2i8, -4, 6, -8].into_iter().enumerate() {
327 block[2 + index] = value as u8;
328 }
329 let out = dequantize_q8_0_block(&block).unwrap();
330 assert_eq!(out[0], 1.0);
331 assert_eq!(out[1], -2.0);
332 assert_eq!(out[2], 3.0);
333 assert_eq!(out[3], -4.0);
334 }
335
336 #[test]
337 fn q4_k_block_length_is_enforced() {
338 let err = dequantize_q4_k_block(&[0u8; 10]).unwrap_err().to_string();
339 assert!(err.contains("Bloque Q4_K invalido"));
340 }
341
342 #[test]
343 fn dot_q8_0_matches_dequantize_then_dot() {
344 let mut block = vec![0u8; 34]; block[0..2].copy_from_slice(&f16::from_f32(0.25).to_bits().to_le_bytes());
347 for i in 0..32u8 {
348 block[2 + i as usize] = (i as i8 - 16) as u8;
349 }
350 let vec_f32: Vec<f32> = (0..32).map(|i| i as f32 * 0.1).collect();
352
353 let dequant = dequantize_q8_0_block(&block).unwrap();
355 let expected: f32 = dequant.iter().zip(vec_f32.iter()).map(|(a, b)| a * b).sum();
356
357 let actual = dot_q8_0_f32(&block, &vec_f32, 32);
359
360 let diff = (expected - actual).abs();
361 assert!(diff < 1e-3, "dot_q8_0 diverge: expected={expected}, actual={actual}, diff={diff}");
362 }
363}