cndrv 0.1.2

Safe Cambricon driver API.
Documentation
use crate::{bindings::CNaddr, impl_spore, AsRaw, Blob, CurrentCtx, Queue};
use std::{
    alloc::Layout,
    ffi::c_void,
    marker::PhantomData,
    mem::size_of_val,
    ops::{Deref, DerefMut},
    ptr::null_mut,
    slice::{from_raw_parts, from_raw_parts_mut},
};

#[repr(transparent)]
pub struct DevByte(#[allow(unused)] u8);

#[inline]
pub fn memcpy_d2h<T: Copy>(dst: &mut [T], src: &[DevByte]) {
    let len = size_of_val(dst);
    let dst = dst.as_mut_ptr().cast();
    assert_eq!(len, size_of_val(src));
    cndrv!(cnMemcpyDtoH(dst, src.as_ptr() as _, len as _));
}

#[inline]
pub fn memcpy_h2d<T: Copy>(dst: &mut [DevByte], src: &[T]) {
    let len = size_of_val(src);
    let src = src.as_ptr().cast();
    assert_eq!(len, size_of_val(dst));
    cndrv!(cnMemcpyHtoD(dst.as_ptr() as _, src, len as _));
}

#[inline]
pub fn memcpy_d2d(dst: &mut [DevByte], src: &[DevByte]) {
    let len = size_of_val(src);
    assert_eq!(len, size_of_val(dst));
    cndrv!(cnMemcpyDtoD(dst.as_ptr() as _, src.as_ptr() as _, len as _));
}

impl Queue<'_> {
    #[inline]
    pub fn memcpy_h2d<T: Copy>(&self, dst: &mut [DevByte], src: &[T]) {
        let len = size_of_val(src);
        let src = src.as_ptr().cast();
        assert_eq!(len, size_of_val(dst));
        cndrv!(cnMemcpyHtoDAsync_V2(
            dst.as_ptr() as _,
            src,
            len as _,
            self.as_raw()
        ));
    }

    #[inline]
    pub fn memcpy_d2d(&self, dst: &mut [DevByte], src: &[DevByte]) {
        let len = size_of_val(src);
        assert_eq!(len, size_of_val(dst));
        cndrv!(cnMemcpyDtoDAsync(
            dst.as_ptr() as _,
            src.as_ptr() as _,
            len as _,
            self.as_raw()
        ));
    }
}

impl_spore!(DevMem and DevMemSpore by Blob<CNaddr>);

impl CurrentCtx {
    pub fn malloc<T: Copy>(&self, len: usize) -> DevMem<'_> {
        let len = Layout::array::<T>(len).unwrap().size();
        let mut ptr = 0;
        cndrv!(cnMalloc(&mut ptr, len as _));
        DevMem(unsafe { self.wrap_raw(Blob { ptr, len }) }, PhantomData)
    }

    pub fn from_host<T: Copy>(&self, slice: &[T]) -> DevMem<'_> {
        let len = size_of_val(slice);
        let src = slice.as_ptr().cast();
        let mut ptr = 0;
        cndrv!(cnMalloc(&mut ptr, len as _));
        cndrv!(cnMemcpyHtoD(ptr, src, len as _));
        DevMem(unsafe { self.wrap_raw(Blob { ptr, len }) }, PhantomData)
    }
}

impl Drop for DevMem<'_> {
    #[inline]
    fn drop(&mut self) {
        cndrv!(cnFree(self.0.raw.ptr));
    }
}

impl Deref for DevMem<'_> {
    type Target = [DevByte];
    #[inline]
    fn deref(&self) -> &Self::Target {
        if self.0.raw.len == 0 {
            &[]
        } else {
            unsafe { from_raw_parts(self.0.raw.ptr as _, self.0.raw.len) }
        }
    }
}

impl DerefMut for DevMem<'_> {
    #[inline]
    fn deref_mut(&mut self) -> &mut Self::Target {
        if self.0.raw.len == 0 {
            &mut []
        } else {
            unsafe { from_raw_parts_mut(self.0.raw.ptr as _, self.0.raw.len) }
        }
    }
}

impl AsRaw for DevMemSpore {
    type Raw = CNaddr;
    #[inline]
    unsafe fn as_raw(&self) -> Self::Raw {
        self.0.raw.ptr
    }
}

impl DevMemSpore {
    #[inline]
    pub const fn len(&self) -> usize {
        self.0.raw.len
    }

    #[inline]
    pub const fn is_empty(&self) -> bool {
        self.0.raw.len == 0
    }
}

impl_spore!(HostMem and HostMemSpore by Blob<*mut c_void>);

impl CurrentCtx {
    pub fn malloc_host<T: Copy>(&self, len: usize) -> HostMem {
        let len = Layout::array::<T>(len).unwrap().size();
        let mut ptr = null_mut();
        cndrv!(cnMallocHost(&mut ptr, len as _));
        HostMem(unsafe { self.wrap_raw(Blob { ptr, len }) }, PhantomData)
    }
}

impl Drop for HostMem<'_> {
    #[inline]
    fn drop(&mut self) {
        cndrv!(cnFreeHost(self.0.raw.ptr));
    }
}

impl AsRaw for HostMem<'_> {
    type Raw = *mut c_void;
    #[inline]
    unsafe fn as_raw(&self) -> Self::Raw {
        self.0.raw.ptr
    }
}

impl Deref for HostMem<'_> {
    type Target = [u8];

    #[inline]
    fn deref(&self) -> &Self::Target {
        unsafe { from_raw_parts(self.0.raw.ptr.cast(), self.0.raw.len) }
    }
}

impl DerefMut for HostMem<'_> {
    #[inline]
    fn deref_mut(&mut self) -> &mut Self::Target {
        unsafe { from_raw_parts_mut(self.0.raw.ptr.cast(), self.0.raw.len) }
    }
}

impl Deref for HostMemSpore {
    type Target = [u8];

    #[inline]
    fn deref(&self) -> &Self::Target {
        unsafe { from_raw_parts(self.0.raw.ptr.cast(), self.0.raw.len) }
    }
}

impl DerefMut for HostMemSpore {
    #[inline]
    fn deref_mut(&mut self) -> &mut Self::Target {
        unsafe { from_raw_parts_mut(self.0.raw.ptr.cast(), self.0.raw.len) }
    }
}

#[test]
fn test_behavior() {
    crate::init();
    let Some(dev) = crate::Device::fetch() else {
        return;
    };
    let mut ptr = null_mut();
    dev.context().apply(|_| {
        cndrv!(cnMallocHost(&mut ptr, 128));
        cndrv!(cnFreeHost(ptr));
    });
    ptr = null_mut();
    cndrv!(cnFreeHost(ptr));
}