ort2 0.1.2

onnxruntime wrapper c/c++ api
Documentation
use std::{
    ffi::{c_void, CStr, CString},
    ptr::{null, null_mut},
};

use ort2_sys::{self as ffi, GraphOptimizationLevel};
use smart_default::SmartDefault;
use tracing::trace;

#[cfg(target_os = "windows")]
use widestring::Utf16String;

use crate::{
    allocator::{Allocator, AllocatorTrait, DefaultAllocator},
    api::{api, ok},
    environment::Environment,
    error::Result,
    iobinding::IoBinding,
    memory::MemoryInfo,
    value::{TypeInfo, Value},
};

#[derive(Debug)]
pub struct IoTypeInfo {
    pub typ: TypeInfo,
    pub name: CString,
}

pub struct Session {
    inner: *mut ffi::OrtSession,
    env: Environment,
}

impl std::fmt::Debug for Session {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("Session")
            .field("inner", &self.inner)
            .field("env", &self.env)
            .field(
                "inputs",
                &self.get_inputs().expect("failed to get inputs info"),
            )
            .field(
                "outputs",
                &self.get_outputs().expect("failed to get outputs info"),
            )
            .finish()
    }
}

impl Drop for Session {
    fn drop(&mut self) {
        trace!(?self, "dropping");
        api!(ReleaseSession, self.inner);
    }
}

impl Session {
    pub fn inner(&self) -> *mut ffi::OrtSession {
        self.inner
    }

    pub fn allocator(&self, mem_info: &MemoryInfo) -> Result<Allocator> {
        Allocator::new(self, mem_info)
    }

    pub fn run<'a, 'b, 'c>(&'c self, inputs: impl AsRef<[&'b Value<'a>]>) -> Result<Vec<Value<'c>>>
    where
        'a: 'b,
    {
        let input_infos = self.get_inputs()?;
        let input_names = input_infos
            .iter()
            .map(|t| t.name.as_ptr())
            .collect::<Vec<_>>();
        let inputs = inputs
            .as_ref()
            .iter()
            .map(|i| i.inner() as *const _)
            .collect::<Vec<_>>();

        let output_infos = self.get_outputs()?;
        let output_names = output_infos
            .iter()
            .map(|t| t.name.as_ptr())
            .collect::<Vec<_>>();

        let mut outputs = vec![null_mut(); output_names.len()];

        ok!(
            Run,
            self.inner,
            null(),
            input_names.as_ptr(),
            inputs.as_ptr(),
            inputs.len(),
            output_names.as_ptr(),
            output_names.len(),
            outputs.as_mut_ptr()
        )?;

        Ok(outputs
            .into_iter()
            .map(|inner| Value::new(inner, self))
            .collect())
    }

    pub fn get_inputs(&self) -> Result<Vec<IoTypeInfo>> {
        let mut count = 0usize;
        ok!(SessionGetInputCount, self.inner, &mut count)?;
        let alloc = DefaultAllocator::default();
        (0..count)
            .map(|n| -> Result<_> {
                let mut name_ = null_mut();
                ok!(
                    SessionGetInputName,
                    self.inner(),
                    n,
                    alloc.inner(),
                    &mut name_
                )?;
                let name = unsafe { CStr::from_ptr(name_) }.to_owned();
                let mut inner = null_mut();
                ok!(SessionGetInputTypeInfo, self.inner, n, &mut inner)?;
                let typ = TypeInfo::new(inner);
                Ok(IoTypeInfo { name, typ })
            })
            .collect::<Result<Vec<_>>>()
    }

    pub fn get_outputs(&self) -> Result<Vec<IoTypeInfo>> {
        let mut count = 0usize;
        ok!(SessionGetOutputCount, self.inner, &mut count)?;
        let alloc = DefaultAllocator::default();
        (0..count)
            .map(|n| -> Result<_> {
                let mut name_ = null_mut();
                ok!(
                    SessionGetOutputName,
                    self.inner(),
                    n,
                    alloc.inner(),
                    &mut name_
                )?;
                let name = unsafe { CStr::from_ptr(name_) }.to_owned();
                let mut inner = null_mut();
                ok!(SessionGetOutputTypeInfo, self.inner, n, &mut inner)?;
                let typ = TypeInfo::new(inner);
                Ok(IoTypeInfo { name, typ })
            })
            .collect::<Result<Vec<_>>>()
    }

    pub fn iobinding(&self) -> Result<IoBinding> {
        IoBinding::new(self)
    }

    pub fn run_with_iobinding(&self, iobinding: &mut IoBinding) -> Result<()> {
        ok!(RunWithBinding, self.inner, null_mut(), iobinding.inner())
    }
}

#[derive(SmartDefault)]
pub struct SessionBuilder {
    env: Option<Environment>,
    use_cuda: bool,
    use_tensor_rt: bool,
    cuda_device: i32,
    #[default(GraphOptimizationLevel::ORT_DISABLE_ALL)]
    graph_optimize_level: GraphOptimizationLevel,
    intra_thread_num: i32,
}

#[cfg(target_os = "windows")]
pub type RawString = widestring::Utf16String;

#[cfg(not(target_os = "windows"))]
pub type RawString = CString;

