#![allow(clippy::tabs_in_doc_comments)]
#[cfg(not(target_family = "windows"))]
use std::os::unix::ffi::OsStrExt;
#[cfg(target_family = "windows")]
use std::os::windows::ffi::OsStrExt;
#[cfg(feature = "fetch-models")]
use std::{env, path::PathBuf, time::Duration};
use std::{
ffi::CString,
fmt::{self, Debug},
os::raw::c_char,
path::Path,
sync::Arc
};
use ndarray::IxDyn;
use tracing::{debug, error};
use super::{
char_p_to_string,
environment::Environment,
error::{assert_non_null_pointer, assert_null_pointer, status_to_result, NonMatchingDimensionsError, OrtApiError, OrtError, OrtResult},
execution_providers::{apply_execution_providers, ExecutionProvider},
extern_system_fn,
memory::MemoryInfo,
metadata::Metadata,
ort, ortsys, sys,
tensor::{
type_dynamic_tensor::{InputOrtTensor, InputTensor},
DynOrtTensor, TensorElementDataType
},
AllocatorType, GraphOptimizationLevel, MemType
};
#[cfg(feature = "fetch-models")]
use super::{download::ModelUrl, error::OrtDownloadError};
pub struct SessionBuilder {
env: Arc<Environment>,
session_options_ptr: *mut sys::OrtSessionOptions,
allocator: AllocatorType,
memory_type: MemType,
custom_runtime_handles: Vec<*mut std::os::raw::c_void>,
execution_providers: Vec<ExecutionProvider>
}
impl Debug for SessionBuilder {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
f.debug_struct("SessionBuilder")
.field("env", &self.env.name())
.field("allocator", &self.allocator)
.field("memory_type", &self.memory_type)
.finish()
}
}
impl Drop for SessionBuilder {
#[tracing::instrument]
fn drop(&mut self) {
for &handle in self.custom_runtime_handles.iter() {
close_lib_handle(handle);
}
if self.session_options_ptr.is_null() {
error!("Session options pointer is null, not dropping");
} else {
debug!("Dropping the session options.");
ortsys![unsafe ReleaseSessionOptions(self.session_options_ptr)];
}
}
}
impl SessionBuilder {
pub fn new(env: &Arc<Environment>) -> OrtResult<Self> {
let mut session_options_ptr: *mut sys::OrtSessionOptions = std::ptr::null_mut();
ortsys![unsafe CreateSessionOptions(&mut session_options_ptr) -> OrtError::CreateSessionOptions; nonNull(session_options_ptr)];
Ok(Self {
env: Arc::clone(env),
session_options_ptr,
allocator: AllocatorType::Arena,
memory_type: MemType::Default,
custom_runtime_handles: Vec::new(),
execution_providers: Vec::new()
})
}
pub fn with_execution_providers(mut self, execution_providers: impl AsRef<[ExecutionProvider]>) -> OrtResult<Self> {
self.execution_providers = execution_providers.as_ref().to_vec();
Ok(self)
}
pub fn with_intra_threads(self, num_threads: i16) -> OrtResult<Self> {
let num_threads = num_threads as i32;
ortsys![unsafe SetIntraOpNumThreads(self.session_options_ptr, num_threads) -> OrtError::CreateSessionOptions];
Ok(self)
}
pub fn with_inter_threads(self, num_threads: i16) -> OrtResult<Self> {
let num_threads = num_threads as i32;
ortsys![unsafe SetInterOpNumThreads(self.session_options_ptr, num_threads) -> OrtError::CreateSessionOptions];
Ok(self)
}
pub fn with_parallel_execution(self, parallel_execution: bool) -> OrtResult<Self> {
let execution_mode = if parallel_execution {
sys::ExecutionMode_ORT_PARALLEL
} else {
sys::ExecutionMode_ORT_SEQUENTIAL
};
ortsys![unsafe SetSessionExecutionMode(self.session_options_ptr, execution_mode) -> OrtError::CreateSessionOptions];
Ok(self)
}
pub fn with_optimization_level(self, opt_level: GraphOptimizationLevel) -> OrtResult<Self> {
ortsys![unsafe SetSessionGraphOptimizationLevel(self.session_options_ptr, opt_level.into()) -> OrtError::CreateSessionOptions];
Ok(self)
}
#[cfg(feature = "profiling")]
pub fn with_profiling<S: AsRef<str>>(self, profiling_file: S) -> OrtResult<Self> {
#[cfg(windows)]
let profiling_file = widestring::WideCString::from_str(profiling_file.as_ref())?;
#[cfg(not(windows))]
let profiling_file = CString::new(profiling_file.as_ref())?;
ortsys![unsafe EnableProfiling(self.session_options_ptr, profiling_file.as_ptr()) -> OrtError::CreateSessionOptions];
Ok(self)
}
pub fn with_memory_pattern(self, enable: bool) -> OrtResult<Self> {
if enable {
ortsys![unsafe EnableMemPattern(self.session_options_ptr) -> OrtError::CreateSessionOptions];
} else {
ortsys![unsafe DisableMemPattern(self.session_options_ptr) -> OrtError::CreateSessionOptions];
}
Ok(self)
}
pub fn with_allocator(mut self, allocator: AllocatorType) -> OrtResult<Self> {
self.allocator = allocator;
Ok(self)
}
pub fn with_memory_type(mut self, memory_type: MemType) -> OrtResult<Self> {
self.memory_type = memory_type;
Ok(self)
}
pub fn with_custom_op_lib(mut self, lib_path: &str) -> OrtResult<Self> {
let path_cstr = CString::new(lib_path)?;
let mut handle: *mut ::std::os::raw::c_void = std::ptr::null_mut();
let status = ortsys![unsafe RegisterCustomOpsLibrary(self.session_options_ptr, path_cstr.as_ptr(), &mut handle)];
match status_to_result(status).map_err(OrtError::CreateSessionOptions) {
Ok(_) => {}
Err(e) => {
if !handle.is_null() {
close_lib_handle(handle);
}
return Err(e);
}
}
self.custom_runtime_handles.push(handle);
Ok(self)
}
#[cfg(feature = "fetch-models")]
pub fn with_model_downloaded<M>(self, model: M) -> OrtResult<Session>
where
M: ModelUrl
{
self.with_model_downloaded_monomorphized(model.fetch_url())
}
#[cfg(feature = "fetch-models")]
fn with_model_downloaded_monomorphized(self, model: &str) -> OrtResult<Session> {
let download_dir = env::current_dir().map_err(OrtDownloadError::IoError)?;
let downloaded_path = self.download_to(model, download_dir)?;
self.with_model_from_file(downloaded_path)
}
#[cfg(feature = "fetch-models")]
#[tracing::instrument]
fn download_to<P>(&self, url: &str, download_dir: P) -> OrtResult<PathBuf>
where
P: AsRef<Path> + std::fmt::Debug
{
let model_filename = PathBuf::from(url.split('/').last().unwrap());
let model_filepath = download_dir.as_ref().join(model_filename);
if model_filepath.exists() {
tracing::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), "Model already exists, skipping download");
Ok(model_filepath)
} else {
tracing::info!(model_filepath = format!("{}", model_filepath.display()).as_str(), url = format!("{:?}", url).as_str(), "Downloading model");
let resp = ureq::get(url)
.timeout(Duration::from_secs(180))
.call()
.map_err(Box::new)
.map_err(OrtDownloadError::FetchError)?;
assert!(resp.has("Content-Length"));
let len = resp.header("Content-Length").and_then(|s| s.parse::<usize>().ok()).unwrap();
tracing::info!(len, "Downloading {} bytes", len);
let mut reader = resp.into_reader();
let f = std::fs::File::create(&model_filepath).unwrap();
let mut writer = std::io::BufWriter::new(f);
let bytes_io_count = std::io::copy(&mut reader, &mut writer).map_err(OrtDownloadError::IoError)?;
if bytes_io_count == len as u64 {
Ok(model_filepath)
} else {
Err(OrtDownloadError::CopyError {
expected: len as u64,
io: bytes_io_count
}
.into())
}
}
}
pub fn with_model_from_file<P>(self, model_filepath_ref: P) -> OrtResult<Session>
where
P: AsRef<Path>
{
let model_filepath = model_filepath_ref.as_ref();
if !model_filepath.exists() {
return Err(OrtError::FileDoesNotExist {
filename: model_filepath.to_path_buf()
});
}
let model_path = std::ffi::OsString::from(model_filepath);
#[cfg(target_family = "windows")]
let model_path: Vec<u16> = model_path
.encode_wide()
.chain(std::iter::once(0)) .collect();
#[cfg(not(target_family = "windows"))]
let model_path: Vec<std::os::raw::c_char> = model_path
.as_bytes()
.iter()
.chain(std::iter::once(&b'\0')) .map(|b| *b as std::os::raw::c_char)
.collect();
apply_execution_providers(
self.session_options_ptr,
self.execution_providers
.iter()
.chain(&self.env.execution_providers)
.cloned()
.collect::<Vec<_>>()
);
let env_ptr: *const sys::OrtEnv = self.env.env_ptr();
let mut session_ptr: *mut sys::OrtSession = std::ptr::null_mut();
ortsys![unsafe CreateSession(env_ptr, model_path.as_ptr(), self.session_options_ptr, &mut session_ptr) -> OrtError::CreateSession; nonNull(session_ptr)];
let mut allocator_ptr: *mut sys::OrtAllocator = std::ptr::null_mut();
ortsys![unsafe GetAllocatorWithDefaultOptions(&mut allocator_ptr) -> OrtError::GetAllocator; nonNull(allocator_ptr)];
let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default)?;
let num_input_nodes = dangerous::extract_inputs_count(session_ptr)?;
let num_output_nodes = dangerous::extract_outputs_count(session_ptr)?;
let inputs = (0..num_input_nodes)
.map(|i| dangerous::extract_input(session_ptr, allocator_ptr, i))
.collect::<OrtResult<Vec<Input>>>()?;
let outputs = (0..num_output_nodes)
.map(|i| dangerous::extract_output(session_ptr, allocator_ptr, i))
.collect::<OrtResult<Vec<Output>>>()?;
Ok(Session {
env: Arc::clone(&self.env),
session_ptr,
allocator_ptr,
memory_info,
inputs,
outputs
})
}
pub fn with_model_from_memory<B>(self, model_bytes: B) -> OrtResult<Session>
where
B: AsRef<[u8]>
{
self.with_model_from_memory_monomorphized(model_bytes.as_ref())
}
fn with_model_from_memory_monomorphized(self, model_bytes: &[u8]) -> OrtResult<Session> {
let mut session_ptr: *mut sys::OrtSession = std::ptr::null_mut();
let env_ptr: *const sys::OrtEnv = self.env.env_ptr();
apply_execution_providers(
self.session_options_ptr,
self.execution_providers
.iter()
.chain(&self.env.execution_providers)
.cloned()
.collect::<Vec<_>>()
);
let model_data = model_bytes.as_ptr() as *const std::ffi::c_void;
let model_data_length = model_bytes.len();
ortsys![
unsafe CreateSessionFromArray(env_ptr, model_data, model_data_length as _, self.session_options_ptr, &mut session_ptr) -> OrtError::CreateSession;
nonNull(session_ptr)
];
let mut allocator_ptr: *mut sys::OrtAllocator = std::ptr::null_mut();
ortsys![unsafe GetAllocatorWithDefaultOptions(&mut allocator_ptr) -> OrtError::GetAllocator; nonNull(allocator_ptr)];
let memory_info = MemoryInfo::new(AllocatorType::Arena, MemType::Default)?;
let num_input_nodes = dangerous::extract_inputs_count(session_ptr)?;
let num_output_nodes = dangerous::extract_outputs_count(session_ptr)?;
let inputs = (0..num_input_nodes)
.map(|i| dangerous::extract_input(session_ptr, allocator_ptr, i))
.collect::<OrtResult<Vec<Input>>>()?;
let outputs = (0..num_output_nodes)
.map(|i| dangerous::extract_output(session_ptr, allocator_ptr, i))
.collect::<OrtResult<Vec<Output>>>()?;
Ok(Session {
env: Arc::clone(&self.env),
session_ptr,
allocator_ptr,
memory_info,
inputs,
outputs
})
}
}
#[derive(Debug)]
pub struct Session {
#[allow(dead_code)]
env: Arc<Environment>,
session_ptr: *mut sys::OrtSession,
allocator_ptr: *mut sys::OrtAllocator,
memory_info: MemoryInfo,
pub inputs: Vec<Input>,
pub outputs: Vec<Output>
}
#[derive(Debug)]
pub struct Input {
pub name: String,
pub input_type: TensorElementDataType,
pub dimensions: Vec<Option<u32>>
}
#[derive(Debug)]
pub struct Output {
pub name: String,
pub output_type: TensorElementDataType,
pub dimensions: Vec<Option<u32>>
}
impl Input {
pub fn dimensions(&self) -> impl Iterator<Item = Option<usize>> + '_ {
self.dimensions.iter().map(|d| d.map(|d2| d2 as usize))
}
}
impl Output {
pub fn dimensions(&self) -> impl Iterator<Item = Option<usize>> + '_ {
self.dimensions.iter().map(|d| d.map(|d2| d2 as usize))
}
}
impl Drop for Session {
#[tracing::instrument]
fn drop(&mut self) {
debug!("Dropping the session.");
if self.session_ptr.is_null() {
error!("Session pointer is null, not dropping.");
} else {
ortsys![unsafe ReleaseSession(self.session_ptr)];
}
self.session_ptr = std::ptr::null_mut();
self.allocator_ptr = std::ptr::null_mut();
}
}
impl Session {
pub fn run<'s, 'm>(&'s self, input_arrays: impl AsRef<[InputTensor]>) -> OrtResult<Vec<DynOrtTensor<'m, IxDyn>>>
where
's: 'm {
let input_arrays = input_arrays.as_ref();
self.validate_input_shapes(input_arrays)?;
let input_names_ptr: Vec<*const c_char> = self
.inputs
.iter()
.map(|input| input.name.clone())
.map(|n| CString::new(n).unwrap())
.map(|n| n.into_raw() as *const c_char)
.collect();
let output_names_cstring: Vec<CString> = self
.outputs
.iter()
.map(|output| output.name.clone())
.map(|n| CString::new(n).unwrap())
.collect();
let output_names_ptr: Vec<*const c_char> = output_names_cstring.iter().map(|n| n.as_ptr() as *const c_char).collect();
let mut output_tensor_ptrs: Vec<*mut sys::OrtValue> = vec![std::ptr::null_mut(); self.outputs.len()];
let input_ort_tensors: Vec<InputOrtTensor> = input_arrays
.iter()
.map(|input_tensor| InputOrtTensor::from_input_tensor(&self.memory_info, self.allocator_ptr, input_tensor))
.collect::<OrtResult<Vec<InputOrtTensor>>>()?;
let input_ort_values: Vec<*const sys::OrtValue> = input_ort_tensors.iter().map(|input_array_ort| input_array_ort.c_ptr()).collect();
let run_options_ptr: *const sys::OrtRunOptions = std::ptr::null();
ortsys![
unsafe Run(
self.session_ptr,
run_options_ptr,
input_names_ptr.as_ptr(),
input_ort_values.as_ptr(),
input_ort_values.len() as _,
output_names_ptr.as_ptr(),
output_names_ptr.len() as _,
output_tensor_ptrs.as_mut_ptr()
) -> OrtError::SessionRun
];
let memory_info_ref = &self.memory_info;
let outputs: OrtResult<Vec<DynOrtTensor<ndarray::Dim<ndarray::IxDynImpl>>>> = output_tensor_ptrs
.into_iter()
.map(|tensor_ptr| {
let (dims, data_type, len) = unsafe {
call_with_tensor_info(tensor_ptr, |tensor_info_ptr| {
get_tensor_dimensions(tensor_info_ptr)
.map(|dims| dims.iter().map(|&n| n as usize).collect::<Vec<_>>())
.and_then(|dims| extract_data_type(tensor_info_ptr).map(|data_type| (dims, data_type)))
.and_then(|(dims, data_type)| {
let mut len = 0;
ortsys![GetTensorShapeElementCount(tensor_info_ptr, &mut len) -> OrtError::GetTensorShapeElementCount];
Ok((dims, data_type, len))
})
})
}?;
Ok(DynOrtTensor::new(tensor_ptr, memory_info_ref, ndarray::IxDyn(&dims), len as _, data_type))
})
.collect();
let cstrings: OrtResult<Vec<CString>> = input_names_ptr
.into_iter()
.map(|p| {
assert_non_null_pointer(p, "c_char for CString")?;
unsafe { Ok(CString::from_raw(p as *mut c_char)) }
})
.collect();
cstrings?;
outputs
}
fn validate_input_shapes(&self, input_arrays: impl AsRef<[InputTensor]>) -> OrtResult<()> {
let input_arrays = input_arrays.as_ref();
if input_arrays.len() != self.inputs.len() {
error!("Non-matching number of inputs: {} (inference) vs {} (model)", input_arrays.len(), self.inputs.len());
return Err(OrtError::NonMatchingDimensions(NonMatchingDimensionsError::InputsCount {
inference_input_count: 0,
model_input_count: 0,
inference_input: input_arrays.iter().map(|input_array| input_array.shape().to_vec()).collect(),
model_input: self.inputs.iter().map(|input| input.dimensions.clone()).collect()
}));
}
let inputs_different_length = input_arrays.iter().zip(self.inputs.iter()).any(|(l, r)| match l {
InputTensor::FloatTensor(input) => input.shape().len() != r.dimensions.len(),
#[cfg(feature = "half")]
InputTensor::Float16Tensor(input) => input.shape().len() != r.dimensions.len(),
#[cfg(feature = "half")]
InputTensor::Bfloat16Tensor(input) => input.shape().len() != r.dimensions.len(),
InputTensor::Uint8Tensor(input) => input.shape().len() != r.dimensions.len(),
InputTensor::Int8Tensor(input) => input.shape().len() != r.dimensions.len(),
InputTensor::Uint16Tensor(input) => input.shape().len() != r.dimensions.len(),
InputTensor::Int16Tensor(input) => input.shape().len() != r.dimensions.len(),
InputTensor::Int32Tensor(input) => input.shape().len() != r.dimensions.len(),
InputTensor::Int64Tensor(input) => input.shape().len() != r.dimensions.len(),
InputTensor::DoubleTensor(input) => input.shape().len() != r.dimensions.len(),
InputTensor::Uint32Tensor(input) => input.shape().len() != r.dimensions.len(),
InputTensor::Uint64Tensor(input) => input.shape().len() != r.dimensions.len(),
InputTensor::StringTensor(input) => input.shape().len() != r.dimensions.len()
});
if inputs_different_length {
error!("Different input lengths: {:?} vs {:?}", self.inputs, input_arrays);
return Err(OrtError::NonMatchingDimensions(NonMatchingDimensionsError::InputsLength {
inference_input: input_arrays.iter().map(|input_array| input_array.shape().to_vec()).collect(),
model_input: self.inputs.iter().map(|input| input.dimensions.clone()).collect()
}));
}
let inputs_different_shape = input_arrays.iter().zip(self.inputs.iter()).any(|(l, r)| {
let l_shape = l.shape();
let r_shape = r.dimensions.as_slice();
l_shape.iter().zip(r_shape.iter()).any(|(l2, r2)| match r2 {
Some(r3) => *r3 as usize != *l2,
None => false })
});
if inputs_different_shape {
error!("Different input lengths: {:?} vs {:?}", self.inputs, input_arrays);
return Err(OrtError::NonMatchingDimensions(NonMatchingDimensionsError::InputsLength {
inference_input: input_arrays.iter().map(|input_array| input_array.shape().to_vec()).collect(),
model_input: self.inputs.iter().map(|input| input.dimensions.clone()).collect()
}));
}
Ok(())
}
pub fn metadata(&self) -> OrtResult<Metadata> {
let mut metadata_ptr: *mut sys::OrtModelMetadata = std::ptr::null_mut();
ortsys![unsafe SessionGetModelMetadata(self.session_ptr, &mut metadata_ptr) -> OrtError::GetModelMetadata; nonNull(metadata_ptr)];
Ok(Metadata::new(metadata_ptr, self.allocator_ptr))
}
#[cfg(feature = "profiling")]
pub fn end_profiling(&self) -> OrtResult<String> {
let mut profiling_name: *mut c_char = std::ptr::null_mut();
ortsys![unsafe SessionEndProfiling(self.session_ptr, self.allocator_ptr, &mut profiling_name)];
assert_non_null_pointer(profiling_name, "ProfilingName")?;
dangerous::raw_pointer_to_string(self.allocator_ptr, profiling_name)
}
}
unsafe impl Send for Session {}
unsafe impl Sync for Session {}
unsafe fn get_tensor_dimensions(tensor_info_ptr: *const sys::OrtTensorTypeAndShapeInfo) -> OrtResult<Vec<i64>> {
let mut num_dims = 0;
ortsys![GetDimensionsCount(tensor_info_ptr, &mut num_dims) -> OrtError::GetDimensionsCount];
let mut node_dims: Vec<i64> = vec![0; num_dims as _];
ortsys![GetDimensions(tensor_info_ptr, node_dims.as_mut_ptr(), num_dims) -> OrtError::GetDimensions];
Ok(node_dims)
}
unsafe fn extract_data_type(tensor_info_ptr: *const sys::OrtTensorTypeAndShapeInfo) -> OrtResult<TensorElementDataType> {
let mut type_sys = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
ortsys![GetTensorElementType(tensor_info_ptr, &mut type_sys) -> OrtError::GetTensorElementType];
assert_ne!(type_sys, sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
Ok(std::mem::transmute(type_sys))
}
unsafe fn call_with_tensor_info<F, T>(tensor_ptr: *const sys::OrtValue, mut f: F) -> OrtResult<T>
where
F: FnMut(*const sys::OrtTensorTypeAndShapeInfo) -> OrtResult<T>
{
let mut tensor_info_ptr: *mut sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut();
ortsys![GetTensorTypeAndShape(tensor_ptr, &mut tensor_info_ptr) -> OrtError::GetTensorTypeAndShape];
let res = f(tensor_info_ptr);
ortsys![ReleaseTensorTypeAndShapeInfo(tensor_info_ptr)];
res
}
#[cfg(unix)]
fn close_lib_handle(handle: *mut std::os::raw::c_void) {
unsafe { libc::dlclose(handle) };
}
#[cfg(windows)]
fn close_lib_handle(handle: *mut std::os::raw::c_void) {
unsafe { winapi::um::libloaderapi::FreeLibrary(handle as winapi::shared::minwindef::HINSTANCE) };
}
mod dangerous {
use super::*;
use crate::{ortfree, tensor::TensorElementDataType};
pub(super) fn extract_inputs_count(session_ptr: *mut sys::OrtSession) -> OrtResult<usize> {
let f = ort().SessionGetInputCount.unwrap();
extract_io_count(f, session_ptr)
}
pub(super) fn extract_outputs_count(session_ptr: *mut sys::OrtSession) -> OrtResult<usize> {
let f = ort().SessionGetOutputCount.unwrap();
extract_io_count(f, session_ptr)
}
#[cfg(target_arch = "x86_64")]
fn extract_io_count(
f: extern_system_fn! { unsafe fn(*const sys::OrtSession, *mut usize) -> *mut sys::OrtStatus },
session_ptr: *mut sys::OrtSession
) -> OrtResult<usize> {
let mut num_nodes = 0;
let status = unsafe { f(session_ptr, &mut num_nodes) };
status_to_result(status).map_err(OrtError::GetInOutCount)?;
assert_null_pointer(status, "SessionStatus")?;
(num_nodes != 0)
.then_some(())
.ok_or_else(|| OrtError::GetInOutCount(OrtApiError::Msg("No nodes in model".to_owned())))?;
Ok(num_nodes)
}
#[cfg(target_arch = "aarch64")]
fn extract_io_count(
f: extern_system_fn! { unsafe fn(*const sys::OrtSession, *mut u64) -> *mut sys::OrtStatus },
session_ptr: *mut sys::OrtSession
) -> OrtResult<usize> {
let mut num_nodes = 0;
let status = unsafe { f(session_ptr, &mut num_nodes) };
status_to_result(status).map_err(OrtError::GetInOutCount)?;
assert_null_pointer(status, "SessionStatus")?;
(num_nodes != 0)
.then_some(())
.ok_or_else(|| OrtError::GetInOutCount(OrtApiError::Msg("No nodes in model".to_owned())))?;
Ok(num_nodes as _)
}
fn extract_input_name(session_ptr: *mut sys::OrtSession, allocator_ptr: *mut sys::OrtAllocator, i: usize) -> OrtResult<String> {
let f = ort().SessionGetInputName.unwrap();
extract_io_name(f, session_ptr, allocator_ptr, i)
}
fn extract_output_name(session_ptr: *mut sys::OrtSession, allocator_ptr: *mut sys::OrtAllocator, i: usize) -> OrtResult<String> {
let f = ort().SessionGetOutputName.unwrap();
extract_io_name(f, session_ptr, allocator_ptr, i)
}
pub(crate) fn raw_pointer_to_string(allocator_ptr: *mut sys::OrtAllocator, c_str: *mut c_char) -> OrtResult<String> {
let name = char_p_to_string(c_str)?;
ortfree!(unsafe allocator_ptr, c_str);
Ok(name)
}
#[cfg(target_arch = "x86_64")]
fn extract_io_name(
f: extern_system_fn! { unsafe fn(
*const sys::OrtSession,
usize,
*mut sys::OrtAllocator,
*mut *mut c_char,
) -> *mut sys::OrtStatus },
session_ptr: *mut sys::OrtSession,
allocator_ptr: *mut sys::OrtAllocator,
i: usize
) -> OrtResult<String> {
let mut name_bytes: *mut c_char = std::ptr::null_mut();
let status = unsafe { f(session_ptr, i, allocator_ptr, &mut name_bytes) };
status_to_result(status).map_err(OrtError::GetInputName)?;
assert_non_null_pointer(name_bytes, "InputName")?;
raw_pointer_to_string(allocator_ptr, name_bytes)
}
#[cfg(target_arch = "aarch64")]
fn extract_io_name(
f: extern_system_fn! { unsafe fn(
*const sys::OrtSession,
u64,
*mut sys::OrtAllocator,
*mut *mut c_char,
) -> *mut sys::OrtStatus },
session_ptr: *mut sys::OrtSession,
allocator_ptr: *mut sys::OrtAllocator,
i: usize
) -> OrtResult<String> {
let mut name_bytes: *mut c_char = std::ptr::null_mut();
let status = unsafe { f(session_ptr, i as _, allocator_ptr, &mut name_bytes) };
status_to_result(status).map_err(OrtError::GetInputName)?;
assert_non_null_pointer(name_bytes, "InputName")?;
raw_pointer_to_string(allocator_ptr, name_bytes)
}
pub(super) fn extract_input(session_ptr: *mut sys::OrtSession, allocator_ptr: *mut sys::OrtAllocator, i: usize) -> OrtResult<Input> {
let input_name = extract_input_name(session_ptr, allocator_ptr, i)?;
let f = ort().SessionGetInputTypeInfo.unwrap();
let (input_type, dimensions) = extract_io(f, session_ptr, i)?;
Ok(Input {
name: input_name,
input_type,
dimensions
})
}
pub(super) fn extract_output(session_ptr: *mut sys::OrtSession, allocator_ptr: *mut sys::OrtAllocator, i: usize) -> OrtResult<Output> {
let output_name = extract_output_name(session_ptr, allocator_ptr, i)?;
let f = ort().SessionGetOutputTypeInfo.unwrap();
let (output_type, dimensions) = extract_io(f, session_ptr, i)?;
Ok(Output {
name: output_name,
output_type,
dimensions
})
}
#[cfg(target_arch = "x86_64")]
fn extract_io(
f: extern_system_fn! { unsafe fn(
*const sys::OrtSession,
usize,
*mut *mut sys::OrtTypeInfo,
) -> *mut sys::OrtStatus },
session_ptr: *mut sys::OrtSession,
i: usize
) -> OrtResult<(TensorElementDataType, Vec<Option<u32>>)> {
let mut typeinfo_ptr: *mut sys::OrtTypeInfo = std::ptr::null_mut();
let status = unsafe { f(session_ptr, i, &mut typeinfo_ptr) };
status_to_result(status).map_err(OrtError::GetTypeInfo)?;
assert_non_null_pointer(typeinfo_ptr, "TypeInfo")?;
let mut tensor_info_ptr: *const sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut();
ortsys![unsafe CastTypeInfoToTensorInfo(typeinfo_ptr, &mut tensor_info_ptr) -> OrtError::CastTypeInfoToTensorInfo; nonNull(tensor_info_ptr)];
let io_type: TensorElementDataType = unsafe { extract_data_type(tensor_info_ptr)? };
let node_dims = unsafe { get_tensor_dimensions(tensor_info_ptr)? };
ortsys![unsafe ReleaseTypeInfo(typeinfo_ptr)];
Ok((io_type, node_dims.into_iter().map(|d| if d == -1 { None } else { Some(d as u32) }).collect()))
}
#[cfg(target_arch = "aarch64")]
fn extract_io(
f: extern_system_fn! { unsafe fn(
*const sys::OrtSession,
u64,
*mut *mut sys::OrtTypeInfo,
) -> *mut sys::OrtStatus },
session_ptr: *mut sys::OrtSession,
i: usize
) -> OrtResult<(TensorElementDataType, Vec<Option<u32>>)> {
let mut typeinfo_ptr: *mut sys::OrtTypeInfo = std::ptr::null_mut();
let status = unsafe { f(session_ptr, i as _, &mut typeinfo_ptr) };
status_to_result(status).map_err(OrtError::GetTypeInfo)?;
assert_non_null_pointer(typeinfo_ptr, "TypeInfo")?;
let mut tensor_info_ptr: *const sys::OrtTensorTypeAndShapeInfo = std::ptr::null_mut();
ortsys![unsafe CastTypeInfoToTensorInfo(typeinfo_ptr, &mut tensor_info_ptr) -> OrtError::CastTypeInfoToTensorInfo; nonNull(tensor_info_ptr)];
let io_type: TensorElementDataType = unsafe { extract_data_type(tensor_info_ptr)? };
let node_dims = unsafe { get_tensor_dimensions(tensor_info_ptr)? };
ortsys![unsafe ReleaseTypeInfo(typeinfo_ptr)];
Ok((io_type, node_dims.into_iter().map(|d| if d == -1 { None } else { Some(d as u32) }).collect()))
}
}