Skip to main content

Driver

Trait Driver 

Source
pub trait Driver: Send + Sync {
    type Tensor;

Show 61 methods // Required methods fn alloc_zeros(&self, n: usize) -> Result<Self::Tensor>; fn clone_tensor( &self, tensor: &Self::Tensor, n: usize, ) -> Result<Self::Tensor>; fn prepare_batch( &self, encodings: &[Encoding], max_seq: usize, ) -> Result<BatchInputs<Self::Tensor>>; fn pad_to_batch( &self, flat: &Self::Tensor, padded: &mut Self::Tensor, seq_lengths: &[usize], max_seq: usize, dim: usize, ) -> Result<()>; fn unpad_from_batch( &self, padded: &Self::Tensor, flat: &mut Self::Tensor, seq_lengths: &[usize], max_seq: usize, dim: usize, ) -> Result<()>; fn embedding_lookup( &self, word_ids: &Self::Tensor, embedding_table: &Self::Tensor, seq_len: usize, hidden: usize, ) -> Result<Self::Tensor>; fn add_embeddings( &self, hidden: &mut Self::Tensor, table: &Self::Tensor, ids: &Self::Tensor, seq_len: usize, hidden_dim: usize, ) -> Result<()>; fn layer_norm( &self, output: &mut Self::Tensor, input: &Self::Tensor, weight: &Self::Tensor, bias: &Self::Tensor, rows: usize, cols: usize, eps: f32, ) -> Result<()>; fn gemm( &self, a: &Self::Tensor, b: &Self::Tensor, output: &mut Self::Tensor, m: usize, n: usize, k: usize, transpose_b: bool, ) -> Result<()>; fn gemm_batched( &self, a: &Self::Tensor, b: &Self::Tensor, output: &mut Self::Tensor, m: usize, n: usize, k: usize, transpose_b: bool, stride_a: usize, stride_b: usize, stride_c: usize, batch_count: usize, ) -> Result<()>; fn fused_scale_mask_softmax( &self, scores: &mut Self::Tensor, mask: &Self::Tensor, batch: usize, num_heads: usize, seq_len: usize, scale: f32, ) -> Result<()>; fn fused_scale_mask_softmax_windowed( &self, scores: &mut Self::Tensor, mask: &Self::Tensor, batch: usize, num_heads: usize, seq_len: usize, scale: f32, window_size: usize, ) -> Result<()>; fn build_attn_mask( &self, output: &mut Self::Tensor, int_mask: &Self::Tensor, n: usize, ) -> Result<()>; fn qkv_split( &self, q: &mut Self::Tensor, k: &mut Self::Tensor, v: &mut Self::Tensor, qkv: &Self::Tensor, batch: usize, seq: usize, hidden: usize, num_heads: usize, head_dim: usize, ) -> Result<()>; fn banded_qk( &self, q: &Self::Tensor, k: &Self::Tensor, scores: &mut Self::Tensor, batch_heads: usize, seq: usize, head_dim: usize, window: usize, stride_qk: usize, stride_scores: usize, ) -> Result<()>; fn banded_sv( &self, scores: &Self::Tensor, v: &Self::Tensor, output: &mut Self::Tensor, batch_heads: usize, seq: usize, head_dim: usize, window: usize, stride_scores: usize, stride_v: usize, stride_out: usize, ) -> Result<()>; fn banded_softmax( &self, scores: &mut Self::Tensor, total_rows: usize, window: usize, scale: f32, ) -> Result<()>; fn attn_reshape( &self, output: &mut Self::Tensor, input: &Self::Tensor, batch: usize, seq: usize, num_heads: usize, head_dim: usize, ) -> Result<()>; fn apply_rope( &self, qk: &mut Self::Tensor, cos: &Self::Tensor, sin: &Self::Tensor, num_rows: usize, seq_len: usize, head_dim: usize, num_heads: usize, ) -> Result<()>; fn split_gate_value( &self, first: &mut Self::Tensor, second: &mut Self::Tensor, input: &Self::Tensor, rows: usize, cols: usize, ) -> Result<()>; fn gelu(&self, x: &mut Self::Tensor, n: usize) -> Result<()>; fn swiglu( &self, value: &Self::Tensor, gate: &Self::Tensor, output: &mut Self::Tensor, n: usize, ) -> Result<()>; fn geglu( &self, value: &Self::Tensor, gate: &Self::Tensor, output: &mut Self::Tensor, n: usize, ) -> Result<()>; fn fused_bias_gelu( &self, x: &mut Self::Tensor, bias: &Self::Tensor, rows: usize, cols: usize, ) -> Result<()>; fn fused_bias_residual( &self, output: &mut Self::Tensor, input: &Self::Tensor, bias: &Self::Tensor, residual: &Self::Tensor, n: usize, cols: usize, ) -> Result<()>; fn fused_residual_layernorm( &self, output: &mut Self::Tensor, hidden: &Self::Tensor, residual: &Self::Tensor, weight: &Self::Tensor, bias: &Self::Tensor, rows: usize, cols: usize, eps: f32, ) -> Result<()>; fn residual_add( &self, output: &mut Self::Tensor, hidden: &Self::Tensor, residual: &Self::Tensor, n: usize, ) -> Result<()>; fn add_bias( &self, x: &mut Self::Tensor, bias: &Self::Tensor, rows: usize, cols: usize, ) -> Result<()>; fn cls_pool( &self, output: &mut Self::Tensor, hidden: &Self::Tensor, batch: usize, seq: usize, hidden_dim: usize, ) -> Result<()>; fn mean_pool( &self, output: &mut Self::Tensor, hidden: &Self::Tensor, mask: &Self::Tensor, batch: usize, seq: usize, hidden_dim: usize, ) -> Result<()>; fn l2_normalize( &self, data: &mut Self::Tensor, rows: usize, cols: usize, ) -> Result<()>; fn to_host( &self, tensor: &Self::Tensor, batch: usize, dim: usize, ) -> Result<Vec<Vec<f32>>>; // Provided methods fn new_for_clone() -> Result<Self> where Self: Sized { ... } fn begin_batch(&self) -> Result<()> { ... } fn end_batch(&self) -> Result<()> { ... } fn flush_batch(&self) -> Result<()> { ... } fn segment_encoder(&self) { ... } fn save_pool_cursor(&self) -> usize { ... } fn restore_pool_cursor(&self, _saved: usize) { ... } fn prepare_batch_unpadded( &self, encodings: &[Encoding], ) -> Result<BatchInputs<Self::Tensor>> { ... } fn alloc_zeros_f16(&self, _n: usize) -> Result<Self::Tensor> { ... } fn f32_to_f16( &self, _output: &mut Self::Tensor, _input: &Self::Tensor, _n: usize, ) -> Result<()> { ... } fn f16_to_f32( &self, _output: &mut Self::Tensor, _input: &Self::Tensor, _n: usize, ) -> Result<()> { ... } fn gemm_mixed( &self, _a_f16: &Self::Tensor, _b_f16: &Self::Tensor, _output_f32: &mut Self::Tensor, _m: usize, _n: usize, _k: usize, _transpose_b: bool, ) -> Result<()> { ... } fn gemm_f16( &self, _a: &Self::Tensor, _b: &Self::Tensor, _output: &mut Self::Tensor, _m: usize, _n: usize, _k: usize, _transpose_b: bool, ) -> Result<()> { ... } fn gemm_batched_f16( &self, _a: &Self::Tensor, _b: &Self::Tensor, _output: &mut Self::Tensor, _m: usize, _n: usize, _k: usize, _transpose_b: bool, _stride_a: usize, _stride_b: usize, _stride_c: usize, _batch_count: usize, ) -> Result<()> { ... } fn layer_norm_f16( &self, _output: &mut Self::Tensor, _input: &Self::Tensor, _weight: &Self::Tensor, _bias: &Self::Tensor, _rows: usize, _cols: usize, _eps: f32, ) -> Result<()> { ... } fn fused_scale_mask_softmax_f16( &self, _scores: &mut Self::Tensor, _mask: &Self::Tensor, _batch: usize, _num_heads: usize, _seq_len: usize, _scale: f32, ) -> Result<()> { ... } fn fused_scale_mask_softmax_windowed_f16( &self, _scores: &mut Self::Tensor, _mask: &Self::Tensor, _batch: usize, _num_heads: usize, _seq_len: usize, _scale: f32, _window_size: usize, ) -> Result<()> { ... } fn qkv_split_f16( &self, _q: &mut Self::Tensor, _k: &mut Self::Tensor, _v: &mut Self::Tensor, _qkv: &Self::Tensor, _batch: usize, _seq: usize, _hidden: usize, _num_heads: usize, _head_dim: usize, ) -> Result<()> { ... } fn attn_reshape_f16( &self, _output: &mut Self::Tensor, _input: &Self::Tensor, _batch: usize, _seq: usize, _num_heads: usize, _head_dim: usize, ) -> Result<()> { ... } fn pad_to_batch_f16( &self, _flat: &Self::Tensor, _padded: &mut Self::Tensor, _seq_lengths: &[usize], _max_seq: usize, _dim: usize, ) -> Result<()> { ... } fn unpad_from_batch_f16( &self, _padded: &Self::Tensor, _flat: &mut Self::Tensor, _seq_lengths: &[usize], _max_seq: usize, _dim: usize, ) -> Result<()> { ... } fn rope_encode_f16( &self, _qk: &mut Self::Tensor, _cos: &Self::Tensor, _sin: &Self::Tensor, _num_rows: usize, _seq_len: usize, _head_dim: usize, _num_heads: usize, ) -> Result<()> { ... } fn geglu_f16( &self, _value: &Self::Tensor, _gate: &Self::Tensor, _output: &mut Self::Tensor, _n: usize, ) -> Result<()> { ... } fn fused_residual_layernorm_f16( &self, _output: &mut Self::Tensor, _hidden: &Self::Tensor, _residual: &Self::Tensor, _weight: &Self::Tensor, _bias: &Self::Tensor, _rows: usize, _cols: usize, _eps: f32, ) -> Result<()> { ... } fn residual_add_f16( &self, _output: &mut Self::Tensor, _hidden: &Self::Tensor, _residual: &Self::Tensor, _n: usize, ) -> Result<()> { ... } fn split_gate_value_f16( &self, _first: &mut Self::Tensor, _second: &mut Self::Tensor, _input: &Self::Tensor, _rows: usize, _cols: usize, ) -> Result<()> { ... } fn fused_split_geglu_f16( &self, output: &mut Self::Tensor, input: &Self::Tensor, rows: usize, cols: usize, ) -> Result<()> { ... } fn fused_pad_qkv_split_f16( &self, q: &mut Self::Tensor, k: &mut Self::Tensor, v: &mut Self::Tensor, qkv_flat: &Self::Tensor, seq_lengths: &[usize], max_seq: usize, batch: usize, hidden: usize, num_heads: usize, head_dim: usize, ) -> Result<()> { ... } fn fused_reshape_unpad_f16( &self, flat: &mut Self::Tensor, heads: &Self::Tensor, seq_lengths: &[usize], max_seq: usize, batch: usize, num_heads: usize, head_dim: usize, ) -> Result<()> { ... }
}
Expand description

