#![allow(non_upper_case_globals, non_snake_case, unused)]
use std::ffi::{c_char, c_void, CString};
use std::ptr;
use std::sync::OnceLock;
pub type Id = *mut c_void;
pub type Class = *mut c_void;
pub type Sel = *mut c_void;
pub type NSInteger = isize;
pub type NSUInteger = usize;
pub const NIL: Id = ptr::null_mut();
extern "C" {
pub fn objc_getClass(name: *const c_char) -> Class;
pub fn sel_registerName(name: *const c_char) -> Sel;
pub fn objc_msgSend(); pub fn objc_alloc(cls: Class) -> Id;
pub fn objc_retain(obj: Id) -> Id;
pub fn objc_release(obj: Id);
pub fn objc_autoreleasePoolPush() -> *mut c_void;
pub fn objc_autoreleasePoolPop(ctx: *mut c_void);
}
extern "C" {
pub fn MTLCreateSystemDefaultDevice() -> Id;
}
#[inline]
pub fn class(name: &str) -> Class {
let cs = CString::new(name).unwrap();
let c = unsafe { objc_getClass(cs.as_ptr()) };
assert!(!c.is_null(), "ObjC class '{}' not found", name);
c
}
#[inline]
pub fn sel(name: &str) -> Sel {
let cs = CString::new(name).unwrap();
unsafe { sel_registerName(cs.as_ptr()) }
}
#[inline]
pub unsafe fn retain(obj: Id) -> Id {
if !obj.is_null() { objc_retain(obj) } else { obj }
}
#[inline]
pub unsafe fn release(obj: Id) {
if !obj.is_null() { objc_release(obj); }
}
macro_rules! msg_send {
(Id; $obj:expr, $sel:expr) => {{
let f: unsafe extern "C" fn(Id, Sel) -> Id =
core::mem::transmute($crate::ffi::objc_msgSend as *const ());
f($obj, $crate::ffi::sel($sel))
}};
(Id; $obj:expr, $sel:expr, ($($at:ty),+), $($av:expr),+) => {{
let f: unsafe extern "C" fn(Id, Sel, $($at),+) -> Id =
core::mem::transmute($crate::ffi::objc_msgSend as *const ());
f($obj, $crate::ffi::sel($sel), $($av),+)
}};
(usize; $obj:expr, $sel:expr) => {{
let f: unsafe extern "C" fn(Id, Sel) -> usize =
core::mem::transmute($crate::ffi::objc_msgSend as *const ());
f($obj, $crate::ffi::sel($sel))
}};
(usize; $obj:expr, $sel:expr, ($($at:ty),+), $($av:expr),+) => {{
let f: unsafe extern "C" fn(Id, Sel, $($at),+) -> usize =
core::mem::transmute($crate::ffi::objc_msgSend as *const ());
f($obj, $crate::ffi::sel($sel), $($av),+)
}};
(isize; $obj:expr, $sel:expr) => {{
let f: unsafe extern "C" fn(Id, Sel) -> isize =
core::mem::transmute($crate::ffi::objc_msgSend as *const ());
f($obj, $crate::ffi::sel($sel))
}};
(bool; $obj:expr, $sel:expr) => {{
let f: unsafe extern "C" fn(Id, Sel) -> bool =
core::mem::transmute($crate::ffi::objc_msgSend as *const ());
f($obj, $crate::ffi::sel($sel))
}};
(void; $obj:expr, $sel:expr) => {{
let f: unsafe extern "C" fn(Id, Sel) =
core::mem::transmute($crate::ffi::objc_msgSend as *const ());
f($obj, $crate::ffi::sel($sel))
}};
(void; $obj:expr, $sel:expr, ($($at:ty),+), $($av:expr),+) => {{
let f: unsafe extern "C" fn(Id, Sel, $($at),+) =
core::mem::transmute($crate::ffi::objc_msgSend as *const ());
f($obj, $crate::ffi::sel($sel), $($av),+)
}};
(f64; $obj:expr, $sel:expr) => {{
let f: unsafe extern "C" fn(Id, Sel) -> f64 =
core::mem::transmute($crate::ffi::objc_msgSend as *const ());
f($obj, $crate::ffi::sel($sel))
}};
(u32; $obj:expr, $sel:expr) => {{
let f: unsafe extern "C" fn(Id, Sel) -> u32 =
core::mem::transmute($crate::ffi::objc_msgSend as *const ());
f($obj, $crate::ffi::sel($sel))
}};
}
pub(crate) use msg_send;
pub mod MPSDataType {
pub const FLOAT32: u32 = 0x10000000 | 32;
pub const FLOAT16: u32 = 0x10000000 | 16;
pub const BFLOAT16: u32 = 0x80000000 | (0x10000000 | 16); pub const INT64: u32 = 0x20000000 | 64;
pub const INT32: u32 = 0x20000000 | 32;
pub const INT16: u32 = 0x20000000 | 16;
pub const INT8: u32 = 0x20000000 | 8;
pub const UINT64: u32 = 64;
pub const UINT32: u32 = 32;
pub const UINT16: u32 = 16;
pub const UINT8: u32 = 8;
pub const BOOL: u32 = 0x80000000 | 8;
}
pub mod MPSGraphPaddingStyle {
pub const EXPLICIT: usize = 0;
}
pub mod MPSGraphTensorNamedDataLayout {
pub const NCHW: usize = 0;
pub const NHWC: usize = 1;
pub const OIHW: usize = 2;
pub const HWIO: usize = 3;
}
pub mod MPSGraphScatterMode {
pub const ADD: isize = 1;
}
pub mod MPSGraphPoolingReturnIndicesMode {
pub const GLOBAL_FLATTEN_2D: isize = 2;
}
pub mod MPSGraphResizeMode {
pub const NEAREST: usize = 0;
pub const BILINEAR: usize = 1;
}
pub unsafe fn ns_number_isize(v: isize) -> Id {
msg_send!(Id; class("NSNumber"), "numberWithLong:", (isize), v)
}
pub unsafe fn ns_number_usize(v: usize) -> Id {
msg_send!(Id; class("NSNumber"), "numberWithUnsignedLong:", (usize), v)
}
pub unsafe fn ns_number_to_isize(n: Id) -> isize {
msg_send!(isize; n, "longValue")
}
pub unsafe fn ns_number_to_usize(n: Id) -> usize {
msg_send!(usize; n, "unsignedLongValue")
}
pub unsafe fn ns_array(objs: &[Id]) -> Id {
msg_send!(Id; class("NSArray"), "arrayWithObjects:count:",
(Id, usize), objs.as_ptr() as Id, objs.len())
}
pub unsafe fn ns_array_count(arr: Id) -> usize {
msg_send!(usize; arr, "count")
}
pub unsafe fn ns_array_get(arr: Id, idx: usize) -> Id {
msg_send!(Id; arr, "objectAtIndex:", (usize), idx)
}
pub unsafe fn ns_data(bytes: *const u8, len: usize) -> Id {
msg_send!(Id; class("NSData"), "dataWithBytes:length:",
(Id, usize), bytes as Id, len)
}
pub unsafe fn ns_dictionary(keys: &[Id], vals: &[Id]) -> Id {
assert_eq!(keys.len(), vals.len());
msg_send!(Id; class("NSDictionary"), "dictionaryWithObjects:forKeys:count:",
(Id, Id, usize), vals.as_ptr() as Id, keys.as_ptr() as Id, keys.len())
}
pub unsafe fn ns_dict_get(dict: Id, key: Id) -> Id {
msg_send!(Id; dict, "objectForKey:", (Id), key)
}
pub unsafe fn ns_isize_array(vals: &[isize]) -> Id {
let nums: Vec<Id> = vals.iter().map(|&v| ns_number_isize(v)).collect();
ns_array(&nums)
}
pub unsafe fn ns_usize_array(vals: &[usize]) -> Id {
let nums: Vec<Id> = vals.iter().map(|&v| ns_number_usize(v)).collect();
ns_array(&nums)
}
pub fn metal_device() -> Id {
static DEV: OnceLock<usize> = OnceLock::new();
*DEV.get_or_init(|| {
let d = unsafe { MTLCreateSystemDefaultDevice() };
assert!(!d.is_null(), "No Metal device found");
d as usize
}) as Id
}
pub fn metal_queue() -> Id {
static Q: OnceLock<usize> = OnceLock::new();
*Q.get_or_init(|| {
let q = unsafe { msg_send!(Id; metal_device(), "newCommandQueue") };
assert!(!q.is_null(), "Failed to create command queue");
q as usize
}) as Id
}
pub unsafe fn mtl_buffer_from_bytes(bytes: &[u8]) -> Id {
msg_send!(Id; metal_device(), "newBufferWithBytes:length:options:",
(Id, usize, usize), bytes.as_ptr() as Id, bytes.len(), 0usize)
}
pub unsafe fn mtl_buffer_zeroed(len: usize) -> Id {
let buf = msg_send!(Id; metal_device(), "newBufferWithLength:options:",
(usize, usize), len, 0usize);
let ptr = mtl_buffer_contents(buf);
std::ptr::write_bytes(ptr as *mut u8, 0, len);
buf
}
pub unsafe fn mtl_buffer_contents(buf: Id) -> *mut c_void {
let f: unsafe extern "C" fn(Id, Sel) -> *mut c_void =
core::mem::transmute(objc_msgSend as *const ());
f(buf, sel("contents"))
}
pub unsafe fn mtl_buffer_length(buf: Id) -> usize {
msg_send!(usize; buf, "length")
}
pub unsafe fn mpsgraph_new() -> Id {
let g = msg_send!(Id; class("MPSGraph"), "new");
assert!(!g.is_null(), "Failed to create MPSGraph");
g
}
pub unsafe fn mpsgraph_device() -> Id {
msg_send!(Id; class("MPSGraphDevice"), "deviceWithMTLDevice:", (Id), metal_device())
}
pub unsafe fn tensor_data_from_buffer(buf: Id, shape: Id, dtype: u32) -> Id {
let alloc = objc_alloc(class("MPSGraphTensorData"));
msg_send!(Id; alloc, "initWithMTLBuffer:shape:dataType:",
(Id, Id, u32), buf, shape, dtype)
}
pub unsafe fn tensor_data_from_nsdata(nsdata: Id, shape: Id, dtype: u32) -> Id {
let alloc = objc_alloc(class("MPSGraphTensorData"));
msg_send!(Id; alloc, "initWithDevice:data:shape:dataType:",
(Id, Id, Id, u32), mpsgraph_device(), nsdata, shape, dtype)
}
pub unsafe fn tensor_data_shape(td: Id) -> Id {
msg_send!(Id; td, "shape")
}
pub unsafe fn tensor_data_dtype(td: Id) -> u32 {
msg_send!(u32; td, "dataType")
}
pub unsafe fn tensor_data_read_bytes(td: Id, buf: &mut [u8]) {
let ndarray = msg_send!(Id; td, "mpsndarray");
let f: unsafe extern "C" fn(Id, Sel, *mut c_void, *mut isize) =
core::mem::transmute(objc_msgSend as *const ());
f(ndarray, sel("readBytes:strideBytes:"),
buf.as_mut_ptr() as *mut c_void, ptr::null_mut());
}
pub unsafe fn graph_placeholder(graph: Id, shape: Id, dtype: u32) -> Id {
msg_send!(Id; graph, "placeholderWithShape:dataType:name:",
(Id, u32, Id), shape, dtype, NIL)
}
pub unsafe fn graph_constant_scalar(graph: Id, value: f64, dtype: u32) -> Id {
msg_send!(Id; graph, "constantWithScalar:dataType:",
(f64, u32), value, dtype)
}
pub unsafe fn graph_constant_scalar_shape(graph: Id, value: f64, shape: Id, dtype: u32) -> Id {
msg_send!(Id; graph, "constantWithScalar:shape:dataType:",
(f64, Id, u32), value, shape, dtype)
}
pub unsafe fn graph_run(
graph: Id,
feeds: Id, targets: Id, ) -> Id {
msg_send!(Id; graph,
"runWithMTLCommandQueue:feeds:targetTensors:targetOperations:",
(Id, Id, Id, Id), metal_queue(), feeds, targets, NIL)
}
pub unsafe fn graph_unary(graph: Id, sel_name: &str, t: Id) -> Id {
msg_send!(Id; graph, sel_name, (Id, Id), t, NIL)
}
pub unsafe fn graph_binary(graph: Id, sel_name: &str, a: Id, b: Id) -> Id {
msg_send!(Id; graph, sel_name, (Id, Id, Id), a, b, NIL)
}
pub unsafe fn graph_cast(graph: Id, t: Id, dtype: u32) -> Id {
msg_send!(Id; graph, "castTensor:toType:name:", (Id, u32, Id), t, dtype, NIL)
}
pub unsafe fn graph_reshape(graph: Id, t: Id, shape: Id) -> Id {
msg_send!(Id; graph, "reshapeTensor:withShape:name:", (Id, Id, Id), t, shape, NIL)
}
pub unsafe fn graph_transpose(graph: Id, t: Id, dim1: usize, dim2: usize) -> Id {
msg_send!(Id; graph, "transposeTensor:dimension:withDimension:name:",
(Id, usize, usize, Id), t, dim1, dim2, NIL)
}
pub unsafe fn graph_permute(graph: Id, t: Id, perm: Id) -> Id {
msg_send!(Id; graph, "transposeTensor:permutation:name:",
(Id, Id, Id), t, perm, NIL)
}
pub unsafe fn graph_broadcast(graph: Id, t: Id, shape: Id) -> Id {
msg_send!(Id; graph, "broadcastTensor:toShape:name:", (Id, Id, Id), t, shape, NIL)
}
pub unsafe fn graph_slice(graph: Id, t: Id, starts: Id, ends: Id, strides: Id) -> Id {
msg_send!(Id; graph, "sliceTensor:starts:ends:strides:name:",
(Id, Id, Id, Id, Id), t, starts, ends, strides, NIL)
}
pub unsafe fn graph_slice_masked(
graph: Id, t: Id, starts: Id, ends: Id, strides: Id,
start_mask: u32, end_mask: u32, squeeze_mask: u32,
) -> Id {
let f: unsafe extern "C" fn(Id, Sel, Id, Id, Id, Id, u32, u32, u32, Id) -> Id =
core::mem::transmute(objc_msgSend as *const ());
f(graph,
sel("sliceTensor:starts:ends:strides:startMask:endMask:squeezeMask:name:"),
t, starts, ends, strides, start_mask, end_mask, squeeze_mask, NIL)
}
pub unsafe fn graph_slice_update(
graph: Id, data: Id, update: Id, starts: Id, ends: Id, strides: Id,
) -> Id {
let f: unsafe extern "C" fn(Id, Sel, Id, Id, Id, Id, Id, Id) -> Id =
core::mem::transmute(objc_msgSend as *const ());
f(graph,
sel("sliceUpdateDataTensor:updateTensor:starts:ends:strides:name:"),
data, update, starts, ends, strides, NIL)
}
pub unsafe fn graph_concat(graph: Id, tensors: Id, dim: isize) -> Id {
msg_send!(Id; graph, "concatTensors:dimension:name:",
(Id, isize, Id), tensors, dim, NIL)
}
pub unsafe fn graph_reduction_sum(graph: Id, t: Id, axes: Id) -> Id {
msg_send!(Id; graph, "reductionSumWithTensor:axes:name:",
(Id, Id, Id), t, axes, NIL)
}
pub unsafe fn graph_reduction_sum_axis(graph: Id, t: Id, axis: isize) -> Id {
msg_send!(Id; graph, "reductionSumWithTensor:axis:name:",
(Id, isize, Id), t, axis, NIL)
}
pub unsafe fn graph_reduction_max_axis(graph: Id, t: Id, axis: isize) -> Id {
msg_send!(Id; graph, "reductionMaximumWithTensor:axis:name:",
(Id, isize, Id), t, axis, NIL)
}
pub unsafe fn graph_reduction_min_axis(graph: Id, t: Id, axis: isize) -> Id {
msg_send!(Id; graph, "reductionMinimumWithTensor:axis:name:",
(Id, isize, Id), t, axis, NIL)
}
pub unsafe fn graph_reduction_max(graph: Id, t: Id, axes: Id) -> Id {
msg_send!(Id; graph, "reductionMaximumWithTensor:axes:name:",
(Id, Id, Id), t, axes, NIL)
}
pub unsafe fn graph_reduction_min(graph: Id, t: Id, axes: Id) -> Id {
msg_send!(Id; graph, "reductionMinimumWithTensor:axes:name:",
(Id, Id, Id), t, axes, NIL)
}
pub unsafe fn graph_reduction_prod(graph: Id, t: Id, axes: Id) -> Id {
msg_send!(Id; graph, "reductionProductWithTensor:axes:name:",
(Id, Id, Id), t, axes, NIL)
}
pub unsafe fn graph_reduction_prod_axis(graph: Id, t: Id, axis: isize) -> Id {
msg_send!(Id; graph, "reductionProductWithTensor:axis:name:",
(Id, isize, Id), t, axis, NIL)
}
pub unsafe fn graph_argmax(graph: Id, t: Id, axis: isize) -> Id {
msg_send!(Id; graph, "reductionArgMaximumWithTensor:axis:name:",
(Id, isize, Id), t, axis, NIL)
}
pub unsafe fn graph_argmin(graph: Id, t: Id, axis: isize) -> Id {
msg_send!(Id; graph, "reductionArgMinimumWithTensor:axis:name:",
(Id, isize, Id), t, axis, NIL)
}
pub unsafe fn graph_cumsum(graph: Id, t: Id, axis: isize) -> Id {
msg_send!(Id; graph, "cumulativeSumWithTensor:axis:name:",
(Id, isize, Id), t, axis, NIL)
}
pub unsafe fn graph_cumprod(graph: Id, t: Id, axis: isize) -> Id {
msg_send!(Id; graph, "cumulativeProductWithTensor:axis:name:",
(Id, isize, Id), t, axis, NIL)
}
pub unsafe fn graph_cummin(graph: Id, t: Id, axis: isize) -> Id {
msg_send!(Id; graph, "cumulativeMinimumWithTensor:axis:name:",
(Id, isize, Id), t, axis, NIL)
}
pub unsafe fn graph_cummax(graph: Id, t: Id, axis: isize) -> Id {
msg_send!(Id; graph, "cumulativeMaximumWithTensor:axis:name:",
(Id, isize, Id), t, axis, NIL)
}
pub unsafe fn graph_gather(graph: Id, updates: Id, indices: Id, axis: usize, batch_dims: usize) -> Id {
let f: unsafe extern "C" fn(Id, Sel, Id, Id, usize, usize, Id) -> Id =
core::mem::transmute(objc_msgSend as *const ());
f(graph,
sel("gatherWithUpdatesTensor:indicesTensor:axis:batchDimensions:name:"),
updates, indices, axis, batch_dims, NIL)
}
pub unsafe fn graph_scatter_along(
graph: Id, axis: isize, data: Id, updates: Id, indices: Id, mode: isize,
) -> Id {
let f: unsafe extern "C" fn(Id, Sel, isize, Id, Id, Id, isize, Id) -> Id =
core::mem::transmute(objc_msgSend as *const ());
f(graph,
sel("scatterAlongAxis:withDataTensor:updatesTensor:indicesTensor:mode:name:"),
axis, data, updates, indices, mode, NIL)
}
pub unsafe fn graph_select(graph: Id, pred: Id, true_t: Id, false_t: Id) -> Id {
let f: unsafe extern "C" fn(Id, Sel, Id, Id, Id, Id) -> Id =
core::mem::transmute(objc_msgSend as *const ());
f(graph,
sel("selectWithPredicateTensor:truePredicateTensor:falsePredicateTensor:name:"),
pred, true_t, false_t, NIL)
}
pub unsafe fn graph_sort(graph: Id, t: Id, axis: isize, descending: bool) -> Id {
let f: unsafe extern "C" fn(Id, Sel, Id, isize, bool, Id) -> Id =
core::mem::transmute(objc_msgSend as *const ());
f(graph, sel("sortWithTensor:axis:descending:name:"),
t, axis, descending, NIL)
}
pub unsafe fn graph_argsort(graph: Id, t: Id, axis: isize, descending: bool) -> Id {
let f: unsafe extern "C" fn(Id, Sel, Id, isize, bool, Id) -> Id =
core::mem::transmute(objc_msgSend as *const ());
f(graph, sel("argSortWithTensor:axis:descending:name:"),
t, axis, descending, NIL)
}
pub unsafe fn graph_matmul(graph: Id, a: Id, b: Id) -> Id {
graph_binary(graph, "matrixMultiplicationWithPrimaryTensor:secondaryTensor:name:", a, b)
}
pub unsafe fn conv2d_desc(
sx: usize, sy: usize, dx: usize, dy: usize, groups: usize,
pl: usize, pr: usize, pt: usize, pb: usize,
) -> Id {
let f: unsafe extern "C" fn(Id, Sel, usize,usize,usize,usize,usize,usize,usize,usize,usize,usize,usize,usize) -> Id =
core::mem::transmute(objc_msgSend as *const ());
let desc = f(class("MPSGraphConvolution2DOpDescriptor"),
sel("descriptorWithStrideInX:strideInY:dilationRateInX:dilationRateInY:groups:paddingLeft:paddingRight:paddingTop:paddingBottom:paddingStyle:dataLayout:weightsLayout:"),
sx, sy, dx, dy, groups, pl, pr, pt, pb,
MPSGraphPaddingStyle::EXPLICIT,
MPSGraphTensorNamedDataLayout::NCHW,
MPSGraphTensorNamedDataLayout::OIHW);
assert!(!desc.is_null(), "Failed to create conv2d descriptor (sx={sx},sy={sy},dx={dx},dy={dy},g={groups})");
desc
}
pub unsafe fn graph_conv2d(graph: Id, src: Id, weights: Id, desc: Id) -> Id {
let f: unsafe extern "C" fn(Id, Sel, Id, Id, Id, Id) -> Id =
core::mem::transmute(objc_msgSend as *const ());
f(graph,
sel("convolution2DWithSourceTensor:weightsTensor:descriptor:name:"),
src, weights, desc, NIL)
}
pub unsafe fn graph_conv_transpose2d(graph: Id, src: Id, weights: Id, out_shape: Id, desc: Id) -> Id {
let f: unsafe extern "C" fn(Id, Sel, Id, Id, Id, Id, Id) -> Id =
core::mem::transmute(objc_msgSend as *const ());
f(graph,
sel("convolutionTranspose2DWithSourceTensor:weightsTensor:outputShape:descriptor:name:"),
src, weights, out_shape, desc, NIL)
}
pub unsafe fn pool2d_desc(
kw: usize, kh: usize, sx: usize, sy: usize,
dx: usize, dy: usize,
pl: usize, pr: usize, pt: usize, pb: usize,
) -> Id {
let f: unsafe extern "C" fn(Id, Sel, usize,usize,usize,usize,usize,usize,usize,usize,usize,usize,usize,usize) -> Id =
core::mem::transmute(objc_msgSend as *const ());
f(class("MPSGraphPooling2DOpDescriptor"),
sel("descriptorWithKernelWidth:kernelHeight:strideInX:strideInY:dilationRateInX:dilationRateInY:paddingLeft:paddingRight:paddingTop:paddingBottom:paddingStyle:dataLayout:"),
kw, kh, sx, sy, dx, dy, pl, pr, pt, pb,
MPSGraphPaddingStyle::EXPLICIT,
MPSGraphTensorNamedDataLayout::NCHW)
}
pub unsafe fn pool_desc_set_include_zero_pad(desc: Id, val: bool) {
msg_send!(void; desc, "setIncludeZeroPadToAverage:", (bool), val);
}
pub unsafe fn pool_desc_set_return_indices(desc: Id) {
msg_send!(void; desc, "setReturnIndicesMode:", (isize),
MPSGraphPoolingReturnIndicesMode::GLOBAL_FLATTEN_2D);
msg_send!(void; desc, "setReturnIndicesDataType:", (u32), MPSDataType::INT32);
}
pub unsafe fn graph_avg_pool2d(graph: Id, src: Id, desc: Id) -> Id {
msg_send!(Id; graph, "avgPooling2DWithSourceTensor:descriptor:name:",
(Id, Id, Id), src, desc, NIL)
}
pub unsafe fn graph_avg_pool2d_grad(graph: Id, grad: Id, src: Id, desc: Id) -> Id {
let f: unsafe extern "C" fn(Id, Sel, Id, Id, Id, Id) -> Id =
core::mem::transmute(objc_msgSend as *const ());
f(graph,
sel("avgPooling2DGradientWithGradientTensor:sourceTensor:descriptor:name:"),
grad, src, desc, NIL)
}
pub unsafe fn graph_max_pool2d(graph: Id, src: Id, desc: Id) -> Id {
msg_send!(Id; graph, "maxPooling2DWithSourceTensor:descriptor:name:",
(Id, Id, Id), src, desc, NIL)
}
pub unsafe fn graph_max_pool2d_return_indices(graph: Id, src: Id, desc: Id) -> Id {
msg_send!(Id; graph, "maxPooling2DReturnIndicesWithSourceTensor:descriptor:name:",
(Id, Id, Id), src, desc, NIL)
}
pub unsafe fn graph_max_pool2d_indices_grad(graph: Id, grad: Id, indices: Id, src: Id, desc: Id) -> Id {
let f: unsafe extern "C" fn(Id, Sel, Id, Id, Id, Id, Id) -> Id =
core::mem::transmute(objc_msgSend as *const ());
f(graph,
sel("maxPooling2DReturnIndicesGradientWithGradientTensor:indicesTensor:sourceTensor:descriptor:name:"),
grad, indices, src, desc, NIL)
}
pub unsafe fn graph_resize(
graph: Id, t: Id, size: Id, mode: usize,
center: bool, align_corners: bool,
) -> Id {
let f: unsafe extern "C" fn(Id, Sel, Id, Id, usize, bool, bool, usize, Id) -> Id =
core::mem::transmute(objc_msgSend as *const ());
f(graph,
sel("resizeTensor:size:mode:centerResult:alignCorners:layout:name:"),
t, size, mode, center, align_corners,
MPSGraphTensorNamedDataLayout::NCHW, NIL)
}
pub unsafe fn graph_resize_grad(
graph: Id, grad: Id, input: Id, mode: usize,
center: bool, align_corners: bool,
) -> Id {
let f: unsafe extern "C" fn(Id, Sel, Id, Id, usize, bool, bool, usize, Id) -> Id =
core::mem::transmute(objc_msgSend as *const ());
f(graph,
sel("resizeWithGradientTensor:input:mode:centerResult:alignCorners:layout:name:"),
grad, input, mode, center, align_corners,
MPSGraphTensorNamedDataLayout::NCHW, NIL)
}