Skip to main content

rage_quant/
ggml_quant.rs

1use anyhow::Result;
2use half::f16;
3
4const QK8_0: usize = 32;
5const QK_K: usize = 256;
6const K_SCALE_SIZE: usize = 12;
7
8pub fn decode_f16(bytes: [u8; 2]) -> f32 {
9    f16::from_bits(u16::from_le_bytes(bytes)).to_f32()
10}
11
12pub fn dequantize_q8_0_block(block: &[u8]) -> Result<Vec<f32>> {
13    if block.len() != 2 + QK8_0 {
14        anyhow::bail!("Bloque Q8_0 invalido: {} bytes", block.len());
15    }
16    let d = decode_f16([block[0], block[1]]);
17    let mut out = Vec::with_capacity(QK8_0);
18    for quant in &block[2..] {
19        out.push(d * (*quant as i8) as f32);
20    }
21    Ok(out)
22}
23
24fn get_scale_min_k4(j: usize, scales: &[u8]) -> (u8, u8) {
25    if j < 4 {
26        (scales[j] & 63, scales[j + 4] & 63)
27    } else {
28        let d = (scales[j + 4] & 0x0F) | ((scales[j - 4] >> 6) << 4);
29        let m = (scales[j + 4] >> 4) | ((scales[j] >> 6) << 4);
30        (d, m)
31    }
32}
33
34pub fn dequantize_q4_k_block(block: &[u8]) -> Result<Vec<f32>> {
35    let expected = 2 * 2 + K_SCALE_SIZE + QK_K / 2;
36    if block.len() != expected {
37        anyhow::bail!("Bloque Q4_K invalido: {} bytes", block.len());
38    }
39
40    let d = decode_f16([block[0], block[1]]);
41    let dmin = decode_f16([block[2], block[3]]);
42    let scales = &block[4..4 + K_SCALE_SIZE];
43    let quants = &block[4 + K_SCALE_SIZE..];
44
45    let mut out = Vec::with_capacity(QK_K);
46    let mut is = 0usize;
47    let mut q_offset = 0usize;
48
49    for _ in (0..QK_K).step_by(64) {
50        let (sc1, m1) = get_scale_min_k4(is, scales);
51        let d1 = d * f32::from(sc1);
52        let min1 = dmin * f32::from(m1);
53        let (sc2, m2) = get_scale_min_k4(is + 1, scales);
54        let d2 = d * f32::from(sc2);
55        let min2 = dmin * f32::from(m2);
56
57        for l in 0..32 {
58            out.push(d1 * f32::from(quants[q_offset + l] & 0x0F) - min1);
59        }
60        for l in 0..32 {
61            out.push(d2 * f32::from(quants[q_offset + l] >> 4) - min2);
62        }
63
64        q_offset += 32;
65        is += 2;
66    }
67
68    Ok(out)
69}
70
71pub fn dequantize_q6_k_block(block: &[u8]) -> Result<Vec<f32>> {
72    let expected = 2 + QK_K / 16 + (3 * QK_K) / 4;
73    if block.len() != expected {
74        anyhow::bail!("Bloque Q6_K invalido: {} bytes", block.len());
75    }
76
77    let ql_len = QK_K / 2;
78    let qh_len = QK_K / 4;
79    let scales_len = QK_K / 16;
80
81    let ql = &block[..ql_len];
82    let qh = &block[ql_len..ql_len + qh_len];
83    let scales = &block[ql_len + qh_len..ql_len + qh_len + scales_len];
84    let d = decode_f16([
85        block[ql_len + qh_len + scales_len],
86        block[ql_len + qh_len + scales_len + 1],
87    ]);
88
89    let mut out = vec![0.0f32; QK_K];
90    let mut ql_offset = 0usize;
91    let mut qh_offset = 0usize;
92    let mut scales_offset = 0usize;
93    let mut y_offset = 0usize;
94
95    for _ in (0..QK_K).step_by(128) {
96        for l in 0..32 {
97            let is = l / 16;
98            let qh_byte = qh[qh_offset + l];
99            let q1 = (((ql[ql_offset + l] & 0x0F) | (((qh_byte >> 0) & 3) << 4)) as i8) - 32;
100            let q2 = (((ql[ql_offset + 32 + l] & 0x0F) | (((qh_byte >> 2) & 3) << 4)) as i8) - 32;
101            let q3 = (((ql[ql_offset + l] >> 4) | (((qh_byte >> 4) & 3) << 4)) as i8) - 32;
102            let q4 = (((ql[ql_offset + 32 + l] >> 4) | (((qh_byte >> 6) & 3) << 4)) as i8) - 32;
103
104            out[y_offset + l] = d * f32::from(scales[scales_offset + is] as i8) * f32::from(q1);
105            out[y_offset + 32 + l] = d * f32::from(scales[scales_offset + 2 + is] as i8) * f32::from(q2);
106            out[y_offset + 64 + l] = d * f32::from(scales[scales_offset + 4 + is] as i8) * f32::from(q3);
107            out[y_offset + 96 + l] = d * f32::from(scales[scales_offset + 6 + is] as i8) * f32::from(q4);
108        }
109        ql_offset += 64;
110        qh_offset += 32;
111        scales_offset += 8;
112        y_offset += 128;
113    }
114
115    Ok(out)
116}
117
118/// Dot product directo sobre bloques Q8_0 contra un vector f32.
119/// Evita dequantizar todo el tensor: procesa bloque a bloque acumulando.
120pub fn dot_q8_0_f32(qdata: &[u8], vec_f32: &[f32], num_elements: usize) -> f32 {
121    let block_size = QK8_0;
122    let type_size = 2 + QK8_0;
123    let num_blocks = num_elements / block_size;
124
125    #[cfg(target_arch = "x86_64")]
126    {
127        if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
128            // SAFETY: verificamos AVX2+FMA; los slices cuadran por la precondicion de bloques Q8_0
129            return unsafe { dot_q8_0_f32_avx2(qdata, vec_f32, num_blocks, type_size, block_size) };
130        }
131    }
132
133    dot_q8_0_f32_scalar(qdata, vec_f32, num_blocks, type_size, block_size)
134}
135
136fn dot_q8_0_f32_scalar(qdata: &[u8], vec_f32: &[f32], num_blocks: usize, type_size: usize, block_size: usize) -> f32 {
137    let mut acc = 0.0f32;
138    for block_idx in 0..num_blocks {
139        let block_start = block_idx * type_size;
140        let block = &qdata[block_start..block_start + type_size];
141        let d = decode_f16([block[0], block[1]]);
142        let quants = &block[2..];
143        let vec_offset = block_idx * block_size;
144
145        let mut block_acc = 0.0f32;
146        for i in 0..block_size {
147            block_acc += (quants[i] as i8) as f32 * vec_f32[vec_offset + i];
148        }
149        acc += d * block_acc;
150    }
151    acc
152}
153
154#[cfg(target_arch = "x86_64")]
155#[target_feature(enable = "avx2,fma")]
156unsafe fn dot_q8_0_f32_avx2(
157    qdata: &[u8],
158    vec_f32: &[f32],
159    num_blocks: usize,
160    type_size: usize,
161    block_size: usize,
162) -> f32 {
163    use std::arch::x86_64::*;
164
165    let mut acc = _mm256_setzero_ps();
166
167    for block_idx in 0..num_blocks {
168        let block_start = block_idx * type_size;
169        let block = &qdata[block_start..block_start + type_size];
170        let d = decode_f16([block[0], block[1]]);
171        let quants = block.as_ptr().add(2);
172        let vec_offset = block_idx * block_size;
173        let vec_ptr = vec_f32.as_ptr().add(vec_offset);
174
175        let d_vec = _mm256_set1_ps(d);
176
177        // Procesar 32 i8 valores en 4 grupos de 8
178        // Grupo 0: quants[0..8]
179        let q_bytes = _mm_loadl_epi64(quants as *const __m128i);
180        let q_i32 = _mm256_cvtepi8_epi32(q_bytes);
181        let q_f32 = _mm256_cvtepi32_ps(q_i32);
182        let x_f32 = _mm256_loadu_ps(vec_ptr);
183        acc = _mm256_fmadd_ps(_mm256_mul_ps(d_vec, q_f32), x_f32, acc);
184
185        // Grupo 1: quants[8..16]
186        let q_bytes = _mm_loadl_epi64(quants.add(8) as *const __m128i);
187        let q_i32 = _mm256_cvtepi8_epi32(q_bytes);
188        let q_f32 = _mm256_cvtepi32_ps(q_i32);
189        let x_f32 = _mm256_loadu_ps(vec_ptr.add(8));
190        acc = _mm256_fmadd_ps(_mm256_mul_ps(d_vec, q_f32), x_f32, acc);
191
192        // Grupo 2: quants[16..24]
193        let q_bytes = _mm_loadl_epi64(quants.add(16) as *const __m128i);
194        let q_i32 = _mm256_cvtepi8_epi32(q_bytes);
195        let q_f32 = _mm256_cvtepi32_ps(q_i32);
196        let x_f32 = _mm256_loadu_ps(vec_ptr.add(16));
197        acc = _mm256_fmadd_ps(_mm256_mul_ps(d_vec, q_f32), x_f32, acc);
198
199        // Grupo 3: quants[24..32]
200        let q_bytes = _mm_loadl_epi64(quants.add(24) as *const __m128i);
201        let q_i32 = _mm256_cvtepi8_epi32(q_bytes);
202        let q_f32 = _mm256_cvtepi32_ps(q_i32);
203        let x_f32 = _mm256_loadu_ps(vec_ptr.add(24));
204        acc = _mm256_fmadd_ps(_mm256_mul_ps(d_vec, q_f32), x_f32, acc);
205    }
206
207    // Reduccion horizontal del acumulador AVX2
208    let hi = _mm256_extractf128_ps(acc, 1);
209    let lo = _mm256_castps256_ps128(acc);
210    let sum128 = _mm_add_ps(lo, hi);
211    let shuf = _mm_movehdup_ps(sum128);
212    let sums = _mm_add_ps(sum128, shuf);
213    let shuf2 = _mm_movehl_ps(sums, sums);
214    let result = _mm_add_ss(sums, shuf2);
215    _mm_cvtss_f32(result)
216}
217
218/// Dot product directo sobre bloques Q6_K contra un vector f32.
219pub fn dot_q6_k_f32(qdata: &[u8], vec_f32: &[f32], num_elements: usize) -> f32 {
220    let ql_per_block = QK_K / 2;
221    let qh_per_block = QK_K / 4;
222    let scales_per_block = QK_K / 16;
223    let type_size = ql_per_block + qh_per_block + scales_per_block + 2;
224    let num_blocks = num_elements / QK_K;
225    let mut acc = 0.0f32;
226
227    for block_idx in 0..num_blocks {
228        let block_start = block_idx * type_size;
229        let block = &qdata[block_start..block_start + type_size];
230
231        let ql = &block[..ql_per_block];
232        let qh = &block[ql_per_block..ql_per_block + qh_per_block];
233        let scales = &block[ql_per_block + qh_per_block..ql_per_block + qh_per_block + scales_per_block];
234        let d = decode_f16([
235            block[ql_per_block + qh_per_block + scales_per_block],
236            block[ql_per_block + qh_per_block + scales_per_block + 1],
237        ]);
238
239        let vec_offset = block_idx * QK_K;
240        let mut ql_off = 0usize;
241        let mut qh_off = 0usize;
242        let mut sc_off = 0usize;
243        let mut y_off = 0usize;
244
245        for _ in (0..QK_K).step_by(128) {
246            for l in 0..32 {
247                let is = l / 16;
248                let qh_byte = qh[qh_off + l];
249                let q1 = (((ql[ql_off + l] & 0x0F) | (((qh_byte >> 0) & 3) << 4)) as i8) - 32;
250                let q2 = (((ql[ql_off + 32 + l] & 0x0F) | (((qh_byte >> 2) & 3) << 4)) as i8) - 32;
251                let q3 = (((ql[ql_off + l] >> 4) | (((qh_byte >> 4) & 3) << 4)) as i8) - 32;
252                let q4 = (((ql[ql_off + 32 + l] >> 4) | (((qh_byte >> 6) & 3) << 4)) as i8) - 32;
253
254                let s1 = d * (scales[sc_off + is] as i8) as f32;
255                let s2 = d * (scales[sc_off + 2 + is] as i8) as f32;
256                let s3 = d * (scales[sc_off + 4 + is] as i8) as f32;
257                let s4 = d * (scales[sc_off + 6 + is] as i8) as f32;
258
259                acc += s1 * q1 as f32 * vec_f32[vec_offset + y_off + l];
260                acc += s2 * q2 as f32 * vec_f32[vec_offset + y_off + 32 + l];
261                acc += s3 * q3 as f32 * vec_f32[vec_offset + y_off + 64 + l];
262                acc += s4 * q4 as f32 * vec_f32[vec_offset + y_off + 96 + l];
263            }
264            ql_off += 64;
265            qh_off += 32;
266            sc_off += 8;
267            y_off += 128;
268        }
269    }
270    acc
271}
272
273/// Dot product directo sobre bloques Q4_K contra un vector f32.
274pub fn dot_q4_k_f32(qdata: &[u8], vec_f32: &[f32], num_elements: usize) -> f32 {
275    let type_size = 2 * 2 + K_SCALE_SIZE + QK_K / 2;
276    let num_blocks = num_elements / QK_K;
277    let mut acc = 0.0f32;
278
279    for block_idx in 0..num_blocks {
280        let block_start = block_idx * type_size;
281        let block = &qdata[block_start..block_start + type_size];
282
283        let d = decode_f16([block[0], block[1]]);
284        let dmin = decode_f16([block[2], block[3]]);
285        let scales = &block[4..4 + K_SCALE_SIZE];
286        let quants = &block[4 + K_SCALE_SIZE..];
287
288        let vec_offset = block_idx * QK_K;
289        let mut is = 0usize;
290        let mut q_off = 0usize;
291        let mut y_off = 0usize;
292
293        for _ in (0..QK_K).step_by(64) {
294            let (sc1, m1) = get_scale_min_k4(is, scales);
295            let d1 = d * f32::from(sc1);
296            let min1 = dmin * f32::from(m1);
297            let (sc2, m2) = get_scale_min_k4(is + 1, scales);
298            let d2 = d * f32::from(sc2);
299            let min2 = dmin * f32::from(m2);
300
301            for l in 0..32 {
302                let val = d1 * f32::from(quants[q_off + l] & 0x0F) - min1;
303                acc += val * vec_f32[vec_offset + y_off + l];
304            }
305            for l in 0..32 {
306                let val = d2 * f32::from(quants[q_off + l] >> 4) - min2;
307                acc += val * vec_f32[vec_offset + y_off + 32 + l];
308            }
309
310            q_off += 32;
311            is += 2;
312            y_off += 64;
313        }
314    }
315    acc
316}
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321
322    #[test]
323    fn q8_0_block_dequantizes_expected_values() {
324        let mut block = vec![0u8; 34];
325        block[0..2].copy_from_slice(&f16::from_f32(0.5).to_bits().to_le_bytes());
326        for (index, value) in [2i8, -4, 6, -8].into_iter().enumerate() {
327            block[2 + index] = value as u8;
328        }
329        let out = dequantize_q8_0_block(&block).unwrap();
330        assert_eq!(out[0], 1.0);
331        assert_eq!(out[1], -2.0);
332        assert_eq!(out[2], 3.0);
333        assert_eq!(out[3], -4.0);
334    }
335
336    #[test]
337    fn q4_k_block_length_is_enforced() {
338        let err = dequantize_q4_k_block(&[0u8; 10]).unwrap_err().to_string();
339        assert!(err.contains("Bloque Q4_K invalido"));
340    }
341
342    #[test]
343    fn dot_q8_0_matches_dequantize_then_dot() {
344        // Crear un bloque Q8_0 con valores conocidos
345        let mut block = vec![0u8; 34]; // 2 + 32
346        block[0..2].copy_from_slice(&f16::from_f32(0.25).to_bits().to_le_bytes());
347        for i in 0..32u8 {
348            block[2 + i as usize] = (i as i8 - 16) as u8;
349        }
350        // Vector f32 de referencia
351        let vec_f32: Vec<f32> = (0..32).map(|i| i as f32 * 0.1).collect();
352
353        // Metodo 1: dequantizar y luego dot
354        let dequant = dequantize_q8_0_block(&block).unwrap();
355        let expected: f32 = dequant.iter().zip(vec_f32.iter()).map(|(a, b)| a * b).sum();
356
357        // Metodo 2: dot directo cuantizado
358        let actual = dot_q8_0_f32(&block, &vec_f32, 32);
359
360        let diff = (expected - actual).abs();
361        assert!(diff < 1e-3, "dot_q8_0 diverge: expected={expected}, actual={actual}, diff={diff}");
362    }
363}