1use super::utils::{
2 get_scale_min_k4, group_for_dequantization, group_for_quantization, make_q3_quants,
3 make_qkx1_quants, make_qx_quants, nearest_int,
4};
5use super::GgmlDType;
6use crate::Result;
7use byteorder::{ByteOrder, LittleEndian};
8use half::f16;
9use rayon::prelude::*;
10
11pub const QK_K: usize = 256;
13pub const K_SCALE_SIZE: usize = 12;
14
15pub const QK4_0: usize = 32;
16pub const QK4_1: usize = 32;
17pub const QK5_0: usize = 32;
18pub const QK5_1: usize = 32;
19pub const QK8_0: usize = 32;
20pub const QK8_1: usize = 32;
21
22pub trait GgmlType: Sized + Clone + Send + Sync {
23 const DTYPE: GgmlDType;
24 const BLCK_SIZE: usize;
25 type VecDotType: GgmlType;
26
27 fn zeros() -> Self {
29 unsafe { std::mem::MaybeUninit::zeroed().assume_init() }
30 }
31 fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()>;
32 fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()>;
33
34 fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32>;
37
38 fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32>;
40}
41
42#[derive(Debug, Clone, PartialEq)]
43#[repr(C)]
44pub struct BlockQ4_0 {
45 pub(crate) d: f16,
46 pub(crate) qs: [u8; QK4_0 / 2],
47}
48const _: () = assert!(std::mem::size_of::<BlockQ4_0>() == 18);
49
50#[derive(Debug, Clone, PartialEq)]
51#[repr(C)]
52pub struct BlockQ4_1 {
53 pub(crate) d: f16,
54 pub(crate) m: f16,
55 pub(crate) qs: [u8; QK4_1 / 2],
56}
57const _: () = assert!(std::mem::size_of::<BlockQ4_1>() == 20);
58
59#[derive(Debug, Clone, PartialEq)]
60#[repr(C)]
61pub struct BlockQ5_0 {
62 pub(crate) d: f16,
63 pub(crate) qh: [u8; 4],
64 pub(crate) qs: [u8; QK5_0 / 2],
65}
66const _: () = assert!(std::mem::size_of::<BlockQ5_0>() == 22);
67
68#[derive(Debug, Clone, PartialEq)]
69#[repr(C)]
70pub struct BlockQ5_1 {
71 pub(crate) d: f16,
72 pub(crate) m: f16,
73 pub(crate) qh: [u8; 4],
74 pub(crate) qs: [u8; QK5_1 / 2],
75}
76const _: () = assert!(std::mem::size_of::<BlockQ5_1>() == 24);
77
78#[derive(Debug, Clone, PartialEq)]
79#[repr(C)]
80pub struct BlockQ8_0 {
81 pub(crate) d: f16,
82 pub(crate) qs: [i8; QK8_0],
83}
84const _: () = assert!(std::mem::size_of::<BlockQ8_0>() == 34);
85
86#[derive(Debug, Clone, PartialEq)]
87#[repr(C)]
88pub struct BlockQ8_1 {
89 pub(crate) d: f16,
90 pub(crate) s: f16,
91 pub(crate) qs: [i8; QK8_1],
92}
93const _: () = assert!(std::mem::size_of::<BlockQ8_1>() == 36);
94
95#[derive(Debug, Clone, PartialEq)]
96#[repr(C)]
97pub struct BlockQ2K {
98 pub(crate) scales: [u8; QK_K / 16],
99 pub(crate) qs: [u8; QK_K / 4],
100 pub(crate) d: f16,
101 pub(crate) dmin: f16,
102}
103const _: () = assert!(QK_K / 16 + QK_K / 4 + 2 * 2 == std::mem::size_of::<BlockQ2K>());
104
105#[derive(Debug, Clone, PartialEq)]
106#[repr(C)]
107pub struct BlockQ3K {
108 pub(crate) hmask: [u8; QK_K / 8],
109 pub(crate) qs: [u8; QK_K / 4],
110 pub(crate) scales: [u8; 12],
111 pub(crate) d: f16,
112}
113const _: () = assert!(QK_K / 8 + QK_K / 4 + 12 + 2 == std::mem::size_of::<BlockQ3K>());
114
115#[derive(Debug, Clone, PartialEq)]
116#[repr(C)]
118pub struct BlockQ4K {
119 pub(crate) d: f16,
120 pub(crate) dmin: f16,
121 pub(crate) scales: [u8; K_SCALE_SIZE],
122 pub(crate) qs: [u8; QK_K / 2],
123}
124const _: () = assert!(QK_K / 2 + K_SCALE_SIZE + 2 * 2 == std::mem::size_of::<BlockQ4K>());
125
126#[derive(Debug, Clone, PartialEq)]
127#[repr(C)]
128pub struct BlockQ5K {
129 pub(crate) d: f16,
130 pub(crate) dmin: f16,
131 pub(crate) scales: [u8; K_SCALE_SIZE],
132 pub(crate) qh: [u8; QK_K / 8],
133 pub(crate) qs: [u8; QK_K / 2],
134}
135const _: () =
136 assert!(QK_K / 8 + QK_K / 2 + 2 * 2 + K_SCALE_SIZE == std::mem::size_of::<BlockQ5K>());
137
138#[derive(Debug, Clone, PartialEq)]
139#[repr(C)]
140pub struct BlockQ6K {
141 pub(crate) ql: [u8; QK_K / 2],
142 pub(crate) qh: [u8; QK_K / 4],
143 pub(crate) scales: [i8; QK_K / 16],
144 pub(crate) d: f16,
145}
146const _: () = assert!(3 * QK_K / 4 + QK_K / 16 + 2 == std::mem::size_of::<BlockQ6K>());
147
148#[derive(Debug, Clone, PartialEq)]
149#[repr(C)]
150pub struct BlockQ8K {
151 pub(crate) d: f32,
152 pub(crate) qs: [i8; QK_K],
153 pub(crate) bsums: [i16; QK_K / 16],
154}
155const _: () = assert!(4 + QK_K + QK_K / 16 * 2 == std::mem::size_of::<BlockQ8K>());
156
157impl GgmlType for BlockQ4_0 {
158 const DTYPE: GgmlDType = GgmlDType::Q4_0;
159 const BLCK_SIZE: usize = QK4_0;
160 type VecDotType = BlockQ8_0;
161
162 fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
164 let k = ys.len();
165 let qk = Self::BLCK_SIZE;
166 if k % qk != 0 {
167 crate::bail!("dequantize_row_q4_0: {k} is not divisible by {qk}")
168 }
169
170 let nb = k / qk;
171 for i in 0..nb {
172 let d = xs[i].d.to_f32();
173
174 for j in 0..(qk / 2) {
175 let x0 = (xs[i].qs[j] & 0x0F) as i16 - 8;
176 let x1 = (xs[i].qs[j] >> 4) as i16 - 8;
177
178 ys[i * qk + j] = (x0 as f32) * d;
179 ys[i * qk + j + qk / 2] = (x1 as f32) * d;
180 }
181 }
182 Ok(())
183 }
184
185 fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
186 let qk = Self::BLCK_SIZE;
188 let k = xs.len();
189 if k % qk != 0 {
190 crate::bail!("{k} is not divisible by {}", qk);
191 };
192 let nb = k / qk;
193 if ys.len() != nb {
194 crate::bail!("size mismatch {} {} {}", xs.len(), ys.len(), qk,)
195 }
196 for (i, ys) in ys.iter_mut().enumerate() {
197 let mut amax = 0f32;
198 let mut max = 0f32;
199
200 let xs = &xs[i * qk..(i + 1) * qk];
201 for &x in xs.iter() {
202 if amax < x.abs() {
203 amax = x.abs();
204 max = x;
205 }
206 }
207 let d = max / -8.0;
208 let id = if d != 0f32 { 1. / d } else { 0. };
209 ys.d = f16::from_f32(d);
210
211 for (j, q) in ys.qs.iter_mut().enumerate() {
212 let x0 = xs[j] * id;
213 let x1 = xs[qk / 2 + j] * id;
214 let xi0 = u8::min(15, (x0 + 8.5) as u8);
215 let xi1 = u8::min(15, (x1 + 8.5) as u8);
216 *q = xi0 | (xi1 << 4)
217 }
218 }
219 Ok(())
220 }
221
222 #[allow(unreachable_code)]
224 fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
225 #[cfg(target_feature = "avx")]
226 return super::avx::vec_dot_q4_0_q8_0(n, xs, ys);
227
228 #[cfg(target_feature = "neon")]
229 return super::neon::vec_dot_q4_0_q8_0(n, xs, ys);
230
231 #[cfg(target_feature = "simd128")]
232 return super::simd128::vec_dot_q4_0_q8_0(n, xs, ys);
233
234 Self::vec_dot_unopt(n, xs, ys)
235 }
236
237 fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
238 let qk = QK8_0;
239 if n % QK8_0 != 0 {
240 crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
241 }
242 let mut sumf = 0f32;
244 for (xs, ys) in xs.iter().zip(ys.iter()) {
245 let mut sum_i = 0;
246 for j in 0..qk / 2 {
247 let v0 = (xs.qs[j] & 0x0F) as i32 - 8;
248 let v1 = (xs.qs[j] >> 4) as i32 - 8;
249 sum_i += v0 * ys.qs[j] as i32 + v1 * ys.qs[j + qk / 2] as i32
250 }
251 sumf += sum_i as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d)
252 }
253 Ok(sumf)
254 }
255}
256
257impl GgmlType for BlockQ4_1 {
258 const DTYPE: GgmlDType = GgmlDType::Q4_1;
259 const BLCK_SIZE: usize = QK4_1;
260 type VecDotType = BlockQ8_1;
261
262 fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
263 Self::vec_dot_unopt(n, xs, ys)
264 }
265
266 fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
267 let qk = QK8_1;
269 if n % qk != 0 {
270 crate::bail!("vec_dot_q4_1_q8_1: {n} is not divisible by {qk}")
271 }
272 let nb = n / qk;
273 if nb % 2 != 0 {
274 crate::bail!("vec_dot_q4_1_q8_1: {n}, nb is not divisible by 2")
275 }
276
277 let mut sumf = 0f32;
279
280 for (xs, ys) in xs.iter().zip(ys.iter()) {
281 let mut sumi = 0i32;
282
283 for j in 0..qk / 2 {
284 let v0 = xs.qs[j] as i32 & 0x0F;
285 let v1 = xs.qs[j] as i32 >> 4;
286 sumi += (v0 * ys.qs[j] as i32) + (v1 * ys.qs[j + qk / 2] as i32);
287 }
288
289 sumf += sumi as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d)
290 + f16::to_f32(xs.m) * f16::to_f32(ys.s)
291 }
292 Ok(sumf)
293 }
294
295 fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
296 let qk = Self::BLCK_SIZE;
298 if ys.len() * qk != xs.len() {
299 crate::bail!("size mismatch {} {} {}", xs.len(), ys.len(), qk,)
300 }
301 for (i, ys) in ys.iter_mut().enumerate() {
302 let xs = &xs[i * qk..(i + 1) * qk];
303
304 let mut min = f32::INFINITY;
305 let mut max = f32::NEG_INFINITY;
306 for &x in xs.iter() {
307 min = f32::min(x, min);
308 max = f32::max(x, max);
309 }
310 let d = (max - min) / ((1 << 4) - 1) as f32;
311 let id = if d != 0f32 { 1. / d } else { 0. };
312 ys.d = f16::from_f32(d);
313 ys.m = f16::from_f32(min);
314
315 for (j, q) in ys.qs.iter_mut().take(qk / 2).enumerate() {
316 let x0 = (xs[j] - min) * id;
317 let x1 = (xs[qk / 2 + j] - min) * id;
318
319 let xi0 = u8::min(15, (x0 + 0.5) as u8);
320 let xi1 = u8::min(15, (x1 + 0.5) as u8);
321
322 *q = xi0 | (xi1 << 4);
323 }
324 }
325 Ok(())
326 }
327
328 fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
330 let k = ys.len();
331 if k % QK4_1 != 0 {
332 crate::bail!("dequantize_row_q4_1: {k} is not divisible by {QK4_1}");
333 }
334
335 let nb = k / QK4_1;
336 for i in 0..nb {
337 let d = xs[i].d.to_f32();
338 let m = xs[i].m.to_f32();
339
340 for j in 0..(QK4_1 / 2) {
341 let x0 = xs[i].qs[j] & 0x0F;
342 let x1 = xs[i].qs[j] >> 4;
343
344 ys[i * QK4_1 + j] = (x0 as f32) * d + m;
345 ys[i * QK4_1 + j + QK4_1 / 2] = (x1 as f32) * d + m;
346 }
347 }
348 Ok(())
349 }
350}
351
352impl GgmlType for BlockQ5_0 {
353 const DTYPE: GgmlDType = GgmlDType::Q5_0;
354 const BLCK_SIZE: usize = QK5_0;
355 type VecDotType = BlockQ8_0;
356
357 fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
358 let qk = Self::BLCK_SIZE;
359 if n % Self::BLCK_SIZE != 0 {
360 crate::bail!("vec_dot_q5_0_q8_0: {n} is not divisible by {qk}")
361 }
362 let nb = n / qk;
363 if nb % 2 != 0 {
364 crate::bail!("vec_dot_q5_0_q8_0: {n}, nb is not divisible by 2")
365 }
366 Self::vec_dot_unopt(n, xs, ys)
367 }
368
369 fn vec_dot_unopt(_n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
370 let mut sumf = 0f32;
372
373 for (xs, ys) in xs.iter().zip(ys.iter()) {
374 let qh = LittleEndian::read_u32(&xs.qh);
375 let mut sumi = 0i32;
376
377 for j in 0..Self::BLCK_SIZE / 2 {
378 let xh_0 = (((qh & (1u32 << j)) >> j) << 4) as u8;
379 let xh_1 = ((qh & (1u32 << (j + 16))) >> (j + 12)) as u8;
380
381 let x0 = ((xs.qs[j] & 0x0F) as i32 | xh_0 as i32) - 16;
382 let x1 = ((xs.qs[j] >> 4) as i32 | xh_1 as i32) - 16;
383
384 sumi += (x0 * ys.qs[j] as i32) + (x1 * ys.qs[j + Self::BLCK_SIZE / 2] as i32);
385 }
386
387 sumf += sumi as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d)
388 }
389 Ok(sumf)
390 }
391
392 fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
393 let k = xs.len();
395 if ys.len() * Self::BLCK_SIZE != k {
396 crate::bail!("size mismatch {k} {} {}", ys.len(), Self::BLCK_SIZE)
397 }
398 for (i, ys) in ys.iter_mut().enumerate() {
399 let xs = &xs[i * Self::BLCK_SIZE..(i + 1) * Self::BLCK_SIZE];
400
401 let mut amax = 0f32;
402 let mut max = 0f32;
403 for &x in xs.iter() {
404 if amax < x.abs() {
405 amax = x.abs();
406 max = x;
407 }
408 }
409 let d = max / -16.;
410 let id = if d != 0f32 { 1. / d } else { 0. };
411 ys.d = f16::from_f32(d);
412 let mut qh = 0u32;
413 for j in 0..Self::BLCK_SIZE / 2 {
414 let x0 = xs[j] * id;
415 let x1 = xs[j + Self::BLCK_SIZE / 2] * id;
416 let xi0 = ((x0 + 16.5) as i8).min(31) as u8;
417 let xi1 = ((x1 + 16.5) as i8).min(31) as u8;
418 ys.qs[j] = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
419 qh |= ((xi0 as u32 & 0x10) >> 4) << j;
420 qh |= ((xi1 as u32 & 0x10) >> 4) << (j + Self::BLCK_SIZE / 2);
421 }
422 LittleEndian::write_u32(&mut ys.qh, qh)
423 }
424 Ok(())
425 }
426
427 fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
429 let k = ys.len();
430 if k % QK5_0 != 0 {
431 crate::bail!("dequantize_row_q5_0: {k} is not divisible by {QK5_0}");
432 }
433
434 let nb = k / QK5_0;
435 for i in 0..nb {
436 let d = xs[i].d.to_f32();
437 let qh: u32 = LittleEndian::read_u32(&xs[i].qh);
438
439 for j in 0..(QK5_0 / 2) {
440 let xh_0 = (((qh >> j) << 4) & 0x10) as u8;
441 let xh_1 = ((qh >> (j + 12)) & 0x10) as u8;
442
443 let x0 = ((xs[i].qs[j] & 0x0F) | xh_0) as i32 - 16;
444 let x1 = ((xs[i].qs[j] >> 4) | xh_1) as i32 - 16;
445
446 ys[i * QK5_0 + j] = (x0 as f32) * d;
447 ys[i * QK5_0 + j + QK5_0 / 2] = (x1 as f32) * d;
448 }
449 }
450 Ok(())
451 }
452}
453
454impl GgmlType for BlockQ5_1 {
455 const DTYPE: GgmlDType = GgmlDType::Q5_1;
456 const BLCK_SIZE: usize = QK5_1;
457 type VecDotType = BlockQ8_1;
458
459 fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
460 Self::vec_dot_unopt(n, xs, ys)
461 }
462
463 fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
464 let qk = Self::BLCK_SIZE;
465 if n % Self::BLCK_SIZE != 0 {
466 crate::bail!("vec_dot_q5_1_q8_1: {n} is not divisible by {qk}")
467 }
468 let nb = n / qk;
469 if nb % 2 != 0 {
470 crate::bail!("vec_dot_q5_1_q8_1: {n}, nb is not divisible by 2")
471 }
472
473 let mut sumf = 0f32;
475
476 for (xs, ys) in xs.iter().zip(ys.iter()) {
477 let qh = LittleEndian::read_u32(&xs.qh);
478 let mut sumi = 0i32;
479
480 for j in 0..Self::BLCK_SIZE / 2 {
481 let xh_0 = ((qh >> j) << 4) & 0x10;
482 let xh_1 = (qh >> (j + 12)) & 0x10;
483
484 let x0 = (xs.qs[j] as i32 & 0xF) | xh_0 as i32;
485 let x1 = (xs.qs[j] as i32 >> 4) | xh_1 as i32;
486
487 sumi += (x0 * ys.qs[j] as i32) + (x1 * ys.qs[j + Self::BLCK_SIZE / 2] as i32);
488 }
489
490 sumf += sumi as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d)
491 + f16::to_f32(xs.m) * f16::to_f32(ys.s)
492 }
493 Ok(sumf)
494 }
495
496 fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
497 let qk = Self::BLCK_SIZE;
499 if ys.len() * qk != xs.len() {
500 crate::bail!("size mismatch {} {} {}", xs.len(), ys.len(), qk,)
501 }
502 for (i, ys) in ys.iter_mut().enumerate() {
503 let xs = &xs[i * qk..(i + 1) * qk];
504
505 let mut min = f32::INFINITY;
506 let mut max = f32::NEG_INFINITY;
507 for &x in xs.iter() {
508 min = f32::min(x, min);
509 max = f32::max(x, max);
510 }
511 let d = (max - min) / ((1 << 5) - 1) as f32;
512 let id = if d != 0f32 { 1. / d } else { 0. };
513 ys.d = f16::from_f32(d);
514 ys.m = f16::from_f32(min);
515
516 let mut qh = 0u32;
517 for (j, q) in ys.qs.iter_mut().take(qk / 2).enumerate() {
518 let x0 = (xs[j] - min) * id;
519 let x1 = (xs[qk / 2 + j] - min) * id;
520
521 let xi0 = (x0 + 0.5) as u8;
522 let xi1 = (x1 + 0.5) as u8;
523
524 *q = (xi0 & 0x0F) | ((xi1 & 0x0F) << 4);
525 qh |= ((xi0 as u32 & 0x10) >> 4) << j;
527 qh |= ((xi1 as u32 & 0x10) >> 4) << (j + qk / 2);
528 }
529 LittleEndian::write_u32(&mut ys.qh, qh);
530 }
531 Ok(())
532 }
533
534 fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
536 let k = ys.len();
537 if k % QK5_1 != 0 {
538 crate::bail!("dequantize_row_q5_1: {k} is not divisible by {QK5_1}");
539 }
540
541 let nb = k / QK5_1;
542 for i in 0..nb {
543 let d = xs[i].d.to_f32();
544 let m = xs[i].m.to_f32();
545 let qh: u32 = LittleEndian::read_u32(&xs[i].qh);
546
547 for j in 0..(QK5_1 / 2) {
548 let xh_0 = (((qh >> j) << 4) & 0x10) as u8;
549 let xh_1 = ((qh >> (j + 12)) & 0x10) as u8;
550
551 let x0 = (xs[i].qs[j] & 0x0F) | xh_0;
552 let x1 = (xs[i].qs[j] >> 4) | xh_1;
553
554 ys[i * QK5_1 + j] = (x0 as f32) * d + m;
555 ys[i * QK5_1 + j + QK5_1 / 2] = (x1 as f32) * d + m;
556 }
557 }
558 Ok(())
559 }
560}
561
562impl GgmlType for BlockQ8_0 {
563 const DTYPE: GgmlDType = GgmlDType::Q8_0;
564 const BLCK_SIZE: usize = QK8_0;
565 type VecDotType = BlockQ8_0;
566
567 fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
569 let k = ys.len();
570 if k % QK8_0 != 0 {
571 crate::bail!("dequantize_row_q8_0: {k} is not divisible by {QK8_0}");
572 }
573
574 let nb = k / QK8_0;
575
576 for i in 0..nb {
577 let d = xs[i].d.to_f32();
578
579 for j in 0..QK8_0 {
580 ys[i * QK8_0 + j] = xs[i].qs[j] as f32 * d;
581 }
582 }
583 Ok(())
584 }
585
586 fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
587 let k = xs.len();
589 if k % Self::BLCK_SIZE != 0 {
590 crate::bail!("{k} is not divisible by {}", Self::BLCK_SIZE);
591 };
592 let nb = k / Self::BLCK_SIZE;
593 if ys.len() != nb {
594 crate::bail!(
595 "size mismatch {} {} {}",
596 xs.len(),
597 ys.len(),
598 Self::BLCK_SIZE
599 )
600 }
601 for (i, ys) in ys.iter_mut().enumerate() {
602 let mut amax = 0f32;
603 let xs = &xs[i * Self::BLCK_SIZE..(i + 1) * Self::BLCK_SIZE];
604 for &x in xs.iter() {
605 amax = amax.max(x.abs())
606 }
607 let d = amax / ((1 << 7) - 1) as f32;
608 let id = if d != 0f32 { 1. / d } else { 0. };
609 ys.d = f16::from_f32(d);
610 for (y, &x) in ys.qs.iter_mut().zip(xs.iter()) {
611 *y = f32::round(x * id) as i8
612 }
613 }
614 Ok(())
615 }
616
617 #[allow(unreachable_code)]
618 fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
619 #[cfg(target_feature = "avx")]
620 return super::avx::vec_dot_q8_0_q8_0(n, xs, ys);
621
622 #[cfg(target_feature = "neon")]
623 return super::neon::vec_dot_q8_0_q8_0(n, xs, ys);
624
625 #[cfg(target_feature = "simd128")]
626 return super::simd128::vec_dot_q8_0_q8_0(n, xs, ys);
627
628 Self::vec_dot_unopt(n, xs, ys)
629 }
630
631 fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
632 let qk = QK8_0;
633 if n % QK8_0 != 0 {
634 crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
635 }
636
637 let mut sumf = 0f32;
639 for (xs, ys) in xs.iter().zip(ys.iter()) {
640 let sum_i = xs
641 .qs
642 .iter()
643 .zip(ys.qs.iter())
644 .map(|(&x, &y)| x as i32 * y as i32)
645 .sum::<i32>();
646 sumf += sum_i as f32 * f16::to_f32(xs.d) * f16::to_f32(ys.d)
647 }
648 Ok(sumf)
649 }
650}
651
652impl GgmlType for BlockQ8_1 {
653 const DTYPE: GgmlDType = GgmlDType::Q8_1;
654 const BLCK_SIZE: usize = QK8_1;
655 type VecDotType = BlockQ8_1;
656
657 fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
658 Self::vec_dot_unopt(n, xs, ys)
659 }
660
661 fn vec_dot_unopt(_n: usize, _xs: &[Self], _ys: &[Self::VecDotType]) -> Result<f32> {
662 unimplemented!("no support for vec-dot on Q8_1")
663 }
664
665 fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
666 let k = xs.len();
668 if ys.len() * Self::BLCK_SIZE != k {
669 crate::bail!("size mismatch {k} {} {}", ys.len(), Self::BLCK_SIZE)
670 }
671 for (i, ys) in ys.iter_mut().enumerate() {
672 let mut amax = 0f32;
673 let xs = &xs[i * Self::BLCK_SIZE..(i + 1) * Self::BLCK_SIZE];
674 for &x in xs.iter() {
675 amax = amax.max(x.abs())
676 }
677 let d = amax / ((1 << 7) - 1) as f32;
678 let id = if d != 0f32 { 1. / d } else { 0. };
679 ys.d = f16::from_f32(d);
680 let mut sum = 0i32;
681 for j in 0..Self::BLCK_SIZE / 2 {
682 let v0 = xs[j] * id;
683 let v1 = xs[j + Self::BLCK_SIZE / 2] * id;
684 ys.qs[j] = f32::round(v0) as i8;
685 ys.qs[j + Self::BLCK_SIZE / 2] = f32::round(v1) as i8;
686 sum += ys.qs[j] as i32 + ys.qs[j + Self::BLCK_SIZE / 2] as i32;
687 }
688 ys.s = f16::from_f32(sum as f32) * ys.d;
689 }
690 Ok(())
691 }
692
693 fn to_float(_xs: &[Self], _ys: &mut [f32]) -> Result<()> {
694 unimplemented!("no support for vec-dot on Q8_1")
695 }
696}
697
698impl GgmlType for BlockQ2K {
699 const DTYPE: GgmlDType = GgmlDType::Q2K;
700 const BLCK_SIZE: usize = QK_K;
701 type VecDotType = BlockQ8K;
702
703 #[allow(unreachable_code)]
704 fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
705 #[cfg(target_feature = "avx")]
706 return super::avx::vec_dot_q2k_q8k(n, xs, ys);
707
708 #[cfg(target_feature = "neon")]
709 return super::neon::vec_dot_q2k_q8k(n, xs, ys);
710
711 #[cfg(target_feature = "simd128")]
712 return super::simd128::vec_dot_q2k_q8k(n, xs, ys);
713
714 Self::vec_dot_unopt(n, xs, ys)
715 }
716
717 fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
718 if n % QK_K != 0 {
719 crate::bail!("vec_dot_q2k_q8k: {n} is not divisible by {QK_K}")
720 }
721
722 let mut sumf = 0.0;
723 for (x, y) in xs.iter().zip(ys.iter()) {
724 let mut q2: &[_] = &x.qs;
725 let mut q8: &[_] = &y.qs;
726 let sc = &x.scales;
727
728 let mut summs = 0;
729 for (bsum, scale) in y.bsums.iter().zip(sc) {
730 summs += *bsum as i32 * ((scale >> 4) as i32);
731 }
732
733 let dall = y.d * x.d.to_f32();
734 let dmin = y.d * x.dmin.to_f32();
735
736 let mut isum = 0;
737 let mut is = 0;
738 for _ in 0..(QK_K / 128) {
739 let mut shift = 0;
740 for _ in 0..4 {
741 let d = (sc[is] & 0xF) as i32;
742 is += 1;
743 let mut isuml = 0;
744 for l in 0..16 {
745 isuml += q8[l] as i32 * (((q2[l] >> shift) & 3) as i32);
746 }
747 isum += d * isuml;
748 let d = (sc[is] & 0xF) as i32;
749 is += 1;
750 isuml = 0;
751 for l in 16..32 {
752 isuml += q8[l] as i32 * (((q2[l] >> shift) & 3) as i32);
753 }
754 isum += d * isuml;
755 shift += 2;
756 q8 = &q8[32..];
758 }
759 q2 = &q2[32..];
761 }
762 sumf += dall * isum as f32 - dmin * summs as f32;
763 }
764
765 Ok(sumf)
766 }
767
768 fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
770 const Q4SCALE: f32 = 15.0;
771
772 for (block, x) in group_for_quantization(xs, ys)? {
773 let mut mins: [f32; QK_K / 16] = [0.0; QK_K / 16];
775 let mut scales: [f32; QK_K / 16] = [0.0; QK_K / 16];
776
777 for (j, x_scale_slice) in x.chunks(16).enumerate() {
778 (scales[j], mins[j]) = make_qkx1_quants(3, 5, x_scale_slice);
779 }
780 let max_scale = scales.iter().fold(0.0, |max, &val| val.max(max));
782 let max_min = mins.iter().fold(0.0, |max, &val| val.max(max));
783
784 if max_scale > 0.0 {
785 let iscale = Q4SCALE / max_scale;
786 for (j, scale) in scales.iter().enumerate().take(QK_K / 16) {
787 block.scales[j] = nearest_int(iscale * scale) as u8;
788 }
789 block.d = f16::from_f32(max_scale / Q4SCALE);
790 } else {
791 for j in 0..QK_K / 16 {
792 block.scales[j] = 0;
793 }
794 block.d = f16::from_f32(0.0);
795 }
796
797 if max_min > 0.0 {
798 let iscale = Q4SCALE / max_min;
799 for (j, scale) in block.scales.iter_mut().enumerate() {
800 let l = nearest_int(iscale * mins[j]) as u8;
801 *scale |= l << 4;
802 }
803 block.dmin = f16::from_f32(max_min / Q4SCALE);
804 } else {
805 block.dmin = f16::from_f32(0.0);
806 }
807
808 let mut big_l: [u8; QK_K] = [0; QK_K];
809
810 for j in 0..QK_K / 16 {
811 let d = block.d.to_f32() * (block.scales[j] & 0xF) as f32;
812 if d == 0.0 {
813 continue;
814 }
815 let dm = block.dmin.to_f32() * (block.scales[j] >> 4) as f32;
816 for ii in 0..16 {
817 let ll = nearest_int((x[16 * j + ii] + dm) / d).clamp(0, 3);
818 big_l[16 * j + ii] = ll as u8;
819 }
820 }
821
822 for j in (0..QK_K).step_by(128) {
823 for ll in 0..32 {
824 block.qs[j / 4 + ll] = big_l[j + ll]
825 | (big_l[j + ll + 32] << 2)
826 | (big_l[j + ll + 64] << 4)
827 | (big_l[j + ll + 96] << 6);
828 }
829 }
830 }
831 Ok(())
832 }
833 fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
835 for (block, y) in group_for_dequantization(xs, ys)? {
836 let d = block.d.to_f32();
837 let min = block.dmin.to_f32();
838
839 let mut is = 0;
840
841 for (y_block, qs) in y.chunks_exact_mut(128).zip(block.qs.chunks_exact(32)) {
842 let mut shift = 0;
844 let mut y_block_index = 0;
845 for _j in 0..4 {
846 let sc = block.scales[is];
847 is += 1;
848 let dl = d * (sc & 0xF) as f32;
849 let ml = min * (sc >> 4) as f32;
850 for q in &qs[..16] {
851 let y = dl * ((q >> shift) & 3) as f32 - ml;
852 y_block[y_block_index] = y;
853 y_block_index += 1;
854 }
855
856 let sc = block.scales[is];
857 is += 1;
858 let dl = d * (sc & 0xF) as f32;
859 let ml = min * (sc >> 4) as f32;
860 for q in &qs[16..] {
861 let y = dl * ((q >> shift) & 3) as f32 - ml;
862 y_block[y_block_index] = y;
863 y_block_index += 1;
864 }
865
866 shift += 2;
867 }
868 }
869 }
870 Ok(())
871 }
872}
873
874impl GgmlType for BlockQ3K {
875 const DTYPE: GgmlDType = GgmlDType::Q3K;
876 const BLCK_SIZE: usize = QK_K;
877 type VecDotType = BlockQ8K;
878
879 #[allow(unreachable_code)]
880 fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
881 #[cfg(target_feature = "avx")]
882 return super::avx::vec_dot_q3k_q8k(n, xs, ys);
883
884 #[cfg(target_feature = "neon")]
885 return super::neon::vec_dot_q3k_q8k(n, xs, ys);
886
887 Self::vec_dot_unopt(n, xs, ys)
888 }
889
890 fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
891 if n % QK_K != 0 {
892 crate::bail!("vec_dot_q3k_q8k: {n} is not divisible by {QK_K}")
893 }
894
895 const KMASK1: u32 = 0x03030303;
896 const KMASK2: u32 = 0x0f0f0f0f;
897
898 let mut aux8: [i8; QK_K] = [0; QK_K];
899 let mut aux16: [i16; 8] = [0; 8];
900 let mut sums: [f32; 8] = [0.0; 8];
901 let mut aux32: [i32; 8] = [0; 8];
902
903 let mut auxs: [u32; 4] = [0; 4];
904
905 for (x, y) in xs.iter().zip(ys.iter()) {
906 let mut q3: &[u8] = &x.qs;
907 let hmask: &[u8] = &x.hmask;
908 let mut q8: &[i8] = &y.qs;
909
910 aux32.fill(0);
911 let mut a = &mut aux8[..];
912
913 let mut m = 1;
914 for _ in 0..QK_K / 128 {
916 a.iter_mut()
917 .take(32)
918 .zip(q3)
919 .for_each(|(a_val, q3_val)| *a_val = (q3_val & 3) as i8);
920 a.iter_mut()
921 .take(32)
922 .zip(hmask)
923 .for_each(|(a_val, hmask_val)| {
924 *a_val -= if hmask_val & m != 0 { 0 } else { 4 }
925 });
926 a = &mut a[32..];
927 m <<= 1;
928
929 a.iter_mut()
930 .take(32)
931 .zip(q3)
932 .for_each(|(a_val, q3_val)| *a_val = ((q3_val >> 2) & 3) as i8);
933 a.iter_mut()
934 .take(32)
935 .zip(hmask)
936 .for_each(|(a_val, hmask_val)| {
937 *a_val -= if hmask_val & m != 0 { 0 } else { 4 }
938 });
939 a = &mut a[32..];
940 m <<= 1;
941
942 a.iter_mut()
943 .take(32)
944 .zip(q3)
945 .for_each(|(a_val, q3_val)| *a_val = ((q3_val >> 4) & 3) as i8);
946 a.iter_mut()
947 .take(32)
948 .zip(hmask)
949 .for_each(|(a_val, hmask_val)| {
950 *a_val -= if hmask_val & m != 0 { 0 } else { 4 }
951 });
952 a = &mut a[32..];
953 m <<= 1;
954
955 a.iter_mut()
956 .take(32)
957 .zip(q3)
958 .for_each(|(a_val, q3_val)| *a_val = ((q3_val >> 6) & 3) as i8);
959 a.iter_mut()
960 .take(32)
961 .zip(hmask)
962 .for_each(|(a_val, hmask_val)| {
963 *a_val -= if hmask_val & m != 0 { 0 } else { 4 }
964 });
965 a = &mut a[32..];
966 m <<= 1;
967 q3 = &q3[32..];
968 }
969
970 a = &mut aux8[..];
971
972 LittleEndian::read_u32_into(&x.scales, &mut auxs[0..3]);
973
974 let tmp = auxs[2];
975 auxs[2] = ((auxs[0] >> 4) & KMASK2) | (((tmp >> 4) & KMASK1) << 4);
976 auxs[3] = ((auxs[1] >> 4) & KMASK2) | (((tmp >> 6) & KMASK1) << 4);
977 auxs[0] = (auxs[0] & KMASK2) | (((tmp) & KMASK1) << 4);
978 auxs[1] = (auxs[1] & KMASK2) | (((tmp >> 2) & KMASK1) << 4);
979
980 for aux in auxs {
981 for scale in aux.to_le_bytes() {
982 let scale = i8::from_be_bytes([scale]);
983 for l in 0..8 {
984 aux16[l] = q8[l] as i16 * a[l] as i16;
985 }
986 for l in 0..8 {
987 aux32[l] += (scale as i32 - 32) * aux16[l] as i32;
988 }
989 q8 = &q8[8..];
990 a = &mut a[8..];
991
992 for l in 0..8 {
993 aux16[l] = q8[l] as i16 * a[l] as i16;
994 }
995 for l in 0..8 {
996 aux32[l] += (scale as i32 - 32) * aux16[l] as i32;
997 }
998 q8 = &q8[8..];
999 a = &mut a[8..];
1000 }
1001 }
1002 let d = x.d.to_f32() * y.d;
1003 for l in 0..8 {
1004 sums[l] += d * aux32[l] as f32;
1005 }
1006 }
1007
1008 Ok(sums.iter().sum())
1009 }
1010
1011 fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
1012 for (block, x) in group_for_quantization(xs, ys)? {
1013 let mut scales: [f32; QK_K / 16] = [0.0; QK_K / 16];
1014 for (j, x_scale_slice) in x.chunks_exact(16).enumerate() {
1015 scales[j] = make_q3_quants(x_scale_slice, 4, true);
1016 }
1017
1018 let mut max_scale: f32 = 0.0;
1020 for &scale in scales.iter() {
1021 if scale.abs() > max_scale.abs() {
1022 max_scale = scale;
1023 }
1024 }
1025
1026 block.scales.fill(0);
1027
1028 if max_scale != 0.0 {
1029 let iscale = -32.0 / max_scale;
1030 for (j, scale) in scales.iter().enumerate() {
1031 let l_val = nearest_int(iscale * scale);
1032 let l_val = l_val.clamp(-32, 31) + 32;
1033 if j < 8 {
1034 block.scales[j] = (l_val & 0xF) as u8;
1035 } else {
1036 block.scales[j - 8] |= ((l_val & 0xF) << 4) as u8;
1037 }
1038 let l_val = l_val >> 4;
1039 block.scales[j % 4 + 8] |= (l_val << (2 * (j / 4))) as u8;
1040 }
1041 block.d = f16::from_f32(1.0 / iscale);
1042 } else {
1043 block.d = f16::from_f32(0.0);
1044 }
1045
1046 let mut l: [i8; QK_K] = [0; QK_K];
1047
1048 for j in 0..QK_K / 16 {
1049 let sc = if j < 8 {
1050 block.scales[j] & 0xF
1051 } else {
1052 block.scales[j - 8] >> 4
1053 };
1054 let sc = (sc | (((block.scales[8 + j % 4] >> (2 * (j / 4))) & 3) << 4)) as i8 - 32;
1055 let d = block.d.to_f32() * sc as f32;
1056 if d != 0.0 {
1057 for ii in 0..16 {
1058 let l_val = nearest_int(x[16 * j + ii] / d);
1059 l[16 * j + ii] = (l_val.clamp(-4, 3) + 4) as i8;
1060 }
1061 }
1062 }
1063
1064 block.hmask.fill(0);
1065 let mut m = 0;
1066 let mut hm = 1;
1067
1068 for ll in l.iter_mut() {
1069 if *ll > 3 {
1070 block.hmask[m] |= hm;
1071 *ll -= 4;
1072 }
1073 m += 1;
1074 if m == QK_K / 8 {
1075 m = 0;
1076 hm <<= 1;
1077 }
1078 }
1079
1080 for j in (0..QK_K).step_by(128) {
1081 for l_val in 0..32 {
1082 block.qs[j / 4 + l_val] = (l[j + l_val]
1083 | (l[j + l_val + 32] << 2)
1084 | (l[j + l_val + 64] << 4)
1085 | (l[j + l_val + 96] << 6))
1086 as u8;
1087 }
1088 }
1089 }
1090
1091 Ok(())
1092 }
1093
1094 fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
1096 const KMASK1: u32 = 0x03030303;
1097 const KMASK2: u32 = 0x0f0f0f0f;
1098
1099 for (block, y) in group_for_dequantization(xs, ys)? {
1100 let mut aux = [0; 4];
1102 LittleEndian::read_u32_into(&block.scales, &mut aux[0..3]);
1103
1104 let tmp = aux[2];
1105 aux[2] = ((aux[0] >> 4) & KMASK2) | (((tmp >> 4) & KMASK1) << 4);
1106 aux[3] = ((aux[1] >> 4) & KMASK2) | (((tmp >> 6) & KMASK1) << 4);
1107 aux[0] = (aux[0] & KMASK2) | (((tmp) & KMASK1) << 4);
1108 aux[1] = (aux[1] & KMASK2) | (((tmp >> 2) & KMASK1) << 4);
1109
1110 let scales: &mut [i8] =
1112 unsafe { std::slice::from_raw_parts_mut(aux.as_mut_ptr() as *mut i8, 16) };
1113
1114 let d_all = block.d.to_f32();
1115 let mut m = 1;
1116 let mut is = 0;
1117
1118 for (y, qs) in y.chunks_exact_mut(128).zip(block.qs.chunks_exact(32)) {
1122 let mut shift = 0;
1123 for shift_scoped_y in y.chunks_exact_mut(32) {
1124 for (scale_index, scale_scoped_y) in
1125 shift_scoped_y.chunks_exact_mut(16).enumerate()
1126 {
1127 let dl = d_all * (scales[is] as f32 - 32.0);
1128 for (i, inner_y) in scale_scoped_y.iter_mut().enumerate() {
1129 let new_y = dl
1130 * (((qs[i + 16 * scale_index] >> shift) & 3) as i8
1131 - if (block.hmask[i + 16 * scale_index] & m) == 0 {
1132 4
1133 } else {
1134 0
1135 }) as f32;
1136 *inner_y = new_y;
1137 }
1138 is += 1;
1140 }
1141 shift += 2;
1143 m <<= 1;
1144 }
1145 }
1146 }
1147
1148 Ok(())
1149 }
1150}
1151
1152impl GgmlType for BlockQ4K {
1153 const DTYPE: GgmlDType = GgmlDType::Q4K;
1154 const BLCK_SIZE: usize = QK_K;
1155 type VecDotType = BlockQ8K;
1156
1157 #[allow(unreachable_code)]
1158 fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
1159 #[cfg(target_feature = "avx")]
1160 return super::avx::vec_dot_q4k_q8k(n, xs, ys);
1161
1162 #[cfg(target_feature = "neon")]
1163 return super::neon::vec_dot_q4k_q8k(n, xs, ys);
1164
1165 #[cfg(target_feature = "simd128")]
1166 return super::simd128::vec_dot_q4k_q8k(n, xs, ys);
1167
1168 Self::vec_dot_unopt(n, xs, ys)
1169 }
1170
1171 fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
1172 if n % QK_K != 0 {
1173 crate::bail!("vec_dot_q4k_q8k: {n} is not divisible by {QK_K}")
1174 }
1175
1176 const KMASK1: u32 = 0x3f3f3f3f;
1177 const KMASK2: u32 = 0x0f0f0f0f;
1178 const KMASK3: u32 = 0x03030303;
1179
1180 let mut utmp: [u32; 4] = [0; 4];
1181 let mut scales: [u8; 8] = [0; 8];
1182 let mut mins: [u8; 8] = [0; 8];
1183
1184 let mut aux8: [i8; QK_K] = [0; QK_K];
1185 let mut aux16: [i16; 8] = [0; 8];
1186 let mut sums: [f32; 8] = [0.0; 8];
1187 let mut aux32: [i32; 8] = [0; 8];
1188
1189 let mut sumf = 0.0;
1190 for (y, x) in ys.iter().zip(xs.iter()) {
1191 let q4 = &x.qs;
1192 let q8 = &y.qs;
1193 aux32.fill(0);
1194
1195 let mut a = &mut aux8[..];
1196 let mut q4 = &q4[..];
1197 for _ in 0..QK_K / 64 {
1198 for l in 0..32 {
1199 a[l] = (q4[l] & 0xF) as i8;
1200 }
1201 a = &mut a[32..];
1202 for l in 0..32 {
1203 a[l] = (q4[l] >> 4) as i8;
1204 }
1205 a = &mut a[32..];
1206 q4 = &q4[32..];
1207 }
1208
1209 LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
1210
1211 utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
1212 let uaux = utmp[1] & KMASK1;
1213 utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
1214 utmp[2] = uaux;
1215 utmp[0] &= KMASK1;
1216
1217 LittleEndian::write_u32_into(&utmp[0..2], &mut scales);
1219 LittleEndian::write_u32_into(&utmp[2..4], &mut mins);
1220
1221 let mut sumi = 0;
1222 for j in 0..QK_K / 16 {
1223 sumi += y.bsums[j] as i32 * mins[j / 2] as i32;
1224 }
1225
1226 let mut a = &mut aux8[..];
1227 let mut q8 = &q8[..];
1228
1229 for scale in scales {
1230 let scale = scale as i32;
1231 for _ in 0..4 {
1232 for l in 0..8 {
1233 aux16[l] = q8[l] as i16 * a[l] as i16;
1234 }
1235 for l in 0..8 {
1236 aux32[l] += scale * aux16[l] as i32;
1237 }
1238 q8 = &q8[8..];
1239 a = &mut a[8..];
1240 }
1241 }
1242 let d = x.d.to_f32() * y.d;
1243 for l in 0..8 {
1244 sums[l] += d * aux32[l] as f32;
1245 }
1246 let dmin = x.dmin.to_f32() * y.d;
1247 sumf -= dmin * sumi as f32;
1248 }
1249 Ok(sumf + sums.iter().sum::<f32>())
1250 }
1251
1252 fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
1253 for (block, x) in group_for_quantization(xs, ys)? {
1254 let mut mins: [f32; QK_K / 32] = [0.0; QK_K / 32];
1255 let mut scales: [f32; QK_K / 32] = [0.0; QK_K / 32];
1256
1257 for (j, x_scale_slice) in x.chunks_exact(32).enumerate() {
1258 (scales[j], mins[j]) = make_qkx1_quants(15, 5, x_scale_slice);
1259 }
1260
1261 let max_scale = scales.iter().fold(0.0, |max, &val| val.max(max));
1263 let max_min = mins.iter().fold(0.0, |max, &val| val.max(max));
1264
1265 let inv_scale = if max_scale > 0.0 {
1266 63.0 / max_scale
1267 } else {
1268 0.0
1269 };
1270 let inv_min = if max_min > 0.0 { 63.0 / max_min } else { 0.0 };
1271
1272 for j in 0..QK_K / 32 {
1273 let ls = nearest_int(inv_scale * scales[j]).min(63) as u8;
1274 let lm = nearest_int(inv_min * mins[j]).min(63) as u8;
1275 if j < 4 {
1276 block.scales[j] = ls;
1277 block.scales[j + 4] = lm;
1278 } else {
1279 block.scales[j + 4] = (ls & 0xF) | ((lm & 0xF) << 4);
1280 block.scales[j - 4] |= (ls >> 4) << 6;
1281 block.scales[j] |= (lm >> 4) << 6;
1282 }
1283 }
1284
1285 block.d = f16::from_f32(max_scale / 63.0);
1286 block.dmin = f16::from_f32(max_min / 63.0);
1287
1288 let mut l: [u8; QK_K] = [0; QK_K];
1289
1290 for j in 0..QK_K / 32 {
1291 let (sc, m) = get_scale_min_k4(j, &block.scales);
1292 let d = block.d.to_f32() * sc as f32;
1293 if d != 0.0 {
1294 let dm = block.dmin.to_f32() * m as f32;
1295 for ii in 0..32 {
1296 let l_val = nearest_int((x[32 * j + ii] + dm) / d);
1297 l[32 * j + ii] = l_val.clamp(0, 15) as u8;
1298 }
1299 }
1300 }
1301
1302 let q = &mut block.qs;
1303 for j in (0..QK_K).step_by(64) {
1304 for l_val in 0..32 {
1305 let offset_index = (j / 64) * 32 + l_val;
1306 q[offset_index] = l[j + l_val] | (l[j + l_val + 32] << 4);
1307 }
1308 }
1309 }
1310 Ok(())
1311 }
1312 fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
1314 for (block, y) in group_for_dequantization(xs, ys)? {
1315 let d = block.d.to_f32();
1316 let min = block.dmin.to_f32();
1317 let q = &block.qs;
1318 let mut is = 0;
1319 let mut ys_index = 0;
1320
1321 for j in (0..QK_K).step_by(64) {
1322 let q = &q[j / 2..j / 2 + 32];
1323 let (sc, m) = get_scale_min_k4(is, &block.scales);
1324 let d1 = d * sc as f32;
1325 let m1 = min * m as f32;
1326 let (sc, m) = get_scale_min_k4(is + 1, &block.scales);
1327 let d2 = d * sc as f32;
1328 let m2 = min * m as f32;
1329 for q in q {
1330 y[ys_index] = d1 * (q & 0xF) as f32 - m1;
1331 ys_index += 1;
1332 }
1333 for q in q {
1334 y[ys_index] = d2 * (q >> 4) as f32 - m2;
1335 ys_index += 1;
1336 }
1337 is += 2;
1338 }
1339 }
1340 Ok(())
1341 }
1342}
1343
1344impl GgmlType for BlockQ5K {
1346 const DTYPE: GgmlDType = GgmlDType::Q5K;
1347 const BLCK_SIZE: usize = QK_K;
1348 type VecDotType = BlockQ8K;
1349
1350 #[allow(unreachable_code)]
1351 fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
1352 #[cfg(target_feature = "avx")]
1353 return super::avx::vec_dot_q5k_q8k(n, xs, ys);
1354
1355 #[cfg(target_feature = "neon")]
1356 return super::neon::vec_dot_q5k_q8k(n, xs, ys);
1357
1358 Self::vec_dot_unopt(n, xs, ys)
1359 }
1360
1361 fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
1362 if n % QK_K != 0 {
1363 crate::bail!("vec_dot_q5k_q8k: {n} is not divisible by {QK_K}")
1364 }
1365
1366 const KMASK1: u32 = 0x3f3f3f3f;
1367 const KMASK2: u32 = 0x0f0f0f0f;
1368 const KMASK3: u32 = 0x03030303;
1369
1370 let mut utmp: [u32; 4] = [0; 4];
1371 let mut scales: [u8; 8] = [0; 8];
1372 let mut mins: [u8; 8] = [0; 8];
1373
1374 let mut aux8: [i8; QK_K] = [0; QK_K];
1375 let mut aux16: [i16; 8] = [0; 8];
1376 let mut sums: [f32; 8] = [0.0; 8];
1377 let mut aux32: [i32; 8] = [0; 8];
1378
1379 let mut sumf = 0.0;
1380 for (y, x) in ys.iter().zip(xs.iter()) {
1381 let q5 = &x.qs;
1382 let hm = &x.qh;
1383 let q8 = &y.qs;
1384 aux32.fill(0);
1385
1386 let mut a = &mut aux8[..];
1387 let mut q5 = &q5[..];
1388 let mut m = 1u8;
1389
1390 for _ in 0..QK_K / 64 {
1391 for l in 0..32 {
1392 a[l] = (q5[l] & 0xF) as i8;
1393 a[l] += if hm[l] & m != 0 { 16 } else { 0 };
1394 }
1395 a = &mut a[32..];
1396 m <<= 1;
1397 for l in 0..32 {
1398 a[l] = (q5[l] >> 4) as i8;
1399 a[l] += if hm[l] & m != 0 { 16 } else { 0 };
1400 }
1401 a = &mut a[32..];
1402 m <<= 1;
1403 q5 = &q5[32..];
1404 }
1405
1406 LittleEndian::read_u32_into(&x.scales, &mut utmp[0..3]);
1407
1408 utmp[3] = ((utmp[2] >> 4) & KMASK2) | (((utmp[1] >> 6) & KMASK3) << 4);
1409 let uaux = utmp[1] & KMASK1;
1410 utmp[1] = (utmp[2] & KMASK2) | (((utmp[0] >> 6) & KMASK3) << 4);
1411 utmp[2] = uaux;
1412 utmp[0] &= KMASK1;
1413
1414 LittleEndian::write_u32_into(&utmp[0..2], &mut scales);
1416 LittleEndian::write_u32_into(&utmp[2..4], &mut mins);
1417
1418 let mut sumi = 0;
1419 for j in 0..QK_K / 16 {
1420 sumi += y.bsums[j] as i32 * mins[j / 2] as i32;
1421 }
1422
1423 let mut a = &mut aux8[..];
1424 let mut q8 = &q8[..];
1425
1426 for scale in scales {
1427 let scale = scale as i32;
1428 for _ in 0..4 {
1429 for l in 0..8 {
1430 aux16[l] = q8[l] as i16 * a[l] as i16;
1431 }
1432 for l in 0..8 {
1433 aux32[l] += scale * aux16[l] as i32;
1434 }
1435 q8 = &q8[8..];
1436 a = &mut a[8..];
1437 }
1438 }
1439 let d = x.d.to_f32() * y.d;
1440 for l in 0..8 {
1441 sums[l] += d * aux32[l] as f32;
1442 }
1443 let dmin = x.dmin.to_f32() * y.d;
1444 sumf -= dmin * sumi as f32;
1445 }
1446 Ok(sumf + sums.iter().sum::<f32>())
1447 }
1448
1449 fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
1451 for (block, x) in group_for_quantization(xs, ys)? {
1452 let mut mins: [f32; QK_K / 32] = [0.0; QK_K / 32];
1453 let mut scales: [f32; QK_K / 32] = [0.0; QK_K / 32];
1454
1455 for (j, x_scale_slice) in x.chunks_exact(32).enumerate() {
1456 (scales[j], mins[j]) = make_qkx1_quants(31, 5, x_scale_slice);
1457 }
1458
1459 let max_scale = scales.iter().fold(0.0, |max, &val| val.max(max));
1461 let max_min = mins.iter().fold(0.0, |max, &val| val.max(max));
1462
1463 let inv_scale = if max_scale > 0.0 {
1464 63.0 / max_scale
1465 } else {
1466 0.0
1467 };
1468 let inv_min = if max_min > 0.0 { 63.0 / max_min } else { 0.0 };
1469 for j in 0..QK_K / 32 {
1470 let ls = nearest_int(inv_scale * scales[j]).min(63) as u8;
1471 let lm = nearest_int(inv_min * mins[j]).min(63) as u8;
1472 if j < 4 {
1473 block.scales[j] = ls;
1474 block.scales[j + 4] = lm;
1475 } else {
1476 block.scales[j + 4] = (ls & 0xF) | ((lm & 0xF) << 4);
1477 block.scales[j - 4] |= (ls >> 4) << 6;
1478 block.scales[j] |= (lm >> 4) << 6;
1479 }
1480 }
1481 block.d = f16::from_f32(max_scale / 63.0);
1482 block.dmin = f16::from_f32(max_min / 63.0);
1483
1484 let mut l: [u8; QK_K] = [0; QK_K];
1485 for j in 0..QK_K / 32 {
1486 let (sc, m) = get_scale_min_k4(j, &block.scales);
1487 let d = block.d.to_f32() * sc as f32;
1488 if d == 0.0 {
1489 continue;
1490 }
1491 let dm = block.dmin.to_f32() * m as f32;
1492 for ii in 0..32 {
1493 let ll = nearest_int((x[32 * j + ii] + dm) / d);
1494 l[32 * j + ii] = ll.clamp(0, 31) as u8;
1495 }
1496 }
1497
1498 let qh = &mut block.qh;
1499 let ql = &mut block.qs;
1500 qh.fill(0);
1501
1502 let mut m1 = 1;
1503 let mut m2 = 2;
1504 for n in (0..QK_K).step_by(64) {
1505 let offset = (n / 64) * 32;
1506 for j in 0..32 {
1507 let mut l1 = l[n + j];
1508 if l1 > 15 {
1509 l1 -= 16;
1510 qh[j] |= m1;
1511 }
1512 let mut l2 = l[n + j + 32];
1513 if l2 > 15 {
1514 l2 -= 16;
1515 qh[j] |= m2;
1516 }
1517 ql[offset + j] = l1 | (l2 << 4);
1518 }
1519 m1 <<= 2;
1520 m2 <<= 2;
1521 }
1522 }
1523
1524 Ok(())
1525 }
1526
1527 fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
1529 for (block, y) in group_for_dequantization(xs, ys)? {
1530 let d = block.d.to_f32();
1531 let min = block.dmin.to_f32();
1532 let ql = &block.qs;
1533 let qh = &block.qh;
1534 let mut is = 0;
1535 let mut u1 = 1;
1536 let mut u2 = 2;
1537 let mut ys_index = 0;
1538
1539 for j in (0..QK_K).step_by(64) {
1540 let ql = &ql[j / 2..j / 2 + 32];
1541 let (sc, m) = get_scale_min_k4(is, &block.scales);
1542 let d1 = d * sc as f32;
1543 let m1 = min * m as f32;
1544 let (sc, m) = get_scale_min_k4(is + 1, &block.scales);
1545 let d2 = d * sc as f32;
1546 let m2 = min * m as f32;
1547 for (ql, qh) in ql.iter().zip(qh) {
1548 let to_add = if qh & u1 != 0 { 16f32 } else { 0f32 };
1549 y[ys_index] = d1 * ((ql & 0xF) as f32 + to_add) - m1;
1550 ys_index += 1;
1551 }
1552 for (ql, qh) in ql.iter().zip(qh) {
1553 let to_add = if qh & u2 != 0 { 16f32 } else { 0f32 };
1554 y[ys_index] = d2 * ((ql >> 4) as f32 + to_add) - m2;
1555 ys_index += 1;
1556 }
1557 is += 2;
1558 u1 <<= 2;
1559 u2 <<= 2;
1560 }
1561 }
1562 Ok(())
1563 }
1564}
1565
1566impl GgmlType for BlockQ6K {
1567 const DTYPE: GgmlDType = GgmlDType::Q6K;
1568 const BLCK_SIZE: usize = QK_K;
1569 type VecDotType = BlockQ8K;
1570
1571 #[allow(unreachable_code)]
1572 fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
1573 #[cfg(target_feature = "avx")]
1574 return super::avx::vec_dot_q6k_q8k(n, xs, ys);
1575
1576 #[cfg(target_feature = "neon")]
1577 return super::neon::vec_dot_q6k_q8k(n, xs, ys);
1578
1579 #[cfg(target_feature = "simd128")]
1580 return super::simd128::vec_dot_q6k_q8k(n, xs, ys);
1581
1582 Self::vec_dot_unopt(n, xs, ys)
1583 }
1584
1585 fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
1586 if n % QK_K != 0 {
1587 crate::bail!("vec_dot_q6k_q8k: {n} is not divisible by {QK_K}")
1588 }
1589
1590 let mut aux8 = [0i8; QK_K];
1591 let mut aux16 = [0i16; 8];
1592 let mut sums = [0f32; 8];
1593 let mut aux32 = [0f32; 8];
1594
1595 for (x, y) in xs.iter().zip(ys.iter()) {
1596 let q4 = &x.ql;
1597 let qh = &x.qh;
1598 let q8 = &y.qs;
1599 aux32.fill(0f32);
1600
1601 for j in (0..QK_K).step_by(128) {
1602 let aux8 = &mut aux8[j..];
1603 let q4 = &q4[j / 2..];
1604 let qh = &qh[j / 4..];
1605 for l in 0..32 {
1606 aux8[l] = (((q4[l] & 0xF) | ((qh[l] & 3) << 4)) as i32 - 32) as i8;
1607 aux8[l + 32] =
1608 (((q4[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) as i32 - 32) as i8;
1609 aux8[l + 64] = (((q4[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) as i32 - 32) as i8;
1610 aux8[l + 96] =
1611 (((q4[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) as i32 - 32) as i8;
1612 }
1613 }
1614
1615 for (j, &scale) in x.scales.iter().enumerate() {
1616 let scale = scale as f32;
1617 let q8 = &q8[16 * j..];
1618 let aux8 = &aux8[16 * j..];
1619 for l in 0..8 {
1620 aux16[l] = q8[l] as i16 * aux8[l] as i16;
1621 }
1622 for l in 0..8 {
1623 aux32[l] += scale * aux16[l] as f32
1624 }
1625 let q8 = &q8[8..];
1626 let aux8 = &aux8[8..];
1627 for l in 0..8 {
1628 aux16[l] = q8[l] as i16 * aux8[l] as i16;
1629 }
1630 for l in 0..8 {
1631 aux32[l] += scale * aux16[l] as f32
1632 }
1633 }
1634
1635 let d = x.d.to_f32() * y.d;
1636 for (sum, &a) in sums.iter_mut().zip(aux32.iter()) {
1637 *sum += a * d;
1638 }
1639 }
1640 Ok(sums.iter().sum())
1641 }
1642
1643 fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
1644 if xs.len() != ys.len() * Self::BLCK_SIZE {
1645 crate::bail!(
1646 "quantize_row_q6k: size mismatch {} {} {}",
1647 xs.len(),
1648 ys.len(),
1649 Self::BLCK_SIZE
1650 )
1651 }
1652 let mut l = [0i8; QK_K];
1653 let mut scales = [0f32; QK_K / 16];
1654 let mut x = xs.as_ptr();
1655 let l = l.as_mut_ptr();
1656 unsafe {
1657 for y in ys.iter_mut() {
1658 let mut max_scale = 0f32;
1659 let mut max_abs_scale = 0f32;
1660 for (ib, scale_) in scales.iter_mut().enumerate() {
1661 let scale = make_qx_quants(16, 32, x.add(16 * ib), l.add(16 * ib), 1);
1662 *scale_ = scale;
1663 let abs_scale = scale.abs();
1664 if abs_scale > max_abs_scale {
1665 max_abs_scale = abs_scale;
1666 max_scale = scale
1667 }
1668 }
1669
1670 let iscale = -128f32 / max_scale;
1671 y.d = f16::from_f32(1.0 / iscale);
1672
1673 for (y_scale, scale) in y.scales.iter_mut().zip(scales.iter()) {
1674 *y_scale = nearest_int(iscale * scale).min(127) as i8
1675 }
1676
1677 for (j, &y_scale) in y.scales.iter().enumerate() {
1678 let d = y.d.to_f32() * y_scale as f32;
1679 if d == 0. {
1680 continue;
1681 }
1682 for ii in 0..16 {
1683 let ll = nearest_int(*x.add(16 * j + ii) / d).clamp(-32, 31);
1684 *l.add(16 * j + ii) = (ll + 32) as i8
1685 }
1686 }
1687
1688 let mut ql = y.ql.as_mut_ptr();
1689 let mut qh = y.qh.as_mut_ptr();
1690
1691 for j in (0..QK_K).step_by(128) {
1692 for l_idx in 0..32 {
1693 let q1 = *l.add(j + l_idx) & 0xF;
1694 let q2 = *l.add(j + l_idx + 32) & 0xF;
1695 let q3 = *l.add(j + l_idx + 64) & 0xF;
1696 let q4 = *l.add(j + l_idx + 96) & 0xF;
1697 *ql.add(l_idx) = (q1 | (q3 << 4)) as u8;
1698 *ql.add(l_idx + 32) = (q2 | (q4 << 4)) as u8;
1699 *qh.add(l_idx) = ((*l.add(j + l_idx) >> 4)
1700 | ((*l.add(j + l_idx + 32) >> 4) << 2)
1701 | ((*l.add(j + l_idx + 64) >> 4) << 4)
1702 | ((*l.add(j + l_idx + 96) >> 4) << 6))
1703 as u8;
1704 }
1705 ql = ql.add(64);
1706 qh = qh.add(32);
1707 }
1708
1709 x = x.add(QK_K)
1710 }
1711 }
1712 Ok(())
1713 }
1714
1715 fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
1717 let k = ys.len();
1718 if k % QK_K != 0 {
1719 crate::bail!("dequantize_row_q6k: {k} is not divisible by {QK_K}")
1720 }
1721 for (idx_x, x) in xs.iter().enumerate() {
1722 let d = x.d.to_f32();
1723 let ql = &x.ql;
1724 let qh = &x.qh;
1725 let sc = &x.scales;
1726 for n in (0..QK_K).step_by(128) {
1727 let idx = n / 128;
1728 let ys = &mut ys[idx_x * QK_K + n..];
1729 let sc = &sc[8 * idx..];
1730 let ql = &ql[64 * idx..];
1731 let qh = &qh[32 * idx..];
1732 for l in 0..32 {
1733 let is = l / 16;
1734 let q1 = ((ql[l] & 0xF) | ((qh[l] & 3) << 4)) as i8 - 32;
1735 let q2 = ((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) as i8 - 32;
1736 let q3 = ((ql[l] >> 4) | (((qh[l] >> 4) & 3) << 4)) as i8 - 32;
1737 let q4 = ((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) as i8 - 32;
1738 ys[l] = d * sc[is] as f32 * q1 as f32;
1739 ys[l + 32] = d * sc[is + 2] as f32 * q2 as f32;
1740 ys[l + 64] = d * sc[is + 4] as f32 * q3 as f32;
1741 ys[l + 96] = d * sc[is + 6] as f32 * q4 as f32;
1742 }
1743 }
1744 }
1745 Ok(())
1746 }
1747}
1748
1749impl GgmlType for BlockQ8K {
1750 const DTYPE: GgmlDType = GgmlDType::Q8K;
1751 const BLCK_SIZE: usize = QK_K;
1752 type VecDotType = BlockQ8K;
1753
1754 #[allow(unreachable_code)]
1755 fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
1756 #[cfg(target_feature = "avx")]
1757 return super::avx::vec_dot_q8k_q8k(n, xs, ys);
1758
1759 #[cfg(target_feature = "neon")]
1760 return super::neon::vec_dot_q8k_q8k(n, xs, ys);
1761
1762 #[cfg(target_feature = "simd128")]
1763 return super::simd128::vec_dot_q8k_q8k(n, xs, ys);
1764
1765 Self::vec_dot_unopt(n, xs, ys)
1766 }
1767
1768 fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
1769 let qk = QK_K;
1770 if n % QK_K != 0 {
1771 crate::bail!("vec_dot_q8k_q8k: {n} is not divisible by {qk}")
1772 }
1773
1774 let mut sumf = 0f32;
1776 for (xs, ys) in xs.iter().zip(ys.iter()) {
1777 let sum_i = xs
1778 .qs
1779 .iter()
1780 .zip(ys.qs.iter())
1781 .map(|(&x, &y)| x as i32 * y as i32)
1782 .sum::<i32>();
1783 sumf += sum_i as f32 * xs.d * ys.d
1784 }
1785 Ok(sumf)
1786 }
1787
1788 fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
1789 let k = xs.len();
1790 if k % QK_K != 0 {
1791 crate::bail!("quantize_row_q8k: {k} is not divisible by {QK_K}")
1792 }
1793 for (i, y) in ys.iter_mut().enumerate() {
1794 let mut max = 0f32;
1795 let mut amax = 0f32;
1796 let xs = &xs[i * QK_K..(i + 1) * QK_K];
1797 for &x in xs.iter() {
1798 if amax < x.abs() {
1799 amax = x.abs();
1800 max = x;
1801 }
1802 }
1803 if amax == 0f32 {
1804 y.d = 0f32;
1805 y.qs.fill(0)
1806 } else {
1807 let iscale = -128f32 / max;
1808 for (j, q) in y.qs.iter_mut().enumerate() {
1809 let v = (iscale * xs[j]).round();
1812 *q = v.min(127.) as i8
1813 }
1814 for j in 0..QK_K / 16 {
1815 let mut sum = 0i32;
1816 for ii in 0..16 {
1817 sum += y.qs[j * 16 + ii] as i32
1818 }
1819 y.bsums[j] = sum as i16
1820 }
1821 y.d = 1.0 / iscale
1822 }
1823 }
1824 Ok(())
1825 }
1826
1827 fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
1828 let k = ys.len();
1829 if k % QK_K != 0 {
1830 crate::bail!("dequantize_row_q8k: {k} is not divisible by {QK_K}")
1831 }
1832 for (i, x) in xs.iter().enumerate() {
1833 for (j, &q) in x.qs.iter().enumerate() {
1834 ys[i * QK_K + j] = x.d * q as f32
1835 }
1836 }
1837 Ok(())
1838 }
1839}
1840
1841pub fn matmul<T: GgmlType>(
1843 mkn: (usize, usize, usize),
1844 lhs: &[f32],
1845 rhs_t: &[T],
1846 dst: &mut [f32],
1847) -> Result<()> {
1848 let (m, k, n) = mkn;
1849 if m * k != lhs.len() {
1850 crate::bail!("unexpected lhs length {} {mkn:?}", lhs.len());
1851 }
1852
1853 let k_in_lhs_blocks = k.div_ceil(T::BLCK_SIZE);
1854 let k_in_rhs_blocks = k.div_ceil(T::VecDotType::BLCK_SIZE);
1855 let mut lhs_b = vec![T::VecDotType::zeros(); m * k_in_lhs_blocks];
1858 for row_idx in 0..m {
1859 let lhs_b = &mut lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks];
1860 let lhs = &lhs[row_idx * k..(row_idx + 1) * k];
1861 T::VecDotType::from_float(lhs, lhs_b)?
1862 }
1863 let lhs_b = lhs_b.as_slice();
1864
1865 for row_idx in 0..m {
1866 let lhs_row = &lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks];
1867 let dst_row = &mut dst[row_idx * n..(row_idx + 1) * n];
1868
1869 let result: Result<Vec<_>> = dst_row
1870 .into_par_iter()
1871 .enumerate()
1872 .with_min_len(128)
1873 .with_max_len(512)
1874 .map(|(col_idx, dst)| {
1875 let rhs_col = &rhs_t[col_idx * k_in_rhs_blocks..(col_idx + 1) * k_in_rhs_blocks];
1876 T::vec_dot(k, rhs_col, lhs_row).map(|value| *dst = value)
1877 })
1878 .collect();
1879
1880 result?;
1881 }
1882 Ok(())
1883}
1884
1885impl GgmlType for f32 {
1886 const DTYPE: GgmlDType = GgmlDType::F32;
1887 const BLCK_SIZE: usize = 1;
1888 type VecDotType = f32;
1889
1890 fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
1891 Self::vec_dot_unopt(n, xs, ys)
1892 }
1893
1894 fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
1895 if xs.len() < n {
1896 crate::bail!("size mismatch {} < {n}", xs.len())
1897 }
1898 if ys.len() < n {
1899 crate::bail!("size mismatch {} < {n}", ys.len())
1900 }
1901 let mut res = 0f32;
1902 unsafe { crate::cpu::vec_dot_f32(xs.as_ptr(), ys.as_ptr(), &mut res, n) };
1903 Ok(res)
1904 }
1905
1906 fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
1907 if xs.len() != ys.len() {
1908 crate::bail!("size mismatch {} {}", xs.len(), ys.len());
1909 }
1910 ys.copy_from_slice(xs);
1911 Ok(())
1912 }
1913
1914 fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
1915 if xs.len() != ys.len() {
1916 crate::bail!("size mismatch {} {}", xs.len(), ys.len());
1917 }
1918 ys.copy_from_slice(xs);
1919 Ok(())
1920 }
1921}
1922
1923impl GgmlType for f16 {
1924 const DTYPE: GgmlDType = GgmlDType::F16;
1925 const BLCK_SIZE: usize = 1;
1926 type VecDotType = f16;
1927
1928 fn vec_dot(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
1929 Self::vec_dot_unopt(n, xs, ys)
1930 }
1931
1932 fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
1933 if xs.len() < n {
1934 crate::bail!("size mismatch {} < {n}", xs.len())
1935 }
1936 if ys.len() < n {
1937 crate::bail!("size mismatch {} < {n}", ys.len())
1938 }
1939 let mut res = 0f32;
1940 unsafe { crate::cpu::vec_dot_f16(xs.as_ptr(), ys.as_ptr(), &mut res, n) };
1941 Ok(res)
1942 }
1943
1944 fn from_float(xs: &[f32], ys: &mut [Self]) -> Result<()> {
1945 if xs.len() != ys.len() {
1946 crate::bail!("size mismatch {} {}", xs.len(), ys.len());
1947 }
1948 for (x, y) in xs.iter().zip(ys.iter_mut()) {
1950 *y = f16::from_f32(*x)
1951 }
1952 Ok(())
1953 }
1954
1955 fn to_float(xs: &[Self], ys: &mut [f32]) -> Result<()> {
1956 if xs.len() != ys.len() {
1957 crate::bail!("size mismatch {} {}", xs.len(), ys.len());
1958 }
1959 for (x, y) in xs.iter().zip(ys.iter_mut()) {
1961 *y = x.to_f32()
1962 }
1963 Ok(())
1964 }
1965}