Hardware-agnostic compute primitives for BERT inference.

Each method corresponds to one operation in the forward pass. Drivers handle memory allocation, kernel dispatch, and synchronization. Architectures compose these primitives via the super::arch::ModelArch trait.

Required Associated Types§

Source

type Tensor

Opaque tensor handle.

Metal: MTLBuffer + byte offset. CUDA: CUdeviceptr. CPU: Array2<f32>.

Required Methods§

Source

fn alloc_zeros(&self, n: usize) -> Result<Self::Tensor>

Allocate a zero-initialized tensor with n float elements on device.

Used by architectures to create workspace buffers (QKV projections, attention scores, intermediate activations, etc.).

§Errors

Returns an error if device memory allocation fails.

Source

fn clone_tensor(&self, tensor: &Self::Tensor, n: usize) -> Result<Self::Tensor>

Clone a tensor, producing an independent copy of the data.

Used when an operation needs both the original and a mutable output referencing the same logical data (e.g., in-place layer normalization where input == output).

§Errors

Returns an error if device memory allocation or the copy fails.

Source

fn prepare_batch( &self, encodings: &[Encoding], max_seq: usize, ) -> Result<BatchInputs<Self::Tensor>>

Prepare a batch of encodings for inference, returning input tensors on device.

Pads all sequences to max_seq and uploads input_ids, attention_mask, token_type_ids, position_ids, and a float attention mask to device memory.

