use metal::{
Buffer, CommandBufferRef, ComputePipelineState, MTLSize, NSUInteger,
};
fn moe_gather_id_enabled() -> bool {
use std::sync::OnceLock;
static CACHED: OnceLock<bool> = OnceLock::new();
*CACHED.get_or_init(|| {
match std::env::var("MOEFLUX_MOE_GATHER_ID").as_deref() {
Ok("0") | Ok("false") | Ok("off") => false,
_ => true,
}
})
}
pub(in crate::riir) fn sdpa_vb_enabled() -> bool {
use std::sync::OnceLock;
static CACHED: OnceLock<bool> = OnceLock::new();
*CACHED.get_or_init(|| {
matches!(
std::env::var("MOEFLUX_SDPA_VB").as_deref(),
Ok("1") | Ok("true") | Ok("on")
)
})
}
pub(in crate::riir) fn sdpa_gqa_enabled() -> bool {
use std::sync::OnceLock;
static CACHED: OnceLock<bool> = OnceLock::new();
*CACHED.get_or_init(|| {
matches!(
std::env::var("MOEFLUX_SDPA_GQA").as_deref(),
Ok("1") | Ok("true") | Ok("on")
)
})
}
pub(in crate::riir) fn delta_net_vb_enabled() -> bool {
use std::sync::OnceLock;
static CACHED: OnceLock<bool> = OnceLock::new();
*CACHED.get_or_init(|| {
!matches!(
std::env::var("MOEFLUX_DELTA_NET_VB").as_deref(),
Ok("0") | Ok("false") | Ok("off")
)
})
}
use crate::riir::backend::buftype::{
AlphaStackBuf, AttnInputBuf, AttnOutBuf, BetaGateBuf, BetaStackBuf,
BucketActBuf, BucketGateBuf, BucketInputBuf, BucketOutBuf,
BucketTokenIdxBuf, BucketUpBuf, BucketWeightsBuf, ConvOutBuf,
ConvStateBuf, DeltaOutBuf, DeltaStateBuf, ExpertBaseBuf,
ExpertIndicesBuf, GDecayBuf, HiddenBuf, HidsBuf, HtpeBuf, KProjOutBuf,
KvCacheKBuf, KvCacheVBuf, LogitsBuf, MoeInputBuf, MoeOutSumBuf,
OProjOutBuf, QBuf, QGateBuf, QProjOutBuf, QkvStackBuf, ResidualBuf,
RouterIdxBuf, RouterLogitsBuf, RouterWeightsBuf, SharedFfnActBuf,
SharedFfnDownBuf, SharedFfnGateBuf, SharedFfnUpBuf, SharedGateBuf,
TokenIdsBuf, ValueOutBuf, VProjOutBuf, ZStackBuf,
};
use crate::riir::backend::{Backend, BufId, BufferPool, MetalBufferPool};
use crate::riir::moe::deferred::{
gpu_batched_experts_begin, gpu_batched_experts_begin_mmap,
DeferredError,
};
use crate::riir::moe::expert_forward::{ChainToNormed, ExpertPayload, MoeBuffers};
use crate::riir::io::expert_io::ExpertFiles;
use crate::riir::attn::gpu_attn::{
encode_attn_scores_batched_into, encode_attn_softmax_batched_into,
encode_attn_values_batched_into, encode_sigmoid_gate_into,
GpuAttnPipelines,
};
use crate::riir::attn::gpu_linear_attn::{
encode_compute_decay_beta, encode_conv1d_step, encode_delta_net_step,
encode_gated_rms_norm, encode_rms_norm_qk, LinearAttnPipelines,
};
use crate::riir::backend::gpu::gpu_matvec::{encode_matvec, MatvecPipelines, MatvecSpec};
use crate::riir::backend::gpu::gpu_norm::{encode_rms_norm_bf16_into, RmsNormBf16Pipelines};
use crate::riir::backend::gpu::gpu_ctx::GpuLayerCtx;
use crate::riir::io::layer_weight_cache::LayerWeightCache;
use crate::riir::backend::gpu::metal::{MetalContext, MetalError};
use crate::riir::moe::moe_router::moe_router_cpu;
use crate::riir::io::mtl_weight_buf::MtlWeightBuf;
use crate::riir::snapshot::state::LinearAttnState;
use crate::riir::variants::{Variant, RMS_NORM_EPS, VARIANT};
use crate::riir::io::weight_file::WeightFile;
#[derive(Debug, thiserror::Error)]
pub enum LayerForwardError {
#[error("missing tensor for layer {layer}: {tensor}")]
MissingTensor {
layer: usize,
tensor: &'static str,
},
#[error("hidden_in must be HIDDEN_DIM={expected} floats, got {actual}")]
BadHiddenLen { expected: usize, actual: usize },
#[error("Metal: {0}")]
Metal(#[from] MetalError),
#[error("MoE router: {0}")]
Router(#[from] crate::riir::moe::moe_router::MoeRouterError),
#[error("expert FFN: {0}")]
Expert(#[from] crate::riir::moe::expert_forward::ExpertForwardError),
#[error("expert I/O: {0}")]
ExpertIo(#[from] crate::riir::io::expert_io::ExpertIoError),
#[error("RoPE: {0}")]
Rope(#[from] crate::riir::attn::rope::RopeError),
#[error("SDPA: {0}")]
Sdpa(#[from] crate::riir::attn::sdpa::SdpaError),
#[error("RMSNorm: {0}")]
RmsNorm(#[from] crate::riir::attn::rms_norm::RmsNormError),
#[error("deferred experts: {0}")]
Deferred(#[from] DeferredError),
#[error("graph: {0}")]
Graph(#[from] crate::riir::backend::GraphError),
}
pub type LinearAttnForwardError = LayerForwardError;
pub struct LayerForwardBuffers {
pub input: BufId<HiddenBuf>,
pub normed: BufId<AttnInputBuf>,
pub residual: BufId<ResidualBuf>,
pub h_mid: BufId<ResidualBuf>,
pub output: BufId<HiddenBuf>,
pub q_stack: BufId<QkvStackBuf>,
pub z_stack: BufId<ZStackBuf>,
pub beta_stack: BufId<BetaStackBuf>,
pub alpha_stack: BufId<AlphaStackBuf>,
pub gate_logits: BufId<RouterLogitsBuf>,
pub shared_gate: BufId<SharedGateBuf>,
pub o_proj_stack: BufId<OProjOutBuf>,
pub conv_state: Vec<BufId<ConvStateBuf>>,
pub delta_state: Vec<BufId<DeltaStateBuf>>,
pub conv_output: BufId<ConvOutBuf>,
pub delta_g_decay: BufId<GDecayBuf>,
pub delta_beta: BufId<BetaGateBuf>,
pub delta_output: BufId<DeltaOutBuf>,
pub sum_sq: BufId<HiddenBuf>,
pub shared_gate_out: BufId<SharedFfnGateBuf>,
pub shared_up_out: BufId<SharedFfnUpBuf>,
pub shared_act: BufId<SharedFfnActBuf>,
pub shared_out: BufId<SharedFfnDownBuf>,
pub q_proj_out: BufId<QProjOutBuf>,
pub k_out: BufId<KProjOutBuf>,
pub v_out: BufId<VProjOutBuf>,
pub gpu_kv_k: Vec<BufId<KvCacheKBuf>>,
pub gpu_kv_v: Vec<BufId<KvCacheVBuf>>,
pub gpu_attn_q: BufId<QBuf>,
pub gpu_attn_scores: BufId<HiddenBuf>,
pub gpu_attn_out: BufId<AttnOutBuf>,
pub gpu_attn_gate: BufId<QGateBuf>,
}
pub type LinearAttnBuffers = LayerForwardBuffers;
pub struct LinearAttnGraphScratch {
pub normed: BufId<AttnInputBuf>,
pub qkv_stack: BufId<QkvStackBuf>,
pub z_stack: BufId<ZStackBuf>,
pub beta_stack: BufId<BetaStackBuf>,
pub alpha_stack: BufId<AlphaStackBuf>,
pub conv_out_stack: BufId<ConvOutBuf>,
pub g_decay_stack: BufId<GDecayBuf>,
pub beta_gate_stack: BufId<BetaGateBuf>,
pub delta_out_stack: BufId<DeltaOutBuf>,
pub value_out_stack: BufId<ValueOutBuf>,
pub o_proj_stack: BufId<OProjOutBuf>,
pub gate_logits: BufId<RouterLogitsBuf>,
pub commit_planned: std::cell::Cell<bool>,
}
impl LinearAttnGraphScratch {
pub fn new(pool: &mut MetalBufferPool) -> Self {
let v = VARIANT;
let chunk = crate::riir::BATCHED_CHUNK_SIZE;
let f32_sz = std::mem::size_of::<f32>();
let hidden = v.hidden_dim;
let conv = v.linear_conv_dim();
let total_value = v.linear_total_value();
let num_v = v.linear_num_v_heads;
let delta_out = num_v * Variant::LINEAR_VALUE_DIM;
let bytes_of = |elems: usize| chunk * elems * f32_sz;
Self {
normed: pool
.alloc(bytes_of(hidden), "lags.normed", false)
.expect("LinearAttnGraphScratch::new pool alloc"),
qkv_stack: pool
.alloc(bytes_of(conv), "lags.qkv_stack", false)
.expect("LinearAttnGraphScratch::new pool alloc"),
z_stack: pool
.alloc(bytes_of(total_value), "lags.z_stack", false)
.expect("LinearAttnGraphScratch::new pool alloc"),
beta_stack: pool
.alloc(bytes_of(num_v), "lags.beta_stack", false)
.expect("LinearAttnGraphScratch::new pool alloc"),
alpha_stack: pool
.alloc(bytes_of(num_v), "lags.alpha_stack", false)
.expect("LinearAttnGraphScratch::new pool alloc"),
conv_out_stack: pool
.alloc(bytes_of(conv), "lags.conv_out_stack", false)
.expect("LinearAttnGraphScratch::new pool alloc"),
g_decay_stack: pool
.alloc(bytes_of(num_v), "lags.g_decay_stack", false)
.expect("LinearAttnGraphScratch::new pool alloc"),
beta_gate_stack: pool
.alloc(bytes_of(num_v), "lags.beta_gate_stack", false)
.expect("LinearAttnGraphScratch::new pool alloc"),
delta_out_stack: pool
.alloc(bytes_of(delta_out), "lags.delta_out_stack", false)
.expect("LinearAttnGraphScratch::new pool alloc"),
value_out_stack: pool
.alloc(bytes_of(total_value), "lags.value_out_stack", false)
.expect("LinearAttnGraphScratch::new pool alloc"),
o_proj_stack: pool
.alloc(bytes_of(hidden), "lags.o_proj_stack", false)
.expect("LinearAttnGraphScratch::new pool alloc"),
gate_logits: pool
.alloc(bytes_of(v.num_experts), "lags.gate_logits", false)
.expect("LinearAttnGraphScratch::new pool alloc"),
commit_planned: std::cell::Cell::new(false),
}
}
}
pub struct HiddenDoubleBuffer {
pub hidden_a: BufId<HiddenBuf>,
pub hidden_b: BufId<HiddenBuf>,
}
impl HiddenDoubleBuffer {
pub fn new(pool: &mut MetalBufferPool) -> Self {
let v = VARIANT;
let chunk = crate::riir::BATCHED_CHUNK_SIZE;
let bytes = chunk * v.hidden_dim * std::mem::size_of::<f32>();
Self {
hidden_a: pool
.alloc(bytes, "hdb.hidden_a", true)
.expect("HiddenDoubleBuffer::new pool alloc"),
hidden_b: pool
.alloc(bytes, "hdb.hidden_b", true)
.expect("HiddenDoubleBuffer::new pool alloc"),
}
}
}
pub struct HeadTailScratch {
pub token_ids: BufId<TokenIdsBuf>,
pub logits: BufId<LogitsBuf>,
}
impl HeadTailScratch {
pub fn new(pool: &mut MetalBufferPool) -> Self {
let v = VARIANT;
let chunk = crate::riir::BATCHED_CHUNK_SIZE;
let token_ids: BufId<TokenIdsBuf> = pool
.alloc(
chunk * std::mem::size_of::<i32>(),
"hts.token_ids",
true,
)
.expect("HeadTailScratch::new token_ids alloc");
let logits: BufId<LogitsBuf> = pool
.alloc(
v.vocab_size * std::mem::size_of::<f32>(),
"hts.logits",
true,
)
.expect("HeadTailScratch::new logits alloc");
Self { token_ids, logits }
}
}
pub struct MoeGraphScratch {
pub h_mid: BufId<ResidualBuf>,
pub h_post: BufId<MoeInputBuf>,
pub shared_gate: BufId<SharedGateBuf>,
pub routing_indices: BufId<RouterIdxBuf>,
pub routing_weights: BufId<RouterWeightsBuf>,
pub shared_ffn_gate: BufId<SharedFfnGateBuf>,
pub shared_up: BufId<SharedFfnUpBuf>,
pub shared_act: BufId<SharedFfnActBuf>,
pub shared_down: BufId<SharedFfnDownBuf>,
pub bucket_input: BufId<BucketInputBuf>,
pub bucket_gate: BufId<BucketGateBuf>,
pub bucket_up: BufId<BucketUpBuf>,
pub bucket_act: BufId<BucketActBuf>,
pub bucket_out: BufId<BucketOutBuf>,
pub bucket_token_idx: BufId<BucketTokenIdxBuf>,
pub bucket_weights: BufId<BucketWeightsBuf>,
pub out_sum: BufId<MoeOutSumBuf>,
pub expert_base: Option<BufId<ExpertBaseBuf>>,
pub expert_indices: BufId<ExpertIndicesBuf>,
pub htpe: BufId<HtpeBuf>,
pub hids: BufId<HidsBuf>,
pub commit_planned: std::cell::Cell<bool>,
}
impl MoeGraphScratch {
pub fn new(
pool: &mut MetalBufferPool,
k_active: usize,
mode: crate::riir::io::expert_io_mode::ExpertIoMode,
) -> Self {
let v = VARIANT;
let chunk = crate::riir::BATCHED_CHUNK_SIZE;
let f32_sz = std::mem::size_of::<f32>();
let hidden = v.hidden_dim;
let shared_inter = v.shared_intermediate;
let moe_inter = v.moe_intermediate;
let chk = |elems: usize| chunk * elems * f32_sz;
let bkt = |per: usize| chunk * k_active * per * f32_sz;
let max_k = crate::riir::moe::expert_forward::MAX_K;
let shared_ffn_gate: BufId<SharedFfnGateBuf> = pool
.alloc(chk(shared_inter), "mgs.shared_ffn_gate", false)
.expect("MoeGraphScratch::new pool alloc");
let shared_up: BufId<SharedFfnUpBuf> = pool
.alloc(chk(shared_inter), "mgs.shared_up", false)
.expect("MoeGraphScratch::new pool alloc");
let shared_act: BufId<SharedFfnActBuf> = pool
.alloc(chk(shared_inter), "mgs.shared_act", false)
.expect("MoeGraphScratch::new pool alloc");
let shared_down: BufId<SharedFfnDownBuf> = pool
.alloc(chk(hidden), "mgs.shared_down", false)
.expect("MoeGraphScratch::new pool alloc");
let bucket_gate: BufId<BucketGateBuf> = pool
.alloc(chk(k_active * moe_inter), "mgs.bucket_gate", false)
.expect("MoeGraphScratch::new pool alloc");
let bucket_up: BufId<BucketUpBuf> = pool
.alloc(chk(k_active * moe_inter), "mgs.bucket_up", false)
.expect("MoeGraphScratch::new pool alloc");
let bucket_act: BufId<BucketActBuf> = pool
.alloc(chk(k_active * moe_inter), "mgs.bucket_act", false)
.expect("MoeGraphScratch::new pool alloc");
let bucket_out: BufId<BucketOutBuf> = pool
.alloc(chk(k_active * hidden), "mgs.bucket_out", false)
.expect("MoeGraphScratch::new pool alloc");
let out_sum: BufId<MoeOutSumBuf> = pool
.alloc(chk(hidden), "mgs.out_sum", false)
.expect("MoeGraphScratch::new pool alloc");
let h_mid: BufId<ResidualBuf> = pool
.alloc(chk(hidden), "mgs.h_mid", true)
.expect("MoeGraphScratch::new pool alloc");
let h_post: BufId<MoeInputBuf> = pool
.alloc(chk(hidden), "mgs.h_post", true)
.expect("MoeGraphScratch::new pool alloc");
let shared_gate: BufId<SharedGateBuf> = pool
.alloc(chk(1), "mgs.shared_gate", true)
.expect("MoeGraphScratch::new pool alloc");
let routing_indices: BufId<RouterIdxBuf> = pool
.alloc(chk(max_k), "mgs.routing_indices", true)
.expect("MoeGraphScratch::new pool alloc");
let routing_weights: BufId<RouterWeightsBuf> = pool
.alloc(chk(max_k), "mgs.routing_weights", true)
.expect("MoeGraphScratch::new pool alloc");
let bucket_input: BufId<BucketInputBuf> = pool
.alloc(bkt(hidden), "mgs.bucket_input", true)
.expect("MoeGraphScratch::new pool alloc");
let bucket_token_idx: BufId<BucketTokenIdxBuf> = pool
.alloc(chunk * k_active * f32_sz, "mgs.bucket_token_idx", true)
.expect("MoeGraphScratch::new pool alloc");
let bucket_weights: BufId<BucketWeightsBuf> = pool
.alloc(chunk * k_active * f32_sz, "mgs.bucket_weights", true)
.expect("MoeGraphScratch::new pool alloc");
let expert_indices: BufId<ExpertIndicesBuf> = pool
.alloc(
chunk * k_active * std::mem::size_of::<u32>(),
"mgs.expert_indices",
true,
)
.expect("MoeGraphScratch::new pool alloc");
let expert_base: Option<BufId<ExpertBaseBuf>> = if mode.is_pread() {
Some(
pool.alloc(
v.num_experts * v.expert_size_4bit(),
"mgs.expert_base",
true,
)
.expect("MoeGraphScratch::new pool alloc (expert_base)"),
)
} else {
None
};
let n_experts = v.num_experts.max(1);
let htpe: BufId<HtpeBuf> = pool
.alloc(n_experts * std::mem::size_of::<u32>(), "mgs.htpe", true)
.expect("MoeGraphScratch::new pool alloc");
let hids: BufId<HidsBuf> = pool
.alloc(
n_experts * chunk * std::mem::size_of::<i32>(),
"mgs.hids",
true,
)
.expect("MoeGraphScratch::new pool alloc");
Self {
h_mid,
h_post,
shared_gate,
routing_indices,
routing_weights,
shared_ffn_gate,
shared_up,
shared_act,
shared_down,
bucket_input,
bucket_gate,
bucket_up,
bucket_act,
bucket_out,
bucket_token_idx,
bucket_weights,
out_sum,
expert_indices,
expert_base,
htpe,
hids,
commit_planned: std::cell::Cell::new(false),
}
}
}
impl LayerForwardBuffers {
pub fn new(pool: &mut MetalBufferPool) -> Self {
let v = VARIANT;
let f32_bytes = |n: usize| n * std::mem::size_of::<f32>();
let q_dim_full = v.num_attn_heads * v.head_dim;
let q_proj_dim_full = q_dim_full * 2;
let kv_dim_full = v.num_kv_heads * v.head_dim;
let oproj_in_max = v.linear_total_value().max(q_dim_full);
let num_linear = v.num_layers - num_full_attn_layers(&v);
let conv_state: Vec<BufId<ConvStateBuf>> = (0..num_linear)
.map(|_| {
pool.alloc(
f32_bytes(
(Variant::CONV_KERNEL_SIZE - 1) * v.linear_conv_dim(),
),
"lfb.conv_state",
true,
)
.expect("LayerForwardBuffers::new pool alloc")
})
.collect();
let delta_state: Vec<BufId<DeltaStateBuf>> = (0..num_linear)
.map(|_| {
pool.alloc(
f32_bytes(
v.linear_num_v_heads
* Variant::LINEAR_VALUE_DIM
* Variant::LINEAR_KEY_DIM,
),
"lfb.delta_state",
true,
)
.expect("LayerForwardBuffers::new pool alloc")
})
.collect();
let num_full_attn = num_full_attn_layers(&v);
let gpu_kv_floats = crate::riir::variants::GPU_KV_SEQ * kv_dim_full;
let gpu_kv_k: Vec<BufId<KvCacheKBuf>> = (0..num_full_attn)
.map(|_| {
pool.alloc(f32_bytes(gpu_kv_floats), "lfb.gpu_kv_k", true)
.expect("LayerForwardBuffers::new pool alloc")
})
.collect();
let gpu_kv_v: Vec<BufId<KvCacheVBuf>> = (0..num_full_attn)
.map(|_| {
pool.alloc(f32_bytes(gpu_kv_floats), "lfb.gpu_kv_v", true)
.expect("LayerForwardBuffers::new pool alloc")
})
.collect();
Self {
input: pool
.alloc(f32_bytes(v.hidden_dim), "lfb.input", true)
.expect("LayerForwardBuffers::new pool alloc"),
normed: pool
.alloc(f32_bytes(v.hidden_dim), "lfb.normed", true)
.expect("LayerForwardBuffers::new pool alloc"),
residual: pool
.alloc(f32_bytes(v.hidden_dim), "lfb.residual", true)
.expect("LayerForwardBuffers::new pool alloc"),
h_mid: pool
.alloc(f32_bytes(v.hidden_dim), "lfb.h_mid", true)
.expect("LayerForwardBuffers::new pool alloc"),
output: pool
.alloc(f32_bytes(v.hidden_dim), "lfb.output", true)
.expect("LayerForwardBuffers::new pool alloc"),
q_stack: pool
.alloc(
f32_bytes(v.linear_conv_dim()),
"lfb.batch_out[0:qkv]",
true,
)
.expect("LayerForwardBuffers::new pool alloc"),
z_stack: pool
.alloc(
f32_bytes(v.linear_total_value()),
"lfb.batch_out[1:z]",
true,
)
.expect("LayerForwardBuffers::new pool alloc"),
beta_stack: pool
.alloc(
f32_bytes(v.linear_num_v_heads),
"lfb.batch_out[2:beta]",
true,
)
.expect("LayerForwardBuffers::new pool alloc"),
alpha_stack: pool
.alloc(
f32_bytes(v.linear_num_v_heads),
"lfb.batch_out[3:alpha]",
true,
)
.expect("LayerForwardBuffers::new pool alloc"),
gate_logits: pool
.alloc(
f32_bytes(v.num_experts),
"lfb.batch_out[4:router_gate]",
true,
)
.expect("LayerForwardBuffers::new pool alloc"),
shared_gate: pool
.alloc(f32_bytes(1), "lfb.batch_out[5:shared_gate]", true)
.expect("LayerForwardBuffers::new pool alloc"),
o_proj_stack: pool
.alloc(
f32_bytes(oproj_in_max),
"lfb.batch_out[6:oproj_in]",
true,
)
.expect("LayerForwardBuffers::new pool alloc"),
conv_state,
delta_state,
conv_output: pool
.alloc(f32_bytes(v.linear_conv_dim()), "lfb.conv_output", true)
.expect("LayerForwardBuffers::new pool alloc"),
delta_g_decay: pool
.alloc(f32_bytes(v.linear_num_v_heads), "lfb.delta_g_decay", true)
.expect("LayerForwardBuffers::new pool alloc"),
delta_beta: pool
.alloc(f32_bytes(v.linear_num_v_heads), "lfb.delta_beta", true)
.expect("LayerForwardBuffers::new pool alloc"),
delta_output: pool
.alloc(
f32_bytes(v.linear_total_value()),
"lfb.delta_output",
true,
)
.expect("LayerForwardBuffers::new pool alloc"),
sum_sq: pool
.alloc(f32_bytes(1), "lfb.sum_sq", true)
.expect("LayerForwardBuffers::new pool alloc"),
shared_gate_out: pool
.alloc(
f32_bytes(v.shared_intermediate),
"lfb.shared_gate_out",
true,
)
.expect("LayerForwardBuffers::new pool alloc"),
shared_up_out: pool
.alloc(
f32_bytes(v.shared_intermediate),
"lfb.shared_up_out",
true,
)
.expect("LayerForwardBuffers::new pool alloc"),
shared_act: pool
.alloc(f32_bytes(v.shared_intermediate), "lfb.shared_act", true)
.expect("LayerForwardBuffers::new pool alloc"),
shared_out: pool
.alloc(f32_bytes(v.hidden_dim), "lfb.shared_out", true)
.expect("LayerForwardBuffers::new pool alloc"),
q_proj_out: pool
.alloc(f32_bytes(q_proj_dim_full), "lfb.q_proj_out", true)
.expect("LayerForwardBuffers::new pool alloc"),
k_out: pool
.alloc(f32_bytes(kv_dim_full), "lfb.k_out", true)
.expect("LayerForwardBuffers::new pool alloc"),
v_out: pool
.alloc(f32_bytes(kv_dim_full), "lfb.v_out", true)
.expect("LayerForwardBuffers::new pool alloc"),
gpu_kv_k,
gpu_kv_v,
gpu_attn_q: pool
.alloc(f32_bytes(q_dim_full), "lfb.gpu_attn_q", true)
.expect("LayerForwardBuffers::new pool alloc"),
gpu_attn_scores: pool
.alloc(
f32_bytes(
v.num_attn_heads * crate::riir::variants::GPU_KV_SEQ,
),
"lfb.gpu_attn_scores",
true,
)
.expect("LayerForwardBuffers::new pool alloc"),
gpu_attn_out: pool
.alloc(f32_bytes(q_dim_full), "lfb.gpu_attn_out", true)
.expect("LayerForwardBuffers::new pool alloc"),
gpu_attn_gate: pool
.alloc(f32_bytes(q_dim_full), "lfb.gpu_attn_gate", true)
.expect("LayerForwardBuffers::new pool alloc"),
}
}
pub fn reset_recurrence(&self, pool: &MetalBufferPool) {
for &id in &self.conv_state {
zero_f32_buffer(pool.handle(id));
}
for &id in &self.delta_state {
zero_f32_buffer(pool.handle(id));
}
}
pub fn reset_gpu_attn_kv_mirrors(&self, pool: &MetalBufferPool) {
for &id in &self.gpu_kv_k {
zero_f32_buffer(pool.handle(id));
}
for &id in &self.gpu_kv_v {
zero_f32_buffer(pool.handle(id));
}
}
}
fn zero_f32_buffer(b: &Buffer) {
let bytes = b.length() as usize;
unsafe {
std::ptr::write_bytes(b.contents() as *mut u8, 0, bytes);
}
}
pub fn linear_layer_idx_for(layer_idx: usize) -> Option<usize> {
use crate::riir::variants::LayerKind;
if VARIANT.layer_kind(layer_idx) == LayerKind::FullAttn {
None
} else {
Some(layer_idx - (layer_idx + 1) / VARIANT.full_attn_interval)
}
}
pub fn full_attn_layer_idx_for(layer_idx: usize) -> Option<usize> {
use crate::riir::variants::LayerKind;
if VARIANT.layer_kind(layer_idx) == LayerKind::FullAttn {
Some((layer_idx + 1) / VARIANT.full_attn_interval - 1)
} else {
None
}
}
pub(in crate::riir) fn num_full_attn_layers(v: &Variant) -> usize {
v.num_layers / v.full_attn_interval
}
pub(in crate::riir) fn bits_of(wf: &WeightFile, name: &str) -> u32 {
wf.tensor_info(name)
.map(|i| i.bits as u32)
.unwrap_or(4)
.max(4)
}
pub(in crate::riir) struct PrefetchEnv<'a> {
pub prefetch: &'a mut crate::riir::io::prefetch::PrefetchState,
}
pub(in crate::riir) struct OProj {
pub w_off: u64,
pub s_off: u64,
pub b_off: u64,
pub bits: u32,
pub in_dim: u32,
}
pub(in crate::riir) struct GpuAttnEncodeArgs {
pub fa_idx: usize,
pub kv_len: u32,
}
#[allow(clippy::too_many_arguments)]
pub fn linear_attn_layer_forward(
metal: &mut MetalContext,
gpu: &GpuLayerCtx<'_>,
moe: &mut MoeBuffers,
deferred: &mut crate::riir::moe::deferred::DeferredRing,
layer_idx: usize,
k_active: usize,
expert_files: &ExpertFiles,
pool: &rayon::ThreadPool,
prefetch: &mut crate::riir::io::prefetch::PrefetchState,
prefetch_set: usize,
_layer_state: &mut LinearAttnState,
gpu_combine: bool,
prev_layer_chained: bool,
chain_next_norm_off: Option<u64>,
) -> Result<(), LayerForwardError> {
let GpuLayerCtx { wf, wf_buf, layer_cache, buffers, buffer_pool } =
*gpu;
let v = VARIANT;
let linear_layer_idx = linear_layer_idx_for(layer_idx).ok_or(
LayerForwardError::MissingTensor {
layer: layer_idx,
tensor: "linear_layer_idx (called on full-attn layer)",
},
)?;
let qkv_bits = bits_of(
wf,
&format!("model.layers.{layer_idx}.linear_attn.in_proj_qkv.weight"),
);
let z_bits = bits_of(
wf,
&format!("model.layers.{layer_idx}.linear_attn.in_proj_z.weight"),
);
let alpha_bits = bits_of(
wf,
&format!("model.layers.{layer_idx}.linear_attn.in_proj_a.weight"),
);
let beta_bits = bits_of(
wf,
&format!("model.layers.{layer_idx}.linear_attn.in_proj_b.weight"),
);
let o_bits = bits_of(
wf,
&format!("model.layers.{layer_idx}.linear_attn.out_proj.weight"),
);
let attn = layer_cache.attn.linear().ok_or(
LayerForwardError::MissingTensor {
layer: layer_idx,
tensor: "linear_attn weights (called on full-attn layer)",
},
)?;
let qkv_w = attn.qkv_w;
let qkv_s = attn.qkv_s;
let qkv_b = attn.qkv_b;
let z_w = attn.z_w;
let z_s = attn.z_s;
let z_b = attn.z_b;
let beta_w = attn.beta_w;
let beta_s = attn.beta_s;
let beta_b = attn.beta_b;
let alpha_w = attn.alpha_w;
let alpha_s = attn.alpha_s;
let alpha_b = attn.alpha_b;
let conv1d_w = attn.conv1d_w;
let a_log = attn.a_log;
let dt_bias = attn.dt_bias;
let gnorm_w = attn.gated_norm_w;
let o_w = attn.o_proj_w;
let o_s = attn.o_proj_s;
let o_b = attn.o_proj_b;
let lp = LinearAttnPipelines::fetch(metal)?;
let mv = MatvecPipelines::fetch(metal)?;
let rms_pipes = RmsNormBf16Pipelines::fetch(metal)?;
let queue = metal.queue_clone();
let cmdbuf = queue.new_command_buffer();
{
if !prev_layer_chained {
encode_rms_norm_bf16_into(
cmdbuf,
&rms_pipes,
buffer_pool.handle(buffers.input),
wf_buf.buffer(),
layer_cache.input_layernorm_w,
buffer_pool.handle(buffers.sum_sq),
buffer_pool.handle(buffers.normed),
v.hidden_dim as u32,
RMS_NORM_EPS,
);
}
let specs = [
MatvecSpec {
w_off: qkv_w,
s_off: qkv_s,
b_off: qkv_b,
input: buffer_pool.handle(buffers.normed),
output: buffer_pool.handle(buffers.q_stack),
out_dim: v.linear_conv_dim() as u32,
in_dim: v.hidden_dim as u32,
bits: qkv_bits,
},
MatvecSpec {
w_off: z_w,
s_off: z_s,
b_off: z_b,
input: buffer_pool.handle(buffers.normed),
output: buffer_pool.handle(buffers.z_stack),
out_dim: v.linear_total_value() as u32,
in_dim: v.hidden_dim as u32,
bits: z_bits,
},
MatvecSpec {
w_off: beta_w,
s_off: beta_s,
b_off: beta_b,
input: buffer_pool.handle(buffers.normed),
output: buffer_pool.handle(buffers.beta_stack),
out_dim: v.linear_num_v_heads as u32,
in_dim: v.hidden_dim as u32,
bits: beta_bits,
},
MatvecSpec {
w_off: alpha_w,
s_off: alpha_s,
b_off: alpha_b,
input: buffer_pool.handle(buffers.normed),
output: buffer_pool.handle(buffers.alpha_stack),
out_dim: v.linear_num_v_heads as u32,
in_dim: v.hidden_dim as u32,
bits: alpha_bits,
},
];
for s in &specs {
encode_matvec(cmdbuf, &mv, wf_buf, s);
}
encode_conv1d_step(
cmdbuf,
&lp.conv1d_step,
&lp.conv1d_state_update,
buffer_pool.handle(buffers.conv_state[linear_layer_idx]),
buffer_pool.handle(buffers.q_stack),
0,
wf_buf.buffer(),
conv1d_w,
buffer_pool.handle(buffers.conv_output),
v.linear_conv_dim() as u32,
);
encode_rms_norm_qk(
cmdbuf,
&lp.rms_norm_qk,
buffer_pool.handle(buffers.conv_output),
v.linear_num_k_heads as u32,
Variant::LINEAR_KEY_DIM as u32,
);
encode_compute_decay_beta(
cmdbuf,
&lp.compute_decay_beta,
buffer_pool.handle(buffers.alpha_stack), 0,
buffer_pool.handle(buffers.beta_stack), 0,
wf_buf.buffer(),
a_log,
dt_bias,
buffer_pool.handle(buffers.delta_g_decay),
buffer_pool.handle(buffers.delta_beta),
v.linear_num_v_heads as u32,
);
let k_heads_per_v =
(v.linear_num_v_heads / v.linear_num_k_heads) as u32;
encode_delta_net_step(
cmdbuf,
&lp.delta_net_step,
buffer_pool.handle(buffers.delta_state[linear_layer_idx]),
buffer_pool.handle(buffers.conv_output),
buffer_pool.handle(buffers.delta_g_decay),
buffer_pool.handle(buffers.delta_beta),
buffer_pool.handle(buffers.delta_output),
v.linear_num_v_heads as u32,
Variant::LINEAR_VALUE_DIM as u32,
k_heads_per_v,
);
encode_gated_rms_norm(
cmdbuf,
&lp.gated_rms_norm,
buffer_pool.handle(buffers.delta_output),
buffer_pool.handle(buffers.z_stack), 0,
wf_buf.buffer(),
gnorm_w,
buffer_pool.handle(buffers.o_proj_stack),
0,
v.linear_num_v_heads as u32,
Variant::LINEAR_VALUE_DIM as u32,
);
}
post_attention_tail(
metal,
cmdbuf,
gpu,
moe,
deferred,
layer_idx,
k_active,
expert_files,
pool,
prefetch,
prefetch_set,
OProj {
w_off: o_w,
s_off: o_s,
b_off: o_b,
bits: o_bits,
in_dim: v.linear_total_value() as u32,
},
gpu_combine,
None,
chain_next_norm_off,
)
}
pub(in crate::riir) struct PostAttnIntermediates {
pub routing_indices: Vec<i32>,
pub routing_weights: Vec<f32>,
pub shared_gate_score: f32,
}
#[allow(clippy::too_many_arguments)]
pub(in crate::riir) fn post_attention_tail(
metal: &mut MetalContext,
cmdbuf: &CommandBufferRef,
gpu: &GpuLayerCtx<'_>,
moe: &mut MoeBuffers,
deferred: &mut crate::riir::moe::deferred::DeferredRing,
layer_idx: usize,
k_active: usize,
expert_files: &ExpertFiles,
pool: &rayon::ThreadPool,
prefetch: &mut crate::riir::io::prefetch::PrefetchState,
prefetch_set: usize,
o_proj: OProj,
gpu_combine: bool,
gpu_attn_args: Option<GpuAttnEncodeArgs>,
chain_next_norm_off: Option<u64>,
) -> Result<(), LayerForwardError> {
let GpuLayerCtx { wf: _, wf_buf, layer_cache: _, buffers, buffer_pool } =
*gpu;
let intermediates = post_attention_pre_moe(
metal,
cmdbuf,
gpu,
layer_idx,
k_active,
o_proj,
gpu_attn_args,
)?;
moe_dispatch_per_token(
metal,
wf_buf,
buffers,
buffer_pool,
moe,
deferred,
layer_idx,
expert_files,
pool,
prefetch,
prefetch_set,
&intermediates,
gpu_combine,
chain_next_norm_off,
)
}
#[allow(clippy::too_many_arguments)]
pub(in crate::riir) fn post_attention_pre_moe(
metal: &mut MetalContext,
cmdbuf: &CommandBufferRef,
gpu: &GpuLayerCtx<'_>,
layer_idx: usize,
k_active: usize,
o_proj: OProj,
gpu_attn_args: Option<GpuAttnEncodeArgs>,
) -> Result<PostAttnIntermediates, LayerForwardError> {
let GpuLayerCtx { wf, wf_buf, layer_cache, buffers, buffer_pool } =
*gpu;
let v = VARIANT;
let gate_bits =
bits_of(wf, &format!("model.layers.{layer_idx}.mlp.gate.weight"));
let seg_bits = bits_of(
wf,
&format!(
"model.layers.{layer_idx}.mlp.shared_expert_gate.weight"
),
);
let s_gate_bits = bits_of(
wf,
&format!(
"model.layers.{layer_idx}.mlp.shared_expert.gate_proj.weight"
),
);
let s_up_bits = bits_of(
wf,
&format!(
"model.layers.{layer_idx}.mlp.shared_expert.up_proj.weight"
),
);
let s_down_bits = bits_of(
wf,
&format!(
"model.layers.{layer_idx}.mlp.shared_expert.down_proj.weight"
),
);
let post_attn_norm_w = layer_cache.post_attention_layernorm_w;
let gate_w = layer_cache.gate.w;
let gate_s = layer_cache.gate.s;
let gate_b = layer_cache.gate.b;
let shared_up_w = layer_cache.shared.up_w;
let shared_up_s = layer_cache.shared.up_s;
let shared_up_b = layer_cache.shared.up_b;
let shared_gate_w = layer_cache.shared.gate_w;
let shared_gate_s = layer_cache.shared.gate_s;
let shared_gate_b = layer_cache.shared.gate_b;
let shared_down_w = layer_cache.shared.down_w;
let shared_down_s = layer_cache.shared.down_s;
let shared_down_b = layer_cache.shared.down_b;
let seg_w = layer_cache.shared.seg_w;
let seg_s = layer_cache.shared.seg_s;
let seg_b = layer_cache.shared.seg_b;
let mv = MatvecPipelines::fetch(metal)?;
let sum_sq = metal.pipeline("rms_norm_sum_sq")?.clone();
let apply = metal.pipeline("rms_norm_apply_bf16")?.clone();
let resid_add = metal.pipeline("residual_add")?.clone();
let swiglu = metal.pipeline("swiglu_fused")?.clone();
let attn_pipes = if gpu_attn_args.is_some() {
Some(GpuAttnPipelines::fetch(metal)?)
} else {
None
};
{
if let (Some(args), Some(attn_pipes)) =
(gpu_attn_args.as_ref(), attn_pipes.as_ref())
{
let head_dim = v.head_dim as u32;
let kv_dim = (v.num_kv_heads * v.head_dim) as u32;
let num_heads = v.num_attn_heads as u32;
let heads_per_kv = (v.num_attn_heads / v.num_kv_heads) as u32;
let scale = 1.0f32 / (head_dim as f32).sqrt();
let seq_stride = crate::riir::variants::GPU_KV_SEQ as u32;
encode_attn_scores_batched_into(
cmdbuf,
&attn_pipes.scores,
buffer_pool.handle(buffers.gpu_attn_q),
buffer_pool.handle(buffers.gpu_kv_k[args.fa_idx]),
buffer_pool.handle(buffers.gpu_attn_scores),
num_heads,
head_dim,
kv_dim,
args.kv_len,
seq_stride,
heads_per_kv,
scale,
);
encode_attn_softmax_batched_into(
cmdbuf,
&attn_pipes.softmax,
buffer_pool.handle(buffers.gpu_attn_scores),
num_heads,
args.kv_len,
seq_stride,
);
encode_attn_values_batched_into(
cmdbuf,
&attn_pipes.values,
buffer_pool.handle(buffers.gpu_attn_scores),
buffer_pool.handle(buffers.gpu_kv_v[args.fa_idx]),
buffer_pool.handle(buffers.gpu_attn_out),
num_heads,
head_dim,
kv_dim,
args.kv_len,
seq_stride,
heads_per_kv,
);
encode_sigmoid_gate_into(
cmdbuf,
&attn_pipes.gate,
buffer_pool.handle(buffers.gpu_attn_out),
buffer_pool.handle(buffers.gpu_attn_gate),
num_heads * head_dim,
);
}
let oproj_input = if gpu_attn_args.is_some() {
buffer_pool.handle(buffers.gpu_attn_out)
} else {
buffer_pool.handle(buffers.o_proj_stack)
};
encode_matvec(
cmdbuf,
&mv,
wf_buf,
&MatvecSpec {
w_off: o_proj.w_off,
s_off: o_proj.s_off,
b_off: o_proj.b_off,
input: oproj_input,
output: buffer_pool.handle(buffers.output),
out_dim: v.hidden_dim as u32,
in_dim: o_proj.in_dim,
bits: o_proj.bits,
},
);
encode_residual_add(
cmdbuf,
&resid_add,
buffer_pool.handle(buffers.output),
buffer_pool.handle(buffers.input), buffer_pool.handle(buffers.h_mid),
v.hidden_dim as u32,
);
encode_rms_norm_pair(
cmdbuf,
&sum_sq,
&apply,
buffer_pool.handle(buffers.h_mid),
wf_buf.buffer(),
post_attn_norm_w,
buffer_pool.handle(buffers.normed),
buffer_pool.handle(buffers.sum_sq),
v.hidden_dim as u32,
);
encode_matvec(
cmdbuf,
&mv,
wf_buf,
&MatvecSpec {
w_off: gate_w,
s_off: gate_s,
b_off: gate_b,
input: buffer_pool.handle(buffers.normed),
output: buffer_pool.handle(buffers.gate_logits),
out_dim: v.num_experts as u32,
in_dim: v.hidden_dim as u32,
bits: gate_bits,
},
);
encode_matvec(
cmdbuf,
&mv,
wf_buf,
&MatvecSpec {
w_off: seg_w,
s_off: seg_s,
b_off: seg_b,
input: buffer_pool.handle(buffers.normed),
output: buffer_pool.handle(buffers.shared_gate),
out_dim: 1,
in_dim: v.hidden_dim as u32,
bits: seg_bits,
},
);
encode_matvec(
cmdbuf,
&mv,
wf_buf,
&MatvecSpec {
w_off: shared_gate_w,
s_off: shared_gate_s,
b_off: shared_gate_b,
input: buffer_pool.handle(buffers.normed),
output: buffer_pool.handle(buffers.shared_gate_out),
out_dim: v.shared_intermediate as u32,
in_dim: v.hidden_dim as u32,
bits: s_gate_bits,
},
);
encode_matvec(
cmdbuf,
&mv,
wf_buf,
&MatvecSpec {
w_off: shared_up_w,
s_off: shared_up_s,
b_off: shared_up_b,
input: buffer_pool.handle(buffers.normed),
output: buffer_pool.handle(buffers.shared_up_out),
out_dim: v.shared_intermediate as u32,
in_dim: v.hidden_dim as u32,
bits: s_up_bits,
},
);
encode_swiglu_buf(
cmdbuf,
&swiglu,
buffer_pool.handle(buffers.shared_gate_out),
buffer_pool.handle(buffers.shared_up_out),
buffer_pool.handle(buffers.shared_act),
v.shared_intermediate as u32,
);
encode_matvec(
cmdbuf,
&mv,
wf_buf,
&MatvecSpec {
w_off: shared_down_w,
s_off: shared_down_s,
b_off: shared_down_b,
input: buffer_pool.handle(buffers.shared_act),
output: buffer_pool.handle(buffers.shared_out),
out_dim: v.hidden_dim as u32,
in_dim: v.shared_intermediate as u32,
bits: s_down_bits,
},
);
metal.commit_and_wait_labeled(cmdbuf, "post_attn_tail.cmd2_3");
}
let mut scores =
read_buffer_to_vec(buffer_pool.handle(buffers.gate_logits), v.num_experts);
let mut routing_indices = vec![0i32; k_active];
let mut routing_weights = vec![0f32; k_active];
moe_router_cpu(
&mut scores,
k_active,
&mut routing_indices,
&mut routing_weights,
)?;
let shared_gate_score = {
let s = read_buffer_to_vec(buffer_pool.handle(buffers.shared_gate), 1);
s[0]
};
Ok(PostAttnIntermediates {
routing_indices,
routing_weights,
shared_gate_score,
})
}
#[allow(dead_code, clippy::too_many_arguments)]
pub(in crate::riir) fn post_attention_post_o_proj_to_intermediates(
metal: &mut MetalContext,
cmdbuf: &CommandBufferRef,
wf: &WeightFile,
wf_buf: &MtlWeightBuf,
layer_cache: &LayerWeightCache,
buffers: &LayerForwardBuffers,
buffer_pool: &MetalBufferPool,
layer_idx: usize,
k_active: usize,
) -> Result<PostAttnIntermediates, LayerForwardError> {
let v = VARIANT;
let gate_bits =
bits_of(wf, &format!("model.layers.{layer_idx}.mlp.gate.weight"));
let seg_bits = bits_of(
wf,
&format!(
"model.layers.{layer_idx}.mlp.shared_expert_gate.weight"
),
);
let s_gate_bits = bits_of(
wf,
&format!(
"model.layers.{layer_idx}.mlp.shared_expert.gate_proj.weight"
),
);
let s_up_bits = bits_of(
wf,
&format!(
"model.layers.{layer_idx}.mlp.shared_expert.up_proj.weight"
),
);
let s_down_bits = bits_of(
wf,
&format!(
"model.layers.{layer_idx}.mlp.shared_expert.down_proj.weight"
),
);
let post_attn_norm_w = layer_cache.post_attention_layernorm_w;
let gate_w = layer_cache.gate.w;
let gate_s = layer_cache.gate.s;
let gate_b = layer_cache.gate.b;
let shared_up_w = layer_cache.shared.up_w;
let shared_up_s = layer_cache.shared.up_s;
let shared_up_b = layer_cache.shared.up_b;
let shared_gate_w = layer_cache.shared.gate_w;
let shared_gate_s = layer_cache.shared.gate_s;
let shared_gate_b = layer_cache.shared.gate_b;
let shared_down_w = layer_cache.shared.down_w;
let shared_down_s = layer_cache.shared.down_s;
let shared_down_b = layer_cache.shared.down_b;
let seg_w = layer_cache.shared.seg_w;
let seg_s = layer_cache.shared.seg_s;
let seg_b = layer_cache.shared.seg_b;
let mv = MatvecPipelines::fetch(metal)?;
let sum_sq = metal.pipeline("rms_norm_sum_sq")?.clone();
let apply = metal.pipeline("rms_norm_apply_bf16")?.clone();
let resid_add = metal.pipeline("residual_add")?.clone();
let swiglu = metal.pipeline("swiglu_fused")?.clone();
encode_residual_add(
cmdbuf,
&resid_add,
buffer_pool.handle(buffers.output),
buffer_pool.handle(buffers.input),
buffer_pool.handle(buffers.h_mid),
v.hidden_dim as u32,
);
encode_rms_norm_pair(
cmdbuf,
&sum_sq,
&apply,
buffer_pool.handle(buffers.h_mid),
wf_buf.buffer(),
post_attn_norm_w,
buffer_pool.handle(buffers.normed),
buffer_pool.handle(buffers.sum_sq),
v.hidden_dim as u32,
);
encode_matvec(
cmdbuf,
&mv,
wf_buf,
&MatvecSpec {
w_off: gate_w,
s_off: gate_s,
b_off: gate_b,
input: buffer_pool.handle(buffers.normed),
output: buffer_pool.handle(buffers.gate_logits),
out_dim: v.num_experts as u32,
in_dim: v.hidden_dim as u32,
bits: gate_bits,
},
);
encode_matvec(
cmdbuf,
&mv,
wf_buf,
&MatvecSpec {
w_off: seg_w,
s_off: seg_s,
b_off: seg_b,
input: buffer_pool.handle(buffers.normed),
output: buffer_pool.handle(buffers.shared_gate),
out_dim: 1,
in_dim: v.hidden_dim as u32,
bits: seg_bits,
},
);
encode_matvec(
cmdbuf,
&mv,
wf_buf,
&MatvecSpec {
w_off: shared_gate_w,
s_off: shared_gate_s,
b_off: shared_gate_b,
input: buffer_pool.handle(buffers.normed),
output: buffer_pool.handle(buffers.shared_gate_out),
out_dim: v.shared_intermediate as u32,
in_dim: v.hidden_dim as u32,
bits: s_gate_bits,
},
);
encode_matvec(
cmdbuf,
&mv,
wf_buf,
&MatvecSpec {
w_off: shared_up_w,
s_off: shared_up_s,
b_off: shared_up_b,
input: buffer_pool.handle(buffers.normed),
output: buffer_pool.handle(buffers.shared_up_out),
out_dim: v.shared_intermediate as u32,
in_dim: v.hidden_dim as u32,
bits: s_up_bits,
},
);
encode_swiglu_buf(
cmdbuf,
&swiglu,
buffer_pool.handle(buffers.shared_gate_out),
buffer_pool.handle(buffers.shared_up_out),
buffer_pool.handle(buffers.shared_act),
v.shared_intermediate as u32,
);
encode_matvec(
cmdbuf,
&mv,
wf_buf,
&MatvecSpec {
w_off: shared_down_w,
s_off: shared_down_s,
b_off: shared_down_b,
input: buffer_pool.handle(buffers.shared_act),
output: buffer_pool.handle(buffers.shared_out),
out_dim: v.hidden_dim as u32,
in_dim: v.shared_intermediate as u32,
bits: s_down_bits,
},
);
metal.commit_and_wait_labeled(cmdbuf, "post_attn_post_oproj.cmd");
let mut scores =
read_buffer_to_vec(buffer_pool.handle(buffers.gate_logits), v.num_experts);
let mut routing_indices = vec![0i32; k_active];
let mut routing_weights = vec![0f32; k_active];
moe_router_cpu(
&mut scores,
k_active,
&mut routing_indices,
&mut routing_weights,
)?;
let shared_gate_score = {
let s = read_buffer_to_vec(buffer_pool.handle(buffers.shared_gate), 1);
s[0]
};
Ok(PostAttnIntermediates {
routing_indices,
routing_weights,
shared_gate_score,
})
}
#[allow(clippy::too_many_arguments)]
pub(in crate::riir) fn moe_dispatch_per_token(
metal: &mut MetalContext,
wf_buf: &MtlWeightBuf,
buffers: &LayerForwardBuffers,
buffer_pool: &MetalBufferPool,
moe: &mut MoeBuffers,
deferred: &mut crate::riir::moe::deferred::DeferredRing,
layer_idx: usize,
expert_files: &ExpertFiles,
pool: &rayon::ThreadPool,
prefetch: &mut crate::riir::io::prefetch::PrefetchState,
prefetch_set: usize,
intermediates: &PostAttnIntermediates,
gpu_combine: bool,
chain_next_norm_off: Option<u64>,
) -> Result<(), LayerForwardError> {
let v = VARIANT;
let indices = &intermediates.routing_indices;
let weights = &intermediates.routing_weights;
let shared_gate_score = intermediates.shared_gate_score;
let k = indices.len();
if gpu_combine {
let bindings_synced; let bindings: Vec<(&metal::Buffer, u64)> = if prefetch.mode().is_mmap() {
indices.iter()
.map(|&idx| {
expert_files
.mmap_buffer_for_expert(layer_idx, idx as u32)
.expect("mmap buffer missing for expert")
})
.collect()
} else {
use crate::riir::io::prefetch::SlotSource;
use rayon::prelude::*;
const MAX_K: usize = crate::riir::moe::expert_forward::MAX_K;
let prefetch_status = prefetch.wait_for(layer_idx);
let mut data_set_per_slot: [SlotSource; MAX_K] =
[SlotSource::Synced; MAX_K];
let mut hit_count: u64 = 0;
if let Some(status) = prefetch_status {
for slot in 0..k {
let actual = indices[slot];
for buf_idx in 0..status.k {
if status.loaded_indices[buf_idx] == actual {
data_set_per_slot[slot] =
SlotSource::Prefetched(buf_idx);
hit_count += 1;
break;
}
}
}
}
prefetch.record_outcome(hit_count, k as u64 - hit_count);
let mut dsts = moe.data_synced_slots_mut_array(buffer_pool);
pool.install(|| -> Result<(), crate::riir::io::expert_io::ExpertIoError> {
dsts[..k]
.par_iter_mut()
.enumerate()
.try_for_each(|(slot, dst)| {
if data_set_per_slot[slot] == SlotSource::Synced {
let expert_idx = indices[slot] as usize;
expert_files.read_expert(layer_idx, expert_idx, *dst)
} else {
Ok(())
}
})
})?;
let mut actuals: [i32; MAX_K] = [0; MAX_K];
actuals[..k].copy_from_slice(&indices[..k]);
prefetch.record_actual(layer_idx, actuals);
bindings_synced = (0..k)
.map(|slot| match data_set_per_slot[slot] {
SlotSource::Synced => moe.data_synced_id(slot),
SlotSource::Prefetched(buf_idx) => {
moe.data_prefetch_id(prefetch_set, buf_idx)
}
})
.collect::<Vec<_>>();
bindings_synced
.iter()
.map(|id| (buffer_pool.handle(*id), 0u64))
.collect()
};
let chain_rms_pipes = if chain_next_norm_off.is_some() {
Some(crate::riir::backend::gpu::gpu_norm::RmsNormBf16Pipelines {
sum: metal.pipeline("rms_norm_sum_sq")?.clone(),
apply: metal.pipeline("rms_norm_apply_bf16")?.clone(),
})
} else {
None
};
let chain = chain_next_norm_off.and_then(|off| {
chain_rms_pipes.as_ref().map(|pipes| ChainToNormed {
pipes,
wf_buf: wf_buf.buffer(),
next_norm_off: off,
combine_out: buffer_pool.handle(buffers.input),
chain_sum_sq: buffer_pool.handle(buffers.sum_sq),
chain_normed: buffer_pool.handle(buffers.normed),
eps: RMS_NORM_EPS,
})
});
gpu_batched_experts_begin_mmap(
metal,
moe,
buffer_pool,
deferred,
k as i32,
buffer_pool.handle(buffers.normed),
buffer_pool.handle(buffers.h_mid),
buffer_pool.handle(buffers.shared_out),
weights,
shared_gate_score,
layer_idx as i32,
&bindings,
chain,
)?;
} else {
let expert_size = v.expert_size_4bit();
let mut expert_data = vec![0u8; k * expert_size];
for slot in 0..k {
let expert_idx = indices[slot] as usize;
let dst = &mut expert_data
[slot * expert_size..(slot + 1) * expert_size];
expert_files.read_expert(layer_idx, expert_idx, dst)?;
}
let h_mid_host = read_buffer_to_vec(buffer_pool.handle(buffers.h_mid), v.hidden_dim);
let shared_out_host =
read_buffer_to_vec(buffer_pool.handle(buffers.shared_out), v.hidden_dim);
let normed_host = read_buffer_to_vec(buffer_pool.handle(buffers.normed), v.hidden_dim);
let payload = ExpertPayload {
h_post: &normed_host,
h_mid: &h_mid_host,
shared_out: &shared_out_host,
expert_weights: weights,
shared_gate_score,
};
gpu_batched_experts_begin(
metal,
moe,
buffer_pool,
deferred,
k as i32,
&expert_data,
payload,
layer_idx as i32,
false,
)?;
}
Ok(())
}
pub(in crate::riir) fn read_buffer_to_vec(b: &Buffer, len: usize) -> Vec<f32> {
let ptr = b.contents() as *const f32;
unsafe { std::slice::from_raw_parts(ptr, len).to_vec() }
}
#[allow(clippy::too_many_arguments)]
fn encode_rms_norm_pair(
cmdbuf: &CommandBufferRef,
sum_pipe: &ComputePipelineState,
apply_pipe: &ComputePipelineState,
input: &Buffer,
weight_buf: &Buffer,
weight_off: u64,
output: &Buffer,
sum_sq: &Buffer,
dim: u32,
) {
{
let enc = cmdbuf.new_compute_command_encoder();
enc.set_compute_pipeline_state(sum_pipe);
enc.set_buffer(0, Some(input), 0);
enc.set_buffer(1, Some(sum_sq), 0);
enc.set_bytes(2, 4, (&dim as *const u32).cast());
enc.dispatch_thread_groups(
MTLSize::new(1, 1, 1),
MTLSize::new(256, 1, 1),
);
enc.end_encoding();
}
{
let enc = cmdbuf.new_compute_command_encoder();
enc.set_compute_pipeline_state(apply_pipe);
enc.set_buffer(0, Some(input), 0);
enc.set_buffer(1, Some(weight_buf), weight_off as NSUInteger);
enc.set_buffer(2, Some(sum_sq), 0);
enc.set_buffer(3, Some(output), 0);
let eps = RMS_NORM_EPS;
enc.set_bytes(4, 4, (&dim as *const u32).cast());
enc.set_bytes(5, 4, (&eps as *const f32).cast());
let num_tgs = (dim + 255) / 256;
enc.dispatch_thread_groups(
MTLSize::new(num_tgs as NSUInteger, 1, 1),
MTLSize::new(256, 1, 1),
);
enc.end_encoding();
}
}
fn encode_swiglu_buf(
cmdbuf: &CommandBufferRef,
pipeline: &ComputePipelineState,
gate: &Buffer,
up: &Buffer,
act: &Buffer,
dim: u32,
) {
let enc = cmdbuf.new_compute_command_encoder();
enc.set_compute_pipeline_state(pipeline);
enc.set_buffer(0, Some(gate), 0);
enc.set_buffer(1, Some(up), 0);
enc.set_buffer(2, Some(act), 0);
enc.set_bytes(3, 4, (&dim as *const u32).cast());
let num_tgs = (dim + 255) / 256;
enc.dispatch_thread_groups(
MTLSize::new(num_tgs as NSUInteger, 1, 1),
MTLSize::new(256, 1, 1),
);
enc.end_encoding();
}
fn encode_residual_add(
cmdbuf: &CommandBufferRef,
pipeline: &ComputePipelineState,
a: &Buffer,
b: &Buffer,
out: &Buffer,
dim: u32,
) {
let enc = cmdbuf.new_compute_command_encoder();
enc.set_compute_pipeline_state(pipeline);
enc.set_buffer(0, Some(a), 0);
enc.set_buffer(1, Some(b), 0);
enc.set_buffer(2, Some(out), 0);
enc.set_bytes(3, 4, (&dim as *const u32).cast());
let num_tgs = (dim + 255) / 256;
enc.dispatch_thread_groups(
MTLSize::new(num_tgs as NSUInteger, 1, 1),
MTLSize::new(256, 1, 1),
);
enc.end_encoding();
}
#[allow(clippy::too_many_arguments)]
pub(in crate::riir) fn batched_linear_attn_layer_forward<B>(
backend: &mut B,
wf: &WeightFile,
layer_cache: &LayerWeightCache,
buffers: &LayerForwardBuffers,
layer_idx: usize,
n_tokens: usize,
k_active: usize,
expert_files: &ExpertFiles,
moe_buffers: &mut crate::riir::moe::expert_forward::MoeBuffers,
_layer_state: &mut LinearAttnState,
prefetch: Option<PrefetchEnv<'_>>,
hidden_in_id: BufId<HiddenBuf>,
hidden_out_id: BufId<HiddenBuf>,
scratch: &LinearAttnGraphScratch,
moe: &MoeGraphScratch,
) -> Result<(), LayerForwardError>
where
B: Backend,
LayerForwardError: From<B::Error>,
LayerForwardError: From<<B::Pool as BufferPool>::Error>,
{
use crate::riir::moe::expert_forward::MAX_K;
let v = VARIANT;
debug_assert!(k_active <= MAX_K);
let linear_layer_idx = linear_layer_idx_for(layer_idx).ok_or(
LayerForwardError::MissingTensor {
layer: layer_idx,
tensor: "linear_layer_idx (batched called on full-attn layer)",
},
)?;
let hidden_dim = v.hidden_dim;
let conv_dim = v.linear_conv_dim();
let total_value = v.linear_total_value();
let num_v_heads = v.linear_num_v_heads;
let qkv_bits = bits_of(
wf,
&format!("model.layers.{layer_idx}.linear_attn.in_proj_qkv.weight"),
);
let z_bits = bits_of(
wf,
&format!("model.layers.{layer_idx}.linear_attn.in_proj_z.weight"),
);
let alpha_bits = bits_of(
wf,
&format!("model.layers.{layer_idx}.linear_attn.in_proj_a.weight"),
);
let beta_bits = bits_of(
wf,
&format!("model.layers.{layer_idx}.linear_attn.in_proj_b.weight"),
);
let o_bits = bits_of(
wf,
&format!("model.layers.{layer_idx}.linear_attn.out_proj.weight"),
);
let attn = layer_cache.attn.linear().ok_or(
LayerForwardError::MissingTensor {
layer: layer_idx,
tensor: "linear_attn weights (batched called on full-attn layer)",
},
)?;
let qkv_w = attn.qkv_w;
let qkv_s = attn.qkv_s;
let qkv_b = attn.qkv_b;
let z_w = attn.z_w;
let z_s = attn.z_s;
let z_b = attn.z_b;
let beta_w = attn.beta_w;
let beta_s = attn.beta_s;
let beta_b = attn.beta_b;
let alpha_w = attn.alpha_w;
let alpha_s = attn.alpha_s;
let alpha_b = attn.alpha_b;
let conv1d_w = attn.conv1d_w;
let a_log = attn.a_log;
let dt_bias = attn.dt_bias;
let gnorm_w = attn.gated_norm_w;
let o_w = attn.o_proj_w;
let o_s = attn.o_proj_s;
let o_b = attn.o_proj_b;
use crate::riir::backend::{Graph, Op, WeightRef};
let gate_bits =
bits_of(wf, &format!("model.layers.{layer_idx}.mlp.gate.weight"));
let seg_bits = bits_of(
wf,
&format!(
"model.layers.{layer_idx}.mlp.shared_expert_gate.weight"
),
);
let value_dim = Variant::LINEAR_VALUE_DIM as u32;
let k_heads_per_v =
(v.linear_num_v_heads / v.linear_num_k_heads) as u32;
let key_offset_per_token =
(v.linear_num_k_heads * Variant::LINEAR_KEY_DIM) as u32;
let graph = {
let mut g = Graph::new();
let normed_id = scratch.normed;
let qkv_stack_id = scratch.qkv_stack;
let z_stack_id = scratch.z_stack;
let beta_stack_id = scratch.beta_stack;
let alpha_stack_id = scratch.alpha_stack;
let conv_out_stack_id = scratch.conv_out_stack;
let g_decay_stack_id = scratch.g_decay_stack;
let beta_gate_stack_id = scratch.beta_gate_stack;
let delta_out_stack_id = scratch.delta_out_stack;
let value_out_stack_id = scratch.value_out_stack;
let o_proj_stack_id = scratch.o_proj_stack;
let gate_logits_id = scratch.gate_logits;
let h_mid_id = moe.h_mid;
let h_post_id = moe.h_post;
let shared_gate_id = moe.shared_gate;
let routing_indices_id = moe.routing_indices;
let routing_weights_id = moe.routing_weights;
g.push(Op::RmsNormBf16NTokens {
label: "linear_attn.input_norm",
x: hidden_in_id.into(),
weight_off: layer_cache.input_layernorm_w,
out: normed_id.into(),
dim: hidden_dim as u32,
n_tokens: n_tokens as u32,
eps: RMS_NORM_EPS,
});
g.push(Op::MatvecNTokens {
label: "linear_attn.qkv_proj",
weight: WeightRef { w_off: qkv_w, s_off: qkv_s, b_off: qkv_b, bits: qkv_bits },
input: normed_id.into(),
input_off: 0,
output: qkv_stack_id.into(),
output_off: 0,
in_dim: hidden_dim as u32,
out_dim: conv_dim as u32,
n_tokens: n_tokens as u32,
});
g.push(Op::MatvecNTokens {
label: "linear_attn.z_proj",
weight: WeightRef { w_off: z_w, s_off: z_s, b_off: z_b, bits: z_bits },
input: normed_id.into(),
input_off: 0,
output: z_stack_id.into(),
output_off: 0,
in_dim: hidden_dim as u32,
out_dim: total_value as u32,
n_tokens: n_tokens as u32,
});
g.push(Op::MatvecNTokens {
label: "linear_attn.beta_proj",
weight: WeightRef { w_off: beta_w, s_off: beta_s, b_off: beta_b, bits: beta_bits },
input: normed_id.into(),
input_off: 0,
output: beta_stack_id.into(),
output_off: 0,
in_dim: hidden_dim as u32,
out_dim: num_v_heads as u32,
n_tokens: n_tokens as u32,
});
g.push(Op::MatvecNTokens {
label: "linear_attn.alpha_proj",
weight: WeightRef { w_off: alpha_w, s_off: alpha_s, b_off: alpha_b, bits: alpha_bits },
input: normed_id.into(),
input_off: 0,
output: alpha_stack_id.into(),
output_off: 0,
in_dim: hidden_dim as u32,
out_dim: num_v_heads as u32,
n_tokens: n_tokens as u32,
});
g.push(Op::Conv1dStepNTokens {
label: "linear_attn.conv1d_step",
qkv_in: qkv_stack_id,
conv_state: buffers.conv_state[linear_layer_idx],
weight_off: conv1d_w,
conv_out: conv_out_stack_id,
conv_dim: conv_dim as u32,
n_tokens: n_tokens as u32,
});
g.push(Op::RmsNormQkNTokens {
label: "linear_attn.rms_norm_qk",
x: conv_out_stack_id,
num_k_heads: v.linear_num_k_heads as u32,
key_dim: Variant::LINEAR_KEY_DIM as u32,
key_offset_per_token,
per_token_total: conv_dim as u32,
n_tokens: n_tokens as u32,
});
g.push(Op::ComputeDecayBetaNTokens {
label: "linear_attn.compute_decay_beta",
alpha_in: alpha_stack_id,
beta_in: beta_stack_id,
a_log_off: a_log,
dt_bias_off: dt_bias,
g_decay_out: g_decay_stack_id,
beta_gate_out: beta_gate_stack_id,
num_v_heads: num_v_heads as u32,
n_tokens: n_tokens as u32,
});
g.push(Op::GatedDeltaNetChunkwise {
label: "linear_attn.gated_delta_net_step",
state: buffers.delta_state[linear_layer_idx],
conv_out: conv_out_stack_id,
g_decay: g_decay_stack_id,
beta_gate: beta_gate_stack_id,
output: delta_out_stack_id,
num_v_heads: num_v_heads as u32,
value_dim,
k_heads_per_v,
n_tokens: n_tokens as u32,
chunk_size: 16,
});
g.push(Op::GatedRmsNormNTokens {
label: "linear_attn.gated_rms_norm",
values: delta_out_stack_id,
z: z_stack_id,
weight_off: gnorm_w,
output: value_out_stack_id,
num_v_heads: num_v_heads as u32,
value_dim,
n_tokens: n_tokens as u32,
eps: RMS_NORM_EPS,
});
g.push(Op::MatvecNTokens {
label: "linear_attn.o_proj",
weight: WeightRef { w_off: o_w, s_off: o_s, b_off: o_b, bits: o_bits },
input: value_out_stack_id.into(),
input_off: 0,
output: o_proj_stack_id.into(),
output_off: 0,
in_dim: total_value as u32,
out_dim: hidden_dim as u32,
n_tokens: n_tokens as u32,
});
g.push(Op::ResidualAddNTokens {
label: "linear_attn.residual_add",
a: o_proj_stack_id,
b: hidden_in_id.into(),
out: h_mid_id,
n_tokens: n_tokens as u32,
dim: hidden_dim as u32,
});
g.push(Op::RmsNormBf16NTokens {
label: "linear_attn.post_attn_norm",
x: h_mid_id.into(),
weight_off: layer_cache.post_attention_layernorm_w,
out: h_post_id.into(),
dim: hidden_dim as u32,
n_tokens: n_tokens as u32,
eps: RMS_NORM_EPS,
});
g.push(Op::MatvecNTokens {
label: "linear_attn.gate_router",
weight: WeightRef {
w_off: layer_cache.gate.w,
s_off: layer_cache.gate.s,
b_off: layer_cache.gate.b,
bits: gate_bits,
},
input: h_post_id.into(),
input_off: 0,
output: gate_logits_id.into(),
output_off: 0,
in_dim: hidden_dim as u32,
out_dim: v.num_experts as u32,
n_tokens: n_tokens as u32,
});
g.push(Op::MatvecNTokens {
label: "linear_attn.shared_gate",
weight: WeightRef {
w_off: layer_cache.shared.seg_w,
s_off: layer_cache.shared.seg_s,
b_off: layer_cache.shared.seg_b,
bits: seg_bits,
},
input: h_post_id.into(),
input_off: 0,
output: shared_gate_id.into(),
output_off: 0,
in_dim: hidden_dim as u32,
out_dim: 1,
n_tokens: n_tokens as u32,
});
g.push(Op::MoeSoftmaxTopK {
label: "linear_attn.router_softmax_topk",
logits: gate_logits_id,
indices_out: routing_indices_id,
weights_out: routing_weights_id,
n_tokens: n_tokens as u32,
n_experts: v.num_experts as u32,
k: k_active as u32,
});
g.push(Op::MoeNormalizeWeights {
label: "linear_attn.router_normalize",
weights: routing_weights_id,
n_tokens: n_tokens as u32,
k: k_active as u32,
});
g
};
if !scratch.commit_planned.get() {
backend.pool_mut().commit_plan(&graph);
scratch.commit_planned.set(true);
}
backend.execute(&graph, "graph_linear_attn")?;
moe_block_forward(
backend,
moe,
wf,
layer_cache,
layer_idx,
n_tokens,
k_active,
expert_files,
moe_buffers,
prefetch,
hidden_out_id,
)
}
pub(in crate::riir) fn moe_block_forward<B>(
backend: &mut B,
moe: &MoeGraphScratch,
wf: &WeightFile,
layer_cache: &LayerWeightCache,
layer_idx: usize,
n_tokens: usize,
k_active: usize,
expert_files: &ExpertFiles,
moe_buffers: &mut crate::riir::moe::expert_forward::MoeBuffers,
mut prefetch: Option<PrefetchEnv<'_>>,
hidden_out_id: BufId<HiddenBuf>,
) -> Result<(), LayerForwardError>
where
B: Backend,
LayerForwardError: From<B::Error>,
LayerForwardError: From<<B::Pool as BufferPool>::Error>,
{
use crate::riir::backend::{Graph, Op, WeightRef};
use crate::riir::moe::moe_router::build_expert_buckets;
let v = VARIANT;
let hidden_dim = v.hidden_dim;
let f32_sz_u = std::mem::size_of::<f32>();
let h_mid_id = moe.h_mid;
let h_post_id = moe.h_post;
let shared_gate_id = moe.shared_gate;
let routing_indices_id = moe.routing_indices;
let routing_weights_id = moe.routing_weights;
let mut h_post_stack = vec![0.0f32; n_tokens * hidden_dim];
let mut all_routing_indices = vec![0i32; n_tokens * k_active];
let mut all_routing_weights = vec![0.0f32; n_tokens * k_active];
{
let pool = backend.pool();
pool.download(h_post_id, unsafe {
std::slice::from_raw_parts_mut(
h_post_stack.as_mut_ptr() as *mut u8,
n_tokens * hidden_dim * f32_sz_u,
)
})?;
pool.download(routing_indices_id, unsafe {
std::slice::from_raw_parts_mut(
all_routing_indices.as_mut_ptr() as *mut u8,
n_tokens * k_active * std::mem::size_of::<i32>(),
)
})?;
pool.download(routing_weights_id, unsafe {
std::slice::from_raw_parts_mut(
all_routing_weights.as_mut_ptr() as *mut u8,
n_tokens * k_active * f32_sz_u,
)
})?;
}
let s_gate_bits = bits_of(
wf,
&format!(
"model.layers.{layer_idx}.mlp.shared_expert.gate_proj.weight"
),
);
let s_up_bits = bits_of(
wf,
&format!("model.layers.{layer_idx}.mlp.shared_expert.up_proj.weight"),
);
let s_down_bits = bits_of(
wf,
&format!(
"model.layers.{layer_idx}.mlp.shared_expert.down_proj.weight"
),
);
let buckets = build_expert_buckets(
&all_routing_indices,
&all_routing_weights,
n_tokens,
k_active,
v.num_experts,
);
let total_assignments = buckets.token_idx.len();
debug_assert_eq!(total_assignments, n_tokens * k_active);
let num_buckets = buckets.expert_ids.len();
if std::env::var_os("MOEFLUX_LOG_HTPE").is_some() {
let mut counts = vec![0u32; v.num_experts];
for bi in 0..num_buckets {
let e = buckets.expert_ids[bi] as usize;
counts[e] = buckets.offsets[bi + 1] - buckets.offsets[bi];
}
let mut line = String::with_capacity(v.num_experts * 5);
for (i, &c) in counts.iter().enumerate() {
if i > 0 {
line.push(',');
}
line.push_str(&c.to_string());
}
eprintln!(
"HTPE layer={layer_idx} n_tokens={n_tokens} k_active={k_active} \
num_experts={} counts=[{line}]",
v.num_experts,
);
}
let _ = moe_buffers;
let expert_base_id: BufId<ExpertBaseBuf> = match moe.expert_base {
None => expert_files
.mmap_id_for_expert(layer_idx, 0)
.expect("mmap layer present in Mmap mode")
.0,
Some(base_id) => {
let expert_size = v.expert_size_4bit();
let mut blob_scratch = vec![0u8; expert_size];
let pool = backend.pool_mut();
for &expert_id in buckets.expert_ids.iter() {
expert_files.read_expert(
layer_idx,
expert_id as usize,
&mut blob_scratch,
)?;
let off = expert_id as usize * expert_size;
pool.upload_at(base_id, off, &blob_scratch)?;
}
base_id
}
};
let expert_slots: Vec<u32> =
buckets.expert_ids.iter().map(|&e| e as u32).collect();
let mut expert_indices_host = vec![0u32; total_assignments];
for bi in 0..num_buckets {
let start = buckets.offsets[bi] as usize;
let end = buckets.offsets[bi + 1] as usize;
expert_indices_host[start..end].fill(expert_slots[bi]);
}
if !moe_gather_id_enabled() {
let mut bucket_input_host =
vec![0.0f32; total_assignments * hidden_dim];
for assignment_idx in 0..total_assignments {
let t = buckets.token_idx[assignment_idx] as usize;
let src =
&h_post_stack[t * hidden_dim..(t + 1) * hidden_dim];
let dst_off = assignment_idx * hidden_dim;
bucket_input_host[dst_off..dst_off + hidden_dim]
.copy_from_slice(src);
}
let pool = backend.pool_mut();
pool.upload(moe.bucket_input, unsafe {
std::slice::from_raw_parts(
bucket_input_host.as_ptr() as *const u8,
total_assignments * hidden_dim * f32_sz_u,
)
})?;
}
{
let pool = backend.pool_mut();
pool.upload(moe.bucket_token_idx, unsafe {
std::slice::from_raw_parts(
buckets.token_idx.as_ptr() as *const u8,
total_assignments * std::mem::size_of::<i32>(),
)
})?;
pool.upload(moe.bucket_weights, unsafe {
std::slice::from_raw_parts(
buckets.weights.as_ptr() as *const u8,
total_assignments * f32_sz_u,
)
})?;
pool.upload(moe.expert_indices, unsafe {
std::slice::from_raw_parts(
expert_indices_host.as_ptr() as *const u8,
total_assignments * std::mem::size_of::<u32>(),
)
})?;
}
let graph2 = {
let mut g = Graph::new();
g.push(Op::MatvecNTokens {
label: "moe.shared_gate_proj",
weight: WeightRef {
w_off: layer_cache.shared.gate_w,
s_off: layer_cache.shared.gate_s,
b_off: layer_cache.shared.gate_b,
bits: s_gate_bits,
},
input: h_post_id.into(),
input_off: 0,
output: moe.shared_ffn_gate.into(),
output_off: 0,
in_dim: hidden_dim as u32,
out_dim: v.shared_intermediate as u32,
n_tokens: n_tokens as u32,
});
g.push(Op::MatvecNTokens {
label: "moe.shared_up_proj",
weight: WeightRef {
w_off: layer_cache.shared.up_w,
s_off: layer_cache.shared.up_s,
b_off: layer_cache.shared.up_b,
bits: s_up_bits,
},
input: h_post_id.into(),
input_off: 0,
output: moe.shared_up.into(),
output_off: 0,
in_dim: hidden_dim as u32,
out_dim: v.shared_intermediate as u32,
n_tokens: n_tokens as u32,
});
g.push(Op::SwigluFusedBatched {
label: "moe.shared_swiglu",
gate: moe.shared_ffn_gate,
up: moe.shared_up,
out: moe.shared_act,
total: (n_tokens * v.shared_intermediate) as u32,
});
g.push(Op::MatvecNTokens {
label: "moe.shared_down_proj",
weight: WeightRef {
w_off: layer_cache.shared.down_w,
s_off: layer_cache.shared.down_s,
b_off: layer_cache.shared.down_b,
bits: s_down_bits,
},
input: moe.shared_act.into(),
input_off: 0,
output: moe.shared_down.into(),
output_off: 0,
in_dim: v.shared_intermediate as u32,
out_dim: hidden_dim as u32,
n_tokens: n_tokens as u32,
});
g.push(Op::ZeroBuffer {
label: "moe.out_sum_zero",
buf: moe.out_sum,
n_bytes: (n_tokens * hidden_dim * f32_sz_u) as u32,
});
if moe_gather_id_enabled() {
g.push(Op::MoeGatherIdFuse {
label: "moe.gather_id_fuse",
expert_base: expert_base_id,
expert_stride: v.expert_size_4bit() as u64,
indices: moe.routing_indices,
weights: moe.routing_weights,
mlp_in: h_post_id,
out_sum: moe.out_sum,
htpe: moe.htpe,
hids: moe.hids,
gate_mid: moe.bucket_gate.into(),
up_mid: moe.bucket_up.into(),
down_mid: moe.bucket_out.into(),
n_tokens: n_tokens as u32,
n_experts: v.num_experts as u32,
k: k_active as u32,
});
} else {
g.push(Op::MoeBatchedPermuteFuse {
label: "moe.permute_fuse",
expert_base: expert_base_id,
expert_stride: v.expert_size_4bit() as u64,
expert_indices: moe.expert_indices,
expert_slots,
bucket_input: moe.bucket_input,
bucket_gate: moe.bucket_gate,
bucket_up: moe.bucket_up,
bucket_act: moe.bucket_act,
bucket_out: moe.bucket_out,
bucket_token_idx: moe.bucket_token_idx,
bucket_weights: moe.bucket_weights,
out_sum: moe.out_sum,
buckets,
});
}
g.push(Op::MoeCombineResidualNTokens {
label: "moe.combine",
h_mid: h_mid_id,
moe_sum: moe.out_sum,
shared_out: moe.shared_down,
shared_gate: shared_gate_id,
hidden_out: hidden_out_id,
n_tokens: n_tokens as u32,
dim: hidden_dim as u32,
});
g
};
if !moe.commit_planned.get() {
backend.pool_mut().commit_plan(&graph2);
moe.commit_planned.set(true);
}
backend.execute(&graph2, "graph_moe")?;
if let Some(pe) = prefetch.as_mut() {
use crate::riir::moe::expert_forward::MAX_K;
let mut actuals: [i32; MAX_K] = [0; MAX_K];
let len = k_active.min(MAX_K).min(all_routing_indices.len());
actuals[..len].copy_from_slice(&all_routing_indices[..len]);
pe.prefetch.record_actual(layer_idx, actuals);
}
Ok(())
}