use ferrum_types::{FerrumError, Result};
use half::{bf16, f16};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum SrcDtype {
F32,
F16,
BF16,
}
impl SrcDtype {
pub const fn bytes_per_elem(self) -> usize {
match self {
SrcDtype::F32 => 4,
SrcDtype::F16 | SrcDtype::BF16 => 2,
}
}
pub fn to_f32_vec(self, raw: &[u8]) -> Vec<f32> {
match self {
SrcDtype::F32 => {
debug_assert_eq!(raw.len() % 4, 0);
let n = raw.len() / 4;
let mut out = vec![0f32; n];
for i in 0..n {
let b = [raw[i * 4], raw[i * 4 + 1], raw[i * 4 + 2], raw[i * 4 + 3]];
out[i] = f32::from_le_bytes(b);
}
out
}
SrcDtype::F16 => {
debug_assert_eq!(raw.len() % 2, 0);
let n = raw.len() / 2;
let mut out = vec![0f32; n];
for i in 0..n {
out[i] = f16::from_le_bytes([raw[i * 2], raw[i * 2 + 1]]).to_f32();
}
out
}
SrcDtype::BF16 => {
debug_assert_eq!(raw.len() % 2, 0);
let n = raw.len() / 2;
let mut out = vec![0f32; n];
for i in 0..n {
out[i] = bf16::from_le_bytes([raw[i * 2], raw[i * 2 + 1]]).to_f32();
}
out
}
}
}
}
#[derive(Clone, Debug)]
pub enum QuantKind {
Gptq {
bits: u32,
group_size: usize,
desc_act: bool,
},
Awq { bits: u32, group_size: usize },
Gguf { quant_type: GgufQuantType },
}
#[derive(Clone, Copy, Debug)]
pub enum GgufQuantType {
Q4_0,
Q4_1,
Q4K,
Q5K,
Q6K,
Q8_0,
}
pub struct QuantWeights<'a, B: Backend> {
pub qweight: &'a B::Buffer,
pub scales: Option<&'a B::Buffer>,
pub zeros: Option<&'a B::Buffer>,
pub g_idx: Option<&'a B::Buffer>,
}
#[derive(Clone, Copy, Debug)]
pub enum ReduceOp {
Sum,
Max,
Min,
}
#[derive(Clone, Debug)]
pub struct AttnConfig {
pub num_heads: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
pub causal: bool,
pub scale: f32,
pub kv_seq_stride: usize,
pub sliding_window: usize,
}
impl Default for AttnConfig {
fn default() -> Self {
Self {
num_heads: 0,
num_kv_heads: 0,
head_dim: 0,
causal: false,
scale: 1.0,
kv_seq_stride: 0,
sliding_window: 0,
}
}
}
pub struct KvCache<B: Backend> {
pub k: B::Buffer,
pub v: B::Buffer,
pub len: usize,
pub capacity: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
pub block_size: usize,
pub block_table: Option<B::Buffer>,
pub context_lens: Option<B::Buffer>,
pub paged_block_indices: Vec<u32>,
}
pub trait Backend: Send + Sync + Sized + 'static {
type Buffer: Send + Sync;
type Context;
type GptqStore: Send + Sync;
type QuantStore: Send + Sync;
fn new_context() -> Self::Context;
fn sync(ctx: &mut Self::Context);
fn set_decode_state(_ctx: &mut Self::Context, _token: u32, _step: u32) {}
fn set_dev_state_mode(_ctx: &mut Self::Context, _enable: bool) {}
fn begin_graph_capture(_ctx: &mut Self::Context) -> Result<()> {
Err(FerrumError::unsupported("graph capture not supported"))
}
fn end_graph_capture(_ctx: &mut Self::Context) -> Result<()> {
Err(FerrumError::unsupported("graph capture not supported"))
}
fn replay_last_graph(_ctx: &mut Self::Context) -> Result<bool> {
Ok(false)
}
fn reset_graph(_ctx: &mut Self::Context) {}
#[allow(clippy::too_many_arguments)]
fn load_gptq(
_qweight: &[i32],
_scales: &[f32],
_qzeros: &[i32],
_g_idx: Option<&[i32]>,
_bits: u32,
_group_size: usize,
_k: usize,
_n: usize,
) -> Result<Self::GptqStore> {
Err(FerrumError::unsupported(
"load_gptq not implemented for this backend",
))
}
fn gemm_gptq(
_ctx: &mut Self::Context,
_a: &Self::Buffer,
_weight: &Self::GptqStore,
_out: &mut Self::Buffer,
_m: usize,
) -> Result<()> {
Err(FerrumError::unsupported(
"gemm_gptq not implemented for this backend",
))
}
fn load_quant(
_kind: GgufQuantType,
_bytes: &[u8],
_n_rows: usize,
_n_cols: usize,
) -> Result<Self::QuantStore> {
Err(FerrumError::unsupported(
"load_quant not implemented for this backend",
))
}
fn load_quant_fused(
_parts: &[(GgufQuantType, &[u8], usize)],
_n_cols: usize,
) -> Result<Self::QuantStore> {
Err(FerrumError::unsupported(
"load_quant_fused not implemented for this backend",
))
}
fn gemm_quant(
_ctx: &mut Self::Context,
_a: &Self::Buffer,
_weight: &Self::QuantStore,
_out: &mut Self::Buffer,
_m: usize,
) -> Result<()> {
Err(FerrumError::unsupported(
"gemm_quant not implemented for this backend",
))
}
fn load_quant_experts(
_kind: GgufQuantType,
_bytes: &[u8],
_num_experts: usize,
_n_rows: usize,
_n_cols: usize,
) -> Result<Self::QuantStore> {
Err(FerrumError::unsupported(
"load_quant_experts not implemented for this backend",
))
}
#[allow(clippy::too_many_arguments)]
fn gemm_quant_moe_id(
_ctx: &mut Self::Context,
_a: &Self::Buffer,
_weight: &Self::QuantStore,
_ids: &Self::Buffer,
_tpe: &Self::Buffer,
_out: &mut Self::Buffer,
_ne11: usize,
_top_k: usize,
_max_per_expert: usize,
_batch: usize,
) -> Result<()> {
Err(FerrumError::unsupported(
"gemm_quant_moe_id not implemented for this backend",
))
}
#[allow(clippy::too_many_arguments)]
fn route_topk_softmax(
_ctx: &mut Self::Context,
_logits: &Self::Buffer,
_out_ids: &mut Self::Buffer,
_out_weights: &mut Self::Buffer,
_batch: usize,
_num_experts: usize,
_top_k: usize,
_norm_topk_prob: bool,
) -> Result<()> {
Err(FerrumError::unsupported(
"route_topk_softmax not implemented for this backend",
))
}
#[allow(clippy::too_many_arguments)]
fn compute_ids_tpe_gpu(
_ctx: &mut Self::Context,
_selected_ids: &Self::Buffer,
_tpe: &mut Self::Buffer,
_ids: &mut Self::Buffer,
_gate_up_args: &mut Self::Buffer,
_down_args: &mut Self::Buffer,
_batch: usize,
_num_experts: usize,
_top_k: usize,
_m_gate_up: usize,
_m_down: usize,
) -> Result<()> {
Err(FerrumError::unsupported(
"compute_ids_tpe_gpu not implemented for this backend",
))
}
#[allow(clippy::too_many_arguments)]
fn gemm_quant_moe_id_indirect(
_ctx: &mut Self::Context,
_src1: &Self::Buffer,
_weights: &Self::QuantStore,
_ids: &Self::Buffer,
_tpe: &Self::Buffer,
_out: &mut Self::Buffer,
_args_buf: &Self::Buffer,
_ne11: usize,
_top_k: usize,
_max_per_expert: usize,
_batch: usize,
) -> Result<()> {
Err(FerrumError::unsupported(
"gemm_quant_moe_id_indirect not implemented for this backend",
))
}
fn silu_mul_batched(
_ctx: &mut Self::Context,
_gate: &Self::Buffer,
_up: &Self::Buffer,
_out: &mut Self::Buffer,
_total_pairs: usize,
_ffn: usize,
) -> Result<()> {
Err(FerrumError::unsupported(
"silu_mul_batched not implemented for this backend",
))
}
fn weighted_sum_residual_stacked(
_ctx: &mut Self::Context,
_slots: &Self::Buffer,
_weights: &Self::Buffer,
_residual: &mut Self::Buffer,
_n_slots: usize,
_hidden: usize,
) -> Result<()> {
Err(FerrumError::unsupported(
"weighted_sum_residual_stacked not implemented for this backend",
))
}
#[allow(clippy::too_many_arguments)]
fn weighted_sum_residual_norm_stacked(
_ctx: &mut Self::Context,
_slots: &Self::Buffer,
_weights: &Self::Buffer,
_residual: &mut Self::Buffer,
_next_norm_w: &Self::Buffer,
_normed_out: &mut Self::Buffer,
_n_slots: usize,
_hidden: usize,
_eps: f32,
) -> Result<()> {
Err(FerrumError::unsupported(
"weighted_sum_residual_norm_stacked not implemented for this backend",
))
}
fn weighted_sum_batched(
_ctx: &mut Self::Context,
_slots: &Self::Buffer,
_weights: &Self::Buffer,
_out: &mut Self::Buffer,
_batch: usize,
_top_k: usize,
_hidden: usize,
) -> Result<()> {
Err(FerrumError::unsupported(
"weighted_sum_batched not implemented for this backend",
))
}
#[allow(clippy::too_many_arguments)]
fn weighted_sum_batched_offset(
ctx: &mut Self::Context,
slots: &Self::Buffer,
weights: &Self::Buffer,
weights_offset: usize,
out: &mut Self::Buffer,
out_offset: usize,
batch: usize,
top_k: usize,
hidden: usize,
) -> Result<()> {
let _ = (
ctx,
slots,
weights,
weights_offset,
out,
out_offset,
batch,
top_k,
hidden,
);
Err(FerrumError::unsupported(
"weighted_sum_batched_offset not implemented for this backend",
))
}
fn gemv_quant_moe_id(
_ctx: &mut Self::Context,
_a: &Self::Buffer,
_weight: &Self::QuantStore,
_ids: &Self::Buffer,
_out: &mut Self::Buffer,
_n_selected: usize,
_src1_stride: usize,
) -> Result<()> {
Err(FerrumError::unsupported(
"gemv_quant_moe_id not implemented for this backend",
))
}
#[allow(clippy::too_many_arguments)]
fn gemv_quant_moe_id_offset(
ctx: &mut Self::Context,
a: &Self::Buffer,
a_offset: usize,
weight: &Self::QuantStore,
ids: &Self::Buffer,
ids_offset: usize,
out: &mut Self::Buffer,
n_selected: usize,
src1_stride: usize,
) -> Result<()> {
let _ = (
ctx,
a,
a_offset,
weight,
ids,
ids_offset,
out,
n_selected,
src1_stride,
);
Err(FerrumError::unsupported(
"gemv_quant_moe_id_offset not implemented for this backend",
))
}
fn from_slice_i32(data: &[i32]) -> Self::Buffer {
let f: Vec<f32> = data.iter().map(|&i| f32::from_bits(i as u32)).collect();
Self::from_slice(&f)
}
fn write_i32_into(buf: &mut Self::Buffer, data: &[i32]) {
*buf = Self::from_slice_i32(data);
}
fn write_f32_into(buf: &mut Self::Buffer, data: &[f32]) {
*buf = Self::from_slice(data);
}
fn silu_mul_stacked(
_ctx: &mut Self::Context,
_gate: &Self::Buffer,
_up: &Self::Buffer,
_out: &mut Self::Buffer,
_n_slots: usize,
_ffn: usize,
) -> Result<()> {
Err(FerrumError::unsupported(
"silu_mul_stacked not implemented for this backend",
))
}
#[allow(clippy::too_many_arguments)]
fn gemv_quant_moe_id_gate_up_silu(
_ctx: &mut Self::Context,
_a: &Self::Buffer,
_gate_w: &Self::QuantStore,
_up_w: &Self::QuantStore,
_ids: &Self::Buffer,
_silu_out: &mut Self::Buffer,
_n_selected: usize,
) -> Result<()> {
Err(FerrumError::unsupported(
"gemv_quant_moe_id_gate_up_silu not implemented for this backend",
))
}
fn supports_fused_moe_gate_up_silu() -> bool {
false
}
#[allow(clippy::too_many_arguments)]
fn gemv_quant_moe_id_batched(
_ctx: &mut Self::Context,
_a: &Self::Buffer,
_weight: &Self::QuantStore,
_ids: &Self::Buffer,
_out: &mut Self::Buffer,
_m: usize,
_top_k: usize,
_src1_outer_stride: usize,
_src1_inner_stride: usize,
) -> Result<()> {
Err(FerrumError::unsupported(
"gemv_quant_moe_id_batched not implemented for this backend",
))
}
fn supports_batched_moe_gemv() -> bool {
false
}
#[allow(clippy::too_many_arguments)]
fn gemv_quant_moe_id_gate_up_silu_batched(
_ctx: &mut Self::Context,
_a: &Self::Buffer,
_gate_w: &Self::QuantStore,
_up_w: &Self::QuantStore,
_ids: &Self::Buffer,
_silu_out: &mut Self::Buffer,
_m: usize,
_top_k: usize,
_src1_outer_stride: usize,
_src1_inner_stride: usize,
) -> Result<()> {
Err(FerrumError::unsupported(
"gemv_quant_moe_id_gate_up_silu_batched not implemented for this backend",
))
}
fn supports_batched_moe_gate_up_silu() -> bool {
false
}
fn weighted_sum_stacked(
_ctx: &mut Self::Context,
_slots: &Self::Buffer,
_weights: &Self::Buffer,
_out: &mut Self::Buffer,
_n_slots: usize,
_hidden: usize,
) -> Result<()> {
Err(FerrumError::unsupported(
"weighted_sum_stacked not implemented for this backend",
))
}
fn gemm(
ctx: &mut Self::Context,
a: &Self::Buffer,
b: &Self::Buffer,
out: &mut Self::Buffer,
m: usize,
n: usize,
k: usize,
);
fn rms_norm(
ctx: &mut Self::Context,
x: &Self::Buffer,
w: &Self::Buffer,
eps: f32,
out: &mut Self::Buffer,
tokens: usize,
dim: usize,
);
fn fused_add_rms_norm(
ctx: &mut Self::Context,
residual: &mut Self::Buffer,
x: &Self::Buffer,
w: &Self::Buffer,
eps: f32,
out: &mut Self::Buffer,
tokens: usize,
dim: usize,
);
fn flash_attention(
ctx: &mut Self::Context,
q: &Self::Buffer,
k: &Self::Buffer,
v: &Self::Buffer,
out: &mut Self::Buffer,
batch: usize,
q_len: usize,
kv_len: usize,
pos_offset: usize,
cfg: &AttnConfig,
);
#[allow(clippy::too_many_arguments)]
fn mla_attention(
_ctx: &mut Self::Context,
_q: &Self::Buffer,
_kv_compressed: &Self::Buffer,
_kv_rope: &Self::Buffer,
_out: &mut Self::Buffer,
_batch: usize,
_q_len: usize,
_kv_len: usize,
_pos_offset: usize,
_cfg: &AttnConfig,
_kv_lora_rank: usize,
_qk_rope_head_dim: usize,
) -> Result<()> {
Err(FerrumError::unsupported(
"mla_attention not implemented for this backend; required by \
DeepSeek V2/V3 (Phase D/E)",
))
}
fn copy_slice(
ctx: &mut Self::Context,
src: &Self::Buffer,
src_offset: usize,
dst: &mut Self::Buffer,
dst_offset: usize,
len: usize,
);
fn embedding_lookup(
ctx: &mut Self::Context,
table: &Self::Buffer,
ids: &[u32],
out: &mut Self::Buffer,
dim: usize,
);
fn split_qkv(
ctx: &mut Self::Context,
qkv: &Self::Buffer,
q: &mut Self::Buffer,
k: &mut Self::Buffer,
v: &mut Self::Buffer,
tokens: usize,
q_dim: usize,
kv_dim: usize,
);
fn fused_silu_mul_split(
ctx: &mut Self::Context,
gate_up: &Self::Buffer,
out: &mut Self::Buffer,
tokens: usize,
im: usize,
);
#[allow(clippy::too_many_arguments)]
fn qk_norm_rope(
ctx: &mut Self::Context,
input: &Self::Buffer,
norm_w: &Self::Buffer,
cos: &Self::Buffer,
sin: &Self::Buffer,
output: &mut Self::Buffer,
tokens: usize,
heads: usize,
head_dim: usize,
pos_offset: usize,
eps: f32,
mode: i32,
);
#[allow(clippy::too_many_arguments)]
fn split_qkv_norm_rope(
_ctx: &mut Self::Context,
_qkv: &Self::Buffer,
_q_norm_w: &Self::Buffer,
_k_norm_w: &Self::Buffer,
_cos: &Self::Buffer,
_sin: &Self::Buffer,
_q_out: &mut Self::Buffer,
_k_out: &mut Self::Buffer,
_v_out: &mut Self::Buffer,
_tokens: usize,
_q_heads: usize,
_kv_heads: usize,
_head_dim: usize,
_pos_offset: usize,
_eps: f32,
_qk_mode: i32,
) -> Result<()> {
Err(FerrumError::unsupported(
"split_qkv_norm_rope not implemented for this backend",
))
}
#[allow(clippy::too_many_arguments)]
fn split_qkv_norm_rope_into_cache(
_ctx: &mut Self::Context,
_qkv: &Self::Buffer,
_q_norm_w: &Self::Buffer,
_k_norm_w: &Self::Buffer,
_cos: &Self::Buffer,
_sin: &Self::Buffer,
_q_out: &mut Self::Buffer,
_cache_k: &mut Self::Buffer,
_cache_v: &mut Self::Buffer,
_tokens: usize,
_q_heads: usize,
_kv_heads: usize,
_head_dim: usize,
_pos_offset: usize,
_eps: f32,
_qk_mode: i32,
_cache_len: usize,
_cache_capacity: usize,
) -> Result<()> {
Err(FerrumError::unsupported(
"split_qkv_norm_rope_into_cache not implemented for this backend",
))
}
#[allow(clippy::too_many_arguments)]
fn split_qkv_norm_rope_into_paged_cache(
_ctx: &mut Self::Context,
_qkv: &Self::Buffer,
_qkv_byte_offset: u64,
_q_norm_w: &Self::Buffer,
_k_norm_w: &Self::Buffer,
_cos: &Self::Buffer,
_sin: &Self::Buffer,
_q_out: &mut Self::Buffer,
_q_out_byte_offset: u64,
_cache_k: &mut Self::Buffer,
_cache_v: &mut Self::Buffer,
_block_table: &Self::Buffer,
_tokens: usize,
_q_heads: usize,
_kv_heads: usize,
_head_dim: usize,
_pos_offset: usize,
_eps: f32,
_qk_mode: i32,
_cache_len: usize,
_block_size: usize,
_max_num_blocks_per_seq: usize,
) -> Result<()> {
Err(FerrumError::unsupported(
"split_qkv_norm_rope_into_paged_cache not implemented for this backend",
))
}
#[allow(clippy::too_many_arguments)]
fn paged_decode_attention(
_ctx: &mut Self::Context,
_q: &Self::Buffer,
_k_pool: &Self::Buffer,
_v_pool: &Self::Buffer,
_out: &mut Self::Buffer,
_block_tables: &Self::Buffer,
_context_lens: &Self::Buffer,
_num_seqs: usize,
_num_heads: usize,
_num_kv_heads: usize,
_head_dim: usize,
_block_size: usize,
_max_num_blocks_per_seq: usize,
_q_len: usize,
) -> Result<()> {
Err(FerrumError::unsupported(
"paged_decode_attention not implemented for this backend",
))
}
fn alloc_u32(n: usize) -> Self::Buffer {
Self::from_slice_i32(&vec![0i32; n])
}
fn write_u32(_ctx: &mut Self::Context, _dst: &mut Self::Buffer, _data: &[u32]) {
}
#[allow(clippy::too_many_arguments)]
fn kv_cache_append_head_major(
ctx: &mut Self::Context,
cache_k: &mut Self::Buffer,
cache_v: &mut Self::Buffer,
cache_len: usize,
cache_capacity: usize,
new_k_head_major: &Self::Buffer,
new_v_head_major: &Self::Buffer,
new_tokens: usize,
nkv: usize,
hd: usize,
);
fn transpose_head_to_token(
ctx: &mut Self::Context,
src: &Self::Buffer,
dst: &mut Self::Buffer,
tokens: usize,
heads: usize,
dim: usize,
);
fn add_inplace(
ctx: &mut Self::Context,
residual: &mut Self::Buffer,
x: &Self::Buffer,
len: usize,
);
fn scaled_add_inplace(
_ctx: &mut Self::Context,
dst: &mut Self::Buffer,
src: &Self::Buffer,
scale: f32,
len: usize,
) {
let mut dst_v = Self::to_vec(dst, len);
let src_v = Self::to_vec(src, len);
for i in 0..len {
dst_v[i] += scale * src_v[i];
}
*dst = Self::from_slice(&dst_v);
}
fn add_bias(
ctx: &mut Self::Context,
data: &mut Self::Buffer,
bias: &Self::Buffer,
rows: usize,
cols: usize,
);
#[allow(clippy::too_many_arguments)]
fn layer_norm(
ctx: &mut Self::Context,
x: &Self::Buffer,
gamma: &Self::Buffer,
beta: &Self::Buffer,
eps: f32,
out: &mut Self::Buffer,
tokens: usize,
dim: usize,
);
fn gelu(ctx: &mut Self::Context, x: &Self::Buffer, out: &mut Self::Buffer, len: usize);
fn alloc(len: usize) -> Self::Buffer;
fn to_vec(buf: &Self::Buffer, len: usize) -> Vec<f32>;
fn from_slice(data: &[f32]) -> Self::Buffer;
fn from_weight_bytes(raw: &[u8], src_dtype: SrcDtype) -> Self::Buffer {
let data = src_dtype.to_f32_vec(raw);
Self::from_slice(&data)
}
fn world_size(_ctx: &Self::Context) -> usize {
1
}
fn rank(_ctx: &Self::Context) -> usize {
0
}
fn all_reduce(_ctx: &mut Self::Context, _buf: &mut Self::Buffer, _len: usize, _op: ReduceOp) {
}
fn all_gather(
_ctx: &mut Self::Context,
_local: &Self::Buffer,
_global: &mut Self::Buffer,
_local_len: usize,
) {
}
fn broadcast(_ctx: &mut Self::Context, _buf: &mut Self::Buffer, _len: usize, _src_rank: usize) {
}
}