candle_core/quantized/
k_quants.rs

1use super::utils::{
2    get_scale_min_k4, group_for_dequantization, group_for_quantization, make_q3_quants,
3    make_qkx1_quants, make_qx_quants, nearest_int,
4};
5use super::GgmlDType;
6use crate::Result;
7use byteorder::{ByteOrder, LittleEndian};
8use half::f16;
9use rayon::prelude::*;
10
11// Default to QK_K 256 rather than 64.
12pub const QK_K: usize = 256;
13pub const K_SCALE_SIZE: usize = 12;
14
15pub const QK4_0: usize = 32;
16pub const QK4_1: usize = 32;
17pub const QK5_0: usize = 32;
18pub const QK5_1: usize = 32;
19pub const QK8_0: usize = 32;
20pub const QK8_1: usize = 32;
21
22pub trait GgmlType: Sized + Clone + Send + Sync {
23    const DTYPE: GgmlDType;
24    const BLCK_SIZE: usize;
25    type VecDotType: GgmlType;
26
27    // This is only safe for types that include immediate values such as float/int/...
28    fn zeros() -> Self {
29        unsafe { std::mem::MaybeUninit::zeroed().assume_init() }
30    }
31    fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()>;
32    fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()>;
33
34    /// Dot product used as a building block for quantized mat-mul.
35    /// n is the number of elements to be considered.
36    fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32>;
37
38    /// Generic implementation of the dot product without simd optimizations.
39    fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32>;
40}
41
42#[derive(Debug, Clone, PartialEq)]
43#[repr(C)]
44pub struct BlockQ4_0 {
45    pub(crate) d: f16,
46    pub(crate) qs: [u8; QK4_0 / 2],
47}
48const _: () = assert!(std::mem::size_of::<BlockQ4_0>() == 18);
49
50#[derive(Debug, Clone, PartialEq)]
51#[repr(C)]
52pub struct BlockQ4_1 {
53    pub(crate) d: f16,
54    pub(crate) m: f16,
55    pub(crate) qs: [u8; QK4_1 / 2],
56}
57const _: () = assert!(std::mem::size_of::<BlockQ4_1>() == 20);
58
59#[derive(Debug, Clone, PartialEq)]
60#[repr(C)]
61pub struct BlockQ5_0 {
62    pub(crate) d: f16,
63    pub(crate) qh: [u8; 4],
64    pub(crate) qs: [u8; QK5_0 / 2],
65}
66const _: () = assert!(std::mem::size_of::<BlockQ5_0>() == 22);
67
68#[derive(Debug, Clone, PartialEq)]
69#[repr(C)]
70pub struct BlockQ5_1 {
71    pub(crate) d: f16,
72    pub(crate) m: f16,
73    pub(crate) qh: [u8; 4],
74    pub(crate) qs: [u8; QK5_1 / 2],
75}
76const _: () = assert!(std::mem::size_of::<BlockQ5_1>() == 24);
77
78#[derive(Debug, Clone, PartialEq)]
79#[repr(C)]
80pub struct BlockQ8_0 {
81    pub(crate) d: f16,
82    pub(crate) qs: [i8; QK8_0],
83}
84const _: () = assert!(std::mem::size_of::<BlockQ8_0>() == 34);
85
86#[derive(Debug, Clone, PartialEq)]
87#[repr(C)]
88pub struct BlockQ8_1 {
89    pub(crate) d: f16,
90    pub(crate) s: f16,
91    pub(crate) qs: [i8; QK8_1],
92}
93const _: () = assert!(std::mem::size_of::<BlockQ8_1>() == 36);
94
95#[derive(Debug, Clone, PartialEq)]
96#[repr(C)]
97pub struct BlockQ2K {
98    pub(crate) scales: [u8; QK_K / 16],
99    pub(crate) qs: [u8; QK_K / 4],
100    pub(crate) d: f16,
101    pub(crate) dmin: f16,
102}
103const _: () = assert!(QK_K / 16 + QK_K / 4 + 2 * 2 == std::mem::size_of::<BlockQ2K>());
104
105#[derive(Debug, Clone, PartialEq)]
106#[repr(C)]
107pub struct BlockQ3K {
108    pub(crate) hmask: [u8; QK_K / 8],
109    pub(crate) qs: [u8; QK_K / 4],
110    pub(crate) scales: [u8; 12],
111    pub(crate) d: f16,
112}
113const _: () = assert!(QK_K / 8 + QK_K / 4 + 12 + 2 == std::mem::size_of::<BlockQ3K>());
114
115#[derive(Debug, Clone, PartialEq)]
116// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/k_quants.h#L82
117#[repr(C)]
118pub struct BlockQ4K {
119    pub(crate) d: f16,
120    pub(crate) dmin: f16,
121    pub(crate) scales: [u8; K_SCALE_SIZE],
122    pub(crate) qs: [u8; QK_K / 2],
123}
124const _: () = assert!(QK_K / 2 + K_SCALE_SIZE + 2 * 2 == std::mem::size_of::<BlockQ4K>());
125
126#[derive(Debug, Clone, PartialEq)]
127#[repr(C)]
128pub struct BlockQ5K {
129    pub(crate) d: f16,
130    pub(crate) dmin: f16,
131    pub(crate) scales: [u8; K_SCALE_SIZE],
132    pub(crate) qh: [u8; QK_K / 8],
133    pub(crate) qs: [u8; QK_K / 2],
134}
135const _: () =
136    assert!(QK_K / 8 + QK_K / 2 + 2 * 2 + K_SCALE_SIZE == std::mem::size_of::<BlockQ5K>());
137
138#[derive(Debug, Clone, PartialEq)]
139#[repr(C)]
140pub struct BlockQ6K {
141    pub(crate) ql: [u8; QK_K / 2],
142    pub(crate) qh: [u8; QK_K / 4],
143    pub(crate) scales: [i8; QK_K / 16],
144    pub(crate) d: f16,
145}
146const _: () = assert!(3 * QK_K / 4 + QK_K / 16 + 2 == std::mem::size_of::<BlockQ6K>());
147
148#[derive(Debug, Clone, PartialEq)]
149#[repr(C)]
150pub struct BlockQ8K {
151    pub(crate) d: f32,
152    pub(crate) qs: [i8; QK_K],
153    pub(crate) bsums: [i16; QK_K / 16],
154}
155const _: () = assert!(4 + QK_K + QK_K / 16 * 2 == std::mem::size_of::<BlockQ8K>());
156
157impl GgmlType for BlockQ4_0 {
158    const DTYPE: GgmlDType = GgmlDType::Q4_0;
159    const BLCK_SIZE: usize = QK4_0;
160    type VecDotType = BlockQ8_0;
161
162    // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1525
163    fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
164        let k = ys.len();
165        let qk = Self::BLCK_SIZE;
166        if k % qk != 0 {
167            crate::bail!("dequantize_row_q4_0: {k} is not divisible by {qk}")
168        }
169
170        let nb = k / qk;
171        for i in 0..nb {
172            let d = xs[i].d.to_f32();
173
174            for j in 0..(qk / 2) {
175                let x0 = (xs[i].qs[j] & 0x0F) as i16 - 8;
176                let x1 = (xs[i].qs[j] >> 4) as i16 - 8;
177
178                ys[i * qk + j] = (x0 as f32) * d;
179                ys[i * qk + j + qk / 2] = (x1 as f32) * d;
180            }
181        }
182        Ok(())
183    }
184
185    fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
186        // quantize_row_q4_0
187        let qk = Self::BLCK_SIZE;
188        let k = xs.len();
189        if k % qk != 0 {
190            crate::bail!("{k} is not divisible by {}", qk);
191        };
192        let nb = k / qk;
193        if ys.len() != nb {
194            crate::bail!("size mismatch {} {} {}", xs.len(), ys.len(), qk,)
195        }
196        for (i, ys) in ys.iter_mut().enumerate() {
197            let mut amax = 0f32;
198            let mut max = 0f32;
199
200            let xs = &xs[i * qk..(i + 1) * qk];
201            for &x in xs.iter() {
202                if amax < x.abs() {
203                    amax = x.abs();
204                    max = x;
205                }
206            }
207            let d = max / -8.0;
208            let id = if d != 0f32 { 1. / d } else { 0. };
209            ys.d = f16::from_f32(d);
210
211            for (j, q) in ys.qs.iter_mut().enumerate() {
212                let x0 = xs[j] * id;
213                let x1 = xs[qk / 2 + j] * id;
214                let xi0 = u8::min(15, (x0 + 8.5) as u8);
215                let xi1 = u8::min(15, (x1 + 8.5) as u8);
216                *q = xi0 | (xi1 << 4)
217            }
218        }
219        Ok(())
220    }
221
222    // https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/ggml.c#L2361C10-L2361C122
223    #[allow(unreachable_code)]
224    fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
225        #[cfg(target_feature = "avx")]
226        return super::avx::vec_dot_q4_0_q8_0(n, xs, ys);
227
228        #[cfg(target_feature = "neon")]
229        return super::neon::vec_dot_q4_0_q8_0(n, xs, ys);
230
231        #[cfg(target_feature = "simd128")]
232        return super::simd128::vec_dot_q4_0_q8_0(n, xs, ys);
233
234        Self::vec_dot_unopt(n, xs, ys)
235    }
236
237    fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
238        let qk = QK8_0;
239        if n % QK8_0 != 0 {
240            crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
241        }
242        // Generic implementation.
243        let mut sumf = 0f32;
244        for (xs, ys) in xs.iter().zip(ys.iter()) {
245            let mut sum_i = 0;
246            for j in 0..qk / 2 {
247                let v0 = (xs.qs[j] & 0x0F) as i32 - 8;
248                let v1 = (xs.qs[j] >> 4) as i32 - 8;
249                sum_i += v0 * ys.qs[j] as i32 + v1 * ys.qs[j + qk / 2] as i32
250            }
251            sumf += sum_i as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d)
252        }
253        Ok(sumf)
254    }
255}
256
257impl GgmlType for BlockQ4_1 {
258    const DTYPE: GgmlDType = GgmlDType::Q4_1;
259    const BLCK_SIZE: usize = QK4_1;
260    type VecDotType = BlockQ8_1;
261
262    fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
263        Self::vec_dot_unopt(n, xs, ys)
264    }
265
266    fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
267        // ggml_vec_dot_q4_1_q8_1
268        let qk = QK8_1;
269        if n % qk != 0 {
270            crate::bail!("vec_dot_q4_1_q8_1: {n} is not divisible by {qk}")
271        }
272        let nb = n / qk;
273        if nb % 2 != 0 {
274            crate::bail!("vec_dot_q4_1_q8_1: {n}, nb is not divisible by 2")
275        }
276
277        // Generic implementation.
278        let mut sumf = 0f32;
279
280        for (xs, ys) in xs.iter().zip(ys.iter()) {
281            let mut sumi = 0i32;
282
283            for j in 0..qk / 2 {
284                let v0 = xs.qs[j] as i32 & 0x0F;
285                let v1 = xs.qs[j] as i32 >> 4;
286                sumi += (v0 * ys.qs[j] as i32) + (v1 * ys.qs[j + qk / 2] as i32);
287            }
288
289            sumf += sumi as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d)
290                + f16::to_f32(xs.m) * f16::to_f32(ys.s)
291        }
292        Ok(sumf)
293    }
294
295    fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
296        // quantize_row_q4_1
297        let qk = Self::BLCK_SIZE;
298        if ys.len() * qk != xs.len() {
299            crate::bail!("size mismatch {} {} {}", xs.len(), ys.len(), qk,)
300        }
301        for (i, ys) in ys.iter_mut().enumerate() {
302            let xs = &xs[i * qk..(i + 1) * qk];
303
304            let mut min = f32::INFINITY;
305            let mut max = f32::NEG_INFINITY;
306            for &x in xs.iter() {
307                min = f32::min(x, min);
308                max = f32::max(x, max);
309            }
310            let d = (max - min) / ((1 << 4) - 1) as f32;
311            let id = if d != 0f32 { 1. / d } else { 0. };
312            ys.d = f16::from_f32(d);
313            ys.m = f16::from_f32(min);
314
315            for (j, q) in ys.qs.iter_mut().take(qk / 2).enumerate() {
316                let x0 = (xs[j] - min) * id;
317                let x1 = (xs[qk / 2 + j] - min) * id;
318
319                let xi0 = u8::min(15, (x0 + 0.5) as u8);
320                let xi1 = u8::min(15, (x1 + 0.5) as u8);
321
322                *q = xi0 | (xi1 << 4);
323            }
324        }
325        Ok(())
326    }
327
328    // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1545
329    fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
330        let k = ys.len();
331        if k % QK4_1 != 0 {
332            crate::bail!("dequantize_row_q4_1: {k} is not divisible by {QK4_1}");
333        }
334
335        let nb = k / QK4_1;
336        for i in 0..nb {
337            let d = xs[i].d.to_f32();
338            let m = xs[i].m.to_f32();
339
340            for j in 0..(QK4_1 / 2) {
341                let x0 = xs[i].qs[j] & 0x0F;
342                let x1 = xs[i].qs[j] >> 4;
343
344                ys[i * QK4_1 + j] = (x0 as f32) * d + m;
345                ys[i * QK4_1 + j + QK4_1 / 2] = (x1 as f32) * d + m;
346            }
347        }
348        Ok(())
349    }
350}
351
352impl GgmlType for BlockQ5_0 {
353    const DTYPE: GgmlDType = GgmlDType::Q5_0;
354    const BLCK_SIZE: usize = QK5_0;
355    type VecDotType = BlockQ8_0;
356
357    fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
358        let qk = Self::BLCK_SIZE;
359        if n % Self::BLCK_SIZE != 0 {
360            crate::bail!("vec_dot_q5_0_q8_0: {n} is not divisible by {qk}")
361        }
362        let nb = n / qk;
363        if nb % 2 != 0 {
364            crate::bail!("vec_dot_q5_0_q8_0: {n}, nb is not divisible by 2")
365        }
366        Self::vec_dot_unopt(n, xs, ys)
367    }
368
369    fn vec_dot_unopt(_n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
370        // Generic implementation.
371        let mut sumf = 0f32;
372
373        for (xs, ys) in xs.iter().zip(ys.iter()) {
374            let qh = LittleEndian::read_u32(&xs.qh);
375            let mut sumi = 0i32;
376
377            for j in 0..Self::BLCK_SIZE / 2 {
378                let xh_0 = (((qh & (1u32 << j)) >> j) << 4) as u8;
379                let xh_1 = ((qh & (1u32 << (j + 16))) >> (j + 12)) as u8;
380
381                let x0 = ((xs.qs[j] & 0x0F) as i32 | xh_0 as i32) - 16;
382                let x1 = ((xs.qs[j] >> 4) as i32 | xh_1 as i32) - 16;
383
384                sumi += (x0 * ys.qs[j] as i32) + (x1 * ys.qs[j + Self::BLCK_SIZE / 2] as i32);
385            }
386
387            sumf += sumi as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d)
388        }
389        Ok(sumf)
390    }
391
392    fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
393        // quantize_row_q5_0
394        let k = xs.len();
395        if ys.len() * Self::BLCK_SIZE != k {
396            crate::bail!("size mismatch {k} {} {}", ys.len(), Self::BLCK_SIZE)
397        }
398        for (i, ys) in ys.iter_mut().enumerate() {
399            let xs = &xs[i * Self::BLCK_SIZE..(i + 1) * Self::BLCK_SIZE];
400
401            let mut amax = 0f32;
402            let mut max = 0f32;
403            for &x in xs.iter() {
404                if amax < x.abs() {
405                    amax = x.abs();
406                    max = x;
407                }
408            }
409            let d = max / -16.;
410            let id = if d != 0f32 { 1. / d } else { 0. };
411            ys.d = f16::from_f32(d);
412            let mut qh = 0u32;
413            for j in 0..Self::BLCK_SIZE / 2 {
414                let x0 = xs[j] * id;
415                let x1 = xs[j + Self::BLCK_SIZE / 2] * id;
416                let xi0 = ((x0 + 16.5) as i8).min(31) as u8;
417                let xi1 = ((x1 + 16.5) as i8).min(31) as u8;
418                ys.qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
419                qh |= ((xi0 as u32 & 0x10) >> 4) << j;
420                qh |= ((xi1 as u32 & 0x10) >> 4) << (j + Self::BLCK_SIZE / 2);
421            }
422            LittleEndian::write_u32(&mut ys.qh, qh)
423        }
424        Ok(())
425    }
426
427    // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1566
428    fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
429        let k = ys.len();
430        if k % QK5_0 != 0 {
431            crate::bail!("dequantize_row_q5_0: {k} is not divisible by {QK5_0}");
432        }
433
434        let nb = k / QK5_0;
435        for i in 0..nb {
436            let d = xs[i].d.to_f32();
437            let qh: u32 = LittleEndian::read_u32(&xs[i].qh);
438
439            for j in 0..(QK5_0 / 2) {
440                let xh_0 = (((qh >> j) << 4) & 0x10) as u8;
441                let xh_1 = ((qh >> (j + 12)) & 0x10) as u8;
442
443                let x0 = ((xs[i].qs[j] & 0x0F) | xh_0) as i32 - 16;
444                let x1 = ((xs[i].qs[j] >> 4) | xh_1) as i32 - 16;
445
446                ys[i * QK5_0 + j] = (x0 as f32) * d;
447                ys[i * QK5_0 + j + QK5_0 / 2] = (x1 as f32) * d;
448            }
449        }
450        Ok(())
451    }
452}
453
454impl GgmlType for BlockQ5_1 {
455    const DTYPE: GgmlDType = GgmlDType::Q5_1;
456    const BLCK_SIZE: usize = QK5_1;
457    type VecDotType = BlockQ8_1;
458
459    fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
460        Self::vec_dot_unopt(n, xs, ys)
461    }
462
463    fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
464        let qk = Self::BLCK_SIZE;
465        if n % Self::BLCK_SIZE != 0 {
466            crate::bail!("vec_dot_q5_1_q8_1: {n} is not divisible by {qk}")
467        }
468        let nb = n / qk;
469        if nb % 2 != 0 {
470            crate::bail!("vec_dot_q5_1_q8_1: {n}, nb is not divisible by 2")
471        }
472
473        // Generic implementation.
474        let mut sumf = 0f32;
475
476        for (xs, ys) in xs.iter().zip(ys.iter()) {
477            let qh = LittleEndian::read_u32(&xs.qh);
478            let mut sumi = 0i32;
479
480            for j in 0..Self::BLCK_SIZE / 2 {
481                let xh_0 = ((qh >> j) << 4) & 0x10;
482                let xh_1 = (qh >> (j + 12)) & 0x10;
483
484                let x0 = (xs.qs[j] as i32 & 0xF) | xh_0 as i32;
485                let x1 = (xs.qs[j] as i32 >> 4) | xh_1 as i32;
486
487                sumi += (x0 * ys.qs[j] as i32) + (x1 * ys.qs[j + Self::BLCK_SIZE / 2] as i32);
488            }
489
490            sumf += sumi as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d)
491                + f16::to_f32(xs.m) * f16::to_f32(ys.s)
492        }
493        Ok(sumf)
494    }
495
496    fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
497        // quantize_row_q5_1
498        let qk = Self::BLCK_SIZE;
499        if ys.len() * qk != xs.len() {
500            crate::bail!("size mismatch {} {} {}", xs.len(), ys.len(), qk,)
501        }
502        for (i, ys) in ys.iter_mut().enumerate() {
503            let xs = &xs[i * qk..(i + 1) * qk];
504
505            let mut min = f32::INFINITY;
506            let mut max = f32::NEG_INFINITY;
507            for &x in xs.iter() {
508                min = f32::min(x, min);
509                max = f32::max(x, max);
510            }
511            let d = (max - min) / ((1 << 5) - 1) as f32;
512            let id = if d != 0f32 { 1. / d } else { 0. };
513            ys.d = f16::from_f32(d);
514            ys.m = f16::from_f32(min);
515
516            let mut qh = 0u32;
517            for (j, q) in ys.qs.iter_mut().take(qk / 2).enumerate() {
518                let x0 = (xs[j] - min) * id;
519                let x1 = (xs[qk / 2 + j] - min) * id;
520
521                let xi0 = (x0 + 0.5) as u8;
522                let xi1 = (x1 + 0.5) as u8;
523
524                *q = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
525                // get the 5-th bit and store it in qh at the right position
526                qh |= ((xi0 as u32 & 0x10) >> 4) << j;
527                qh |= ((xi1 as u32 & 0x10) >> 4) << (j + qk / 2);
528            }
529            LittleEndian::write_u32(&mut ys.qh, qh);
530        }
531        Ok(())
532    }
533
534    // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1592
535    fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
536        let k = ys.len();
537        if k % QK5_1 != 0 {
538            crate::bail!("dequantize_row_q5_1: {k} is not divisible by {QK5_1}");
539        }
540
541        let nb = k / QK5_1;
542        for i in 0..nb {
543            let d = xs[i].d.to_f32();
544            let m = xs[i].m.to_f32();
545            let qh: u32 = LittleEndian::read_u32(&xs[i].qh);
546
547            for j in 0..(QK5_1 / 2) {
548                let xh_0 = (((qh >> j) << 4) & 0x10) as u8;
549                let xh_1 = ((qh >> (j + 12)) & 0x10) as u8;
550
551                let x0 = (xs[i].qs[j] & 0x0F) | xh_0;
552                let x1 = (xs[i].qs[j] >> 4) | xh_1;
553
554                ys[i * QK5_1 + j] = (x0 as f32) * d + m;
555                ys[i * QK5_1 + j + QK5_1 / 2] = (x1 as f32) * d + m;
556            }
557        }
558        Ok(())
559    }
560}
561
562impl GgmlType for BlockQ8_0 {
563    const DTYPE: GgmlDType = GgmlDType::Q8_0;
564    const BLCK_SIZE: usize = QK8_0;
565    type VecDotType = BlockQ8_0;
566
567    // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1619
568    fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
569        let k = ys.len();
570        if k % QK8_0 != 0 {
571            crate::bail!("dequantize_row_q8_0: {k} is not divisible by {QK8_0}");
572        }
573
574        let nb = k / QK8_0;
575
576        for i in 0..nb {
577            let d = xs[i].d.to_f32();
578
579            for j in 0..QK8_0 {
580                ys[i * QK8_0 + j] = xs[i].qs[j] as f32 * d;
581            }
582        }
583        Ok(())
584    }
585
586    fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
587        // quantize_row_q8_0
588        let k = xs.len();
589        if k % Self::BLCK_SIZE != 0 {
590            crate::bail!("{k} is not divisible by {}", Self::BLCK_SIZE);
591        };
592        let nb = k / Self::BLCK_SIZE;
593        if ys.len() != nb {
594            crate::bail!(
595                "size mismatch {} {} {}",
596                xs.len(),
597                ys.len(),
598                Self::BLCK_SIZE
599            )
600        }
601        for (i, ys) in ys.iter_mut().enumerate() {
602            let mut amax = 0f32;
603            let xs = &xs[i * Self::BLCK_SIZE..(i + 1) * Self::BLCK_SIZE];
604            for &x in xs.iter() {
605                amax = amax.max(x.abs())
606            }
607            let d = amax / ((1 << 7) - 1) as f32;
608            let id = if d != 0f32 { 1. / d } else { 0. };
609            ys.d = f16::from_f32(d);
610            for (y, &x) in ys.qs.iter_mut().zip(xs.iter()) {
611                *y = f32::round(x * id) as i8
612            }
613        }
614        Ok(())
615    }
616
617    #[allow(unreachable_code)]
618    fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
619        #[cfg(target_feature = "avx")]
620        return super::avx::vec_dot_q8_0_q8_0(n, xs, ys);
621
622        #[cfg(target_feature = "neon")]
623        return super::neon::vec_dot_q8_0_q8_0(n, xs, ys);
624
625        #[cfg(target_feature = "simd128")]
626        return super::simd128::vec_dot_q8_0_q8_0(n, xs, ys);
627
628        Self::vec_dot_unopt(n, xs, ys)
629    }
630
631    fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
632        let qk = QK8_0;
633        if n % QK8_0 != 0 {
634            crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
635        }
636
637        // Generic implementation.
638        let mut sumf = 0f32;
639        for (xs, ys) in xs.iter().zip(ys.iter()) {
640            let sum_i = xs
641                .qs
642                .iter()
643                .zip(ys.qs.iter())
644                .map(|(&x, &y)| x as i32 * y as i32)
645                .sum::<i32>();
646            sumf += sum_i as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d)
647        }
648        Ok(sumf)
649    }
650}
651
652impl GgmlType for BlockQ8_1 {
653    const DTYPE: GgmlDType = GgmlDType::Q8_1;
654    const BLCK_SIZE: usize = QK8_1;
655    type VecDotType = BlockQ8_1;
656
657    fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
658        Self::vec_dot_unopt(n, xs, ys)
659    }
660
661    fn vec_dot_unopt(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
662        unimplemented!("no support for vec-dot on Q8_1")
663    }
664
665    fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
666        // quantize_row_q8_1
667        let k = xs.len();
668        if ys.len() * Self::BLCK_SIZE != k {
669            crate::bail!("size mismatch {k} {} {}", ys.len(), Self::BLCK_SIZE)
670        }
671        for (i, ys) in ys.iter_mut().enumerate() {
672            let mut amax = 0f32;
673            let xs = &xs[i * Self::BLCK_SIZE..(i + 1) * Self::BLCK_SIZE];
674            for &x in xs.iter() {
675                amax = amax.max(x.abs())
676            }
677            let d = amax / ((1 << 7) - 1) as f32;
678            let id = if d != 0f32 { 1. / d } else { 0. };
679            ys.d = f16::from_f32(d);
680            let mut sum = 0i32;
681            for j in 0..Self::BLCK_SIZE / 2 {
682                let v0 = xs[j] * id;
683                let v1 = xs[j + Self::BLCK_SIZE / 2] * id;
684                ys.qs[j] = f32::round(v0) as i8;
685                ys.qs[j + Self::BLCK_SIZE / 2] = f32::round(v1) as i8;
686                sum += ys.qs[j] as i32 + ys.qs[j + Self::BLCK_SIZE / 2] as i32;
687            }
688            ys.s = f16::from_f32(sum as f32) * ys.d;
689        }
690        Ok(())
691    }
692
693    fn to_float(_xs: &[Self], _ys: &mut [f32]) -> Result<()> {
694        unimplemented!("no support for vec-dot on Q8_1")
695    }
696}
697
698impl GgmlType for BlockQ2K {
699    const DTYPE: GgmlDType = GgmlDType::Q2K;
700    const BLCK_SIZE: usize = QK_K;
701    type VecDotType = BlockQ8K;
702
703    #[allow(unreachable_code)]
704    fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
705        #[cfg(target_feature = "avx")]
706        return super::avx::vec_dot_q2k_q8k(n, xs, ys);
707
708        #[cfg(target_feature = "neon")]
709        return super::neon::vec_dot_q2k_q8k(n, xs, ys);
710
711        #[cfg(target_feature = "simd128")]
712        return super::simd128::vec_dot_q2k_q8k(n, xs, ys);
713
714        Self::vec_dot_unopt(n, xs, ys)
715    }
716
717    fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
718        if n % QK_K != 0 {
719            crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
720        }
721
722        let mut sumf = 0.0;
723        for (x, y) in xs.iter().zip(ys.iter()) {
724            let mut q2: &[_] = &x.qs;
725            let mut q8: &[_] = &y.qs;
726            let sc = &x.scales;
727
728            let mut summs = 0;
729            for (bsum, scale) in y.bsums.iter().zip(sc) {
730                summs += *bsum as i32 * ((scale >> 4) as i32);
731            }
732
733            let dall = y.d * x.d.to_f32();
734            let dmin = y.d * x.dmin.to_f32();
735
736            let mut isum = 0;
737            let mut is = 0;
738            for _ in 0..(QK_K / 128) {
739                let mut shift = 0;
740                for _ in 0..4 {
741                    let d = (sc[is] & 0xF) as i32;
742                    is += 1;
743                    let mut isuml = 0;
744                    for l in 0..16 {
745                        isuml += q8[l] as i32 * (((q2[l] >> shift) & 3) as i32);
746                    }
747                    isum += d * isuml;
748                    let d = (sc[is] & 0xF) as i32;
749                    is += 1;
750                    isuml = 0;
751                    for l in 16..32 {
752                        isuml += q8[l] as i32 * (((q2[l] >> shift) & 3) as i32);
753                    }
754                    isum += d * isuml;
755                    shift += 2;
756                    // adjust the indexing
757                    q8 = &q8[32..];
758                }
759                // adjust the indexing
760                q2 = &q2[32..];
761            }
762            sumf += dall * isum as f32 - dmin * summs as f32;
763        }
764
765        Ok(sumf)
766    }
767
768    // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L279
769    fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
770        const Q4SCALE: f32 = 15.0;
771
772        for (block, x) in group_for_quantization(xs, ys)? {
773            //calculate scales and mins
774            let mut mins: [f32; QK_K / 16] = [0.0; QK_K / 16];
775            let mut scales: [f32; QK_K / 16] = [0.0; QK_K / 16];
776
777            for (j, x_scale_slice) in x.chunks(16).enumerate() {
778                (scales[j], mins[j]) = make_qkx1_quants(3, 5, x_scale_slice);
779            }
780            // get max scale and max min and ensure they are >= 0.0
781            let max_scale = scales.iter().fold(0.0, |max, &val| val.max(max));
782            let max_min = mins.iter().fold(0.0, |max, &val| val.max(max));
783
784            if max_scale > 0.0 {
785                let iscale = Q4SCALE / max_scale;
786                for (j, scale) in scales.iter().enumerate().take(QK_K / 16) {
787                    block.scales[j] = nearest_int(iscale * scale) as u8;
788                }
789                block.d = f16::from_f32(max_scale / Q4SCALE);
790            } else {
791                for j in 0..QK_K / 16 {
792                    block.scales[j] = 0;
793                }
794                block.d = f16::from_f32(0.0);
795            }
796
797            if max_min > 0.0 {
798                let iscale = Q4SCALE / max_min;
799                for (j, scale) in block.scales.iter_mut().enumerate() {
800                    let l = nearest_int(iscale * mins[j]) as u8;
801                    *scale |= l << 4;
802                }
803                block.dmin = f16::from_f32(max_min / Q4SCALE);
804            } else {
805                block.dmin = f16::from_f32(0.0);
806            }
807
808            let mut big_l: [u8; QK_K] = [0; QK_K];
809
810            for j in 0..QK_K / 16 {
811                let d = block.d.to_f32() * (block.scales[j] & 0xF) as f32;
812                if d == 0.0 {
813                    continue;
814                }
815                let dm = block.dmin.to_f32() * (block.scales[j] >> 4) as f32;
816                for ii in 0..16 {
817                    let ll = nearest_int((x[16 * j + ii] + dm) / d).clamp(0, 3);
818                    big_l[16 * j + ii] = ll as u8;
819                }
820            }
821
822            for j in (0..QK_K).step_by(128) {
823                for ll in 0..32 {
824                    block.qs[j / 4 + ll] = big_l[j + ll]
825                        | (big_l[j + ll + 32] << 2)
826                        | (big_l[j + ll + 64] << 4)
827                        | (big_l[j + ll + 96] << 6);
828                }
829            }
830        }
831        Ok(())
832    }
833    // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L354
834    fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
835        for (block, y) in group_for_dequantization(xs, ys)? {
836            let d = block.d.to_f32();
837            let min = block.dmin.to_f32();
838
839            let mut is = 0;
840
841            for (y_block, qs) in y.chunks_exact_mut(128).zip(block.qs.chunks_exact(32)) {
842                // Step by 32 over q.
843                let mut shift = 0;
844                let mut y_block_index = 0;
845                for _j in 0..4 {
846                    let sc = block.scales[is];
847                    is += 1;
848                    let dl = d * (sc & 0xF) as f32;
849                    let ml = min * (sc >> 4) as f32;
850                    for q in &qs[..16] {
851                        let y = dl * ((q >> shift) & 3) as f32 - ml;
852                        y_block[y_block_index] = y;
853                        y_block_index += 1;
854                    }
855
856                    let sc = block.scales[is];
857                    is += 1;
858                    let dl = d * (sc & 0xF) as f32;
859                    let ml = min * (sc >> 4) as f32;
860                    for q in &qs[16..] {
861                        let y = dl * ((q >> shift) & 3) as f32 - ml;
862                        y_block[y_block_index] = y;
863                        y_block_index += 1;
864                    }
865
866                    shift += 2;
867                }
868            }
869        }
870        Ok(())
871    }
872}
873
874impl GgmlType for BlockQ3K {
875    const DTYPE: GgmlDType = GgmlDType::Q3K;
876    const BLCK_SIZE: usize = QK_K;
877    type VecDotType = BlockQ8K;
878
879    #[allow(unreachable_code)]
880    fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
881        #[cfg(target_feature = "avx")]
882        return super::avx::vec_dot_q3k_q8k(n, xs, ys);
883
884        #[cfg(target_feature = "neon")]
885        return super::neon::vec_dot_q3k_q8k(n, xs, ys);
886
887        Self::vec_dot_unopt(n, xs, ys)
888    }
889
890    fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
891        if n % QK_K != 0 {
892            crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}")
893        }
894
895        const KMASK1: u32 = 0x03030303;
896        const KMASK2: u32 = 0x0f0f0f0f;
897
898        let mut aux8: [i8; QK_K] = [0; QK_K];
899        let mut aux16: [i16; 8] = [0; 8];
900        let mut sums: [f32; 8] = [0.0; 8];
901        let mut aux32: [i32; 8] = [0; 8];
902
903        let mut auxs: [u32; 4] = [0; 4];
904
905        for (x, y) in xs.iter().zip(ys.iter()) {
906            let mut q3: &[u8] = &x.qs;
907            let hmask: &[u8] = &x.hmask;
908            let mut q8: &[i8] = &y.qs;
909
910            aux32.fill(0);
911            let mut a = &mut aux8[..];
912
913            let mut m = 1;
914            //Like the GGML original this is written this way to enable the compiler to vectorize it.
915            for _ in 0..QK_K / 128 {
916                a.iter_mut()
917                    .take(32)
918                    .zip(q3)
919                    .for_each(|(a_val, q3_val)| *a_val = (q3_val & 3) as i8);
920                a.iter_mut()
921                    .take(32)
922                    .zip(hmask)
923                    .for_each(|(a_val, hmask_val)| {
924                        *a_val -= if hmask_val & m != 0 { 0 } else { 4 }
925                    });
926                a = &mut a[32..];
927                m <<= 1;
928
929                a.iter_mut()
930                    .take(32)
931                    .zip(q3)
932                    .for_each(|(a_val, q3_val)| *a_val = ((q3_val >> 2) & 3) as i8);
933                a.iter_mut()
934                    .take(32)
935                    .zip(hmask)
936                    .for_each(|(a_val, hmask_val)| {
937                        *a_val -= if hmask_val & m != 0 { 0 } else { 4 }
938                    });
939                a = &mut a[32..];
940                m <<= 1;
941
942                a.iter_mut()
943                    .take(32)
944                    .zip(q3)
945                    .for_each(|(a_val, q3_val)| *a_val = ((q3_val >> 4) & 3) as i8);
946                a.iter_mut()
947                    .take(32)
948                    .zip(hmask)
949                    .for_each(|(a_val, hmask_val)| {
950                        *a_val -= if hmask_val & m != 0 { 0 } else { 4 }
951                    });
952                a = &mut a[32..];
953                m <<= 1;
954
955                a.iter_mut()
956                    .take(32)
957                    .zip(q3)
958                    .for_each(|(a_val, q3_val)| *a_val = ((q3_val >> 6) & 3) as i8);
959                a.iter_mut()
960                    .take(32)
961                    .zip(hmask)
962                    .for_each(|(a_val, hmask_val)| {
963                        *a_val -= if hmask_val & m != 0 { 0 } else { 4 }
964                    });
965                a = &mut a[32..];
966                m <<= 1;
967                q3 = &q3[32..];
968            }
969
970            a = &mut aux8[..];
971
972            LittleEndian::read_u32_into(&x.scales, &mut auxs[0..3]);
973
974            let tmp = auxs[2];
975            auxs[2] = ((auxs[0] >> 4) & KMASK2) | (((tmp >> 4) & KMASK1) << 4);
976            auxs[3] = ((auxs[1] >> 4) & KMASK2) | (((tmp >> 6) & KMASK1) << 4);
977            auxs[0] = (auxs[0] & KMASK2) | (((tmp) & KMASK1) << 4);
978            auxs[1] = (auxs[1] & KMASK2) | (((tmp >> 2) & KMASK1) << 4);
979
980            for aux in auxs {
981                for scale in aux.to_le_bytes() {
982                    let scale = i8::from_be_bytes([scale]);
983                    for l in 0..8 {
984                        aux16[l] = q8[l] as i16 * a[l] as i16;
985                    }
986                    for l in 0..8 {
987                        aux32[l] += (scale as i32 - 32) * aux16[l] as i32;
988                    }
989                    q8 = &q8[8..];
990                    a = &mut a[8..];
991
992                    for l in 0..8 {
993                        aux16[l] = q8[l] as i16 * a[l] as i16;
994                    }
995                    for l in 0..8 {
996                        aux32[l] += (scale as i32 - 32) * aux16[l] as i32;
997                    }
998                    q8 = &q8[8..];
999                    a = &mut a[8..];
1000                }
1001            }
1002            let d = x.d.to_f32() * y.d;
1003            for l in 0..8 {
1004                sums[l] += d * aux32[l] as f32;
1005            }
1006        }
1007
1008        Ok(sums.iter().sum())
1009    }
1010
1011    fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
1012        for (block, x) in group_for_quantization(xs, ys)? {
1013            let mut scales: [f32; QK_K / 16] = [0.0; QK_K / 16];
1014            for (j, x_scale_slice) in x.chunks_exact(16).enumerate() {
1015                scales[j] = make_q3_quants(x_scale_slice, 4, true);
1016            }
1017
1018            // Get max scale by absolute value.
1019            let mut max_scale: f32 = 0.0;
1020            for &scale in scales.iter() {
1021                if scale.abs() > max_scale.abs() {
1022                    max_scale = scale;
1023                }
1024            }
1025
1026            block.scales.fill(0);
1027
1028            if max_scale != 0.0 {
1029                let iscale = -32.0 / max_scale;
1030                for (j, scale) in scales.iter().enumerate() {
1031                    let l_val = nearest_int(iscale * scale);
1032                    let l_val = l_val.clamp(-32, 31) + 32;
1033                    if j < 8 {
1034                        block.scales[j] = (l_val & 0xF) as u8;
1035                    } else {
1036                        block.scales[j - 8] |= ((l_val & 0xF) << 4) as u8;
1037                    }
1038                    let l_val = l_val >> 4;
1039                    block.scales[j % 4 + 8] |= (l_val << (2 * (j / 4))) as u8;
1040                }
1041                block.d = f16::from_f32(1.0 / iscale);
1042            } else {
1043                block.d = f16::from_f32(0.0);
1044            }
1045
1046            let mut l: [i8; QK_K] = [0; QK_K];
1047
1048            for j in 0..QK_K / 16 {
1049                let sc = if j < 8 {
1050                    block.scales[j] & 0xF
1051                } else {
1052                    block.scales[j - 8] >> 4
1053                };
1054                let sc = (sc | (((block.scales[8 + j % 4] >> (2 * (j / 4))) & 3) << 4)) as i8 - 32;
1055                let d = block.d.to_f32() * sc as f32;
1056                if d != 0.0 {
1057                    for ii in 0..16 {
1058                        let l_val = nearest_int(x[16 * j + ii] / d);
1059                        l[16 * j + ii] = (l_val.clamp(-4, 3) + 4) as i8;
1060                    }
1061                }
1062            }
1063
1064            block.hmask.fill(0);
1065            let mut m = 0;
1066            let mut hm = 1;
1067
1068            for ll in l.iter_mut() {
1069                if *ll > 3 {
1070                    block.hmask[m] |= hm;
1071                    *ll -= 4;
1072                }
1073                m += 1;
1074                if m == QK_K / 8 {
1075                    m = 0;
1076                    hm <<= 1;
1077                }
1078            }
1079
1080            for j in (0..QK_K).step_by(128) {
1081                for l_val in 0..32 {
1082                    block.qs[j / 4 + l_val] = (l[j + l_val]
1083                        | (l[j + l_val + 32] << 2)
1084                        | (l[j + l_val + 64] << 4)
1085                        | (l[j + l_val + 96] << 6))
1086                        as u8;
1087                }
1088            }
1089        }
1090
1091        Ok(())
1092    }
1093
1094    // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L533
1095    fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
1096        const KMASK1: u32 = 0x03030303;
1097        const KMASK2: u32 = 0x0f0f0f0f;
1098
1099        for (block, y) in group_for_dequantization(xs, ys)? {
1100            //Reconstruct the scales
1101            let mut aux = [0; 4];
1102            LittleEndian::read_u32_into(&block.scales, &mut aux[0..3]);
1103
1104            let tmp = aux[2];
1105            aux[2] = ((aux[0] >> 4) & KMASK2) | (((tmp >> 4) & KMASK1) << 4);
1106            aux[3] = ((aux[1] >> 4) & KMASK2) | (((tmp >> 6) & KMASK1) << 4);
1107            aux[0] = (aux[0] & KMASK2) | (((tmp) & KMASK1) << 4);
1108            aux[1] = (aux[1] & KMASK2) | (((tmp >> 2) & KMASK1) << 4);
1109
1110            //Transfer the scales into an i8 array
1111            let scales: &mut [i8] =
1112                unsafe { std::slice::from_raw_parts_mut(aux.as_mut_ptr() as *mut i8, 16) };
1113
1114            let d_all = block.d.to_f32();
1115            let mut m = 1;
1116            let mut is = 0;
1117
1118            // Dequantize both 128 long blocks
1119            // 32 qs values per 128 long block
1120            // Each 16 elements get a scale
1121            for (y, qs) in y.chunks_exact_mut(128).zip(block.qs.chunks_exact(32)) {
1122                let mut shift = 0;
1123                for shift_scoped_y in y.chunks_exact_mut(32) {
1124                    for (scale_index, scale_scoped_y) in
1125                        shift_scoped_y.chunks_exact_mut(16).enumerate()
1126                    {
1127                        let dl = d_all * (scales[is] as f32 - 32.0);
1128                        for (i, inner_y) in scale_scoped_y.iter_mut().enumerate() {
1129                            let new_y = dl
1130                                * (((qs[i + 16 * scale_index] >> shift) & 3) as i8
1131                                    - if (block.hmask[i + 16 * scale_index] & m) == 0 {
1132                                        4
1133                                    } else {
1134                                        0
1135                                    }) as f32;
1136                            *inner_y = new_y;
1137                        }
1138                        // 16 block finished => advance scale index
1139                        is += 1;
1140                    }
1141                    // 32 block finished => increase shift and m
1142                    shift += 2;
1143                    m <<= 1;
1144                }
1145            }
1146        }
1147
1148        Ok(())
1149    }
1150}
1151
1152impl GgmlType for BlockQ4K {
1153    const DTYPE: GgmlDType = GgmlDType::Q4K;
1154    const BLCK_SIZE: usize = QK_K;
1155    type VecDotType = BlockQ8K;
1156
1157    #[allow(unreachable_code)]
1158    fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
1159        #[cfg(target_feature = "avx")]
1160        return super::avx::vec_dot_q4k_q8k(n, xs, ys);
1161
1162        #[cfg(target_feature = "neon")]
1163        return super::neon::vec_dot_q4k_q8k(n, xs, ys);
1164
1165        #[cfg(target_feature = "simd128")]
1166        return super::simd128::vec_dot_q4k_q8k(n, xs, ys);
1167
1168        Self::vec_dot_unopt(n, xs, ys)
1169    }
1170
1171    fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
1172        if n % QK_K != 0 {
1173            crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
1174        }
1175
1176        const KMASK1: u32 = 0x3f3f3f3f;
1177        const KMASK2: u32 = 0x0f0f0f0f;
1178        const KMASK3: u32 = 0x03030303;
1179
1180        let mut utmp: [u32; 4] = [0; 4];
1181        let mut scales: [u8; 8] = [0; 8];
1182        let mut mins: [u8; 8] = [0; 8];
1183
1184        let mut aux8: [i8; QK_K] = [0; QK_K];
1185        let mut aux16: [i16; 8] = [0; 8];
1186        let mut sums: [f32; 8] = [0.0; 8];
1187        let mut aux32: [i32; 8] = [0; 8];
1188
1189        let mut sumf = 0.0;
1190        for (y, x) in ys.iter().zip(xs.iter()) {
1191            let q4 = &x.qs;
1192            let q8 = &y.qs;
1193            aux32.fill(0);
1194
1195            let mut a = &mut aux8[..];
1196            let mut q4 = &q4[..];
1197            for _ in 0..QK_K / 64 {
1198                for l in 0..32 {
1199                    a[l] = (q4[l] & 0xF) as i8;
1200                }
1201                a = &mut a[32..];
1202                for l in 0..32 {
1203                    a[l] = (q4[l] >> 4) as i8;
1204                }
1205                a = &mut a[32..];
1206                q4 = &q4[32..];
1207            }
1208
1209            LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
1210
1211            utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
1212            let uaux = utmp[1] & KMASK1;
1213            utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
1214            utmp[2] = uaux;
1215            utmp[0] &= KMASK1;
1216
1217            //extract scales and mins
1218            LittleEndian::write_u32_into(&utmp[0..2], &mut scales);
1219            LittleEndian::write_u32_into(&utmp[2..4], &mut mins);
1220
1221            let mut sumi = 0;
1222            for j in 0..QK_K / 16 {
1223                sumi += y.bsums[j] as i32 * mins[j / 2] as i32;
1224            }
1225
1226            let mut a = &mut aux8[..];
1227            let mut q8 = &q8[..];
1228
1229            for scale in scales {
1230                let scale = scale as i32;
1231                for _ in 0..4 {
1232                    for l in 0..8 {
1233                        aux16[l] = q8[l] as i16 * a[l] as i16;
1234                    }
1235                    for l in 0..8 {
1236                        aux32[l] += scale * aux16[l] as i32;
1237                    }
1238                    q8 = &q8[8..];
1239                    a = &mut a[8..];
1240                }
1241            }
1242            let d = x.d.to_f32() * y.d;
1243            for l in 0..8 {
1244                sums[l] += d * aux32[l] as f32;
1245            }
1246            let dmin = x.dmin.to_f32() * y.d;
1247            sumf -= dmin * sumi as f32;
1248        }
1249        Ok(sumf + sums.iter().sum::<f32>())
1250    }
1251
1252    fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
1253        for (block, x) in group_for_quantization(xs, ys)? {
1254            let mut mins: [f32; QK_K / 32] = [0.0; QK_K / 32];
1255            let mut scales: [f32; QK_K / 32] = [0.0; QK_K / 32];
1256
1257            for (j, x_scale_slice) in x.chunks_exact(32).enumerate() {
1258                (scales[j], mins[j]) = make_qkx1_quants(15, 5, x_scale_slice);
1259            }
1260
1261            // get max scale and max min and ensure they are >= 0.0
1262            let max_scale = scales.iter().fold(0.0, |max, &val| val.max(max));
1263            let max_min = mins.iter().fold(0.0, |max, &val| val.max(max));
1264
1265            let inv_scale = if max_scale > 0.0 {
1266                63.0 / max_scale
1267            } else {
1268                0.0
1269            };
1270            let inv_min = if max_min > 0.0 { 63.0 / max_min } else { 0.0 };
1271
1272            for j in 0..QK_K / 32 {
1273                let ls = nearest_int(inv_scale * scales[j]).min(63) as u8;
1274                let lm = nearest_int(inv_min * mins[j]).min(63) as u8;
1275                if j < 4 {
1276                    block.scales[j] = ls;
1277                    block.scales[j + 4] = lm;
1278                } else {
1279                    block.scales[j + 4] = (ls & 0xF) | ((lm & 0xF) << 4);
1280                    block.scales[j - 4] |= (ls >> 4) << 6;
1281                    block.scales[j] |= (lm >> 4) << 6;
1282                }
1283            }
1284
1285            block.d = f16::from_f32(max_scale / 63.0);
1286            block.dmin = f16::from_f32(max_min / 63.0);
1287
1288            let mut l: [u8; QK_K] = [0; QK_K];
1289
1290            for j in 0..QK_K / 32 {
1291                let (sc, m) = get_scale_min_k4(j, &block.scales);
1292                let d = block.d.to_f32() * sc as f32;
1293                if d != 0.0 {
1294                    let dm = block.dmin.to_f32() * m as f32;
1295                    for ii in 0..32 {
1296                        let l_val = nearest_int((x[32 * j + ii] + dm) / d);
1297                        l[32 * j + ii] = l_val.clamp(0, 15) as u8;
1298                    }
1299                }
1300            }
1301
1302            let q = &mut block.qs;
1303            for j in (0..QK_K).step_by(64) {
1304                for l_val in 0..32 {
1305                    let offset_index = (j / 64) * 32 + l_val;
1306                    q[offset_index] = l[j + l_val] | (l[j + l_val + 32] << 4);
1307                }
1308            }
1309        }
1310        Ok(())
1311    }
1312    // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L735
1313    fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
1314        for (block, y) in group_for_dequantization(xs, ys)? {
1315            let d = block.d.to_f32();
1316            let min = block.dmin.to_f32();
1317            let q = &block.qs;
1318            let mut is = 0;
1319            let mut ys_index = 0;
1320
1321            for j in (0..QK_K).step_by(64) {
1322                let q = &q[j / 2..j / 2 + 32];
1323                let (sc, m) = get_scale_min_k4(is, &block.scales);
1324                let d1 = d * sc as f32;
1325                let m1 = min * m as f32;
1326                let (sc, m) = get_scale_min_k4(is + 1, &block.scales);
1327                let d2 = d * sc as f32;
1328                let m2 = min * m as f32;
1329                for q in q {
1330                    y[ys_index] = d1 * (q & 0xF) as f32 - m1;
1331                    ys_index += 1;
1332                }
1333                for q in q {
1334                    y[ys_index] = d2 * (q >> 4) as f32 - m2;
1335                    ys_index += 1;
1336                }
1337                is += 2;
1338            }
1339        }
1340        Ok(())
1341    }
1342}
1343
1344// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L928
1345impl GgmlType for BlockQ5K {
1346    const DTYPE: GgmlDType = GgmlDType::Q5K;
1347    const BLCK_SIZE: usize = QK_K;
1348    type VecDotType = BlockQ8K;
1349
1350    #[allow(unreachable_code)]
1351    fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
1352        #[cfg(target_feature = "avx")]
1353        return super::avx::vec_dot_q5k_q8k(n, xs, ys);
1354
1355        #[cfg(target_feature = "neon")]
1356        return super::neon::vec_dot_q5k_q8k(n, xs, ys);
1357
1358        Self::vec_dot_unopt(n, xs, ys)
1359    }
1360
1361    fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
1362        if n % QK_K != 0 {
1363            crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}")
1364        }
1365
1366        const KMASK1: u32 = 0x3f3f3f3f;
1367        const KMASK2: u32 = 0x0f0f0f0f;
1368        const KMASK3: u32 = 0x03030303;
1369
1370        let mut utmp: [u32; 4] = [0; 4];
1371        let mut scales: [u8; 8] = [0; 8];
1372        let mut mins: [u8; 8] = [0; 8];
1373
1374        let mut aux8: [i8; QK_K] = [0; QK_K];
1375        let mut aux16: [i16; 8] = [0; 8];
1376        let mut sums: [f32; 8] = [0.0; 8];
1377        let mut aux32: [i32; 8] = [0; 8];
1378
1379        let mut sumf = 0.0;
1380        for (y, x) in ys.iter().zip(xs.iter()) {
1381            let q5 = &x.qs;
1382            let hm = &x.qh;
1383            let q8 = &y.qs;
1384            aux32.fill(0);
1385
1386            let mut a = &mut aux8[..];
1387            let mut q5 = &q5[..];
1388            let mut m = 1u8;
1389
1390            for _ in 0..QK_K / 64 {
1391                for l in 0..32 {
1392                    a[l] = (q5[l] & 0xF) as i8;
1393                    a[l] += if hm[l] & m != 0 { 16 } else { 0 };
1394                }
1395                a = &mut a[32..];
1396                m <<= 1;
1397                for l in 0..32 {
1398                    a[l] = (q5[l] >> 4) as i8;
1399                    a[l] += if hm[l] & m != 0 { 16 } else { 0 };
1400                }
1401                a = &mut a[32..];
1402                m <<= 1;
1403                q5 = &q5[32..];
1404            }
1405
1406            LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
1407
1408            utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
1409            let uaux = utmp[1] & KMASK1;
1410            utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
1411            utmp[2] = uaux;
1412            utmp[0] &= KMASK1;
1413
1414            //extract scales and mins
1415            LittleEndian::write_u32_into(&utmp[0..2], &mut scales);
1416            LittleEndian::write_u32_into(&utmp[2..4], &mut mins);
1417
1418            let mut sumi = 0;
1419            for j in 0..QK_K / 16 {
1420                sumi += y.bsums[j] as i32 * mins[j / 2] as i32;
1421            }
1422
1423            let mut a = &mut aux8[..];
1424            let mut q8 = &q8[..];
1425
1426            for scale in scales {
1427                let scale = scale as i32;
1428                for _ in 0..4 {
1429                    for l in 0..8 {
1430                        aux16[l] = q8[l] as i16 * a[l] as i16;
1431                    }
1432                    for l in 0..8 {
1433                        aux32[l] += scale * aux16[l] as i32;
1434                    }
1435                    q8 = &q8[8..];
1436                    a = &mut a[8..];
1437                }
1438            }
1439            let d = x.d.to_f32() * y.d;
1440            for l in 0..8 {
1441                sums[l] += d * aux32[l] as f32;
1442            }
1443            let dmin = x.dmin.to_f32() * y.d;
1444            sumf -= dmin * sumi as f32;
1445        }
1446        Ok(sumf + sums.iter().sum::<f32>())
1447    }
1448
1449    // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L793
1450    fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
1451        for (block, x) in group_for_quantization(xs, ys)? {
1452            let mut mins: [f32; QK_K / 32] = [0.0; QK_K / 32];
1453            let mut scales: [f32; QK_K / 32] = [0.0; QK_K / 32];
1454
1455            for (j, x_scale_slice) in x.chunks_exact(32).enumerate() {
1456                (scales[j], mins[j]) = make_qkx1_quants(31, 5, x_scale_slice);
1457            }
1458
1459            // get max scale and max min and ensure they are >= 0.0
1460            let max_scale = scales.iter().fold(0.0, |max, &val| val.max(max));
1461            let max_min = mins.iter().fold(0.0, |max, &val| val.max(max));
1462
1463            let inv_scale = if max_scale > 0.0 {
1464                63.0 / max_scale
1465            } else {
1466                0.0
1467            };
1468            let inv_min = if max_min > 0.0 { 63.0 / max_min } else { 0.0 };
1469            for j in 0..QK_K / 32 {
1470                let ls = nearest_int(inv_scale * scales[j]).min(63) as u8;
1471                let lm = nearest_int(inv_min * mins[j]).min(63) as u8;
1472                if j < 4 {
1473                    block.scales[j] = ls;
1474                    block.scales[j + 4] = lm;
1475                } else {
1476                    block.scales[j + 4] = (ls & 0xF) | ((lm & 0xF) << 4);
1477                    block.scales[j - 4] |= (ls >> 4) << 6;
1478                    block.scales[j] |= (lm >> 4) << 6;
1479                }
1480            }
1481            block.d = f16::from_f32(max_scale / 63.0);
1482            block.dmin = f16::from_f32(max_min / 63.0);
1483
1484            let mut l: [u8; QK_K] = [0; QK_K];
1485            for j in 0..QK_K / 32 {
1486                let (sc, m) = get_scale_min_k4(j, &block.scales);
1487                let d = block.d.to_f32() * sc as f32;
1488                if d == 0.0 {
1489                    continue;
1490                }
1491                let dm = block.dmin.to_f32() * m as f32;
1492                for ii in 0..32 {
1493                    let ll = nearest_int((x[32 * j + ii] + dm) / d);
1494                    l[32 * j + ii] = ll.clamp(0, 31) as u8;
1495                }
1496            }
1497
1498            let qh = &mut block.qh;
1499            let ql = &mut block.qs;
1500            qh.fill(0);
1501
1502            let mut m1 = 1;
1503            let mut m2 = 2;
1504            for n in (0..QK_K).step_by(64) {
1505                let offset = (n / 64) * 32;
1506                for j in 0..32 {
1507                    let mut l1 = l[n + j];
1508                    if l1 > 15 {
1509                        l1 -= 16;
1510                        qh[j] |= m1;
1511                    }
1512                    let mut l2 = l[n + j + 32];
1513                    if l2 > 15 {
1514                        l2 -= 16;
1515                        qh[j] |= m2;
1516                    }
1517                    ql[offset + j] = l1 | (l2 << 4);
1518                }
1519                m1 <<= 2;
1520                m2 <<= 2;
1521            }
1522        }
1523
1524        Ok(())
1525    }
1526
1527    // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L928
1528    fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
1529        for (block, y) in group_for_dequantization(xs, ys)? {
1530            let d = block.d.to_f32();
1531            let min = block.dmin.to_f32();
1532            let ql = &block.qs;
1533            let qh = &block.qh;
1534            let mut is = 0;
1535            let mut u1 = 1;
1536            let mut u2 = 2;
1537            let mut ys_index = 0;
1538
1539            for j in (0..QK_K).step_by(64) {
1540                let ql = &ql[j / 2..j / 2 + 32];
1541                let (sc, m) = get_scale_min_k4(is, &block.scales);
1542                let d1 = d * sc as f32;
1543                let m1 = min * m as f32;
1544                let (sc, m) = get_scale_min_k4(is + 1, &block.scales);
1545                let d2 = d * sc as f32;
1546                let m2 = min * m as f32;
1547                for (ql, qh) in ql.iter().zip(qh) {
1548                    let to_add = if qh & u1 != 0 { 16f32 } else { 0f32 };
1549                    y[ys_index] = d1 * ((ql & 0xF) as f32 + to_add) - m1;
1550                    ys_index += 1;
1551                }
1552                for (ql, qh) in ql.iter().zip(qh) {
1553                    let to_add = if qh & u2 != 0 { 16f32 } else { 0f32 };
1554                    y[ys_index] = d2 * ((ql >> 4) as f32 + to_add) - m2;
1555                    ys_index += 1;
1556                }
1557                is += 2;
1558                u1 <<= 2;
1559                u2 <<= 2;
1560            }
1561        }
1562        Ok(())
1563    }
1564}
1565
1566impl GgmlType for BlockQ6K {
1567    const DTYPE: GgmlDType = GgmlDType::Q6K;
1568    const BLCK_SIZE: usize = QK_K;
1569    type VecDotType = BlockQ8K;
1570
1571    #[allow(unreachable_code)]
1572    fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
1573        #[cfg(target_feature = "avx")]
1574        return super::avx::vec_dot_q6k_q8k(n, xs, ys);
1575
1576        #[cfg(target_feature = "neon")]
1577        return super::neon::vec_dot_q6k_q8k(n, xs, ys);
1578
1579        #[cfg(target_feature = "simd128")]
1580        return super::simd128::vec_dot_q6k_q8k(n, xs, ys);
1581
1582        Self::vec_dot_unopt(n, xs, ys)
1583    }
1584
1585    fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
1586        if n % QK_K != 0 {
1587            crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}")
1588        }
1589
1590        let mut aux8 = [0i8; QK_K];
1591        let mut aux16 = [0i16; 8];
1592        let mut sums = [0f32; 8];
1593        let mut aux32 = [0f32; 8];
1594
1595        for (x, y) in xs.iter().zip(ys.iter()) {
1596            let q4 = &x.ql;
1597            let qh = &x.qh;
1598            let q8 = &y.qs;
1599            aux32.fill(0f32);
1600
1601            for j in (0..QK_K).step_by(128) {
1602                let aux8 = &mut aux8[j..];
1603                let q4 = &q4[j / 2..];
1604                let qh = &qh[j / 4..];
1605                for l in 0..32 {
1606                    aux8[l] = (((q4[l] & 0xF) | ((qh[l] & 3) << 4)) as i32 - 32) as i8;
1607                    aux8[l + 32] =
1608                        (((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) as i32 - 32) as i8;
1609                    aux8[l + 64] = (((q4[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) as i32 - 32) as i8;
1610                    aux8[l + 96] =
1611                        (((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) as i32 - 32) as i8;
1612                }
1613            }
1614
1615            for (j, &scale) in x.scales.iter().enumerate() {
1616                let scale = scale as f32;
1617                let q8 = &q8[16 * j..];
1618                let aux8 = &aux8[16 * j..];
1619                for l in 0..8 {
1620                    aux16[l] = q8[l] as i16 * aux8[l] as i16;
1621                }
1622                for l in 0..8 {
1623                    aux32[l] += scale * aux16[l] as f32
1624                }
1625                let q8 = &q8[8..];
1626                let aux8 = &aux8[8..];
1627                for l in 0..8 {
1628                    aux16[l] = q8[l] as i16 * aux8[l] as i16;
1629                }
1630                for l in 0..8 {
1631                    aux32[l] += scale * aux16[l] as f32
1632                }
1633            }
1634
1635            let d = x.d.to_f32() * y.d;
1636            for (sum, &a) in sums.iter_mut().zip(aux32.iter()) {
1637                *sum += a * d;
1638            }
1639        }
1640        Ok(sums.iter().sum())
1641    }
1642
1643    fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
1644        if xs.len() != ys.len() * Self::BLCK_SIZE {
1645            crate::bail!(
1646                "quantize_row_q6k: size mismatch {} {} {}",
1647                xs.len(),
1648                ys.len(),
1649                Self::BLCK_SIZE
1650            )
1651        }
1652        let mut l = [0i8; QK_K];
1653        let mut scales = [0f32; QK_K / 16];
1654        let mut x = xs.as_ptr();
1655        let l = l.as_mut_ptr();
1656        unsafe {
1657            for y in ys.iter_mut() {
1658                let mut max_scale = 0f32;
1659                let mut max_abs_scale = 0f32;
1660                for (ib, scale_) in scales.iter_mut().enumerate() {
1661                    let scale = make_qx_quants(16, 32, x.add(16 * ib), l.add(16 * ib), 1);
1662                    *scale_ = scale;
1663                    let abs_scale = scale.abs();
1664                    if abs_scale > max_abs_scale {
1665                        max_abs_scale = abs_scale;
1666                        max_scale = scale
1667                    }
1668                }
1669
1670                let iscale = -128f32 / max_scale;
1671                y.d = f16::from_f32(1.0 / iscale);
1672
1673                for (y_scale, scale) in y.scales.iter_mut().zip(scales.iter()) {
1674                    *y_scale = nearest_int(iscale * scale).min(127) as i8
1675                }
1676
1677                for (j, &y_scale) in y.scales.iter().enumerate() {
1678                    let d = y.d.to_f32() * y_scale as f32;
1679                    if d == 0. {
1680                        continue;
1681                    }
1682                    for ii in 0..16 {
1683                        let ll = nearest_int(*x.add(16 * j + ii) / d).clamp(-32, 31);
1684                        *l.add(16 * j + ii) = (ll + 32) as i8
1685                    }
1686                }
1687
1688                let mut ql = y.ql.as_mut_ptr();
1689                let mut qh = y.qh.as_mut_ptr();
1690
1691                for j in (0..QK_K).step_by(128) {
1692                    for l_idx in 0..32 {
1693                        let q1 = *l.add(j + l_idx) & 0xF;
1694                        let q2 = *l.add(j + l_idx + 32) & 0xF;
1695                        let q3 = *l.add(j + l_idx + 64) & 0xF;
1696                        let q4 = *l.add(j + l_idx + 96) & 0xF;
1697                        *ql.add(l_idx) = (q1 | (q3 << 4)) as u8;
1698                        *ql.add(l_idx + 32) = (q2 | (q4 << 4)) as u8;
1699                        *qh.add(l_idx) = ((*l.add(j + l_idx) >> 4)
1700                            | ((*l.add(j + l_idx + 32) >> 4) << 2)
1701                            | ((*l.add(j + l_idx + 64) >> 4) << 4)
1702                            | ((*l.add(j + l_idx + 96) >> 4) << 6))
1703                            as u8;
1704                    }
1705                    ql = ql.add(64);
1706                    qh = qh.add(32);
1707                }
1708
1709                x = x.add(QK_K)
1710            }
1711        }
1712        Ok(())
1713    }
1714
1715    // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L1067
1716    fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
1717        let k = ys.len();
1718        if k % QK_K != 0 {
1719            crate::bail!("dequantize_row_q6k: {k} is not divisible by {QK_K}")
1720        }
1721        for (idx_x, x) in xs.iter().enumerate() {
1722            let d = x.d.to_f32();
1723            let ql = &x.ql;
1724            let qh = &x.qh;
1725            let sc = &x.scales;
1726            for n in (0..QK_K).step_by(128) {
1727                let idx = n / 128;
1728                let ys = &mut ys[idx_x * QK_K + n..];
1729                let sc = &sc[8 * idx..];
1730                let ql = &ql[64 * idx..];
1731                let qh = &qh[32 * idx..];
1732                for l in 0..32 {
1733                    let is = l / 16;
1734                    let q1 = ((ql[l] & 0xF) | ((qh[l] & 3) << 4)) as i8 - 32;
1735                    let q2 = ((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) as i8 - 32;
1736                    let q3 = ((ql[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) as i8 - 32;
1737                    let q4 = ((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) as i8 - 32;
1738                    ys[l] = d * sc[is] as f32 * q1 as f32;
1739                    ys[l + 32] = d * sc[is + 2] as f32 * q2 as f32;
1740                    ys[l + 64] = d * sc[is + 4] as f32 * q3 as f32;
1741                    ys[l + 96] = d * sc[is + 6] as f32 * q4 as f32;
1742                }
1743            }
1744        }
1745        Ok(())
1746    }
1747}
1748
1749impl GgmlType for BlockQ8K {
1750    const DTYPE: GgmlDType = GgmlDType::Q8K;
1751    const BLCK_SIZE: usize = QK_K;
1752    type VecDotType = BlockQ8K;
1753
1754    #[allow(unreachable_code)]
1755    fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
1756        #[cfg(target_feature = "avx")]
1757        return super::avx::vec_dot_q8k_q8k(n, xs, ys);
1758
1759        #[cfg(target_feature = "neon")]
1760        return super::neon::vec_dot_q8k_q8k(n, xs, ys);
1761
1762        #[cfg(target_feature = "simd128")]
1763        return super::simd128::vec_dot_q8k_q8k(n, xs, ys);
1764
1765        Self::vec_dot_unopt(n, xs, ys)
1766    }
1767
1768    fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
1769        let qk = QK_K;
1770        if n % QK_K != 0 {
1771            crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}")
1772        }
1773
1774        // Generic implementation.
1775        let mut sumf = 0f32;
1776        for (xs, ys) in xs.iter().zip(ys.iter()) {
1777            let sum_i = xs
1778                .qs
1779                .iter()
1780                .zip(ys.qs.iter())
1781                .map(|(&x, &y)| x as i32 * y as i32)
1782                .sum::<i32>();
1783            sumf += sum_i as f32 * xs.d * ys.d
1784        }
1785        Ok(sumf)
1786    }
1787
1788    fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
1789        let k = xs.len();
1790        if k % QK_K != 0 {
1791            crate::bail!("quantize_row_q8k: {k} is not divisible by {QK_K}")
1792        }
1793        for (i, y) in ys.iter_mut().enumerate() {
1794            let mut max = 0f32;
1795            let mut amax = 0f32;
1796            let xs = &xs[i * QK_K..(i + 1) * QK_K];
1797            for &x in xs.iter() {
1798                if amax < x.abs() {
1799                    amax = x.abs();
1800                    max = x;
1801                }
1802            }
1803            if amax == 0f32 {
1804                y.d = 0f32;
1805                y.qs.fill(0)
1806            } else {
1807                let iscale = -128f32 / max;
1808                for (j, q) in y.qs.iter_mut().enumerate() {
1809                    // ggml uses nearest_int with bit magic here, maybe we want the same
1810                    // but we would have to test and benchmark it.
1811                    let v = (iscale * xs[j]).round();
1812                    *q = v.min(127.) as i8
1813                }
1814                for j in 0..QK_K / 16 {
1815                    let mut sum = 0i32;
1816                    for ii in 0..16 {
1817                        sum += y.qs[j * 16 + ii] as i32
1818                    }
1819                    y.bsums[j] = sum as i16
1820                }
1821                y.d = 1.0 / iscale
1822            }
1823        }
1824        Ok(())
1825    }
1826
1827    fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
1828        let k = ys.len();
1829        if k % QK_K != 0 {
1830            crate::bail!("dequantize_row_q8k: {k} is not divisible by {QK_K}")
1831        }
1832        for (i, x) in xs.iter().enumerate() {
1833            for (j, &q) in x.qs.iter().enumerate() {
1834                ys[i * QK_K + j] = x.d * q as f32
1835            }
1836        }
1837        Ok(())
1838    }
1839}
1840
1841// https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/ggml.c#L10605
1842pub fn matmul<T: GgmlType>(
1843    mkn: (usize, usize, usize),
1844    lhs: &[f32],
1845    rhs_t: &[T],
1846    dst: &mut [f32],
1847) -> Result<()> {
1848    let (m, k, n) = mkn;
1849    if m * k != lhs.len() {
1850        crate::bail!("unexpected lhs length {} {mkn:?}", lhs.len());
1851    }
1852
1853    let k_in_lhs_blocks = k.div_ceil(T::BLCK_SIZE);
1854    let k_in_rhs_blocks = k.div_ceil(T::VecDotType::BLCK_SIZE);
1855    // TODO: Do not make this copy if the DotType is f32.
1856    // TODO: Pre-allocate this.
1857    let mut lhs_b = vec![T::VecDotType::zeros(); m * k_in_lhs_blocks];
1858    for row_idx in 0..m {
1859        let lhs_b = &mut lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks];
1860        let lhs = &lhs[row_idx * k..(row_idx + 1) * k];
1861        T::VecDotType::from_float(lhs, lhs_b)?
1862    }
1863    let lhs_b = lhs_b.as_slice();
1864
1865    for row_idx in 0..m {
1866        let lhs_row = &lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks];
1867        let dst_row = &mut dst[row_idx * n..(row_idx + 1) * n];
1868
1869        let result: Result<Vec<_>> = dst_row
1870            .into_par_iter()
1871            .enumerate()
1872            .with_min_len(128)
1873            .with_max_len(512)
1874            .map(|(col_idx, dst)| {
1875                let rhs_col = &rhs_t[col_idx * k_in_rhs_blocks..(col_idx + 1) * k_in_rhs_blocks];
1876                T::vec_dot(k, rhs_col, lhs_row).map(|value| *dst = value)
1877            })
1878            .collect();
1879
1880        result?;
1881    }
1882    Ok(())
1883}
1884
1885impl GgmlType for f32 {
1886    const DTYPE: GgmlDType = GgmlDType::F32;
1887    const BLCK_SIZE: usize = 1;
1888    type VecDotType = f32;
1889
1890    fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
1891        Self::vec_dot_unopt(n, xs, ys)
1892    }
1893
1894    fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
1895        if xs.len() < n {
1896            crate::bail!("size mismatch {} < {n}", xs.len())
1897        }
1898        if ys.len() < n {
1899            crate::bail!("size mismatch {} < {n}", ys.len())
1900        }
1901        let mut res = 0f32;
1902        unsafe { crate::cpu::vec_dot_f32(xs.as_ptr(), ys.as_ptr(), &mut res, n) };
1903        Ok(res)
1904    }
1905
1906    fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
1907        if xs.len() != ys.len() {
1908            crate::bail!("size mismatch {} {}", xs.len(), ys.len());
1909        }
1910        ys.copy_from_slice(xs);
1911        Ok(())
1912    }
1913
1914    fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
1915        if xs.len() != ys.len() {
1916            crate::bail!("size mismatch {} {}", xs.len(), ys.len());
1917        }
1918        ys.copy_from_slice(xs);
1919        Ok(())
1920    }
1921}
1922
1923impl GgmlType for f16 {
1924    const DTYPE: GgmlDType = GgmlDType::F16;
1925    const BLCK_SIZE: usize = 1;
1926    type VecDotType = f16;
1927
1928    fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
1929        Self::vec_dot_unopt(n, xs, ys)
1930    }
1931
1932    fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
1933        if xs.len() < n {
1934            crate::bail!("size mismatch {} < {n}", xs.len())
1935        }
1936        if ys.len() < n {
1937            crate::bail!("size mismatch {} < {n}", ys.len())
1938        }
1939        let mut res = 0f32;
1940        unsafe { crate::cpu::vec_dot_f16(xs.as_ptr(), ys.as_ptr(), &mut res, n) };
1941        Ok(res)
1942    }
1943
1944    fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
1945        if xs.len() != ys.len() {
1946            crate::bail!("size mismatch {} {}", xs.len(), ys.len());
1947        }
1948        // TODO: vectorize
1949        for (x, y) in xs.iter().zip(ys.iter_mut()) {
1950            *y = f16::from_f32(*x)
1951        }
1952        Ok(())
1953    }
1954
1955    fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
1956        if xs.len() != ys.len() {
1957            crate::bail!("size mismatch {} {}", xs.len(), ys.len());
1958        }
1959        // TODO: vectorize
1960        for (x, y) in xs.iter().zip(ys.iter_mut()) {
1961            *y = x.to_f32()
1962        }
1963        Ok(())
1964    }
1965}