Skip to main content

Backend

Trait Backend 

Source
pub trait Backend:
    Send
    + Sync
    + Sized
    + 'static {
    type Buffer: Send + Sync;
    type Context;
    type Timer: BackendTimer<Self>;

Show 38 methods // Required methods fn make_timer() -> Self::Timer; fn new_context() -> Self::Context; fn sync(ctx: &mut Self::Context); fn alloc_typed(dtype: Dtype, n: usize) -> Self::Buffer; fn from_slice_typed<T: HostDtype>(data: &[T]) -> Self::Buffer; fn write_typed<T: HostDtype>( ctx: &mut Self::Context, dst: &mut Self::Buffer, data: &[T], ); fn gemm( ctx: &mut Self::Context, a: &Self::Buffer, b: &Self::Buffer, out: &mut Self::Buffer, m: usize, n: usize, k: usize, ); fn rms_norm( ctx: &mut Self::Context, x: &Self::Buffer, w: &Self::Buffer, eps: f32, out: &mut Self::Buffer, tokens: usize, dim: usize, ); fn fused_add_rms_norm( ctx: &mut Self::Context, residual: &mut Self::Buffer, x: &Self::Buffer, w: &Self::Buffer, eps: f32, out: &mut Self::Buffer, tokens: usize, dim: usize, ); fn flash_attention( ctx: &mut Self::Context, q: &Self::Buffer, k: &Self::Buffer, v: &Self::Buffer, out: &mut Self::Buffer, batch: usize, q_len: usize, kv_len: usize, pos_offset: usize, cfg: &AttnConfig, ); fn copy_slice( ctx: &mut Self::Context, src: &Self::Buffer, src_offset: usize, dst: &mut Self::Buffer, dst_offset: usize, len: usize, ); fn embedding_lookup( ctx: &mut Self::Context, table: &Self::Buffer, ids: &[u32], out: &mut Self::Buffer, dim: usize, ); fn split_qkv( ctx: &mut Self::Context, qkv: &Self::Buffer, q: &mut Self::Buffer, k: &mut Self::Buffer, v: &mut Self::Buffer, tokens: usize, q_dim: usize, kv_dim: usize, ); fn fused_silu_mul_split( ctx: &mut Self::Context, gate_up: &Self::Buffer, out: &mut Self::Buffer, tokens: usize, im: usize, ); fn qk_norm_rope( ctx: &mut Self::Context, input: &Self::Buffer, norm_w: &Self::Buffer, cos: &Self::Buffer, sin: &Self::Buffer, output: &mut Self::Buffer, tokens: usize, heads: usize, head_dim: usize, pos_offset: usize, eps: f32, mode: i32, ); fn kv_cache_append_head_major( ctx: &mut Self::Context, cache_k: &mut Self::Buffer, cache_v: &mut Self::Buffer, cache_len: usize, cache_capacity: usize, new_k_head_major: &Self::Buffer, new_v_head_major: &Self::Buffer, new_tokens: usize, nkv: usize, hd: usize, ); fn transpose_head_to_token( ctx: &mut Self::Context, src: &Self::Buffer, dst: &mut Self::Buffer, tokens: usize, heads: usize, dim: usize, ); fn add_inplace( ctx: &mut Self::Context, residual: &mut Self::Buffer, x: &Self::Buffer, len: usize, ); fn add_bias( ctx: &mut Self::Context, data: &mut Self::Buffer, bias: &Self::Buffer, rows: usize, cols: usize, ); fn layer_norm( ctx: &mut Self::Context, x: &Self::Buffer, gamma: &Self::Buffer, beta: &Self::Buffer, eps: f32, out: &mut Self::Buffer, tokens: usize, dim: usize, ); fn gelu( ctx: &mut Self::Context, x: &Self::Buffer, out: &mut Self::Buffer, len: usize, ); fn alloc(len: usize) -> Self::Buffer; fn to_vec(buf: &Self::Buffer, len: usize) -> Vec<f32>; fn from_slice(data: &[f32]) -> Self::Buffer; // Provided methods fn is_metal_backend() -> bool { ... } fn zero_buffer( _ctx: &mut Self::Context, _buf: &mut Self::Buffer, _len: usize, ) -> Result<()> { ... } fn mla_attention( _ctx: &mut Self::Context, _q: &Self::Buffer, _kv_compressed: &Self::Buffer, _kv_rope: &Self::Buffer, _out: &mut Self::Buffer, _batch: usize, _q_len: usize, _kv_len: usize, _pos_offset: usize, _cfg: &AttnConfig, _kv_lora_rank: usize, _qk_rope_head_dim: usize, ) -> Result<()> { ... } fn embedding_lookup_dev( ctx: &mut Self::Context, table: &Self::Buffer, ids: &Self::Buffer, out: &mut Self::Buffer, batch: usize, dim: usize, ) { ... } fn kv_cache_append_batched_per_cache( _ctx: &mut Self::Context, _caches: &[&Self::Buffer], _new_data: &Self::Buffer, _cache_lens: &Self::Buffer, _capacity: usize, _m: usize, _nkv: usize, _hd: usize, _slot: usize, ) -> Result<()> { ... } fn flash_attention_batched_per_cache( _ctx: &mut Self::Context, _q: &Self::Buffer, _k_caches: &[&Self::Buffer], _v_caches: &[&Self::Buffer], _kv_lens: &Self::Buffer, _out: &mut Self::Buffer, _nq: usize, _nkv: usize, _hd: usize, _scale: f32, _max_valid_kv: usize, _capacity: usize, _slot: usize, ) -> Result<()> { ... } fn qk_norm_rope_batched_per_item( _ctx: &mut Self::Context, _input: &Self::Buffer, _norm_w: &Self::Buffer, _cos: &Self::Buffer, _sin: &Self::Buffer, _output: &mut Self::Buffer, _positions: &Self::Buffer, _m: usize, _heads: usize, _head_dim: usize, _eps: f32, _mode: i32, ) -> Result<()> { ... } fn split_qkv_norm_rope( _ctx: &mut Self::Context, _qkv: &Self::Buffer, _q_norm_w: &Self::Buffer, _k_norm_w: &Self::Buffer, _cos: &Self::Buffer, _sin: &Self::Buffer, _q_out: &mut Self::Buffer, _k_out: &mut Self::Buffer, _v_out: &mut Self::Buffer, _tokens: usize, _q_heads: usize, _kv_heads: usize, _head_dim: usize, _pos_offset: usize, _eps: f32, _qk_mode: i32, ) -> Result<()> { ... } fn split_qkv_norm_rope_into_cache( _ctx: &mut Self::Context, _qkv: &Self::Buffer, _q_norm_w: &Self::Buffer, _k_norm_w: &Self::Buffer, _cos: &Self::Buffer, _sin: &Self::Buffer, _q_out: &mut Self::Buffer, _cache_k: &mut Self::Buffer, _cache_v: &mut Self::Buffer, _tokens: usize, _q_heads: usize, _kv_heads: usize, _head_dim: usize, _pos_offset: usize, _eps: f32, _qk_mode: i32, _cache_len: usize, _cache_capacity: usize, ) -> Result<()> { ... } fn transpose_token_to_head( _ctx: &mut Self::Context, _src: &Self::Buffer, _dst: &mut Self::Buffer, _tokens: usize, _heads: usize, _dim: usize, ) { ... } fn scaled_add_inplace( _ctx: &mut Self::Context, dst: &mut Self::Buffer, src: &Self::Buffer, scale: f32, len: usize, ) { ... } fn fused_silu_mul_split_strided( _ctx: &mut Self::Context, _gate_up: &Self::Buffer, _in_row_offset: usize, _out: &mut Self::Buffer, _out_row_offset: usize, _tokens: usize, _intermediate: usize, ) { ... } fn argmax_rows_f16( _ctx: &mut Self::Context, logits: &Self::Buffer, m: usize, n: usize, ) -> Result<Vec<u32>> { ... } fn from_weight_bytes(raw: &[u8], src_dtype: SrcDtype) -> Self::Buffer { ... }
}
Expand description

