use ferrum_types::{FerrumError, Result};
#[derive(Clone, Debug)]
pub enum QuantKind {
Gptq {
bits: u32,
group_size: usize,
desc_act: bool,
},
Awq { bits: u32, group_size: usize },
Gguf { quant_type: GgufQuantType },
}
#[derive(Clone, Copy, Debug)]
pub enum GgufQuantType {
Q4_0,
Q4_1,
Q4K,
Q5K,
Q6K,
Q8_0,
}
pub struct QuantWeights<'a, B: Backend> {
pub qweight: &'a B::Buffer,
pub scales: Option<&'a B::Buffer>,
pub zeros: Option<&'a B::Buffer>,
pub g_idx: Option<&'a B::Buffer>,
}
#[derive(Clone, Copy, Debug)]
pub enum ReduceOp {
Sum,
Max,
Min,
}
#[derive(Clone, Debug)]
pub struct AttnConfig {
pub num_heads: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
pub causal: bool,
pub scale: f32,
pub kv_seq_stride: usize,
pub sliding_window: usize,
}
impl Default for AttnConfig {
fn default() -> Self {
Self {
num_heads: 0,
num_kv_heads: 0,
head_dim: 0,
causal: false,
scale: 1.0,
kv_seq_stride: 0,
sliding_window: 0,
}
}
}
pub struct KvCache<B: Backend> {
pub k: B::Buffer,
pub v: B::Buffer,
pub len: usize,
pub capacity: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
}
pub trait Backend: Send + Sync + Sized + 'static {
type Buffer: Send + Sync;
type Context;
type GptqStore: Send + Sync;
fn new_context() -> Self::Context;
fn sync(ctx: &mut Self::Context);
fn set_decode_state(_ctx: &mut Self::Context, _token: u32, _step: u32) {}
fn set_dev_state_mode(_ctx: &mut Self::Context, _enable: bool) {}
fn begin_graph_capture(_ctx: &mut Self::Context) -> Result<()> {
Err(FerrumError::unsupported("graph capture not supported"))
}
fn end_graph_capture(_ctx: &mut Self::Context) -> Result<()> {
Err(FerrumError::unsupported("graph capture not supported"))
}
fn replay_last_graph(_ctx: &mut Self::Context) -> Result<bool> {
Ok(false)
}
fn reset_graph(_ctx: &mut Self::Context) {}
#[allow(clippy::too_many_arguments)]
fn load_gptq(
_qweight: &[i32],
_scales: &[f32],
_qzeros: &[i32],
_g_idx: Option<&[i32]>,
_bits: u32,
_group_size: usize,
_k: usize,
_n: usize,
) -> Result<Self::GptqStore> {
Err(FerrumError::unsupported(
"load_gptq not implemented for this backend",
))
}
fn gemm_gptq(
_ctx: &mut Self::Context,
_a: &Self::Buffer,
_weight: &Self::GptqStore,
_out: &mut Self::Buffer,
_m: usize,
) -> Result<()> {
Err(FerrumError::unsupported(
"gemm_gptq not implemented for this backend",
))
}
fn gemm(
ctx: &mut Self::Context,
a: &Self::Buffer,
b: &Self::Buffer,
out: &mut Self::Buffer,
m: usize,
n: usize,
k: usize,
);
fn rms_norm(
ctx: &mut Self::Context,
x: &Self::Buffer,
w: &Self::Buffer,
eps: f32,
out: &mut Self::Buffer,
tokens: usize,
dim: usize,
);
fn fused_add_rms_norm(
ctx: &mut Self::Context,
residual: &mut Self::Buffer,
x: &Self::Buffer,
w: &Self::Buffer,
eps: f32,
out: &mut Self::Buffer,
tokens: usize,
dim: usize,
);
fn flash_attention(
ctx: &mut Self::Context,
q: &Self::Buffer,
k: &Self::Buffer,
v: &Self::Buffer,
out: &mut Self::Buffer,
batch: usize,
q_len: usize,
kv_len: usize,
pos_offset: usize,
cfg: &AttnConfig,
);
#[allow(clippy::too_many_arguments)]
fn mla_attention(
_ctx: &mut Self::Context,
_q: &Self::Buffer,
_kv_compressed: &Self::Buffer,
_kv_rope: &Self::Buffer,
_out: &mut Self::Buffer,
_batch: usize,
_q_len: usize,
_kv_len: usize,
_pos_offset: usize,
_cfg: &AttnConfig,
_kv_lora_rank: usize,
_qk_rope_head_dim: usize,
) -> Result<()> {
Err(FerrumError::unsupported(
"mla_attention not implemented for this backend; required by \
DeepSeek V2/V3 (Phase D/E)",
))
}
fn copy_slice(
ctx: &mut Self::Context,
src: &Self::Buffer,
src_offset: usize,
dst: &mut Self::Buffer,
dst_offset: usize,
len: usize,
);
fn embedding_lookup(
ctx: &mut Self::Context,
table: &Self::Buffer,
ids: &[u32],
out: &mut Self::Buffer,
dim: usize,
);
fn split_qkv(
ctx: &mut Self::Context,
qkv: &Self::Buffer,
q: &mut Self::Buffer,
k: &mut Self::Buffer,
v: &mut Self::Buffer,
tokens: usize,
q_dim: usize,
kv_dim: usize,
);
fn fused_silu_mul_split(
ctx: &mut Self::Context,
gate_up: &Self::Buffer,
out: &mut Self::Buffer,
tokens: usize,
im: usize,
);
#[allow(clippy::too_many_arguments)]
fn qk_norm_rope(
ctx: &mut Self::Context,
input: &Self::Buffer,
norm_w: &Self::Buffer,
cos: &Self::Buffer,
sin: &Self::Buffer,
output: &mut Self::Buffer,
tokens: usize,
heads: usize,
head_dim: usize,
pos_offset: usize,
eps: f32,
mode: i32,
);
#[allow(clippy::too_many_arguments)]
fn kv_cache_append_head_major(
ctx: &mut Self::Context,
cache_k: &mut Self::Buffer,
cache_v: &mut Self::Buffer,
cache_len: usize,
cache_capacity: usize,
new_k_head_major: &Self::Buffer,
new_v_head_major: &Self::Buffer,
new_tokens: usize,
nkv: usize,
hd: usize,
);
fn transpose_head_to_token(
ctx: &mut Self::Context,
src: &Self::Buffer,
dst: &mut Self::Buffer,
tokens: usize,
heads: usize,
dim: usize,
);
fn add_inplace(
ctx: &mut Self::Context,
residual: &mut Self::Buffer,
x: &Self::Buffer,
len: usize,
);
fn add_bias(
ctx: &mut Self::Context,
data: &mut Self::Buffer,
bias: &Self::Buffer,
rows: usize,
cols: usize,
);
#[allow(clippy::too_many_arguments)]
fn layer_norm(
ctx: &mut Self::Context,
x: &Self::Buffer,
gamma: &Self::Buffer,
beta: &Self::Buffer,
eps: f32,
out: &mut Self::Buffer,
tokens: usize,
dim: usize,
);
fn gelu(ctx: &mut Self::Context, x: &Self::Buffer, out: &mut Self::Buffer, len: usize);
fn alloc(len: usize) -> Self::Buffer;
fn to_vec(buf: &Self::Buffer, len: usize) -> Vec<f32>;
fn from_slice(data: &[f32]) -> Self::Buffer;
#[allow(clippy::too_many_arguments)]
fn gemm_quant(
_ctx: &mut Self::Context,
_a: &Self::Buffer,
_weights: &QuantWeights<'_, Self>,
_out: &mut Self::Buffer,
_m: usize,
_n: usize,
_k: usize,
kind: &QuantKind,
) -> Result<()> {
Err(FerrumError::unsupported(format!(
"gemm_quant({kind:?}) not implemented for this backend"
)))
}
fn world_size(_ctx: &Self::Context) -> usize {
1
}
fn rank(_ctx: &Self::Context) -> usize {
0
}
fn all_reduce(_ctx: &mut Self::Context, _buf: &mut Self::Buffer, _len: usize, _op: ReduceOp) {
}
fn all_gather(
_ctx: &mut Self::Context,
_local: &Self::Buffer,
_global: &mut Self::Buffer,
_local_len: usize,
) {
}
fn broadcast(_ctx: &mut Self::Context, _buf: &mut Self::Buffer, _len: usize, _src_rank: usize) {
}
}