pub trait Driver: Send + Sync {
type Tensor;
Show 61 methods
// Required methods
fn alloc_zeros(&self, n: usize) -> Result<Self::Tensor>;
fn clone_tensor(
&self,
tensor: &Self::Tensor,
n: usize,
) -> Result<Self::Tensor>;
fn prepare_batch(
&self,
encodings: &[Encoding],
max_seq: usize,
) -> Result<BatchInputs<Self::Tensor>>;
fn pad_to_batch(
&self,
flat: &Self::Tensor,
padded: &mut Self::Tensor,
seq_lengths: &[usize],
max_seq: usize,
dim: usize,
) -> Result<()>;
fn unpad_from_batch(
&self,
padded: &Self::Tensor,
flat: &mut Self::Tensor,
seq_lengths: &[usize],
max_seq: usize,
dim: usize,
) -> Result<()>;
fn embedding_lookup(
&self,
word_ids: &Self::Tensor,
embedding_table: &Self::Tensor,
seq_len: usize,
hidden: usize,
) -> Result<Self::Tensor>;
fn add_embeddings(
&self,
hidden: &mut Self::Tensor,
table: &Self::Tensor,
ids: &Self::Tensor,
seq_len: usize,
hidden_dim: usize,
) -> Result<()>;
fn layer_norm(
&self,
output: &mut Self::Tensor,
input: &Self::Tensor,
weight: &Self::Tensor,
bias: &Self::Tensor,
rows: usize,
cols: usize,
eps: f32,
) -> Result<()>;
fn gemm(
&self,
a: &Self::Tensor,
b: &Self::Tensor,
output: &mut Self::Tensor,
m: usize,
n: usize,
k: usize,
transpose_b: bool,
) -> Result<()>;
fn gemm_batched(
&self,
a: &Self::Tensor,
b: &Self::Tensor,
output: &mut Self::Tensor,
m: usize,
n: usize,
k: usize,
transpose_b: bool,
stride_a: usize,
stride_b: usize,
stride_c: usize,
batch_count: usize,
) -> Result<()>;
fn fused_scale_mask_softmax(
&self,
scores: &mut Self::Tensor,
mask: &Self::Tensor,
batch: usize,
num_heads: usize,
seq_len: usize,
scale: f32,
) -> Result<()>;
fn fused_scale_mask_softmax_windowed(
&self,
scores: &mut Self::Tensor,
mask: &Self::Tensor,
batch: usize,
num_heads: usize,
seq_len: usize,
scale: f32,
window_size: usize,
) -> Result<()>;
fn build_attn_mask(
&self,
output: &mut Self::Tensor,
int_mask: &Self::Tensor,
n: usize,
) -> Result<()>;
fn qkv_split(
&self,
q: &mut Self::Tensor,
k: &mut Self::Tensor,
v: &mut Self::Tensor,
qkv: &Self::Tensor,
batch: usize,
seq: usize,
hidden: usize,
num_heads: usize,
head_dim: usize,
) -> Result<()>;
fn banded_qk(
&self,
q: &Self::Tensor,
k: &Self::Tensor,
scores: &mut Self::Tensor,
batch_heads: usize,
seq: usize,
head_dim: usize,
window: usize,
stride_qk: usize,
stride_scores: usize,
) -> Result<()>;
fn banded_sv(
&self,
scores: &Self::Tensor,
v: &Self::Tensor,
output: &mut Self::Tensor,
batch_heads: usize,
seq: usize,
head_dim: usize,
window: usize,
stride_scores: usize,
stride_v: usize,
stride_out: usize,
) -> Result<()>;
fn banded_softmax(
&self,
scores: &mut Self::Tensor,
total_rows: usize,
window: usize,
scale: f32,
) -> Result<()>;
fn attn_reshape(
&self,
output: &mut Self::Tensor,
input: &Self::Tensor,
batch: usize,
seq: usize,
num_heads: usize,
head_dim: usize,
) -> Result<()>;
fn apply_rope(
&self,
qk: &mut Self::Tensor,
cos: &Self::Tensor,
sin: &Self::Tensor,
num_rows: usize,
seq_len: usize,
head_dim: usize,
num_heads: usize,
) -> Result<()>;
fn split_gate_value(
&self,
first: &mut Self::Tensor,
second: &mut Self::Tensor,
input: &Self::Tensor,
rows: usize,
cols: usize,
) -> Result<()>;
fn gelu(&self, x: &mut Self::Tensor, n: usize) -> Result<()>;
fn swiglu(
&self,
value: &Self::Tensor,
gate: &Self::Tensor,
output: &mut Self::Tensor,
n: usize,
) -> Result<()>;
fn geglu(
&self,
value: &Self::Tensor,
gate: &Self::Tensor,
output: &mut Self::Tensor,
n: usize,
) -> Result<()>;
fn fused_bias_gelu(
&self,
x: &mut Self::Tensor,
bias: &Self::Tensor,
rows: usize,
cols: usize,
) -> Result<()>;
fn fused_bias_residual(
&self,
output: &mut Self::Tensor,
input: &Self::Tensor,
bias: &Self::Tensor,
residual: &Self::Tensor,
n: usize,
cols: usize,
) -> Result<()>;
fn fused_residual_layernorm(
&self,
output: &mut Self::Tensor,
hidden: &Self::Tensor,
residual: &Self::Tensor,
weight: &Self::Tensor,
bias: &Self::Tensor,
rows: usize,
cols: usize,
eps: f32,
) -> Result<()>;
fn residual_add(
&self,
output: &mut Self::Tensor,
hidden: &Self::Tensor,
residual: &Self::Tensor,
n: usize,
) -> Result<()>;
fn add_bias(
&self,
x: &mut Self::Tensor,
bias: &Self::Tensor,
rows: usize,
cols: usize,
) -> Result<()>;
fn cls_pool(
&self,
output: &mut Self::Tensor,
hidden: &Self::Tensor,
batch: usize,
seq: usize,
hidden_dim: usize,
) -> Result<()>;
fn mean_pool(
&self,
output: &mut Self::Tensor,
hidden: &Self::Tensor,
mask: &Self::Tensor,
batch: usize,
seq: usize,
hidden_dim: usize,
) -> Result<()>;
fn l2_normalize(
&self,
data: &mut Self::Tensor,
rows: usize,
cols: usize,
) -> Result<()>;
fn to_host(
&self,
tensor: &Self::Tensor,
batch: usize,
dim: usize,
) -> Result<Vec<Vec<f32>>>;
// Provided methods
fn new_for_clone() -> Result<Self>
where Self: Sized { ... }
fn begin_batch(&self) -> Result<()> { ... }
fn end_batch(&self) -> Result<()> { ... }
fn flush_batch(&self) -> Result<()> { ... }
fn segment_encoder(&self) { ... }
fn save_pool_cursor(&self) -> usize { ... }
fn restore_pool_cursor(&self, _saved: usize) { ... }
fn prepare_batch_unpadded(
&self,
encodings: &[Encoding],
) -> Result<BatchInputs<Self::Tensor>> { ... }
fn alloc_zeros_f16(&self, _n: usize) -> Result<Self::Tensor> { ... }
fn f32_to_f16(
&self,
_output: &mut Self::Tensor,
_input: &Self::Tensor,
_n: usize,
) -> Result<()> { ... }
fn f16_to_f32(
&self,
_output: &mut Self::Tensor,
_input: &Self::Tensor,
_n: usize,
) -> Result<()> { ... }
fn gemm_mixed(
&self,
_a_f16: &Self::Tensor,
_b_f16: &Self::Tensor,
_output_f32: &mut Self::Tensor,
_m: usize,
_n: usize,
_k: usize,
_transpose_b: bool,
) -> Result<()> { ... }
fn gemm_f16(
&self,
_a: &Self::Tensor,
_b: &Self::Tensor,
_output: &mut Self::Tensor,
_m: usize,
_n: usize,
_k: usize,
_transpose_b: bool,
) -> Result<()> { ... }
fn gemm_batched_f16(
&self,
_a: &Self::Tensor,
_b: &Self::Tensor,
_output: &mut Self::Tensor,
_m: usize,
_n: usize,
_k: usize,
_transpose_b: bool,
_stride_a: usize,
_stride_b: usize,
_stride_c: usize,
_batch_count: usize,
) -> Result<()> { ... }
fn layer_norm_f16(
&self,
_output: &mut Self::Tensor,
_input: &Self::Tensor,
_weight: &Self::Tensor,
_bias: &Self::Tensor,
_rows: usize,
_cols: usize,
_eps: f32,
) -> Result<()> { ... }
fn fused_scale_mask_softmax_f16(
&self,
_scores: &mut Self::Tensor,
_mask: &Self::Tensor,
_batch: usize,
_num_heads: usize,
_seq_len: usize,
_scale: f32,
) -> Result<()> { ... }
fn fused_scale_mask_softmax_windowed_f16(
&self,
_scores: &mut Self::Tensor,
_mask: &Self::Tensor,
_batch: usize,
_num_heads: usize,
_seq_len: usize,
_scale: f32,
_window_size: usize,
) -> Result<()> { ... }
fn qkv_split_f16(
&self,
_q: &mut Self::Tensor,
_k: &mut Self::Tensor,
_v: &mut Self::Tensor,
_qkv: &Self::Tensor,
_batch: usize,
_seq: usize,
_hidden: usize,
_num_heads: usize,
_head_dim: usize,
) -> Result<()> { ... }
fn attn_reshape_f16(
&self,
_output: &mut Self::Tensor,
_input: &Self::Tensor,
_batch: usize,
_seq: usize,
_num_heads: usize,
_head_dim: usize,
) -> Result<()> { ... }
fn pad_to_batch_f16(
&self,
_flat: &Self::Tensor,
_padded: &mut Self::Tensor,
_seq_lengths: &[usize],
_max_seq: usize,
_dim: usize,
) -> Result<()> { ... }
fn unpad_from_batch_f16(
&self,
_padded: &Self::Tensor,
_flat: &mut Self::Tensor,
_seq_lengths: &[usize],
_max_seq: usize,
_dim: usize,
) -> Result<()> { ... }
fn rope_encode_f16(
&self,
_qk: &mut Self::Tensor,
_cos: &Self::Tensor,
_sin: &Self::Tensor,
_num_rows: usize,
_seq_len: usize,
_head_dim: usize,
_num_heads: usize,
) -> Result<()> { ... }
fn geglu_f16(
&self,
_value: &Self::Tensor,
_gate: &Self::Tensor,
_output: &mut Self::Tensor,
_n: usize,
) -> Result<()> { ... }
fn fused_residual_layernorm_f16(
&self,
_output: &mut Self::Tensor,
_hidden: &Self::Tensor,
_residual: &Self::Tensor,
_weight: &Self::Tensor,
_bias: &Self::Tensor,
_rows: usize,
_cols: usize,
_eps: f32,
) -> Result<()> { ... }
fn residual_add_f16(
&self,
_output: &mut Self::Tensor,
_hidden: &Self::Tensor,
_residual: &Self::Tensor,
_n: usize,
) -> Result<()> { ... }
fn split_gate_value_f16(
&self,
_first: &mut Self::Tensor,
_second: &mut Self::Tensor,
_input: &Self::Tensor,
_rows: usize,
_cols: usize,
) -> Result<()> { ... }
fn fused_split_geglu_f16(
&self,
output: &mut Self::Tensor,
input: &Self::Tensor,
rows: usize,
cols: usize,
) -> Result<()> { ... }
fn fused_pad_qkv_split_f16(
&self,
q: &mut Self::Tensor,
k: &mut Self::Tensor,
v: &mut Self::Tensor,
qkv_flat: &Self::Tensor,
seq_lengths: &[usize],
max_seq: usize,
batch: usize,
hidden: usize,
num_heads: usize,
head_dim: usize,
) -> Result<()> { ... }
fn fused_reshape_unpad_f16(
&self,
flat: &mut Self::Tensor,
heads: &Self::Tensor,
seq_lengths: &[usize],
max_seq: usize,
batch: usize,
num_heads: usize,
head_dim: usize,
) -> Result<()> { ... }
}Expand description
Hardware-agnostic compute primitives for BERT inference.
Each method corresponds to one operation in the forward pass. Drivers handle
memory allocation, kernel dispatch, and synchronization. Architectures
compose these primitives via the super::arch::ModelArch trait.
Required Associated Types§
Required Methods§
Sourcefn alloc_zeros(&self, n: usize) -> Result<Self::Tensor>
fn alloc_zeros(&self, n: usize) -> Result<Self::Tensor>
Allocate a zero-initialized tensor with n float elements on device.
Used by architectures to create workspace buffers (QKV projections, attention scores, intermediate activations, etc.).
§Errors
Returns an error if device memory allocation fails.
Sourcefn clone_tensor(&self, tensor: &Self::Tensor, n: usize) -> Result<Self::Tensor>
fn clone_tensor(&self, tensor: &Self::Tensor, n: usize) -> Result<Self::Tensor>
Clone a tensor, producing an independent copy of the data.
Used when an operation needs both the original and a mutable output referencing the same logical data (e.g., in-place layer normalization where input == output).
§Errors
Returns an error if device memory allocation or the copy fails.
Sourcefn prepare_batch(
&self,
encodings: &[Encoding],
max_seq: usize,
) -> Result<BatchInputs<Self::Tensor>>
fn prepare_batch( &self, encodings: &[Encoding], max_seq: usize, ) -> Result<BatchInputs<Self::Tensor>>
Prepare a batch of encodings for inference, returning input tensors on device.
Pads all sequences to max_seq and uploads input_ids, attention_mask,
token_type_ids, position_ids, and a float attention mask to device memory.
Sourcefn pad_to_batch(
&self,
flat: &Self::Tensor,
padded: &mut Self::Tensor,
seq_lengths: &[usize],
max_seq: usize,
dim: usize,
) -> Result<()>
fn pad_to_batch( &self, flat: &Self::Tensor, padded: &mut Self::Tensor, seq_lengths: &[usize], max_seq: usize, dim: usize, ) -> Result<()>
Scatter flat [total_tokens, dim] tensor into padded [batch, max_seq, dim].
Used before attention: linear layers produce unpadded output, but the
QKV split + batched attention GEMM need aligned [batch*heads, seq, head_dim].
Padding positions are zeroed.
Sourcefn unpad_from_batch(
&self,
padded: &Self::Tensor,
flat: &mut Self::Tensor,
seq_lengths: &[usize],
max_seq: usize,
dim: usize,
) -> Result<()>
fn unpad_from_batch( &self, padded: &Self::Tensor, flat: &mut Self::Tensor, seq_lengths: &[usize], max_seq: usize, dim: usize, ) -> Result<()>
Gather padded [batch, max_seq, dim] back to flat [total_tokens, dim].
Used after attention: extracts only the real tokens, discarding padding.
Sourcefn embedding_lookup(
&self,
word_ids: &Self::Tensor,
embedding_table: &Self::Tensor,
seq_len: usize,
hidden: usize,
) -> Result<Self::Tensor>
fn embedding_lookup( &self, word_ids: &Self::Tensor, embedding_table: &Self::Tensor, seq_len: usize, hidden: usize, ) -> Result<Self::Tensor>
Word/position/token-type embedding lookup via gather.
Reads seq_len token IDs from word_ids, gathers rows from
embedding_table, and writes [seq_len, hidden] floats to the result.
Sourcefn add_embeddings(
&self,
hidden: &mut Self::Tensor,
table: &Self::Tensor,
ids: &Self::Tensor,
seq_len: usize,
hidden_dim: usize,
) -> Result<()>
fn add_embeddings( &self, hidden: &mut Self::Tensor, table: &Self::Tensor, ids: &Self::Tensor, seq_len: usize, hidden_dim: usize, ) -> Result<()>
Element-wise add an embedding table lookup into hidden.
Used for position and token-type embeddings:
hidden[i] += table[ids[i]] for each token position.
Sourcefn layer_norm(
&self,
output: &mut Self::Tensor,
input: &Self::Tensor,
weight: &Self::Tensor,
bias: &Self::Tensor,
rows: usize,
cols: usize,
eps: f32,
) -> Result<()>
fn layer_norm( &self, output: &mut Self::Tensor, input: &Self::Tensor, weight: &Self::Tensor, bias: &Self::Tensor, rows: usize, cols: usize, eps: f32, ) -> Result<()>
Layer normalization: output = (input - mean) / sqrt(var + eps) * weight + bias.
Sourcefn gemm(
&self,
a: &Self::Tensor,
b: &Self::Tensor,
output: &mut Self::Tensor,
m: usize,
n: usize,
k: usize,
transpose_b: bool,
) -> Result<()>
fn gemm( &self, a: &Self::Tensor, b: &Self::Tensor, output: &mut Self::Tensor, m: usize, n: usize, k: usize, transpose_b: bool, ) -> Result<()>
General matrix multiply: output = A * B (or A * B^T if transpose_b).
Dimensions: A is [m, k], B is [k, n] (or [n, k] if transposed),
output is [m, n].
Sourcefn gemm_batched(
&self,
a: &Self::Tensor,
b: &Self::Tensor,
output: &mut Self::Tensor,
m: usize,
n: usize,
k: usize,
transpose_b: bool,
stride_a: usize,
stride_b: usize,
stride_c: usize,
batch_count: usize,
) -> Result<()>
fn gemm_batched( &self, a: &Self::Tensor, b: &Self::Tensor, output: &mut Self::Tensor, m: usize, n: usize, k: usize, transpose_b: bool, stride_a: usize, stride_b: usize, stride_c: usize, batch_count: usize, ) -> Result<()>
Batched GEMM for multi-head attention.
Performs batch_count independent GEMMs with strided access into
contiguous buffers. Used for per-head QK^T and attnV.
Sourcefn fused_scale_mask_softmax(
&self,
scores: &mut Self::Tensor,
mask: &Self::Tensor,
batch: usize,
num_heads: usize,
seq_len: usize,
scale: f32,
) -> Result<()>
fn fused_scale_mask_softmax( &self, scores: &mut Self::Tensor, mask: &Self::Tensor, batch: usize, num_heads: usize, seq_len: usize, scale: f32, ) -> Result<()>
Fused scale + mask + softmax for attention scores.
scores = softmax(scores * scale + mask) computed per-head.
Sourcefn fused_scale_mask_softmax_windowed(
&self,
scores: &mut Self::Tensor,
mask: &Self::Tensor,
batch: usize,
num_heads: usize,
seq_len: usize,
scale: f32,
window_size: usize,
) -> Result<()>
fn fused_scale_mask_softmax_windowed( &self, scores: &mut Self::Tensor, mask: &Self::Tensor, batch: usize, num_heads: usize, seq_len: usize, scale: f32, window_size: usize, ) -> Result<()>
Fused scale + mask + sliding window + softmax for attention scores.
Like fused_scale_mask_softmax but
additionally masks out positions where |query_pos - key_pos| > window_size / 2.
Used by ModernBERT’s local attention layers.
Sourcefn build_attn_mask(
&self,
output: &mut Self::Tensor,
int_mask: &Self::Tensor,
n: usize,
) -> Result<()>
fn build_attn_mask( &self, output: &mut Self::Tensor, int_mask: &Self::Tensor, n: usize, ) -> Result<()>
Build a float attention mask from an integer mask.
Converts [batch * seq] int mask (0/1) to [batch * seq] float mask
(0.0 / -10000.0) for use with fused_scale_mask_softmax.
Sourcefn qkv_split(
&self,
q: &mut Self::Tensor,
k: &mut Self::Tensor,
v: &mut Self::Tensor,
qkv: &Self::Tensor,
batch: usize,
seq: usize,
hidden: usize,
num_heads: usize,
head_dim: usize,
) -> Result<()>
fn qkv_split( &self, q: &mut Self::Tensor, k: &mut Self::Tensor, v: &mut Self::Tensor, qkv: &Self::Tensor, batch: usize, seq: usize, hidden: usize, num_heads: usize, head_dim: usize, ) -> Result<()>
Split a fused QKV projection into separate Q, K, V tensors.
Sourcefn banded_qk(
&self,
q: &Self::Tensor,
k: &Self::Tensor,
scores: &mut Self::Tensor,
batch_heads: usize,
seq: usize,
head_dim: usize,
window: usize,
stride_qk: usize,
stride_scores: usize,
) -> Result<()>
fn banded_qk( &self, q: &Self::Tensor, k: &Self::Tensor, scores: &mut Self::Tensor, batch_heads: usize, seq: usize, head_dim: usize, window: usize, stride_qk: usize, stride_scores: usize, ) -> Result<()>
Banded Q@K^T: compute attention scores only within a sliding window.
Output shape: [batch * num_heads, seq, window] (NOT [seq, seq]).
scores[h, i, w] = dot(Q[h, i, :], K[h, i - window/2 + w, :])
where out-of-bounds positions are set to -inf (masked in softmax).
Reduces attention compute from O(seq²) to O(seq × window).
For seq=512, window=128: 4× less compute per local layer.
Sourcefn banded_sv(
&self,
scores: &Self::Tensor,
v: &Self::Tensor,
output: &mut Self::Tensor,
batch_heads: usize,
seq: usize,
head_dim: usize,
window: usize,
stride_scores: usize,
stride_v: usize,
stride_out: usize,
) -> Result<()>
fn banded_sv( &self, scores: &Self::Tensor, v: &Self::Tensor, output: &mut Self::Tensor, batch_heads: usize, seq: usize, head_dim: usize, window: usize, stride_scores: usize, stride_v: usize, stride_out: usize, ) -> Result<()>
Banded scores@V: weighted sum using banded attention scores.
Input scores: [batch * num_heads, seq, window] (from banded_qk).
Output: [batch * num_heads, seq, head_dim].
output[h, i, d] = sum_w scores[h, i, w] * V[h, i - window/2 + w, d]
Sourcefn banded_softmax(
&self,
scores: &mut Self::Tensor,
total_rows: usize,
window: usize,
scale: f32,
) -> Result<()>
fn banded_softmax( &self, scores: &mut Self::Tensor, total_rows: usize, window: usize, scale: f32, ) -> Result<()>
Fused scale + softmax over the window dimension (no padding mask needed).
Operates on [batch * num_heads * seq, window] rows.
Sourcefn attn_reshape(
&self,
output: &mut Self::Tensor,
input: &Self::Tensor,
batch: usize,
seq: usize,
num_heads: usize,
head_dim: usize,
) -> Result<()>
fn attn_reshape( &self, output: &mut Self::Tensor, input: &Self::Tensor, batch: usize, seq: usize, num_heads: usize, head_dim: usize, ) -> Result<()>
Reshape attention output from [batch, num_heads, seq, head_dim] to
[batch * seq, hidden].
Sourcefn apply_rope(
&self,
qk: &mut Self::Tensor,
cos: &Self::Tensor,
sin: &Self::Tensor,
num_rows: usize,
seq_len: usize,
head_dim: usize,
num_heads: usize,
) -> Result<()>
fn apply_rope( &self, qk: &mut Self::Tensor, cos: &Self::Tensor, sin: &Self::Tensor, num_rows: usize, seq_len: usize, head_dim: usize, num_heads: usize, ) -> Result<()>
Apply Rotary Position Embedding (RoPE) to Q/K tensors.
Used by ModernBERT (not ClassicBert which uses learned position embeddings).
Sourcefn split_gate_value(
&self,
first: &mut Self::Tensor,
second: &mut Self::Tensor,
input: &Self::Tensor,
rows: usize,
cols: usize,
) -> Result<()>
fn split_gate_value( &self, first: &mut Self::Tensor, second: &mut Self::Tensor, input: &Self::Tensor, rows: usize, cols: usize, ) -> Result<()>
Split a [rows, 2*cols] matrix into two [rows, cols] halves.
Each row of input is [first_half | second_half]. The first cols
elements go to first, the remaining cols to second.
Used by ModernBERT for gated MLP splits.
Sourcefn gelu(&self, x: &mut Self::Tensor, n: usize) -> Result<()>
fn gelu(&self, x: &mut Self::Tensor, n: usize) -> Result<()>
GELU activation (Gaussian Error Linear Unit), applied in-place.
Sourcefn swiglu(
&self,
value: &Self::Tensor,
gate: &Self::Tensor,
output: &mut Self::Tensor,
n: usize,
) -> Result<()>
fn swiglu( &self, value: &Self::Tensor, gate: &Self::Tensor, output: &mut Self::Tensor, n: usize, ) -> Result<()>
SwiGLU gated activation: output = value * silu(gate).
The gate and value come from splitting the intermediate projection.
Sourcefn geglu(
&self,
value: &Self::Tensor,
gate: &Self::Tensor,
output: &mut Self::Tensor,
n: usize,
) -> Result<()>
fn geglu( &self, value: &Self::Tensor, gate: &Self::Tensor, output: &mut Self::Tensor, n: usize, ) -> Result<()>
GeGLU gated activation: output = gelu(value) * gate.
Used by ModernBERT. The value and gate come from splitting the
MLP Wi projection output in half.
Sourcefn fused_bias_gelu(
&self,
x: &mut Self::Tensor,
bias: &Self::Tensor,
rows: usize,
cols: usize,
) -> Result<()>
fn fused_bias_gelu( &self, x: &mut Self::Tensor, bias: &Self::Tensor, rows: usize, cols: usize, ) -> Result<()>
Fused bias + GELU: x = gelu(x + bias) row-wise.
Sourcefn fused_bias_residual(
&self,
output: &mut Self::Tensor,
input: &Self::Tensor,
bias: &Self::Tensor,
residual: &Self::Tensor,
n: usize,
cols: usize,
) -> Result<()>
fn fused_bias_residual( &self, output: &mut Self::Tensor, input: &Self::Tensor, bias: &Self::Tensor, residual: &Self::Tensor, n: usize, cols: usize, ) -> Result<()>
Fused bias + residual add: output = input + bias + residual.
Bias is broadcast row-wise (cols-wide) across n / cols rows.
Sourcefn fused_residual_layernorm(
&self,
output: &mut Self::Tensor,
hidden: &Self::Tensor,
residual: &Self::Tensor,
weight: &Self::Tensor,
bias: &Self::Tensor,
rows: usize,
cols: usize,
eps: f32,
) -> Result<()>
fn fused_residual_layernorm( &self, output: &mut Self::Tensor, hidden: &Self::Tensor, residual: &Self::Tensor, weight: &Self::Tensor, bias: &Self::Tensor, rows: usize, cols: usize, eps: f32, ) -> Result<()>
Fused residual add + layer normalization.
output = layer_norm(hidden + residual, weight, bias, eps).
Sourcefn residual_add(
&self,
output: &mut Self::Tensor,
hidden: &Self::Tensor,
residual: &Self::Tensor,
n: usize,
) -> Result<()>
fn residual_add( &self, output: &mut Self::Tensor, hidden: &Self::Tensor, residual: &Self::Tensor, n: usize, ) -> Result<()>
Residual add without bias: output = hidden + residual.
Used by ModernBERT which has no bias terms.
Sourcefn add_bias(
&self,
x: &mut Self::Tensor,
bias: &Self::Tensor,
rows: usize,
cols: usize,
) -> Result<()>
fn add_bias( &self, x: &mut Self::Tensor, bias: &Self::Tensor, rows: usize, cols: usize, ) -> Result<()>
Add bias to a matrix row-wise: x[row] += bias for each row.
Sourcefn cls_pool(
&self,
output: &mut Self::Tensor,
hidden: &Self::Tensor,
batch: usize,
seq: usize,
hidden_dim: usize,
) -> Result<()>
fn cls_pool( &self, output: &mut Self::Tensor, hidden: &Self::Tensor, batch: usize, seq: usize, hidden_dim: usize, ) -> Result<()>
CLS pooling: extract the first token’s hidden state per batch element.
Sourcefn mean_pool(
&self,
output: &mut Self::Tensor,
hidden: &Self::Tensor,
mask: &Self::Tensor,
batch: usize,
seq: usize,
hidden_dim: usize,
) -> Result<()>
fn mean_pool( &self, output: &mut Self::Tensor, hidden: &Self::Tensor, mask: &Self::Tensor, batch: usize, seq: usize, hidden_dim: usize, ) -> Result<()>
Mean pooling: attention-mask-weighted average of hidden states.
Provided Methods§
Sourcefn new_for_clone() -> Result<Self>where
Self: Sized,
fn new_for_clone() -> Result<Self>where
Self: Sized,
Create a new driver instance for a cloned worker thread.
CPU drivers are zero-size and always succeed. GPU drivers typically cannot be cloned this way (they share device state) and should leave the default panic implementation.
Sourcefn begin_batch(&self) -> Result<()>
fn begin_batch(&self) -> Result<()>
Begin batched mode: all subsequent operations encode into one dispatch.
GPU drivers accumulate into a single command buffer; CPU is a no-op.
Call Self::end_batch to commit. This eliminates per-call overhead.
Sourcefn end_batch(&self) -> Result<()>
fn end_batch(&self) -> Result<()>
End batched mode: commit all accumulated operations and wait.
Sourcefn flush_batch(&self) -> Result<()>
fn flush_batch(&self) -> Result<()>
Flush the current command buffer and start a new one, preserving pool state. Use mid-forward-pass to prevent GPU timeouts on deep models.
Sourcefn segment_encoder(&self)
fn segment_encoder(&self)
Close and reopen the compute encoder within the same command buffer.
This segments a long sequence of compute dispatches into multiple encoders without committing or waiting. Metal processes encoders back-to-back from the same CB — zero sync overhead.
Use every few layers to prevent encoder state overflow (>~60 dispatches per encoder can cause hangs on some Apple Silicon GPUs).
Sourcefn save_pool_cursor(&self) -> usize
fn save_pool_cursor(&self) -> usize
Save the current pool cursor position. Call BEFORE a layer’s work.
Sourcefn restore_pool_cursor(&self, _saved: usize)
fn restore_pool_cursor(&self, _saved: usize)
Restore the pool cursor to a previously saved position. Call AFTER a layer’s transient tensors have been dropped (out of scope).
The architecture must ensure only the output tensor (hidden_states)
survives — all layer-internal tensors (qkv, scores, context, etc.)
must be dropped before this call so their pool slots can be recycled.
Sourcefn prepare_batch_unpadded(
&self,
encodings: &[Encoding],
) -> Result<BatchInputs<Self::Tensor>>
fn prepare_batch_unpadded( &self, encodings: &[Encoding], ) -> Result<BatchInputs<Self::Tensor>>
Prepare a batch WITHOUT padding — concatenate all tokens flat.
Returns BatchInputs with total_tokens actual tokens (no padding),
cu_seqlens for attention boundaries, and per-token position IDs.
Linear layers (GEMM, LN, GELU) process total_tokens rows.
Attention must pad/unpad around the per-head operations.
Sourcefn alloc_zeros_f16(&self, _n: usize) -> Result<Self::Tensor>
fn alloc_zeros_f16(&self, _n: usize) -> Result<Self::Tensor>
Allocate a zero-initialized FP16 tensor with n half-precision elements.
§Errors
Returns an error if device memory allocation fails or FP16 is unsupported.
Sourcefn f32_to_f16(
&self,
_output: &mut Self::Tensor,
_input: &Self::Tensor,
_n: usize,
) -> Result<()>
fn f32_to_f16( &self, _output: &mut Self::Tensor, _input: &Self::Tensor, _n: usize, ) -> Result<()>
Convert FP32 tensor to FP16 (element-wise narrowing).
Sourcefn f16_to_f32(
&self,
_output: &mut Self::Tensor,
_input: &Self::Tensor,
_n: usize,
) -> Result<()>
fn f16_to_f32( &self, _output: &mut Self::Tensor, _input: &Self::Tensor, _n: usize, ) -> Result<()>
Convert FP16 tensor back to FP32 (element-wise widening).
Sourcefn gemm_mixed(
&self,
_a_f16: &Self::Tensor,
_b_f16: &Self::Tensor,
_output_f32: &mut Self::Tensor,
_m: usize,
_n: usize,
_k: usize,
_transpose_b: bool,
) -> Result<()>
fn gemm_mixed( &self, _a_f16: &Self::Tensor, _b_f16: &Self::Tensor, _output_f32: &mut Self::Tensor, _m: usize, _n: usize, _k: usize, _transpose_b: bool, ) -> Result<()>
Mixed-precision GEMM: FP16 inputs → FP32 output via native simdgroup ops.
Sourcefn gemm_f16(
&self,
_a: &Self::Tensor,
_b: &Self::Tensor,
_output: &mut Self::Tensor,
_m: usize,
_n: usize,
_k: usize,
_transpose_b: bool,
) -> Result<()>
fn gemm_f16( &self, _a: &Self::Tensor, _b: &Self::Tensor, _output: &mut Self::Tensor, _m: usize, _n: usize, _k: usize, _transpose_b: bool, ) -> Result<()>
FP16 GEMM: output = A * B (or A * B^T). All tensors are half.
Sourcefn gemm_batched_f16(
&self,
_a: &Self::Tensor,
_b: &Self::Tensor,
_output: &mut Self::Tensor,
_m: usize,
_n: usize,
_k: usize,
_transpose_b: bool,
_stride_a: usize,
_stride_b: usize,
_stride_c: usize,
_batch_count: usize,
) -> Result<()>
fn gemm_batched_f16( &self, _a: &Self::Tensor, _b: &Self::Tensor, _output: &mut Self::Tensor, _m: usize, _n: usize, _k: usize, _transpose_b: bool, _stride_a: usize, _stride_b: usize, _stride_c: usize, _batch_count: usize, ) -> Result<()>
FP16 batched GEMM for multi-head attention. All tensors are half.
Sourcefn layer_norm_f16(
&self,
_output: &mut Self::Tensor,
_input: &Self::Tensor,
_weight: &Self::Tensor,
_bias: &Self::Tensor,
_rows: usize,
_cols: usize,
_eps: f32,
) -> Result<()>
fn layer_norm_f16( &self, _output: &mut Self::Tensor, _input: &Self::Tensor, _weight: &Self::Tensor, _bias: &Self::Tensor, _rows: usize, _cols: usize, _eps: f32, ) -> Result<()>
FP16 layer normalization. Half I/O, FP32 reductions.
Sourcefn fused_scale_mask_softmax_f16(
&self,
_scores: &mut Self::Tensor,
_mask: &Self::Tensor,
_batch: usize,
_num_heads: usize,
_seq_len: usize,
_scale: f32,
) -> Result<()>
fn fused_scale_mask_softmax_f16( &self, _scores: &mut Self::Tensor, _mask: &Self::Tensor, _batch: usize, _num_heads: usize, _seq_len: usize, _scale: f32, ) -> Result<()>
FP16 fused scale + mask + softmax. Half scores, FP32 reductions.
Sourcefn fused_scale_mask_softmax_windowed_f16(
&self,
_scores: &mut Self::Tensor,
_mask: &Self::Tensor,
_batch: usize,
_num_heads: usize,
_seq_len: usize,
_scale: f32,
_window_size: usize,
) -> Result<()>
fn fused_scale_mask_softmax_windowed_f16( &self, _scores: &mut Self::Tensor, _mask: &Self::Tensor, _batch: usize, _num_heads: usize, _seq_len: usize, _scale: f32, _window_size: usize, ) -> Result<()>
FP16 fused scale + mask + sliding window + softmax.
Sourcefn qkv_split_f16(
&self,
_q: &mut Self::Tensor,
_k: &mut Self::Tensor,
_v: &mut Self::Tensor,
_qkv: &Self::Tensor,
_batch: usize,
_seq: usize,
_hidden: usize,
_num_heads: usize,
_head_dim: usize,
) -> Result<()>
fn qkv_split_f16( &self, _q: &mut Self::Tensor, _k: &mut Self::Tensor, _v: &mut Self::Tensor, _qkv: &Self::Tensor, _batch: usize, _seq: usize, _hidden: usize, _num_heads: usize, _head_dim: usize, ) -> Result<()>
FP16 QKV split: [batch*seq, 3*hidden] into Q, K, V per-head layout.
Sourcefn attn_reshape_f16(
&self,
_output: &mut Self::Tensor,
_input: &Self::Tensor,
_batch: usize,
_seq: usize,
_num_heads: usize,
_head_dim: usize,
) -> Result<()>
fn attn_reshape_f16( &self, _output: &mut Self::Tensor, _input: &Self::Tensor, _batch: usize, _seq: usize, _num_heads: usize, _head_dim: usize, ) -> Result<()>
FP16 attention output reshape: [batch*num_heads, seq, head_dim] to
[batch*seq, hidden].
Sourcefn pad_to_batch_f16(
&self,
_flat: &Self::Tensor,
_padded: &mut Self::Tensor,
_seq_lengths: &[usize],
_max_seq: usize,
_dim: usize,
) -> Result<()>
fn pad_to_batch_f16( &self, _flat: &Self::Tensor, _padded: &mut Self::Tensor, _seq_lengths: &[usize], _max_seq: usize, _dim: usize, ) -> Result<()>
FP16 scatter flat [total_tokens, dim] to padded [batch, max_seq, dim].
Sourcefn unpad_from_batch_f16(
&self,
_padded: &Self::Tensor,
_flat: &mut Self::Tensor,
_seq_lengths: &[usize],
_max_seq: usize,
_dim: usize,
) -> Result<()>
fn unpad_from_batch_f16( &self, _padded: &Self::Tensor, _flat: &mut Self::Tensor, _seq_lengths: &[usize], _max_seq: usize, _dim: usize, ) -> Result<()>
FP16 gather padded [batch, max_seq, dim] back to flat [total_tokens, dim].
Sourcefn rope_encode_f16(
&self,
_qk: &mut Self::Tensor,
_cos: &Self::Tensor,
_sin: &Self::Tensor,
_num_rows: usize,
_seq_len: usize,
_head_dim: usize,
_num_heads: usize,
) -> Result<()>
fn rope_encode_f16( &self, _qk: &mut Self::Tensor, _cos: &Self::Tensor, _sin: &Self::Tensor, _num_rows: usize, _seq_len: usize, _head_dim: usize, _num_heads: usize, ) -> Result<()>
FP16 RoPE: apply rotary position embedding. Half Q/K, float cos/sin tables.
Sourcefn geglu_f16(
&self,
_value: &Self::Tensor,
_gate: &Self::Tensor,
_output: &mut Self::Tensor,
_n: usize,
) -> Result<()>
fn geglu_f16( &self, _value: &Self::Tensor, _gate: &Self::Tensor, _output: &mut Self::Tensor, _n: usize, ) -> Result<()>
FP16 GeGLU gated activation: output = gelu(value) * gate. Half I/O.
Sourcefn fused_residual_layernorm_f16(
&self,
_output: &mut Self::Tensor,
_hidden: &Self::Tensor,
_residual: &Self::Tensor,
_weight: &Self::Tensor,
_bias: &Self::Tensor,
_rows: usize,
_cols: usize,
_eps: f32,
) -> Result<()>
fn fused_residual_layernorm_f16( &self, _output: &mut Self::Tensor, _hidden: &Self::Tensor, _residual: &Self::Tensor, _weight: &Self::Tensor, _bias: &Self::Tensor, _rows: usize, _cols: usize, _eps: f32, ) -> Result<()>
FP16 fused residual add + layer normalization.
Sourcefn residual_add_f16(
&self,
_output: &mut Self::Tensor,
_hidden: &Self::Tensor,
_residual: &Self::Tensor,
_n: usize,
) -> Result<()>
fn residual_add_f16( &self, _output: &mut Self::Tensor, _hidden: &Self::Tensor, _residual: &Self::Tensor, _n: usize, ) -> Result<()>
FP16 residual add (no bias): output = hidden + residual.
Sourcefn split_gate_value_f16(
&self,
_first: &mut Self::Tensor,
_second: &mut Self::Tensor,
_input: &Self::Tensor,
_rows: usize,
_cols: usize,
) -> Result<()>
fn split_gate_value_f16( &self, _first: &mut Self::Tensor, _second: &mut Self::Tensor, _input: &Self::Tensor, _rows: usize, _cols: usize, ) -> Result<()>
FP16 split [rows, 2*cols] into two [rows, cols] halves.
Sourcefn fused_split_geglu_f16(
&self,
output: &mut Self::Tensor,
input: &Self::Tensor,
rows: usize,
cols: usize,
) -> Result<()>
fn fused_split_geglu_f16( &self, output: &mut Self::Tensor, input: &Self::Tensor, rows: usize, cols: usize, ) -> Result<()>
Fused split + GeGLU: read [rows, 2*cols], write [rows, cols].
Combines split_gate_value_f16 and
geglu_f16 into a single kernel, eliminating
two intermediate [rows, cols] buffers and halving HBM round-trips.
Default falls back to separate split + geglu calls.
Sourcefn fused_pad_qkv_split_f16(
&self,
q: &mut Self::Tensor,
k: &mut Self::Tensor,
v: &mut Self::Tensor,
qkv_flat: &Self::Tensor,
seq_lengths: &[usize],
max_seq: usize,
batch: usize,
hidden: usize,
num_heads: usize,
head_dim: usize,
) -> Result<()>
fn fused_pad_qkv_split_f16( &self, q: &mut Self::Tensor, k: &mut Self::Tensor, v: &mut Self::Tensor, qkv_flat: &Self::Tensor, seq_lengths: &[usize], max_seq: usize, batch: usize, hidden: usize, num_heads: usize, head_dim: usize, ) -> Result<()>
Fused pad + QKV split: flat [total_tokens, 3*hidden] → Q, K, V
each [batch*heads, max_seq, head_dim].
Eliminates the padded intermediate buffer. Default calls pad then split.
Sourcefn fused_reshape_unpad_f16(
&self,
flat: &mut Self::Tensor,
heads: &Self::Tensor,
seq_lengths: &[usize],
max_seq: usize,
batch: usize,
num_heads: usize,
head_dim: usize,
) -> Result<()>
fn fused_reshape_unpad_f16( &self, flat: &mut Self::Tensor, heads: &Self::Tensor, seq_lengths: &[usize], max_seq: usize, batch: usize, num_heads: usize, head_dim: usize, ) -> Result<()>
Fused attn_reshape + unpad: [batch*heads, max_seq, head_dim] →
[total_tokens, hidden].
Eliminates the padded context intermediate. Default calls reshape then unpad.