use crate::allocator::Allocator;
use crate::element::TensorElement;
use crate::memory::MemoryInfo;
use crate::session::Session;
use crate::tensor::{AllocatedTensor, OwnedValue, RunInput, TensorBuffer, tensor_memory_info};
use crate::type_info::checked_element_count;
use crate::{Error, Result, api, check, sys};
use std::ffi::{CString, c_void};
use std::marker::PhantomData;
use std::ptr;
pub struct OutputValue<'a> {
value: *mut sys::ValueHandle,
elem_type: sys::ElementType,
count: usize,
_life: PhantomData<&'a mut [u8]>,
}
impl<'a> OutputValue<'a> {
pub fn from_buffer<T: TensorElement>(
buf: &'a mut [T], shape: &[i64], mem: &MemoryInfo,
) -> Result<Self> {
validate_shape_len(shape, buf.len())?;
if !mem.is_host_accessible()? {
let info = mem.snapshot()?;
return Err(Error::new(
-1,
format!(
"OutputValue wraps a Rust slice and requires host-accessible memory, got {} device {} ({:?}/{:?})",
info.name, info.device_id, info.alloc_type, info.mem_type
),
));
}
let bytes = std::mem::size_of_val(buf);
let mut value: *mut sys::ValueHandle = ptr::null_mut();
check(unsafe {
api().create_tensor_with_data_as_ort_value()(
mem.info as *const sys::MemoryInfoHandle,
buf.as_mut_ptr() as *mut c_void,
bytes,
shape.as_ptr(),
shape.len(),
T::ELEM,
&mut value,
)
})?;
let value = crate::ensure_non_null(value, "output value")?;
Ok(Self {
value,
elem_type: T::ELEM,
count: buf.len(),
_life: PhantomData,
})
}
#[inline]
pub(crate) fn as_value_ptr(&self) -> *const sys::ValueHandle {
self.value as *const sys::ValueHandle
}
pub fn as_slice<T: TensorElement>(&self) -> Result<&[T]> {
if self.elem_type as i32 != T::ELEM as i32 {
return Err(Error::new(
-1,
format!(
"zrt: OutputValue::as_slice<{}> on a {:?} buffer",
std::any::type_name::<T>(),
self.elem_type
),
));
}
let info = tensor_memory_info(self.value as *const sys::ValueHandle)?;
if !info.is_host_accessible() {
return Err(Error::new(
-1,
format!(
"output tensor memory is not host-accessible: {} device {} ({:?}/{:?})",
info.name, info.device_id, info.alloc_type, info.mem_type
),
));
}
let mut data: *mut c_void = ptr::null_mut();
check(unsafe { api().get_tensor_mutable_data()(self.value, &mut data) })?;
let data = crate::slice_data_ptr(data as *mut T, self.count, "output tensor data")?;
Ok(unsafe { std::slice::from_raw_parts(data as *const T, self.count) })
}
}
impl Drop for OutputValue<'_> {
fn drop(&mut self) {
unsafe { api().release_value()(self.value) }
}
}
unsafe impl Send for OutputValue<'_> {}
unsafe impl Sync for OutputValue<'_> {}
fn validate_shape_len(shape: &[i64], len: usize) -> Result<()> {
let expected = checked_element_count(shape)?;
if expected != len {
return Err(Error::new(
-1,
format!("output tensor shape expects {expected} elements, got {len}"),
));
}
Ok(())
}
pub struct IoBinding {
binding: *mut sys::IoBindingHandle,
}
impl IoBinding {
pub fn new(sess: &Session) -> Result<Self> {
let mut binding: *mut sys::IoBindingHandle = ptr::null_mut();
check(unsafe { api().create_io_binding()(sess.as_ptr(), &mut binding) })?;
let binding = crate::ensure_non_null(binding, "I/O binding")?;
Ok(Self { binding })
}
#[inline]
pub(crate) fn as_ptr(&self) -> *const sys::IoBindingHandle {
self.binding as *const sys::IoBindingHandle
}
pub fn bind_input(&mut self, name: &str, input: &dyn RunInput) -> Result<()> {
let cname = CString::new(name).map_err(|_| Error::new(-1, "input name contains a NUL"))?;
check(unsafe { api().bind_input()(self.binding, cname.as_ptr(), input.as_value_ptr()) })
}
pub fn bind_output(&mut self, name: &str, value: &OutputValue<'_>) -> Result<()> {
let cname = CString::new(name).map_err(|_| Error::new(-1, "output name contains a NUL"))?;
check(unsafe { api().bind_output()(self.binding, cname.as_ptr(), value.as_value_ptr()) })
}
pub fn bind_output_buffer<T: TensorElement>(
&mut self, name: &str, value: &TensorBuffer<T>,
) -> Result<()> {
let cname = CString::new(name).map_err(|_| Error::new(-1, "output name contains a NUL"))?;
check(unsafe { api().bind_output()(self.binding, cname.as_ptr(), value.as_value_ptr()) })
}
pub fn bind_output_allocated<T: TensorElement>(
&mut self, name: &str, value: &AllocatedTensor<T>,
) -> Result<()> {
let cname = CString::new(name).map_err(|_| Error::new(-1, "output name contains a NUL"))?;
check(unsafe { api().bind_output()(self.binding, cname.as_ptr(), value.as_value_ptr()) })
}
pub fn bind_output_device(&mut self, name: &str, mem: &MemoryInfo) -> Result<()> {
let cname = CString::new(name).map_err(|_| Error::new(-1, "output name contains a NUL"))?;
check(unsafe {
api().bind_output_to_device()(
self.binding,
cname.as_ptr(),
mem.info as *const sys::MemoryInfoHandle,
)
})
}
pub fn synchronize_outputs(&self) -> Result<()> {
check(unsafe { api().synchronize_bound_outputs()(self.binding) })
}
pub fn synchronize_inputs(&self) -> Result<()> {
check(unsafe { api().synchronize_bound_inputs()(self.binding) })
}
pub fn clear_inputs(&mut self) {
unsafe { api().clear_bound_inputs()(self.binding) }
}
pub fn clear_outputs(&mut self) {
unsafe { api().clear_bound_outputs()(self.binding) }
}
pub fn output_values(&self) -> Result<Vec<OwnedValue>> {
let alloc = Allocator::get_default()?;
let mut out: *mut *mut sys::ValueHandle = ptr::null_mut();
let mut count: usize = 0;
check(unsafe {
api().get_bound_output_values()(
self.binding as *const sys::IoBindingHandle,
alloc.alloc,
&mut out,
&mut count,
)
})?;
let handles: &[*mut sys::ValueHandle] = if count == 0 {
&[]
} else {
unsafe { std::slice::from_raw_parts(out, count) }
};
let values = OwnedValue::collect_from_raw(handles);
let free = if out.is_null() {
Ok(())
} else {
unsafe { alloc.free(out as *mut c_void) }
};
match (values, free) {
(Ok(values), Ok(())) => Ok(values),
(Err(err), _) => Err(err),
(Ok(_), Err(err)) => Err(err),
}
}
}
impl Drop for IoBinding {
fn drop(&mut self) {
unsafe { api().release_io_binding()(self.binding) }
}
}
unsafe impl Send for IoBinding {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn output_value_rejects_dynamic_and_mismatched_shapes() {
let mem = MemoryInfo::cpu().unwrap();
let mut buf = [0.0f32; 4];
assert!(OutputValue::from_buffer(&mut buf, &[-1, 4], &mem).is_err());
assert!(OutputValue::from_buffer(&mut buf, &[5], &mem).is_err());
assert!(OutputValue::from_buffer(&mut buf, &[2, 2], &mem).is_ok());
}
#[test]
fn output_value_rejects_cuda_device_memory() {
let mem = MemoryInfo::cuda(0).unwrap();
let mut buf = [0.0f32; 4];
assert!(OutputValue::from_buffer(&mut buf, &[2, 2], &mem).is_err());
}
}