#[cfg(any(feature = "cpu", feature = "cpu-accelerate"))]
pub mod cpu;
#[cfg(feature = "cuda")]
pub mod cuda;
#[cfg(feature = "metal")]
pub mod metal;
#[cfg(feature = "mlx")]
pub mod mlx;
use super::Encoding;
pub trait Driver: Send + Sync {
type Tensor;
fn name(&self) -> &'static str;
fn new_for_clone() -> crate::Result<Self>
where
Self: Sized,
{
Err(crate::Error::Other(anyhow::anyhow!(
"this driver does not support cloning"
)))
}
fn begin_batch(&self) -> crate::Result<()> {
Ok(())
}
fn end_batch(&self) -> crate::Result<()> {
Ok(())
}
fn flush_batch(&self) -> crate::Result<()> {
Ok(())
}
fn segment_encoder(&self) {
}
fn save_pool_cursor(&self) -> usize {
0
}
fn restore_pool_cursor(&self, _saved: usize) {}
fn alloc_zeros(&self, n: usize) -> crate::Result<Self::Tensor>;
fn clone_tensor(&self, tensor: &Self::Tensor, n: usize) -> crate::Result<Self::Tensor>;
fn prepare_batch(
&self,
encodings: &[Encoding],
max_seq: usize,
) -> crate::Result<BatchInputs<Self::Tensor>>;
fn prepare_batch_unpadded(
&self,
encodings: &[Encoding],
) -> crate::Result<BatchInputs<Self::Tensor>> {
let max_seq = encodings
.iter()
.map(|e| e.input_ids.len())
.max()
.unwrap_or(0)
.next_multiple_of(8);
self.prepare_batch(encodings, max_seq)
}
fn pad_to_batch(
&self,
flat: &Self::Tensor,
padded: &mut Self::Tensor,
seq_lengths: &[usize],
max_seq: usize,
dim: usize,
) -> crate::Result<()>;
fn unpad_from_batch(
&self,
padded: &Self::Tensor,
flat: &mut Self::Tensor,
seq_lengths: &[usize],
max_seq: usize,
dim: usize,
) -> crate::Result<()>;
fn embedding_lookup(
&self,
word_ids: &Self::Tensor,
embedding_table: &Self::Tensor,
seq_len: usize,
hidden: usize,
) -> crate::Result<Self::Tensor>;
fn add_embeddings(
&self,
hidden: &mut Self::Tensor,
table: &Self::Tensor,
ids: &Self::Tensor,
seq_len: usize,
hidden_dim: usize,
) -> crate::Result<()>;
fn layer_norm(
&self,
output: &mut Self::Tensor,
input: &Self::Tensor,
weight: &Self::Tensor,
bias: &Self::Tensor,
rows: usize,
cols: usize,
eps: f32,
) -> crate::Result<()>;
fn gemm(
&self,
a: &Self::Tensor,
b: &Self::Tensor,
output: &mut Self::Tensor,
m: usize,
n: usize,
k: usize,
transpose_b: bool,
) -> crate::Result<()>;
fn gemm_batched(
&self,
a: &Self::Tensor,
b: &Self::Tensor,
output: &mut Self::Tensor,
m: usize,
n: usize,
k: usize,
transpose_b: bool,
stride_a: usize,
stride_b: usize,
stride_c: usize,
batch_count: usize,
) -> crate::Result<()>;
fn fused_scale_mask_softmax(
&self,
scores: &mut Self::Tensor,
mask: &Self::Tensor,
batch: usize,
num_heads: usize,
seq_len: usize,
scale: f32,
) -> crate::Result<()>;
fn fused_scale_mask_softmax_windowed(
&self,
scores: &mut Self::Tensor,
mask: &Self::Tensor,
batch: usize,
num_heads: usize,
seq_len: usize,
scale: f32,
window_size: usize,
) -> crate::Result<()>;
fn build_attn_mask(
&self,
output: &mut Self::Tensor,
int_mask: &Self::Tensor,
n: usize,
) -> crate::Result<()>;
fn qkv_split(
&self,
q: &mut Self::Tensor,
k: &mut Self::Tensor,
v: &mut Self::Tensor,
qkv: &Self::Tensor,
batch: usize,
seq: usize,
hidden: usize,
num_heads: usize,
head_dim: usize,
) -> crate::Result<()>;
fn banded_qk(
&self,
q: &Self::Tensor,
k: &Self::Tensor,
scores: &mut Self::Tensor,
batch_heads: usize,
seq: usize,
head_dim: usize,
window: usize,
stride_qk: usize,
stride_scores: usize,
) -> crate::Result<()>;
fn banded_sv(
&self,
scores: &Self::Tensor,
v: &Self::Tensor,
output: &mut Self::Tensor,
batch_heads: usize,
seq: usize,
head_dim: usize,
window: usize,
stride_scores: usize,
stride_v: usize,
stride_out: usize,
) -> crate::Result<()>;
fn banded_softmax(
&self,
scores: &mut Self::Tensor,
total_rows: usize,
window: usize,
scale: f32,
) -> crate::Result<()>;
fn attn_reshape(
&self,
output: &mut Self::Tensor,
input: &Self::Tensor,
batch: usize,
seq: usize,
num_heads: usize,
head_dim: usize,
) -> crate::Result<()>;
fn apply_rope(
&self,
qk: &mut Self::Tensor,
cos: &Self::Tensor,
sin: &Self::Tensor,
num_rows: usize,
seq_len: usize,
head_dim: usize,
num_heads: usize,
) -> crate::Result<()>;
fn split_gate_value(
&self,
first: &mut Self::Tensor,
second: &mut Self::Tensor,
input: &Self::Tensor,
rows: usize,
cols: usize,
) -> crate::Result<()>;
fn gelu(&self, x: &mut Self::Tensor, n: usize) -> crate::Result<()>;
fn swiglu(
&self,
value: &Self::Tensor,
gate: &Self::Tensor,
output: &mut Self::Tensor,
n: usize,
) -> crate::Result<()>;
fn geglu(
&self,
value: &Self::Tensor,
gate: &Self::Tensor,
output: &mut Self::Tensor,
n: usize,
) -> crate::Result<()>;
fn fused_bias_gelu(
&self,
x: &mut Self::Tensor,
bias: &Self::Tensor,
rows: usize,
cols: usize,
) -> crate::Result<()>;
fn fused_bias_residual(
&self,
output: &mut Self::Tensor,
input: &Self::Tensor,
bias: &Self::Tensor,
residual: &Self::Tensor,
n: usize,
cols: usize,
) -> crate::Result<()>;
fn fused_residual_layernorm(
&self,
output: &mut Self::Tensor,
hidden: &Self::Tensor,
residual: &Self::Tensor,
weight: &Self::Tensor,
bias: &Self::Tensor,
rows: usize,
cols: usize,
eps: f32,
) -> crate::Result<()>;
fn residual_add(
&self,
output: &mut Self::Tensor,
hidden: &Self::Tensor,
residual: &Self::Tensor,
n: usize,
) -> crate::Result<()>;
fn add_bias(
&self,
x: &mut Self::Tensor,
bias: &Self::Tensor,
rows: usize,
cols: usize,
) -> crate::Result<()>;
fn cls_pool(
&self,
output: &mut Self::Tensor,
hidden: &Self::Tensor,
batch: usize,
seq: usize,
hidden_dim: usize,
) -> crate::Result<()>;
fn mean_pool(
&self,
output: &mut Self::Tensor,
hidden: &Self::Tensor,
mask: &Self::Tensor,
batch: usize,
seq: usize,
hidden_dim: usize,
) -> crate::Result<()>;
fn l2_normalize(&self, data: &mut Self::Tensor, rows: usize, cols: usize) -> crate::Result<()>;
fn to_host(
&self,
tensor: &Self::Tensor,
batch: usize,
dim: usize,
) -> crate::Result<Vec<Vec<f32>>>;
fn debug_tensor(
&self,
_label: &str,
_tensor: &Self::Tensor,
_rows: usize,
_cols: usize,
) -> crate::Result<()> {
Ok(())
}
fn debug_tensors_enabled(&self) -> bool {
false
}
fn alloc_zeros_f16(&self, _n: usize) -> crate::Result<Self::Tensor> {
Err(crate::Error::Metal(
"FP16 not supported by this driver".into(),
))
}
fn f32_to_f16(
&self,
_output: &mut Self::Tensor,
_input: &Self::Tensor,
_n: usize,
) -> crate::Result<()> {
Err(crate::Error::Metal(
"FP16 not supported by this driver".into(),
))
}
fn f16_to_f32(
&self,
_output: &mut Self::Tensor,
_input: &Self::Tensor,
_n: usize,
) -> crate::Result<()> {
Err(crate::Error::Metal(
"FP16 not supported by this driver".into(),
))
}
fn gemm_mixed(
&self,
_a_f16: &Self::Tensor,
_b_f16: &Self::Tensor,
_output_f32: &mut Self::Tensor,
_m: usize,
_n: usize,
_k: usize,
_transpose_b: bool,
) -> crate::Result<()> {
Err(crate::Error::Metal(
"gemm_mixed not supported by this driver".into(),
))
}
fn gemm_f16(
&self,
_a: &Self::Tensor,
_b: &Self::Tensor,
_output: &mut Self::Tensor,
_m: usize,
_n: usize,
_k: usize,
_transpose_b: bool,
) -> crate::Result<()> {
Err(crate::Error::Metal(
"FP16 not supported by this driver".into(),
))
}
#[expect(
clippy::too_many_arguments,
reason = "matches FP32 gemm_batched signature"
)]
fn gemm_batched_f16(
&self,
_a: &Self::Tensor,
_b: &Self::Tensor,
_output: &mut Self::Tensor,
_m: usize,
_n: usize,
_k: usize,
_transpose_b: bool,
_stride_a: usize,
_stride_b: usize,
_stride_c: usize,
_batch_count: usize,
) -> crate::Result<()> {
Err(crate::Error::Metal(
"FP16 not supported by this driver".into(),
))
}
fn layer_norm_f16(
&self,
_output: &mut Self::Tensor,
_input: &Self::Tensor,
_weight: &Self::Tensor,
_bias: &Self::Tensor,
_rows: usize,
_cols: usize,
_eps: f32,
) -> crate::Result<()> {
Err(crate::Error::Metal(
"FP16 not supported by this driver".into(),
))
}
fn fused_scale_mask_softmax_f16(
&self,
_scores: &mut Self::Tensor,
_mask: &Self::Tensor,
_batch: usize,
_num_heads: usize,
_seq_len: usize,
_scale: f32,
) -> crate::Result<()> {
Err(crate::Error::Metal(
"FP16 not supported by this driver".into(),
))
}
fn fused_scale_mask_softmax_windowed_f16(
&self,
_scores: &mut Self::Tensor,
_mask: &Self::Tensor,
_batch: usize,
_num_heads: usize,
_seq_len: usize,
_scale: f32,
_window_size: usize,
) -> crate::Result<()> {
Err(crate::Error::Metal(
"FP16 not supported by this driver".into(),
))
}
fn qkv_split_f16(
&self,
_q: &mut Self::Tensor,
_k: &mut Self::Tensor,
_v: &mut Self::Tensor,
_qkv: &Self::Tensor,
_batch: usize,
_seq: usize,
_hidden: usize,
_num_heads: usize,
_head_dim: usize,
) -> crate::Result<()> {
Err(crate::Error::Metal(
"FP16 not supported by this driver".into(),
))
}
fn attn_reshape_f16(
&self,
_output: &mut Self::Tensor,
_input: &Self::Tensor,
_batch: usize,
_seq: usize,
_num_heads: usize,
_head_dim: usize,
) -> crate::Result<()> {
Err(crate::Error::Metal(
"FP16 not supported by this driver".into(),
))
}
fn pad_to_batch_f16(
&self,
_flat: &Self::Tensor,
_padded: &mut Self::Tensor,
_seq_lengths: &[usize],
_max_seq: usize,
_dim: usize,
) -> crate::Result<()> {
Err(crate::Error::Metal(
"FP16 not supported by this driver".into(),
))
}
fn unpad_from_batch_f16(
&self,
_padded: &Self::Tensor,
_flat: &mut Self::Tensor,
_seq_lengths: &[usize],
_max_seq: usize,
_dim: usize,
) -> crate::Result<()> {
Err(crate::Error::Metal(
"FP16 not supported by this driver".into(),
))
}
fn rope_encode_f16(
&self,
_qk: &mut Self::Tensor,
_cos: &Self::Tensor,
_sin: &Self::Tensor,
_num_rows: usize,
_seq_len: usize,
_head_dim: usize,
_num_heads: usize,
) -> crate::Result<()> {
Err(crate::Error::Metal(
"FP16 not supported by this driver".into(),
))
}
fn geglu_f16(
&self,
_value: &Self::Tensor,
_gate: &Self::Tensor,
_output: &mut Self::Tensor,
_n: usize,
) -> crate::Result<()> {
Err(crate::Error::Metal(
"FP16 not supported by this driver".into(),
))
}
fn fused_residual_layernorm_f16(
&self,
_output: &mut Self::Tensor,
_hidden: &Self::Tensor,
_residual: &Self::Tensor,
_weight: &Self::Tensor,
_bias: &Self::Tensor,
_rows: usize,
_cols: usize,
_eps: f32,
) -> crate::Result<()> {
Err(crate::Error::Metal(
"FP16 not supported by this driver".into(),
))
}
fn residual_add_f16(
&self,
_output: &mut Self::Tensor,
_hidden: &Self::Tensor,
_residual: &Self::Tensor,
_n: usize,
) -> crate::Result<()> {
Err(crate::Error::Metal(
"FP16 not supported by this driver".into(),
))
}
fn split_gate_value_f16(
&self,
_first: &mut Self::Tensor,
_second: &mut Self::Tensor,
_input: &Self::Tensor,
_rows: usize,
_cols: usize,
) -> crate::Result<()> {
Err(crate::Error::Metal(
"FP16 not supported by this driver".into(),
))
}
fn fused_split_geglu_f16(
&self,
output: &mut Self::Tensor,
input: &Self::Tensor,
rows: usize,
cols: usize,
) -> crate::Result<()> {
let n = rows * cols;
let mut value = self.alloc_zeros_f16(n)?;
let mut gate = self.alloc_zeros_f16(n)?;
self.split_gate_value_f16(&mut value, &mut gate, input, rows, cols)?;
self.geglu_f16(&value, &gate, output, n)
}
#[expect(clippy::too_many_arguments, reason = "mirrors pad + qkv_split args")]
fn fused_pad_qkv_split_f16(
&self,
q: &mut Self::Tensor,
k: &mut Self::Tensor,
v: &mut Self::Tensor,
qkv_flat: &Self::Tensor,
seq_lengths: &[usize],
max_seq: usize,
batch: usize,
hidden: usize,
num_heads: usize,
head_dim: usize,
) -> crate::Result<()> {
let padded_tokens = batch * max_seq;
let mut qkv_padded = self.alloc_zeros_f16(padded_tokens * 3 * hidden)?;
self.pad_to_batch_f16(qkv_flat, &mut qkv_padded, seq_lengths, max_seq, 3 * hidden)?;
self.qkv_split_f16(
q,
k,
v,
&qkv_padded,
batch,
max_seq,
hidden,
num_heads,
head_dim,
)
}
fn fused_reshape_unpad_f16(
&self,
flat: &mut Self::Tensor,
heads: &Self::Tensor,
seq_lengths: &[usize],
max_seq: usize,
batch: usize,
num_heads: usize,
head_dim: usize,
) -> crate::Result<()> {
let hidden = num_heads * head_dim;
let padded_tokens = batch * max_seq;
let mut context = self.alloc_zeros_f16(padded_tokens * hidden)?;
self.attn_reshape_f16(&mut context, heads, batch, max_seq, num_heads, head_dim)?;
self.unpad_from_batch_f16(&context, flat, seq_lengths, max_seq, hidden)
}
}
pub struct BatchInputs<T> {
pub input_ids: T,
pub attention_mask: T,
pub token_type_ids: T,
pub position_ids: T,
pub float_mask: T,
pub pooling_mask: T,
pub batch: usize,
pub max_seq: usize,
pub total_tokens: usize,
pub seq_lengths: Vec<usize>,
pub cu_seqlens: Option<Vec<usize>>,
}