Source

fn pad_to_batch( &self, flat: &Self::Tensor, padded: &mut Self::Tensor, seq_lengths: &[usize], max_seq: usize, dim: usize, ) -> Result<()>

Scatter flat [total_tokens, dim] tensor into padded [batch, max_seq, dim].

Used before attention: linear layers produce unpadded output, but the QKV split + batched attention GEMM need aligned [batch*heads, seq, head_dim]. Padding positions are zeroed.

Source

fn unpad_from_batch( &self, padded: &Self::Tensor, flat: &mut Self::Tensor, seq_lengths: &[usize], max_seq: usize, dim: usize, ) -> Result<()>

Gather padded [batch, max_seq, dim] back to flat [total_tokens, dim].

Used after attention: extracts only the real tokens, discarding padding.

Source

fn embedding_lookup( &self, word_ids: &Self::Tensor, embedding_table: &Self::Tensor, seq_len: usize, hidden: usize, ) -> Result<Self::Tensor>

Word/position/token-type embedding lookup via gather.

Reads seq_len token IDs from word_ids, gathers rows from embedding_table, and writes [seq_len, hidden] floats to the result.

Source

fn add_embeddings( &self, hidden: &mut Self::Tensor, table: &Self::Tensor, ids: &Self::Tensor, seq_len: usize, hidden_dim: usize, ) -> Result<()>

