Skip to main content

Backend

Trait Backend 

Source
pub trait Backend:
    Send
    + Sync
    + Sized
    + 'static {
    type Buffer: Send + Sync;
    type Context;
    type GptqStore: Send + Sync;

Show 35 methods // Required methods fn new_context() -> Self::Context; fn sync(ctx: &mut Self::Context); fn gemm( ctx: &mut Self::Context, a: &Self::Buffer, b: &Self::Buffer, out: &mut Self::Buffer, m: usize, n: usize, k: usize, ); fn rms_norm( ctx: &mut Self::Context, x: &Self::Buffer, w: &Self::Buffer, eps: f32, out: &mut Self::Buffer, tokens: usize, dim: usize, ); fn fused_add_rms_norm( ctx: &mut Self::Context, residual: &mut Self::Buffer, x: &Self::Buffer, w: &Self::Buffer, eps: f32, out: &mut Self::Buffer, tokens: usize, dim: usize, ); fn flash_attention( ctx: &mut Self::Context, q: &Self::Buffer, k: &Self::Buffer, v: &Self::Buffer, out: &mut Self::Buffer, batch: usize, q_len: usize, kv_len: usize, pos_offset: usize, cfg: &AttnConfig, ); fn copy_slice( ctx: &mut Self::Context, src: &Self::Buffer, src_offset: usize, dst: &mut Self::Buffer, dst_offset: usize, len: usize, ); fn embedding_lookup( ctx: &mut Self::Context, table: &Self::Buffer, ids: &[u32], out: &mut Self::Buffer, dim: usize, ); fn split_qkv( ctx: &mut Self::Context, qkv: &Self::Buffer, q: &mut Self::Buffer, k: &mut Self::Buffer, v: &mut Self::Buffer, tokens: usize, q_dim: usize, kv_dim: usize, ); fn fused_silu_mul_split( ctx: &mut Self::Context, gate_up: &Self::Buffer, out: &mut Self::Buffer, tokens: usize, im: usize, ); fn qk_norm_rope( ctx: &mut Self::Context, input: &Self::Buffer, norm_w: &Self::Buffer, cos: &Self::Buffer, sin: &Self::Buffer, output: &mut Self::Buffer, tokens: usize, heads: usize, head_dim: usize, pos_offset: usize, eps: f32, mode: i32, ); fn kv_cache_append_head_major( ctx: &mut Self::Context, cache_k: &mut Self::Buffer, cache_v: &mut Self::Buffer, cache_len: usize, cache_capacity: usize, new_k_head_major: &Self::Buffer, new_v_head_major: &Self::Buffer, new_tokens: usize, nkv: usize, hd: usize, ); fn transpose_head_to_token( ctx: &mut Self::Context, src: &Self::Buffer, dst: &mut Self::Buffer, tokens: usize, heads: usize, dim: usize, ); fn add_inplace( ctx: &mut Self::Context, residual: &mut Self::Buffer, x: &Self::Buffer, len: usize, ); fn add_bias( ctx: &mut Self::Context, data: &mut Self::Buffer, bias: &Self::Buffer, rows: usize, cols: usize, ); fn layer_norm( ctx: &mut Self::Context, x: &Self::Buffer, gamma: &Self::Buffer, beta: &Self::Buffer, eps: f32, out: &mut Self::Buffer, tokens: usize, dim: usize, ); fn gelu( ctx: &mut Self::Context, x: &Self::Buffer, out: &mut Self::Buffer, len: usize, ); fn alloc(len: usize) -> Self::Buffer; fn to_vec(buf: &Self::Buffer, len: usize) -> Vec<f32>; fn from_slice(data: &[f32]) -> Self::Buffer; // Provided methods fn set_decode_state(_ctx: &mut Self::Context, _token: u32, _step: u32) { ... } fn set_dev_state_mode(_ctx: &mut Self::Context, _enable: bool) { ... } fn begin_graph_capture(_ctx: &mut Self::Context) -> Result<()> { ... } fn end_graph_capture(_ctx: &mut Self::Context) -> Result<()> { ... } fn replay_last_graph(_ctx: &mut Self::Context) -> Result<bool> { ... } fn reset_graph(_ctx: &mut Self::Context) { ... } fn load_gptq( _qweight: &[i32], _scales: &[f32], _qzeros: &[i32], _g_idx: Option<&[i32]>, _bits: u32, _group_size: usize, _k: usize, _n: usize, ) -> Result<Self::GptqStore> { ... } fn gemm_gptq( _ctx: &mut Self::Context, _a: &Self::Buffer, _weight: &Self::GptqStore, _out: &mut Self::Buffer, _m: usize, ) -> Result<()> { ... } fn mla_attention( _ctx: &mut Self::Context, _q: &Self::Buffer, _kv_compressed: &Self::Buffer, _kv_rope: &Self::Buffer, _out: &mut Self::Buffer, _batch: usize, _q_len: usize, _kv_len: usize, _pos_offset: usize, _cfg: &AttnConfig, _kv_lora_rank: usize, _qk_rope_head_dim: usize, ) -> Result<()> { ... } fn gemm_quant( _ctx: &mut Self::Context, _a: &Self::Buffer, _weights: &QuantWeights<'_, Self>, _out: &mut Self::Buffer, _m: usize, _n: usize, _k: usize, kind: &QuantKind, ) -> Result<()> { ... } fn world_size(_ctx: &Self::Context) -> usize { ... } fn rank(_ctx: &Self::Context) -> usize { ... } fn all_reduce( _ctx: &mut Self::Context, _buf: &mut Self::Buffer, _len: usize, _op: ReduceOp, ) { ... } fn all_gather( _ctx: &mut Self::Context, _local: &Self::Buffer, _global: &mut Self::Buffer, _local_len: usize, ) { ... } fn broadcast( _ctx: &mut Self::Context, _buf: &mut Self::Buffer, _len: usize, _src_rank: usize, ) { ... }
}
Expand description

