#[cfg(feature = "std")]
use alloc::ffi::CString;
use alloc::{
boxed::Box,
format,
string::{String, ToString},
sync::Arc,
vec::Vec
};
use core::{
any::Any,
ffi::{CStr, c_char},
iter,
marker::PhantomData,
ops::{Deref, DerefMut},
ptr::{self, NonNull},
slice
};
use smallvec::SmallVec;
use crate::{
AsPointer,
environment::Environment,
error::{Error, ErrorCode, Result},
memory::Allocator,
ortsys,
util::{AllocatedString, STACK_SESSION_INPUTS, STACK_SESSION_OUTPUTS, with_cstr, with_cstr_ptr_array},
value::{DynValue, Outlet, Value, ValueType}
};
#[cfg(feature = "api-20")]
mod adapter;
#[cfg(all(feature = "std", not(target_arch = "wasm32")))]
mod r#async;
pub mod builder;
mod input;
mod io_binding;
mod metadata;
mod output;
mod run_options;
#[cfg(feature = "api-20")]
#[cfg_attr(docsrs, doc(cfg(feature = "api-20")))]
pub use self::adapter::Adapter;
#[cfg(all(feature = "std", not(target_arch = "wasm32")))]
pub use self::r#async::InferenceFut;
#[cfg(all(feature = "std", not(target_arch = "wasm32")))]
use self::r#async::{AsyncInferenceContext, InferenceFutInner};
use self::{builder::SessionBuilder, run_options::UntypedRunOptions};
pub use self::{
input::{SessionInputValue, SessionInputs},
io_binding::IoBinding,
metadata::ModelMetadata,
output::SessionOutputs,
run_options::{HasSelectedOutputs, NoSelectedOutputs, OutputSelector, RunOptions, SelectedOutputMarker}
};
#[derive(Debug)]
pub struct SharedSessionInner {
session_ptr: NonNull<ort_sys::OrtSession>,
pub(crate) allocator: Allocator,
_extras: SmallVec<[Arc<dyn Any>; 4]>,
_environment: Arc<Environment>
}
unsafe impl Send for SharedSessionInner {}
unsafe impl Sync for SharedSessionInner {}
impl AsPointer for SharedSessionInner {
type Sys = ort_sys::OrtSession;
fn ptr(&self) -> *const Self::Sys {
self.session_ptr.as_ptr()
}
}
impl Drop for SharedSessionInner {
fn drop(&mut self) {
ortsys![unsafe ReleaseSession(self.session_ptr.as_ptr())];
crate::logging::drop!(Session, self.session_ptr);
}
}
#[derive(Debug)]
pub struct Session {
pub(crate) inner: Arc<SharedSessionInner>,
inputs: Vec<Outlet>,
outputs: Vec<Outlet>
}
pub struct InMemorySession<'s> {
session: Session,
phantom: PhantomData<&'s ()>
}
impl Deref for InMemorySession<'_> {
type Target = Session;
fn deref(&self) -> &Self::Target {
&self.session
}
}
impl DerefMut for InMemorySession<'_> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.session
}
}
impl Session {
pub fn inputs(&self) -> &[Outlet] {
&self.inputs
}
pub fn outputs(&self) -> &[Outlet] {
&self.outputs
}
pub fn builder() -> Result<SessionBuilder> {
SessionBuilder::new()
}
#[must_use]
pub fn allocator(&self) -> &Allocator {
&self.inner.allocator
}
pub fn create_binding(&self) -> Result<IoBinding> {
IoBinding::new(self)
}
#[must_use]
pub fn inner(&self) -> Arc<SharedSessionInner> {
Arc::clone(&self.inner)
}
#[must_use]
pub fn overridable_initializers(&self) -> Vec<OverridableInitializer> {
let mut size = 0;
ortsys![unsafe SessionGetOverridableInitializerCount(self.ptr(), &mut size).expect("infallible")];
let allocator = Allocator::default();
(0..size)
.map(|i| {
let mut name: *mut c_char = ptr::null_mut();
ortsys![unsafe SessionGetOverridableInitializerName(self.ptr(), i, allocator.ptr().cast_mut(), &mut name).expect("infallible")];
let name = unsafe { CStr::from_ptr(name) }.to_string_lossy().into_owned();
let mut typeinfo_ptr: *mut ort_sys::OrtTypeInfo = ptr::null_mut();
ortsys![unsafe SessionGetOverridableInitializerTypeInfo(self.ptr(), i, &mut typeinfo_ptr).expect("infallible"); nonNull(typeinfo_ptr)];
let dtype = unsafe { ValueType::from_type_info(typeinfo_ptr) };
OverridableInitializer { name, dtype }
})
.collect()
}
#[cfg(not(target_arch = "wasm32"))]
pub fn run<'s, 'i, 'v: 'i, const N: usize>(&'s mut self, input_values: impl Into<SessionInputs<'i, 'v, N>>) -> Result<SessionOutputs<'s>> {
match input_values.into() {
SessionInputs::ValueSlice(input_values) => {
self.run_inner(self.inputs.iter().map(|input| input.name()).collect(), input_values.iter().collect(), None)
}
SessionInputs::ValueArray(input_values) => {
self.run_inner(self.inputs.iter().map(|input| input.name()).collect(), input_values.iter().collect(), None)
}
SessionInputs::ValueMap(input_values) => {
self.run_inner(input_values.iter().map(|(k, _)| k.as_ref()).collect(), input_values.iter().map(|(_, v)| v).collect(), None)
}
}
}
#[cfg(not(target_arch = "wasm32"))]
pub fn run_with_options<'r, 's: 'r, 'i, 'v: 'i, O: SelectedOutputMarker, const N: usize>(
&'s mut self,
input_values: impl Into<SessionInputs<'i, 'v, N>>,
run_options: &'r RunOptions<O>
) -> Result<SessionOutputs<'r>> {
match input_values.into() {
SessionInputs::ValueSlice(input_values) => {
self.run_inner(self.inputs.iter().map(|input| input.name()).collect(), input_values.iter().collect(), Some(&run_options.inner))
}
SessionInputs::ValueArray(input_values) => {
self.run_inner(self.inputs.iter().map(|input| input.name()).collect(), input_values.iter().collect(), Some(&run_options.inner))
}
SessionInputs::ValueMap(input_values) => {
self.run_inner(input_values.iter().map(|(k, _)| k.as_ref()).collect(), input_values.iter().map(|(_, v)| v).collect(), Some(&run_options.inner))
}
}
}
#[cfg(not(target_arch = "wasm32"))]
fn run_inner<'i, 'r, 's: 'r, 'v: 'i>(
&'s self,
input_names: SmallVec<[&str; STACK_SESSION_INPUTS]>,
input_values: SmallVec<[&'i SessionInputValue<'v>; STACK_SESSION_INPUTS]>,
run_options: Option<&'r UntypedRunOptions>
) -> Result<SessionOutputs<'r>> {
if input_values.len() > input_names.len() {
return Err(Error::new_with_code(
ErrorCode::InvalidArgument,
format!("{} inputs were provided, but the model only accepts {}.", input_values.len(), input_names.len())
));
}
let (output_names, mut output_tensors) = match run_options {
Some(r) => r.outputs.resolve_outputs(&self.outputs),
None => (self.outputs.iter().map(|o| o.name()).collect(), iter::repeat_with(|| None).take(self.outputs.len()).collect())
};
let output_value_ptrs: SmallVec<[*mut ort_sys::OrtValue; STACK_SESSION_OUTPUTS]> = output_tensors
.iter_mut()
.map(|c| match c {
Some(v) => v.ptr_mut(),
None => ptr::null_mut()
})
.collect();
let input_value_ptrs: SmallVec<[*const ort_sys::OrtValue; STACK_SESSION_INPUTS]> = input_values.iter().map(|c| c.ptr()).collect();
let run_options_ptr = if let Some(run_options) = &run_options { run_options.ptr.as_ptr() } else { ptr::null() };
with_cstr_ptr_array(&input_names, &|input_name_ptrs| {
with_cstr_ptr_array(&output_names, &|output_name_ptrs| {
ortsys![
unsafe Run(
self.inner.session_ptr.as_ptr(),
run_options_ptr,
input_name_ptrs.as_ptr(),
input_value_ptrs.as_ptr(),
input_value_ptrs.len(),
output_name_ptrs.as_ptr(),
output_name_ptrs.len(),
output_value_ptrs.as_ptr().cast_mut()
)?
];
Ok(())
})
})?;
let outputs = output_tensors
.into_iter()
.enumerate()
.map(|(i, v)| match v {
Some(value) => value,
None => unsafe {
Value::from_ptr(
NonNull::new(output_value_ptrs[i]).expect("OrtValue ptr returned from session Run should not be null"),
Some(Arc::clone(&self.inner))
)
}
})
.collect();
Ok(SessionOutputs::new(output_names, outputs))
}
#[cfg(not(target_arch = "wasm32"))]
pub fn run_binding<'b, 's: 'b>(&'s mut self, binding: &'b IoBinding) -> Result<SessionOutputs<'b>> {
self.run_binding_inner(binding, None)
}
#[cfg(not(target_arch = "wasm32"))]
pub fn run_binding_with_options<'r, 'b, 's: 'b>(
&'s mut self,
binding: &'b IoBinding,
run_options: &'r RunOptions<NoSelectedOutputs>
) -> Result<SessionOutputs<'b>> {
self.run_binding_inner(binding, Some(run_options))
}
#[cfg(not(target_arch = "wasm32"))]
fn run_binding_inner<'r, 'b, 's: 'b>(
&'s self,
binding: &'b IoBinding,
run_options: Option<&'r RunOptions<NoSelectedOutputs>>
) -> Result<SessionOutputs<'b>> {
use crate::util::run_on_drop;
let run_options_ptr = if let Some(run_options) = run_options { run_options.ptr() } else { ptr::null() };
ortsys![unsafe RunWithBinding(self.inner.ptr().cast_mut(), run_options_ptr, binding.ptr())?];
let mut count = binding.output_values.len();
if count > 0 {
let mut output_values_ptr: *mut *mut ort_sys::OrtValue = ptr::null_mut();
ortsys![unsafe GetBoundOutputValues(binding.ptr(), self.inner.allocator.ptr().cast_mut(), &mut output_values_ptr, &mut count)?; nonNull(output_values_ptr)];
let _guard = run_on_drop(|| unsafe {
self.inner.allocator.free(output_values_ptr.as_ptr());
});
let output_values = unsafe { slice::from_raw_parts(output_values_ptr.as_ptr(), count) }
.iter()
.map(|ptr| unsafe {
DynValue::from_ptr(NonNull::new(*ptr).expect("OrtValue ptrs returned by GetBoundOutputValues should not be null"), Some(self.inner()))
})
.collect();
Ok(SessionOutputs::new(binding.output_values.iter().map(|(k, _)| k.as_str()).collect(), output_values))
} else {
Ok(SessionOutputs::new_empty())
}
}
#[cfg(all(feature = "std", not(target_arch = "wasm32")))]
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
pub fn run_async<'r, 's: 'r, 'i, 'v: 'i + 's, O: SelectedOutputMarker, const N: usize>(
&'s mut self,
input_values: impl Into<SessionInputs<'i, 'v, N>>,
run_options: &'r RunOptions<O>
) -> Result<InferenceFut<'r, 'v>> {
match input_values.into() {
SessionInputs::ValueSlice(input_values) => {
self.run_inner_async(self.inputs.iter().map(|input| input.name()).collect(), input_values.iter().collect(), &run_options.inner)
}
SessionInputs::ValueArray(input_values) => {
self.run_inner_async(self.inputs.iter().map(|input| input.name()).collect(), input_values.iter().collect(), &run_options.inner)
}
SessionInputs::ValueMap(input_values) => {
self.run_inner_async(input_values.iter().map(|(k, _)| k.as_ref()).collect(), input_values.iter().map(|(_, v)| v).collect(), &run_options.inner)
}
}
}
#[cfg(all(feature = "std", not(target_arch = "wasm32")))]
fn run_inner_async<'i, 'r, 's: 'r, 'v: 'i + 's>(
&'s self,
input_names: SmallVec<[&str; STACK_SESSION_INPUTS]>,
input_values: SmallVec<[&SessionInputValue<'v>; STACK_SESSION_INPUTS]>,
run_options: &'r Arc<UntypedRunOptions>
) -> Result<InferenceFut<'r, 'v>> {
let input_name_ptrs = input_names
.into_iter()
.map(|name| CString::new(name.as_bytes()).map(|s| s.into_raw().cast_const()))
.collect::<Result<SmallVec<[*const c_char; STACK_SESSION_INPUTS]>, _>>()?;
let mut input_inner_holders = SmallVec::with_capacity(input_values.len());
let mut input_ort_values = SmallVec::with_capacity(input_values.len());
for input in input_values {
input_ort_values.push(input.ptr());
input_inner_holders.push(Arc::clone(match input {
SessionInputValue::ViewMut(v) => &(**v).inner,
SessionInputValue::View(v) => &(**v).inner,
SessionInputValue::Owned(v) => &v.inner
}));
}
let (output_names, mut output_tensors) = run_options.outputs.resolve_outputs(&self.outputs);
let output_name_ptrs = output_names
.iter()
.map(|n| CString::new(*n).unwrap_or_else(|_| unreachable!()))
.map(|n| n.into_raw().cast_const())
.collect();
let output_tensor_ptrs = output_tensors
.iter_mut()
.map(|c| match c {
Some(v) => v.ptr_mut(),
None => ptr::null_mut()
})
.collect();
let async_inner = Arc::new(InferenceFutInner::new(Arc::clone(run_options)));
let mut ctx = Box::<AsyncInferenceContext>::new_uninit();
unsafe {
use core::ptr::write;
let ctx = ctx.assume_init_mut();
write(&mut ctx.inner, Arc::clone(&async_inner));
write(&mut ctx.input_ort_values, input_ort_values);
write(&mut ctx._input_inner_holders, input_inner_holders);
write(&mut ctx.input_name_ptrs, input_name_ptrs);
write(&mut ctx.output_name_ptrs, output_name_ptrs);
write(&mut ctx.output_names, output_names);
write(&mut ctx.output_value_ptrs, output_tensor_ptrs);
write(&mut ctx.session_inner, &self.inner);
};
let ctx = Box::leak(unsafe { ctx.assume_init() });
crate::logging::create!(AsyncInferenceContext, ctx);
ortsys![
unsafe RunAsync(
self.inner.session_ptr.as_ptr(),
run_options.ptr.as_ptr(),
ctx.input_name_ptrs.as_ptr(),
ctx.input_ort_values.as_ptr(),
ctx.input_ort_values.len(),
ctx.output_name_ptrs.as_ptr(),
ctx.output_name_ptrs.len(),
ctx.output_value_ptrs.as_mut_ptr(),
Some(self::r#async::async_callback),
ctx as *mut _ as *mut ort_sys::c_void
)?
];
Ok(InferenceFut::new(async_inner))
}
#[cfg(target_arch = "wasm32")]
pub async fn run_async<'r, 's: 'r, 'i, 'v: 'i + 's, O: SelectedOutputMarker, const N: usize>(
&'s mut self,
input_values: impl Into<SessionInputs<'i, 'v, N>>,
run_options: &'r RunOptions<O>
) -> Result<SessionOutputs<'r>> {
match input_values.into() {
SessionInputs::ValueSlice(input_values) => {
self.run_inner_async(self.inputs.iter().map(|input| input.name()).collect(), input_values.iter().collect(), &run_options.inner)
.await
}
SessionInputs::ValueArray(input_values) => {
self.run_inner_async(self.inputs.iter().map(|input| input.name()).collect(), input_values.iter().collect(), &run_options.inner)
.await
}
SessionInputs::ValueMap(input_values) => {
self.run_inner_async(input_values.iter().map(|(k, _)| k.as_ref()).collect(), input_values.iter().map(|(_, v)| v).collect(), &run_options.inner)
.await
}
}
}
#[cfg(target_arch = "wasm32")]
async fn run_inner_async<'i, 'r, 's: 'r, 'v: 'i + 's>(
&'s self,
input_names: SmallVec<[&str; STACK_SESSION_INPUTS]>,
input_values: SmallVec<[&SessionInputValue<'v>; STACK_SESSION_INPUTS]>,
run_options: &'r UntypedRunOptions
) -> Result<SessionOutputs<'r>> {
if input_values.len() > input_names.len() {
return Err(Error::new_with_code(
ErrorCode::InvalidArgument,
format!("{} inputs were provided, but the model only accepts {}.", input_values.len(), input_names.len())
));
}
let (output_names, mut output_tensors) = run_options.outputs.resolve_outputs(&self.outputs);
let mut output_value_ptrs: SmallVec<[*mut ort_sys::OrtValue; STACK_SESSION_OUTPUTS]> = output_tensors
.iter_mut()
.map(|c| match c {
Some(v) => v.ptr_mut(),
None => ptr::null_mut()
})
.collect();
let input_value_ptrs: SmallVec<[*const ort_sys::OrtValue; STACK_SESSION_INPUTS]> = input_values.iter().map(|c| c.ptr()).collect();
let status = ortsys![
unsafe RunAsync(
self.inner.session_ptr.as_ptr(),
run_options.ptr.as_ptr(),
&input_names,
&input_value_ptrs,
&output_names,
&mut output_value_ptrs
)
]
.await;
unsafe { Error::result_from_status(status) }?;
let outputs = output_tensors
.into_iter()
.enumerate()
.map(|(i, v)| match v {
Some(value) => value,
None => unsafe {
Value::from_ptr(
NonNull::new(output_value_ptrs[i]).expect("OrtValue ptr returned from session Run should not be null"),
Some(Arc::clone(&self.inner))
)
}
})
.collect();
Ok(SessionOutputs::new(output_names, outputs))
}
pub fn metadata(&self) -> Result<ModelMetadata<'_>> {
let mut metadata_ptr: *mut ort_sys::OrtModelMetadata = ptr::null_mut();
ortsys![unsafe SessionGetModelMetadata(self.inner.session_ptr.as_ptr(), &mut metadata_ptr)?; nonNull(metadata_ptr)];
Ok(unsafe { ModelMetadata::new(metadata_ptr) })
}
pub fn profiling_start_ns(&self) -> Result<u64> {
let mut out = 0;
ortsys![unsafe SessionGetProfilingStartTimeNs(self.inner.session_ptr.as_ptr(), &mut out)?];
Ok(out)
}
pub fn end_profiling(&mut self) -> Result<String> {
let mut profiling_name: *mut c_char = ptr::null_mut();
ortsys![unsafe SessionEndProfiling(self.inner.session_ptr.as_ptr(), self.inner.allocator.ptr().cast_mut(), &mut profiling_name)?; nonNull(profiling_name)];
unsafe { AllocatedString::from_ptr(profiling_name.as_ptr(), &self.inner.allocator) }.map(|x| x.to_string())
}
#[cfg(feature = "api-20")]
#[cfg_attr(docsrs, doc(cfg(feature = "api-20")))]
pub fn set_workload_type(&mut self, workload_type: WorkloadType) -> Result<()> {
static KEY: &[u8] = b"ep.dynamic.workload_type\0";
match workload_type {
WorkloadType::Default => self.set_dynamic_option(KEY.as_ptr().cast(), c"Default".as_ptr().cast()),
WorkloadType::Efficient => self.set_dynamic_option(KEY.as_ptr().cast(), c"Efficient".as_ptr().cast())
}
}
#[cfg(feature = "api-20")]
pub(crate) fn set_dynamic_option(&mut self, key: *const c_char, value: *const c_char) -> Result<()> {
ortsys![unsafe SetEpDynamicOptions(self.inner.session_ptr.as_ptr(), &key, &value, 1)?];
Ok(())
}
#[cfg(feature = "api-22")]
#[cfg_attr(docsrs, doc(cfg(feature = "api-22")))]
pub fn opset_for_domain(&self, domain: impl AsRef<str>) -> Result<u32> {
with_cstr(domain.as_ref().as_bytes(), &|domain| {
let mut opset = 0;
ortsys![@editor: unsafe SessionGetOpsetForDomain(self.inner.session_ptr.as_ptr(), domain.as_ptr(), &mut opset)?];
Ok(opset as u32)
})
}
}
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq)]
pub enum WorkloadType {
#[default]
Default,
Efficient
}
unsafe impl Send for Session {}
unsafe impl Sync for Session {}
impl AsPointer for Session {
type Sys = ort_sys::OrtSession;
fn ptr(&self) -> *const Self::Sys {
self.inner.ptr()
}
}
#[derive(Debug, Clone)]
pub struct OverridableInitializer {
name: String,
dtype: ValueType
}
impl OverridableInitializer {
pub fn name(&self) -> &str {
&self.name
}
pub fn dtype(&self) -> &ValueType {
&self.dtype
}
}
pub(crate) mod io {
use super::*;
pub(super) fn extract_io_count(
f: unsafe extern "system" fn(*const ort_sys::OrtSession, *mut usize) -> ort_sys::OrtStatusPtr,
session_ptr: NonNull<ort_sys::OrtSession>
) -> Result<usize> {
let mut num_nodes = 0;
let status = unsafe { f(session_ptr.as_ptr(), &mut num_nodes) };
unsafe { Error::result_from_status(status) }?;
Ok(num_nodes)
}
fn extract_io_name(
f: unsafe extern "system" fn(*const ort_sys::OrtSession, usize, *mut ort_sys::OrtAllocator, *mut *mut c_char) -> ort_sys::OrtStatusPtr,
session_ptr: NonNull<ort_sys::OrtSession>,
allocator: &Allocator,
i: usize
) -> Result<String> {
let mut name_ptr: *mut c_char = ptr::null_mut();
let status = unsafe { f(session_ptr.as_ptr(), i, allocator.ptr().cast_mut(), &mut name_ptr) };
unsafe { Error::result_from_status(status) }?;
if name_ptr.is_null() {
crate::util::cold();
return Err(crate::Error::new("expected `name_ptr` to not be null"));
}
unsafe { AllocatedString::from_ptr(name_ptr, allocator) }.map(|x| x.to_string())
}
fn extract_io(
f: unsafe extern "system" fn(*const ort_sys::OrtSession, usize, *mut *mut ort_sys::OrtTypeInfo) -> ort_sys::OrtStatusPtr,
session_ptr: NonNull<ort_sys::OrtSession>,
i: usize
) -> Result<ValueType> {
let mut typeinfo_ptr: *mut ort_sys::OrtTypeInfo = ptr::null_mut();
let status = unsafe { f(session_ptr.as_ptr(), i, &mut typeinfo_ptr) };
unsafe { Error::result_from_status(status) }?;
let Some(typeinfo_ptr) = NonNull::new(typeinfo_ptr) else {
crate::util::cold();
return Err(crate::Error::new("expected `typeinfo_ptr` to not be null"));
};
Ok(unsafe { ValueType::from_type_info(typeinfo_ptr) })
}
pub(super) fn extract_input(session_ptr: NonNull<ort_sys::OrtSession>, allocator: &Allocator, i: usize) -> Result<Outlet> {
let name = extract_io_name(ortsys![SessionGetInputName], session_ptr, allocator, i)?;
let dtype = extract_io(ortsys![SessionGetInputTypeInfo], session_ptr, i)?;
Ok(Outlet::new(name, dtype))
}
pub(super) fn extract_output(session_ptr: NonNull<ort_sys::OrtSession>, allocator: &Allocator, i: usize) -> Result<Outlet> {
let name = extract_io_name(ortsys![SessionGetOutputName], session_ptr, allocator, i)?;
let dtype = extract_io(ortsys![SessionGetOutputTypeInfo], session_ptr, i)?;
Ok(Outlet::new(name, dtype))
}
}