Element-wise add an embedding table lookup into hidden.

Used for position and token-type embeddings: hidden[i] += table[ids[i]] for each token position.

Source

fn layer_norm( &self, output: &mut Self::Tensor, input: &Self::Tensor, weight: &Self::Tensor, bias: &Self::Tensor, rows: usize, cols: usize, eps: f32, ) -> Result<()>

Layer normalization: output = (input - mean) / sqrt(var + eps) * weight + bias.

Source

fn gemm( &self, a: &Self::Tensor, b: &Self::Tensor, output: &mut Self::Tensor, m: usize, n: usize, k: usize, transpose_b: bool, ) -> Result<()>

General matrix multiply: output = A * B (or A * B^T if transpose_b).

Dimensions: A is [m, k], B is [k, n] (or [n, k] if transposed), output is [m, n].

Source

fn gemm_batched( &self, a: &Self::Tensor, b: &Self::Tensor, output: &mut Self::Tensor, m: usize, n: usize, k: usize, transpose_b: bool, stride_a: usize, stride_b: usize, stride_c: usize, batch_count: usize, ) -> Result<()>

Batched GEMM for multi-head attention.

Performs batch_count independent GEMMs with strided access into contiguous buffers. Used for per-head QK^T and attnV.

Source

fn fused_scale_mask_softmax( &self, scores: &mut Self::Tensor, mask: &Self::Tensor, batch: usize, num_heads: usize, seq_len: usize, scale: f32, ) -> Result<()>

Fused scale + mask + softmax for attention scores.

scores = softmax(scores * scale + mask) computed per-head.

Source

