use crate::allocator::Allocator;
use crate::memory::MemoryInfo;
use crate::session_options::SessionOptions;
use crate::tensor::TensorView;
use crate::{api, check, sys, Error, Result};
use std::ffi::{c_char, c_int, c_void, CString};
use std::marker::PhantomData;
use std::ptr;
unsafe fn fetch_sized_string(
fill: impl Fn(*mut c_char, *mut usize) -> sys::StatusPtr,
) -> Result<String> {
let mut size: usize = 0;
let probe = fill(ptr::null_mut(), &mut size);
if !probe.is_null() {
api().release_status()(probe);
}
let mut buf = vec![0u8; size];
check(fill(buf.as_mut_ptr() as *mut c_char, &mut size))?;
trim_nul(&buf, size)
}
unsafe fn fetch_sized_array<T: Copy + Default>(
fill: impl Fn(*mut T, *mut usize) -> sys::StatusPtr,
) -> Result<Vec<T>> {
let mut count: usize = 0;
let probe = fill(ptr::null_mut(), &mut count);
if !probe.is_null() {
api().release_status()(probe);
}
let mut buf = vec![T::default(); count];
check(fill(buf.as_mut_ptr(), &mut count))?;
Ok(buf)
}
fn trim_nul(buf: &[u8], size: usize) -> Result<String> {
if size == 0 {
return Ok(String::new());
}
let end = if buf[size - 1] == 0 { size - 1 } else { size };
std::str::from_utf8(&buf[..end])
.map(str::to_owned)
.map_err(|_| Error::new(-1, "zrt: custom-op string is not valid UTF-8"))
}
fn cstring(s: &str) -> Result<CString> {
CString::new(s).map_err(|_| Error::new(-1, "custom-op string contains a NUL byte"))
}
fn usize_to_c_int(value: usize, what: &'static str) -> Result<c_int> {
c_int::try_from(value).map_err(|_| Error::new(-1, format!("zrt: {what} exceeds c_int::MAX")))
}
pub struct OpAttr {
ptr: *mut sys::OpAttrHandle,
}
impl OpAttr {
pub fn new(name: &str, data: &[u8], len: usize, ty: sys::OpAttrType) -> Result<Self> {
let name = cstring(name)?;
let len = usize_to_c_int(len, "custom-op attribute length")?;
let api = api();
let mut out: *mut sys::OpAttrHandle = ptr::null_mut();
check(unsafe {
api.create_op_attr()(
name.as_ptr(),
data.as_ptr() as *const c_void,
len,
ty,
&mut out,
)
})?;
let out = crate::ensure_non_null(out, "custom-op attribute")?;
Ok(Self { ptr: out })
}
pub fn new_float(name: &str, value: f32) -> Result<Self> {
Self::new(
name,
value.to_ne_bytes().as_slice(),
1usize,
sys::OpAttrType::Float,
)
}
pub fn new_int(name: &str, value: i64) -> Result<Self> {
Self::new(
name,
value.to_ne_bytes().as_slice(),
1usize,
sys::OpAttrType::Int,
)
}
pub fn new_string(name: &str, value: &str) -> Result<Self> {
Self::new(name, value.as_bytes(), value.len(), sys::OpAttrType::String)
}
pub fn new_ints(name: &str, values: &[i64]) -> Result<Self> {
Self::new(
name,
unsafe {
std::slice::from_raw_parts(
values.as_ptr() as *const u8,
std::mem::size_of_val(values),
)
},
values.len(),
sys::OpAttrType::Ints,
)
}
pub fn new_floats(name: &str, values: &[f32]) -> Result<Self> {
Self::new(
name,
unsafe {
std::slice::from_raw_parts(
values.as_ptr() as *const u8,
std::mem::size_of_val(values),
)
},
values.len(),
sys::OpAttrType::Floats,
)
}
pub fn ty(&self) -> Result<sys::OpAttrType> {
let mut out = sys::OpAttrType::Undefined;
check(unsafe { api().op_attr__get_type()(self.ptr, &mut out) })?;
Ok(out)
}
pub fn name(&self) -> Result<String> {
let p: *const c_char = ptr::null();
check(unsafe { api().op_attr__get_name()(self.ptr, &p) })?;
Ok(if p.is_null() {
String::new()
} else {
unsafe { crate::cstr_to_string(p, "custom-op attribute name") }?
})
}
pub fn read_into(&self, ty: sys::OpAttrType, buf: &mut [u8]) -> Result<usize> {
let mut written: usize = 0;
check(unsafe {
api().read_op_attr()(
self.ptr,
ty,
buf.as_mut_ptr() as *mut c_void,
buf.len(),
&mut written,
)
})?;
Ok(written)
}
}
impl Drop for OpAttr {
fn drop(&mut self) {
unsafe { api().release_op_attr()(self.ptr) }
}
}
unsafe impl Send for OpAttr {}
unsafe impl Sync for OpAttr {}
pub struct CustomOpDomain {
ptr: *mut sys::CustomOpDomainHandle,
}
impl CustomOpDomain {
pub fn new(domain: &str) -> Result<Self> {
let domain = cstring(domain)?;
let mut out: *mut sys::CustomOpDomainHandle = ptr::null_mut();
check(unsafe { api().create_custom_op_domain()(domain.as_ptr(), &mut out) })?;
let out = crate::ensure_non_null(out, "custom-op domain")?;
Ok(Self { ptr: out })
}
pub unsafe fn add_raw(&self, op: *const sys::CustomOpHandle) -> Result<()> {
check(api().custom_op_domain__add()(self.ptr, op))
}
pub fn add_op(&self, vtable: &'static sys::OrtCustomOp) -> Result<()> {
unsafe { self.add_raw(vtable as *const sys::OrtCustomOp as *const sys::CustomOpHandle) }
}
}
impl Drop for CustomOpDomain {
fn drop(&mut self) {
unsafe { api().release_custom_op_domain()(self.ptr) }
}
}
unsafe impl Send for CustomOpDomain {}
unsafe impl Sync for CustomOpDomain {}
impl SessionOptions {
#[cfg(feature = "custom-ops")]
pub fn with_custom_op_domain(mut self, domain: &CustomOpDomain) -> Self {
self.custom_op_domains.push(domain.ptr);
self
}
}
#[derive(Clone, Copy, Debug)]
pub struct OpIoSpec {
pub element_type: sys::ElementType,
pub characteristic: sys::CustomOpInputOutputCharacteristic,
pub memory_type: sys::MemType,
}
impl OpIoSpec {
pub const fn required(ty: sys::ElementType) -> Self {
Self {
element_type: ty,
characteristic: sys::CustomOpInputOutputCharacteristic::Required,
memory_type: sys::MemType::Default,
}
}
pub const fn optional(ty: sys::ElementType) -> Self {
Self {
element_type: ty,
characteristic: sys::CustomOpInputOutputCharacteristic::Optional,
memory_type: sys::MemType::Default,
}
}
pub const fn required_on(ty: sys::ElementType, mem: sys::MemType) -> Self {
Self {
element_type: ty,
characteristic: sys::CustomOpInputOutputCharacteristic::Required,
memory_type: mem,
}
}
}
pub trait CustomOp: Sized + Send + 'static {
const NAME: &'static str;
const DOMAIN: &'static str = "";
const SINCE_VERSION: i32 = 1;
const END_VERSION: i32 = sys::API_VERSION as i32;
fn create(info: &KernelInfo<'_>) -> Result<Self>;
fn compute(&mut self, ctx: &KernelContext<'_>) -> Result<()>;
fn infer_shapes(_ctx: &ShapeInferContext<'_>) -> Result<()> {
Ok(())
}
fn inputs() -> &'static [OpIoSpec];
fn outputs() -> &'static [OpIoSpec];
fn execution_provider_type() -> Option<&'static str> {
None
}
fn variadic_input_min_arity() -> std::os::raw::c_int {
1
}
fn variadic_input_homogeneity() -> bool {
false
}
fn variadic_output_min_arity() -> std::os::raw::c_int {
1
}
fn variadic_output_homogeneity() -> bool {
false
}
}
pub struct KernelInfo<'a> {
ptr: *const sys::KernelInfoHandle,
_life: PhantomData<&'a ()>,
}
impl<'a> KernelInfo<'a> {
pub unsafe fn from_ptr(ptr: *const sys::KernelInfoHandle) -> Self {
Self {
ptr,
_life: PhantomData,
}
}
pub fn input_count(&self) -> Result<usize> {
let mut out = 0usize;
check(unsafe { api().kernel_info__get_input_count()(self.ptr, &mut out) })?;
Ok(out)
}
pub fn output_count(&self) -> Result<usize> {
let mut out = 0usize;
check(unsafe { api().kernel_info__get_output_count()(self.ptr, &mut out) })?;
Ok(out)
}
pub fn input_name(&self, index: usize) -> Result<String> {
unsafe {
fetch_sized_string(|out, size| {
api().kernel_info__get_input_name()(self.ptr, index, out, size)
})
}
}
pub fn output_name(&self, index: usize) -> Result<String> {
unsafe {
fetch_sized_string(|out, size| {
api().kernel_info__get_output_name()(self.ptr, index, out, size)
})
}
}
pub fn input_type_info(&self, index: usize) -> Result<*mut sys::TypeInfoHandle> {
let mut out: *mut sys::TypeInfoHandle = ptr::null_mut();
check(unsafe { api().kernel_info__get_input_type_info()(self.ptr, index, &mut out) })?;
crate::ensure_non_null(out, "kernel input type info")
}
pub fn output_type_info(&self, index: usize) -> Result<*mut sys::TypeInfoHandle> {
let mut out: *mut sys::TypeInfoHandle = ptr::null_mut();
check(unsafe { api().kernel_info__get_output_type_info()(self.ptr, index, &mut out) })?;
crate::ensure_non_null(out, "kernel output type info")
}
pub fn node_name(&self) -> Result<String> {
unsafe {
fetch_sized_string(|out, size| api().kernel_info__get_node_name()(self.ptr, out, size))
}
}
pub fn operator_domain(&self) -> Result<String> {
unsafe {
fetch_sized_string(|out, size| {
api().kernel_info__get_operator_domain()(self.ptr, out, size)
})
}
}
pub fn operator_type(&self) -> Result<String> {
unsafe {
fetch_sized_string(|out, size| {
api().kernel_info__get_operator_type()(self.ptr, out, size)
})
}
}
pub fn operator_since_version(&self) -> Result<i32> {
let mut out: c_int = 0;
check(unsafe { api().kernel_info__get_operator_since_version()(self.ptr, &mut out) })?;
Ok(out)
}
pub fn attr_float(&self, name: &str) -> Result<f32> {
let name = cstring(name)?;
let mut out: f32 = 0.0;
check(unsafe {
api().kernel_info_get_attribute_float()(self.ptr, name.as_ptr(), &mut out)
})?;
Ok(out)
}
pub fn attr_int64(&self, name: &str) -> Result<i64> {
let name = cstring(name)?;
let mut out: i64 = 0;
check(unsafe {
api().kernel_info_get_attribute_int64()(self.ptr, name.as_ptr(), &mut out)
})?;
Ok(out)
}
pub fn attr_string(&self, name: &str) -> Result<String> {
let name = cstring(name)?;
unsafe {
fetch_sized_string(|out, size| {
api().kernel_info_get_attribute_string()(self.ptr, name.as_ptr(), out, size)
})
}
}
pub fn attr_floats(&self, name: &str) -> Result<Vec<f32>> {
let name = cstring(name)?;
unsafe {
fetch_sized_array(|out, size| {
api().kernel_info_get_attribute_array_float()(self.ptr, name.as_ptr(), out, size)
})
}
}
pub fn attr_int64s(&self, name: &str) -> Result<Vec<i64>> {
let name = cstring(name)?;
unsafe {
fetch_sized_array(|out, size| {
api().kernel_info_get_attribute_array_int64()(self.ptr, name.as_ptr(), out, size)
})
}
}
pub fn attr_tensor(&self, name: &str, allocator: &Allocator) -> Result<*mut sys::ValueHandle> {
let name = cstring(name)?;
let mut out: *mut sys::ValueHandle = ptr::null_mut();
check(unsafe {
api().kernel_info_get_attribute_tensor()(
self.ptr,
name.as_ptr(),
allocator.alloc,
&mut out,
)
})?;
Ok(out)
}
pub fn constant_input_tensor(&self, index: usize) -> Result<(bool, *const sys::ValueHandle)> {
let mut is_const: c_int = 0;
let out: *const sys::ValueHandle = ptr::null();
check(unsafe {
api().kernel_info_get_constant_input_tensor()(self.ptr, index, &mut is_const, &out)
})?;
Ok((is_const != 0, out))
}
pub fn logger(&self) -> Result<*const sys::LoggerHandle> {
let out: *const sys::LoggerHandle = ptr::null();
check(unsafe { api().kernel_info__get_logger()(self.ptr, &out) })?;
Ok(out)
}
pub fn allocator(&self, mem_type: sys::MemType) -> Result<*mut sys::AllocatorHandle> {
let mut out: *mut sys::AllocatorHandle = ptr::null_mut();
check(unsafe { api().kernel_info_get_allocator()(self.ptr, mem_type, &mut out) })?;
crate::ensure_non_null(out, "kernel allocator")
}
pub fn to_owned(&self) -> Result<OwnedKernelInfo> {
let mut out: *mut sys::KernelInfoHandle = ptr::null_mut();
check(unsafe { api().copy_kernel_info()(self.ptr, &mut out) })?;
let out = crate::ensure_non_null(out, "kernel info")?;
Ok(OwnedKernelInfo { ptr: out })
}
}
pub struct OwnedKernelInfo {
ptr: *mut sys::KernelInfoHandle,
}
impl OwnedKernelInfo {
pub fn as_ptr(&self) -> *const sys::KernelInfoHandle {
self.ptr
}
pub fn as_ref(&self) -> KernelInfo<'_> {
KernelInfo {
ptr: self.ptr,
_life: PhantomData,
}
}
}
impl Drop for OwnedKernelInfo {
fn drop(&mut self) {
unsafe { api().release_kernel_info()(self.ptr) }
}
}
unsafe impl Send for OwnedKernelInfo {}
unsafe impl Sync for OwnedKernelInfo {}
pub struct KernelContext<'a> {
ptr: *const sys::KernelContextHandle,
_life: PhantomData<&'a ()>,
}
impl<'a> KernelContext<'a> {
pub unsafe fn from_ptr(ptr: *const sys::KernelContextHandle) -> Self {
Self {
ptr,
_life: PhantomData,
}
}
pub fn input_count(&self) -> Result<usize> {
let mut out = 0usize;
check(unsafe { api().kernel_context__get_input_count()(self.ptr, &mut out) })?;
Ok(out)
}
pub fn output_count(&self) -> Result<usize> {
let mut out = 0usize;
check(unsafe { api().kernel_context__get_output_count()(self.ptr, &mut out) })?;
Ok(out)
}
pub fn input(&self, index: usize) -> Result<Option<TensorView<'a>>> {
let out: *const sys::ValueHandle = ptr::null();
check(unsafe { api().kernel_context__get_input()(self.ptr, index, &out) })?;
Ok(if out.is_null() {
None
} else {
Some(TensorView {
value: out as *mut sys::ValueHandle,
_life: PhantomData,
})
})
}
pub fn output_mut<T: crate::element::TensorElement>(
&self, index: usize, dims: &[i64], f: impl FnOnce(&mut [T]) -> Result<()>,
) -> Result<()> {
let mut out: *mut sys::ValueHandle = ptr::null_mut();
check(unsafe {
api().kernel_context__get_output()(
self.ptr as *mut sys::KernelContextHandle,
index,
dims.as_ptr(),
dims.len(),
&mut out,
)
})?;
let out = crate::ensure_non_null(out, "custom-op output value")?;
let mut data: *mut std::ffi::c_void = ptr::null_mut();
check(unsafe { api().get_tensor_mutable_data()(out, &mut data) })?;
let count = crate::type_info::checked_element_count(dims)?;
let data = crate::slice_data_ptr(data as *mut T, count, "custom-op output data")?;
let slice = unsafe { std::slice::from_raw_parts_mut(data, count) };
f(slice)
}
pub fn logger(&self) -> Result<*const sys::LoggerHandle> {
let out: *const sys::LoggerHandle = ptr::null();
check(unsafe { api().kernel_context__get_logger()(self.ptr, &out) })?;
Ok(out)
}
pub fn allocator(&self, mem: &MemoryInfo) -> Result<*mut sys::AllocatorHandle> {
let mut out: *mut sys::AllocatorHandle = ptr::null_mut();
check(unsafe { api().kernel_context__get_allocator()(self.ptr, mem.info, &mut out) })?;
crate::ensure_non_null(out, "kernel context allocator")
}
pub fn scratch_buffer(&self, mem: &MemoryInfo, count_or_bytes: usize) -> Result<*mut c_void> {
let mut out: *mut c_void = ptr::null_mut();
check(unsafe {
api().kernel_context__get_scratch_buffer()(self.ptr, mem.info, count_or_bytes, &mut out)
})?;
Ok(out)
}
pub fn gpu_compute_stream(&self) -> Result<*mut c_void> {
let mut out: *mut c_void = ptr::null_mut();
check(unsafe { api().kernel_context__get_gpu_compute_stream()(self.ptr, &mut out) })?;
Ok(out)
}
pub fn resource(&self, resource_version: c_int, resource_id: c_int) -> Result<*mut c_void> {
let mut out: *mut c_void = ptr::null_mut();
check(unsafe {
api().kernel_context__get_resource()(self.ptr, resource_version, resource_id, &mut out)
})?;
Ok(out)
}
pub fn parallel_for<F>(&self, total: usize, num_batch: usize, f: F) -> Result<()>
where
F: Fn(usize) + Send + Sync,
{
let data = Box::into_raw(Box::new(f)) as *mut c_void;
let res = check(unsafe {
api().kernel_context__parallel_for()(
self.ptr,
Some(parallel_for_trampoline::<F>),
total,
num_batch,
data,
)
});
unsafe { drop(Box::from_raw(data as *mut F)) };
res
}
}
unsafe extern "C" fn parallel_for_trampoline<F>(data: *mut c_void, index: usize)
where
F: Fn(usize) + Send + Sync,
{
if !data.is_null() {
(&*(data as *const F))(index);
}
}
pub struct ShapeInferContext<'a> {
ptr: *const sys::ShapeInferContextHandle,
_life: PhantomData<&'a ()>,
}
impl<'a> ShapeInferContext<'a> {
pub unsafe fn from_ptr(ptr: *const sys::ShapeInferContextHandle) -> Self {
Self {
ptr,
_life: PhantomData,
}
}
pub fn input_count(&self) -> Result<usize> {
let mut out = 0usize;
check(unsafe { api().shape_infer_context__get_input_count()(self.ptr, &mut out) })?;
Ok(out)
}
pub fn input_type_shape(
&self, index: usize,
) -> Result<crate::type_info::TensorTypeAndShapeInfo> {
let mut info: *mut sys::TensorTypeAndShapeInfoHandle = ptr::null_mut();
check(unsafe {
api().shape_infer_context__get_input_type_shape()(self.ptr, index, &mut info)
})?;
let info = crate::ensure_non_null(info, "tensor type and shape info")?;
Ok(unsafe { crate::type_info::TensorTypeAndShapeInfo::from_owning(info) })
}
pub fn set_output_type_shape(
&self, index: usize, info: crate::type_info::TensorTypeAndShapeInfo,
) -> Result<()> {
let info_ptr = info.as_ptr();
let res = check(unsafe {
api().shape_infer_context__set_output_type_shape()(self.ptr, index, info_ptr)
});
if res.is_ok() {
std::mem::forget(info);
}
res
}
}
pub struct Op {
ptr: *mut sys::OpHandle,
}
impl Op {
#[allow(clippy::too_many_arguments)] pub fn create(
info: &KernelInfo<'_>, op_name: &str, domain: &str, version: i32,
type_constraints: &[(&str, sys::ElementType)], attrs: &[&OpAttr], input_count: usize,
output_count: usize,
) -> Result<Self> {
let op_name = cstring(op_name)?;
let domain = cstring(domain)?;
let type_constraint_count =
usize_to_c_int(type_constraints.len(), "custom-op type constraint count")?;
let attr_count = usize_to_c_int(attrs.len(), "custom-op attribute count")?;
let input_count = usize_to_c_int(input_count, "custom-op input count")?;
let output_count = usize_to_c_int(output_count, "custom-op output count")?;
let tc_names: Vec<CString> = type_constraints
.iter()
.map(|(n, _)| cstring(n))
.collect::<Result<_>>()?;
let tc_name_ptrs: Vec<*const c_char> = tc_names.iter().map(|c| c.as_ptr()).collect();
let tc_vals: Vec<sys::ElementType> = type_constraints.iter().map(|(_, t)| *t).collect();
let attr_ptrs: Vec<*const sys::OpAttrHandle> = attrs
.iter()
.map(|a| a.ptr as *const sys::OpAttrHandle)
.collect();
let mut out: *mut sys::OpHandle = ptr::null_mut();
check(unsafe {
api().create_op()(
info.ptr,
op_name.as_ptr(),
domain.as_ptr(),
version,
tc_name_ptrs.as_ptr(),
tc_vals.as_ptr(),
type_constraint_count,
attr_ptrs.as_ptr(),
attr_count,
input_count,
output_count,
&mut out,
)
})?;
let out = crate::ensure_non_null(out, "custom-op native op")?;
Ok(Self { ptr: out })
}
pub fn invoke(
&self, ctx: &KernelContext<'_>, inputs: &[&TensorView<'_>],
outputs: &mut [&mut TensorView<'_>],
) -> Result<()> {
let in_ptrs: Vec<*const sys::ValueHandle> = inputs
.iter()
.map(|t| t.value as *const sys::ValueHandle)
.collect();
let mut out_ptrs: Vec<*mut sys::ValueHandle> =
outputs.iter_mut().map(|t| t.value).collect();
let input_count = usize_to_c_int(in_ptrs.len(), "custom-op invoke input count")?;
let output_count = usize_to_c_int(out_ptrs.len(), "custom-op invoke output count")?;
check(unsafe {
api().invoke_op()(
ctx.ptr,
self.ptr,
in_ptrs.as_ptr(),
input_count,
out_ptrs.as_mut_ptr(),
output_count,
)
})
}
}
impl Drop for Op {
fn drop(&mut self) {
unsafe { api().release_op()(self.ptr) }
}
}
unsafe impl Send for Op {}
unsafe impl Sync for Op {}
#[doc(hidden)]
pub mod __priv {
use super::{CustomOp, KernelContext, KernelInfo, ShapeInferContext};
use crate::{api, sys, Error};
use std::ffi::{c_char, c_void, CString};
use std::os::raw::c_int;
use std::panic::AssertUnwindSafe;
use std::ptr;
pub fn error_to_status(e: Error) -> sys::StatusPtr {
match CString::new(e.to_string()) {
Ok(msg) => unsafe {
api().create_status()(sys::OrtErrorCode::Fail as c_int, msg.as_ptr())
},
Err(_) => {
static FALLBACK: &[u8] = b"st-zrt custom-op error (NUL in message)\0";
unsafe {
api().create_status()(
sys::OrtErrorCode::Fail as c_int,
FALLBACK.as_ptr() as *const c_char,
)
}
},
}
}
pub unsafe extern "C" fn create_kernel<T: CustomOp>(
_op: *const c_void, _api: *const c_void, info: *const c_void, kernel_out: *mut *mut c_void,
) -> sys::StatusPtr {
let res = std::panic::catch_unwind(AssertUnwindSafe(|| {
let info = unsafe { KernelInfo::from_ptr(info as *const sys::KernelInfoHandle) };
T::create(&info)
}));
match res {
Ok(Ok(kernel)) => {
unsafe { *kernel_out = Box::into_raw(Box::new(kernel)) as *mut c_void };
ptr::null_mut()
},
Ok(Err(e)) => error_to_status(e),
Err(_) => error_to_status(Error::new(
sys::OrtErrorCode::Fail as i32,
"st-zrt custom-op create panicked",
)),
}
}
pub unsafe extern "C" fn compute<T: CustomOp>(
kernel: *mut c_void, ctx: *mut c_void,
) -> sys::StatusPtr {
let res = std::panic::catch_unwind(AssertUnwindSafe(|| {
if kernel.is_null() {
return Ok(());
}
let k = unsafe { &mut *(kernel as *mut T) };
let ctx = unsafe { KernelContext::from_ptr(ctx as *const sys::KernelContextHandle) };
k.compute(&ctx)
}));
match res {
Ok(Ok(())) => ptr::null_mut(),
Ok(Err(e)) => error_to_status(e),
Err(_) => error_to_status(Error::new(
sys::OrtErrorCode::Fail as i32,
"st-zrt custom-op compute panicked",
)),
}
}
pub unsafe extern "C" fn destroy<T: CustomOp>(kernel: *mut c_void) {
if kernel.is_null() {
return;
}
let res = std::panic::catch_unwind(AssertUnwindSafe(|| unsafe {
drop(Box::from_raw(kernel as *mut T));
}));
if res.is_err() {
std::process::abort();
}
}
pub unsafe extern "C" fn infer_output_shape<T: CustomOp>(
_op: *const c_void, ctx: *mut c_void,
) -> sys::StatusPtr {
let res = std::panic::catch_unwind(AssertUnwindSafe(|| {
let ctx =
unsafe { ShapeInferContext::from_ptr(ctx as *const sys::ShapeInferContextHandle) };
T::infer_shapes(&ctx)
}));
match res {
Ok(Ok(())) => ptr::null_mut(),
Ok(Err(e)) => error_to_status(e),
Err(_) => error_to_status(Error::new(
sys::OrtErrorCode::Fail as i32,
"st-zrt custom-op infer_shapes panicked",
)),
}
}
pub unsafe extern "C" fn get_input_type_count<T: CustomOp>(_op: *const c_void) -> usize {
<T as CustomOp>::inputs().len()
}
pub unsafe extern "C" fn get_input_type<T: CustomOp>(_op: *const c_void, index: usize) -> i32 {
<T as CustomOp>::inputs()
.get(index)
.map_or(sys::ElementType::Undefined as i32, |s| {
s.element_type as i32
})
}
pub unsafe extern "C" fn get_output_type_count<T: CustomOp>(_op: *const c_void) -> usize {
<T as CustomOp>::outputs().len()
}
pub unsafe extern "C" fn get_output_type<T: CustomOp>(_op: *const c_void, index: usize) -> i32 {
<T as CustomOp>::outputs()
.get(index)
.map_or(sys::ElementType::Undefined as i32, |s| {
s.element_type as i32
})
}
pub unsafe extern "C" fn get_input_characteristic<T: CustomOp>(
_op: *const c_void, index: usize,
) -> i32 {
<T as CustomOp>::inputs().get(index).map_or(
sys::CustomOpInputOutputCharacteristic::Required as i32,
|s| s.characteristic as i32,
)
}
pub unsafe extern "C" fn get_output_characteristic<T: CustomOp>(
_op: *const c_void, index: usize,
) -> i32 {
<T as CustomOp>::outputs().get(index).map_or(
sys::CustomOpInputOutputCharacteristic::Required as i32,
|s| s.characteristic as i32,
)
}
pub unsafe extern "C" fn get_input_memory_type<T: CustomOp>(
_op: *const c_void, index: usize,
) -> i32 {
<T as CustomOp>::inputs()
.get(index)
.map_or(sys::MemType::Default as i32, |s| s.memory_type as i32)
}
pub unsafe extern "C" fn get_execution_provider_type<T: CustomOp>(
_op: *const c_void,
) -> *const c_char {
let _ = <T as CustomOp>::execution_provider_type();
ptr::null()
}
pub unsafe extern "C" fn get_variadic_input_min_arity<T: CustomOp>(
_op: *const c_void,
) -> c_int {
<T as CustomOp>::variadic_input_min_arity()
}
pub unsafe extern "C" fn get_variadic_input_homogeneity<T: CustomOp>(
_op: *const c_void,
) -> c_int {
<T as CustomOp>::variadic_input_homogeneity() as c_int
}
pub unsafe extern "C" fn get_variadic_output_min_arity<T: CustomOp>(
_op: *const c_void,
) -> c_int {
<T as CustomOp>::variadic_output_min_arity()
}
pub unsafe extern "C" fn get_variadic_output_homogeneity<T: CustomOp>(
_op: *const c_void,
) -> c_int {
<T as CustomOp>::variadic_output_homogeneity() as c_int
}
pub unsafe extern "C" fn get_start_version<T: CustomOp>(_op: *const c_void) -> c_int {
<T as CustomOp>::SINCE_VERSION
}
pub unsafe extern "C" fn get_end_version<T: CustomOp>(_op: *const c_void) -> c_int {
<T as CustomOp>::END_VERSION
}
}
#[macro_export]
macro_rules! custom_op {
($T:ty, $name:literal, as $vtable:ident $(,)?) => {
impl $T {
#[doc(hidden)]
unsafe extern "C" fn __zrt_custom_op_get_name(
_op: *const ::std::ffi::c_void,
) -> *const ::std::os::raw::c_char {
concat!($name, "\0").as_ptr() as *const ::std::os::raw::c_char
}
}
#[doc = concat!(" `OrtCustomOp` vtable for `", stringify!($T), "`.")]
pub static $vtable: $crate::sys::OrtCustomOp = $crate::sys::OrtCustomOp {
version: $crate::sys::API_VERSION,
create_kernel: None,
get_name: Some(<$T>::__zrt_custom_op_get_name),
get_execution_provider_type: Some($crate::__priv::get_execution_provider_type::<$T>),
get_input_type: Some($crate::__priv::get_input_type::<$T>),
get_input_type_count: Some($crate::__priv::get_input_type_count::<$T>),
get_output_type: Some($crate::__priv::get_output_type::<$T>),
get_output_type_count: Some($crate::__priv::get_output_type_count::<$T>),
kernel_compute: None,
kernel_destroy: Some($crate::__priv::destroy::<$T>),
get_input_characteristic: Some($crate::__priv::get_input_characteristic::<$T>),
get_output_characteristic: Some($crate::__priv::get_output_characteristic::<$T>),
get_input_memory_type: Some($crate::__priv::get_input_memory_type::<$T>),
get_variadic_input_min_arity: Some($crate::__priv::get_variadic_input_min_arity::<$T>),
get_variadic_input_homogeneity: Some(
$crate::__priv::get_variadic_input_homogeneity::<$T>,
),
get_variadic_output_min_arity: Some(
$crate::__priv::get_variadic_output_min_arity::<$T>,
),
get_variadic_output_homogeneity: Some(
$crate::__priv::get_variadic_output_homogeneity::<$T>,
),
create_kernel_v2: Some($crate::__priv::create_kernel::<$T>),
kernel_compute_v2: Some($crate::__priv::compute::<$T>),
infer_output_shape_fn: Some($crate::__priv::infer_output_shape::<$T>),
get_start_version: Some($crate::__priv::get_start_version::<$T>),
get_end_version: Some($crate::__priv::get_end_version::<$T>),
get_may_inplace: None,
release_may_inplace: None,
get_alias_map: None,
release_alias_map: None,
};
};
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn op_attr_round_trip() {
let f = OpAttr::new_float("alpha", 1.5).expect("new_float");
assert_eq!(f.ty().unwrap(), sys::OpAttrType::Float);
assert_eq!(f.name().unwrap(), "alpha");
let mut buf = [0u8; 4];
let n = f.read_into(sys::OpAttrType::Float, &mut buf).unwrap();
assert_eq!(n, 4);
assert_eq!(f32::from_ne_bytes(buf), 1.5);
let i = OpAttr::new_int("count", 42).expect("new_int");
let mut b = [0u8; 8];
i.read_into(sys::OpAttrType::Int, &mut b).unwrap();
assert_eq!(i64::from_ne_bytes(b), 42);
let s = OpAttr::new_string("mode", "fast").expect("new_string");
assert_eq!(s.ty().unwrap(), sys::OpAttrType::String);
let mut b = [0u8; 16];
let n = s.read_into(sys::OpAttrType::String, &mut b).unwrap();
assert_eq!(&b[..n], b"fast");
let arr = OpAttr::new_ints("dims", &[3, 5, 7]).expect("new_ints");
assert_eq!(arr.ty().unwrap(), sys::OpAttrType::Ints);
eprintln!("op_attr_round_trip: float/int/string/ints all round-tripped + released");
}
#[test]
fn custom_op_domain_lifecycle() {
let domain = CustomOpDomain::new("com.example.foo").expect("new domain");
let opts = SessionOptions::default().with_custom_op_domain(&domain);
let h = opts.build_handle().expect("build_handle");
unsafe {
crate::api().release_session_options()(h);
}
drop(domain); eprintln!("custom_op_domain_lifecycle: create + attach + release clean");
}
struct TestOp;
impl CustomOp for TestOp {
const NAME: &'static str = "TestOp";
const DOMAIN: &'static str = "com.example.test";
const SINCE_VERSION: i32 = 7;
fn create(_info: &KernelInfo<'_>) -> Result<Self> {
Ok(Self)
}
fn compute(&mut self, _ctx: &KernelContext<'_>) -> Result<()> {
Ok(())
}
fn inputs() -> &'static [OpIoSpec] {
static INPUTS: [OpIoSpec; 2] = [
OpIoSpec::required(sys::ElementType::Float),
OpIoSpec::optional(sys::ElementType::Int64),
];
&INPUTS
}
fn outputs() -> &'static [OpIoSpec] {
static OUTPUTS: [OpIoSpec; 1] = [OpIoSpec::required(sys::ElementType::Float)];
&OUTPUTS
}
}
crate::custom_op!(TestOp, "TestOp", as TEST_OP_VTABLE);
#[test]
fn custom_op_vtable_schema() {
let v = &TEST_OP_VTABLE;
unsafe {
let name_ptr = (v.get_name.unwrap())(std::ptr::null());
assert_eq!(
std::ffi::CStr::from_ptr(name_ptr).to_bytes(),
b"TestOp",
"get_name"
);
assert_eq!(
(v.get_input_type_count.unwrap())(std::ptr::null()),
2,
"input count"
);
assert_eq!(
(v.get_input_type.unwrap())(std::ptr::null(), 0),
sys::ElementType::Float as i32,
"input[0] type"
);
assert_eq!(
(v.get_input_type.unwrap())(std::ptr::null(), 1),
sys::ElementType::Int64 as i32,
"input[1] type"
);
assert_eq!(
(v.get_output_type_count.unwrap())(std::ptr::null()),
1,
"output count"
);
assert_eq!(
(v.get_output_type.unwrap())(std::ptr::null(), 0),
sys::ElementType::Float as i32,
"output[0] type"
);
assert_eq!(
(v.get_input_characteristic.unwrap())(std::ptr::null(), 1),
sys::CustomOpInputOutputCharacteristic::Optional as i32,
"input[1] optional"
);
assert_eq!(
(v.get_start_version.unwrap())(std::ptr::null()),
7,
"since version"
);
assert_eq!(
(v.get_end_version.unwrap())(std::ptr::null()),
sys::API_VERSION as i32,
"end version"
);
assert!(
(v.get_execution_provider_type.unwrap())(std::ptr::null()).is_null(),
"EP type null (CPU)"
);
}
eprintln!(
"custom_op_vtable_schema: name/schema/versions correct via direct callback calls"
);
}
#[test]
fn custom_op_vtable_registration() {
let domain = CustomOpDomain::new(TestOp::DOMAIN).expect("new domain");
domain.add_op(&TEST_OP_VTABLE).expect("add_op");
let opts = SessionOptions::default().with_custom_op_domain(&domain);
let h = opts.build_handle().expect("build_handle");
unsafe {
crate::api().release_session_options()(h);
}
drop(domain);
eprintln!("custom_op_vtable_registration: add_op + AddCustomOpDomain + release clean");
}
}