1use half::f16;
5
6use super::{AttnConfig, Backend};
7use ferrum_types::{FerrumError, Result};
8
9const Q4_K_QK: usize = 256;
15const Q4_K_SCALE_SIZE: usize = 12;
16const Q4_K_BLOCK_BYTES: usize = 4 + Q4_K_SCALE_SIZE + Q4_K_QK / 2; fn get_scale_min_k4(j: usize, q: &[u8]) -> (u8, u8) {
20 if j < 4 {
21 (q[j] & 63, q[j + 4] & 63)
22 } else {
23 let d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
24 let m = (q[j + 4] >> 4) | ((q[j] >> 6) << 4);
25 (d, m)
26 }
27}
28
29fn dequant_q4_k_cpu(bytes: &[u8], n_blocks: usize) -> Vec<f32> {
33 debug_assert_eq!(bytes.len(), n_blocks * Q4_K_BLOCK_BYTES);
34 let mut out = Vec::with_capacity(n_blocks * Q4_K_QK);
35 for b in 0..n_blocks {
36 let off = b * Q4_K_BLOCK_BYTES;
37 let d = f16::from_le_bytes([bytes[off], bytes[off + 1]]).to_f32();
38 let dmin = f16::from_le_bytes([bytes[off + 2], bytes[off + 3]]).to_f32();
39 let scales = &bytes[off + 4..off + 4 + Q4_K_SCALE_SIZE];
40 let qs = &bytes[off + 4 + Q4_K_SCALE_SIZE..off + Q4_K_BLOCK_BYTES];
41
42 let mut is = 0usize;
43 for j in (0..Q4_K_QK).step_by(64) {
44 let q_chunk = &qs[j / 2..j / 2 + 32];
45 let (sc1, mn1) = get_scale_min_k4(is, scales);
46 let d1 = d * sc1 as f32;
47 let m1 = dmin * mn1 as f32;
48 let (sc2, mn2) = get_scale_min_k4(is + 1, scales);
49 let d2 = d * sc2 as f32;
50 let m2 = dmin * mn2 as f32;
51 for q in q_chunk {
52 out.push(d1 * (q & 0xF) as f32 - m1);
53 }
54 for q in q_chunk {
55 out.push(d2 * (q >> 4) as f32 - m2);
56 }
57 is += 2;
58 }
59 }
60 out
61}
62
63pub struct CpuBackend;
64
65#[cfg(target_os = "macos")]
66unsafe extern "C" {
67 unsafe fn cblas_sgemm(
68 order: i32,
69 transa: i32,
70 transb: i32,
71 m: i32,
72 n: i32,
73 k: i32,
74 alpha: f32,
75 a: *const f32,
76 lda: i32,
77 b: *const f32,
78 ldb: i32,
79 beta: f32,
80 c: *mut f32,
81 ldc: i32,
82 );
83 fn vDSP_dotpr(
84 a: *const f32,
85 a_stride: i32,
86 b: *const f32,
87 b_stride: i32,
88 result: *mut f32,
89 n: u64,
90 );
91}
92
93pub struct CpuGptqStore {
96 pub weight_f32: Vec<f32>, pub k: usize,
98 pub n: usize,
99}
100
101pub enum CpuQuantStore {
109 Q4K {
110 weights: Vec<f32>, n_rows: usize,
112 n_cols: usize,
113 },
114}
115
116impl Backend for CpuBackend {
117 type Buffer = Vec<f32>;
118 type Context = ();
119 type GptqStore = CpuGptqStore;
120 type QuantStore = CpuQuantStore;
121
122 fn new_context() -> Self::Context {}
123 fn sync(_ctx: &mut Self::Context) {}
124
125 fn load_gptq(
126 qweight: &[i32],
127 scales: &[f32],
128 qzeros: &[i32],
129 _g_idx: Option<&[i32]>,
130 bits: u32,
131 group_size: usize,
132 k: usize,
133 n: usize,
134 ) -> Result<Self::GptqStore> {
135 if bits != 4 {
136 return Err(FerrumError::unsupported(format!(
137 "CPU GPTQ: only bits=4 supported (got {bits})"
138 )));
139 }
140 let num_groups = k / group_size;
141 let mut w = vec![0.0f32; n * k];
145 let packed_rows = k / 8;
146 for pr in 0..packed_rows {
147 for col in 0..n {
148 let packed = qweight[pr * n + col] as u32;
149 for bi in 0..8 {
150 let ki = pr * 8 + bi;
151 let q = ((packed >> (bi * 4)) & 0xF) as i32;
152 let grp = ki / group_size;
153 let scale = scales[grp * n + col];
154 let z_packed = qzeros[grp * (n / 8) + (col / 8)] as u32;
156 let zero = (((z_packed >> ((col % 8) * 4)) & 0xF) as i32) + 1;
157 let val = (q - zero) as f32 * scale;
158 w[col * k + ki] = val;
159 }
160 }
161 }
162 let _ = num_groups; Ok(CpuGptqStore {
164 weight_f32: w,
165 k,
166 n,
167 })
168 }
169
170 fn gemm_gptq(
171 ctx: &mut Self::Context,
172 a: &Self::Buffer,
173 weight: &Self::GptqStore,
174 out: &mut Self::Buffer,
175 m: usize,
176 ) -> Result<()> {
177 Self::gemm(ctx, a, &weight.weight_f32, out, m, weight.n, weight.k);
180 Ok(())
181 }
182
183 fn load_quant(
184 kind: super::GgufQuantType,
185 bytes: &[u8],
186 n_rows: usize,
187 n_cols: usize,
188 ) -> Result<Self::QuantStore> {
189 use super::GgufQuantType;
190 match kind {
191 GgufQuantType::Q4K => {
192 let total_elems = n_rows * n_cols;
193 if total_elems % Q4_K_QK != 0 {
194 return Err(FerrumError::model(format!(
195 "load_quant Q4K: elements {total_elems} not a multiple of {Q4_K_QK}"
196 )));
197 }
198 let n_blocks = total_elems / Q4_K_QK;
199 let expected = n_blocks * Q4_K_BLOCK_BYTES;
200 if bytes.len() != expected {
201 return Err(FerrumError::model(format!(
202 "load_quant Q4K: bytes {} != expected {} ({n_blocks} × {Q4_K_BLOCK_BYTES})",
203 bytes.len(),
204 expected
205 )));
206 }
207 Ok(CpuQuantStore::Q4K {
208 weights: dequant_q4_k_cpu(bytes, n_blocks),
209 n_rows,
210 n_cols,
211 })
212 }
213 other => Err(FerrumError::unsupported(format!(
214 "CPU load_quant: {other:?} not yet implemented"
215 ))),
216 }
217 }
218
219 fn gemm_quant(
220 ctx: &mut Self::Context,
221 a: &Self::Buffer,
222 weight: &Self::QuantStore,
223 out: &mut Self::Buffer,
224 m: usize,
225 ) -> Result<()> {
226 match weight {
227 CpuQuantStore::Q4K {
228 weights,
229 n_rows,
230 n_cols,
231 } => {
232 Self::gemm(ctx, a, weights, out, m, *n_rows, *n_cols);
233 Ok(())
234 }
235 }
236 }
237
238 fn gemm(
239 _ctx: &mut Self::Context,
240 a: &Self::Buffer,
241 b: &Self::Buffer,
242 out: &mut Self::Buffer,
243 m: usize,
244 n: usize,
245 k: usize,
246 ) {
247 assert!(
248 a.len() >= m * k,
249 "gemm: a too small len={} m={m} k={k}",
250 a.len()
251 );
252 assert!(
253 b.len() >= n * k,
254 "gemm: b too small len={} n={n} k={k}",
255 b.len()
256 );
257 assert!(
258 out.len() >= m * n,
259 "gemm: out too small len={} m={m} n={n}",
260 out.len()
261 );
262 #[cfg(target_os = "macos")]
263 unsafe {
264 cblas_sgemm(
265 101,
266 111,
267 112,
268 m as i32,
269 n as i32,
270 k as i32,
271 1.0,
272 a.as_ptr(),
273 k as i32,
274 b.as_ptr(),
275 k as i32,
276 0.0,
277 out.as_mut_ptr(),
278 n as i32,
279 );
280 }
281 #[cfg(not(target_os = "macos"))]
282 {
283 for i in 0..m {
284 for j in 0..n {
285 let mut sum = 0.0f64;
286 for p in 0..k {
287 sum += a[i * k + p] as f64 * b[j * k + p] as f64;
288 }
289 out[i * n + j] = sum as f32;
290 }
291 }
292 }
293 }
294
295 fn rms_norm(
296 _ctx: &mut Self::Context,
297 x: &Self::Buffer,
298 w: &Self::Buffer,
299 eps: f32,
300 out: &mut Self::Buffer,
301 tokens: usize,
302 dim: usize,
303 ) {
304 for t in 0..tokens {
305 let row = &x[t * dim..(t + 1) * dim];
306 let o = &mut out[t * dim..(t + 1) * dim];
307 let sum_sq = dot_product(row, row);
308 let inv = 1.0f32 / (sum_sq / dim as f32 + eps).sqrt();
309 for i in 0..dim {
310 o[i] = row[i] * inv * w[i];
311 }
312 }
313 }
314
315 fn fused_add_rms_norm(
316 _ctx: &mut Self::Context,
317 residual: &mut Self::Buffer,
318 x: &Self::Buffer,
319 w: &Self::Buffer,
320 eps: f32,
321 out: &mut Self::Buffer,
322 tokens: usize,
323 dim: usize,
324 ) {
325 for t in 0..tokens {
326 let off = t * dim;
327 for i in 0..dim {
328 residual[off + i] += x[off + i];
329 }
330 let row = &residual[off..off + dim];
331 let o = &mut out[off..off + dim];
332 let sum_sq = dot_product(row, row);
333 let inv = 1.0f32 / (sum_sq / dim as f32 + eps).sqrt();
334 for i in 0..dim {
335 o[i] = row[i] * inv * w[i];
336 }
337 }
338 }
339
340 fn flash_attention(
341 _ctx: &mut Self::Context,
342 q: &Self::Buffer,
343 k: &Self::Buffer,
344 v: &Self::Buffer,
345 out: &mut Self::Buffer,
346 batch: usize,
347 q_len: usize,
348 kv_len: usize,
349 pos_offset: usize,
350 cfg: &AttnConfig,
351 ) {
352 cpu_attention(
353 q, k, v, out, batch, q_len, kv_len, cfg.causal, pos_offset, cfg,
354 );
355 }
356
357 fn copy_slice(
358 _ctx: &mut Self::Context,
359 src: &Self::Buffer,
360 src_offset: usize,
361 dst: &mut Self::Buffer,
362 dst_offset: usize,
363 len: usize,
364 ) {
365 dst[dst_offset..dst_offset + len].copy_from_slice(&src[src_offset..src_offset + len]);
366 }
367
368 fn embedding_lookup(
369 _ctx: &mut Self::Context,
370 table: &Self::Buffer,
371 ids: &[u32],
372 out: &mut Self::Buffer,
373 dim: usize,
374 ) {
375 for (i, &id) in ids.iter().enumerate() {
376 let src = id as usize * dim;
377 out[i * dim..(i + 1) * dim].copy_from_slice(&table[src..src + dim]);
378 }
379 }
380
381 fn split_qkv(
382 _ctx: &mut Self::Context,
383 qkv: &Self::Buffer,
384 q: &mut Self::Buffer,
385 k: &mut Self::Buffer,
386 v: &mut Self::Buffer,
387 tokens: usize,
388 q_dim: usize,
389 kv_dim: usize,
390 ) {
391 let qkv_dim = q_dim + 2 * kv_dim;
392 for t in 0..tokens {
393 let base = t * qkv_dim;
394 q[t * q_dim..(t + 1) * q_dim].copy_from_slice(&qkv[base..base + q_dim]);
395 k[t * kv_dim..(t + 1) * kv_dim]
396 .copy_from_slice(&qkv[base + q_dim..base + q_dim + kv_dim]);
397 v[t * kv_dim..(t + 1) * kv_dim]
398 .copy_from_slice(&qkv[base + q_dim + kv_dim..base + qkv_dim]);
399 }
400 }
401
402 fn fused_silu_mul_split(
403 _ctx: &mut Self::Context,
404 gate_up: &Self::Buffer,
405 out: &mut Self::Buffer,
406 tokens: usize,
407 im: usize,
408 ) {
409 for t in 0..tokens {
410 for i in 0..im {
411 let g = gate_up[t * 2 * im + i];
412 let u = gate_up[t * 2 * im + im + i];
413 out[t * im + i] = (g / (1.0 + (-g).exp())) * u;
414 }
415 }
416 }
417
418 fn qk_norm_rope(
419 _ctx: &mut Self::Context,
420 input: &Self::Buffer,
421 norm_w: &Self::Buffer,
422 cos: &Self::Buffer,
423 sin: &Self::Buffer,
424 output: &mut Self::Buffer,
425 tokens: usize,
426 heads: usize,
427 head_dim: usize,
428 pos_offset: usize,
429 eps: f32,
430 mode: i32,
431 ) {
432 let half = head_dim / 2;
433 let cos_len = cos.len();
434 let sin_len = sin.len();
435 debug_assert_eq!(cos_len, sin_len);
436
437 for t in 0..tokens {
438 let pos = pos_offset + t;
439 for h in 0..heads {
440 let src_off = (t * heads + h) * head_dim;
442 let dst_off = (h * tokens + t) * head_dim;
444
445 if mode == 0 {
447 for i in 0..head_dim {
448 output[dst_off + i] = input[src_off + i];
449 }
450 continue;
451 }
452
453 let scale = if mode == 1 {
455 let mut sum_sq = 0.0f32;
456 for i in 0..head_dim {
457 sum_sq += input[src_off + i] * input[src_off + i];
458 }
459 1.0f32 / (sum_sq / head_dim as f32 + eps).sqrt()
460 } else {
461 1.0
462 };
463
464 for i in 0..half {
466 let (x0_raw, x1_raw) = (input[src_off + i], input[src_off + i + half]);
467 let (x0, x1) = if mode == 1 {
468 (
469 x0_raw * scale * norm_w[i],
470 x1_raw * scale * norm_w[i + half],
471 )
472 } else {
473 (x0_raw, x1_raw)
474 };
475 let c = cos[pos * half + i];
476 let s = sin[pos * half + i];
477 output[dst_off + i] = x0 * c - x1 * s;
478 output[dst_off + i + half] = x1 * c + x0 * s;
479 }
480 }
481 }
482 }
483
484 fn kv_cache_append_head_major(
485 _ctx: &mut Self::Context,
486 cache_k: &mut Self::Buffer,
487 cache_v: &mut Self::Buffer,
488 cache_len: usize,
489 cache_capacity: usize,
490 new_k_head_major: &Self::Buffer,
491 new_v_head_major: &Self::Buffer,
492 new_tokens: usize,
493 nkv: usize,
494 hd: usize,
495 ) {
496 debug_assert!(cache_len + new_tokens <= cache_capacity);
497 debug_assert_eq!(cache_k.len(), nkv * cache_capacity * hd);
498 debug_assert_eq!(cache_v.len(), nkv * cache_capacity * hd);
499 debug_assert!(new_k_head_major.len() >= nkv * new_tokens * hd);
504 debug_assert!(new_v_head_major.len() >= nkv * new_tokens * hd);
505
506 for h in 0..nkv {
507 let dst_base = h * cache_capacity * hd + cache_len * hd;
508 let src_base = h * new_tokens * hd;
509 cache_k[dst_base..dst_base + new_tokens * hd]
510 .copy_from_slice(&new_k_head_major[src_base..src_base + new_tokens * hd]);
511 cache_v[dst_base..dst_base + new_tokens * hd]
512 .copy_from_slice(&new_v_head_major[src_base..src_base + new_tokens * hd]);
513 }
514 }
515
516 fn transpose_head_to_token(
517 _ctx: &mut Self::Context,
518 src: &Self::Buffer,
519 dst: &mut Self::Buffer,
520 tokens: usize,
521 heads: usize,
522 dim: usize,
523 ) {
524 for h in 0..heads {
525 for t in 0..tokens {
526 let s = (h * tokens + t) * dim;
527 let d = (t * heads + h) * dim;
528 dst[d..d + dim].copy_from_slice(&src[s..s + dim]);
529 }
530 }
531 }
532
533 fn add_inplace(
534 _ctx: &mut Self::Context,
535 residual: &mut Self::Buffer,
536 x: &Self::Buffer,
537 len: usize,
538 ) {
539 for i in 0..len {
540 residual[i] += x[i];
541 }
542 }
543
544 fn scaled_add_inplace(
545 _ctx: &mut Self::Context,
546 dst: &mut Self::Buffer,
547 src: &Self::Buffer,
548 scale: f32,
549 len: usize,
550 ) {
551 for i in 0..len {
552 dst[i] += scale * src[i];
553 }
554 }
555
556 fn add_bias(
557 _ctx: &mut Self::Context,
558 data: &mut Self::Buffer,
559 bias: &Self::Buffer,
560 rows: usize,
561 cols: usize,
562 ) {
563 debug_assert_eq!(bias.len(), cols);
564 for r in 0..rows {
565 let off = r * cols;
566 for c in 0..cols {
567 data[off + c] += bias[c];
568 }
569 }
570 }
571
572 fn layer_norm(
573 _ctx: &mut Self::Context,
574 x: &Self::Buffer,
575 gamma: &Self::Buffer,
576 beta: &Self::Buffer,
577 eps: f32,
578 out: &mut Self::Buffer,
579 tokens: usize,
580 dim: usize,
581 ) {
582 debug_assert_eq!(gamma.len(), dim);
583 debug_assert_eq!(beta.len(), dim);
584 for t in 0..tokens {
585 let off = t * dim;
586 let mut mean = 0.0f64;
588 for i in 0..dim {
589 mean += x[off + i] as f64;
590 }
591 mean /= dim as f64;
592 let mut var = 0.0f64;
593 for i in 0..dim {
594 let d = x[off + i] as f64 - mean;
595 var += d * d;
596 }
597 var /= dim as f64;
598 let inv = 1.0f32 / ((var as f32) + eps).sqrt();
599 let mean_f32 = mean as f32;
600 for i in 0..dim {
601 out[off + i] = (x[off + i] - mean_f32) * inv * gamma[i] + beta[i];
602 }
603 }
604 }
605
606 fn gelu(_ctx: &mut Self::Context, x: &Self::Buffer, out: &mut Self::Buffer, len: usize) {
607 for i in 0..len {
610 let xi = x[i];
611 out[i] = 0.5 * xi * (1.0 + libm_erf(xi / std::f32::consts::SQRT_2));
612 }
613 }
614
615 fn alloc(len: usize) -> Self::Buffer {
616 vec![0.0f32; len]
617 }
618 fn to_vec(buf: &Self::Buffer, len: usize) -> Vec<f32> {
619 buf[..len].to_vec()
620 }
621 fn from_slice(data: &[f32]) -> Self::Buffer {
622 data.to_vec()
623 }
624}
625
626fn dot_product(a: &[f32], b: &[f32]) -> f32 {
629 #[cfg(target_os = "macos")]
630 {
631 let mut result = 0.0f32;
632 unsafe {
633 vDSP_dotpr(a.as_ptr(), 1, b.as_ptr(), 1, &mut result, a.len() as u64);
634 }
635 result
636 }
637 #[cfg(not(target_os = "macos"))]
638 {
639 a.iter().zip(b).map(|(x, y)| x * y).sum()
640 }
641}
642
643#[allow(dead_code)]
644fn apply_rope_impl(
645 data: &mut [f32],
646 tokens: usize,
647 heads: usize,
648 head_dim: usize,
649 half: usize,
650 cos: &[f32],
651 sin: &[f32],
652 positions: &[u32],
653) {
654 for t in 0..tokens {
655 let pos = positions[t] as usize;
656 for h in 0..heads {
657 let base = t * heads * head_dim + h * head_dim;
658 for i in 0..half {
659 let c = cos[pos * half + i];
660 let s = sin[pos * half + i];
661 let x0 = data[base + i];
662 let x1 = data[base + half + i];
663 data[base + i] = x0 * c - x1 * s;
664 data[base + half + i] = x1 * c + x0 * s;
665 }
666 }
667 }
668}
669
670fn cpu_attention(
671 q: &[f32],
672 k: &[f32],
673 v: &[f32],
674 out: &mut [f32],
675 batch: usize,
676 q_len: usize,
677 kv_len: usize,
678 causal: bool,
679 pos_offset: usize,
680 cfg: &AttnConfig,
681) {
682 let nh = cfg.num_heads;
683 let nkv = cfg.num_kv_heads;
684 let d = cfg.head_dim;
685 let n_rep = nh / nkv;
686 let scale = cfg.scale;
687 let kv_stride = if cfg.kv_seq_stride > 0 {
692 cfg.kv_seq_stride
693 } else {
694 kv_len
695 };
696
697 for b in 0..batch {
698 for h in 0..nh {
699 let kv_h = h / n_rep;
700 let q_off = (b * nh + h) * q_len * d;
701 let k_off = (b * nkv + kv_h) * kv_stride * d;
702 let v_off = (b * nkv + kv_h) * kv_stride * d;
703 let o_off = (b * nh + h) * q_len * d;
704
705 for qi in 0..q_len {
706 let attend_end = if causal {
707 (pos_offset + qi + 1).min(kv_len)
708 } else {
709 kv_len
710 };
711 let attend_start = if causal && cfg.sliding_window > 0 {
712 attend_end.saturating_sub(cfg.sliding_window)
713 } else {
714 0
715 };
716 let mut max_score = f32::NEG_INFINITY;
717 let mut sum_exp = 0.0f32;
718 let mut acc = vec![0.0f32; d];
719
720 for ki in attend_start..attend_end {
721 let mut dot = 0.0f32;
722 for di in 0..d {
723 dot += q[q_off + qi * d + di] * k[k_off + ki * d + di];
724 }
725 let score = dot * scale;
726 if score > max_score {
727 let correction = (max_score - score).exp();
728 for di in 0..d {
729 acc[di] *= correction;
730 }
731 sum_exp *= correction;
732 max_score = score;
733 }
734 let w = (score - max_score).exp();
735 sum_exp += w;
736 for di in 0..d {
737 acc[di] += w * v[v_off + ki * d + di];
738 }
739 }
740
741 if sum_exp > 0.0 {
742 let inv = 1.0 / sum_exp;
743 for di in 0..d {
744 out[o_off + qi * d + di] = acc[di] * inv;
745 }
746 }
747 }
748 }
749 }
750}
751
752fn libm_erf(x: f32) -> f32 {
755 let sign = if x < 0.0 { -1.0 } else { 1.0 };
756 let x = x.abs();
757 let t = 1.0 / (1.0 + 0.3275911 * x);
758 let y = 1.0
759 - (((((1.061_405_4 * t - 1.453_152_1) * t) + 1.421_413_8) * t - 0.284_496_72) * t
760 + 0.254_829_6)
761 * t
762 * (-x * x).exp();
763 sign * y
764}