Skip to main content

ferrum_kernels/attention/
mod.rs

1//! ferrum-attention: Fused flash attention and transformer for Metal, CUDA, and CPU.
2//!
3//! Single-kernel attention (QK^T + softmax + attn@V) with no intermediate buffer
4//! materialization. Full transformer layer with all ops fused on GPU.
5
6#![allow(dead_code, unused_imports, unused_variables, unused_mut, unused_parens)]
7
8use std::sync::OnceLock;
9
10pub mod cpu;
11
12#[cfg(feature = "metal")]
13pub mod metal;
14
15/// Opaque GPU buffer type.
16#[cfg(feature = "metal")]
17pub type GpuBuffer = ::metal::Buffer;
18#[cfg(not(feature = "metal"))]
19pub type GpuBuffer = Vec<f32>;
20
21#[derive(Debug, Clone, PartialEq, Eq)]
22struct AttentionRuntimeEnv {
23    fused_cpu: bool,
24    fused_metal: bool,
25}
26
27impl AttentionRuntimeEnv {
28    fn from_env() -> Self {
29        Self::from_env_vars(std::env::vars())
30    }
31
32    fn from_env_vars<I, K, V>(vars: I) -> Self
33    where
34        I: IntoIterator<Item = (K, V)>,
35        K: AsRef<str>,
36        V: AsRef<str>,
37    {
38        let mut fused_cpu = false;
39        let mut fused_metal = false;
40
41        for (key, value) in vars {
42            match key.as_ref() {
43                "FERRUM_FUSED_CPU" => fused_cpu = value.as_ref() == "1",
44                "FERRUM_FUSED_METAL" => fused_metal = value.as_ref() == "1",
45                _ => {}
46            }
47        }
48
49        Self {
50            fused_cpu,
51            fused_metal,
52        }
53    }
54}
55
56fn attention_runtime_env() -> &'static AttentionRuntimeEnv {
57    static CONFIG: OnceLock<AttentionRuntimeEnv> = OnceLock::new();
58    CONFIG.get_or_init(AttentionRuntimeEnv::from_env)
59}
60
61/// Attention configuration.
62#[derive(Clone, Debug, Default)]
63pub struct AttentionParams {
64    pub batch: usize,
65    pub num_heads: usize,
66    pub num_kv_heads: usize,
67    pub q_len: usize,
68    pub kv_len: usize,
69    pub head_dim: usize,
70    pub causal: bool,
71    pub pos_offset: usize,
72    /// Sliding-window size. `0` = full causal (default). Mistral v0.1 and
73    /// Gemma use 4096; later Mistral versions set this to 0 to disable.
74    pub sliding_window: usize,
75}
76
77/// Run fused attention on CPU.
78pub fn attention_cpu(q: &[f32], k: &[f32], v: &[f32], out: &mut [f32], params: &AttentionParams) {
79    cpu::fused_attention(q, k, v, out, params);
80}
81
82/// Run fused attention on best available backend.
83pub fn attention(q: &[f32], k: &[f32], v: &[f32], out: &mut [f32], params: &AttentionParams) {
84    #[cfg(feature = "metal")]
85    {
86        if metal::is_available() {
87            metal::fused_attention(q, k, v, out, params);
88            return;
89        }
90    }
91    cpu::fused_attention(q, k, v, out, params);
92}
93
94// ── Fused Transformer ───────────────────────────────────────────────────
95
96/// Transformer layer configuration.
97#[derive(Clone)]
98pub struct TransformerConfig {
99    pub hidden_size: usize,
100    pub intermediate_size: usize,
101    pub num_heads: usize,
102    pub num_kv_heads: usize,
103    pub head_dim: usize,
104    pub num_layers: usize,
105    pub rms_norm_eps: f64,
106    pub rope_theta: f64,
107    pub max_position_embeddings: usize,
108}
109
110/// Per-layer weights as flat f32 vectors (extracted from safetensors once at init).
111pub struct LayerWeights {
112    pub input_ln_w: Vec<f32>,
113    pub q_proj_w: Vec<f32>,
114    pub k_proj_w: Vec<f32>,
115    pub v_proj_w: Vec<f32>,
116    pub o_proj_w: Vec<f32>,
117    pub q_norm_w: Vec<f32>,
118    pub k_norm_w: Vec<f32>,
119    pub post_ln_w: Vec<f32>,
120    pub gate_proj_w: Vec<f32>,
121    pub up_proj_w: Vec<f32>,
122    pub down_proj_w: Vec<f32>,
123    /// Optional layer_scale for attention (vocoder transformer uses this, talker doesn't)
124    pub attn_layer_scale: Option<Vec<f32>>,
125    /// Optional layer_scale for MLP
126    pub mlp_layer_scale: Option<Vec<f32>>,
127}
128
129/// A complete N-layer fused transformer. All ops bypass candle.
130pub struct FusedTransformer {
131    cfg: TransformerConfig,
132    cos: Vec<f32>,
133    sin: Vec<f32>,
134    norm_w: Vec<f32>,
135
136    #[cfg(feature = "metal")]
137    metal_state: Option<MetalTransformerState>,
138
139    // CPU state
140    cpu_layers: Vec<LayerWeights>,
141    cpu_kv: Vec<cpu::transformer::CpuKvCache>,
142    tokens_generated: usize,
143    /// true = always use CPU path (skips Metal even if available).
144    /// Auto-set for small models where Metal sync overhead > compute benefit.
145    /// Only read on Metal-enabled builds.
146    #[allow(dead_code)]
147    use_cpu: bool,
148}
149
150#[cfg(feature = "metal")]
151struct MetalTransformerState {
152    pipes: metal::pipelines::MetalPipelines,
153    weights: Vec<metal::transformer::MetalLayerWeights>,
154    kv: Vec<metal::transformer::MetalKvCache>,
155    cos_buf: ::metal::Buffer,
156    sin_buf: ::metal::Buffer,
157    metal_cfg: metal::transformer::MetalTransformerConfig,
158    scratch: Option<metal::transformer::LayerScratch>,
159    max_scratch_tokens: usize,
160    input_buf: Option<::metal::Buffer>,
161    input_buf_size: usize,
162    /// GPU-resident final norm weight for forward_gpu
163    norm_w_buf: ::metal::Buffer,
164    /// Reusable output buffer for forward_gpu (avoids alloc per call)
165    norm_out_buf: Option<::metal::Buffer>,
166}
167
168impl FusedTransformer {
169    /// Create from pre-extracted layer weights.
170    pub fn new(cfg: TransformerConfig, layers: Vec<LayerWeights>, norm_w: Vec<f32>) -> Self {
171        // Precompute cos/sin
172        let hd = cfg.head_dim;
173        let half = hd / 2;
174        let max_seq = cfg.max_position_embeddings.min(32768);
175        let mut cos = vec![0.0f32; max_seq * half];
176        let mut sin = vec![0.0f32; max_seq * half];
177        for pos in 0..max_seq {
178            for i in 0..half {
179                let freq = 1.0f64 / cfg.rope_theta.powf((2 * i) as f64 / hd as f64);
180                let angle = pos as f64 * freq;
181                cos[pos * half + i] = angle.cos() as f32;
182                sin[pos * half + i] = angle.sin() as f32;
183            }
184        }
185
186        let n = layers.len();
187        let cpu_kv = (0..n)
188            .map(|_| cpu::transformer::CpuKvCache::new())
189            .collect();
190
191        // Backend selection:
192        //   FERRUM_FUSED_CPU=1  → force CPU
193        //   FERRUM_FUSED_CPU=1  → force CPU
194        //   FERRUM_FUSED_METAL=1 → force Metal
195        //   otherwise → auto: Metal for large models (28-layer talker), CPU for small (SubTalker/vocoder)
196        let runtime_env = attention_runtime_env();
197        let use_cpu = if runtime_env.fused_cpu {
198            true
199        } else if runtime_env.fused_metal {
200            false
201        } else {
202            // Metal for talker (28) + vocoder (8), CPU for SubTalker (5) only
203            // All Metal: total pipeline is faster even though SubTalker per-step is slower,
204            // because GPU pipeline amortizes overhead across the full decode loop.
205            false
206        };
207
208        #[cfg(feature = "metal")]
209        let metal_state = {
210            if let Some(device) = ::metal::Device::system_default() {
211                let pipes = metal::pipelines::MetalPipelines::new(&device);
212                let weights: Vec<_> = layers
213                    .iter()
214                    .map(|lw| {
215                        metal::transformer::MetalLayerWeights {
216                            input_ln_w: pipes.buffer_from_data(&lw.input_ln_w),
217                            q_proj_w: pipes.buffer_from_data(&lw.q_proj_w),
218                            k_proj_w: pipes.buffer_from_data(&lw.k_proj_w),
219                            v_proj_w: pipes.buffer_from_data(&lw.v_proj_w),
220                            o_proj_w: pipes.buffer_from_data(&lw.o_proj_w),
221                            q_norm_w: if lw.q_norm_w.is_empty() {
222                                pipes.buffer_from_data(&[1.0f32]) // dummy, won't be used
223                            } else {
224                                pipes.buffer_from_data(&lw.q_norm_w)
225                            },
226                            k_norm_w: if lw.k_norm_w.is_empty() {
227                                pipes.buffer_from_data(&[1.0f32])
228                            } else {
229                                pipes.buffer_from_data(&lw.k_norm_w)
230                            },
231                            post_ln_w: pipes.buffer_from_data(&lw.post_ln_w),
232                            gate_proj_w: pipes.buffer_from_data(&lw.gate_proj_w),
233                            up_proj_w: pipes.buffer_from_data(&lw.up_proj_w),
234                            down_proj_w: pipes.buffer_from_data(&lw.down_proj_w),
235                            has_qk_norm: !lw.q_norm_w.is_empty(),
236                            attn_scale: lw
237                                .attn_layer_scale
238                                .as_ref()
239                                .map(|s| pipes.buffer_from_data(s)),
240                            mlp_scale: lw
241                                .mlp_layer_scale
242                                .as_ref()
243                                .map(|s| pipes.buffer_from_data(s)),
244                        }
245                    })
246                    .collect();
247                let kv_max_len = cfg.max_position_embeddings.min(4096);
248                let kv = (0..n)
249                    .map(|_| {
250                        metal::transformer::MetalKvCache::new(
251                            &pipes,
252                            cfg.num_kv_heads,
253                            cfg.head_dim,
254                            kv_max_len,
255                        )
256                    })
257                    .collect();
258                let metal_cfg = metal::transformer::MetalTransformerConfig {
259                    hidden_size: cfg.hidden_size,
260                    intermediate_size: cfg.intermediate_size,
261                    num_heads: cfg.num_heads,
262                    num_kv_heads: cfg.num_kv_heads,
263                    head_dim: cfg.head_dim,
264                    rms_norm_eps: cfg.rms_norm_eps as f32,
265                };
266                let cos_buf = pipes.buffer_from_data(&cos);
267                let sin_buf = pipes.buffer_from_data(&sin);
268                let norm_w_buf = pipes.buffer_from_data(&norm_w);
269                Some(MetalTransformerState {
270                    pipes,
271                    weights,
272                    kv,
273                    cos_buf,
274                    sin_buf,
275                    metal_cfg,
276                    scratch: None,
277                    max_scratch_tokens: 0,
278                    input_buf: None,
279                    input_buf_size: 0,
280                    norm_w_buf,
281                    norm_out_buf: None,
282                })
283            } else {
284                None
285            }
286        };
287
288        // Log backend selection (visible with RUST_LOG=info)
289        #[cfg(feature = "metal")]
290        {
291            let backend = if use_cpu {
292                "CPU (Accelerate)"
293            } else {
294                "Metal+Accelerate"
295            };
296            tracing::info!(
297                "FusedTransformer: backend={backend}, hidden={}, layers={n}",
298                cfg.hidden_size
299            );
300        }
301        #[cfg(not(feature = "metal"))]
302        tracing::info!(
303            "FusedTransformer: backend=CPU, hidden={}, layers={n}",
304            cfg.hidden_size
305        );
306
307        FusedTransformer {
308            cfg,
309            cos,
310            sin,
311            norm_w,
312            #[cfg(feature = "metal")]
313            metal_state,
314            cpu_layers: layers,
315            cpu_kv,
316            tokens_generated: 0,
317            use_cpu,
318        }
319    }
320
321    /// Forward: input [tokens, hidden] → output [tokens, hidden] (f32 vecs).
322    pub fn forward(&mut self, input: &[f32], tokens: usize) -> Vec<f32> {
323        let pos_offset = self.tokens_generated;
324        #[cfg(feature = "metal")]
325        let h = self.cfg.hidden_size;
326
327        #[cfg(feature = "metal")]
328        if !self.use_cpu {
329            if let Some(ref mut ms) = self.metal_state {
330                // Allocate/resize scratch buffers if needed
331                if ms.scratch.is_none() || ms.max_scratch_tokens < tokens {
332                    ms.scratch = Some(metal::transformer::LayerScratch::new(
333                        &ms.pipes,
334                        tokens,
335                        h,
336                        ms.metal_cfg.intermediate_size,
337                        ms.metal_cfg.num_heads,
338                        ms.metal_cfg.num_kv_heads,
339                        ms.metal_cfg.head_dim,
340                    ));
341                    ms.max_scratch_tokens = tokens;
342                }
343                let scratch = ms.scratch.as_ref().unwrap();
344
345                // Reuse or allocate input buffer (shared memory = zero-copy write on Apple Silicon)
346                let needed = tokens * h;
347                if ms.input_buf.is_none() || ms.input_buf_size < needed {
348                    ms.input_buf = Some(ms.pipes.buffer_empty(needed.max(128 * h))); // preallocate for up to 128 tokens
349                    ms.input_buf_size = needed.max(128 * h);
350                }
351                let input_buf = ms.input_buf.as_ref().unwrap();
352                unsafe {
353                    std::ptr::copy_nonoverlapping(
354                        input.as_ptr(),
355                        input_buf.contents() as *mut f32,
356                        needed,
357                    );
358                }
359
360                let cmd = ms.pipes.queue.new_command_buffer();
361
362                // Layer 0: input from input_buf
363                metal::transformer::metal_layer_forward_v2(
364                    cmd,
365                    &ms.pipes,
366                    input_buf,
367                    tokens,
368                    &ms.weights[0],
369                    &ms.metal_cfg,
370                    &mut ms.kv[0],
371                    pos_offset,
372                    &ms.cos_buf,
373                    &ms.sin_buf,
374                    scratch,
375                );
376
377                // Layers 1..N: input from scratch.output (ping-pong via copy)
378                for li in 1..ms.weights.len() {
379                    // Copy scratch.output to input_buf for next layer
380                    let enc = cmd.new_blit_command_encoder();
381                    enc.copy_from_buffer(&scratch.output, 0, input_buf, 0, (tokens * h * 4) as u64);
382                    enc.end_encoding();
383
384                    metal::transformer::metal_layer_forward_v2(
385                        cmd,
386                        &ms.pipes,
387                        input_buf,
388                        tokens,
389                        &ms.weights[li],
390                        &ms.metal_cfg,
391                        &mut ms.kv[li],
392                        pos_offset,
393                        &ms.cos_buf,
394                        &ms.sin_buf,
395                        scratch,
396                    );
397                }
398
399                // Single commit+wait for all layers
400                cmd.commit();
401                cmd.wait_until_completed();
402
403                let hidden =
404                    metal::pipelines::MetalPipelines::read_buffer(&scratch.output, tokens * h);
405                self.tokens_generated += tokens;
406                return self.final_rms_norm(&hidden, tokens);
407            }
408        } // !self.use_cpu
409
410        // CPU path (Accelerate sgemm + SIMD element-wise)
411        let mut hidden = input.to_vec();
412        for li in 0..self.cpu_layers.len() {
413            hidden = cpu::transformer::cpu_layer_forward(
414                &hidden,
415                tokens,
416                &self.cpu_layers[li],
417                &self.cfg,
418                &self.cos,
419                &self.sin,
420                &mut self.cpu_kv[li],
421                pos_offset,
422            );
423        }
424        self.tokens_generated += tokens;
425        self.final_rms_norm(&hidden, tokens)
426    }
427
428    /// Forward on GPU, returning Metal Buffer directly (zero CPU transfer).
429    /// Input: raw f32 slice (will be copied to GPU once).
430    /// Output: Metal Buffer containing normed hidden [tokens, hidden].
431    /// Falls back to CPU path if Metal not available.
432    #[cfg(feature = "metal")]
433    pub fn forward_gpu(
434        &mut self,
435        input: &[f32],
436        tokens: usize,
437    ) -> Option<(::metal::Buffer, usize)> {
438        let pos_offset = self.tokens_generated;
439        let h = self.cfg.hidden_size;
440
441        if self.use_cpu {
442            return None;
443        }
444
445        let ms = self.metal_state.as_mut()?;
446
447        // Allocate scratch
448        if ms.scratch.is_none() || ms.max_scratch_tokens < tokens {
449            ms.scratch = Some(metal::transformer::LayerScratch::new(
450                &ms.pipes,
451                tokens,
452                h,
453                ms.metal_cfg.intermediate_size,
454                ms.metal_cfg.num_heads,
455                ms.metal_cfg.num_kv_heads,
456                ms.metal_cfg.head_dim,
457            ));
458            ms.max_scratch_tokens = tokens;
459        }
460        let scratch = ms.scratch.as_ref().unwrap();
461
462        // Input buffer
463        let needed = tokens * h;
464        if ms.input_buf.is_none() || ms.input_buf_size < needed {
465            ms.input_buf = Some(ms.pipes.buffer_empty(needed.max(128 * h)));
466            ms.input_buf_size = needed.max(128 * h);
467        }
468        let input_buf = ms.input_buf.as_ref().unwrap();
469        unsafe {
470            std::ptr::copy_nonoverlapping(input.as_ptr(), input_buf.contents() as *mut f32, needed);
471        }
472
473        let cmd = ms.pipes.queue.new_command_buffer();
474
475        // All transformer layers
476        metal::transformer::metal_layer_forward_v2(
477            cmd,
478            &ms.pipes,
479            input_buf,
480            tokens,
481            &ms.weights[0],
482            &ms.metal_cfg,
483            &mut ms.kv[0],
484            pos_offset,
485            &ms.cos_buf,
486            &ms.sin_buf,
487            scratch,
488        );
489        for li in 1..ms.weights.len() {
490            let enc = cmd.new_blit_command_encoder();
491            enc.copy_from_buffer(&scratch.output, 0, input_buf, 0, (tokens * h * 4) as u64);
492            enc.end_encoding();
493            metal::transformer::metal_layer_forward_v2(
494                cmd,
495                &ms.pipes,
496                input_buf,
497                tokens,
498                &ms.weights[li],
499                &ms.metal_cfg,
500                &mut ms.kv[li],
501                pos_offset,
502                &ms.cos_buf,
503                &ms.sin_buf,
504                scratch,
505            );
506        }
507
508        // Final RMSNorm on GPU
509        if ms.norm_out_buf.is_none() {
510            ms.norm_out_buf = Some(ms.pipes.buffer_empty(needed.max(128 * h)));
511        }
512        let norm_out = ms.norm_out_buf.as_ref().unwrap();
513        {
514            let enc = cmd.new_compute_command_encoder();
515            ms.pipes.rms_norm_enc(
516                enc,
517                &scratch.output,
518                &ms.norm_w_buf,
519                norm_out,
520                tokens,
521                h,
522                self.cfg.rms_norm_eps as f32,
523            );
524            enc.end_encoding();
525        }
526
527        cmd.commit();
528        cmd.wait_until_completed();
529
530        self.tokens_generated += tokens;
531
532        // Return buffer with normed hidden (stays on GPU)
533        let result = ms.pipes.buffer_empty(tokens * h);
534        // Copy norm_out to result (so caller owns it)
535        let cmd2 = ms.pipes.queue.new_command_buffer();
536        let enc = cmd2.new_blit_command_encoder();
537        enc.copy_from_buffer(norm_out, 0, &result, 0, (tokens * h * 4) as u64);
538        enc.end_encoding();
539        cmd2.commit();
540        cmd2.wait_until_completed();
541
542        Some((result, tokens * h))
543    }
544
545    /// Forward + lm_head + argmax in ONE command buffer, ZERO extra allocs.
546    /// Returns (token_index, norm_hidden as Vec<f32>).
547    /// Pre-allocated buffers reused across calls.
548    #[cfg(feature = "metal")]
549    pub fn forward_and_argmax(
550        &mut self,
551        input_buf: &GpuBuffer,
552        tokens: usize,
553        lm_weights_buf: &GpuBuffer,
554        vocab_size: usize,
555    ) -> Option<(u32, Vec<f32>)> {
556        let pos_offset = self.tokens_generated;
557        let h = self.cfg.hidden_size;
558        if self.use_cpu {
559            return None;
560        }
561
562        let ms = self.metal_state.as_mut()?;
563
564        // Ensure scratch allocated
565        if ms.scratch.is_none() || ms.max_scratch_tokens < tokens {
566            ms.scratch = Some(metal::transformer::LayerScratch::new(
567                &ms.pipes,
568                tokens,
569                h,
570                ms.metal_cfg.intermediate_size,
571                ms.metal_cfg.num_heads,
572                ms.metal_cfg.num_kv_heads,
573                ms.metal_cfg.head_dim,
574            ));
575            ms.max_scratch_tokens = tokens;
576        }
577        let scratch = ms.scratch.as_ref().unwrap();
578        let needed = tokens * h;
579        if ms.input_buf.is_none() || ms.input_buf_size < needed {
580            ms.input_buf = Some(ms.pipes.buffer_empty(needed.max(128 * h)));
581            ms.input_buf_size = needed.max(128 * h);
582        }
583        let int_buf = ms.input_buf.as_ref().unwrap();
584        if ms.norm_out_buf.is_none() {
585            ms.norm_out_buf = Some(ms.pipes.buffer_empty(needed.max(128 * h)));
586        }
587        let norm_out = ms.norm_out_buf.as_ref().unwrap();
588
589        // === SINGLE command buffer: layers + norm + lm_head + argmax ===
590        let cmd = ms.pipes.queue.new_command_buffer();
591
592        metal::transformer::metal_layer_forward_v2(
593            cmd,
594            &ms.pipes,
595            input_buf,
596            tokens,
597            &ms.weights[0],
598            &ms.metal_cfg,
599            &mut ms.kv[0],
600            pos_offset,
601            &ms.cos_buf,
602            &ms.sin_buf,
603            scratch,
604        );
605        for li in 1..ms.weights.len() {
606            let enc = cmd.new_blit_command_encoder();
607            enc.copy_from_buffer(&scratch.output, 0, int_buf, 0, (needed * 4) as u64);
608            enc.end_encoding();
609            metal::transformer::metal_layer_forward_v2(
610                cmd,
611                &ms.pipes,
612                int_buf,
613                tokens,
614                &ms.weights[li],
615                &ms.metal_cfg,
616                &mut ms.kv[li],
617                pos_offset,
618                &ms.cos_buf,
619                &ms.sin_buf,
620                scratch,
621            );
622        }
623
624        // RMSNorm
625        {
626            let enc = cmd.new_compute_command_encoder();
627            ms.pipes.rms_norm_enc(
628                enc,
629                &scratch.output,
630                &ms.norm_w_buf,
631                norm_out,
632                tokens,
633                h,
634                self.cfg.rms_norm_eps as f32,
635            );
636            enc.end_encoding();
637        }
638
639        // lm_head GEMM (reuse int_buf as logits since it's ≥ vocab_size for small models)
640        // For safety, use a dedicated buffer only if needed
641        let logits_buf = if ms.input_buf_size >= vocab_size {
642            // Can't reuse int_buf — it might be read by norm. Use scratch.ln_out as temp.
643            &scratch.gate_buf // gate_buf is large enough (intermediate_size ≥ vocab_size)
644        } else {
645            &scratch.gate_buf
646        };
647        {
648            let enc = cmd.new_compute_command_encoder();
649            ms.pipes
650                .gemm_v2(enc, norm_out, lm_weights_buf, logits_buf, 1, vocab_size, h);
651            enc.end_encoding();
652        }
653
654        // Argmax (reuse scratch.up_buf for result — just need 1 u32 = 4 bytes)
655        let result_ptr = scratch.up_buf.contents() as *mut u32;
656        {
657            let enc = cmd.new_compute_command_encoder();
658            #[repr(C)]
659            struct P {
660                n: i32,
661            }
662            let p = P {
663                n: vocab_size as i32,
664            };
665            let p_buf = ms.pipes.device.new_buffer_with_data(
666                &p as *const _ as *const std::ffi::c_void,
667                4,
668                ::metal::MTLResourceOptions::StorageModeShared,
669            );
670            enc.set_compute_pipeline_state(ms.pipes.pipeline("argmax_f32"));
671            enc.set_buffer(0, Some(logits_buf), 0);
672            enc.set_buffer(1, Some(&scratch.up_buf), 0);
673            enc.set_buffer(2, Some(&p_buf), 0);
674            enc.dispatch_thread_groups(
675                ::metal::MTLSize::new(1, 1, 1),
676                ::metal::MTLSize::new(256, 1, 1),
677            );
678            enc.end_encoding();
679        }
680
681        cmd.commit();
682        cmd.wait_until_completed();
683        self.tokens_generated += tokens;
684
685        // Read results from shared memory (zero-copy on Apple Silicon)
686        let token = unsafe { *result_ptr };
687        let hidden_vec = metal::pipelines::MetalPipelines::read_buffer(norm_out, needed);
688
689        Some((token, hidden_vec))
690    }
691
692    /// Forward on GPU from a Metal Buffer input. Zero CPU transfer.
693    /// Returns normed hidden as Metal Buffer.
694    #[cfg(feature = "metal")]
695    pub fn forward_gpu_buffer(
696        &mut self,
697        input_buf: &::metal::Buffer,
698        tokens: usize,
699    ) -> Option<::metal::Buffer> {
700        let pos_offset = self.tokens_generated;
701        let h = self.cfg.hidden_size;
702        if self.use_cpu {
703            return None;
704        }
705        let ms = self.metal_state.as_mut()?;
706
707        if ms.scratch.is_none() || ms.max_scratch_tokens < tokens {
708            ms.scratch = Some(metal::transformer::LayerScratch::new(
709                &ms.pipes,
710                tokens,
711                h,
712                ms.metal_cfg.intermediate_size,
713                ms.metal_cfg.num_heads,
714                ms.metal_cfg.num_kv_heads,
715                ms.metal_cfg.head_dim,
716            ));
717            ms.max_scratch_tokens = tokens;
718        }
719        let scratch = ms.scratch.as_ref().unwrap();
720
721        let cmd = ms.pipes.queue.new_command_buffer();
722
723        // All transformer layers (input from caller's buffer)
724        metal::transformer::metal_layer_forward_v2(
725            cmd,
726            &ms.pipes,
727            input_buf,
728            tokens,
729            &ms.weights[0],
730            &ms.metal_cfg,
731            &mut ms.kv[0],
732            pos_offset,
733            &ms.cos_buf,
734            &ms.sin_buf,
735            scratch,
736        );
737        // Use ms.input_buf as intermediate for layers 1..N
738        let needed = tokens * h;
739        if ms.input_buf.is_none() || ms.input_buf_size < needed {
740            ms.input_buf = Some(ms.pipes.buffer_empty(needed.max(128 * h)));
741            ms.input_buf_size = needed.max(128 * h);
742        }
743        let int_buf = ms.input_buf.as_ref().unwrap();
744
745        for li in 1..ms.weights.len() {
746            let enc = cmd.new_blit_command_encoder();
747            enc.copy_from_buffer(&scratch.output, 0, int_buf, 0, (tokens * h * 4) as u64);
748            enc.end_encoding();
749            metal::transformer::metal_layer_forward_v2(
750                cmd,
751                &ms.pipes,
752                int_buf,
753                tokens,
754                &ms.weights[li],
755                &ms.metal_cfg,
756                &mut ms.kv[li],
757                pos_offset,
758                &ms.cos_buf,
759                &ms.sin_buf,
760                scratch,
761            );
762        }
763
764        // Final RMSNorm on GPU
765        if ms.norm_out_buf.is_none() {
766            ms.norm_out_buf = Some(ms.pipes.buffer_empty(needed.max(128 * h)));
767        }
768        let norm_out = ms.norm_out_buf.as_ref().unwrap();
769        {
770            let enc = cmd.new_compute_command_encoder();
771            ms.pipes.rms_norm_enc(
772                enc,
773                &scratch.output,
774                &ms.norm_w_buf,
775                norm_out,
776                tokens,
777                h,
778                self.cfg.rms_norm_eps as f32,
779            );
780            enc.end_encoding();
781        }
782
783        cmd.commit();
784        cmd.wait_until_completed();
785        self.tokens_generated += tokens;
786
787        // Return copy of norm output
788        let result = ms.pipes.buffer_empty(tokens * h);
789        let cmd2 = ms.pipes.queue.new_command_buffer();
790        let enc = cmd2.new_blit_command_encoder();
791        enc.copy_from_buffer(norm_out, 0, &result, 0, (tokens * h * 4) as u64);
792        enc.end_encoding();
793        cmd2.commit();
794        cmd2.wait_until_completed();
795
796        Some(result)
797    }
798
799    /// Forward on GPU with GPU-side norm, returns Vec<f32>.
800    /// Avoids CPU-side RMSNorm but still transfers output to CPU.
801    #[cfg(feature = "metal")]
802    pub fn forward_gpu_to_vec(&mut self, input: &[f32], tokens: usize) -> Option<Vec<f32>> {
803        let h = self.cfg.hidden_size;
804        let (buf, _) = self.forward_gpu(input, tokens)?;
805        Some(metal::pipelines::MetalPipelines::read_buffer(
806            &buf,
807            tokens * h,
808        ))
809    }
810
811    fn final_rms_norm(&self, hidden: &[f32], tokens: usize) -> Vec<f32> {
812        let h = self.cfg.hidden_size;
813        let eps = self.cfg.rms_norm_eps as f32;
814        let mut out = vec![0.0f32; tokens * h];
815        for t in 0..tokens {
816            let row = &hidden[t * h..(t + 1) * h];
817            let o = &mut out[t * h..(t + 1) * h];
818            // vDSP_dotpr for sum-of-squares (same SIMD path as PyTorch on macOS)
819            let sum_sq;
820            #[cfg(feature = "metal")]
821            {
822                extern "C" {
823                    fn vDSP_dotpr(
824                        a: *const f32,
825                        a_stride: i32,
826                        b: *const f32,
827                        b_stride: i32,
828                        result: *mut f32,
829                        n: u64,
830                    );
831                }
832                let mut dot = 0.0f32;
833                unsafe {
834                    vDSP_dotpr(row.as_ptr(), 1, row.as_ptr(), 1, &mut dot, h as u64);
835                }
836                sum_sq = dot;
837            }
838            #[cfg(not(feature = "metal"))]
839            {
840                let mut v = 0.0f32;
841                for &val in row {
842                    v += val * val;
843                }
844                sum_sq = v;
845            }
846            let inv = 1.0f32 / (sum_sq / h as f32 + eps).sqrt();
847            for i in 0..h {
848                o[i] = row[i] * inv * self.norm_w[i];
849            }
850        }
851        out
852    }
853
854    /// Create a Metal buffer from f32 data (shared memory, zero-copy on Apple Silicon).
855    /// Returns None if Metal not available.
856    /// Create a GPU buffer from f32 data. Returns None if GPU not available.
857    pub fn create_gpu_buffer(&self, data: &[f32]) -> Option<GpuBuffer> {
858        #[cfg(feature = "metal")]
859        {
860            let ms = self.metal_state.as_ref()?;
861            Some(ms.pipes.buffer_from_data(data))
862        }
863        #[cfg(not(feature = "metal"))]
864        {
865            Some(data.to_vec())
866        }
867    }
868
869    pub fn reset(&mut self) {
870        self.tokens_generated = 0;
871        for kv in &mut self.cpu_kv {
872            *kv = cpu::transformer::CpuKvCache::new();
873        }
874        #[cfg(feature = "metal")]
875        if let Some(ref mut ms) = self.metal_state {
876            for kv in &mut ms.kv {
877                kv.reset();
878            }
879        }
880    }
881}
882
883#[cfg(test)]
884mod tests {
885    use super::*;
886
887    #[test]
888    fn attention_runtime_env_parses_forced_backends() {
889        let env = AttentionRuntimeEnv::from_env_vars([
890            ("FERRUM_FUSED_CPU", "1"),
891            ("FERRUM_FUSED_METAL", "0"),
892        ]);
893
894        assert!(env.fused_cpu);
895        assert!(!env.fused_metal);
896    }
897
898    #[test]
899    fn attention_runtime_env_only_accepts_one() {
900        let env = AttentionRuntimeEnv::from_env_vars([
901            ("FERRUM_FUSED_CPU", "true"),
902            ("FERRUM_FUSED_METAL", "1"),
903        ]);
904
905        assert!(!env.fused_cpu);
906        assert!(env.fused_metal);
907    }
908}