Skip to main content

trueno_quant/
quantize.rs

1//! K-Quant quantization functions (`Q4_K`, `Q5_K`, `Q6_K`)
2//!
3//! Shared helpers and all quantize functions extracted from lib.rs.
4
5use crate::{f32_to_f16, F16_MIN_NORMAL};
6
7// ============================================================================
8// Shared K-Quant Helpers (extracted for cognitive complexity reduction)
9// ============================================================================
10
11/// Compute per-sub-block scale and min values from padded data.
12///
13/// Returns (`sub_scales`, `sub_mins`) for 8 sub-blocks of 32 elements each.
14/// `quant_max` is the maximum quantized value (15 for `Q4_K`, 31 for `Q5_K`).
15pub(crate) fn compute_sub_block_stats(padded: &[f32; 256], quant_max: f32) -> ([f32; 8], [f32; 8]) {
16    const SUB_BLOCK_SIZE: usize = 32;
17    let mut sub_scales = [0.0f32; 8];
18    let mut sub_mins = [0.0f32; 8];
19
20    for (j, sub_block) in padded.chunks(SUB_BLOCK_SIZE).enumerate().take(8) {
21        let min = sub_block.iter().fold(f32::INFINITY, |a, &b| a.min(b));
22        let max = sub_block.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
23        let range = max - min;
24
25        sub_scales[j] = if range > F16_MIN_NORMAL {
26            range / quant_max
27        } else {
28            F16_MIN_NORMAL
29        };
30        sub_mins[j] = (-min).max(0.0);
31    }
32
33    (sub_scales, sub_mins)
34}
35
36/// Compute global d and dmin from sub-block statistics, plus quantized 6-bit scales/mins.
37pub(crate) fn compute_global_scales(
38    sub_scales: &[f32; 8],
39    sub_mins: &[f32; 8],
40) -> (f32, f32, [u8; 8], [u8; 8]) {
41    let max_scale = sub_scales.iter().fold(0.0f32, |a, &b| a.max(b));
42    let max_min = sub_mins.iter().fold(0.0f32, |a, &b| a.max(b));
43
44    let d = if max_scale > F16_MIN_NORMAL {
45        max_scale / 63.0
46    } else {
47        F16_MIN_NORMAL
48    };
49    let dmin = if max_min > F16_MIN_NORMAL {
50        max_min / 63.0
51    } else {
52        F16_MIN_NORMAL
53    };
54
55    let mut scales_6bit = [0u8; 8];
56    let mut mins_6bit = [0u8; 8];
57    for j in 0..8 {
58        scales_6bit[j] = ((sub_scales[j] / d).round() as u8).min(63);
59        mins_6bit[j] = ((sub_mins[j] / dmin).round() as u8).min(63);
60    }
61
62    (d, dmin, scales_6bit, mins_6bit)
63}
64
65/// Write the K-quant header: d (f16) + dmin (f16) + packed 12-byte scales.
66pub(crate) fn write_kquant_header(
67    result: &mut Vec<u8>,
68    d: f32,
69    dmin: f32,
70    scales_6bit: &[u8; 8],
71    mins_6bit: &[u8; 8],
72) {
73    result.extend_from_slice(&f32_to_f16(d).to_le_bytes());
74    result.extend_from_slice(&f32_to_f16(dmin).to_le_bytes());
75
76    let mut scales_packed = [0u8; 12];
77    for i in 0..4 {
78        scales_packed[i] = (scales_6bit[i] & 0x3F) | ((scales_6bit[i + 4] & 0x30) << 2);
79        scales_packed[i + 4] = (mins_6bit[i] & 0x3F) | ((mins_6bit[i + 4] & 0x30) << 2);
80    }
81    for i in 0..4 {
82        scales_packed[i + 8] = (scales_6bit[i + 4] & 0x0F) | ((mins_6bit[i + 4] & 0x0F) << 4);
83    }
84    result.extend_from_slice(&scales_packed);
85}
86
87/// Quantize a single value: (value + `min_val`) / scale, clamped to [0, `max_q`].
88#[inline]
89pub(crate) fn quantize_one(value: f32, min_val: f32, scale: f32, max_q: f32) -> u8 {
90    if scale > 1e-10 {
91        ((value + min_val) / scale).round().clamp(0.0, max_q) as u8
92    } else {
93        0
94    }
95}
96
97// ============================================================================
98// Q4_K Quantization
99// ============================================================================
100
101/// Quantize F32 data to `Q4_K` format (llama.cpp/candle compatible)
102///
103/// `Q4_K` format: 256 elements per super-block, 144 bytes per block
104/// Layout: d (2B) + dmin (2B) + scales (12B) + qs (128B)
105///
106/// Value packing (candle/llama.cpp layout):
107/// - For each 64-value chunk: 32 bytes store low nibbles first, then high nibbles
108/// - Low nibbles use scale[is], high nibbles use scale[is+1]
109#[must_use]
110pub fn quantize_q4_k(data: &[f32]) -> Vec<u8> {
111    const SUPER_BLOCK_SIZE: usize = 256;
112    const SUPER_BLOCK_BYTES: usize = 144;
113
114    if data.is_empty() {
115        return vec![];
116    }
117
118    let num_blocks = data.len().div_ceil(SUPER_BLOCK_SIZE);
119    let mut result = Vec::with_capacity(num_blocks * SUPER_BLOCK_BYTES);
120
121    for block_idx in 0..num_blocks {
122        let block_start = block_idx * SUPER_BLOCK_SIZE;
123        let block_end = (block_start + SUPER_BLOCK_SIZE).min(data.len());
124        let block_data = &data[block_start..block_end];
125
126        let mut padded = [0.0f32; SUPER_BLOCK_SIZE];
127        padded[..block_data.len()].copy_from_slice(block_data);
128
129        let (sub_scales, sub_mins) = compute_sub_block_stats(&padded, 15.0);
130        let (d, dmin, scales_6bit, mins_6bit) = compute_global_scales(&sub_scales, &sub_mins);
131        write_kquant_header(&mut result, d, dmin, &scales_6bit, &mins_6bit);
132
133        // Quantize values into 4-bit packed nibbles
134        let mut qs = [0u8; 128];
135        for chunk in 0..4 {
136            let chunk_start = chunk * 64;
137            let is = chunk * 2;
138            let scale_lo = d * f32::from(scales_6bit[is]);
139            let min_lo = dmin * f32::from(mins_6bit[is]);
140            let scale_hi = d * f32::from(scales_6bit[is + 1]);
141            let min_hi = dmin * f32::from(mins_6bit[is + 1]);
142
143            for l in 0..32 {
144                let q_lo = quantize_one(padded[chunk_start + l], min_lo, scale_lo, 15.0);
145                let q_hi = quantize_one(padded[chunk_start + l + 32], min_hi, scale_hi, 15.0);
146                qs[chunk * 32 + l] = (q_lo & 0x0F) | ((q_hi & 0x0F) << 4);
147            }
148        }
149        result.extend_from_slice(&qs);
150    }
151
152    result
153}
154
155/// Quantize F32 matrix to `Q4_K` format with proper row layout
156///
157/// Processes each row independently to maintain row-major layout.
158#[must_use]
159pub fn quantize_q4_k_matrix(data: &[f32], shape: &[usize]) -> Vec<u8> {
160    const SUPER_BLOCK_SIZE: usize = 256;
161    const SUPER_BLOCK_BYTES: usize = 144;
162
163    if shape.len() != 2 {
164        return quantize_q4_k(data);
165    }
166
167    let rows = shape[0];
168    let cols = shape[1];
169
170    let super_blocks_per_row = cols.div_ceil(SUPER_BLOCK_SIZE);
171    let padded_cols = super_blocks_per_row * SUPER_BLOCK_SIZE;
172
173    let mut result = Vec::with_capacity(rows * super_blocks_per_row * SUPER_BLOCK_BYTES);
174
175    for row_idx in 0..rows {
176        let mut padded_row = vec![0.0f32; padded_cols];
177        let row_start = row_idx * cols;
178        let row_end = row_start + cols;
179        if row_end <= data.len() {
180            padded_row[..cols].copy_from_slice(&data[row_start..row_end]);
181        }
182
183        let row_q4k = quantize_q4_k(&padded_row);
184        result.extend_from_slice(&row_q4k);
185    }
186
187    result
188}
189
190// ============================================================================
191// Q5_K Quantization
192// ============================================================================
193
194/// Quantize F32 data to `Q5_K` format
195///
196/// `Q5_K`: 256 elements per super-block, 176 bytes per block
197/// Layout: d (2B) + dmin (2B) + scales (12B) + qh (32B) + qs (128B)
198#[must_use]
199pub fn quantize_q5_k(data: &[f32]) -> Vec<u8> {
200    const SUPER_BLOCK_SIZE: usize = 256;
201    const SUPER_BLOCK_BYTES: usize = 176;
202
203    if data.is_empty() {
204        return vec![];
205    }
206
207    let num_blocks = data.len().div_ceil(SUPER_BLOCK_SIZE);
208    let mut result = Vec::with_capacity(num_blocks * SUPER_BLOCK_BYTES);
209
210    for block_idx in 0..num_blocks {
211        let block_start = block_idx * SUPER_BLOCK_SIZE;
212        let block_end = (block_start + SUPER_BLOCK_SIZE).min(data.len());
213        let block_data = &data[block_start..block_end];
214
215        let mut padded = [0.0f32; SUPER_BLOCK_SIZE];
216        padded[..block_data.len()].copy_from_slice(block_data);
217
218        let (sub_scales, sub_mins) = compute_sub_block_stats(&padded, 31.0);
219        let (d, dmin, scales_6bit, mins_6bit) = compute_global_scales(&sub_scales, &sub_mins);
220        write_kquant_header(&mut result, d, dmin, &scales_6bit, &mins_6bit);
221
222        // Quantize all 256 values to 5-bit
223        let mut q5_vals = [0u8; 256];
224        for j in 0..8 {
225            let scale = d * f32::from(scales_6bit[j]);
226            let min_val = dmin * f32::from(mins_6bit[j]);
227            for k in 0..32 {
228                q5_vals[j * 32 + k] = quantize_one(padded[j * 32 + k], min_val, scale, 31.0);
229            }
230        }
231
232        // Pack high bits (qh)
233        result.extend_from_slice(&pack_q5k_high_bits(&q5_vals));
234
235        // Pack low 4 bits (qs)
236        result.extend_from_slice(&pack_q5k_low_nibbles(&q5_vals));
237    }
238
239    result
240}
241
242/// Pack `Q5_K` high bits: extract bit 4 from each value into 32 bytes.
243fn pack_q5k_high_bits(q5_vals: &[u8; 256]) -> [u8; 32] {
244    let mut qh = [0u8; 32];
245    for i in 0..32 {
246        let mut h = 0u8;
247        for j in 0..8 {
248            h |= ((q5_vals[j * 32 + i] >> 4) & 1) << j;
249        }
250        qh[i] = h;
251    }
252    qh
253}
254
255/// Pack `Q5_K` low nibbles: combine pairs of 4-bit values into 128 bytes.
256fn pack_q5k_low_nibbles(q5_vals: &[u8; 256]) -> [u8; 128] {
257    let mut qs = [0u8; 128];
258    for j in 0..8 {
259        for k in 0..16 {
260            let idx1 = j * 32 + k;
261            let idx2 = j * 32 + k + 16;
262            qs[j * 16 + k] = (q5_vals[idx1] & 0x0F) | ((q5_vals[idx2] & 0x0F) << 4);
263        }
264    }
265    qs
266}
267
268/// Quantize F32 matrix to `Q5_K` format with proper row layout
269#[must_use]
270pub fn quantize_q5_k_matrix(data: &[f32], shape: &[usize]) -> Vec<u8> {
271    const SUPER_BLOCK_SIZE: usize = 256;
272    const SUPER_BLOCK_BYTES: usize = 176;
273
274    if shape.len() != 2 {
275        return quantize_q5_k(data);
276    }
277
278    let rows = shape[0];
279    let cols = shape[1];
280    let super_blocks_per_row = cols.div_ceil(SUPER_BLOCK_SIZE);
281    let padded_cols = super_blocks_per_row * SUPER_BLOCK_SIZE;
282
283    let mut result = Vec::with_capacity(rows * super_blocks_per_row * SUPER_BLOCK_BYTES);
284
285    for row_idx in 0..rows {
286        let mut padded_row = vec![0.0f32; padded_cols];
287        let row_start = row_idx * cols;
288        let row_end = row_start + cols;
289        if row_end <= data.len() {
290            padded_row[..cols].copy_from_slice(&data[row_start..row_end]);
291        }
292
293        let row_q5k = quantize_q5_k(&padded_row);
294        result.extend_from_slice(&row_q5k);
295    }
296
297    result
298}
299
300// ============================================================================
301// Q6_K Quantization
302// ============================================================================
303
304/// Quantize F32 data to `Q6_K` format (candle/GGUF compatible)
305///
306/// `Q6_K` format: 256-element super-blocks
307/// Each super block: ql (128 bytes) + qh (64 bytes) + scales (16 bytes) + d (f16) = 210 bytes
308/// - 6-bit values stored split: low 4 bits in ql, high 2 bits in qh
309/// - 16 sub-blocks of 16 elements each, with int8 scale per sub-block
310#[must_use]
311pub fn quantize_q6_k(data: &[f32]) -> Vec<u8> {
312    const SUPER_BLOCK_SIZE: usize = 256;
313    const SUPER_BLOCK_BYTES: usize = 210;
314
315    if data.is_empty() {
316        return vec![];
317    }
318
319    let num_blocks = data.len().div_ceil(SUPER_BLOCK_SIZE);
320    let mut result = Vec::with_capacity(num_blocks * SUPER_BLOCK_BYTES);
321
322    for block_idx in 0..num_blocks {
323        let block_start = block_idx * SUPER_BLOCK_SIZE;
324        let block_end = (block_start + SUPER_BLOCK_SIZE).min(data.len());
325        let block_data = &data[block_start..block_end];
326
327        let mut padded = [0.0f32; SUPER_BLOCK_SIZE];
328        padded[..block_data.len()].copy_from_slice(block_data);
329
330        let (d, scales_i8) = compute_q6k_scales(&padded);
331        let q6_vals = quantize_q6k_values(&padded, d, &scales_i8);
332        let (ql, qh) = pack_q6k_bits(&q6_vals);
333
334        // Write in candle order: ql, qh, scales, d
335        result.extend_from_slice(&ql);
336        result.extend_from_slice(&qh);
337        for s in &scales_i8 {
338            result.push(*s as u8);
339        }
340        result.extend_from_slice(&f32_to_f16(d).to_le_bytes());
341    }
342
343    result
344}
345
346/// Compute `Q6_K` global scale and per-sub-block int8 scales.
347fn compute_q6k_scales(padded: &[f32; 256]) -> (f32, [i8; 16]) {
348    let mut sub_scales = [0.0f32; 16];
349    for (j, sub_block) in padded.chunks(16).enumerate().take(16) {
350        let max_abs = sub_block.iter().fold(0.0f32, |a, &b| a.max(b.abs()));
351        sub_scales[j] = if max_abs > F16_MIN_NORMAL {
352            max_abs / 31.0
353        } else {
354            F16_MIN_NORMAL
355        };
356    }
357
358    let max_scale = sub_scales.iter().fold(0.0f32, |a, &b| a.max(b));
359    let d = if max_scale > F16_MIN_NORMAL {
360        max_scale / 127.0
361    } else {
362        F16_MIN_NORMAL
363    };
364
365    let mut scales_i8 = [0i8; 16];
366    for j in 0..16 {
367        scales_i8[j] = (sub_scales[j] / d).round().clamp(-127.0, 127.0) as i8;
368    }
369
370    (d, scales_i8)
371}
372
373/// Quantize 256 padded values to 6-bit `Q6_K` format.
374fn quantize_q6k_values(padded: &[f32; 256], d: f32, scales_i8: &[i8; 16]) -> [u8; 256] {
375    let mut q6_vals = [0u8; 256];
376    for j in 0..16 {
377        let scale = d * f32::from(scales_i8[j]);
378        let inv_scale = if scale.abs() > 1e-10 {
379            1.0 / scale
380        } else {
381            0.0
382        };
383        for k in 0..16 {
384            let idx = j * 16 + k;
385            let q = (padded[idx] * inv_scale).round().clamp(-32.0, 31.0) as i8;
386            q6_vals[idx] = (q + 32) as u8;
387        }
388    }
389    q6_vals
390}
391
392/// Pack 256 `Q6_K` values into ql (128 bytes) and qh (64 bytes) candle/GGUF layout.
393fn pack_q6k_bits(q6_vals: &[u8; 256]) -> ([u8; 128], [u8; 64]) {
394    let mut ql = [0u8; 128];
395    let mut qh = [0u8; 64];
396
397    for half in 0..2 {
398        let n = half * 128;
399        let ql_base = half * 64;
400        let qh_base = half * 32;
401
402        for l in 0..32 {
403            let q1 = q6_vals[n + l];
404            let q2 = q6_vals[n + l + 32];
405            let q3 = q6_vals[n + l + 64];
406            let q4 = q6_vals[n + l + 96];
407
408            ql[ql_base + l] = (q1 & 0x0F) | ((q3 & 0x0F) << 4);
409            ql[ql_base + l + 32] = (q2 & 0x0F) | ((q4 & 0x0F) << 4);
410
411            qh[qh_base + l] = ((q1 >> 4) & 0x03)
412                | (((q2 >> 4) & 0x03) << 2)
413                | (((q3 >> 4) & 0x03) << 4)
414                | (((q4 >> 4) & 0x03) << 6);
415        }
416    }
417
418    (ql, qh)
419}
420
421/// Quantize F32 matrix to `Q6_K` format with proper row layout
422#[must_use]
423pub fn quantize_q6_k_matrix(data: &[f32], shape: &[usize]) -> Vec<u8> {
424    const SUPER_BLOCK_SIZE: usize = 256;
425    const SUPER_BLOCK_BYTES: usize = 210;
426
427    if shape.len() != 2 {
428        return quantize_q6_k(data);
429    }
430
431    let rows = shape[0];
432    let cols = shape[1];
433    let super_blocks_per_row = cols.div_ceil(SUPER_BLOCK_SIZE);
434    let padded_cols = super_blocks_per_row * SUPER_BLOCK_SIZE;
435
436    let mut result = Vec::with_capacity(rows * super_blocks_per_row * SUPER_BLOCK_BYTES);
437
438    for row_idx in 0..rows {
439        let mut padded_row = vec![0.0f32; padded_cols];
440        let row_start = row_idx * cols;
441        let row_end = row_start + cols;
442        if row_end <= data.len() {
443            padded_row[..cols].copy_from_slice(&data[row_start..row_end]);
444        }
445
446        let row_q6k = quantize_q6_k(&padded_row);
447        result.extend_from_slice(&row_q6k);
448    }
449
450    result
451}