fn fused_scale_mask_softmax_windowed( &self, scores: &mut Self::Tensor, mask: &Self::Tensor, batch: usize, num_heads: usize, seq_len: usize, scale: f32, window_size: usize, ) -> Result<()>

Fused scale + mask + sliding window + softmax for attention scores.

Like fused_scale_mask_softmax but additionally masks out positions where |query_pos - key_pos| > window_size / 2. Used by ModernBERT’s local attention layers.

Source

fn build_attn_mask( &self, output: &mut Self::Tensor, int_mask: &Self::Tensor, n: usize, ) -> Result<()>

Build a float attention mask from an integer mask.

Converts [batch * seq] int mask (0/1) to [batch * seq] float mask (0.0 / -10000.0) for use with fused_scale_mask_softmax.

Source

fn qkv_split( &self, q: &mut Self::Tensor, k: &mut Self::Tensor, v: &mut Self::Tensor, qkv: &Self::Tensor, batch: usize, seq: usize, hidden: usize, num_heads: usize, head_dim: usize, ) -> Result<()>

Split a fused QKV projection into separate Q, K, V tensors.

Source

fn banded_qk( &self, q: &Self::Tensor, k: &Self::Tensor, scores: &mut Self::Tensor, batch_heads: usize, seq: usize, head_dim: usize, window: usize, stride_qk: usize, stride_scores: usize, ) -> Result<()>

Banded Q@K^T: compute attention scores only within a sliding window.

Output shape: [batch * num_heads, seq, window] (NOT [seq, seq]). scores[h, i, w] = dot(Q[h, i, :], K[h, i - window/2 + w, :]) where out-of-bounds positions are set to -inf (masked in softmax).

Reduces attention compute from O(seq²) to O(seq × window). For seq=512, window=128: 4× less compute per local layer.

Source

fn banded_sv( &self, scores: &Self::Tensor, v: &Self::Tensor, output: &mut Self::Tensor, batch_heads: usize, seq: usize, head_dim: usize, window: usize, stride_scores: usize, stride_v: usize, stride_out: usize, ) -> Result<()>

Banded scores@V: weighted sum using banded attention scores.

Input scores: [batch * num_heads, seq, window] (from banded_qk). Output: [batch * num_heads, seq, head_dim]. output[h, i, d] = sum_w scores[h, i, w] * V[h, i - window/2 + w, d]

Source

fn banded_softmax( &self, scores: &mut Self::Tensor, total_rows: usize, window: usize, scale: f32, ) -> Result<()>

Fused scale + softmax over the window dimension (no padding mask needed).

Operates on [batch * num_heads * seq, window] rows.

Source

fn attn_reshape( &self, output: &mut Self::Tensor, input: &Self::Tensor, batch: usize, seq: usize, num_heads: usize, head_dim: usize, ) -> Result<()>

Reshape attention output from [batch, num_heads, seq, head_dim] to [batch * seq, hidden].

Source

fn apply_rope( &self, qk: &mut Self::Tensor, cos: &Self::Tensor, sin: &Self::Tensor, num_rows: usize, seq_len: usize, head_dim: usize, num_heads: usize, ) -> Result<()>

Apply Rotary Position Embedding (RoPE) to Q/K tensors.

Used by ModernBERT (not ClassicBert which uses learned position embeddings).

Source

fn split_gate_value( &self, first: &mut Self::Tensor, second: &mut Self::Tensor, input: &Self::Tensor, rows: usize, cols: usize, ) -> Result<()>

Split a [rows, 2*cols] matrix into two [rows, cols] halves.

Each row of input is [first_half | second_half]. The first cols elements go to first, the remaining cols to second. Used by ModernBERT for gated MLP splits.

Source

fn gelu(&self, x: &mut Self::Tensor, n: usize) -> Result<()>

GELU activation (Gaussian Error Linear Unit), applied in-place.

Source

fn swiglu( &self, value: &Self::Tensor, gate: &Self::Tensor, output: &mut Self::Tensor, n: usize, ) -> Result<()>