The core abstraction over CUDA / Metal / CPU.

Key design: operations take a &mut Self::Context which accumulates work.

  • CPU: Context is () — ops execute immediately.
  • Metal: Context is a CommandBuffer — ops encode into it, flushed on sync().
  • CUDA: Context is a CudaStream — ops launch on the stream, synced on sync().

layer_forward passes the context through all ops in a layer. ModelRunner calls sync() only when it needs results (e.g., reading logits).

Required Associated Types§

Source

type Buffer: Send + Sync

Source

type Context

Execution context that accumulates GPU work.

  • CPU: () (no-op, ops execute inline)
  • Metal: wraps a CommandBuffer
  • CUDA: wraps a CudaStream
Source

type GptqStore: Send + Sync

Opaque per-backend GPTQ weight representation.

  • CPU: dequantized f32 weights (run as regular GEMM)
  • Metal: () — unsupported; gemm_gptq errors
  • CUDA: MarlinWeight — pre-repacked tiles + permuted scales

Each backend repacks raw GPTQ tensors (qweight/scales/qzeros, all i32/f16) into its preferred format at model load time, so inference doesn’t pay the repack cost per forward pass.

Required Methods§

Source

fn new_context() -> Self::Context

Create a new execution context (begin accumulating work).

Source

fn sync(ctx: &mut Self::Context)

Flush accumulated work and wait for completion. CPU: no-op. Metal: commit + waitUntilCompleted. CUDA: stream sync.

Source

fn gemm( ctx: &mut Self::Context, a: &Self::Buffer, b: &Self::Buffer, out: &mut Self::Buffer, m: usize, n: usize, k: usize, )

Source

fn rms_norm( ctx: &mut Self::Context, x: &Self::Buffer, w: &Self::Buffer, eps: f32, out: &mut Self::Buffer, tokens: usize, dim: usize, )

Source

fn fused_add_rms_norm( ctx: &mut Self::Context, residual: &mut Self::Buffer, x: &Self::Buffer, w: &Self::Buffer, eps: f32, out: &mut Self::Buffer, tokens: usize, dim: usize, )

Source

fn flash_attention( ctx: &mut Self::Context, q: &Self::Buffer, k: &Self::Buffer, v: &Self::Buffer, out: &mut Self::Buffer, batch: usize, q_len: usize, kv_len: usize, pos_offset: usize, cfg: &AttnConfig, )

Source

fn copy_slice( ctx: &mut Self::Context, src: &Self::Buffer, src_offset: usize, dst: &mut Self::Buffer, dst_offset: usize, len: usize, )

Copy len floats from src[src_offset..] to dst[dst_offset..].

