Skip to main content

Backend

Trait Backend 

Source
pub trait Backend:
    Send
    + Sync
    + Sized
    + 'static {
    type Buffer: Send + Sync;
    type Context;
    type Timer: BackendTimer<Self>;

Show 43 methods // Required methods fn make_timer() -> Self::Timer; fn new_context() -> Self::Context; fn sync(ctx: &mut Self::Context); fn alloc_typed(dtype: Dtype, n: usize) -> Self::Buffer; fn from_slice_typed<T: HostDtype>(data: &[T]) -> Self::Buffer; fn write_typed<T: HostDtype>( ctx: &mut Self::Context, dst: &mut Self::Buffer, data: &[T], ); fn gemm( ctx: &mut Self::Context, a: &Self::Buffer, b: &Self::Buffer, out: &mut Self::Buffer, m: usize, n: usize, k: usize, ); fn rms_norm( ctx: &mut Self::Context, x: &Self::Buffer, w: &Self::Buffer, eps: f32, out: &mut Self::Buffer, tokens: usize, dim: usize, ); fn fused_add_rms_norm( ctx: &mut Self::Context, residual: &mut Self::Buffer, x: &Self::Buffer, w: &Self::Buffer, eps: f32, out: &mut Self::Buffer, tokens: usize, dim: usize, ); fn flash_attention( ctx: &mut Self::Context, q: &Self::Buffer, k: &Self::Buffer, v: &Self::Buffer, out: &mut Self::Buffer, batch: usize, q_len: usize, kv_len: usize, pos_offset: usize, cfg: &AttnConfig, ); fn copy_slice( ctx: &mut Self::Context, src: &Self::Buffer, src_offset: usize, dst: &mut Self::Buffer, dst_offset: usize, len: usize, ); fn embedding_lookup( ctx: &mut Self::Context, table: &Self::Buffer, ids: &[u32], out: &mut Self::Buffer, dim: usize, ); fn split_qkv( ctx: &mut Self::Context, qkv: &Self::Buffer, q: &mut Self::Buffer, k: &mut Self::Buffer, v: &mut Self::Buffer, tokens: usize, q_dim: usize, kv_dim: usize, ); fn fused_silu_mul_split( ctx: &mut Self::Context, gate_up: &Self::Buffer, out: &mut Self::Buffer, tokens: usize, im: usize, ); fn qk_norm_rope( ctx: &mut Self::Context, input: &Self::Buffer, norm_w: &Self::Buffer, cos: &Self::Buffer, sin: &Self::Buffer, output: &mut Self::Buffer, tokens: usize, heads: usize, head_dim: usize, pos_offset: usize, eps: f32, mode: i32, ); fn kv_cache_append_head_major( ctx: &mut Self::Context, cache_k: &mut Self::Buffer, cache_v: &mut Self::Buffer, cache_len: usize, cache_capacity: usize, new_k_head_major: &Self::Buffer, new_v_head_major: &Self::Buffer, new_tokens: usize, nkv: usize, hd: usize, ); fn transpose_head_to_token( ctx: &mut Self::Context, src: &Self::Buffer, dst: &mut Self::Buffer, tokens: usize, heads: usize, dim: usize, ); fn add_inplace( ctx: &mut Self::Context, residual: &mut Self::Buffer, x: &Self::Buffer, len: usize, ); fn add_bias( ctx: &mut Self::Context, data: &mut Self::Buffer, bias: &Self::Buffer, rows: usize, cols: usize, ); fn layer_norm( ctx: &mut Self::Context, x: &Self::Buffer, gamma: &Self::Buffer, beta: &Self::Buffer, eps: f32, out: &mut Self::Buffer, tokens: usize, dim: usize, ); fn gelu( ctx: &mut Self::Context, x: &Self::Buffer, out: &mut Self::Buffer, len: usize, ); fn alloc(len: usize) -> Self::Buffer; fn to_vec(buf: &Self::Buffer, len: usize) -> Vec<f32>; fn from_slice(data: &[f32]) -> Self::Buffer; // Provided methods fn with_device_ordinal<R>( _device_ordinal: Option<usize>, body: impl FnOnce() -> R, ) -> R { ... } fn supports_device_ordinal_scope() -> bool { ... } fn sync_before_host_readback(_ctx: &mut Self::Context) { ... } fn activation_elem_size_bytes() -> usize { ... } fn supports_llama_family_batched_decode() -> bool { ... } fn zero_buffer( _ctx: &mut Self::Context, _buf: &mut Self::Buffer, _len: usize, ) -> Result<()> { ... } fn mla_attention( _ctx: &mut Self::Context, _q: &Self::Buffer, _kv_compressed: &Self::Buffer, _kv_rope: &Self::Buffer, _out: &mut Self::Buffer, _batch: usize, _q_len: usize, _kv_len: usize, _pos_offset: usize, _cfg: &AttnConfig, _kv_lora_rank: usize, _qk_rope_head_dim: usize, ) -> Result<()> { ... } fn embedding_lookup_dev( ctx: &mut Self::Context, table: &Self::Buffer, ids: &Self::Buffer, out: &mut Self::Buffer, batch: usize, dim: usize, ) { ... } fn kv_cache_append_batched_per_cache( _ctx: &mut Self::Context, _caches: &[&Self::Buffer], _new_data: &Self::Buffer, _cache_lens: &Self::Buffer, _capacity: usize, _m: usize, _nkv: usize, _hd: usize, _slot: usize, ) -> Result<()> { ... } fn flash_attention_batched_per_cache( _ctx: &mut Self::Context, _q: &Self::Buffer, _k_caches: &[&Self::Buffer], _v_caches: &[&Self::Buffer], _kv_lens: &Self::Buffer, _out: &mut Self::Buffer, _nq: usize, _nkv: usize, _hd: usize, _scale: f32, _max_valid_kv: usize, _capacity: usize, _slot: usize, ) -> Result<()> { ... } fn qk_norm_rope_batched_per_item( _ctx: &mut Self::Context, _input: &Self::Buffer, _norm_w: &Self::Buffer, _cos: &Self::Buffer, _sin: &Self::Buffer, _output: &mut Self::Buffer, _positions: &Self::Buffer, _m: usize, _heads: usize, _head_dim: usize, _eps: f32, _mode: i32, ) -> Result<()> { ... } fn split_qkv_norm_rope( _ctx: &mut Self::Context, _qkv: &Self::Buffer, _q_norm_w: &Self::Buffer, _k_norm_w: &Self::Buffer, _cos: &Self::Buffer, _sin: &Self::Buffer, _q_out: &mut Self::Buffer, _k_out: &mut Self::Buffer, _v_out: &mut Self::Buffer, _tokens: usize, _q_heads: usize, _kv_heads: usize, _head_dim: usize, _pos_offset: usize, _eps: f32, _qk_mode: i32, ) -> Result<()> { ... } fn split_qkv_norm_rope_into_cache( _ctx: &mut Self::Context, _qkv: &Self::Buffer, _q_norm_w: &Self::Buffer, _k_norm_w: &Self::Buffer, _cos: &Self::Buffer, _sin: &Self::Buffer, _q_out: &mut Self::Buffer, _cache_k: &mut Self::Buffer, _cache_v: &mut Self::Buffer, _tokens: usize, _q_heads: usize, _kv_heads: usize, _head_dim: usize, _pos_offset: usize, _eps: f32, _qk_mode: i32, _cache_len: usize, _cache_capacity: usize, ) -> Result<()> { ... } fn transpose_token_to_head( _ctx: &mut Self::Context, _src: &Self::Buffer, _dst: &mut Self::Buffer, _tokens: usize, _heads: usize, _dim: usize, ) { ... } fn scaled_add_inplace( _ctx: &mut Self::Context, dst: &mut Self::Buffer, src: &Self::Buffer, scale: f32, len: usize, ) { ... } fn fused_silu_mul_split_strided( _ctx: &mut Self::Context, _gate_up: &Self::Buffer, _in_row_offset: usize, _out: &mut Self::Buffer, _out_row_offset: usize, _tokens: usize, _intermediate: usize, ) { ... } fn write_f32_to_activation( ctx: &mut Self::Context, dst: &mut Self::Buffer, data: &[f32], ) { ... } fn argmax_rows_f16( _ctx: &mut Self::Context, logits: &Self::Buffer, m: usize, n: usize, ) -> Result<Vec<u32>> { ... } fn from_weight_bytes(raw: &[u8], src_dtype: SrcDtype) -> Self::Buffer { ... }
}
Expand description