The core abstraction over CUDA / Metal / CPU.

Key design: operations take a &mut Self::Context which accumulates work.

  • CPU: Context is () — ops execute immediately.
  • Metal: Context is a CommandBuffer — ops encode into it, flushed on sync().
  • CUDA: Context is a CudaStream — ops launch on the stream, synced on sync().

layer_forward passes the context through all ops in a layer. ModelRunner calls sync() only when it needs results (e.g., reading logits).

Required Associated Types§

Source

type Buffer: Send + Sync

Source

type Context

Execution context that accumulates GPU work.

  • CPU: () (no-op, ops execute inline)
  • Metal: wraps a CommandBuffer
  • CUDA: wraps a CudaStream
Source

type Timer: BackendTimer<Self>

GPU-side timer scoped to this backend. See super::timer — CPU: Instant; Metal: sync-wrap; CUDA: cuEvent. PLAYBOOK § 1.1.

Required Methods§

Source

fn make_timer() -> Self::Timer

Factory for Self::Timer — exists so call sites that have a <B: Backend> parameter can spawn a timer without importing the concrete impl. PLAYBOOK § 1.2.

Source

fn new_context() -> Self::Context

Opaque per-backend GPTQ weight representation.

  • CPU: dequantized f32 weights (run as regular GEMM)
  • Metal: () — unsupported; gemm_gptq errors Create a new execution context (begin accumulating work).
