pub trait Backend:
Send
+ Sync
+ Sized
+ 'static {
type Buffer: Send + Sync;
type Context;
type Timer: BackendTimer<Self>;
Show 43 methods
// Required methods
fn make_timer() -> Self::Timer;
fn new_context() -> Self::Context;
fn sync(ctx: &mut Self::Context);
fn alloc_typed(dtype: Dtype, n: usize) -> Self::Buffer;
fn from_slice_typed<T: HostDtype>(data: &[T]) -> Self::Buffer;
fn write_typed<T: HostDtype>(
ctx: &mut Self::Context,
dst: &mut Self::Buffer,
data: &[T],
);
fn gemm(
ctx: &mut Self::Context,
a: &Self::Buffer,
b: &Self::Buffer,
out: &mut Self::Buffer,
m: usize,
n: usize,
k: usize,
);
fn rms_norm(
ctx: &mut Self::Context,
x: &Self::Buffer,
w: &Self::Buffer,
eps: f32,
out: &mut Self::Buffer,
tokens: usize,
dim: usize,
);
fn fused_add_rms_norm(
ctx: &mut Self::Context,
residual: &mut Self::Buffer,
x: &Self::Buffer,
w: &Self::Buffer,
eps: f32,
out: &mut Self::Buffer,
tokens: usize,
dim: usize,
);
fn flash_attention(
ctx: &mut Self::Context,
q: &Self::Buffer,
k: &Self::Buffer,
v: &Self::Buffer,
out: &mut Self::Buffer,
batch: usize,
q_len: usize,
kv_len: usize,
pos_offset: usize,
cfg: &AttnConfig,
);
fn copy_slice(
ctx: &mut Self::Context,
src: &Self::Buffer,
src_offset: usize,
dst: &mut Self::Buffer,
dst_offset: usize,
len: usize,
);
fn embedding_lookup(
ctx: &mut Self::Context,
table: &Self::Buffer,
ids: &[u32],
out: &mut Self::Buffer,
dim: usize,
);
fn split_qkv(
ctx: &mut Self::Context,
qkv: &Self::Buffer,
q: &mut Self::Buffer,
k: &mut Self::Buffer,
v: &mut Self::Buffer,
tokens: usize,
q_dim: usize,
kv_dim: usize,
);
fn fused_silu_mul_split(
ctx: &mut Self::Context,
gate_up: &Self::Buffer,
out: &mut Self::Buffer,
tokens: usize,
im: usize,
);
fn qk_norm_rope(
ctx: &mut Self::Context,
input: &Self::Buffer,
norm_w: &Self::Buffer,
cos: &Self::Buffer,
sin: &Self::Buffer,
output: &mut Self::Buffer,
tokens: usize,
heads: usize,
head_dim: usize,
pos_offset: usize,
eps: f32,
mode: i32,
);
fn kv_cache_append_head_major(
ctx: &mut Self::Context,
cache_k: &mut Self::Buffer,
cache_v: &mut Self::Buffer,
cache_len: usize,
cache_capacity: usize,
new_k_head_major: &Self::Buffer,
new_v_head_major: &Self::Buffer,
new_tokens: usize,
nkv: usize,
hd: usize,
);
fn transpose_head_to_token(
ctx: &mut Self::Context,
src: &Self::Buffer,
dst: &mut Self::Buffer,
tokens: usize,
heads: usize,
dim: usize,
);
fn add_inplace(
ctx: &mut Self::Context,
residual: &mut Self::Buffer,
x: &Self::Buffer,
len: usize,
);
fn add_bias(
ctx: &mut Self::Context,
data: &mut Self::Buffer,
bias: &Self::Buffer,
rows: usize,
cols: usize,
);
fn layer_norm(
ctx: &mut Self::Context,
x: &Self::Buffer,
gamma: &Self::Buffer,
beta: &Self::Buffer,
eps: f32,
out: &mut Self::Buffer,
tokens: usize,
dim: usize,
);
fn gelu(
ctx: &mut Self::Context,
x: &Self::Buffer,
out: &mut Self::Buffer,
len: usize,
);
fn alloc(len: usize) -> Self::Buffer;
fn to_vec(buf: &Self::Buffer, len: usize) -> Vec<f32>;
fn from_slice(data: &[f32]) -> Self::Buffer;
// Provided methods
fn with_device_ordinal<R>(
_device_ordinal: Option<usize>,
body: impl FnOnce() -> R,
) -> R { ... }
fn supports_device_ordinal_scope() -> bool { ... }
fn sync_before_host_readback(_ctx: &mut Self::Context) { ... }
fn activation_elem_size_bytes() -> usize { ... }
fn supports_llama_family_batched_decode() -> bool { ... }
fn zero_buffer(
_ctx: &mut Self::Context,
_buf: &mut Self::Buffer,
_len: usize,
) -> Result<()> { ... }
fn mla_attention(
_ctx: &mut Self::Context,
_q: &Self::Buffer,
_kv_compressed: &Self::Buffer,
_kv_rope: &Self::Buffer,
_out: &mut Self::Buffer,
_batch: usize,
_q_len: usize,
_kv_len: usize,
_pos_offset: usize,
_cfg: &AttnConfig,
_kv_lora_rank: usize,
_qk_rope_head_dim: usize,
) -> Result<()> { ... }
fn embedding_lookup_dev(
ctx: &mut Self::Context,
table: &Self::Buffer,
ids: &Self::Buffer,
out: &mut Self::Buffer,
batch: usize,
dim: usize,
) { ... }
fn kv_cache_append_batched_per_cache(
_ctx: &mut Self::Context,
_caches: &[&Self::Buffer],
_new_data: &Self::Buffer,
_cache_lens: &Self::Buffer,
_capacity: usize,
_m: usize,
_nkv: usize,
_hd: usize,
_slot: usize,
) -> Result<()> { ... }
fn flash_attention_batched_per_cache(
_ctx: &mut Self::Context,
_q: &Self::Buffer,
_k_caches: &[&Self::Buffer],
_v_caches: &[&Self::Buffer],
_kv_lens: &Self::Buffer,
_out: &mut Self::Buffer,
_nq: usize,
_nkv: usize,
_hd: usize,
_scale: f32,
_max_valid_kv: usize,
_capacity: usize,
_slot: usize,
) -> Result<()> { ... }
fn qk_norm_rope_batched_per_item(
_ctx: &mut Self::Context,
_input: &Self::Buffer,
_norm_w: &Self::Buffer,
_cos: &Self::Buffer,
_sin: &Self::Buffer,
_output: &mut Self::Buffer,
_positions: &Self::Buffer,
_m: usize,
_heads: usize,
_head_dim: usize,
_eps: f32,
_mode: i32,
) -> Result<()> { ... }
fn split_qkv_norm_rope(
_ctx: &mut Self::Context,
_qkv: &Self::Buffer,
_q_norm_w: &Self::Buffer,
_k_norm_w: &Self::Buffer,
_cos: &Self::Buffer,
_sin: &Self::Buffer,
_q_out: &mut Self::Buffer,
_k_out: &mut Self::Buffer,
_v_out: &mut Self::Buffer,
_tokens: usize,
_q_heads: usize,
_kv_heads: usize,
_head_dim: usize,
_pos_offset: usize,
_eps: f32,
_qk_mode: i32,
) -> Result<()> { ... }
fn split_qkv_norm_rope_into_cache(
_ctx: &mut Self::Context,
_qkv: &Self::Buffer,
_q_norm_w: &Self::Buffer,
_k_norm_w: &Self::Buffer,
_cos: &Self::Buffer,
_sin: &Self::Buffer,
_q_out: &mut Self::Buffer,
_cache_k: &mut Self::Buffer,
_cache_v: &mut Self::Buffer,
_tokens: usize,
_q_heads: usize,
_kv_heads: usize,
_head_dim: usize,
_pos_offset: usize,
_eps: f32,
_qk_mode: i32,
_cache_len: usize,
_cache_capacity: usize,
) -> Result<()> { ... }
fn transpose_token_to_head(
_ctx: &mut Self::Context,
_src: &Self::Buffer,
_dst: &mut Self::Buffer,
_tokens: usize,
_heads: usize,
_dim: usize,
) { ... }
fn scaled_add_inplace(
_ctx: &mut Self::Context,
dst: &mut Self::Buffer,
src: &Self::Buffer,
scale: f32,
len: usize,
) { ... }
fn fused_silu_mul_split_strided(
_ctx: &mut Self::Context,
_gate_up: &Self::Buffer,
_in_row_offset: usize,
_out: &mut Self::Buffer,
_out_row_offset: usize,
_tokens: usize,
_intermediate: usize,
) { ... }
fn write_f32_to_activation(
ctx: &mut Self::Context,
dst: &mut Self::Buffer,
data: &[f32],
) { ... }
fn argmax_rows_f16(
_ctx: &mut Self::Context,
logits: &Self::Buffer,
m: usize,
n: usize,
) -> Result<Vec<u32>> { ... }
fn from_weight_bytes(raw: &[u8], src_dtype: SrcDtype) -> Self::Buffer { ... }
}Expand description
The core abstraction over CUDA / Metal / CPU.
Key design: operations take a &mut Self::Context which accumulates work.
- CPU: Context is
()— ops execute immediately. - Metal: Context is a
CommandBuffer— ops encode into it, flushed onsync(). - CUDA: Context is a
CudaStream— ops launch on the stream, synced onsync().
layer_forward passes the context through all ops in a layer.
ModelRunner calls sync() only when it needs results (e.g., reading logits).
Required Associated Types§
type Buffer: Send + Sync
Sourcetype Context
type Context
Execution context that accumulates GPU work.
- CPU:
()(no-op, ops execute inline) - Metal: wraps a CommandBuffer
- CUDA: wraps a CudaStream
Sourcetype Timer: BackendTimer<Self>
type Timer: BackendTimer<Self>
GPU-side timer scoped to this backend. See super::timer —
CPU: Instant; Metal: sync-wrap; CUDA: cuEvent.
PLAYBOOK § 1.1.
Required Methods§
Sourcefn make_timer() -> Self::Timer
fn make_timer() -> Self::Timer
Factory for Self::Timer — exists so call sites that have a
<B: Backend> parameter can spawn a timer without importing the
concrete impl. PLAYBOOK § 1.2.
Sourcefn new_context() -> Self::Context
fn new_context() -> Self::Context
Opaque per-backend GPTQ weight representation.
- CPU: dequantized f32 weights (run as regular GEMM)
- Metal:
()— unsupported;gemm_gptqerrors Create a new execution context (begin accumulating work).
Sourcefn sync(ctx: &mut Self::Context)
fn sync(ctx: &mut Self::Context)
Flush accumulated work and wait for completion. CPU: no-op. Metal: commit + waitUntilCompleted. CUDA: stream sync.
Sourcefn alloc_typed(dtype: Dtype, n: usize) -> Self::Buffer
fn alloc_typed(dtype: Dtype, n: usize) -> Self::Buffer
Phase D step 2+3: unified typed allocator. Replaces per-dtype
alloc_u32 / alloc_typed_i32 / etc. The buffer is dtype-
tagged at the wrapper level (CudaBuf::U32, MetalBuf with
Dtype::U32, CpuBuf::U32), so reads/writes through .as_<T>()
accessors get the correct byte count automatically.
Sourcefn from_slice_typed<T: HostDtype>(data: &[T]) -> Self::Buffer
fn from_slice_typed<T: HostDtype>(data: &[T]) -> Self::Buffer
Upload typed host data — replaces from_slice_i32 /
from_slice_u32 etc. The host element type T carries its
Dtype via the HostDtype marker so dispatch in the impl
is a one-line match T::DTYPE.
Sourcefn write_typed<T: HostDtype>(
ctx: &mut Self::Context,
dst: &mut Self::Buffer,
data: &[T],
)
fn write_typed<T: HostDtype>( ctx: &mut Self::Context, dst: &mut Self::Buffer, data: &[T], )
In-place typed write — replaces write_u32 / write_i32_into
/ write_f32_into. The buffer must already be dtype-tagged
matching T::DTYPE (typically alloc’d via alloc_typed or
from_slice_typed).
fn gemm( ctx: &mut Self::Context, a: &Self::Buffer, b: &Self::Buffer, out: &mut Self::Buffer, m: usize, n: usize, k: usize, )
fn rms_norm( ctx: &mut Self::Context, x: &Self::Buffer, w: &Self::Buffer, eps: f32, out: &mut Self::Buffer, tokens: usize, dim: usize, )
fn fused_add_rms_norm( ctx: &mut Self::Context, residual: &mut Self::Buffer, x: &Self::Buffer, w: &Self::Buffer, eps: f32, out: &mut Self::Buffer, tokens: usize, dim: usize, )
fn flash_attention( ctx: &mut Self::Context, q: &Self::Buffer, k: &Self::Buffer, v: &Self::Buffer, out: &mut Self::Buffer, batch: usize, q_len: usize, kv_len: usize, pos_offset: usize, cfg: &AttnConfig, )
Sourcefn copy_slice(
ctx: &mut Self::Context,
src: &Self::Buffer,
src_offset: usize,
dst: &mut Self::Buffer,
dst_offset: usize,
len: usize,
)
fn copy_slice( ctx: &mut Self::Context, src: &Self::Buffer, src_offset: usize, dst: &mut Self::Buffer, dst_offset: usize, len: usize, )
Copy len floats from src[src_offset..] to dst[dst_offset..].
Needed for Qwen3Model::prefill to pluck the last token’s hidden state
out of residual[seq_len, h] without round-tripping through host RAM.
Backend::copy is the offset-free variant; copy_slice additionally
supports non-zero source and destination offsets.
fn embedding_lookup( ctx: &mut Self::Context, table: &Self::Buffer, ids: &[u32], out: &mut Self::Buffer, dim: usize, )
Sourcefn split_qkv(
ctx: &mut Self::Context,
qkv: &Self::Buffer,
q: &mut Self::Buffer,
k: &mut Self::Buffer,
v: &mut Self::Buffer,
tokens: usize,
q_dim: usize,
kv_dim: usize,
)
fn split_qkv( ctx: &mut Self::Context, qkv: &Self::Buffer, q: &mut Self::Buffer, k: &mut Self::Buffer, v: &mut Self::Buffer, tokens: usize, q_dim: usize, kv_dim: usize, )
Split fused QKV [tokens, q_dim+2*kv_dim] into separate Q, K, V buffers. Q: [tokens, q_dim], K: [tokens, kv_dim], V: [tokens, kv_dim]
Sourcefn fused_silu_mul_split(
ctx: &mut Self::Context,
gate_up: &Self::Buffer,
out: &mut Self::Buffer,
tokens: usize,
im: usize,
)
fn fused_silu_mul_split( ctx: &mut Self::Context, gate_up: &Self::Buffer, out: &mut Self::Buffer, tokens: usize, im: usize, )
Split fused gate_up [tokens, 2*im] into gate [tokens, im] and up [tokens, im], then compute SiLU(gate) * up → out [tokens, im].
Sourcefn qk_norm_rope(
ctx: &mut Self::Context,
input: &Self::Buffer,
norm_w: &Self::Buffer,
cos: &Self::Buffer,
sin: &Self::Buffer,
output: &mut Self::Buffer,
tokens: usize,
heads: usize,
head_dim: usize,
pos_offset: usize,
eps: f32,
mode: i32,
)
fn qk_norm_rope( ctx: &mut Self::Context, input: &Self::Buffer, norm_w: &Self::Buffer, cos: &Self::Buffer, sin: &Self::Buffer, output: &mut Self::Buffer, tokens: usize, heads: usize, head_dim: usize, pos_offset: usize, eps: f32, mode: i32, )
Fused QK-norm + RoPE + transpose-to-head-major.
mode selects the operation:
0 = transpose only (typical for V, which needs no norm and no RoPE)
1 = per-head RMS norm + RoPE + transpose (Q/K with QK-norm, Qwen3)
2 = RoPE + transpose (Q/K without QK-norm, Llama/Mistral)
input: [tokens, heads, head_dim] (token-major, output of split_qkv)
output: [heads, tokens, head_dim] (head-major, ready for flash_attn / kv_cache_append)
pos_offset is the position of token 0 (decode uses current seq len;
prefill uses 0). Within the batch, positions are taken as pos_offset + i.
This is the primary attention-input preparation op. Backends that have a
fused kernel (Metal’s qk_norm_rope_transpose_f32) will be dramatically
faster than composing norm + rope + transpose separately; the CPU
fallback lowers to the individual ops.
Sourcefn kv_cache_append_head_major(
ctx: &mut Self::Context,
cache_k: &mut Self::Buffer,
cache_v: &mut Self::Buffer,
cache_len: usize,
cache_capacity: usize,
new_k_head_major: &Self::Buffer,
new_v_head_major: &Self::Buffer,
new_tokens: usize,
nkv: usize,
hd: usize,
)
fn kv_cache_append_head_major( ctx: &mut Self::Context, cache_k: &mut Self::Buffer, cache_v: &mut Self::Buffer, cache_len: usize, cache_capacity: usize, new_k_head_major: &Self::Buffer, new_v_head_major: &Self::Buffer, new_tokens: usize, nkv: usize, hd: usize, )
Append new K/V into a pre-allocated head-major cache buffer.
cache_k / cache_v: [nkv, capacity, hd] (head-major, pre-allocated)
new_k_head_major / new_v_head_major: [nkv, new_tokens, hd]
— produced directly by qk_norm_rope, no extra transpose needed.
In-place append at slot [nkv, cache_len..cache_len+new_tokens, hd].
Caller owns cache_len bookkeeping.
Sourcefn transpose_head_to_token(
ctx: &mut Self::Context,
src: &Self::Buffer,
dst: &mut Self::Buffer,
tokens: usize,
heads: usize,
dim: usize,
)
fn transpose_head_to_token( ctx: &mut Self::Context, src: &Self::Buffer, dst: &mut Self::Buffer, tokens: usize, heads: usize, dim: usize, )
Transpose [heads, tokens, dim] → [tokens, heads, dim].
Called after flash_attention to restore token-major layout for O-proj.
Sourcefn add_inplace(
ctx: &mut Self::Context,
residual: &mut Self::Buffer,
x: &Self::Buffer,
len: usize,
)
fn add_inplace( ctx: &mut Self::Context, residual: &mut Self::Buffer, x: &Self::Buffer, len: usize, )
residual[i] += x[i] (in-place)
Sourcefn add_bias(
ctx: &mut Self::Context,
data: &mut Self::Buffer,
bias: &Self::Buffer,
rows: usize,
cols: usize,
)
fn add_bias( ctx: &mut Self::Context, data: &mut Self::Buffer, bias: &Self::Buffer, rows: usize, cols: usize, )
Broadcast bias add: data[r, c] += bias[c] for every row.
Required by Bert / Clip / Whisper whose linear projections carry a bias.
Sourcefn layer_norm(
ctx: &mut Self::Context,
x: &Self::Buffer,
gamma: &Self::Buffer,
beta: &Self::Buffer,
eps: f32,
out: &mut Self::Buffer,
tokens: usize,
dim: usize,
)
fn layer_norm( ctx: &mut Self::Context, x: &Self::Buffer, gamma: &Self::Buffer, beta: &Self::Buffer, eps: f32, out: &mut Self::Buffer, tokens: usize, dim: usize, )
Full LayerNorm (mean + variance normalisation + affine), distinct from
the rms_norm used by Llama-family decoders.
out[r, c] = ((x[r, c] - mean) / sqrt(var + eps)) * gamma[c] + beta[c]
Where mean and var are reduced over the last dim (cols).
Sourcefn gelu(
ctx: &mut Self::Context,
x: &Self::Buffer,
out: &mut Self::Buffer,
len: usize,
)
fn gelu( ctx: &mut Self::Context, x: &Self::Buffer, out: &mut Self::Buffer, len: usize, )
Element-wise GELU activation (erf-based, matches PyTorch default).
fn alloc(len: usize) -> Self::Buffer
fn to_vec(buf: &Self::Buffer, len: usize) -> Vec<f32>
fn from_slice(data: &[f32]) -> Self::Buffer
Provided Methods§
Sourcefn with_device_ordinal<R>(
_device_ordinal: Option<usize>,
body: impl FnOnce() -> R,
) -> R
fn with_device_ordinal<R>( _device_ordinal: Option<usize>, body: impl FnOnce() -> R, ) -> R
Run body while binding context-free backend operations to an
explicit device ordinal when the backend supports multi-device scopes.
Most backends have no per-ordinal concept and use the default no-op implementation. CUDA overrides this once its stream/context caches are device-keyed, allowing layer-split stages to load and execute on their selected GPU instead of relying on process-global defaults.
Sourcefn supports_device_ordinal_scope() -> bool
fn supports_device_ordinal_scope() -> bool
Whether Self::with_device_ordinal actually switches backend
execution to the requested ordinal.
Sourcefn sync_before_host_readback(_ctx: &mut Self::Context)
fn sync_before_host_readback(_ctx: &mut Self::Context)
Prepare pending GPU work for a following host readback.
Most backends either execute eagerly or synchronize as part of their
device-to-host copy. Metal shared-buffer reads use the CPU pointer
directly, so Metal must flush its command buffer before to_vec.
Sourcefn activation_elem_size_bytes() -> usize
fn activation_elem_size_bytes() -> usize
Byte width of buffers returned by Self::alloc.
CUDA activation scratch is fp16, while Metal and CPU scratch are fp32. Generic model code uses this for byte offsets into batched scratch buffers without checking concrete backend types.
Sourcefn supports_llama_family_batched_decode() -> bool
fn supports_llama_family_batched_decode() -> bool
Whether LlamaFamilyModel::decode_batch_internal may use its optimized
batched decode path on this backend.
Backends that do not yet produce correct follow-up logits under concurrent dense decode should override this to force the per-item fallback until the optimized path is fixed.
Sourcefn zero_buffer(
_ctx: &mut Self::Context,
_buf: &mut Self::Buffer,
_len: usize,
) -> Result<()>
fn zero_buffer( _ctx: &mut Self::Context, _buf: &mut Self::Buffer, _len: usize, ) -> Result<()>
Zero the first len elements of a Self::Buffer. CUDA path uses
cuMemsetD16Async; default returns unsupported.
Sourcefn mla_attention(
_ctx: &mut Self::Context,
_q: &Self::Buffer,
_kv_compressed: &Self::Buffer,
_kv_rope: &Self::Buffer,
_out: &mut Self::Buffer,
_batch: usize,
_q_len: usize,
_kv_len: usize,
_pos_offset: usize,
_cfg: &AttnConfig,
_kv_lora_rank: usize,
_qk_rope_head_dim: usize,
) -> Result<()>
fn mla_attention( _ctx: &mut Self::Context, _q: &Self::Buffer, _kv_compressed: &Self::Buffer, _kv_rope: &Self::Buffer, _out: &mut Self::Buffer, _batch: usize, _q_len: usize, _kv_len: usize, _pos_offset: usize, _cfg: &AttnConfig, _kv_lora_rank: usize, _qk_rope_head_dim: usize, ) -> Result<()>
Multi-Head Latent Attention — DeepSeek V2 / V3’s compressed-KV attention variant. Extension point only; no backend implements it yet. DeepSeek V3 landing in Phase D/E will fill this in.
q: full Q [batch, num_heads, q_len, head_dim]
kv_compressed: latent KV [batch, kv_len, kv_lora_rank]
kv_rope: per-position rope-applied key heads [batch, kv_len, qk_rope_head_dim]
out: [batch, num_heads, q_len, head_dim]
Sourcefn embedding_lookup_dev(
ctx: &mut Self::Context,
table: &Self::Buffer,
ids: &Self::Buffer,
out: &mut Self::Buffer,
batch: usize,
dim: usize,
)
fn embedding_lookup_dev( ctx: &mut Self::Context, table: &Self::Buffer, ids: &Self::Buffer, out: &mut Self::Buffer, batch: usize, dim: usize, )
Device-buffer variant of embedding_lookup for graph-capturable
MoE routing — the gather step before phase-1 GEMM in
moe_forward_bucketed. The host-slice embedding_lookup does
clone_htod(ids) internally, which records stale host pointers
under CUDA Graph capture replay.
ids: &Self::Buffer must be a device I32 buffer of batch
elements (e.g. Qwen3MoeScratch::route_packed_idx_dev).
batch is passed explicitly since a typed CudaBuf carries
its element count but the caller often wants a partial gather.
Default impl: round-trip via to_vec + dispatch the host-slice
variant. CUDA overrides.
Sourcefn kv_cache_append_batched_per_cache(
_ctx: &mut Self::Context,
_caches: &[&Self::Buffer],
_new_data: &Self::Buffer,
_cache_lens: &Self::Buffer,
_capacity: usize,
_m: usize,
_nkv: usize,
_hd: usize,
_slot: usize,
) -> Result<()>
fn kv_cache_append_batched_per_cache( _ctx: &mut Self::Context, _caches: &[&Self::Buffer], _new_data: &Self::Buffer, _cache_lens: &Self::Buffer, _capacity: usize, _m: usize, _nkv: usize, _hd: usize, _slot: usize, ) -> Result<()>
Batched kv_cache_append across M caches in one launch. Each item
writes its (head-major) K-or-V row into its own cache at offset
read from cache_lens[i]. Replaces M sequential
kv_cache_append_head_major calls with a single dispatch.
new_data layout: [m, nkv, hd] item-major (each item’s slice
is contiguous, identical to the k/v_normed_batched produced by
qk_norm_rope_batched_per_item).
caches: per-cache [nkv, capacity, hd] head-major.
cache_lens: device buffer (u32 storage, length ≥ m). Caller
fills via B::write_u32_into BEFORE the call. Required for
CUDA-graph capture: the kernel reads from this stable device
buffer, so a captured graph can be replayed with new lens by
just rewriting the buffer between launches.
Sourcefn flash_attention_batched_per_cache(
_ctx: &mut Self::Context,
_q: &Self::Buffer,
_k_caches: &[&Self::Buffer],
_v_caches: &[&Self::Buffer],
_kv_lens: &Self::Buffer,
_out: &mut Self::Buffer,
_nq: usize,
_nkv: usize,
_hd: usize,
_scale: f32,
_max_valid_kv: usize,
_capacity: usize,
_slot: usize,
) -> Result<()>
fn flash_attention_batched_per_cache( _ctx: &mut Self::Context, _q: &Self::Buffer, _k_caches: &[&Self::Buffer], _v_caches: &[&Self::Buffer], _kv_lens: &Self::Buffer, _out: &mut Self::Buffer, _nq: usize, _nkv: usize, _hd: usize, _scale: f32, _max_valid_kv: usize, _capacity: usize, _slot: usize, ) -> Result<()>
Batched flash_attention across M decode caches in one launch.
Replaces the per-item flash_attention(q_len=1, ...) × M
loop in the non-paged batched-decode path.
API takes Vec<&Buffer> for the per-cache K/V buffers (each
[nkv, capacity, hd] head-major) plus host-side kv_lens.
Backends that implement it must extract per-cache device
pointers, build the device arrays the kernel needs, and launch
one kernel covering all M items.
q layout: [m, nq, hd] item-major (matches the
qk_norm_rope_batched_per_item output for q_len=1).
out layout: [m, nq, hd] item-major — written directly into
the caller’s batched attn_out buffer, no per-item copy needed.
CUDA-only for now (kernel batched_decode_attention exists in
kernels/batched_decode_attention.cu).
kv_lens: device buffer (u32 storage, length ≥ m) — same
design as kv_cache_append_batched_per_cache::cache_lens.
Sourcefn qk_norm_rope_batched_per_item(
_ctx: &mut Self::Context,
_input: &Self::Buffer,
_norm_w: &Self::Buffer,
_cos: &Self::Buffer,
_sin: &Self::Buffer,
_output: &mut Self::Buffer,
_positions: &Self::Buffer,
_m: usize,
_heads: usize,
_head_dim: usize,
_eps: f32,
_mode: i32,
) -> Result<()>
fn qk_norm_rope_batched_per_item( _ctx: &mut Self::Context, _input: &Self::Buffer, _norm_w: &Self::Buffer, _cos: &Self::Buffer, _sin: &Self::Buffer, _output: &mut Self::Buffer, _positions: &Self::Buffer, _m: usize, _heads: usize, _head_dim: usize, _eps: f32, _mode: i32, ) -> Result<()>
Batched per-item-position variant of qk_norm_rope for the
non-paged batched-decode path. Each of the m items has its own
absolute RoPE position (read from a device i32 buffer of length
m). Layout is item-major in both input and output:
input [m, heads, head_dim] output [m, heads, head_dim] (no head-major transpose)
Item-major output keeps the per-item flash_attention slice
contiguous (output[i * heads * head_dim ..] is item i’s whole
Q tensor in head-major-equivalent layout for q_len=1).
Replaces the M sequential single-item launches in the existing
forward_layer_batched_decode path with one batched dispatch.
CUDA-only for now; other backends fall through to the default
unsupported and the caller falls back to the per-item loop.
Sourcefn split_qkv_norm_rope(
_ctx: &mut Self::Context,
_qkv: &Self::Buffer,
_q_norm_w: &Self::Buffer,
_k_norm_w: &Self::Buffer,
_cos: &Self::Buffer,
_sin: &Self::Buffer,
_q_out: &mut Self::Buffer,
_k_out: &mut Self::Buffer,
_v_out: &mut Self::Buffer,
_tokens: usize,
_q_heads: usize,
_kv_heads: usize,
_head_dim: usize,
_pos_offset: usize,
_eps: f32,
_qk_mode: i32,
) -> Result<()>
fn split_qkv_norm_rope( _ctx: &mut Self::Context, _qkv: &Self::Buffer, _q_norm_w: &Self::Buffer, _k_norm_w: &Self::Buffer, _cos: &Self::Buffer, _sin: &Self::Buffer, _q_out: &mut Self::Buffer, _k_out: &mut Self::Buffer, _v_out: &mut Self::Buffer, _tokens: usize, _q_heads: usize, _kv_heads: usize, _head_dim: usize, _pos_offset: usize, _eps: f32, _qk_mode: i32, ) -> Result<()>
Fused split-QKV + QK-norm + RoPE + head-major transpose.
Single-dispatch replacement for the (split_qkv → 3× qk_norm_rope)
chain on the decode-attention prelude. Reads the linear-layer
fused-QKV output once and writes head-major Q/K/V directly into
attention scratch.
qkv layout: [tokens, q_heads*hd + 2*kv_heads*hd].
q_out: [q_heads, tokens, hd]. k_out/v_out: [kv_heads, tokens, hd].
qk_mode: 1 = norm + half-split RoPE for Q/K (Qwen3 with QK-norm),
2 = half-split RoPE only for Q/K,
3 = interleaved RoPE only for Q/K (GGUF LLaMA / llama.cpp layout).
V always falls through to transpose-only.
Default returns Unsupported. Backends that implement it are expected to be dramatically faster than the four-dispatch chain.
Sourcefn split_qkv_norm_rope_into_cache(
_ctx: &mut Self::Context,
_qkv: &Self::Buffer,
_q_norm_w: &Self::Buffer,
_k_norm_w: &Self::Buffer,
_cos: &Self::Buffer,
_sin: &Self::Buffer,
_q_out: &mut Self::Buffer,
_cache_k: &mut Self::Buffer,
_cache_v: &mut Self::Buffer,
_tokens: usize,
_q_heads: usize,
_kv_heads: usize,
_head_dim: usize,
_pos_offset: usize,
_eps: f32,
_qk_mode: i32,
_cache_len: usize,
_cache_capacity: usize,
) -> Result<()>
fn split_qkv_norm_rope_into_cache( _ctx: &mut Self::Context, _qkv: &Self::Buffer, _q_norm_w: &Self::Buffer, _k_norm_w: &Self::Buffer, _cos: &Self::Buffer, _sin: &Self::Buffer, _q_out: &mut Self::Buffer, _cache_k: &mut Self::Buffer, _cache_v: &mut Self::Buffer, _tokens: usize, _q_heads: usize, _kv_heads: usize, _head_dim: usize, _pos_offset: usize, _eps: f32, _qk_mode: i32, _cache_len: usize, _cache_capacity: usize, ) -> Result<()>
Variant of Backend::split_qkv_norm_rope that writes the new
K and V directly into pre-allocated head-major KV cache buffers
at slot [kv_heads, cache_len .. cache_len + tokens, hd].
Eliminates the trailing kv_cache_append_head_major dispatch on
the decode hot path. Q still lands in per-token head-major
scratch (flash-attention reads it as the query).
Default returns Unsupported. Backends without the fused kernel
can keep using split_qkv_norm_rope + kv_cache_append_head_major.
Sourcefn transpose_token_to_head(
_ctx: &mut Self::Context,
_src: &Self::Buffer,
_dst: &mut Self::Buffer,
_tokens: usize,
_heads: usize,
_dim: usize,
)
fn transpose_token_to_head( _ctx: &mut Self::Context, _src: &Self::Buffer, _dst: &mut Self::Buffer, _tokens: usize, _heads: usize, _dim: usize, )
Inverse of transpose_head_to_token: [tokens, heads, dim] →
[heads, tokens, dim]. Used by the CUDA paged_decode_attention
wrapper to convert paged_varlen_attention’s token-major output
back to the head-major layout that Qwen3MoeModel expects.
Default panics — backends without a paged-KV CUDA path don’t
hit this code.
Sourcefn scaled_add_inplace(
_ctx: &mut Self::Context,
dst: &mut Self::Buffer,
src: &Self::Buffer,
scale: f32,
len: usize,
)
fn scaled_add_inplace( _ctx: &mut Self::Context, dst: &mut Self::Buffer, src: &Self::Buffer, scale: f32, len: usize, )
dst[i] += scale * src[i] — scalar-broadcast scaled add, in place.
MoE per-token combine writes out[b] += weight_k * expert_k(x[b])
for each top-K expert; this primitive is the per-call accumulate.
Backends without a dedicated kernel can fall back to the default
implementation, which round-trips through host memory — correct,
but slow on a hot path. Override on any backend you actually
dispatch MoE on.
Sourcefn fused_silu_mul_split_strided(
_ctx: &mut Self::Context,
_gate_up: &Self::Buffer,
_in_row_offset: usize,
_out: &mut Self::Buffer,
_out_row_offset: usize,
_tokens: usize,
_intermediate: usize,
)
fn fused_silu_mul_split_strided( _ctx: &mut Self::Context, _gate_up: &Self::Buffer, _in_row_offset: usize, _out: &mut Self::Buffer, _out_row_offset: usize, _tokens: usize, _intermediate: usize, )
Strided variant of Backend::fused_silu_mul_split for the
bucketed MoE path: reads gate_up rows starting at
in_row_offset, writes out rows starting at out_row_offset.
fn write_f32_to_activation( ctx: &mut Self::Context, dst: &mut Self::Buffer, data: &[f32], )
Sourcefn argmax_rows_f16(
_ctx: &mut Self::Context,
logits: &Self::Buffer,
m: usize,
n: usize,
) -> Result<Vec<u32>>
fn argmax_rows_f16( _ctx: &mut Self::Context, logits: &Self::Buffer, m: usize, n: usize, ) -> Result<Vec<u32>>
Greedy-decode fast path: GPU argmax over each row of a
[m, n] FP16 logits buffer, returning the m token indices on the
host. Saves m × n × 2 bytes of D2H per call (e.g. 19.5 MB at
c=32, vocab=152064) and the host-side argmax scan (~150 µs × m).
Default impl falls back to the slow path: full to_vec + host
argmax. CUDA overrides with a native kernel + tiny D2H (m × 4 B).
Backends that don’t override pay the same cost as
to_vec + host argmax, so callers can call this unconditionally.
Sourcefn from_weight_bytes(raw: &[u8], src_dtype: SrcDtype) -> Self::Buffer
fn from_weight_bytes(raw: &[u8], src_dtype: SrcDtype) -> Self::Buffer
Load a weight tensor straight from its on-disk byte representation, letting the backend pick its preferred storage dtype.
Default impl upcasts bf16/f16 to f32 via an intermediate Vec, matching
pre-existing loader behaviour. Backends override this to go straight
from raw bytes into a native half-precision buffer (e.g. Metal with
FERRUM_METAL_DTYPE=f16), avoiding the transient 2× RAM spike.
Dyn Compatibility§
This trait is not dyn compatible.
In older versions of Rust, dyn compatibility was called "object safety".