ferrum_kernels/backend/types.rs
1//! Backend data types shared by the capability traits and model code.
2
3use half::{bf16, f16};
4
5use super::traits::{Backend, BackendKvDtype};
6use ferrum_interfaces::kv_dtype::{KvDtypeKind, KvFp16};
7
8/// Source dtype for a weight tensor read straight from safetensors mmap.
9///
10/// Passed to `Backend::from_weight_bytes` so each backend can choose whether
11/// to upcast to its compute dtype or store as-is.
12#[derive(Clone, Copy, Debug, PartialEq, Eq)]
13pub enum SrcDtype {
14 F32,
15 F16,
16 BF16,
17}
18
19impl SrcDtype {
20 /// Number of bytes per element in the raw on-disk representation.
21 pub const fn bytes_per_elem(self) -> usize {
22 match self {
23 SrcDtype::F32 => 4,
24 SrcDtype::F16 | SrcDtype::BF16 => 2,
25 }
26 }
27
28 /// Materialise the raw byte slice into a `Vec<f32>`. Used by the default
29 /// `Backend::from_weight_bytes` impl; fp16-preferring backends bypass it.
30 pub fn to_f32_vec(self, raw: &[u8]) -> Vec<f32> {
31 match self {
32 SrcDtype::F32 => {
33 debug_assert_eq!(raw.len() % 4, 0);
34 let n = raw.len() / 4;
35 let mut out = vec![0f32; n];
36 for i in 0..n {
37 let b = [raw[i * 4], raw[i * 4 + 1], raw[i * 4 + 2], raw[i * 4 + 3]];
38 out[i] = f32::from_le_bytes(b);
39 }
40 out
41 }
42 SrcDtype::F16 => {
43 debug_assert_eq!(raw.len() % 2, 0);
44 let n = raw.len() / 2;
45 let mut out = vec![0f32; n];
46 for i in 0..n {
47 out[i] = f16::from_le_bytes([raw[i * 2], raw[i * 2 + 1]]).to_f32();
48 }
49 out
50 }
51 SrcDtype::BF16 => {
52 debug_assert_eq!(raw.len() % 2, 0);
53 let n = raw.len() / 2;
54 let mut out = vec![0f32; n];
55 for i in 0..n {
56 out[i] = bf16::from_le_bytes([raw[i * 2], raw[i * 2 + 1]]).to_f32();
57 }
58 out
59 }
60 }
61 }
62}
63
64/// Quantization flavour discriminator for `Backend::gemm_quant`.
65///
66/// Distinct schemes need distinct kernels. Carried as a parameter so the
67/// Backend trait does not explode with one method per quantization type.
68#[derive(Clone, Debug)]
69pub enum QuantKind {
70 /// GPTQ: group-wise int4/int8 with scales + zeros (asymmetric) + optional g_idx.
71 Gptq {
72 bits: u32,
73 group_size: usize,
74 desc_act: bool,
75 },
76 /// AWQ: activation-aware int4 with scales + zeros, different packing from GPTQ.
77 Awq { bits: u32, group_size: usize },
78 /// GGUF: one of k-quants / legacy quants, fully specified by the inner type.
79 Gguf { quant_type: GgufQuantType },
80}
81
82/// GGUF quantization sub-type (expand as kernels are added).
83#[derive(Clone, Copy, Debug)]
84pub enum GgufQuantType {
85 Q4_0,
86 Q4_1,
87 Q4K,
88 Q5K,
89 Q6K,
90 Q8_0,
91}
92
93/// Packed quantized weight buffers passed to `Backend::gemm_quant`.
94///
95/// Not every field is used by every `QuantKind` — e.g. GGUF packs scales
96/// inside `qweight`, so `scales` / `zeros` may be dummies. The Backend
97/// implementation is expected to validate the shape for the kind it handles.
98pub struct QuantWeights<'a, B: Backend> {
99 pub qweight: &'a B::Buffer,
100 pub scales: Option<&'a B::Buffer>,
101 pub zeros: Option<&'a B::Buffer>,
102 pub g_idx: Option<&'a B::Buffer>,
103}
104
105/// Collective-op reduction kind for TP all_reduce.
106#[derive(Clone, Copy, Debug)]
107pub enum ReduceOp {
108 Sum,
109 Max,
110 Min,
111}
112
113/// Configuration for attention dispatch.
114#[derive(Clone, Debug)]
115pub struct AttnConfig {
116 pub num_heads: usize,
117 pub num_kv_heads: usize,
118 pub head_dim: usize,
119 pub causal: bool,
120 pub scale: f32,
121 /// Stride (in rows) between head blocks in the KV buffer.
122 /// `0` means contiguous (use `kv_len`, legacy behaviour).
123 /// Set to `cache_capacity` when flashing against a pre-allocated cache
124 /// that only has `kv_len` valid slots out of `cache_capacity`.
125 pub kv_seq_stride: usize,
126 /// Sliding-window attention size (Mistral v0.1, Gemma).
127 /// `0` = disabled (full causal attention).
128 /// `w > 0` = each query position attends to the previous `w` KV positions
129 /// (still bounded by `causal` + `pos_offset + qi + 1` as the upper end).
130 pub sliding_window: usize,
131}
132
133impl Default for AttnConfig {
134 fn default() -> Self {
135 Self {
136 num_heads: 0,
137 num_kv_heads: 0,
138 head_dim: 0,
139 causal: false,
140 scale: 1.0,
141 kv_seq_stride: 0,
142 sliding_window: 0,
143 }
144 }
145}
146
147/// Per-layer KV cache. Each model owns its own `Vec<KvCache<B, K>>` per
148/// sequence. The `K: KvDtypeKind` parameter selects the cache element
149/// type — defaults to [`KvFp16`] so existing call sites that wrote
150/// `KvCache<B>` keep compiling unchanged.
151///
152/// Two layouts are supported, selected at allocation time:
153/// 1. **Contiguous** (default): `k`/`v` are `[num_kv_heads, capacity, head_dim]`
154/// f32 buffers. `block_size == 0` and `block_table` / `context_lens` are
155/// `None`. Original ferrum layout — used when `FERRUM_METAL_PAGED_KV` is
156/// unset.
157/// 2. **Paged** (vLLM-style): `k`/`v` are `[num_blocks, num_kv_heads,
158/// block_size, head_dim]` block pools. `block_size > 0` and
159/// `block_table` (`u32[max_num_blocks_per_seq]`) + `context_lens`
160/// (`u32[1]` single-seq for now) are populated. Multi-seq sharing
161/// is a Phase 4 concern; today every paged cache_id has its own
162/// pool but the kernel-level indirection works.
163///
164/// The `K` parameter is currently a phantom-type marker — the buffer
165/// fields stay `B::Buffer` regardless. Future PRs will switch backends
166/// to `BackendKvDtype<KvInt8>` etc. and the kernel dispatch will read
167/// `K::NAME` / `K::BYTES_PER_ELEM` to pick the right append / attention
168/// kernel without any `KvCache` struct change.
169pub struct KvCache<B: Backend, K: KvDtypeKind = KvFp16> {
170 pub k: B::Buffer,
171 pub v: B::Buffer,
172 pub len: usize,
173 pub capacity: usize,
174 pub num_kv_heads: usize,
175 pub head_dim: usize,
176 /// Paged: KV positions per physical block. `0` => contiguous layout.
177 pub block_size: usize,
178 /// Paged: `[max_num_blocks_per_seq]` u32 — logical → physical block.
179 pub block_table: Option<B::Buffer>,
180 /// Paged: `[1]` u32 — current context length for the kernel to read.
181 pub context_lens: Option<B::Buffer>,
182 /// Paged: host-side mirror of the physical block indices owned by
183 /// this cache. Lets the model's release path return blocks to the
184 /// shared allocator without reading them back from device.
185 pub paged_block_indices: Vec<u32>,
186 /// Marker — KV cache element type. Zero-sized.
187 pub _kv_dtype: std::marker::PhantomData<K>,
188}
189
190/// Quantized-KV cache (Dim 5 INT8 / future FP8 paths). Sibling of
191/// [`KvCache`] for backends that store K/V in a non-FP16 element type
192/// plus per-token per-kv-head scales.
193///
194/// Why a separate struct: the FP16 `KvCache<B, K>` uses `B::Buffer`
195/// uniformly, which is FP16 on every concrete backend. Stuffing INT8
196/// storage into that buffer would require unsafe transmutes; making
197/// the FP16 struct generic over the storage type would force every
198/// existing call site (4 model files, ~20 functions) to pick up an
199/// equality-bound on the associated type. Keeping a parallel struct
200/// for INT8 is the cheaper trade — the kernel launchers in
201/// [`crate::int8_kv`] take cudarc primitives directly anyway.
202///
203/// `KStorage` and `ScaleStorage` come from `BackendKvDtype<K>::KvBuffer`
204/// and `BackendKvDtype<K>::KvScales`. On CUDA they wrap `CudaSlice<i8>`
205/// and `CudaSlice<f16>`.
206pub struct KvCacheQuant<B: BackendKvDtype<K>, K: KvDtypeKind> {
207 pub k: <B as BackendKvDtype<K>>::KvBuffer,
208 pub v: <B as BackendKvDtype<K>>::KvBuffer,
209 pub k_scales: <B as BackendKvDtype<K>>::KvScales,
210 pub v_scales: <B as BackendKvDtype<K>>::KvScales,
211 pub len: usize,
212 pub capacity: usize,
213 pub num_kv_heads: usize,
214 pub head_dim: usize,
215 pub block_size: usize,
216 pub block_table: Option<B::Buffer>,
217 pub context_lens: Option<B::Buffer>,
218 pub paged_block_indices: Vec<u32>,
219 pub _kv_dtype: std::marker::PhantomData<K>,
220}
221
222/// Routing buffers consumed by `moe_gemm_phase_vllm` — held by the
223/// caller across phase 1 and phase 3 of one MoE forward. All three
224/// fields are i32 device tensors in disguise (`Self::Buffer = fp16` on
225/// CUDA; the backend reinterprets the underlying device pointer).
226pub struct MoeRouting<B: Backend + ?Sized> {
227 pub sorted_token_ids: B::Buffer,
228 pub expert_ids: B::Buffer,
229 pub num_tokens_past_padded: B::Buffer,
230}