1use crate::{f32_to_f16, F16_MIN_NORMAL};
6
7pub(crate) fn compute_sub_block_stats(padded: &[f32; 256], quant_max: f32) -> ([f32; 8], [f32; 8]) {
16 const SUB_BLOCK_SIZE: usize = 32;
17 let mut sub_scales = [0.0f32; 8];
18 let mut sub_mins = [0.0f32; 8];
19
20 for (j, sub_block) in padded.chunks(SUB_BLOCK_SIZE).enumerate().take(8) {
21 let min = sub_block.iter().fold(f32::INFINITY, |a, &b| a.min(b));
22 let max = sub_block.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
23 let range = max - min;
24
25 sub_scales[j] = if range > F16_MIN_NORMAL {
26 range / quant_max
27 } else {
28 F16_MIN_NORMAL
29 };
30 sub_mins[j] = (-min).max(0.0);
31 }
32
33 (sub_scales, sub_mins)
34}
35
36pub(crate) fn compute_global_scales(
38 sub_scales: &[f32; 8],
39 sub_mins: &[f32; 8],
40) -> (f32, f32, [u8; 8], [u8; 8]) {
41 let max_scale = sub_scales.iter().fold(0.0f32, |a, &b| a.max(b));
42 let max_min = sub_mins.iter().fold(0.0f32, |a, &b| a.max(b));
43
44 let d = if max_scale > F16_MIN_NORMAL {
45 max_scale / 63.0
46 } else {
47 F16_MIN_NORMAL
48 };
49 let dmin = if max_min > F16_MIN_NORMAL {
50 max_min / 63.0
51 } else {
52 F16_MIN_NORMAL
53 };
54
55 let mut scales_6bit = [0u8; 8];
56 let mut mins_6bit = [0u8; 8];
57 for j in 0..8 {
58 scales_6bit[j] = ((sub_scales[j] / d).round() as u8).min(63);
59 mins_6bit[j] = ((sub_mins[j] / dmin).round() as u8).min(63);
60 }
61
62 (d, dmin, scales_6bit, mins_6bit)
63}
64
65pub(crate) fn write_kquant_header(
67 result: &mut Vec<u8>,
68 d: f32,
69 dmin: f32,
70 scales_6bit: &[u8; 8],
71 mins_6bit: &[u8; 8],
72) {
73 result.extend_from_slice(&f32_to_f16(d).to_le_bytes());
74 result.extend_from_slice(&f32_to_f16(dmin).to_le_bytes());
75
76 let mut scales_packed = [0u8; 12];
77 for i in 0..4 {
78 scales_packed[i] = (scales_6bit[i] & 0x3F) | ((scales_6bit[i + 4] & 0x30) << 2);
79 scales_packed[i + 4] = (mins_6bit[i] & 0x3F) | ((mins_6bit[i + 4] & 0x30) << 2);
80 }
81 for i in 0..4 {
82 scales_packed[i + 8] = (scales_6bit[i + 4] & 0x0F) | ((mins_6bit[i + 4] & 0x0F) << 4);
83 }
84 result.extend_from_slice(&scales_packed);
85}
86
87#[inline]
89pub(crate) fn quantize_one(value: f32, min_val: f32, scale: f32, max_q: f32) -> u8 {
90 if scale > 1e-10 {
91 ((value + min_val) / scale).round().clamp(0.0, max_q) as u8
92 } else {
93 0
94 }
95}
96
97#[must_use]
110pub fn quantize_q4_k(data: &[f32]) -> Vec<u8> {
111 const SUPER_BLOCK_SIZE: usize = 256;
112 const SUPER_BLOCK_BYTES: usize = 144;
113
114 if data.is_empty() {
115 return vec![];
116 }
117
118 let num_blocks = data.len().div_ceil(SUPER_BLOCK_SIZE);
119 let mut result = Vec::with_capacity(num_blocks * SUPER_BLOCK_BYTES);
120
121 for block_idx in 0..num_blocks {
122 let block_start = block_idx * SUPER_BLOCK_SIZE;
123 let block_end = (block_start + SUPER_BLOCK_SIZE).min(data.len());
124 let block_data = &data[block_start..block_end];
125
126 let mut padded = [0.0f32; SUPER_BLOCK_SIZE];
127 padded[..block_data.len()].copy_from_slice(block_data);
128
129 let (sub_scales, sub_mins) = compute_sub_block_stats(&padded, 15.0);
130 let (d, dmin, scales_6bit, mins_6bit) = compute_global_scales(&sub_scales, &sub_mins);
131 write_kquant_header(&mut result, d, dmin, &scales_6bit, &mins_6bit);
132
133 let mut qs = [0u8; 128];
135 for chunk in 0..4 {
136 let chunk_start = chunk * 64;
137 let is = chunk * 2;
138 let scale_lo = d * f32::from(scales_6bit[is]);
139 let min_lo = dmin * f32::from(mins_6bit[is]);
140 let scale_hi = d * f32::from(scales_6bit[is + 1]);
141 let min_hi = dmin * f32::from(mins_6bit[is + 1]);
142
143 for l in 0..32 {
144 let q_lo = quantize_one(padded[chunk_start + l], min_lo, scale_lo, 15.0);
145 let q_hi = quantize_one(padded[chunk_start + l + 32], min_hi, scale_hi, 15.0);
146 qs[chunk * 32 + l] = (q_lo & 0x0F) | ((q_hi & 0x0F) << 4);
147 }
148 }
149 result.extend_from_slice(&qs);
150 }
151
152 result
153}
154
155#[must_use]
159pub fn quantize_q4_k_matrix(data: &[f32], shape: &[usize]) -> Vec<u8> {
160 const SUPER_BLOCK_SIZE: usize = 256;
161 const SUPER_BLOCK_BYTES: usize = 144;
162
163 if shape.len() != 2 {
164 return quantize_q4_k(data);
165 }
166
167 let rows = shape[0];
168 let cols = shape[1];
169
170 let super_blocks_per_row = cols.div_ceil(SUPER_BLOCK_SIZE);
171 let padded_cols = super_blocks_per_row * SUPER_BLOCK_SIZE;
172
173 let mut result = Vec::with_capacity(rows * super_blocks_per_row * SUPER_BLOCK_BYTES);
174
175 for row_idx in 0..rows {
176 let mut padded_row = vec![0.0f32; padded_cols];
177 let row_start = row_idx * cols;
178 let row_end = row_start + cols;
179 if row_end <= data.len() {
180 padded_row[..cols].copy_from_slice(&data[row_start..row_end]);
181 }
182
183 let row_q4k = quantize_q4_k(&padded_row);
184 result.extend_from_slice(&row_q4k);
185 }
186
187 result
188}
189
190#[must_use]
199pub fn quantize_q5_k(data: &[f32]) -> Vec<u8> {
200 const SUPER_BLOCK_SIZE: usize = 256;
201 const SUPER_BLOCK_BYTES: usize = 176;
202
203 if data.is_empty() {
204 return vec![];
205 }
206
207 let num_blocks = data.len().div_ceil(SUPER_BLOCK_SIZE);
208 let mut result = Vec::with_capacity(num_blocks * SUPER_BLOCK_BYTES);
209
210 for block_idx in 0..num_blocks {
211 let block_start = block_idx * SUPER_BLOCK_SIZE;
212 let block_end = (block_start + SUPER_BLOCK_SIZE).min(data.len());
213 let block_data = &data[block_start..block_end];
214
215 let mut padded = [0.0f32; SUPER_BLOCK_SIZE];
216 padded[..block_data.len()].copy_from_slice(block_data);
217
218 let (sub_scales, sub_mins) = compute_sub_block_stats(&padded, 31.0);
219 let (d, dmin, scales_6bit, mins_6bit) = compute_global_scales(&sub_scales, &sub_mins);
220 write_kquant_header(&mut result, d, dmin, &scales_6bit, &mins_6bit);
221
222 let mut q5_vals = [0u8; 256];
224 for j in 0..8 {
225 let scale = d * f32::from(scales_6bit[j]);
226 let min_val = dmin * f32::from(mins_6bit[j]);
227 for k in 0..32 {
228 q5_vals[j * 32 + k] = quantize_one(padded[j * 32 + k], min_val, scale, 31.0);
229 }
230 }
231
232 result.extend_from_slice(&pack_q5k_high_bits(&q5_vals));
234
235 result.extend_from_slice(&pack_q5k_low_nibbles(&q5_vals));
237 }
238
239 result
240}
241
242fn pack_q5k_high_bits(q5_vals: &[u8; 256]) -> [u8; 32] {
244 let mut qh = [0u8; 32];
245 for i in 0..32 {
246 let mut h = 0u8;
247 for j in 0..8 {
248 h |= ((q5_vals[j * 32 + i] >> 4) & 1) << j;
249 }
250 qh[i] = h;
251 }
252 qh
253}
254
255fn pack_q5k_low_nibbles(q5_vals: &[u8; 256]) -> [u8; 128] {
257 let mut qs = [0u8; 128];
258 for j in 0..8 {
259 for k in 0..16 {
260 let idx1 = j * 32 + k;
261 let idx2 = j * 32 + k + 16;
262 qs[j * 16 + k] = (q5_vals[idx1] & 0x0F) | ((q5_vals[idx2] & 0x0F) << 4);
263 }
264 }
265 qs
266}
267
268#[must_use]
270pub fn quantize_q5_k_matrix(data: &[f32], shape: &[usize]) -> Vec<u8> {
271 const SUPER_BLOCK_SIZE: usize = 256;
272 const SUPER_BLOCK_BYTES: usize = 176;
273
274 if shape.len() != 2 {
275 return quantize_q5_k(data);
276 }
277
278 let rows = shape[0];
279 let cols = shape[1];
280 let super_blocks_per_row = cols.div_ceil(SUPER_BLOCK_SIZE);
281 let padded_cols = super_blocks_per_row * SUPER_BLOCK_SIZE;
282
283 let mut result = Vec::with_capacity(rows * super_blocks_per_row * SUPER_BLOCK_BYTES);
284
285 for row_idx in 0..rows {
286 let mut padded_row = vec![0.0f32; padded_cols];
287 let row_start = row_idx * cols;
288 let row_end = row_start + cols;
289 if row_end <= data.len() {
290 padded_row[..cols].copy_from_slice(&data[row_start..row_end]);
291 }
292
293 let row_q5k = quantize_q5_k(&padded_row);
294 result.extend_from_slice(&row_q5k);
295 }
296
297 result
298}
299
300#[must_use]
311pub fn quantize_q6_k(data: &[f32]) -> Vec<u8> {
312 const SUPER_BLOCK_SIZE: usize = 256;
313 const SUPER_BLOCK_BYTES: usize = 210;
314
315 if data.is_empty() {
316 return vec![];
317 }
318
319 let num_blocks = data.len().div_ceil(SUPER_BLOCK_SIZE);
320 let mut result = Vec::with_capacity(num_blocks * SUPER_BLOCK_BYTES);
321
322 for block_idx in 0..num_blocks {
323 let block_start = block_idx * SUPER_BLOCK_SIZE;
324 let block_end = (block_start + SUPER_BLOCK_SIZE).min(data.len());
325 let block_data = &data[block_start..block_end];
326
327 let mut padded = [0.0f32; SUPER_BLOCK_SIZE];
328 padded[..block_data.len()].copy_from_slice(block_data);
329
330 let (d, scales_i8) = compute_q6k_scales(&padded);
331 let q6_vals = quantize_q6k_values(&padded, d, &scales_i8);
332 let (ql, qh) = pack_q6k_bits(&q6_vals);
333
334 result.extend_from_slice(&ql);
336 result.extend_from_slice(&qh);
337 for s in &scales_i8 {
338 result.push(*s as u8);
339 }
340 result.extend_from_slice(&f32_to_f16(d).to_le_bytes());
341 }
342
343 result
344}
345
346fn compute_q6k_scales(padded: &[f32; 256]) -> (f32, [i8; 16]) {
348 let mut sub_scales = [0.0f32; 16];
349 for (j, sub_block) in padded.chunks(16).enumerate().take(16) {
350 let max_abs = sub_block.iter().fold(0.0f32, |a, &b| a.max(b.abs()));
351 sub_scales[j] = if max_abs > F16_MIN_NORMAL {
352 max_abs / 31.0
353 } else {
354 F16_MIN_NORMAL
355 };
356 }
357
358 let max_scale = sub_scales.iter().fold(0.0f32, |a, &b| a.max(b));
359 let d = if max_scale > F16_MIN_NORMAL {
360 max_scale / 127.0
361 } else {
362 F16_MIN_NORMAL
363 };
364
365 let mut scales_i8 = [0i8; 16];
366 for j in 0..16 {
367 scales_i8[j] = (sub_scales[j] / d).round().clamp(-127.0, 127.0) as i8;
368 }
369
370 (d, scales_i8)
371}
372
373fn quantize_q6k_values(padded: &[f32; 256], d: f32, scales_i8: &[i8; 16]) -> [u8; 256] {
375 let mut q6_vals = [0u8; 256];
376 for j in 0..16 {
377 let scale = d * f32::from(scales_i8[j]);
378 let inv_scale = if scale.abs() > 1e-10 {
379 1.0 / scale
380 } else {
381 0.0
382 };
383 for k in 0..16 {
384 let idx = j * 16 + k;
385 let q = (padded[idx] * inv_scale).round().clamp(-32.0, 31.0) as i8;
386 q6_vals[idx] = (q + 32) as u8;
387 }
388 }
389 q6_vals
390}
391
392fn pack_q6k_bits(q6_vals: &[u8; 256]) -> ([u8; 128], [u8; 64]) {
394 let mut ql = [0u8; 128];
395 let mut qh = [0u8; 64];
396
397 for half in 0..2 {
398 let n = half * 128;
399 let ql_base = half * 64;
400 let qh_base = half * 32;
401
402 for l in 0..32 {
403 let q1 = q6_vals[n + l];
404 let q2 = q6_vals[n + l + 32];
405 let q3 = q6_vals[n + l + 64];
406 let q4 = q6_vals[n + l + 96];
407
408 ql[ql_base + l] = (q1 & 0x0F) | ((q3 & 0x0F) << 4);
409 ql[ql_base + l + 32] = (q2 & 0x0F) | ((q4 & 0x0F) << 4);
410
411 qh[qh_base + l] = ((q1 >> 4) & 0x03)
412 | (((q2 >> 4) & 0x03) << 2)
413 | (((q3 >> 4) & 0x03) << 4)
414 | (((q4 >> 4) & 0x03) << 6);
415 }
416 }
417
418 (ql, qh)
419}
420
421#[must_use]
423pub fn quantize_q6_k_matrix(data: &[f32], shape: &[usize]) -> Vec<u8> {
424 const SUPER_BLOCK_SIZE: usize = 256;
425 const SUPER_BLOCK_BYTES: usize = 210;
426
427 if shape.len() != 2 {
428 return quantize_q6_k(data);
429 }
430
431 let rows = shape[0];
432 let cols = shape[1];
433 let super_blocks_per_row = cols.div_ceil(SUPER_BLOCK_SIZE);
434 let padded_cols = super_blocks_per_row * SUPER_BLOCK_SIZE;
435
436 let mut result = Vec::with_capacity(rows * super_blocks_per_row * SUPER_BLOCK_BYTES);
437
438 for row_idx in 0..rows {
439 let mut padded_row = vec![0.0f32; padded_cols];
440 let row_start = row_idx * cols;
441 let row_end = row_start + cols;
442 if row_end <= data.len() {
443 padded_row[..cols].copy_from_slice(&data[row_start..row_end]);
444 }
445
446 let row_q6k = quantize_q6_k(&padded_row);
447 result.extend_from_slice(&row_q6k);
448 }
449
450 result
451}