The core abstraction over CUDA / Metal / CPU.

Key design: operations take a &mut Self::Context which accumulates work.

  • CPU: Context is () — ops execute immediately.
  • Metal: Context is a CommandBuffer — ops encode into it, flushed on sync().
  • CUDA: Context is a CudaStream — ops launch on the stream, synced on sync().

layer_forward passes the context through all ops in a layer. ModelRunner calls sync() only when it needs results (e.g., reading logits).

Required Associated Types§

Source

type Buffer: Send + Sync

Source

type Context

Execution context that accumulates GPU work.

  • CPU: () (no-op, ops execute inline)
  • Metal: wraps a CommandBuffer
  • CUDA: wraps a CudaStream
Source

type Timer: BackendTimer<Self>

GPU-side timer scoped to this backend. See super::timer — CPU: Instant; Metal: sync-wrap; CUDA: cuEvent. PLAYBOOK § 1.1.

Required Methods§

Source

fn make_timer() -> Self::Timer

Factory for Self::Timer — exists so call sites that have a <B: Backend> parameter can spawn a timer without importing the concrete impl. PLAYBOOK § 1.2.

Source

fn new_context() -> Self::Context

Opaque per-backend GPTQ weight representation.

  • CPU: dequantized f32 weights (run as regular GEMM)
  • Metal: () — unsupported; gemm_gptq errors Create a new execution context (begin accumulating work).
