#![allow(dead_code, unused_imports, unused_variables, unused_mut, unused_parens)]
use std::sync::OnceLock;
pub mod cpu;
#[cfg(feature = "metal")]
pub mod metal;
#[cfg(feature = "metal")]
pub type GpuBuffer = ::metal::Buffer;
#[cfg(not(feature = "metal"))]
pub type GpuBuffer = Vec<f32>;
#[derive(Debug, Clone, PartialEq, Eq)]
struct AttentionRuntimeEnv {
fused_cpu: bool,
fused_metal: bool,
}
impl AttentionRuntimeEnv {
fn from_env() -> Self {
Self::from_env_vars(std::env::vars())
}
fn from_env_vars<I, K, V>(vars: I) -> Self
where
I: IntoIterator<Item = (K, V)>,
K: AsRef<str>,
V: AsRef<str>,
{
let mut fused_cpu = false;
let mut fused_metal = false;
for (key, value) in vars {
match key.as_ref() {
"FERRUM_FUSED_CPU" => fused_cpu = value.as_ref() == "1",
"FERRUM_FUSED_METAL" => fused_metal = value.as_ref() == "1",
_ => {}
}
}
Self {
fused_cpu,
fused_metal,
}
}
}
fn attention_runtime_env() -> &'static AttentionRuntimeEnv {
static CONFIG: OnceLock<AttentionRuntimeEnv> = OnceLock::new();
CONFIG.get_or_init(AttentionRuntimeEnv::from_env)
}
#[derive(Clone, Debug, Default)]
pub struct AttentionParams {
pub batch: usize,
pub num_heads: usize,
pub num_kv_heads: usize,
pub q_len: usize,
pub kv_len: usize,
pub head_dim: usize,
pub causal: bool,
pub pos_offset: usize,
pub sliding_window: usize,
}
pub fn attention_cpu(q: &[f32], k: &[f32], v: &[f32], out: &mut [f32], params: &AttentionParams) {
cpu::fused_attention(q, k, v, out, params);
}
pub fn attention(q: &[f32], k: &[f32], v: &[f32], out: &mut [f32], params: &AttentionParams) {
#[cfg(feature = "metal")]
{
if metal::is_available() {
metal::fused_attention(q, k, v, out, params);
return;
}
}
cpu::fused_attention(q, k, v, out, params);
}
#[derive(Clone)]
pub struct TransformerConfig {
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_heads: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
pub num_layers: usize,
pub rms_norm_eps: f64,
pub rope_theta: f64,
pub max_position_embeddings: usize,
}
pub struct LayerWeights {
pub input_ln_w: Vec<f32>,
pub q_proj_w: Vec<f32>,
pub k_proj_w: Vec<f32>,
pub v_proj_w: Vec<f32>,
pub o_proj_w: Vec<f32>,
pub q_norm_w: Vec<f32>,
pub k_norm_w: Vec<f32>,
pub post_ln_w: Vec<f32>,
pub gate_proj_w: Vec<f32>,
pub up_proj_w: Vec<f32>,
pub down_proj_w: Vec<f32>,
pub attn_layer_scale: Option<Vec<f32>>,
pub mlp_layer_scale: Option<Vec<f32>>,
}
pub struct FusedTransformer {
cfg: TransformerConfig,
cos: Vec<f32>,
sin: Vec<f32>,
norm_w: Vec<f32>,
#[cfg(feature = "metal")]
metal_state: Option<MetalTransformerState>,
cpu_layers: Vec<LayerWeights>,
cpu_kv: Vec<cpu::transformer::CpuKvCache>,
tokens_generated: usize,
#[allow(dead_code)]
use_cpu: bool,
}
#[cfg(feature = "metal")]
struct MetalTransformerState {
pipes: metal::pipelines::MetalPipelines,
weights: Vec<metal::transformer::MetalLayerWeights>,
kv: Vec<metal::transformer::MetalKvCache>,
cos_buf: ::metal::Buffer,
sin_buf: ::metal::Buffer,
metal_cfg: metal::transformer::MetalTransformerConfig,
scratch: Option<metal::transformer::LayerScratch>,
max_scratch_tokens: usize,
input_buf: Option<::metal::Buffer>,
input_buf_size: usize,
norm_w_buf: ::metal::Buffer,
norm_out_buf: Option<::metal::Buffer>,
}
impl FusedTransformer {
pub fn new(cfg: TransformerConfig, layers: Vec<LayerWeights>, norm_w: Vec<f32>) -> Self {
let hd = cfg.head_dim;
let half = hd / 2;
let max_seq = cfg.max_position_embeddings.min(32768);
let mut cos = vec![0.0f32; max_seq * half];
let mut sin = vec![0.0f32; max_seq * half];
for pos in 0..max_seq {
for i in 0..half {
let freq = 1.0f64 / cfg.rope_theta.powf((2 * i) as f64 / hd as f64);
let angle = pos as f64 * freq;
cos[pos * half + i] = angle.cos() as f32;
sin[pos * half + i] = angle.sin() as f32;
}
}
let n = layers.len();
let cpu_kv = (0..n)
.map(|_| cpu::transformer::CpuKvCache::new())
.collect();
let runtime_env = attention_runtime_env();
let use_cpu = if runtime_env.fused_cpu {
true
} else if runtime_env.fused_metal {
false
} else {
false
};
#[cfg(feature = "metal")]
let metal_state = {
if let Some(device) = ::metal::Device::system_default() {
let pipes = metal::pipelines::MetalPipelines::new(&device);
let weights: Vec<_> = layers
.iter()
.map(|lw| {
metal::transformer::MetalLayerWeights {
input_ln_w: pipes.buffer_from_data(&lw.input_ln_w),
q_proj_w: pipes.buffer_from_data(&lw.q_proj_w),
k_proj_w: pipes.buffer_from_data(&lw.k_proj_w),
v_proj_w: pipes.buffer_from_data(&lw.v_proj_w),
o_proj_w: pipes.buffer_from_data(&lw.o_proj_w),
q_norm_w: if lw.q_norm_w.is_empty() {
pipes.buffer_from_data(&[1.0f32]) } else {
pipes.buffer_from_data(&lw.q_norm_w)
},
k_norm_w: if lw.k_norm_w.is_empty() {
pipes.buffer_from_data(&[1.0f32])
} else {
pipes.buffer_from_data(&lw.k_norm_w)
},
post_ln_w: pipes.buffer_from_data(&lw.post_ln_w),
gate_proj_w: pipes.buffer_from_data(&lw.gate_proj_w),
up_proj_w: pipes.buffer_from_data(&lw.up_proj_w),
down_proj_w: pipes.buffer_from_data(&lw.down_proj_w),
has_qk_norm: !lw.q_norm_w.is_empty(),
attn_scale: lw
.attn_layer_scale
.as_ref()
.map(|s| pipes.buffer_from_data(s)),
mlp_scale: lw
.mlp_layer_scale
.as_ref()
.map(|s| pipes.buffer_from_data(s)),
}
})
.collect();
let kv_max_len = cfg.max_position_embeddings.min(4096);
let kv = (0..n)
.map(|_| {
metal::transformer::MetalKvCache::new(
&pipes,
cfg.num_kv_heads,
cfg.head_dim,
kv_max_len,
)
})
.collect();
let metal_cfg = metal::transformer::MetalTransformerConfig {
hidden_size: cfg.hidden_size,
intermediate_size: cfg.intermediate_size,
num_heads: cfg.num_heads,
num_kv_heads: cfg.num_kv_heads,
head_dim: cfg.head_dim,
rms_norm_eps: cfg.rms_norm_eps as f32,
};
let cos_buf = pipes.buffer_from_data(&cos);
let sin_buf = pipes.buffer_from_data(&sin);
let norm_w_buf = pipes.buffer_from_data(&norm_w);
Some(MetalTransformerState {
pipes,
weights,
kv,
cos_buf,
sin_buf,
metal_cfg,
scratch: None,
max_scratch_tokens: 0,
input_buf: None,
input_buf_size: 0,
norm_w_buf,
norm_out_buf: None,
})
} else {
None
}
};
#[cfg(feature = "metal")]
{
let backend = if use_cpu {
"CPU (Accelerate)"
} else {
"Metal+Accelerate"
};
tracing::info!(
"FusedTransformer: backend={backend}, hidden={}, layers={n}",
cfg.hidden_size
);
}
#[cfg(not(feature = "metal"))]
tracing::info!(
"FusedTransformer: backend=CPU, hidden={}, layers={n}",
cfg.hidden_size
);
FusedTransformer {
cfg,
cos,
sin,
norm_w,
#[cfg(feature = "metal")]
metal_state,
cpu_layers: layers,
cpu_kv,
tokens_generated: 0,
use_cpu,
}
}
pub fn forward(&mut self, input: &[f32], tokens: usize) -> Vec<f32> {
let pos_offset = self.tokens_generated;
#[cfg(feature = "metal")]
let h = self.cfg.hidden_size;
#[cfg(feature = "metal")]
if !self.use_cpu {
if let Some(ref mut ms) = self.metal_state {
if ms.scratch.is_none() || ms.max_scratch_tokens < tokens {
ms.scratch = Some(metal::transformer::LayerScratch::new(
&ms.pipes,
tokens,
h,
ms.metal_cfg.intermediate_size,
ms.metal_cfg.num_heads,
ms.metal_cfg.num_kv_heads,
ms.metal_cfg.head_dim,
));
ms.max_scratch_tokens = tokens;
}
let scratch = ms.scratch.as_ref().unwrap();
let needed = tokens * h;
if ms.input_buf.is_none() || ms.input_buf_size < needed {
ms.input_buf = Some(ms.pipes.buffer_empty(needed.max(128 * h))); ms.input_buf_size = needed.max(128 * h);
}
let input_buf = ms.input_buf.as_ref().unwrap();
unsafe {
std::ptr::copy_nonoverlapping(
input.as_ptr(),
input_buf.contents() as *mut f32,
needed,
);
}
let cmd = ms.pipes.queue.new_command_buffer();
metal::transformer::metal_layer_forward_v2(
cmd,
&ms.pipes,
input_buf,
tokens,
&ms.weights[0],
&ms.metal_cfg,
&mut ms.kv[0],
pos_offset,
&ms.cos_buf,
&ms.sin_buf,
scratch,
);
for li in 1..ms.weights.len() {
let enc = cmd.new_blit_command_encoder();
enc.copy_from_buffer(&scratch.output, 0, input_buf, 0, (tokens * h * 4) as u64);
enc.end_encoding();
metal::transformer::metal_layer_forward_v2(
cmd,
&ms.pipes,
input_buf,
tokens,
&ms.weights[li],
&ms.metal_cfg,
&mut ms.kv[li],
pos_offset,
&ms.cos_buf,
&ms.sin_buf,
scratch,
);
}
cmd.commit();
cmd.wait_until_completed();
let hidden =
metal::pipelines::MetalPipelines::read_buffer(&scratch.output, tokens * h);
self.tokens_generated += tokens;
return self.final_rms_norm(&hidden, tokens);
}
}
let mut hidden = input.to_vec();
for li in 0..self.cpu_layers.len() {
hidden = cpu::transformer::cpu_layer_forward(
&hidden,
tokens,
&self.cpu_layers[li],
&self.cfg,
&self.cos,
&self.sin,
&mut self.cpu_kv[li],
pos_offset,
);
}
self.tokens_generated += tokens;
self.final_rms_norm(&hidden, tokens)
}
#[cfg(feature = "metal")]
pub fn forward_gpu(
&mut self,
input: &[f32],
tokens: usize,
) -> Option<(::metal::Buffer, usize)> {
let pos_offset = self.tokens_generated;
let h = self.cfg.hidden_size;
if self.use_cpu {
return None;
}
let ms = self.metal_state.as_mut()?;
if ms.scratch.is_none() || ms.max_scratch_tokens < tokens {
ms.scratch = Some(metal::transformer::LayerScratch::new(
&ms.pipes,
tokens,
h,
ms.metal_cfg.intermediate_size,
ms.metal_cfg.num_heads,
ms.metal_cfg.num_kv_heads,
ms.metal_cfg.head_dim,
));
ms.max_scratch_tokens = tokens;
}
let scratch = ms.scratch.as_ref().unwrap();
let needed = tokens * h;
if ms.input_buf.is_none() || ms.input_buf_size < needed {
ms.input_buf = Some(ms.pipes.buffer_empty(needed.max(128 * h)));
ms.input_buf_size = needed.max(128 * h);
}
let input_buf = ms.input_buf.as_ref().unwrap();
unsafe {
std::ptr::copy_nonoverlapping(input.as_ptr(), input_buf.contents() as *mut f32, needed);
}
let cmd = ms.pipes.queue.new_command_buffer();
metal::transformer::metal_layer_forward_v2(
cmd,
&ms.pipes,
input_buf,
tokens,
&ms.weights[0],
&ms.metal_cfg,
&mut ms.kv[0],
pos_offset,
&ms.cos_buf,
&ms.sin_buf,
scratch,
);
for li in 1..ms.weights.len() {
let enc = cmd.new_blit_command_encoder();
enc.copy_from_buffer(&scratch.output, 0, input_buf, 0, (tokens * h * 4) as u64);
enc.end_encoding();
metal::transformer::metal_layer_forward_v2(
cmd,
&ms.pipes,
input_buf,
tokens,
&ms.weights[li],
&ms.metal_cfg,
&mut ms.kv[li],
pos_offset,
&ms.cos_buf,
&ms.sin_buf,
scratch,
);
}
if ms.norm_out_buf.is_none() {
ms.norm_out_buf = Some(ms.pipes.buffer_empty(needed.max(128 * h)));
}
let norm_out = ms.norm_out_buf.as_ref().unwrap();
{
let enc = cmd.new_compute_command_encoder();
ms.pipes.rms_norm_enc(
enc,
&scratch.output,
&ms.norm_w_buf,
norm_out,
tokens,
h,
self.cfg.rms_norm_eps as f32,
);
enc.end_encoding();
}
cmd.commit();
cmd.wait_until_completed();
self.tokens_generated += tokens;
let result = ms.pipes.buffer_empty(tokens * h);
let cmd2 = ms.pipes.queue.new_command_buffer();
let enc = cmd2.new_blit_command_encoder();
enc.copy_from_buffer(norm_out, 0, &result, 0, (tokens * h * 4) as u64);
enc.end_encoding();
cmd2.commit();
cmd2.wait_until_completed();
Some((result, tokens * h))
}
#[cfg(feature = "metal")]
pub fn forward_and_argmax(
&mut self,
input_buf: &GpuBuffer,
tokens: usize,
lm_weights_buf: &GpuBuffer,
vocab_size: usize,
) -> Option<(u32, Vec<f32>)> {
let pos_offset = self.tokens_generated;
let h = self.cfg.hidden_size;
if self.use_cpu {
return None;
}
let ms = self.metal_state.as_mut()?;
if ms.scratch.is_none() || ms.max_scratch_tokens < tokens {
ms.scratch = Some(metal::transformer::LayerScratch::new(
&ms.pipes,
tokens,
h,
ms.metal_cfg.intermediate_size,
ms.metal_cfg.num_heads,
ms.metal_cfg.num_kv_heads,
ms.metal_cfg.head_dim,
));
ms.max_scratch_tokens = tokens;
}
let scratch = ms.scratch.as_ref().unwrap();
let needed = tokens * h;
if ms.input_buf.is_none() || ms.input_buf_size < needed {
ms.input_buf = Some(ms.pipes.buffer_empty(needed.max(128 * h)));
ms.input_buf_size = needed.max(128 * h);
}
let int_buf = ms.input_buf.as_ref().unwrap();
if ms.norm_out_buf.is_none() {
ms.norm_out_buf = Some(ms.pipes.buffer_empty(needed.max(128 * h)));
}
let norm_out = ms.norm_out_buf.as_ref().unwrap();
let cmd = ms.pipes.queue.new_command_buffer();
metal::transformer::metal_layer_forward_v2(
cmd,
&ms.pipes,
input_buf,
tokens,
&ms.weights[0],
&ms.metal_cfg,
&mut ms.kv[0],
pos_offset,
&ms.cos_buf,
&ms.sin_buf,
scratch,
);
for li in 1..ms.weights.len() {
let enc = cmd.new_blit_command_encoder();
enc.copy_from_buffer(&scratch.output, 0, int_buf, 0, (needed * 4) as u64);
enc.end_encoding();
metal::transformer::metal_layer_forward_v2(
cmd,
&ms.pipes,
int_buf,
tokens,
&ms.weights[li],
&ms.metal_cfg,
&mut ms.kv[li],
pos_offset,
&ms.cos_buf,
&ms.sin_buf,
scratch,
);
}
{
let enc = cmd.new_compute_command_encoder();
ms.pipes.rms_norm_enc(
enc,
&scratch.output,
&ms.norm_w_buf,
norm_out,
tokens,
h,
self.cfg.rms_norm_eps as f32,
);
enc.end_encoding();
}
let logits_buf = if ms.input_buf_size >= vocab_size {
&scratch.gate_buf } else {
&scratch.gate_buf
};
{
let enc = cmd.new_compute_command_encoder();
ms.pipes
.gemm_v2(enc, norm_out, lm_weights_buf, logits_buf, 1, vocab_size, h);
enc.end_encoding();
}
let result_ptr = scratch.up_buf.contents() as *mut u32;
{
let enc = cmd.new_compute_command_encoder();
#[repr(C)]
struct P {
n: i32,
}
let p = P {
n: vocab_size as i32,
};
let p_buf = ms.pipes.device.new_buffer_with_data(
&p as *const _ as *const std::ffi::c_void,
4,
::metal::MTLResourceOptions::StorageModeShared,
);
enc.set_compute_pipeline_state(ms.pipes.pipeline("argmax_f32"));
enc.set_buffer(0, Some(logits_buf), 0);
enc.set_buffer(1, Some(&scratch.up_buf), 0);
enc.set_buffer(2, Some(&p_buf), 0);
enc.dispatch_thread_groups(
::metal::MTLSize::new(1, 1, 1),
::metal::MTLSize::new(256, 1, 1),
);
enc.end_encoding();
}
cmd.commit();
cmd.wait_until_completed();
self.tokens_generated += tokens;
let token = unsafe { *result_ptr };
let hidden_vec = metal::pipelines::MetalPipelines::read_buffer(norm_out, needed);
Some((token, hidden_vec))
}
#[cfg(feature = "metal")]
pub fn forward_gpu_buffer(
&mut self,
input_buf: &::metal::Buffer,
tokens: usize,
) -> Option<::metal::Buffer> {
let pos_offset = self.tokens_generated;
let h = self.cfg.hidden_size;
if self.use_cpu {
return None;
}
let ms = self.metal_state.as_mut()?;
if ms.scratch.is_none() || ms.max_scratch_tokens < tokens {
ms.scratch = Some(metal::transformer::LayerScratch::new(
&ms.pipes,
tokens,
h,
ms.metal_cfg.intermediate_size,
ms.metal_cfg.num_heads,
ms.metal_cfg.num_kv_heads,
ms.metal_cfg.head_dim,
));
ms.max_scratch_tokens = tokens;
}
let scratch = ms.scratch.as_ref().unwrap();
let cmd = ms.pipes.queue.new_command_buffer();
metal::transformer::metal_layer_forward_v2(
cmd,
&ms.pipes,
input_buf,
tokens,
&ms.weights[0],
&ms.metal_cfg,
&mut ms.kv[0],
pos_offset,
&ms.cos_buf,
&ms.sin_buf,
scratch,
);
let needed = tokens * h;
if ms.input_buf.is_none() || ms.input_buf_size < needed {
ms.input_buf = Some(ms.pipes.buffer_empty(needed.max(128 * h)));
ms.input_buf_size = needed.max(128 * h);
}
let int_buf = ms.input_buf.as_ref().unwrap();
for li in 1..ms.weights.len() {
let enc = cmd.new_blit_command_encoder();
enc.copy_from_buffer(&scratch.output, 0, int_buf, 0, (tokens * h * 4) as u64);
enc.end_encoding();
metal::transformer::metal_layer_forward_v2(
cmd,
&ms.pipes,
int_buf,
tokens,
&ms.weights[li],
&ms.metal_cfg,
&mut ms.kv[li],
pos_offset,
&ms.cos_buf,
&ms.sin_buf,
scratch,
);
}
if ms.norm_out_buf.is_none() {
ms.norm_out_buf = Some(ms.pipes.buffer_empty(needed.max(128 * h)));
}
let norm_out = ms.norm_out_buf.as_ref().unwrap();
{
let enc = cmd.new_compute_command_encoder();
ms.pipes.rms_norm_enc(
enc,
&scratch.output,
&ms.norm_w_buf,
norm_out,
tokens,
h,
self.cfg.rms_norm_eps as f32,
);
enc.end_encoding();
}
cmd.commit();
cmd.wait_until_completed();
self.tokens_generated += tokens;
let result = ms.pipes.buffer_empty(tokens * h);
let cmd2 = ms.pipes.queue.new_command_buffer();
let enc = cmd2.new_blit_command_encoder();
enc.copy_from_buffer(norm_out, 0, &result, 0, (tokens * h * 4) as u64);
enc.end_encoding();
cmd2.commit();
cmd2.wait_until_completed();
Some(result)
}
#[cfg(feature = "metal")]
pub fn forward_gpu_to_vec(&mut self, input: &[f32], tokens: usize) -> Option<Vec<f32>> {
let h = self.cfg.hidden_size;
let (buf, _) = self.forward_gpu(input, tokens)?;
Some(metal::pipelines::MetalPipelines::read_buffer(
&buf,
tokens * h,
))
}
fn final_rms_norm(&self, hidden: &[f32], tokens: usize) -> Vec<f32> {
let h = self.cfg.hidden_size;
let eps = self.cfg.rms_norm_eps as f32;
let mut out = vec![0.0f32; tokens * h];
for t in 0..tokens {
let row = &hidden[t * h..(t + 1) * h];
let o = &mut out[t * h..(t + 1) * h];
let sum_sq;
#[cfg(feature = "metal")]
{
extern "C" {
fn vDSP_dotpr(
a: *const f32,
a_stride: i32,
b: *const f32,
b_stride: i32,
result: *mut f32,
n: u64,
);
}
let mut dot = 0.0f32;
unsafe {
vDSP_dotpr(row.as_ptr(), 1, row.as_ptr(), 1, &mut dot, h as u64);
}
sum_sq = dot;
}
#[cfg(not(feature = "metal"))]
{
let mut v = 0.0f32;
for &val in row {
v += val * val;
}
sum_sq = v;
}
let inv = 1.0f32 / (sum_sq / h as f32 + eps).sqrt();
for i in 0..h {
o[i] = row[i] * inv * self.norm_w[i];
}
}
out
}
pub fn create_gpu_buffer(&self, data: &[f32]) -> Option<GpuBuffer> {
#[cfg(feature = "metal")]
{
let ms = self.metal_state.as_ref()?;
Some(ms.pipes.buffer_from_data(data))
}
#[cfg(not(feature = "metal"))]
{
Some(data.to_vec())
}
}
pub fn reset(&mut self) {
self.tokens_generated = 0;
for kv in &mut self.cpu_kv {
*kv = cpu::transformer::CpuKvCache::new();
}
#[cfg(feature = "metal")]
if let Some(ref mut ms) = self.metal_state {
for kv in &mut ms.kv {
kv.reset();
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn attention_runtime_env_parses_forced_backends() {
let env = AttentionRuntimeEnv::from_env_vars([
("FERRUM_FUSED_CPU", "1"),
("FERRUM_FUSED_METAL", "0"),
]);
assert!(env.fused_cpu);
assert!(!env.fused_metal);
}
#[test]
fn attention_runtime_env_only_accepts_one() {
let env = AttentionRuntimeEnv::from_env_vars([
("FERRUM_FUSED_CPU", "true"),
("FERRUM_FUSED_METAL", "1"),
]);
assert!(!env.fused_cpu);
assert!(env.fused_metal);
}
}