use alloc::{boxed::Box, ffi::CString, vec::Vec};
use core::ptr::{self, NonNull};
mod attribute;
pub(crate) mod bound;
mod io;
mod kernel;
#[cfg(test)]
mod tests;
use self::bound::BoundOperator;
pub use self::{
attribute::{Attribute, FromKernelAttributes, FromOpAttr, ToAttribute},
io::{InputOutputCharacteristic, OperatorInput, OperatorOutput},
kernel::{Kernel, KernelAttributes, KernelContext, ScratchBuffer}
};
use crate::{
AsPointer, Error,
error::Result,
ortsys,
util::with_cstr,
value::{ValueType, r#type::extract_data_type_from_tensor_info}
};
pub trait Operator: Send {
fn name(&self) -> &str;
fn execution_provider_type(&self) -> Option<&str> {
None
}
fn inputs(&self) -> Vec<OperatorInput>;
fn outputs(&self) -> Vec<OperatorOutput>;
fn create_kernel(&self, attributes: &KernelAttributes) -> crate::Result<Box<dyn Kernel>>;
fn min_version(&self) -> i32 {
1
}
fn max_version(&self) -> i32 {
i32::MAX
}
fn infer_shape(&self, ctx: &mut ShapeInferenceContext) -> crate::Result<()> {
let _ = ctx;
Ok(())
}
}
pub struct ShapeInferenceContext {
ptr: *mut ort_sys::OrtShapeInferContext
}
impl ShapeInferenceContext {
pub fn inputs(&self) -> Vec<ValueType> {
let mut count = 0;
ortsys![unsafe ShapeInferContext_GetInputCount(self.ptr(), &mut count).expect("failed to get input count")];
let mut tys = Vec::with_capacity(count);
for i in 0..count {
let mut ty_info = ptr::null_mut();
ortsys![unsafe ShapeInferContext_GetInputTypeShape(self.ptr(), i, &mut ty_info).expect("failed to get info type"); nonNull(ty_info)];
tys.push(unsafe { extract_data_type_from_tensor_info(ty_info) });
}
tys
}
pub fn attr<T: FromOpAttr>(&self, name: impl AsRef<str>) -> Result<T> {
let attr = with_cstr(name.as_ref().as_bytes(), &|name| {
let mut attr = ptr::null();
ortsys![unsafe ShapeInferContext_GetAttribute(self.ptr(), name.as_ptr(), &mut attr)?];
Ok(attr)
})?;
let mut len = 0;
ortsys![unsafe ReadOpAttr(attr, T::attr_type(), ptr::null_mut(), 0, &mut len)?];
unsafe { T::from_op_attr(attr, len) }
}
pub fn set_output(&mut self, idx: usize, ty: &ValueType) -> Result<()> {
match ty.to_tensor_type_info() {
Some(ty_ptr) => {
ortsys![unsafe ShapeInferContext_SetOutputTypeShape(self.ptr(), idx, ty_ptr)?];
ortsys![unsafe ReleaseTensorTypeAndShapeInfo(ty_ptr)];
Ok(())
}
None => Err(Error::new("only tensors are supported"))
}
}
}
impl AsPointer for ShapeInferenceContext {
type Sys = ort_sys::OrtShapeInferContext;
fn ptr(&self) -> *const Self::Sys {
self.ptr
}
}
pub struct OperatorDomain {
ptr: NonNull<ort_sys::OrtCustomOpDomain>,
_name: CString,
#[allow(clippy::vec_box)]
operators: Vec<Box<BoundOperator>>
}
impl OperatorDomain {
pub fn new(name: impl AsRef<str>) -> Result<Self> {
let name = CString::new(name.as_ref())?;
let mut ptr: *mut ort_sys::OrtCustomOpDomain = ptr::null_mut();
ortsys![unsafe CreateCustomOpDomain(name.as_ptr(), &mut ptr)?; nonNull(ptr)];
crate::logging::create!(OperatorDomain, ptr);
Ok(Self {
ptr,
_name: name,
operators: Vec::new()
})
}
#[allow(clippy::should_implement_trait)]
pub fn add<O: Operator + 'static>(mut self, operator: O) -> Result<Self> {
let bound = Box::new(BoundOperator::new(operator)?);
ortsys![unsafe CustomOpDomain_Add(self.ptr.as_ptr(), (&*bound as *const BoundOperator) as *mut _)?];
self.operators.push(bound);
Ok(self)
}
}
impl AsPointer for OperatorDomain {
type Sys = ort_sys::OrtCustomOpDomain;
fn ptr(&self) -> *const Self::Sys {
self.ptr.as_ptr()
}
}
impl Drop for OperatorDomain {
fn drop(&mut self) {
ortsys![unsafe ReleaseCustomOpDomain(self.ptr.as_ptr())];
crate::logging::drop!(OperatorDomain, self.ptr);
}
}