Source

fn sync(ctx: &mut Self::Context)

Flush accumulated work and wait for completion. CPU: no-op. Metal: commit + waitUntilCompleted. CUDA: stream sync.

Source

fn alloc_typed(dtype: Dtype, n: usize) -> Self::Buffer

Phase D step 2+3: unified typed allocator. Replaces per-dtype alloc_u32 / alloc_typed_i32 / etc. The buffer is dtype- tagged at the wrapper level (CudaBuf::U32, MetalBuf with Dtype::U32, CpuBuf::U32), so reads/writes through .as_<T>() accessors get the correct byte count automatically.

Source

fn from_slice_typed<T: HostDtype>(data: &[T]) -> Self::Buffer

Upload typed host data — replaces from_slice_i32 / from_slice_u32 etc. The host element type T carries its Dtype via the HostDtype marker so dispatch in the impl is a one-line match T::DTYPE.

Source

fn write_typed<T: HostDtype>( ctx: &mut Self::Context, dst: &mut Self::Buffer, data: &[T], )

In-place typed write — replaces write_u32 / write_i32_into / write_f32_into. The buffer must already be dtype-tagged matching T::DTYPE (typically alloc’d via alloc_typed or from_slice_typed).

Source

fn gemm( ctx: &mut Self::Context, a: &Self::Buffer, b: &Self::Buffer, out: &mut Self::Buffer, m: usize, n: usize, k: usize, )

Source

fn rms_norm( ctx: &mut Self::Context, x: &Self::Buffer, w: &Self::Buffer, eps: f32, out: &mut Self::Buffer, tokens: usize, dim: usize, )

Source

fn fused_add_rms_norm( ctx: &mut Self::Context, residual: &mut Self::Buffer, x: &Self::Buffer, w: &Self::Buffer, eps: f32, out: &mut Self::Buffer, tokens: usize, dim: usize, )

Source

fn flash_attention( ctx: &mut Self::Context, q: &Self::Buffer, k: &Self::Buffer, v: &Self::Buffer, out: &mut Self::Buffer, batch: usize, q_len: usize, kv_len: usize, pos_offset: usize, cfg: &AttnConfig, )

Source

fn copy_slice( ctx: &mut Self::Context, src: &Self::Buffer, src_offset: usize, dst: &mut Self::Buffer, dst_offset: usize, len: usize, )

Copy len floats from src[src_offset..] to dst[dst_offset..].

Needed for Qwen3Model::prefill to pluck the last token’s hidden state out of residual[seq_len, h] without round-tripping through host RAM. Backend::copy is the offset-free variant; copy_slice additionally supports non-zero source and destination offsets.