Source

fn sync(ctx: &mut Self::Context)

Flush accumulated work and wait for completion. CPU: no-op. Metal: commit + waitUntilCompleted. CUDA: stream sync.

Source

fn alloc_typed(dtype: Dtype, n: usize) -> Self::Buffer

Phase D step 2+3: unified typed allocator. Replaces per-dtype alloc_u32 / alloc_typed_i32 / etc. The buffer is dtype- tagged at the wrapper level (CudaBuf::U32, MetalBuf with Dtype::U32, CpuBuf::U32), so reads/writes through .as_<T>() accessors get the correct byte count automatically.

Source

fn from_slice_typed<T: HostDtype>(data: &[T]) -> Self::Buffer

Upload typed host data — replaces from_slice_i32 / from_slice_u32 etc. The host element type T carries its Dtype via the HostDtype marker so dispatch in the impl is a one-line match T::DTYPE.

Source

fn write_typed<T: HostDtype>( ctx: &mut Self::Context, dst: &mut Self::Buffer, data: &[T], )

In-place typed write — replaces write_u32 / write_i32_into / write_f32_into. The buffer must already be dtype-tagged matching T::DTYPE (typically alloc’d via alloc_typed or from_slice_typed).

Source

fn gemm( ctx: &mut Self::Context, a: &Self::Buffer, b: &Self::Buffer, out: &mut Self::Buffer, m: usize, n: usize, k: usize, )

Source

fn rms_norm( ctx: &mut Self::Context, x: &Self::Buffer, w: &Self::Buffer, eps: f32, out: &mut Self::Buffer, tokens: usize, dim: usize, )

Source

fn fused_add_rms_norm( ctx: &mut Self::Context, residual: &mut Self::Buffer, x: &Self::Buffer, w: &Self::Buffer, eps: f32, out: &mut Self::Buffer, tokens: usize, dim: usize, )

Source

fn flash_attention( ctx: &mut Self::Context, q: &Self::Buffer, k: &Self::Buffer, v: &Self::Buffer, out: &mut Self::Buffer, batch: usize, q_len: usize, kv_len: usize, pos_offset: usize, cfg: &AttnConfig, )

Source

fn copy_slice( ctx: &mut Self::Context, src: &Self::Buffer, src_offset: usize, dst: &mut Self::Buffer, dst_offset: usize, len: usize, )