SwiGLU gated activation: output = value * silu(gate).

The gate and value come from splitting the intermediate projection.

Source

fn geglu( &self, value: &Self::Tensor, gate: &Self::Tensor, output: &mut Self::Tensor, n: usize, ) -> Result<()>

GeGLU gated activation: output = gelu(value) * gate.

Used by ModernBERT. The value and gate come from splitting the MLP Wi projection output in half.

Source

fn fused_bias_gelu( &self, x: &mut Self::Tensor, bias: &Self::Tensor, rows: usize, cols: usize, ) -> Result<()>

Fused bias + GELU: x = gelu(x + bias) row-wise.

Source

fn fused_bias_residual( &self, output: &mut Self::Tensor, input: &Self::Tensor, bias: &Self::Tensor, residual: &Self::Tensor, n: usize, cols: usize, ) -> Result<()>

Fused bias + residual add: output = input + bias + residual.

Bias is broadcast row-wise (cols-wide) across n / cols rows.

Source

fn fused_residual_layernorm( &self, output: &mut Self::Tensor, hidden: &Self::Tensor, residual: &Self::Tensor, weight: &Self::Tensor, bias: &Self::Tensor, rows: usize, cols: usize, eps: f32, ) -> Result<()>

Fused residual add + layer normalization.

output = layer_norm(hidden + residual, weight, bias, eps).

Source

fn residual_add( &self, output: &mut Self::Tensor, hidden: &Self::Tensor, residual: &Self::Tensor, n: usize, ) -> Result<()>

Residual add without bias: output = hidden + residual.

Used by ModernBERT which has no bias terms.

Source

fn add_bias( &self, x: &mut Self::Tensor, bias: &Self::Tensor, rows: usize, cols: usize, ) -> Result<()>

Add bias to a matrix row-wise: x[row] += bias for each row.

Source

fn cls_pool( &self, output: &mut Self::Tensor, hidden: &Self::Tensor, batch: usize, seq: usize, hidden_dim: usize, ) -> Result<()>

CLS pooling: extract the first token’s hidden state per batch element.

Source

fn mean_pool( &self, output: &mut Self::Tensor, hidden: &Self::Tensor, mask: &Self::Tensor, batch: usize, seq: usize, hidden_dim: usize, ) -> Result<()>

Mean pooling: attention-mask-weighted average of hidden states.

Source

fn l2_normalize( &self, data: &mut Self::Tensor, rows: usize, cols: usize, ) -> Result<()>

L2-normalize each row vector in-place.

Source

fn to_host( &self, tensor: &Self::Tensor, batch: usize, dim: usize, ) -> Result<Vec<Vec<f32>>>

Copy tensor data back to host memory as Vec<Vec<f32>>.

Returns one Vec<f32> of length dim per batch element.

Provided Methods§

Source

fn new_for_clone() -> Result<Self>
where Self: Sized,

Create a new driver instance for a cloned worker thread.

CPU drivers are zero-size and always succeed. GPU drivers typically cannot be cloned this way (they share device state) and should leave the default panic implementation.

Source

fn begin_batch(&self) -> Result<()>

Begin batched mode: all subsequent operations encode into one dispatch.

GPU drivers accumulate into a single command buffer; CPU is a no-op. Call [end_batch] to commit. This eliminates per-call overhead.

Source

fn end_batch(&self) -> Result<()>

End batched mode: commit all accumulated operations and wait.

Source

fn flush_batch(&self) -> Result<()>

Flush the current command buffer and start a new one, preserving pool state. Use mid-forward-pass to prevent GPU timeouts on deep models.

Source

fn segment_encoder(&self)

Close and reopen the compute encoder within the same command buffer.

This segments a long sequence of compute dispatches into multiple encoders without committing or waiting. Metal processes encoders back-to-back from the same CB — zero sync overhead.

