Skip to main content

kaio_runtime/
buffer.rs

1//! Typed device memory buffers.
2
3use cudarc::driver::CudaSlice;
4
5use crate::device::KaioDevice;
6use crate::error::Result;
7
8/// A typed buffer in GPU device memory, wrapping cudarc's [`CudaSlice<T>`].
9///
10/// Created via [`KaioDevice::alloc_from`] or [`KaioDevice::alloc_zeros`].
11///
12/// # Memory management
13///
14/// `GpuBuffer` does **not** implement [`Drop`] manually — cudarc's
15/// [`CudaSlice`] handles device memory deallocation automatically when
16/// the buffer is dropped. The `CudaSlice` holds an `Arc<CudaContext>`
17/// internally, ensuring the CUDA context outlives the allocation.
18pub struct GpuBuffer<T> {
19    inner: CudaSlice<T>,
20}
21
22impl<T> GpuBuffer<T> {
23    /// Create a `GpuBuffer` from an existing `CudaSlice`.
24    pub(crate) fn from_raw(inner: CudaSlice<T>) -> Self {
25        Self { inner }
26    }
27
28    /// Number of elements in the buffer.
29    pub fn len(&self) -> usize {
30        self.inner.len()
31    }
32
33    /// Returns `true` if the buffer contains no elements.
34    pub fn is_empty(&self) -> bool {
35        self.inner.len() == 0
36    }
37
38    /// Access the underlying [`CudaSlice`] for passing to cudarc launch
39    /// operations.
40    ///
41    /// This is the escape hatch for Sprint 1.7's launch builder — the
42    /// caller pushes `&buf.inner()` as a kernel argument.
43    pub fn inner(&self) -> &CudaSlice<T> {
44        &self.inner
45    }
46
47    /// Mutable access to the underlying [`CudaSlice`].
48    pub fn inner_mut(&mut self) -> &mut CudaSlice<T> {
49        &mut self.inner
50    }
51}
52
53impl<T: cudarc::driver::DeviceRepr + Default + Clone + Unpin> GpuBuffer<T> {
54    /// Transfer buffer contents from device to host.
55    ///
56    /// Requires a reference to the [`KaioDevice`] that created this buffer
57    /// (for stream access). The device is borrowed, not consumed.
58    ///
59    /// # Example
60    ///
61    /// ```ignore
62    /// let device = KaioDevice::new(0)?;
63    /// let buf = device.alloc_from(&[1.0f32, 2.0, 3.0])?;
64    /// let host_data = buf.to_host(&device)?;
65    /// assert_eq!(host_data, vec![1.0, 2.0, 3.0]);
66    /// ```
67    pub fn to_host(&self, device: &KaioDevice) -> Result<Vec<T>> {
68        Ok(device.stream().clone_dtoh(&self.inner)?)
69    }
70}