1use half::f16;
5
6use super::{AttnConfig, Backend};
7use ferrum_types::{FerrumError, Result};
8
9const Q4_K_QK: usize = 256;
15const Q4_K_SCALE_SIZE: usize = 12;
16const Q4_K_BLOCK_BYTES: usize = 4 + Q4_K_SCALE_SIZE + Q4_K_QK / 2; fn get_scale_min_k4(j: usize, q: &[u8]) -> (u8, u8) {
20 if j < 4 {
21 (q[j] & 63, q[j + 4] & 63)
22 } else {
23 let d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
24 let m = (q[j + 4] >> 4) | ((q[j] >> 6) << 4);
25 (d, m)
26 }
27}
28
29fn dequant_q4_k_cpu(bytes: &[u8], n_blocks: usize) -> Vec<f32> {
33 debug_assert_eq!(bytes.len(), n_blocks * Q4_K_BLOCK_BYTES);
34 let mut out = Vec::with_capacity(n_blocks * Q4_K_QK);
35 for b in 0..n_blocks {
36 let off = b * Q4_K_BLOCK_BYTES;
37 let d = f16::from_le_bytes([bytes[off], bytes[off + 1]]).to_f32();
38 let dmin = f16::from_le_bytes([bytes[off + 2], bytes[off + 3]]).to_f32();
39 let scales = &bytes[off + 4..off + 4 + Q4_K_SCALE_SIZE];
40 let qs = &bytes[off + 4 + Q4_K_SCALE_SIZE..off + Q4_K_BLOCK_BYTES];
41
42 let mut is = 0usize;
43 for j in (0..Q4_K_QK).step_by(64) {
44 let q_chunk = &qs[j / 2..j / 2 + 32];
45 let (sc1, mn1) = get_scale_min_k4(is, scales);
46 let d1 = d * sc1 as f32;
47 let m1 = dmin * mn1 as f32;
48 let (sc2, mn2) = get_scale_min_k4(is + 1, scales);
49 let d2 = d * sc2 as f32;
50 let m2 = dmin * mn2 as f32;
51 for q in q_chunk {
52 out.push(d1 * (q & 0xF) as f32 - m1);
53 }
54 for q in q_chunk {
55 out.push(d2 * (q >> 4) as f32 - m2);
56 }
57 is += 2;
58 }
59 }
60 out
61}
62
63pub struct CpuBackend;
64
65#[cfg(target_os = "macos")]
66unsafe extern "C" {
67 unsafe fn cblas_sgemm(
68 order: i32,
69 transa: i32,
70 transb: i32,
71 m: i32,
72 n: i32,
73 k: i32,
74 alpha: f32,
75 a: *const f32,
76 lda: i32,
77 b: *const f32,
78 ldb: i32,
79 beta: f32,
80 c: *mut f32,
81 ldc: i32,
82 );
83 fn vDSP_dotpr(
84 a: *const f32,
85 a_stride: i32,
86 b: *const f32,
87 b_stride: i32,
88 result: *mut f32,
89 n: u64,
90 );
91}
92
93pub struct CpuGptqStore {
96 pub weight_f32: Vec<f32>, pub k: usize,
98 pub n: usize,
99}
100
101pub enum CpuQuantStore {
109 Q4K {
110 weights: Vec<f32>, n_rows: usize,
112 n_cols: usize,
113 },
114}
115
116impl Backend for CpuBackend {
117 type Buffer = Vec<f32>;
118 type Context = ();
119 type Timer = crate::backend::timer::CpuTimer;
123 fn make_timer() -> Self::Timer {
124 crate::backend::timer::CpuTimer::new()
125 }
126
127 fn new_context() -> Self::Context {}
128 fn sync(_ctx: &mut Self::Context) {}
129
130 fn alloc_typed(dtype: crate::backend::Dtype, n: usize) -> Self::Buffer {
134 let bytes = n * dtype.bytes_per_elem();
137 let f32_len = bytes.div_ceil(4);
138 vec![0.0f32; f32_len]
139 }
140
141 fn from_slice_typed<T: crate::backend::HostDtype>(data: &[T]) -> Self::Buffer {
144 let bytes = data.len() * std::mem::size_of::<T>();
145 let f32_len = bytes.div_ceil(4);
146 let mut out = vec![0.0f32; f32_len];
147 unsafe {
148 std::ptr::copy_nonoverlapping(
149 data.as_ptr() as *const u8,
150 out.as_mut_ptr() as *mut u8,
151 bytes,
152 );
153 }
154 out
155 }
156
157 fn write_typed<T: crate::backend::HostDtype>(
160 _ctx: &mut Self::Context,
161 dst: &mut Self::Buffer,
162 data: &[T],
163 ) {
164 let bytes = data.len() * std::mem::size_of::<T>();
165 debug_assert!(
166 bytes <= dst.len() * 4,
167 "CpuBackend::write_typed: src bytes {} > dst bytes {}",
168 bytes,
169 dst.len() * 4
170 );
171 unsafe {
172 std::ptr::copy_nonoverlapping(
173 data.as_ptr() as *const u8,
174 dst.as_mut_ptr() as *mut u8,
175 bytes,
176 );
177 }
178 }
179
180 fn fused_silu_mul_split_strided(
181 _ctx: &mut Self::Context,
182 gate_up: &Self::Buffer,
183 in_row_offset: usize,
184 out: &mut Self::Buffer,
185 out_row_offset: usize,
186 tokens: usize,
187 intermediate: usize,
188 ) {
189 let in_per_row = 2 * intermediate;
190 let in_start = in_row_offset * in_per_row;
191 let out_start = out_row_offset * intermediate;
192 for r in 0..tokens {
193 for c in 0..intermediate {
194 let g = gate_up[in_start + r * in_per_row + c];
195 let u = gate_up[in_start + r * in_per_row + intermediate + c];
196 let silu = g / (1.0 + (-g).exp());
197 out[out_start + r * intermediate + c] = silu * u;
198 }
199 }
200 }
201
202 fn gemm(
203 _ctx: &mut Self::Context,
204 a: &Self::Buffer,
205 b: &Self::Buffer,
206 out: &mut Self::Buffer,
207 m: usize,
208 n: usize,
209 k: usize,
210 ) {
211 assert!(
212 a.len() >= m * k,
213 "gemm: a too small len={} m={m} k={k}",
214 a.len()
215 );
216 assert!(
217 b.len() >= n * k,
218 "gemm: b too small len={} n={n} k={k}",
219 b.len()
220 );
221 assert!(
222 out.len() >= m * n,
223 "gemm: out too small len={} m={m} n={n}",
224 out.len()
225 );
226 #[cfg(target_os = "macos")]
227 unsafe {
228 cblas_sgemm(
229 101,
230 111,
231 112,
232 m as i32,
233 n as i32,
234 k as i32,
235 1.0,
236 a.as_ptr(),
237 k as i32,
238 b.as_ptr(),
239 k as i32,
240 0.0,
241 out.as_mut_ptr(),
242 n as i32,
243 );
244 }
245 #[cfg(not(target_os = "macos"))]
246 {
247 for i in 0..m {
248 for j in 0..n {
249 let mut sum = 0.0f64;
250 for p in 0..k {
251 sum += a[i * k + p] as f64 * b[j * k + p] as f64;
252 }
253 out[i * n + j] = sum as f32;
254 }
255 }
256 }
257 }
258
259 fn rms_norm(
260 _ctx: &mut Self::Context,
261 x: &Self::Buffer,
262 w: &Self::Buffer,
263 eps: f32,
264 out: &mut Self::Buffer,
265 tokens: usize,
266 dim: usize,
267 ) {
268 for t in 0..tokens {
269 let row = &x[t * dim..(t + 1) * dim];
270 let o = &mut out[t * dim..(t + 1) * dim];
271 let sum_sq = dot_product(row, row);
272 let inv = 1.0f32 / (sum_sq / dim as f32 + eps).sqrt();
273 for i in 0..dim {
274 o[i] = row[i] * inv * w[i];
275 }
276 }
277 }
278
279 fn fused_add_rms_norm(
280 _ctx: &mut Self::Context,
281 residual: &mut Self::Buffer,
282 x: &Self::Buffer,
283 w: &Self::Buffer,
284 eps: f32,
285 out: &mut Self::Buffer,
286 tokens: usize,
287 dim: usize,
288 ) {
289 for t in 0..tokens {
290 let off = t * dim;
291 for i in 0..dim {
292 residual[off + i] += x[off + i];
293 }
294 let row = &residual[off..off + dim];
295 let o = &mut out[off..off + dim];
296 let sum_sq = dot_product(row, row);
297 let inv = 1.0f32 / (sum_sq / dim as f32 + eps).sqrt();
298 for i in 0..dim {
299 o[i] = row[i] * inv * w[i];
300 }
301 }
302 }
303
304 fn flash_attention(
305 _ctx: &mut Self::Context,
306 q: &Self::Buffer,
307 k: &Self::Buffer,
308 v: &Self::Buffer,
309 out: &mut Self::Buffer,
310 batch: usize,
311 q_len: usize,
312 kv_len: usize,
313 pos_offset: usize,
314 cfg: &AttnConfig,
315 ) {
316 cpu_attention(
317 q, k, v, out, batch, q_len, kv_len, cfg.causal, pos_offset, cfg,
318 );
319 }
320
321 fn copy_slice(
322 _ctx: &mut Self::Context,
323 src: &Self::Buffer,
324 src_offset: usize,
325 dst: &mut Self::Buffer,
326 dst_offset: usize,
327 len: usize,
328 ) {
329 dst[dst_offset..dst_offset + len].copy_from_slice(&src[src_offset..src_offset + len]);
330 }
331
332 fn embedding_lookup(
333 _ctx: &mut Self::Context,
334 table: &Self::Buffer,
335 ids: &[u32],
336 out: &mut Self::Buffer,
337 dim: usize,
338 ) {
339 for (i, &id) in ids.iter().enumerate() {
340 let src = id as usize * dim;
341 out[i * dim..(i + 1) * dim].copy_from_slice(&table[src..src + dim]);
342 }
343 }
344
345 fn split_qkv(
346 _ctx: &mut Self::Context,
347 qkv: &Self::Buffer,
348 q: &mut Self::Buffer,
349 k: &mut Self::Buffer,
350 v: &mut Self::Buffer,
351 tokens: usize,
352 q_dim: usize,
353 kv_dim: usize,
354 ) {
355 let qkv_dim = q_dim + 2 * kv_dim;
356 for t in 0..tokens {
357 let base = t * qkv_dim;
358 q[t * q_dim..(t + 1) * q_dim].copy_from_slice(&qkv[base..base + q_dim]);
359 k[t * kv_dim..(t + 1) * kv_dim]
360 .copy_from_slice(&qkv[base + q_dim..base + q_dim + kv_dim]);
361 v[t * kv_dim..(t + 1) * kv_dim]
362 .copy_from_slice(&qkv[base + q_dim + kv_dim..base + qkv_dim]);
363 }
364 }
365
366 fn fused_silu_mul_split(
367 _ctx: &mut Self::Context,
368 gate_up: &Self::Buffer,
369 out: &mut Self::Buffer,
370 tokens: usize,
371 im: usize,
372 ) {
373 for t in 0..tokens {
374 for i in 0..im {
375 let g = gate_up[t * 2 * im + i];
376 let u = gate_up[t * 2 * im + im + i];
377 out[t * im + i] = (g / (1.0 + (-g).exp())) * u;
378 }
379 }
380 }
381
382 fn qk_norm_rope(
383 _ctx: &mut Self::Context,
384 input: &Self::Buffer,
385 norm_w: &Self::Buffer,
386 cos: &Self::Buffer,
387 sin: &Self::Buffer,
388 output: &mut Self::Buffer,
389 tokens: usize,
390 heads: usize,
391 head_dim: usize,
392 pos_offset: usize,
393 eps: f32,
394 mode: i32,
395 ) {
396 let half = head_dim / 2;
397 let cos_len = cos.len();
398 let sin_len = sin.len();
399 debug_assert_eq!(cos_len, sin_len);
400
401 for t in 0..tokens {
402 let pos = pos_offset + t;
403 for h in 0..heads {
404 let src_off = (t * heads + h) * head_dim;
406 let dst_off = (h * tokens + t) * head_dim;
408
409 if mode == 0 {
411 for i in 0..head_dim {
412 output[dst_off + i] = input[src_off + i];
413 }
414 continue;
415 }
416
417 let scale = if mode == 1 {
419 let mut sum_sq = 0.0f32;
420 for i in 0..head_dim {
421 sum_sq += input[src_off + i] * input[src_off + i];
422 }
423 1.0f32 / (sum_sq / head_dim as f32 + eps).sqrt()
424 } else {
425 1.0
426 };
427
428 if mode == 3 {
429 for i in 0..half {
431 let j = 2 * i;
432 let x0 = input[src_off + j];
433 let x1 = input[src_off + j + 1];
434 let c = cos[pos * half + i];
435 let s = sin[pos * half + i];
436 output[dst_off + j] = x0 * c - x1 * s;
437 output[dst_off + j + 1] = x1 * c + x0 * s;
438 }
439 } else {
440 for i in 0..half {
442 let (x0_raw, x1_raw) = (input[src_off + i], input[src_off + i + half]);
443 let (x0, x1) = if mode == 1 {
444 (
445 x0_raw * scale * norm_w[i],
446 x1_raw * scale * norm_w[i + half],
447 )
448 } else {
449 (x0_raw, x1_raw)
450 };
451 let c = cos[pos * half + i];
452 let s = sin[pos * half + i];
453 output[dst_off + i] = x0 * c - x1 * s;
454 output[dst_off + i + half] = x1 * c + x0 * s;
455 }
456 }
457 }
458 }
459 }
460
461 fn kv_cache_append_head_major(
462 _ctx: &mut Self::Context,
463 cache_k: &mut Self::Buffer,
464 cache_v: &mut Self::Buffer,
465 cache_len: usize,
466 cache_capacity: usize,
467 new_k_head_major: &Self::Buffer,
468 new_v_head_major: &Self::Buffer,
469 new_tokens: usize,
470 nkv: usize,
471 hd: usize,
472 ) {
473 debug_assert!(cache_len + new_tokens <= cache_capacity);
474 debug_assert_eq!(cache_k.len(), nkv * cache_capacity * hd);
475 debug_assert_eq!(cache_v.len(), nkv * cache_capacity * hd);
476 debug_assert!(new_k_head_major.len() >= nkv * new_tokens * hd);
481 debug_assert!(new_v_head_major.len() >= nkv * new_tokens * hd);
482
483 for h in 0..nkv {
484 let dst_base = h * cache_capacity * hd + cache_len * hd;
485 let src_base = h * new_tokens * hd;
486 cache_k[dst_base..dst_base + new_tokens * hd]
487 .copy_from_slice(&new_k_head_major[src_base..src_base + new_tokens * hd]);
488 cache_v[dst_base..dst_base + new_tokens * hd]
489 .copy_from_slice(&new_v_head_major[src_base..src_base + new_tokens * hd]);
490 }
491 }
492
493 fn transpose_head_to_token(
494 _ctx: &mut Self::Context,
495 src: &Self::Buffer,
496 dst: &mut Self::Buffer,
497 tokens: usize,
498 heads: usize,
499 dim: usize,
500 ) {
501 for h in 0..heads {
502 for t in 0..tokens {
503 let s = (h * tokens + t) * dim;
504 let d = (t * heads + h) * dim;
505 dst[d..d + dim].copy_from_slice(&src[s..s + dim]);
506 }
507 }
508 }
509
510 fn add_inplace(
511 _ctx: &mut Self::Context,
512 residual: &mut Self::Buffer,
513 x: &Self::Buffer,
514 len: usize,
515 ) {
516 for i in 0..len {
517 residual[i] += x[i];
518 }
519 }
520
521 fn scaled_add_inplace(
522 _ctx: &mut Self::Context,
523 dst: &mut Self::Buffer,
524 src: &Self::Buffer,
525 scale: f32,
526 len: usize,
527 ) {
528 for i in 0..len {
529 dst[i] += scale * src[i];
530 }
531 }
532
533 fn add_bias(
534 _ctx: &mut Self::Context,
535 data: &mut Self::Buffer,
536 bias: &Self::Buffer,
537 rows: usize,
538 cols: usize,
539 ) {
540 debug_assert_eq!(bias.len(), cols);
541 for r in 0..rows {
542 let off = r * cols;
543 for c in 0..cols {
544 data[off + c] += bias[c];
545 }
546 }
547 }
548
549 fn layer_norm(
550 _ctx: &mut Self::Context,
551 x: &Self::Buffer,
552 gamma: &Self::Buffer,
553 beta: &Self::Buffer,
554 eps: f32,
555 out: &mut Self::Buffer,
556 tokens: usize,
557 dim: usize,
558 ) {
559 debug_assert_eq!(gamma.len(), dim);
560 debug_assert_eq!(beta.len(), dim);
561 for t in 0..tokens {
562 let off = t * dim;
563 let mut mean = 0.0f64;
565 for i in 0..dim {
566 mean += x[off + i] as f64;
567 }
568 mean /= dim as f64;
569 let mut var = 0.0f64;
570 for i in 0..dim {
571 let d = x[off + i] as f64 - mean;
572 var += d * d;
573 }
574 var /= dim as f64;
575 let inv = 1.0f32 / ((var as f32) + eps).sqrt();
576 let mean_f32 = mean as f32;
577 for i in 0..dim {
578 out[off + i] = (x[off + i] - mean_f32) * inv * gamma[i] + beta[i];
579 }
580 }
581 }
582
583 fn gelu(_ctx: &mut Self::Context, x: &Self::Buffer, out: &mut Self::Buffer, len: usize) {
584 for i in 0..len {
587 let xi = x[i];
588 out[i] = 0.5 * xi * (1.0 + libm_erf(xi / std::f32::consts::SQRT_2));
589 }
590 }
591
592 fn alloc(len: usize) -> Self::Buffer {
593 vec![0.0f32; len]
594 }
595 fn to_vec(buf: &Self::Buffer, len: usize) -> Vec<f32> {
596 buf[..len].to_vec()
597 }
598 fn from_slice(data: &[f32]) -> Self::Buffer {
599 data.to_vec()
600 }
601}
602
603fn dot_product(a: &[f32], b: &[f32]) -> f32 {
606 #[cfg(target_os = "macos")]
607 {
608 let mut result = 0.0f32;
609 unsafe {
610 vDSP_dotpr(a.as_ptr(), 1, b.as_ptr(), 1, &mut result, a.len() as u64);
611 }
612 result
613 }
614 #[cfg(not(target_os = "macos"))]
615 {
616 a.iter().zip(b).map(|(x, y)| x * y).sum()
617 }
618}
619
620#[allow(dead_code)]
621fn apply_rope_impl(
622 data: &mut [f32],
623 tokens: usize,
624 heads: usize,
625 head_dim: usize,
626 half: usize,
627 cos: &[f32],
628 sin: &[f32],
629 positions: &[u32],
630) {
631 for t in 0..tokens {
632 let pos = positions[t] as usize;
633 for h in 0..heads {
634 let base = t * heads * head_dim + h * head_dim;
635 for i in 0..half {
636 let c = cos[pos * half + i];
637 let s = sin[pos * half + i];
638 let x0 = data[base + i];
639 let x1 = data[base + half + i];
640 data[base + i] = x0 * c - x1 * s;
641 data[base + half + i] = x1 * c + x0 * s;
642 }
643 }
644 }
645}
646
647fn cpu_attention(
648 q: &[f32],
649 k: &[f32],
650 v: &[f32],
651 out: &mut [f32],
652 batch: usize,
653 q_len: usize,
654 kv_len: usize,
655 causal: bool,
656 pos_offset: usize,
657 cfg: &AttnConfig,
658) {
659 let nh = cfg.num_heads;
660 let nkv = cfg.num_kv_heads;
661 let d = cfg.head_dim;
662 let n_rep = nh / nkv;
663 let scale = cfg.scale;
664 let kv_stride = if cfg.kv_seq_stride > 0 {
669 cfg.kv_seq_stride
670 } else {
671 kv_len
672 };
673
674 for b in 0..batch {
675 for h in 0..nh {
676 let kv_h = h / n_rep;
677 let q_off = (b * nh + h) * q_len * d;
678 let k_off = (b * nkv + kv_h) * kv_stride * d;
679 let v_off = (b * nkv + kv_h) * kv_stride * d;
680 let o_off = (b * nh + h) * q_len * d;
681
682 for qi in 0..q_len {
683 let attend_end = if causal {
684 (pos_offset + qi + 1).min(kv_len)
685 } else {
686 kv_len
687 };
688 let attend_start = if causal && cfg.sliding_window > 0 {
689 attend_end.saturating_sub(cfg.sliding_window)
690 } else {
691 0
692 };
693 let mut max_score = f32::NEG_INFINITY;
694 let mut sum_exp = 0.0f32;
695 let mut acc = vec![0.0f32; d];
696
697 for ki in attend_start..attend_end {
698 let mut dot = 0.0f32;
699 for di in 0..d {
700 dot += q[q_off + qi * d + di] * k[k_off + ki * d + di];
701 }
702 let score = dot * scale;
703 if score > max_score {
704 let correction = (max_score - score).exp();
705 for di in 0..d {
706 acc[di] *= correction;
707 }
708 sum_exp *= correction;
709 max_score = score;
710 }
711 let w = (score - max_score).exp();
712 sum_exp += w;
713 for di in 0..d {
714 acc[di] += w * v[v_off + ki * d + di];
715 }
716 }
717
718 if sum_exp > 0.0 {
719 let inv = 1.0 / sum_exp;
720 for di in 0..d {
721 out[o_off + qi * d + di] = acc[di] * inv;
722 }
723 }
724 }
725 }
726 }
727}
728
729fn libm_erf(x: f32) -> f32 {
732 let sign = if x < 0.0 { -1.0 } else { 1.0 };
733 let x = x.abs();
734 let t = 1.0 / (1.0 + 0.3275911 * x);
735 let y = 1.0
736 - (((((1.061_405_4 * t - 1.453_152_1) * t) + 1.421_413_8) * t - 0.284_496_72) * t
737 + 0.254_829_6)
738 * t
739 * (-x * x).exp();
740 sign * y
741}
742
743impl crate::backend::BackendGraph for CpuBackend {}
745
746impl crate::backend::BackendCollective for CpuBackend {}
748
749fn cpu_dequant_gptq(
752 qweight: &[i32],
753 scales: &[f32],
754 qzeros: &[i32],
755 bits: u32,
756 group_size: usize,
757 k: usize,
758 n: usize,
759) -> Result<Vec<f32>> {
760 if bits != 4 {
761 return Err(FerrumError::unsupported(format!(
762 "CPU GPTQ: only bits=4 supported (got {bits})"
763 )));
764 }
765 let mut w = vec![0.0f32; n * k];
766 let packed_rows = k / 8;
767 for pr in 0..packed_rows {
768 for col in 0..n {
769 let packed = qweight[pr * n + col] as u32;
770 for bi in 0..8 {
771 let ki = pr * 8 + bi;
772 let q = ((packed >> (bi * 4)) & 0xF) as i32;
773 let grp = ki / group_size;
774 let scale = scales[grp * n + col];
775 let z_packed = qzeros[grp * (n / 8) + (col / 8)] as u32;
776 let zero = (((z_packed >> ((col % 8) * 4)) & 0xF) as i32) + 1;
777 let val = (q - zero) as f32 * scale;
778 w[col * k + ki] = val;
779 }
780 }
781 }
782 Ok(w)
783}
784
785impl crate::backend::BackendQuantMarlin for CpuBackend {
786 fn load_gptq(
787 qweight: &[i32],
788 scales: &[f32],
789 qzeros: &[i32],
790 _g_idx: Option<&[i32]>,
791 bias_host: Option<&[f32]>,
792 bits: u32,
793 group_size: usize,
794 k: usize,
795 n: usize,
796 ) -> Result<Box<dyn crate::Linear<Self> + Send + Sync>> {
797 let w = cpu_dequant_gptq(qweight, scales, qzeros, bits, group_size, k, n)?;
798 Ok(Box::new(crate::quant_linear::cpu_dequant::CpuGptqLinear {
802 weight_f32: w,
803 bias: bias_host.map(|b| b.to_vec()),
804 in_features: k,
805 out_features: n,
806 }))
807 }
808 fn load_gptq_stacked(
809 qweights: &[&[i32]],
810 scales: &[&[f32]],
811 qzeros: &[&[i32]],
812 _g_idx: Option<&[i32]>,
813 bits: u32,
814 group_size: usize,
815 k: usize,
816 n_per_expert: usize,
817 ) -> Result<std::sync::Arc<dyn crate::MarlinExpertStack<Self>>> {
818 let num_experts = qweights.len();
821 if scales.len() != num_experts || qzeros.len() != num_experts {
822 return Err(FerrumError::model(format!(
823 "load_gptq_stacked: input slice lengths disagree (qw {num_experts}, sc {}, qz {})",
824 scales.len(),
825 qzeros.len()
826 )));
827 }
828 let total_n = num_experts * n_per_expert;
829 let mut all_w = Vec::with_capacity(total_n * k);
830 for ((qw_e, sc_e), qz_e) in qweights.iter().zip(scales.iter()).zip(qzeros.iter()) {
831 let w_e = cpu_dequant_gptq(qw_e, sc_e, qz_e, bits, group_size, k, n_per_expert)?;
832 all_w.extend_from_slice(&w_e);
833 }
834 let store = std::sync::Arc::new(CpuGptqStore {
835 weight_f32: all_w,
836 k,
837 n: total_n,
838 });
839 Ok(std::sync::Arc::new(
840 crate::quant_linear::cpu_marlin_stack::CpuMarlinExpertStack::new(
841 store,
842 num_experts,
843 n_per_expert,
844 k,
845 ),
846 ))
847 }
848 }
855
856#[allow(clippy::too_many_arguments)]
860pub(crate) fn cpu_gemm_gptq_with_offset_strided(
861 _ctx: &mut <CpuBackend as Backend>::Context,
862 input: &<CpuBackend as Backend>::Buffer,
863 in_row_offset: usize,
864 weight: &CpuGptqStore,
865 expert_offset: usize,
866 expert_n: usize,
867 output: &mut <CpuBackend as Backend>::Buffer,
868 out_row_offset: usize,
869 m: usize,
870 k: usize,
871) -> Result<()> {
872 if expert_offset + expert_n > weight.n {
873 return Err(FerrumError::model(format!(
874 "cpu_gemm_gptq_with_offset_strided OOB: offset {expert_offset} + n {expert_n} > stacked_n {}",
875 weight.n
876 )));
877 }
878 if k != weight.k {
879 return Err(FerrumError::model(format!(
880 "cpu_gemm_gptq_with_offset_strided k mismatch: arg {k} vs weight.k {}",
881 weight.k
882 )));
883 }
884 let in_start = in_row_offset * k;
885 let in_end = (in_row_offset + m) * k;
886 let out_start = out_row_offset * expert_n;
887 let out_end = (out_row_offset + m) * expert_n;
888 let row_start = expert_offset * k;
889 let row_end = (expert_offset + expert_n) * k;
890 let weight_slice = weight.weight_f32[row_start..row_end].to_vec();
891 let in_slice = input[in_start..in_end].to_vec();
892 let mut out_slice = vec![0.0f32; m * expert_n];
893 let mut ctx_local = ();
894 CpuBackend::gemm(
895 &mut ctx_local,
896 &in_slice,
897 &weight_slice,
898 &mut out_slice,
899 m,
900 expert_n,
901 k,
902 );
903 output[out_start..out_end].copy_from_slice(&out_slice);
904 Ok(())
905}
906
907impl crate::backend::BackendQuantGguf for CpuBackend {
908 fn load_quant(
909 kind: super::GgufQuantType,
910 bytes: &[u8],
911 n_rows: usize,
912 n_cols: usize,
913 ) -> Result<Box<dyn crate::Linear<Self> + Send + Sync>> {
914 use super::GgufQuantType;
915 let store = match kind {
916 GgufQuantType::Q4K => {
917 let total_elems = n_rows * n_cols;
918 if total_elems % Q4_K_QK != 0 {
919 return Err(FerrumError::model(format!(
920 "load_quant Q4K: elements {total_elems} not a multiple of {Q4_K_QK}"
921 )));
922 }
923 let n_blocks = total_elems / Q4_K_QK;
924 let expected = n_blocks * Q4_K_BLOCK_BYTES;
925 if bytes.len() != expected {
926 return Err(FerrumError::model(format!(
927 "load_quant Q4K: bytes {} != expected {} ({n_blocks} × {Q4_K_BLOCK_BYTES})",
928 bytes.len(),
929 expected
930 )));
931 }
932 CpuQuantStore::Q4K {
933 weights: dequant_q4_k_cpu(bytes, n_blocks),
934 n_rows,
935 n_cols,
936 }
937 }
938 other => {
939 return Err(FerrumError::unsupported(format!(
940 "CPU load_quant: {other:?} not yet implemented"
941 )));
942 }
943 };
944 Ok(Box::new(crate::quant_linear::cpu_gguf::CpuGgufLinear {
947 store,
948 in_features: n_cols,
949 out_features: n_rows,
950 }))
951 }
952}
953
954impl crate::backend::BackendPagedKv for CpuBackend {}
956
957impl crate::backend::BackendMoeFused for CpuBackend {}
959
960impl crate::backend::BackendKvDtype<crate::backend::KvFp16> for CpuBackend {
962 type KvBuffer = <Self as crate::backend::Backend>::Buffer;
963 type KvScales = ();
964}