use ferrum_types::{FerrumError, Result};
pub use super::capabilities::{
BackendCollective, BackendGraph, BackendMoeFused, BackendQuantGguf, BackendQuantMarlin,
};
pub use super::types::MoeRouting;
use super::types::{AttnConfig, KvCacheQuant, SrcDtype};
pub const MAX_LAYERS_FOR_GRAPH: usize = 64;
pub trait Backend: Send + Sync + Sized + 'static {
type Buffer: Send + Sync;
type Context;
type Timer: super::timer::BackendTimer<Self>;
fn make_timer() -> Self::Timer;
fn new_context() -> Self::Context;
fn with_device_ordinal<R>(_device_ordinal: Option<usize>, body: impl FnOnce() -> R) -> R {
body()
}
fn supports_device_ordinal_scope() -> bool {
false
}
fn sync(ctx: &mut Self::Context);
fn sync_before_host_readback(_ctx: &mut Self::Context) {}
fn activation_elem_size_bytes() -> usize {
std::mem::size_of::<half::f16>()
}
fn supports_llama_family_batched_decode() -> bool {
true
}
fn zero_buffer(_ctx: &mut Self::Context, _buf: &mut Self::Buffer, _len: usize) -> Result<()> {
Err(FerrumError::unsupported(
"zero_buffer not implemented for this backend",
))
}
fn alloc_typed(dtype: super::Dtype, n: usize) -> Self::Buffer;
fn from_slice_typed<T: super::HostDtype>(data: &[T]) -> Self::Buffer;
fn write_typed<T: super::HostDtype>(
ctx: &mut Self::Context,
dst: &mut Self::Buffer,
data: &[T],
);
fn gemm(
ctx: &mut Self::Context,
a: &Self::Buffer,
b: &Self::Buffer,
out: &mut Self::Buffer,
m: usize,
n: usize,
k: usize,
);
fn rms_norm(
ctx: &mut Self::Context,
x: &Self::Buffer,
w: &Self::Buffer,
eps: f32,
out: &mut Self::Buffer,
tokens: usize,
dim: usize,
);
fn fused_add_rms_norm(
ctx: &mut Self::Context,
residual: &mut Self::Buffer,
x: &Self::Buffer,
w: &Self::Buffer,
eps: f32,
out: &mut Self::Buffer,
tokens: usize,
dim: usize,
);
fn flash_attention(
ctx: &mut Self::Context,
q: &Self::Buffer,
k: &Self::Buffer,
v: &Self::Buffer,
out: &mut Self::Buffer,
batch: usize,
q_len: usize,
kv_len: usize,
pos_offset: usize,
cfg: &AttnConfig,
);
#[allow(clippy::too_many_arguments)]
fn mla_attention(
_ctx: &mut Self::Context,
_q: &Self::Buffer,
_kv_compressed: &Self::Buffer,
_kv_rope: &Self::Buffer,
_out: &mut Self::Buffer,
_batch: usize,
_q_len: usize,
_kv_len: usize,
_pos_offset: usize,
_cfg: &AttnConfig,
_kv_lora_rank: usize,
_qk_rope_head_dim: usize,
) -> Result<()> {
Err(FerrumError::unsupported(
"mla_attention not implemented for this backend; required by \
DeepSeek V2/V3 (Phase D/E)",
))
}
fn copy_slice(
ctx: &mut Self::Context,
src: &Self::Buffer,
src_offset: usize,
dst: &mut Self::Buffer,
dst_offset: usize,
len: usize,
);
fn embedding_lookup(
ctx: &mut Self::Context,
table: &Self::Buffer,
ids: &[u32],
out: &mut Self::Buffer,
dim: usize,
);
fn embedding_lookup_dev(
ctx: &mut Self::Context,
table: &Self::Buffer,
ids: &Self::Buffer,
out: &mut Self::Buffer,
batch: usize,
dim: usize,
) {
let ids_host_f32 = Self::to_vec(ids, batch);
let ids_host_u32: Vec<u32> = ids_host_f32.iter().map(|x| x.to_bits()).collect();
Self::embedding_lookup(ctx, table, &ids_host_u32, out, dim);
}
fn split_qkv(
ctx: &mut Self::Context,
qkv: &Self::Buffer,
q: &mut Self::Buffer,
k: &mut Self::Buffer,
v: &mut Self::Buffer,
tokens: usize,
q_dim: usize,
kv_dim: usize,
);
fn fused_silu_mul_split(
ctx: &mut Self::Context,
gate_up: &Self::Buffer,
out: &mut Self::Buffer,
tokens: usize,
im: usize,
);
#[allow(clippy::too_many_arguments)]
fn qk_norm_rope(
ctx: &mut Self::Context,
input: &Self::Buffer,
norm_w: &Self::Buffer,
cos: &Self::Buffer,
sin: &Self::Buffer,
output: &mut Self::Buffer,
tokens: usize,
heads: usize,
head_dim: usize,
pos_offset: usize,
eps: f32,
mode: i32,
);
fn kv_cache_append_batched_per_cache(
_ctx: &mut Self::Context,
_caches: &[&Self::Buffer],
_new_data: &Self::Buffer,
_cache_lens: &Self::Buffer,
_capacity: usize,
_m: usize,
_nkv: usize,
_hd: usize,
_slot: usize,
) -> Result<()> {
Err(FerrumError::unsupported(
"kv_cache_append_batched_per_cache not implemented for this backend",
))
}
fn flash_attention_batched_per_cache(
_ctx: &mut Self::Context,
_q: &Self::Buffer,
_k_caches: &[&Self::Buffer],
_v_caches: &[&Self::Buffer],
_kv_lens: &Self::Buffer,
_out: &mut Self::Buffer,
_nq: usize,
_nkv: usize,
_hd: usize,
_scale: f32,
_max_valid_kv: usize,
_capacity: usize,
_slot: usize,
) -> Result<()> {
Err(FerrumError::unsupported(
"flash_attention_batched_per_cache not implemented for this backend",
))
}
fn qk_norm_rope_batched_per_item(
_ctx: &mut Self::Context,
_input: &Self::Buffer,
_norm_w: &Self::Buffer,
_cos: &Self::Buffer,
_sin: &Self::Buffer,
_output: &mut Self::Buffer,
_positions: &Self::Buffer,
_m: usize,
_heads: usize,
_head_dim: usize,
_eps: f32,
_mode: i32,
) -> Result<()> {
Err(FerrumError::unsupported(
"qk_norm_rope_batched_per_item not implemented for this backend",
))
}
#[allow(clippy::too_many_arguments)]
fn split_qkv_norm_rope(
_ctx: &mut Self::Context,
_qkv: &Self::Buffer,
_q_norm_w: &Self::Buffer,
_k_norm_w: &Self::Buffer,
_cos: &Self::Buffer,
_sin: &Self::Buffer,
_q_out: &mut Self::Buffer,
_k_out: &mut Self::Buffer,
_v_out: &mut Self::Buffer,
_tokens: usize,
_q_heads: usize,
_kv_heads: usize,
_head_dim: usize,
_pos_offset: usize,
_eps: f32,
_qk_mode: i32,
) -> Result<()> {
Err(FerrumError::unsupported(
"split_qkv_norm_rope not implemented for this backend",
))
}
#[allow(clippy::too_many_arguments)]
fn split_qkv_norm_rope_into_cache(
_ctx: &mut Self::Context,
_qkv: &Self::Buffer,
_q_norm_w: &Self::Buffer,
_k_norm_w: &Self::Buffer,
_cos: &Self::Buffer,
_sin: &Self::Buffer,
_q_out: &mut Self::Buffer,
_cache_k: &mut Self::Buffer,
_cache_v: &mut Self::Buffer,
_tokens: usize,
_q_heads: usize,
_kv_heads: usize,
_head_dim: usize,
_pos_offset: usize,
_eps: f32,
_qk_mode: i32,
_cache_len: usize,
_cache_capacity: usize,
) -> Result<()> {
Err(FerrumError::unsupported(
"split_qkv_norm_rope_into_cache not implemented for this backend",
))
}
#[allow(clippy::too_many_arguments)]
fn kv_cache_append_head_major(
ctx: &mut Self::Context,
cache_k: &mut Self::Buffer,
cache_v: &mut Self::Buffer,
cache_len: usize,
cache_capacity: usize,
new_k_head_major: &Self::Buffer,
new_v_head_major: &Self::Buffer,
new_tokens: usize,
nkv: usize,
hd: usize,
);
fn transpose_head_to_token(
ctx: &mut Self::Context,
src: &Self::Buffer,
dst: &mut Self::Buffer,
tokens: usize,
heads: usize,
dim: usize,
);
fn transpose_token_to_head(
_ctx: &mut Self::Context,
_src: &Self::Buffer,
_dst: &mut Self::Buffer,
_tokens: usize,
_heads: usize,
_dim: usize,
) {
panic!("transpose_token_to_head not implemented for this backend");
}
fn add_inplace(
ctx: &mut Self::Context,
residual: &mut Self::Buffer,
x: &Self::Buffer,
len: usize,
);
fn scaled_add_inplace(
_ctx: &mut Self::Context,
dst: &mut Self::Buffer,
src: &Self::Buffer,
scale: f32,
len: usize,
) {
let mut dst_v = Self::to_vec(dst, len);
let src_v = Self::to_vec(src, len);
for i in 0..len {
dst_v[i] += scale * src_v[i];
}
*dst = Self::from_slice(&dst_v);
}
#[allow(clippy::too_many_arguments)]
fn fused_silu_mul_split_strided(
_ctx: &mut Self::Context,
_gate_up: &Self::Buffer,
_in_row_offset: usize,
_out: &mut Self::Buffer,
_out_row_offset: usize,
_tokens: usize,
_intermediate: usize,
) {
unimplemented!("fused_silu_mul_split_strided default impl missing");
}
fn add_bias(
ctx: &mut Self::Context,
data: &mut Self::Buffer,
bias: &Self::Buffer,
rows: usize,
cols: usize,
);
#[allow(clippy::too_many_arguments)]
fn layer_norm(
ctx: &mut Self::Context,
x: &Self::Buffer,
gamma: &Self::Buffer,
beta: &Self::Buffer,
eps: f32,
out: &mut Self::Buffer,
tokens: usize,
dim: usize,
);
fn gelu(ctx: &mut Self::Context, x: &Self::Buffer, out: &mut Self::Buffer, len: usize);
fn alloc(len: usize) -> Self::Buffer;
fn to_vec(buf: &Self::Buffer, len: usize) -> Vec<f32>;
fn from_slice(data: &[f32]) -> Self::Buffer;
fn write_f32_to_activation(ctx: &mut Self::Context, dst: &mut Self::Buffer, data: &[f32]) {
if data.is_empty() {
return;
}
let src = Self::from_slice(data);
Self::copy_slice(ctx, &src, 0, dst, 0, data.len());
}
fn argmax_rows_f16(
_ctx: &mut Self::Context,
logits: &Self::Buffer,
m: usize,
n: usize,
) -> Result<Vec<u32>> {
let host = Self::to_vec(logits, m * n);
let mut out = Vec::with_capacity(m);
for row in 0..m {
let slice = &host[row * n..(row + 1) * n];
let mut max_idx = 0usize;
let mut max_val = f32::NEG_INFINITY;
for (i, &v) in slice.iter().enumerate() {
if v > max_val {
max_val = v;
max_idx = i;
}
}
out.push(max_idx as u32);
}
Ok(out)
}
fn from_weight_bytes(raw: &[u8], src_dtype: SrcDtype) -> Self::Buffer {
let data = src_dtype.to_f32_vec(raw);
Self::from_slice(&data)
}
}
pub trait BackendPagedKv: Backend {
fn supports_paged_kv() -> bool {
false
}
fn populate_batched_pointers(
_ctx: &mut Self::Context,
_k_caches: &[&Self::Buffer],
_v_caches: &[&Self::Buffer],
_num_layers: usize,
_m: usize,
) -> Result<()> {
Err(FerrumError::unsupported(
"populate_batched_pointers not implemented for this backend",
))
}
#[allow(clippy::too_many_arguments)]
fn split_qkv_norm_rope_into_paged_cache(
_ctx: &mut Self::Context,
_qkv: &Self::Buffer,
_qkv_byte_offset: u64,
_q_norm_w: &Self::Buffer,
_k_norm_w: &Self::Buffer,
_cos: &Self::Buffer,
_sin: &Self::Buffer,
_q_out: &mut Self::Buffer,
_q_out_byte_offset: u64,
_cache_k: &mut Self::Buffer,
_cache_v: &mut Self::Buffer,
_block_table: &Self::Buffer,
_tokens: usize,
_q_heads: usize,
_kv_heads: usize,
_head_dim: usize,
_pos_offset: usize,
_eps: f32,
_qk_mode: i32,
_cache_len: usize,
_block_size: usize,
_max_num_blocks_per_seq: usize,
) -> Result<()> {
Err(FerrumError::unsupported(
"split_qkv_norm_rope_into_paged_cache not implemented for this backend",
))
}
#[allow(clippy::too_many_arguments)]
fn paged_decode_attention(
_ctx: &mut Self::Context,
_q: &Self::Buffer,
_k_pool: &Self::Buffer,
_v_pool: &Self::Buffer,
_out: &mut Self::Buffer,
_block_tables: &Self::Buffer,
_context_lens: &Self::Buffer,
_num_seqs: usize,
_num_heads: usize,
_num_kv_heads: usize,
_head_dim: usize,
_block_size: usize,
_max_num_blocks_per_seq: usize,
_q_len: usize,
) -> Result<()> {
Err(FerrumError::unsupported(
"paged_decode_attention not implemented for this backend",
))
}
fn supports_varlen_qkv() -> bool {
false
}
#[allow(clippy::too_many_arguments)]
fn split_qkv_norm_rope_into_paged_cache_varlen(
_ctx: &mut Self::Context,
_qkv: &Self::Buffer,
_q_norm_w: &Self::Buffer,
_k_norm_w: &Self::Buffer,
_cos: &Self::Buffer,
_sin: &Self::Buffer,
_q_out: &mut Self::Buffer,
_cache_k: &mut Self::Buffer,
_cache_v: &mut Self::Buffer,
_cu_seqlens_q: &Self::Buffer,
_pos_offsets: &Self::Buffer,
_block_tables: &Self::Buffer,
_num_seqs: usize,
_m_total: usize,
_q_heads: usize,
_kv_heads: usize,
_head_dim: usize,
_eps: f32,
_qk_mode: i32,
_block_size: usize,
_max_blocks_per_seq: usize,
) -> Result<()> {
Err(FerrumError::unsupported(
"split_qkv_norm_rope_into_paged_cache_varlen not implemented for this backend",
))
}
#[allow(clippy::too_many_arguments)]
fn paged_varlen_attention(
_ctx: &mut Self::Context,
_q: &Self::Buffer,
_k_pool: &Self::Buffer,
_v_pool: &Self::Buffer,
_out: &mut Self::Buffer,
_cu_seqlens_q: &Self::Buffer,
_pos_offsets: &Self::Buffer,
_block_tables: &Self::Buffer,
_num_seqs: usize,
_total_q_tokens: usize,
_max_kv_len: usize,
_num_heads: usize,
_num_kv_heads: usize,
_head_dim: usize,
_block_size: usize,
_max_num_blocks_per_seq: usize,
) -> Result<()> {
Err(FerrumError::unsupported(
"paged_varlen_attention not implemented for this backend",
))
}
#[allow(clippy::too_many_arguments)]
fn paged_varlen_attention_fa2_ffi(
_ctx: &mut Self::Context,
_q: &Self::Buffer,
_k_pool: &Self::Buffer,
_v_pool: &Self::Buffer,
_out: &mut Self::Buffer,
_lse: &mut Self::Buffer,
_cu_seqlens_q: &Self::Buffer,
_seq_lens: &Self::Buffer,
_block_tables: &Self::Buffer,
_num_seqs: usize,
_total_q_tokens: usize,
_max_q_len: usize,
_max_kv_len: usize,
_num_heads: usize,
_num_kv_heads: usize,
_head_dim: usize,
_block_size: usize,
_max_num_blocks_per_seq: usize,
) -> Result<()> {
Err(FerrumError::unsupported(
"paged_varlen_attention_fa2_ffi not implemented for this backend",
))
}
#[allow(clippy::too_many_arguments)]
fn paged_batched_decode_attention(
_ctx: &mut Self::Context,
_q: &Self::Buffer,
_k_pool: &Self::Buffer,
_v_pool: &Self::Buffer,
_out: &mut Self::Buffer,
_block_tables: &Self::Buffer,
_valid_kv_lens: &Self::Buffer,
_num_seqs: usize,
_max_kv_len: usize,
_num_heads: usize,
_num_kv_heads: usize,
_head_dim: usize,
_block_size: usize,
_max_num_blocks_per_seq: usize,
) -> Result<()> {
Err(FerrumError::unsupported(
"paged_batched_decode_attention not implemented for this backend",
))
}
fn supports_vllm_paged_attn() -> bool {
false
}
#[allow(clippy::too_many_arguments)]
fn split_qkv_norm_rope_into_paged_cache_vllm(
_ctx: &mut Self::Context,
_qkv: &Self::Buffer,
_qkv_byte_offset: u64,
_q_norm_w: &Self::Buffer,
_k_norm_w: &Self::Buffer,
_cos: &Self::Buffer,
_sin: &Self::Buffer,
_q_out: &mut Self::Buffer,
_q_out_byte_offset: u64,
_cache_k: &mut Self::Buffer,
_cache_v: &mut Self::Buffer,
_block_table: &Self::Buffer,
_tokens: usize,
_q_heads: usize,
_kv_heads: usize,
_head_dim: usize,
_pos_offset: usize,
_eps: f32,
_qk_mode: i32,
_cache_len: usize,
_block_size: usize,
_max_num_blocks_per_seq: usize,
) -> Result<()> {
Err(FerrumError::unsupported(
"split_qkv_norm_rope_into_paged_cache_vllm not implemented for this backend",
))
}
#[allow(clippy::too_many_arguments)]
fn split_qkv_norm_rope_into_paged_cache_varlen_vllm(
_ctx: &mut Self::Context,
_qkv: &Self::Buffer,
_q_norm_w: &Self::Buffer,
_k_norm_w: &Self::Buffer,
_cos: &Self::Buffer,
_sin: &Self::Buffer,
_q_out: &mut Self::Buffer,
_cache_k: &mut Self::Buffer,
_cache_v: &mut Self::Buffer,
_cu_seqlens_q: &Self::Buffer,
_pos_offsets: &Self::Buffer,
_block_tables: &Self::Buffer,
_num_seqs: usize,
_m_total: usize,
_q_heads: usize,
_kv_heads: usize,
_head_dim: usize,
_eps: f32,
_qk_mode: i32,
_block_size: usize,
_max_blocks_per_seq: usize,
) -> Result<()> {
Err(FerrumError::unsupported(
"split_qkv_norm_rope_into_paged_cache_varlen_vllm not implemented for this backend",
))
}
#[allow(clippy::too_many_arguments)]
fn paged_decode_attention_v2(
_ctx: &mut Self::Context,
_q: &Self::Buffer,
_k_pool: &Self::Buffer,
_v_pool: &Self::Buffer,
_out: &mut Self::Buffer,
_block_tables: &Self::Buffer,
_context_lens: &Self::Buffer,
_num_seqs: usize,
_num_heads: usize,
_num_kv_heads: usize,
_head_dim: usize,
_block_size: usize,
_max_num_blocks_per_seq: usize,
_max_seq_len: usize,
) -> Result<()> {
Err(FerrumError::unsupported(
"paged_decode_attention_v2 not implemented for this backend",
))
}
#[allow(clippy::too_many_arguments)]
fn paged_varlen_attention_vllm_layout(
_ctx: &mut Self::Context,
_q: &Self::Buffer,
_k_pool: &Self::Buffer,
_v_pool: &Self::Buffer,
_out: &mut Self::Buffer,
_block_tables: &Self::Buffer,
_context_lens: &Self::Buffer,
_num_seqs: usize,
_num_heads: usize,
_num_kv_heads: usize,
_head_dim: usize,
_block_size: usize,
_max_num_blocks_per_seq: usize,
_q_len: usize,
) -> Result<()> {
Err(FerrumError::unsupported(
"paged_varlen_attention_vllm_layout not implemented for this backend",
))
}
#[allow(clippy::too_many_arguments)]
fn paged_varlen_attention_vllm(
_ctx: &mut Self::Context,
_q: &Self::Buffer,
_k_pool: &Self::Buffer,
_v_pool: &Self::Buffer,
_out: &mut Self::Buffer,
_cu_seqlens_q: &Self::Buffer,
_pos_offsets: &Self::Buffer,
_block_tables: &Self::Buffer,
_num_seqs: usize,
_total_q_tokens: usize,
_max_kv_len: usize,
_num_heads: usize,
_num_kv_heads: usize,
_head_dim: usize,
_block_size: usize,
_max_num_blocks_per_seq: usize,
) -> Result<()> {
Err(FerrumError::unsupported(
"paged_varlen_attention_vllm not implemented for this backend",
))
}
#[allow(clippy::too_many_arguments)]
fn paged_varlen_attention_vllm_tiled_q4(
_ctx: &mut Self::Context,
_q: &Self::Buffer,
_k_pool: &Self::Buffer,
_v_pool: &Self::Buffer,
_out: &mut Self::Buffer,
_cu_seqlens_q: &Self::Buffer,
_pos_offsets: &Self::Buffer,
_block_tables: &Self::Buffer,
_tile_seqs: &Self::Buffer,
_tile_starts: &Self::Buffer,
_num_tiles: usize,
_max_kv_len: usize,
_num_heads: usize,
_num_kv_heads: usize,
_head_dim: usize,
_block_size: usize,
_max_num_blocks_per_seq: usize,
) -> Result<()> {
Err(FerrumError::unsupported(
"paged_varlen_attention_vllm_tiled_q4 not implemented for this backend",
))
}
}
pub trait LlmBackend: Backend + BackendGraph + BackendPagedKv {}
impl<T> LlmBackend for T where T: Backend + BackendGraph + BackendPagedKv {}
pub trait QuantLlmBackend: LlmBackend + BackendQuantMarlin + BackendQuantGguf {}
impl<T> QuantLlmBackend for T where T: LlmBackend + BackendQuantMarlin + BackendQuantGguf {}
pub trait MoeLlmBackend: QuantLlmBackend + BackendMoeFused {}
impl<T> MoeLlmBackend for T where T: QuantLlmBackend + BackendMoeFused {}
pub use ferrum_interfaces::kv_dtype::{KvBf16, KvDtypeKind, KvFp16, KvFp8, KvInt8};
pub trait BackendKvDtype<K: KvDtypeKind>: BackendPagedKv {
type KvBuffer: Send + Sync;
type KvScales: Send + Sync + Default;
}
#[allow(clippy::too_many_arguments)]
pub trait BackendInt8KvOps: Backend + BackendKvDtype<KvInt8> {
fn alloc_paged_int8_layer(
_max_blocks_per_seq: usize,
_block_size: usize,
_num_kv_heads: usize,
_head_dim: usize,
) -> KvCacheQuant<Self, KvInt8> {
unimplemented!("alloc_paged_int8_layer not supported on this backend")
}
fn int8_kv_append_paged(
_ctx: &mut Self::Context,
_k_in: &Self::Buffer,
_v_in: &Self::Buffer,
_layer_k: &mut <Self as BackendKvDtype<KvInt8>>::KvBuffer,
_layer_v: &mut <Self as BackendKvDtype<KvInt8>>::KvBuffer,
_layer_k_scales: &mut <Self as BackendKvDtype<KvInt8>>::KvScales,
_layer_v_scales: &mut <Self as BackendKvDtype<KvInt8>>::KvScales,
_paged_block_indices: &[u32],
_cache_len_before: usize,
_tokens: usize,
_block_size: usize,
_num_kv_heads: usize,
_head_dim: usize,
) -> Result<()> {
Err(FerrumError::unsupported(
"int8_kv_append_paged not implemented for this backend",
))
}
fn int8_paged_decode_attention(
_ctx: &mut Self::Context,
_q: &Self::Buffer,
_layer_k: &<Self as BackendKvDtype<KvInt8>>::KvBuffer,
_layer_v: &<Self as BackendKvDtype<KvInt8>>::KvBuffer,
_layer_k_scales: &<Self as BackendKvDtype<KvInt8>>::KvScales,
_layer_v_scales: &<Self as BackendKvDtype<KvInt8>>::KvScales,
_block_table: &Self::Buffer,
_output: &mut Self::Buffer,
_num_q_heads: usize,
_num_kv_heads: usize,
_head_dim: usize,
_valid_kv_len: usize,
_block_size: usize,
_scale: f32,
) -> Result<()> {
Err(FerrumError::unsupported(
"int8_paged_decode_attention not implemented for this backend",
))
}
}