Source

fn embedding_lookup( ctx: &mut Self::Context, table: &Self::Buffer, ids: &[u32], out: &mut Self::Buffer, dim: usize, )

Source

fn split_qkv( ctx: &mut Self::Context, qkv: &Self::Buffer, q: &mut Self::Buffer, k: &mut Self::Buffer, v: &mut Self::Buffer, tokens: usize, q_dim: usize, kv_dim: usize, )

Split fused QKV [tokens, q_dim+2*kv_dim] into separate Q, K, V buffers. Q: [tokens, q_dim], K: [tokens, kv_dim], V: [tokens, kv_dim]

Source

fn fused_silu_mul_split( ctx: &mut Self::Context, gate_up: &Self::Buffer, out: &mut Self::Buffer, tokens: usize, im: usize, )

Split fused gate_up [tokens, 2*im] into gate [tokens, im] and up [tokens, im], then compute SiLU(gate) * up → out [tokens, im].

Source

fn qk_norm_rope( ctx: &mut Self::Context, input: &Self::Buffer, norm_w: &Self::Buffer, cos: &Self::Buffer, sin: &Self::Buffer, output: &mut Self::Buffer, tokens: usize, heads: usize, head_dim: usize, pos_offset: usize, eps: f32, mode: i32, )

Fused QK-norm + RoPE + transpose-to-head-major.

mode selects the operation: 0 = transpose only (typical for V, which needs no norm and no RoPE) 1 = per-head RMS norm + RoPE + transpose (Q/K with QK-norm, Qwen3) 2 = RoPE + transpose (Q/K without QK-norm, Llama/Mistral)

input: [tokens, heads, head_dim] (token-major, output of split_qkv) output: [heads, tokens, head_dim] (head-major, ready for flash_attn / kv_cache_append)

pos_offset is the position of token 0 (decode uses current seq len; prefill uses 0). Within the batch, positions are taken as pos_offset + i.

This is the primary attention-input preparation op. Backends that have a fused kernel (Metal’s qk_norm_rope_transpose_f32) will be dramatically faster than composing norm + rope + transpose separately; the CPU fallback lowers to the individual ops.

Source

fn kv_cache_append_head_major( ctx: &mut Self::Context, cache_k: &mut Self::Buffer, cache_v: &mut Self::Buffer, cache_len: usize, cache_capacity: usize, new_k_head_major: &Self::Buffer, new_v_head_major: &Self::Buffer, new_tokens: usize, nkv: usize, hd: usize, )

Append new K/V into a pre-allocated head-major cache buffer.

cache_k / cache_v: [nkv, capacity, hd] (head-major, pre-allocated) new_k_head_major / new_v_head_major: [nkv, new_tokens, hd] — produced directly by qk_norm_rope, no extra transpose needed.

In-place append at slot [nkv, cache_len..cache_len+new_tokens, hd]. Caller owns cache_len bookkeeping.

Source

fn transpose_head_to_token( ctx: &mut Self::Context, src: &Self::Buffer, dst: &mut Self::Buffer, tokens: usize, heads: usize, dim: usize, )

Transpose [heads, tokens, dim] → [tokens, heads, dim]. Called after flash_attention to restore token-major layout for O-proj.

Source

fn add_inplace( ctx: &mut Self::Context, residual: &mut Self::Buffer, x: &Self::Buffer, len: usize, )

residual[i] += x[i] (in-place)

Source

fn add_bias( ctx: &mut Self::Context, data: &mut Self::Buffer, bias: &Self::Buffer, rows: usize, cols: usize, )

Broadcast bias add: data[r, c] += bias[c] for every row. Required by Bert / Clip / Whisper whose linear projections carry a bias.

Source

fn layer_norm( ctx: &mut Self::Context, x: &Self::Buffer, gamma: &Self::Buffer, beta: &Self::Buffer, eps: f32, out: &mut Self::Buffer, tokens: usize, dim: usize, )

Full LayerNorm (mean + variance normalisation + affine), distinct from the rms_norm used by Llama-family decoders. out[r, c] = ((x[r, c] - mean) / sqrt(var + eps)) * gamma[c] + beta[c] Where mean and var are reduced over the last dim (cols).

