1#![allow(dead_code, unused_imports, unused_variables, unused_mut, unused_parens)]
7
8use std::sync::OnceLock;
9
10pub mod cpu;
11
12#[cfg(feature = "metal")]
13pub mod metal;
14
15#[cfg(feature = "metal")]
17pub type GpuBuffer = ::metal::Buffer;
18#[cfg(not(feature = "metal"))]
19pub type GpuBuffer = Vec<f32>;
20
21#[derive(Debug, Clone, PartialEq, Eq)]
22struct AttentionRuntimeEnv {
23 fused_cpu: bool,
24 fused_metal: bool,
25}
26
27impl AttentionRuntimeEnv {
28 fn from_env() -> Self {
29 Self::from_env_vars(std::env::vars())
30 }
31
32 fn from_env_vars<I, K, V>(vars: I) -> Self
33 where
34 I: IntoIterator<Item = (K, V)>,
35 K: AsRef<str>,
36 V: AsRef<str>,
37 {
38 let mut fused_cpu = false;
39 let mut fused_metal = false;
40
41 for (key, value) in vars {
42 match key.as_ref() {
43 "FERRUM_FUSED_CPU" => fused_cpu = value.as_ref() == "1",
44 "FERRUM_FUSED_METAL" => fused_metal = value.as_ref() == "1",
45 _ => {}
46 }
47 }
48
49 Self {
50 fused_cpu,
51 fused_metal,
52 }
53 }
54}
55
56fn attention_runtime_env() -> &'static AttentionRuntimeEnv {
57 static CONFIG: OnceLock<AttentionRuntimeEnv> = OnceLock::new();
58 CONFIG.get_or_init(AttentionRuntimeEnv::from_env)
59}
60
61#[derive(Clone, Debug, Default)]
63pub struct AttentionParams {
64 pub batch: usize,
65 pub num_heads: usize,
66 pub num_kv_heads: usize,
67 pub q_len: usize,
68 pub kv_len: usize,
69 pub head_dim: usize,
70 pub causal: bool,
71 pub pos_offset: usize,
72 pub sliding_window: usize,
75}
76
77pub fn attention_cpu(q: &[f32], k: &[f32], v: &[f32], out: &mut [f32], params: &AttentionParams) {
79 cpu::fused_attention(q, k, v, out, params);
80}
81
82pub fn attention(q: &[f32], k: &[f32], v: &[f32], out: &mut [f32], params: &AttentionParams) {
84 #[cfg(feature = "metal")]
85 {
86 if metal::is_available() {
87 metal::fused_attention(q, k, v, out, params);
88 return;
89 }
90 }
91 cpu::fused_attention(q, k, v, out, params);
92}
93
94#[derive(Clone)]
98pub struct TransformerConfig {
99 pub hidden_size: usize,
100 pub intermediate_size: usize,
101 pub num_heads: usize,
102 pub num_kv_heads: usize,
103 pub head_dim: usize,
104 pub num_layers: usize,
105 pub rms_norm_eps: f64,
106 pub rope_theta: f64,
107 pub max_position_embeddings: usize,
108}
109
110pub struct LayerWeights {
112 pub input_ln_w: Vec<f32>,
113 pub q_proj_w: Vec<f32>,
114 pub k_proj_w: Vec<f32>,
115 pub v_proj_w: Vec<f32>,
116 pub o_proj_w: Vec<f32>,
117 pub q_norm_w: Vec<f32>,
118 pub k_norm_w: Vec<f32>,
119 pub post_ln_w: Vec<f32>,
120 pub gate_proj_w: Vec<f32>,
121 pub up_proj_w: Vec<f32>,
122 pub down_proj_w: Vec<f32>,
123 pub attn_layer_scale: Option<Vec<f32>>,
125 pub mlp_layer_scale: Option<Vec<f32>>,
127}
128
129pub struct FusedTransformer {
131 cfg: TransformerConfig,
132 cos: Vec<f32>,
133 sin: Vec<f32>,
134 norm_w: Vec<f32>,
135
136 #[cfg(feature = "metal")]
137 metal_state: Option<MetalTransformerState>,
138
139 cpu_layers: Vec<LayerWeights>,
141 cpu_kv: Vec<cpu::transformer::CpuKvCache>,
142 tokens_generated: usize,
143 #[allow(dead_code)]
147 use_cpu: bool,
148}
149
150#[cfg(feature = "metal")]
151struct MetalTransformerState {
152 pipes: metal::pipelines::MetalPipelines,
153 weights: Vec<metal::transformer::MetalLayerWeights>,
154 kv: Vec<metal::transformer::MetalKvCache>,
155 cos_buf: ::metal::Buffer,
156 sin_buf: ::metal::Buffer,
157 metal_cfg: metal::transformer::MetalTransformerConfig,
158 scratch: Option<metal::transformer::LayerScratch>,
159 max_scratch_tokens: usize,
160 input_buf: Option<::metal::Buffer>,
161 input_buf_size: usize,
162 norm_w_buf: ::metal::Buffer,
164 norm_out_buf: Option<::metal::Buffer>,
166}
167
168impl FusedTransformer {
169 pub fn new(cfg: TransformerConfig, layers: Vec<LayerWeights>, norm_w: Vec<f32>) -> Self {
171 let hd = cfg.head_dim;
173 let half = hd / 2;
174 let max_seq = cfg.max_position_embeddings.min(32768);
175 let mut cos = vec![0.0f32; max_seq * half];
176 let mut sin = vec![0.0f32; max_seq * half];
177 for pos in 0..max_seq {
178 for i in 0..half {
179 let freq = 1.0f64 / cfg.rope_theta.powf((2 * i) as f64 / hd as f64);
180 let angle = pos as f64 * freq;
181 cos[pos * half + i] = angle.cos() as f32;
182 sin[pos * half + i] = angle.sin() as f32;
183 }
184 }
185
186 let n = layers.len();
187 let cpu_kv = (0..n)
188 .map(|_| cpu::transformer::CpuKvCache::new())
189 .collect();
190
191 let runtime_env = attention_runtime_env();
197 let use_cpu = if runtime_env.fused_cpu {
198 true
199 } else if runtime_env.fused_metal {
200 false
201 } else {
202 false
206 };
207
208 #[cfg(feature = "metal")]
209 let metal_state = {
210 if let Some(device) = ::metal::Device::system_default() {
211 let pipes = metal::pipelines::MetalPipelines::new(&device);
212 let weights: Vec<_> = layers
213 .iter()
214 .map(|lw| {
215 metal::transformer::MetalLayerWeights {
216 input_ln_w: pipes.buffer_from_data(&lw.input_ln_w),
217 q_proj_w: pipes.buffer_from_data(&lw.q_proj_w),
218 k_proj_w: pipes.buffer_from_data(&lw.k_proj_w),
219 v_proj_w: pipes.buffer_from_data(&lw.v_proj_w),
220 o_proj_w: pipes.buffer_from_data(&lw.o_proj_w),
221 q_norm_w: if lw.q_norm_w.is_empty() {
222 pipes.buffer_from_data(&[1.0f32]) } else {
224 pipes.buffer_from_data(&lw.q_norm_w)
225 },
226 k_norm_w: if lw.k_norm_w.is_empty() {
227 pipes.buffer_from_data(&[1.0f32])
228 } else {
229 pipes.buffer_from_data(&lw.k_norm_w)
230 },
231 post_ln_w: pipes.buffer_from_data(&lw.post_ln_w),
232 gate_proj_w: pipes.buffer_from_data(&lw.gate_proj_w),
233 up_proj_w: pipes.buffer_from_data(&lw.up_proj_w),
234 down_proj_w: pipes.buffer_from_data(&lw.down_proj_w),
235 has_qk_norm: !lw.q_norm_w.is_empty(),
236 attn_scale: lw
237 .attn_layer_scale
238 .as_ref()
239 .map(|s| pipes.buffer_from_data(s)),
240 mlp_scale: lw
241 .mlp_layer_scale
242 .as_ref()
243 .map(|s| pipes.buffer_from_data(s)),
244 }
245 })
246 .collect();
247 let kv_max_len = cfg.max_position_embeddings.min(4096);
248 let kv = (0..n)
249 .map(|_| {
250 metal::transformer::MetalKvCache::new(
251 &pipes,
252 cfg.num_kv_heads,
253 cfg.head_dim,
254 kv_max_len,
255 )
256 })
257 .collect();
258 let metal_cfg = metal::transformer::MetalTransformerConfig {
259 hidden_size: cfg.hidden_size,
260 intermediate_size: cfg.intermediate_size,
261 num_heads: cfg.num_heads,
262 num_kv_heads: cfg.num_kv_heads,
263 head_dim: cfg.head_dim,
264 rms_norm_eps: cfg.rms_norm_eps as f32,
265 };
266 let cos_buf = pipes.buffer_from_data(&cos);
267 let sin_buf = pipes.buffer_from_data(&sin);
268 let norm_w_buf = pipes.buffer_from_data(&norm_w);
269 Some(MetalTransformerState {
270 pipes,
271 weights,
272 kv,
273 cos_buf,
274 sin_buf,
275 metal_cfg,
276 scratch: None,
277 max_scratch_tokens: 0,
278 input_buf: None,
279 input_buf_size: 0,
280 norm_w_buf,
281 norm_out_buf: None,
282 })
283 } else {
284 None
285 }
286 };
287
288 #[cfg(feature = "metal")]
290 {
291 let backend = if use_cpu {
292 "CPU (Accelerate)"
293 } else {
294 "Metal+Accelerate"
295 };
296 tracing::info!(
297 "FusedTransformer: backend={backend}, hidden={}, layers={n}",
298 cfg.hidden_size
299 );
300 }
301 #[cfg(not(feature = "metal"))]
302 tracing::info!(
303 "FusedTransformer: backend=CPU, hidden={}, layers={n}",
304 cfg.hidden_size
305 );
306
307 FusedTransformer {
308 cfg,
309 cos,
310 sin,
311 norm_w,
312 #[cfg(feature = "metal")]
313 metal_state,
314 cpu_layers: layers,
315 cpu_kv,
316 tokens_generated: 0,
317 use_cpu,
318 }
319 }
320
321 pub fn forward(&mut self, input: &[f32], tokens: usize) -> Vec<f32> {
323 let pos_offset = self.tokens_generated;
324 #[cfg(feature = "metal")]
325 let h = self.cfg.hidden_size;
326
327 #[cfg(feature = "metal")]
328 if !self.use_cpu {
329 if let Some(ref mut ms) = self.metal_state {
330 if ms.scratch.is_none() || ms.max_scratch_tokens < tokens {
332 ms.scratch = Some(metal::transformer::LayerScratch::new(
333 &ms.pipes,
334 tokens,
335 h,
336 ms.metal_cfg.intermediate_size,
337 ms.metal_cfg.num_heads,
338 ms.metal_cfg.num_kv_heads,
339 ms.metal_cfg.head_dim,
340 ));
341 ms.max_scratch_tokens = tokens;
342 }
343 let scratch = ms.scratch.as_ref().unwrap();
344
345 let needed = tokens * h;
347 if ms.input_buf.is_none() || ms.input_buf_size < needed {
348 ms.input_buf = Some(ms.pipes.buffer_empty(needed.max(128 * h))); ms.input_buf_size = needed.max(128 * h);
350 }
351 let input_buf = ms.input_buf.as_ref().unwrap();
352 unsafe {
353 std::ptr::copy_nonoverlapping(
354 input.as_ptr(),
355 input_buf.contents() as *mut f32,
356 needed,
357 );
358 }
359
360 let cmd = ms.pipes.queue.new_command_buffer();
361
362 metal::transformer::metal_layer_forward_v2(
364 cmd,
365 &ms.pipes,
366 input_buf,
367 tokens,
368 &ms.weights[0],
369 &ms.metal_cfg,
370 &mut ms.kv[0],
371 pos_offset,
372 &ms.cos_buf,
373 &ms.sin_buf,
374 scratch,
375 );
376
377 for li in 1..ms.weights.len() {
379 let enc = cmd.new_blit_command_encoder();
381 enc.copy_from_buffer(&scratch.output, 0, input_buf, 0, (tokens * h * 4) as u64);
382 enc.end_encoding();
383
384 metal::transformer::metal_layer_forward_v2(
385 cmd,
386 &ms.pipes,
387 input_buf,
388 tokens,
389 &ms.weights[li],
390 &ms.metal_cfg,
391 &mut ms.kv[li],
392 pos_offset,
393 &ms.cos_buf,
394 &ms.sin_buf,
395 scratch,
396 );
397 }
398
399 cmd.commit();
401 cmd.wait_until_completed();
402
403 let hidden =
404 metal::pipelines::MetalPipelines::read_buffer(&scratch.output, tokens * h);
405 self.tokens_generated += tokens;
406 return self.final_rms_norm(&hidden, tokens);
407 }
408 } let mut hidden = input.to_vec();
412 for li in 0..self.cpu_layers.len() {
413 hidden = cpu::transformer::cpu_layer_forward(
414 &hidden,
415 tokens,
416 &self.cpu_layers[li],
417 &self.cfg,
418 &self.cos,
419 &self.sin,
420 &mut self.cpu_kv[li],
421 pos_offset,
422 );
423 }
424 self.tokens_generated += tokens;
425 self.final_rms_norm(&hidden, tokens)
426 }
427
428 #[cfg(feature = "metal")]
433 pub fn forward_gpu(
434 &mut self,
435 input: &[f32],
436 tokens: usize,
437 ) -> Option<(::metal::Buffer, usize)> {
438 let pos_offset = self.tokens_generated;
439 let h = self.cfg.hidden_size;
440
441 if self.use_cpu {
442 return None;
443 }
444
445 let ms = self.metal_state.as_mut()?;
446
447 if ms.scratch.is_none() || ms.max_scratch_tokens < tokens {
449 ms.scratch = Some(metal::transformer::LayerScratch::new(
450 &ms.pipes,
451 tokens,
452 h,
453 ms.metal_cfg.intermediate_size,
454 ms.metal_cfg.num_heads,
455 ms.metal_cfg.num_kv_heads,
456 ms.metal_cfg.head_dim,
457 ));
458 ms.max_scratch_tokens = tokens;
459 }
460 let scratch = ms.scratch.as_ref().unwrap();
461
462 let needed = tokens * h;
464 if ms.input_buf.is_none() || ms.input_buf_size < needed {
465 ms.input_buf = Some(ms.pipes.buffer_empty(needed.max(128 * h)));
466 ms.input_buf_size = needed.max(128 * h);
467 }
468 let input_buf = ms.input_buf.as_ref().unwrap();
469 unsafe {
470 std::ptr::copy_nonoverlapping(input.as_ptr(), input_buf.contents() as *mut f32, needed);
471 }
472
473 let cmd = ms.pipes.queue.new_command_buffer();
474
475 metal::transformer::metal_layer_forward_v2(
477 cmd,
478 &ms.pipes,
479 input_buf,
480 tokens,
481 &ms.weights[0],
482 &ms.metal_cfg,
483 &mut ms.kv[0],
484 pos_offset,
485 &ms.cos_buf,
486 &ms.sin_buf,
487 scratch,
488 );
489 for li in 1..ms.weights.len() {
490 let enc = cmd.new_blit_command_encoder();
491 enc.copy_from_buffer(&scratch.output, 0, input_buf, 0, (tokens * h * 4) as u64);
492 enc.end_encoding();
493 metal::transformer::metal_layer_forward_v2(
494 cmd,
495 &ms.pipes,
496 input_buf,
497 tokens,
498 &ms.weights[li],
499 &ms.metal_cfg,
500 &mut ms.kv[li],
501 pos_offset,
502 &ms.cos_buf,
503 &ms.sin_buf,
504 scratch,
505 );
506 }
507
508 if ms.norm_out_buf.is_none() {
510 ms.norm_out_buf = Some(ms.pipes.buffer_empty(needed.max(128 * h)));
511 }
512 let norm_out = ms.norm_out_buf.as_ref().unwrap();
513 {
514 let enc = cmd.new_compute_command_encoder();
515 ms.pipes.rms_norm_enc(
516 enc,
517 &scratch.output,
518 &ms.norm_w_buf,
519 norm_out,
520 tokens,
521 h,
522 self.cfg.rms_norm_eps as f32,
523 );
524 enc.end_encoding();
525 }
526
527 cmd.commit();
528 cmd.wait_until_completed();
529
530 self.tokens_generated += tokens;
531
532 let result = ms.pipes.buffer_empty(tokens * h);
534 let cmd2 = ms.pipes.queue.new_command_buffer();
536 let enc = cmd2.new_blit_command_encoder();
537 enc.copy_from_buffer(norm_out, 0, &result, 0, (tokens * h * 4) as u64);
538 enc.end_encoding();
539 cmd2.commit();
540 cmd2.wait_until_completed();
541
542 Some((result, tokens * h))
543 }
544
545 #[cfg(feature = "metal")]
549 pub fn forward_and_argmax(
550 &mut self,
551 input_buf: &GpuBuffer,
552 tokens: usize,
553 lm_weights_buf: &GpuBuffer,
554 vocab_size: usize,
555 ) -> Option<(u32, Vec<f32>)> {
556 let pos_offset = self.tokens_generated;
557 let h = self.cfg.hidden_size;
558 if self.use_cpu {
559 return None;
560 }
561
562 let ms = self.metal_state.as_mut()?;
563
564 if ms.scratch.is_none() || ms.max_scratch_tokens < tokens {
566 ms.scratch = Some(metal::transformer::LayerScratch::new(
567 &ms.pipes,
568 tokens,
569 h,
570 ms.metal_cfg.intermediate_size,
571 ms.metal_cfg.num_heads,
572 ms.metal_cfg.num_kv_heads,
573 ms.metal_cfg.head_dim,
574 ));
575 ms.max_scratch_tokens = tokens;
576 }
577 let scratch = ms.scratch.as_ref().unwrap();
578 let needed = tokens * h;
579 if ms.input_buf.is_none() || ms.input_buf_size < needed {
580 ms.input_buf = Some(ms.pipes.buffer_empty(needed.max(128 * h)));
581 ms.input_buf_size = needed.max(128 * h);
582 }
583 let int_buf = ms.input_buf.as_ref().unwrap();
584 if ms.norm_out_buf.is_none() {
585 ms.norm_out_buf = Some(ms.pipes.buffer_empty(needed.max(128 * h)));
586 }
587 let norm_out = ms.norm_out_buf.as_ref().unwrap();
588
589 let cmd = ms.pipes.queue.new_command_buffer();
591
592 metal::transformer::metal_layer_forward_v2(
593 cmd,
594 &ms.pipes,
595 input_buf,
596 tokens,
597 &ms.weights[0],
598 &ms.metal_cfg,
599 &mut ms.kv[0],
600 pos_offset,
601 &ms.cos_buf,
602 &ms.sin_buf,
603 scratch,
604 );
605 for li in 1..ms.weights.len() {
606 let enc = cmd.new_blit_command_encoder();
607 enc.copy_from_buffer(&scratch.output, 0, int_buf, 0, (needed * 4) as u64);
608 enc.end_encoding();
609 metal::transformer::metal_layer_forward_v2(
610 cmd,
611 &ms.pipes,
612 int_buf,
613 tokens,
614 &ms.weights[li],
615 &ms.metal_cfg,
616 &mut ms.kv[li],
617 pos_offset,
618 &ms.cos_buf,
619 &ms.sin_buf,
620 scratch,
621 );
622 }
623
624 {
626 let enc = cmd.new_compute_command_encoder();
627 ms.pipes.rms_norm_enc(
628 enc,
629 &scratch.output,
630 &ms.norm_w_buf,
631 norm_out,
632 tokens,
633 h,
634 self.cfg.rms_norm_eps as f32,
635 );
636 enc.end_encoding();
637 }
638
639 let logits_buf = if ms.input_buf_size >= vocab_size {
642 &scratch.gate_buf } else {
645 &scratch.gate_buf
646 };
647 {
648 let enc = cmd.new_compute_command_encoder();
649 ms.pipes
650 .gemm_v2(enc, norm_out, lm_weights_buf, logits_buf, 1, vocab_size, h);
651 enc.end_encoding();
652 }
653
654 let result_ptr = scratch.up_buf.contents() as *mut u32;
656 {
657 let enc = cmd.new_compute_command_encoder();
658 #[repr(C)]
659 struct P {
660 n: i32,
661 }
662 let p = P {
663 n: vocab_size as i32,
664 };
665 let p_buf = ms.pipes.device.new_buffer_with_data(
666 &p as *const _ as *const std::ffi::c_void,
667 4,
668 ::metal::MTLResourceOptions::StorageModeShared,
669 );
670 enc.set_compute_pipeline_state(ms.pipes.pipeline("argmax_f32"));
671 enc.set_buffer(0, Some(logits_buf), 0);
672 enc.set_buffer(1, Some(&scratch.up_buf), 0);
673 enc.set_buffer(2, Some(&p_buf), 0);
674 enc.dispatch_thread_groups(
675 ::metal::MTLSize::new(1, 1, 1),
676 ::metal::MTLSize::new(256, 1, 1),
677 );
678 enc.end_encoding();
679 }
680
681 cmd.commit();
682 cmd.wait_until_completed();
683 self.tokens_generated += tokens;
684
685 let token = unsafe { *result_ptr };
687 let hidden_vec = metal::pipelines::MetalPipelines::read_buffer(norm_out, needed);
688
689 Some((token, hidden_vec))
690 }
691
692 #[cfg(feature = "metal")]
695 pub fn forward_gpu_buffer(
696 &mut self,
697 input_buf: &::metal::Buffer,
698 tokens: usize,
699 ) -> Option<::metal::Buffer> {
700 let pos_offset = self.tokens_generated;
701 let h = self.cfg.hidden_size;
702 if self.use_cpu {
703 return None;
704 }
705 let ms = self.metal_state.as_mut()?;
706
707 if ms.scratch.is_none() || ms.max_scratch_tokens < tokens {
708 ms.scratch = Some(metal::transformer::LayerScratch::new(
709 &ms.pipes,
710 tokens,
711 h,
712 ms.metal_cfg.intermediate_size,
713 ms.metal_cfg.num_heads,
714 ms.metal_cfg.num_kv_heads,
715 ms.metal_cfg.head_dim,
716 ));
717 ms.max_scratch_tokens = tokens;
718 }
719 let scratch = ms.scratch.as_ref().unwrap();
720
721 let cmd = ms.pipes.queue.new_command_buffer();
722
723 metal::transformer::metal_layer_forward_v2(
725 cmd,
726 &ms.pipes,
727 input_buf,
728 tokens,
729 &ms.weights[0],
730 &ms.metal_cfg,
731 &mut ms.kv[0],
732 pos_offset,
733 &ms.cos_buf,
734 &ms.sin_buf,
735 scratch,
736 );
737 let needed = tokens * h;
739 if ms.input_buf.is_none() || ms.input_buf_size < needed {
740 ms.input_buf = Some(ms.pipes.buffer_empty(needed.max(128 * h)));
741 ms.input_buf_size = needed.max(128 * h);
742 }
743 let int_buf = ms.input_buf.as_ref().unwrap();
744
745 for li in 1..ms.weights.len() {
746 let enc = cmd.new_blit_command_encoder();
747 enc.copy_from_buffer(&scratch.output, 0, int_buf, 0, (tokens * h * 4) as u64);
748 enc.end_encoding();
749 metal::transformer::metal_layer_forward_v2(
750 cmd,
751 &ms.pipes,
752 int_buf,
753 tokens,
754 &ms.weights[li],
755 &ms.metal_cfg,
756 &mut ms.kv[li],
757 pos_offset,
758 &ms.cos_buf,
759 &ms.sin_buf,
760 scratch,
761 );
762 }
763
764 if ms.norm_out_buf.is_none() {
766 ms.norm_out_buf = Some(ms.pipes.buffer_empty(needed.max(128 * h)));
767 }
768 let norm_out = ms.norm_out_buf.as_ref().unwrap();
769 {
770 let enc = cmd.new_compute_command_encoder();
771 ms.pipes.rms_norm_enc(
772 enc,
773 &scratch.output,
774 &ms.norm_w_buf,
775 norm_out,
776 tokens,
777 h,
778 self.cfg.rms_norm_eps as f32,
779 );
780 enc.end_encoding();
781 }
782
783 cmd.commit();
784 cmd.wait_until_completed();
785 self.tokens_generated += tokens;
786
787 let result = ms.pipes.buffer_empty(tokens * h);
789 let cmd2 = ms.pipes.queue.new_command_buffer();
790 let enc = cmd2.new_blit_command_encoder();
791 enc.copy_from_buffer(norm_out, 0, &result, 0, (tokens * h * 4) as u64);
792 enc.end_encoding();
793 cmd2.commit();
794 cmd2.wait_until_completed();
795
796 Some(result)
797 }
798
799 #[cfg(feature = "metal")]
802 pub fn forward_gpu_to_vec(&mut self, input: &[f32], tokens: usize) -> Option<Vec<f32>> {
803 let h = self.cfg.hidden_size;
804 let (buf, _) = self.forward_gpu(input, tokens)?;
805 Some(metal::pipelines::MetalPipelines::read_buffer(
806 &buf,
807 tokens * h,
808 ))
809 }
810
811 fn final_rms_norm(&self, hidden: &[f32], tokens: usize) -> Vec<f32> {
812 let h = self.cfg.hidden_size;
813 let eps = self.cfg.rms_norm_eps as f32;
814 let mut out = vec![0.0f32; tokens * h];
815 for t in 0..tokens {
816 let row = &hidden[t * h..(t + 1) * h];
817 let o = &mut out[t * h..(t + 1) * h];
818 let sum_sq;
820 #[cfg(feature = "metal")]
821 {
822 extern "C" {
823 fn vDSP_dotpr(
824 a: *const f32,
825 a_stride: i32,
826 b: *const f32,
827 b_stride: i32,
828 result: *mut f32,
829 n: u64,
830 );
831 }
832 let mut dot = 0.0f32;
833 unsafe {
834 vDSP_dotpr(row.as_ptr(), 1, row.as_ptr(), 1, &mut dot, h as u64);
835 }
836 sum_sq = dot;
837 }
838 #[cfg(not(feature = "metal"))]
839 {
840 let mut v = 0.0f32;
841 for &val in row {
842 v += val * val;
843 }
844 sum_sq = v;
845 }
846 let inv = 1.0f32 / (sum_sq / h as f32 + eps).sqrt();
847 for i in 0..h {
848 o[i] = row[i] * inv * self.norm_w[i];
849 }
850 }
851 out
852 }
853
854 pub fn create_gpu_buffer(&self, data: &[f32]) -> Option<GpuBuffer> {
858 #[cfg(feature = "metal")]
859 {
860 let ms = self.metal_state.as_ref()?;
861 Some(ms.pipes.buffer_from_data(data))
862 }
863 #[cfg(not(feature = "metal"))]
864 {
865 Some(data.to_vec())
866 }
867 }
868
869 pub fn reset(&mut self) {
870 self.tokens_generated = 0;
871 for kv in &mut self.cpu_kv {
872 *kv = cpu::transformer::CpuKvCache::new();
873 }
874 #[cfg(feature = "metal")]
875 if let Some(ref mut ms) = self.metal_state {
876 for kv in &mut ms.kv {
877 kv.reset();
878 }
879 }
880 }
881}
882
883#[cfg(test)]
884mod tests {
885 use super::*;
886
887 #[test]
888 fn attention_runtime_env_parses_forced_backends() {
889 let env = AttentionRuntimeEnv::from_env_vars([
890 ("FERRUM_FUSED_CPU", "1"),
891 ("FERRUM_FUSED_METAL", "0"),
892 ]);
893
894 assert!(env.fused_cpu);
895 assert!(!env.fused_metal);
896 }
897
898 #[test]
899 fn attention_runtime_env_only_accepts_one() {
900 let env = AttentionRuntimeEnv::from_env_vars([
901 ("FERRUM_FUSED_CPU", "true"),
902 ("FERRUM_FUSED_METAL", "1"),
903 ]);
904
905 assert!(!env.fused_cpu);
906 assert!(env.fused_metal);
907 }
908}