rifft 1.1.2

RIFFT FFT/DLPack/FFI bridge
Documentation
use crate::types::{Result, RifftError};
use libc::c_void;
use num_complex::Complex32;
use std::os::raw::{c_int, c_ulonglong};

#[repr(C)]
#[derive(Clone, Copy)]
pub struct DLDevice {
    pub device_type: c_int,
    pub device_id: c_int,
}

#[repr(C)]
#[derive(Clone, Copy)]
pub struct DLDataType {
    pub code: u8,
    pub bits: u8,
    pub lanes: u16,
}

#[repr(C)]
#[derive(Clone, Copy)]
pub struct DLTensor {
    pub data: *mut c_void,
    pub device: DLDevice,
    pub ndim: c_int,
    pub dtype: DLDataType,
    pub shape: *mut i64,
    pub strides: *mut i64,
    pub byte_offset: c_ulonglong,
}

pub type DlDeleter = Option<unsafe extern "C" fn(*mut DLManagedTensor)>;

#[repr(C)]
pub struct DLManagedTensor {
    pub dl_tensor: DLTensor,
    pub manager_ctx: *mut c_void,
    pub deleter: DlDeleter,
}

unsafe impl Send for DLManagedTensor {}
unsafe impl Sync for DLManagedTensor {}

pub struct DlManagedTensorHandle {
    ptr: *mut DLManagedTensor,
    len: usize,
    height: usize,
    width: usize,
}

unsafe impl Send for DlManagedTensorHandle {}
unsafe impl Sync for DlManagedTensorHandle {}

impl DlManagedTensorHandle {
    /// # Safety
    /// `ptr` must be a valid `DLManagedTensor` allocated by a DLPack producer.
    pub unsafe fn from_raw(ptr: *mut DLManagedTensor) -> Result<Self> {
        if ptr.is_null() {
            return Err(RifftError::DlPack("received null pointer".into()));
        }
        let tensor = &*ptr;
        let dl = &tensor.dl_tensor;
        if dl.ndim < 2 {
            return Err(RifftError::DlPack("expected at least 2 dimensions".into()));
        }
        let dtype = dl.dtype;
        let is_complex32 = dtype.code == 5 && dtype.bits == 64;
        if !is_complex32 {
            return Err(RifftError::UnsupportedDType);
        }
        let ndim = dl.ndim as isize;
        let shape_slice = std::slice::from_raw_parts(dl.shape, ndim as usize);
        let height = shape_slice[shape_slice.len() - 2] as usize;
        let width = shape_slice[shape_slice.len() - 1] as usize;
        let len = shape_slice
            .iter()
            .fold(1usize, |acc, dim| acc * (*dim as usize));
        if !dl.strides.is_null() {
            let strides = std::slice::from_raw_parts(dl.strides, ndim as usize);
            if strides[strides.len() - 1] != 1 {
                return Err(RifftError::NonContiguous);
            }
            if strides[strides.len() - 2] as usize != width {
                return Err(RifftError::NonContiguous);
            }
        }
        Ok(Self {
            ptr,
            len,
            height,
            width,
        })
    }

    /// # Safety
    /// Caller must ensure the underlying tensor memory is unique for mutation.
    /// # Safety
    /// Caller must ensure the underlying tensor memory is uniquely owned.
    pub unsafe fn as_mut_slice(&mut self) -> &mut [Complex32] {
        let tensor = &*self.ptr;
        let dl = &tensor.dl_tensor;
        let data = (dl.data as *mut Complex32).add(dl.byte_offset as usize / 8);
        std::slice::from_raw_parts_mut(data, self.len)
    }

    pub fn len(&self) -> usize {
        self.len
    }

    pub fn is_empty(&self) -> bool {
        self.len == 0
    }

    pub fn height(&self) -> usize {
        self.height
    }

    pub fn width(&self) -> usize {
        self.width
    }

    pub fn into_raw(self) -> *mut DLManagedTensor {
        let ptr = self.ptr;
        std::mem::forget(self);
        ptr
    }
}

impl Drop for DlManagedTensorHandle {
    fn drop(&mut self) {
        unsafe {
            if let Some(deleter) = (*self.ptr).deleter {
                deleter(self.ptr);
            }
        }
    }
}

/// # Safety
/// `ptr` must be a valid DLPack capsule pointer obtained from Python.
pub unsafe fn from_dlpack_capsule(ptr: *mut c_void) -> Result<DlManagedTensorHandle> {
    DlManagedTensorHandle::from_raw(ptr as *mut DLManagedTensor)
}

/// # Safety
/// The caller must guarantee the slice stays alive for the lifetime of the capsule.
pub unsafe fn to_dlpack_capsule(
    slice: &mut [Complex32],
    height: usize,
    width: usize,
) -> *mut DLManagedTensor {
    extern "C" fn release(ptr: *mut DLManagedTensor) {
        unsafe {
            if !(*ptr).manager_ctx.is_null() {
                let _ = Box::from_raw((*ptr).manager_ctx as *mut [i64; 2]);
            }
            let _ = Box::from_raw(ptr);
        }
    }

    let shape_box = Box::new([height as i64, width as i64]);
    let shape_ptr = Box::into_raw(shape_box);
    let tensor = Box::new(DLManagedTensor {
        dl_tensor: DLTensor {
            data: slice.as_mut_ptr() as *mut c_void,
            device: DLDevice {
                device_type: 1,
                device_id: 0,
            },
            ndim: 2,
            dtype: DLDataType {
                code: 5,
                bits: 64,
                lanes: 1,
            },
            shape: shape_ptr as *mut i64,
            strides: std::ptr::null_mut(),
            byte_offset: 0,
        },
        manager_ctx: shape_ptr as *mut c_void,
        deleter: Some(release),
    });
    Box::into_raw(tensor)
}