use metal::{
Buffer, CommandBufferRef, ComputePipelineState, Device,
MTLResourceOptions, MTLSize, NSUInteger,
};
use super::deferred::{
gpu_batched_experts_begin, gpu_batched_experts_begin_pre_staged,
DeferredError,
};
use super::expert_forward::{ChainToNormed, MoeBuffers};
use super::expert_io::ExpertFiles;
use super::gpu_attn::{
encode_attn_scores_batched_into, encode_attn_softmax_batched_into,
encode_attn_values_batched_into, encode_sigmoid_gate_into,
GpuAttnPipelines,
};
use super::gpu_linear_attn::{
encode_compute_decay_beta, encode_conv1d_step, encode_delta_net_step,
encode_gated_rms_norm, encode_rms_norm_qk, LinearAttnPipelines,
};
use super::gpu_matvec::{encode_matvec, MatvecPipelines, MatvecSpec};
use super::gpu_norm::{encode_rms_norm_bf16_into, RmsNormBf16Pipelines};
use super::layer_weight_cache::LayerWeightCache;
use super::metal::{MetalBackend, MetalError};
use super::moe_router::moe_router_cpu;
use super::mtl_weight_buf::MtlWeightBuf;
use super::state::LinearAttnState;
use super::variants::{Variant, RMS_NORM_EPS, VARIANT};
use super::weight_file::WeightFile;
#[derive(Debug, thiserror::Error)]
pub enum LayerForwardError {
#[error("missing tensor for layer {layer}: {tensor}")]
MissingTensor {
layer: usize,
tensor: &'static str,
},
#[error("hidden_in must be HIDDEN_DIM={expected} floats, got {actual}")]
BadHiddenLen { expected: usize, actual: usize },
#[error("Metal: {0}")]
Metal(#[from] MetalError),
#[error("MoE router: {0}")]
Router(#[from] super::moe_router::MoeRouterError),
#[error("expert FFN: {0}")]
Expert(#[from] super::expert_forward::ExpertForwardError),
#[error("expert I/O: {0}")]
ExpertIo(#[from] super::expert_io::ExpertIoError),
#[error("RoPE: {0}")]
Rope(#[from] super::rope::RopeError),
#[error("SDPA: {0}")]
Sdpa(#[from] super::sdpa::SdpaError),
#[error("RMSNorm: {0}")]
RmsNorm(#[from] super::rms_norm::RmsNormError),
#[error("deferred experts: {0}")]
Deferred(#[from] DeferredError),
}
pub type LinearAttnForwardError = LayerForwardError;
pub struct LayerForwardBuffers {
pub input: Buffer,
pub normed: Buffer,
pub residual: Buffer,
pub h_mid: Buffer,
pub output: Buffer,
pub batch_out: [Buffer; 7],
pub conv_state: Vec<Buffer>,
pub delta_state: Vec<Buffer>,
pub conv_output: Buffer,
pub delta_g_decay: Buffer,
pub delta_beta: Buffer,
pub delta_output: Buffer,
pub sum_sq: Buffer,
pub shared_gate_out: Buffer,
pub shared_up_out: Buffer,
pub shared_act: Buffer,
pub shared_out: Buffer,
pub q_proj_out: Buffer,
pub k_out: Buffer,
pub v_out: Buffer,
pub gpu_kv_k: Vec<Buffer>,
pub gpu_kv_v: Vec<Buffer>,
pub gpu_attn_q: Buffer,
pub gpu_attn_scores: Buffer,
pub gpu_attn_out: Buffer,
pub gpu_attn_gate: Buffer,
}
pub type LinearAttnBuffers = LayerForwardBuffers;
impl LayerForwardBuffers {
pub fn new(device: &Device) -> Self {
let v = VARIANT;
let f32_buf = |n: usize| {
let b = device.new_buffer(
(n * std::mem::size_of::<f32>()) as NSUInteger,
MTLResourceOptions::StorageModeShared,
);
unsafe {
std::ptr::write_bytes(
b.contents() as *mut u8,
0,
n * std::mem::size_of::<f32>(),
);
}
b
};
let q_dim_full = v.num_attn_heads * v.head_dim;
let q_proj_dim_full = q_dim_full * 2;
let kv_dim_full = v.num_kv_heads * v.head_dim;
let oproj_in_max = v.linear_total_value().max(q_dim_full);
let batch_sizes = [
v.linear_conv_dim(),
v.linear_total_value(),
v.linear_num_v_heads,
v.linear_num_v_heads,
v.num_experts,
1,
oproj_in_max,
];
let batch_out: [Buffer; 7] =
std::array::from_fn(|i| f32_buf(batch_sizes[i]));
let num_linear = v.num_layers - num_full_attn_layers(&v);
let conv_state = (0..num_linear)
.map(|_| {
f32_buf((Variant::CONV_KERNEL_SIZE - 1) * v.linear_conv_dim())
})
.collect();
let delta_state = (0..num_linear)
.map(|_| {
f32_buf(
v.linear_num_v_heads
* Variant::LINEAR_VALUE_DIM
* Variant::LINEAR_KEY_DIM,
)
})
.collect();
let num_full_attn = num_full_attn_layers(&v);
let gpu_kv_floats =
super::variants::GPU_KV_SEQ * kv_dim_full;
let gpu_kv_k =
(0..num_full_attn).map(|_| f32_buf(gpu_kv_floats)).collect();
let gpu_kv_v =
(0..num_full_attn).map(|_| f32_buf(gpu_kv_floats)).collect();
Self {
input: f32_buf(v.hidden_dim),
normed: f32_buf(v.hidden_dim),
residual: f32_buf(v.hidden_dim),
h_mid: f32_buf(v.hidden_dim),
output: f32_buf(v.hidden_dim),
batch_out,
conv_state,
delta_state,
conv_output: f32_buf(v.linear_conv_dim()),
delta_g_decay: f32_buf(v.linear_num_v_heads),
delta_beta: f32_buf(v.linear_num_v_heads),
delta_output: f32_buf(v.linear_total_value()),
sum_sq: f32_buf(1),
shared_gate_out: f32_buf(v.shared_intermediate),
shared_up_out: f32_buf(v.shared_intermediate),
shared_act: f32_buf(v.shared_intermediate),
shared_out: f32_buf(v.hidden_dim),
q_proj_out: f32_buf(q_proj_dim_full),
k_out: f32_buf(kv_dim_full),
v_out: f32_buf(kv_dim_full),
gpu_kv_k,
gpu_kv_v,
gpu_attn_q: f32_buf(q_dim_full),
gpu_attn_scores: f32_buf(
v.num_attn_heads * super::variants::GPU_KV_SEQ,
),
gpu_attn_out: f32_buf(q_dim_full),
gpu_attn_gate: f32_buf(q_dim_full),
}
}
pub fn reset_recurrence(&mut self) {
for b in &self.conv_state {
zero_f32_buffer(b);
}
for b in &self.delta_state {
zero_f32_buffer(b);
}
}
pub fn reset_gpu_attn_kv_mirrors(&mut self) {
for b in &self.gpu_kv_k {
zero_f32_buffer(b);
}
for b in &self.gpu_kv_v {
zero_f32_buffer(b);
}
}
}
fn zero_f32_buffer(b: &Buffer) {
let bytes = b.length() as usize;
unsafe {
std::ptr::write_bytes(b.contents() as *mut u8, 0, bytes);
}
}
pub fn linear_layer_idx_for(layer_idx: usize) -> Option<usize> {
use super::variants::LayerKind;
if VARIANT.layer_kind(layer_idx) == LayerKind::FullAttn {
None
} else {
Some(layer_idx - (layer_idx + 1) / VARIANT.full_attn_interval)
}
}
pub fn full_attn_layer_idx_for(layer_idx: usize) -> Option<usize> {
use super::variants::LayerKind;
if VARIANT.layer_kind(layer_idx) == LayerKind::FullAttn {
Some((layer_idx + 1) / VARIANT.full_attn_interval - 1)
} else {
None
}
}
pub(super) fn num_full_attn_layers(v: &Variant) -> usize {
v.num_layers / v.full_attn_interval
}
pub(super) fn bits_of(wf: &WeightFile, name: &str) -> u32 {
wf.tensor_info(name)
.map(|i| i.bits as u32)
.unwrap_or(4)
.max(4)
}
pub(super) struct OProj {
pub w_off: u64,
pub s_off: u64,
pub b_off: u64,
pub bits: u32,
pub in_dim: u32,
}
pub(super) struct GpuAttnEncodeArgs {
pub fa_idx: usize,
pub kv_len: u32,
}
#[allow(clippy::too_many_arguments)]
pub fn linear_attn_layer_forward(
metal: &mut MetalBackend,
wf: &WeightFile,
wf_buf: &MtlWeightBuf,
layer_cache: &LayerWeightCache,
buffers: &mut LayerForwardBuffers,
moe: &mut MoeBuffers,
deferred: &mut super::DeferredRing,
layer_idx: usize,
k_active: usize,
expert_files: &ExpertFiles,
pool: &rayon::ThreadPool,
prefetch: &mut super::PrefetchState,
prefetch_set: usize,
_layer_state: &mut LinearAttnState,
gpu_combine: bool,
prev_layer_chained: bool,
chain_next_norm_off: Option<u64>,
) -> Result<(), LayerForwardError> {
let v = VARIANT;
let linear_layer_idx = linear_layer_idx_for(layer_idx).ok_or(
LayerForwardError::MissingTensor {
layer: layer_idx,
tensor: "linear_layer_idx (called on full-attn layer)",
},
)?;
let qkv_bits = bits_of(
wf,
&format!("model.layers.{layer_idx}.linear_attn.in_proj_qkv.weight"),
);
let z_bits = bits_of(
wf,
&format!("model.layers.{layer_idx}.linear_attn.in_proj_z.weight"),
);
let alpha_bits = bits_of(
wf,
&format!("model.layers.{layer_idx}.linear_attn.in_proj_a.weight"),
);
let beta_bits = bits_of(
wf,
&format!("model.layers.{layer_idx}.linear_attn.in_proj_b.weight"),
);
let o_bits = bits_of(
wf,
&format!("model.layers.{layer_idx}.linear_attn.out_proj.weight"),
);
let attn = layer_cache.attn.linear().ok_or(
LayerForwardError::MissingTensor {
layer: layer_idx,
tensor: "linear_attn weights (called on full-attn layer)",
},
)?;
let qkv_w = attn.qkv_w;
let qkv_s = attn.qkv_s;
let qkv_b = attn.qkv_b;
let z_w = attn.z_w;
let z_s = attn.z_s;
let z_b = attn.z_b;
let beta_w = attn.beta_w;
let beta_s = attn.beta_s;
let beta_b = attn.beta_b;
let alpha_w = attn.alpha_w;
let alpha_s = attn.alpha_s;
let alpha_b = attn.alpha_b;
let conv1d_w = attn.conv1d_w;
let a_log = attn.a_log;
let dt_bias = attn.dt_bias;
let gnorm_w = attn.gated_norm_w;
let o_w = attn.o_proj_w;
let o_s = attn.o_proj_s;
let o_b = attn.o_proj_b;
let lp = LinearAttnPipelines::fetch(metal)?;
let mv = MatvecPipelines::fetch(metal)?;
let rms_pipes = RmsNormBf16Pipelines::fetch(metal)?;
{
let cmdbuf = metal.queue().new_command_buffer();
if !prev_layer_chained {
encode_rms_norm_bf16_into(
cmdbuf,
&rms_pipes,
&buffers.input,
wf_buf.buffer(),
layer_cache.input_layernorm_w,
&buffers.sum_sq,
&buffers.normed,
v.hidden_dim as u32,
RMS_NORM_EPS,
);
}
let specs = [
MatvecSpec {
w_off: qkv_w,
s_off: qkv_s,
b_off: qkv_b,
input: &buffers.normed,
output: &buffers.batch_out[0],
out_dim: v.linear_conv_dim() as u32,
in_dim: v.hidden_dim as u32,
bits: qkv_bits,
},
MatvecSpec {
w_off: z_w,
s_off: z_s,
b_off: z_b,
input: &buffers.normed,
output: &buffers.batch_out[1],
out_dim: v.linear_total_value() as u32,
in_dim: v.hidden_dim as u32,
bits: z_bits,
},
MatvecSpec {
w_off: beta_w,
s_off: beta_s,
b_off: beta_b,
input: &buffers.normed,
output: &buffers.batch_out[2],
out_dim: v.linear_num_v_heads as u32,
in_dim: v.hidden_dim as u32,
bits: beta_bits,
},
MatvecSpec {
w_off: alpha_w,
s_off: alpha_s,
b_off: alpha_b,
input: &buffers.normed,
output: &buffers.batch_out[3],
out_dim: v.linear_num_v_heads as u32,
in_dim: v.hidden_dim as u32,
bits: alpha_bits,
},
];
for s in &specs {
encode_matvec(cmdbuf, &mv, wf_buf, s);
}
encode_conv1d_step(
cmdbuf,
&lp.conv1d_step,
&buffers.conv_state[linear_layer_idx],
&buffers.batch_out[0],
wf_buf.buffer(),
conv1d_w,
&buffers.conv_output,
v.linear_conv_dim() as u32,
);
encode_rms_norm_qk(
cmdbuf,
&lp.rms_norm_qk,
&buffers.conv_output,
v.linear_num_k_heads as u32,
Variant::LINEAR_KEY_DIM as u32,
);
encode_compute_decay_beta(
cmdbuf,
&lp.compute_decay_beta,
&buffers.batch_out[3], &buffers.batch_out[2], wf_buf.buffer(),
a_log,
dt_bias,
&buffers.delta_g_decay,
&buffers.delta_beta,
v.linear_num_v_heads as u32,
);
let k_heads_per_v =
(v.linear_num_v_heads / v.linear_num_k_heads) as u32;
encode_delta_net_step(
cmdbuf,
&lp.delta_net_step,
&buffers.delta_state[linear_layer_idx],
&buffers.conv_output,
&buffers.delta_g_decay,
&buffers.delta_beta,
&buffers.delta_output,
v.linear_num_v_heads as u32,
Variant::LINEAR_VALUE_DIM as u32,
k_heads_per_v,
);
encode_gated_rms_norm(
cmdbuf,
&lp.gated_rms_norm,
&buffers.delta_output,
&buffers.batch_out[1], wf_buf.buffer(),
gnorm_w,
&buffers.batch_out[6],
v.linear_num_v_heads as u32,
Variant::LINEAR_VALUE_DIM as u32,
);
cmdbuf.commit();
cmdbuf.wait_until_completed();
}
post_attention_tail(
metal,
wf,
wf_buf,
layer_cache,
buffers,
moe,
deferred,
layer_idx,
k_active,
expert_files,
pool,
prefetch,
prefetch_set,
OProj {
w_off: o_w,
s_off: o_s,
b_off: o_b,
bits: o_bits,
in_dim: v.linear_total_value() as u32,
},
gpu_combine,
None,
chain_next_norm_off,
)
}
#[allow(clippy::too_many_arguments)]
pub(super) fn post_attention_tail(
metal: &mut MetalBackend,
wf: &WeightFile,
wf_buf: &MtlWeightBuf,
layer_cache: &LayerWeightCache,
buffers: &mut LayerForwardBuffers,
moe: &mut MoeBuffers,
deferred: &mut super::DeferredRing,
layer_idx: usize,
k_active: usize,
expert_files: &ExpertFiles,
pool: &rayon::ThreadPool,
prefetch: &mut super::PrefetchState,
prefetch_set: usize,
o_proj: OProj,
gpu_combine: bool,
gpu_attn_args: Option<GpuAttnEncodeArgs>,
chain_next_norm_off: Option<u64>,
) -> Result<(), LayerForwardError> {
let v = VARIANT;
let gate_bits =
bits_of(wf, &format!("model.layers.{layer_idx}.mlp.gate.weight"));
let seg_bits = bits_of(
wf,
&format!(
"model.layers.{layer_idx}.mlp.shared_expert_gate.weight"
),
);
let s_gate_bits = bits_of(
wf,
&format!(
"model.layers.{layer_idx}.mlp.shared_expert.gate_proj.weight"
),
);
let s_up_bits = bits_of(
wf,
&format!(
"model.layers.{layer_idx}.mlp.shared_expert.up_proj.weight"
),
);
let s_down_bits = bits_of(
wf,
&format!(
"model.layers.{layer_idx}.mlp.shared_expert.down_proj.weight"
),
);
let post_attn_norm_w = layer_cache.post_attention_layernorm_w;
let gate_w = layer_cache.gate.w;
let gate_s = layer_cache.gate.s;
let gate_b = layer_cache.gate.b;
let shared_up_w = layer_cache.shared.up_w;
let shared_up_s = layer_cache.shared.up_s;
let shared_up_b = layer_cache.shared.up_b;
let shared_gate_w = layer_cache.shared.gate_w;
let shared_gate_s = layer_cache.shared.gate_s;
let shared_gate_b = layer_cache.shared.gate_b;
let shared_down_w = layer_cache.shared.down_w;
let shared_down_s = layer_cache.shared.down_s;
let shared_down_b = layer_cache.shared.down_b;
let seg_w = layer_cache.shared.seg_w;
let seg_s = layer_cache.shared.seg_s;
let seg_b = layer_cache.shared.seg_b;
let mv = MatvecPipelines::fetch(metal)?;
let sum_sq = metal.pipeline("rms_norm_sum_sq")?.clone();
let apply = metal.pipeline("rms_norm_apply_bf16")?.clone();
let resid_add = metal.pipeline("residual_add")?.clone();
let swiglu = metal.pipeline("swiglu_fused")?.clone();
let attn_pipes = if gpu_attn_args.is_some() {
Some(GpuAttnPipelines::fetch(metal)?)
} else {
None
};
{
let cmdbuf = metal.queue().new_command_buffer();
if let (Some(args), Some(attn_pipes)) =
(gpu_attn_args.as_ref(), attn_pipes.as_ref())
{
let head_dim = v.head_dim as u32;
let kv_dim = (v.num_kv_heads * v.head_dim) as u32;
let num_heads = v.num_attn_heads as u32;
let heads_per_kv = (v.num_attn_heads / v.num_kv_heads) as u32;
let scale = 1.0f32 / (head_dim as f32).sqrt();
let seq_stride = super::variants::GPU_KV_SEQ as u32;
encode_attn_scores_batched_into(
cmdbuf,
&attn_pipes.scores,
&buffers.gpu_attn_q,
&buffers.gpu_kv_k[args.fa_idx],
&buffers.gpu_attn_scores,
num_heads,
head_dim,
kv_dim,
args.kv_len,
seq_stride,
heads_per_kv,
scale,
);
encode_attn_softmax_batched_into(
cmdbuf,
&attn_pipes.softmax,
&buffers.gpu_attn_scores,
num_heads,
args.kv_len,
seq_stride,
);
encode_attn_values_batched_into(
cmdbuf,
&attn_pipes.values,
&buffers.gpu_attn_scores,
&buffers.gpu_kv_v[args.fa_idx],
&buffers.gpu_attn_out,
num_heads,
head_dim,
kv_dim,
args.kv_len,
seq_stride,
heads_per_kv,
);
encode_sigmoid_gate_into(
cmdbuf,
&attn_pipes.gate,
&buffers.gpu_attn_out,
&buffers.gpu_attn_gate,
num_heads * head_dim,
);
}
let oproj_input = if gpu_attn_args.is_some() {
&buffers.gpu_attn_out
} else {
&buffers.batch_out[6]
};
encode_matvec(
cmdbuf,
&mv,
wf_buf,
&MatvecSpec {
w_off: o_proj.w_off,
s_off: o_proj.s_off,
b_off: o_proj.b_off,
input: oproj_input,
output: &buffers.output,
out_dim: v.hidden_dim as u32,
in_dim: o_proj.in_dim,
bits: o_proj.bits,
},
);
encode_residual_add(
cmdbuf,
&resid_add,
&buffers.output,
&buffers.input, &buffers.h_mid,
v.hidden_dim as u32,
);
encode_rms_norm_pair(
cmdbuf,
&sum_sq,
&apply,
&buffers.h_mid,
wf_buf.buffer(),
post_attn_norm_w,
&buffers.normed,
&buffers.sum_sq,
v.hidden_dim as u32,
);
encode_matvec(
cmdbuf,
&mv,
wf_buf,
&MatvecSpec {
w_off: gate_w,
s_off: gate_s,
b_off: gate_b,
input: &buffers.normed,
output: &buffers.batch_out[4],
out_dim: v.num_experts as u32,
in_dim: v.hidden_dim as u32,
bits: gate_bits,
},
);
encode_matvec(
cmdbuf,
&mv,
wf_buf,
&MatvecSpec {
w_off: seg_w,
s_off: seg_s,
b_off: seg_b,
input: &buffers.normed,
output: &buffers.batch_out[5],
out_dim: 1,
in_dim: v.hidden_dim as u32,
bits: seg_bits,
},
);
encode_matvec(
cmdbuf,
&mv,
wf_buf,
&MatvecSpec {
w_off: shared_gate_w,
s_off: shared_gate_s,
b_off: shared_gate_b,
input: &buffers.normed,
output: &buffers.shared_gate_out,
out_dim: v.shared_intermediate as u32,
in_dim: v.hidden_dim as u32,
bits: s_gate_bits,
},
);
encode_matvec(
cmdbuf,
&mv,
wf_buf,
&MatvecSpec {
w_off: shared_up_w,
s_off: shared_up_s,
b_off: shared_up_b,
input: &buffers.normed,
output: &buffers.shared_up_out,
out_dim: v.shared_intermediate as u32,
in_dim: v.hidden_dim as u32,
bits: s_up_bits,
},
);
encode_swiglu_buf(
cmdbuf,
&swiglu,
&buffers.shared_gate_out,
&buffers.shared_up_out,
&buffers.shared_act,
v.shared_intermediate as u32,
);
encode_matvec(
cmdbuf,
&mv,
wf_buf,
&MatvecSpec {
w_off: shared_down_w,
s_off: shared_down_s,
b_off: shared_down_b,
input: &buffers.shared_act,
output: &buffers.shared_out,
out_dim: v.hidden_dim as u32,
in_dim: v.shared_intermediate as u32,
bits: s_down_bits,
},
);
cmdbuf.commit();
cmdbuf.wait_until_completed();
}
let mut scores =
read_buffer_to_vec(&buffers.batch_out[4], v.num_experts);
let mut indices = vec![0i32; k_active];
let mut weights = vec![0f32; k_active];
moe_router_cpu(&mut scores, k_active, &mut indices, &mut weights)?;
let shared_gate_score = {
let s = read_buffer_to_vec(&buffers.batch_out[5], 1);
s[0]
};
let k = k_active;
if gpu_combine {
use rayon::prelude::*;
use super::prefetch::SlotSource;
const MAX_K: usize = super::expert_forward::MAX_K;
let prefetch_status = prefetch.wait_for(layer_idx);
let mut data_set_per_slot: [SlotSource; MAX_K] =
[SlotSource::Synced; MAX_K];
if let Some(status) = prefetch_status {
for slot in 0..k.min(status.k) {
if status.loaded_indices[slot] == indices[slot] {
data_set_per_slot[slot] = SlotSource::Prefetched;
}
}
}
let mut dsts = moe.data_synced_slots_mut_array();
let active = &mut dsts[..k];
pool.install(|| -> Result<(), super::expert_io::ExpertIoError> {
active
.par_iter_mut()
.enumerate()
.try_for_each(|(slot, dst)| {
if data_set_per_slot[slot] == SlotSource::Synced {
let expert_idx = indices[slot] as usize;
expert_files
.read_expert(layer_idx, expert_idx, *dst)
} else {
Ok(())
}
})
})?;
let mut actuals: [i32; MAX_K] = [0; MAX_K];
actuals[..k].copy_from_slice(&indices[..k]);
prefetch.record_actual(layer_idx, actuals);
let chain_rms_pipes = chain_next_norm_off.map(|_| {
super::gpu_norm::RmsNormBf16Pipelines {
sum: sum_sq.clone(),
apply: apply.clone(),
}
});
let chain = chain_next_norm_off.and_then(|off| {
chain_rms_pipes.as_ref().map(|pipes| ChainToNormed {
pipes,
wf_buf: wf_buf.buffer(),
next_norm_off: off,
combine_out: &buffers.input,
chain_sum_sq: &buffers.sum_sq,
chain_normed: &buffers.normed,
eps: RMS_NORM_EPS,
})
});
gpu_batched_experts_begin_pre_staged(
metal,
moe,
deferred,
k as i32,
&buffers.normed, &buffers.h_mid, &buffers.shared_out, &weights,
shared_gate_score,
layer_idx as i32,
&data_set_per_slot,
prefetch_set,
chain,
)?;
} else {
let expert_size = v.expert_size_4bit();
let mut expert_data = vec![0u8; k * expert_size];
for slot in 0..k {
let expert_idx = indices[slot] as usize;
let dst = &mut expert_data
[slot * expert_size..(slot + 1) * expert_size];
expert_files.read_expert(layer_idx, expert_idx, dst)?;
}
let h_mid_host = read_buffer_to_vec(&buffers.h_mid, v.hidden_dim);
let shared_out_host =
read_buffer_to_vec(&buffers.shared_out, v.hidden_dim);
let normed_host = read_buffer_to_vec(&buffers.normed, v.hidden_dim);
gpu_batched_experts_begin(
metal,
moe,
deferred,
k as i32,
&expert_data,
&normed_host,
&h_mid_host,
&shared_out_host,
&weights,
shared_gate_score,
layer_idx as i32,
false,
)?;
}
Ok(())
}
pub(super) fn read_buffer_to_vec(b: &Buffer, len: usize) -> Vec<f32> {
let ptr = b.contents() as *const f32;
unsafe { std::slice::from_raw_parts(ptr, len).to_vec() }
}
#[allow(clippy::too_many_arguments)]
fn encode_rms_norm_pair(
cmdbuf: &CommandBufferRef,
sum_pipe: &ComputePipelineState,
apply_pipe: &ComputePipelineState,
input: &Buffer,
weight_buf: &Buffer,
weight_off: u64,
output: &Buffer,
sum_sq: &Buffer,
dim: u32,
) {
{
let enc = cmdbuf.new_compute_command_encoder();
enc.set_compute_pipeline_state(sum_pipe);
enc.set_buffer(0, Some(input), 0);
enc.set_buffer(1, Some(sum_sq), 0);
enc.set_bytes(2, 4, (&dim as *const u32).cast());
enc.dispatch_thread_groups(
MTLSize::new(1, 1, 1),
MTLSize::new(256, 1, 1),
);
enc.end_encoding();
}
{
let enc = cmdbuf.new_compute_command_encoder();
enc.set_compute_pipeline_state(apply_pipe);
enc.set_buffer(0, Some(input), 0);
enc.set_buffer(1, Some(weight_buf), weight_off as NSUInteger);
enc.set_buffer(2, Some(sum_sq), 0);
enc.set_buffer(3, Some(output), 0);
let eps = RMS_NORM_EPS;
enc.set_bytes(4, 4, (&dim as *const u32).cast());
enc.set_bytes(5, 4, (&eps as *const f32).cast());
let num_tgs = (dim + 255) / 256;
enc.dispatch_thread_groups(
MTLSize::new(num_tgs as NSUInteger, 1, 1),
MTLSize::new(256, 1, 1),
);
enc.end_encoding();
}
}
fn encode_swiglu_buf(
cmdbuf: &CommandBufferRef,
pipeline: &ComputePipelineState,
gate: &Buffer,
up: &Buffer,
act: &Buffer,
dim: u32,
) {
let enc = cmdbuf.new_compute_command_encoder();
enc.set_compute_pipeline_state(pipeline);
enc.set_buffer(0, Some(gate), 0);
enc.set_buffer(1, Some(up), 0);
enc.set_buffer(2, Some(act), 0);
enc.set_bytes(3, 4, (&dim as *const u32).cast());
let num_tgs = (dim + 255) / 256;
enc.dispatch_thread_groups(
MTLSize::new(num_tgs as NSUInteger, 1, 1),
MTLSize::new(256, 1, 1),
);
enc.end_encoding();
}
fn encode_residual_add(
cmdbuf: &CommandBufferRef,
pipeline: &ComputePipelineState,
a: &Buffer,
b: &Buffer,
out: &Buffer,
dim: u32,
) {
let enc = cmdbuf.new_compute_command_encoder();
enc.set_compute_pipeline_state(pipeline);
enc.set_buffer(0, Some(a), 0);
enc.set_buffer(1, Some(b), 0);
enc.set_buffer(2, Some(out), 0);
enc.set_bytes(3, 4, (&dim as *const u32).cast());
let num_tgs = (dim + 255) / 256;
enc.dispatch_thread_groups(
MTLSize::new(num_tgs as NSUInteger, 1, 1),
MTLSize::new(256, 1, 1),
);
enc.end_encoding();
}