candle_core/quantized/
utils.rs

1pub(super) fn nearest_int(v: f32) -> i32 {
2    v.round() as i32
3}
4
5/// Validates that the input and output are the right size and returns an iterator which maps each
6/// input region `xs` to its corresponding output block in `ys`. Each output region is guaranteed
7/// to be `T::BLCK_SIZE` long.
8pub(super) fn group_for_quantization<'a, 'b, T: super::k_quants::GgmlType>(
9    xs: &'b [f32],
10    ys: &'a mut [T],
11) -> Vec<(&'a mut T, &'b [f32])> {
12    let block_size = T::BLCK_SIZE;
13    let dtype = T::DTYPE;
14
15    let expected_blocks = xs.len() / block_size;
16    let actual_blocks = ys.len();
17
18    // Validate that the input is the right size
19    debug_assert_eq!(
20        expected_blocks,
21        actual_blocks,
22        "quantize {dtype:?}: expected {expected_blocks} blocks but only {actual_blocks} were provided!");
23
24    ys.iter_mut().zip(xs.chunks_exact(block_size)).collect()
25}
26
27/// Validates that the input and output are the right size and returns an iterator which maps each
28/// input block `xs` to its corresponding output region in `ys`. Each output region is guaranteed
29/// to be `T::BLCK_SIZE` long.
30pub(super) fn group_for_dequantization<'a, 'b, T: super::k_quants::GgmlType>(
31    xs: &'a [T],
32    ys: &'b mut [f32],
33) -> Vec<(&'a T, &'b mut [f32])> {
34    let block_size = T::BLCK_SIZE;
35    let dtype = T::DTYPE;
36
37    let actual_output_len = ys.len();
38    let expected_output_len = xs.len() * block_size;
39    // Validate that the output is the right size
40    debug_assert_eq!(
41        expected_output_len,
42        actual_output_len,
43        "dequantize {dtype:?}: ys (len = {actual_output_len}) does not match the expected length of {expected_output_len}!"
44    );
45
46    // Zip the blocks and outputs together
47    xs.iter().zip(ys.chunks_exact_mut(block_size)).collect()
48}
49
50pub(super) fn get_scale_min_k4(j: usize, q: &[u8]) -> (u8, u8) {
51    if j < 4 {
52        let d = q[j] & 63;
53        let m = q[j + 4] & 63;
54        (d, m)
55    } else {
56        let d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
57        let m = (q[j + 4] >> 4) | ((q[j] >> 6) << 4);
58        (d, m)
59    }
60}
61
62pub(super) unsafe fn make_qx_quants(
63    n: usize,
64    nmax: i32,
65    x: *const f32,
66    ls: *mut i8,
67    rmse_type: i32,
68    qw: *const f32,
69) -> f32 {
70    let mut max = 0f32;
71    let mut amax = 0f32;
72    for i in 0..n {
73        let x = *x.add(i);
74        let ax = x.abs();
75        if ax > amax {
76            amax = ax;
77            max = x;
78        }
79    }
80    if amax == 0. {
81        // all zero
82        for i in 0..n {
83            *ls.add(i) = 0;
84        }
85        return 0.;
86    }
87    let mut iscale = -(nmax as f32) / max;
88    if rmse_type == 0 {
89        for i in 0..n {
90            let x = *x.add(i);
91            let l = nearest_int(iscale * x);
92            *ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8;
93        }
94        return 1.0 / iscale;
95    }
96    let weight_type = rmse_type % 2;
97    let mut sumlx = 0f32;
98    let mut suml2 = 0f32;
99    for i in 0..n {
100        let x = *x.add(i);
101        let l = nearest_int(iscale * x);
102        let l = l.clamp(-nmax, nmax - 1);
103        *ls.add(i) = (l + nmax) as i8;
104        let w = if !qw.is_null() {
105            *qw.add(i)
106        } else if weight_type == 1 {
107            x * x
108        } else {
109            1.0
110        };
111        let l = l as f32;
112        sumlx += w * x * l;
113        suml2 += w * l * l;
114    }
115    let mut scale = sumlx / suml2;
116    let mut best = scale * sumlx;
117    for _itry in 0..3 {
118        let iscale = 1.0 / scale;
119        let mut slx = 0f32;
120        let mut sl2 = 0f32;
121        let mut changed = false;
122        for i in 0..n {
123            let x = *x.add(i);
124            let l = nearest_int(iscale * x);
125            let l = l.clamp(-nmax, nmax - 1);
126            if l + nmax != *ls.add(i) as i32 {
127                changed = true;
128            }
129            let w = if !qw.is_null() {
130                *qw.add(i)
131            } else if weight_type == 1 {
132                x * x
133            } else {
134                1.0
135            };
136            let l = l as f32;
137            slx += w * x * l;
138            sl2 += w * l * l;
139        }
140        if !changed || sl2 == 0.0 || slx * slx <= best * sl2 {
141            break;
142        }
143        for i in 0..n {
144            let x = *x.add(i);
145            let l = nearest_int(iscale * x);
146            *ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8;
147        }
148        sumlx = slx;
149        suml2 = sl2;
150        scale = sumlx / suml2;
151        best = scale * sumlx;
152    }
153    for _itry in 0..5 {
154        let mut n_changed = 0;
155        for i in 0..n {
156            let x = *x.add(i);
157            let w = if !qw.is_null() {
158                *qw.add(i)
159            } else if weight_type == 1 {
160                x * x
161            } else {
162                1.0
163            };
164            let l = *ls.add(i) as i32 - nmax;
165            let mut slx = sumlx - w * x * l as f32;
166            if slx > 0. {
167                let mut sl2 = suml2 - w * l as f32 * l as f32;
168                let new_l = nearest_int(x * sl2 / slx);
169                let new_l = new_l.clamp(-nmax, nmax - 1);
170                if new_l != l {
171                    slx += w * x * new_l as f32;
172                    sl2 += w * new_l as f32 * new_l as f32;
173                    if sl2 > 0. && slx * slx * suml2 > sumlx * sumlx * sl2 {
174                        *ls.add(i) = (nmax + new_l) as i8;
175                        sumlx = slx;
176                        suml2 = sl2;
177                        scale = sumlx / suml2;
178                        best = scale * sumlx;
179                        n_changed += 1;
180                    }
181                }
182            }
183        }
184        if n_changed == 0 {
185            break;
186        }
187    }
188    if rmse_type < 3 {
189        return scale;
190    }
191    for is in -4..4 {
192        if is == 0 {
193            continue;
194        }
195        iscale = -(nmax as f32 + 0.1f32 * is as f32) / max;
196        let mut sumlx = 0.;
197        let mut suml2 = 0.;
198        for i in 0..n {
199            let x = *x.add(i);
200            let l = nearest_int(iscale * x);
201            let l = l.clamp(-nmax, nmax - 1);
202            let w = if !qw.is_null() {
203                *qw.add(i)
204            } else if weight_type == 1 {
205                x * x
206            } else {
207                1.0
208            };
209            let l = l as f32;
210            sumlx += w * x * l;
211            suml2 += w * l * l;
212        }
213        if suml2 > 0. && sumlx * sumlx > best * suml2 {
214            for i in 0..n {
215                let x = *x.add(i);
216                let l = nearest_int(iscale * x);
217                *ls.add(i) = (nmax + l.clamp(-nmax, nmax - 1)) as i8;
218            }
219            scale = sumlx / suml2;
220            best = scale * sumlx;
221        }
222    }
223    scale
224}
225
226// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L224
227pub(super) fn make_qkx1_quants(nmax: i32, ntry: usize, x: &[f32]) -> (f32, f32) {
228    let n = x.len();
229    let mut l = vec![0; n];
230    // Get min/max
231    let min = *x
232        .iter()
233        .take(n)
234        .min_by(|a, b| a.total_cmp(b))
235        .unwrap_or(&x[0]);
236    let max = *x.iter().max_by(|a, b| a.total_cmp(b)).unwrap_or(&x[0]);
237
238    // If min == max, all values are the same => nothing to do here
239    if max == min {
240        return (0.0, 0.0);
241    }
242
243    // Ensure min <= 0.0
244    let mut min = min.min(0.);
245
246    // Compute scale and inverse scale
247    let mut iscale = nmax as f32 / (max - min);
248    let mut scale = 1.0 / iscale;
249
250    for _ in 0..ntry {
251        let mut sumlx = 0.0;
252        let mut suml2 = 0;
253        let mut did_change = false;
254
255        for (i, value) in x.iter().enumerate().take(n) {
256            let li = nearest_int(iscale * (value - min)).clamp(0, nmax);
257            let clamped_li = li as u8;
258            if clamped_li != l[i] {
259                l[i] = clamped_li;
260                did_change = true;
261            }
262            sumlx += (value - min) * li as f32;
263            suml2 += li * li;
264        }
265        scale = sumlx / suml2 as f32;
266
267        let sum: f32 = x
268            .iter()
269            .take(n)
270            .zip(l.iter().take(n))
271            .map(|(xi, &li)| xi - scale * li as f32)
272            .sum();
273
274        min = sum / n as f32;
275        if min > 0.0 {
276            min = 0.0;
277        }
278        iscale = 1.0 / scale;
279        if !did_change {
280            break;
281        }
282    }
283    (scale, -min)
284}
285
286// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L165
287pub(super) fn make_q3_quants(x: &[f32], nmax: i32, do_rmse: bool) -> f32 {
288    let n = x.len();
289    let mut l = vec![0i8; n];
290
291    let mut max = 0.0;
292    let mut amax = 0.0;
293    for &xi in x.iter().take(n) {
294        let ax = xi.abs();
295        if ax > amax {
296            amax = ax;
297            max = xi;
298        }
299    }
300
301    if amax == 0.0 {
302        return 0.0;
303    }
304
305    let iscale = -(nmax as f32) / max;
306    if do_rmse {
307        let mut sumlx = 0.0;
308        let mut suml2 = 0.0;
309        for i in 0..n {
310            let li = (iscale * x[i]).round() as i32;
311            let li = li.clamp(-nmax, nmax - 1);
312            l[i] = li as i8;
313            let w = x[i] * x[i];
314            sumlx += w * x[i] * li as f32;
315            suml2 += w * (li * li) as f32;
316        }
317        for _ in 0..5 {
318            let mut n_changed = 0;
319            for i in 0..n {
320                let w = x[i] * x[i];
321                let mut slx = sumlx - w * x[i] * l[i] as f32;
322                if slx > 0.0 {
323                    let mut sl2 = suml2 - w * (l[i] as i32 * l[i] as i32) as f32;
324                    let mut new_l = (x[i] * sl2 / slx).round() as i32;
325                    new_l = new_l.clamp(-nmax, nmax - 1);
326                    if new_l != l[i] as i32 {
327                        slx += w * x[i] * new_l as f32;
328                        sl2 += w * (new_l * new_l) as f32;
329                        if sl2 > 0.0 && slx * slx * suml2 > sumlx * sumlx * sl2 {
330                            l[i] = new_l as i8;
331                            sumlx = slx;
332                            suml2 = sl2;
333                            n_changed += 1;
334                        }
335                    }
336                }
337            }
338            if n_changed == 0 {
339                break;
340            }
341        }
342        for li in l.iter_mut() {
343            *li += nmax as i8;
344        }
345        return sumlx / suml2;
346    }
347    for i in 0..n {
348        let li = (iscale * x[i]).round() as i32;
349        l[i] = (li.clamp(-nmax, nmax - 1) + nmax) as i8;
350    }
351    1.0 / iscale
352}
353
354// https://github.com/ggerganov/llama.cpp/blob/678d7994f4da0af3d29046be99950ac999ee9762/ggml/src/ggml-quants.c#L744
355/// (scale, min)
356pub(super) fn make_qkx3_quants(
357    nmax: i32,
358    x: &[f32],
359    weights: Option<&[f32]>,
360    rmin: f32,
361    rdelta: f32,
362    nstep: usize,
363    use_mad: bool,
364) -> (f32, f32) {
365    let n = x.len();
366    let mut l: [u8; 32] = [0; 32];
367    let mut l_aux: [u8; 32] = [0; 32];
368
369    let mut min_val = x[0];
370    let mut max_val = x[0];
371    let mut sum_w = match weights {
372        Some(w) => w[0],
373        None => x[0] * x[0],
374    };
375    let mut sum_x = sum_w * x[0];
376
377    for i in 1..n {
378        if x[i] < min_val {
379            min_val = x[i];
380        }
381        if x[i] > max_val {
382            max_val = x[i];
383        }
384        let w = match weights {
385            Some(w) => w[i],
386            None => x[i] * x[i],
387        };
388        sum_w += w;
389        sum_x += w * x[i];
390    }
391
392    if min_val > 0.0 {
393        min_val = 0.0;
394    }
395
396    if max_val <= min_val {
397        return (0.0, -min_val);
398    }
399
400    let mut iscale = nmax as f32 / (max_val - min_val);
401    let mut scale = 1.0 / iscale;
402    let mut best_mad = 0.0;
403
404    for i in 0..n {
405        let l_val = nearest_int(iscale * (x[i] - min_val)).clamp(0, nmax) as u8;
406        l[i] = l_val;
407        let diff = scale * (l_val as f32) + min_val - x[i];
408        let diff = if use_mad { diff.abs() } else { diff * diff };
409        let w = match weights {
410            Some(w) => w[i],
411            None => x[i] * x[i],
412        };
413        best_mad += w * diff;
414    }
415
416    if nstep < 1 {
417        return (scale, -min_val);
418    }
419
420    for is in 0..=nstep {
421        iscale = (rmin + rdelta * is as f32 + nmax as f32) / (max_val - min_val);
422        let (mut sum_l, mut sum_l2, mut sum_xl) = (0.0, 0.0, 0.0);
423
424        for i in 0..n {
425            let l_val = nearest_int(iscale * (x[i] - min_val)).clamp(0, nmax) as u8;
426            l_aux[i] = l_val;
427            let w = match weights {
428                Some(w) => w[i],
429                None => x[i] * x[i],
430            };
431            sum_l += w * l_val as f32;
432            sum_l2 += w * (l_val as f32).powi(2);
433            sum_xl += w * l_val as f32 * x[i];
434        }
435
436        let d = sum_w * sum_l2 - sum_l * sum_l;
437        if d > 0.0 {
438            let mut this_scale = (sum_w * sum_xl - sum_x * sum_l) / d;
439            let mut this_min = (sum_l2 * sum_x - sum_l * sum_xl) / d;
440
441            if this_min > 0.0 {
442                this_min = 0.0;
443                this_scale = sum_xl / sum_l2;
444            }
445
446            let mut mad = 0.0;
447            for i in 0..n {
448                let diff = this_scale * (l_aux[i] as f32) + this_min - x[i];
449                let diff = if use_mad { diff.abs() } else { diff * diff };
450                let w = match weights {
451                    Some(w) => w[i],
452                    None => x[i] * x[i],
453                };
454                mad += w * diff;
455            }
456
457            if mad < best_mad {
458                l.copy_from_slice(&l_aux);
459                best_mad = mad;
460                scale = this_scale;
461                min_val = this_min;
462            }
463        }
464    }
465
466    (scale, -min_val)
467}
468
469// https://github.com/ggerganov/llama.cpp/blob/678d7994f4da0af3d29046be99950ac999ee9762/ggml/src/ggml-quants.c#L827
470pub(super) fn make_qp_quants(
471    n: usize,
472    nmax: u8,
473    x: &[f32],
474    l: &mut [u8],
475    quant_weights: &[f32],
476) -> f32 {
477    assert_eq!(x.len(), n);
478    assert_eq!(l.len(), n);
479    assert_eq!(quant_weights.len(), n);
480
481    let max = x.iter().copied().fold(0.0, f32::max);
482    if max == 0.0 {
483        l.iter_mut().for_each(|li| *li = 0);
484        return 0.0;
485    }
486
487    let mut iscale = nmax as f32 / max;
488    for (xi, li) in x.iter().zip(l.iter_mut()) {
489        *li = nearest_int(iscale * xi) as u8;
490    }
491
492    let scale = 1.0 / iscale;
493    let mut best_mse = x
494        .iter()
495        .zip(l.iter())
496        .zip(quant_weights.iter())
497        .map(|((&xi, &li), &w)| {
498            let diff = xi - scale * li as f32;
499            w * diff * diff
500        })
501        .sum::<f32>();
502
503    for is in -4..=4 {
504        if is == 0 {
505            continue;
506        }
507        let iscale_is = (0.1 * is as f32 + nmax as f32) / max;
508        let scale_is = 1.0 / iscale_is;
509
510        let mse = x
511            .iter()
512            .zip(quant_weights.iter())
513            .map(|(&xi, &w)| {
514                let mut li = nearest_int(iscale_is * xi) as u8;
515                li = li.min(nmax);
516                let diff = xi - scale_is * li as f32;
517                w * diff * diff
518            })
519            .sum::<f32>();
520
521        if mse < best_mse {
522            best_mse = mse;
523            iscale = iscale_is;
524        }
525    }
526
527    let mut sumlx = 0.0;
528    let mut suml2 = 0.0;
529    for ((xi, li), &w) in x.iter().zip(l.iter_mut()).zip(quant_weights.iter()) {
530        let mut li_new = (iscale * xi).round() as u8;
531        li_new = li_new.min(nmax);
532        *li = li_new;
533        sumlx += w * xi * li_new as f32;
534        suml2 += w * (li_new as f32).powi(2);
535    }
536
537    for _ in 0..5 {
538        let mut n_changed = 0;
539        for ((xi, li), &w) in x.iter().zip(l.iter_mut()).zip(quant_weights.iter()) {
540            let mut slx = sumlx - w * xi * *li as f32;
541            let mut sl2 = suml2 - w * (*li as f32).powi(2);
542            if slx > 0.0 && sl2 > 0.0 {
543                let new_li = (nearest_int(xi * sl2 / slx) as u8).min(nmax);
544                if new_li != *li {
545                    slx += w * xi * new_li as f32;
546                    sl2 += w * (new_li as f32).powi(2);
547                    if slx.powi(2) * suml2 > sumlx.powi(2) * sl2 {
548                        *li = new_li;
549                        sumlx = slx;
550                        suml2 = sl2;
551                        n_changed += 1;
552                    }
553                }
554            }
555        }
556        if n_changed == 0 {
557            break;
558        }
559    }
560
561    sumlx / suml2
562}