Source

fn gelu( ctx: &mut Self::Context, x: &Self::Buffer, out: &mut Self::Buffer, len: usize, )

Element-wise GELU activation (erf-based, matches PyTorch default).

Source

fn alloc(len: usize) -> Self::Buffer

Source

fn to_vec(buf: &Self::Buffer, len: usize) -> Vec<f32>

Source

fn from_slice(data: &[f32]) -> Self::Buffer

Provided Methods§

Source

fn with_device_ordinal<R>( _device_ordinal: Option<usize>, body: impl FnOnce() -> R, ) -> R

Run body while binding context-free backend operations to an explicit device ordinal when the backend supports multi-device scopes.

Most backends have no per-ordinal concept and use the default no-op implementation. CUDA overrides this once its stream/context caches are device-keyed, allowing layer-split stages to load and execute on their selected GPU instead of relying on process-global defaults.

Source

fn supports_device_ordinal_scope() -> bool

Whether Self::with_device_ordinal actually switches backend execution to the requested ordinal.

Source

fn sync_before_host_readback(_ctx: &mut Self::Context)

Prepare pending GPU work for a following host readback.

Most backends either execute eagerly or synchronize as part of their device-to-host copy. Metal shared-buffer reads use the CPU pointer directly, so Metal must flush its command buffer before to_vec.

Source

fn activation_elem_size_bytes() -> usize

Byte width of buffers returned by Self::alloc.

CUDA activation scratch is fp16, while Metal and CPU scratch are fp32. Generic model code uses this for byte offsets into batched scratch buffers without checking concrete backend types.

Source

fn supports_llama_family_batched_decode() -> bool

Whether LlamaFamilyModel::decode_batch_internal may use its optimized batched decode path on this backend.

Backends that do not yet produce correct follow-up logits under concurrent dense decode should override this to force the per-item fallback until the optimized path is fixed.

Source

fn zero_buffer( _ctx: &mut Self::Context, _buf: &mut Self::Buffer, _len: usize, ) -> Result<()>

Zero the first len elements of a Self::Buffer. CUDA path uses cuMemsetD16Async; default returns unsupported.

Source

fn mla_attention( _ctx: &mut Self::Context, _q: &Self::Buffer, _kv_compressed: &Self::Buffer, _kv_rope: &Self::Buffer, _out: &mut Self::Buffer, _batch: usize, _q_len: usize, _kv_len: usize, _pos_offset: usize, _cfg: &AttnConfig, _kv_lora_rank: usize, _qk_rope_head_dim: usize, ) -> Result<()>

Multi-Head Latent Attention — DeepSeek V2 / V3’s compressed-KV attention variant. Extension point only; no backend implements it yet. DeepSeek V3 landing in Phase D/E will fill this in.

q: full Q [batch, num_heads, q_len, head_dim] kv_compressed: latent KV [batch, kv_len, kv_lora_rank] kv_rope: per-position rope-applied key heads [batch, kv_len, qk_rope_head_dim] out: [batch, num_heads, q_len, head_dim]

Source

fn embedding_lookup_dev( ctx: &mut Self::Context, table: &Self::Buffer, ids: &Self::Buffer, out: &mut Self::Buffer, batch: usize, dim: usize, )

Device-buffer variant of embedding_lookup for graph-capturable MoE routing — the gather step before phase-1 GEMM in moe_forward_bucketed. The host-slice embedding_lookup does clone_htod(ids) internally, which records stale host pointers under CUDA Graph capture replay.

ids: &Self::Buffer must be a device I32 buffer of batch elements (e.g. Qwen3MoeScratch::route_packed_idx_dev). batch is passed explicitly since a typed CudaBuf carries its element count but the caller often wants a partial gather.

Default impl: round-trip via to_vec + dispatch the host-slice variant. CUDA overrides.

Source

fn kv_cache_append_batched_per_cache( _ctx: &mut Self::Context, _caches: &[&Self::Buffer], _new_data: &Self::Buffer, _cache_lens: &Self::Buffer, _capacity: usize, _m: usize, _nkv: usize, _hd: usize, _slot: usize, ) -> Result<()>

