1use super::{AttnConfig, Backend};
5use ferrum_types::{FerrumError, Result};
6
7pub struct CpuBackend;
8
9#[cfg(target_os = "macos")]
10extern "C" {
11 fn cblas_sgemm(
12 order: i32,
13 transa: i32,
14 transb: i32,
15 m: i32,
16 n: i32,
17 k: i32,
18 alpha: f32,
19 a: *const f32,
20 lda: i32,
21 b: *const f32,
22 ldb: i32,
23 beta: f32,
24 c: *mut f32,
25 ldc: i32,
26 );
27 fn vDSP_dotpr(
28 a: *const f32,
29 a_stride: i32,
30 b: *const f32,
31 b_stride: i32,
32 result: *mut f32,
33 n: u64,
34 );
35}
36
37pub struct CpuGptqStore {
40 pub weight_f32: Vec<f32>, pub k: usize,
42 pub n: usize,
43}
44
45impl Backend for CpuBackend {
46 type Buffer = Vec<f32>;
47 type Context = ();
48 type GptqStore = CpuGptqStore;
49
50 fn new_context() -> Self::Context {}
51 fn sync(_ctx: &mut Self::Context) {}
52
53 fn load_gptq(
54 qweight: &[i32],
55 scales: &[f32],
56 qzeros: &[i32],
57 _g_idx: Option<&[i32]>,
58 bits: u32,
59 group_size: usize,
60 k: usize,
61 n: usize,
62 ) -> Result<Self::GptqStore> {
63 if bits != 4 {
64 return Err(FerrumError::unsupported(format!(
65 "CPU GPTQ: only bits=4 supported (got {bits})"
66 )));
67 }
68 let num_groups = k / group_size;
69 let mut w = vec![0.0f32; n * k];
73 let packed_rows = k / 8;
74 for pr in 0..packed_rows {
75 for col in 0..n {
76 let packed = qweight[pr * n + col] as u32;
77 for bi in 0..8 {
78 let ki = pr * 8 + bi;
79 let q = ((packed >> (bi * 4)) & 0xF) as i32;
80 let grp = ki / group_size;
81 let scale = scales[grp * n + col];
82 let z_packed = qzeros[grp * (n / 8) + (col / 8)] as u32;
84 let zero = (((z_packed >> ((col % 8) * 4)) & 0xF) as i32) + 1;
85 let val = (q - zero) as f32 * scale;
86 w[col * k + ki] = val;
87 }
88 }
89 }
90 let _ = num_groups; Ok(CpuGptqStore {
92 weight_f32: w,
93 k,
94 n,
95 })
96 }
97
98 fn gemm_gptq(
99 ctx: &mut Self::Context,
100 a: &Self::Buffer,
101 weight: &Self::GptqStore,
102 out: &mut Self::Buffer,
103 m: usize,
104 ) -> Result<()> {
105 Self::gemm(ctx, a, &weight.weight_f32, out, m, weight.n, weight.k);
108 Ok(())
109 }
110
111 fn gemm(
112 _ctx: &mut Self::Context,
113 a: &Self::Buffer,
114 b: &Self::Buffer,
115 out: &mut Self::Buffer,
116 m: usize,
117 n: usize,
118 k: usize,
119 ) {
120 assert!(
121 a.len() >= m * k,
122 "gemm: a too small len={} m={m} k={k}",
123 a.len()
124 );
125 assert!(
126 b.len() >= n * k,
127 "gemm: b too small len={} n={n} k={k}",
128 b.len()
129 );
130 assert!(
131 out.len() >= m * n,
132 "gemm: out too small len={} m={m} n={n}",
133 out.len()
134 );
135 #[cfg(target_os = "macos")]
136 unsafe {
137 cblas_sgemm(
138 101,
139 111,
140 112,
141 m as i32,
142 n as i32,
143 k as i32,
144 1.0,
145 a.as_ptr(),
146 k as i32,
147 b.as_ptr(),
148 k as i32,
149 0.0,
150 out.as_mut_ptr(),
151 n as i32,
152 );
153 }
154 #[cfg(not(target_os = "macos"))]
155 {
156 for i in 0..m {
157 for j in 0..n {
158 let mut sum = 0.0f64;
159 for p in 0..k {
160 sum += a[i * k + p] as f64 * b[j * k + p] as f64;
161 }
162 out[i * n + j] = sum as f32;
163 }
164 }
165 }
166 }
167
168 fn rms_norm(
169 _ctx: &mut Self::Context,
170 x: &Self::Buffer,
171 w: &Self::Buffer,
172 eps: f32,
173 out: &mut Self::Buffer,
174 tokens: usize,
175 dim: usize,
176 ) {
177 for t in 0..tokens {
178 let row = &x[t * dim..(t + 1) * dim];
179 let o = &mut out[t * dim..(t + 1) * dim];
180 let sum_sq = dot_product(row, row);
181 let inv = 1.0f32 / (sum_sq / dim as f32 + eps).sqrt();
182 for i in 0..dim {
183 o[i] = row[i] * inv * w[i];
184 }
185 }
186 }
187
188 fn fused_add_rms_norm(
189 _ctx: &mut Self::Context,
190 residual: &mut Self::Buffer,
191 x: &Self::Buffer,
192 w: &Self::Buffer,
193 eps: f32,
194 out: &mut Self::Buffer,
195 tokens: usize,
196 dim: usize,
197 ) {
198 for t in 0..tokens {
199 let off = t * dim;
200 for i in 0..dim {
201 residual[off + i] += x[off + i];
202 }
203 let row = &residual[off..off + dim];
204 let o = &mut out[off..off + dim];
205 let sum_sq = dot_product(row, row);
206 let inv = 1.0f32 / (sum_sq / dim as f32 + eps).sqrt();
207 for i in 0..dim {
208 o[i] = row[i] * inv * w[i];
209 }
210 }
211 }
212
213 fn flash_attention(
214 _ctx: &mut Self::Context,
215 q: &Self::Buffer,
216 k: &Self::Buffer,
217 v: &Self::Buffer,
218 out: &mut Self::Buffer,
219 batch: usize,
220 q_len: usize,
221 kv_len: usize,
222 pos_offset: usize,
223 cfg: &AttnConfig,
224 ) {
225 cpu_attention(
226 q, k, v, out, batch, q_len, kv_len, cfg.causal, pos_offset, cfg,
227 );
228 }
229
230 fn copy_slice(
231 _ctx: &mut Self::Context,
232 src: &Self::Buffer,
233 src_offset: usize,
234 dst: &mut Self::Buffer,
235 dst_offset: usize,
236 len: usize,
237 ) {
238 dst[dst_offset..dst_offset + len].copy_from_slice(&src[src_offset..src_offset + len]);
239 }
240
241 fn embedding_lookup(
242 _ctx: &mut Self::Context,
243 table: &Self::Buffer,
244 ids: &[u32],
245 out: &mut Self::Buffer,
246 dim: usize,
247 ) {
248 for (i, &id) in ids.iter().enumerate() {
249 let src = id as usize * dim;
250 out[i * dim..(i + 1) * dim].copy_from_slice(&table[src..src + dim]);
251 }
252 }
253
254 fn split_qkv(
255 _ctx: &mut Self::Context,
256 qkv: &Self::Buffer,
257 q: &mut Self::Buffer,
258 k: &mut Self::Buffer,
259 v: &mut Self::Buffer,
260 tokens: usize,
261 q_dim: usize,
262 kv_dim: usize,
263 ) {
264 let qkv_dim = q_dim + 2 * kv_dim;
265 for t in 0..tokens {
266 let base = t * qkv_dim;
267 q[t * q_dim..(t + 1) * q_dim].copy_from_slice(&qkv[base..base + q_dim]);
268 k[t * kv_dim..(t + 1) * kv_dim]
269 .copy_from_slice(&qkv[base + q_dim..base + q_dim + kv_dim]);
270 v[t * kv_dim..(t + 1) * kv_dim]
271 .copy_from_slice(&qkv[base + q_dim + kv_dim..base + qkv_dim]);
272 }
273 }
274
275 fn fused_silu_mul_split(
276 _ctx: &mut Self::Context,
277 gate_up: &Self::Buffer,
278 out: &mut Self::Buffer,
279 tokens: usize,
280 im: usize,
281 ) {
282 for t in 0..tokens {
283 for i in 0..im {
284 let g = gate_up[t * 2 * im + i];
285 let u = gate_up[t * 2 * im + im + i];
286 out[t * im + i] = (g / (1.0 + (-g).exp())) * u;
287 }
288 }
289 }
290
291 fn qk_norm_rope(
292 _ctx: &mut Self::Context,
293 input: &Self::Buffer,
294 norm_w: &Self::Buffer,
295 cos: &Self::Buffer,
296 sin: &Self::Buffer,
297 output: &mut Self::Buffer,
298 tokens: usize,
299 heads: usize,
300 head_dim: usize,
301 pos_offset: usize,
302 eps: f32,
303 mode: i32,
304 ) {
305 let half = head_dim / 2;
306 let cos_len = cos.len();
307 let sin_len = sin.len();
308 debug_assert_eq!(cos_len, sin_len);
309
310 for t in 0..tokens {
311 let pos = pos_offset + t;
312 for h in 0..heads {
313 let src_off = (t * heads + h) * head_dim;
315 let dst_off = (h * tokens + t) * head_dim;
317
318 if mode == 0 {
320 for i in 0..head_dim {
321 output[dst_off + i] = input[src_off + i];
322 }
323 continue;
324 }
325
326 let scale = if mode == 1 {
328 let mut sum_sq = 0.0f32;
329 for i in 0..head_dim {
330 sum_sq += input[src_off + i] * input[src_off + i];
331 }
332 1.0f32 / (sum_sq / head_dim as f32 + eps).sqrt()
333 } else {
334 1.0
335 };
336
337 for i in 0..half {
339 let (x0_raw, x1_raw) = (input[src_off + i], input[src_off + i + half]);
340 let (x0, x1) = if mode == 1 {
341 (
342 x0_raw * scale * norm_w[i],
343 x1_raw * scale * norm_w[i + half],
344 )
345 } else {
346 (x0_raw, x1_raw)
347 };
348 let c = cos[pos * half + i];
349 let s = sin[pos * half + i];
350 output[dst_off + i] = x0 * c - x1 * s;
351 output[dst_off + i + half] = x1 * c + x0 * s;
352 }
353 }
354 }
355 }
356
357 fn kv_cache_append_head_major(
358 _ctx: &mut Self::Context,
359 cache_k: &mut Self::Buffer,
360 cache_v: &mut Self::Buffer,
361 cache_len: usize,
362 cache_capacity: usize,
363 new_k_head_major: &Self::Buffer,
364 new_v_head_major: &Self::Buffer,
365 new_tokens: usize,
366 nkv: usize,
367 hd: usize,
368 ) {
369 debug_assert!(cache_len + new_tokens <= cache_capacity);
370 debug_assert_eq!(cache_k.len(), nkv * cache_capacity * hd);
371 debug_assert_eq!(cache_v.len(), nkv * cache_capacity * hd);
372 debug_assert_eq!(new_k_head_major.len(), nkv * new_tokens * hd);
373 debug_assert_eq!(new_v_head_major.len(), nkv * new_tokens * hd);
374
375 for h in 0..nkv {
376 let dst_base = h * cache_capacity * hd + cache_len * hd;
377 let src_base = h * new_tokens * hd;
378 cache_k[dst_base..dst_base + new_tokens * hd]
379 .copy_from_slice(&new_k_head_major[src_base..src_base + new_tokens * hd]);
380 cache_v[dst_base..dst_base + new_tokens * hd]
381 .copy_from_slice(&new_v_head_major[src_base..src_base + new_tokens * hd]);
382 }
383 }
384
385 fn transpose_head_to_token(
386 _ctx: &mut Self::Context,
387 src: &Self::Buffer,
388 dst: &mut Self::Buffer,
389 tokens: usize,
390 heads: usize,
391 dim: usize,
392 ) {
393 for h in 0..heads {
394 for t in 0..tokens {
395 let s = (h * tokens + t) * dim;
396 let d = (t * heads + h) * dim;
397 dst[d..d + dim].copy_from_slice(&src[s..s + dim]);
398 }
399 }
400 }
401
402 fn add_inplace(
403 _ctx: &mut Self::Context,
404 residual: &mut Self::Buffer,
405 x: &Self::Buffer,
406 len: usize,
407 ) {
408 for i in 0..len {
409 residual[i] += x[i];
410 }
411 }
412
413 fn add_bias(
414 _ctx: &mut Self::Context,
415 data: &mut Self::Buffer,
416 bias: &Self::Buffer,
417 rows: usize,
418 cols: usize,
419 ) {
420 debug_assert_eq!(bias.len(), cols);
421 for r in 0..rows {
422 let off = r * cols;
423 for c in 0..cols {
424 data[off + c] += bias[c];
425 }
426 }
427 }
428
429 fn layer_norm(
430 _ctx: &mut Self::Context,
431 x: &Self::Buffer,
432 gamma: &Self::Buffer,
433 beta: &Self::Buffer,
434 eps: f32,
435 out: &mut Self::Buffer,
436 tokens: usize,
437 dim: usize,
438 ) {
439 debug_assert_eq!(gamma.len(), dim);
440 debug_assert_eq!(beta.len(), dim);
441 for t in 0..tokens {
442 let off = t * dim;
443 let mut mean = 0.0f64;
445 for i in 0..dim {
446 mean += x[off + i] as f64;
447 }
448 mean /= dim as f64;
449 let mut var = 0.0f64;
450 for i in 0..dim {
451 let d = x[off + i] as f64 - mean;
452 var += d * d;
453 }
454 var /= dim as f64;
455 let inv = 1.0f32 / ((var as f32) + eps).sqrt();
456 let mean_f32 = mean as f32;
457 for i in 0..dim {
458 out[off + i] = (x[off + i] - mean_f32) * inv * gamma[i] + beta[i];
459 }
460 }
461 }
462
463 fn gelu(_ctx: &mut Self::Context, x: &Self::Buffer, out: &mut Self::Buffer, len: usize) {
464 for i in 0..len {
467 let xi = x[i];
468 out[i] = 0.5 * xi * (1.0 + libm_erf(xi / std::f32::consts::SQRT_2));
469 }
470 }
471
472 fn alloc(len: usize) -> Self::Buffer {
473 vec![0.0f32; len]
474 }
475 fn to_vec(buf: &Self::Buffer, len: usize) -> Vec<f32> {
476 buf[..len].to_vec()
477 }
478 fn from_slice(data: &[f32]) -> Self::Buffer {
479 data.to_vec()
480 }
481}
482
483fn dot_product(a: &[f32], b: &[f32]) -> f32 {
486 #[cfg(target_os = "macos")]
487 {
488 let mut result = 0.0f32;
489 unsafe {
490 vDSP_dotpr(a.as_ptr(), 1, b.as_ptr(), 1, &mut result, a.len() as u64);
491 }
492 result
493 }
494 #[cfg(not(target_os = "macos"))]
495 {
496 a.iter().zip(b).map(|(x, y)| x * y).sum()
497 }
498}
499
500fn apply_rope_impl(
501 data: &mut [f32],
502 tokens: usize,
503 heads: usize,
504 head_dim: usize,
505 half: usize,
506 cos: &[f32],
507 sin: &[f32],
508 positions: &[u32],
509) {
510 for t in 0..tokens {
511 let pos = positions[t] as usize;
512 for h in 0..heads {
513 let base = t * heads * head_dim + h * head_dim;
514 for i in 0..half {
515 let c = cos[pos * half + i];
516 let s = sin[pos * half + i];
517 let x0 = data[base + i];
518 let x1 = data[base + half + i];
519 data[base + i] = x0 * c - x1 * s;
520 data[base + half + i] = x1 * c + x0 * s;
521 }
522 }
523 }
524}
525
526fn cpu_attention(
527 q: &[f32],
528 k: &[f32],
529 v: &[f32],
530 out: &mut [f32],
531 batch: usize,
532 q_len: usize,
533 kv_len: usize,
534 causal: bool,
535 pos_offset: usize,
536 cfg: &AttnConfig,
537) {
538 let nh = cfg.num_heads;
539 let nkv = cfg.num_kv_heads;
540 let d = cfg.head_dim;
541 let n_rep = nh / nkv;
542 let scale = cfg.scale;
543 let kv_stride = if cfg.kv_seq_stride > 0 {
548 cfg.kv_seq_stride
549 } else {
550 kv_len
551 };
552
553 for b in 0..batch {
554 for h in 0..nh {
555 let kv_h = h / n_rep;
556 let q_off = (b * nh + h) * q_len * d;
557 let k_off = (b * nkv + kv_h) * kv_stride * d;
558 let v_off = (b * nkv + kv_h) * kv_stride * d;
559 let o_off = (b * nh + h) * q_len * d;
560
561 for qi in 0..q_len {
562 let attend_end = if causal {
563 (pos_offset + qi + 1).min(kv_len)
564 } else {
565 kv_len
566 };
567 let attend_start = if causal && cfg.sliding_window > 0 {
568 attend_end.saturating_sub(cfg.sliding_window)
569 } else {
570 0
571 };
572 let mut max_score = f32::NEG_INFINITY;
573 let mut sum_exp = 0.0f32;
574 let mut acc = vec![0.0f32; d];
575
576 for ki in attend_start..attend_end {
577 let mut dot = 0.0f32;
578 for di in 0..d {
579 dot += q[q_off + qi * d + di] * k[k_off + ki * d + di];
580 }
581 let score = dot * scale;
582 if score > max_score {
583 let correction = (max_score - score).exp();
584 for di in 0..d {
585 acc[di] *= correction;
586 }
587 sum_exp *= correction;
588 max_score = score;
589 }
590 let w = (score - max_score).exp();
591 sum_exp += w;
592 for di in 0..d {
593 acc[di] += w * v[v_off + ki * d + di];
594 }
595 }
596
597 if sum_exp > 0.0 {
598 let inv = 1.0 / sum_exp;
599 for di in 0..d {
600 out[o_off + qi * d + di] = acc[di] * inv;
601 }
602 }
603 }
604 }
605 }
606}
607
608fn libm_erf(x: f32) -> f32 {
611 let sign = if x < 0.0 { -1.0 } else { 1.0 };
612 let x = x.abs();
613 let t = 1.0 / (1.0 + 0.3275911 * x);
614 let y = 1.0
615 - (((((1.061405429 * t - 1.453152027) * t) + 1.421413741) * t - 0.284496736) * t
616 + 0.254829592)
617 * t
618 * (-x * x).exp();
619 sign * y
620}