use alloc::{string::String, sync::Arc};
use core::{
fmt::Debug,
ptr::{self, NonNull}
};
use crate::{
AsPointer,
error::Result,
memory::MemoryInfo,
ortsys,
session::{Session, SharedSessionInner},
util::{MiniMap, with_cstr},
value::{DynValue, Value, ValueInner, ValueTypeMarker}
};
#[derive(Debug)]
pub struct IoBinding {
ptr: NonNull<ort_sys::OrtIoBinding>,
held_inputs: MiniMap<String, Arc<ValueInner>>,
pub(crate) output_values: MiniMap<String, Option<DynValue>>,
_session: Arc<SharedSessionInner>
}
impl IoBinding {
pub(crate) fn new(session: &Session) -> Result<Self> {
let mut ptr: *mut ort_sys::OrtIoBinding = ptr::null_mut();
ortsys![unsafe CreateIoBinding(session.ptr().cast_mut(), &mut ptr)?; nonNull(ptr)];
crate::logging::create!(IoBinding, ptr);
Ok(Self {
ptr,
held_inputs: MiniMap::new(),
output_values: MiniMap::new(),
_session: session.inner()
})
}
pub fn bind_input<T: ValueTypeMarker + ?Sized, S: Into<String>>(&mut self, name: S, ort_value: &Value<T>) -> Result<()> {
let name: String = name.into();
let ptr = self.ptr_mut();
with_cstr(name.as_bytes(), &|name| {
ortsys![unsafe BindInput(ptr, name.as_ptr(), ort_value.ptr())?];
Ok(())
})?;
self.held_inputs.insert(name, Arc::clone(&ort_value.inner));
Ok(())
}
pub fn bind_output<T: ValueTypeMarker + ?Sized, S: Into<String>>(&mut self, name: S, mut ort_value: Value<T>) -> Result<()> {
let name: String = name.into();
unsafe { self.bind_output_mut(name.as_bytes(), &mut ort_value) }?;
self.output_values.insert(name, Some(ort_value.into_dyn()));
Ok(())
}
pub(crate) unsafe fn bind_output_mut<T: ValueTypeMarker + ?Sized, S: AsRef<[u8]>>(&mut self, name: S, ort_value: &mut Value<T>) -> Result<()> {
let ptr = self.ptr_mut();
with_cstr(name.as_ref(), &|name| {
ortsys![unsafe BindOutput(ptr, name.as_ptr(), ort_value.ptr())?];
Ok(())
})?;
Ok(())
}
pub fn bind_output_to_device<S: Into<String>>(&mut self, name: S, mem_info: &MemoryInfo) -> Result<()> {
let name: String = name.into();
let ptr = self.ptr_mut();
with_cstr(name.as_bytes(), &|name| {
ortsys![unsafe BindOutputToDevice(ptr, name.as_ptr(), mem_info.ptr())?];
Ok(())
})?;
self.output_values.insert(name, None);
Ok(())
}
pub fn clear_inputs(&mut self) {
ortsys![unsafe ClearBoundInputs(self.ptr_mut())];
drop(self.held_inputs.drain());
}
pub fn clear_outputs(&mut self) {
ortsys![unsafe ClearBoundOutputs(self.ptr_mut())];
drop(self.output_values.drain());
}
pub fn clear(&mut self) {
self.clear_inputs();
self.clear_outputs();
}
pub fn synchronize_inputs(&self) -> Result<()> {
ortsys![unsafe SynchronizeBoundInputs(self.ptr().cast_mut())?];
Ok(())
}
pub fn synchronize_outputs(&self) -> Result<()> {
ortsys![unsafe SynchronizeBoundOutputs(self.ptr().cast_mut())?];
Ok(())
}
pub fn synchronize(&self) -> Result<()> {
self.synchronize_inputs()?;
self.synchronize_outputs()?;
Ok(())
}
}
unsafe impl Send for IoBinding {}
impl AsPointer for IoBinding {
type Sys = ort_sys::OrtIoBinding;
fn ptr(&self) -> *const Self::Sys {
self.ptr.as_ptr()
}
}
impl Drop for IoBinding {
fn drop(&mut self) {
ortsys![unsafe ReleaseIoBinding(self.ptr_mut())];
crate::logging::drop!(IoBinding, self.ptr());
}
}
#[cfg(test)]
mod tests {
#[cfg(feature = "ndarray")]
use ndarray::Array2;
#[cfg(feature = "ndarray")]
use crate::test_util::mnist;
use crate::{
Result,
memory::{AllocationDevice, AllocatorType, MemoryInfo, MemoryType},
session::Session,
value::Tensor
};
#[test]
#[cfg(all(feature = "ndarray", feature = "fetch-models"))]
fn test_mnist_input_bound() -> Result<()> {
let mut session = Session::builder()?.commit_from_url(mnist::MODEL_URL)?;
let array = mnist::get_image();
let mut binding = session.create_binding()?;
binding.bind_input(session.inputs()[0].name(), &Tensor::from_array(array)?)?;
binding
.bind_output_to_device(session.outputs()[0].name(), &MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Device, MemoryType::CPUOutput)?)?;
let outputs = session.run_binding(&binding)?;
let probabilities = mnist::extract_probabilities(&outputs[0])?;
assert_eq!(probabilities[0].0, 5);
Ok(())
}
#[test]
#[cfg(all(feature = "ndarray", feature = "fetch-models"))]
fn test_mnist_input_output_bound() -> Result<()> {
let mut session = Session::builder()?.commit_from_url(mnist::MODEL_URL)?;
let array = mnist::get_image();
let mut binding = session.create_binding()?;
binding.bind_input(session.inputs()[0].name(), &Tensor::from_array(array)?)?;
let output = Array2::from_shape_simple_fn((1, 10), || 0.0_f32);
binding.bind_output(session.outputs()[0].name(), Tensor::from_array(output)?)?;
let outputs = session.run_binding(&binding)?;
let probabilities = mnist::extract_probabilities(&outputs[0])?;
assert_eq!(probabilities[0].0, 5);
Ok(())
}
#[test]
#[cfg(all(feature = "ndarray", feature = "fetch-models"))]
fn test_send_iobinding() -> Result<()> {
let mut session = Session::builder()?.commit_from_url(mnist::MODEL_URL)?;
let array = mnist::get_image();
let mut binding = session.create_binding()?;
let output = Array2::from_shape_simple_fn((1, 10), || 0.0_f32);
binding.bind_output(session.outputs()[0].name(), Tensor::from_array(output)?)?;
let probabilities = std::thread::spawn(move || {
binding.bind_input(session.inputs()[0].name(), &Tensor::from_array(array)?)?;
let outputs = session.run_binding(&binding)?;
let probabilities = mnist::extract_probabilities(&outputs[0])?;
Ok::<Vec<(usize, f32)>, crate::Error>(probabilities)
})
.join()
.expect("")?;
assert_eq!(probabilities[0].0, 5);
Ok(())
}
#[test]
#[cfg(all(feature = "ndarray", feature = "fetch-models"))]
fn test_mnist_clear_binds() -> Result<()> {
let mut session = Session::builder()?.commit_from_url(mnist::MODEL_URL)?;
let array = mnist::get_image();
let mut binding = session.create_binding()?;
binding.bind_input(session.inputs()[0].name(), &Tensor::from_array(array)?)?;
let output = Array2::from_shape_simple_fn((1, 10), || 0.0_f32);
binding.bind_output(session.outputs()[0].name(), Tensor::from_array(output)?)?;
{
let outputs = session.run_binding(&binding)?;
let probabilities = mnist::extract_probabilities(&outputs[0])?;
assert_eq!(probabilities[0].0, 5);
}
binding.clear_outputs();
binding
.bind_output_to_device(session.outputs()[0].name(), &MemoryInfo::new(AllocationDevice::CPU, 0, AllocatorType::Device, MemoryType::CPUOutput)?)?;
{
let outputs = session.run_binding(&binding)?;
let probabilities = mnist::extract_probabilities(&outputs[0])?;
assert_eq!(probabilities[0].0, 5);
}
binding.clear_inputs();
assert!(session.run_binding(&binding).is_err());
Ok(())
}
}