1use half::f16;
5
6use super::{AttnConfig, Backend};
7use ferrum_types::{FerrumError, Result};
8
9const Q4_K_QK: usize = 256;
15const Q4_K_SCALE_SIZE: usize = 12;
16const Q4_K_BLOCK_BYTES: usize = 4 + Q4_K_SCALE_SIZE + Q4_K_QK / 2; fn get_scale_min_k4(j: usize, q: &[u8]) -> (u8, u8) {
20 if j < 4 {
21 (q[j] & 63, q[j + 4] & 63)
22 } else {
23 let d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
24 let m = (q[j + 4] >> 4) | ((q[j] >> 6) << 4);
25 (d, m)
26 }
27}
28
29fn dequant_q4_k_cpu(bytes: &[u8], n_blocks: usize) -> Vec<f32> {
33 debug_assert_eq!(bytes.len(), n_blocks * Q4_K_BLOCK_BYTES);
34 let mut out = Vec::with_capacity(n_blocks * Q4_K_QK);
35 for b in 0..n_blocks {
36 let off = b * Q4_K_BLOCK_BYTES;
37 let d = f16::from_le_bytes([bytes[off], bytes[off + 1]]).to_f32();
38 let dmin = f16::from_le_bytes([bytes[off + 2], bytes[off + 3]]).to_f32();
39 let scales = &bytes[off + 4..off + 4 + Q4_K_SCALE_SIZE];
40 let qs = &bytes[off + 4 + Q4_K_SCALE_SIZE..off + Q4_K_BLOCK_BYTES];
41
42 let mut is = 0usize;
43 for j in (0..Q4_K_QK).step_by(64) {
44 let q_chunk = &qs[j / 2..j / 2 + 32];
45 let (sc1, mn1) = get_scale_min_k4(is, scales);
46 let d1 = d * sc1 as f32;
47 let m1 = dmin * mn1 as f32;
48 let (sc2, mn2) = get_scale_min_k4(is + 1, scales);
49 let d2 = d * sc2 as f32;
50 let m2 = dmin * mn2 as f32;
51 for q in q_chunk {
52 out.push(d1 * (q & 0xF) as f32 - m1);
53 }
54 for q in q_chunk {
55 out.push(d2 * (q >> 4) as f32 - m2);
56 }
57 is += 2;
58 }
59 }
60 out
61}
62
63pub struct CpuBackend;
64
65#[cfg(target_os = "macos")]
66unsafe extern "C" {
67 unsafe fn cblas_sgemm(
68 order: i32,
69 transa: i32,
70 transb: i32,
71 m: i32,
72 n: i32,
73 k: i32,
74 alpha: f32,
75 a: *const f32,
76 lda: i32,
77 b: *const f32,
78 ldb: i32,
79 beta: f32,
80 c: *mut f32,
81 ldc: i32,
82 );
83 fn vDSP_dotpr(
84 a: *const f32,
85 a_stride: i32,
86 b: *const f32,
87 b_stride: i32,
88 result: *mut f32,
89 n: u64,
90 );
91}
92
93pub struct CpuGptqStore {
96 pub weight_f32: Vec<f32>, pub k: usize,
98 pub n: usize,
99}
100
101pub enum CpuQuantStore {
109 Q4K {
110 weights: Vec<f32>, n_rows: usize,
112 n_cols: usize,
113 },
114}
115
116impl Backend for CpuBackend {
117 type Buffer = Vec<f32>;
118 type Context = ();
119 type Timer = crate::backend::timer::CpuTimer;
123 fn make_timer() -> Self::Timer {
124 crate::backend::timer::CpuTimer::new()
125 }
126
127 fn new_context() -> Self::Context {}
128 fn sync(_ctx: &mut Self::Context) {}
129
130 fn alloc_typed(dtype: crate::backend::Dtype, n: usize) -> Self::Buffer {
134 let bytes = n * dtype.bytes_per_elem();
137 let f32_len = bytes.div_ceil(4);
138 vec![0.0f32; f32_len]
139 }
140
141 fn from_slice_typed<T: crate::backend::HostDtype>(data: &[T]) -> Self::Buffer {
144 let bytes = data.len() * std::mem::size_of::<T>();
145 let f32_len = bytes.div_ceil(4);
146 let mut out = vec![0.0f32; f32_len];
147 unsafe {
148 std::ptr::copy_nonoverlapping(
149 data.as_ptr() as *const u8,
150 out.as_mut_ptr() as *mut u8,
151 bytes,
152 );
153 }
154 out
155 }
156
157 fn write_typed<T: crate::backend::HostDtype>(
160 _ctx: &mut Self::Context,
161 dst: &mut Self::Buffer,
162 data: &[T],
163 ) {
164 let bytes = data.len() * std::mem::size_of::<T>();
165 debug_assert!(
166 bytes <= dst.len() * 4,
167 "CpuBackend::write_typed: src bytes {} > dst bytes {}",
168 bytes,
169 dst.len() * 4
170 );
171 unsafe {
172 std::ptr::copy_nonoverlapping(
173 data.as_ptr() as *const u8,
174 dst.as_mut_ptr() as *mut u8,
175 bytes,
176 );
177 }
178 }
179
180 fn fused_silu_mul_split_strided(
181 _ctx: &mut Self::Context,
182 gate_up: &Self::Buffer,
183 in_row_offset: usize,
184 out: &mut Self::Buffer,
185 out_row_offset: usize,
186 tokens: usize,
187 intermediate: usize,
188 ) {
189 let in_per_row = 2 * intermediate;
190 let in_start = in_row_offset * in_per_row;
191 let out_start = out_row_offset * intermediate;
192 for r in 0..tokens {
193 for c in 0..intermediate {
194 let g = gate_up[in_start + r * in_per_row + c];
195 let u = gate_up[in_start + r * in_per_row + intermediate + c];
196 let silu = g / (1.0 + (-g).exp());
197 out[out_start + r * intermediate + c] = silu * u;
198 }
199 }
200 }
201
202 fn gemm(
203 _ctx: &mut Self::Context,
204 a: &Self::Buffer,
205 b: &Self::Buffer,
206 out: &mut Self::Buffer,
207 m: usize,
208 n: usize,
209 k: usize,
210 ) {
211 assert!(
212 a.len() >= m * k,
213 "gemm: a too small len={} m={m} k={k}",
214 a.len()
215 );
216 assert!(
217 b.len() >= n * k,
218 "gemm: b too small len={} n={n} k={k}",
219 b.len()
220 );
221 assert!(
222 out.len() >= m * n,
223 "gemm: out too small len={} m={m} n={n}",
224 out.len()
225 );
226 #[cfg(target_os = "macos")]
227 unsafe {
228 cblas_sgemm(
229 101,
230 111,
231 112,
232 m as i32,
233 n as i32,
234 k as i32,
235 1.0,
236 a.as_ptr(),
237 k as i32,
238 b.as_ptr(),
239 k as i32,
240 0.0,
241 out.as_mut_ptr(),
242 n as i32,
243 );
244 }
245 #[cfg(not(target_os = "macos"))]
246 {
247 for i in 0..m {
248 for j in 0..n {
249 let mut sum = 0.0f64;
250 for p in 0..k {
251 sum += a[i * k + p] as f64 * b[j * k + p] as f64;
252 }
253 out[i * n + j] = sum as f32;
254 }
255 }
256 }
257 }
258
259 fn rms_norm(
260 _ctx: &mut Self::Context,
261 x: &Self::Buffer,
262 w: &Self::Buffer,
263 eps: f32,
264 out: &mut Self::Buffer,
265 tokens: usize,
266 dim: usize,
267 ) {
268 for t in 0..tokens {
269 let row = &x[t * dim..(t + 1) * dim];
270 let o = &mut out[t * dim..(t + 1) * dim];
271 let sum_sq = dot_product(row, row);
272 let inv = 1.0f32 / (sum_sq / dim as f32 + eps).sqrt();
273 for i in 0..dim {
274 o[i] = row[i] * inv * w[i];
275 }
276 }
277 }
278
279 fn fused_add_rms_norm(
280 _ctx: &mut Self::Context,
281 residual: &mut Self::Buffer,
282 x: &Self::Buffer,
283 w: &Self::Buffer,
284 eps: f32,
285 out: &mut Self::Buffer,
286 tokens: usize,
287 dim: usize,
288 ) {
289 for t in 0..tokens {
290 let off = t * dim;
291 for i in 0..dim {
292 residual[off + i] += x[off + i];
293 }
294 let row = &residual[off..off + dim];
295 let o = &mut out[off..off + dim];
296 let sum_sq = dot_product(row, row);
297 let inv = 1.0f32 / (sum_sq / dim as f32 + eps).sqrt();
298 for i in 0..dim {
299 o[i] = row[i] * inv * w[i];
300 }
301 }
302 }
303
304 fn flash_attention(
305 _ctx: &mut Self::Context,
306 q: &Self::Buffer,
307 k: &Self::Buffer,
308 v: &Self::Buffer,
309 out: &mut Self::Buffer,
310 batch: usize,
311 q_len: usize,
312 kv_len: usize,
313 pos_offset: usize,
314 cfg: &AttnConfig,
315 ) {
316 cpu_attention(
317 q, k, v, out, batch, q_len, kv_len, cfg.causal, pos_offset, cfg,
318 );
319 }
320
321 fn copy_slice(
322 _ctx: &mut Self::Context,
323 src: &Self::Buffer,
324 src_offset: usize,
325 dst: &mut Self::Buffer,
326 dst_offset: usize,
327 len: usize,
328 ) {
329 dst[dst_offset..dst_offset + len].copy_from_slice(&src[src_offset..src_offset + len]);
330 }
331
332 fn embedding_lookup(
333 _ctx: &mut Self::Context,
334 table: &Self::Buffer,
335 ids: &[u32],
336 out: &mut Self::Buffer,
337 dim: usize,
338 ) {
339 for (i, &id) in ids.iter().enumerate() {
340 let src = id as usize * dim;
341 out[i * dim..(i + 1) * dim].copy_from_slice(&table[src..src + dim]);
342 }
343 }
344
345 fn split_qkv(
346 _ctx: &mut Self::Context,
347 qkv: &Self::Buffer,
348 q: &mut Self::Buffer,
349 k: &mut Self::Buffer,
350 v: &mut Self::Buffer,
351 tokens: usize,
352 q_dim: usize,
353 kv_dim: usize,
354 ) {
355 let qkv_dim = q_dim + 2 * kv_dim;
356 for t in 0..tokens {
357 let base = t * qkv_dim;
358 q[t * q_dim..(t + 1) * q_dim].copy_from_slice(&qkv[base..base + q_dim]);
359 k[t * kv_dim..(t + 1) * kv_dim]
360 .copy_from_slice(&qkv[base + q_dim..base + q_dim + kv_dim]);
361 v[t * kv_dim..(t + 1) * kv_dim]
362 .copy_from_slice(&qkv[base + q_dim + kv_dim..base + qkv_dim]);
363 }
364 }
365
366 fn fused_silu_mul_split(
367 _ctx: &mut Self::Context,
368 gate_up: &Self::Buffer,
369 out: &mut Self::Buffer,
370 tokens: usize,
371 im: usize,
372 ) {
373 for t in 0..tokens {
374 for i in 0..im {
375 let g = gate_up[t * 2 * im + i];
376 let u = gate_up[t * 2 * im + im + i];
377 out[t * im + i] = (g / (1.0 + (-g).exp())) * u;
378 }
379 }
380 }
381
382 fn qk_norm_rope(
383 _ctx: &mut Self::Context,
384 input: &Self::Buffer,
385 norm_w: &Self::Buffer,
386 cos: &Self::Buffer,
387 sin: &Self::Buffer,
388 output: &mut Self::Buffer,
389 tokens: usize,
390 heads: usize,
391 head_dim: usize,
392 pos_offset: usize,
393 eps: f32,
394 mode: i32,
395 ) {
396 let half = head_dim / 2;
397 let cos_len = cos.len();
398 let sin_len = sin.len();
399 debug_assert_eq!(cos_len, sin_len);
400
401 for t in 0..tokens {
402 let pos = pos_offset + t;
403 for h in 0..heads {
404 let src_off = (t * heads + h) * head_dim;
406 let dst_off = (h * tokens + t) * head_dim;
408
409 if mode == 0 {
411 for i in 0..head_dim {
412 output[dst_off + i] = input[src_off + i];
413 }
414 continue;
415 }
416
417 let scale = if mode == 1 {
419 let mut sum_sq = 0.0f32;
420 for i in 0..head_dim {
421 sum_sq += input[src_off + i] * input[src_off + i];
422 }
423 1.0f32 / (sum_sq / head_dim as f32 + eps).sqrt()
424 } else {
425 1.0
426 };
427
428 for i in 0..half {
430 let (x0_raw, x1_raw) = (input[src_off + i], input[src_off + i + half]);
431 let (x0, x1) = if mode == 1 {
432 (
433 x0_raw * scale * norm_w[i],
434 x1_raw * scale * norm_w[i + half],
435 )
436 } else {
437 (x0_raw, x1_raw)
438 };
439 let c = cos[pos * half + i];
440 let s = sin[pos * half + i];
441 output[dst_off + i] = x0 * c - x1 * s;
442 output[dst_off + i + half] = x1 * c + x0 * s;
443 }
444 }
445 }
446 }
447
448 fn kv_cache_append_head_major(
449 _ctx: &mut Self::Context,
450 cache_k: &mut Self::Buffer,
451 cache_v: &mut Self::Buffer,
452 cache_len: usize,
453 cache_capacity: usize,
454 new_k_head_major: &Self::Buffer,
455 new_v_head_major: &Self::Buffer,
456 new_tokens: usize,
457 nkv: usize,
458 hd: usize,
459 ) {
460 debug_assert!(cache_len + new_tokens <= cache_capacity);
461 debug_assert_eq!(cache_k.len(), nkv * cache_capacity * hd);
462 debug_assert_eq!(cache_v.len(), nkv * cache_capacity * hd);
463 debug_assert!(new_k_head_major.len() >= nkv * new_tokens * hd);
468 debug_assert!(new_v_head_major.len() >= nkv * new_tokens * hd);
469
470 for h in 0..nkv {
471 let dst_base = h * cache_capacity * hd + cache_len * hd;
472 let src_base = h * new_tokens * hd;
473 cache_k[dst_base..dst_base + new_tokens * hd]
474 .copy_from_slice(&new_k_head_major[src_base..src_base + new_tokens * hd]);
475 cache_v[dst_base..dst_base + new_tokens * hd]
476 .copy_from_slice(&new_v_head_major[src_base..src_base + new_tokens * hd]);
477 }
478 }
479
480 fn transpose_head_to_token(
481 _ctx: &mut Self::Context,
482 src: &Self::Buffer,
483 dst: &mut Self::Buffer,
484 tokens: usize,
485 heads: usize,
486 dim: usize,
487 ) {
488 for h in 0..heads {
489 for t in 0..tokens {
490 let s = (h * tokens + t) * dim;
491 let d = (t * heads + h) * dim;
492 dst[d..d + dim].copy_from_slice(&src[s..s + dim]);
493 }
494 }
495 }
496
497 fn add_inplace(
498 _ctx: &mut Self::Context,
499 residual: &mut Self::Buffer,
500 x: &Self::Buffer,
501 len: usize,
502 ) {
503 for i in 0..len {
504 residual[i] += x[i];
505 }
506 }
507
508 fn scaled_add_inplace(
509 _ctx: &mut Self::Context,
510 dst: &mut Self::Buffer,
511 src: &Self::Buffer,
512 scale: f32,
513 len: usize,
514 ) {
515 for i in 0..len {
516 dst[i] += scale * src[i];
517 }
518 }
519
520 fn add_bias(
521 _ctx: &mut Self::Context,
522 data: &mut Self::Buffer,
523 bias: &Self::Buffer,
524 rows: usize,
525 cols: usize,
526 ) {
527 debug_assert_eq!(bias.len(), cols);
528 for r in 0..rows {
529 let off = r * cols;
530 for c in 0..cols {
531 data[off + c] += bias[c];
532 }
533 }
534 }
535
536 fn layer_norm(
537 _ctx: &mut Self::Context,
538 x: &Self::Buffer,
539 gamma: &Self::Buffer,
540 beta: &Self::Buffer,
541 eps: f32,
542 out: &mut Self::Buffer,
543 tokens: usize,
544 dim: usize,
545 ) {
546 debug_assert_eq!(gamma.len(), dim);
547 debug_assert_eq!(beta.len(), dim);
548 for t in 0..tokens {
549 let off = t * dim;
550 let mut mean = 0.0f64;
552 for i in 0..dim {
553 mean += x[off + i] as f64;
554 }
555 mean /= dim as f64;
556 let mut var = 0.0f64;
557 for i in 0..dim {
558 let d = x[off + i] as f64 - mean;
559 var += d * d;
560 }
561 var /= dim as f64;
562 let inv = 1.0f32 / ((var as f32) + eps).sqrt();
563 let mean_f32 = mean as f32;
564 for i in 0..dim {
565 out[off + i] = (x[off + i] - mean_f32) * inv * gamma[i] + beta[i];
566 }
567 }
568 }
569
570 fn gelu(_ctx: &mut Self::Context, x: &Self::Buffer, out: &mut Self::Buffer, len: usize) {
571 for i in 0..len {
574 let xi = x[i];
575 out[i] = 0.5 * xi * (1.0 + libm_erf(xi / std::f32::consts::SQRT_2));
576 }
577 }
578
579 fn alloc(len: usize) -> Self::Buffer {
580 vec![0.0f32; len]
581 }
582 fn to_vec(buf: &Self::Buffer, len: usize) -> Vec<f32> {
583 buf[..len].to_vec()
584 }
585 fn from_slice(data: &[f32]) -> Self::Buffer {
586 data.to_vec()
587 }
588}
589
590fn dot_product(a: &[f32], b: &[f32]) -> f32 {
593 #[cfg(target_os = "macos")]
594 {
595 let mut result = 0.0f32;
596 unsafe {
597 vDSP_dotpr(a.as_ptr(), 1, b.as_ptr(), 1, &mut result, a.len() as u64);
598 }
599 result
600 }
601 #[cfg(not(target_os = "macos"))]
602 {
603 a.iter().zip(b).map(|(x, y)| x * y).sum()
604 }
605}
606
607#[allow(dead_code)]
608fn apply_rope_impl(
609 data: &mut [f32],
610 tokens: usize,
611 heads: usize,
612 head_dim: usize,
613 half: usize,
614 cos: &[f32],
615 sin: &[f32],
616 positions: &[u32],
617) {
618 for t in 0..tokens {
619 let pos = positions[t] as usize;
620 for h in 0..heads {
621 let base = t * heads * head_dim + h * head_dim;
622 for i in 0..half {
623 let c = cos[pos * half + i];
624 let s = sin[pos * half + i];
625 let x0 = data[base + i];
626 let x1 = data[base + half + i];
627 data[base + i] = x0 * c - x1 * s;
628 data[base + half + i] = x1 * c + x0 * s;
629 }
630 }
631 }
632}
633
634fn cpu_attention(
635 q: &[f32],
636 k: &[f32],
637 v: &[f32],
638 out: &mut [f32],
639 batch: usize,
640 q_len: usize,
641 kv_len: usize,
642 causal: bool,
643 pos_offset: usize,
644 cfg: &AttnConfig,
645) {
646 let nh = cfg.num_heads;
647 let nkv = cfg.num_kv_heads;
648 let d = cfg.head_dim;
649 let n_rep = nh / nkv;
650 let scale = cfg.scale;
651 let kv_stride = if cfg.kv_seq_stride > 0 {
656 cfg.kv_seq_stride
657 } else {
658 kv_len
659 };
660
661 for b in 0..batch {
662 for h in 0..nh {
663 let kv_h = h / n_rep;
664 let q_off = (b * nh + h) * q_len * d;
665 let k_off = (b * nkv + kv_h) * kv_stride * d;
666 let v_off = (b * nkv + kv_h) * kv_stride * d;
667 let o_off = (b * nh + h) * q_len * d;
668
669 for qi in 0..q_len {
670 let attend_end = if causal {
671 (pos_offset + qi + 1).min(kv_len)
672 } else {
673 kv_len
674 };
675 let attend_start = if causal && cfg.sliding_window > 0 {
676 attend_end.saturating_sub(cfg.sliding_window)
677 } else {
678 0
679 };
680 let mut max_score = f32::NEG_INFINITY;
681 let mut sum_exp = 0.0f32;
682 let mut acc = vec![0.0f32; d];
683
684 for ki in attend_start..attend_end {
685 let mut dot = 0.0f32;
686 for di in 0..d {
687 dot += q[q_off + qi * d + di] * k[k_off + ki * d + di];
688 }
689 let score = dot * scale;
690 if score > max_score {
691 let correction = (max_score - score).exp();
692 for di in 0..d {
693 acc[di] *= correction;
694 }
695 sum_exp *= correction;
696 max_score = score;
697 }
698 let w = (score - max_score).exp();
699 sum_exp += w;
700 for di in 0..d {
701 acc[di] += w * v[v_off + ki * d + di];
702 }
703 }
704
705 if sum_exp > 0.0 {
706 let inv = 1.0 / sum_exp;
707 for di in 0..d {
708 out[o_off + qi * d + di] = acc[di] * inv;
709 }
710 }
711 }
712 }
713 }
714}
715
716fn libm_erf(x: f32) -> f32 {
719 let sign = if x < 0.0 { -1.0 } else { 1.0 };
720 let x = x.abs();
721 let t = 1.0 / (1.0 + 0.3275911 * x);
722 let y = 1.0
723 - (((((1.061_405_4 * t - 1.453_152_1) * t) + 1.421_413_8) * t - 0.284_496_72) * t
724 + 0.254_829_6)
725 * t
726 * (-x * x).exp();
727 sign * y
728}
729
730impl crate::backend::BackendGraph for CpuBackend {}
732
733impl crate::backend::BackendCollective for CpuBackend {}
735
736fn cpu_dequant_gptq(
739 qweight: &[i32],
740 scales: &[f32],
741 qzeros: &[i32],
742 bits: u32,
743 group_size: usize,
744 k: usize,
745 n: usize,
746) -> Result<Vec<f32>> {
747 if bits != 4 {
748 return Err(FerrumError::unsupported(format!(
749 "CPU GPTQ: only bits=4 supported (got {bits})"
750 )));
751 }
752 let mut w = vec![0.0f32; n * k];
753 let packed_rows = k / 8;
754 for pr in 0..packed_rows {
755 for col in 0..n {
756 let packed = qweight[pr * n + col] as u32;
757 for bi in 0..8 {
758 let ki = pr * 8 + bi;
759 let q = ((packed >> (bi * 4)) & 0xF) as i32;
760 let grp = ki / group_size;
761 let scale = scales[grp * n + col];
762 let z_packed = qzeros[grp * (n / 8) + (col / 8)] as u32;
763 let zero = (((z_packed >> ((col % 8) * 4)) & 0xF) as i32) + 1;
764 let val = (q - zero) as f32 * scale;
765 w[col * k + ki] = val;
766 }
767 }
768 }
769 Ok(w)
770}
771
772impl crate::backend::BackendQuantMarlin for CpuBackend {
773 fn load_gptq(
774 qweight: &[i32],
775 scales: &[f32],
776 qzeros: &[i32],
777 _g_idx: Option<&[i32]>,
778 bias_host: Option<&[f32]>,
779 bits: u32,
780 group_size: usize,
781 k: usize,
782 n: usize,
783 ) -> Result<Box<dyn crate::Linear<Self> + Send + Sync>> {
784 let w = cpu_dequant_gptq(qweight, scales, qzeros, bits, group_size, k, n)?;
785 Ok(Box::new(crate::quant_linear::cpu_dequant::CpuGptqLinear {
789 weight_f32: w,
790 bias: bias_host.map(|b| b.to_vec()),
791 in_features: k,
792 out_features: n,
793 }))
794 }
795 fn load_gptq_stacked(
796 qweights: &[&[i32]],
797 scales: &[&[f32]],
798 qzeros: &[&[i32]],
799 _g_idx: Option<&[i32]>,
800 bits: u32,
801 group_size: usize,
802 k: usize,
803 n_per_expert: usize,
804 ) -> Result<std::sync::Arc<dyn crate::MarlinExpertStack<Self>>> {
805 let num_experts = qweights.len();
808 if scales.len() != num_experts || qzeros.len() != num_experts {
809 return Err(FerrumError::model(format!(
810 "load_gptq_stacked: input slice lengths disagree (qw {num_experts}, sc {}, qz {})",
811 scales.len(),
812 qzeros.len()
813 )));
814 }
815 let total_n = num_experts * n_per_expert;
816 let mut all_w = Vec::with_capacity(total_n * k);
817 for ((qw_e, sc_e), qz_e) in qweights.iter().zip(scales.iter()).zip(qzeros.iter()) {
818 let w_e = cpu_dequant_gptq(qw_e, sc_e, qz_e, bits, group_size, k, n_per_expert)?;
819 all_w.extend_from_slice(&w_e);
820 }
821 let store = std::sync::Arc::new(CpuGptqStore {
822 weight_f32: all_w,
823 k,
824 n: total_n,
825 });
826 Ok(std::sync::Arc::new(
827 crate::quant_linear::cpu_marlin_stack::CpuMarlinExpertStack::new(
828 store,
829 num_experts,
830 n_per_expert,
831 k,
832 ),
833 ))
834 }
835 }
842
843#[allow(clippy::too_many_arguments)]
847pub(crate) fn cpu_gemm_gptq_with_offset_strided(
848 _ctx: &mut <CpuBackend as Backend>::Context,
849 input: &<CpuBackend as Backend>::Buffer,
850 in_row_offset: usize,
851 weight: &CpuGptqStore,
852 expert_offset: usize,
853 expert_n: usize,
854 output: &mut <CpuBackend as Backend>::Buffer,
855 out_row_offset: usize,
856 m: usize,
857 k: usize,
858) -> Result<()> {
859 if expert_offset + expert_n > weight.n {
860 return Err(FerrumError::model(format!(
861 "cpu_gemm_gptq_with_offset_strided OOB: offset {expert_offset} + n {expert_n} > stacked_n {}",
862 weight.n
863 )));
864 }
865 if k != weight.k {
866 return Err(FerrumError::model(format!(
867 "cpu_gemm_gptq_with_offset_strided k mismatch: arg {k} vs weight.k {}",
868 weight.k
869 )));
870 }
871 let in_start = in_row_offset * k;
872 let in_end = (in_row_offset + m) * k;
873 let out_start = out_row_offset * expert_n;
874 let out_end = (out_row_offset + m) * expert_n;
875 let row_start = expert_offset * k;
876 let row_end = (expert_offset + expert_n) * k;
877 let weight_slice = weight.weight_f32[row_start..row_end].to_vec();
878 let in_slice = input[in_start..in_end].to_vec();
879 let mut out_slice = vec![0.0f32; m * expert_n];
880 let mut ctx_local = ();
881 CpuBackend::gemm(
882 &mut ctx_local,
883 &in_slice,
884 &weight_slice,
885 &mut out_slice,
886 m,
887 expert_n,
888 k,
889 );
890 output[out_start..out_end].copy_from_slice(&out_slice);
891 Ok(())
892}
893
894impl crate::backend::BackendQuantGguf for CpuBackend {
895 fn load_quant(
896 kind: super::GgufQuantType,
897 bytes: &[u8],
898 n_rows: usize,
899 n_cols: usize,
900 ) -> Result<Box<dyn crate::Linear<Self> + Send + Sync>> {
901 use super::GgufQuantType;
902 let store = match kind {
903 GgufQuantType::Q4K => {
904 let total_elems = n_rows * n_cols;
905 if total_elems % Q4_K_QK != 0 {
906 return Err(FerrumError::model(format!(
907 "load_quant Q4K: elements {total_elems} not a multiple of {Q4_K_QK}"
908 )));
909 }
910 let n_blocks = total_elems / Q4_K_QK;
911 let expected = n_blocks * Q4_K_BLOCK_BYTES;
912 if bytes.len() != expected {
913 return Err(FerrumError::model(format!(
914 "load_quant Q4K: bytes {} != expected {} ({n_blocks} × {Q4_K_BLOCK_BYTES})",
915 bytes.len(),
916 expected
917 )));
918 }
919 CpuQuantStore::Q4K {
920 weights: dequant_q4_k_cpu(bytes, n_blocks),
921 n_rows,
922 n_cols,
923 }
924 }
925 other => {
926 return Err(FerrumError::unsupported(format!(
927 "CPU load_quant: {other:?} not yet implemented"
928 )));
929 }
930 };
931 Ok(Box::new(crate::quant_linear::cpu_gguf::CpuGgufLinear {
934 store,
935 in_features: n_cols,
936 out_features: n_rows,
937 }))
938 }
939}
940
941impl crate::backend::BackendPagedKv for CpuBackend {}
943
944impl crate::backend::BackendMoeFused for CpuBackend {}
946
947impl crate::backend::BackendKvDtype<crate::backend::KvFp16> for CpuBackend {
949 type KvBuffer = <Self as crate::backend::Backend>::Buffer;
950 type KvScales = ();
951}