kaio_runtime/buffer.rs
1//! Typed device memory buffers.
2
3use cudarc::driver::CudaSlice;
4
5use crate::device::KaioDevice;
6use crate::error::Result;
7
8/// A typed buffer in GPU device memory, wrapping cudarc's [`CudaSlice<T>`].
9///
10/// Created via [`KaioDevice::alloc_from`] or [`KaioDevice::alloc_zeros`].
11///
12/// # Memory management
13///
14/// `GpuBuffer` does **not** implement [`Drop`] manually — cudarc's
15/// [`CudaSlice`] handles device memory deallocation automatically when
16/// the buffer is dropped. The `CudaSlice` holds an `Arc<CudaContext>`
17/// internally, ensuring the CUDA context outlives the allocation.
18///
19/// # Representation — load-bearing
20///
21/// `#[repr(transparent)]` guarantees this newtype has identical memory
22/// layout, size, and alignment to its sole field [`CudaSlice<T>`]. The
23/// `kaio-candle` bridge crate relies on this to cast `&CudaSlice<T>`
24/// (borrowed from candle's `CudaStorage`) to `&GpuBuffer<T>` for passing
25/// into `kaio-ops` kernel entry points without round-tripping through an
26/// owned clone.
27///
28/// **Do not remove `#[repr(transparent)]` or add a second field without
29/// coordinating with `kaio-candle`.** The soundness-assertion tests at the
30/// bottom of this module will fail at compile time if the layout diverges.
31#[repr(transparent)]
32pub struct GpuBuffer<T> {
33 inner: CudaSlice<T>,
34}
35
36impl<T> GpuBuffer<T> {
37 /// Wrap an existing cudarc [`CudaSlice`] as a [`GpuBuffer`].
38 ///
39 /// Takes ownership of the slice. The returned `GpuBuffer` drops the
40 /// underlying device allocation via cudarc's normal `Drop` on its own
41 /// drop.
42 ///
43 /// Used by bridge crates (e.g. `kaio-candle`) to consume a
44 /// fresh-allocated slice back into the KAIO buffer type after a kernel
45 /// produces its output.
46 pub fn from_cuda_slice(inner: CudaSlice<T>) -> Self {
47 Self { inner }
48 }
49
50 /// Consume the buffer and return the underlying cudarc [`CudaSlice`].
51 ///
52 /// Used by bridge crates to hand the owned output slice back to the
53 /// host framework (e.g. wrapping into `candle_core::CudaStorage`) after
54 /// a KAIO kernel has written into the buffer.
55 pub fn into_cuda_slice(self) -> CudaSlice<T> {
56 self.inner
57 }
58
59 /// Number of elements in the buffer.
60 pub fn len(&self) -> usize {
61 self.inner.len()
62 }
63
64 /// Returns `true` if the buffer contains no elements.
65 pub fn is_empty(&self) -> bool {
66 self.inner.len() == 0
67 }
68
69 /// Access the underlying [`CudaSlice`] for passing to cudarc launch
70 /// operations.
71 ///
72 /// This is the escape hatch for Sprint 1.7's launch builder — the
73 /// caller pushes `&buf.inner()` as a kernel argument.
74 pub fn inner(&self) -> &CudaSlice<T> {
75 &self.inner
76 }
77
78 /// Mutable access to the underlying [`CudaSlice`].
79 pub fn inner_mut(&mut self) -> &mut CudaSlice<T> {
80 &mut self.inner
81 }
82}
83
84impl<T: cudarc::driver::DeviceRepr + Default + Clone + Unpin> GpuBuffer<T> {
85 /// Transfer buffer contents from device to host.
86 ///
87 /// Requires a reference to the [`KaioDevice`] that created this buffer
88 /// (for stream access). The device is borrowed, not consumed.
89 ///
90 /// # Example
91 ///
92 /// ```ignore
93 /// let device = KaioDevice::new(0)?;
94 /// let buf = device.alloc_from(&[1.0f32, 2.0, 3.0])?;
95 /// let host_data = buf.to_host(&device)?;
96 /// assert_eq!(host_data, vec![1.0, 2.0, 3.0]);
97 /// ```
98 pub fn to_host(&self, device: &KaioDevice) -> Result<Vec<T>> {
99 Ok(device.stream().clone_dtoh(&self.inner)?)
100 }
101}
102
103// Soundness assertions for the `#[repr(transparent)]` contract above.
104// Compile-time: any future change to `GpuBuffer`'s layout (adding a field,
105// removing `#[repr(transparent)]`, changing the inner type) fails the build
106// here instead of producing UB at the `kaio-candle` transmute site.
107// Placed at end-of-file to satisfy the `clippy::items_after_test_module`
108// lint.
109#[cfg(test)]
110mod repr_soundness {
111 use super::GpuBuffer;
112 use cudarc::driver::CudaSlice;
113 use half::f16;
114 use static_assertions::{assert_eq_align, assert_eq_size};
115
116 assert_eq_size!(GpuBuffer<f32>, CudaSlice<f32>);
117 assert_eq_align!(GpuBuffer<f32>, CudaSlice<f32>);
118 assert_eq_size!(GpuBuffer<f16>, CudaSlice<f16>);
119 assert_eq_align!(GpuBuffer<f16>, CudaSlice<f16>);
120 assert_eq_size!(GpuBuffer<i8>, CudaSlice<i8>);
121 assert_eq_align!(GpuBuffer<i8>, CudaSlice<i8>);
122 assert_eq_size!(GpuBuffer<u32>, CudaSlice<u32>);
123 assert_eq_align!(GpuBuffer<u32>, CudaSlice<u32>);
124}