1use half::f16;
5
6use super::{AttnConfig, Backend};
7use ferrum_types::{FerrumError, Result};
8
9const Q4_K_QK: usize = 256;
15const Q4_K_SCALE_SIZE: usize = 12;
16const Q4_K_BLOCK_BYTES: usize = 4 + Q4_K_SCALE_SIZE + Q4_K_QK / 2; fn get_scale_min_k4(j: usize, q: &[u8]) -> (u8, u8) {
20 if j < 4 {
21 (q[j] & 63, q[j + 4] & 63)
22 } else {
23 let d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
24 let m = (q[j + 4] >> 4) | ((q[j] >> 6) << 4);
25 (d, m)
26 }
27}
28
29fn dequant_q4_k_cpu(bytes: &[u8], n_blocks: usize) -> Vec<f32> {
33 debug_assert_eq!(bytes.len(), n_blocks * Q4_K_BLOCK_BYTES);
34 let mut out = Vec::with_capacity(n_blocks * Q4_K_QK);
35 for b in 0..n_blocks {
36 let off = b * Q4_K_BLOCK_BYTES;
37 let d = f16::from_le_bytes([bytes[off], bytes[off + 1]]).to_f32();
38 let dmin = f16::from_le_bytes([bytes[off + 2], bytes[off + 3]]).to_f32();
39 let scales = &bytes[off + 4..off + 4 + Q4_K_SCALE_SIZE];
40 let qs = &bytes[off + 4 + Q4_K_SCALE_SIZE..off + Q4_K_BLOCK_BYTES];
41
42 let mut is = 0usize;
43 for j in (0..Q4_K_QK).step_by(64) {
44 let q_chunk = &qs[j / 2..j / 2 + 32];
45 let (sc1, mn1) = get_scale_min_k4(is, scales);
46 let d1 = d * sc1 as f32;
47 let m1 = dmin * mn1 as f32;
48 let (sc2, mn2) = get_scale_min_k4(is + 1, scales);
49 let d2 = d * sc2 as f32;
50 let m2 = dmin * mn2 as f32;
51 for q in q_chunk {
52 out.push(d1 * (q & 0xF) as f32 - m1);
53 }
54 for q in q_chunk {
55 out.push(d2 * (q >> 4) as f32 - m2);
56 }
57 is += 2;
58 }
59 }
60 out
61}
62
63pub struct CpuBackend;
64
65#[cfg(target_os = "macos")]
66unsafe extern "C" {
67 unsafe fn cblas_sgemm(
68 order: i32,
69 transa: i32,
70 transb: i32,
71 m: i32,
72 n: i32,
73 k: i32,
74 alpha: f32,
75 a: *const f32,
76 lda: i32,
77 b: *const f32,
78 ldb: i32,
79 beta: f32,
80 c: *mut f32,
81 ldc: i32,
82 );
83 fn vDSP_dotpr(
84 a: *const f32,
85 a_stride: i32,
86 b: *const f32,
87 b_stride: i32,
88 result: *mut f32,
89 n: u64,
90 );
91}
92
93pub struct CpuGptqStore {
96 pub weight_f32: Vec<f32>, pub k: usize,
98 pub n: usize,
99}
100
101pub enum CpuQuantStore {
109 Q4K {
110 weights: Vec<f32>, n_rows: usize,
112 n_cols: usize,
113 },
114}
115
116impl Backend for CpuBackend {
117 type Buffer = Vec<f32>;
118 type Context = ();
119 type Timer = crate::backend::timer::CpuTimer;
123 fn make_timer() -> Self::Timer {
124 crate::backend::timer::CpuTimer::new()
125 }
126
127 fn new_context() -> Self::Context {}
128 fn sync(_ctx: &mut Self::Context) {}
129 fn activation_elem_size_bytes() -> usize {
130 std::mem::size_of::<f32>()
131 }
132
133 fn alloc_typed(dtype: crate::backend::Dtype, n: usize) -> Self::Buffer {
137 let bytes = n * dtype.bytes_per_elem();
140 let f32_len = bytes.div_ceil(4);
141 vec![0.0f32; f32_len]
142 }
143
144 fn from_slice_typed<T: crate::backend::HostDtype>(data: &[T]) -> Self::Buffer {
147 let bytes = data.len() * std::mem::size_of::<T>();
148 let f32_len = bytes.div_ceil(4);
149 let mut out = vec![0.0f32; f32_len];
150 unsafe {
151 std::ptr::copy_nonoverlapping(
152 data.as_ptr() as *const u8,
153 out.as_mut_ptr() as *mut u8,
154 bytes,
155 );
156 }
157 out
158 }
159
160 fn write_typed<T: crate::backend::HostDtype>(
163 _ctx: &mut Self::Context,
164 dst: &mut Self::Buffer,
165 data: &[T],
166 ) {
167 let bytes = data.len() * std::mem::size_of::<T>();
168 debug_assert!(
169 bytes <= dst.len() * 4,
170 "CpuBackend::write_typed: src bytes {} > dst bytes {}",
171 bytes,
172 dst.len() * 4
173 );
174 unsafe {
175 std::ptr::copy_nonoverlapping(
176 data.as_ptr() as *const u8,
177 dst.as_mut_ptr() as *mut u8,
178 bytes,
179 );
180 }
181 }
182
183 fn fused_silu_mul_split_strided(
184 _ctx: &mut Self::Context,
185 gate_up: &Self::Buffer,
186 in_row_offset: usize,
187 out: &mut Self::Buffer,
188 out_row_offset: usize,
189 tokens: usize,
190 intermediate: usize,
191 ) {
192 let in_per_row = 2 * intermediate;
193 let in_start = in_row_offset * in_per_row;
194 let out_start = out_row_offset * intermediate;
195 for r in 0..tokens {
196 for c in 0..intermediate {
197 let g = gate_up[in_start + r * in_per_row + c];
198 let u = gate_up[in_start + r * in_per_row + intermediate + c];
199 let silu = g / (1.0 + (-g).exp());
200 out[out_start + r * intermediate + c] = silu * u;
201 }
202 }
203 }
204
205 fn gemm(
206 _ctx: &mut Self::Context,
207 a: &Self::Buffer,
208 b: &Self::Buffer,
209 out: &mut Self::Buffer,
210 m: usize,
211 n: usize,
212 k: usize,
213 ) {
214 assert!(
215 a.len() >= m * k,
216 "gemm: a too small len={} m={m} k={k}",
217 a.len()
218 );
219 assert!(
220 b.len() >= n * k,
221 "gemm: b too small len={} n={n} k={k}",
222 b.len()
223 );
224 assert!(
225 out.len() >= m * n,
226 "gemm: out too small len={} m={m} n={n}",
227 out.len()
228 );
229 #[cfg(target_os = "macos")]
230 unsafe {
231 cblas_sgemm(
232 101,
233 111,
234 112,
235 m as i32,
236 n as i32,
237 k as i32,
238 1.0,
239 a.as_ptr(),
240 k as i32,
241 b.as_ptr(),
242 k as i32,
243 0.0,
244 out.as_mut_ptr(),
245 n as i32,
246 );
247 }
248 #[cfg(not(target_os = "macos"))]
249 {
250 for i in 0..m {
251 for j in 0..n {
252 let mut sum = 0.0f64;
253 for p in 0..k {
254 sum += a[i * k + p] as f64 * b[j * k + p] as f64;
255 }
256 out[i * n + j] = sum as f32;
257 }
258 }
259 }
260 }
261
262 fn rms_norm(
263 _ctx: &mut Self::Context,
264 x: &Self::Buffer,
265 w: &Self::Buffer,
266 eps: f32,
267 out: &mut Self::Buffer,
268 tokens: usize,
269 dim: usize,
270 ) {
271 for t in 0..tokens {
272 let row = &x[t * dim..(t + 1) * dim];
273 let o = &mut out[t * dim..(t + 1) * dim];
274 let sum_sq = dot_product(row, row);
275 let inv = 1.0f32 / (sum_sq / dim as f32 + eps).sqrt();
276 for i in 0..dim {
277 o[i] = row[i] * inv * w[i];
278 }
279 }
280 }
281
282 fn fused_add_rms_norm(
283 _ctx: &mut Self::Context,
284 residual: &mut Self::Buffer,
285 x: &Self::Buffer,
286 w: &Self::Buffer,
287 eps: f32,
288 out: &mut Self::Buffer,
289 tokens: usize,
290 dim: usize,
291 ) {
292 for t in 0..tokens {
293 let off = t * dim;
294 for i in 0..dim {
295 residual[off + i] += x[off + i];
296 }
297 let row = &residual[off..off + dim];
298 let o = &mut out[off..off + dim];
299 let sum_sq = dot_product(row, row);
300 let inv = 1.0f32 / (sum_sq / dim as f32 + eps).sqrt();
301 for i in 0..dim {
302 o[i] = row[i] * inv * w[i];
303 }
304 }
305 }
306
307 fn flash_attention(
308 _ctx: &mut Self::Context,
309 q: &Self::Buffer,
310 k: &Self::Buffer,
311 v: &Self::Buffer,
312 out: &mut Self::Buffer,
313 batch: usize,
314 q_len: usize,
315 kv_len: usize,
316 pos_offset: usize,
317 cfg: &AttnConfig,
318 ) {
319 cpu_attention(
320 q, k, v, out, batch, q_len, kv_len, cfg.causal, pos_offset, cfg,
321 );
322 }
323
324 fn copy_slice(
325 _ctx: &mut Self::Context,
326 src: &Self::Buffer,
327 src_offset: usize,
328 dst: &mut Self::Buffer,
329 dst_offset: usize,
330 len: usize,
331 ) {
332 dst[dst_offset..dst_offset + len].copy_from_slice(&src[src_offset..src_offset + len]);
333 }
334
335 fn embedding_lookup(
336 _ctx: &mut Self::Context,
337 table: &Self::Buffer,
338 ids: &[u32],
339 out: &mut Self::Buffer,
340 dim: usize,
341 ) {
342 for (i, &id) in ids.iter().enumerate() {
343 let src = id as usize * dim;
344 out[i * dim..(i + 1) * dim].copy_from_slice(&table[src..src + dim]);
345 }
346 }
347
348 fn split_qkv(
349 _ctx: &mut Self::Context,
350 qkv: &Self::Buffer,
351 q: &mut Self::Buffer,
352 k: &mut Self::Buffer,
353 v: &mut Self::Buffer,
354 tokens: usize,
355 q_dim: usize,
356 kv_dim: usize,
357 ) {
358 let qkv_dim = q_dim + 2 * kv_dim;
359 for t in 0..tokens {
360 let base = t * qkv_dim;
361 q[t * q_dim..(t + 1) * q_dim].copy_from_slice(&qkv[base..base + q_dim]);
362 k[t * kv_dim..(t + 1) * kv_dim]
363 .copy_from_slice(&qkv[base + q_dim..base + q_dim + kv_dim]);
364 v[t * kv_dim..(t + 1) * kv_dim]
365 .copy_from_slice(&qkv[base + q_dim + kv_dim..base + qkv_dim]);
366 }
367 }
368
369 fn fused_silu_mul_split(
370 _ctx: &mut Self::Context,
371 gate_up: &Self::Buffer,
372 out: &mut Self::Buffer,
373 tokens: usize,
374 im: usize,
375 ) {
376 for t in 0..tokens {
377 for i in 0..im {
378 let g = gate_up[t * 2 * im + i];
379 let u = gate_up[t * 2 * im + im + i];
380 out[t * im + i] = (g / (1.0 + (-g).exp())) * u;
381 }
382 }
383 }
384
385 fn qk_norm_rope(
386 _ctx: &mut Self::Context,
387 input: &Self::Buffer,
388 norm_w: &Self::Buffer,
389 cos: &Self::Buffer,
390 sin: &Self::Buffer,
391 output: &mut Self::Buffer,
392 tokens: usize,
393 heads: usize,
394 head_dim: usize,
395 pos_offset: usize,
396 eps: f32,
397 mode: i32,
398 ) {
399 let half = head_dim / 2;
400 let cos_len = cos.len();
401 let sin_len = sin.len();
402 debug_assert_eq!(cos_len, sin_len);
403
404 for t in 0..tokens {
405 let pos = pos_offset + t;
406 for h in 0..heads {
407 let src_off = (t * heads + h) * head_dim;
409 let dst_off = (h * tokens + t) * head_dim;
411
412 if mode == 0 {
414 for i in 0..head_dim {
415 output[dst_off + i] = input[src_off + i];
416 }
417 continue;
418 }
419
420 let scale = if mode == 1 {
422 let mut sum_sq = 0.0f32;
423 for i in 0..head_dim {
424 sum_sq += input[src_off + i] * input[src_off + i];
425 }
426 1.0f32 / (sum_sq / head_dim as f32 + eps).sqrt()
427 } else {
428 1.0
429 };
430
431 if mode == 3 {
432 for i in 0..half {
434 let j = 2 * i;
435 let x0 = input[src_off + j];
436 let x1 = input[src_off + j + 1];
437 let c = cos[pos * half + i];
438 let s = sin[pos * half + i];
439 output[dst_off + j] = x0 * c - x1 * s;
440 output[dst_off + j + 1] = x1 * c + x0 * s;
441 }
442 } else {
443 for i in 0..half {
445 let (x0_raw, x1_raw) = (input[src_off + i], input[src_off + i + half]);
446 let (x0, x1) = if mode == 1 {
447 (
448 x0_raw * scale * norm_w[i],
449 x1_raw * scale * norm_w[i + half],
450 )
451 } else {
452 (x0_raw, x1_raw)
453 };
454 let c = cos[pos * half + i];
455 let s = sin[pos * half + i];
456 output[dst_off + i] = x0 * c - x1 * s;
457 output[dst_off + i + half] = x1 * c + x0 * s;
458 }
459 }
460 }
461 }
462 }
463
464 fn kv_cache_append_head_major(
465 _ctx: &mut Self::Context,
466 cache_k: &mut Self::Buffer,
467 cache_v: &mut Self::Buffer,
468 cache_len: usize,
469 cache_capacity: usize,
470 new_k_head_major: &Self::Buffer,
471 new_v_head_major: &Self::Buffer,
472 new_tokens: usize,
473 nkv: usize,
474 hd: usize,
475 ) {
476 debug_assert!(cache_len + new_tokens <= cache_capacity);
477 debug_assert_eq!(cache_k.len(), nkv * cache_capacity * hd);
478 debug_assert_eq!(cache_v.len(), nkv * cache_capacity * hd);
479 debug_assert!(new_k_head_major.len() >= nkv * new_tokens * hd);
484 debug_assert!(new_v_head_major.len() >= nkv * new_tokens * hd);
485
486 for h in 0..nkv {
487 let dst_base = h * cache_capacity * hd + cache_len * hd;
488 let src_base = h * new_tokens * hd;
489 cache_k[dst_base..dst_base + new_tokens * hd]
490 .copy_from_slice(&new_k_head_major[src_base..src_base + new_tokens * hd]);
491 cache_v[dst_base..dst_base + new_tokens * hd]
492 .copy_from_slice(&new_v_head_major[src_base..src_base + new_tokens * hd]);
493 }
494 }
495
496 fn transpose_head_to_token(
497 _ctx: &mut Self::Context,
498 src: &Self::Buffer,
499 dst: &mut Self::Buffer,
500 tokens: usize,
501 heads: usize,
502 dim: usize,
503 ) {
504 for h in 0..heads {
505 for t in 0..tokens {
506 let s = (h * tokens + t) * dim;
507 let d = (t * heads + h) * dim;
508 dst[d..d + dim].copy_from_slice(&src[s..s + dim]);
509 }
510 }
511 }
512
513 fn add_inplace(
514 _ctx: &mut Self::Context,
515 residual: &mut Self::Buffer,
516 x: &Self::Buffer,
517 len: usize,
518 ) {
519 for i in 0..len {
520 residual[i] += x[i];
521 }
522 }
523
524 fn scaled_add_inplace(
525 _ctx: &mut Self::Context,
526 dst: &mut Self::Buffer,
527 src: &Self::Buffer,
528 scale: f32,
529 len: usize,
530 ) {
531 for i in 0..len {
532 dst[i] += scale * src[i];
533 }
534 }
535
536 fn add_bias(
537 _ctx: &mut Self::Context,
538 data: &mut Self::Buffer,
539 bias: &Self::Buffer,
540 rows: usize,
541 cols: usize,
542 ) {
543 debug_assert_eq!(bias.len(), cols);
544 for r in 0..rows {
545 let off = r * cols;
546 for c in 0..cols {
547 data[off + c] += bias[c];
548 }
549 }
550 }
551
552 fn layer_norm(
553 _ctx: &mut Self::Context,
554 x: &Self::Buffer,
555 gamma: &Self::Buffer,
556 beta: &Self::Buffer,
557 eps: f32,
558 out: &mut Self::Buffer,
559 tokens: usize,
560 dim: usize,
561 ) {
562 debug_assert_eq!(gamma.len(), dim);
563 debug_assert_eq!(beta.len(), dim);
564 for t in 0..tokens {
565 let off = t * dim;
566 let mut mean = 0.0f64;
568 for i in 0..dim {
569 mean += x[off + i] as f64;
570 }
571 mean /= dim as f64;
572 let mut var = 0.0f64;
573 for i in 0..dim {
574 let d = x[off + i] as f64 - mean;
575 var += d * d;
576 }
577 var /= dim as f64;
578 let inv = 1.0f32 / ((var as f32) + eps).sqrt();
579 let mean_f32 = mean as f32;
580 for i in 0..dim {
581 out[off + i] = (x[off + i] - mean_f32) * inv * gamma[i] + beta[i];
582 }
583 }
584 }
585
586 fn gelu(_ctx: &mut Self::Context, x: &Self::Buffer, out: &mut Self::Buffer, len: usize) {
587 for i in 0..len {
590 let xi = x[i];
591 out[i] = 0.5 * xi * (1.0 + libm_erf(xi / std::f32::consts::SQRT_2));
592 }
593 }
594
595 fn alloc(len: usize) -> Self::Buffer {
596 vec![0.0f32; len]
597 }
598 fn to_vec(buf: &Self::Buffer, len: usize) -> Vec<f32> {
599 buf[..len].to_vec()
600 }
601 fn from_slice(data: &[f32]) -> Self::Buffer {
602 data.to_vec()
603 }
604}
605
606fn dot_product(a: &[f32], b: &[f32]) -> f32 {
609 #[cfg(target_os = "macos")]
610 {
611 let mut result = 0.0f32;
612 unsafe {
613 vDSP_dotpr(a.as_ptr(), 1, b.as_ptr(), 1, &mut result, a.len() as u64);
614 }
615 result
616 }
617 #[cfg(not(target_os = "macos"))]
618 {
619 a.iter().zip(b).map(|(x, y)| x * y).sum()
620 }
621}
622
623#[allow(dead_code)]
624fn apply_rope_impl(
625 data: &mut [f32],
626 tokens: usize,
627 heads: usize,
628 head_dim: usize,
629 half: usize,
630 cos: &[f32],
631 sin: &[f32],
632 positions: &[u32],
633) {
634 for t in 0..tokens {
635 let pos = positions[t] as usize;
636 for h in 0..heads {
637 let base = t * heads * head_dim + h * head_dim;
638 for i in 0..half {
639 let c = cos[pos * half + i];
640 let s = sin[pos * half + i];
641 let x0 = data[base + i];
642 let x1 = data[base + half + i];
643 data[base + i] = x0 * c - x1 * s;
644 data[base + half + i] = x1 * c + x0 * s;
645 }
646 }
647 }
648}
649
650fn cpu_attention(
651 q: &[f32],
652 k: &[f32],
653 v: &[f32],
654 out: &mut [f32],
655 batch: usize,
656 q_len: usize,
657 kv_len: usize,
658 causal: bool,
659 pos_offset: usize,
660 cfg: &AttnConfig,
661) {
662 let nh = cfg.num_heads;
663 let nkv = cfg.num_kv_heads;
664 let d = cfg.head_dim;
665 let n_rep = nh / nkv;
666 let scale = cfg.scale;
667 let kv_stride = if cfg.kv_seq_stride > 0 {
672 cfg.kv_seq_stride
673 } else {
674 kv_len
675 };
676
677 for b in 0..batch {
678 for h in 0..nh {
679 let kv_h = h / n_rep;
680 let q_off = (b * nh + h) * q_len * d;
681 let k_off = (b * nkv + kv_h) * kv_stride * d;
682 let v_off = (b * nkv + kv_h) * kv_stride * d;
683 let o_off = (b * nh + h) * q_len * d;
684
685 for qi in 0..q_len {
686 let attend_end = if causal {
687 (pos_offset + qi + 1).min(kv_len)
688 } else {
689 kv_len
690 };
691 let attend_start = if causal && cfg.sliding_window > 0 {
692 attend_end.saturating_sub(cfg.sliding_window)
693 } else {
694 0
695 };
696 let mut max_score = f32::NEG_INFINITY;
697 let mut sum_exp = 0.0f32;
698 let mut acc = vec![0.0f32; d];
699
700 for ki in attend_start..attend_end {
701 let mut dot = 0.0f32;
702 for di in 0..d {
703 dot += q[q_off + qi * d + di] * k[k_off + ki * d + di];
704 }
705 let score = dot * scale;
706 if score > max_score {
707 let correction = (max_score - score).exp();
708 for di in 0..d {
709 acc[di] *= correction;
710 }
711 sum_exp *= correction;
712 max_score = score;
713 }
714 let w = (score - max_score).exp();
715 sum_exp += w;
716 for di in 0..d {
717 acc[di] += w * v[v_off + ki * d + di];
718 }
719 }
720
721 if sum_exp > 0.0 {
722 let inv = 1.0 / sum_exp;
723 for di in 0..d {
724 out[o_off + qi * d + di] = acc[di] * inv;
725 }
726 }
727 }
728 }
729 }
730}
731
732fn libm_erf(x: f32) -> f32 {
735 let sign = if x < 0.0 { -1.0 } else { 1.0 };
736 let x = x.abs();
737 let t = 1.0 / (1.0 + 0.3275911 * x);
738 let y = 1.0
739 - (((((1.061_405_4 * t - 1.453_152_1) * t) + 1.421_413_8) * t - 0.284_496_72) * t
740 + 0.254_829_6)
741 * t
742 * (-x * x).exp();
743 sign * y
744}
745
746impl crate::backend::BackendGraph for CpuBackend {}
748
749impl crate::backend::BackendCollective for CpuBackend {}
751
752fn cpu_dequant_gptq(
755 qweight: &[i32],
756 scales: &[f32],
757 qzeros: &[i32],
758 bits: u32,
759 group_size: usize,
760 k: usize,
761 n: usize,
762) -> Result<Vec<f32>> {
763 if bits != 4 {
764 return Err(FerrumError::unsupported(format!(
765 "CPU GPTQ: only bits=4 supported (got {bits})"
766 )));
767 }
768 let mut w = vec![0.0f32; n * k];
769 let packed_rows = k / 8;
770 for pr in 0..packed_rows {
771 for col in 0..n {
772 let packed = qweight[pr * n + col] as u32;
773 for bi in 0..8 {
774 let ki = pr * 8 + bi;
775 let q = ((packed >> (bi * 4)) & 0xF) as i32;
776 let grp = ki / group_size;
777 let scale = scales[grp * n + col];
778 let z_packed = qzeros[grp * (n / 8) + (col / 8)] as u32;
779 let zero = (((z_packed >> ((col % 8) * 4)) & 0xF) as i32) + 1;
780 let val = (q - zero) as f32 * scale;
781 w[col * k + ki] = val;
782 }
783 }
784 }
785 Ok(w)
786}
787
788impl crate::backend::BackendQuantMarlin for CpuBackend {
789 fn load_gptq(
790 qweight: &[i32],
791 scales: &[f32],
792 qzeros: &[i32],
793 _g_idx: Option<&[i32]>,
794 bias_host: Option<&[f32]>,
795 bits: u32,
796 group_size: usize,
797 k: usize,
798 n: usize,
799 ) -> Result<Box<dyn crate::Linear<Self> + Send + Sync>> {
800 let w = cpu_dequant_gptq(qweight, scales, qzeros, bits, group_size, k, n)?;
801 Ok(Box::new(crate::quant_linear::cpu_dequant::CpuGptqLinear {
805 weight_f32: w,
806 bias: bias_host.map(|b| b.to_vec()),
807 in_features: k,
808 out_features: n,
809 }))
810 }
811 fn load_gptq_stacked(
812 qweights: &[&[i32]],
813 scales: &[&[f32]],
814 qzeros: &[&[i32]],
815 _g_idx: Option<&[i32]>,
816 bits: u32,
817 group_size: usize,
818 k: usize,
819 n_per_expert: usize,
820 ) -> Result<std::sync::Arc<dyn crate::MarlinExpertStack<Self>>> {
821 let num_experts = qweights.len();
824 if scales.len() != num_experts || qzeros.len() != num_experts {
825 return Err(FerrumError::model(format!(
826 "load_gptq_stacked: input slice lengths disagree (qw {num_experts}, sc {}, qz {})",
827 scales.len(),
828 qzeros.len()
829 )));
830 }
831 let total_n = num_experts * n_per_expert;
832 let mut all_w = Vec::with_capacity(total_n * k);
833 for ((qw_e, sc_e), qz_e) in qweights.iter().zip(scales.iter()).zip(qzeros.iter()) {
834 let w_e = cpu_dequant_gptq(qw_e, sc_e, qz_e, bits, group_size, k, n_per_expert)?;
835 all_w.extend_from_slice(&w_e);
836 }
837 let store = std::sync::Arc::new(CpuGptqStore {
838 weight_f32: all_w,
839 k,
840 n: total_n,
841 });
842 Ok(std::sync::Arc::new(
843 crate::quant_linear::cpu_marlin_stack::CpuMarlinExpertStack::new(
844 store,
845 num_experts,
846 n_per_expert,
847 k,
848 ),
849 ))
850 }
851 }
858
859#[allow(clippy::too_many_arguments)]
863pub(crate) fn cpu_gemm_gptq_with_offset_strided(
864 _ctx: &mut <CpuBackend as Backend>::Context,
865 input: &<CpuBackend as Backend>::Buffer,
866 in_row_offset: usize,
867 weight: &CpuGptqStore,
868 expert_offset: usize,
869 expert_n: usize,
870 output: &mut <CpuBackend as Backend>::Buffer,
871 out_row_offset: usize,
872 m: usize,
873 k: usize,
874) -> Result<()> {
875 if expert_offset + expert_n > weight.n {
876 return Err(FerrumError::model(format!(
877 "cpu_gemm_gptq_with_offset_strided OOB: offset {expert_offset} + n {expert_n} > stacked_n {}",
878 weight.n
879 )));
880 }
881 if k != weight.k {
882 return Err(FerrumError::model(format!(
883 "cpu_gemm_gptq_with_offset_strided k mismatch: arg {k} vs weight.k {}",
884 weight.k
885 )));
886 }
887 let in_start = in_row_offset * k;
888 let in_end = (in_row_offset + m) * k;
889 let out_start = out_row_offset * expert_n;
890 let out_end = (out_row_offset + m) * expert_n;
891 let row_start = expert_offset * k;
892 let row_end = (expert_offset + expert_n) * k;
893 let weight_slice = weight.weight_f32[row_start..row_end].to_vec();
894 let in_slice = input[in_start..in_end].to_vec();
895 let mut out_slice = vec![0.0f32; m * expert_n];
896 let mut ctx_local = ();
897 CpuBackend::gemm(
898 &mut ctx_local,
899 &in_slice,
900 &weight_slice,
901 &mut out_slice,
902 m,
903 expert_n,
904 k,
905 );
906 output[out_start..out_end].copy_from_slice(&out_slice);
907 Ok(())
908}
909
910impl crate::backend::BackendQuantGguf for CpuBackend {
911 fn load_quant(
912 kind: super::GgufQuantType,
913 bytes: &[u8],
914 n_rows: usize,
915 n_cols: usize,
916 ) -> Result<Box<dyn crate::Linear<Self> + Send + Sync>> {
917 use super::GgufQuantType;
918 let store = match kind {
919 GgufQuantType::Q4K => {
920 let total_elems = n_rows * n_cols;
921 if total_elems % Q4_K_QK != 0 {
922 return Err(FerrumError::model(format!(
923 "load_quant Q4K: elements {total_elems} not a multiple of {Q4_K_QK}"
924 )));
925 }
926 let n_blocks = total_elems / Q4_K_QK;
927 let expected = n_blocks * Q4_K_BLOCK_BYTES;
928 if bytes.len() != expected {
929 return Err(FerrumError::model(format!(
930 "load_quant Q4K: bytes {} != expected {} ({n_blocks} × {Q4_K_BLOCK_BYTES})",
931 bytes.len(),
932 expected
933 )));
934 }
935 CpuQuantStore::Q4K {
936 weights: dequant_q4_k_cpu(bytes, n_blocks),
937 n_rows,
938 n_cols,
939 }
940 }
941 other => {
942 return Err(FerrumError::unsupported(format!(
943 "CPU load_quant: {other:?} not yet implemented"
944 )));
945 }
946 };
947 Ok(Box::new(crate::quant_linear::cpu_gguf::CpuGgufLinear {
950 store,
951 in_features: n_cols,
952 out_features: n_rows,
953 }))
954 }
955}
956
957impl crate::backend::BackendPagedKv for CpuBackend {}
959
960impl crate::backend::BackendMoeFused for CpuBackend {}
962
963impl crate::backend::BackendKvDtype<crate::backend::KvFp16> for CpuBackend {
965 type KvBuffer = <Self as crate::backend::Backend>::Buffer;
966 type KvScales = ();
967}