use std::{
ffi::{c_void, CStr, CString},
ptr::{null, null_mut},
};
use ort2_sys::{self as ffi, GraphOptimizationLevel};
use smart_default::SmartDefault;
use tracing::trace;
#[cfg(target_os = "windows")]
use widestring::Utf16String;
use crate::{
allocator::{Allocator, AllocatorTrait, DefaultAllocator},
api::{api, ok},
environment::Environment,
error::Result,
iobinding::IoBinding,
memory::MemoryInfo,
value::{TypeInfo, Value},
};
#[derive(Debug)]
pub struct IoTypeInfo {
pub typ: TypeInfo,
pub name: CString,
}
pub struct Session {
inner: *mut ffi::OrtSession,
env: Environment,
}
impl std::fmt::Debug for Session {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Session")
.field("inner", &self.inner)
.field("env", &self.env)
.field(
"inputs",
&self.get_inputs().expect("failed to get inputs info"),
)
.field(
"outputs",
&self.get_outputs().expect("failed to get outputs info"),
)
.finish()
}
}
impl Drop for Session {
fn drop(&mut self) {
trace!(?self, "dropping");
api!(ReleaseSession, self.inner);
}
}
impl Session {
pub fn inner(&self) -> *mut ffi::OrtSession {
self.inner
}
pub fn allocator(&self, mem_info: &MemoryInfo) -> Result<Allocator> {
Allocator::new(self, mem_info)
}
pub fn run<'a, 'b, 'c>(&'c self, inputs: impl AsRef<[&'b Value<'a>]>) -> Result<Vec<Value<'c>>>
where
'a: 'b,
{
let input_infos = self.get_inputs()?;
let input_names = input_infos
.iter()
.map(|t| t.name.as_ptr())
.collect::<Vec<_>>();
let inputs = inputs
.as_ref()
.iter()
.map(|i| i.inner() as *const _)
.collect::<Vec<_>>();
let output_infos = self.get_outputs()?;
let output_names = output_infos
.iter()
.map(|t| t.name.as_ptr())
.collect::<Vec<_>>();
let mut outputs = vec![null_mut(); output_names.len()];
ok!(
Run,
self.inner,
null(),
input_names.as_ptr(),
inputs.as_ptr(),
inputs.len(),
output_names.as_ptr(),
output_names.len(),
outputs.as_mut_ptr()
)?;
Ok(outputs
.into_iter()
.map(|inner| Value::new(inner, self))
.collect())
}
pub fn get_inputs(&self) -> Result<Vec<IoTypeInfo>> {
let mut count = 0usize;
ok!(SessionGetInputCount, self.inner, &mut count)?;
let alloc = DefaultAllocator::default();
(0..count)
.map(|n| -> Result<_> {
let mut name_ = null_mut();
ok!(
SessionGetInputName,
self.inner(),
n,
alloc.inner(),
&mut name_
)?;
let name = unsafe { CStr::from_ptr(name_) }.to_owned();
let mut inner = null_mut();
ok!(SessionGetInputTypeInfo, self.inner, n, &mut inner)?;
let typ = TypeInfo::new(inner);
Ok(IoTypeInfo { name, typ })
})
.collect::<Result<Vec<_>>>()
}
pub fn get_outputs(&self) -> Result<Vec<IoTypeInfo>> {
let mut count = 0usize;
ok!(SessionGetOutputCount, self.inner, &mut count)?;
let alloc = DefaultAllocator::default();
(0..count)
.map(|n| -> Result<_> {
let mut name_ = null_mut();
ok!(
SessionGetOutputName,
self.inner(),
n,
alloc.inner(),
&mut name_
)?;
let name = unsafe { CStr::from_ptr(name_) }.to_owned();
let mut inner = null_mut();
ok!(SessionGetOutputTypeInfo, self.inner, n, &mut inner)?;
let typ = TypeInfo::new(inner);
Ok(IoTypeInfo { name, typ })
})
.collect::<Result<Vec<_>>>()
}
pub fn iobinding(&self) -> Result<IoBinding> {
IoBinding::new(self)
}
pub fn run_with_iobinding(&self, iobinding: &mut IoBinding) -> Result<()> {
ok!(RunWithBinding, self.inner, null_mut(), iobinding.inner())
}
}
#[derive(SmartDefault)]
pub struct SessionBuilder {
env: Option<Environment>,
use_cuda: bool,
use_tensor_rt: bool,
cuda_device: i32,
#[default(GraphOptimizationLevel::ORT_DISABLE_ALL)]
graph_optimize_level: GraphOptimizationLevel,
intra_thread_num: i32,
}
#[cfg(target_os = "windows")]
pub type RawString = widestring::Utf16String;
#[cfg(not(target_os = "windows"))]
pub type RawString = CString;
pub enum Model {
File(RawString),
Binary { data: *const c_void, len: usize },
}
impl SessionBuilder {
pub fn build(self, model: impl Into<Model>) -> Result<Session> {
let (env, options) = self.prepare()?;
let mut inner = null_mut();
let model: Model = model.into();
match model {
Model::File(file) => ok!(CreateSession, env.inner, file.as_ptr(), options, &mut inner)?,
Model::Binary { data, len } => ok!(
CreateSessionFromArray,
env.inner,
data,
len,
options,
&mut inner
)?,
};
api!(ReleaseSessionOptions, options);
Ok(Session { inner, env })
}
pub fn prepare(self) -> Result<(Environment, *mut ffi::OrtSessionOptions)> {
let Self {
env,
use_cuda,
use_tensor_rt,
cuda_device,
graph_optimize_level,
intra_thread_num,
} = self;
let mut options = null_mut();
ok!(CreateSessionOptions, &mut options)?;
if use_cuda {
let mut provider = null_mut();
ok!(CreateCUDAProviderOptions, &mut provider)?;
let keys = [CString::new("device_id")?];
let values = [CString::new(cuda_device.to_string())?];
let keys = keys.iter().map(|k| k.as_ptr()).collect::<Vec<_>>();
let values = values.iter().map(|v| v.as_ptr()).collect::<Vec<_>>();
ok!(
UpdateCUDAProviderOptions,
provider,
keys.as_ptr(),
values.as_ptr(),
keys.len()
)?;
ok!(
SessionOptionsAppendExecutionProvider_CUDA_V2,
options,
provider
)?;
api!(ReleaseCUDAProviderOptions, provider);
}
if use_tensor_rt {
let mut provider = null_mut();
ok!(CreateTensorRTProviderOptions, &mut provider)?;
let keys = [CString::new("device_id")?];
let values = [CString::new(cuda_device.to_string())?];
let keys = keys.iter().map(|k| k.as_ptr()).collect::<Vec<_>>();
let values = values.iter().map(|v| v.as_ptr()).collect::<Vec<_>>();
ok!(
UpdateTensorRTProviderOptions,
provider,
keys.as_ptr(),
values.as_ptr(),
values.len()
)?;
ok!(
SessionOptionsAppendExecutionProvider_TensorRT_V2,
options,
provider
)?;
api!(ReleaseTensorRTProviderOptions, provider);
}
ok!(SetIntraOpNumThreads, options, intra_thread_num)?;
ok!(
SetSessionGraphOptimizationLevel,
options,
graph_optimize_level
)?;
let env = match env {
Some(e) => e,
None => Environment::builder().build()?,
};
Ok((env, options))
}
pub fn with_cuda(mut self, device: i32) -> Self {
self.use_cuda = true;
self.cuda_device = device;
self
}
pub fn with_graph_optimize_level(mut self, level: GraphOptimizationLevel) -> Self {
self.graph_optimize_level = level;
self
}
pub fn with_intra_thread_num(mut self, n_threads: i32) -> Self {
self.intra_thread_num = n_threads;
self
}
pub fn with_tensor_rt(mut self, device: i32) -> Self {
self.use_tensor_rt = true;
self.cuda_device = device;
self
}
pub fn with_environment(mut self, env: Environment) -> Self {
self.env = Some(env);
self
}
}
impl From<&[u8]> for Model {
fn from(value: &[u8]) -> Self {
Model::Binary {
data: value.as_ptr() as *const _,
len: value.len(),
}
}
}
#[allow(unreachable_code)]
fn to_raw_string(value: &str) -> RawString {
#[cfg(not(target_os = "windows"))]
{
return CString::new(value).expect("failed to get model filename");
}
#[cfg(target_os = "windows")]
{
return Utf16String::from_str(value);
}
unreachable!()
}
impl From<&str> for Model {
fn from(value: &str) -> Self {
let value = to_raw_string(value);
Model::File(value)
}
}
impl Session {
pub fn builder() -> SessionBuilder {
SessionBuilder::default()
}
}
#[cfg(test)]
#[test]
fn test_session_ok() -> Result<()> {
std::env::set_var("RUST_LOG", "trace");
let _ = tracing_subscriber::fmt::try_init();
use ffi::ONNXTensorElementDataType;
let session = Session::builder()
.with_environment(
Environment::builder()
.with_level(ort2_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE)
.build()?,
)
.build(include_bytes!("../models/mnist-8.onnx").as_ref())?;
let input = vec![0.0f32; 28 * 28];
let input = Value::tensor()
.with_shape([1, 1, 28, 28])
.with_typ(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
.borrow(&input)?;
let outputs = session.run([&input])?;
let outputs = outputs
.iter()
.map(|o| o.view::<f32>())
.collect::<Result<Vec<_>>>()?;
tracing::info!(?outputs);
Ok(())
}