use std::{ffi::CString, marker::PhantomData, ptr::null_mut, slice::from_raw_parts};
use ort2_sys as ffi;
use crate::{
api::{api, ok},
error::Result,
memory::MemoryInfo,
prelude::AllocatorTrait,
session::Session,
value::Value,
};
pub mod prelude {}
#[derive(Debug)]
pub struct IoBinding<'a> {
inner: *mut ffi::OrtIoBinding,
marker: PhantomData<&'a ()>,
}
impl Drop for IoBinding<'_> {
fn drop(&mut self) {
api!(ReleaseIoBinding, self.inner)
}
}
impl<'a> IoBinding<'a> {
pub fn new(session: &'a Session) -> Result<Self> {
let mut inner = null_mut();
ok!(CreateIoBinding, session.inner(), &mut inner)?;
Ok(Self {
inner,
marker: PhantomData,
})
}
pub fn clear_inputs(&mut self) {
api!(ClearBoundInputs, self.inner)
}
pub fn clear_outputs(&mut self) {
api!(ClearBoundOutputs, self.inner)
}
pub fn bind_input(&mut self, name: &CString, value: &Value) -> Result<()> {
ok!(BindInput, self.inner, name.as_ptr(), value.inner())
}
pub fn bind_output(&mut self, name: &CString, value: &Value) -> Result<()> {
ok!(BindOutput, self.inner, name.as_ptr(), value.inner())
}
pub fn bind_output_to_device(&mut self, name: &CString, mem_info: &MemoryInfo) -> Result<()> {
ok!(
BindOutputToDevice,
self.inner,
name.as_ptr(),
mem_info.inner()
)
}
pub fn get_bound_outputs(&self, allocator: &impl AllocatorTrait) -> Result<Vec<Value<'_>>> {
let mut count = 0usize;
let mut inners = null_mut();
ok!(
GetBoundOutputValues,
self.inner,
allocator.inner(),
&mut inners,
&mut count
)?;
let inners = unsafe { from_raw_parts(inners, count) };
Ok((0..count)
.map(|n| Value::new(inners[n], allocator))
.collect())
}
pub fn inner(&self) -> *mut ffi::OrtIoBinding {
self.inner
}
}