use crate::riir::moe::moe_router::ExpertBuckets;
pub mod buftype;
pub use buftype::*;
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub struct WeightRef {
pub w_off: u64,
pub s_off: u64,
pub b_off: u64,
pub bits: u32,
}
#[derive(Debug, thiserror::Error)]
pub enum GraphError {
#[error("buffer id {0} out of range")]
BadBufId(u32),
#[error("buffer size mismatch for {label:?}: expected {expected} bytes, got {actual}")]
SizeMismatch { label: &'static str, expected: usize, actual: usize },
#[error("backend error: {0}")]
Backend(Box<dyn std::error::Error + Send + Sync + 'static>),
}
pub trait BufferPool {
type Handle;
type Error: std::error::Error + Send + Sync + 'static;
fn alloc<B: Buf>(
&mut self,
bytes: usize,
label: &'static str,
persistent: bool,
) -> Result<BufId<B>, Self::Error>;
fn handle<B: Buf>(&self, id: BufId<B>) -> &Self::Handle;
fn upload<B: Buf>(
&mut self,
id: BufId<B>,
host: &[u8],
) -> Result<(), Self::Error>;
fn upload_at<B: Buf>(
&mut self,
id: BufId<B>,
offset: usize,
host: &[u8],
) -> Result<(), Self::Error>;
fn download<B: Buf>(
&self,
id: BufId<B>,
host: &mut [u8],
) -> Result<(), Self::Error>;
fn reset_transient(&mut self);
fn label<B: Buf>(&self, id: BufId<B>) -> &'static str;
fn commit_plan(&mut self, _graph: &Graph) {}
}
pub trait Backend {
type Pool: BufferPool;
type EncodeCtx;
type Config;
type Error: std::error::Error + Send + Sync + 'static;
fn open(config: Self::Config) -> Result<Self, Self::Error>
where
Self: Sized;
fn pool(&self) -> &Self::Pool;
fn pool_mut(&mut self) -> &mut Self::Pool;
fn begin_encoding(&self) -> Self::EncodeCtx;
fn encode_op(&self, op: &Op, ctx: &mut Self::EncodeCtx);
fn encode_graph(&self, graph: &Graph, ctx: &mut Self::EncodeCtx) {
for op in &graph.ops {
self.encode_op(op, ctx);
}
}
fn submit_and_wait(
&self,
ctx: Self::EncodeCtx,
label: &'static str,
) -> Result<(), Self::Error>;
fn execute(
&self,
graph: &Graph,
label: &'static str,
) -> Result<(), Self::Error> {
let mut ctx = self.begin_encoding();
self.encode_graph(graph, &mut ctx);
self.submit_and_wait(ctx, label)
}
fn begin_layer(&mut self, _chunk_idx: usize, _layer_idx: usize) {}
}
#[derive(Debug)]
pub enum Op {
RmsNormBf16NTokens {
label: &'static str,
x: BufId<RmsNormIn>,
weight_off: u64,
out: BufId<RmsNormOut>,
dim: u32,
n_tokens: u32,
eps: f32,
},
RmsNormQkNTokens {
label: &'static str,
x: BufId<ConvOutBuf>,
num_k_heads: u32,
key_dim: u32,
key_offset_per_token: u32,
per_token_total: u32,
n_tokens: u32,
},
RmsNormPerHeadNTokens {
label: &'static str,
x: BufId<RmsNormIn>,
weight_off: u64,
num_heads: u32,
head_dim: u32,
n_tokens: u32,
eps: f32,
},
RopeNTokens {
label: &'static str,
x: BufId<RmsNormIn>,
inv_freq: BufId<RopeInvFreqBuf>,
n_tokens: u32,
num_heads: u32,
head_dim: u32,
rotary_dim: u32,
start_pos: i32,
},
SplitQGate {
label: &'static str,
q_proj: BufId<QProjOutBuf>,
q_out: BufId<QBuf>,
gate_out: BufId<QGateBuf>,
num_heads: u32,
head_dim: u32,
n_tokens: u32,
},
ResidualAddNTokens {
label: &'static str,
a: BufId<OProjOutBuf>,
b: BufId<RmsNormIn>,
out: BufId<ResidualBuf>,
n_tokens: u32,
dim: u32,
},
ZeroBuffer {
label: &'static str,
buf: BufId<MoeOutSumBuf>,
n_bytes: u32,
},
MatvecNTokens {
label: &'static str,
weight: WeightRef,
input: BufId<MatvecIn>,
input_off: u64,
output: BufId<MatvecOut>,
output_off: u64,
in_dim: u32,
out_dim: u32,
n_tokens: u32,
},
SwigluFusedBatched {
label: &'static str,
gate: BufId<SharedFfnGateBuf>,
up: BufId<SharedFfnUpBuf>,
out: BufId<SharedFfnActBuf>,
total: u32,
},
SdpaCausalTiled {
label: &'static str,
q: BufId<QBuf>,
k: BufId<KvCacheKBuf>,
v: BufId<KvCacheVBuf>,
attn_out: BufId<AttnOutBuf>,
n_tokens: u32,
num_heads: u32,
heads_per_kv: u32,
head_dim: u32,
kv_dim: u32,
kv_start: u32,
kv_len_total: u32,
softmax_scale: f32,
},
SigmoidGateNTokens {
label: &'static str,
x: BufId<AttnOutBuf>,
gate: BufId<QGateBuf>,
dim: u32,
n_tokens: u32,
},
KvCacheAppendNTokens {
label: &'static str,
k_src: BufId<KProjOutBuf>,
v_src: BufId<VProjOutBuf>,
k_cache: BufId<KvCacheKBuf>,
v_cache: BufId<KvCacheVBuf>,
kv_dim: u32,
n_tokens: u32,
kv_start: u32,
},
MoeSoftmaxTopK {
label: &'static str,
logits: BufId<RouterLogitsBuf>,
indices_out: BufId<RouterIdxBuf>,
weights_out: BufId<RouterWeightsBuf>,
n_tokens: u32,
n_experts: u32,
k: u32,
},
MoeNormalizeWeights {
label: &'static str,
weights: BufId<RouterWeightsBuf>,
n_tokens: u32,
k: u32,
},
MoeBatchedPermuteFuse {
label: &'static str,
expert_base: BufId<ExpertBaseBuf>,
expert_stride: u64,
expert_indices: BufId<ExpertIndicesBuf>,
expert_slots: Vec<u32>,
bucket_input: BufId<BucketInputBuf>,
bucket_gate: BufId<BucketGateBuf>,
bucket_up: BufId<BucketUpBuf>,
bucket_act: BufId<BucketActBuf>,
bucket_out: BufId<BucketOutBuf>,
bucket_token_idx: BufId<BucketTokenIdxBuf>,
bucket_weights: BufId<BucketWeightsBuf>,
out_sum: BufId<MoeOutSumBuf>,
buckets: ExpertBuckets,
},
MoeGatherIdFuse {
label: &'static str,
expert_base: BufId<ExpertBaseBuf>,
expert_stride: u64,
indices: BufId<RouterIdxBuf>,
weights: BufId<RouterWeightsBuf>,
mlp_in: BufId<MoeInputBuf>,
out_sum: BufId<MoeOutSumBuf>,
htpe: BufId<HtpeBuf>,
hids: BufId<HidsBuf>,
gate_mid: BufId<GateMidBuf>,
up_mid: BufId<UpMidBuf>,
down_mid: BufId<DownMidBuf>,
n_tokens: u32,
n_experts: u32,
k: u32,
},
MoeCombineResidualNTokens {
label: &'static str,
h_mid: BufId<ResidualBuf>,
moe_sum: BufId<MoeOutSumBuf>,
shared_out: BufId<SharedFfnDownBuf>,
shared_gate: BufId<SharedGateBuf>,
hidden_out: BufId<HiddenBuf>,
n_tokens: u32,
dim: u32,
},
Conv1dStepNTokens {
label: &'static str,
qkv_in: BufId<QkvStackBuf>,
conv_state: BufId<ConvStateBuf>,
weight_off: u64,
conv_out: BufId<ConvOutBuf>,
conv_dim: u32,
n_tokens: u32,
},
ComputeDecayBetaNTokens {
label: &'static str,
alpha_in: BufId<AlphaStackBuf>,
beta_in: BufId<BetaStackBuf>,
a_log_off: u64,
dt_bias_off: u64,
g_decay_out: BufId<GDecayBuf>,
beta_gate_out: BufId<BetaGateBuf>,
num_v_heads: u32,
n_tokens: u32,
},
GatedDeltaNetStepNTokens {
label: &'static str,
state: BufId<DeltaStateBuf>,
conv_out: BufId<ConvOutBuf>,
g_decay: BufId<GDecayBuf>,
beta_gate: BufId<BetaGateBuf>,
output: BufId<DeltaOutBuf>,
num_v_heads: u32,
value_dim: u32,
k_heads_per_v: u32,
n_tokens: u32,
},
GatedDeltaNetChunkwise {
label: &'static str,
state: BufId<DeltaStateBuf>,
conv_out: BufId<ConvOutBuf>,
g_decay: BufId<GDecayBuf>,
beta_gate: BufId<BetaGateBuf>,
output: BufId<DeltaOutBuf>,
num_v_heads: u32,
value_dim: u32,
k_heads_per_v: u32,
n_tokens: u32,
chunk_size: u32,
},
GatedRmsNormNTokens {
label: &'static str,
values: BufId<DeltaOutBuf>,
z: BufId<ZStackBuf>,
weight_off: u64,
output: BufId<ValueOutBuf>,
num_v_heads: u32,
value_dim: u32,
n_tokens: u32,
eps: f32,
},
EmbedGatherNTokens {
label: &'static str,
token_ids: BufId<TokenIdsBuf>,
weight: WeightRef,
hidden_out: BufId<EmbedOutBuf>,
hidden_dim: u32,
n_tokens: u32,
},
}
impl Op {
pub fn label(&self) -> &'static str {
match self {
Op::RmsNormBf16NTokens { label, .. } => label,
Op::RmsNormQkNTokens { label, .. } => label,
Op::RopeNTokens { label, .. } => label,
Op::ResidualAddNTokens { label, .. } => label,
Op::ZeroBuffer { label, .. } => label,
Op::MatvecNTokens { label, .. } => label,
Op::SwigluFusedBatched { label, .. } => label,
Op::SdpaCausalTiled { label, .. } => label,
Op::SigmoidGateNTokens { label, .. } => label,
Op::SplitQGate { label, .. } => label,
Op::RmsNormPerHeadNTokens { label, .. } => label,
Op::KvCacheAppendNTokens { label, .. } => label,
Op::MoeSoftmaxTopK { label, .. } => label,
Op::MoeNormalizeWeights { label, .. } => label,
Op::MoeBatchedPermuteFuse { label, .. } => label,
Op::MoeGatherIdFuse { label, .. } => label,
Op::MoeCombineResidualNTokens { label, .. } => label,
Op::Conv1dStepNTokens { label, .. } => label,
Op::ComputeDecayBetaNTokens { label, .. } => label,
Op::GatedDeltaNetStepNTokens { label, .. } => label,
Op::GatedDeltaNetChunkwise { label, .. } => label,
Op::GatedRmsNormNTokens { label, .. } => label,
Op::EmbedGatherNTokens { label, .. } => label,
}
}
pub fn variant_name(&self) -> &'static str {
match self {
Op::RmsNormBf16NTokens { .. } => "RmsNormBf16NTokens",
Op::RmsNormQkNTokens { .. } => "RmsNormQkNTokens",
Op::RopeNTokens { .. } => "RopeNTokens",
Op::ResidualAddNTokens { .. } => "ResidualAddNTokens",
Op::ZeroBuffer { .. } => "ZeroBuffer",
Op::MatvecNTokens { .. } => "MatvecNTokens",
Op::SwigluFusedBatched { .. } => "SwigluFusedBatched",
Op::SdpaCausalTiled { .. } => "SdpaCausalTiled",
Op::SigmoidGateNTokens { .. } => "SigmoidGateNTokens",
Op::SplitQGate { .. } => "SplitQGate",
Op::RmsNormPerHeadNTokens { .. } => "RmsNormPerHeadNTokens",
Op::KvCacheAppendNTokens { .. } => "KvCacheAppendNTokens",
Op::MoeSoftmaxTopK { .. } => "MoeSoftmaxTopK",
Op::MoeNormalizeWeights { .. } => "MoeNormalizeWeights",
Op::MoeBatchedPermuteFuse { .. } => "MoeBatchedPermuteFuse",
Op::MoeGatherIdFuse { .. } => "MoeGatherIdFuse",
Op::MoeCombineResidualNTokens { .. } => "MoeCombineResidualNTokens",
Op::Conv1dStepNTokens { .. } => "Conv1dStepNTokens",
Op::ComputeDecayBetaNTokens { .. } => "ComputeDecayBetaNTokens",
Op::GatedDeltaNetStepNTokens { .. } => "GatedDeltaNetStepNTokens",
Op::GatedDeltaNetChunkwise { .. } => "GatedDeltaNetChunkwise",
Op::GatedRmsNormNTokens { .. } => "GatedRmsNormNTokens",
Op::EmbedGatherNTokens { .. } => "EmbedGatherNTokens",
}
}
pub fn reads_raw(&self) -> Vec<u32> {
match self {
Op::RmsNormBf16NTokens { x, .. } => vec![x.raw()],
Op::RmsNormQkNTokens { x, .. } => vec![x.raw()],
Op::RopeNTokens { x, inv_freq, .. } => {
vec![x.raw(), inv_freq.raw()]
}
Op::ResidualAddNTokens { a, b, .. } => vec![a.raw(), b.raw()],
Op::ZeroBuffer { .. } => vec![],
Op::MatvecNTokens { input, .. } => vec![input.raw()],
Op::SwigluFusedBatched { gate, up, .. } => {
vec![gate.raw(), up.raw()]
}
Op::SdpaCausalTiled { q, k, v, .. } => {
vec![q.raw(), k.raw(), v.raw()]
}
Op::SigmoidGateNTokens { x, gate, .. } => {
vec![x.raw(), gate.raw()]
}
Op::SplitQGate { q_proj, .. } => vec![q_proj.raw()],
Op::RmsNormPerHeadNTokens { x, .. } => vec![x.raw()],
Op::KvCacheAppendNTokens { k_src, v_src, .. } => {
vec![k_src.raw(), v_src.raw()]
}
Op::MoeSoftmaxTopK { logits, .. } => vec![logits.raw()],
Op::MoeNormalizeWeights { weights, .. } => vec![weights.raw()],
Op::MoeBatchedPermuteFuse {
expert_base,
expert_indices,
bucket_input,
bucket_token_idx,
bucket_weights,
..
} => vec![
expert_base.raw(),
expert_indices.raw(),
bucket_input.raw(),
bucket_token_idx.raw(),
bucket_weights.raw(),
],
Op::MoeGatherIdFuse {
expert_base,
indices,
weights,
mlp_in,
..
} => vec![
expert_base.raw(),
indices.raw(),
weights.raw(),
mlp_in.raw(),
],
Op::MoeCombineResidualNTokens {
h_mid,
moe_sum,
shared_out,
shared_gate,
..
} => vec![
h_mid.raw(),
moe_sum.raw(),
shared_out.raw(),
shared_gate.raw(),
],
Op::Conv1dStepNTokens {
qkv_in,
conv_state,
..
} => vec![qkv_in.raw(), conv_state.raw()],
Op::ComputeDecayBetaNTokens {
alpha_in, beta_in, ..
} => vec![alpha_in.raw(), beta_in.raw()],
Op::GatedDeltaNetStepNTokens {
state,
conv_out,
g_decay,
beta_gate,
..
} => vec![
state.raw(),
conv_out.raw(),
g_decay.raw(),
beta_gate.raw(),
],
Op::GatedDeltaNetChunkwise {
state,
conv_out,
g_decay,
beta_gate,
..
} => vec![
state.raw(),
conv_out.raw(),
g_decay.raw(),
beta_gate.raw(),
],
Op::GatedRmsNormNTokens { values, z, .. } => {
vec![values.raw(), z.raw()]
}
Op::EmbedGatherNTokens { token_ids, .. } => vec![token_ids.raw()],
}
}
pub fn writes_raw(&self) -> Vec<u32> {
match self {
Op::RmsNormBf16NTokens { out, .. } => vec![out.raw()],
Op::RmsNormQkNTokens { x, .. } => vec![x.raw()], Op::RopeNTokens { x, .. } => vec![x.raw()], Op::ResidualAddNTokens { out, .. } => vec![out.raw()],
Op::ZeroBuffer { buf, .. } => vec![buf.raw()],
Op::MatvecNTokens { output, .. } => vec![output.raw()],
Op::SwigluFusedBatched { out, .. } => vec![out.raw()],
Op::SdpaCausalTiled { attn_out, .. } => vec![attn_out.raw()],
Op::SigmoidGateNTokens { x, .. } => vec![x.raw()], Op::SplitQGate {
q_out, gate_out, ..
} => vec![q_out.raw(), gate_out.raw()],
Op::RmsNormPerHeadNTokens { x, .. } => vec![x.raw()], Op::KvCacheAppendNTokens {
k_cache, v_cache, ..
} => vec![k_cache.raw(), v_cache.raw()],
Op::MoeSoftmaxTopK {
indices_out,
weights_out,
..
} => vec![indices_out.raw(), weights_out.raw()],
Op::MoeNormalizeWeights { weights, .. } => {
vec![weights.raw()] }
Op::MoeBatchedPermuteFuse {
bucket_gate,
bucket_up,
bucket_act,
bucket_out,
out_sum,
..
} => vec![
bucket_gate.raw(),
bucket_up.raw(),
bucket_act.raw(),
bucket_out.raw(),
out_sum.raw(),
],
Op::MoeGatherIdFuse {
htpe,
hids,
gate_mid,
up_mid,
down_mid,
out_sum,
..
} => vec![
htpe.raw(),
hids.raw(),
gate_mid.raw(),
up_mid.raw(),
down_mid.raw(),
out_sum.raw(),
],
Op::MoeCombineResidualNTokens { hidden_out, .. } => {
vec![hidden_out.raw()]
}
Op::Conv1dStepNTokens {
conv_state,
conv_out,
..
} => vec![conv_state.raw(), conv_out.raw()], Op::ComputeDecayBetaNTokens {
g_decay_out,
beta_gate_out,
..
} => vec![g_decay_out.raw(), beta_gate_out.raw()],
Op::GatedDeltaNetStepNTokens { state, output, .. } => {
vec![state.raw(), output.raw()] }
Op::GatedDeltaNetChunkwise { state, output, .. } => {
vec![state.raw(), output.raw()] }
Op::GatedRmsNormNTokens { output, .. } => vec![output.raw()],
Op::EmbedGatherNTokens { hidden_out, .. } => vec![hidden_out.raw()],
}
}
}
#[derive(Debug, Default)]
pub struct Graph {
pub ops: Vec<Op>,
}
impl Graph {
pub fn new() -> Self {
Self { ops: Vec::new() }
}
pub fn push(&mut self, op: Op) {
self.ops.push(op);
}
pub fn len(&self) -> usize {
self.ops.len()
}
pub fn is_empty(&self) -> bool {
self.ops.is_empty()
}
pub fn labels(&self) -> impl Iterator<Item = &'static str> + '_ {
self.ops.iter().map(Op::label)
}
pub fn dump(&self) -> String {
use std::fmt::Write as _;
let mut s = String::new();
for (i, op) in self.ops.iter().enumerate() {
let _ = writeln!(
s,
"{i:3} {variant:<28} {label}",
variant = op.variant_name(),
label = op.label(),
);
}
s
}
}
pub mod cpu;
pub mod gpu;
pub mod lifetime;
pub use cpu::{CpuBackend, CpuBufferPool};
pub use gpu::{MetalBackend, MetalBufferPool, MetalConfig, MetalEncodeCtx};
#[cfg(test)]
mod tests {
use super::*;
fn buf<B: Buf>(n: u32) -> BufId<B> {
BufId::from_raw(n)
}
fn one_of_each() -> Graph {
let mut g = Graph::new();
g.push(Op::RmsNormBf16NTokens {
label: "rms_in",
x: buf::<EmbedOutBuf>(0).into(),
weight_off: 0,
out: buf::<AttnInputBuf>(1).into(),
dim: 4096,
n_tokens: 8,
eps: 1e-6,
});
g.push(Op::RmsNormQkNTokens {
label: "qk_norm",
x: buf::<ConvOutBuf>(2),
num_k_heads: 4,
key_dim: 128,
key_offset_per_token: 512,
per_token_total: 1024,
n_tokens: 8,
});
g.push(Op::ResidualAddNTokens {
label: "resid",
a: buf::<OProjOutBuf>(3),
b: buf::<HiddenBuf>(4).into(),
out: buf::<ResidualBuf>(5),
n_tokens: 8,
dim: 4096,
});
g.push(Op::MatvecNTokens {
label: "q_proj",
weight: WeightRef { w_off: 0, s_off: 0, b_off: 0, bits: 4 },
input: buf::<AttnInputBuf>(6).into(),
input_off: 0,
output: buf::<QProjOutBuf>(7).into(),
output_off: 0,
in_dim: 4096,
out_dim: 4096,
n_tokens: 8,
});
g.push(Op::SwigluFusedBatched {
label: "ffn_swiglu",
gate: buf::<SharedFfnGateBuf>(8),
up: buf::<SharedFfnUpBuf>(9),
out: buf::<SharedFfnActBuf>(10),
total: 8 * 1024,
});
g.push(Op::SdpaCausalTiled {
label: "sdpa",
q: buf::<QBuf>(11),
k: buf::<KvCacheKBuf>(12),
v: buf::<KvCacheVBuf>(13),
attn_out: buf::<AttnOutBuf>(14),
n_tokens: 8,
num_heads: 16,
heads_per_kv: 2,
head_dim: 128,
kv_dim: 1024,
kv_start: 0,
kv_len_total: 8,
softmax_scale: 0.088_388_35,
});
g.push(Op::SigmoidGateNTokens {
label: "sigmoid_gate",
x: buf::<AttnOutBuf>(15),
gate: buf::<QGateBuf>(16),
dim: 1024,
n_tokens: 8,
});
g.push(Op::SplitQGate {
label: "split_q_gate",
q_proj: buf::<QProjOutBuf>(40),
q_out: buf::<QBuf>(41),
gate_out: buf::<QGateBuf>(42),
num_heads: 16,
head_dim: 128,
n_tokens: 8,
});
g.push(Op::RmsNormPerHeadNTokens {
label: "rms_norm_per_head",
x: buf::<QBuf>(43).into(),
weight_off: 0,
num_heads: 16,
head_dim: 128,
n_tokens: 8,
eps: 1e-6,
});
g.push(Op::KvCacheAppendNTokens {
label: "kv_cache_append",
k_src: buf::<KProjOutBuf>(44),
v_src: buf::<VProjOutBuf>(45),
k_cache: buf::<KvCacheKBuf>(46),
v_cache: buf::<KvCacheVBuf>(47),
kv_dim: 1024,
n_tokens: 8,
kv_start: 0,
});
g.push(Op::MoeSoftmaxTopK {
label: "moe_topk",
logits: buf::<RouterLogitsBuf>(18),
indices_out: buf::<RouterIdxBuf>(19),
weights_out: buf::<RouterWeightsBuf>(20),
n_tokens: 8,
n_experts: 128,
k: 8,
});
g.push(Op::MoeNormalizeWeights {
label: "moe_norm",
weights: buf::<RouterWeightsBuf>(20),
n_tokens: 8,
k: 8,
});
g.push(Op::MoeBatchedPermuteFuse {
label: "moe_pf",
expert_base: buf::<ExpertBaseBuf>(21),
expert_stride: 0,
expert_indices: buf::<ExpertIndicesBuf>(35),
expert_slots: vec![0],
bucket_input: buf::<BucketInputBuf>(22),
bucket_gate: buf::<BucketGateBuf>(23),
bucket_up: buf::<BucketUpBuf>(24),
bucket_act: buf::<BucketActBuf>(25),
bucket_out: buf::<BucketOutBuf>(26),
bucket_token_idx: buf::<BucketTokenIdxBuf>(27),
bucket_weights: buf::<BucketWeightsBuf>(28),
out_sum: buf::<MoeOutSumBuf>(29),
buckets: ExpertBuckets {
expert_ids: vec![0],
offsets: vec![0, 8],
token_idx: vec![0, 1, 2, 3, 4, 5, 6, 7],
weights: vec![0.125; 8],
},
});
g.push(Op::MoeCombineResidualNTokens {
label: "moe_combine",
h_mid: buf::<ResidualBuf>(30),
moe_sum: buf::<MoeOutSumBuf>(31),
shared_out: buf::<SharedFfnDownBuf>(32),
shared_gate: buf::<SharedGateBuf>(33),
hidden_out: buf::<HiddenBuf>(34),
n_tokens: 8,
dim: 4096,
});
g.push(Op::Conv1dStepNTokens {
label: "conv1d",
qkv_in: buf::<QkvStackBuf>(35),
conv_state: buf::<ConvStateBuf>(36),
weight_off: 0,
conv_out: buf::<ConvOutBuf>(37),
conv_dim: 5120,
n_tokens: 8,
});
g.push(Op::ComputeDecayBetaNTokens {
label: "decay_beta",
alpha_in: buf::<AlphaStackBuf>(38),
beta_in: buf::<BetaStackBuf>(39),
a_log_off: 0,
dt_bias_off: 0,
g_decay_out: buf::<GDecayBuf>(40),
beta_gate_out: buf::<BetaGateBuf>(41),
num_v_heads: 16,
n_tokens: 8,
});
g.push(Op::GatedDeltaNetStepNTokens {
label: "delta_net",
state: buf::<DeltaStateBuf>(42),
conv_out: buf::<ConvOutBuf>(43),
g_decay: buf::<GDecayBuf>(44),
beta_gate: buf::<BetaGateBuf>(45),
output: buf::<DeltaOutBuf>(46),
num_v_heads: 16,
value_dim: 128,
k_heads_per_v: 2,
n_tokens: 8,
});
g.push(Op::GatedRmsNormNTokens {
label: "gated_rms",
values: buf::<DeltaOutBuf>(47),
z: buf::<ZStackBuf>(48),
weight_off: 0,
output: buf::<ValueOutBuf>(49),
num_v_heads: 16,
value_dim: 128,
n_tokens: 8,
eps: 1e-6,
});
g.push(Op::EmbedGatherNTokens {
label: "embed_gather",
token_ids: buf::<TokenIdsBuf>(50),
weight: WeightRef { w_off: 0, s_off: 0, b_off: 0, bits: 4 },
hidden_out: buf::<EmbedOutBuf>(51),
hidden_dim: 2048,
n_tokens: 8,
});
g
}
#[test]
fn push_round_trips() {
let g = one_of_each();
assert_eq!(g.len(), 19);
assert!(!g.is_empty());
}
#[test]
fn labels_iter_matches_push_order() {
let g = one_of_each();
let labels: Vec<&str> = g.labels().collect();
assert_eq!(
labels,
vec![
"rms_in",
"qk_norm",
"resid",
"q_proj",
"ffn_swiglu",
"sdpa",
"sigmoid_gate",
"split_q_gate",
"rms_norm_per_head",
"kv_cache_append",
"moe_topk",
"moe_norm",
"moe_pf",
"moe_combine",
"conv1d",
"decay_beta",
"delta_net",
"gated_rms",
"embed_gather",
]
);
}
#[test]
fn variant_name_matches_label_for_each_variant() {
let g = one_of_each();
let pairs: Vec<(&str, &str)> = g
.ops
.iter()
.map(|op| (op.variant_name(), op.label()))
.collect();
assert!(pairs.contains(&("RmsNormBf16NTokens", "rms_in")));
assert!(pairs.contains(&("MoeBatchedPermuteFuse", "moe_pf")));
assert!(pairs.contains(&("EmbedGatherNTokens", "embed_gather")));
assert_eq!(pairs.len(), 19);
}
#[test]
fn reads_and_writes_are_non_empty_for_every_variant() {
let g = one_of_each();
for op in &g.ops {
assert!(
!op.reads_raw().is_empty(),
"{} produced empty reads_raw()",
op.variant_name()
);
assert!(
!op.writes_raw().is_empty(),
"{} produced empty writes_raw()",
op.variant_name()
);
}
}
#[test]
fn in_place_ops_appear_in_both_reads_and_writes() {
let g = Graph {
ops: vec![
Op::RmsNormQkNTokens {
label: "qk",
x: buf::<ConvOutBuf>(2),
num_k_heads: 4,
key_dim: 128,
key_offset_per_token: 512,
per_token_total: 1024,
n_tokens: 8,
},
Op::MoeNormalizeWeights {
label: "moe_norm",
weights: buf::<RouterWeightsBuf>(20),
n_tokens: 8,
k: 8,
},
Op::Conv1dStepNTokens {
label: "conv1d",
qkv_in: buf::<QkvStackBuf>(35),
conv_state: buf::<ConvStateBuf>(36),
weight_off: 0,
conv_out: buf::<ConvOutBuf>(37),
conv_dim: 5120,
n_tokens: 8,
},
],
};
assert!(g.ops[0].reads_raw().contains(&2));
assert!(g.ops[0].writes_raw().contains(&2));
assert!(g.ops[1].reads_raw().contains(&20));
assert!(g.ops[1].writes_raw().contains(&20));
assert!(g.ops[2].reads_raw().contains(&36));
assert!(g.ops[2].writes_raw().contains(&36));
}
#[test]
fn dump_emits_one_line_per_op() {
let g = one_of_each();
let dump = g.dump();
let line_count = dump.lines().count();
assert_eq!(line_count, 19);
assert!(dump.contains("RmsNormBf16NTokens"));
assert!(dump.contains("rms_in"));
assert!(dump.contains("MoeBatchedPermuteFuse"));
assert!(dump.contains("moe_pf"));
}
#[test]
fn dump_snapshot_tiny_graph() {
let mut g = Graph::new();
g.push(Op::RmsNormBf16NTokens {
label: "rms_in",
x: buf::<EmbedOutBuf>(0).into(),
weight_off: 0,
out: buf::<AttnInputBuf>(1).into(),
dim: 64,
n_tokens: 2,
eps: 1e-6,
});
g.push(Op::ResidualAddNTokens {
label: "resid",
a: buf::<OProjOutBuf>(1),
b: buf::<EmbedOutBuf>(0).into(),
out: buf::<ResidualBuf>(2),
n_tokens: 2,
dim: 64,
});
let expected = concat!(
" 0 RmsNormBf16NTokens rms_in\n",
" 1 ResidualAddNTokens resid\n",
);
assert_eq!(g.dump(), expected);
}
#[test]
fn bufid_display_uses_percent_prefix() {
let id: BufId<MoeInputBuf> = BufId::from_raw(42);
assert_eq!(format!("{id}"), "%42");
}
}