Use every few layers to prevent encoder state overflow (>~60 dispatches per encoder can cause hangs on some Apple Silicon GPUs).

Source

fn save_pool_cursor(&self) -> usize

Save the current pool cursor position. Call BEFORE a layer’s work.

Source

fn restore_pool_cursor(&self, _saved: usize)

Restore the pool cursor to a previously saved position. Call AFTER a layer’s transient tensors have been dropped (out of scope).

The architecture must ensure only the output tensor (hidden_states) survives — all layer-internal tensors (qkv, scores, context, etc.) must be dropped before this call so their pool slots can be recycled.

Source

fn prepare_batch_unpadded( &self, encodings: &[Encoding], ) -> Result<BatchInputs<Self::Tensor>>

Prepare a batch WITHOUT padding — concatenate all tokens flat.

Returns BatchInputs with total_tokens actual tokens (no padding), cu_seqlens for attention boundaries, and per-token position IDs. Linear layers (GEMM, LN, GELU) process total_tokens rows. Attention must pad/unpad around the per-head operations.

Source

fn alloc_zeros_f16(&self, _n: usize) -> Result<Self::Tensor>

Allocate a zero-initialized FP16 tensor with n half-precision elements.

§Errors

Returns an error if device memory allocation fails or FP16 is unsupported.

Source

fn f32_to_f16( &self, _output: &mut Self::Tensor, _input: &Self::Tensor, _n: usize, ) -> Result<()>

Convert FP32 tensor to FP16 (element-wise narrowing).

Source

fn f16_to_f32( &self, _output: &mut Self::Tensor, _input: &Self::Tensor, _n: usize, ) -> Result<()>

Convert FP16 tensor back to FP32 (element-wise widening).

Source

fn gemm_mixed( &self, _a_f16: &Self::Tensor, _b_f16: &Self::Tensor, _output_f32: &mut Self::Tensor, _m: usize, _n: usize, _k: usize, _transpose_b: bool, ) -> Result<()>

Mixed-precision GEMM: FP16 inputs → FP32 output via native simdgroup ops.

Source

fn gemm_f16( &self, _a: &Self::Tensor, _b: &Self::Tensor, _output: &mut Self::Tensor, _m: usize, _n: usize, _k: usize, _transpose_b: bool, ) -> Result<()>

FP16 GEMM: output = A * B (or A * B^T). All tensors are half.

Source

fn gemm_batched_f16( &self, _a: &Self::Tensor, _b: &Self::Tensor, _output: &mut Self::Tensor, _m: usize, _n: usize, _k: usize, _transpose_b: bool, _stride_a: usize, _stride_b: usize, _stride_c: usize, _batch_count: usize, ) -> Result<()>

FP16 batched GEMM for multi-head attention. All tensors are half.

Source

fn layer_norm_f16( &self, _output: &mut Self::Tensor, _input: &Self::Tensor, _weight: &Self::Tensor, _bias: &Self::Tensor, _rows: usize, _cols: usize, _eps: f32, ) -> Result<()>

FP16 layer normalization. Half I/O, FP32 reductions.

Source

fn fused_scale_mask_softmax_f16( &self, _scores: &mut Self::Tensor, _mask: &Self::Tensor, _batch: usize, _num_heads: usize, _seq_len: usize, _scale: f32, ) -> Result<()>

FP16 fused scale + mask + softmax. Half scores, FP32 reductions.

Source

fn fused_scale_mask_softmax_windowed_f16( &self, _scores: &mut Self::Tensor, _mask: &Self::Tensor, _batch: usize, _num_heads: usize, _seq_len: usize, _scale: f32, _window_size: usize, ) -> Result<()>

FP16 fused scale + mask + sliding window + softmax.

Source

fn qkv_split_f16( &self, _q: &mut Self::Tensor, _k: &mut Self::Tensor, _v: &mut Self::Tensor, _qkv: &Self::Tensor, _batch: usize, _seq: usize, _hidden: usize, _num_heads: usize, _head_dim: usize, ) -> Result<()>

