pub trait Backend:
Send
+ Sync
+ Sized
+ 'static {
type Buffer: Send + Sync;
type Context;
type GptqStore: Send + Sync;
Show 35 methods
// Required methods
fn new_context() -> Self::Context;
fn sync(ctx: &mut Self::Context);
fn gemm(
ctx: &mut Self::Context,
a: &Self::Buffer,
b: &Self::Buffer,
out: &mut Self::Buffer,
m: usize,
n: usize,
k: usize,
);
fn rms_norm(
ctx: &mut Self::Context,
x: &Self::Buffer,
w: &Self::Buffer,
eps: f32,
out: &mut Self::Buffer,
tokens: usize,
dim: usize,
);
fn fused_add_rms_norm(
ctx: &mut Self::Context,
residual: &mut Self::Buffer,
x: &Self::Buffer,
w: &Self::Buffer,
eps: f32,
out: &mut Self::Buffer,
tokens: usize,
dim: usize,
);
fn flash_attention(
ctx: &mut Self::Context,
q: &Self::Buffer,
k: &Self::Buffer,
v: &Self::Buffer,
out: &mut Self::Buffer,
batch: usize,
q_len: usize,
kv_len: usize,
pos_offset: usize,
cfg: &AttnConfig,
);
fn copy_slice(
ctx: &mut Self::Context,
src: &Self::Buffer,
src_offset: usize,
dst: &mut Self::Buffer,
dst_offset: usize,
len: usize,
);
fn embedding_lookup(
ctx: &mut Self::Context,
table: &Self::Buffer,
ids: &[u32],
out: &mut Self::Buffer,
dim: usize,
);
fn split_qkv(
ctx: &mut Self::Context,
qkv: &Self::Buffer,
q: &mut Self::Buffer,
k: &mut Self::Buffer,
v: &mut Self::Buffer,
tokens: usize,
q_dim: usize,
kv_dim: usize,
);
fn fused_silu_mul_split(
ctx: &mut Self::Context,
gate_up: &Self::Buffer,
out: &mut Self::Buffer,
tokens: usize,
im: usize,
);
fn qk_norm_rope(
ctx: &mut Self::Context,
input: &Self::Buffer,
norm_w: &Self::Buffer,
cos: &Self::Buffer,
sin: &Self::Buffer,
output: &mut Self::Buffer,
tokens: usize,
heads: usize,
head_dim: usize,
pos_offset: usize,
eps: f32,
mode: i32,
);
fn kv_cache_append_head_major(
ctx: &mut Self::Context,
cache_k: &mut Self::Buffer,
cache_v: &mut Self::Buffer,
cache_len: usize,
cache_capacity: usize,
new_k_head_major: &Self::Buffer,
new_v_head_major: &Self::Buffer,
new_tokens: usize,
nkv: usize,
hd: usize,
);
fn transpose_head_to_token(
ctx: &mut Self::Context,
src: &Self::Buffer,
dst: &mut Self::Buffer,
tokens: usize,
heads: usize,
dim: usize,
);
fn add_inplace(
ctx: &mut Self::Context,
residual: &mut Self::Buffer,
x: &Self::Buffer,
len: usize,
);
fn add_bias(
ctx: &mut Self::Context,
data: &mut Self::Buffer,
bias: &Self::Buffer,
rows: usize,
cols: usize,
);
fn layer_norm(
ctx: &mut Self::Context,
x: &Self::Buffer,
gamma: &Self::Buffer,
beta: &Self::Buffer,
eps: f32,
out: &mut Self::Buffer,
tokens: usize,
dim: usize,
);
fn gelu(
ctx: &mut Self::Context,
x: &Self::Buffer,
out: &mut Self::Buffer,
len: usize,
);
fn alloc(len: usize) -> Self::Buffer;
fn to_vec(buf: &Self::Buffer, len: usize) -> Vec<f32>;
fn from_slice(data: &[f32]) -> Self::Buffer;
// Provided methods
fn set_decode_state(_ctx: &mut Self::Context, _token: u32, _step: u32) { ... }
fn set_dev_state_mode(_ctx: &mut Self::Context, _enable: bool) { ... }
fn begin_graph_capture(_ctx: &mut Self::Context) -> Result<()> { ... }
fn end_graph_capture(_ctx: &mut Self::Context) -> Result<()> { ... }
fn replay_last_graph(_ctx: &mut Self::Context) -> Result<bool> { ... }
fn reset_graph(_ctx: &mut Self::Context) { ... }
fn load_gptq(
_qweight: &[i32],
_scales: &[f32],
_qzeros: &[i32],
_g_idx: Option<&[i32]>,
_bits: u32,
_group_size: usize,
_k: usize,
_n: usize,
) -> Result<Self::GptqStore> { ... }
fn gemm_gptq(
_ctx: &mut Self::Context,
_a: &Self::Buffer,
_weight: &Self::GptqStore,
_out: &mut Self::Buffer,
_m: usize,
) -> Result<()> { ... }
fn mla_attention(
_ctx: &mut Self::Context,
_q: &Self::Buffer,
_kv_compressed: &Self::Buffer,
_kv_rope: &Self::Buffer,
_out: &mut Self::Buffer,
_batch: usize,
_q_len: usize,
_kv_len: usize,
_pos_offset: usize,
_cfg: &AttnConfig,
_kv_lora_rank: usize,
_qk_rope_head_dim: usize,
) -> Result<()> { ... }
fn gemm_quant(
_ctx: &mut Self::Context,
_a: &Self::Buffer,
_weights: &QuantWeights<'_, Self>,
_out: &mut Self::Buffer,
_m: usize,
_n: usize,
_k: usize,
kind: &QuantKind,
) -> Result<()> { ... }
fn world_size(_ctx: &Self::Context) -> usize { ... }
fn rank(_ctx: &Self::Context) -> usize { ... }
fn all_reduce(
_ctx: &mut Self::Context,
_buf: &mut Self::Buffer,
_len: usize,
_op: ReduceOp,
) { ... }
fn all_gather(
_ctx: &mut Self::Context,
_local: &Self::Buffer,
_global: &mut Self::Buffer,
_local_len: usize,
) { ... }
fn broadcast(
_ctx: &mut Self::Context,
_buf: &mut Self::Buffer,
_len: usize,
_src_rank: usize,
) { ... }
}Expand description
The core abstraction over CUDA / Metal / CPU.
Key design: operations take a &mut Self::Context which accumulates work.
- CPU: Context is
()— ops execute immediately. - Metal: Context is a
CommandBuffer— ops encode into it, flushed onsync(). - CUDA: Context is a
CudaStream— ops launch on the stream, synced onsync().
layer_forward passes the context through all ops in a layer.
ModelRunner calls sync() only when it needs results (e.g., reading logits).
Required Associated Types§
type Buffer: Send + Sync
Sourcetype Context
type Context
Execution context that accumulates GPU work.
- CPU:
()(no-op, ops execute inline) - Metal: wraps a CommandBuffer
- CUDA: wraps a CudaStream
Sourcetype GptqStore: Send + Sync
type GptqStore: Send + Sync
Opaque per-backend GPTQ weight representation.
- CPU: dequantized f32 weights (run as regular GEMM)
- Metal:
()— unsupported;gemm_gptqerrors - CUDA:
MarlinWeight— pre-repacked tiles + permuted scales
Each backend repacks raw GPTQ tensors (qweight/scales/qzeros, all i32/f16) into its preferred format at model load time, so inference doesn’t pay the repack cost per forward pass.
Required Methods§
Sourcefn new_context() -> Self::Context
fn new_context() -> Self::Context
Create a new execution context (begin accumulating work).
Sourcefn sync(ctx: &mut Self::Context)
fn sync(ctx: &mut Self::Context)
Flush accumulated work and wait for completion. CPU: no-op. Metal: commit + waitUntilCompleted. CUDA: stream sync.
fn gemm( ctx: &mut Self::Context, a: &Self::Buffer, b: &Self::Buffer, out: &mut Self::Buffer, m: usize, n: usize, k: usize, )
fn rms_norm( ctx: &mut Self::Context, x: &Self::Buffer, w: &Self::Buffer, eps: f32, out: &mut Self::Buffer, tokens: usize, dim: usize, )
fn fused_add_rms_norm( ctx: &mut Self::Context, residual: &mut Self::Buffer, x: &Self::Buffer, w: &Self::Buffer, eps: f32, out: &mut Self::Buffer, tokens: usize, dim: usize, )
fn flash_attention( ctx: &mut Self::Context, q: &Self::Buffer, k: &Self::Buffer, v: &Self::Buffer, out: &mut Self::Buffer, batch: usize, q_len: usize, kv_len: usize, pos_offset: usize, cfg: &AttnConfig, )
Sourcefn copy_slice(
ctx: &mut Self::Context,
src: &Self::Buffer,
src_offset: usize,
dst: &mut Self::Buffer,
dst_offset: usize,
len: usize,
)
fn copy_slice( ctx: &mut Self::Context, src: &Self::Buffer, src_offset: usize, dst: &mut Self::Buffer, dst_offset: usize, len: usize, )
Copy len floats from src[src_offset..] to dst[dst_offset..].
Needed for Qwen3Model::prefill to pluck the last token’s hidden state
out of residual[seq_len, h] without round-tripping through host RAM.
Backend::copy is the offset-free variant; copy_slice additionally
supports non-zero source and destination offsets.
fn embedding_lookup( ctx: &mut Self::Context, table: &Self::Buffer, ids: &[u32], out: &mut Self::Buffer, dim: usize, )
Sourcefn split_qkv(
ctx: &mut Self::Context,
qkv: &Self::Buffer,
q: &mut Self::Buffer,
k: &mut Self::Buffer,
v: &mut Self::Buffer,
tokens: usize,
q_dim: usize,
kv_dim: usize,
)
fn split_qkv( ctx: &mut Self::Context, qkv: &Self::Buffer, q: &mut Self::Buffer, k: &mut Self::Buffer, v: &mut Self::Buffer, tokens: usize, q_dim: usize, kv_dim: usize, )
Split fused QKV [tokens, q_dim+2*kv_dim] into separate Q, K, V buffers. Q: [tokens, q_dim], K: [tokens, kv_dim], V: [tokens, kv_dim]
Sourcefn fused_silu_mul_split(
ctx: &mut Self::Context,
gate_up: &Self::Buffer,
out: &mut Self::Buffer,
tokens: usize,
im: usize,
)
fn fused_silu_mul_split( ctx: &mut Self::Context, gate_up: &Self::Buffer, out: &mut Self::Buffer, tokens: usize, im: usize, )
Split fused gate_up [tokens, 2*im] into gate [tokens, im] and up [tokens, im], then compute SiLU(gate) * up → out [tokens, im].
Sourcefn qk_norm_rope(
ctx: &mut Self::Context,
input: &Self::Buffer,
norm_w: &Self::Buffer,
cos: &Self::Buffer,
sin: &Self::Buffer,
output: &mut Self::Buffer,
tokens: usize,
heads: usize,
head_dim: usize,
pos_offset: usize,
eps: f32,
mode: i32,
)
fn qk_norm_rope( ctx: &mut Self::Context, input: &Self::Buffer, norm_w: &Self::Buffer, cos: &Self::Buffer, sin: &Self::Buffer, output: &mut Self::Buffer, tokens: usize, heads: usize, head_dim: usize, pos_offset: usize, eps: f32, mode: i32, )
Fused QK-norm + RoPE + transpose-to-head-major.
mode selects the operation:
0 = transpose only (typical for V, which needs no norm and no RoPE)
1 = per-head RMS norm + RoPE + transpose (Q/K with QK-norm, Qwen3)
2 = RoPE + transpose (Q/K without QK-norm, Llama/Mistral)
input: [tokens, heads, head_dim] (token-major, output of split_qkv)
output: [heads, tokens, head_dim] (head-major, ready for flash_attn / kv_cache_append)
pos_offset is the position of token 0 (decode uses current seq len;
prefill uses 0). Within the batch, positions are taken as pos_offset + i.
This is the primary attention-input preparation op. Backends that have a
fused kernel (Metal’s qk_norm_rope_transpose_f32) will be dramatically
faster than composing norm + rope + transpose separately; the CPU
fallback lowers to the individual ops.
Sourcefn kv_cache_append_head_major(
ctx: &mut Self::Context,
cache_k: &mut Self::Buffer,
cache_v: &mut Self::Buffer,
cache_len: usize,
cache_capacity: usize,
new_k_head_major: &Self::Buffer,
new_v_head_major: &Self::Buffer,
new_tokens: usize,
nkv: usize,
hd: usize,
)
fn kv_cache_append_head_major( ctx: &mut Self::Context, cache_k: &mut Self::Buffer, cache_v: &mut Self::Buffer, cache_len: usize, cache_capacity: usize, new_k_head_major: &Self::Buffer, new_v_head_major: &Self::Buffer, new_tokens: usize, nkv: usize, hd: usize, )
Append new K/V into a pre-allocated head-major cache buffer.
cache_k / cache_v: [nkv, capacity, hd] (head-major, pre-allocated)
new_k_head_major / new_v_head_major: [nkv, new_tokens, hd]
— produced directly by qk_norm_rope, no extra transpose needed.
In-place append at slot [nkv, cache_len..cache_len+new_tokens, hd].
Caller owns cache_len bookkeeping.
Sourcefn transpose_head_to_token(
ctx: &mut Self::Context,
src: &Self::Buffer,
dst: &mut Self::Buffer,
tokens: usize,
heads: usize,
dim: usize,
)
fn transpose_head_to_token( ctx: &mut Self::Context, src: &Self::Buffer, dst: &mut Self::Buffer, tokens: usize, heads: usize, dim: usize, )
Transpose [heads, tokens, dim] → [tokens, heads, dim].
Called after flash_attention to restore token-major layout for O-proj.
Sourcefn add_inplace(
ctx: &mut Self::Context,
residual: &mut Self::Buffer,
x: &Self::Buffer,
len: usize,
)
fn add_inplace( ctx: &mut Self::Context, residual: &mut Self::Buffer, x: &Self::Buffer, len: usize, )
residual[i] += x[i] (in-place)
Sourcefn add_bias(
ctx: &mut Self::Context,
data: &mut Self::Buffer,
bias: &Self::Buffer,
rows: usize,
cols: usize,
)
fn add_bias( ctx: &mut Self::Context, data: &mut Self::Buffer, bias: &Self::Buffer, rows: usize, cols: usize, )
Broadcast bias add: data[r, c] += bias[c] for every row.
Required by Bert / Clip / Whisper whose linear projections carry a bias.
Sourcefn layer_norm(
ctx: &mut Self::Context,
x: &Self::Buffer,
gamma: &Self::Buffer,
beta: &Self::Buffer,
eps: f32,
out: &mut Self::Buffer,
tokens: usize,
dim: usize,
)
fn layer_norm( ctx: &mut Self::Context, x: &Self::Buffer, gamma: &Self::Buffer, beta: &Self::Buffer, eps: f32, out: &mut Self::Buffer, tokens: usize, dim: usize, )
Full LayerNorm (mean + variance normalisation + affine), distinct from
the rms_norm used by Llama-family decoders.
out[r, c] = ((x[r, c] - mean) / sqrt(var + eps)) * gamma[c] + beta[c]
Where mean and var are reduced over the last dim (cols).
Sourcefn gelu(
ctx: &mut Self::Context,
x: &Self::Buffer,
out: &mut Self::Buffer,
len: usize,
)
fn gelu( ctx: &mut Self::Context, x: &Self::Buffer, out: &mut Self::Buffer, len: usize, )
Element-wise GELU activation (erf-based, matches PyTorch default).
fn alloc(len: usize) -> Self::Buffer
fn to_vec(buf: &Self::Buffer, len: usize) -> Vec<f32>
fn from_slice(data: &[f32]) -> Self::Buffer
Provided Methods§
Sourcefn set_decode_state(_ctx: &mut Self::Context, _token: u32, _step: u32)
fn set_decode_state(_ctx: &mut Self::Context, _token: u32, _step: u32)
Update per-step dynamic state (token id, step/pos). Fast (3x memcpy).
Sourcefn set_dev_state_mode(_ctx: &mut Self::Context, _enable: bool)
fn set_dev_state_mode(_ctx: &mut Self::Context, _enable: bool)
Toggle between scalar-arg kernels (normal) and _dyn kernels that
read their dynamic scalar args from device memory (graph-friendly).
Sourcefn begin_graph_capture(_ctx: &mut Self::Context) -> Result<()>
fn begin_graph_capture(_ctx: &mut Self::Context) -> Result<()>
Begin stream capture. Subsequent kernel launches are recorded into a pending graph instead of executing eagerly.
Sourcefn end_graph_capture(_ctx: &mut Self::Context) -> Result<()>
fn end_graph_capture(_ctx: &mut Self::Context) -> Result<()>
End stream capture and install the captured graph as this context’s
“last graph” for future replay_last_graph calls.
Sourcefn replay_last_graph(_ctx: &mut Self::Context) -> Result<bool>
fn replay_last_graph(_ctx: &mut Self::Context) -> Result<bool>
Replay the last captured graph. Returns Ok(false) if no graph
is cached; caller should run eager.
Sourcefn reset_graph(_ctx: &mut Self::Context)
fn reset_graph(_ctx: &mut Self::Context)
Drop the cached decode graph — required when the KV cache it was captured against is about to be freed (e.g. request release), since the graph holds raw device pointers into that cache.
Sourcefn load_gptq(
_qweight: &[i32],
_scales: &[f32],
_qzeros: &[i32],
_g_idx: Option<&[i32]>,
_bits: u32,
_group_size: usize,
_k: usize,
_n: usize,
) -> Result<Self::GptqStore>
fn load_gptq( _qweight: &[i32], _scales: &[f32], _qzeros: &[i32], _g_idx: Option<&[i32]>, _bits: u32, _group_size: usize, _k: usize, _n: usize, ) -> Result<Self::GptqStore>
Repack raw GPTQ tensors into the backend’s preferred format. Called once per layer at model load time.
Inputs are host-side slices (CPU memory) — the loader reads from
safetensors and hands them off; each backend uploads + repacks
per its own strategy. bits is typically 4; group_size is
typically 128.
Sourcefn gemm_gptq(
_ctx: &mut Self::Context,
_a: &Self::Buffer,
_weight: &Self::GptqStore,
_out: &mut Self::Buffer,
_m: usize,
) -> Result<()>
fn gemm_gptq( _ctx: &mut Self::Context, _a: &Self::Buffer, _weight: &Self::GptqStore, _out: &mut Self::Buffer, _m: usize, ) -> Result<()>
GEMM with pre-loaded GPTQ weights.
out[m, n] = a[m, k] @ dequant(weight)^T
Sourcefn mla_attention(
_ctx: &mut Self::Context,
_q: &Self::Buffer,
_kv_compressed: &Self::Buffer,
_kv_rope: &Self::Buffer,
_out: &mut Self::Buffer,
_batch: usize,
_q_len: usize,
_kv_len: usize,
_pos_offset: usize,
_cfg: &AttnConfig,
_kv_lora_rank: usize,
_qk_rope_head_dim: usize,
) -> Result<()>
fn mla_attention( _ctx: &mut Self::Context, _q: &Self::Buffer, _kv_compressed: &Self::Buffer, _kv_rope: &Self::Buffer, _out: &mut Self::Buffer, _batch: usize, _q_len: usize, _kv_len: usize, _pos_offset: usize, _cfg: &AttnConfig, _kv_lora_rank: usize, _qk_rope_head_dim: usize, ) -> Result<()>
Multi-Head Latent Attention — DeepSeek V2 / V3’s compressed-KV attention variant. Extension point only; no backend implements it yet. DeepSeek V3 landing in Phase D/E will fill this in.
q: full Q [batch, num_heads, q_len, head_dim]
kv_compressed: latent KV [batch, kv_len, kv_lora_rank]
kv_rope: per-position rope-applied key heads [batch, kv_len, qk_rope_head_dim]
out: [batch, num_heads, q_len, head_dim]
Sourcefn gemm_quant(
_ctx: &mut Self::Context,
_a: &Self::Buffer,
_weights: &QuantWeights<'_, Self>,
_out: &mut Self::Buffer,
_m: usize,
_n: usize,
_k: usize,
kind: &QuantKind,
) -> Result<()>
fn gemm_quant( _ctx: &mut Self::Context, _a: &Self::Buffer, _weights: &QuantWeights<'_, Self>, _out: &mut Self::Buffer, _m: usize, _n: usize, _k: usize, kind: &QuantKind, ) -> Result<()>
GEMM with packed-quantized B matrix. m/n/k describe the dense
equivalent ([m,n] = [m,k] @ [k,n]^T).
fn world_size(_ctx: &Self::Context) -> usize
fn rank(_ctx: &Self::Context) -> usize
fn all_reduce( _ctx: &mut Self::Context, _buf: &mut Self::Buffer, _len: usize, _op: ReduceOp, )
fn all_gather( _ctx: &mut Self::Context, _local: &Self::Buffer, _global: &mut Self::Buffer, _local_len: usize, )
fn broadcast( _ctx: &mut Self::Context, _buf: &mut Self::Buffer, _len: usize, _src_rank: usize, )
Dyn Compatibility§
This trait is not dyn compatible.
In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.