use metal::NSUInteger;
use crate::riir::moe::expert_forward::MoeBuffers;
use crate::riir::backend::buftype::{
AttnInputBuf, AttnOutBuf, HiddenBuf, KProjOutBuf, OProjOutBuf, QBuf,
QGateBuf, QProjOutBuf, RopeInvFreqBuf, RouterLogitsBuf, VProjOutBuf,
};
use crate::riir::backend::{
Backend, BufId, BufferPool, MetalBufferPool,
};
use crate::riir::io::expert_io::ExpertFiles;
use crate::riir::io::layer_weight_cache::LayerWeightCache;
use crate::riir::io::weight_file::WeightFile;
use crate::riir::backend::gpu::gpu_matvec::{encode_matvec, MatvecPipelines, MatvecSpec};
use crate::riir::backend::gpu::gpu_norm::{encode_rms_norm_bf16_into, RmsNormBf16Pipelines};
use crate::riir::backend::gpu::gpu_ctx::GpuLayerCtx;
use crate::riir::attn::linear_attn_forward::{
bits_of, full_attn_layer_idx_for, moe_block_forward,
moe_dispatch_per_token, post_attention_pre_moe, read_buffer_to_vec,
GpuAttnEncodeArgs, LayerForwardError, MoeGraphScratch, OProj,
PostAttnIntermediates,
};
use crate::riir::backend::gpu::metal::MetalContext;
use crate::riir::attn::rms_norm::rms_norm_per_head_cpu;
use crate::riir::attn::rope::apply_rotary_emb;
use crate::riir::attn::sdpa::sdpa_cpu;
use crate::riir::snapshot::state::KvCache;
use crate::riir::variants::VARIANT;
pub struct FullAttnGraphScratch {
pub normed: BufId<AttnInputBuf>,
pub q_proj_stack: BufId<QProjOutBuf>,
pub k_proj_stack: BufId<KProjOutBuf>,
pub v_proj_stack: BufId<VProjOutBuf>,
pub q_stack: BufId<QBuf>,
pub q_gate_stack: BufId<QGateBuf>,
pub attn_out_stack: BufId<AttnOutBuf>,
pub o_proj_stack: BufId<OProjOutBuf>,
pub gate_logits: BufId<RouterLogitsBuf>,
pub inv_freq: BufId<RopeInvFreqBuf>,
pub commit_planned: std::cell::Cell<bool>,
}
impl FullAttnGraphScratch {
pub fn new(pool: &mut MetalBufferPool) -> Self {
let v = VARIANT;
let chunk = crate::riir::BATCHED_CHUNK_SIZE;
let f32_sz = std::mem::size_of::<f32>();
let hidden = v.hidden_dim;
let q_dim = v.num_attn_heads * v.head_dim;
let kv_dim = v.num_kv_heads * v.head_dim;
let bytes_of = |elems: usize| chunk * elems * f32_sz;
let normed: BufId<AttnInputBuf> = pool
.alloc(bytes_of(hidden), "fags.normed", false)
.expect("FullAttnGraphScratch::new pool alloc");
let q_proj_stack: BufId<QProjOutBuf> = pool
.alloc(bytes_of(q_dim * 2), "fags.q_proj_stack", false)
.expect("FullAttnGraphScratch::new pool alloc");
let k_proj_stack: BufId<KProjOutBuf> = pool
.alloc(bytes_of(kv_dim), "fags.k_proj_stack", false)
.expect("FullAttnGraphScratch::new pool alloc");
let v_proj_stack: BufId<VProjOutBuf> = pool
.alloc(bytes_of(kv_dim), "fags.v_proj_stack", false)
.expect("FullAttnGraphScratch::new pool alloc");
let q_stack: BufId<QBuf> = pool
.alloc(bytes_of(q_dim), "fags.q_stack", false)
.expect("FullAttnGraphScratch::new pool alloc");
let q_gate_stack: BufId<QGateBuf> = pool
.alloc(bytes_of(q_dim), "fags.q_gate_stack", false)
.expect("FullAttnGraphScratch::new pool alloc");
let attn_out_stack: BufId<AttnOutBuf> = pool
.alloc(bytes_of(q_dim), "fags.attn_out_stack", false)
.expect("FullAttnGraphScratch::new pool alloc");
let o_proj_stack: BufId<OProjOutBuf> = pool
.alloc(bytes_of(hidden), "fags.o_proj_stack", false)
.expect("FullAttnGraphScratch::new pool alloc");
let gate_logits: BufId<RouterLogitsBuf> = pool
.alloc(bytes_of(v.num_experts), "fags.gate_logits", false)
.expect("FullAttnGraphScratch::new pool alloc");
let rotary_dim = v.rotary_dim();
let half = rotary_dim / 2;
let theta = crate::riir::variants::ROPE_THETA;
let inv_freq_host: Vec<f32> = (0..half)
.map(|i| {
1.0f32 / theta.powf((2 * i) as f32 / rotary_dim as f32)
})
.collect();
let inv_freq: BufId<RopeInvFreqBuf> = pool
.alloc(half * f32_sz, "fags.inv_freq", true)
.expect("FullAttnGraphScratch::new pool alloc");
pool.upload(inv_freq, unsafe {
std::slice::from_raw_parts(
inv_freq_host.as_ptr() as *const u8,
half * f32_sz,
)
})
.expect("FullAttnGraphScratch::new inv_freq upload");
Self {
normed,
q_proj_stack,
k_proj_stack,
v_proj_stack,
q_stack,
q_gate_stack,
attn_out_stack,
o_proj_stack,
gate_logits,
inv_freq,
commit_planned: std::cell::Cell::new(false),
}
}
}
#[allow(clippy::too_many_arguments)]
pub(in crate::riir) fn full_attn_pre_moe_layer_forward(
metal: &mut MetalContext,
gpu: &GpuLayerCtx<'_>,
layer_idx: usize,
pos: i32,
k_active: usize,
kv_state: &mut KvCache,
prev_layer_chained: bool,
) -> Result<PostAttnIntermediates, LayerForwardError> {
let GpuLayerCtx { wf, wf_buf, layer_cache, buffers, buffer_pool } =
*gpu;
let v = VARIANT;
if v.layer_kind(layer_idx) != crate::riir::variants::LayerKind::FullAttn {
return Err(LayerForwardError::MissingTensor {
layer: layer_idx,
tensor: "full_attn_layer_forward called on linear-attn layer",
});
}
let q_bits = bits_of(
wf,
&format!("model.layers.{layer_idx}.self_attn.q_proj.weight"),
);
let k_bits = bits_of(
wf,
&format!("model.layers.{layer_idx}.self_attn.k_proj.weight"),
);
let v_bits = bits_of(
wf,
&format!("model.layers.{layer_idx}.self_attn.v_proj.weight"),
);
let o_bits = bits_of(
wf,
&format!("model.layers.{layer_idx}.self_attn.o_proj.weight"),
);
let attn = layer_cache.attn.full().ok_or(
LayerForwardError::MissingTensor {
layer: layer_idx,
tensor: "full_attn weights (called on linear-attn layer)",
},
)?;
let q_w = attn.q_proj_w;
let q_s = attn.q_proj_s;
let q_b = attn.q_proj_b;
let k_w = attn.k_proj_w;
let k_s = attn.k_proj_s;
let k_b = attn.k_proj_b;
let v_w = attn.v_proj_w;
let v_s = attn.v_proj_s;
let v_b = attn.v_proj_b;
let o_w = attn.o_proj_w;
let o_s = attn.o_proj_s;
let o_b = attn.o_proj_b;
let q_dim = v.num_attn_heads * v.head_dim; let q_proj_dim = q_dim * 2; let kv_dim = v.num_kv_heads * v.head_dim;
let mv = MatvecPipelines::fetch(metal)?;
let rms_pipes = RmsNormBf16Pipelines::fetch(metal)?;
{
let cmdbuf = metal.queue().new_command_buffer();
if !prev_layer_chained {
encode_rms_norm_bf16_into(
cmdbuf,
&rms_pipes,
buffer_pool.handle(buffers.input),
wf_buf.buffer(),
layer_cache.input_layernorm_w,
buffer_pool.handle(buffers.sum_sq),
buffer_pool.handle(buffers.normed),
v.hidden_dim as u32,
crate::riir::variants::RMS_NORM_EPS,
);
}
let specs = [
MatvecSpec {
w_off: q_w,
s_off: q_s,
b_off: q_b,
input: buffer_pool.handle(buffers.normed),
output: buffer_pool.handle(buffers.q_proj_out),
out_dim: q_proj_dim as u32,
in_dim: v.hidden_dim as u32,
bits: q_bits,
},
MatvecSpec {
w_off: k_w,
s_off: k_s,
b_off: k_b,
input: buffer_pool.handle(buffers.normed),
output: buffer_pool.handle(buffers.k_out),
out_dim: kv_dim as u32,
in_dim: v.hidden_dim as u32,
bits: k_bits,
},
MatvecSpec {
w_off: v_w,
s_off: v_s,
b_off: v_b,
input: buffer_pool.handle(buffers.normed),
output: buffer_pool.handle(buffers.v_out),
out_dim: kv_dim as u32,
in_dim: v.hidden_dim as u32,
bits: v_bits,
},
];
for s in &specs {
encode_matvec(cmdbuf, &mv, wf_buf, s);
}
metal.commit_and_wait_labeled(cmdbuf, "full_attn.cmd1");
}
let q_proj_host = read_buffer_to_vec(buffer_pool.handle(buffers.q_proj_out), q_proj_dim);
let mut k_host = read_buffer_to_vec(buffer_pool.handle(buffers.k_out), kv_dim);
let v_host = read_buffer_to_vec(buffer_pool.handle(buffers.v_out), kv_dim);
let mut q_host = vec![0.0f32; q_dim];
let mut q_gate_host = vec![0.0f32; q_dim];
for h in 0..v.num_attn_heads {
let src_off = h * (2 * v.head_dim);
let dst_off = h * v.head_dim;
q_host[dst_off..dst_off + v.head_dim].copy_from_slice(
&q_proj_host[src_off..src_off + v.head_dim],
);
q_gate_host[dst_off..dst_off + v.head_dim].copy_from_slice(
&q_proj_host[src_off + v.head_dim..src_off + 2 * v.head_dim],
);
}
let q_norm_name =
format!("model.layers.{layer_idx}.self_attn.q_norm.weight");
rms_norm_per_head_cpu(
wf,
&q_norm_name,
v.num_attn_heads,
v.head_dim,
&mut q_host,
)?;
let k_norm_name =
format!("model.layers.{layer_idx}.self_attn.k_norm.weight");
rms_norm_per_head_cpu(
wf,
&k_norm_name,
v.num_kv_heads,
v.head_dim,
&mut k_host,
)?;
apply_rotary_emb(pos, &mut q_host, &mut k_host)?;
let cache_pos = kv_state.len as usize;
if cache_pos + 1 > crate::riir::variants::MAX_SEQ_LEN {
return Err(LayerForwardError::MissingTensor {
layer: layer_idx,
tensor: "kv cache overflow",
});
}
unsafe {
kv_state
.k_slice_mut(buffer_pool, cache_pos, cache_pos + 1)
.copy_from_slice(&k_host);
kv_state
.v_slice_mut(buffer_pool, cache_pos, cache_pos + 1)
.copy_from_slice(&v_host);
}
kv_state.len += 1;
let fa_idx = full_attn_layer_idx_for(layer_idx);
if let Some(fa_idx) = fa_idx {
if cache_pos < crate::riir::variants::GPU_KV_SEQ {
let row_start = cache_pos * kv_dim;
unsafe {
let k_dst = buffer_pool.handle(buffers.gpu_kv_k[fa_idx]).contents() as *mut f32;
let v_dst = buffer_pool.handle(buffers.gpu_kv_v[fa_idx]).contents() as *mut f32;
std::ptr::copy_nonoverlapping(
k_host.as_ptr(),
k_dst.add(row_start),
kv_dim,
);
std::ptr::copy_nonoverlapping(
v_host.as_ptr(),
v_dst.add(row_start),
kv_dim,
);
}
}
}
let kv_len = kv_state.len;
let gpu_attn_ready = fa_idx.is_some()
&& kv_len >= 32
&& (kv_len as usize) < crate::riir::variants::GPU_KV_SEQ;
let gpu_attn_args = if gpu_attn_ready {
let fa_idx = fa_idx.expect("gpu_attn_ready ⇒ Some(fa_idx)");
unsafe {
let q_dst = buffer_pool.handle(buffers.gpu_attn_q).contents() as *mut f32;
let g_dst = buffer_pool.handle(buffers.gpu_attn_gate).contents() as *mut f32;
std::ptr::copy_nonoverlapping(q_host.as_ptr(), q_dst, q_dim);
std::ptr::copy_nonoverlapping(
q_gate_host.as_ptr(),
g_dst,
q_dim,
);
}
Some(GpuAttnEncodeArgs {
fa_idx,
kv_len: kv_len as u32,
})
} else {
let mut attn_out = vec![0.0f32; q_dim];
let (k_prefix, v_prefix) = unsafe {
(
kv_state.k_slice(buffer_pool, kv_len as usize),
kv_state.v_slice(buffer_pool, kv_len as usize),
)
};
sdpa_cpu(
kv_len,
&q_host,
&q_gate_host,
k_prefix,
v_prefix,
&mut attn_out,
)?;
let dst = buffer_pool.handle(buffers.o_proj_stack).contents() as *mut f32;
unsafe {
std::ptr::copy_nonoverlapping(
attn_out.as_ptr(),
dst,
q_dim,
);
}
debug_assert!(
buffer_pool.handle(buffers.o_proj_stack).length() as usize
>= q_dim * std::mem::size_of::<f32>(),
"batch_out[6] sized {} bytes, need {} for full-attn o_proj input",
buffer_pool.handle(buffers.o_proj_stack).length() as NSUInteger,
q_dim * std::mem::size_of::<f32>(),
);
None
};
let queue = metal.queue_clone();
let cmdbuf = queue.new_command_buffer();
let intermediates = post_attention_pre_moe(
metal,
cmdbuf,
gpu,
layer_idx,
k_active,
OProj {
w_off: o_w,
s_off: o_s,
b_off: o_b,
bits: o_bits,
in_dim: q_dim as u32,
},
gpu_attn_args,
)?;
Ok(intermediates)
}
#[allow(clippy::too_many_arguments)]
pub(in crate::riir) fn full_attn_layer_forward(
metal: &mut MetalContext,
gpu: &GpuLayerCtx<'_>,
moe: &mut MoeBuffers,
deferred: &mut crate::riir::moe::deferred::DeferredRing,
layer_idx: usize,
pos: i32,
k_active: usize,
expert_files: &ExpertFiles,
pool: &rayon::ThreadPool,
prefetch: &mut crate::riir::io::prefetch::PrefetchState,
prefetch_set: usize,
kv_state: &mut KvCache,
gpu_combine: bool,
prev_layer_chained: bool,
chain_next_norm_off: Option<u64>,
) -> Result<(), LayerForwardError> {
let GpuLayerCtx { wf: _, wf_buf, layer_cache: _, buffers, buffer_pool } =
*gpu;
let intermediates = full_attn_pre_moe_layer_forward(
metal,
gpu,
layer_idx,
pos,
k_active,
kv_state,
prev_layer_chained,
)?;
moe_dispatch_per_token(
metal,
wf_buf,
buffers,
buffer_pool,
moe,
deferred,
layer_idx,
expert_files,
pool,
prefetch,
prefetch_set,
&intermediates,
gpu_combine,
chain_next_norm_off,
)
}
#[allow(clippy::too_many_arguments)]
pub(in crate::riir) fn batched_full_attn_layer_forward<B>(
backend: &mut B,
wf: &WeightFile,
layer_cache: &LayerWeightCache,
layer_idx: usize,
start_pos: i32,
n_tokens: usize,
k_active: usize,
expert_files: &ExpertFiles,
moe_buffers: &mut MoeBuffers,
kv_state: &mut KvCache,
prefetch: Option<crate::riir::attn::linear_attn_forward::PrefetchEnv<'_>>,
hidden_in_id: BufId<HiddenBuf>,
hidden_out_id: BufId<HiddenBuf>,
scratch: &FullAttnGraphScratch,
moe: &MoeGraphScratch,
) -> Result<(), LayerForwardError>
where
B: Backend,
LayerForwardError: From<B::Error>,
LayerForwardError: From<<B::Pool as BufferPool>::Error>,
{
use crate::riir::moe::expert_forward::MAX_K;
use crate::riir::backend::{Graph, Op, WeightRef};
let v = VARIANT;
debug_assert!(k_active <= MAX_K);
let hidden_dim = v.hidden_dim;
let q_dim = v.num_attn_heads * v.head_dim;
let kv_dim = v.num_kv_heads * v.head_dim;
let q_proj_dim = q_dim * 2;
let num_attn_heads = v.num_attn_heads as u32;
let num_kv_heads = v.num_kv_heads as u32;
let head_dim = v.head_dim as u32;
let rotary_dim = v.rotary_dim() as u32;
let eps = crate::riir::variants::RMS_NORM_EPS;
let k_cache_id = kv_state
.k_id
.expect("kv cache registered by ensure_linear_resources");
let v_cache_id = kv_state
.v_id
.expect("kv cache registered by ensure_linear_resources");
let kv_start = kv_state.len;
if (kv_start as usize) + n_tokens > crate::riir::variants::MAX_SEQ_LEN {
return Err(LayerForwardError::MissingTensor {
layer: layer_idx,
tensor: "kv cache overflow",
});
}
let attn = layer_cache.attn.full().ok_or(
LayerForwardError::MissingTensor {
layer: layer_idx,
tensor: "full_attn weights (batched graph path)",
},
)?;
let q_bits = bits_of(
wf,
&format!("model.layers.{layer_idx}.self_attn.q_proj.weight"),
);
let k_bits = bits_of(
wf,
&format!("model.layers.{layer_idx}.self_attn.k_proj.weight"),
);
let v_bits = bits_of(
wf,
&format!("model.layers.{layer_idx}.self_attn.v_proj.weight"),
);
let o_bits = bits_of(
wf,
&format!("model.layers.{layer_idx}.self_attn.o_proj.weight"),
);
let gate_bits =
bits_of(wf, &format!("model.layers.{layer_idx}.mlp.gate.weight"));
let seg_bits = bits_of(
wf,
&format!("model.layers.{layer_idx}.mlp.shared_expert_gate.weight"),
);
let softmax_scale = 1.0f32 / (v.head_dim as f32).sqrt();
let heads_per_kv = num_attn_heads / num_kv_heads;
let n = n_tokens as u32;
let graph = {
let mut g = Graph::new();
g.push(Op::RmsNormBf16NTokens {
label: "full_attn.input_norm",
x: hidden_in_id.into(),
weight_off: layer_cache.input_layernorm_w,
out: scratch.normed.into(),
dim: hidden_dim as u32,
n_tokens: n,
eps,
});
g.push(Op::MatvecNTokens {
label: "full_attn.q_proj",
weight: WeightRef {
w_off: attn.q_proj_w,
s_off: attn.q_proj_s,
b_off: attn.q_proj_b,
bits: q_bits,
},
input: scratch.normed.into(),
input_off: 0,
output: scratch.q_proj_stack.into(),
output_off: 0,
in_dim: hidden_dim as u32,
out_dim: q_proj_dim as u32,
n_tokens: n,
});
g.push(Op::MatvecNTokens {
label: "full_attn.k_proj",
weight: WeightRef {
w_off: attn.k_proj_w,
s_off: attn.k_proj_s,
b_off: attn.k_proj_b,
bits: k_bits,
},
input: scratch.normed.into(),
input_off: 0,
output: scratch.k_proj_stack.into(),
output_off: 0,
in_dim: hidden_dim as u32,
out_dim: kv_dim as u32,
n_tokens: n,
});
g.push(Op::MatvecNTokens {
label: "full_attn.v_proj",
weight: WeightRef {
w_off: attn.v_proj_w,
s_off: attn.v_proj_s,
b_off: attn.v_proj_b,
bits: v_bits,
},
input: scratch.normed.into(),
input_off: 0,
output: scratch.v_proj_stack.into(),
output_off: 0,
in_dim: hidden_dim as u32,
out_dim: kv_dim as u32,
n_tokens: n,
});
g.push(Op::SplitQGate {
label: "full_attn.split_q_gate",
q_proj: scratch.q_proj_stack,
q_out: scratch.q_stack,
gate_out: scratch.q_gate_stack,
num_heads: num_attn_heads,
head_dim,
n_tokens: n,
});
g.push(Op::RmsNormPerHeadNTokens {
label: "full_attn.q_norm",
x: scratch.q_stack.into(),
weight_off: attn.q_norm_w,
num_heads: num_attn_heads,
head_dim,
n_tokens: n,
eps,
});
g.push(Op::RmsNormPerHeadNTokens {
label: "full_attn.k_norm",
x: scratch.k_proj_stack.into(),
weight_off: attn.k_norm_w,
num_heads: num_kv_heads,
head_dim,
n_tokens: n,
eps,
});
g.push(Op::RopeNTokens {
label: "full_attn.q_rope",
x: scratch.q_stack.into(),
inv_freq: scratch.inv_freq,
n_tokens: n,
num_heads: num_attn_heads,
head_dim,
rotary_dim,
start_pos,
});
g.push(Op::RopeNTokens {
label: "full_attn.k_rope",
x: scratch.k_proj_stack.into(),
inv_freq: scratch.inv_freq,
n_tokens: n,
num_heads: num_kv_heads,
head_dim,
rotary_dim,
start_pos,
});
g.push(Op::KvCacheAppendNTokens {
label: "full_attn.kv_append",
k_src: scratch.k_proj_stack,
v_src: scratch.v_proj_stack,
k_cache: k_cache_id,
v_cache: v_cache_id,
kv_dim: kv_dim as u32,
n_tokens: n,
kv_start: kv_start as u32,
});
g.push(Op::SdpaCausalTiled {
label: "full_attn.sdpa",
q: scratch.q_stack,
k: k_cache_id,
v: v_cache_id,
attn_out: scratch.attn_out_stack,
n_tokens: n,
num_heads: num_attn_heads,
heads_per_kv,
head_dim,
kv_dim: kv_dim as u32,
kv_start: kv_start as u32,
kv_len_total: kv_start as u32 + n,
softmax_scale,
});
g.push(Op::SigmoidGateNTokens {
label: "full_attn.sigmoid_gate",
x: scratch.attn_out_stack,
gate: scratch.q_gate_stack,
dim: q_dim as u32,
n_tokens: n,
});
g.push(Op::MatvecNTokens {
label: "full_attn.o_proj",
weight: WeightRef {
w_off: attn.o_proj_w,
s_off: attn.o_proj_s,
b_off: attn.o_proj_b,
bits: o_bits,
},
input: scratch.attn_out_stack.into(),
input_off: 0,
output: scratch.o_proj_stack.into(),
output_off: 0,
in_dim: q_dim as u32,
out_dim: hidden_dim as u32,
n_tokens: n,
});
g.push(Op::ResidualAddNTokens {
label: "full_attn.residual_add",
a: scratch.o_proj_stack,
b: hidden_in_id.into(),
out: moe.h_mid,
n_tokens: n,
dim: hidden_dim as u32,
});
g.push(Op::RmsNormBf16NTokens {
label: "full_attn.post_attn_norm",
x: moe.h_mid.into(),
weight_off: layer_cache.post_attention_layernorm_w,
out: moe.h_post.into(),
dim: hidden_dim as u32,
n_tokens: n,
eps,
});
g.push(Op::MatvecNTokens {
label: "full_attn.gate_router",
weight: WeightRef {
w_off: layer_cache.gate.w,
s_off: layer_cache.gate.s,
b_off: layer_cache.gate.b,
bits: gate_bits,
},
input: moe.h_post.into(),
input_off: 0,
output: scratch.gate_logits.into(),
output_off: 0,
in_dim: hidden_dim as u32,
out_dim: v.num_experts as u32,
n_tokens: n,
});
g.push(Op::MatvecNTokens {
label: "full_attn.shared_gate",
weight: WeightRef {
w_off: layer_cache.shared.seg_w,
s_off: layer_cache.shared.seg_s,
b_off: layer_cache.shared.seg_b,
bits: seg_bits,
},
input: moe.h_post.into(),
input_off: 0,
output: moe.shared_gate.into(),
output_off: 0,
in_dim: hidden_dim as u32,
out_dim: 1,
n_tokens: n,
});
g.push(Op::MoeSoftmaxTopK {
label: "full_attn.router_softmax_topk",
logits: scratch.gate_logits,
indices_out: moe.routing_indices,
weights_out: moe.routing_weights,
n_tokens: n,
n_experts: v.num_experts as u32,
k: k_active as u32,
});
g.push(Op::MoeNormalizeWeights {
label: "full_attn.router_normalize",
weights: moe.routing_weights,
n_tokens: n,
k: k_active as u32,
});
g
};
if !scratch.commit_planned.get() {
backend.pool_mut().commit_plan(&graph);
scratch.commit_planned.set(true);
}
backend.execute(&graph, "graph_full_attn")?;
kv_state.len += n_tokens as i32;
moe_block_forward(
backend,
moe,
wf,
layer_cache,
layer_idx,
n_tokens,
k_active,
expert_files,
moe_buffers,
prefetch,
hidden_out_id,
)
}