use crate::TensorRef;
use ferrum_types::Result;
#[derive(Debug, Clone)]
pub struct RoPEConfig {
pub head_dim: usize,
pub max_seq_len: usize,
pub theta: f32,
}
#[derive(Debug, Clone)]
pub struct AttentionParams {
pub num_heads: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
pub softmax_scale: f32,
pub causal: bool,
}
#[derive(Debug, Clone)]
pub enum QuantScheme {
Q4_0 { group_size: usize },
Q8_0 { group_size: usize },
}
#[derive(Debug, Clone)]
pub struct SamplingParams {
pub temperature: f32,
pub top_k: usize,
pub top_p: f32,
pub repetition_penalty: f32,
pub repetition_token_ids: Vec<u32>,
pub repetition_token_freqs: Vec<u32>,
pub rng_seed: u32,
}
pub trait NormOps: Send + Sync {
fn rms_norm(&self, input: &TensorRef, weight: &TensorRef, eps: f32) -> Result<TensorRef>;
fn rms_norm_residual(
&self,
input: &TensorRef,
residual: &TensorRef,
weight: &TensorRef,
eps: f32,
) -> Result<(TensorRef, TensorRef)> {
let _ = (input, residual, weight, eps);
Err(ferrum_types::FerrumError::unsupported(
"rms_norm_residual not implemented",
))
}
}
pub trait PositionOps: Send + Sync {
fn rotary_embedding(
&self,
x: &TensorRef,
cos_cache: &TensorRef,
sin_cache: &TensorRef,
position_ids: &[usize],
) -> Result<TensorRef>;
}
pub trait AttentionOps: Send + Sync {
fn attention(
&self,
q: &TensorRef,
k: &TensorRef,
v: &TensorRef,
params: &AttentionParams,
) -> Result<TensorRef>;
fn paged_attention(
&self,
_q: &TensorRef,
_k_cache: &TensorRef,
_v_cache: &TensorRef,
_block_table: &[u32],
_params: &AttentionParams,
) -> Result<TensorRef> {
Err(ferrum_types::FerrumError::unsupported(
"paged_attention not implemented",
))
}
}
pub trait ActivationOps: Send + Sync {
fn silu_mul(&self, gate: &TensorRef, up: &TensorRef) -> Result<TensorRef>;
fn gelu(&self, input: &TensorRef) -> Result<TensorRef>;
}
pub trait LinearOps: Send + Sync {
fn linear(&self, input: &TensorRef, weight: &TensorRef) -> Result<TensorRef>;
fn quantized_linear(
&self,
_input: &TensorRef,
_packed_weight: &TensorRef,
_scheme: &QuantScheme,
) -> Result<TensorRef> {
Err(ferrum_types::FerrumError::unsupported(
"quantized_linear not implemented",
))
}
}
pub trait SamplingOps: Send + Sync {
fn sample_token(&self, logits: &TensorRef, params: &SamplingParams) -> Result<u32>;
fn argmax(&self, logits: &TensorRef) -> Result<u32>;
}
pub trait KernelOps: Send + Sync {
fn norm_ops(&self) -> Option<&dyn NormOps> {
None
}
fn position_ops(&self) -> Option<&dyn PositionOps> {
None
}
fn attention_ops(&self) -> Option<&dyn AttentionOps> {
None
}
fn activation_ops(&self) -> Option<&dyn ActivationOps> {
None
}
fn linear_ops(&self) -> Option<&dyn LinearOps> {
None
}
fn sampling_ops(&self) -> Option<&dyn SamplingOps> {
None
}
fn backend_name(&self) -> &str;
}
pub struct KernelOpsDispatch<'a> {
kernel_ops: Option<&'a dyn KernelOps>,
tensor_ops: &'a dyn crate::TensorOps,
}
impl<'a> KernelOpsDispatch<'a> {
pub fn new(
kernel_ops: Option<&'a dyn KernelOps>,
tensor_ops: &'a dyn crate::TensorOps,
) -> Self {
Self {
kernel_ops,
tensor_ops,
}
}
pub fn rms_norm(&self, input: &TensorRef, weight: &TensorRef, eps: f32) -> Result<TensorRef> {
if let Some(ko) = self.kernel_ops {
if let Some(norm) = ko.norm_ops() {
return norm.rms_norm(input, weight, eps);
}
}
self.tensor_ops.rms_norm(input, weight, eps)
}
pub fn gelu(&self, input: &TensorRef) -> Result<TensorRef> {
if let Some(ko) = self.kernel_ops {
if let Some(act) = ko.activation_ops() {
return act.gelu(input);
}
}
self.tensor_ops.gelu(input)
}
pub fn silu(&self, input: &TensorRef) -> Result<TensorRef> {
self.tensor_ops.silu(input)
}
pub fn silu_mul(&self, gate: &TensorRef, up: &TensorRef) -> Result<TensorRef> {
if let Some(ko) = self.kernel_ops {
if let Some(act) = ko.activation_ops() {
return act.silu_mul(gate, up);
}
}
let activated = self.tensor_ops.silu(gate)?;
self.tensor_ops.mul(&activated, up)
}
pub fn linear(&self, input: &TensorRef, weight: &TensorRef) -> Result<TensorRef> {
if let Some(ko) = self.kernel_ops {
if let Some(lin) = ko.linear_ops() {
return lin.linear(input, weight);
}
}
self.tensor_ops.matmul(input, weight)
}
pub fn softmax(&self, input: &TensorRef, dim: i32) -> Result<TensorRef> {
self.tensor_ops.softmax(input, dim)
}
pub fn kernel_ops(&self) -> Option<&'a dyn KernelOps> {
self.kernel_ops
}
pub fn tensor_ops(&self) -> &'a dyn crate::TensorOps {
self.tensor_ops
}
}
#[cfg(test)]
mod tests {
use super::*;
struct EmptyKernelOps;
impl KernelOps for EmptyKernelOps {
fn backend_name(&self) -> &str {
"empty"
}
}
#[test]
fn test_empty_kernel_ops_returns_none() {
let ops = EmptyKernelOps;
assert!(ops.norm_ops().is_none());
assert!(ops.position_ops().is_none());
assert!(ops.attention_ops().is_none());
assert!(ops.activation_ops().is_none());
assert!(ops.linear_ops().is_none());
assert!(ops.sampling_ops().is_none());
assert_eq!(ops.backend_name(), "empty");
}
#[test]
fn test_rope_config_default() {
let cfg = RoPEConfig {
head_dim: 128,
max_seq_len: 2048,
theta: 10000.0,
};
assert_eq!(cfg.head_dim, 128);
}
#[test]
fn test_attention_params() {
let params = AttentionParams {
num_heads: 32,
num_kv_heads: 8,
head_dim: 128,
softmax_scale: 1.0 / (128.0_f32).sqrt(),
causal: true,
};
assert!(params.causal);
assert_eq!(params.num_heads / params.num_kv_heads, 4); }
#[test]
fn test_quant_scheme_variants() {
let q4 = QuantScheme::Q4_0 { group_size: 32 };
let q8 = QuantScheme::Q8_0 { group_size: 128 };
match q4 {
QuantScheme::Q4_0 { group_size } => assert_eq!(group_size, 32),
_ => panic!("expected Q4_0"),
}
match q8 {
QuantScheme::Q8_0 { group_size } => assert_eq!(group_size, 128),
_ => panic!("expected Q8_0"),
}
}
}