sp1-gpu-cudart 6.2.2

High-level Rust API for SP1 GPU operations
Documentation
use std::{
    borrow::{Borrow, BorrowMut},
    ffi::c_void,
    mem::MaybeUninit,
    ptr::{self, NonNull},
};

use slop_alloc::{AllocError, Allocator, RawBuffer};
use sp1_gpu_sys::runtime::{cuda_free_host, cuda_malloc_host};

use crate::CudaError;

pub const PINNED_ALLOCATOR: PinnedAllocator = PinnedAllocator;

#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub struct PinnedAllocator;

unsafe impl Allocator for PinnedAllocator {
    unsafe fn allocate(
        &self,
        layout: std::alloc::Layout,
    ) -> Result<std::ptr::NonNull<[u8]>, slop_alloc::AllocError> {
        let mut ptr: *mut c_void = ptr::null_mut();
        unsafe {
            CudaError::result_from_ffi(cuda_malloc_host(
                &mut ptr as *mut *mut c_void,
                layout.size(),
            ))
            .map_err(|_| AllocError)?;
        };
        let ptr = ptr as *mut u8;
        Ok(NonNull::slice_from_raw_parts(NonNull::new_unchecked(ptr), layout.size()))
    }

    unsafe fn deallocate(&self, ptr: std::ptr::NonNull<u8>, _layout: std::alloc::Layout) {
        CudaError::result_from_ffi(cuda_free_host(ptr.as_ptr() as *mut c_void)).unwrap()
    }
}

pub struct PinnedBuffer<T> {
    buf: RawBuffer<T, PinnedAllocator>,
}

impl<T> PinnedBuffer<T> {
    pub fn with_capacity(capacity: usize) -> Self {
        Self { buf: RawBuffer::with_capacity_in(capacity, PINNED_ALLOCATOR) }
    }

    pub fn as_slice(&self) -> &[MaybeUninit<T>] {
        self.borrow()
    }

    pub fn as_mut_slice(&mut self) -> &mut [MaybeUninit<T>] {
        self.borrow_mut()
    }

    #[inline]
    pub fn capacity(&self) -> usize {
        self.buf.capacity()
    }

    #[inline]
    pub fn as_ptr(&self) -> *const T {
        self.buf.ptr() as *const T
    }

    #[inline]
    pub fn as_mut_ptr(&mut self) -> *mut T {
        self.buf.ptr()
    }
}

impl<T> Borrow<[MaybeUninit<T>]> for PinnedBuffer<T> {
    fn borrow(&self) -> &[MaybeUninit<T>] {
        unsafe {
            std::slice::from_raw_parts(self.buf.ptr() as *const MaybeUninit<T>, self.buf.capacity())
        }
    }
}

impl<T> BorrowMut<[MaybeUninit<T>]> for PinnedBuffer<T> {
    fn borrow_mut(&mut self) -> &mut [MaybeUninit<T>] {
        unsafe {
            std::slice::from_raw_parts_mut(
                self.buf.ptr() as *mut MaybeUninit<T>,
                self.buf.capacity(),
            )
        }
    }
}

unsafe impl<T> Send for PinnedBuffer<T> {}
unsafe impl<T> Sync for PinnedBuffer<T> {}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_pinned_buffer() {
        let mut buf = PinnedBuffer::<u32>::with_capacity(10);
        buf.as_mut_slice()[0].write(1);
        assert_eq!(buf.capacity(), 10);
        assert_eq!(unsafe { buf.as_slice()[0].assume_init() }, 1);
    }
}