use alloc::{ffi::CString, string::String, vec, vec::Vec};
use core::{
ffi::c_char,
ptr::{self, NonNull}
};
use crate::{
AsPointer, Error, Result,
memory::Allocator,
ortsys,
util::{with_cstr, with_cstr_ptr_array},
value::{DowncastableTarget, DynValue, ValueRef}
};
#[derive(Debug)]
#[repr(transparent)] pub struct Attribute(NonNull<ort_sys::OrtOpAttr>);
impl Attribute {
pub fn new(name: impl AsRef<str>, value: impl ToAttribute) -> Result<Self> {
with_cstr(name.as_ref().as_bytes(), &|name| {
let ptr = unsafe { value.to_attribute(name.as_ptr()) }?;
crate::logging::create!(Attribute, ptr);
Ok(Self(ptr))
})
}
}
impl Drop for Attribute {
fn drop(&mut self) {
ortsys![unsafe ReleaseOpAttr(self.0.as_ptr())];
crate::logging::drop!(Attribute, self.0);
}
}
pub trait FromKernelAttributes<'s> {
#[doc(hidden)]
unsafe fn from_info(info: *mut ort_sys::OrtKernelInfo, name: *const ort_sys::c_char) -> Result<Self>
where
Self: Sized;
private_trait!();
}
pub trait FromOpAttr {
#[doc(hidden)]
fn attr_type() -> ort_sys::OrtOpAttrType;
#[doc(hidden)]
unsafe fn from_op_attr(attr: *const ort_sys::OrtOpAttr, len: usize) -> Result<Self>
where
Self: Sized;
private_trait!();
}
pub trait ToAttribute {
#[doc(hidden)]
unsafe fn to_attribute(&self, name: *const ort_sys::c_char) -> Result<NonNull<ort_sys::OrtOpAttr>>
where
Self: Sized;
private_trait!();
}
impl FromKernelAttributes<'_> for f32 {
unsafe fn from_info(info: *mut ort_sys::OrtKernelInfo, name: *const ort_sys::c_char) -> Result<Self>
where
Self: Sized
{
let mut value = Self::default();
ortsys![unsafe KernelInfoGetAttribute_float(info, name, &mut value)?];
Ok(value)
}
private_impl!();
}
impl FromOpAttr for f32 {
fn attr_type() -> ort_sys::OrtOpAttrType {
ort_sys::OrtOpAttrType::ORT_OP_ATTR_FLOAT
}
unsafe fn from_op_attr(attr: *const ort_sys::OrtOpAttr, mut len: usize) -> Result<Self>
where
Self: Sized
{
let mut out = 0.0_f32;
ortsys![unsafe ReadOpAttr(attr, ort_sys::OrtOpAttrType::ORT_OP_ATTR_FLOAT, (&mut out as *mut f32).cast(), size_of::<f32>(), &mut len)?];
debug_assert_eq!(len, size_of::<f32>(), "float attribute is smaller/larger than expected");
Ok(out)
}
private_impl!();
}
impl ToAttribute for f32 {
unsafe fn to_attribute(&self, name: *const ort_sys::c_char) -> Result<NonNull<ort_sys::OrtOpAttr>>
where
Self: Sized
{
let mut out = ptr::null_mut();
ortsys![
unsafe CreateOpAttr(
name,
(self as *const f32).cast(),
1,
ort_sys::OrtOpAttrType::ORT_OP_ATTR_FLOAT,
&mut out
)?;
nonNull(out)
];
Ok(out)
}
private_impl!();
}
impl FromKernelAttributes<'_> for i64 {
unsafe fn from_info(info: *mut ort_sys::OrtKernelInfo, name: *const ort_sys::c_char) -> Result<Self>
where
Self: Sized
{
let mut value = Self::default();
ortsys![unsafe KernelInfoGetAttribute_int64(info, name, &mut value)?];
Ok(value)
}
private_impl!();
}
impl FromOpAttr for i64 {
fn attr_type() -> ort_sys::OrtOpAttrType {
ort_sys::OrtOpAttrType::ORT_OP_ATTR_INT
}
unsafe fn from_op_attr(attr: *const ort_sys::OrtOpAttr, mut len: usize) -> Result<Self>
where
Self: Sized
{
let mut out = 0_i64;
ortsys![unsafe ReadOpAttr(attr, ort_sys::OrtOpAttrType::ORT_OP_ATTR_INT, (&mut out as *mut i64).cast(), size_of::<i64>(), &mut len)?];
debug_assert_eq!(len, size_of::<i64>(), "int attribute is smaller/larger than expected");
Ok(out)
}
private_impl!();
}
impl ToAttribute for i64 {
unsafe fn to_attribute(&self, name: *const ort_sys::c_char) -> Result<NonNull<ort_sys::OrtOpAttr>>
where
Self: Sized
{
let mut out = ptr::null_mut();
ortsys![
unsafe CreateOpAttr(
name,
(self as *const i64).cast(),
1,
ort_sys::OrtOpAttrType::ORT_OP_ATTR_INT,
&mut out
)?;
nonNull(out)
];
Ok(out)
}
private_impl!();
}
impl FromKernelAttributes<'_> for String {
unsafe fn from_info(info: *mut ort_sys::OrtKernelInfo, name: *const ort_sys::c_char) -> Result<Self>
where
Self: Sized
{
let mut size = 0;
ortsys![unsafe KernelInfoGetAttribute_string(info, name, ptr::null_mut(), &mut size)?];
let mut out = vec![0u8; size];
ortsys![unsafe KernelInfoGetAttribute_string(info, name, out.as_mut_ptr().cast::<c_char>(), &mut size)?];
let string = CString::from_vec_with_nul(out)?;
Ok(string.into_string()?)
}
private_impl!();
}
impl FromOpAttr for String {
fn attr_type() -> ort_sys::OrtOpAttrType {
ort_sys::OrtOpAttrType::ORT_OP_ATTR_STRING
}
unsafe fn from_op_attr(attr: *const ort_sys::OrtOpAttr, mut len: usize) -> Result<Self>
where
Self: Sized
{
let mut out = vec![0_u8; len];
ortsys![unsafe ReadOpAttr(attr, ort_sys::OrtOpAttrType::ORT_OP_ATTR_STRING, out.as_mut_ptr().cast(), len, &mut len)?];
debug_assert_eq!(out.len(), len, "int attribute is smaller/larger than expected");
CString::from_vec_with_nul(out)
.map_err(|_| Error::new("invalid string attribute contents"))
.and_then(|f| f.into_string().map_err(|_| Error::new("invalid string attribute contents")))
}
private_impl!();
}
impl ToAttribute for String {
unsafe fn to_attribute(&self, name: *const ort_sys::c_char) -> Result<NonNull<ort_sys::OrtOpAttr>>
where
Self: Sized
{
with_cstr(self.as_bytes(), &|contents| {
let mut out = ptr::null_mut();
ortsys![
unsafe CreateOpAttr(
name,
contents.as_ptr().cast(),
1,
ort_sys::OrtOpAttrType::ORT_OP_ATTR_STRING,
&mut out
)?;
nonNull(out)
];
Ok(out)
})
}
private_impl!();
}
impl FromKernelAttributes<'_> for Vec<f32> {
unsafe fn from_info(info: *mut ort_sys::OrtKernelInfo, name: *const ort_sys::c_char) -> Result<Self>
where
Self: Sized
{
let mut size = 0;
ortsys![unsafe KernelInfoGetAttributeArray_float(info, name, ptr::null_mut(), &mut size)?];
let mut out = vec![0f32; size];
ortsys![unsafe KernelInfoGetAttributeArray_float(info, name, out.as_mut_ptr(), &mut size)?];
Ok(out)
}
private_impl!();
}
impl FromOpAttr for Vec<f32> {
fn attr_type() -> ort_sys::OrtOpAttrType {
ort_sys::OrtOpAttrType::ORT_OP_ATTR_FLOATS
}
unsafe fn from_op_attr(attr: *const ort_sys::OrtOpAttr, mut len: usize) -> Result<Self>
where
Self: Sized
{
let num_els = len / size_of::<f32>();
let mut out = vec![0.0_f32; num_els];
ortsys![unsafe ReadOpAttr(attr, ort_sys::OrtOpAttrType::ORT_OP_ATTR_FLOATS, out.as_mut_ptr().cast(), len, &mut len)?];
debug_assert_eq!(out.len(), num_els, "float array attribute is smaller/larger than expected");
Ok(out)
}
private_impl!();
}
impl ToAttribute for &[f32] {
unsafe fn to_attribute(&self, name: *const ort_sys::c_char) -> Result<NonNull<ort_sys::OrtOpAttr>>
where
Self: Sized
{
let mut out = ptr::null_mut();
ortsys![
unsafe CreateOpAttr(
name,
self.as_ptr().cast(),
self.len() as _,
ort_sys::OrtOpAttrType::ORT_OP_ATTR_FLOATS,
&mut out
)?;
nonNull(out)
];
Ok(out)
}
private_impl!();
}
impl ToAttribute for Vec<f32> {
unsafe fn to_attribute(&self, name: *const ort_sys::c_char) -> Result<NonNull<ort_sys::OrtOpAttr>>
where
Self: Sized
{
unsafe { self.as_slice().to_attribute(name) }
}
private_impl!();
}
impl FromKernelAttributes<'_> for Vec<i64> {
unsafe fn from_info(info: *mut ort_sys::OrtKernelInfo, name: *const ort_sys::c_char) -> Result<Self>
where
Self: Sized
{
let mut size = 0;
ortsys![unsafe KernelInfoGetAttributeArray_int64(info, name, ptr::null_mut(), &mut size)?];
let mut out = vec![0i64; size];
ortsys![unsafe KernelInfoGetAttributeArray_int64(info, name, out.as_mut_ptr(), &mut size)?];
Ok(out)
}
private_impl!();
}
impl FromOpAttr for Vec<i64> {
fn attr_type() -> ort_sys::OrtOpAttrType {
ort_sys::OrtOpAttrType::ORT_OP_ATTR_INTS
}
unsafe fn from_op_attr(attr: *const ort_sys::OrtOpAttr, mut len: usize) -> Result<Self>
where
Self: Sized
{
let num_els = len / size_of::<i64>();
let mut out = vec![0_i64; num_els];
ortsys![unsafe ReadOpAttr(attr, ort_sys::OrtOpAttrType::ORT_OP_ATTR_INTS, out.as_mut_ptr().cast(), len, &mut len)?];
debug_assert_eq!(out.len(), num_els, "int array attribute is smaller/larger than expected");
Ok(out)
}
private_impl!();
}
impl ToAttribute for &[i64] {
unsafe fn to_attribute(&self, name: *const ort_sys::c_char) -> Result<NonNull<ort_sys::OrtOpAttr>>
where
Self: Sized
{
let mut out = ptr::null_mut();
ortsys![
unsafe CreateOpAttr(
name,
self.as_ptr().cast(),
self.len() as _,
ort_sys::OrtOpAttrType::ORT_OP_ATTR_INTS,
&mut out
)?;
nonNull(out)
];
Ok(out)
}
private_impl!();
}
impl ToAttribute for Vec<i64> {
unsafe fn to_attribute(&self, name: *const ort_sys::c_char) -> Result<NonNull<ort_sys::OrtOpAttr>>
where
Self: Sized
{
unsafe { self.as_slice().to_attribute(name) }
}
private_impl!();
}
impl ToAttribute for &[String] {
unsafe fn to_attribute(&self, name: *const ort_sys::c_char) -> Result<NonNull<ort_sys::OrtOpAttr>>
where
Self: Sized
{
with_cstr_ptr_array(self, &|strings| {
let mut out = ptr::null_mut();
ortsys![
unsafe CreateOpAttr(
name,
strings.as_ptr().cast(),
strings.len() as _,
ort_sys::OrtOpAttrType::ORT_OP_ATTR_STRINGS,
&mut out
)?;
nonNull(out)
];
Ok(out)
})
}
private_impl!();
}
impl ToAttribute for &[&str] {
unsafe fn to_attribute(&self, name: *const ort_sys::c_char) -> Result<NonNull<ort_sys::OrtOpAttr>>
where
Self: Sized
{
with_cstr_ptr_array(self, &|strings| {
let mut out = ptr::null_mut();
ortsys![
unsafe CreateOpAttr(
name,
strings.as_ptr().cast(),
strings.len() as _,
ort_sys::OrtOpAttrType::ORT_OP_ATTR_STRINGS,
&mut out
)?;
nonNull(out)
];
Ok(out)
})
}
private_impl!();
}
impl ToAttribute for Vec<String> {
unsafe fn to_attribute(&self, name: *const ort_sys::c_char) -> Result<NonNull<ort_sys::OrtOpAttr>>
where
Self: Sized
{
unsafe { self.as_slice().to_attribute(name) }
}
private_impl!();
}
impl ToAttribute for Vec<&str> {
unsafe fn to_attribute(&self, name: *const ort_sys::c_char) -> Result<NonNull<ort_sys::OrtOpAttr>>
where
Self: Sized
{
unsafe { self.as_slice().to_attribute(name) }
}
private_impl!();
}
impl<'s, T: DowncastableTarget> FromKernelAttributes<'s> for ValueRef<'s, T> {
unsafe fn from_info(info: *mut ort_sys::OrtKernelInfo, name: *const ort_sys::c_char) -> Result<Self>
where
Self: Sized
{
let allocator = Allocator::default();
let mut value_ptr: *mut ort_sys::OrtValue = ptr::null_mut();
ortsys![unsafe KernelInfoGetAttribute_tensor(info, name, allocator.ptr().cast_mut(), &mut value_ptr)?; nonNull(value_ptr)];
unsafe { ValueRef::new(DynValue::from_ptr(value_ptr, None)) }.downcast()
}
private_impl!();
}