FP16 QKV split: [batch*seq, 3*hidden] into Q, K, V per-head layout.

Source

fn attn_reshape_f16( &self, _output: &mut Self::Tensor, _input: &Self::Tensor, _batch: usize, _seq: usize, _num_heads: usize, _head_dim: usize, ) -> Result<()>

FP16 attention output reshape: [batch*num_heads, seq, head_dim] to [batch*seq, hidden].

Source

fn pad_to_batch_f16( &self, _flat: &Self::Tensor, _padded: &mut Self::Tensor, _seq_lengths: &[usize], _max_seq: usize, _dim: usize, ) -> Result<()>

FP16 scatter flat [total_tokens, dim] to padded [batch, max_seq, dim].

Source

fn unpad_from_batch_f16( &self, _padded: &Self::Tensor, _flat: &mut Self::Tensor, _seq_lengths: &[usize], _max_seq: usize, _dim: usize, ) -> Result<()>

FP16 gather padded [batch, max_seq, dim] back to flat [total_tokens, dim].

Source

fn rope_encode_f16( &self, _qk: &mut Self::Tensor, _cos: &Self::Tensor, _sin: &Self::Tensor, _num_rows: usize, _seq_len: usize, _head_dim: usize, _num_heads: usize, ) -> Result<()>

FP16 RoPE: apply rotary position embedding. Half Q/K, float cos/sin tables.

Source

fn geglu_f16( &self, _value: &Self::Tensor, _gate: &Self::Tensor, _output: &mut Self::Tensor, _n: usize, ) -> Result<()>

FP16 GeGLU gated activation: output = gelu(value) * gate. Half I/O.

Source

fn fused_residual_layernorm_f16( &self, _output: &mut Self::Tensor, _hidden: &Self::Tensor, _residual: &Self::Tensor, _weight: &Self::Tensor, _bias: &Self::Tensor, _rows: usize, _cols: usize, _eps: f32, ) -> Result<()>

FP16 fused residual add + layer normalization.

Source

fn residual_add_f16( &self, _output: &mut Self::Tensor, _hidden: &Self::Tensor, _residual: &Self::Tensor, _n: usize, ) -> Result<()>

FP16 residual add (no bias): output = hidden + residual.

Source

fn split_gate_value_f16( &self, _first: &mut Self::Tensor, _second: &mut Self::Tensor, _input: &Self::Tensor, _rows: usize, _cols: usize, ) -> Result<()>

FP16 split [rows, 2*cols] into two [rows, cols] halves.

Source

fn fused_split_geglu_f16( &self, output: &mut Self::Tensor, input: &Self::Tensor, rows: usize, cols: usize, ) -> Result<()>

Fused split + GeGLU: read [rows, 2*cols], write [rows, cols].

Combines split_gate_value_f16 and geglu_f16 into a single kernel, eliminating two intermediate [rows, cols] buffers and halving HBM round-trips.

Default falls back to separate split + geglu calls.

Source

fn fused_pad_qkv_split_f16( &self, q: &mut Self::Tensor, k: &mut Self::Tensor, v: &mut Self::Tensor, qkv_flat: &Self::Tensor, seq_lengths: &[usize], max_seq: usize, batch: usize, hidden: usize, num_heads: usize, head_dim: usize, ) -> Result<()>

Fused pad + QKV split: flat [total_tokens, 3*hidden] → Q, K, V each [batch*heads, max_seq, head_dim].

Eliminates the padded intermediate buffer. Default calls pad then split.

Source

fn fused_reshape_unpad_f16( &self, flat: &mut Self::Tensor, heads: &Self::Tensor, seq_lengths: &[usize], max_seq: usize, batch: usize, num_heads: usize, head_dim: usize, ) -> Result<()>

Fused attn_reshape + unpad: [batch*heads, max_seq, head_dim][total_tokens, hidden].

Eliminates the padded context intermediate. Default calls reshape then unpad.

Implementors§