Skip to main content

ferrum_kernels/backend/
dtype.rs

1//! Runtime element type tag for typed device buffers.
2//!
3//! Used by the typed buffer handles (`MetalBuf`, `CpuBuf`, `CudaBuf`)
4//! to carry a dtype alongside the raw byte storage. This lets the
5//! existing Backend trait surface (which today exposes a monomorphic
6//! `Self::Buffer` and a pile of `from_slice_i32` / `alloc_u32` /
7//! `write_u32` / `from_slice_f32` etc. helpers) move toward a single
8//! `Self::alloc(Dtype, n)` / `Self::write(buf, &[T])` API without
9//! breaking callers in one PR.
10//!
11//! Phase A (PR A): defined here + Metal `MetalBuf` re-uses it (was
12//! Metal-internal previously). CPU `CpuBuf` and CUDA `CudaBuf` not
13//! yet wired — see Phase B for the migration.
14
15#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
16pub enum Dtype {
17    /// 32-bit IEEE float. Default activation / weight dtype on the
18    /// CPU path; fallback for backends without F16 hw support.
19    F32,
20    /// 16-bit IEEE half. Hot-path dtype on CUDA + Metal (decode q,
21    /// K/V, GEMM outputs).
22    F16,
23    /// 32-bit unsigned integer. Block tables, context lens, sorted
24    /// token ids, args buffers — anything previously tunneled through
25    /// an FP buffer via `alloc_u32` / `write_u32`.
26    U32,
27    /// 32-bit signed integer. Expert ids, position offsets,
28    /// `cu_seqlens_q`, `tpe` (tokens-per-expert). Same byte width as
29    /// `U32`; separate variant so kernel signatures
30    /// (`device const int*` vs `device const uint*`) can stay
31    /// type-honest at runtime.
32    I32,
33    /// 8-bit signed integer. INT8 quantized KV cache cells. Used by
34    /// `KvCacheQuant<B, KvInt8>`'s paged stores.
35    I8,
36}
37
38impl Dtype {
39    pub const fn bytes_per_elem(self) -> usize {
40        match self {
41            Dtype::F32 | Dtype::U32 | Dtype::I32 => 4,
42            Dtype::F16 => 2,
43            Dtype::I8 => 1,
44        }
45    }
46
47    /// Human-readable tag for log lines / panic messages.
48    pub const fn name(self) -> &'static str {
49        match self {
50            Dtype::F32 => "f32",
51            Dtype::F16 => "f16",
52            Dtype::U32 => "u32",
53            Dtype::I32 => "i32",
54            Dtype::I8 => "i8",
55        }
56    }
57}
58
59/// Marker trait connecting a host element type `T` to its runtime
60/// `Dtype` tag. Used by the typed Backend allocator + uploader so the
61/// trait surface has ONE `alloc_typed(Dtype, n)` /
62/// `from_slice_typed<T>(&[T])` / `write_typed<T>(buf, &[T])` instead
63/// of the per-dtype-named family (`alloc_u32`, `from_slice_i32`,
64/// `write_i32_into`, `write_f32_into`, ...).
65///
66/// Implemented for all dtypes that backends store in `Self::Buffer`.
67pub trait HostDtype: Copy + Send + Sync + 'static {
68    const DTYPE: Dtype;
69}
70
71impl HostDtype for u32 {
72    const DTYPE: Dtype = Dtype::U32;
73}
74impl HostDtype for i32 {
75    const DTYPE: Dtype = Dtype::I32;
76}
77impl HostDtype for f32 {
78    const DTYPE: Dtype = Dtype::F32;
79}
80impl HostDtype for half::f16 {
81    const DTYPE: Dtype = Dtype::F16;
82}
83impl HostDtype for i8 {
84    const DTYPE: Dtype = Dtype::I8;
85}