Copy len floats from src[src_offset..] to dst[dst_offset..].

Needed for Qwen3Model::prefill to pluck the last token’s hidden state out of residual[seq_len, h] without round-tripping through host RAM. Backend::copy is the offset-free variant; copy_slice additionally supports non-zero source and destination offsets.

Source

fn embedding_lookup( ctx: &mut Self::Context, table: &Self::Buffer, ids: &[u32], out: &mut Self::Buffer, dim: usize, )

Source

fn split_qkv( ctx: &mut Self::Context, qkv: &Self::Buffer, q: &mut Self::Buffer, k: &mut Self::Buffer, v: &mut Self::Buffer, tokens: usize, q_dim: usize, kv_dim: usize, )

Split fused QKV [tokens, q_dim+2*kv_dim] into separate Q, K, V buffers. Q: [tokens, q_dim], K: [tokens, kv_dim], V: [tokens, kv_dim]

Source

fn fused_silu_mul_split( ctx: &mut Self::Context, gate_up: &Self::Buffer, out: &mut Self::Buffer, tokens: usize, im: usize, )

Split fused gate_up [tokens, 2*im] into gate [tokens, im] and up [tokens, im], then compute SiLU(gate) * up → out [tokens, im].

Source

fn qk_norm_rope( ctx: &mut Self::Context, input: &Self::Buffer, norm_w: &Self::Buffer, cos: &Self::Buffer, sin: &Self::Buffer, output: &mut Self::Buffer, tokens: usize, heads: usize, head_dim: usize, pos_offset: usize, eps: f32, mode: i32, )

Fused QK-norm + RoPE + transpose-to-head-major.

mode selects the operation: 0 = transpose only (typical for V, which needs no norm and no RoPE) 1 = per-head RMS norm + RoPE + transpose (Q/K with QK-norm, Qwen3) 2 = RoPE + transpose (Q/K without QK-norm, Llama/Mistral)

input: [tokens, heads, head_dim] (token-major, output of split_qkv) output: [heads, tokens, head_dim] (head-major, ready for flash_attn / kv_cache_append)

pos_offset is the position of token 0 (decode uses current seq len; prefill uses 0). Within the batch, positions are taken as pos_offset + i.

This is the primary attention-input preparation op. Backends that have a fused kernel (Metal’s qk_norm_rope_transpose_f32) will be dramatically faster than composing norm + rope + transpose separately; the CPU fallback lowers to the individual ops.

Source

fn kv_cache_append_head_major( ctx: &mut Self::Context, cache_k: &mut Self::Buffer, cache_v: &mut Self::Buffer, cache_len: usize, cache_capacity: usize, new_k_head_major: &Self::Buffer, new_v_head_major: &Self::Buffer, new_tokens: usize, nkv: usize, hd: usize, )

Append new K/V into a pre-allocated head-major cache buffer.

cache_k / cache_v: [nkv, capacity, hd] (head-major, pre-allocated) new_k_head_major / new_v_head_major: [nkv, new_tokens, hd] — produced directly by qk_norm_rope, no extra transpose needed.

In-place append at slot [nkv, cache_len..cache_len+new_tokens, hd]. Caller owns cache_len bookkeeping.

Source

fn transpose_head_to_token( ctx: &mut Self::Context, src: &Self::Buffer, dst: &mut Self::Buffer, tokens: usize, heads: usize, dim: usize, )

Transpose [heads, tokens, dim] → [tokens, heads, dim]. Called after flash_attention to restore token-major layout for O-proj.

Source

fn add_inplace( ctx: &mut Self::Context, residual: &mut Self::Buffer, x: &Self::Buffer, len: usize, )

residual[i] += x[i] (in-place)

Source

fn add_bias( ctx: &mut Self::Context, data: &mut Self::Buffer, bias: &Self::Buffer, rows: usize, cols: usize, )