Batched kv_cache_append across M caches in one launch. Each item writes its (head-major) K-or-V row into its own cache at offset read from cache_lens[i]. Replaces M sequential kv_cache_append_head_major calls with a single dispatch.

new_data layout: [m, nkv, hd] item-major (each item’s slice is contiguous, identical to the k/v_normed_batched produced by qk_norm_rope_batched_per_item). caches: per-cache [nkv, capacity, hd] head-major. cache_lens: device buffer (u32 storage, length ≥ m). Caller fills via B::write_u32_into BEFORE the call. Required for CUDA-graph capture: the kernel reads from this stable device buffer, so a captured graph can be replayed with new lens by just rewriting the buffer between launches.

Source

fn flash_attention_batched_per_cache( _ctx: &mut Self::Context, _q: &Self::Buffer, _k_caches: &[&Self::Buffer], _v_caches: &[&Self::Buffer], _kv_lens: &Self::Buffer, _out: &mut Self::Buffer, _nq: usize, _nkv: usize, _hd: usize, _scale: f32, _max_valid_kv: usize, _capacity: usize, _slot: usize, ) -> Result<()>

Batched flash_attention across M decode caches in one launch. Replaces the per-item flash_attention(q_len=1, ...) × M loop in the non-paged batched-decode path.

API takes Vec<&Buffer> for the per-cache K/V buffers (each [nkv, capacity, hd] head-major) plus host-side kv_lens. Backends that implement it must extract per-cache device pointers, build the device arrays the kernel needs, and launch one kernel covering all M items.

q layout: [m, nq, hd] item-major (matches the qk_norm_rope_batched_per_item output for q_len=1). out layout: [m, nq, hd] item-major — written directly into the caller’s batched attn_out buffer, no per-item copy needed.

CUDA-only for now (kernel batched_decode_attention exists in kernels/batched_decode_attention.cu). kv_lens: device buffer (u32 storage, length ≥ m) — same design as kv_cache_append_batched_per_cache::cache_lens.

Source

fn qk_norm_rope_batched_per_item( _ctx: &mut Self::Context, _input: &Self::Buffer, _norm_w: &Self::Buffer, _cos: &Self::Buffer, _sin: &Self::Buffer, _output: &mut Self::Buffer, _positions: &Self::Buffer, _m: usize, _heads: usize, _head_dim: usize, _eps: f32, _mode: i32, ) -> Result<()>

Batched per-item-position variant of qk_norm_rope for the non-paged batched-decode path. Each of the m items has its own absolute RoPE position (read from a device i32 buffer of length m). Layout is item-major in both input and output:

input [m, heads, head_dim] output [m, heads, head_dim] (no head-major transpose)

Item-major output keeps the per-item flash_attention slice contiguous (output[i * heads * head_dim ..] is item i’s whole Q tensor in head-major-equivalent layout for q_len=1).

Replaces the M sequential single-item launches in the existing forward_layer_batched_decode path with one batched dispatch. CUDA-only for now; other backends fall through to the default unsupported and the caller falls back to the per-item loop.

Source

fn split_qkv_norm_rope( _ctx: &mut Self::Context, _qkv: &Self::Buffer, _q_norm_w: &Self::Buffer, _k_norm_w: &Self::Buffer, _cos: &Self::Buffer, _sin: &Self::Buffer, _q_out: &mut Self::Buffer, _k_out: &mut Self::Buffer, _v_out: &mut Self::Buffer, _tokens: usize, _q_heads: usize, _kv_heads: usize, _head_dim: usize, _pos_offset: usize, _eps: f32, _qk_mode: i32, ) -> Result<()>

Fused split-QKV + QK-norm + RoPE + head-major transpose.

Single-dispatch replacement for the (split_qkv → 3× qk_norm_rope) chain on the decode-attention prelude. Reads the linear-layer fused-QKV output once and writes head-major Q/K/V directly into attention scratch.

qkv layout: [tokens, q_heads*hd + 2*kv_heads*hd]. q_out: [q_heads, tokens, hd]. k_out/v_out: [kv_heads, tokens, hd]. qk_mode: 1 = norm + half-split RoPE for Q/K (Qwen3 with QK-norm), 2 = half-split RoPE only for Q/K, 3 = interleaved RoPE only for Q/K (GGUF LLaMA / llama.cpp layout). V always falls through to transpose-only.

