ort2 0.1.2

onnxruntime wrapper c/c++ api
Documentation
use std::{ffi::CString, marker::PhantomData, ptr::null_mut, slice::from_raw_parts};

use ort2_sys as ffi;

use crate::{
    api::{api, ok},
    error::Result,
    memory::MemoryInfo,
    prelude::AllocatorTrait,
    session::Session,
    value::Value,
};

pub mod prelude {}

#[derive(Debug)]
pub struct IoBinding<'a> {
    inner: *mut ffi::OrtIoBinding,
    marker: PhantomData<&'a ()>,
}

impl Drop for IoBinding<'_> {
    fn drop(&mut self) {
        api!(ReleaseIoBinding, self.inner)
    }
}

impl<'a> IoBinding<'a> {
    pub fn new(session: &'a Session) -> Result<Self> {
        let mut inner = null_mut();
        ok!(CreateIoBinding, session.inner(), &mut inner)?;
        Ok(Self {
            inner,
            marker: PhantomData,
        })
    }

    pub fn clear_inputs(&mut self) {
        api!(ClearBoundInputs, self.inner)
    }

    pub fn clear_outputs(&mut self) {
        api!(ClearBoundOutputs, self.inner)
    }

    pub fn bind_input(&mut self, name: &CString, value: &Value) -> Result<()> {
        ok!(BindInput, self.inner, name.as_ptr(), value.inner())
    }

    pub fn bind_output(&mut self, name: &CString, value: &Value) -> Result<()> {
        ok!(BindOutput, self.inner, name.as_ptr(), value.inner())
    }

    pub fn bind_output_to_device(&mut self, name: &CString, mem_info: &MemoryInfo) -> Result<()> {
        ok!(
            BindOutputToDevice,
            self.inner,
            name.as_ptr(),
            mem_info.inner()
        )
    }

    pub fn get_bound_outputs(&self, allocator: &impl AllocatorTrait) -> Result<Vec<Value<'_>>> {
        let mut count = 0usize;
        let mut inners = null_mut();
        ok!(
            GetBoundOutputValues,
            self.inner,
            allocator.inner(),
            &mut inners,
            &mut count
        )?;
        let inners = unsafe { from_raw_parts(inners, count) };
        Ok((0..count)
            .map(|n| Value::new(inners[n], allocator))
            .collect())
    }

    pub fn inner(&self) -> *mut ffi::OrtIoBinding {
        self.inner
    }
}