Skip to main content

ferrum_kernels/backend/
types.rs

1//! Backend data types shared by the capability traits and model code.
2
3use half::{bf16, f16};
4
5use super::traits::{Backend, BackendKvDtype};
6use ferrum_interfaces::kv_dtype::{KvDtypeKind, KvFp16};
7
8/// Source dtype for a weight tensor read straight from safetensors mmap.
9///
10/// Passed to `Backend::from_weight_bytes` so each backend can choose whether
11/// to upcast to its compute dtype or store as-is.
12#[derive(Clone, Copy, Debug, PartialEq, Eq)]
13pub enum SrcDtype {
14    F32,
15    F16,
16    BF16,
17}
18
19impl SrcDtype {
20    /// Number of bytes per element in the raw on-disk representation.
21    pub const fn bytes_per_elem(self) -> usize {
22        match self {
23            SrcDtype::F32 => 4,
24            SrcDtype::F16 | SrcDtype::BF16 => 2,
25        }
26    }
27
28    /// Materialise the raw byte slice into a `Vec<f32>`. Used by the default
29    /// `Backend::from_weight_bytes` impl; fp16-preferring backends bypass it.
30    pub fn to_f32_vec(self, raw: &[u8]) -> Vec<f32> {
31        match self {
32            SrcDtype::F32 => {
33                debug_assert_eq!(raw.len() % 4, 0);
34                let n = raw.len() / 4;
35                let mut out = vec![0f32; n];
36                for i in 0..n {
37                    let b = [raw[i * 4], raw[i * 4 + 1], raw[i * 4 + 2], raw[i * 4 + 3]];
38                    out[i] = f32::from_le_bytes(b);
39                }
40                out
41            }
42            SrcDtype::F16 => {
43                debug_assert_eq!(raw.len() % 2, 0);
44                let n = raw.len() / 2;
45                let mut out = vec![0f32; n];
46                for i in 0..n {
47                    out[i] = f16::from_le_bytes([raw[i * 2], raw[i * 2 + 1]]).to_f32();
48                }
49                out
50            }
51            SrcDtype::BF16 => {
52                debug_assert_eq!(raw.len() % 2, 0);
53                let n = raw.len() / 2;
54                let mut out = vec![0f32; n];
55                for i in 0..n {
56                    out[i] = bf16::from_le_bytes([raw[i * 2], raw[i * 2 + 1]]).to_f32();
57                }
58                out
59            }
60        }
61    }
62}
63
64/// Quantization flavour discriminator for `Backend::gemm_quant`.
65///
66/// Distinct schemes need distinct kernels. Carried as a parameter so the
67/// Backend trait does not explode with one method per quantization type.
68#[derive(Clone, Debug)]
69pub enum QuantKind {
70    /// GPTQ: group-wise int4/int8 with scales + zeros (asymmetric) + optional g_idx.
71    Gptq {
72        bits: u32,
73        group_size: usize,
74        desc_act: bool,
75    },
76    /// AWQ: activation-aware int4 with scales + zeros, different packing from GPTQ.
77    Awq { bits: u32, group_size: usize },
78    /// GGUF: one of k-quants / legacy quants, fully specified by the inner type.
79    Gguf { quant_type: GgufQuantType },
80}
81
82/// GGUF quantization sub-type (expand as kernels are added).
83#[derive(Clone, Copy, Debug)]
84pub enum GgufQuantType {
85    Q4_0,
86    Q4_1,
87    Q4K,
88    Q5K,
89    Q6K,
90    Q8_0,
91}
92
93/// Packed quantized weight buffers passed to `Backend::gemm_quant`.
94///
95/// Not every field is used by every `QuantKind` — e.g. GGUF packs scales
96/// inside `qweight`, so `scales` / `zeros` may be dummies. The Backend
97/// implementation is expected to validate the shape for the kind it handles.
98pub struct QuantWeights<'a, B: Backend> {
99    pub qweight: &'a B::Buffer,
100    pub scales: Option<&'a B::Buffer>,
101    pub zeros: Option<&'a B::Buffer>,
102    pub g_idx: Option<&'a B::Buffer>,
103}
104
105/// Collective-op reduction kind for TP all_reduce.
106#[derive(Clone, Copy, Debug)]
107pub enum ReduceOp {
108    Sum,
109    Max,
110    Min,
111}
112
113/// Configuration for attention dispatch.
114#[derive(Clone, Debug)]
115pub struct AttnConfig {
116    pub num_heads: usize,
117    pub num_kv_heads: usize,
118    pub head_dim: usize,
119    pub causal: bool,
120    pub scale: f32,
121    /// Stride (in rows) between head blocks in the KV buffer.
122    /// `0` means contiguous (use `kv_len`, legacy behaviour).
123    /// Set to `cache_capacity` when flashing against a pre-allocated cache
124    /// that only has `kv_len` valid slots out of `cache_capacity`.
125    pub kv_seq_stride: usize,
126    /// Sliding-window attention size (Mistral v0.1, Gemma).
127    /// `0` = disabled (full causal attention).
128    /// `w > 0` = each query position attends to the previous `w` KV positions
129    ///            (still bounded by `causal` + `pos_offset + qi + 1` as the upper end).
130    pub sliding_window: usize,
131}
132
133impl Default for AttnConfig {
134    fn default() -> Self {
135        Self {
136            num_heads: 0,
137            num_kv_heads: 0,
138            head_dim: 0,
139            causal: false,
140            scale: 1.0,
141            kv_seq_stride: 0,
142            sliding_window: 0,
143        }
144    }
145}
146
147/// Per-layer KV cache. Each model owns its own `Vec<KvCache<B, K>>` per
148/// sequence. The `K: KvDtypeKind` parameter selects the cache element
149/// type — defaults to [`KvFp16`] so existing call sites that wrote
150/// `KvCache<B>` keep compiling unchanged.
151///
152/// Two layouts are supported, selected at allocation time:
153/// 1. **Contiguous** (default): `k`/`v` are `[num_kv_heads, capacity, head_dim]`
154///    f32 buffers. `block_size == 0` and `block_table` / `context_lens` are
155///    `None`. Original ferrum layout — used when `FERRUM_METAL_PAGED_KV` is
156///    unset.
157/// 2. **Paged** (vLLM-style): `k`/`v` are `[num_blocks, num_kv_heads,
158///    block_size, head_dim]` block pools. `block_size > 0` and
159///    `block_table` (`u32[max_num_blocks_per_seq]`) + `context_lens`
160///    (`u32[1]` single-seq for now) are populated. Multi-seq sharing
161///    is a Phase 4 concern; today every paged cache_id has its own
162///    pool but the kernel-level indirection works.
163///
164/// The `K` parameter is currently a phantom-type marker — the buffer
165/// fields stay `B::Buffer` regardless. Future PRs will switch backends
166/// to `BackendKvDtype<KvInt8>` etc. and the kernel dispatch will read
167/// `K::NAME` / `K::BYTES_PER_ELEM` to pick the right append / attention
168/// kernel without any `KvCache` struct change.
169pub struct KvCache<B: Backend, K: KvDtypeKind = KvFp16> {
170    pub k: B::Buffer,
171    pub v: B::Buffer,
172    pub len: usize,
173    pub capacity: usize,
174    pub num_kv_heads: usize,
175    pub head_dim: usize,
176    /// Paged: KV positions per physical block. `0` => contiguous layout.
177    pub block_size: usize,
178    /// Paged: `[max_num_blocks_per_seq]` u32 — logical → physical block.
179    pub block_table: Option<B::Buffer>,
180    /// Paged: `[1]` u32 — current context length for the kernel to read.
181    pub context_lens: Option<B::Buffer>,
182    /// Paged: host-side mirror of the physical block indices owned by
183    /// this cache. Lets the model's release path return blocks to the
184    /// shared allocator without reading them back from device.
185    pub paged_block_indices: Vec<u32>,
186    /// Marker — KV cache element type. Zero-sized.
187    pub _kv_dtype: std::marker::PhantomData<K>,
188}
189
190/// Quantized-KV cache (Dim 5 INT8 / future FP8 paths). Sibling of
191/// [`KvCache`] for backends that store K/V in a non-FP16 element type
192/// plus per-token per-kv-head scales.
193///
194/// Why a separate struct: the FP16 `KvCache<B, K>` uses `B::Buffer`
195/// uniformly, which is FP16 on every concrete backend. Stuffing INT8
196/// storage into that buffer would require unsafe transmutes; making
197/// the FP16 struct generic over the storage type would force every
198/// existing call site (4 model files, ~20 functions) to pick up an
199/// equality-bound on the associated type. Keeping a parallel struct
200/// for INT8 is the cheaper trade — the kernel launchers in
201/// [`crate::int8_kv`] take cudarc primitives directly anyway.
202///
203/// `KStorage` and `ScaleStorage` come from `BackendKvDtype<K>::KvBuffer`
204/// and `BackendKvDtype<K>::KvScales`. On CUDA they wrap `CudaSlice<i8>`
205/// and `CudaSlice<f16>`.
206pub struct KvCacheQuant<B: BackendKvDtype<K>, K: KvDtypeKind> {
207    pub k: <B as BackendKvDtype<K>>::KvBuffer,
208    pub v: <B as BackendKvDtype<K>>::KvBuffer,
209    pub k_scales: <B as BackendKvDtype<K>>::KvScales,
210    pub v_scales: <B as BackendKvDtype<K>>::KvScales,
211    pub len: usize,
212    pub capacity: usize,
213    pub num_kv_heads: usize,
214    pub head_dim: usize,
215    pub block_size: usize,
216    pub block_table: Option<B::Buffer>,
217    pub context_lens: Option<B::Buffer>,
218    pub paged_block_indices: Vec<u32>,
219    pub _kv_dtype: std::marker::PhantomData<K>,
220}
221
222/// Routing buffers consumed by `moe_gemm_phase_vllm` — held by the
223/// caller across phase 1 and phase 3 of one MoE forward. All three
224/// fields are i32 device tensors in disguise (`Self::Buffer = fp16` on
225/// CUDA; the backend reinterprets the underlying device pointer).
226pub struct MoeRouting<B: Backend + ?Sized> {
227    pub sorted_token_ids: B::Buffer,
228    pub expert_ids: B::Buffer,
229    pub num_tokens_past_padded: B::Buffer,
230}