Broadcast bias add: data[r, c] += bias[c] for every row. Required by Bert / Clip / Whisper whose linear projections carry a bias.

Source

fn layer_norm( ctx: &mut Self::Context, x: &Self::Buffer, gamma: &Self::Buffer, beta: &Self::Buffer, eps: f32, out: &mut Self::Buffer, tokens: usize, dim: usize, )

Full LayerNorm (mean + variance normalisation + affine), distinct from the rms_norm used by Llama-family decoders. out[r, c] = ((x[r, c] - mean) / sqrt(var + eps)) * gamma[c] + beta[c] Where mean and var are reduced over the last dim (cols).

Source

fn gelu( ctx: &mut Self::Context, x: &Self::Buffer, out: &mut Self::Buffer, len: usize, )

Element-wise GELU activation (erf-based, matches PyTorch default).

Source

fn alloc(len: usize) -> Self::Buffer

Source

fn to_vec(buf: &Self::Buffer, len: usize) -> Vec<f32>

Source

fn from_slice(data: &[f32]) -> Self::Buffer

Provided Methods§

Source

fn is_metal_backend() -> bool

True for the Apple Metal backend.

Keep this as a backend capability instead of matching on type names in model code. It is used only for backend-specific safety fallbacks where a generic optimized path is known to be incorrect on one backend.

Source

fn zero_buffer( _ctx: &mut Self::Context, _buf: &mut Self::Buffer, _len: usize, ) -> Result<()>

Zero the first len elements of a Self::Buffer. CUDA path uses cuMemsetD16Async; default returns unsupported.

Source

fn mla_attention( _ctx: &mut Self::Context, _q: &Self::Buffer, _kv_compressed: &Self::Buffer, _kv_rope: &Self::Buffer, _out: &mut Self::Buffer, _batch: usize, _q_len: usize, _kv_len: usize, _pos_offset: usize, _cfg: &AttnConfig, _kv_lora_rank: usize, _qk_rope_head_dim: usize, ) -> Result<()>

Multi-Head Latent Attention — DeepSeek V2 / V3’s compressed-KV attention variant. Extension point only; no backend implements it yet. DeepSeek V3 landing in Phase D/E will fill this in.

q: full Q [batch, num_heads, q_len, head_dim] kv_compressed: latent KV [batch, kv_len, kv_lora_rank] kv_rope: per-position rope-applied key heads [batch, kv_len, qk_rope_head_dim] out: [batch, num_heads, q_len, head_dim]

Source

fn embedding_lookup_dev( ctx: &mut Self::Context, table: &Self::Buffer, ids: &Self::Buffer, out: &mut Self::Buffer, batch: usize, dim: usize, )

Device-buffer variant of embedding_lookup for graph-capturable MoE routing — the gather step before phase-1 GEMM in moe_forward_bucketed. The host-slice embedding_lookup does clone_htod(ids) internally, which records stale host pointers under CUDA Graph capture replay.

ids: &Self::Buffer must be a device I32 buffer of batch elements (e.g. Qwen3MoeScratch::route_packed_idx_dev). batch is passed explicitly since a typed CudaBuf carries its element count but the caller often wants a partial gather.

Default impl: round-trip via to_vec + dispatch the host-slice variant. CUDA overrides.

Source

fn kv_cache_append_batched_per_cache( _ctx: &mut Self::Context, _caches: &[&Self::Buffer], _new_data: &Self::Buffer, _cache_lens: &Self::Buffer, _capacity: usize, _m: usize, _nkv: usize, _hd: usize, _slot: usize, ) -> Result<()>

Batched kv_cache_append across M caches in one launch. Each item writes its (head-major) K-or-V row into its own cache at offset read from cache_lens[i]. Replaces M sequential kv_cache_append_head_major calls with a single dispatch.

