Skip to main content

ferrum_kernels/backend/
buffer.rs

1//! Typed buffer wrappers — Phase B foundation.
2//!
3//! Each backend grows a wrapper struct that carries a runtime `Dtype`
4//! tag alongside the raw storage. The wrappers are **not yet wired** to
5//! `Backend::Buffer` (that's Phase B-2: switch `type Buffer = ...` per
6//! backend + migrate all callsites). For now they exist as separate
7//! types so Phase B-2 can be staged without breaking everything in
8//! one go.
9//!
10//! Pattern mirrors the existing `MetalBuf` (see `backend::metal::MetalBuf`)
11//! which has been carrying a dtype tag since the Metal backend shipped.
12//! Generalising it to CPU + CUDA eliminates the `from_slice_i32` /
13//! `alloc_u32` / `write_u32` family of helpers that tunnel integer data
14//! through the FP-typed buffer.
15//!
16//! Migration plan (next PR):
17//! 1. Replace each backend's `type Buffer = <concrete>` with the typed
18//!    wrapper.
19//! 2. Add forwarder methods so existing call sites compile unchanged
20//!    (`buf.as_f16()` / `buf.as_f32_slice()` / `buf.as_u32_mut()` etc.).
21//! 3. Replace `alloc_u32` / `from_slice_i32` / `write_u32` with one
22//!    `Self::alloc(Dtype, n)` and one `Self::write_typed(buf, &[T])`.
23//! 4. Delete the legacy helpers + their i32 bit-cast tunnels.
24
25use crate::backend::dtype::Dtype;
26use half::f16;
27
28/// CPU-side typed buffer. Variants per dtype keep storage typed (no
29/// `unsafe` bytemuck casting, no alignment concerns). Compared to a
30/// `(Vec<u8>, Dtype, n)` triple this trades 16 bytes of discriminant
31/// for type safety on a non-hot-path (CPU is the slow path; if you
32/// care about throughput here, switch backends).
33pub enum CpuBuf {
34    F32(Vec<f32>),
35    F16(Vec<f16>),
36    U32(Vec<u32>),
37    I32(Vec<i32>),
38    I8(Vec<i8>),
39}
40
41impl CpuBuf {
42    pub fn alloc(dtype: Dtype, n: usize) -> Self {
43        match dtype {
44            Dtype::F32 => CpuBuf::F32(vec![0.0; n]),
45            Dtype::F16 => CpuBuf::F16(vec![f16::ZERO; n]),
46            Dtype::U32 => CpuBuf::U32(vec![0u32; n]),
47            Dtype::I32 => CpuBuf::I32(vec![0i32; n]),
48            Dtype::I8 => CpuBuf::I8(vec![0i8; n]),
49        }
50    }
51
52    pub fn dtype(&self) -> Dtype {
53        match self {
54            CpuBuf::F32(_) => Dtype::F32,
55            CpuBuf::F16(_) => Dtype::F16,
56            CpuBuf::U32(_) => Dtype::U32,
57            CpuBuf::I32(_) => Dtype::I32,
58            CpuBuf::I8(_) => Dtype::I8,
59        }
60    }
61
62    pub fn len(&self) -> usize {
63        match self {
64            CpuBuf::F32(v) => v.len(),
65            CpuBuf::F16(v) => v.len(),
66            CpuBuf::U32(v) => v.len(),
67            CpuBuf::I32(v) => v.len(),
68            CpuBuf::I8(v) => v.len(),
69        }
70    }
71
72    pub fn is_empty(&self) -> bool {
73        self.len() == 0
74    }
75
76    /// Typed accessor — panics on dtype mismatch. Catches the silent
77    /// type-tunnel bugs the old `from_slice_i32` route would have
78    /// papered over.
79    pub fn as_f32(&self) -> &[f32] {
80        match self {
81            CpuBuf::F32(v) => v,
82            _ => panic!("CpuBuf::as_f32 on dtype {}", self.dtype().name()),
83        }
84    }
85    pub fn as_f32_mut(&mut self) -> &mut [f32] {
86        match self {
87            CpuBuf::F32(v) => v,
88            _ => panic!("CpuBuf::as_f32_mut on dtype {}", self.dtype().name()),
89        }
90    }
91    pub fn as_f16(&self) -> &[f16] {
92        match self {
93            CpuBuf::F16(v) => v,
94            _ => panic!("CpuBuf::as_f16 on dtype {}", self.dtype().name()),
95        }
96    }
97    pub fn as_u32(&self) -> &[u32] {
98        match self {
99            CpuBuf::U32(v) => v,
100            _ => panic!("CpuBuf::as_u32 on dtype {}", self.dtype().name()),
101        }
102    }
103    pub fn as_u32_mut(&mut self) -> &mut [u32] {
104        match self {
105            CpuBuf::U32(v) => v,
106            _ => panic!("CpuBuf::as_u32_mut on dtype {}", self.dtype().name()),
107        }
108    }
109    pub fn as_i32(&self) -> &[i32] {
110        match self {
111            CpuBuf::I32(v) => v,
112            _ => panic!("CpuBuf::as_i32 on dtype {}", self.dtype().name()),
113        }
114    }
115    pub fn as_i32_mut(&mut self) -> &mut [i32] {
116        match self {
117            CpuBuf::I32(v) => v,
118            _ => panic!("CpuBuf::as_i32_mut on dtype {}", self.dtype().name()),
119        }
120    }
121    pub fn as_i8(&self) -> &[i8] {
122        match self {
123            CpuBuf::I8(v) => v,
124            _ => panic!("CpuBuf::as_i8 on dtype {}", self.dtype().name()),
125        }
126    }
127
128    pub fn from_f32(data: Vec<f32>) -> Self {
129        CpuBuf::F32(data)
130    }
131    pub fn from_u32(data: Vec<u32>) -> Self {
132        CpuBuf::U32(data)
133    }
134    pub fn from_i32(data: Vec<i32>) -> Self {
135        CpuBuf::I32(data)
136    }
137}
138
139/// CUDA-side typed buffer. Enum over typed `CudaSlice<T>` variants —
140/// keeps cudarc's typed alloc / dtoh / htod APIs working unchanged
141/// when consumers call `buf.as_f16()` / `buf.as_u32_mut()` etc.
142///
143/// Phase B-1: defined here, not yet wired to `CudaBackend::Buffer`.
144/// Phase B-2 switches `CudaBackend::Buffer = CudaBuf` and migrates
145/// all `CudaSlice<f16>` consumers to the typed accessor pattern.
146#[cfg(feature = "cuda")]
147pub enum CudaBuf {
148    F32(cudarc::driver::CudaSlice<f32>),
149    F16(cudarc::driver::CudaSlice<f16>),
150    U32(cudarc::driver::CudaSlice<u32>),
151    I32(cudarc::driver::CudaSlice<i32>),
152    I8(cudarc::driver::CudaSlice<i8>),
153}
154
155#[cfg(feature = "cuda")]
156impl CudaBuf {
157    pub fn dtype(&self) -> Dtype {
158        match self {
159            CudaBuf::F32(_) => Dtype::F32,
160            CudaBuf::F16(_) => Dtype::F16,
161            CudaBuf::U32(_) => Dtype::U32,
162            CudaBuf::I32(_) => Dtype::I32,
163            CudaBuf::I8(_) => Dtype::I8,
164        }
165    }
166
167    pub fn len(&self) -> usize {
168        match self {
169            CudaBuf::F32(s) => s.len(),
170            CudaBuf::F16(s) => s.len(),
171            CudaBuf::U32(s) => s.len(),
172            CudaBuf::I32(s) => s.len(),
173            CudaBuf::I8(s) => s.len(),
174        }
175    }
176
177    pub fn is_empty(&self) -> bool {
178        self.len() == 0
179    }
180
181    pub fn as_f16(&self) -> &cudarc::driver::CudaSlice<f16> {
182        match self {
183            CudaBuf::F16(s) => s,
184            _ => panic!("CudaBuf::as_f16 on dtype {}", self.dtype().name()),
185        }
186    }
187    pub fn as_f16_mut(&mut self) -> &mut cudarc::driver::CudaSlice<f16> {
188        match self {
189            CudaBuf::F16(s) => s,
190            _ => panic!("CudaBuf::as_f16_mut on dtype {}", self.dtype().name()),
191        }
192    }
193    pub fn as_f32(&self) -> &cudarc::driver::CudaSlice<f32> {
194        match self {
195            CudaBuf::F32(s) => s,
196            _ => panic!("CudaBuf::as_f32 on dtype {}", self.dtype().name()),
197        }
198    }
199    pub fn as_u32(&self) -> &cudarc::driver::CudaSlice<u32> {
200        match self {
201            CudaBuf::U32(s) => s,
202            _ => panic!("CudaBuf::as_u32 on dtype {}", self.dtype().name()),
203        }
204    }
205    pub fn as_u32_mut(&mut self) -> &mut cudarc::driver::CudaSlice<u32> {
206        match self {
207            CudaBuf::U32(s) => s,
208            _ => panic!("CudaBuf::as_u32_mut on dtype {}", self.dtype().name()),
209        }
210    }
211    pub fn as_i32(&self) -> &cudarc::driver::CudaSlice<i32> {
212        match self {
213            CudaBuf::I32(s) => s,
214            _ => panic!("CudaBuf::as_i32 on dtype {}", self.dtype().name()),
215        }
216    }
217    pub fn as_i8(&self) -> &cudarc::driver::CudaSlice<i8> {
218        match self {
219            CudaBuf::I8(s) => s,
220            _ => panic!("CudaBuf::as_i8 on dtype {}", self.dtype().name()),
221        }
222    }
223    pub fn as_i8_mut(&mut self) -> &mut cudarc::driver::CudaSlice<i8> {
224        match self {
225            CudaBuf::I8(s) => s,
226            _ => panic!("CudaBuf::as_i8_mut on dtype {}", self.dtype().name()),
227        }
228    }
229    pub fn as_f32_mut(&mut self) -> &mut cudarc::driver::CudaSlice<f32> {
230        match self {
231            CudaBuf::F32(s) => s,
232            _ => panic!("CudaBuf::as_f32_mut on dtype {}", self.dtype().name()),
233        }
234    }
235    pub fn as_i32_mut(&mut self) -> &mut cudarc::driver::CudaSlice<i32> {
236        match self {
237            CudaBuf::I32(s) => s,
238            _ => panic!("CudaBuf::as_i32_mut on dtype {}", self.dtype().name()),
239        }
240    }
241
242    /// Constructors — used by `Backend::alloc` etc.
243    pub fn from_f16(s: cudarc::driver::CudaSlice<f16>) -> Self {
244        CudaBuf::F16(s)
245    }
246    pub fn from_f32(s: cudarc::driver::CudaSlice<f32>) -> Self {
247        CudaBuf::F32(s)
248    }
249    pub fn from_u32(s: cudarc::driver::CudaSlice<u32>) -> Self {
250        CudaBuf::U32(s)
251    }
252    pub fn from_i32(s: cudarc::driver::CudaSlice<i32>) -> Self {
253        CudaBuf::I32(s)
254    }
255    pub fn from_i8(s: cudarc::driver::CudaSlice<i8>) -> Self {
256        CudaBuf::I8(s)
257    }
258
259    /// Bit-reinterpret the underlying buffer as a view of another type.
260    /// Dispatches on the active variant so callers don't have to know
261    /// which inner `CudaSlice<T>` holds the bytes — handy when integer
262    /// data was allocated as `CudaBuf::U32` (post-B-2) but a kernel
263    /// expects an `&CudaView<i32>` (same byte pattern, signed view).
264    /// `len` is in elements of the target type `T`; cudarc returns
265    /// `None` if `len * size_of::<T>()` doesn't fit the source bytes.
266    pub unsafe fn transmute<T>(&self, len: usize) -> Option<cudarc::driver::CudaView<'_, T>> {
267        match self {
268            CudaBuf::F16(s) => unsafe { s.transmute(len) },
269            CudaBuf::F32(s) => unsafe { s.transmute(len) },
270            CudaBuf::U32(s) => unsafe { s.transmute(len) },
271            CudaBuf::I32(s) => unsafe { s.transmute(len) },
272            CudaBuf::I8(s) => unsafe { s.transmute(len) },
273        }
274    }
275}
276
277// Implement `PushKernelArg<&CudaBuf>` / `<&mut CudaBuf>` so existing
278// `launch_builder.arg(&buf)` callsites compile without changing each
279// to `.arg(buf.as_f16())`. Delegates to the inner CudaSlice's existing
280// `PushKernelArg<&CudaSlice<T>>` impl, dispatched on variant.
281#[cfg(feature = "cuda")]
282unsafe impl<'a, 'b: 'a> cudarc::driver::PushKernelArg<&'b CudaBuf>
283    for cudarc::driver::LaunchArgs<'a>
284{
285    fn arg(&mut self, arg: &'b CudaBuf) -> &mut Self {
286        match arg {
287            CudaBuf::F16(s) => self.arg(s),
288            CudaBuf::F32(s) => self.arg(s),
289            CudaBuf::U32(s) => self.arg(s),
290            CudaBuf::I32(s) => self.arg(s),
291            CudaBuf::I8(s) => self.arg(s),
292        }
293    }
294}
295
296#[cfg(feature = "cuda")]
297unsafe impl<'a, 'b: 'a> cudarc::driver::PushKernelArg<&'b mut CudaBuf>
298    for cudarc::driver::LaunchArgs<'a>
299{
300    fn arg(&mut self, arg: &'b mut CudaBuf) -> &mut Self {
301        match arg {
302            CudaBuf::F16(s) => self.arg(s),
303            CudaBuf::F32(s) => self.arg(s),
304            CudaBuf::U32(s) => self.arg(s),
305            CudaBuf::I32(s) => self.arg(s),
306            CudaBuf::I8(s) => self.arg(s),
307        }
308    }
309}