pub enum Model {
    File(RawString),
    Binary { data: *const c_void, len: usize },
}

impl SessionBuilder {
    pub fn build(self, model: impl Into<Model>) -> Result<Session> {
        let (env, options) = self.prepare()?;
        let mut inner = null_mut();
        let model: Model = model.into();
        match model {
            Model::File(file) => ok!(CreateSession, env.inner, file.as_ptr(), options, &mut inner)?,
            Model::Binary { data, len } => ok!(
                CreateSessionFromArray,
                env.inner,
                data,
                len,
                options,
                &mut inner
            )?,
        };
        api!(ReleaseSessionOptions, options);
        Ok(Session { inner, env })
    }

    pub fn prepare(self) -> Result<(Environment, *mut ffi::OrtSessionOptions)> {
        let Self {
            env,
            use_cuda,
            use_tensor_rt,
            cuda_device,
            graph_optimize_level,
            intra_thread_num,
        } = self;

        let mut options = null_mut();

        ok!(CreateSessionOptions, &mut options)?;

        if use_cuda {
            let mut provider = null_mut();
            ok!(CreateCUDAProviderOptions, &mut provider)?;

            let keys = [CString::new("device_id")?];
            let values = [CString::new(cuda_device.to_string())?];
            let keys = keys.iter().map(|k| k.as_ptr()).collect::<Vec<_>>();
            let values = values.iter().map(|v| v.as_ptr()).collect::<Vec<_>>();

            ok!(
                UpdateCUDAProviderOptions,
                provider,
                keys.as_ptr(),
                values.as_ptr(),
                keys.len()
            )?;

            ok!(
                SessionOptionsAppendExecutionProvider_CUDA_V2,
                options,
                provider
            )?;

            api!(ReleaseCUDAProviderOptions, provider);
        }

        if use_tensor_rt {
            let mut provider = null_mut();
            ok!(CreateTensorRTProviderOptions, &mut provider)?;

            let keys = [CString::new("device_id")?];
            let values = [CString::new(cuda_device.to_string())?];
            let keys = keys.iter().map(|k| k.as_ptr()).collect::<Vec<_>>();
            let values = values.iter().map(|v| v.as_ptr()).collect::<Vec<_>>();

            ok!(
                UpdateTensorRTProviderOptions,
                provider,
                keys.as_ptr(),
                values.as_ptr(),
                values.len()
            )?;

            ok!(
                SessionOptionsAppendExecutionProvider_TensorRT_V2,
                options,
                provider
            )?;

            api!(ReleaseTensorRTProviderOptions, provider);
        }

        ok!(SetIntraOpNumThreads, options, intra_thread_num)?;
        ok!(
            SetSessionGraphOptimizationLevel,
            options,
            graph_optimize_level
        )?;

        let env = match env {
            Some(e) => e,
            None => Environment::builder().build()?,
        };

        Ok((env, options))
    }

    pub fn with_cuda(mut self, device: i32) -> Self {
        self.use_cuda = true;
        self.cuda_device = device;
        self
    }

    pub fn with_graph_optimize_level(mut self, level: GraphOptimizationLevel) -> Self {
        self.graph_optimize_level = level;
        self
    }

    pub fn with_intra_thread_num(mut self, n_threads: i32) -> Self {
        self.intra_thread_num = n_threads;
        self
    }

    pub fn with_tensor_rt(mut self, device: i32) -> Self {
        self.use_tensor_rt = true;
        self.cuda_device = device;
        self
    }

    pub fn with_environment(mut self, env: Environment) -> Self {
        self.env = Some(env);
        self
    }
}

impl From<&[u8]> for Model {
    fn from(value: &[u8]) -> Self {
        Model::Binary {
            data: value.as_ptr() as *const _,
            len: value.len(),
        }
    }
}

#[allow(unreachable_code)]
fn to_raw_string(value: &str) -> RawString {
    #[cfg(not(target_os = "windows"))]
    {
        return CString::new(value).expect("failed to get model filename");
    }

    #[cfg(target_os = "windows")]
    {
        return Utf16String::from_str(value);
    }

    unreachable!()
}

impl From<&str> for Model {
    fn from(value: &str) -> Self {
        let value = to_raw_string(value);
        Model::File(value)
    }
}

impl Session {
    pub fn builder() -> SessionBuilder {
        SessionBuilder::default()
    }
}

#[cfg(test)]
#[test]
fn test_session_ok() -> Result<()> {
    std::env::set_var("RUST_LOG", "trace");
    let _ = tracing_subscriber::fmt::try_init();
    use ffi::ONNXTensorElementDataType;
    let session = Session::builder()
        // .with_cuda(0)
        .with_environment(
            Environment::builder()
                .with_level(ort2_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE)
                .build()?,
        )
        .build(include_bytes!("../models/mnist-8.onnx").as_ref())?;
    let input = vec![0.0f32; 28 * 28];
    let input = Value::tensor()
        .with_shape([1, 1, 28, 28])
        .with_typ(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
        .borrow(&input)?;

    let outputs = session.run([&input])?;
    let outputs = outputs
        .iter()
        .map(|o| o.view::<f32>())
        .collect::<Result<Vec<_>>>()?;
    tracing::info!(?outputs);

    Ok(())
}