1use super::utils::{
2 get_scale_min_k4, group_for_dequantization, group_for_quantization, make_q3_quants,
3 make_qkx1_quants, make_qx_quants, nearest_int,
4};
5use super::GgmlDType;
6use crate::quantized::utils::{make_qkx3_quants, make_qp_quants};
7use crate::Result;
8use byteorder::{ByteOrder, LittleEndian};
9use half::{bf16, f16, slice::HalfFloatSliceExt};
10use rayon::prelude::*;
11
12pub const QK_K: usize = 256;
14pub const K_SCALE_SIZE: usize = 12;
15
16pub const QK4_0: usize = 32;
17pub const QK4_1: usize = 32;
18pub const QK5_0: usize = 32;
19pub const QK5_1: usize = 32;
20pub const QK8_0: usize = 32;
21pub const QK8_1: usize = 32;
22
23pub trait GgmlType: Sized + Clone + Send + Sync {
24 const DTYPE: GgmlDType;
25 const BLCK_SIZE: usize;
26 const DIRECT_COPY: bool = false;
27 type VecDotType: GgmlType;
28
29 fn zeros() -> Self {
31 unsafe { std::mem::MaybeUninit::zeroed().assume_init() }
32 }
33 fn to_float(xs: &[Self], ys: &mut [f32]);
34 fn from_float(xs: &[f32], ys: &mut [Self]);
35 fn from_float_imatrix(
36 _xs: &[f32],
37 _ys: &mut [Self],
38 _imatrix_weights: &[f32],
39 _n_per_row: usize,
40 ) {
41 panic!(
42 "`from_float_imatrix` is unimplemented for {:?}",
43 Self::DTYPE
44 );
45 }
46
47 fn direct_copy(_xs: &[f32], _ys: &mut [Self]) {}
48
49 fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32;
52
53 fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32;
55}
56
57#[derive(Debug, Clone, PartialEq)]
58#[repr(C)]
59pub struct BlockQ4_0 {
60 pub(crate) d: f16,
61 pub(crate) qs: [u8; QK4_0 / 2],
62}
63const _: () = assert!(std::mem::size_of::<BlockQ4_0>() == 18);
64
65#[derive(Debug, Clone, PartialEq)]
66#[repr(C)]
67pub struct BlockQ4_1 {
68 pub(crate) d: f16,
69 pub(crate) m: f16,
70 pub(crate) qs: [u8; QK4_1 / 2],
71}
72const _: () = assert!(std::mem::size_of::<BlockQ4_1>() == 20);
73
74#[derive(Debug, Clone, PartialEq)]
75#[repr(C)]
76pub struct BlockQ5_0 {
77 pub(crate) d: f16,
78 pub(crate) qh: [u8; 4],
79 pub(crate) qs: [u8; QK5_0 / 2],
80}
81const _: () = assert!(std::mem::size_of::<BlockQ5_0>() == 22);
82
83#[derive(Debug, Clone, PartialEq)]
84#[repr(C)]
85pub struct BlockQ5_1 {
86 pub(crate) d: f16,
87 pub(crate) m: f16,
88 pub(crate) qh: [u8; 4],
89 pub(crate) qs: [u8; QK5_1 / 2],
90}
91const _: () = assert!(std::mem::size_of::<BlockQ5_1>() == 24);
92
93#[derive(Debug, Clone, PartialEq)]
94#[repr(C)]
95pub struct BlockQ8_0 {
96 pub(crate) d: f16,
97 pub(crate) qs: [i8; QK8_0],
98}
99const _: () = assert!(std::mem::size_of::<BlockQ8_0>() == 34);
100
101#[derive(Debug, Clone, PartialEq)]
102#[repr(C)]
103pub struct BlockQ8_1 {
104 pub(crate) d: f16,
105 pub(crate) s: f16,
106 pub(crate) qs: [i8; QK8_1],
107}
108const _: () = assert!(std::mem::size_of::<BlockQ8_1>() == 36);
109
110#[derive(Debug, Clone, PartialEq)]
111#[repr(C)]
112pub struct BlockQ2K {
113 pub(crate) scales: [u8; QK_K / 16],
114 pub(crate) qs: [u8; QK_K / 4],
115 pub(crate) d: f16,
116 pub(crate) dmin: f16,
117}
118const _: () = assert!(QK_K / 16 + QK_K / 4 + 2 * 2 == std::mem::size_of::<BlockQ2K>());
119
120#[derive(Debug, Clone, PartialEq)]
121#[repr(C)]
122pub struct BlockQ3K {
123 pub(crate) hmask: [u8; QK_K / 8],
124 pub(crate) qs: [u8; QK_K / 4],
125 pub(crate) scales: [u8; 12],
126 pub(crate) d: f16,
127}
128const _: () = assert!(QK_K / 8 + QK_K / 4 + 12 + 2 == std::mem::size_of::<BlockQ3K>());
129
130#[derive(Debug, Clone, PartialEq)]
131#[repr(C)]
133pub struct BlockQ4K {
134 pub(crate) d: f16,
135 pub(crate) dmin: f16,
136 pub(crate) scales: [u8; K_SCALE_SIZE],
137 pub(crate) qs: [u8; QK_K / 2],
138}
139const _: () = assert!(QK_K / 2 + K_SCALE_SIZE + 2 * 2 == std::mem::size_of::<BlockQ4K>());
140
141#[derive(Debug, Clone, PartialEq)]
142#[repr(C)]
143pub struct BlockQ5K {
144 pub(crate) d: f16,
145 pub(crate) dmin: f16,
146 pub(crate) scales: [u8; K_SCALE_SIZE],
147 pub(crate) qh: [u8; QK_K / 8],
148 pub(crate) qs: [u8; QK_K / 2],
149}
150const _: () =
151 assert!(QK_K / 8 + QK_K / 2 + 2 * 2 + K_SCALE_SIZE == std::mem::size_of::<BlockQ5K>());
152
153#[derive(Debug, Clone, PartialEq)]
154#[repr(C)]
155pub struct BlockQ6K {
156 pub(crate) ql: [u8; QK_K / 2],
157 pub(crate) qh: [u8; QK_K / 4],
158 pub(crate) scales: [i8; QK_K / 16],
159 pub(crate) d: f16,
160}
161const _: () = assert!(3 * QK_K / 4 + QK_K / 16 + 2 == std::mem::size_of::<BlockQ6K>());
162
163#[derive(Debug, Clone, PartialEq)]
164#[repr(C)]
165pub struct BlockQ8K {
166 pub(crate) d: f32,
167 pub(crate) qs: [i8; QK_K],
168 pub(crate) bsums: [i16; QK_K / 16],
169}
170const _: () = assert!(4 + QK_K + QK_K / 16 * 2 == std::mem::size_of::<BlockQ8K>());
171
172impl GgmlType for BlockQ4_0 {
173 const DTYPE: GgmlDType = GgmlDType::Q4_0;
174 const BLCK_SIZE: usize = QK4_0;
175 type VecDotType = BlockQ8_0;
176
177 fn to_float(xs: &[Self], ys: &mut [f32]) {
179 let k = ys.len();
180 let qk = Self::BLCK_SIZE;
181 debug_assert!(
182 k.is_multiple_of(qk),
183 "dequantize_row_q4_0: {k} is not divisible by {qk}"
184 );
185
186 let nb = k / qk;
187 for i in 0..nb {
188 let d = xs[i].d.to_f32();
189
190 for j in 0..(qk / 2) {
191 let x0 = (xs[i].qs[j] & 0x0F) as i16 - 8;
192 let x1 = (xs[i].qs[j] >> 4) as i16 - 8;
193
194 ys[i * qk + j] = (x0 as f32) * d;
195 ys[i * qk + j + qk / 2] = (x1 as f32) * d;
196 }
197 }
198 }
199
200 fn from_float(xs: &[f32], ys: &mut [Self]) {
201 let qk = Self::BLCK_SIZE;
203 let k = xs.len();
204 debug_assert!(k.is_multiple_of(qk), "{k} is not divisible by {qk}");
205 debug_assert_eq!(
206 ys.len(),
207 k / qk,
208 "size mismatch {} {} {}",
209 xs.len(),
210 ys.len(),
211 qk,
212 );
213 for (i, ys) in ys.iter_mut().enumerate() {
214 let mut amax = 0f32;
215 let mut max = 0f32;
216
217 let xs = &xs[i * qk..(i + 1) * qk];
218 for &x in xs.iter() {
219 if amax < x.abs() {
220 amax = x.abs();
221 max = x;
222 }
223 }
224 let d = max / -8.0;
225 let id = if d != 0f32 { 1. / d } else { 0. };
226 ys.d = f16::from_f32(d);
227
228 for (j, q) in ys.qs.iter_mut().enumerate() {
229 let x0 = xs[j] * id;
230 let x1 = xs[qk / 2 + j] * id;
231 let xi0 = u8::min(15, (x0 + 8.5) as u8);
232 let xi1 = u8::min(15, (x1 + 8.5) as u8);
233 *q = xi0 | (xi1 << 4)
234 }
235 }
236 }
237
238 #[allow(unreachable_code)]
240 fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
241 #[cfg(target_feature = "avx2")]
242 return super::avx::vec_dot_q4_0_q8_0(n, xs, ys);
243
244 #[cfg(target_feature = "neon")]
245 return super::neon::vec_dot_q4_0_q8_0(n, xs, ys);
246
247 #[cfg(target_feature = "simd128")]
248 return super::simd128::vec_dot_q4_0_q8_0(n, xs, ys);
249
250 Self::vec_dot_unopt(n, xs, ys)
251 }
252
253 fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
254 debug_assert!(
255 n.is_multiple_of(QK8_0),
256 "vec_dot_q4_0_q8_0: {n} is not divisible by {QK8_0}"
257 );
258 let mut sumf = 0f32;
260 for (xs, ys) in xs.iter().zip(ys.iter()) {
261 let mut sum_i = 0;
262 for j in 0..QK8_0 / 2 {
263 let v0 = (xs.qs[j] & 0x0F) as i32 - 8;
264 let v1 = (xs.qs[j] >> 4) as i32 - 8;
265 sum_i += v0 * ys.qs[j] as i32 + v1 * ys.qs[j + QK8_0 / 2] as i32
266 }
267 sumf += sum_i as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d)
268 }
269 sumf
270 }
271}
272
273impl GgmlType for BlockQ4_1 {
274 const DTYPE: GgmlDType = GgmlDType::Q4_1;
275 const BLCK_SIZE: usize = QK4_1;
276 type VecDotType = BlockQ8_1;
277
278 fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
279 Self::vec_dot_unopt(n, xs, ys)
280 }
281
282 fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
283 let qk = QK8_1;
285 debug_assert!(
286 n.is_multiple_of(qk),
287 "vec_dot_q4_1_q8_1: {n} is not divisible by {qk}"
288 );
289 debug_assert!(
290 (n / qk).is_multiple_of(2),
291 "vec_dot_q4_1_q8_1: {n}, nb is not divisible by 2"
292 );
293
294 let mut sumf = 0f32;
296
297 for (xs, ys) in xs.iter().zip(ys.iter()) {
298 let mut sumi = 0i32;
299
300 for j in 0..qk / 2 {
301 let v0 = xs.qs[j] as i32 & 0x0F;
302 let v1 = xs.qs[j] as i32 >> 4;
303 sumi += (v0 * ys.qs[j] as i32) + (v1 * ys.qs[j + qk / 2] as i32);
304 }
305
306 sumf += sumi as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d)
307 + f16::to_f32(xs.m) * f16::to_f32(ys.s)
308 }
309 sumf
310 }
311
312 fn from_float(xs: &[f32], ys: &mut [Self]) {
313 let qk = Self::BLCK_SIZE;
315
316 debug_assert_eq!(
317 ys.len() * qk,
318 xs.len(),
319 "size mismatch {} {} {}",
320 xs.len(),
321 ys.len(),
322 qk,
323 );
324 for (i, ys) in ys.iter_mut().enumerate() {
325 let xs = &xs[i * qk..(i + 1) * qk];
326
327 let mut min = f32::INFINITY;
328 let mut max = f32::NEG_INFINITY;
329 for &x in xs.iter() {
330 min = f32::min(x, min);
331 max = f32::max(x, max);
332 }
333 let d = (max - min) / ((1 << 4) - 1) as f32;
334 let id = if d != 0f32 { 1. / d } else { 0. };
335 ys.d = f16::from_f32(d);
336 ys.m = f16::from_f32(min);
337
338 for (j, q) in ys.qs.iter_mut().take(qk / 2).enumerate() {
339 let x0 = (xs[j] - min) * id;
340 let x1 = (xs[qk / 2 + j] - min) * id;
341
342 let xi0 = u8::min(15, (x0 + 0.5) as u8);
343 let xi1 = u8::min(15, (x1 + 0.5) as u8);
344
345 *q = xi0 | (xi1 << 4);
346 }
347 }
348 }
349
350 fn to_float(xs: &[Self], ys: &mut [f32]) {
352 let k = ys.len();
353 debug_assert!(
354 k.is_multiple_of(QK4_1),
355 "dequantize_row_q4_1: {k} is not divisible by {QK4_1}"
356 );
357
358 let nb = k / QK4_1;
359 for i in 0..nb {
360 let d = xs[i].d.to_f32();
361 let m = xs[i].m.to_f32();
362
363 for j in 0..(QK4_1 / 2) {
364 let x0 = xs[i].qs[j] & 0x0F;
365 let x1 = xs[i].qs[j] >> 4;
366
367 ys[i * QK4_1 + j] = (x0 as f32) * d + m;
368 ys[i * QK4_1 + j + QK4_1 / 2] = (x1 as f32) * d + m;
369 }
370 }
371 }
372}
373
374impl GgmlType for BlockQ5_0 {
375 const DTYPE: GgmlDType = GgmlDType::Q5_0;
376 const BLCK_SIZE: usize = QK5_0;
377 type VecDotType = BlockQ8_0;
378
379 fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
380 let qk = Self::BLCK_SIZE;
381
382 debug_assert!(
383 n.is_multiple_of(qk),
384 "vec_dot_q5_0_q8_0: {n} is not divisible by {qk}"
385 );
386 debug_assert!(
387 (n / qk).is_multiple_of(2),
388 "vec_dot_q5_0_q8_0: {n}, nb is not divisible by 2"
389 );
390 Self::vec_dot_unopt(n, xs, ys)
391 }
392
393 fn vec_dot_unopt(_n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
394 let mut sumf = 0f32;
396
397 for (xs, ys) in xs.iter().zip(ys.iter()) {
398 let qh = LittleEndian::read_u32(&xs.qh);
399 let mut sumi = 0i32;
400
401 for j in 0..Self::BLCK_SIZE / 2 {
402 let xh_0 = (((qh & (1u32 << j)) >> j) << 4) as u8;
403 let xh_1 = ((qh & (1u32 << (j + 16))) >> (j + 12)) as u8;
404
405 let x0 = ((xs.qs[j] & 0x0F) as i32 | xh_0 as i32) - 16;
406 let x1 = ((xs.qs[j] >> 4) as i32 | xh_1 as i32) - 16;
407
408 sumi += (x0 * ys.qs[j] as i32) + (x1 * ys.qs[j + Self::BLCK_SIZE / 2] as i32);
409 }
410
411 sumf += sumi as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d)
412 }
413 sumf
414 }
415
416 fn from_float(xs: &[f32], ys: &mut [Self]) {
417 debug_assert_eq!(
419 ys.len() * Self::BLCK_SIZE,
420 xs.len(),
421 "size mismatch {} {} {}",
422 xs.len(),
423 ys.len(),
424 Self::BLCK_SIZE,
425 );
426 for (i, ys) in ys.iter_mut().enumerate() {
427 let xs = &xs[i * Self::BLCK_SIZE..(i + 1) * Self::BLCK_SIZE];
428
429 let mut amax = 0f32;
430 let mut max = 0f32;
431 for &x in xs.iter() {
432 if amax < x.abs() {
433 amax = x.abs();
434 max = x;
435 }
436 }
437 let d = max / -16.;
438 let id = if d != 0f32 { 1. / d } else { 0. };
439 ys.d = f16::from_f32(d);
440 let mut qh = 0u32;
441 for j in 0..Self::BLCK_SIZE / 2 {
442 let x0 = xs[j] * id;
443 let x1 = xs[j + Self::BLCK_SIZE / 2] * id;
444 let xi0 = ((x0 + 16.5) as i8).min(31) as u8;
445 let xi1 = ((x1 + 16.5) as i8).min(31) as u8;
446 ys.qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
447 qh |= ((xi0 as u32 & 0x10) >> 4) << j;
448 qh |= ((xi1 as u32 & 0x10) >> 4) << (j + Self::BLCK_SIZE / 2);
449 }
450 LittleEndian::write_u32(&mut ys.qh, qh)
451 }
452 }
453
454 fn to_float(xs: &[Self], ys: &mut [f32]) {
456 let k = ys.len();
457 debug_assert!(
458 k.is_multiple_of(QK5_0),
459 "dequantize_row_q5_0: {k} is not divisible by {QK5_0}"
460 );
461 let nb = k / QK5_0;
462 for i in 0..nb {
463 let d = xs[i].d.to_f32();
464 let qh: u32 = LittleEndian::read_u32(&xs[i].qh);
465
466 for j in 0..(QK5_0 / 2) {
467 let xh_0 = (((qh >> j) << 4) & 0x10) as u8;
468 let xh_1 = ((qh >> (j + 12)) & 0x10) as u8;
469
470 let x0 = ((xs[i].qs[j] & 0x0F) | xh_0) as i32 - 16;
471 let x1 = ((xs[i].qs[j] >> 4) | xh_1) as i32 - 16;
472
473 ys[i * QK5_0 + j] = (x0 as f32) * d;
474 ys[i * QK5_0 + j + QK5_0 / 2] = (x1 as f32) * d;
475 }
476 }
477 }
478}
479
480impl GgmlType for BlockQ5_1 {
481 const DTYPE: GgmlDType = GgmlDType::Q5_1;
482 const BLCK_SIZE: usize = QK5_1;
483 type VecDotType = BlockQ8_1;
484
485 fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
486 Self::vec_dot_unopt(n, xs, ys)
487 }
488
489 fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
490 let qk = Self::BLCK_SIZE;
491 debug_assert!(
492 n.is_multiple_of(qk),
493 "vec_dot_q5_1_q8_1: {n} is not divisible by {qk}"
494 );
495 debug_assert!(
496 (n / qk).is_multiple_of(2),
497 "vec_dot_q5_1_q8_1: {n}, nb is not divisible by 2"
498 );
499
500 let mut sumf = 0f32;
502
503 for (xs, ys) in xs.iter().zip(ys.iter()) {
504 let qh = LittleEndian::read_u32(&xs.qh);
505 let mut sumi = 0i32;
506
507 for j in 0..Self::BLCK_SIZE / 2 {
508 let xh_0 = ((qh >> j) << 4) & 0x10;
509 let xh_1 = (qh >> (j + 12)) & 0x10;
510
511 let x0 = (xs.qs[j] as i32 & 0xF) | xh_0 as i32;
512 let x1 = (xs.qs[j] as i32 >> 4) | xh_1 as i32;
513
514 sumi += (x0 * ys.qs[j] as i32) + (x1 * ys.qs[j + Self::BLCK_SIZE / 2] as i32);
515 }
516
517 sumf += sumi as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d)
518 + f16::to_f32(xs.m) * f16::to_f32(ys.s)
519 }
520 sumf
521 }
522
523 fn from_float(xs: &[f32], ys: &mut [Self]) {
524 let qk = Self::BLCK_SIZE;
526 debug_assert_eq!(
527 ys.len() * qk,
528 xs.len(),
529 "size mismatch {} {} {}",
530 xs.len(),
531 ys.len(),
532 qk,
533 );
534 for (i, ys) in ys.iter_mut().enumerate() {
535 let xs = &xs[i * qk..(i + 1) * qk];
536
537 let mut min = f32::INFINITY;
538 let mut max = f32::NEG_INFINITY;
539 for &x in xs.iter() {
540 min = f32::min(x, min);
541 max = f32::max(x, max);
542 }
543 let d = (max - min) / ((1 << 5) - 1) as f32;
544 let id = if d != 0f32 { 1. / d } else { 0. };
545 ys.d = f16::from_f32(d);
546 ys.m = f16::from_f32(min);
547
548 let mut qh = 0u32;
549 for (j, q) in ys.qs.iter_mut().take(qk / 2).enumerate() {
550 let x0 = (xs[j] - min) * id;
551 let x1 = (xs[qk / 2 + j] - min) * id;
552
553 let xi0 = (x0 + 0.5) as u8;
554 let xi1 = (x1 + 0.5) as u8;
555
556 *q = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
557 qh |= ((xi0 as u32 & 0x10) >> 4) << j;
559 qh |= ((xi1 as u32 & 0x10) >> 4) << (j + qk / 2);
560 }
561 LittleEndian::write_u32(&mut ys.qh, qh);
562 }
563 }
564
565 fn to_float(xs: &[Self], ys: &mut [f32]) {
567 let k = ys.len();
568 debug_assert!(
569 k.is_multiple_of(QK5_1),
570 "dequantize_row_q5_1: {k} is not divisible by {QK5_1}"
571 );
572
573 let nb = k / QK5_1;
574 for i in 0..nb {
575 let d = xs[i].d.to_f32();
576 let m = xs[i].m.to_f32();
577 let qh: u32 = LittleEndian::read_u32(&xs[i].qh);
578
579 for j in 0..(QK5_1 / 2) {
580 let xh_0 = (((qh >> j) << 4) & 0x10) as u8;
581 let xh_1 = ((qh >> (j + 12)) & 0x10) as u8;
582
583 let x0 = (xs[i].qs[j] & 0x0F) | xh_0;
584 let x1 = (xs[i].qs[j] >> 4) | xh_1;
585
586 ys[i * QK5_1 + j] = (x0 as f32) * d + m;
587 ys[i * QK5_1 + j + QK5_1 / 2] = (x1 as f32) * d + m;
588 }
589 }
590 }
591}
592
593impl GgmlType for BlockQ8_0 {
594 const DTYPE: GgmlDType = GgmlDType::Q8_0;
595 const BLCK_SIZE: usize = QK8_0;
596 type VecDotType = BlockQ8_0;
597
598 fn to_float(xs: &[Self], ys: &mut [f32]) {
600 let k = ys.len();
601 debug_assert!(
602 k.is_multiple_of(QK8_0),
603 "dequantize_row_q8_0: {k} is not divisible by {QK8_0}"
604 );
605
606 let nb = k / QK8_0;
607
608 for i in 0..nb {
609 let d = xs[i].d.to_f32();
610
611 for j in 0..QK8_0 {
612 ys[i * QK8_0 + j] = xs[i].qs[j] as f32 * d;
613 }
614 }
615 }
616
617 fn from_float(xs: &[f32], ys: &mut [Self]) {
618 let k = xs.len();
620 debug_assert!(
621 k.is_multiple_of(Self::BLCK_SIZE),
622 "{k} is not divisible by {}",
623 Self::BLCK_SIZE
624 );
625 debug_assert_eq!(
626 ys.len(),
627 k / Self::BLCK_SIZE,
628 "size mismatch {} {} {}",
629 xs.len(),
630 ys.len(),
631 Self::BLCK_SIZE
632 );
633 for (i, ys) in ys.iter_mut().enumerate() {
634 let mut amax = 0f32;
635 let xs = &xs[i * Self::BLCK_SIZE..(i + 1) * Self::BLCK_SIZE];
636 for &x in xs.iter() {
637 amax = amax.max(x.abs())
638 }
639 let d = amax / ((1 << 7) - 1) as f32;
640 let id = if d != 0f32 { 1. / d } else { 0. };
641 ys.d = f16::from_f32(d);
642 for (y, &x) in ys.qs.iter_mut().zip(xs.iter()) {
643 *y = f32::round(x * id) as i8
644 }
645 }
646 }
647
648 #[allow(unreachable_code)]
649 fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
650 #[cfg(target_feature = "avx2")]
651 return super::avx::vec_dot_q8_0_q8_0(n, xs, ys);
652
653 #[cfg(target_feature = "neon")]
654 return super::neon::vec_dot_q8_0_q8_0(n, xs, ys);
655
656 #[cfg(target_feature = "simd128")]
657 return super::simd128::vec_dot_q8_0_q8_0(n, xs, ys);
658
659 Self::vec_dot_unopt(n, xs, ys)
660 }
661
662 fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
663 debug_assert!(
664 n.is_multiple_of(QK8_0),
665 "vec_dot_q8_0_q8_0: {n} is not divisible by {QK8_0}"
666 );
667
668 let mut sumf = 0f32;
670 for (xs, ys) in xs.iter().zip(ys.iter()) {
671 let sum_i = xs
672 .qs
673 .iter()
674 .zip(ys.qs.iter())
675 .map(|(&x, &y)| x as i32 * y as i32)
676 .sum::<i32>();
677 sumf += sum_i as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d)
678 }
679 sumf
680 }
681}
682
683impl GgmlType for BlockQ8_1 {
684 const DTYPE: GgmlDType = GgmlDType::Q8_1;
685 const BLCK_SIZE: usize = QK8_1;
686 type VecDotType = BlockQ8_1;
687
688 fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
689 Self::vec_dot_unopt(n, xs, ys)
690 }
691
692 fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
693 debug_assert!(
694 n.is_multiple_of(QK8_1),
695 "vec_dot_q8_1_q8_1: {n} is not divisible by {QK8_1}"
696 );
697
698 let mut sumf = 0f32;
700 for (xs, ys) in xs.iter().zip(ys.iter()) {
701 let sum_i = xs
702 .qs
703 .iter()
704 .zip(ys.qs.iter())
705 .map(|(&x, &y)| x as i32 * y as i32)
706 .sum::<i32>();
707 sumf += sum_i as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d)
708 }
709 sumf
710 }
711
712 fn from_float(xs: &[f32], ys: &mut [Self]) {
713 debug_assert_eq!(
715 ys.len() * Self::BLCK_SIZE,
716 xs.len(),
717 "size mismatch {} {} {}",
718 xs.len(),
719 ys.len(),
720 Self::BLCK_SIZE
721 );
722 for (i, ys) in ys.iter_mut().enumerate() {
723 let mut amax = 0f32;
724 let xs = &xs[i * Self::BLCK_SIZE..(i + 1) * Self::BLCK_SIZE];
725 for &x in xs.iter() {
726 amax = amax.max(x.abs())
727 }
728 let d = amax / ((1 << 7) - 1) as f32;
729 let id = if d != 0f32 { 1. / d } else { 0. };
730 ys.d = f16::from_f32(d);
731 let mut sum = 0i32;
732 for j in 0..Self::BLCK_SIZE / 2 {
733 let v0 = xs[j] * id;
734 let v1 = xs[j + Self::BLCK_SIZE / 2] * id;
735 ys.qs[j] = f32::round(v0) as i8;
736 ys.qs[j + Self::BLCK_SIZE / 2] = f32::round(v1) as i8;
737 sum += ys.qs[j] as i32 + ys.qs[j + Self::BLCK_SIZE / 2] as i32;
738 }
739 ys.s = f16::from_f32(sum as f32) * ys.d;
740 }
741 }
742
743 fn to_float(_xs: &[Self], _ys: &mut [f32]) {
744 unimplemented!("no support for vec-dot on Q8_1")
745 }
746}
747
748impl GgmlType for BlockQ2K {
749 const DTYPE: GgmlDType = GgmlDType::Q2K;
750 const BLCK_SIZE: usize = QK_K;
751 type VecDotType = BlockQ8K;
752
753 #[allow(unreachable_code)]
754 fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
755 #[cfg(target_feature = "avx2")]
756 return super::avx::vec_dot_q2k_q8k(n, xs, ys);
757
758 #[cfg(target_feature = "neon")]
759 return super::neon::vec_dot_q2k_q8k(n, xs, ys);
760
761 #[cfg(target_feature = "simd128")]
762 return super::simd128::vec_dot_q2k_q8k(n, xs, ys);
763
764 Self::vec_dot_unopt(n, xs, ys)
765 }
766
767 fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
768 debug_assert!(
769 n.is_multiple_of(QK_K),
770 "vec_dot_q2k_q8k: {n} is not divisible by {QK_K}"
771 );
772
773 let mut sumf = 0.0;
774 for (x, y) in xs.iter().zip(ys.iter()) {
775 let mut q2: &[_] = &x.qs;
776 let mut q8: &[_] = &y.qs;
777 let sc = &x.scales;
778
779 let mut summs = 0;
780 for (bsum, scale) in y.bsums.iter().zip(sc) {
781 summs += *bsum as i32 * ((scale >> 4) as i32);
782 }
783
784 let dall = y.d * x.d.to_f32();
785 let dmin = y.d * x.dmin.to_f32();
786
787 let mut isum = 0;
788 let mut is = 0;
789 for _ in 0..(QK_K / 128) {
790 let mut shift = 0;
791 for _ in 0..4 {
792 let d = (sc[is] & 0xF) as i32;
793 is += 1;
794 let mut isuml = 0;
795 for l in 0..16 {
796 isuml += q8[l] as i32 * (((q2[l] >> shift) & 3) as i32);
797 }
798 isum += d * isuml;
799 let d = (sc[is] & 0xF) as i32;
800 is += 1;
801 isuml = 0;
802 for l in 16..32 {
803 isuml += q8[l] as i32 * (((q2[l] >> shift) & 3) as i32);
804 }
805 isum += d * isuml;
806 shift += 2;
807 q8 = &q8[32..];
809 }
810 q2 = &q2[32..];
812 }
813 sumf += dall * isum as f32 - dmin * summs as f32;
814 }
815
816 sumf
817 }
818
819 fn from_float(xs: &[f32], ys: &mut [Self]) {
821 const Q4SCALE: f32 = 15.0;
822
823 for (block, x) in group_for_quantization(xs, ys) {
824 let mut mins: [f32; QK_K / 16] = [0.0; QK_K / 16];
826 let mut scales: [f32; QK_K / 16] = [0.0; QK_K / 16];
827
828 for (j, x_scale_slice) in x.chunks(16).enumerate() {
829 (scales[j], mins[j]) = make_qkx1_quants(3, 5, x_scale_slice);
830 }
831 let max_scale = scales.iter().fold(0.0, |max, &val| val.max(max));
833 let max_min = mins.iter().fold(0.0, |max, &val| val.max(max));
834
835 if max_scale > 0.0 {
836 let iscale = Q4SCALE / max_scale;
837 for (j, scale) in scales.iter().enumerate().take(QK_K / 16) {
838 block.scales[j] = nearest_int(iscale * scale) as u8;
839 }
840 block.d = f16::from_f32(max_scale / Q4SCALE);
841 } else {
842 for j in 0..QK_K / 16 {
843 block.scales[j] = 0;
844 }
845 block.d = f16::from_f32(0.0);
846 }
847
848 if max_min > 0.0 {
849 let iscale = Q4SCALE / max_min;
850 for (j, scale) in block.scales.iter_mut().enumerate() {
851 let l = nearest_int(iscale * mins[j]) as u8;
852 *scale |= l << 4;
853 }
854 block.dmin = f16::from_f32(max_min / Q4SCALE);
855 } else {
856 block.dmin = f16::from_f32(0.0);
857 }
858
859 let mut big_l: [u8; QK_K] = [0; QK_K];
860
861 for j in 0..QK_K / 16 {
862 let d = block.d.to_f32() * (block.scales[j] & 0xF) as f32;
863 if d == 0.0 {
864 continue;
865 }
866 let dm = block.dmin.to_f32() * (block.scales[j] >> 4) as f32;
867 for ii in 0..16 {
868 let ll = nearest_int((x[16 * j + ii] + dm) / d).clamp(0, 3);
869 big_l[16 * j + ii] = ll as u8;
870 }
871 }
872
873 for j in (0..QK_K).step_by(128) {
874 for ll in 0..32 {
875 block.qs[j / 4 + ll] = big_l[j + ll]
876 | (big_l[j + ll + 32] << 2)
877 | (big_l[j + ll + 64] << 4)
878 | (big_l[j + ll + 96] << 6);
879 }
880 }
881 }
882 }
883
884 fn from_float_imatrix(xs: &[f32], ys: &mut [Self], imatrix_weights: &[f32], n_per_row: usize) {
885 for (sblk_idx, (block, x)) in group_for_quantization(xs, ys).into_iter().enumerate() {
886 let mut mins: [f32; QK_K / 16] = [0.0; QK_K / 16];
887 let mut scales: [f32; QK_K / 16] = [0.0; QK_K / 16];
888 let mut weights: [f32; 16] = [0.0; 16];
889 let mut sw: [f32; QK_K / 16] = [0.0; QK_K / 16];
890 let mut ls: [u8; QK_K / 16] = [0; QK_K / 16];
891 let mut lm: [u8; QK_K / 16] = [0; QK_K / 16];
892
893 let sum_x2 = x.iter().map(|x| x * x).sum::<f32>();
894 let sigma2 = sum_x2 / QK_K as f32;
895 for (j, x_scale_slice) in x.chunks_exact(16).enumerate() {
896 for (l, (w_elem, x_elem)) in weights.iter_mut().zip(x_scale_slice).enumerate() {
897 let imatrix_row = sblk_idx % (n_per_row / QK_K);
898 let imatrix_w = imatrix_weights[imatrix_row * QK_K + 16 * j + l];
899 *w_elem = imatrix_w * (sigma2 + x_elem * x_elem).sqrt();
900 }
901 let sumw = weights.iter().sum::<f32>();
902 sw[j] = sumw;
903 (scales[j], mins[j]) =
904 make_qkx3_quants(3, x_scale_slice, Some(&weights), -0.9, 0.05, 36, false);
905 }
906
907 let d_block = make_qp_quants(QK_K / 16, 15, &scales, &mut ls, &sw);
908 let m_block = make_qp_quants(QK_K / 16, 15, &mins, &mut lm, &sw);
909
910 block.d = f16::from_f32(d_block);
911 block.dmin = f16::from_f32(m_block);
912
913 for j in 0..QK_K / 16 {
914 block.scales[j] = ls[j] | (lm[j] << 4);
915 }
916
917 let mut big_l: [u8; QK_K] = [0; QK_K];
918
919 for j in 0..QK_K / 16 {
920 let d = block.d.to_f32() * (block.scales[j] & 0xF) as f32;
921 if d == 0.0 {
922 continue;
923 }
924 let dm = block.dmin.to_f32() * (block.scales[j] >> 4) as f32;
925 for ii in 0..16 {
926 let ll = nearest_int((x[16 * j + ii] + dm) / d).clamp(0, 3);
927 big_l[16 * j + ii] = ll as u8;
928 }
929 }
930
931 for j in (0..QK_K).step_by(128) {
932 for ll in 0..32 {
933 block.qs[j / 4 + ll] = big_l[j + ll]
934 | (big_l[j + ll + 32] << 2)
935 | (big_l[j + ll + 64] << 4)
936 | (big_l[j + ll + 96] << 6);
937 }
938 }
939 }
940 }
941 fn to_float(xs: &[Self], ys: &mut [f32]) {
943 for (block, y) in group_for_dequantization(xs, ys) {
944 let d = block.d.to_f32();
945 let min = block.dmin.to_f32();
946
947 let mut is = 0;
948
949 for (y_block, qs) in y.chunks_exact_mut(128).zip(block.qs.chunks_exact(32)) {
950 let mut shift = 0;
952 let mut y_block_index = 0;
953 for _j in 0..4 {
954 let sc = block.scales[is];
955 is += 1;
956 let dl = d * (sc & 0xF) as f32;
957 let ml = min * (sc >> 4) as f32;
958 for q in &qs[..16] {
959 let y = dl * ((q >> shift) & 3) as f32 - ml;
960 y_block[y_block_index] = y;
961 y_block_index += 1;
962 }
963
964 let sc = block.scales[is];
965 is += 1;
966 let dl = d * (sc & 0xF) as f32;
967 let ml = min * (sc >> 4) as f32;
968 for q in &qs[16..] {
969 let y = dl * ((q >> shift) & 3) as f32 - ml;
970 y_block[y_block_index] = y;
971 y_block_index += 1;
972 }
973
974 shift += 2;
975 }
976 }
977 }
978 }
979}
980
981impl GgmlType for BlockQ3K {
982 const DTYPE: GgmlDType = GgmlDType::Q3K;
983 const BLCK_SIZE: usize = QK_K;
984 type VecDotType = BlockQ8K;
985
986 #[allow(unreachable_code)]
987 fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
988 #[cfg(target_feature = "avx2")]
989 return super::avx::vec_dot_q3k_q8k(n, xs, ys);
990
991 #[cfg(target_feature = "neon")]
992 return super::neon::vec_dot_q3k_q8k(n, xs, ys);
993
994 Self::vec_dot_unopt(n, xs, ys)
995 }
996
997 fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
998 debug_assert!(
999 n.is_multiple_of(QK_K),
1000 "vec_dot_q3k_q8k: {n} is not divisible by {QK_K}"
1001 );
1002
1003 const KMASK1: u32 = 0x03030303;
1004 const KMASK2: u32 = 0x0f0f0f0f;
1005
1006 let mut aux8: [i8; QK_K] = [0; QK_K];
1007 let mut aux16: [i16; 8] = [0; 8];
1008 let mut sums: [f32; 8] = [0.0; 8];
1009 let mut aux32: [i32; 8] = [0; 8];
1010
1011 let mut auxs: [u32; 4] = [0; 4];
1012
1013 for (x, y) in xs.iter().zip(ys.iter()) {
1014 let mut q3: &[u8] = &x.qs;
1015 let hmask: &[u8] = &x.hmask;
1016 let mut q8: &[i8] = &y.qs;
1017
1018 aux32.fill(0);
1019 let mut a = &mut aux8[..];
1020
1021 let mut m = 1;
1022 for _ in 0..QK_K / 128 {
1024 a.iter_mut()
1025 .take(32)
1026 .zip(q3)
1027 .for_each(|(a_val, q3_val)| *a_val = (q3_val & 3) as i8);
1028 a.iter_mut()
1029 .take(32)
1030 .zip(hmask)
1031 .for_each(|(a_val, hmask_val)| {
1032 *a_val -= if hmask_val & m != 0 { 0 } else { 4 }
1033 });
1034 a = &mut a[32..];
1035 m <<= 1;
1036
1037 a.iter_mut()
1038 .take(32)
1039 .zip(q3)
1040 .for_each(|(a_val, q3_val)| *a_val = ((q3_val >> 2) & 3) as i8);
1041 a.iter_mut()
1042 .take(32)
1043 .zip(hmask)
1044 .for_each(|(a_val, hmask_val)| {
1045 *a_val -= if hmask_val & m != 0 { 0 } else { 4 }
1046 });
1047 a = &mut a[32..];
1048 m <<= 1;
1049
1050 a.iter_mut()
1051 .take(32)
1052 .zip(q3)
1053 .for_each(|(a_val, q3_val)| *a_val = ((q3_val >> 4) & 3) as i8);
1054 a.iter_mut()
1055 .take(32)
1056 .zip(hmask)
1057 .for_each(|(a_val, hmask_val)| {
1058 *a_val -= if hmask_val & m != 0 { 0 } else { 4 }
1059 });
1060 a = &mut a[32..];
1061 m <<= 1;
1062
1063 a.iter_mut()
1064 .take(32)
1065 .zip(q3)
1066 .for_each(|(a_val, q3_val)| *a_val = ((q3_val >> 6) & 3) as i8);
1067 a.iter_mut()
1068 .take(32)
1069 .zip(hmask)
1070 .for_each(|(a_val, hmask_val)| {
1071 *a_val -= if hmask_val & m != 0 { 0 } else { 4 }
1072 });
1073 a = &mut a[32..];
1074 m <<= 1;
1075 q3 = &q3[32..];
1076 }
1077
1078 a = &mut aux8[..];
1079
1080 LittleEndian::read_u32_into(&x.scales, &mut auxs[0..3]);
1081
1082 let tmp = auxs[2];
1083 auxs[2] = ((auxs[0] >> 4) & KMASK2) | (((tmp >> 4) & KMASK1) << 4);
1084 auxs[3] = ((auxs[1] >> 4) & KMASK2) | (((tmp >> 6) & KMASK1) << 4);
1085 auxs[0] = (auxs[0] & KMASK2) | (((tmp) & KMASK1) << 4);
1086 auxs[1] = (auxs[1] & KMASK2) | (((tmp >> 2) & KMASK1) << 4);
1087
1088 for aux in auxs {
1089 for scale in aux.to_le_bytes() {
1090 let scale = i8::from_be_bytes([scale]);
1091 for l in 0..8 {
1092 aux16[l] = q8[l] as i16 * a[l] as i16;
1093 }
1094 for l in 0..8 {
1095 aux32[l] += (scale as i32 - 32) * aux16[l] as i32;
1096 }
1097 q8 = &q8[8..];
1098 a = &mut a[8..];
1099
1100 for l in 0..8 {
1101 aux16[l] = q8[l] as i16 * a[l] as i16;
1102 }
1103 for l in 0..8 {
1104 aux32[l] += (scale as i32 - 32) * aux16[l] as i32;
1105 }
1106 q8 = &q8[8..];
1107 a = &mut a[8..];
1108 }
1109 }
1110 let d = x.d.to_f32() * y.d;
1111 for l in 0..8 {
1112 sums[l] += d * aux32[l] as f32;
1113 }
1114 }
1115
1116 sums.iter().sum()
1117 }
1118
1119 fn from_float(xs: &[f32], ys: &mut [Self]) {
1120 for (block, x) in group_for_quantization(xs, ys) {
1121 let mut scales: [f32; QK_K / 16] = [0.0; QK_K / 16];
1122 for (j, x_scale_slice) in x.chunks_exact(16).enumerate() {
1123 scales[j] = make_q3_quants(x_scale_slice, 4, true);
1124 }
1125
1126 let mut max_scale: f32 = 0.0;
1128 for &scale in scales.iter() {
1129 if scale.abs() > max_scale.abs() {
1130 max_scale = scale;
1131 }
1132 }
1133
1134 block.scales.fill(0);
1135
1136 if max_scale != 0.0 {
1137 let iscale = -32.0 / max_scale;
1138 for (j, scale) in scales.iter().enumerate() {
1139 let l_val = nearest_int(iscale * scale);
1140 let l_val = l_val.clamp(-32, 31) + 32;
1141 if j < 8 {
1142 block.scales[j] = (l_val & 0xF) as u8;
1143 } else {
1144 block.scales[j - 8] |= ((l_val & 0xF) << 4) as u8;
1145 }
1146 let l_val = l_val >> 4;
1147 block.scales[j % 4 + 8] |= (l_val << (2 * (j / 4))) as u8;
1148 }
1149 block.d = f16::from_f32(1.0 / iscale);
1150 } else {
1151 block.d = f16::from_f32(0.0);
1152 }
1153
1154 let mut l: [i8; QK_K] = [0; QK_K];
1155
1156 for j in 0..QK_K / 16 {
1157 let sc = if j < 8 {
1158 block.scales[j] & 0xF
1159 } else {
1160 block.scales[j - 8] >> 4
1161 };
1162 let sc = (sc | (((block.scales[8 + j % 4] >> (2 * (j / 4))) & 3) << 4)) as i8 - 32;
1163 let d = block.d.to_f32() * sc as f32;
1164 if d != 0.0 {
1165 for ii in 0..16 {
1166 let l_val = nearest_int(x[16 * j + ii] / d);
1167 l[16 * j + ii] = (l_val.clamp(-4, 3) + 4) as i8;
1168 }
1169 }
1170 }
1171
1172 block.hmask.fill(0);
1173 let mut m = 0;
1174 let mut hm = 1;
1175
1176 for ll in l.iter_mut() {
1177 if *ll > 3 {
1178 block.hmask[m] |= hm;
1179 *ll -= 4;
1180 }
1181 m += 1;
1182 if m == QK_K / 8 {
1183 m = 0;
1184 hm <<= 1;
1185 }
1186 }
1187
1188 for j in (0..QK_K).step_by(128) {
1189 for l_val in 0..32 {
1190 block.qs[j / 4 + l_val] = (l[j + l_val]
1191 | (l[j + l_val + 32] << 2)
1192 | (l[j + l_val + 64] << 4)
1193 | (l[j + l_val + 96] << 6))
1194 as u8;
1195 }
1196 }
1197 }
1198 }
1199
1200 fn from_float_imatrix(xs: &[f32], ys: &mut [Self], imatrix_weights: &[f32], n_per_row: usize) {
1201 for (sblk_idx, (block, x)) in group_for_quantization(xs, ys).into_iter().enumerate() {
1202 let mut scales: [f32; QK_K / 16] = [0.0; QK_K / 16];
1203 let mut weights: [f32; 16] = [0.0; 16];
1204 let mut sw: [f32; QK_K / 16] = [0.0; QK_K / 16];
1205 let mut ls: [i8; QK_K / 16] = [0; QK_K / 16];
1206 let mut l: [i8; QK_K] = [0; QK_K];
1207
1208 let sum_x2 = x.iter().map(|x| x * x).sum::<f32>();
1209 let sigma2 = 2. * sum_x2 / QK_K as f32;
1210
1211 for (j, x_scale_slice) in x.chunks_exact(16).enumerate() {
1212 for (l_idx, (w_elem, x_elem)) in weights.iter_mut().zip(x_scale_slice).enumerate() {
1213 let imatrix_row = sblk_idx % (n_per_row / QK_K);
1214 let imatrix_w = imatrix_weights[imatrix_row * QK_K + 16 * j + l_idx];
1215 *w_elem = imatrix_w * (sigma2 + x_elem * x_elem).sqrt();
1216 }
1217 let sumw = weights.iter().sum::<f32>();
1218 sw[j] = sumw;
1219 scales[j] = unsafe {
1220 make_qx_quants(
1221 16,
1222 4,
1223 x_scale_slice.as_ptr(),
1224 l.as_mut_ptr().add(16 * j),
1225 1,
1226 weights.as_ptr(),
1227 )
1228 };
1229 }
1230
1231 block.scales.fill(0);
1232 let d_block = unsafe {
1233 make_qx_quants(
1234 QK_K / 16,
1235 32,
1236 scales.as_ptr(),
1237 ls.as_mut_ptr(),
1238 1,
1239 sw.as_ptr(),
1240 )
1241 };
1242 block.d = f16::from_f32(d_block);
1243 for (j, l_val) in ls.iter().enumerate().take(QK_K / 16) {
1244 if j < 8 {
1245 block.scales[j] = (l_val & 0xF) as u8;
1246 } else {
1247 block.scales[j - 8] |= ((l_val & 0xF) << 4) as u8;
1248 }
1249 let l_val = l_val >> 4;
1250 block.scales[j % 4 + 8] |= (l_val << (2 * (j / 4))) as u8;
1251 }
1252
1253 for j in 0..QK_K / 16 {
1254 let sc = if j < 8 {
1255 block.scales[j] & 0xF
1256 } else {
1257 block.scales[j - 8] >> 4
1258 };
1259 let sc = (sc | (((block.scales[8 + j % 4] >> (2 * (j / 4))) & 3) << 4)) as i8 - 32;
1260 let d = block.d.to_f32() * sc as f32;
1261 if d != 0.0 {
1262 for ii in 0..16 {
1263 let l_val = nearest_int(x[16 * j + ii] / d);
1264 l[16 * j + ii] = (l_val.clamp(-4, 3) + 4) as i8;
1265 }
1266 }
1267 }
1268
1269 block.hmask.fill(0);
1270 let mut m = 0;
1271 let mut hm = 1;
1272
1273 for ll in l.iter_mut() {
1274 if *ll > 3 {
1275 block.hmask[m] |= hm;
1276 *ll -= 4;
1277 }
1278 m += 1;
1279 if m == QK_K / 8 {
1280 m = 0;
1281 hm <<= 1;
1282 }
1283 }
1284
1285 for j in (0..QK_K).step_by(128) {
1286 for l_val in 0..32 {
1287 block.qs[j / 4 + l_val] = (l[j + l_val]
1288 | (l[j + l_val + 32] << 2)
1289 | (l[j + l_val + 64] << 4)
1290 | (l[j + l_val + 96] << 6))
1291 as u8;
1292 }
1293 }
1294 }
1295 }
1296
1297 fn to_float(xs: &[Self], ys: &mut [f32]) {
1299 const KMASK1: u32 = 0x03030303;
1300 const KMASK2: u32 = 0x0f0f0f0f;
1301
1302 for (block, y) in group_for_dequantization(xs, ys) {
1303 let mut aux = [0; 4];
1305 LittleEndian::read_u32_into(&block.scales, &mut aux[0..3]);
1306
1307 let tmp = aux[2];
1308 aux[2] = ((aux[0] >> 4) & KMASK2) | (((tmp >> 4) & KMASK1) << 4);
1309 aux[3] = ((aux[1] >> 4) & KMASK2) | (((tmp >> 6) & KMASK1) << 4);
1310 aux[0] = (aux[0] & KMASK2) | (((tmp) & KMASK1) << 4);
1311 aux[1] = (aux[1] & KMASK2) | (((tmp >> 2) & KMASK1) << 4);
1312
1313 let scales: &mut [i8] =
1315 unsafe { std::slice::from_raw_parts_mut(aux.as_mut_ptr() as *mut i8, 16) };
1316
1317 let d_all = block.d.to_f32();
1318 let mut m = 1;
1319 let mut is = 0;
1320
1321 for (y, qs) in y.chunks_exact_mut(128).zip(block.qs.chunks_exact(32)) {
1325 let mut shift = 0;
1326 for shift_scoped_y in y.chunks_exact_mut(32) {
1327 for (scale_index, scale_scoped_y) in
1328 shift_scoped_y.chunks_exact_mut(16).enumerate()
1329 {
1330 let dl = d_all * (scales[is] as f32 - 32.0);
1331 for (i, inner_y) in scale_scoped_y.iter_mut().enumerate() {
1332 let new_y = dl
1333 * (((qs[i + 16 * scale_index] >> shift) & 3) as i8
1334 - if (block.hmask[i + 16 * scale_index] & m) == 0 {
1335 4
1336 } else {
1337 0
1338 }) as f32;
1339 *inner_y = new_y;
1340 }
1341 is += 1;
1343 }
1344 shift += 2;
1346 m <<= 1;
1347 }
1348 }
1349 }
1350 }
1351}
1352
1353impl GgmlType for BlockQ4K {
1354 const DTYPE: GgmlDType = GgmlDType::Q4K;
1355 const BLCK_SIZE: usize = QK_K;
1356 type VecDotType = BlockQ8K;
1357
1358 #[allow(unreachable_code)]
1359 fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
1360 #[cfg(target_feature = "avx2")]
1361 return super::avx::vec_dot_q4k_q8k(n, xs, ys);
1362
1363 #[cfg(target_feature = "neon")]
1364 return super::neon::vec_dot_q4k_q8k(n, xs, ys);
1365
1366 #[cfg(target_feature = "simd128")]
1367 return super::simd128::vec_dot_q4k_q8k(n, xs, ys);
1368
1369 Self::vec_dot_unopt(n, xs, ys)
1370 }
1371
1372 fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
1373 debug_assert!(
1374 n.is_multiple_of(QK_K),
1375 "vec_dot_q4k_q8k: {n} is not divisible by {QK_K}"
1376 );
1377
1378 const KMASK1: u32 = 0x3f3f3f3f;
1379 const KMASK2: u32 = 0x0f0f0f0f;
1380 const KMASK3: u32 = 0x03030303;
1381
1382 let mut utmp: [u32; 4] = [0; 4];
1383 let mut scales: [u8; 8] = [0; 8];
1384 let mut mins: [u8; 8] = [0; 8];
1385
1386 let mut aux8: [i8; QK_K] = [0; QK_K];
1387 let mut aux16: [i16; 8] = [0; 8];
1388 let mut sums: [f32; 8] = [0.0; 8];
1389 let mut aux32: [i32; 8] = [0; 8];
1390
1391 let mut sumf = 0.0;
1392 for (y, x) in ys.iter().zip(xs.iter()) {
1393 let q4 = &x.qs;
1394 let q8 = &y.qs;
1395 aux32.fill(0);
1396
1397 let mut a = &mut aux8[..];
1398 let mut q4 = &q4[..];
1399 for _ in 0..QK_K / 64 {
1400 for l in 0..32 {
1401 a[l] = (q4[l] & 0xF) as i8;
1402 }
1403 a = &mut a[32..];
1404 for l in 0..32 {
1405 a[l] = (q4[l] >> 4) as i8;
1406 }
1407 a = &mut a[32..];
1408 q4 = &q4[32..];
1409 }
1410
1411 LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
1412
1413 utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
1414 let uaux = utmp[1] & KMASK1;
1415 utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
1416 utmp[2] = uaux;
1417 utmp[0] &= KMASK1;
1418
1419 LittleEndian::write_u32_into(&utmp[0..2], &mut scales);
1421 LittleEndian::write_u32_into(&utmp[2..4], &mut mins);
1422
1423 let mut sumi = 0;
1424 for j in 0..QK_K / 16 {
1425 sumi += y.bsums[j] as i32 * mins[j / 2] as i32;
1426 }
1427
1428 let mut a = &mut aux8[..];
1429 let mut q8 = &q8[..];
1430
1431 for scale in scales {
1432 let scale = scale as i32;
1433 for _ in 0..4 {
1434 for l in 0..8 {
1435 aux16[l] = q8[l] as i16 * a[l] as i16;
1436 }
1437 for l in 0..8 {
1438 aux32[l] += scale * aux16[l] as i32;
1439 }
1440 q8 = &q8[8..];
1441 a = &mut a[8..];
1442 }
1443 }
1444 let d = x.d.to_f32() * y.d;
1445 for l in 0..8 {
1446 sums[l] += d * aux32[l] as f32;
1447 }
1448 let dmin = x.dmin.to_f32() * y.d;
1449 sumf -= dmin * sumi as f32;
1450 }
1451 sumf + sums.iter().sum::<f32>()
1452 }
1453
1454 fn from_float(xs: &[f32], ys: &mut [Self]) {
1455 for (block, x) in group_for_quantization(xs, ys) {
1456 let mut mins: [f32; QK_K / 32] = [0.0; QK_K / 32];
1457 let mut scales: [f32; QK_K / 32] = [0.0; QK_K / 32];
1458
1459 for (j, x_scale_slice) in x.chunks_exact(32).enumerate() {
1460 (scales[j], mins[j]) = make_qkx1_quants(15, 5, x_scale_slice);
1461 }
1462
1463 let max_scale = scales.iter().fold(0.0, |max, &val| val.max(max));
1465 let max_min = mins.iter().fold(0.0, |max, &val| val.max(max));
1466
1467 let inv_scale = if max_scale > 0.0 {
1468 63.0 / max_scale
1469 } else {
1470 0.0
1471 };
1472 let inv_min = if max_min > 0.0 { 63.0 / max_min } else { 0.0 };
1473
1474 for j in 0..QK_K / 32 {
1475 let ls = nearest_int(inv_scale * scales[j]).min(63) as u8;
1476 let lm = nearest_int(inv_min * mins[j]).min(63) as u8;
1477 if j < 4 {
1478 block.scales[j] = ls;
1479 block.scales[j + 4] = lm;
1480 } else {
1481 block.scales[j + 4] = (ls & 0xF) | ((lm & 0xF) << 4);
1482 block.scales[j - 4] |= (ls >> 4) << 6;
1483 block.scales[j] |= (lm >> 4) << 6;
1484 }
1485 }
1486
1487 block.d = f16::from_f32(max_scale / 63.0);
1488 block.dmin = f16::from_f32(max_min / 63.0);
1489
1490 let mut l: [u8; QK_K] = [0; QK_K];
1491
1492 for j in 0..QK_K / 32 {
1493 let (sc, m) = get_scale_min_k4(j, &block.scales);
1494 let d = block.d.to_f32() * sc as f32;
1495 if d != 0.0 {
1496 let dm = block.dmin.to_f32() * m as f32;
1497 for ii in 0..32 {
1498 let l_val = nearest_int((x[32 * j + ii] + dm) / d);
1499 l[32 * j + ii] = l_val.clamp(0, 15) as u8;
1500 }
1501 }
1502 }
1503
1504 let q = &mut block.qs;
1505 for j in (0..QK_K).step_by(64) {
1506 for l_val in 0..32 {
1507 let offset_index = (j / 64) * 32 + l_val;
1508 q[offset_index] = l[j + l_val] | (l[j + l_val + 32] << 4);
1509 }
1510 }
1511 }
1512 }
1513
1514 fn from_float_imatrix(xs: &[f32], ys: &mut [Self], imatrix_weights: &[f32], n_per_row: usize) {
1515 for (sblk_idx, (block, x)) in group_for_quantization(xs, ys).into_iter().enumerate() {
1516 let mut mins: [f32; QK_K / 32] = [0.0; QK_K / 32];
1517 let mut scales: [f32; QK_K / 32] = [0.0; QK_K / 32];
1518 let mut weights: [f32; 32] = [0.0; 32];
1519 let mut sw: [f32; QK_K / 32] = [0.0; QK_K / 32];
1520 let mut ls: [u8; QK_K / 32] = [0; QK_K / 32];
1521 let mut lm: [u8; QK_K / 32] = [0; QK_K / 32];
1522
1523 let sum_x2 = x.iter().map(|x| x * x).sum::<f32>();
1524 let sigma2 = 2. * sum_x2 / QK_K as f32;
1525
1526 for (j, x_scale_slice) in x.chunks_exact(32).enumerate() {
1527 for (l, (w_elem, x_elem)) in weights.iter_mut().zip(x_scale_slice).enumerate() {
1528 let imatrix_row = sblk_idx % (n_per_row / QK_K);
1529 let imatrix_w = imatrix_weights[imatrix_row * QK_K + 32 * j + l];
1530 *w_elem = imatrix_w * (sigma2 + x_elem * x_elem).sqrt();
1531 }
1532 let sumw = weights.iter().sum::<f32>();
1533 sw[j] = sumw;
1534 (scales[j], mins[j]) =
1535 make_qkx3_quants(15, x_scale_slice, Some(&weights), -0.9, 0.05, 36, false);
1536 }
1537
1538 let d_block = make_qp_quants(QK_K / 32, 63, &scales, &mut ls, &sw);
1539 let m_block = make_qp_quants(QK_K / 32, 63, &mins, &mut lm, &sw);
1540 for j in 0..QK_K / 32 {
1541 let ls_val = ls[j];
1542 let lm_val = lm[j];
1543 if j < 4 {
1544 block.scales[j] = ls_val;
1545 block.scales[j + 4] = lm_val;
1546 } else {
1547 block.scales[j + 4] = (ls_val & 0xF) | ((lm_val & 0xF) << 4);
1548 block.scales[j - 4] |= (ls_val >> 4) << 6;
1549 block.scales[j] |= (lm_val >> 4) << 6;
1550 }
1551 }
1552
1553 block.d = f16::from_f32(d_block);
1554 block.dmin = f16::from_f32(m_block);
1555
1556 let mut l: [u8; QK_K] = [0; QK_K];
1557 for j in 0..QK_K / 32 {
1558 let (sc, m) = get_scale_min_k4(j, &block.scales);
1559 let d = block.d.to_f32() * sc as f32;
1560 if d != 0.0 {
1561 let dm = block.dmin.to_f32() * m as f32;
1562 for ii in 0..32 {
1563 let l_val = nearest_int((x[32 * j + ii] + dm) / d);
1564 l[32 * j + ii] = l_val.clamp(0, 15) as u8;
1565 }
1566 }
1567 }
1568
1569 let q = &mut block.qs;
1570 for j in (0..QK_K).step_by(64) {
1571 for l_val in 0..32 {
1572 let offset_index = (j / 64) * 32 + l_val;
1573 q[offset_index] = l[j + l_val] | (l[j + l_val + 32] << 4);
1574 }
1575 }
1576 }
1577 }
1578 fn to_float(xs: &[Self], ys: &mut [f32]) {
1580 for (block, y) in group_for_dequantization(xs, ys) {
1581 let d = block.d.to_f32();
1582 let min = block.dmin.to_f32();
1583 let q = &block.qs;
1584 let mut is = 0;
1585 let mut ys_index = 0;
1586
1587 for j in (0..QK_K).step_by(64) {
1588 let q = &q[j / 2..j / 2 + 32];
1589 let (sc, m) = get_scale_min_k4(is, &block.scales);
1590 let d1 = d * sc as f32;
1591 let m1 = min * m as f32;
1592 let (sc, m) = get_scale_min_k4(is + 1, &block.scales);
1593 let d2 = d * sc as f32;
1594 let m2 = min * m as f32;
1595 for q in q {
1596 y[ys_index] = d1 * (q & 0xF) as f32 - m1;
1597 ys_index += 1;
1598 }
1599 for q in q {
1600 y[ys_index] = d2 * (q >> 4) as f32 - m2;
1601 ys_index += 1;
1602 }
1603 is += 2;
1604 }
1605 }
1606 }
1607}
1608
1609impl GgmlType for BlockQ5K {
1611 const DTYPE: GgmlDType = GgmlDType::Q5K;
1612 const BLCK_SIZE: usize = QK_K;
1613 type VecDotType = BlockQ8K;
1614
1615 #[allow(unreachable_code)]
1616 fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
1617 #[cfg(target_feature = "avx2")]
1618 return super::avx::vec_dot_q5k_q8k(n, xs, ys);
1619
1620 #[cfg(target_feature = "neon")]
1621 return super::neon::vec_dot_q5k_q8k(n, xs, ys);
1622
1623 Self::vec_dot_unopt(n, xs, ys)
1624 }
1625
1626 fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
1627 debug_assert!(
1628 n.is_multiple_of(QK_K),
1629 "vec_dot_q5k_q8k: {n} is not divisible by {QK_K}"
1630 );
1631
1632 const KMASK1: u32 = 0x3f3f3f3f;
1633 const KMASK2: u32 = 0x0f0f0f0f;
1634 const KMASK3: u32 = 0x03030303;
1635
1636 let mut utmp: [u32; 4] = [0; 4];
1637 let mut scales: [u8; 8] = [0; 8];
1638 let mut mins: [u8; 8] = [0; 8];
1639
1640 let mut aux8: [i8; QK_K] = [0; QK_K];
1641 let mut aux16: [i16; 8] = [0; 8];
1642 let mut sums: [f32; 8] = [0.0; 8];
1643 let mut aux32: [i32; 8] = [0; 8];
1644
1645 let mut sumf = 0.0;
1646 for (y, x) in ys.iter().zip(xs.iter()) {
1647 let q5 = &x.qs;
1648 let hm = &x.qh;
1649 let q8 = &y.qs;
1650 aux32.fill(0);
1651
1652 let mut a = &mut aux8[..];
1653 let mut q5 = &q5[..];
1654 let mut m = 1u8;
1655
1656 for _ in 0..QK_K / 64 {
1657 for l in 0..32 {
1658 a[l] = (q5[l] & 0xF) as i8;
1659 a[l] += if hm[l] & m != 0 { 16 } else { 0 };
1660 }
1661 a = &mut a[32..];
1662 m <<= 1;
1663 for l in 0..32 {
1664 a[l] = (q5[l] >> 4) as i8;
1665 a[l] += if hm[l] & m != 0 { 16 } else { 0 };
1666 }
1667 a = &mut a[32..];
1668 m <<= 1;
1669 q5 = &q5[32..];
1670 }
1671
1672 LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
1673
1674 utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
1675 let uaux = utmp[1] & KMASK1;
1676 utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
1677 utmp[2] = uaux;
1678 utmp[0] &= KMASK1;
1679
1680 LittleEndian::write_u32_into(&utmp[0..2], &mut scales);
1682 LittleEndian::write_u32_into(&utmp[2..4], &mut mins);
1683
1684 let mut sumi = 0;
1685 for j in 0..QK_K / 16 {
1686 sumi += y.bsums[j] as i32 * mins[j / 2] as i32;
1687 }
1688
1689 let mut a = &mut aux8[..];
1690 let mut q8 = &q8[..];
1691
1692 for scale in scales {
1693 let scale = scale as i32;
1694 for _ in 0..4 {
1695 for l in 0..8 {
1696 aux16[l] = q8[l] as i16 * a[l] as i16;
1697 }
1698 for l in 0..8 {
1699 aux32[l] += scale * aux16[l] as i32;
1700 }
1701 q8 = &q8[8..];
1702 a = &mut a[8..];
1703 }
1704 }
1705 let d = x.d.to_f32() * y.d;
1706 for l in 0..8 {
1707 sums[l] += d * aux32[l] as f32;
1708 }
1709 let dmin = x.dmin.to_f32() * y.d;
1710 sumf -= dmin * sumi as f32;
1711 }
1712 sumf + sums.iter().sum::<f32>()
1713 }
1714
1715 fn from_float(xs: &[f32], ys: &mut [Self]) {
1717 for (block, x) in group_for_quantization(xs, ys) {
1718 let mut mins: [f32; QK_K / 32] = [0.0; QK_K / 32];
1719 let mut scales: [f32; QK_K / 32] = [0.0; QK_K / 32];
1720
1721 for (j, x_scale_slice) in x.chunks_exact(32).enumerate() {
1722 (scales[j], mins[j]) = make_qkx1_quants(31, 5, x_scale_slice);
1723 }
1724
1725 let max_scale = scales.iter().fold(0.0, |max, &val| val.max(max));
1727 let max_min = mins.iter().fold(0.0, |max, &val| val.max(max));
1728
1729 let inv_scale = if max_scale > 0.0 {
1730 63.0 / max_scale
1731 } else {
1732 0.0
1733 };
1734 let inv_min = if max_min > 0.0 { 63.0 / max_min } else { 0.0 };
1735 for j in 0..QK_K / 32 {
1736 let ls = nearest_int(inv_scale * scales[j]).min(63) as u8;
1737 let lm = nearest_int(inv_min * mins[j]).min(63) as u8;
1738 if j < 4 {
1739 block.scales[j] = ls;
1740 block.scales[j + 4] = lm;
1741 } else {
1742 block.scales[j + 4] = (ls & 0xF) | ((lm & 0xF) << 4);
1743 block.scales[j - 4] |= (ls >> 4) << 6;
1744 block.scales[j] |= (lm >> 4) << 6;
1745 }
1746 }
1747 block.d = f16::from_f32(max_scale / 63.0);
1748 block.dmin = f16::from_f32(max_min / 63.0);
1749
1750 let mut l: [u8; QK_K] = [0; QK_K];
1751 for j in 0..QK_K / 32 {
1752 let (sc, m) = get_scale_min_k4(j, &block.scales);
1753 let d = block.d.to_f32() * sc as f32;
1754 if d == 0.0 {
1755 continue;
1756 }
1757 let dm = block.dmin.to_f32() * m as f32;
1758 for ii in 0..32 {
1759 let ll = nearest_int((x[32 * j + ii] + dm) / d);
1760 l[32 * j + ii] = ll.clamp(0, 31) as u8;
1761 }
1762 }
1763
1764 let qh = &mut block.qh;
1765 let ql = &mut block.qs;
1766 qh.fill(0);
1767
1768 let mut m1 = 1;
1769 let mut m2 = 2;
1770 for n in (0..QK_K).step_by(64) {
1771 let offset = (n / 64) * 32;
1772 for j in 0..32 {
1773 let mut l1 = l[n + j];
1774 if l1 > 15 {
1775 l1 -= 16;
1776 qh[j] |= m1;
1777 }
1778 let mut l2 = l[n + j + 32];
1779 if l2 > 15 {
1780 l2 -= 16;
1781 qh[j] |= m2;
1782 }
1783 ql[offset + j] = l1 | (l2 << 4);
1784 }
1785 m1 <<= 2;
1786 m2 <<= 2;
1787 }
1788 }
1789 }
1790
1791 fn from_float_imatrix(xs: &[f32], ys: &mut [Self], imatrix_weights: &[f32], n_per_row: usize) {
1792 for (sblk_idx, (block, x)) in group_for_quantization(xs, ys).into_iter().enumerate() {
1793 let mut mins: [f32; QK_K / 32] = [0.0; QK_K / 32];
1794 let mut scales: [f32; QK_K / 32] = [0.0; QK_K / 32];
1795 let mut weights: [f32; 32] = [0.0; 32];
1796 let mut sw: [f32; QK_K / 32] = [0.0; QK_K / 32];
1797 let mut ls: [u8; QK_K / 32] = [0; QK_K / 32];
1798 let mut lm: [u8; QK_K / 32] = [0; QK_K / 32];
1799
1800 let sum_x2 = x.iter().map(|x| x * x).sum::<f32>();
1801 let sigma2 = 2. * sum_x2 / QK_K as f32;
1802
1803 for (j, x_scale_slice) in x.chunks_exact(32).enumerate() {
1804 for (l, (w_elem, x_elem)) in weights.iter_mut().zip(x_scale_slice).enumerate() {
1805 let imatrix_row = sblk_idx % (n_per_row / QK_K);
1806 let imatrix_w = imatrix_weights[imatrix_row * QK_K + 32 * j + l];
1807 *w_elem = imatrix_w * (sigma2 + x_elem * x_elem).sqrt();
1808 }
1809 let sumw = weights.iter().sum::<f32>();
1810 sw[j] = sumw;
1811 (scales[j], mins[j]) =
1812 make_qkx3_quants(31, x_scale_slice, Some(&weights), -0.9, 0.05, 36, false);
1813 }
1814
1815 let d_block = make_qp_quants(QK_K / 32, 63, &scales, &mut ls, &sw);
1816 let m_block = make_qp_quants(QK_K / 32, 63, &mins, &mut lm, &sw);
1817 for j in 0..QK_K / 32 {
1818 let ls_val = ls[j].min(63);
1819 let lm_val = lm[j].min(63);
1820 if j < 4 {
1821 block.scales[j] = ls_val;
1822 block.scales[j + 4] = lm_val;
1823 } else {
1824 block.scales[j + 4] = (ls_val & 0xF) | ((lm_val & 0xF) << 4);
1825 block.scales[j - 4] |= (ls_val >> 4) << 6;
1826 block.scales[j] |= (lm_val >> 4) << 6;
1827 }
1828 }
1829
1830 block.d = f16::from_f32(d_block);
1831 block.dmin = f16::from_f32(m_block);
1832
1833 let mut l: [u8; QK_K] = [0; QK_K];
1834 for j in 0..QK_K / 32 {
1835 let (sc, m) = get_scale_min_k4(j, &block.scales);
1836 let d = block.d.to_f32() * sc as f32;
1837 if d != 0.0 {
1838 let dm = block.dmin.to_f32() * m as f32;
1839 for ii in 0..32 {
1840 let l_val = nearest_int((x[32 * j + ii] + dm) / d);
1841 l[32 * j + ii] = l_val.clamp(0, 31) as u8;
1842 }
1843 }
1844 }
1845
1846 let qh = &mut block.qh;
1847 let ql = &mut block.qs;
1848 qh.fill(0);
1849
1850 let mut m1 = 1;
1851 let mut m2 = 2;
1852 for n in (0..QK_K).step_by(64) {
1853 let offset = (n / 64) * 32;
1854 for j in 0..32 {
1855 let mut l1 = l[n + j];
1856 if l1 > 15 {
1857 l1 -= 16;
1858 qh[j] |= m1;
1859 }
1860 let mut l2 = l[n + j + 32];
1861 if l2 > 15 {
1862 l2 -= 16;
1863 qh[j] |= m2;
1864 }
1865 ql[offset + j] = l1 | (l2 << 4);
1866 }
1867 m1 <<= 2;
1868 m2 <<= 2;
1869 }
1870 }
1871 }
1872
1873 fn to_float(xs: &[Self], ys: &mut [f32]) {
1875 for (block, y) in group_for_dequantization(xs, ys) {
1876 let d = block.d.to_f32();
1877 let min = block.dmin.to_f32();
1878 let ql = &block.qs;
1879 let qh = &block.qh;
1880 let mut is = 0;
1881 let mut u1 = 1;
1882 let mut u2 = 2;
1883 let mut ys_index = 0;
1884
1885 for j in (0..QK_K).step_by(64) {
1886 let ql = &ql[j / 2..j / 2 + 32];
1887 let (sc, m) = get_scale_min_k4(is, &block.scales);
1888 let d1 = d * sc as f32;
1889 let m1 = min * m as f32;
1890 let (sc, m) = get_scale_min_k4(is + 1, &block.scales);
1891 let d2 = d * sc as f32;
1892 let m2 = min * m as f32;
1893 for (ql, qh) in ql.iter().zip(qh) {
1894 let to_add = if qh & u1 != 0 { 16f32 } else { 0f32 };
1895 y[ys_index] = d1 * ((ql & 0xF) as f32 + to_add) - m1;
1896 ys_index += 1;
1897 }
1898 for (ql, qh) in ql.iter().zip(qh) {
1899 let to_add = if qh & u2 != 0 { 16f32 } else { 0f32 };
1900 y[ys_index] = d2 * ((ql >> 4) as f32 + to_add) - m2;
1901 ys_index += 1;
1902 }
1903 is += 2;
1904 u1 <<= 2;
1905 u2 <<= 2;
1906 }
1907 }
1908 }
1909}
1910
1911impl GgmlType for BlockQ6K {
1912 const DTYPE: GgmlDType = GgmlDType::Q6K;
1913 const BLCK_SIZE: usize = QK_K;
1914 type VecDotType = BlockQ8K;
1915
1916 #[allow(unreachable_code)]
1917 fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
1918 #[cfg(target_feature = "avx2")]
1919 return super::avx::vec_dot_q6k_q8k(n, xs, ys);
1920
1921 #[cfg(target_feature = "neon")]
1922 return super::neon::vec_dot_q6k_q8k(n, xs, ys);
1923
1924 #[cfg(target_feature = "simd128")]
1925 return super::simd128::vec_dot_q6k_q8k(n, xs, ys);
1926
1927 Self::vec_dot_unopt(n, xs, ys)
1928 }
1929
1930 fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
1931 debug_assert!(
1932 n.is_multiple_of(QK_K),
1933 "vec_dot_q6k_q8k: {n} is not divisible by {QK_K}"
1934 );
1935
1936 let mut aux8 = [0i8; QK_K];
1937 let mut aux16 = [0i16; 8];
1938 let mut sums = [0f32; 8];
1939 let mut aux32 = [0f32; 8];
1940
1941 for (x, y) in xs.iter().zip(ys.iter()) {
1942 let q4 = &x.ql;
1943 let qh = &x.qh;
1944 let q8 = &y.qs;
1945 aux32.fill(0f32);
1946
1947 for j in (0..QK_K).step_by(128) {
1948 let aux8 = &mut aux8[j..];
1949 let q4 = &q4[j / 2..];
1950 let qh = &qh[j / 4..];
1951 for l in 0..32 {
1952 aux8[l] = (((q4[l] & 0xF) | ((qh[l] & 3) << 4)) as i32 - 32) as i8;
1953 aux8[l + 32] =
1954 (((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) as i32 - 32) as i8;
1955 aux8[l + 64] = (((q4[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) as i32 - 32) as i8;
1956 aux8[l + 96] =
1957 (((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) as i32 - 32) as i8;
1958 }
1959 }
1960
1961 for (j, &scale) in x.scales.iter().enumerate() {
1962 let scale = scale as f32;
1963 let q8 = &q8[16 * j..];
1964 let aux8 = &aux8[16 * j..];
1965 for l in 0..8 {
1966 aux16[l] = q8[l] as i16 * aux8[l] as i16;
1967 }
1968 for l in 0..8 {
1969 aux32[l] += scale * aux16[l] as f32
1970 }
1971 let q8 = &q8[8..];
1972 let aux8 = &aux8[8..];
1973 for l in 0..8 {
1974 aux16[l] = q8[l] as i16 * aux8[l] as i16;
1975 }
1976 for l in 0..8 {
1977 aux32[l] += scale * aux16[l] as f32
1978 }
1979 }
1980
1981 let d = x.d.to_f32() * y.d;
1982 for (sum, &a) in sums.iter_mut().zip(aux32.iter()) {
1983 *sum += a * d;
1984 }
1985 }
1986 sums.iter().sum()
1987 }
1988
1989 fn from_float(xs: &[f32], ys: &mut [Self]) {
1990 debug_assert_eq!(
1991 xs.len(),
1992 ys.len() * Self::BLCK_SIZE,
1993 "quantize_row_q6k: size mismatch {} {} {}",
1994 xs.len(),
1995 ys.len(),
1996 Self::BLCK_SIZE
1997 );
1998 let mut l = [0i8; QK_K];
1999 let mut scales = [0f32; QK_K / 16];
2000 let mut x = xs.as_ptr();
2001 let l = l.as_mut_ptr();
2002 unsafe {
2003 for y in ys.iter_mut() {
2004 let mut max_scale = 0f32;
2005 let mut max_abs_scale = 0f32;
2006 for (ib, scale_) in scales.iter_mut().enumerate() {
2007 let scale =
2008 make_qx_quants(16, 32, x.add(16 * ib), l.add(16 * ib), 1, std::ptr::null());
2009 *scale_ = scale;
2010 let abs_scale = scale.abs();
2011 if abs_scale > max_abs_scale {
2012 max_abs_scale = abs_scale;
2013 max_scale = scale
2014 }
2015 }
2016
2017 let iscale = -128f32 / max_scale;
2018 y.d = f16::from_f32(1.0 / iscale);
2019
2020 for (y_scale, scale) in y.scales.iter_mut().zip(scales.iter()) {
2021 *y_scale = nearest_int(iscale * scale).min(127) as i8
2022 }
2023
2024 for (j, &y_scale) in y.scales.iter().enumerate() {
2025 let d = y.d.to_f32() * y_scale as f32;
2026 if d == 0. {
2027 continue;
2028 }
2029 for ii in 0..16 {
2030 let ll = nearest_int(*x.add(16 * j + ii) / d).clamp(-32, 31);
2031 *l.add(16 * j + ii) = (ll + 32) as i8
2032 }
2033 }
2034
2035 let mut ql = y.ql.as_mut_ptr();
2036 let mut qh = y.qh.as_mut_ptr();
2037
2038 for j in (0..QK_K).step_by(128) {
2039 for l_idx in 0..32 {
2040 let q1 = *l.add(j + l_idx) & 0xF;
2041 let q2 = *l.add(j + l_idx + 32) & 0xF;
2042 let q3 = *l.add(j + l_idx + 64) & 0xF;
2043 let q4 = *l.add(j + l_idx + 96) & 0xF;
2044 *ql.add(l_idx) = (q1 | (q3 << 4)) as u8;
2045 *ql.add(l_idx + 32) = (q2 | (q4 << 4)) as u8;
2046 *qh.add(l_idx) = ((*l.add(j + l_idx) >> 4)
2047 | ((*l.add(j + l_idx + 32) >> 4) << 2)
2048 | ((*l.add(j + l_idx + 64) >> 4) << 4)
2049 | ((*l.add(j + l_idx + 96) >> 4) << 6))
2050 as u8;
2051 }
2052 ql = ql.add(64);
2053 qh = qh.add(32);
2054 }
2055
2056 x = x.add(QK_K)
2057 }
2058 }
2059 }
2060
2061 fn from_float_imatrix(xs: &[f32], ys: &mut [Self], imatrix_weights: &[f32], n_per_row: usize) {
2062 debug_assert_eq!(
2063 xs.len(),
2064 ys.len() * Self::BLCK_SIZE,
2065 "quantize_row_q6k imatrix: size mismatch {} {} {}",
2066 xs.len(),
2067 ys.len(),
2068 Self::BLCK_SIZE
2069 );
2070 let mut l = [0i8; QK_K];
2071 let mut scales = [0f32; QK_K / 16];
2072 let mut x = xs.as_ptr();
2073 let imatrix_weights = imatrix_weights.as_ptr();
2074 let l = l.as_mut_ptr();
2075 unsafe {
2076 for (sblk_idx, y) in ys.iter_mut().enumerate() {
2077 let mut max_scale = 0f32;
2078 let mut max_abs_scale = 0f32;
2079 for (ib, scale_) in scales.iter_mut().enumerate() {
2080 let imatrix_row = sblk_idx % (n_per_row / QK_K);
2081 let scale = make_qx_quants(
2082 16,
2083 32,
2084 x.add(16 * ib),
2085 l.add(16 * ib),
2086 1,
2087 imatrix_weights.add(QK_K * imatrix_row + 16 * ib),
2088 );
2089 *scale_ = scale;
2090 let abs_scale = scale.abs();
2091 if abs_scale > max_abs_scale {
2092 max_abs_scale = abs_scale;
2093 max_scale = scale
2094 }
2095 }
2096
2097 let iscale = -128f32 / max_scale;
2098 y.d = f16::from_f32(1.0 / iscale);
2099
2100 for (y_scale, scale) in y.scales.iter_mut().zip(scales.iter()) {
2101 *y_scale = nearest_int(iscale * scale).min(127) as i8
2102 }
2103
2104 for (j, &y_scale) in y.scales.iter().enumerate() {
2105 let d = y.d.to_f32() * y_scale as f32;
2106 if d == 0. {
2107 continue;
2108 }
2109 for ii in 0..16 {
2110 let ll = nearest_int(*x.add(16 * j + ii) / d).clamp(-32, 31);
2111 *l.add(16 * j + ii) = (ll + 32) as i8
2112 }
2113 }
2114
2115 let mut ql = y.ql.as_mut_ptr();
2116 let mut qh = y.qh.as_mut_ptr();
2117
2118 for j in (0..QK_K).step_by(128) {
2119 for l_idx in 0..32 {
2120 let q1 = *l.add(j + l_idx) & 0xF;
2121 let q2 = *l.add(j + l_idx + 32) & 0xF;
2122 let q3 = *l.add(j + l_idx + 64) & 0xF;
2123 let q4 = *l.add(j + l_idx + 96) & 0xF;
2124 *ql.add(l_idx) = (q1 | (q3 << 4)) as u8;
2125 *ql.add(l_idx + 32) = (q2 | (q4 << 4)) as u8;
2126 *qh.add(l_idx) = ((*l.add(j + l_idx) >> 4)
2127 | ((*l.add(j + l_idx + 32) >> 4) << 2)
2128 | ((*l.add(j + l_idx + 64) >> 4) << 4)
2129 | ((*l.add(j + l_idx + 96) >> 4) << 6))
2130 as u8;
2131 }
2132 ql = ql.add(64);
2133 qh = qh.add(32);
2134 }
2135
2136 x = x.add(QK_K)
2137 }
2138 }
2139 }
2140
2141 fn to_float(xs: &[Self], ys: &mut [f32]) {
2143 let k = ys.len();
2144 debug_assert!(
2145 k.is_multiple_of(QK_K),
2146 "dequantize_row_q6k: {k} is not divisible by {QK_K}"
2147 );
2148
2149 for (idx_x, x) in xs.iter().enumerate() {
2150 let d = x.d.to_f32();
2151 let ql = &x.ql;
2152 let qh = &x.qh;
2153 let sc = &x.scales;
2154 for n in (0..QK_K).step_by(128) {
2155 let idx = n / 128;
2156 let ys = &mut ys[idx_x * QK_K + n..];
2157 let sc = &sc[8 * idx..];
2158 let ql = &ql[64 * idx..];
2159 let qh = &qh[32 * idx..];
2160 for l in 0..32 {
2161 let is = l / 16;
2162 let q1 = ((ql[l] & 0xF) | ((qh[l] & 3) << 4)) as i8 - 32;
2163 let q2 = ((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) as i8 - 32;
2164 let q3 = ((ql[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) as i8 - 32;
2165 let q4 = ((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) as i8 - 32;
2166 ys[l] = d * sc[is] as f32 * q1 as f32;
2167 ys[l + 32] = d * sc[is + 2] as f32 * q2 as f32;
2168 ys[l + 64] = d * sc[is + 4] as f32 * q3 as f32;
2169 ys[l + 96] = d * sc[is + 6] as f32 * q4 as f32;
2170 }
2171 }
2172 }
2173 }
2174}
2175
2176impl GgmlType for BlockQ8K {
2177 const DTYPE: GgmlDType = GgmlDType::Q8K;
2178 const BLCK_SIZE: usize = QK_K;
2179 type VecDotType = BlockQ8K;
2180
2181 #[allow(unreachable_code)]
2182 fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
2183 #[cfg(target_feature = "avx2")]
2184 return super::avx::vec_dot_q8k_q8k(n, xs, ys);
2185
2186 #[cfg(target_feature = "neon")]
2187 return super::neon::vec_dot_q8k_q8k(n, xs, ys);
2188
2189 #[cfg(target_feature = "simd128")]
2190 return super::simd128::vec_dot_q8k_q8k(n, xs, ys);
2191
2192 Self::vec_dot_unopt(n, xs, ys)
2193 }
2194
2195 fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
2196 debug_assert!(
2197 n.is_multiple_of(QK_K),
2198 "vec_dot_q8k_q8k: {n} is not divisible by {QK_K}"
2199 );
2200 let mut sumf = 0f32;
2202 for (xs, ys) in xs.iter().zip(ys.iter()) {
2203 let sum_i = xs
2204 .qs
2205 .iter()
2206 .zip(ys.qs.iter())
2207 .map(|(&x, &y)| x as i32 * y as i32)
2208 .sum::<i32>();
2209 sumf += sum_i as f32 * xs.d * ys.d
2210 }
2211 sumf
2212 }
2213
2214 fn from_float(xs: &[f32], ys: &mut [Self]) {
2215 let k = xs.len();
2216 debug_assert!(
2217 k.is_multiple_of(QK_K),
2218 "quantize_row_q8k: {k} is not divisible by {QK_K}"
2219 );
2220 for (i, y) in ys.iter_mut().enumerate() {
2221 let mut max = 0f32;
2222 let mut amax = 0f32;
2223 let xs = &xs[i * QK_K..(i + 1) * QK_K];
2224 for &x in xs.iter() {
2225 if amax < x.abs() {
2226 amax = x.abs();
2227 max = x;
2228 }
2229 }
2230 if amax == 0f32 {
2231 y.d = 0f32;
2232 y.qs.fill(0)
2233 } else {
2234 let iscale = -128f32 / max;
2235 for (j, q) in y.qs.iter_mut().enumerate() {
2236 let v = (iscale * xs[j]).round();
2239 *q = v.min(127.) as i8
2240 }
2241 for j in 0..QK_K / 16 {
2242 let mut sum = 0i32;
2243 for ii in 0..16 {
2244 sum += y.qs[j * 16 + ii] as i32
2245 }
2246 y.bsums[j] = sum as i16
2247 }
2248 y.d = 1.0 / iscale
2249 }
2250 }
2251 }
2252
2253 fn to_float(xs: &[Self], ys: &mut [f32]) {
2254 let k = ys.len();
2255 debug_assert!(
2256 k.is_multiple_of(QK_K),
2257 "dequantize_row_q8k: {k} is not divisible by {QK_K}"
2258 );
2259 for (i, x) in xs.iter().enumerate() {
2260 for (j, &q) in x.qs.iter().enumerate() {
2261 ys[i * QK_K + j] = x.d * q as f32
2262 }
2263 }
2264 }
2265}
2266
2267pub fn matmul<T: GgmlType>(
2269 (m, k, n): (usize, usize, usize),
2270 lhs: &[f32],
2271 rhs_t: &[T],
2272 dst: &mut [f32],
2273) -> Result<()> {
2274 debug_assert_eq!(
2275 T::BLCK_SIZE,
2276 T::VecDotType::BLCK_SIZE,
2277 "Mismatched block sizes"
2278 );
2279 debug_assert_eq!(
2280 m * k,
2281 lhs.len(),
2282 "unexpected lhs length {} ({m},{k},{n})",
2283 lhs.len()
2284 );
2285 let k_in_blocks = k.div_ceil(T::BLCK_SIZE);
2286
2287 let mut lhs_b = vec![T::VecDotType::zeros(); m * k_in_blocks];
2289 if T::DIRECT_COPY {
2291 T::VecDotType::direct_copy(lhs, &mut lhs_b);
2292 } else {
2293 for row_idx in 0..m {
2294 let lhs_b_mut = &mut lhs_b[row_idx * k_in_blocks..(row_idx + 1) * k_in_blocks];
2295 let lhs = &lhs[row_idx * k..(row_idx + 1) * k];
2296 T::VecDotType::from_float(lhs, lhs_b_mut)
2297 }
2298 }
2299
2300 for row_idx in 0..m {
2301 let lhs_row = &lhs_b[row_idx * k_in_blocks..(row_idx + 1) * k_in_blocks];
2302 let dst_row = &mut dst[row_idx * n..(row_idx + 1) * n];
2303
2304 dst_row
2305 .into_par_iter()
2306 .enumerate()
2307 .with_min_len(128)
2308 .with_max_len(512)
2309 .for_each(|(col_idx, dst)| {
2310 let rhs_col = &rhs_t[col_idx * k_in_blocks..(col_idx + 1) * k_in_blocks];
2311 *dst = T::vec_dot(k, rhs_col, lhs_row);
2312 });
2313 }
2314 Ok(())
2315}
2316
2317pub fn matmul_f16<T: GgmlType>(
2318 mkn: (usize, usize, usize),
2319 lhs: &[f16],
2320 rhs_t: &[T],
2321 dst: &mut [f16],
2322) -> Result<()> {
2323 let (m, k, n) = mkn;
2324 if m * k != lhs.len() {
2325 crate::bail!("unexpected lhs length {} {mkn:?}", lhs.len());
2326 }
2327
2328 let k_in_lhs_blocks = k.div_ceil(T::BLCK_SIZE);
2329 let k_in_rhs_blocks = k.div_ceil(T::VecDotType::BLCK_SIZE);
2330 let mut lhs_b = vec![T::VecDotType::zeros(); m * k_in_lhs_blocks];
2331 for row_idx in 0..m {
2332 let lhs_b = &mut lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks];
2333 let lhs = &lhs[row_idx * k..(row_idx + 1) * k];
2334 let lhs_f32: Vec<_> = lhs.iter().map(|&x| x.to_f32()).collect();
2335 T::VecDotType::from_float(&lhs_f32, lhs_b);
2336 }
2337 let lhs_b = lhs_b.as_slice();
2338
2339 for row_idx in 0..m {
2340 let lhs_row = &lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks];
2341 let dst_row = &mut dst[row_idx * n..(row_idx + 1) * n];
2342
2343 for (col_idx, dst) in dst_row.iter_mut().enumerate() {
2344 let rhs_col = &rhs_t[col_idx * k_in_rhs_blocks..(col_idx + 1) * k_in_rhs_blocks];
2345 let value = T::vec_dot(k, rhs_col, lhs_row);
2346 *dst = f16::from_f32(value);
2347 }
2348 }
2349 Ok(())
2350}
2351
2352impl GgmlType for f32 {
2353 const DTYPE: GgmlDType = GgmlDType::F32;
2354 const BLCK_SIZE: usize = 1;
2355 const DIRECT_COPY: bool = true;
2356 type VecDotType = f32;
2357
2358 fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
2359 Self::vec_dot_unopt(n, xs, ys)
2360 }
2361
2362 fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
2363 debug_assert!(xs.len() >= n, "size mismatch xs {} < {n}", xs.len());
2364 debug_assert!(ys.len() >= n, "size mismatch ys {} < {n}", ys.len());
2365 let mut res = 0f32;
2366 unsafe { crate::cpu::vec_dot_f32(xs.as_ptr(), ys.as_ptr(), &mut res, n) };
2367 res
2368 }
2369
2370 fn from_float(xs: &[f32], ys: &mut [Self]) {
2371 debug_assert_eq!(
2372 xs.len(),
2373 ys.len(),
2374 "size mismatch xs {} != ys {}",
2375 xs.len(),
2376 ys.len()
2377 );
2378 ys.copy_from_slice(xs);
2379 }
2380
2381 fn to_float(xs: &[Self], ys: &mut [f32]) {
2382 debug_assert_eq!(
2383 xs.len(),
2384 ys.len(),
2385 "size mismatch xs {} != ys {}",
2386 xs.len(),
2387 ys.len()
2388 );
2389 ys.copy_from_slice(xs);
2390 }
2391
2392 fn direct_copy(xs: &[f32], ys: &mut [Self]) {
2393 Self::from_float(xs, ys)
2394 }
2395}
2396
2397impl GgmlType for f16 {
2398 const DTYPE: GgmlDType = GgmlDType::F16;
2399 const BLCK_SIZE: usize = 1;
2400 const DIRECT_COPY: bool = true;
2401 type VecDotType = f16;
2402
2403 fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
2404 Self::vec_dot_unopt(n, xs, ys)
2405 }
2406
2407 fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
2408 debug_assert!(xs.len() >= n, "size mismatch xs {} < {n}", xs.len());
2409 debug_assert!(ys.len() >= n, "size mismatch ys {} < {n}", ys.len());
2410 let mut res = 0f32;
2411 unsafe { crate::cpu::vec_dot_f16(xs.as_ptr(), ys.as_ptr(), &mut res, n) };
2412 res
2413 }
2414
2415 fn from_float(xs: &[f32], ys: &mut [Self]) {
2416 debug_assert_eq!(
2417 xs.len(),
2418 ys.len(),
2419 "size mismatch xs {} != ys {}",
2420 xs.len(),
2421 ys.len()
2422 );
2423 ys.convert_from_f32_slice(xs);
2424 }
2425
2426 fn to_float(xs: &[Self], ys: &mut [f32]) {
2427 debug_assert_eq!(
2428 xs.len(),
2429 ys.len(),
2430 "size mismatch xs {} != ys {}",
2431 xs.len(),
2432 ys.len()
2433 );
2434 xs.convert_to_f32_slice(ys);
2435 }
2436
2437 fn direct_copy(xs: &[f32], ys: &mut [Self]) {
2438 Self::from_float(xs, ys)
2439 }
2440}
2441
2442impl GgmlType for bf16 {
2443 const DTYPE: GgmlDType = GgmlDType::BF16;
2444 const BLCK_SIZE: usize = 1;
2445 const DIRECT_COPY: bool = true;
2446 type VecDotType = bf16;
2447
2448 fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
2449 Self::vec_dot_unopt(n, xs, ys)
2450 }
2451
2452 fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> f32 {
2453 debug_assert!(xs.len() >= n, "size mismatch xs {} < {n}", xs.len());
2454 debug_assert!(ys.len() >= n, "size mismatch ys {} < {n}", ys.len());
2455 let mut res = 0f32;
2456 unsafe { crate::cpu::vec_dot_bf16(xs.as_ptr(), ys.as_ptr(), &mut res, n) };
2457 res
2458 }
2459
2460 fn from_float(xs: &[f32], ys: &mut [Self]) {
2461 debug_assert_eq!(
2462 xs.len(),
2463 ys.len(),
2464 "size mismatch xs {} != ys {}",
2465 xs.len(),
2466 ys.len()
2467 );
2468 ys.convert_from_f32_slice(xs);
2469 }
2470
2471 fn to_float(xs: &[Self], ys: &mut [f32]) {
2472 debug_assert_eq!(
2473 xs.len(),
2474 ys.len(),
2475 "size mismatch xs {} != ys {}",
2476 xs.len(),
2477 ys.len()
2478 );
2479 xs.convert_to_f32_slice(ys);
2480 }
2481
2482 fn direct_copy(xs: &[f32], ys: &mut [Self]) {
2483 Self::from_float(xs, ys)
2484 }
2485}
2486
2487macro_rules! verify_block_size {
2488 ( $block_type:ident ) => {
2489 const _: () =
2490 assert!($block_type::BLCK_SIZE == <$block_type as GgmlType>::VecDotType::BLCK_SIZE);
2491 };
2492}
2493
2494macro_rules! verify_block_sizes {
2495 ( $( $block_type:ident ),* ) => {
2496 $(
2497 verify_block_size!($block_type);
2498 )*
2499 };
2500}
2501
2502verify_block_sizes!(
2503 BlockQ4_0, BlockQ4_1, BlockQ5_0, BlockQ5_1, BlockQ8_0, BlockQ8_1, BlockQ2K, BlockQ3K, BlockQ4K,
2504 BlockQ5K, BlockQ6K, BlockQ8K, f32, f16, bf16
2505);