1pub(super) fn nearest_int(v: f32) -> i32 {
2 v.round() as i32
3}
4
5pub(super) fn group_for_quantization<'a, 'b, T: super::k_quants::GgmlType>(
9 xs: &'b [f32],
10 ys: &'a mut [T],
11) -> Vec<(&'a mut T, &'b [f32])> {
12 let block_size = T::BLCK_SIZE;
13 let dtype = T::DTYPE;
14
15 let expected_blocks = xs.len() / block_size;
16 let actual_blocks = ys.len();
17
18 debug_assert_eq!(
20 expected_blocks,
21 actual_blocks,
22 "quantize {dtype:?}: expected {expected_blocks} blocks but only {actual_blocks} were provided!");
23
24 ys.iter_mut().zip(xs.chunks_exact(block_size)).collect()
25}
26
27pub(super) fn group_for_dequantization<'a, 'b, T: super::k_quants::GgmlType>(
31 xs: &'a [T],
32 ys: &'b mut [f32],
33) -> Vec<(&'a T, &'b mut [f32])> {
34 let block_size = T::BLCK_SIZE;
35 let dtype = T::DTYPE;
36
37 let actual_output_len = ys.len();
38 let expected_output_len = xs.len() * block_size;
39 debug_assert_eq!(
41 expected_output_len,
42 actual_output_len,
43 "dequantize {dtype:?}: ys (len = {actual_output_len}) does not match the expected length of {expected_output_len}!"
44 );
45
46 xs.iter().zip(ys.chunks_exact_mut(block_size)).collect()
48}
49
50pub(super) fn get_scale_min_k4(j: usize, q: &[u8]) -> (u8, u8) {
51 if j < 4 {
52 let d = q[j] & 63;
53 let m = q[j + 4] & 63;
54 (d, m)
55 } else {
56 let d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
57 let m = (q[j + 4] >> 4) | ((q[j] >> 6) << 4);
58 (d, m)
59 }
60}
61
62pub(super) unsafe fn make_qx_quants(
63 n: usize,
64 nmax: i32,
65 x: *const f32,
66 ls: *mut i8,
67 rmse_type: i32,
68 qw: *const f32,
69) -> f32 {
70 let mut max = 0f32;
71 let mut amax = 0f32;
72 for i in 0..n {
73 let x = *x.add(i);
74 let ax = x.abs();
75 if ax > amax {
76 amax = ax;
77 max = x;
78 }
79 }
80 if amax == 0. {
81 for i in 0..n {
83 *ls.add(i) = 0;
84 }
85 return 0.;
86 }
87 let mut iscale = -(nmax as f32) / max;
88 if rmse_type == 0 {
89 for i in 0..n {
90 let x = *x.add(i);
91 let l = nearest_int(iscale * x);
92 *ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8;
93 }
94 return 1.0 / iscale;
95 }
96 let weight_type = rmse_type % 2;
97 let mut sumlx = 0f32;
98 let mut suml2 = 0f32;
99 for i in 0..n {
100 let x = *x.add(i);
101 let l = nearest_int(iscale * x);
102 let l = l.clamp(-nmax, nmax - 1);
103 *ls.add(i) = (l + nmax) as i8;
104 let w = if !qw.is_null() {
105 *qw.add(i)
106 } else if weight_type == 1 {
107 x * x
108 } else {
109 1.0
110 };
111 let l = l as f32;
112 sumlx += w * x * l;
113 suml2 += w * l * l;
114 }
115 let mut scale = sumlx / suml2;
116 let mut best = scale * sumlx;
117 for _itry in 0..3 {
118 let iscale = 1.0 / scale;
119 let mut slx = 0f32;
120 let mut sl2 = 0f32;
121 let mut changed = false;
122 for i in 0..n {
123 let x = *x.add(i);
124 let l = nearest_int(iscale * x);
125 let l = l.clamp(-nmax, nmax - 1);
126 if l + nmax != *ls.add(i) as i32 {
127 changed = true;
128 }
129 let w = if !qw.is_null() {
130 *qw.add(i)
131 } else if weight_type == 1 {
132 x * x
133 } else {
134 1.0
135 };
136 let l = l as f32;
137 slx += w * x * l;
138 sl2 += w * l * l;
139 }
140 if !changed || sl2 == 0.0 || slx * slx <= best * sl2 {
141 break;
142 }
143 for i in 0..n {
144 let x = *x.add(i);
145 let l = nearest_int(iscale * x);
146 *ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8;
147 }
148 sumlx = slx;
149 suml2 = sl2;
150 scale = sumlx / suml2;
151 best = scale * sumlx;
152 }
153 for _itry in 0..5 {
154 let mut n_changed = 0;
155 for i in 0..n {
156 let x = *x.add(i);
157 let w = if !qw.is_null() {
158 *qw.add(i)
159 } else if weight_type == 1 {
160 x * x
161 } else {
162 1.0
163 };
164 let l = *ls.add(i) as i32 - nmax;
165 let mut slx = sumlx - w * x * l as f32;
166 if slx > 0. {
167 let mut sl2 = suml2 - w * l as f32 * l as f32;
168 let new_l = nearest_int(x * sl2 / slx);
169 let new_l = new_l.clamp(-nmax, nmax - 1);
170 if new_l != l {
171 slx += w * x * new_l as f32;
172 sl2 += w * new_l as f32 * new_l as f32;
173 if sl2 > 0. && slx * slx * suml2 > sumlx * sumlx * sl2 {
174 *ls.add(i) = (nmax + new_l) as i8;
175 sumlx = slx;
176 suml2 = sl2;
177 scale = sumlx / suml2;
178 best = scale * sumlx;
179 n_changed += 1;
180 }
181 }
182 }
183 }
184 if n_changed == 0 {
185 break;
186 }
187 }
188 if rmse_type < 3 {
189 return scale;
190 }
191 for is in -4..4 {
192 if is == 0 {
193 continue;
194 }
195 iscale = -(nmax as f32 + 0.1f32 * is as f32) / max;
196 let mut sumlx = 0.;
197 let mut suml2 = 0.;
198 for i in 0..n {
199 let x = *x.add(i);
200 let l = nearest_int(iscale * x);
201 let l = l.clamp(-nmax, nmax - 1);
202 let w = if !qw.is_null() {
203 *qw.add(i)
204 } else if weight_type == 1 {
205 x * x
206 } else {
207 1.0
208 };
209 let l = l as f32;
210 sumlx += w * x * l;
211 suml2 += w * l * l;
212 }
213 if suml2 > 0. && sumlx * sumlx > best * suml2 {
214 for i in 0..n {
215 let x = *x.add(i);
216 let l = nearest_int(iscale * x);
217 *ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8;
218 }
219 scale = sumlx / suml2;
220 best = scale * sumlx;
221 }
222 }
223 scale
224}
225
226pub(super) fn make_qkx1_quants(nmax: i32, ntry: usize, x: &[f32]) -> (f32, f32) {
228 let n = x.len();
229 let mut l = vec![0; n];
230 let min = *x
232 .iter()
233 .take(n)
234 .min_by(|a, b| a.total_cmp(b))
235 .unwrap_or(&x[0]);
236 let max = *x.iter().max_by(|a, b| a.total_cmp(b)).unwrap_or(&x[0]);
237
238 if max == min {
240 return (0.0, 0.0);
241 }
242
243 let mut min = min.min(0.);
245
246 let mut iscale = nmax as f32 / (max - min);
248 let mut scale = 1.0 / iscale;
249
250 for _ in 0..ntry {
251 let mut sumlx = 0.0;
252 let mut suml2 = 0;
253 let mut did_change = false;
254
255 for (i, value) in x.iter().enumerate().take(n) {
256 let li = nearest_int(iscale * (value - min)).clamp(0, nmax);
257 let clamped_li = li as u8;
258 if clamped_li != l[i] {
259 l[i] = clamped_li;
260 did_change = true;
261 }
262 sumlx += (value - min) * li as f32;
263 suml2 += li * li;
264 }
265 scale = sumlx / suml2 as f32;
266
267 let sum: f32 = x
268 .iter()
269 .take(n)
270 .zip(l.iter().take(n))
271 .map(|(xi, &li)| xi - scale * li as f32)
272 .sum();
273
274 min = sum / n as f32;
275 if min > 0.0 {
276 min = 0.0;
277 }
278 iscale = 1.0 / scale;
279 if !did_change {
280 break;
281 }
282 }
283 (scale, -min)
284}
285
286pub(super) fn make_q3_quants(x: &[f32], nmax: i32, do_rmse: bool) -> f32 {
288 let n = x.len();
289 let mut l = vec![0i8; n];
290
291 let mut max = 0.0;
292 let mut amax = 0.0;
293 for &xi in x.iter().take(n) {
294 let ax = xi.abs();
295 if ax > amax {
296 amax = ax;
297 max = xi;
298 }
299 }
300
301 if amax == 0.0 {
302 return 0.0;
303 }
304
305 let iscale = -(nmax as f32) / max;
306 if do_rmse {
307 let mut sumlx = 0.0;
308 let mut suml2 = 0.0;
309 for i in 0..n {
310 let li = (iscale * x[i]).round() as i32;
311 let li = li.clamp(-nmax, nmax - 1);
312 l[i] = li as i8;
313 let w = x[i] * x[i];
314 sumlx += w * x[i] * li as f32;
315 suml2 += w * (li * li) as f32;
316 }
317 for _ in 0..5 {
318 let mut n_changed = 0;
319 for i in 0..n {
320 let w = x[i] * x[i];
321 let mut slx = sumlx - w * x[i] * l[i] as f32;
322 if slx > 0.0 {
323 let mut sl2 = suml2 - w * (l[i] as i32 * l[i] as i32) as f32;
324 let mut new_l = (x[i] * sl2 / slx).round() as i32;
325 new_l = new_l.clamp(-nmax, nmax - 1);
326 if new_l != l[i] as i32 {
327 slx += w * x[i] * new_l as f32;
328 sl2 += w * (new_l * new_l) as f32;
329 if sl2 > 0.0 && slx * slx * suml2 > sumlx * sumlx * sl2 {
330 l[i] = new_l as i8;
331 sumlx = slx;
332 suml2 = sl2;
333 n_changed += 1;
334 }
335 }
336 }
337 }
338 if n_changed == 0 {
339 break;
340 }
341 }
342 for li in l.iter_mut() {
343 *li += nmax as i8;
344 }
345 return sumlx / suml2;
346 }
347 for i in 0..n {
348 let li = (iscale * x[i]).round() as i32;
349 l[i] = (li.clamp(-nmax, nmax - 1) + nmax) as i8;
350 }
351 1.0 / iscale
352}
353
354pub(super) fn make_qkx3_quants(
357 nmax: i32,
358 x: &[f32],
359 weights: Option<&[f32]>,
360 rmin: f32,
361 rdelta: f32,
362 nstep: usize,
363 use_mad: bool,
364) -> (f32, f32) {
365 let n = x.len();
366 let mut l: [u8; 32] = [0; 32];
367 let mut l_aux: [u8; 32] = [0; 32];
368
369 let mut min_val = x[0];
370 let mut max_val = x[0];
371 let mut sum_w = match weights {
372 Some(w) => w[0],
373 None => x[0] * x[0],
374 };
375 let mut sum_x = sum_w * x[0];
376
377 for i in 1..n {
378 if x[i] < min_val {
379 min_val = x[i];
380 }
381 if x[i] > max_val {
382 max_val = x[i];
383 }
384 let w = match weights {
385 Some(w) => w[i],
386 None => x[i] * x[i],
387 };
388 sum_w += w;
389 sum_x += w * x[i];
390 }
391
392 if min_val > 0.0 {
393 min_val = 0.0;
394 }
395
396 if max_val <= min_val {
397 return (0.0, -min_val);
398 }
399
400 let mut iscale = nmax as f32 / (max_val - min_val);
401 let mut scale = 1.0 / iscale;
402 let mut best_mad = 0.0;
403
404 for i in 0..n {
405 let l_val = nearest_int(iscale * (x[i] - min_val)).clamp(0, nmax) as u8;
406 l[i] = l_val;
407 let diff = scale * (l_val as f32) + min_val - x[i];
408 let diff = if use_mad { diff.abs() } else { diff * diff };
409 let w = match weights {
410 Some(w) => w[i],
411 None => x[i] * x[i],
412 };
413 best_mad += w * diff;
414 }
415
416 if nstep < 1 {
417 return (scale, -min_val);
418 }
419
420 for is in 0..=nstep {
421 iscale = (rmin + rdelta * is as f32 + nmax as f32) / (max_val - min_val);
422 let (mut sum_l, mut sum_l2, mut sum_xl) = (0.0, 0.0, 0.0);
423
424 for i in 0..n {
425 let l_val = nearest_int(iscale * (x[i] - min_val)).clamp(0, nmax) as u8;
426 l_aux[i] = l_val;
427 let w = match weights {
428 Some(w) => w[i],
429 None => x[i] * x[i],
430 };
431 sum_l += w * l_val as f32;
432 sum_l2 += w * (l_val as f32).powi(2);
433 sum_xl += w * l_val as f32 * x[i];
434 }
435
436 let d = sum_w * sum_l2 - sum_l * sum_l;
437 if d > 0.0 {
438 let mut this_scale = (sum_w * sum_xl - sum_x * sum_l) / d;
439 let mut this_min = (sum_l2 * sum_x - sum_l * sum_xl) / d;
440
441 if this_min > 0.0 {
442 this_min = 0.0;
443 this_scale = sum_xl / sum_l2;
444 }
445
446 let mut mad = 0.0;
447 for i in 0..n {
448 let diff = this_scale * (l_aux[i] as f32) + this_min - x[i];
449 let diff = if use_mad { diff.abs() } else { diff * diff };
450 let w = match weights {
451 Some(w) => w[i],
452 None => x[i] * x[i],
453 };
454 mad += w * diff;
455 }
456
457 if mad < best_mad {
458 l.copy_from_slice(&l_aux);
459 best_mad = mad;
460 scale = this_scale;
461 min_val = this_min;
462 }
463 }
464 }
465
466 (scale, -min_val)
467}
468
469pub(super) fn make_qp_quants(
471 n: usize,
472 nmax: u8,
473 x: &[f32],
474 l: &mut [u8],
475 quant_weights: &[f32],
476) -> f32 {
477 assert_eq!(x.len(), n);
478 assert_eq!(l.len(), n);
479 assert_eq!(quant_weights.len(), n);
480
481 let max = x.iter().copied().fold(0.0, f32::max);
482 if max == 0.0 {
483 l.iter_mut().for_each(|li| *li = 0);
484 return 0.0;
485 }
486
487 let mut iscale = nmax as f32 / max;
488 for (xi, li) in x.iter().zip(l.iter_mut()) {
489 *li = nearest_int(iscale * xi) as u8;
490 }
491
492 let scale = 1.0 / iscale;
493 let mut best_mse = x
494 .iter()
495 .zip(l.iter())
496 .zip(quant_weights.iter())
497 .map(|((&xi, &li), &w)| {
498 let diff = xi - scale * li as f32;
499 w * diff * diff
500 })
501 .sum::<f32>();
502
503 for is in -4..=4 {
504 if is == 0 {
505 continue;
506 }
507 let iscale_is = (0.1 * is as f32 + nmax as f32) / max;
508 let scale_is = 1.0 / iscale_is;
509
510 let mse = x
511 .iter()
512 .zip(quant_weights.iter())
513 .map(|(&xi, &w)| {
514 let mut li = nearest_int(iscale_is * xi) as u8;
515 li = li.min(nmax);
516 let diff = xi - scale_is * li as f32;
517 w * diff * diff
518 })
519 .sum::<f32>();
520
521 if mse < best_mse {
522 best_mse = mse;
523 iscale = iscale_is;
524 }
525 }
526
527 let mut sumlx = 0.0;
528 let mut suml2 = 0.0;
529 for ((xi, li), &w) in x.iter().zip(l.iter_mut()).zip(quant_weights.iter()) {
530 let mut li_new = (iscale * xi).round() as u8;
531 li_new = li_new.min(nmax);
532 *li = li_new;
533 sumlx += w * xi * li_new as f32;
534 suml2 += w * (li_new as f32).powi(2);
535 }
536
537 for _ in 0..5 {
538 let mut n_changed = 0;
539 for ((xi, li), &w) in x.iter().zip(l.iter_mut()).zip(quant_weights.iter()) {
540 let mut slx = sumlx - w * xi * *li as f32;
541 let mut sl2 = suml2 - w * (*li as f32).powi(2);
542 if slx > 0.0 && sl2 > 0.0 {
543 let new_li = (nearest_int(xi * sl2 / slx) as u8).min(nmax);
544 if new_li != *li {
545 slx += w * xi * new_li as f32;
546 sl2 += w * (new_li as f32).powi(2);
547 if slx.powi(2) * suml2 > sumlx.powi(2) * sl2 {
548 *li = new_li;
549 sumlx = slx;
550 suml2 = sl2;
551 n_changed += 1;
552 }
553 }
554 }
555 }
556 if n_changed == 0 {
557 break;
558 }
559 }
560
561 sumlx / suml2
562}