pub trait Backend:
Send
+ Sync
+ Sized
+ 'static {
type Buffer: Send + Sync;
type Context;
type Timer: BackendTimer<Self>;
Show 38 methods
// Required methods
fn make_timer() -> Self::Timer;
fn new_context() -> Self::Context;
fn sync(ctx: &mut Self::Context);
fn alloc_typed(dtype: Dtype, n: usize) -> Self::Buffer;
fn from_slice_typed<T: HostDtype>(data: &[T]) -> Self::Buffer;
fn write_typed<T: HostDtype>(
ctx: &mut Self::Context,
dst: &mut Self::Buffer,
data: &[T],
);
fn gemm(
ctx: &mut Self::Context,
a: &Self::Buffer,
b: &Self::Buffer,
out: &mut Self::Buffer,
m: usize,
n: usize,
k: usize,
);
fn rms_norm(
ctx: &mut Self::Context,
x: &Self::Buffer,
w: &Self::Buffer,
eps: f32,
out: &mut Self::Buffer,
tokens: usize,
dim: usize,
);
fn fused_add_rms_norm(
ctx: &mut Self::Context,
residual: &mut Self::Buffer,
x: &Self::Buffer,
w: &Self::Buffer,
eps: f32,
out: &mut Self::Buffer,
tokens: usize,
dim: usize,
);
fn flash_attention(
ctx: &mut Self::Context,
q: &Self::Buffer,
k: &Self::Buffer,
v: &Self::Buffer,
out: &mut Self::Buffer,
batch: usize,
q_len: usize,
kv_len: usize,
pos_offset: usize,
cfg: &AttnConfig,
);
fn copy_slice(
ctx: &mut Self::Context,
src: &Self::Buffer,
src_offset: usize,
dst: &mut Self::Buffer,
dst_offset: usize,
len: usize,
);
fn embedding_lookup(
ctx: &mut Self::Context,
table: &Self::Buffer,
ids: &[u32],
out: &mut Self::Buffer,
dim: usize,
);
fn split_qkv(
ctx: &mut Self::Context,
qkv: &Self::Buffer,
q: &mut Self::Buffer,
k: &mut Self::Buffer,
v: &mut Self::Buffer,
tokens: usize,
q_dim: usize,
kv_dim: usize,
);
fn fused_silu_mul_split(
ctx: &mut Self::Context,
gate_up: &Self::Buffer,
out: &mut Self::Buffer,
tokens: usize,
im: usize,
);
fn qk_norm_rope(
ctx: &mut Self::Context,
input: &Self::Buffer,
norm_w: &Self::Buffer,
cos: &Self::Buffer,
sin: &Self::Buffer,
output: &mut Self::Buffer,
tokens: usize,
heads: usize,
head_dim: usize,
pos_offset: usize,
eps: f32,
mode: i32,
);
fn kv_cache_append_head_major(
ctx: &mut Self::Context,
cache_k: &mut Self::Buffer,
cache_v: &mut Self::Buffer,
cache_len: usize,
cache_capacity: usize,
new_k_head_major: &Self::Buffer,
new_v_head_major: &Self::Buffer,
new_tokens: usize,
nkv: usize,
hd: usize,
);
fn transpose_head_to_token(
ctx: &mut Self::Context,
src: &Self::Buffer,
dst: &mut Self::Buffer,
tokens: usize,
heads: usize,
dim: usize,
);
fn add_inplace(
ctx: &mut Self::Context,
residual: &mut Self::Buffer,
x: &Self::Buffer,
len: usize,
);
fn add_bias(
ctx: &mut Self::Context,
data: &mut Self::Buffer,
bias: &Self::Buffer,
rows: usize,
cols: usize,
);
fn layer_norm(
ctx: &mut Self::Context,
x: &Self::Buffer,
gamma: &Self::Buffer,
beta: &Self::Buffer,
eps: f32,
out: &mut Self::Buffer,
tokens: usize,
dim: usize,
);
fn gelu(
ctx: &mut Self::Context,
x: &Self::Buffer,
out: &mut Self::Buffer,
len: usize,
);
fn alloc(len: usize) -> Self::Buffer;
fn to_vec(buf: &Self::Buffer, len: usize) -> Vec<f32>;
fn from_slice(data: &[f32]) -> Self::Buffer;
// Provided methods
fn is_metal_backend() -> bool { ... }
fn zero_buffer(
_ctx: &mut Self::Context,
_buf: &mut Self::Buffer,
_len: usize,
) -> Result<()> { ... }
fn mla_attention(
_ctx: &mut Self::Context,
_q: &Self::Buffer,
_kv_compressed: &Self::Buffer,
_kv_rope: &Self::Buffer,
_out: &mut Self::Buffer,
_batch: usize,
_q_len: usize,
_kv_len: usize,
_pos_offset: usize,
_cfg: &AttnConfig,
_kv_lora_rank: usize,
_qk_rope_head_dim: usize,
) -> Result<()> { ... }
fn embedding_lookup_dev(
ctx: &mut Self::Context,
table: &Self::Buffer,
ids: &Self::Buffer,
out: &mut Self::Buffer,
batch: usize,
dim: usize,
) { ... }
fn kv_cache_append_batched_per_cache(
_ctx: &mut Self::Context,
_caches: &[&Self::Buffer],
_new_data: &Self::Buffer,
_cache_lens: &Self::Buffer,
_capacity: usize,
_m: usize,
_nkv: usize,
_hd: usize,
_slot: usize,
) -> Result<()> { ... }
fn flash_attention_batched_per_cache(
_ctx: &mut Self::Context,
_q: &Self::Buffer,
_k_caches: &[&Self::Buffer],
_v_caches: &[&Self::Buffer],
_kv_lens: &Self::Buffer,
_out: &mut Self::Buffer,
_nq: usize,
_nkv: usize,
_hd: usize,
_scale: f32,
_max_valid_kv: usize,
_capacity: usize,
_slot: usize,
) -> Result<()> { ... }
fn qk_norm_rope_batched_per_item(
_ctx: &mut Self::Context,
_input: &Self::Buffer,
_norm_w: &Self::Buffer,
_cos: &Self::Buffer,
_sin: &Self::Buffer,
_output: &mut Self::Buffer,
_positions: &Self::Buffer,
_m: usize,
_heads: usize,
_head_dim: usize,
_eps: f32,
_mode: i32,
) -> Result<()> { ... }
fn split_qkv_norm_rope(
_ctx: &mut Self::Context,
_qkv: &Self::Buffer,
_q_norm_w: &Self::Buffer,
_k_norm_w: &Self::Buffer,
_cos: &Self::Buffer,
_sin: &Self::Buffer,
_q_out: &mut Self::Buffer,
_k_out: &mut Self::Buffer,
_v_out: &mut Self::Buffer,
_tokens: usize,
_q_heads: usize,
_kv_heads: usize,
_head_dim: usize,
_pos_offset: usize,
_eps: f32,
_qk_mode: i32,
) -> Result<()> { ... }
fn split_qkv_norm_rope_into_cache(
_ctx: &mut Self::Context,
_qkv: &Self::Buffer,
_q_norm_w: &Self::Buffer,
_k_norm_w: &Self::Buffer,
_cos: &Self::Buffer,
_sin: &Self::Buffer,
_q_out: &mut Self::Buffer,
_cache_k: &mut Self::Buffer,
_cache_v: &mut Self::Buffer,
_tokens: usize,
_q_heads: usize,
_kv_heads: usize,
_head_dim: usize,
_pos_offset: usize,
_eps: f32,
_qk_mode: i32,
_cache_len: usize,
_cache_capacity: usize,
) -> Result<()> { ... }
fn transpose_token_to_head(
_ctx: &mut Self::Context,
_src: &Self::Buffer,
_dst: &mut Self::Buffer,
_tokens: usize,
_heads: usize,
_dim: usize,
) { ... }
fn scaled_add_inplace(
_ctx: &mut Self::Context,
dst: &mut Self::Buffer,
src: &Self::Buffer,
scale: f32,
len: usize,
) { ... }
fn fused_silu_mul_split_strided(
_ctx: &mut Self::Context,
_gate_up: &Self::Buffer,
_in_row_offset: usize,
_out: &mut Self::Buffer,
_out_row_offset: usize,
_tokens: usize,
_intermediate: usize,
) { ... }
fn argmax_rows_f16(
_ctx: &mut Self::Context,
logits: &Self::Buffer,
m: usize,
n: usize,
) -> Result<Vec<u32>> { ... }
fn from_weight_bytes(raw: &[u8], src_dtype: SrcDtype) -> Self::Buffer { ... }
}Expand description
The core abstraction over CUDA / Metal / CPU.
Key design: operations take a &mut Self::Context which accumulates work.
- CPU: Context is
()— ops execute immediately. - Metal: Context is a
CommandBuffer— ops encode into it, flushed onsync(). - CUDA: Context is a
CudaStream— ops launch on the stream, synced onsync().
layer_forward passes the context through all ops in a layer.
ModelRunner calls sync() only when it needs results (e.g., reading logits).
Required Associated Types§
type Buffer: Send + Sync
Sourcetype Context
type Context
Execution context that accumulates GPU work.
- CPU:
()(no-op, ops execute inline) - Metal: wraps a CommandBuffer
- CUDA: wraps a CudaStream
Sourcetype Timer: BackendTimer<Self>
type Timer: BackendTimer<Self>
GPU-side timer scoped to this backend. See super::timer —
CPU: Instant; Metal: sync-wrap; CUDA: cuEvent.
PLAYBOOK § 1.1.
Required Methods§
Sourcefn make_timer() -> Self::Timer
fn make_timer() -> Self::Timer
Factory for Self::Timer — exists so call sites that have a
<B: Backend> parameter can spawn a timer without importing the
concrete impl. PLAYBOOK § 1.2.
Sourcefn new_context() -> Self::Context
fn new_context() -> Self::Context
Opaque per-backend GPTQ weight representation.
- CPU: dequantized f32 weights (run as regular GEMM)
- Metal:
()— unsupported;gemm_gptqerrors Create a new execution context (begin accumulating work).
Sourcefn sync(ctx: &mut Self::Context)
fn sync(ctx: &mut Self::Context)
Flush accumulated work and wait for completion. CPU: no-op. Metal: commit + waitUntilCompleted. CUDA: stream sync.
Sourcefn alloc_typed(dtype: Dtype, n: usize) -> Self::Buffer
fn alloc_typed(dtype: Dtype, n: usize) -> Self::Buffer
Phase D step 2+3: unified typed allocator. Replaces per-dtype
alloc_u32 / alloc_typed_i32 / etc. The buffer is dtype-
tagged at the wrapper level (CudaBuf::U32, MetalBuf with
Dtype::U32, CpuBuf::U32), so reads/writes through .as_<T>()
accessors get the correct byte count automatically.
Sourcefn from_slice_typed<T: HostDtype>(data: &[T]) -> Self::Buffer
fn from_slice_typed<T: HostDtype>(data: &[T]) -> Self::Buffer
Upload typed host data — replaces from_slice_i32 /
from_slice_u32 etc. The host element type T carries its
Dtype via the HostDtype marker so dispatch in the impl
is a one-line match T::DTYPE.
Sourcefn write_typed<T: HostDtype>(
ctx: &mut Self::Context,
dst: &mut Self::Buffer,
data: &[T],
)
fn write_typed<T: HostDtype>( ctx: &mut Self::Context, dst: &mut Self::Buffer, data: &[T], )
In-place typed write — replaces write_u32 / write_i32_into
/ write_f32_into. The buffer must already be dtype-tagged
matching T::DTYPE (typically alloc’d via alloc_typed or
from_slice_typed).
fn gemm( ctx: &mut Self::Context, a: &Self::Buffer, b: &Self::Buffer, out: &mut Self::Buffer, m: usize, n: usize, k: usize, )
fn rms_norm( ctx: &mut Self::Context, x: &Self::Buffer, w: &Self::Buffer, eps: f32, out: &mut Self::Buffer, tokens: usize, dim: usize, )
fn fused_add_rms_norm( ctx: &mut Self::Context, residual: &mut Self::Buffer, x: &Self::Buffer, w: &Self::Buffer, eps: f32, out: &mut Self::Buffer, tokens: usize, dim: usize, )
fn flash_attention( ctx: &mut Self::Context, q: &Self::Buffer, k: &Self::Buffer, v: &Self::Buffer, out: &mut Self::Buffer, batch: usize, q_len: usize, kv_len: usize, pos_offset: usize, cfg: &AttnConfig, )
Sourcefn copy_slice(
ctx: &mut Self::Context,
src: &Self::Buffer,
src_offset: usize,
dst: &mut Self::Buffer,
dst_offset: usize,
len: usize,
)
fn copy_slice( ctx: &mut Self::Context, src: &Self::Buffer, src_offset: usize, dst: &mut Self::Buffer, dst_offset: usize, len: usize, )
Copy len floats from src[src_offset..] to dst[dst_offset..].
Needed for Qwen3Model::prefill to pluck the last token’s hidden state
out of residual[seq_len, h] without round-tripping through host RAM.
Backend::copy is the offset-free variant; copy_slice additionally
supports non-zero source and destination offsets.
fn embedding_lookup( ctx: &mut Self::Context, table: &Self::Buffer, ids: &[u32], out: &mut Self::Buffer, dim: usize, )
Sourcefn split_qkv(
ctx: &mut Self::Context,
qkv: &Self::Buffer,
q: &mut Self::Buffer,
k: &mut Self::Buffer,
v: &mut Self::Buffer,
tokens: usize,
q_dim: usize,
kv_dim: usize,
)
fn split_qkv( ctx: &mut Self::Context, qkv: &Self::Buffer, q: &mut Self::Buffer, k: &mut Self::Buffer, v: &mut Self::Buffer, tokens: usize, q_dim: usize, kv_dim: usize, )
Split fused QKV [tokens, q_dim+2*kv_dim] into separate Q, K, V buffers. Q: [tokens, q_dim], K: [tokens, kv_dim], V: [tokens, kv_dim]
Sourcefn fused_silu_mul_split(
ctx: &mut Self::Context,
gate_up: &Self::Buffer,
out: &mut Self::Buffer,
tokens: usize,
im: usize,
)
fn fused_silu_mul_split( ctx: &mut Self::Context, gate_up: &Self::Buffer, out: &mut Self::Buffer, tokens: usize, im: usize, )
Split fused gate_up [tokens, 2*im] into gate [tokens, im] and up [tokens, im], then compute SiLU(gate) * up → out [tokens, im].
Sourcefn qk_norm_rope(
ctx: &mut Self::Context,
input: &Self::Buffer,
norm_w: &Self::Buffer,
cos: &Self::Buffer,
sin: &Self::Buffer,
output: &mut Self::Buffer,
tokens: usize,
heads: usize,
head_dim: usize,
pos_offset: usize,
eps: f32,
mode: i32,
)
fn qk_norm_rope( ctx: &mut Self::Context, input: &Self::Buffer, norm_w: &Self::Buffer, cos: &Self::Buffer, sin: &Self::Buffer, output: &mut Self::Buffer, tokens: usize, heads: usize, head_dim: usize, pos_offset: usize, eps: f32, mode: i32, )
Fused QK-norm + RoPE + transpose-to-head-major.
mode selects the operation:
0 = transpose only (typical for V, which needs no norm and no RoPE)
1 = per-head RMS norm + RoPE + transpose (Q/K with QK-norm, Qwen3)
2 = RoPE + transpose (Q/K without QK-norm, Llama/Mistral)
input: [tokens, heads, head_dim] (token-major, output of split_qkv)
output: [heads, tokens, head_dim] (head-major, ready for flash_attn / kv_cache_append)
pos_offset is the position of token 0 (decode uses current seq len;
prefill uses 0). Within the batch, positions are taken as pos_offset + i.
This is the primary attention-input preparation op. Backends that have a
fused kernel (Metal’s qk_norm_rope_transpose_f32) will be dramatically
faster than composing norm + rope + transpose separately; the CPU
fallback lowers to the individual ops.
Sourcefn kv_cache_append_head_major(
ctx: &mut Self::Context,
cache_k: &mut Self::Buffer,
cache_v: &mut Self::Buffer,
cache_len: usize,
cache_capacity: usize,
new_k_head_major: &Self::Buffer,
new_v_head_major: &Self::Buffer,
new_tokens: usize,
nkv: usize,
hd: usize,
)
fn kv_cache_append_head_major( ctx: &mut Self::Context, cache_k: &mut Self::Buffer, cache_v: &mut Self::Buffer, cache_len: usize, cache_capacity: usize, new_k_head_major: &Self::Buffer, new_v_head_major: &Self::Buffer, new_tokens: usize, nkv: usize, hd: usize, )
Append new K/V into a pre-allocated head-major cache buffer.
cache_k / cache_v: [nkv, capacity, hd] (head-major, pre-allocated)
new_k_head_major / new_v_head_major: [nkv, new_tokens, hd]
— produced directly by qk_norm_rope, no extra transpose needed.
In-place append at slot [nkv, cache_len..cache_len+new_tokens, hd].
Caller owns cache_len bookkeeping.
Sourcefn transpose_head_to_token(
ctx: &mut Self::Context,
src: &Self::Buffer,
dst: &mut Self::Buffer,
tokens: usize,
heads: usize,
dim: usize,
)
fn transpose_head_to_token( ctx: &mut Self::Context, src: &Self::Buffer, dst: &mut Self::Buffer, tokens: usize, heads: usize, dim: usize, )
Transpose [heads, tokens, dim] → [tokens, heads, dim].
Called after flash_attention to restore token-major layout for O-proj.
Sourcefn add_inplace(
ctx: &mut Self::Context,
residual: &mut Self::Buffer,
x: &Self::Buffer,
len: usize,
)
fn add_inplace( ctx: &mut Self::Context, residual: &mut Self::Buffer, x: &Self::Buffer, len: usize, )
residual[i] += x[i] (in-place)
Sourcefn add_bias(
ctx: &mut Self::Context,
data: &mut Self::Buffer,
bias: &Self::Buffer,
rows: usize,
cols: usize,
)
fn add_bias( ctx: &mut Self::Context, data: &mut Self::Buffer, bias: &Self::Buffer, rows: usize, cols: usize, )
Broadcast bias add: data[r, c] += bias[c] for every row.
Required by Bert / Clip / Whisper whose linear projections carry a bias.
Sourcefn layer_norm(
ctx: &mut Self::Context,
x: &Self::Buffer,
gamma: &Self::Buffer,
beta: &Self::Buffer,
eps: f32,
out: &mut Self::Buffer,
tokens: usize,
dim: usize,
)
fn layer_norm( ctx: &mut Self::Context, x: &Self::Buffer, gamma: &Self::Buffer, beta: &Self::Buffer, eps: f32, out: &mut Self::Buffer, tokens: usize, dim: usize, )
Full LayerNorm (mean + variance normalisation + affine), distinct from
the rms_norm used by Llama-family decoders.
out[r, c] = ((x[r, c] - mean) / sqrt(var + eps)) * gamma[c] + beta[c]
Where mean and var are reduced over the last dim (cols).
Sourcefn gelu(
ctx: &mut Self::Context,
x: &Self::Buffer,
out: &mut Self::Buffer,
len: usize,
)
fn gelu( ctx: &mut Self::Context, x: &Self::Buffer, out: &mut Self::Buffer, len: usize, )
Element-wise GELU activation (erf-based, matches PyTorch default).
fn alloc(len: usize) -> Self::Buffer
fn to_vec(buf: &Self::Buffer, len: usize) -> Vec<f32>
fn from_slice(data: &[f32]) -> Self::Buffer
Provided Methods§
Sourcefn is_metal_backend() -> bool
fn is_metal_backend() -> bool
True for the Apple Metal backend.
Keep this as a backend capability instead of matching on type names in model code. It is used only for backend-specific safety fallbacks where a generic optimized path is known to be incorrect on one backend.
Sourcefn zero_buffer(
_ctx: &mut Self::Context,
_buf: &mut Self::Buffer,
_len: usize,
) -> Result<()>
fn zero_buffer( _ctx: &mut Self::Context, _buf: &mut Self::Buffer, _len: usize, ) -> Result<()>
Zero the first len elements of a Self::Buffer. CUDA path uses
cuMemsetD16Async; default returns unsupported.
Sourcefn mla_attention(
_ctx: &mut Self::Context,
_q: &Self::Buffer,
_kv_compressed: &Self::Buffer,
_kv_rope: &Self::Buffer,
_out: &mut Self::Buffer,
_batch: usize,
_q_len: usize,
_kv_len: usize,
_pos_offset: usize,
_cfg: &AttnConfig,
_kv_lora_rank: usize,
_qk_rope_head_dim: usize,
) -> Result<()>
fn mla_attention( _ctx: &mut Self::Context, _q: &Self::Buffer, _kv_compressed: &Self::Buffer, _kv_rope: &Self::Buffer, _out: &mut Self::Buffer, _batch: usize, _q_len: usize, _kv_len: usize, _pos_offset: usize, _cfg: &AttnConfig, _kv_lora_rank: usize, _qk_rope_head_dim: usize, ) -> Result<()>
Multi-Head Latent Attention — DeepSeek V2 / V3’s compressed-KV attention variant. Extension point only; no backend implements it yet. DeepSeek V3 landing in Phase D/E will fill this in.
q: full Q [batch, num_heads, q_len, head_dim]
kv_compressed: latent KV [batch, kv_len, kv_lora_rank]
kv_rope: per-position rope-applied key heads [batch, kv_len, qk_rope_head_dim]
out: [batch, num_heads, q_len, head_dim]
Sourcefn embedding_lookup_dev(
ctx: &mut Self::Context,
table: &Self::Buffer,
ids: &Self::Buffer,
out: &mut Self::Buffer,
batch: usize,
dim: usize,
)
fn embedding_lookup_dev( ctx: &mut Self::Context, table: &Self::Buffer, ids: &Self::Buffer, out: &mut Self::Buffer, batch: usize, dim: usize, )
Device-buffer variant of embedding_lookup for graph-capturable
MoE routing — the gather step before phase-1 GEMM in
moe_forward_bucketed. The host-slice embedding_lookup does
clone_htod(ids) internally, which records stale host pointers
under CUDA Graph capture replay.
ids: &Self::Buffer must be a device I32 buffer of batch
elements (e.g. Qwen3MoeScratch::route_packed_idx_dev).
batch is passed explicitly since a typed CudaBuf carries
its element count but the caller often wants a partial gather.
Default impl: round-trip via to_vec + dispatch the host-slice
variant. CUDA overrides.
Sourcefn kv_cache_append_batched_per_cache(
_ctx: &mut Self::Context,
_caches: &[&Self::Buffer],
_new_data: &Self::Buffer,
_cache_lens: &Self::Buffer,
_capacity: usize,
_m: usize,
_nkv: usize,
_hd: usize,
_slot: usize,
) -> Result<()>
fn kv_cache_append_batched_per_cache( _ctx: &mut Self::Context, _caches: &[&Self::Buffer], _new_data: &Self::Buffer, _cache_lens: &Self::Buffer, _capacity: usize, _m: usize, _nkv: usize, _hd: usize, _slot: usize, ) -> Result<()>
Batched kv_cache_append across M caches in one launch. Each item
writes its (head-major) K-or-V row into its own cache at offset
read from cache_lens[i]. Replaces M sequential
kv_cache_append_head_major calls with a single dispatch.
new_data layout: [m, nkv, hd] item-major (each item’s slice
is contiguous, identical to the k/v_normed_batched produced by
qk_norm_rope_batched_per_item).
caches: per-cache [nkv, capacity, hd] head-major.
cache_lens: device buffer (u32 storage, length ≥ m). Caller
fills via B::write_u32_into BEFORE the call. Required for
CUDA-graph capture: the kernel reads from this stable device
buffer, so a captured graph can be replayed with new lens by
just rewriting the buffer between launches.
Sourcefn flash_attention_batched_per_cache(
_ctx: &mut Self::Context,
_q: &Self::Buffer,
_k_caches: &[&Self::Buffer],
_v_caches: &[&Self::Buffer],
_kv_lens: &Self::Buffer,
_out: &mut Self::Buffer,
_nq: usize,
_nkv: usize,
_hd: usize,
_scale: f32,
_max_valid_kv: usize,
_capacity: usize,
_slot: usize,
) -> Result<()>
fn flash_attention_batched_per_cache( _ctx: &mut Self::Context, _q: &Self::Buffer, _k_caches: &[&Self::Buffer], _v_caches: &[&Self::Buffer], _kv_lens: &Self::Buffer, _out: &mut Self::Buffer, _nq: usize, _nkv: usize, _hd: usize, _scale: f32, _max_valid_kv: usize, _capacity: usize, _slot: usize, ) -> Result<()>
Batched flash_attention across M decode caches in one launch.
Replaces the per-item flash_attention(q_len=1, ...) × M
loop in the non-paged batched-decode path.
API takes Vec<&Buffer> for the per-cache K/V buffers (each
[nkv, capacity, hd] head-major) plus host-side kv_lens.
Backends that implement it must extract per-cache device
pointers, build the device arrays the kernel needs, and launch
one kernel covering all M items.
q layout: [m, nq, hd] item-major (matches the
qk_norm_rope_batched_per_item output for q_len=1).
out layout: [m, nq, hd] item-major — written directly into
the caller’s batched attn_out buffer, no per-item copy needed.
CUDA-only for now (kernel batched_decode_attention exists in
kernels/batched_decode_attention.cu).
kv_lens: device buffer (u32 storage, length ≥ m) — same
design as kv_cache_append_batched_per_cache::cache_lens.
Sourcefn qk_norm_rope_batched_per_item(
_ctx: &mut Self::Context,
_input: &Self::Buffer,
_norm_w: &Self::Buffer,
_cos: &Self::Buffer,
_sin: &Self::Buffer,
_output: &mut Self::Buffer,
_positions: &Self::Buffer,
_m: usize,
_heads: usize,
_head_dim: usize,
_eps: f32,
_mode: i32,
) -> Result<()>
fn qk_norm_rope_batched_per_item( _ctx: &mut Self::Context, _input: &Self::Buffer, _norm_w: &Self::Buffer, _cos: &Self::Buffer, _sin: &Self::Buffer, _output: &mut Self::Buffer, _positions: &Self::Buffer, _m: usize, _heads: usize, _head_dim: usize, _eps: f32, _mode: i32, ) -> Result<()>
Batched per-item-position variant of qk_norm_rope for the
non-paged batched-decode path. Each of the m items has its own
absolute RoPE position (read from a device i32 buffer of length
m). Layout is item-major in both input and output:
input [m, heads, head_dim] output [m, heads, head_dim] (no head-major transpose)
Item-major output keeps the per-item flash_attention slice
contiguous (output[i * heads * head_dim ..] is item i’s whole
Q tensor in head-major-equivalent layout for q_len=1).
Replaces the M sequential single-item launches in the existing
forward_layer_batched_decode path with one batched dispatch.
CUDA-only for now; other backends fall through to the default
unsupported and the caller falls back to the per-item loop.
Sourcefn split_qkv_norm_rope(
_ctx: &mut Self::Context,
_qkv: &Self::Buffer,
_q_norm_w: &Self::Buffer,
_k_norm_w: &Self::Buffer,
_cos: &Self::Buffer,
_sin: &Self::Buffer,
_q_out: &mut Self::Buffer,
_k_out: &mut Self::Buffer,
_v_out: &mut Self::Buffer,
_tokens: usize,
_q_heads: usize,
_kv_heads: usize,
_head_dim: usize,
_pos_offset: usize,
_eps: f32,
_qk_mode: i32,
) -> Result<()>
fn split_qkv_norm_rope( _ctx: &mut Self::Context, _qkv: &Self::Buffer, _q_norm_w: &Self::Buffer, _k_norm_w: &Self::Buffer, _cos: &Self::Buffer, _sin: &Self::Buffer, _q_out: &mut Self::Buffer, _k_out: &mut Self::Buffer, _v_out: &mut Self::Buffer, _tokens: usize, _q_heads: usize, _kv_heads: usize, _head_dim: usize, _pos_offset: usize, _eps: f32, _qk_mode: i32, ) -> Result<()>
Fused split-QKV + QK-norm + RoPE + head-major transpose.
Single-dispatch replacement for the (split_qkv → 3× qk_norm_rope)
chain on the decode-attention prelude. Reads the linear-layer
fused-QKV output once and writes head-major Q/K/V directly into
attention scratch.
qkv layout: [tokens, q_heads*hd + 2*kv_heads*hd].
q_out: [q_heads, tokens, hd]. k_out/v_out: [kv_heads, tokens, hd].
qk_mode: 1 = norm + half-split RoPE for Q/K (Qwen3 with QK-norm),
2 = half-split RoPE only for Q/K,
3 = interleaved RoPE only for Q/K (GGUF LLaMA / llama.cpp layout).
V always falls through to transpose-only.
Default returns Unsupported. Backends that implement it are expected to be dramatically faster than the four-dispatch chain.
Sourcefn split_qkv_norm_rope_into_cache(
_ctx: &mut Self::Context,
_qkv: &Self::Buffer,
_q_norm_w: &Self::Buffer,
_k_norm_w: &Self::Buffer,
_cos: &Self::Buffer,
_sin: &Self::Buffer,
_q_out: &mut Self::Buffer,
_cache_k: &mut Self::Buffer,
_cache_v: &mut Self::Buffer,
_tokens: usize,
_q_heads: usize,
_kv_heads: usize,
_head_dim: usize,
_pos_offset: usize,
_eps: f32,
_qk_mode: i32,
_cache_len: usize,
_cache_capacity: usize,
) -> Result<()>
fn split_qkv_norm_rope_into_cache( _ctx: &mut Self::Context, _qkv: &Self::Buffer, _q_norm_w: &Self::Buffer, _k_norm_w: &Self::Buffer, _cos: &Self::Buffer, _sin: &Self::Buffer, _q_out: &mut Self::Buffer, _cache_k: &mut Self::Buffer, _cache_v: &mut Self::Buffer, _tokens: usize, _q_heads: usize, _kv_heads: usize, _head_dim: usize, _pos_offset: usize, _eps: f32, _qk_mode: i32, _cache_len: usize, _cache_capacity: usize, ) -> Result<()>
Variant of Backend::split_qkv_norm_rope that writes the new
K and V directly into pre-allocated head-major KV cache buffers
at slot [kv_heads, cache_len .. cache_len + tokens, hd].
Eliminates the trailing kv_cache_append_head_major dispatch on
the decode hot path. Q still lands in per-token head-major
scratch (flash-attention reads it as the query).
Default returns Unsupported. Backends without the fused kernel
can keep using split_qkv_norm_rope + kv_cache_append_head_major.
Sourcefn transpose_token_to_head(
_ctx: &mut Self::Context,
_src: &Self::Buffer,
_dst: &mut Self::Buffer,
_tokens: usize,
_heads: usize,
_dim: usize,
)
fn transpose_token_to_head( _ctx: &mut Self::Context, _src: &Self::Buffer, _dst: &mut Self::Buffer, _tokens: usize, _heads: usize, _dim: usize, )
Inverse of transpose_head_to_token: [tokens, heads, dim] →
[heads, tokens, dim]. Used by the CUDA paged_decode_attention
wrapper to convert paged_varlen_attention’s token-major output
back to the head-major layout that Qwen3MoeModel expects.
Default panics — backends without a paged-KV CUDA path don’t
hit this code.
Sourcefn scaled_add_inplace(
_ctx: &mut Self::Context,
dst: &mut Self::Buffer,
src: &Self::Buffer,
scale: f32,
len: usize,
)
fn scaled_add_inplace( _ctx: &mut Self::Context, dst: &mut Self::Buffer, src: &Self::Buffer, scale: f32, len: usize, )
dst[i] += scale * src[i] — scalar-broadcast scaled add, in place.
MoE per-token combine writes out[b] += weight_k * expert_k(x[b])
for each top-K expert; this primitive is the per-call accumulate.
Backends without a dedicated kernel can fall back to the default
implementation, which round-trips through host memory — correct,
but slow on a hot path. Override on any backend you actually
dispatch MoE on.
Sourcefn fused_silu_mul_split_strided(
_ctx: &mut Self::Context,
_gate_up: &Self::Buffer,
_in_row_offset: usize,
_out: &mut Self::Buffer,
_out_row_offset: usize,
_tokens: usize,
_intermediate: usize,
)
fn fused_silu_mul_split_strided( _ctx: &mut Self::Context, _gate_up: &Self::Buffer, _in_row_offset: usize, _out: &mut Self::Buffer, _out_row_offset: usize, _tokens: usize, _intermediate: usize, )
Strided variant of Backend::fused_silu_mul_split for the
bucketed MoE path: reads gate_up rows starting at
in_row_offset, writes out rows starting at out_row_offset.
Sourcefn argmax_rows_f16(
_ctx: &mut Self::Context,
logits: &Self::Buffer,
m: usize,
n: usize,
) -> Result<Vec<u32>>
fn argmax_rows_f16( _ctx: &mut Self::Context, logits: &Self::Buffer, m: usize, n: usize, ) -> Result<Vec<u32>>
Greedy-decode fast path: GPU argmax over each row of a
[m, n] FP16 logits buffer, returning the m token indices on the
host. Saves m × n × 2 bytes of D2H per call (e.g. 19.5 MB at
c=32, vocab=152064) and the host-side argmax scan (~150 µs × m).
Default impl falls back to the slow path: full to_vec + host
argmax. CUDA overrides with a native kernel + tiny D2H (m × 4 B).
Backends that don’t override pay the same cost as
to_vec + host argmax, so callers can call this unconditionally.
Sourcefn from_weight_bytes(raw: &[u8], src_dtype: SrcDtype) -> Self::Buffer
fn from_weight_bytes(raw: &[u8], src_dtype: SrcDtype) -> Self::Buffer
Load a weight tensor straight from its on-disk byte representation, letting the backend pick its preferred storage dtype.
Default impl upcasts bf16/f16 to f32 via an intermediate Vec, matching
pre-existing loader behaviour. Backends override this to go straight
from raw bytes into a native half-precision buffer (e.g. Metal with
FERRUM_METAL_DTYPE=f16), avoiding the transient 2× RAM spike.
Dyn Compatibility§
This trait is not dyn compatible.
In older versions of Rust, dyn compatibility was called "object safety".