use crate::data::TensorData;
use crate::error::{Error, Result};
use crate::ffi;
use crate::graph::Tensor;
use apple_metal::MetalDevice;
use core::ffi::c_void;
use core::ptr;
fn release_handle(ptr: &mut *mut c_void) {
if !ptr.is_null() {
unsafe { ffi::mpsgraph_object_release(*ptr) };
*ptr = ptr::null_mut();
}
}
fn copy_optional_signed_shape(
handle: *mut c_void,
has_shape: unsafe extern "C" fn(*mut c_void) -> bool,
shape_len: unsafe extern "C" fn(*mut c_void) -> usize,
copy_shape: unsafe extern "C" fn(*mut c_void, *mut isize),
) -> Option<Vec<isize>> {
if unsafe { !has_shape(handle) } {
return None;
}
let len = unsafe { shape_len(handle) };
let mut shape = vec![0_isize; len];
if len > 0 {
unsafe { copy_shape(handle, shape.as_mut_ptr()) };
}
Some(shape)
}
fn collect_tensor_array_box(handle: *mut c_void) -> Vec<Tensor> {
if handle.is_null() {
return Vec::new();
}
let len = unsafe { ffi::mpsgraph_tensor_array_box_len(handle) };
let mut tensors = Vec::with_capacity(len);
for index in 0..len {
let tensor = unsafe { ffi::mpsgraph_tensor_array_box_get(handle, index) };
if !tensor.is_null() {
tensors.push(Tensor::from_raw(tensor));
}
}
let mut box_handle = handle;
release_handle(&mut box_handle);
tensors
}
pub(crate) fn collect_tensor_data_array_box(handle: *mut c_void) -> Vec<TensorData> {
if handle.is_null() {
return Vec::new();
}
let len = unsafe { ffi::mpsgraph_tensor_data_array_box_len(handle) };
let mut values = Vec::with_capacity(len);
for index in 0..len {
let value = unsafe { ffi::mpsgraph_tensor_data_array_box_get(handle, index) };
if !value.is_null() {
values.push(TensorData::from_raw(value));
}
}
let mut box_handle = handle;
release_handle(&mut box_handle);
values
}
pub(crate) fn collect_shaped_type_array_box(handle: *mut c_void) -> Vec<ShapedType> {
if handle.is_null() {
return Vec::new();
}
let len = unsafe { ffi::mpsgraph_shaped_type_array_box_len(handle) };
let mut values = Vec::with_capacity(len);
for index in 0..len {
let value = unsafe { ffi::mpsgraph_shaped_type_array_box_get(handle, index) };
if !value.is_null() {
values.push(ShapedType { ptr: value });
}
}
let mut box_handle = handle;
release_handle(&mut box_handle);
values
}
pub mod graph_device_type {
pub const METAL: u32 = 0;
}
pub struct GraphDevice {
ptr: *mut c_void,
}
unsafe impl Send for GraphDevice {}
unsafe impl Sync for GraphDevice {}
impl Drop for GraphDevice {
fn drop(&mut self) {
release_handle(&mut self.ptr);
}
}
impl GraphDevice {
#[must_use]
pub fn from_metal_device(device: &MetalDevice) -> Option<Self> {
let ptr = unsafe { ffi::mpsgraph_device_new_with_metal_device(device.as_ptr()) };
if ptr.is_null() {
None
} else {
Some(Self { ptr })
}
}
#[must_use]
pub const fn as_ptr(&self) -> *mut c_void {
self.ptr
}
#[must_use]
pub fn device_type(&self) -> u32 {
unsafe { ffi::mpsgraph_device_type(self.ptr) }
}
}
pub struct ShapedType {
ptr: *mut c_void,
}
unsafe impl Send for ShapedType {}
unsafe impl Sync for ShapedType {}
impl Drop for ShapedType {
fn drop(&mut self) {
release_handle(&mut self.ptr);
}
}
impl ShapedType {
#[must_use]
pub fn new(shape: Option<&[isize]>, data_type: u32) -> Option<Self> {
let (shape_ptr, shape_len) = shape.map_or((ptr::null(), 0), |shape| (shape.as_ptr(), shape.len()));
let ptr = unsafe { ffi::mpsgraph_shaped_type_new(shape_ptr, shape_len, data_type) };
if ptr.is_null() {
None
} else {
Some(Self { ptr })
}
}
#[must_use]
pub(crate) const fn as_ptr(&self) -> *mut c_void {
self.ptr
}
#[must_use]
pub fn shape(&self) -> Option<Vec<isize>> {
copy_optional_signed_shape(
self.ptr,
ffi::mpsgraph_shaped_type_has_shape,
ffi::mpsgraph_shaped_type_shape_len,
ffi::mpsgraph_shaped_type_copy_shape,
)
}
#[must_use]
pub fn data_type(&self) -> u32 {
unsafe { ffi::mpsgraph_shaped_type_data_type(self.ptr) }
}
pub fn set_shape(&self, shape: Option<&[isize]>) -> Result<()> {
let (shape_ptr, shape_len) = shape.map_or((ptr::null(), 0), |shape| (shape.as_ptr(), shape.len()));
let ok = unsafe { ffi::mpsgraph_shaped_type_set_shape(self.ptr, shape_ptr, shape_len) };
if ok {
Ok(())
} else {
Err(Error::OperationFailed("failed to set shaped type shape"))
}
}
pub fn set_data_type(&self, data_type: u32) -> Result<()> {
let ok = unsafe { ffi::mpsgraph_shaped_type_set_data_type(self.ptr, data_type) };
if ok {
Ok(())
} else {
Err(Error::OperationFailed("failed to set shaped type data type"))
}
}
#[must_use]
pub fn is_equal(&self, other: Option<&Self>) -> bool {
let other_ptr = other.map_or(ptr::null_mut(), Self::as_ptr);
unsafe { ffi::mpsgraph_shaped_type_is_equal(self.ptr, other_ptr) }
}
}
pub struct Operation {
ptr: *mut c_void,
}
unsafe impl Send for Operation {}
unsafe impl Sync for Operation {}
impl Drop for Operation {
fn drop(&mut self) {
release_handle(&mut self.ptr);
}
}
impl Operation {
#[must_use]
pub const fn as_ptr(&self) -> *mut c_void {
self.ptr
}
}
impl Tensor {
#[must_use]
pub fn shape(&self) -> Option<Vec<isize>> {
copy_optional_signed_shape(
self.as_ptr(),
ffi::mpsgraph_tensor_has_shape,
ffi::mpsgraph_tensor_shape_len,
ffi::mpsgraph_tensor_copy_shape,
)
}
#[must_use]
pub fn data_type(&self) -> u32 {
unsafe { ffi::mpsgraph_tensor_data_type(self.as_ptr()) }
}
#[must_use]
pub fn operation(&self) -> Option<Operation> {
let ptr = unsafe { ffi::mpsgraph_tensor_operation(self.as_ptr()) };
if ptr.is_null() {
None
} else {
Some(Operation { ptr })
}
}
}
impl TensorData {
#[must_use]
pub fn graph_device_type(&self) -> Option<u32> {
let ptr = unsafe { ffi::mpsgraph_tensor_data_device(self.as_ptr()) };
if ptr.is_null() {
return None;
}
let device = GraphDevice { ptr };
Some(device.device_type())
}
}
pub(crate) fn collect_owned_tensors(handle: *mut c_void) -> Vec<Tensor> {
collect_tensor_array_box(handle)
}