new_data layout: [m, nkv, hd] item-major (each item’s slice is contiguous, identical to the k/v_normed_batched produced by qk_norm_rope_batched_per_item). caches: per-cache [nkv, capacity, hd] head-major. cache_lens: device buffer (u32 storage, length ≥ m). Caller fills via B::write_u32_into BEFORE the call. Required for CUDA-graph capture: the kernel reads from this stable device buffer, so a captured graph can be replayed with new lens by just rewriting the buffer between launches.

Source

fn flash_attention_batched_per_cache( _ctx: &mut Self::Context, _q: &Self::Buffer, _k_caches: &[&Self::Buffer], _v_caches: &[&Self::Buffer], _kv_lens: &Self::Buffer, _out: &mut Self::Buffer, _nq: usize, _nkv: usize, _hd: usize, _scale: f32, _max_valid_kv: usize, _capacity: usize, _slot: usize, ) -> Result<()>

Batched flash_attention across M decode caches in one launch. Replaces the per-item flash_attention(q_len=1, ...) × M loop in the non-paged batched-decode path.

API takes Vec<&Buffer> for the per-cache K/V buffers (each [nkv, capacity, hd] head-major) plus host-side kv_lens. Backends that implement it must extract per-cache device pointers, build the device arrays the kernel needs, and launch one kernel covering all M items.

q layout: [m, nq, hd] item-major (matches the qk_norm_rope_batched_per_item output for q_len=1). out layout: [m, nq, hd] item-major — written directly into the caller’s batched attn_out buffer, no per-item copy needed.

CUDA-only for now (kernel batched_decode_attention exists in kernels/batched_decode_attention.cu). kv_lens: device buffer (u32 storage, length ≥ m) — same design as kv_cache_append_batched_per_cache::cache_lens.

Source

fn qk_norm_rope_batched_per_item( _ctx: &mut Self::Context, _input: &Self::Buffer, _norm_w: &Self::Buffer, _cos: &Self::Buffer, _sin: &Self::Buffer, _output: &mut Self::Buffer, _positions: &Self::Buffer, _m: usize, _heads: usize, _head_dim: usize, _eps: f32, _mode: i32, ) -> Result<()>

Batched per-item-position variant of qk_norm_rope for the non-paged batched-decode path. Each of the m items has its own absolute RoPE position (read from a device i32 buffer of length m). Layout is item-major in both input and output:

input [m, heads, head_dim] output [m, heads, head_dim] (no head-major transpose)

Item-major output keeps the per-item flash_attention slice contiguous (output[i * heads * head_dim ..] is item i’s whole Q tensor in head-major-equivalent layout for q_len=1).

Replaces the M sequential single-item launches in the existing forward_layer_batched_decode path with one batched dispatch. CUDA-only for now; other backends fall through to the default unsupported and the caller falls back to the per-item loop.

Source

fn split_qkv_norm_rope( _ctx: &mut Self::Context, _qkv: &Self::Buffer, _q_norm_w: &Self::Buffer, _k_norm_w: &Self::Buffer, _cos: &Self::Buffer, _sin: &Self::Buffer, _q_out: &mut Self::Buffer, _k_out: &mut Self::Buffer, _v_out: &mut Self::Buffer, _tokens: usize, _q_heads: usize, _kv_heads: usize, _head_dim: usize, _pos_offset: usize, _eps: f32, _qk_mode: i32, ) -> Result<()>

Fused split-QKV + QK-norm + RoPE + head-major transpose.

Single-dispatch replacement for the (split_qkv → 3× qk_norm_rope) chain on the decode-attention prelude. Reads the linear-layer fused-QKV output once and writes head-major Q/K/V directly into attention scratch.

qkv layout: [tokens, q_heads*hd + 2*kv_heads*hd]. q_out: [q_heads, tokens, hd]. k_out/v_out: [kv_heads, tokens, hd]. qk_mode: 1 = norm + half-split RoPE for Q/K (Qwen3 with QK-norm), 2 = half-split RoPE only for Q/K, 3 = interleaved RoPE only for Q/K (GGUF LLaMA / llama.cpp layout). V always falls through to transpose-only.