Default returns Unsupported. Backends that implement it are expected to be dramatically faster than the four-dispatch chain.

Source

fn split_qkv_norm_rope_into_cache( _ctx: &mut Self::Context, _qkv: &Self::Buffer, _q_norm_w: &Self::Buffer, _k_norm_w: &Self::Buffer, _cos: &Self::Buffer, _sin: &Self::Buffer, _q_out: &mut Self::Buffer, _cache_k: &mut Self::Buffer, _cache_v: &mut Self::Buffer, _tokens: usize, _q_heads: usize, _kv_heads: usize, _head_dim: usize, _pos_offset: usize, _eps: f32, _qk_mode: i32, _cache_len: usize, _cache_capacity: usize, ) -> Result<()>

Variant of Backend::split_qkv_norm_rope that writes the new K and V directly into pre-allocated head-major KV cache buffers at slot [kv_heads, cache_len .. cache_len + tokens, hd]. Eliminates the trailing kv_cache_append_head_major dispatch on the decode hot path. Q still lands in per-token head-major scratch (flash-attention reads it as the query).

Default returns Unsupported. Backends without the fused kernel can keep using split_qkv_norm_rope + kv_cache_append_head_major.

Source

fn transpose_token_to_head( _ctx: &mut Self::Context, _src: &Self::Buffer, _dst: &mut Self::Buffer, _tokens: usize, _heads: usize, _dim: usize, )

Inverse of transpose_head_to_token: [tokens, heads, dim] → [heads, tokens, dim]. Used by the CUDA paged_decode_attention wrapper to convert paged_varlen_attention’s token-major output back to the head-major layout that Qwen3MoeModel expects. Default panics — backends without a paged-KV CUDA path don’t hit this code.

Source

fn scaled_add_inplace( _ctx: &mut Self::Context, dst: &mut Self::Buffer, src: &Self::Buffer, scale: f32, len: usize, )

dst[i] += scale * src[i] — scalar-broadcast scaled add, in place.

MoE per-token combine writes out[b] += weight_k * expert_k(x[b]) for each top-K expert; this primitive is the per-call accumulate. Backends without a dedicated kernel can fall back to the default implementation, which round-trips through host memory — correct, but slow on a hot path. Override on any backend you actually dispatch MoE on.

Source

fn fused_silu_mul_split_strided( _ctx: &mut Self::Context, _gate_up: &Self::Buffer, _in_row_offset: usize, _out: &mut Self::Buffer, _out_row_offset: usize, _tokens: usize, _intermediate: usize, )

Strided variant of Backend::fused_silu_mul_split for the bucketed MoE path: reads gate_up rows starting at in_row_offset, writes out rows starting at out_row_offset.

Source

fn write_f32_to_activation( ctx: &mut Self::Context, dst: &mut Self::Buffer, data: &[f32], )

Source

fn argmax_rows_f16( _ctx: &mut Self::Context, logits: &Self::Buffer, m: usize, n: usize, ) -> Result<Vec<u32>>

Greedy-decode fast path: GPU argmax over each row of a [m, n] FP16 logits buffer, returning the m token indices on the host. Saves m × n × 2 bytes of D2H per call (e.g. 19.5 MB at c=32, vocab=152064) and the host-side argmax scan (~150 µs × m).

Default impl falls back to the slow path: full to_vec + host argmax. CUDA overrides with a native kernel + tiny D2H (m × 4 B). Backends that don’t override pay the same cost as to_vec + host argmax, so callers can call this unconditionally.

Source

fn from_weight_bytes(raw: &[u8], src_dtype: SrcDtype) -> Self::Buffer

Load a weight tensor straight from its on-disk byte representation, letting the backend pick its preferred storage dtype.

Default impl upcasts bf16/f16 to f32 via an intermediate Vec, matching pre-existing loader behaviour. Backends override this to go straight from raw bytes into a native half-precision buffer (e.g. Metal with FERRUM_METAL_DTYPE=f16), avoiding the transient 2× RAM spike.

Dyn Compatibility§

This trait is not dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety".

Implementors§