Skip to main content

candle_core/quantized/
k_quants.rs

1use super::utils::{
2    get_scale_min_k4, group_for_dequantization, group_for_quantization, make_q3_quants,
3    make_qkx1_quants, make_qx_quants, nearest_int,
4};
5use super::GgmlDType;
6use crate::quantized::utils::{make_qkx3_quants, make_qp_quants};
7use crate::Result;
8use byteorder::{ByteOrder, LittleEndian};
9use half::{bf16, f16, slice::HalfFloatSliceExt};
10use rayon::prelude::*;
11
12// Default to QK_K 256 rather than 64.
13pub const QK_K: usize = 256;
14pub const K_SCALE_SIZE: usize = 12;
15
16pub const QK4_0: usize = 32;
17pub const QK4_1: usize = 32;
18pub const QK5_0: usize = 32;
19pub const QK5_1: usize = 32;
20pub const QK8_0: usize = 32;
21pub const QK8_1: usize = 32;
22
23pub trait GgmlType: Sized + Clone + Send + Sync {
24    const DTYPE: GgmlDType;
25    const BLCK_SIZE: usize;
26    const DIRECT_COPY: bool = false;
27    type VecDotType: GgmlType;
28
29    // This is only safe for types that include immediate values such as float/int/...
30    fn zeros() -> Self {
31        unsafe { std::mem::MaybeUninit::zeroed().assume_init() }
32    }
33    fn to_float(xs: &[Self], ys: &mut [f32]);
34    fn from_float(xs: &[f32], ys: &mut [Self]);
35    fn from_float_imatrix(
36        _xs: &[f32],
37        _ys: &mut [Self],
38        _imatrix_weights: &[f32],
39        _n_per_row: usize,
40    ) {
41        panic!(
42            "`from_float_imatrix` is unimplemented for {:?}",
43            Self::DTYPE
44        );
45    }
46
47    fn direct_copy(_xs: &[f32], _ys: &mut [Self]) {}
48
49    /// Dot product used as a building block for quantized mat-mul.
50    /// n is the number of elements to be considered.
51    fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32;
52
53    /// Generic implementation of the dot product without simd optimizations.
54    fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32;
55}
56
57#[derive(Debug, Clone, PartialEq)]
58#[repr(C)]
59pub struct BlockQ4_0 {
60    pub(crate) d: f16,
61    pub(crate) qs: [u8; QK4_0 / 2],
62}
63const _: () = assert!(std::mem::size_of::<BlockQ4_0>() == 18);
64
65#[derive(Debug, Clone, PartialEq)]
66#[repr(C)]
67pub struct BlockQ4_1 {
68    pub(crate) d: f16,
69    pub(crate) m: f16,
70    pub(crate) qs: [u8; QK4_1 / 2],
71}
72const _: () = assert!(std::mem::size_of::<BlockQ4_1>() == 20);
73
74#[derive(Debug, Clone, PartialEq)]
75#[repr(C)]
76pub struct BlockQ5_0 {
77    pub(crate) d: f16,
78    pub(crate) qh: [u8; 4],
79    pub(crate) qs: [u8; QK5_0 / 2],
80}
81const _: () = assert!(std::mem::size_of::<BlockQ5_0>() == 22);
82
83#[derive(Debug, Clone, PartialEq)]
84#[repr(C)]
85pub struct BlockQ5_1 {
86    pub(crate) d: f16,
87    pub(crate) m: f16,
88    pub(crate) qh: [u8; 4],
89    pub(crate) qs: [u8; QK5_1 / 2],
90}
91const _: () = assert!(std::mem::size_of::<BlockQ5_1>() == 24);
92
93#[derive(Debug, Clone, PartialEq)]
94#[repr(C)]
95pub struct BlockQ8_0 {
96    pub(crate) d: f16,
97    pub(crate) qs: [i8; QK8_0],
98}
99const _: () = assert!(std::mem::size_of::<BlockQ8_0>() == 34);
100
101#[derive(Debug, Clone, PartialEq)]
102#[repr(C)]
103pub struct BlockQ8_1 {
104    pub(crate) d: f16,
105    pub(crate) s: f16,
106    pub(crate) qs: [i8; QK8_1],
107}
108const _: () = assert!(std::mem::size_of::<BlockQ8_1>() == 36);
109
110#[derive(Debug, Clone, PartialEq)]
111#[repr(C)]
112pub struct BlockQ2K {
113    pub(crate) scales: [u8; QK_K / 16],
114    pub(crate) qs: [u8; QK_K / 4],
115    pub(crate) d: f16,
116    pub(crate) dmin: f16,
117}
118const _: () = assert!(QK_K / 16 + QK_K / 4 + 2 * 2 == std::mem::size_of::<BlockQ2K>());
119
120#[derive(Debug, Clone, PartialEq)]
121#[repr(C)]
122pub struct BlockQ3K {
123    pub(crate) hmask: [u8; QK_K / 8],
124    pub(crate) qs: [u8; QK_K / 4],
125    pub(crate) scales: [u8; 12],
126    pub(crate) d: f16,
127}
128const _: () = assert!(QK_K / 8 + QK_K / 4 + 12 + 2 == std::mem::size_of::<BlockQ3K>());
129
130#[derive(Debug, Clone, PartialEq)]
131// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/k_quants.h#L82
132#[repr(C)]
133pub struct BlockQ4K {
134    pub(crate) d: f16,
135    pub(crate) dmin: f16,
136    pub(crate) scales: [u8; K_SCALE_SIZE],
137    pub(crate) qs: [u8; QK_K / 2],
138}
139const _: () = assert!(QK_K / 2 + K_SCALE_SIZE + 2 * 2 == std::mem::size_of::<BlockQ4K>());
140
141#[derive(Debug, Clone, PartialEq)]
142#[repr(C)]
143pub struct BlockQ5K {
144    pub(crate) d: f16,
145    pub(crate) dmin: f16,
146    pub(crate) scales: [u8; K_SCALE_SIZE],
147    pub(crate) qh: [u8; QK_K / 8],
148    pub(crate) qs: [u8; QK_K / 2],
149}
150const _: () =
151    assert!(QK_K / 8 + QK_K / 2 + 2 * 2 + K_SCALE_SIZE == std::mem::size_of::<BlockQ5K>());
152
153#[derive(Debug, Clone, PartialEq)]
154#[repr(C)]
155pub struct BlockQ6K {
156    pub(crate) ql: [u8; QK_K / 2],
157    pub(crate) qh: [u8; QK_K / 4],
158    pub(crate) scales: [i8; QK_K / 16],
159    pub(crate) d: f16,
160}
161const _: () = assert!(3 * QK_K / 4 + QK_K / 16 + 2 == std::mem::size_of::<BlockQ6K>());
162
163#[derive(Debug, Clone, PartialEq)]
164#[repr(C)]
165pub struct BlockQ8K {
166    pub(crate) d: f32,
167    pub(crate) qs: [i8; QK_K],
168    pub(crate) bsums: [i16; QK_K / 16],
169}
170const _: () = assert!(4 + QK_K + QK_K / 16 * 2 == std::mem::size_of::<BlockQ8K>());
171
172impl GgmlType for BlockQ4_0 {
173    const DTYPE: GgmlDType = GgmlDType::Q4_0;
174    const BLCK_SIZE: usize = QK4_0;
175    type VecDotType = BlockQ8_0;
176
177    // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1525
178    fn to_float(xs: &[Self], ys: &mut [f32]) {
179        let k = ys.len();
180        let qk = Self::BLCK_SIZE;
181        debug_assert!(
182            k.is_multiple_of(qk),
183            "dequantize_row_q4_0: {k} is not divisible by {qk}"
184        );
185
186        let nb = k / qk;
187        for i in 0..nb {
188            let d = xs[i].d.to_f32();
189
190            for j in 0..(qk / 2) {
191                let x0 = (xs[i].qs[j] & 0x0F) as i16 - 8;
192                let x1 = (xs[i].qs[j] >> 4) as i16 - 8;
193
194                ys[i * qk + j] = (x0 as f32) * d;
195                ys[i * qk + j + qk / 2] = (x1 as f32) * d;
196            }
197        }
198    }
199
200    fn from_float(xs: &[f32], ys: &mut [Self]) {
201        // quantize_row_q4_0
202        let qk = Self::BLCK_SIZE;
203        let k = xs.len();
204        debug_assert!(k.is_multiple_of(qk), "{k} is not divisible by {qk}");
205        debug_assert_eq!(
206            ys.len(),
207            k / qk,
208            "size mismatch {} {} {}",
209            xs.len(),
210            ys.len(),
211            qk,
212        );
213        for (i, ys) in ys.iter_mut().enumerate() {
214            let mut amax = 0f32;
215            let mut max = 0f32;
216
217            let xs = &xs[i * qk..(i + 1) * qk];
218            for &x in xs.iter() {
219                if amax < x.abs() {
220                    amax = x.abs();
221                    max = x;
222                }
223            }
224            let d = max / -8.0;
225            let id = if d != 0f32 { 1. / d } else { 0. };
226            ys.d = f16::from_f32(d);
227
228            for (j, q) in ys.qs.iter_mut().enumerate() {
229                let x0 = xs[j] * id;
230                let x1 = xs[qk / 2 + j] * id;
231                let xi0 = u8::min(15, (x0 + 8.5) as u8);
232                let xi1 = u8::min(15, (x1 + 8.5) as u8);
233                *q = xi0 | (xi1 << 4)
234            }
235        }
236    }
237
238    // https://github.com/ggerganov/llama.cpp/blob/b5ffb2849d23afe73647f68eec7b68187af09be6/ggml.c#L2361C10-L2361C122
239    #[allow(unreachable_code)]
240    fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
241        #[cfg(target_feature = "avx2")]
242        return super::avx::vec_dot_q4_0_q8_0(n, xs, ys);
243
244        #[cfg(target_feature = "neon")]
245        return super::neon::vec_dot_q4_0_q8_0(n, xs, ys);
246
247        #[cfg(target_feature = "simd128")]
248        return super::simd128::vec_dot_q4_0_q8_0(n, xs, ys);
249
250        Self::vec_dot_unopt(n, xs, ys)
251    }
252
253    fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
254        debug_assert!(
255            n.is_multiple_of(QK8_0),
256            "vec_dot_q4_0_q8_0: {n} is not divisible by {QK8_0}"
257        );
258        // Generic implementation.
259        let mut sumf = 0f32;
260        for (xs, ys) in xs.iter().zip(ys.iter()) {
261            let mut sum_i = 0;
262            for j in 0..QK8_0 / 2 {
263                let v0 = (xs.qs[j] & 0x0F) as i32 - 8;
264                let v1 = (xs.qs[j] >> 4) as i32 - 8;
265                sum_i += v0 * ys.qs[j] as i32 + v1 * ys.qs[j + QK8_0 / 2] as i32
266            }
267            sumf += sum_i as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d)
268        }
269        sumf
270    }
271}
272
273impl GgmlType for BlockQ4_1 {
274    const DTYPE: GgmlDType = GgmlDType::Q4_1;
275    const BLCK_SIZE: usize = QK4_1;
276    type VecDotType = BlockQ8_1;
277
278    fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
279        Self::vec_dot_unopt(n, xs, ys)
280    }
281
282    fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
283        // ggml_vec_dot_q4_1_q8_1
284        let qk = QK8_1;
285        debug_assert!(
286            n.is_multiple_of(qk),
287            "vec_dot_q4_1_q8_1: {n} is not divisible by {qk}"
288        );
289        debug_assert!(
290            (n / qk).is_multiple_of(2),
291            "vec_dot_q4_1_q8_1: {n}, nb is not divisible by 2"
292        );
293
294        // Generic implementation.
295        let mut sumf = 0f32;
296
297        for (xs, ys) in xs.iter().zip(ys.iter()) {
298            let mut sumi = 0i32;
299
300            for j in 0..qk / 2 {
301                let v0 = xs.qs[j] as i32 & 0x0F;
302                let v1 = xs.qs[j] as i32 >> 4;
303                sumi += (v0 * ys.qs[j] as i32) + (v1 * ys.qs[j + qk / 2] as i32);
304            }
305
306            sumf += sumi as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d)
307                + f16::to_f32(xs.m) * f16::to_f32(ys.s)
308        }
309        sumf
310    }
311
312    fn from_float(xs: &[f32], ys: &mut [Self]) {
313        // quantize_row_q4_1
314        let qk = Self::BLCK_SIZE;
315
316        debug_assert_eq!(
317            ys.len() * qk,
318            xs.len(),
319            "size mismatch {} {} {}",
320            xs.len(),
321            ys.len(),
322            qk,
323        );
324        for (i, ys) in ys.iter_mut().enumerate() {
325            let xs = &xs[i * qk..(i + 1) * qk];
326
327            let mut min = f32::INFINITY;
328            let mut max = f32::NEG_INFINITY;
329            for &x in xs.iter() {
330                min = f32::min(x, min);
331                max = f32::max(x, max);
332            }
333            let d = (max - min) / ((1 << 4) - 1) as f32;
334            let id = if d != 0f32 { 1. / d } else { 0. };
335            ys.d = f16::from_f32(d);
336            ys.m = f16::from_f32(min);
337
338            for (j, q) in ys.qs.iter_mut().take(qk / 2).enumerate() {
339                let x0 = (xs[j] - min) * id;
340                let x1 = (xs[qk / 2 + j] - min) * id;
341
342                let xi0 = u8::min(15, (x0 + 0.5) as u8);
343                let xi1 = u8::min(15, (x1 + 0.5) as u8);
344
345                *q = xi0 | (xi1 << 4);
346            }
347        }
348    }
349
350    // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1545
351    fn to_float(xs: &[Self], ys: &mut [f32]) {
352        let k = ys.len();
353        debug_assert!(
354            k.is_multiple_of(QK4_1),
355            "dequantize_row_q4_1: {k} is not divisible by {QK4_1}"
356        );
357
358        let nb = k / QK4_1;
359        for i in 0..nb {
360            let d = xs[i].d.to_f32();
361            let m = xs[i].m.to_f32();
362
363            for j in 0..(QK4_1 / 2) {
364                let x0 = xs[i].qs[j] & 0x0F;
365                let x1 = xs[i].qs[j] >> 4;
366
367                ys[i * QK4_1 + j] = (x0 as f32) * d + m;
368                ys[i * QK4_1 + j + QK4_1 / 2] = (x1 as f32) * d + m;
369            }
370        }
371    }
372}
373
374impl GgmlType for BlockQ5_0 {
375    const DTYPE: GgmlDType = GgmlDType::Q5_0;
376    const BLCK_SIZE: usize = QK5_0;
377    type VecDotType = BlockQ8_0;
378
379    fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
380        let qk = Self::BLCK_SIZE;
381
382        debug_assert!(
383            n.is_multiple_of(qk),
384            "vec_dot_q5_0_q8_0: {n} is not divisible by {qk}"
385        );
386        debug_assert!(
387            (n / qk).is_multiple_of(2),
388            "vec_dot_q5_0_q8_0: {n}, nb is not divisible by 2"
389        );
390        Self::vec_dot_unopt(n, xs, ys)
391    }
392
393    fn vec_dot_unopt(_n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
394        // Generic implementation.
395        let mut sumf = 0f32;
396
397        for (xs, ys) in xs.iter().zip(ys.iter()) {
398            let qh = LittleEndian::read_u32(&xs.qh);
399            let mut sumi = 0i32;
400
401            for j in 0..Self::BLCK_SIZE / 2 {
402                let xh_0 = (((qh & (1u32 << j)) >> j) << 4) as u8;
403                let xh_1 = ((qh & (1u32 << (j + 16))) >> (j + 12)) as u8;
404
405                let x0 = ((xs.qs[j] & 0x0F) as i32 | xh_0 as i32) - 16;
406                let x1 = ((xs.qs[j] >> 4) as i32 | xh_1 as i32) - 16;
407
408                sumi += (x0 * ys.qs[j] as i32) + (x1 * ys.qs[j + Self::BLCK_SIZE / 2] as i32);
409            }
410
411            sumf += sumi as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d)
412        }
413        sumf
414    }
415
416    fn from_float(xs: &[f32], ys: &mut [Self]) {
417        // quantize_row_q5_0
418        debug_assert_eq!(
419            ys.len() * Self::BLCK_SIZE,
420            xs.len(),
421            "size mismatch {} {} {}",
422            xs.len(),
423            ys.len(),
424            Self::BLCK_SIZE,
425        );
426        for (i, ys) in ys.iter_mut().enumerate() {
427            let xs = &xs[i * Self::BLCK_SIZE..(i + 1) * Self::BLCK_SIZE];
428
429            let mut amax = 0f32;
430            let mut max = 0f32;
431            for &x in xs.iter() {
432                if amax < x.abs() {
433                    amax = x.abs();
434                    max = x;
435                }
436            }
437            let d = max / -16.;
438            let id = if d != 0f32 { 1. / d } else { 0. };
439            ys.d = f16::from_f32(d);
440            let mut qh = 0u32;
441            for j in 0..Self::BLCK_SIZE / 2 {
442                let x0 = xs[j] * id;
443                let x1 = xs[j + Self::BLCK_SIZE / 2] * id;
444                let xi0 = ((x0 + 16.5) as i8).min(31) as u8;
445                let xi1 = ((x1 + 16.5) as i8).min(31) as u8;
446                ys.qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
447                qh |= ((xi0 as u32 & 0x10) >> 4) << j;
448                qh |= ((xi1 as u32 & 0x10) >> 4) << (j + Self::BLCK_SIZE / 2);
449            }
450            LittleEndian::write_u32(&mut ys.qh, qh)
451        }
452    }
453
454    // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1566
455    fn to_float(xs: &[Self], ys: &mut [f32]) {
456        let k = ys.len();
457        debug_assert!(
458            k.is_multiple_of(QK5_0),
459            "dequantize_row_q5_0: {k} is not divisible by {QK5_0}"
460        );
461        let nb = k / QK5_0;
462        for i in 0..nb {
463            let d = xs[i].d.to_f32();
464            let qh: u32 = LittleEndian::read_u32(&xs[i].qh);
465
466            for j in 0..(QK5_0 / 2) {
467                let xh_0 = (((qh >> j) << 4) & 0x10) as u8;
468                let xh_1 = ((qh >> (j + 12)) & 0x10) as u8;
469
470                let x0 = ((xs[i].qs[j] & 0x0F) | xh_0) as i32 - 16;
471                let x1 = ((xs[i].qs[j] >> 4) | xh_1) as i32 - 16;
472
473                ys[i * QK5_0 + j] = (x0 as f32) * d;
474                ys[i * QK5_0 + j + QK5_0 / 2] = (x1 as f32) * d;
475            }
476        }
477    }
478}
479
480impl GgmlType for BlockQ5_1 {
481    const DTYPE: GgmlDType = GgmlDType::Q5_1;
482    const BLCK_SIZE: usize = QK5_1;
483    type VecDotType = BlockQ8_1;
484
485    fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
486        Self::vec_dot_unopt(n, xs, ys)
487    }
488
489    fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
490        let qk = Self::BLCK_SIZE;
491        debug_assert!(
492            n.is_multiple_of(qk),
493            "vec_dot_q5_1_q8_1: {n} is not divisible by {qk}"
494        );
495        debug_assert!(
496            (n / qk).is_multiple_of(2),
497            "vec_dot_q5_1_q8_1: {n}, nb is not divisible by 2"
498        );
499
500        // Generic implementation.
501        let mut sumf = 0f32;
502
503        for (xs, ys) in xs.iter().zip(ys.iter()) {
504            let qh = LittleEndian::read_u32(&xs.qh);
505            let mut sumi = 0i32;
506
507            for j in 0..Self::BLCK_SIZE / 2 {
508                let xh_0 = ((qh >> j) << 4) & 0x10;
509                let xh_1 = (qh >> (j + 12)) & 0x10;
510
511                let x0 = (xs.qs[j] as i32 & 0xF) | xh_0 as i32;
512                let x1 = (xs.qs[j] as i32 >> 4) | xh_1 as i32;
513
514                sumi += (x0 * ys.qs[j] as i32) + (x1 * ys.qs[j + Self::BLCK_SIZE / 2] as i32);
515            }
516
517            sumf += sumi as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d)
518                + f16::to_f32(xs.m) * f16::to_f32(ys.s)
519        }
520        sumf
521    }
522
523    fn from_float(xs: &[f32], ys: &mut [Self]) {
524        // quantize_row_q5_1
525        let qk = Self::BLCK_SIZE;
526        debug_assert_eq!(
527            ys.len() * qk,
528            xs.len(),
529            "size mismatch {} {} {}",
530            xs.len(),
531            ys.len(),
532            qk,
533        );
534        for (i, ys) in ys.iter_mut().enumerate() {
535            let xs = &xs[i * qk..(i + 1) * qk];
536
537            let mut min = f32::INFINITY;
538            let mut max = f32::NEG_INFINITY;
539            for &x in xs.iter() {
540                min = f32::min(x, min);
541                max = f32::max(x, max);
542            }
543            let d = (max - min) / ((1 << 5) - 1) as f32;
544            let id = if d != 0f32 { 1. / d } else { 0. };
545            ys.d = f16::from_f32(d);
546            ys.m = f16::from_f32(min);
547
548            let mut qh = 0u32;
549            for (j, q) in ys.qs.iter_mut().take(qk / 2).enumerate() {
550                let x0 = (xs[j] - min) * id;
551                let x1 = (xs[qk / 2 + j] - min) * id;
552
553                let xi0 = (x0 + 0.5) as u8;
554                let xi1 = (x1 + 0.5) as u8;
555
556                *q = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
557                // get the 5-th bit and store it in qh at the right position
558                qh |= ((xi0 as u32 & 0x10) >> 4) << j;
559                qh |= ((xi1 as u32 & 0x10) >> 4) << (j + qk / 2);
560            }
561            LittleEndian::write_u32(&mut ys.qh, qh);
562        }
563    }
564
565    // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1592
566    fn to_float(xs: &[Self], ys: &mut [f32]) {
567        let k = ys.len();
568        debug_assert!(
569            k.is_multiple_of(QK5_1),
570            "dequantize_row_q5_1: {k} is not divisible by {QK5_1}"
571        );
572
573        let nb = k / QK5_1;
574        for i in 0..nb {
575            let d = xs[i].d.to_f32();
576            let m = xs[i].m.to_f32();
577            let qh: u32 = LittleEndian::read_u32(&xs[i].qh);
578
579            for j in 0..(QK5_1 / 2) {
580                let xh_0 = (((qh >> j) << 4) & 0x10) as u8;
581                let xh_1 = ((qh >> (j + 12)) & 0x10) as u8;
582
583                let x0 = (xs[i].qs[j] & 0x0F) | xh_0;
584                let x1 = (xs[i].qs[j] >> 4) | xh_1;
585
586                ys[i * QK5_1 + j] = (x0 as f32) * d + m;
587                ys[i * QK5_1 + j + QK5_1 / 2] = (x1 as f32) * d + m;
588            }
589        }
590    }
591}
592
593impl GgmlType for BlockQ8_0 {
594    const DTYPE: GgmlDType = GgmlDType::Q8_0;
595    const BLCK_SIZE: usize = QK8_0;
596    type VecDotType = BlockQ8_0;
597
598    // https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/ggml.c#L1619
599    fn to_float(xs: &[Self], ys: &mut [f32]) {
600        let k = ys.len();
601        debug_assert!(
602            k.is_multiple_of(QK8_0),
603            "dequantize_row_q8_0: {k} is not divisible by {QK8_0}"
604        );
605
606        let nb = k / QK8_0;
607
608        for i in 0..nb {
609            let d = xs[i].d.to_f32();
610
611            for j in 0..QK8_0 {
612                ys[i * QK8_0 + j] = xs[i].qs[j] as f32 * d;
613            }
614        }
615    }
616
617    fn from_float(xs: &[f32], ys: &mut [Self]) {
618        // quantize_row_q8_0
619        let k = xs.len();
620        debug_assert!(
621            k.is_multiple_of(Self::BLCK_SIZE),
622            "{k} is not divisible by {}",
623            Self::BLCK_SIZE
624        );
625        debug_assert_eq!(
626            ys.len(),
627            k / Self::BLCK_SIZE,
628            "size mismatch {} {} {}",
629            xs.len(),
630            ys.len(),
631            Self::BLCK_SIZE
632        );
633        for (i, ys) in ys.iter_mut().enumerate() {
634            let mut amax = 0f32;
635            let xs = &xs[i * Self::BLCK_SIZE..(i + 1) * Self::BLCK_SIZE];
636            for &x in xs.iter() {
637                amax = amax.max(x.abs())
638            }
639            let d = amax / ((1 << 7) - 1) as f32;
640            let id = if d != 0f32 { 1. / d } else { 0. };
641            ys.d = f16::from_f32(d);
642            for (y, &x) in ys.qs.iter_mut().zip(xs.iter()) {
643                *y = f32::round(x * id) as i8
644            }
645        }
646    }
647
648    #[allow(unreachable_code)]
649    fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
650        #[cfg(target_feature = "avx2")]
651        return super::avx::vec_dot_q8_0_q8_0(n, xs, ys);
652
653        #[cfg(target_feature = "neon")]
654        return super::neon::vec_dot_q8_0_q8_0(n, xs, ys);
655
656        #[cfg(target_feature = "simd128")]
657        return super::simd128::vec_dot_q8_0_q8_0(n, xs, ys);
658
659        Self::vec_dot_unopt(n, xs, ys)
660    }
661
662    fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
663        debug_assert!(
664            n.is_multiple_of(QK8_0),
665            "vec_dot_q8_0_q8_0: {n} is not divisible by {QK8_0}"
666        );
667
668        // Generic implementation.
669        let mut sumf = 0f32;
670        for (xs, ys) in xs.iter().zip(ys.iter()) {
671            let sum_i = xs
672                .qs
673                .iter()
674                .zip(ys.qs.iter())
675                .map(|(&x, &y)| x as i32 * y as i32)
676                .sum::<i32>();
677            sumf += sum_i as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d)
678        }
679        sumf
680    }
681}
682
683impl GgmlType for BlockQ8_1 {
684    const DTYPE: GgmlDType = GgmlDType::Q8_1;
685    const BLCK_SIZE: usize = QK8_1;
686    type VecDotType = BlockQ8_1;
687
688    fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
689        Self::vec_dot_unopt(n, xs, ys)
690    }
691
692    fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
693        debug_assert!(
694            n.is_multiple_of(QK8_1),
695            "vec_dot_q8_1_q8_1: {n} is not divisible by {QK8_1}"
696        );
697
698        // Generic implementation.
699        let mut sumf = 0f32;
700        for (xs, ys) in xs.iter().zip(ys.iter()) {
701            let sum_i = xs
702                .qs
703                .iter()
704                .zip(ys.qs.iter())
705                .map(|(&x, &y)| x as i32 * y as i32)
706                .sum::<i32>();
707            sumf += sum_i as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d)
708        }
709        sumf
710    }
711
712    fn from_float(xs: &[f32], ys: &mut [Self]) {
713        // quantize_row_q8_1
714        debug_assert_eq!(
715            ys.len() * Self::BLCK_SIZE,
716            xs.len(),
717            "size mismatch {} {} {}",
718            xs.len(),
719            ys.len(),
720            Self::BLCK_SIZE
721        );
722        for (i, ys) in ys.iter_mut().enumerate() {
723            let mut amax = 0f32;
724            let xs = &xs[i * Self::BLCK_SIZE..(i + 1) * Self::BLCK_SIZE];
725            for &x in xs.iter() {
726                amax = amax.max(x.abs())
727            }
728            let d = amax / ((1 << 7) - 1) as f32;
729            let id = if d != 0f32 { 1. / d } else { 0. };
730            ys.d = f16::from_f32(d);
731            let mut sum = 0i32;
732            for j in 0..Self::BLCK_SIZE / 2 {
733                let v0 = xs[j] * id;
734                let v1 = xs[j + Self::BLCK_SIZE / 2] * id;
735                ys.qs[j] = f32::round(v0) as i8;
736                ys.qs[j + Self::BLCK_SIZE / 2] = f32::round(v1) as i8;
737                sum += ys.qs[j] as i32 + ys.qs[j + Self::BLCK_SIZE / 2] as i32;
738            }
739            ys.s = f16::from_f32(sum as f32) * ys.d;
740        }
741    }
742
743    fn to_float(_xs: &[Self], _ys: &mut [f32]) {
744        unimplemented!("no support for vec-dot on Q8_1")
745    }
746}
747
748impl GgmlType for BlockQ2K {
749    const DTYPE: GgmlDType = GgmlDType::Q2K;
750    const BLCK_SIZE: usize = QK_K;
751    type VecDotType = BlockQ8K;
752
753    #[allow(unreachable_code)]
754    fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
755        #[cfg(target_feature = "avx2")]
756        return super::avx::vec_dot_q2k_q8k(n, xs, ys);
757
758        #[cfg(target_feature = "neon")]
759        return super::neon::vec_dot_q2k_q8k(n, xs, ys);
760
761        #[cfg(target_feature = "simd128")]
762        return super::simd128::vec_dot_q2k_q8k(n, xs, ys);
763
764        Self::vec_dot_unopt(n, xs, ys)
765    }
766
767    fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
768        debug_assert!(
769            n.is_multiple_of(QK_K),
770            "vec_dot_q2k_q8k: {n} is not divisible by {QK_K}"
771        );
772
773        let mut sumf = 0.0;
774        for (x, y) in xs.iter().zip(ys.iter()) {
775            let mut q2: &[_] = &x.qs;
776            let mut q8: &[_] = &y.qs;
777            let sc = &x.scales;
778
779            let mut summs = 0;
780            for (bsum, scale) in y.bsums.iter().zip(sc) {
781                summs += *bsum as i32 * ((scale >> 4) as i32);
782            }
783
784            let dall = y.d * x.d.to_f32();
785            let dmin = y.d * x.dmin.to_f32();
786
787            let mut isum = 0;
788            let mut is = 0;
789            for _ in 0..(QK_K / 128) {
790                let mut shift = 0;
791                for _ in 0..4 {
792                    let d = (sc[is] & 0xF) as i32;
793                    is += 1;
794                    let mut isuml = 0;
795                    for l in 0..16 {
796                        isuml += q8[l] as i32 * (((q2[l] >> shift) & 3) as i32);
797                    }
798                    isum += d * isuml;
799                    let d = (sc[is] & 0xF) as i32;
800                    is += 1;
801                    isuml = 0;
802                    for l in 16..32 {
803                        isuml += q8[l] as i32 * (((q2[l] >> shift) & 3) as i32);
804                    }
805                    isum += d * isuml;
806                    shift += 2;
807                    // adjust the indexing
808                    q8 = &q8[32..];
809                }
810                // adjust the indexing
811                q2 = &q2[32..];
812            }
813            sumf += dall * isum as f32 - dmin * summs as f32;
814        }
815
816        sumf
817    }
818
819    // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L279
820    fn from_float(xs: &[f32], ys: &mut [Self]) {
821        const Q4SCALE: f32 = 15.0;
822
823        for (block, x) in group_for_quantization(xs, ys) {
824            //calculate scales and mins
825            let mut mins: [f32; QK_K / 16] = [0.0; QK_K / 16];
826            let mut scales: [f32; QK_K / 16] = [0.0; QK_K / 16];
827
828            for (j, x_scale_slice) in x.chunks(16).enumerate() {
829                (scales[j], mins[j]) = make_qkx1_quants(3, 5, x_scale_slice);
830            }
831            // get max scale and max min and ensure they are >= 0.0
832            let max_scale = scales.iter().fold(0.0, |max, &val| val.max(max));
833            let max_min = mins.iter().fold(0.0, |max, &val| val.max(max));
834
835            if max_scale > 0.0 {
836                let iscale = Q4SCALE / max_scale;
837                for (j, scale) in scales.iter().enumerate().take(QK_K / 16) {
838                    block.scales[j] = nearest_int(iscale * scale) as u8;
839                }
840                block.d = f16::from_f32(max_scale / Q4SCALE);
841            } else {
842                for j in 0..QK_K / 16 {
843                    block.scales[j] = 0;
844                }
845                block.d = f16::from_f32(0.0);
846            }
847
848            if max_min > 0.0 {
849                let iscale = Q4SCALE / max_min;
850                for (j, scale) in block.scales.iter_mut().enumerate() {
851                    let l = nearest_int(iscale * mins[j]) as u8;
852                    *scale |= l << 4;
853                }
854                block.dmin = f16::from_f32(max_min / Q4SCALE);
855            } else {
856                block.dmin = f16::from_f32(0.0);
857            }
858
859            let mut big_l: [u8; QK_K] = [0; QK_K];
860
861            for j in 0..QK_K / 16 {
862                let d = block.d.to_f32() * (block.scales[j] & 0xF) as f32;
863                if d == 0.0 {
864                    continue;
865                }
866                let dm = block.dmin.to_f32() * (block.scales[j] >> 4) as f32;
867                for ii in 0..16 {
868                    let ll = nearest_int((x[16 * j + ii] + dm) / d).clamp(0, 3);
869                    big_l[16 * j + ii] = ll as u8;
870                }
871            }
872
873            for j in (0..QK_K).step_by(128) {
874                for ll in 0..32 {
875                    block.qs[j / 4 + ll] = big_l[j + ll]
876                        | (big_l[j + ll + 32] << 2)
877                        | (big_l[j + ll + 64] << 4)
878                        | (big_l[j + ll + 96] << 6);
879                }
880            }
881        }
882    }
883
884    fn from_float_imatrix(xs: &[f32], ys: &mut [Self], imatrix_weights: &[f32], n_per_row: usize) {
885        for (sblk_idx, (block, x)) in group_for_quantization(xs, ys).into_iter().enumerate() {
886            let mut mins: [f32; QK_K / 16] = [0.0; QK_K / 16];
887            let mut scales: [f32; QK_K / 16] = [0.0; QK_K / 16];
888            let mut weights: [f32; 16] = [0.0; 16];
889            let mut sw: [f32; QK_K / 16] = [0.0; QK_K / 16];
890            let mut ls: [u8; QK_K / 16] = [0; QK_K / 16];
891            let mut lm: [u8; QK_K / 16] = [0; QK_K / 16];
892
893            let sum_x2 = x.iter().map(|x| x * x).sum::<f32>();
894            let sigma2 = sum_x2 / QK_K as f32;
895            for (j, x_scale_slice) in x.chunks_exact(16).enumerate() {
896                for (l, (w_elem, x_elem)) in weights.iter_mut().zip(x_scale_slice).enumerate() {
897                    let imatrix_row = sblk_idx % (n_per_row / QK_K);
898                    let imatrix_w = imatrix_weights[imatrix_row * QK_K + 16 * j + l];
899                    *w_elem = imatrix_w * (sigma2 + x_elem * x_elem).sqrt();
900                }
901                let sumw = weights.iter().sum::<f32>();
902                sw[j] = sumw;
903                (scales[j], mins[j]) =
904                    make_qkx3_quants(3, x_scale_slice, Some(&weights), -0.9, 0.05, 36, false);
905            }
906
907            let d_block = make_qp_quants(QK_K / 16, 15, &scales, &mut ls, &sw);
908            let m_block = make_qp_quants(QK_K / 16, 15, &mins, &mut lm, &sw);
909
910            block.d = f16::from_f32(d_block);
911            block.dmin = f16::from_f32(m_block);
912
913            for j in 0..QK_K / 16 {
914                block.scales[j] = ls[j] | (lm[j] << 4);
915            }
916
917            let mut big_l: [u8; QK_K] = [0; QK_K];
918
919            for j in 0..QK_K / 16 {
920                let d = block.d.to_f32() * (block.scales[j] & 0xF) as f32;
921                if d == 0.0 {
922                    continue;
923                }
924                let dm = block.dmin.to_f32() * (block.scales[j] >> 4) as f32;
925                for ii in 0..16 {
926                    let ll = nearest_int((x[16 * j + ii] + dm) / d).clamp(0, 3);
927                    big_l[16 * j + ii] = ll as u8;
928                }
929            }
930
931            for j in (0..QK_K).step_by(128) {
932                for ll in 0..32 {
933                    block.qs[j / 4 + ll] = big_l[j + ll]
934                        | (big_l[j + ll + 32] << 2)
935                        | (big_l[j + ll + 64] << 4)
936                        | (big_l[j + ll + 96] << 6);
937                }
938            }
939        }
940    }
941    // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L354
942    fn to_float(xs: &[Self], ys: &mut [f32]) {
943        for (block, y) in group_for_dequantization(xs, ys) {
944            let d = block.d.to_f32();
945            let min = block.dmin.to_f32();
946
947            let mut is = 0;
948
949            for (y_block, qs) in y.chunks_exact_mut(128).zip(block.qs.chunks_exact(32)) {
950                // Step by 32 over q.
951                let mut shift = 0;
952                let mut y_block_index = 0;
953                for _j in 0..4 {
954                    let sc = block.scales[is];
955                    is += 1;
956                    let dl = d * (sc & 0xF) as f32;
957                    let ml = min * (sc >> 4) as f32;
958                    for q in &qs[..16] {
959                        let y = dl * ((q >> shift) & 3) as f32 - ml;
960                        y_block[y_block_index] = y;
961                        y_block_index += 1;
962                    }
963
964                    let sc = block.scales[is];
965                    is += 1;
966                    let dl = d * (sc & 0xF) as f32;
967                    let ml = min * (sc >> 4) as f32;
968                    for q in &qs[16..] {
969                        let y = dl * ((q >> shift) & 3) as f32 - ml;
970                        y_block[y_block_index] = y;
971                        y_block_index += 1;
972                    }
973
974                    shift += 2;
975                }
976            }
977        }
978    }
979}
980
981impl GgmlType for BlockQ3K {
982    const DTYPE: GgmlDType = GgmlDType::Q3K;
983    const BLCK_SIZE: usize = QK_K;
984    type VecDotType = BlockQ8K;
985
986    #[allow(unreachable_code)]
987    fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
988        #[cfg(target_feature = "avx2")]
989        return super::avx::vec_dot_q3k_q8k(n, xs, ys);
990
991        #[cfg(target_feature = "neon")]
992        return super::neon::vec_dot_q3k_q8k(n, xs, ys);
993
994        Self::vec_dot_unopt(n, xs, ys)
995    }
996
997    fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
998        debug_assert!(
999            n.is_multiple_of(QK_K),
1000            "vec_dot_q3k_q8k: {n} is not divisible by {QK_K}"
1001        );
1002
1003        const KMASK1: u32 = 0x03030303;
1004        const KMASK2: u32 = 0x0f0f0f0f;
1005
1006        let mut aux8: [i8; QK_K] = [0; QK_K];
1007        let mut aux16: [i16; 8] = [0; 8];
1008        let mut sums: [f32; 8] = [0.0; 8];
1009        let mut aux32: [i32; 8] = [0; 8];
1010
1011        let mut auxs: [u32; 4] = [0; 4];
1012
1013        for (x, y) in xs.iter().zip(ys.iter()) {
1014            let mut q3: &[u8] = &x.qs;
1015            let hmask: &[u8] = &x.hmask;
1016            let mut q8: &[i8] = &y.qs;
1017
1018            aux32.fill(0);
1019            let mut a = &mut aux8[..];
1020
1021            let mut m = 1;
1022            //Like the GGML original this is written this way to enable the compiler to vectorize it.
1023            for _ in 0..QK_K / 128 {
1024                a.iter_mut()
1025                    .take(32)
1026                    .zip(q3)
1027                    .for_each(|(a_val, q3_val)| *a_val = (q3_val & 3) as i8);
1028                a.iter_mut()
1029                    .take(32)
1030                    .zip(hmask)
1031                    .for_each(|(a_val, hmask_val)| {
1032                        *a_val -= if hmask_val & m != 0 { 0 } else { 4 }
1033                    });
1034                a = &mut a[32..];
1035                m <<= 1;
1036
1037                a.iter_mut()
1038                    .take(32)
1039                    .zip(q3)
1040                    .for_each(|(a_val, q3_val)| *a_val = ((q3_val >> 2) & 3) as i8);
1041                a.iter_mut()
1042                    .take(32)
1043                    .zip(hmask)
1044                    .for_each(|(a_val, hmask_val)| {
1045                        *a_val -= if hmask_val & m != 0 { 0 } else { 4 }
1046                    });
1047                a = &mut a[32..];
1048                m <<= 1;
1049
1050                a.iter_mut()
1051                    .take(32)
1052                    .zip(q3)
1053                    .for_each(|(a_val, q3_val)| *a_val = ((q3_val >> 4) & 3) as i8);
1054                a.iter_mut()
1055                    .take(32)
1056                    .zip(hmask)
1057                    .for_each(|(a_val, hmask_val)| {
1058                        *a_val -= if hmask_val & m != 0 { 0 } else { 4 }
1059                    });
1060                a = &mut a[32..];
1061                m <<= 1;
1062
1063                a.iter_mut()
1064                    .take(32)
1065                    .zip(q3)
1066                    .for_each(|(a_val, q3_val)| *a_val = ((q3_val >> 6) & 3) as i8);
1067                a.iter_mut()
1068                    .take(32)
1069                    .zip(hmask)
1070                    .for_each(|(a_val, hmask_val)| {
1071                        *a_val -= if hmask_val & m != 0 { 0 } else { 4 }
1072                    });
1073                a = &mut a[32..];
1074                m <<= 1;
1075                q3 = &q3[32..];
1076            }
1077
1078            a = &mut aux8[..];
1079
1080            LittleEndian::read_u32_into(&x.scales, &mut auxs[0..3]);
1081
1082            let tmp = auxs[2];
1083            auxs[2] = ((auxs[0] >> 4) & KMASK2) | (((tmp >> 4) & KMASK1) << 4);
1084            auxs[3] = ((auxs[1] >> 4) & KMASK2) | (((tmp >> 6) & KMASK1) << 4);
1085            auxs[0] = (auxs[0] & KMASK2) | (((tmp) & KMASK1) << 4);
1086            auxs[1] = (auxs[1] & KMASK2) | (((tmp >> 2) & KMASK1) << 4);
1087
1088            for aux in auxs {
1089                for scale in aux.to_le_bytes() {
1090                    let scale = i8::from_be_bytes([scale]);
1091                    for l in 0..8 {
1092                        aux16[l] = q8[l] as i16 * a[l] as i16;
1093                    }
1094                    for l in 0..8 {
1095                        aux32[l] += (scale as i32 - 32) * aux16[l] as i32;
1096                    }
1097                    q8 = &q8[8..];
1098                    a = &mut a[8..];
1099
1100                    for l in 0..8 {
1101                        aux16[l] = q8[l] as i16 * a[l] as i16;
1102                    }
1103                    for l in 0..8 {
1104                        aux32[l] += (scale as i32 - 32) * aux16[l] as i32;
1105                    }
1106                    q8 = &q8[8..];
1107                    a = &mut a[8..];
1108                }
1109            }
1110            let d = x.d.to_f32() * y.d;
1111            for l in 0..8 {
1112                sums[l] += d * aux32[l] as f32;
1113            }
1114        }
1115
1116        sums.iter().sum()
1117    }
1118
1119    fn from_float(xs: &[f32], ys: &mut [Self]) {
1120        for (block, x) in group_for_quantization(xs, ys) {
1121            let mut scales: [f32; QK_K / 16] = [0.0; QK_K / 16];
1122            for (j, x_scale_slice) in x.chunks_exact(16).enumerate() {
1123                scales[j] = make_q3_quants(x_scale_slice, 4, true);
1124            }
1125
1126            // Get max scale by absolute value.
1127            let mut max_scale: f32 = 0.0;
1128            for &scale in scales.iter() {
1129                if scale.abs() > max_scale.abs() {
1130                    max_scale = scale;
1131                }
1132            }
1133
1134            block.scales.fill(0);
1135
1136            if max_scale != 0.0 {
1137                let iscale = -32.0 / max_scale;
1138                for (j, scale) in scales.iter().enumerate() {
1139                    let l_val = nearest_int(iscale * scale);
1140                    let l_val = l_val.clamp(-32, 31) + 32;
1141                    if j < 8 {
1142                        block.scales[j] = (l_val & 0xF) as u8;
1143                    } else {
1144                        block.scales[j - 8] |= ((l_val & 0xF) << 4) as u8;
1145                    }
1146                    let l_val = l_val >> 4;
1147                    block.scales[j % 4 + 8] |= (l_val << (2 * (j / 4))) as u8;
1148                }
1149                block.d = f16::from_f32(1.0 / iscale);
1150            } else {
1151                block.d = f16::from_f32(0.0);
1152            }
1153
1154            let mut l: [i8; QK_K] = [0; QK_K];
1155
1156            for j in 0..QK_K / 16 {
1157                let sc = if j < 8 {
1158                    block.scales[j] & 0xF
1159                } else {
1160                    block.scales[j - 8] >> 4
1161                };
1162                let sc = (sc | (((block.scales[8 + j % 4] >> (2 * (j / 4))) & 3) << 4)) as i8 - 32;
1163                let d = block.d.to_f32() * sc as f32;
1164                if d != 0.0 {
1165                    for ii in 0..16 {
1166                        let l_val = nearest_int(x[16 * j + ii] / d);
1167                        l[16 * j + ii] = (l_val.clamp(-4, 3) + 4) as i8;
1168                    }
1169                }
1170            }
1171
1172            block.hmask.fill(0);
1173            let mut m = 0;
1174            let mut hm = 1;
1175
1176            for ll in l.iter_mut() {
1177                if *ll > 3 {
1178                    block.hmask[m] |= hm;
1179                    *ll -= 4;
1180                }
1181                m += 1;
1182                if m == QK_K / 8 {
1183                    m = 0;
1184                    hm <<= 1;
1185                }
1186            }
1187
1188            for j in (0..QK_K).step_by(128) {
1189                for l_val in 0..32 {
1190                    block.qs[j / 4 + l_val] = (l[j + l_val]
1191                        | (l[j + l_val + 32] << 2)
1192                        | (l[j + l_val + 64] << 4)
1193                        | (l[j + l_val + 96] << 6))
1194                        as u8;
1195                }
1196            }
1197        }
1198    }
1199
1200    fn from_float_imatrix(xs: &[f32], ys: &mut [Self], imatrix_weights: &[f32], n_per_row: usize) {
1201        for (sblk_idx, (block, x)) in group_for_quantization(xs, ys).into_iter().enumerate() {
1202            let mut scales: [f32; QK_K / 16] = [0.0; QK_K / 16];
1203            let mut weights: [f32; 16] = [0.0; 16];
1204            let mut sw: [f32; QK_K / 16] = [0.0; QK_K / 16];
1205            let mut ls: [i8; QK_K / 16] = [0; QK_K / 16];
1206            let mut l: [i8; QK_K] = [0; QK_K];
1207
1208            let sum_x2 = x.iter().map(|x| x * x).sum::<f32>();
1209            let sigma2 = 2. * sum_x2 / QK_K as f32;
1210
1211            for (j, x_scale_slice) in x.chunks_exact(16).enumerate() {
1212                for (l_idx, (w_elem, x_elem)) in weights.iter_mut().zip(x_scale_slice).enumerate() {
1213                    let imatrix_row = sblk_idx % (n_per_row / QK_K);
1214                    let imatrix_w = imatrix_weights[imatrix_row * QK_K + 16 * j + l_idx];
1215                    *w_elem = imatrix_w * (sigma2 + x_elem * x_elem).sqrt();
1216                }
1217                let sumw = weights.iter().sum::<f32>();
1218                sw[j] = sumw;
1219                scales[j] = unsafe {
1220                    make_qx_quants(
1221                        16,
1222                        4,
1223                        x_scale_slice.as_ptr(),
1224                        l.as_mut_ptr().add(16 * j),
1225                        1,
1226                        weights.as_ptr(),
1227                    )
1228                };
1229            }
1230
1231            block.scales.fill(0);
1232            let d_block = unsafe {
1233                make_qx_quants(
1234                    QK_K / 16,
1235                    32,
1236                    scales.as_ptr(),
1237                    ls.as_mut_ptr(),
1238                    1,
1239                    sw.as_ptr(),
1240                )
1241            };
1242            block.d = f16::from_f32(d_block);
1243            for (j, l_val) in ls.iter().enumerate().take(QK_K / 16) {
1244                if j < 8 {
1245                    block.scales[j] = (l_val & 0xF) as u8;
1246                } else {
1247                    block.scales[j - 8] |= ((l_val & 0xF) << 4) as u8;
1248                }
1249                let l_val = l_val >> 4;
1250                block.scales[j % 4 + 8] |= (l_val << (2 * (j / 4))) as u8;
1251            }
1252
1253            for j in 0..QK_K / 16 {
1254                let sc = if j < 8 {
1255                    block.scales[j] & 0xF
1256                } else {
1257                    block.scales[j - 8] >> 4
1258                };
1259                let sc = (sc | (((block.scales[8 + j % 4] >> (2 * (j / 4))) & 3) << 4)) as i8 - 32;
1260                let d = block.d.to_f32() * sc as f32;
1261                if d != 0.0 {
1262                    for ii in 0..16 {
1263                        let l_val = nearest_int(x[16 * j + ii] / d);
1264                        l[16 * j + ii] = (l_val.clamp(-4, 3) + 4) as i8;
1265                    }
1266                }
1267            }
1268
1269            block.hmask.fill(0);
1270            let mut m = 0;
1271            let mut hm = 1;
1272
1273            for ll in l.iter_mut() {
1274                if *ll > 3 {
1275                    block.hmask[m] |= hm;
1276                    *ll -= 4;
1277                }
1278                m += 1;
1279                if m == QK_K / 8 {
1280                    m = 0;
1281                    hm <<= 1;
1282                }
1283            }
1284
1285            for j in (0..QK_K).step_by(128) {
1286                for l_val in 0..32 {
1287                    block.qs[j / 4 + l_val] = (l[j + l_val]
1288                        | (l[j + l_val + 32] << 2)
1289                        | (l[j + l_val + 64] << 4)
1290                        | (l[j + l_val + 96] << 6))
1291                        as u8;
1292                }
1293            }
1294        }
1295    }
1296
1297    // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L533
1298    fn to_float(xs: &[Self], ys: &mut [f32]) {
1299        const KMASK1: u32 = 0x03030303;
1300        const KMASK2: u32 = 0x0f0f0f0f;
1301
1302        for (block, y) in group_for_dequantization(xs, ys) {
1303            //Reconstruct the scales
1304            let mut aux = [0; 4];
1305            LittleEndian::read_u32_into(&block.scales, &mut aux[0..3]);
1306
1307            let tmp = aux[2];
1308            aux[2] = ((aux[0] >> 4) & KMASK2) | (((tmp >> 4) & KMASK1) << 4);
1309            aux[3] = ((aux[1] >> 4) & KMASK2) | (((tmp >> 6) & KMASK1) << 4);
1310            aux[0] = (aux[0] & KMASK2) | (((tmp) & KMASK1) << 4);
1311            aux[1] = (aux[1] & KMASK2) | (((tmp >> 2) & KMASK1) << 4);
1312
1313            //Transfer the scales into an i8 array
1314            let scales: &mut [i8] =
1315                unsafe { std::slice::from_raw_parts_mut(aux.as_mut_ptr() as *mut i8, 16) };
1316
1317            let d_all = block.d.to_f32();
1318            let mut m = 1;
1319            let mut is = 0;
1320
1321            // Dequantize both 128 long blocks
1322            // 32 qs values per 128 long block
1323            // Each 16 elements get a scale
1324            for (y, qs) in y.chunks_exact_mut(128).zip(block.qs.chunks_exact(32)) {
1325                let mut shift = 0;
1326                for shift_scoped_y in y.chunks_exact_mut(32) {
1327                    for (scale_index, scale_scoped_y) in
1328                        shift_scoped_y.chunks_exact_mut(16).enumerate()
1329                    {
1330                        let dl = d_all * (scales[is] as f32 - 32.0);
1331                        for (i, inner_y) in scale_scoped_y.iter_mut().enumerate() {
1332                            let new_y = dl
1333                                * (((qs[i + 16 * scale_index] >> shift) & 3) as i8
1334                                    - if (block.hmask[i + 16 * scale_index] & m) == 0 {
1335                                        4
1336                                    } else {
1337                                        0
1338                                    }) as f32;
1339                            *inner_y = new_y;
1340                        }
1341                        // 16 block finished => advance scale index
1342                        is += 1;
1343                    }
1344                    // 32 block finished => increase shift and m
1345                    shift += 2;
1346                    m <<= 1;
1347                }
1348            }
1349        }
1350    }
1351}
1352
1353impl GgmlType for BlockQ4K {
1354    const DTYPE: GgmlDType = GgmlDType::Q4K;
1355    const BLCK_SIZE: usize = QK_K;
1356    type VecDotType = BlockQ8K;
1357
1358    #[allow(unreachable_code)]
1359    fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
1360        #[cfg(target_feature = "avx2")]
1361        return super::avx::vec_dot_q4k_q8k(n, xs, ys);
1362
1363        #[cfg(target_feature = "neon")]
1364        return super::neon::vec_dot_q4k_q8k(n, xs, ys);
1365
1366        #[cfg(target_feature = "simd128")]
1367        return super::simd128::vec_dot_q4k_q8k(n, xs, ys);
1368
1369        Self::vec_dot_unopt(n, xs, ys)
1370    }
1371
1372    fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
1373        debug_assert!(
1374            n.is_multiple_of(QK_K),
1375            "vec_dot_q4k_q8k: {n} is not divisible by {QK_K}"
1376        );
1377
1378        const KMASK1: u32 = 0x3f3f3f3f;
1379        const KMASK2: u32 = 0x0f0f0f0f;
1380        const KMASK3: u32 = 0x03030303;
1381
1382        let mut utmp: [u32; 4] = [0; 4];
1383        let mut scales: [u8; 8] = [0; 8];
1384        let mut mins: [u8; 8] = [0; 8];
1385
1386        let mut aux8: [i8; QK_K] = [0; QK_K];
1387        let mut aux16: [i16; 8] = [0; 8];
1388        let mut sums: [f32; 8] = [0.0; 8];
1389        let mut aux32: [i32; 8] = [0; 8];
1390
1391        let mut sumf = 0.0;
1392        for (y, x) in ys.iter().zip(xs.iter()) {
1393            let q4 = &x.qs;
1394            let q8 = &y.qs;
1395            aux32.fill(0);
1396
1397            let mut a = &mut aux8[..];
1398            let mut q4 = &q4[..];
1399            for _ in 0..QK_K / 64 {
1400                for l in 0..32 {
1401                    a[l] = (q4[l] & 0xF) as i8;
1402                }
1403                a = &mut a[32..];
1404                for l in 0..32 {
1405                    a[l] = (q4[l] >> 4) as i8;
1406                }
1407                a = &mut a[32..];
1408                q4 = &q4[32..];
1409            }
1410
1411            LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
1412
1413            utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
1414            let uaux = utmp[1] & KMASK1;
1415            utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
1416            utmp[2] = uaux;
1417            utmp[0] &= KMASK1;
1418
1419            //extract scales and mins
1420            LittleEndian::write_u32_into(&utmp[0..2], &mut scales);
1421            LittleEndian::write_u32_into(&utmp[2..4], &mut mins);
1422
1423            let mut sumi = 0;
1424            for j in 0..QK_K / 16 {
1425                sumi += y.bsums[j] as i32 * mins[j / 2] as i32;
1426            }
1427
1428            let mut a = &mut aux8[..];
1429            let mut q8 = &q8[..];
1430
1431            for scale in scales {
1432                let scale = scale as i32;
1433                for _ in 0..4 {
1434                    for l in 0..8 {
1435                        aux16[l] = q8[l] as i16 * a[l] as i16;
1436                    }
1437                    for l in 0..8 {
1438                        aux32[l] += scale * aux16[l] as i32;
1439                    }
1440                    q8 = &q8[8..];
1441                    a = &mut a[8..];
1442                }
1443            }
1444            let d = x.d.to_f32() * y.d;
1445            for l in 0..8 {
1446                sums[l] += d * aux32[l] as f32;
1447            }
1448            let dmin = x.dmin.to_f32() * y.d;
1449            sumf -= dmin * sumi as f32;
1450        }
1451        sumf + sums.iter().sum::<f32>()
1452    }
1453
1454    fn from_float(xs: &[f32], ys: &mut [Self]) {
1455        for (block, x) in group_for_quantization(xs, ys) {
1456            let mut mins: [f32; QK_K / 32] = [0.0; QK_K / 32];
1457            let mut scales: [f32; QK_K / 32] = [0.0; QK_K / 32];
1458
1459            for (j, x_scale_slice) in x.chunks_exact(32).enumerate() {
1460                (scales[j], mins[j]) = make_qkx1_quants(15, 5, x_scale_slice);
1461            }
1462
1463            // get max scale and max min and ensure they are >= 0.0
1464            let max_scale = scales.iter().fold(0.0, |max, &val| val.max(max));
1465            let max_min = mins.iter().fold(0.0, |max, &val| val.max(max));
1466
1467            let inv_scale = if max_scale > 0.0 {
1468                63.0 / max_scale
1469            } else {
1470                0.0
1471            };
1472            let inv_min = if max_min > 0.0 { 63.0 / max_min } else { 0.0 };
1473
1474            for j in 0..QK_K / 32 {
1475                let ls = nearest_int(inv_scale * scales[j]).min(63) as u8;
1476                let lm = nearest_int(inv_min * mins[j]).min(63) as u8;
1477                if j < 4 {
1478                    block.scales[j] = ls;
1479                    block.scales[j + 4] = lm;
1480                } else {
1481                    block.scales[j + 4] = (ls & 0xF) | ((lm & 0xF) << 4);
1482                    block.scales[j - 4] |= (ls >> 4) << 6;
1483                    block.scales[j] |= (lm >> 4) << 6;
1484                }
1485            }
1486
1487            block.d = f16::from_f32(max_scale / 63.0);
1488            block.dmin = f16::from_f32(max_min / 63.0);
1489
1490            let mut l: [u8; QK_K] = [0; QK_K];
1491
1492            for j in 0..QK_K / 32 {
1493                let (sc, m) = get_scale_min_k4(j, &block.scales);
1494                let d = block.d.to_f32() * sc as f32;
1495                if d != 0.0 {
1496                    let dm = block.dmin.to_f32() * m as f32;
1497                    for ii in 0..32 {
1498                        let l_val = nearest_int((x[32 * j + ii] + dm) / d);
1499                        l[32 * j + ii] = l_val.clamp(0, 15) as u8;
1500                    }
1501                }
1502            }
1503
1504            let q = &mut block.qs;
1505            for j in (0..QK_K).step_by(64) {
1506                for l_val in 0..32 {
1507                    let offset_index = (j / 64) * 32 + l_val;
1508                    q[offset_index] = l[j + l_val] | (l[j + l_val + 32] << 4);
1509                }
1510            }
1511        }
1512    }
1513
1514    fn from_float_imatrix(xs: &[f32], ys: &mut [Self], imatrix_weights: &[f32], n_per_row: usize) {
1515        for (sblk_idx, (block, x)) in group_for_quantization(xs, ys).into_iter().enumerate() {
1516            let mut mins: [f32; QK_K / 32] = [0.0; QK_K / 32];
1517            let mut scales: [f32; QK_K / 32] = [0.0; QK_K / 32];
1518            let mut weights: [f32; 32] = [0.0; 32];
1519            let mut sw: [f32; QK_K / 32] = [0.0; QK_K / 32];
1520            let mut ls: [u8; QK_K / 32] = [0; QK_K / 32];
1521            let mut lm: [u8; QK_K / 32] = [0; QK_K / 32];
1522
1523            let sum_x2 = x.iter().map(|x| x * x).sum::<f32>();
1524            let sigma2 = 2. * sum_x2 / QK_K as f32;
1525
1526            for (j, x_scale_slice) in x.chunks_exact(32).enumerate() {
1527                for (l, (w_elem, x_elem)) in weights.iter_mut().zip(x_scale_slice).enumerate() {
1528                    let imatrix_row = sblk_idx % (n_per_row / QK_K);
1529                    let imatrix_w = imatrix_weights[imatrix_row * QK_K + 32 * j + l];
1530                    *w_elem = imatrix_w * (sigma2 + x_elem * x_elem).sqrt();
1531                }
1532                let sumw = weights.iter().sum::<f32>();
1533                sw[j] = sumw;
1534                (scales[j], mins[j]) =
1535                    make_qkx3_quants(15, x_scale_slice, Some(&weights), -0.9, 0.05, 36, false);
1536            }
1537
1538            let d_block = make_qp_quants(QK_K / 32, 63, &scales, &mut ls, &sw);
1539            let m_block = make_qp_quants(QK_K / 32, 63, &mins, &mut lm, &sw);
1540            for j in 0..QK_K / 32 {
1541                let ls_val = ls[j];
1542                let lm_val = lm[j];
1543                if j < 4 {
1544                    block.scales[j] = ls_val;
1545                    block.scales[j + 4] = lm_val;
1546                } else {
1547                    block.scales[j + 4] = (ls_val & 0xF) | ((lm_val & 0xF) << 4);
1548                    block.scales[j - 4] |= (ls_val >> 4) << 6;
1549                    block.scales[j] |= (lm_val >> 4) << 6;
1550                }
1551            }
1552
1553            block.d = f16::from_f32(d_block);
1554            block.dmin = f16::from_f32(m_block);
1555
1556            let mut l: [u8; QK_K] = [0; QK_K];
1557            for j in 0..QK_K / 32 {
1558                let (sc, m) = get_scale_min_k4(j, &block.scales);
1559                let d = block.d.to_f32() * sc as f32;
1560                if d != 0.0 {
1561                    let dm = block.dmin.to_f32() * m as f32;
1562                    for ii in 0..32 {
1563                        let l_val = nearest_int((x[32 * j + ii] + dm) / d);
1564                        l[32 * j + ii] = l_val.clamp(0, 15) as u8;
1565                    }
1566                }
1567            }
1568
1569            let q = &mut block.qs;
1570            for j in (0..QK_K).step_by(64) {
1571                for l_val in 0..32 {
1572                    let offset_index = (j / 64) * 32 + l_val;
1573                    q[offset_index] = l[j + l_val] | (l[j + l_val + 32] << 4);
1574                }
1575            }
1576        }
1577    }
1578    // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L735
1579    fn to_float(xs: &[Self], ys: &mut [f32]) {
1580        for (block, y) in group_for_dequantization(xs, ys) {
1581            let d = block.d.to_f32();
1582            let min = block.dmin.to_f32();
1583            let q = &block.qs;
1584            let mut is = 0;
1585            let mut ys_index = 0;
1586
1587            for j in (0..QK_K).step_by(64) {
1588                let q = &q[j / 2..j / 2 + 32];
1589                let (sc, m) = get_scale_min_k4(is, &block.scales);
1590                let d1 = d * sc as f32;
1591                let m1 = min * m as f32;
1592                let (sc, m) = get_scale_min_k4(is + 1, &block.scales);
1593                let d2 = d * sc as f32;
1594                let m2 = min * m as f32;
1595                for q in q {
1596                    y[ys_index] = d1 * (q & 0xF) as f32 - m1;
1597                    ys_index += 1;
1598                }
1599                for q in q {
1600                    y[ys_index] = d2 * (q >> 4) as f32 - m2;
1601                    ys_index += 1;
1602                }
1603                is += 2;
1604            }
1605        }
1606    }
1607}
1608
1609// https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L928
1610impl GgmlType for BlockQ5K {
1611    const DTYPE: GgmlDType = GgmlDType::Q5K;
1612    const BLCK_SIZE: usize = QK_K;
1613    type VecDotType = BlockQ8K;
1614
1615    #[allow(unreachable_code)]
1616    fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
1617        #[cfg(target_feature = "avx2")]
1618        return super::avx::vec_dot_q5k_q8k(n, xs, ys);
1619
1620        #[cfg(target_feature = "neon")]
1621        return super::neon::vec_dot_q5k_q8k(n, xs, ys);
1622
1623        Self::vec_dot_unopt(n, xs, ys)
1624    }
1625
1626    fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
1627        debug_assert!(
1628            n.is_multiple_of(QK_K),
1629            "vec_dot_q5k_q8k: {n} is not divisible by {QK_K}"
1630        );
1631
1632        const KMASK1: u32 = 0x3f3f3f3f;
1633        const KMASK2: u32 = 0x0f0f0f0f;
1634        const KMASK3: u32 = 0x03030303;
1635
1636        let mut utmp: [u32; 4] = [0; 4];
1637        let mut scales: [u8; 8] = [0; 8];
1638        let mut mins: [u8; 8] = [0; 8];
1639
1640        let mut aux8: [i8; QK_K] = [0; QK_K];
1641        let mut aux16: [i16; 8] = [0; 8];
1642        let mut sums: [f32; 8] = [0.0; 8];
1643        let mut aux32: [i32; 8] = [0; 8];
1644
1645        let mut sumf = 0.0;
1646        for (y, x) in ys.iter().zip(xs.iter()) {
1647            let q5 = &x.qs;
1648            let hm = &x.qh;
1649            let q8 = &y.qs;
1650            aux32.fill(0);
1651
1652            let mut a = &mut aux8[..];
1653            let mut q5 = &q5[..];
1654            let mut m = 1u8;
1655
1656            for _ in 0..QK_K / 64 {
1657                for l in 0..32 {
1658                    a[l] = (q5[l] & 0xF) as i8;
1659                    a[l] += if hm[l] & m != 0 { 16 } else { 0 };
1660                }
1661                a = &mut a[32..];
1662                m <<= 1;
1663                for l in 0..32 {
1664                    a[l] = (q5[l] >> 4) as i8;
1665                    a[l] += if hm[l] & m != 0 { 16 } else { 0 };
1666                }
1667                a = &mut a[32..];
1668                m <<= 1;
1669                q5 = &q5[32..];
1670            }
1671
1672            LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
1673
1674            utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
1675            let uaux = utmp[1] & KMASK1;
1676            utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
1677            utmp[2] = uaux;
1678            utmp[0] &= KMASK1;
1679
1680            //extract scales and mins
1681            LittleEndian::write_u32_into(&utmp[0..2], &mut scales);
1682            LittleEndian::write_u32_into(&utmp[2..4], &mut mins);
1683
1684            let mut sumi = 0;
1685            for j in 0..QK_K / 16 {
1686                sumi += y.bsums[j] as i32 * mins[j / 2] as i32;
1687            }
1688
1689            let mut a = &mut aux8[..];
1690            let mut q8 = &q8[..];
1691
1692            for scale in scales {
1693                let scale = scale as i32;
1694                for _ in 0..4 {
1695                    for l in 0..8 {
1696                        aux16[l] = q8[l] as i16 * a[l] as i16;
1697                    }
1698                    for l in 0..8 {
1699                        aux32[l] += scale * aux16[l] as i32;
1700                    }
1701                    q8 = &q8[8..];
1702                    a = &mut a[8..];
1703                }
1704            }
1705            let d = x.d.to_f32() * y.d;
1706            for l in 0..8 {
1707                sums[l] += d * aux32[l] as f32;
1708            }
1709            let dmin = x.dmin.to_f32() * y.d;
1710            sumf -= dmin * sumi as f32;
1711        }
1712        sumf + sums.iter().sum::<f32>()
1713    }
1714
1715    // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L793
1716    fn from_float(xs: &[f32], ys: &mut [Self]) {
1717        for (block, x) in group_for_quantization(xs, ys) {
1718            let mut mins: [f32; QK_K / 32] = [0.0; QK_K / 32];
1719            let mut scales: [f32; QK_K / 32] = [0.0; QK_K / 32];
1720
1721            for (j, x_scale_slice) in x.chunks_exact(32).enumerate() {
1722                (scales[j], mins[j]) = make_qkx1_quants(31, 5, x_scale_slice);
1723            }
1724
1725            // get max scale and max min and ensure they are >= 0.0
1726            let max_scale = scales.iter().fold(0.0, |max, &val| val.max(max));
1727            let max_min = mins.iter().fold(0.0, |max, &val| val.max(max));
1728
1729            let inv_scale = if max_scale > 0.0 {
1730                63.0 / max_scale
1731            } else {
1732                0.0
1733            };
1734            let inv_min = if max_min > 0.0 { 63.0 / max_min } else { 0.0 };
1735            for j in 0..QK_K / 32 {
1736                let ls = nearest_int(inv_scale * scales[j]).min(63) as u8;
1737                let lm = nearest_int(inv_min * mins[j]).min(63) as u8;
1738                if j < 4 {
1739                    block.scales[j] = ls;
1740                    block.scales[j + 4] = lm;
1741                } else {
1742                    block.scales[j + 4] = (ls & 0xF) | ((lm & 0xF) << 4);
1743                    block.scales[j - 4] |= (ls >> 4) << 6;
1744                    block.scales[j] |= (lm >> 4) << 6;
1745                }
1746            }
1747            block.d = f16::from_f32(max_scale / 63.0);
1748            block.dmin = f16::from_f32(max_min / 63.0);
1749
1750            let mut l: [u8; QK_K] = [0; QK_K];
1751            for j in 0..QK_K / 32 {
1752                let (sc, m) = get_scale_min_k4(j, &block.scales);
1753                let d = block.d.to_f32() * sc as f32;
1754                if d == 0.0 {
1755                    continue;
1756                }
1757                let dm = block.dmin.to_f32() * m as f32;
1758                for ii in 0..32 {
1759                    let ll = nearest_int((x[32 * j + ii] + dm) / d);
1760                    l[32 * j + ii] = ll.clamp(0, 31) as u8;
1761                }
1762            }
1763
1764            let qh = &mut block.qh;
1765            let ql = &mut block.qs;
1766            qh.fill(0);
1767
1768            let mut m1 = 1;
1769            let mut m2 = 2;
1770            for n in (0..QK_K).step_by(64) {
1771                let offset = (n / 64) * 32;
1772                for j in 0..32 {
1773                    let mut l1 = l[n + j];
1774                    if l1 > 15 {
1775                        l1 -= 16;
1776                        qh[j] |= m1;
1777                    }
1778                    let mut l2 = l[n + j + 32];
1779                    if l2 > 15 {
1780                        l2 -= 16;
1781                        qh[j] |= m2;
1782                    }
1783                    ql[offset + j] = l1 | (l2 << 4);
1784                }
1785                m1 <<= 2;
1786                m2 <<= 2;
1787            }
1788        }
1789    }
1790
1791    fn from_float_imatrix(xs: &[f32], ys: &mut [Self], imatrix_weights: &[f32], n_per_row: usize) {
1792        for (sblk_idx, (block, x)) in group_for_quantization(xs, ys).into_iter().enumerate() {
1793            let mut mins: [f32; QK_K / 32] = [0.0; QK_K / 32];
1794            let mut scales: [f32; QK_K / 32] = [0.0; QK_K / 32];
1795            let mut weights: [f32; 32] = [0.0; 32];
1796            let mut sw: [f32; QK_K / 32] = [0.0; QK_K / 32];
1797            let mut ls: [u8; QK_K / 32] = [0; QK_K / 32];
1798            let mut lm: [u8; QK_K / 32] = [0; QK_K / 32];
1799
1800            let sum_x2 = x.iter().map(|x| x * x).sum::<f32>();
1801            let sigma2 = 2. * sum_x2 / QK_K as f32;
1802
1803            for (j, x_scale_slice) in x.chunks_exact(32).enumerate() {
1804                for (l, (w_elem, x_elem)) in weights.iter_mut().zip(x_scale_slice).enumerate() {
1805                    let imatrix_row = sblk_idx % (n_per_row / QK_K);
1806                    let imatrix_w = imatrix_weights[imatrix_row * QK_K + 32 * j + l];
1807                    *w_elem = imatrix_w * (sigma2 + x_elem * x_elem).sqrt();
1808                }
1809                let sumw = weights.iter().sum::<f32>();
1810                sw[j] = sumw;
1811                (scales[j], mins[j]) =
1812                    make_qkx3_quants(31, x_scale_slice, Some(&weights), -0.9, 0.05, 36, false);
1813            }
1814
1815            let d_block = make_qp_quants(QK_K / 32, 63, &scales, &mut ls, &sw);
1816            let m_block = make_qp_quants(QK_K / 32, 63, &mins, &mut lm, &sw);
1817            for j in 0..QK_K / 32 {
1818                let ls_val = ls[j].min(63);
1819                let lm_val = lm[j].min(63);
1820                if j < 4 {
1821                    block.scales[j] = ls_val;
1822                    block.scales[j + 4] = lm_val;
1823                } else {
1824                    block.scales[j + 4] = (ls_val & 0xF) | ((lm_val & 0xF) << 4);
1825                    block.scales[j - 4] |= (ls_val >> 4) << 6;
1826                    block.scales[j] |= (lm_val >> 4) << 6;
1827                }
1828            }
1829
1830            block.d = f16::from_f32(d_block);
1831            block.dmin = f16::from_f32(m_block);
1832
1833            let mut l: [u8; QK_K] = [0; QK_K];
1834            for j in 0..QK_K / 32 {
1835                let (sc, m) = get_scale_min_k4(j, &block.scales);
1836                let d = block.d.to_f32() * sc as f32;
1837                if d != 0.0 {
1838                    let dm = block.dmin.to_f32() * m as f32;
1839                    for ii in 0..32 {
1840                        let l_val = nearest_int((x[32 * j + ii] + dm) / d);
1841                        l[32 * j + ii] = l_val.clamp(0, 31) as u8;
1842                    }
1843                }
1844            }
1845
1846            let qh = &mut block.qh;
1847            let ql = &mut block.qs;
1848            qh.fill(0);
1849
1850            let mut m1 = 1;
1851            let mut m2 = 2;
1852            for n in (0..QK_K).step_by(64) {
1853                let offset = (n / 64) * 32;
1854                for j in 0..32 {
1855                    let mut l1 = l[n + j];
1856                    if l1 > 15 {
1857                        l1 -= 16;
1858                        qh[j] |= m1;
1859                    }
1860                    let mut l2 = l[n + j + 32];
1861                    if l2 > 15 {
1862                        l2 -= 16;
1863                        qh[j] |= m2;
1864                    }
1865                    ql[offset + j] = l1 | (l2 << 4);
1866                }
1867                m1 <<= 2;
1868                m2 <<= 2;
1869            }
1870        }
1871    }
1872
1873    // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L928
1874    fn to_float(xs: &[Self], ys: &mut [f32]) {
1875        for (block, y) in group_for_dequantization(xs, ys) {
1876            let d = block.d.to_f32();
1877            let min = block.dmin.to_f32();
1878            let ql = &block.qs;
1879            let qh = &block.qh;
1880            let mut is = 0;
1881            let mut u1 = 1;
1882            let mut u2 = 2;
1883            let mut ys_index = 0;
1884
1885            for j in (0..QK_K).step_by(64) {
1886                let ql = &ql[j / 2..j / 2 + 32];
1887                let (sc, m) = get_scale_min_k4(is, &block.scales);
1888                let d1 = d * sc as f32;
1889                let m1 = min * m as f32;
1890                let (sc, m) = get_scale_min_k4(is + 1, &block.scales);
1891                let d2 = d * sc as f32;
1892                let m2 = min * m as f32;
1893                for (ql, qh) in ql.iter().zip(qh) {
1894                    let to_add = if qh & u1 != 0 { 16f32 } else { 0f32 };
1895                    y[ys_index] = d1 * ((ql & 0xF) as f32 + to_add) - m1;
1896                    ys_index += 1;
1897                }
1898                for (ql, qh) in ql.iter().zip(qh) {
1899                    let to_add = if qh & u2 != 0 { 16f32 } else { 0f32 };
1900                    y[ys_index] = d2 * ((ql >> 4) as f32 + to_add) - m2;
1901                    ys_index += 1;
1902                }
1903                is += 2;
1904                u1 <<= 2;
1905                u2 <<= 2;
1906            }
1907        }
1908    }
1909}
1910
1911impl GgmlType for BlockQ6K {
1912    const DTYPE: GgmlDType = GgmlDType::Q6K;
1913    const BLCK_SIZE: usize = QK_K;
1914    type VecDotType = BlockQ8K;
1915
1916    #[allow(unreachable_code)]
1917    fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
1918        #[cfg(target_feature = "avx2")]
1919        return super::avx::vec_dot_q6k_q8k(n, xs, ys);
1920
1921        #[cfg(target_feature = "neon")]
1922        return super::neon::vec_dot_q6k_q8k(n, xs, ys);
1923
1924        #[cfg(target_feature = "simd128")]
1925        return super::simd128::vec_dot_q6k_q8k(n, xs, ys);
1926
1927        Self::vec_dot_unopt(n, xs, ys)
1928    }
1929
1930    fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
1931        debug_assert!(
1932            n.is_multiple_of(QK_K),
1933            "vec_dot_q6k_q8k: {n} is not divisible by {QK_K}"
1934        );
1935
1936        let mut aux8 = [0i8; QK_K];
1937        let mut aux16 = [0i16; 8];
1938        let mut sums = [0f32; 8];
1939        let mut aux32 = [0f32; 8];
1940
1941        for (x, y) in xs.iter().zip(ys.iter()) {
1942            let q4 = &x.ql;
1943            let qh = &x.qh;
1944            let q8 = &y.qs;
1945            aux32.fill(0f32);
1946
1947            for j in (0..QK_K).step_by(128) {
1948                let aux8 = &mut aux8[j..];
1949                let q4 = &q4[j / 2..];
1950                let qh = &qh[j / 4..];
1951                for l in 0..32 {
1952                    aux8[l] = (((q4[l] & 0xF) | ((qh[l] & 3) << 4)) as i32 - 32) as i8;
1953                    aux8[l + 32] =
1954                        (((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) as i32 - 32) as i8;
1955                    aux8[l + 64] = (((q4[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) as i32 - 32) as i8;
1956                    aux8[l + 96] =
1957                        (((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) as i32 - 32) as i8;
1958                }
1959            }
1960
1961            for (j, &scale) in x.scales.iter().enumerate() {
1962                let scale = scale as f32;
1963                let q8 = &q8[16 * j..];
1964                let aux8 = &aux8[16 * j..];
1965                for l in 0..8 {
1966                    aux16[l] = q8[l] as i16 * aux8[l] as i16;
1967                }
1968                for l in 0..8 {
1969                    aux32[l] += scale * aux16[l] as f32
1970                }
1971                let q8 = &q8[8..];
1972                let aux8 = &aux8[8..];
1973                for l in 0..8 {
1974                    aux16[l] = q8[l] as i16 * aux8[l] as i16;
1975                }
1976                for l in 0..8 {
1977                    aux32[l] += scale * aux16[l] as f32
1978                }
1979            }
1980
1981            let d = x.d.to_f32() * y.d;
1982            for (sum, &a) in sums.iter_mut().zip(aux32.iter()) {
1983                *sum += a * d;
1984            }
1985        }
1986        sums.iter().sum()
1987    }
1988
1989    fn from_float(xs: &[f32], ys: &mut [Self]) {
1990        debug_assert_eq!(
1991            xs.len(),
1992            ys.len() * Self::BLCK_SIZE,
1993            "quantize_row_q6k: size mismatch {} {} {}",
1994            xs.len(),
1995            ys.len(),
1996            Self::BLCK_SIZE
1997        );
1998        let mut l = [0i8; QK_K];
1999        let mut scales = [0f32; QK_K / 16];
2000        let mut x = xs.as_ptr();
2001        let l = l.as_mut_ptr();
2002        unsafe {
2003            for y in ys.iter_mut() {
2004                let mut max_scale = 0f32;
2005                let mut max_abs_scale = 0f32;
2006                for (ib, scale_) in scales.iter_mut().enumerate() {
2007                    let scale =
2008                        make_qx_quants(16, 32, x.add(16 * ib), l.add(16 * ib), 1, std::ptr::null());
2009                    *scale_ = scale;
2010                    let abs_scale = scale.abs();
2011                    if abs_scale > max_abs_scale {
2012                        max_abs_scale = abs_scale;
2013                        max_scale = scale
2014                    }
2015                }
2016
2017                let iscale = -128f32 / max_scale;
2018                y.d = f16::from_f32(1.0 / iscale);
2019
2020                for (y_scale, scale) in y.scales.iter_mut().zip(scales.iter()) {
2021                    *y_scale = nearest_int(iscale * scale).min(127) as i8
2022                }
2023
2024                for (j, &y_scale) in y.scales.iter().enumerate() {
2025                    let d = y.d.to_f32() * y_scale as f32;
2026                    if d == 0. {
2027                        continue;
2028                    }
2029                    for ii in 0..16 {
2030                        let ll = nearest_int(*x.add(16 * j + ii) / d).clamp(-32, 31);
2031                        *l.add(16 * j + ii) = (ll + 32) as i8
2032                    }
2033                }
2034
2035                let mut ql = y.ql.as_mut_ptr();
2036                let mut qh = y.qh.as_mut_ptr();
2037
2038                for j in (0..QK_K).step_by(128) {
2039                    for l_idx in 0..32 {
2040                        let q1 = *l.add(j + l_idx) & 0xF;
2041                        let q2 = *l.add(j + l_idx + 32) & 0xF;
2042                        let q3 = *l.add(j + l_idx + 64) & 0xF;
2043                        let q4 = *l.add(j + l_idx + 96) & 0xF;
2044                        *ql.add(l_idx) = (q1 | (q3 << 4)) as u8;
2045                        *ql.add(l_idx + 32) = (q2 | (q4 << 4)) as u8;
2046                        *qh.add(l_idx) = ((*l.add(j + l_idx) >> 4)
2047                            | ((*l.add(j + l_idx + 32) >> 4) << 2)
2048                            | ((*l.add(j + l_idx + 64) >> 4) << 4)
2049                            | ((*l.add(j + l_idx + 96) >> 4) << 6))
2050                            as u8;
2051                    }
2052                    ql = ql.add(64);
2053                    qh = qh.add(32);
2054                }
2055
2056                x = x.add(QK_K)
2057            }
2058        }
2059    }
2060
2061    fn from_float_imatrix(xs: &[f32], ys: &mut [Self], imatrix_weights: &[f32], n_per_row: usize) {
2062        debug_assert_eq!(
2063            xs.len(),
2064            ys.len() * Self::BLCK_SIZE,
2065            "quantize_row_q6k imatrix: size mismatch {} {} {}",
2066            xs.len(),
2067            ys.len(),
2068            Self::BLCK_SIZE
2069        );
2070        let mut l = [0i8; QK_K];
2071        let mut scales = [0f32; QK_K / 16];
2072        let mut x = xs.as_ptr();
2073        let imatrix_weights = imatrix_weights.as_ptr();
2074        let l = l.as_mut_ptr();
2075        unsafe {
2076            for (sblk_idx, y) in ys.iter_mut().enumerate() {
2077                let mut max_scale = 0f32;
2078                let mut max_abs_scale = 0f32;
2079                for (ib, scale_) in scales.iter_mut().enumerate() {
2080                    let imatrix_row = sblk_idx % (n_per_row / QK_K);
2081                    let scale = make_qx_quants(
2082                        16,
2083                        32,
2084                        x.add(16 * ib),
2085                        l.add(16 * ib),
2086                        1,
2087                        imatrix_weights.add(QK_K * imatrix_row + 16 * ib),
2088                    );
2089                    *scale_ = scale;
2090                    let abs_scale = scale.abs();
2091                    if abs_scale > max_abs_scale {
2092                        max_abs_scale = abs_scale;
2093                        max_scale = scale
2094                    }
2095                }
2096
2097                let iscale = -128f32 / max_scale;
2098                y.d = f16::from_f32(1.0 / iscale);
2099
2100                for (y_scale, scale) in y.scales.iter_mut().zip(scales.iter()) {
2101                    *y_scale = nearest_int(iscale * scale).min(127) as i8
2102                }
2103
2104                for (j, &y_scale) in y.scales.iter().enumerate() {
2105                    let d = y.d.to_f32() * y_scale as f32;
2106                    if d == 0. {
2107                        continue;
2108                    }
2109                    for ii in 0..16 {
2110                        let ll = nearest_int(*x.add(16 * j + ii) / d).clamp(-32, 31);
2111                        *l.add(16 * j + ii) = (ll + 32) as i8
2112                    }
2113                }
2114
2115                let mut ql = y.ql.as_mut_ptr();
2116                let mut qh = y.qh.as_mut_ptr();
2117
2118                for j in (0..QK_K).step_by(128) {
2119                    for l_idx in 0..32 {
2120                        let q1 = *l.add(j + l_idx) & 0xF;
2121                        let q2 = *l.add(j + l_idx + 32) & 0xF;
2122                        let q3 = *l.add(j + l_idx + 64) & 0xF;
2123                        let q4 = *l.add(j + l_idx + 96) & 0xF;
2124                        *ql.add(l_idx) = (q1 | (q3 << 4)) as u8;
2125                        *ql.add(l_idx + 32) = (q2 | (q4 << 4)) as u8;
2126                        *qh.add(l_idx) = ((*l.add(j + l_idx) >> 4)
2127                            | ((*l.add(j + l_idx + 32) >> 4) << 2)
2128                            | ((*l.add(j + l_idx + 64) >> 4) << 4)
2129                            | ((*l.add(j + l_idx + 96) >> 4) << 6))
2130                            as u8;
2131                    }
2132                    ql = ql.add(64);
2133                    qh = qh.add(32);
2134                }
2135
2136                x = x.add(QK_K)
2137            }
2138        }
2139    }
2140
2141    // https://github.com/ggerganov/llama.cpp/blob/8183159cf3def112f6d1fe94815fce70e1bffa12/k_quants.c#L1067
2142    fn to_float(xs: &[Self], ys: &mut [f32]) {
2143        let k = ys.len();
2144        debug_assert!(
2145            k.is_multiple_of(QK_K),
2146            "dequantize_row_q6k: {k} is not divisible by {QK_K}"
2147        );
2148
2149        for (idx_x, x) in xs.iter().enumerate() {
2150            let d = x.d.to_f32();
2151            let ql = &x.ql;
2152            let qh = &x.qh;
2153            let sc = &x.scales;
2154            for n in (0..QK_K).step_by(128) {
2155                let idx = n / 128;
2156                let ys = &mut ys[idx_x * QK_K + n..];
2157                let sc = &sc[8 * idx..];
2158                let ql = &ql[64 * idx..];
2159                let qh = &qh[32 * idx..];
2160                for l in 0..32 {
2161                    let is = l / 16;
2162                    let q1 = ((ql[l] & 0xF) | ((qh[l] & 3) << 4)) as i8 - 32;
2163                    let q2 = ((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) as i8 - 32;
2164                    let q3 = ((ql[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) as i8 - 32;
2165                    let q4 = ((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) as i8 - 32;
2166                    ys[l] = d * sc[is] as f32 * q1 as f32;
2167                    ys[l + 32] = d * sc[is + 2] as f32 * q2 as f32;
2168                    ys[l + 64] = d * sc[is + 4] as f32 * q3 as f32;
2169                    ys[l + 96] = d * sc[is + 6] as f32 * q4 as f32;
2170                }
2171            }
2172        }
2173    }
2174}
2175
2176impl GgmlType for BlockQ8K {
2177    const DTYPE: GgmlDType = GgmlDType::Q8K;
2178    const BLCK_SIZE: usize = QK_K;
2179    type VecDotType = BlockQ8K;
2180
2181    #[allow(unreachable_code)]
2182    fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
2183        #[cfg(target_feature = "avx2")]
2184        return super::avx::vec_dot_q8k_q8k(n, xs, ys);
2185
2186        #[cfg(target_feature = "neon")]
2187        return super::neon::vec_dot_q8k_q8k(n, xs, ys);
2188
2189        #[cfg(target_feature = "simd128")]
2190        return super::simd128::vec_dot_q8k_q8k(n, xs, ys);
2191
2192        Self::vec_dot_unopt(n, xs, ys)
2193    }
2194
2195    fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
2196        debug_assert!(
2197            n.is_multiple_of(QK_K),
2198            "vec_dot_q8k_q8k: {n} is not divisible by {QK_K}"
2199        );
2200        // Generic implementation.
2201        let mut sumf = 0f32;
2202        for (xs, ys) in xs.iter().zip(ys.iter()) {
2203            let sum_i = xs
2204                .qs
2205                .iter()
2206                .zip(ys.qs.iter())
2207                .map(|(&x, &y)| x as i32 * y as i32)
2208                .sum::<i32>();
2209            sumf += sum_i as f32 * xs.d * ys.d
2210        }
2211        sumf
2212    }
2213
2214    fn from_float(xs: &[f32], ys: &mut [Self]) {
2215        let k = xs.len();
2216        debug_assert!(
2217            k.is_multiple_of(QK_K),
2218            "quantize_row_q8k: {k} is not divisible by {QK_K}"
2219        );
2220        for (i, y) in ys.iter_mut().enumerate() {
2221            let mut max = 0f32;
2222            let mut amax = 0f32;
2223            let xs = &xs[i * QK_K..(i + 1) * QK_K];
2224            for &x in xs.iter() {
2225                if amax < x.abs() {
2226                    amax = x.abs();
2227                    max = x;
2228                }
2229            }
2230            if amax == 0f32 {
2231                y.d = 0f32;
2232                y.qs.fill(0)
2233            } else {
2234                let iscale = -128f32 / max;
2235                for (j, q) in y.qs.iter_mut().enumerate() {
2236                    // ggml uses nearest_int with bit magic here, maybe we want the same
2237                    // but we would have to test and benchmark it.
2238                    let v = (iscale * xs[j]).round();
2239                    *q = v.min(127.) as i8
2240                }
2241                for j in 0..QK_K / 16 {
2242                    let mut sum = 0i32;
2243                    for ii in 0..16 {
2244                        sum += y.qs[j * 16 + ii] as i32
2245                    }
2246                    y.bsums[j] = sum as i16
2247                }
2248                y.d = 1.0 / iscale
2249            }
2250        }
2251    }
2252
2253    fn to_float(xs: &[Self], ys: &mut [f32]) {
2254        let k = ys.len();
2255        debug_assert!(
2256            k.is_multiple_of(QK_K),
2257            "dequantize_row_q8k: {k} is not divisible by {QK_K}"
2258        );
2259        for (i, x) in xs.iter().enumerate() {
2260            for (j, &q) in x.qs.iter().enumerate() {
2261                ys[i * QK_K + j] = x.d * q as f32
2262            }
2263        }
2264    }
2265}
2266
2267// https://github.com/ggml-org/llama.cpp/blob/aa3ee0eb0b80efca126cedf9bcb4fb5864b46ce3/ggml/src/ggml-cpu/ggml-cpu.c#L1205
2268pub fn matmul<T: GgmlType>(
2269    (m, k, n): (usize, usize, usize),
2270    lhs: &[f32],
2271    rhs_t: &[T],
2272    dst: &mut [f32],
2273) -> Result<()> {
2274    debug_assert_eq!(
2275        T::BLCK_SIZE,
2276        T::VecDotType::BLCK_SIZE,
2277        "Mismatched block sizes"
2278    );
2279    debug_assert_eq!(
2280        m * k,
2281        lhs.len(),
2282        "unexpected lhs length {} ({m},{k},{n})",
2283        lhs.len()
2284    );
2285    let k_in_blocks = k.div_ceil(T::BLCK_SIZE);
2286
2287    // TODO: Pre-allocate this.
2288    let mut lhs_b = vec![T::VecDotType::zeros(); m * k_in_blocks];
2289    // f32, f16, and bf16 support direct copy
2290    if T::DIRECT_COPY {
2291        T::VecDotType::direct_copy(lhs, &mut lhs_b);
2292    } else {
2293        for row_idx in 0..m {
2294            let lhs_b_mut = &mut lhs_b[row_idx * k_in_blocks..(row_idx + 1) * k_in_blocks];
2295            let lhs = &lhs[row_idx * k..(row_idx + 1) * k];
2296            T::VecDotType::from_float(lhs, lhs_b_mut)
2297        }
2298    }
2299
2300    for row_idx in 0..m {
2301        let lhs_row = &lhs_b[row_idx * k_in_blocks..(row_idx + 1) * k_in_blocks];
2302        let dst_row = &mut dst[row_idx * n..(row_idx + 1) * n];
2303
2304        dst_row
2305            .into_par_iter()
2306            .enumerate()
2307            .with_min_len(128)
2308            .with_max_len(512)
2309            .for_each(|(col_idx, dst)| {
2310                let rhs_col = &rhs_t[col_idx * k_in_blocks..(col_idx + 1) * k_in_blocks];
2311                *dst = T::vec_dot(k, rhs_col, lhs_row);
2312            });
2313    }
2314    Ok(())
2315}
2316
2317pub fn matmul_f16<T: GgmlType>(
2318    mkn: (usize, usize, usize),
2319    lhs: &[f16],
2320    rhs_t: &[T],
2321    dst: &mut [f16],
2322) -> Result<()> {
2323    let (m, k, n) = mkn;
2324    if m * k != lhs.len() {
2325        crate::bail!("unexpected lhs length {} {mkn:?}", lhs.len());
2326    }
2327
2328    let k_in_lhs_blocks = k.div_ceil(T::BLCK_SIZE);
2329    let k_in_rhs_blocks = k.div_ceil(T::VecDotType::BLCK_SIZE);
2330    let mut lhs_b = vec![T::VecDotType::zeros(); m * k_in_lhs_blocks];
2331    for row_idx in 0..m {
2332        let lhs_b = &mut lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks];
2333        let lhs = &lhs[row_idx * k..(row_idx + 1) * k];
2334        let lhs_f32: Vec<_> = lhs.iter().map(|&x| x.to_f32()).collect();
2335        T::VecDotType::from_float(&lhs_f32, lhs_b);
2336    }
2337    let lhs_b = lhs_b.as_slice();
2338
2339    for row_idx in 0..m {
2340        let lhs_row = &lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks];
2341        let dst_row = &mut dst[row_idx * n..(row_idx + 1) * n];
2342
2343        for (col_idx, dst) in dst_row.iter_mut().enumerate() {
2344            let rhs_col = &rhs_t[col_idx * k_in_rhs_blocks..(col_idx + 1) * k_in_rhs_blocks];
2345            let value = T::vec_dot(k, rhs_col, lhs_row);
2346            *dst = f16::from_f32(value);
2347        }
2348    }
2349    Ok(())
2350}
2351
2352impl GgmlType for f32 {
2353    const DTYPE: GgmlDType = GgmlDType::F32;
2354    const BLCK_SIZE: usize = 1;
2355    const DIRECT_COPY: bool = true;
2356    type VecDotType = f32;
2357
2358    fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
2359        Self::vec_dot_unopt(n, xs, ys)
2360    }
2361
2362    fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
2363        debug_assert!(xs.len() >= n, "size mismatch xs {} < {n}", xs.len());
2364        debug_assert!(ys.len() >= n, "size mismatch ys {} < {n}", ys.len());
2365        let mut res = 0f32;
2366        unsafe { crate::cpu::vec_dot_f32(xs.as_ptr(), ys.as_ptr(), &mut res, n) };
2367        res
2368    }
2369
2370    fn from_float(xs: &[f32], ys: &mut [Self]) {
2371        debug_assert_eq!(
2372            xs.len(),
2373            ys.len(),
2374            "size mismatch xs {} != ys {}",
2375            xs.len(),
2376            ys.len()
2377        );
2378        ys.copy_from_slice(xs);
2379    }
2380
2381    fn to_float(xs: &[Self], ys: &mut [f32]) {
2382        debug_assert_eq!(
2383            xs.len(),
2384            ys.len(),
2385            "size mismatch xs {} != ys {}",
2386            xs.len(),
2387            ys.len()
2388        );
2389        ys.copy_from_slice(xs);
2390    }
2391
2392    fn direct_copy(xs: &[f32], ys: &mut [Self]) {
2393        Self::from_float(xs, ys)
2394    }
2395}
2396
2397impl GgmlType for f16 {
2398    const DTYPE: GgmlDType = GgmlDType::F16;
2399    const BLCK_SIZE: usize = 1;
2400    const DIRECT_COPY: bool = true;
2401    type VecDotType = f16;
2402
2403    fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
2404        Self::vec_dot_unopt(n, xs, ys)
2405    }
2406
2407    fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
2408        debug_assert!(xs.len() >= n, "size mismatch xs {} < {n}", xs.len());
2409        debug_assert!(ys.len() >= n, "size mismatch ys {} < {n}", ys.len());
2410        let mut res = 0f32;
2411        unsafe { crate::cpu::vec_dot_f16(xs.as_ptr(), ys.as_ptr(), &mut res, n) };
2412        res
2413    }
2414
2415    fn from_float(xs: &[f32], ys: &mut [Self]) {
2416        debug_assert_eq!(
2417            xs.len(),
2418            ys.len(),
2419            "size mismatch xs {} != ys {}",
2420            xs.len(),
2421            ys.len()
2422        );
2423        ys.convert_from_f32_slice(xs);
2424    }
2425
2426    fn to_float(xs: &[Self], ys: &mut [f32]) {
2427        debug_assert_eq!(
2428            xs.len(),
2429            ys.len(),
2430            "size mismatch xs {} != ys {}",
2431            xs.len(),
2432            ys.len()
2433        );
2434        xs.convert_to_f32_slice(ys);
2435    }
2436
2437    fn direct_copy(xs: &[f32], ys: &mut [Self]) {
2438        Self::from_float(xs, ys)
2439    }
2440}
2441
2442impl GgmlType for bf16 {
2443    const DTYPE: GgmlDType = GgmlDType::BF16;
2444    const BLCK_SIZE: usize = 1;
2445    const DIRECT_COPY: bool = true;
2446    type VecDotType = bf16;
2447
2448    fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
2449        Self::vec_dot_unopt(n, xs, ys)
2450    }
2451
2452    fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
2453        debug_assert!(xs.len() >= n, "size mismatch xs {} < {n}", xs.len());
2454        debug_assert!(ys.len() >= n, "size mismatch ys {} < {n}", ys.len());
2455        let mut res = 0f32;
2456        unsafe { crate::cpu::vec_dot_bf16(xs.as_ptr(), ys.as_ptr(), &mut res, n) };
2457        res
2458    }
2459
2460    fn from_float(xs: &[f32], ys: &mut [Self]) {
2461        debug_assert_eq!(
2462            xs.len(),
2463            ys.len(),
2464            "size mismatch xs {} != ys {}",
2465            xs.len(),
2466            ys.len()
2467        );
2468        ys.convert_from_f32_slice(xs);
2469    }
2470
2471    fn to_float(xs: &[Self], ys: &mut [f32]) {
2472        debug_assert_eq!(
2473            xs.len(),
2474            ys.len(),
2475            "size mismatch xs {} != ys {}",
2476            xs.len(),
2477            ys.len()
2478        );
2479        xs.convert_to_f32_slice(ys);
2480    }
2481
2482    fn direct_copy(xs: &[f32], ys: &mut [Self]) {
2483        Self::from_float(xs, ys)
2484    }
2485}
2486
2487macro_rules! verify_block_size {
2488    ( $block_type:ident ) => {
2489        const _: () =
2490            assert!($block_type::BLCK_SIZE == <$block_type as GgmlType>::VecDotType::BLCK_SIZE);
2491    };
2492}
2493
2494macro_rules! verify_block_sizes {
2495    ( $( $block_type:ident ),* ) => {
2496        $(
2497            verify_block_size!($block_type);
2498        )*
2499    };
2500}
2501
2502verify_block_sizes!(
2503    BlockQ4_0, BlockQ4_1, BlockQ5_0, BlockQ5_1, BlockQ8_0, BlockQ8_1, BlockQ2K, BlockQ3K, BlockQ4K,
2504    BlockQ5K, BlockQ6K, BlockQ8K, f32, f16, bf16
2505);