use half::{bf16, f16};
use super::traits::{Backend, BackendKvDtype};
use ferrum_interfaces::kv_dtype::{KvDtypeKind, KvFp16};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum SrcDtype {
F32,
F16,
BF16,
}
impl SrcDtype {
pub const fn bytes_per_elem(self) -> usize {
match self {
SrcDtype::F32 => 4,
SrcDtype::F16 | SrcDtype::BF16 => 2,
}
}
pub fn to_f32_vec(self, raw: &[u8]) -> Vec<f32> {
match self {
SrcDtype::F32 => {
debug_assert_eq!(raw.len() % 4, 0);
let n = raw.len() / 4;
let mut out = vec![0f32; n];
for i in 0..n {
let b = [raw[i * 4], raw[i * 4 + 1], raw[i * 4 + 2], raw[i * 4 + 3]];
out[i] = f32::from_le_bytes(b);
}
out
}
SrcDtype::F16 => {
debug_assert_eq!(raw.len() % 2, 0);
let n = raw.len() / 2;
let mut out = vec![0f32; n];
for i in 0..n {
out[i] = f16::from_le_bytes([raw[i * 2], raw[i * 2 + 1]]).to_f32();
}
out
}
SrcDtype::BF16 => {
debug_assert_eq!(raw.len() % 2, 0);
let n = raw.len() / 2;
let mut out = vec![0f32; n];
for i in 0..n {
out[i] = bf16::from_le_bytes([raw[i * 2], raw[i * 2 + 1]]).to_f32();
}
out
}
}
}
}
#[derive(Clone, Debug)]
pub enum QuantKind {
Gptq {
bits: u32,
group_size: usize,
desc_act: bool,
},
Awq { bits: u32, group_size: usize },
Gguf { quant_type: GgufQuantType },
}
#[derive(Clone, Copy, Debug)]
pub enum GgufQuantType {
Q4_0,
Q4_1,
Q4K,
Q5K,
Q6K,
Q8_0,
}
pub struct QuantWeights<'a, B: Backend> {
pub qweight: &'a B::Buffer,
pub scales: Option<&'a B::Buffer>,
pub zeros: Option<&'a B::Buffer>,
pub g_idx: Option<&'a B::Buffer>,
}
#[derive(Clone, Copy, Debug)]
pub enum ReduceOp {
Sum,
Max,
Min,
}
#[derive(Clone, Debug)]
pub struct AttnConfig {
pub num_heads: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
pub causal: bool,
pub scale: f32,
pub kv_seq_stride: usize,
pub sliding_window: usize,
}
impl Default for AttnConfig {
fn default() -> Self {
Self {
num_heads: 0,
num_kv_heads: 0,
head_dim: 0,
causal: false,
scale: 1.0,
kv_seq_stride: 0,
sliding_window: 0,
}
}
}
pub struct KvCache<B: Backend, K: KvDtypeKind = KvFp16> {
pub k: B::Buffer,
pub v: B::Buffer,
pub len: usize,
pub capacity: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
pub block_size: usize,
pub block_table: Option<B::Buffer>,
pub context_lens: Option<B::Buffer>,
pub paged_block_indices: Vec<u32>,
pub _kv_dtype: std::marker::PhantomData<K>,
}
pub struct KvCacheQuant<B: BackendKvDtype<K>, K: KvDtypeKind> {
pub k: <B as BackendKvDtype<K>>::KvBuffer,
pub v: <B as BackendKvDtype<K>>::KvBuffer,
pub k_scales: <B as BackendKvDtype<K>>::KvScales,
pub v_scales: <B as BackendKvDtype<K>>::KvScales,
pub len: usize,
pub capacity: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
pub block_size: usize,
pub block_table: Option<B::Buffer>,
pub context_lens: Option<B::Buffer>,
pub paged_block_indices: Vec<u32>,
pub _kv_dtype: std::marker::PhantomData<K>,
}
pub struct MoeRouting<B: Backend + ?Sized> {
pub sorted_token_ids: B::Buffer,
pub expert_ids: B::Buffer,
pub num_tokens_past_padded: B::Buffer,
}