#![allow(clippy::tabs_in_doc_comments)]
#[cfg(not(target_family = "windows"))]
use std::os::unix::ffi::OsStrExt;
#[cfg(target_family = "windows")]
use std::os::windows::ffi::OsStrExt;
#[cfg(feature = "fetch-models")]
use std::{env, path::PathBuf, time::Duration};
use std::{
ffi::CString,
fmt::{self, Debug},
marker::PhantomData,
ops::Deref,
os::raw::c_char,
path::Path,
sync::Arc
};
use tracing::error;
use super::{
char_p_to_string,
environment::Environment,
error::{assert_non_null_pointer, assert_null_pointer, status_to_result, OrtApiError, OrtError, OrtResult},
execution_providers::{apply_execution_providers, ExecutionProvider},
extern_system_fn,
metadata::Metadata,
ort, ortsys, sys,
tensor::TensorElementDataType,
AllocatorType, GraphOptimizationLevel, MemType
};
#[cfg(feature = "fetch-models")]
use super::{download::ModelUrl, error::OrtDownloadError};
use crate::{io_binding::IoBinding, value::Value};
pub struct SessionBuilder {
env: Arc<Environment>,
session_options_ptr: *mut sys::OrtSessionOptions,
allocator: AllocatorType,
memory_type: MemType,
custom_runtime_handles: Vec<*mut std::os::raw::c_void>,
execution_providers: Vec<ExecutionProvider>
}
impl Debug for SessionBuilder {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
f.debug_struct("SessionBuilder")
.field("env", &self.env.name())
.field("allocator", &self.allocator)
.field("memory_type", &self.memory_type)
.finish()
}
}
impl Drop for SessionBuilder {
#[tracing::instrument]
fn drop(&mut self) {
for &handle in self.custom_runtime_handles.iter() {
close_lib_handle(handle);
}
if self.session_options_ptr.is_null() {
error!("Session options pointer is null, not dropping");
} else {
ortsys![unsafe ReleaseSessionOptions(self.session_options_ptr)];
}
}
}
impl SessionBuilder {
pub fn new(env: &Arc<Environment>) -> OrtResult<Self> {
let mut session_options_ptr: *mut sys::OrtSessionOptions = std::ptr::null_mut();
ortsys![unsafe CreateSessionOptions(&mut session_options_ptr) -> OrtError::CreateSessionOptions; nonNull(session_options_ptr)];
Ok(Self {
env: Arc::clone(env),
session_options_ptr,
allocator: AllocatorType::Device,
memory_type: MemType::Default,
custom_runtime_handles: Vec::new(),
execution_providers: Vec::new()
})
}
pub fn with_execution_providers(mut self, execution_providers: impl AsRef<[ExecutionProvider]>) -> OrtResult<Self> {
self.execution_providers = execution_providers.as_ref().to_vec();
Ok(self)
}
pub fn with_intra_threads(self, num_threads: i16) -> OrtResult<Self> {
let num_threads = num_threads as i32;
ortsys![unsafe SetIntraOpNumThreads(self.session_options_ptr, num_threads) -> OrtError::CreateSessionOptions];
Ok(self)
}
pub fn with_disable_per_session_threads(self) -> OrtResult<Self> {
ortsys![unsafe DisablePerSessionThreads(self.session_options_ptr) -> OrtError::CreateSessionOptions];
Ok(self)
}
pub fn with_inter_threads(self, num_threads: i16) -> OrtResult<Self> {
let num_threads = num_threads as i32;
ortsys![unsafe SetInterOpNumThreads(self.session_options_ptr, num_threads) -> OrtError::CreateSessionOptions];
Ok(self)
}
pub fn with_parallel_execution(self, parallel_execution: bool) -> OrtResult<Self> {
let execution_mode = if parallel_execution {
sys::ExecutionMode::ORT_PARALLEL
} else {
sys::ExecutionMode::ORT_SEQUENTIAL
};
ortsys![unsafe SetSessionExecutionMode(self.session_options_ptr, execution_mode) -> OrtError::CreateSessionOptions];
Ok(self)
}
pub fn with_optimization_level(self, opt_level: GraphOptimizationLevel) -> OrtResult<Self> {
ortsys![unsafe SetSessionGraphOptimizationLevel(self.session_options_ptr, opt_level.into()) -> OrtError::CreateSessionOptions];
Ok(self)
}
#[cfg(feature = "profiling")]
pub fn with_profiling<S: AsRef<str>>(self, profiling_file: S) -> OrtResult<Self> {
#[cfg(windows)]
let profiling_file = widestring::WideCString::from_str(profiling_file.as_ref())?;
#[cfg(not(windows))]
let profiling_file = CString::new(profiling_file.as_ref())?;
ortsys![unsafe EnableProfiling(self.session_options_ptr, profiling_file.as_ptr()) -> OrtError::CreateSessionOptions];
Ok(self)
}
pub fn with_memory_pattern(self, enable: bool) -> OrtResult<Self> {
if enable {
ortsys![unsafe EnableMemPattern(self.session_options_ptr) -> OrtError::CreateSessionOptions];
} else {
ortsys![unsafe DisableMemPattern(self.session_options_ptr) -> OrtError::CreateSessionOptions];
}
Ok(self)
}
pub fn with_allocator(mut self, allocator: AllocatorType) -> OrtResult<Self> {
self.allocator = allocator;
Ok(self)
}
pub fn with_memory_type(mut self, memory_type: MemType) -> OrtResult<Self> {
self.memory_type = memory_type;
Ok(self)
}
pub fn with_custom_op_lib(mut self, lib_path: &str) -> OrtResult<Self> {
let path_cstr = CString::new(lib_path)?;
let mut handle: *mut ::std::os::raw::c_void = std::ptr::null_mut();
let status = ortsys![unsafe RegisterCustomOpsLibrary(self.session_options_ptr, path_cstr.as_ptr(), &mut handle)];
match status_to_result(status).map_err(OrtError::CreateSessionOptions) {
Ok(_) => {}
Err(e) => {
if !handle.is_null() {
close_lib_handle(handle);
}
return Err(e);
}
}
self.custom_runtime_handles.push(handle);
Ok(self)
}
#[cfg(feature = "fetch-models")]
pub fn with_model_downloaded<M>(self, model: M) -> OrtResult<Session>
where
M: ModelUrl
{
self.with_model_downloaded_monomorphized(model.fetch_url())
}
#[cfg(feature = "fetch-models")]
fn with_model_downloaded_monomorphized(self, model: &str) -> OrtResult<Session> {
let download_dir = env::current_dir().map_err(OrtDownloadError::IoError)?;
let downloaded_path = self.download_to(model, download_dir)?;
self.with_model_from_file(downloaded_path)
}
#[cfg(feature = "fetch-models")]
#[tracing::instrument]
fn download_to<P>(&self, url: &str, download_dir: P) -> OrtResult<PathBuf>
where
P: AsRef<Path> + std::fmt::Debug
{
let model_filename = PathBuf::from(url.split('/').last().unwrap());
let model_filepath = download_dir.as_ref().join(model_filename);
if model_filepath.exists() {
tracing::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), "Model already exists, skipping download");
Ok(model_filepath)
} else {
tracing::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), url = format!("{:?}", url).as_str(), "Downloading model");
let resp = ureq::get(url)
.timeout(Duration::from_secs(180))
.call()
.map_err(Box::new)
.map_err(OrtDownloadError::FetchError)?;
assert!(resp.has("Content-Length"));
let len = resp.header("Content-Length").and_then(|s| s.parse::<usize>().ok()).unwrap();
tracing::info!(len, "Downloading {} bytes", len);
let mut reader = resp.into_reader();
let f = std::fs::File::create(&model_filepath).unwrap();
let mut writer = std::io::BufWriter::new(f);
let bytes_io_count = std::io::copy(&mut reader, &mut writer).map_err(OrtDownloadError::IoError)?;
if bytes_io_count == len as u64 {
Ok(model_filepath)
} else {
Err(OrtDownloadError::CopyError {
expected: len as u64,
io: bytes_io_count
}
.into())
}
}
}
pub fn with_model_from_file<P>(self, model_filepath_ref: P) -> OrtResult<Session>
where
P: AsRef<Path>
{
let model_filepath = model_filepath_ref.as_ref();
if !model_filepath.exists() {
return Err(OrtError::FileDoesNotExist {
filename: model_filepath.to_path_buf()
});
}
let model_path = std::ffi::OsString::from(model_filepath);
#[cfg(target_family = "windows")]
let model_path: Vec<u16> = model_path
.encode_wide()
.chain(std::iter::once(0)) .collect();
#[cfg(not(target_family = "windows"))]
let model_path: Vec<std::os::raw::c_char> = model_path
.as_bytes()
.iter()
.chain(std::iter::once(&b'\0')) .map(|b| *b as std::os::raw::c_char)
.collect();
apply_execution_providers(
self.session_options_ptr,
self.execution_providers
.iter()
.chain(&self.env.execution_providers)
.cloned()
.collect::<Vec<_>>()
);
let env_ptr: *const sys::OrtEnv = self.env.ptr();
let mut session_ptr: *mut sys::OrtSession = std::ptr::null_mut();
ortsys![unsafe CreateSession(env_ptr, model_path.as_ptr(), self.session_options_ptr, &mut session_ptr) -> OrtError::CreateSession; nonNull(session_ptr)];
let mut allocator_ptr: *mut sys::OrtAllocator = std::ptr::null_mut();
ortsys![unsafe GetAllocatorWithDefaultOptions(&mut allocator_ptr) -> OrtError::CreateAllocator; nonNull(allocator_ptr)];
let num_input_nodes = dangerous::extract_inputs_count(session_ptr)?;
let num_output_nodes = dangerous::extract_outputs_count(session_ptr)?;
let inputs = (0..num_input_nodes)
.map(|i| dangerous::extract_input(session_ptr, allocator_ptr, i))
.collect::<OrtResult<Vec<Input>>>()?;
let outputs = (0..num_output_nodes)
.map(|i| dangerous::extract_output(session_ptr, allocator_ptr, i))
.collect::<OrtResult<Vec<Output>>>()?;
Ok(Session {
env: Arc::clone(&self.env),
session_ptr: Arc::new(SessionPointerHolder { inner: session_ptr }),
allocator_ptr,
inputs,
outputs
})
}
pub fn with_model_from_file_and_external_initializers<'v, 'i, P>(self, model_filepath_ref: P, initializers: &'i [(String, Value<'v>)]) -> OrtResult<Session>
where
'i: 'v,
P: AsRef<Path>
{
let model_filepath = model_filepath_ref.as_ref();
if !model_filepath.exists() {
return Err(OrtError::FileDoesNotExist {
filename: model_filepath.to_path_buf()
});
}
let model_path = std::ffi::OsString::from(model_filepath);
#[cfg(target_family = "windows")]
let model_path: Vec<u16> = model_path
.encode_wide()
.chain(std::iter::once(0)) .collect();
#[cfg(not(target_family = "windows"))]
let model_path: Vec<std::os::raw::c_char> = model_path
.as_bytes()
.iter()
.chain(std::iter::once(&b'\0')) .map(|b| *b as std::os::raw::c_char)
.collect();
apply_execution_providers(
self.session_options_ptr,
self.execution_providers
.iter()
.chain(&self.env.execution_providers)
.cloned()
.collect::<Vec<_>>()
);
let env_ptr: *const sys::OrtEnv = self.env.ptr();
let mut allocator_ptr: *mut sys::OrtAllocator = std::ptr::null_mut();
ortsys![unsafe GetAllocatorWithDefaultOptions(&mut allocator_ptr) -> OrtError::CreateAllocator; nonNull(allocator_ptr)];
let initializer_names: Vec<CString> = initializers
.iter()
.map(|(name, _)| CString::new(name.as_str()).unwrap())
.map(|n| CString::new(n).unwrap())
.collect();
let initializer_names_ptr: Vec<*const c_char> = initializer_names.iter().map(|n| n.as_ptr() as *const c_char).collect();
let initializers: Vec<*const sys::OrtValue> = initializers.iter().map(|input_array_ort| input_array_ort.1.ptr() as *const _).collect();
if !initializers.is_empty() {
assert_eq!(initializer_names.len(), initializers.len());
ortsys![unsafe AddExternalInitializers(self.session_options_ptr, initializer_names_ptr.as_ptr(), initializers.as_ptr(), initializers.len() as _) -> OrtError::CreateSession];
}
let mut session_ptr: *mut sys::OrtSession = std::ptr::null_mut();
ortsys![unsafe CreateSession(env_ptr, model_path.as_ptr(), self.session_options_ptr, &mut session_ptr) -> OrtError::CreateSession; nonNull(session_ptr)];
std::mem::drop(initializer_names);
std::mem::drop(initializers);
let num_input_nodes = dangerous::extract_inputs_count(session_ptr)?;
let num_output_nodes = dangerous::extract_outputs_count(session_ptr)?;
let inputs = (0..num_input_nodes)
.map(|i| dangerous::extract_input(session_ptr, allocator_ptr, i))
.collect::<OrtResult<Vec<Input>>>()?;
let outputs = (0..num_output_nodes)
.map(|i| dangerous::extract_output(session_ptr, allocator_ptr, i))
.collect::<OrtResult<Vec<Output>>>()?;
Ok(Session {
env: Arc::clone(&self.env),
session_ptr: Arc::new(SessionPointerHolder { inner: session_ptr }),
allocator_ptr,
inputs,
outputs
})
}
pub fn with_model_from_memory(self, model_bytes: &[u8]) -> OrtResult<InMemorySession<'_>> {
let mut session_ptr: *mut sys::OrtSession = std::ptr::null_mut();
let env_ptr: *const sys::OrtEnv = self.env.ptr();
apply_execution_providers(
self.session_options_ptr,
self.execution_providers
.iter()
.chain(&self.env.execution_providers)
.cloned()
.collect::<Vec<_>>()
);
let str_to_char = |s: &str| {
s.as_bytes()
.iter()
.chain(std::iter::once(&b'\0')) .map(|b| *b as std::os::raw::c_char)
.collect::<Vec<std::os::raw::c_char>>()
};
ortsys![unsafe AddSessionConfigEntry(self.session_options_ptr, str_to_char("session.use_ort_model_bytes_directly").as_ptr(), str_to_char("1").as_ptr())];
ortsys![unsafe AddSessionConfigEntry(self.session_options_ptr, str_to_char("session.use_ort_model_bytes_for_initializers").as_ptr(), str_to_char("1").as_ptr())];
let model_data = model_bytes.as_ptr() as *const std::ffi::c_void;
let model_data_length = model_bytes.len();
ortsys![
unsafe CreateSessionFromArray(env_ptr, model_data, model_data_length as _, self.session_options_ptr, &mut session_ptr) -> OrtError::CreateSession;
nonNull(session_ptr)
];
let mut allocator_ptr: *mut sys::OrtAllocator = std::ptr::null_mut();
ortsys![unsafe GetAllocatorWithDefaultOptions(&mut allocator_ptr) -> OrtError::CreateAllocator; nonNull(allocator_ptr)];
let num_input_nodes = dangerous::extract_inputs_count(session_ptr)?;
let num_output_nodes = dangerous::extract_outputs_count(session_ptr)?;
let inputs = (0..num_input_nodes)
.map(|i| dangerous::extract_input(session_ptr, allocator_ptr, i))
.collect::<OrtResult<Vec<Input>>>()?;
let outputs = (0..num_output_nodes)
.map(|i| dangerous::extract_output(session_ptr, allocator_ptr, i))
.collect::<OrtResult<Vec<Output>>>()?;
let session = Session {
env: Arc::clone(&self.env),
session_ptr: Arc::new(SessionPointerHolder { inner: session_ptr }),
allocator_ptr,
inputs,
outputs
};
Ok(InMemorySession { session, phantom: PhantomData })
}
}
#[derive(Debug)]
pub struct SessionPointerHolder {
pub inner: *mut sys::OrtSession
}
unsafe impl Send for SessionPointerHolder {}
unsafe impl Sync for SessionPointerHolder {}
impl Drop for SessionPointerHolder {
fn drop(&mut self) {
ortsys![unsafe ReleaseSession(self.inner)];
self.inner = std::ptr::null_mut();
}
}
#[derive(Debug)]
pub struct Session {
#[allow(dead_code)]
env: Arc<Environment>,
pub(crate) session_ptr: Arc<SessionPointerHolder>,
allocator_ptr: *mut sys::OrtAllocator,
pub inputs: Vec<Input>,
pub outputs: Vec<Output>
}
pub struct InMemorySession<'s> {
session: Session,
phantom: PhantomData<&'s ()>
}
impl<'s> Deref for InMemorySession<'s> {
type Target = Session;
fn deref(&self) -> &Self::Target {
&self.session
}
}
#[derive(Debug)]
pub struct Input {
pub name: String,
pub input_type: TensorElementDataType,
pub dimensions: Vec<Option<u32>>
}
#[derive(Debug)]
pub struct Output {
pub name: String,
pub output_type: TensorElementDataType,
pub dimensions: Vec<Option<u32>>
}
impl Input {
pub fn dimensions(&self) -> impl Iterator<Item = Option<usize>> + '_ {
self.dimensions.iter().map(|d| d.map(|d2| d2 as usize))
}
}
impl Output {
pub fn dimensions(&self) -> impl Iterator<Item = Option<usize>> + '_ {
self.dimensions.iter().map(|d| d.map(|d2| d2 as usize))
}
}
impl Drop for Session {
#[tracing::instrument]
fn drop(&mut self) {
self.allocator_ptr = std::ptr::null_mut();
}
}
impl Session {
pub fn allocator(&self) -> *mut sys::OrtAllocator {
self.allocator_ptr
}
pub fn bind(&self) -> OrtResult<IoBinding> {
IoBinding::new(self)
}
pub fn run<'s, 'm, 'v, 'i>(&'s self, input_values: Vec<Value<'v>>) -> OrtResult<Vec<Value<'static>>>
where
's: 'm, 'i: 'v
{
let input_names_ptr: Vec<*const c_char> = self
.inputs
.iter()
.map(|input| input.name.clone())
.map(|n| CString::new(n).unwrap())
.map(|n| n.into_raw() as *const c_char)
.collect();
let output_names_cstring: Vec<CString> = self
.outputs
.iter()
.map(|output| output.name.clone())
.map(|n| CString::new(n).unwrap())
.collect();
let output_names_ptr: Vec<*const c_char> = output_names_cstring.iter().map(|n| n.as_ptr() as *const c_char).collect();
let mut output_tensor_ptrs: Vec<*mut sys::OrtValue> = vec![std::ptr::null_mut(); self.outputs.len()];
let input_ort_values: Vec<*const sys::OrtValue> = input_values.iter().map(|input_array_ort| input_array_ort.ptr() as *const _).collect();
let run_options_ptr: *const sys::OrtRunOptions = std::ptr::null();
ortsys![
unsafe Run(
self.session_ptr.inner,
run_options_ptr,
input_names_ptr.as_ptr(),
input_ort_values.as_ptr(),
input_ort_values.len() as _,
output_names_ptr.as_ptr(),
output_names_ptr.len() as _,
output_tensor_ptrs.as_mut_ptr()
) -> OrtError::SessionRun
];
std::mem::drop(input_ort_values);
let outputs: Vec<Value> = output_tensor_ptrs
.into_iter()
.map(|tensor_ptr| Value::from_raw(tensor_ptr, Arc::clone(&self.session_ptr)))
.collect();
let cstrings: OrtResult<Vec<CString>> = input_names_ptr
.into_iter()
.map(|p| {
assert_non_null_pointer(p, "c_char for CString")?;
unsafe { Ok(CString::from_raw(p as *mut c_char)) }
})
.collect();
cstrings?;
Ok(outputs)
}
pub fn run_with_binding<'s, 'a: 's>(&'a self, binding: &IoBinding<'s>) -> OrtResult<()> {
let run_options_ptr: *const sys::OrtRunOptions = std::ptr::null();
ortsys![unsafe RunWithBinding(self.session_ptr.inner, run_options_ptr, binding.ptr) -> OrtError::SessionRun];
Ok(())
}
pub fn metadata(&self) -> OrtResult<Metadata> {
let mut metadata_ptr: *mut sys::OrtModelMetadata = std::ptr::null_mut();
ortsys![unsafe SessionGetModelMetadata(self.session_ptr.inner, &mut metadata_ptr) -> OrtError::GetModelMetadata; nonNull(metadata_ptr)];
Ok(Metadata::new(metadata_ptr, self.allocator_ptr))
}
#[cfg(feature = "profiling")]
pub fn end_profiling(&self) -> OrtResult<String> {
let mut profiling_name: *mut c_char = std::ptr::null_mut();
ortsys![unsafe SessionEndProfiling(self.session_ptr.inner, self.allocator_ptr, &mut profiling_name)];
assert_non_null_pointer(profiling_name, "ProfilingName")?;
dangerous::raw_pointer_to_string(self.allocator_ptr, profiling_name)
}
}
unsafe impl Send for Session {}
unsafe impl Sync for Session {}
unsafe fn get_tensor_dimensions(tensor_info_ptr: *const sys::OrtTensorTypeAndShapeInfo) -> OrtResult<Vec<i64>> {
let mut num_dims = 0;
ortsys![GetDimensionsCount(tensor_info_ptr, &mut num_dims) -> OrtError::GetDimensionsCount];
let mut node_dims: Vec<i64> = vec![0; num_dims as _];
ortsys![GetDimensions(tensor_info_ptr, node_dims.as_mut_ptr(), num_dims as _) -> OrtError::GetDimensions];
Ok(node_dims)
}
unsafe fn extract_data_type(tensor_info_ptr: *const sys::OrtTensorTypeAndShapeInfo) -> OrtResult<TensorElementDataType> {
let mut type_sys = sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
ortsys![GetTensorElementType(tensor_info_ptr, &mut type_sys) -> OrtError::GetTensorElementType];
assert_ne!(type_sys, sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
Ok(type_sys.into())
}
#[cfg(unix)]
fn close_lib_handle(handle: *mut std::os::raw::c_void) {
unsafe { libc::dlclose(handle) };
}
#[cfg(windows)]
fn close_lib_handle(handle: *mut std::os::raw::c_void) {
unsafe { winapi::um::libloaderapi::FreeLibrary(handle as winapi::shared::minwindef::HINSTANCE) };
}
mod dangerous {
use super::*;
use crate::{ortfree, sys::size_t, tensor::TensorElementDataType};
pub(super) fn extract_inputs_count(session_ptr: *mut sys::OrtSession) -> OrtResult<usize> {
let f = ort().SessionGetInputCount.unwrap();
extract_io_count(f, session_ptr)
}
pub(super) fn extract_outputs_count(session_ptr: *mut sys::OrtSession) -> OrtResult<usize> {
let f = ort().SessionGetOutputCount.unwrap();
extract_io_count(f, session_ptr)
}
fn extract_io_count(
f: extern_system_fn! { unsafe fn(*const sys::OrtSession, *mut size_t) -> *mut sys::OrtStatus },
session_ptr: *mut sys::OrtSession
) -> OrtResult<usize> {
let mut num_nodes = 0;
let status = unsafe { f(session_ptr, &mut num_nodes) };
status_to_result(status).map_err(OrtError::GetInOutCount)?;
assert_null_pointer(status, "SessionStatus")?;
(num_nodes != 0)
.then_some(())
.ok_or_else(|| OrtError::GetInOutCount(OrtApiError::Msg("No nodes in model".to_owned())))?;
Ok(num_nodes as _)
}
fn extract_input_name(session_ptr: *mut sys::OrtSession, allocator_ptr: *mut sys::OrtAllocator, i: size_t) -> OrtResult<String> {
let f = ort().SessionGetInputName.unwrap();
extract_io_name(f, session_ptr, allocator_ptr, i)
}
fn extract_output_name(session_ptr: *mut sys::OrtSession, allocator_ptr: *mut sys::OrtAllocator, i: size_t) -> OrtResult<String> {
let f = ort().SessionGetOutputName.unwrap();
extract_io_name(f, session_ptr, allocator_ptr, i)
}
pub(crate) fn raw_pointer_to_string(allocator_ptr: *mut sys::OrtAllocator, c_str: *mut c_char) -> OrtResult<String> {
let name = char_p_to_string(c_str)?;
ortfree!(unsafe allocator_ptr, c_str);
Ok(name)
}
fn extract_io_name(
f: extern_system_fn! { unsafe fn(
*const sys::OrtSession,
size_t,
*mut sys::OrtAllocator,
*mut *mut c_char,
) -> *mut sys::OrtStatus },
session_ptr: *mut sys::OrtSession,
allocator_ptr: *mut sys::OrtAllocator,
i: size_t
) -> OrtResult<String> {
let mut name_bytes: *mut c_char = std::ptr::null_mut();
let status = unsafe { f(session_ptr, i, allocator_ptr, &mut name_bytes) };
status_to_result(status).map_err(OrtError::GetInputName)?;
assert_non_null_pointer(name_bytes, "InputName")?;
raw_pointer_to_string(allocator_ptr, name_bytes)
}
pub(super) fn extract_input(session_ptr: *mut sys::OrtSession, allocator_ptr: *mut sys::OrtAllocator, i: usize) -> OrtResult<Input> {
let input_name = extract_input_name(session_ptr, allocator_ptr, i as _)?;
let f = ort().SessionGetInputTypeInfo.unwrap();
let (input_type, dimensions) = extract_io(f, session_ptr, i as _)?;
Ok(Input {
name: input_name,
input_type,
dimensions
})
}
pub(super) fn extract_output(session_ptr: *mut sys::OrtSession, allocator_ptr: *mut sys::OrtAllocator, i: usize) -> OrtResult<Output> {
let output_name = extract_output_name(session_ptr, allocator_ptr, i as _)?;
let f = ort().SessionGetOutputTypeInfo.unwrap();
let (output_type, dimensions) = extract_io(f, session_ptr, i as _)?;
Ok(Output {
name: output_name,
output_type,
dimensions
})
}
fn extract_io(
f: extern_system_fn! { unsafe fn(
*const sys::OrtSession,
size_t,
*mut *mut sys::OrtTypeInfo,
) -> *mut sys::OrtStatus },
session_ptr: *mut sys::OrtSession,
i: size_t
) -> OrtResult<(TensorElementDataType, Vec<Option<u32>>)> {
let mut typeinfo_ptr: *mut sys::OrtTypeInfo = std::ptr::null_mut();
let status = unsafe { f(session_ptr, i, &mut typeinfo_ptr) };
status_to_result(status).map_err(OrtError::GetTypeInfo)?;
assert_non_null_pointer(typeinfo_ptr, "TypeInfo")?;
let mut tensor_info_ptr: *const sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut();
ortsys![unsafe CastTypeInfoToTensorInfo(typeinfo_ptr, &mut tensor_info_ptr) -> OrtError::CastTypeInfoToTensorInfo; nonNull(tensor_info_ptr)];
let io_type: TensorElementDataType = unsafe { extract_data_type(tensor_info_ptr)? };
let node_dims = unsafe { get_tensor_dimensions(tensor_info_ptr)? };
ortsys![unsafe ReleaseTypeInfo(typeinfo_ptr)];
Ok((io_type, node_dims.into_iter().map(|d| if d == -1 { None } else { Some(d as u32) }).collect()))
}
}