Needed for Qwen3Model::prefill to pluck the last token’s hidden state out of residual[seq_len, h] without round-tripping through host RAM. Backend::copy is the offset-free variant; copy_slice additionally supports non-zero source and destination offsets.

Source

fn embedding_lookup( ctx: &mut Self::Context, table: &Self::Buffer, ids: &[u32], out: &mut Self::Buffer, dim: usize, )

Source

fn split_qkv( ctx: &mut Self::Context, qkv: &Self::Buffer, q: &mut Self::Buffer, k: &mut Self::Buffer, v: &mut Self::Buffer, tokens: usize, q_dim: usize, kv_dim: usize, )

Split fused QKV [tokens, q_dim+2*kv_dim] into separate Q, K, V buffers. Q: [tokens, q_dim], K: [tokens, kv_dim], V: [tokens, kv_dim]

Source

fn fused_silu_mul_split( ctx: &mut Self::Context, gate_up: &Self::Buffer, out: &mut Self::Buffer, tokens: usize, im: usize, )

Split fused gate_up [tokens, 2*im] into gate [tokens, im] and up [tokens, im], then compute SiLU(gate) * up → out [tokens, im].

Source

fn qk_norm_rope( ctx: &mut Self::Context, input: &Self::Buffer, norm_w: &Self::Buffer, cos: &Self::Buffer, sin: &Self::Buffer, output: &mut Self::Buffer, tokens: usize, heads: usize, head_dim: usize, pos_offset: usize, eps: f32, mode: i32, )

Fused QK-norm + RoPE + transpose-to-head-major.

mode selects the operation: 0 = transpose only (typical for V, which needs no norm and no RoPE) 1 = per-head RMS norm + RoPE + transpose (Q/K with QK-norm, Qwen3) 2 = RoPE + transpose (Q/K without QK-norm, Llama/Mistral)

input: [tokens, heads, head_dim] (token-major, output of split_qkv) output: [heads, tokens, head_dim] (head-major, ready for flash_attn / kv_cache_append)

pos_offset is the position of token 0 (decode uses current seq len; prefill uses 0). Within the batch, positions are taken as pos_offset + i.

This is the primary attention-input preparation op. Backends that have a fused kernel (Metal’s qk_norm_rope_transpose_f32) will be dramatically faster than composing norm + rope + transpose separately; the CPU fallback lowers to the individual ops.

Source

fn kv_cache_append_head_major( ctx: &mut Self::Context, cache_k: &mut Self::Buffer, cache_v: &mut Self::Buffer, cache_len: usize, cache_capacity: usize, new_k_head_major: &Self::Buffer, new_v_head_major: &Self::Buffer, new_tokens: usize, nkv: usize, hd: usize, )

Append new K/V into a pre-allocated head-major cache buffer.

cache_k / cache_v: [nkv, capacity, hd] (head-major, pre-allocated) new_k_head_major / new_v_head_major: [nkv, new_tokens, hd] — produced directly by qk_norm_rope, no extra transpose needed.

In-place append at slot [nkv, cache_len..cache_len+new_tokens, hd]. Caller owns cache_len bookkeeping.

Source

fn transpose_head_to_token( ctx: &mut Self::Context, src: &Self::Buffer, dst: &mut Self::Buffer, tokens: usize, heads: usize, dim: usize, )

Transpose [heads, tokens, dim] → [tokens, heads, dim]. Called after flash_attention to restore token-major layout for O-proj.

Source

fn add_inplace( ctx: &mut Self::Context, residual: &mut Self::Buffer, x: &Self::Buffer, len: usize, )

residual[i] += x[i] (in-place)

Source

fn add_bias( ctx: &mut Self::Context, data: &mut Self::Buffer, bias: &Self::Buffer, rows: usize, cols: usize, )

Broadcast bias add: data[r, c] += bias[c] for every row. Required by Bert / Clip / Whisper whose linear projections carry a bias.

Source

fn layer_norm( ctx: &mut Self::Context, x: &Self::Buffer, gamma: &Self::Buffer, beta: &Self::Buffer, eps: f32, out: &mut Self::Buffer, tokens: usize, dim: usize, )

Full LayerNorm (mean + variance normalisation + affine), distinct from the rms_norm used by Llama-family decoders. out[r, c] = ((x[r, c] - mean) / sqrt(var + eps)) * gamma[c] + beta[c] Where mean and var are reduced over the last dim (cols).

Source