Default returns Unsupported. Backends that implement it are expected to be dramatically faster than the four-dispatch chain.

Source

fn split_qkv_norm_rope_into_cache( _ctx: &mut Self::Context, _qkv: &Self::Buffer, _q_norm_w: &Self::Buffer, _k_norm_w: &Self::Buffer, _cos: &Self::Buffer, _sin: &Self::Buffer, _q_out: &mut Self::Buffer, _cache_k: &mut Self::Buffer, _cache_v: &mut Self::Buffer, _tokens: usize, _q_heads: usize, _kv_heads: usize, _head_dim: usize, _pos_offset: usize, _eps: f32, _qk_mode: i32, _cache_len: usize, _cache_capacity: usize, ) -> Result<()>

Variant of Backend::split_qkv_norm_rope that writes the new K and V directly into pre-allocated head-major KV cache buffers at slot [kv_heads, cache_len .. cache_len + tokens, hd]. Eliminates the trailing kv_cache_append_head_major dispatch on the decode hot path. Q still lands in per-token head-major scratch (flash-attention reads it as the query).

Default returns Unsupported. Backends without the fused kernel can keep using split_qkv_norm_rope + kv_cache_append_head_major.

Source

fn transpose_token_to_head( _ctx: &mut Self::Context, _src: &Self::Buffer, _dst: &mut Self::Buffer, _tokens: usize, _heads: usize, _dim: usize, )

Inverse of transpose_head_to_token: [tokens, heads, dim] → [heads, tokens, dim]. Used by the CUDA paged_decode_attention wrapper to convert paged_varlen_attention’s token-major output back to the head-major layout that Qwen3MoeModel expects. Default panics — backends without a paged-KV CUDA path don’t hit this code.

Source

fn scaled_add_inplace( _ctx: &mut Self::Context, dst: &mut Self::Buffer, src: &Self::Buffer, scale: f32, len: usize, )

dst[i] += scale * src[i] — scalar-broadcast scaled add, in place.

MoE per-token combine writes out[b] += weight_k * expert_k(x[b]) for each top-K expert; this primitive is the per-call accumulate. Backends without a dedicated kernel can fall back to the default implementation, which round-trips through host memory — correct, but slow on a hot path. Override on any backend you actually dispatch MoE on.

Source

fn fused_silu_mul_split_strided( _ctx: &mut Self::Context, _gate_up: &Self::Buffer, _in_row_offset: usize, _out: &mut Self::Buffer, _out_row_offset: usize, _tokens: usize, _intermediate: usize, )

Strided variant of Backend::fused_silu_mul_split for the bucketed MoE path: reads gate_up rows starting at in_row_offset, writes out rows starting at out_row_offset.

Source

fn argmax_rows_f16( _ctx: &mut Self::Context, logits: &Self::Buffer, m: usize, n: usize, ) -> Result<Vec<u32>>

Greedy-decode fast path: GPU argmax over each row of a [m, n] FP16 logits buffer, returning the m token indices on the host. Saves m × n × 2 bytes of D2H per call (e.g. 19.5 MB at c=32, vocab=152064) and the host-side argmax scan (~150 µs × m).

Default impl falls back to the slow path: full to_vec + host argmax. CUDA overrides with a native kernel + tiny D2H (m × 4 B). Backends that don’t override pay the same cost as to_vec + host argmax, so callers can call this unconditionally.

Source

fn from_weight_bytes(raw: &[u8], src_dtype: SrcDtype) -> Self::Buffer

Load a weight tensor straight from its on-disk byte representation, letting the backend pick its preferred storage dtype.

Default impl upcasts bf16/f16 to f32 via an intermediate Vec, matching pre-existing loader behaviour. Backends override this to go straight from raw bytes into a native half-precision buffer (e.g. Metal with FERRUM_METAL_DTYPE=f16), avoiding the transient 2× RAM spike.

Dyn Compatibility§

This trait is not dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety".

Implementors§