fn gelu( ctx: &mut Self::Context, x: &Self::Buffer, out: &mut Self::Buffer, len: usize, )

Element-wise GELU activation (erf-based, matches PyTorch default).

Source

fn alloc(len: usize) -> Self::Buffer

Source

fn to_vec(buf: &Self::Buffer, len: usize) -> Vec<f32>

Source

fn from_slice(data: &[f32]) -> Self::Buffer

Provided Methods§

Source

fn set_decode_state(_ctx: &mut Self::Context, _token: u32, _step: u32)

Update per-step dynamic state (token id, step/pos). Fast (3x memcpy).

Source

fn set_dev_state_mode(_ctx: &mut Self::Context, _enable: bool)

Toggle between scalar-arg kernels (normal) and _dyn kernels that read their dynamic scalar args from device memory (graph-friendly).

Source

fn begin_graph_capture(_ctx: &mut Self::Context) -> Result<()>

Begin stream capture. Subsequent kernel launches are recorded into a pending graph instead of executing eagerly.

Source

fn end_graph_capture(_ctx: &mut Self::Context) -> Result<()>

End stream capture and install the captured graph as this context’s “last graph” for future replay_last_graph calls.

Source

fn replay_last_graph(_ctx: &mut Self::Context) -> Result<bool>

Replay the last captured graph. Returns Ok(false) if no graph is cached; caller should run eager.

Source

fn reset_graph(_ctx: &mut Self::Context)

Drop the cached decode graph — required when the KV cache it was captured against is about to be freed (e.g. request release), since the graph holds raw device pointers into that cache.

Source

fn load_gptq( _qweight: &[i32], _scales: &[f32], _qzeros: &[i32], _g_idx: Option<&[i32]>, _bits: u32, _group_size: usize, _k: usize, _n: usize, ) -> Result<Self::GptqStore>

Repack raw GPTQ tensors into the backend’s preferred format. Called once per layer at model load time.

Inputs are host-side slices (CPU memory) — the loader reads from safetensors and hands them off; each backend uploads + repacks per its own strategy. bits is typically 4; group_size is typically 128.

Source

fn gemm_gptq( _ctx: &mut Self::Context, _a: &Self::Buffer, _weight: &Self::GptqStore, _out: &mut Self::Buffer, _m: usize, ) -> Result<()>

GEMM with pre-loaded GPTQ weights. out[m, n] = a[m, k] @ dequant(weight)^T

Source

fn mla_attention( _ctx: &mut Self::Context, _q: &Self::Buffer, _kv_compressed: &Self::Buffer, _kv_rope: &Self::Buffer, _out: &mut Self::Buffer, _batch: usize, _q_len: usize, _kv_len: usize, _pos_offset: usize, _cfg: &AttnConfig, _kv_lora_rank: usize, _qk_rope_head_dim: usize, ) -> Result<()>

Multi-Head Latent Attention — DeepSeek V2 / V3’s compressed-KV attention variant. Extension point only; no backend implements it yet. DeepSeek V3 landing in Phase D/E will fill this in.

q: full Q [batch, num_heads, q_len, head_dim] kv_compressed: latent KV [batch, kv_len, kv_lora_rank] kv_rope: per-position rope-applied key heads [batch, kv_len, qk_rope_head_dim] out: [batch, num_heads, q_len, head_dim]

Source

fn gemm_quant( _ctx: &mut Self::Context, _a: &Self::Buffer, _weights: &QuantWeights<'_, Self>, _out: &mut Self::Buffer, _m: usize, _n: usize, _k: usize, kind: &QuantKind, ) -> Result<()>

GEMM with packed-quantized B matrix. m/n/k describe the dense equivalent ([m,n] = [m,k] @ [k,n]^T).

Source

fn world_size(_ctx: &Self::Context) -> usize

Source

fn rank(_ctx: &Self::Context) -> usize

Source

fn all_reduce( _ctx: &mut Self::Context, _buf: &mut Self::Buffer, _len: usize, _op: ReduceOp, )

Source

fn all_gather( _ctx: &mut Self::Context, _local: &Self::Buffer, _global: &mut Self::Buffer, _local_len: usize, )

Source

fn broadcast( _ctx: &mut Self::Context, _buf: &mut Self::Buffer, _len: usize, _src_rank: usize, )

Dyn Compatibility§

This trait is not dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.

Implementors§